bitwarden_encoding/
b64.rs

1use std::str::FromStr;
2
3use data_encoding::BASE64;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6#[cfg(feature = "wasm")]
7use tsify::Tsify;
8
9use crate::FromStrVisitor;
10
11/// Base64 encoded data
12///
13/// Is indifferent about padding when decoding, but always produces padding when encoding.
14#[cfg(feature = "wasm")]
15#[derive(Debug, Serialize, Clone, Hash, PartialEq, Eq)]
16#[serde(into = "String")]
17#[derive(Tsify)]
18#[tsify(into_wasm_abi, from_wasm_abi)]
19pub struct B64(#[tsify(type = "string")] Vec<u8>);
20
21/// Base64 encoded data
22///
23/// Is indifferent about padding when decoding, but always produces padding when encoding.
24#[cfg(not(feature = "wasm"))]
25#[derive(Debug, Serialize, Clone, Hash, PartialEq, Eq)]
26#[serde(into = "String")]
27pub struct B64(Vec<u8>);
28
29impl B64 {
30    /// Returns a byte slice of the inner vector.
31    pub fn as_bytes(&self) -> &[u8] {
32        &self.0
33    }
34
35    /// Returns the inner byte vector.
36    pub fn into_bytes(self) -> Vec<u8> {
37        self.0
38    }
39}
40
41// We manually implement this to handle both `String` and `&str`
42impl<'de> Deserialize<'de> for B64 {
43    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
44    where
45        D: serde::Deserializer<'de>,
46    {
47        deserializer.deserialize_str(FromStrVisitor::new())
48    }
49}
50
51impl From<Vec<u8>> for B64 {
52    fn from(src: Vec<u8>) -> Self {
53        Self(src)
54    }
55}
56impl From<&[u8]> for B64 {
57    fn from(src: &[u8]) -> Self {
58        Self(src.to_vec())
59    }
60}
61
62impl From<B64> for Vec<u8> {
63    fn from(src: B64) -> Self {
64        src.0
65    }
66}
67
68impl From<B64> for String {
69    fn from(src: B64) -> Self {
70        String::from(&src)
71    }
72}
73
74impl From<&B64> for String {
75    fn from(src: &B64) -> Self {
76        BASE64.encode(&src.0)
77    }
78}
79
80impl std::fmt::Display for B64 {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        f.write_str(String::from(self).as_str())
83    }
84}
85
86/// An error returned when a string is not base64 decodable.
87#[derive(Debug, Error)]
88#[error("Data isn't base64 encoded")]
89pub struct NotB64EncodedError;
90
91const BASE64_PERMISSIVE: data_encoding::Encoding = data_encoding_macro::new_encoding! {
92    symbols: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
93    padding: None,
94    check_trailing_bits: false,
95};
96const BASE64_PADDING: &str = "=";
97
98impl TryFrom<String> for B64 {
99    type Error = NotB64EncodedError;
100
101    fn try_from(value: String) -> Result<Self, Self::Error> {
102        Self::try_from(value.as_str())
103    }
104}
105
106impl TryFrom<&str> for B64 {
107    type Error = NotB64EncodedError;
108
109    fn try_from(value: &str) -> Result<Self, Self::Error> {
110        let sane_string = value.trim_end_matches(BASE64_PADDING);
111        BASE64_PERMISSIVE
112            .decode(sane_string.as_bytes())
113            .map(Self)
114            .map_err(|_| NotB64EncodedError)
115    }
116}
117
118impl FromStr for B64 {
119    type Err = NotB64EncodedError;
120
121    fn from_str(s: &str) -> Result<Self, Self::Err> {
122        Self::try_from(s)
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn test_b64_from_vec() {
132        let data = vec![72, 101, 108, 108, 111];
133        let b64 = B64::from(data.clone());
134        assert_eq!(Vec::<u8>::from(b64), data);
135    }
136
137    #[test]
138    fn test_b64_from_slice() {
139        let data = b"Hello";
140        let b64 = B64::from(data.as_slice());
141        assert_eq!(b64.as_bytes(), data);
142    }
143
144    #[test]
145    fn test_b64_encoding_with_padding() {
146        let data = b"Hello, World!";
147        let b64 = B64::from(data.as_slice());
148        let encoded = String::from(&b64);
149        assert_eq!(encoded, "SGVsbG8sIFdvcmxkIQ==");
150        assert!(encoded.contains('='));
151    }
152
153    #[test]
154    fn test_b64_decoding_with_padding() {
155        let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
156        let b64 = B64::try_from(encoded_with_padding).unwrap();
157        assert_eq!(b64.as_bytes(), b"Hello, World!");
158    }
159
160    #[test]
161    fn test_b64_decoding_without_padding() {
162        let encoded_without_padding = "SGVsbG8sIFdvcmxkIQ";
163        let b64 = B64::try_from(encoded_without_padding).unwrap();
164        assert_eq!(b64.as_bytes(), b"Hello, World!");
165    }
166
167    #[test]
168    fn test_b64_round_trip_with_padding() {
169        let original = b"Test data that requires padding!";
170        let b64 = B64::from(original.as_slice());
171        let encoded = String::from(&b64);
172        let decoded = B64::try_from(encoded.as_str()).unwrap();
173        assert_eq!(decoded.as_bytes(), original);
174    }
175
176    #[test]
177    fn test_b64_round_trip_without_padding() {
178        let original = b"Test data";
179        let b64 = B64::from(original.as_slice());
180        let encoded = String::from(&b64);
181        let decoded = B64::try_from(encoded.as_str()).unwrap();
182        assert_eq!(decoded.as_bytes(), original);
183    }
184
185    #[test]
186    fn test_b64_display() {
187        let data = b"Hello";
188        let b64 = B64::from(data.as_slice());
189        assert_eq!(b64.to_string(), "SGVsbG8=");
190    }
191
192    #[test]
193    fn test_b64_invalid_encoding() {
194        let invalid_b64 = "This is not base64!@#$";
195        let result = B64::try_from(invalid_b64);
196        assert!(result.is_err());
197    }
198
199    #[test]
200    fn test_b64_empty_string() {
201        let empty = "";
202        let b64 = B64::try_from(empty).unwrap();
203        assert_eq!(b64.as_bytes().len(), 0);
204    }
205
206    #[test]
207    fn test_b64_padding_removal() {
208        let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
209        let b64 = B64::try_from(encoded_with_padding).unwrap();
210        assert_eq!(b64.as_bytes(), b"Hello, World!");
211    }
212
213    #[test]
214    fn test_b64_serialization() {
215        let data = b"serialization test";
216        let b64 = B64::from(data.as_slice());
217
218        let serialized = serde_json::to_string(&b64).unwrap();
219        assert_eq!(serialized, "\"c2VyaWFsaXphdGlvbiB0ZXN0\"");
220
221        let deserialized: B64 = serde_json::from_str(&serialized).unwrap();
222        assert_eq!(b64.as_bytes(), deserialized.as_bytes());
223    }
224
225    #[test]
226    fn test_not_b64_encoded_error_display() {
227        let error = NotB64EncodedError;
228        assert_eq!(error.to_string(), "Data isn't base64 encoded");
229    }
230
231    #[test]
232    fn test_b64_from_str() {
233        let encoded = "SGVsbG8sIFdvcmxkIQ==";
234        let b64: B64 = encoded.parse().unwrap();
235        assert_eq!(b64.as_bytes(), b"Hello, World!");
236    }
237
238    #[test]
239    fn test_b64_eq_and_hash() {
240        let data1 = b"test data";
241        let data2 = b"test data";
242        let data3 = b"different data";
243
244        let b64_1 = B64::from(data1.as_slice());
245        let b64_2 = B64::from(data2.as_slice());
246        let b64_3 = B64::from(data3.as_slice());
247
248        assert_eq!(b64_1, b64_2);
249        assert_ne!(b64_1, b64_3);
250
251        use std::{
252            collections::hash_map::DefaultHasher,
253            hash::{Hash, Hasher},
254        };
255
256        let mut hasher1 = DefaultHasher::new();
257        let mut hasher2 = DefaultHasher::new();
258
259        b64_1.hash(&mut hasher1);
260        b64_2.hash(&mut hasher2);
261
262        assert_eq!(hasher1.finish(), hasher2.finish());
263    }
264}