Skip to main content

bitwarden_crypto/safe/
helpers.rs

1use std::fmt::DebugStruct;
2
3use ciborium::Value;
4
5use crate::{
6    KEY_ID_SIZE,
7    cose::{
8        CONTAINED_KEY_ID, ContentNamespace, SAFE_CONTENT_NAMESPACE, SAFE_OBJECT_NAMESPACE,
9        SafeObjectNamespace, extract_bytes, extract_integer,
10    },
11    keys::KeyId,
12};
13
14#[derive(Debug)]
15pub(super) enum ExtractionError {
16    MissingNamespace,
17    InvalidNamespace,
18}
19
20pub(super) fn extract_safe_object_namespace(
21    header: &coset::Header,
22) -> Result<SafeObjectNamespace, ExtractionError> {
23    match extract_integer(header, SAFE_OBJECT_NAMESPACE, "safe object namespace") {
24        Ok(value) => value
25            .try_into()
26            .map_err(|_| ExtractionError::InvalidNamespace),
27        Err(_) => Err(ExtractionError::MissingNamespace),
28    }
29}
30
31pub(super) fn extract_safe_content_namespace<T: ContentNamespace>(
32    header: &coset::Header,
33) -> Result<T, ExtractionError> {
34    match extract_integer(header, SAFE_CONTENT_NAMESPACE, "safe content namespace") {
35        Ok(value) => value
36            .try_into()
37            .map_err(|_| ExtractionError::InvalidNamespace),
38        Err(_) => Err(ExtractionError::MissingNamespace),
39    }
40}
41
42pub(super) fn debug_fmt<C: ContentNamespace>(
43    debug_struct: &mut DebugStruct,
44    header: &coset::Header,
45) {
46    if let Ok(object_namespace) = extract_safe_object_namespace(header) {
47        debug_struct.field("object_namespace", &object_namespace);
48    }
49    if let Ok(content_namespace) = extract_safe_content_namespace::<C>(header) {
50        debug_struct.field("content_namespace", &content_namespace);
51    }
52}
53
54fn set_header_value(header: &mut coset::Header, label: i64, value: Value) {
55    if let Some((_, existing_value)) =
56        header
57            .rest
58            .iter_mut()
59            .find(|(existing_label, _)| matches!(existing_label, coset::Label::Int(existing) if *existing == label))
60    {
61        *existing_value = value;
62    } else {
63        header.rest.push((coset::Label::Int(label), value));
64    }
65}
66
67pub(super) fn set_safe_namespaces<T: ContentNamespace>(
68    header: &mut coset::Header,
69    object_namespace: SafeObjectNamespace,
70    content_namespace: T,
71) {
72    set_header_value(
73        header,
74        SAFE_OBJECT_NAMESPACE,
75        Value::from(i128::from(object_namespace)),
76    );
77    set_header_value(
78        header,
79        SAFE_CONTENT_NAMESPACE,
80        Value::from(content_namespace.into()),
81    );
82}
83
84/// Validates the provided header contains the expected object and content namespace.
85/// For backward compatibility, missing values are OK, but incorrect values are not.
86/// The validation happens individually for both namespace layers, and either one
87/// missing with the other being present is OK.
88pub(super) fn validate_safe_namespaces<T: ContentNamespace>(
89    header: &coset::Header,
90    expected_object_namespace: SafeObjectNamespace,
91    expected_content_namespace: T,
92) -> Result<(), ExtractionError> {
93    match extract_safe_object_namespace(header) {
94        Ok(ns) if ns == expected_object_namespace => (),
95        // If the namespace is present but doesn't match, return an error immediately.
96        Ok(_) => return Err(ExtractionError::InvalidNamespace),
97        // If the namespace is missing, do not validate for backward compatibility
98        Err(ExtractionError::MissingNamespace) => (),
99        // If the namespace is present but invalid (e.g., not an integer or out of range), return an
100        // error.
101        Err(ExtractionError::InvalidNamespace) => return Err(ExtractionError::InvalidNamespace),
102    }
103
104    match extract_safe_content_namespace::<T>(header) {
105        Ok(ns) if ns == expected_content_namespace => Ok(()),
106        // If the namespace is present but doesn't match, return an error immediately.
107        Ok(_) => Err(ExtractionError::InvalidNamespace),
108        // If the namespace is missing, do not validate for backward compatibility
109        Err(ExtractionError::MissingNamespace) => Ok(()),
110        // If the namespace is present but invalid (e.g., not an integer or out of range), return an
111        // error.
112        Err(ExtractionError::InvalidNamespace) => Err(ExtractionError::InvalidNamespace),
113    }
114}
115
116/// Extract the contained key ID from a COSE header, if present.
117pub(super) fn extract_contained_key_id(header: &coset::Header) -> Result<Option<KeyId>, ()> {
118    let key_id_bytes = extract_bytes(header, CONTAINED_KEY_ID, "key id");
119
120    if let Ok(bytes) = key_id_bytes {
121        let key_id_array: [u8; KEY_ID_SIZE] = bytes.as_slice().try_into().map_err(|_| ())?;
122        Ok(Some(KeyId::from(key_id_array)))
123    } else {
124        Ok(None)
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use ciborium::Value;
131
132    use super::*;
133    use crate::{cose::SAFE_OBJECT_NAMESPACE, safe::DataEnvelopeNamespace};
134
135    fn count_label(header: &coset::Header, label: i64) -> usize {
136        header
137            .rest
138            .iter()
139            .filter(
140                |(existing_label, _)| {
141                    matches!(existing_label, coset::Label::Int(existing) if *existing == label)
142                },
143            )
144            .count()
145    }
146
147    fn extract_safe_namespaces<T: ContentNamespace>(
148        header: &coset::Header,
149    ) -> Result<(SafeObjectNamespace, T), ExtractionError> {
150        let object_namespace = extract_safe_object_namespace(header)?;
151        let content_namespace = extract_safe_content_namespace(header)?;
152
153        Ok((object_namespace, content_namespace))
154    }
155
156    #[test]
157    fn set_safe_namespaces_sets_both_namespace_labels() {
158        let mut header = coset::HeaderBuilder::new().build();
159
160        set_safe_namespaces(
161            &mut header,
162            SafeObjectNamespace::DataEnvelope,
163            DataEnvelopeNamespace::ExampleNamespace,
164        );
165
166        let extracted = extract_safe_namespaces::<DataEnvelopeNamespace>(&header);
167        assert!(matches!(
168            extracted,
169            Ok((
170                SafeObjectNamespace::DataEnvelope,
171                DataEnvelopeNamespace::ExampleNamespace
172            ))
173        ));
174    }
175
176    #[test]
177    fn set_safe_namespaces_overwrites_existing_namespace_values() {
178        let mut header = coset::HeaderBuilder::new()
179            .value(SAFE_OBJECT_NAMESPACE, Value::from(999_i64))
180            .value(SAFE_CONTENT_NAMESPACE, Value::from(999_i64))
181            .build();
182
183        set_safe_namespaces(
184            &mut header,
185            SafeObjectNamespace::DataEnvelope,
186            DataEnvelopeNamespace::ExampleNamespace,
187        );
188
189        assert_eq!(count_label(&header, SAFE_OBJECT_NAMESPACE), 1);
190        assert_eq!(count_label(&header, SAFE_CONTENT_NAMESPACE), 1);
191        assert!(matches!(
192            extract_safe_namespaces::<DataEnvelopeNamespace>(&header),
193            Ok((
194                SafeObjectNamespace::DataEnvelope,
195                DataEnvelopeNamespace::ExampleNamespace
196            ))
197        ));
198    }
199
200    #[test]
201    fn extract_safe_namespaces_fails_when_namespace_missing() {
202        let header = coset::HeaderBuilder::new().build();
203
204        assert!(matches!(
205            extract_safe_namespaces::<DataEnvelopeNamespace>(&header),
206            Err(ExtractionError::MissingNamespace)
207        ));
208    }
209
210    #[test]
211    fn extract_safe_namespaces_fails_when_namespace_invalid() {
212        let header = coset::HeaderBuilder::new()
213            .value(
214                SAFE_OBJECT_NAMESPACE,
215                Value::from(SafeObjectNamespace::DataEnvelope as i64),
216            )
217            .value(SAFE_CONTENT_NAMESPACE, Value::from(999_i64))
218            .build();
219
220        assert!(matches!(
221            extract_safe_namespaces::<DataEnvelopeNamespace>(&header),
222            Err(ExtractionError::InvalidNamespace)
223        ));
224    }
225
226    #[test]
227    fn validate_safe_namespaces_allows_missing_labels_for_backwards_compat() {
228        let header = coset::HeaderBuilder::new().build();
229
230        let result = validate_safe_namespaces(
231            &header,
232            SafeObjectNamespace::DataEnvelope,
233            DataEnvelopeNamespace::ExampleNamespace,
234        );
235        assert!(result.is_ok());
236    }
237
238    #[test]
239    fn validate_safe_namespaces_rejects_namespace_mismatch() {
240        let mut header = coset::HeaderBuilder::new().build();
241        set_safe_namespaces(
242            &mut header,
243            SafeObjectNamespace::DataEnvelope,
244            DataEnvelopeNamespace::ExampleNamespace,
245        );
246
247        let result = validate_safe_namespaces(
248            &header,
249            SafeObjectNamespace::DataEnvelope,
250            DataEnvelopeNamespace::ExampleNamespace2,
251        );
252        assert!(matches!(result, Err(ExtractionError::InvalidNamespace)));
253    }
254}