bitwarden_encoding/
b64.rs1use std::str::FromStr;
2
3use data_encoding::BASE64;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6#[cfg(feature = "wasm")]
7use tsify::Tsify;
8
9use crate::FromStrVisitor;
10
11#[cfg(feature = "wasm")]
15#[derive(Debug, Serialize, Clone, Hash, PartialEq, Eq)]
16#[serde(into = "String")]
17#[derive(Tsify)]
18#[tsify(into_wasm_abi, from_wasm_abi)]
19pub struct B64(#[tsify(type = "String")] Vec<u8>);
20
21#[cfg(not(feature = "wasm"))]
25#[derive(Debug, Serialize, Clone, Hash, PartialEq, Eq)]
26#[serde(into = "String")]
27pub struct B64(Vec<u8>);
28
29impl<'de> Deserialize<'de> for B64 {
31 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
32 where
33 D: serde::Deserializer<'de>,
34 {
35 deserializer.deserialize_str(FromStrVisitor::new())
36 }
37}
38
39impl From<Vec<u8>> for B64 {
40 fn from(src: Vec<u8>) -> Self {
41 Self(src)
42 }
43}
44impl From<&[u8]> for B64 {
45 fn from(src: &[u8]) -> Self {
46 Self(src.to_vec())
47 }
48}
49
50impl From<B64> for Vec<u8> {
51 fn from(src: B64) -> Self {
52 src.0
53 }
54}
55
56impl AsRef<[u8]> for B64 {
57 fn as_ref(&self) -> &[u8] {
58 &self.0
59 }
60}
61
62impl From<B64> for String {
63 fn from(src: B64) -> Self {
64 String::from(&src)
65 }
66}
67
68impl From<&B64> for String {
69 fn from(src: &B64) -> Self {
70 BASE64.encode(&src.0)
71 }
72}
73
74impl std::fmt::Display for B64 {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 f.write_str(String::from(self).as_str())
77 }
78}
79
80#[derive(Debug, Error)]
82#[error("Data isn't base64 encoded")]
83pub struct NotB64Encoded;
84
85const BASE64_PERMISSIVE: data_encoding::Encoding = data_encoding_macro::new_encoding! {
86 symbols: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
87 padding: None,
88 check_trailing_bits: false,
89};
90const BASE64_PADDING: &str = "=";
91
92impl TryFrom<String> for B64 {
93 type Error = NotB64Encoded;
94
95 fn try_from(value: String) -> Result<Self, Self::Error> {
96 let sane_string = value.trim_end_matches(BASE64_PADDING);
97 BASE64_PERMISSIVE
98 .decode(sane_string.as_bytes())
99 .map(Self)
100 .map_err(|_| NotB64Encoded)
101 }
102}
103
104impl TryFrom<&str> for B64 {
105 type Error = NotB64Encoded;
106
107 fn try_from(value: &str) -> Result<Self, Self::Error> {
108 let sane_string = value.trim_end_matches(BASE64_PADDING);
109 BASE64_PERMISSIVE
110 .decode(sane_string.as_bytes())
111 .map(Self)
112 .map_err(|_| NotB64Encoded)
113 }
114}
115
116impl FromStr for B64 {
117 type Err = NotB64Encoded;
118
119 fn from_str(s: &str) -> Result<Self, Self::Err> {
120 Self::try_from(s)
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn test_b64_from_vec() {
130 let data = vec![72, 101, 108, 108, 111];
131 let b64 = B64::from(data.clone());
132 assert_eq!(Vec::<u8>::from(b64), data);
133 }
134
135 #[test]
136 fn test_b64_from_slice() {
137 let data = b"Hello";
138 let b64 = B64::from(data.as_slice());
139 assert_eq!(b64.as_ref(), data);
140 }
141
142 #[test]
143 fn test_b64_encoding_with_padding() {
144 let data = b"Hello, World!";
145 let b64 = B64::from(data.as_slice());
146 let encoded = String::from(&b64);
147 assert_eq!(encoded, "SGVsbG8sIFdvcmxkIQ==");
148 assert!(encoded.contains('='));
149 }
150
151 #[test]
152 fn test_b64_decoding_with_padding() {
153 let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
154 let b64 = B64::try_from(encoded_with_padding).unwrap();
155 assert_eq!(b64.as_ref(), b"Hello, World!");
156 }
157
158 #[test]
159 fn test_b64_decoding_without_padding() {
160 let encoded_without_padding = "SGVsbG8sIFdvcmxkIQ";
161 let b64 = B64::try_from(encoded_without_padding).unwrap();
162 assert_eq!(b64.as_ref(), b"Hello, World!");
163 }
164
165 #[test]
166 fn test_b64_round_trip_with_padding() {
167 let original = b"Test data that requires padding!";
168 let b64 = B64::from(original.as_slice());
169 let encoded = String::from(&b64);
170 let decoded = B64::try_from(encoded.as_str()).unwrap();
171 assert_eq!(decoded.as_ref(), original);
172 }
173
174 #[test]
175 fn test_b64_round_trip_without_padding() {
176 let original = b"Test data";
177 let b64 = B64::from(original.as_slice());
178 let encoded = String::from(&b64);
179 let decoded = B64::try_from(encoded.as_str()).unwrap();
180 assert_eq!(decoded.as_ref(), original);
181 }
182
183 #[test]
184 fn test_b64_display() {
185 let data = b"Hello";
186 let b64 = B64::from(data.as_slice());
187 assert_eq!(b64.to_string(), "SGVsbG8=");
188 }
189
190 #[test]
191 fn test_b64_invalid_encoding() {
192 let invalid_b64 = "This is not base64!@#$";
193 let result = B64::try_from(invalid_b64);
194 assert!(result.is_err());
195 }
196
197 #[test]
198 fn test_b64_empty_string() {
199 let empty = "";
200 let b64 = B64::try_from(empty).unwrap();
201 assert_eq!(b64.as_ref().len(), 0);
202 }
203
204 #[test]
205 fn test_b64_padding_removal() {
206 let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
207 let b64 = B64::try_from(encoded_with_padding).unwrap();
208 assert_eq!(b64.as_ref(), b"Hello, World!");
209 }
210
211 #[test]
212 fn test_b64_serialization() {
213 let data = b"serialization test";
214 let b64 = B64::from(data.as_slice());
215
216 let serialized = serde_json::to_string(&b64).unwrap();
217 assert_eq!(serialized, "\"c2VyaWFsaXphdGlvbiB0ZXN0\"");
218
219 let deserialized: B64 = serde_json::from_str(&serialized).unwrap();
220 assert_eq!(b64.as_ref(), deserialized.as_ref());
221 }
222
223 #[test]
224 fn test_not_b64_encoded_error_display() {
225 let error = NotB64Encoded;
226 assert_eq!(error.to_string(), "Data isn't base64 encoded");
227 }
228
229 #[test]
230 fn test_b64_from_str() {
231 let encoded = "SGVsbG8sIFdvcmxkIQ==";
232 let b64: B64 = encoded.parse().unwrap();
233 assert_eq!(b64.as_ref(), b"Hello, World!");
234 }
235
236 #[test]
237 fn test_b64_eq_and_hash() {
238 let data1 = b"test data";
239 let data2 = b"test data";
240 let data3 = b"different data";
241
242 let b64_1 = B64::from(data1.as_slice());
243 let b64_2 = B64::from(data2.as_slice());
244 let b64_3 = B64::from(data3.as_slice());
245
246 assert_eq!(b64_1, b64_2);
247 assert_ne!(b64_1, b64_3);
248
249 use std::{
250 collections::hash_map::DefaultHasher,
251 hash::{Hash, Hasher},
252 };
253
254 let mut hasher1 = DefaultHasher::new();
255 let mut hasher2 = DefaultHasher::new();
256
257 b64_1.hash(&mut hasher1);
258 b64_2.hash(&mut hasher2);
259
260 assert_eq!(hasher1.finish(), hasher2.finish());
261 }
262}