bitwarden_encoding/
b64url.rs

1use std::str::FromStr;
2
3use data_encoding::BASE64URL;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7/// Base64URL encoded data
8///
9/// Is indifferent about padding when decoding, but always produces padding when encoding.
10#[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)]
11#[serde(try_from = "&str", into = "String")]
12pub struct B64Url(Vec<u8>);
13
14impl From<Vec<u8>> for B64Url {
15    fn from(src: Vec<u8>) -> Self {
16        Self(src)
17    }
18}
19impl From<&[u8]> for B64Url {
20    fn from(src: &[u8]) -> Self {
21        Self(src.to_vec())
22    }
23}
24
25impl From<B64Url> for Vec<u8> {
26    fn from(src: B64Url) -> Self {
27        src.0
28    }
29}
30
31impl AsRef<[u8]> for B64Url {
32    fn as_ref(&self) -> &[u8] {
33        &self.0
34    }
35}
36
37impl From<B64Url> for String {
38    fn from(src: B64Url) -> Self {
39        String::from(&src)
40    }
41}
42
43impl From<&B64Url> for String {
44    fn from(src: &B64Url) -> Self {
45        BASE64URL.encode(&src.0)
46    }
47}
48
49impl std::fmt::Display for B64Url {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.write_str(String::from(self).as_str())
52    }
53}
54
55/// An error returned when a string is not base64 decodable.
56#[derive(Debug, Error)]
57#[error("Data isn't base64url encoded")]
58pub struct NotB64UrlEncoded;
59
60const BASE64URL_PERMISSIVE: data_encoding::Encoding = data_encoding_macro::new_encoding! {
61    symbols: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_",
62    padding: None,
63    check_trailing_bits: false,
64};
65const BASE64URL_PADDING: &str = "=";
66
67impl TryFrom<&str> for B64Url {
68    type Error = NotB64UrlEncoded;
69
70    fn try_from(value: &str) -> Result<Self, Self::Error> {
71        let sane_string = value.trim_end_matches(BASE64URL_PADDING);
72        BASE64URL_PERMISSIVE
73            .decode(sane_string.as_bytes())
74            .map(Self)
75            .map_err(|_| NotB64UrlEncoded)
76    }
77}
78
79impl FromStr for B64Url {
80    type Err = NotB64UrlEncoded;
81
82    fn from_str(s: &str) -> Result<Self, Self::Err> {
83        Self::try_from(s)
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn test_b64url_from_vec() {
93        let data = vec![72, 101, 108, 108, 111];
94        let b64url = B64Url::from(data.clone());
95        assert_eq!(Vec::<u8>::from(b64url), data);
96    }
97
98    #[test]
99    fn test_b64url_from_slice() {
100        let data = b"Hello";
101        let b64url = B64Url::from(data.as_slice());
102        assert_eq!(b64url.as_ref(), data);
103    }
104
105    #[test]
106    fn test_b64url_encoding_with_padding() {
107        let data = b"Hello, World!";
108        let b64url = B64Url::from(data.as_slice());
109        let encoded = String::from(&b64url);
110        assert_eq!(encoded, "SGVsbG8sIFdvcmxkIQ==");
111        assert!(encoded.contains('='));
112    }
113
114    #[test]
115    fn test_b64url_decoding_with_padding() {
116        let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
117        let b64url = B64Url::try_from(encoded_with_padding).unwrap();
118        assert_eq!(b64url.as_ref(), b"Hello, World!");
119    }
120
121    #[test]
122    fn test_b64url_decoding_without_padding() {
123        let encoded_without_padding = "SGVsbG8sIFdvcmxkIQ";
124        let b64url = B64Url::try_from(encoded_without_padding).unwrap();
125        assert_eq!(b64url.as_ref(), b"Hello, World!");
126    }
127
128    #[test]
129    fn test_b64url_round_trip_with_padding() {
130        let original = b"Test data that requires padding!";
131        let b64url = B64Url::from(original.as_slice());
132        let encoded = String::from(&b64url);
133        let decoded = B64Url::try_from(encoded.as_str()).unwrap();
134        assert_eq!(decoded.as_ref(), original);
135    }
136
137    #[test]
138    fn test_b64url_round_trip_without_padding() {
139        let original = b"Test data";
140        let b64url = B64Url::from(original.as_slice());
141        let encoded = String::from(&b64url);
142        let decoded = B64Url::try_from(encoded.as_str()).unwrap();
143        assert_eq!(decoded.as_ref(), original);
144    }
145
146    #[test]
147    fn test_b64url_display() {
148        let data = b"Hello";
149        let b64url = B64Url::from(data.as_slice());
150        assert_eq!(b64url.to_string(), "SGVsbG8=");
151    }
152
153    #[test]
154    fn test_b64url_invalid_encoding() {
155        let invalid_b64url = "This is not base64url!@#$";
156        let result = B64Url::try_from(invalid_b64url);
157        assert!(result.is_err());
158    }
159
160    #[test]
161    fn test_b64url_empty_string() {
162        let empty = "";
163        let b64url = B64Url::try_from(empty).unwrap();
164        assert_eq!(b64url.as_ref().len(), 0);
165    }
166
167    #[test]
168    fn test_b64url_padding_removal() {
169        let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
170        let b64url = B64Url::try_from(encoded_with_padding).unwrap();
171        assert_eq!(b64url.as_ref(), b"Hello, World!");
172    }
173
174    #[test]
175    fn test_b64url_serialization() {
176        let data = b"serialization test";
177        let b64url = B64Url::from(data.as_slice());
178
179        let serialized = serde_json::to_string(&b64url).unwrap();
180        assert_eq!(serialized, "\"c2VyaWFsaXphdGlvbiB0ZXN0\"");
181
182        let deserialized: B64Url = serde_json::from_str(&serialized).unwrap();
183        assert_eq!(b64url.as_ref(), deserialized.as_ref());
184    }
185
186    #[test]
187    fn test_not_b64url_encoded_error_display() {
188        let error = NotB64UrlEncoded;
189        assert_eq!(error.to_string(), "Data isn't base64url encoded");
190    }
191
192    #[test]
193    fn test_b64url_from_str() {
194        let encoded = "SGVsbG8sIFdvcmxkIQ==";
195        let b64url: B64Url = encoded.parse().unwrap();
196        assert_eq!(b64url.as_ref(), b"Hello, World!");
197    }
198
199    #[test]
200    fn test_b64url_eq_and_hash() {
201        let data1 = b"test data";
202        let data2 = b"test data";
203        let data3 = b"different data";
204
205        let b64url_1 = B64Url::from(data1.as_slice());
206        let b64url_2 = B64Url::from(data2.as_slice());
207        let b64url_3 = B64Url::from(data3.as_slice());
208
209        assert_eq!(b64url_1, b64url_2);
210        assert_ne!(b64url_1, b64url_3);
211
212        use std::{
213            collections::hash_map::DefaultHasher,
214            hash::{Hash, Hasher},
215        };
216
217        let mut hasher1 = DefaultHasher::new();
218        let mut hasher2 = DefaultHasher::new();
219
220        b64url_1.hash(&mut hasher1);
221        b64url_2.hash(&mut hasher2);
222
223        assert_eq!(hasher1.finish(), hasher2.finish());
224    }
225}