bitwarden_encoding/
b64.rs

1use std::str::FromStr;
2
3use data_encoding::BASE64;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6#[cfg(feature = "wasm")]
7use tsify::Tsify;
8
9use crate::FromStrVisitor;
10
11/// Base64 encoded data
12///
13/// Is indifferent about padding when decoding, but always produces padding when encoding.
14#[cfg(feature = "wasm")]
15#[derive(Debug, Serialize, Clone, Hash, PartialEq, Eq)]
16#[serde(into = "String")]
17#[derive(Tsify)]
18#[tsify(into_wasm_abi, from_wasm_abi)]
19pub struct B64(#[tsify(type = "String")] Vec<u8>);
20
21/// Base64 encoded data
22///
23/// Is indifferent about padding when decoding, but always produces padding when encoding.
24#[cfg(not(feature = "wasm"))]
25#[derive(Debug, Serialize, Clone, Hash, PartialEq, Eq)]
26#[serde(into = "String")]
27pub struct B64(Vec<u8>);
28
29// We manually implement this to handle both `String` and `&str`
30impl<'de> Deserialize<'de> for B64 {
31    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
32    where
33        D: serde::Deserializer<'de>,
34    {
35        deserializer.deserialize_str(FromStrVisitor::new())
36    }
37}
38
39impl From<Vec<u8>> for B64 {
40    fn from(src: Vec<u8>) -> Self {
41        Self(src)
42    }
43}
44impl From<&[u8]> for B64 {
45    fn from(src: &[u8]) -> Self {
46        Self(src.to_vec())
47    }
48}
49
50impl From<B64> for Vec<u8> {
51    fn from(src: B64) -> Self {
52        src.0
53    }
54}
55
56impl AsRef<[u8]> for B64 {
57    fn as_ref(&self) -> &[u8] {
58        &self.0
59    }
60}
61
62impl From<B64> for String {
63    fn from(src: B64) -> Self {
64        String::from(&src)
65    }
66}
67
68impl From<&B64> for String {
69    fn from(src: &B64) -> Self {
70        BASE64.encode(&src.0)
71    }
72}
73
74impl std::fmt::Display for B64 {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        f.write_str(String::from(self).as_str())
77    }
78}
79
80/// An error returned when a string is not base64 decodable.
81#[derive(Debug, Error)]
82#[error("Data isn't base64 encoded")]
83pub struct NotB64Encoded;
84
85const BASE64_PERMISSIVE: data_encoding::Encoding = data_encoding_macro::new_encoding! {
86    symbols: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
87    padding: None,
88    check_trailing_bits: false,
89};
90const BASE64_PADDING: &str = "=";
91
92impl TryFrom<String> for B64 {
93    type Error = NotB64Encoded;
94
95    fn try_from(value: String) -> Result<Self, Self::Error> {
96        let sane_string = value.trim_end_matches(BASE64_PADDING);
97        BASE64_PERMISSIVE
98            .decode(sane_string.as_bytes())
99            .map(Self)
100            .map_err(|_| NotB64Encoded)
101    }
102}
103
104impl TryFrom<&str> for B64 {
105    type Error = NotB64Encoded;
106
107    fn try_from(value: &str) -> Result<Self, Self::Error> {
108        let sane_string = value.trim_end_matches(BASE64_PADDING);
109        BASE64_PERMISSIVE
110            .decode(sane_string.as_bytes())
111            .map(Self)
112            .map_err(|_| NotB64Encoded)
113    }
114}
115
116impl FromStr for B64 {
117    type Err = NotB64Encoded;
118
119    fn from_str(s: &str) -> Result<Self, Self::Err> {
120        Self::try_from(s)
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn test_b64_from_vec() {
130        let data = vec![72, 101, 108, 108, 111];
131        let b64 = B64::from(data.clone());
132        assert_eq!(Vec::<u8>::from(b64), data);
133    }
134
135    #[test]
136    fn test_b64_from_slice() {
137        let data = b"Hello";
138        let b64 = B64::from(data.as_slice());
139        assert_eq!(b64.as_ref(), data);
140    }
141
142    #[test]
143    fn test_b64_encoding_with_padding() {
144        let data = b"Hello, World!";
145        let b64 = B64::from(data.as_slice());
146        let encoded = String::from(&b64);
147        assert_eq!(encoded, "SGVsbG8sIFdvcmxkIQ==");
148        assert!(encoded.contains('='));
149    }
150
151    #[test]
152    fn test_b64_decoding_with_padding() {
153        let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
154        let b64 = B64::try_from(encoded_with_padding).unwrap();
155        assert_eq!(b64.as_ref(), b"Hello, World!");
156    }
157
158    #[test]
159    fn test_b64_decoding_without_padding() {
160        let encoded_without_padding = "SGVsbG8sIFdvcmxkIQ";
161        let b64 = B64::try_from(encoded_without_padding).unwrap();
162        assert_eq!(b64.as_ref(), b"Hello, World!");
163    }
164
165    #[test]
166    fn test_b64_round_trip_with_padding() {
167        let original = b"Test data that requires padding!";
168        let b64 = B64::from(original.as_slice());
169        let encoded = String::from(&b64);
170        let decoded = B64::try_from(encoded.as_str()).unwrap();
171        assert_eq!(decoded.as_ref(), original);
172    }
173
174    #[test]
175    fn test_b64_round_trip_without_padding() {
176        let original = b"Test data";
177        let b64 = B64::from(original.as_slice());
178        let encoded = String::from(&b64);
179        let decoded = B64::try_from(encoded.as_str()).unwrap();
180        assert_eq!(decoded.as_ref(), original);
181    }
182
183    #[test]
184    fn test_b64_display() {
185        let data = b"Hello";
186        let b64 = B64::from(data.as_slice());
187        assert_eq!(b64.to_string(), "SGVsbG8=");
188    }
189
190    #[test]
191    fn test_b64_invalid_encoding() {
192        let invalid_b64 = "This is not base64!@#$";
193        let result = B64::try_from(invalid_b64);
194        assert!(result.is_err());
195    }
196
197    #[test]
198    fn test_b64_empty_string() {
199        let empty = "";
200        let b64 = B64::try_from(empty).unwrap();
201        assert_eq!(b64.as_ref().len(), 0);
202    }
203
204    #[test]
205    fn test_b64_padding_removal() {
206        let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
207        let b64 = B64::try_from(encoded_with_padding).unwrap();
208        assert_eq!(b64.as_ref(), b"Hello, World!");
209    }
210
211    #[test]
212    fn test_b64_serialization() {
213        let data = b"serialization test";
214        let b64 = B64::from(data.as_slice());
215
216        let serialized = serde_json::to_string(&b64).unwrap();
217        assert_eq!(serialized, "\"c2VyaWFsaXphdGlvbiB0ZXN0\"");
218
219        let deserialized: B64 = serde_json::from_str(&serialized).unwrap();
220        assert_eq!(b64.as_ref(), deserialized.as_ref());
221    }
222
223    #[test]
224    fn test_not_b64_encoded_error_display() {
225        let error = NotB64Encoded;
226        assert_eq!(error.to_string(), "Data isn't base64 encoded");
227    }
228
229    #[test]
230    fn test_b64_from_str() {
231        let encoded = "SGVsbG8sIFdvcmxkIQ==";
232        let b64: B64 = encoded.parse().unwrap();
233        assert_eq!(b64.as_ref(), b"Hello, World!");
234    }
235
236    #[test]
237    fn test_b64_eq_and_hash() {
238        let data1 = b"test data";
239        let data2 = b"test data";
240        let data3 = b"different data";
241
242        let b64_1 = B64::from(data1.as_slice());
243        let b64_2 = B64::from(data2.as_slice());
244        let b64_3 = B64::from(data3.as_slice());
245
246        assert_eq!(b64_1, b64_2);
247        assert_ne!(b64_1, b64_3);
248
249        use std::{
250            collections::hash_map::DefaultHasher,
251            hash::{Hash, Hasher},
252        };
253
254        let mut hasher1 = DefaultHasher::new();
255        let mut hasher2 = DefaultHasher::new();
256
257        b64_1.hash(&mut hasher1);
258        b64_2.hash(&mut hasher2);
259
260        assert_eq!(hasher1.finish(), hasher2.finish());
261    }
262}