bitwarden_encoding/
b64.rs

1use std::str::FromStr;
2
3use data_encoding::BASE64;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6#[cfg(feature = "wasm")]
7use tsify::Tsify;
8
9use crate::FromStrVisitor;
10
11/// Base64 encoded data
12///
13/// Is indifferent about padding when decoding, but always produces padding when encoding.
14#[cfg(feature = "wasm")]
15#[derive(Debug, Serialize, Clone, Hash, PartialEq, Eq)]
16#[serde(into = "String")]
17#[derive(Tsify)]
18#[tsify(into_wasm_abi, from_wasm_abi)]
19pub struct B64(#[tsify(type = "string")] Vec<u8>);
20
21/// Base64 encoded data
22///
23/// Is indifferent about padding when decoding, but always produces padding when encoding.
24#[cfg(not(feature = "wasm"))]
25#[derive(Debug, Serialize, Clone, Hash, PartialEq, Eq)]
26#[serde(into = "String")]
27pub struct B64(Vec<u8>);
28
29impl B64 {
30    /// Returns a byte slice of the inner vector.
31    pub fn as_bytes(&self) -> &[u8] {
32        &self.0
33    }
34}
35
36// We manually implement this to handle both `String` and `&str`
37impl<'de> Deserialize<'de> for B64 {
38    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
39    where
40        D: serde::Deserializer<'de>,
41    {
42        deserializer.deserialize_str(FromStrVisitor::new())
43    }
44}
45
46impl From<Vec<u8>> for B64 {
47    fn from(src: Vec<u8>) -> Self {
48        Self(src)
49    }
50}
51impl From<&[u8]> for B64 {
52    fn from(src: &[u8]) -> Self {
53        Self(src.to_vec())
54    }
55}
56
57impl From<B64> for Vec<u8> {
58    fn from(src: B64) -> Self {
59        src.0
60    }
61}
62
63impl From<B64> for String {
64    fn from(src: B64) -> Self {
65        String::from(&src)
66    }
67}
68
69impl From<&B64> for String {
70    fn from(src: &B64) -> Self {
71        BASE64.encode(&src.0)
72    }
73}
74
75impl std::fmt::Display for B64 {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        f.write_str(String::from(self).as_str())
78    }
79}
80
81/// An error returned when a string is not base64 decodable.
82#[derive(Debug, Error)]
83#[error("Data isn't base64 encoded")]
84pub struct NotB64Encoded;
85
86const BASE64_PERMISSIVE: data_encoding::Encoding = data_encoding_macro::new_encoding! {
87    symbols: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
88    padding: None,
89    check_trailing_bits: false,
90};
91const BASE64_PADDING: &str = "=";
92
93impl TryFrom<String> for B64 {
94    type Error = NotB64Encoded;
95
96    fn try_from(value: String) -> Result<Self, Self::Error> {
97        Self::try_from(value.as_str())
98    }
99}
100
101impl TryFrom<&str> for B64 {
102    type Error = NotB64Encoded;
103
104    fn try_from(value: &str) -> Result<Self, Self::Error> {
105        let sane_string = value.trim_end_matches(BASE64_PADDING);
106        BASE64_PERMISSIVE
107            .decode(sane_string.as_bytes())
108            .map(Self)
109            .map_err(|_| NotB64Encoded)
110    }
111}
112
113impl FromStr for B64 {
114    type Err = NotB64Encoded;
115
116    fn from_str(s: &str) -> Result<Self, Self::Err> {
117        Self::try_from(s)
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    #[test]
126    fn test_b64_from_vec() {
127        let data = vec![72, 101, 108, 108, 111];
128        let b64 = B64::from(data.clone());
129        assert_eq!(Vec::<u8>::from(b64), data);
130    }
131
132    #[test]
133    fn test_b64_from_slice() {
134        let data = b"Hello";
135        let b64 = B64::from(data.as_slice());
136        assert_eq!(b64.as_bytes(), data);
137    }
138
139    #[test]
140    fn test_b64_encoding_with_padding() {
141        let data = b"Hello, World!";
142        let b64 = B64::from(data.as_slice());
143        let encoded = String::from(&b64);
144        assert_eq!(encoded, "SGVsbG8sIFdvcmxkIQ==");
145        assert!(encoded.contains('='));
146    }
147
148    #[test]
149    fn test_b64_decoding_with_padding() {
150        let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
151        let b64 = B64::try_from(encoded_with_padding).unwrap();
152        assert_eq!(b64.as_bytes(), b"Hello, World!");
153    }
154
155    #[test]
156    fn test_b64_decoding_without_padding() {
157        let encoded_without_padding = "SGVsbG8sIFdvcmxkIQ";
158        let b64 = B64::try_from(encoded_without_padding).unwrap();
159        assert_eq!(b64.as_bytes(), b"Hello, World!");
160    }
161
162    #[test]
163    fn test_b64_round_trip_with_padding() {
164        let original = b"Test data that requires padding!";
165        let b64 = B64::from(original.as_slice());
166        let encoded = String::from(&b64);
167        let decoded = B64::try_from(encoded.as_str()).unwrap();
168        assert_eq!(decoded.as_bytes(), original);
169    }
170
171    #[test]
172    fn test_b64_round_trip_without_padding() {
173        let original = b"Test data";
174        let b64 = B64::from(original.as_slice());
175        let encoded = String::from(&b64);
176        let decoded = B64::try_from(encoded.as_str()).unwrap();
177        assert_eq!(decoded.as_bytes(), original);
178    }
179
180    #[test]
181    fn test_b64_display() {
182        let data = b"Hello";
183        let b64 = B64::from(data.as_slice());
184        assert_eq!(b64.to_string(), "SGVsbG8=");
185    }
186
187    #[test]
188    fn test_b64_invalid_encoding() {
189        let invalid_b64 = "This is not base64!@#$";
190        let result = B64::try_from(invalid_b64);
191        assert!(result.is_err());
192    }
193
194    #[test]
195    fn test_b64_empty_string() {
196        let empty = "";
197        let b64 = B64::try_from(empty).unwrap();
198        assert_eq!(b64.as_bytes().len(), 0);
199    }
200
201    #[test]
202    fn test_b64_padding_removal() {
203        let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
204        let b64 = B64::try_from(encoded_with_padding).unwrap();
205        assert_eq!(b64.as_bytes(), b"Hello, World!");
206    }
207
208    #[test]
209    fn test_b64_serialization() {
210        let data = b"serialization test";
211        let b64 = B64::from(data.as_slice());
212
213        let serialized = serde_json::to_string(&b64).unwrap();
214        assert_eq!(serialized, "\"c2VyaWFsaXphdGlvbiB0ZXN0\"");
215
216        let deserialized: B64 = serde_json::from_str(&serialized).unwrap();
217        assert_eq!(b64.as_bytes(), deserialized.as_bytes());
218    }
219
220    #[test]
221    fn test_not_b64_encoded_error_display() {
222        let error = NotB64Encoded;
223        assert_eq!(error.to_string(), "Data isn't base64 encoded");
224    }
225
226    #[test]
227    fn test_b64_from_str() {
228        let encoded = "SGVsbG8sIFdvcmxkIQ==";
229        let b64: B64 = encoded.parse().unwrap();
230        assert_eq!(b64.as_bytes(), b"Hello, World!");
231    }
232
233    #[test]
234    fn test_b64_eq_and_hash() {
235        let data1 = b"test data";
236        let data2 = b"test data";
237        let data3 = b"different data";
238
239        let b64_1 = B64::from(data1.as_slice());
240        let b64_2 = B64::from(data2.as_slice());
241        let b64_3 = B64::from(data3.as_slice());
242
243        assert_eq!(b64_1, b64_2);
244        assert_ne!(b64_1, b64_3);
245
246        use std::{
247            collections::hash_map::DefaultHasher,
248            hash::{Hash, Hasher},
249        };
250
251        let mut hasher1 = DefaultHasher::new();
252        let mut hasher2 = DefaultHasher::new();
253
254        b64_1.hash(&mut hasher1);
255        b64_2.hash(&mut hasher2);
256
257        assert_eq!(hasher1.finish(), hasher2.finish());
258    }
259}