1use std::fmt::DebugStruct;
2
3use ciborium::Value;
4
5use crate::{
6 KEY_ID_SIZE,
7 cose::{
8 CONTAINED_KEY_ID, ContentNamespace, SAFE_CONTENT_NAMESPACE, SAFE_OBJECT_NAMESPACE,
9 SafeObjectNamespace, extract_bytes, extract_integer,
10 },
11 keys::KeyId,
12};
13
14#[derive(Debug)]
15pub(super) enum ExtractionError {
16 MissingNamespace,
17 InvalidNamespace,
18}
19
20pub(super) fn extract_safe_object_namespace(
21 header: &coset::Header,
22) -> Result<SafeObjectNamespace, ExtractionError> {
23 match extract_integer(header, SAFE_OBJECT_NAMESPACE, "safe object namespace") {
24 Ok(value) => value
25 .try_into()
26 .map_err(|_| ExtractionError::InvalidNamespace),
27 Err(_) => Err(ExtractionError::MissingNamespace),
28 }
29}
30
31pub(super) fn extract_safe_content_namespace<T: ContentNamespace>(
32 header: &coset::Header,
33) -> Result<T, ExtractionError> {
34 match extract_integer(header, SAFE_CONTENT_NAMESPACE, "safe content namespace") {
35 Ok(value) => value
36 .try_into()
37 .map_err(|_| ExtractionError::InvalidNamespace),
38 Err(_) => Err(ExtractionError::MissingNamespace),
39 }
40}
41
42pub(super) fn debug_fmt<C: ContentNamespace>(
43 debug_struct: &mut DebugStruct,
44 header: &coset::Header,
45) {
46 if let Ok(object_namespace) = extract_safe_object_namespace(header) {
47 debug_struct.field("object_namespace", &object_namespace);
48 }
49 if let Ok(content_namespace) = extract_safe_content_namespace::<C>(header) {
50 debug_struct.field("content_namespace", &content_namespace);
51 }
52}
53
54fn set_header_value(header: &mut coset::Header, label: i64, value: Value) {
55 if let Some((_, existing_value)) =
56 header
57 .rest
58 .iter_mut()
59 .find(|(existing_label, _)| matches!(existing_label, coset::Label::Int(existing) if *existing == label))
60 {
61 *existing_value = value;
62 } else {
63 header.rest.push((coset::Label::Int(label), value));
64 }
65}
66
67pub(super) fn set_safe_namespaces<T: ContentNamespace>(
68 header: &mut coset::Header,
69 object_namespace: SafeObjectNamespace,
70 content_namespace: T,
71) {
72 set_header_value(
73 header,
74 SAFE_OBJECT_NAMESPACE,
75 Value::from(i128::from(object_namespace)),
76 );
77 set_header_value(
78 header,
79 SAFE_CONTENT_NAMESPACE,
80 Value::from(content_namespace.into()),
81 );
82}
83
84pub(super) fn validate_safe_namespaces<T: ContentNamespace>(
89 header: &coset::Header,
90 expected_object_namespace: SafeObjectNamespace,
91 expected_content_namespace: T,
92) -> Result<(), ExtractionError> {
93 match extract_safe_object_namespace(header) {
94 Ok(ns) if ns == expected_object_namespace => (),
95 Ok(_) => return Err(ExtractionError::InvalidNamespace),
97 Err(ExtractionError::MissingNamespace) => (),
99 Err(ExtractionError::InvalidNamespace) => return Err(ExtractionError::InvalidNamespace),
102 }
103
104 match extract_safe_content_namespace::<T>(header) {
105 Ok(ns) if ns == expected_content_namespace => Ok(()),
106 Ok(_) => Err(ExtractionError::InvalidNamespace),
108 Err(ExtractionError::MissingNamespace) => Ok(()),
110 Err(ExtractionError::InvalidNamespace) => Err(ExtractionError::InvalidNamespace),
113 }
114}
115
116pub(super) fn extract_contained_key_id(header: &coset::Header) -> Result<Option<KeyId>, ()> {
118 let key_id_bytes = extract_bytes(header, CONTAINED_KEY_ID, "key id");
119
120 if let Ok(bytes) = key_id_bytes {
121 let key_id_array: [u8; KEY_ID_SIZE] = bytes.as_slice().try_into().map_err(|_| ())?;
122 Ok(Some(KeyId::from(key_id_array)))
123 } else {
124 Ok(None)
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use ciborium::Value;
131
132 use super::*;
133 use crate::{cose::SAFE_OBJECT_NAMESPACE, safe::DataEnvelopeNamespace};
134
135 fn count_label(header: &coset::Header, label: i64) -> usize {
136 header
137 .rest
138 .iter()
139 .filter(
140 |(existing_label, _)| {
141 matches!(existing_label, coset::Label::Int(existing) if *existing == label)
142 },
143 )
144 .count()
145 }
146
147 fn extract_safe_namespaces<T: ContentNamespace>(
148 header: &coset::Header,
149 ) -> Result<(SafeObjectNamespace, T), ExtractionError> {
150 let object_namespace = extract_safe_object_namespace(header)?;
151 let content_namespace = extract_safe_content_namespace(header)?;
152
153 Ok((object_namespace, content_namespace))
154 }
155
156 #[test]
157 fn set_safe_namespaces_sets_both_namespace_labels() {
158 let mut header = coset::HeaderBuilder::new().build();
159
160 set_safe_namespaces(
161 &mut header,
162 SafeObjectNamespace::DataEnvelope,
163 DataEnvelopeNamespace::ExampleNamespace,
164 );
165
166 let extracted = extract_safe_namespaces::<DataEnvelopeNamespace>(&header);
167 assert!(matches!(
168 extracted,
169 Ok((
170 SafeObjectNamespace::DataEnvelope,
171 DataEnvelopeNamespace::ExampleNamespace
172 ))
173 ));
174 }
175
176 #[test]
177 fn set_safe_namespaces_overwrites_existing_namespace_values() {
178 let mut header = coset::HeaderBuilder::new()
179 .value(SAFE_OBJECT_NAMESPACE, Value::from(999_i64))
180 .value(SAFE_CONTENT_NAMESPACE, Value::from(999_i64))
181 .build();
182
183 set_safe_namespaces(
184 &mut header,
185 SafeObjectNamespace::DataEnvelope,
186 DataEnvelopeNamespace::ExampleNamespace,
187 );
188
189 assert_eq!(count_label(&header, SAFE_OBJECT_NAMESPACE), 1);
190 assert_eq!(count_label(&header, SAFE_CONTENT_NAMESPACE), 1);
191 assert!(matches!(
192 extract_safe_namespaces::<DataEnvelopeNamespace>(&header),
193 Ok((
194 SafeObjectNamespace::DataEnvelope,
195 DataEnvelopeNamespace::ExampleNamespace
196 ))
197 ));
198 }
199
200 #[test]
201 fn extract_safe_namespaces_fails_when_namespace_missing() {
202 let header = coset::HeaderBuilder::new().build();
203
204 assert!(matches!(
205 extract_safe_namespaces::<DataEnvelopeNamespace>(&header),
206 Err(ExtractionError::MissingNamespace)
207 ));
208 }
209
210 #[test]
211 fn extract_safe_namespaces_fails_when_namespace_invalid() {
212 let header = coset::HeaderBuilder::new()
213 .value(
214 SAFE_OBJECT_NAMESPACE,
215 Value::from(SafeObjectNamespace::DataEnvelope as i64),
216 )
217 .value(SAFE_CONTENT_NAMESPACE, Value::from(999_i64))
218 .build();
219
220 assert!(matches!(
221 extract_safe_namespaces::<DataEnvelopeNamespace>(&header),
222 Err(ExtractionError::InvalidNamespace)
223 ));
224 }
225
226 #[test]
227 fn validate_safe_namespaces_allows_missing_labels_for_backwards_compat() {
228 let header = coset::HeaderBuilder::new().build();
229
230 let result = validate_safe_namespaces(
231 &header,
232 SafeObjectNamespace::DataEnvelope,
233 DataEnvelopeNamespace::ExampleNamespace,
234 );
235 assert!(result.is_ok());
236 }
237
238 #[test]
239 fn validate_safe_namespaces_rejects_namespace_mismatch() {
240 let mut header = coset::HeaderBuilder::new().build();
241 set_safe_namespaces(
242 &mut header,
243 SafeObjectNamespace::DataEnvelope,
244 DataEnvelopeNamespace::ExampleNamespace,
245 );
246
247 let result = validate_safe_namespaces(
248 &header,
249 SafeObjectNamespace::DataEnvelope,
250 DataEnvelopeNamespace::ExampleNamespace2,
251 );
252 assert!(matches!(result, Err(ExtractionError::InvalidNamespace)));
253 }
254}