bitwarden_encoding/
b64.rs1use std::str::FromStr;
2
3use data_encoding::BASE64;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6#[cfg(feature = "wasm")]
7use tsify::Tsify;
8
9use crate::FromStrVisitor;
10
11#[cfg(feature = "wasm")]
15#[derive(Debug, Serialize, Clone, Hash, PartialEq, Eq)]
16#[serde(into = "String")]
17#[derive(Tsify)]
18#[tsify(into_wasm_abi, from_wasm_abi)]
19pub struct B64(#[tsify(type = "string")] Vec<u8>);
20
21#[cfg(not(feature = "wasm"))]
25#[derive(Debug, Serialize, Clone, Hash, PartialEq, Eq)]
26#[serde(into = "String")]
27pub struct B64(Vec<u8>);
28
29impl B64 {
30 pub fn as_bytes(&self) -> &[u8] {
32 &self.0
33 }
34
35 pub fn into_bytes(self) -> Vec<u8> {
37 self.0
38 }
39}
40
41impl<'de> Deserialize<'de> for B64 {
43 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
44 where
45 D: serde::Deserializer<'de>,
46 {
47 deserializer.deserialize_str(FromStrVisitor::new())
48 }
49}
50
51impl From<Vec<u8>> for B64 {
52 fn from(src: Vec<u8>) -> Self {
53 Self(src)
54 }
55}
56impl From<&[u8]> for B64 {
57 fn from(src: &[u8]) -> Self {
58 Self(src.to_vec())
59 }
60}
61
62impl From<B64> for Vec<u8> {
63 fn from(src: B64) -> Self {
64 src.0
65 }
66}
67
68impl From<B64> for String {
69 fn from(src: B64) -> Self {
70 String::from(&src)
71 }
72}
73
74impl From<&B64> for String {
75 fn from(src: &B64) -> Self {
76 BASE64.encode(&src.0)
77 }
78}
79
80impl std::fmt::Display for B64 {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 f.write_str(String::from(self).as_str())
83 }
84}
85
86#[derive(Debug, Error)]
88#[error("Data isn't base64 encoded")]
89pub struct NotB64EncodedError;
90
91const BASE64_PERMISSIVE: data_encoding::Encoding = data_encoding_macro::new_encoding! {
92 symbols: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
93 padding: None,
94 check_trailing_bits: false,
95};
96const BASE64_PADDING: &str = "=";
97
98impl TryFrom<String> for B64 {
99 type Error = NotB64EncodedError;
100
101 fn try_from(value: String) -> Result<Self, Self::Error> {
102 Self::try_from(value.as_str())
103 }
104}
105
106impl TryFrom<&str> for B64 {
107 type Error = NotB64EncodedError;
108
109 fn try_from(value: &str) -> Result<Self, Self::Error> {
110 let sane_string = value.trim_end_matches(BASE64_PADDING);
111 BASE64_PERMISSIVE
112 .decode(sane_string.as_bytes())
113 .map(Self)
114 .map_err(|_| NotB64EncodedError)
115 }
116}
117
118impl FromStr for B64 {
119 type Err = NotB64EncodedError;
120
121 fn from_str(s: &str) -> Result<Self, Self::Err> {
122 Self::try_from(s)
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[test]
131 fn test_b64_from_vec() {
132 let data = vec![72, 101, 108, 108, 111];
133 let b64 = B64::from(data.clone());
134 assert_eq!(Vec::<u8>::from(b64), data);
135 }
136
137 #[test]
138 fn test_b64_from_slice() {
139 let data = b"Hello";
140 let b64 = B64::from(data.as_slice());
141 assert_eq!(b64.as_bytes(), data);
142 }
143
144 #[test]
145 fn test_b64_encoding_with_padding() {
146 let data = b"Hello, World!";
147 let b64 = B64::from(data.as_slice());
148 let encoded = String::from(&b64);
149 assert_eq!(encoded, "SGVsbG8sIFdvcmxkIQ==");
150 assert!(encoded.contains('='));
151 }
152
153 #[test]
154 fn test_b64_decoding_with_padding() {
155 let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
156 let b64 = B64::try_from(encoded_with_padding).unwrap();
157 assert_eq!(b64.as_bytes(), b"Hello, World!");
158 }
159
160 #[test]
161 fn test_b64_decoding_without_padding() {
162 let encoded_without_padding = "SGVsbG8sIFdvcmxkIQ";
163 let b64 = B64::try_from(encoded_without_padding).unwrap();
164 assert_eq!(b64.as_bytes(), b"Hello, World!");
165 }
166
167 #[test]
168 fn test_b64_round_trip_with_padding() {
169 let original = b"Test data that requires padding!";
170 let b64 = B64::from(original.as_slice());
171 let encoded = String::from(&b64);
172 let decoded = B64::try_from(encoded.as_str()).unwrap();
173 assert_eq!(decoded.as_bytes(), original);
174 }
175
176 #[test]
177 fn test_b64_round_trip_without_padding() {
178 let original = b"Test data";
179 let b64 = B64::from(original.as_slice());
180 let encoded = String::from(&b64);
181 let decoded = B64::try_from(encoded.as_str()).unwrap();
182 assert_eq!(decoded.as_bytes(), original);
183 }
184
185 #[test]
186 fn test_b64_display() {
187 let data = b"Hello";
188 let b64 = B64::from(data.as_slice());
189 assert_eq!(b64.to_string(), "SGVsbG8=");
190 }
191
192 #[test]
193 fn test_b64_invalid_encoding() {
194 let invalid_b64 = "This is not base64!@#$";
195 let result = B64::try_from(invalid_b64);
196 assert!(result.is_err());
197 }
198
199 #[test]
200 fn test_b64_empty_string() {
201 let empty = "";
202 let b64 = B64::try_from(empty).unwrap();
203 assert_eq!(b64.as_bytes().len(), 0);
204 }
205
206 #[test]
207 fn test_b64_padding_removal() {
208 let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
209 let b64 = B64::try_from(encoded_with_padding).unwrap();
210 assert_eq!(b64.as_bytes(), b"Hello, World!");
211 }
212
213 #[test]
214 fn test_b64_serialization() {
215 let data = b"serialization test";
216 let b64 = B64::from(data.as_slice());
217
218 let serialized = serde_json::to_string(&b64).unwrap();
219 assert_eq!(serialized, "\"c2VyaWFsaXphdGlvbiB0ZXN0\"");
220
221 let deserialized: B64 = serde_json::from_str(&serialized).unwrap();
222 assert_eq!(b64.as_bytes(), deserialized.as_bytes());
223 }
224
225 #[test]
226 fn test_not_b64_encoded_error_display() {
227 let error = NotB64EncodedError;
228 assert_eq!(error.to_string(), "Data isn't base64 encoded");
229 }
230
231 #[test]
232 fn test_b64_from_str() {
233 let encoded = "SGVsbG8sIFdvcmxkIQ==";
234 let b64: B64 = encoded.parse().unwrap();
235 assert_eq!(b64.as_bytes(), b"Hello, World!");
236 }
237
238 #[test]
239 fn test_b64_eq_and_hash() {
240 let data1 = b"test data";
241 let data2 = b"test data";
242 let data3 = b"different data";
243
244 let b64_1 = B64::from(data1.as_slice());
245 let b64_2 = B64::from(data2.as_slice());
246 let b64_3 = B64::from(data3.as_slice());
247
248 assert_eq!(b64_1, b64_2);
249 assert_ne!(b64_1, b64_3);
250
251 use std::{
252 collections::hash_map::DefaultHasher,
253 hash::{Hash, Hasher},
254 };
255
256 let mut hasher1 = DefaultHasher::new();
257 let mut hasher2 = DefaultHasher::new();
258
259 b64_1.hash(&mut hasher1);
260 b64_2.hash(&mut hasher2);
261
262 assert_eq!(hasher1.finish(), hasher2.finish());
263 }
264}