bitwarden_encoding/
b64url.rs

1use std::str::FromStr;
2
3use data_encoding::BASE64URL_NOPAD;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7/// Base64URL encoded data
8///
9/// Is indifferent about padding when decoding, but always produces padding when encoding.
10#[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)]
11#[serde(try_from = "&str", into = "String")]
12pub struct B64Url(Vec<u8>);
13
14impl B64Url {
15    /// Returns a byte slice of the inner vector.
16    pub fn as_bytes(&self) -> &[u8] {
17        &self.0
18    }
19
20    /// Returns the inner byte vector.
21    pub fn into_bytes(self) -> Vec<u8> {
22        self.0
23    }
24}
25
26impl From<Vec<u8>> for B64Url {
27    fn from(src: Vec<u8>) -> Self {
28        Self(src)
29    }
30}
31impl From<&[u8]> for B64Url {
32    fn from(src: &[u8]) -> Self {
33        Self(src.to_vec())
34    }
35}
36
37impl From<B64Url> for Vec<u8> {
38    fn from(src: B64Url) -> Self {
39        src.0
40    }
41}
42
43impl From<B64Url> for String {
44    fn from(src: B64Url) -> Self {
45        String::from(&src)
46    }
47}
48
49impl From<&B64Url> for String {
50    fn from(src: &B64Url) -> Self {
51        BASE64URL_NOPAD.encode(&src.0)
52    }
53}
54
55impl std::fmt::Display for B64Url {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.write_str(String::from(self).as_str())
58    }
59}
60
61/// An error returned when a string is not base64 decodable.
62#[derive(Debug, Error)]
63#[error("Data isn't base64url encoded")]
64pub struct NotB64UrlEncodedError;
65
66const BASE64URL_PERMISSIVE: data_encoding::Encoding = data_encoding_macro::new_encoding! {
67    symbols: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_",
68    padding: None,
69    check_trailing_bits: false,
70};
71const BASE64URL_PADDING: &str = "=";
72
73impl TryFrom<String> for B64Url {
74    type Error = NotB64UrlEncodedError;
75
76    fn try_from(value: String) -> Result<Self, Self::Error> {
77        Self::try_from(value.as_str())
78    }
79}
80
81impl TryFrom<&str> for B64Url {
82    type Error = NotB64UrlEncodedError;
83
84    fn try_from(value: &str) -> Result<Self, Self::Error> {
85        let sane_string = value.trim_end_matches(BASE64URL_PADDING);
86        BASE64URL_PERMISSIVE
87            .decode(sane_string.as_bytes())
88            .map(Self)
89            .map_err(|_| NotB64UrlEncodedError)
90    }
91}
92
93impl FromStr for B64Url {
94    type Err = NotB64UrlEncodedError;
95
96    fn from_str(s: &str) -> Result<Self, Self::Err> {
97        Self::try_from(s)
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn test_b64url_from_vec() {
107        let data = vec![72, 101, 108, 108, 111];
108        let b64url = B64Url::from(data.clone());
109        assert_eq!(Vec::<u8>::from(b64url), data);
110    }
111
112    #[test]
113    fn test_b64url_from_slice() {
114        let data = b"Hello";
115        let b64url = B64Url::from(data.as_slice());
116        assert_eq!(b64url.as_bytes(), data);
117    }
118
119    #[test]
120    fn test_b64url_encoding_with_padding() {
121        let data = b"Hello, World!";
122        let b64url = B64Url::from(data.as_slice());
123        let encoded = String::from(&b64url);
124        assert_eq!(encoded, "SGVsbG8sIFdvcmxkIQ");
125    }
126
127    #[test]
128    fn test_b64url_decoding_with_padding() {
129        let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
130        let b64url = B64Url::try_from(encoded_with_padding).unwrap();
131        assert_eq!(b64url.as_bytes(), b"Hello, World!");
132    }
133
134    #[test]
135    fn test_b64url_decoding_without_padding() {
136        let encoded_without_padding = "SGVsbG8sIFdvcmxkIQ";
137        let b64url = B64Url::try_from(encoded_without_padding).unwrap();
138        assert_eq!(b64url.as_bytes(), b"Hello, World!");
139    }
140
141    #[test]
142    fn test_b64url_round_trip_with_padding() {
143        let original = b"Test data that requires padding!";
144        let b64url = B64Url::from(original.as_slice());
145        let encoded = String::from(&b64url);
146        let decoded = B64Url::try_from(encoded.as_str()).unwrap();
147        assert_eq!(decoded.as_bytes(), original);
148    }
149
150    #[test]
151    fn test_b64url_round_trip_without_padding() {
152        let original = b"Test data";
153        let b64url = B64Url::from(original.as_slice());
154        let encoded = String::from(&b64url);
155        let decoded = B64Url::try_from(encoded.as_str()).unwrap();
156        assert_eq!(decoded.as_bytes(), original);
157    }
158
159    #[test]
160    fn test_b64url_display() {
161        let data = b"Hello";
162        let b64url = B64Url::from(data.as_slice());
163        assert_eq!(b64url.to_string(), "SGVsbG8");
164    }
165
166    #[test]
167    fn test_b64url_invalid_encoding() {
168        let invalid_b64url = "This is not base64url!@#$";
169        let result = B64Url::try_from(invalid_b64url);
170        assert!(result.is_err());
171    }
172
173    #[test]
174    fn test_b64url_empty_string() {
175        let empty = "";
176        let b64url = B64Url::try_from(empty).unwrap();
177        assert_eq!(b64url.as_bytes().len(), 0);
178    }
179
180    #[test]
181    fn test_b64url_padding_removal() {
182        let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
183        let b64url = B64Url::try_from(encoded_with_padding).unwrap();
184        assert_eq!(b64url.as_bytes(), b"Hello, World!");
185    }
186
187    #[test]
188    fn test_b64url_serialization() {
189        let data = b"serialization test";
190        let b64url = B64Url::from(data.as_slice());
191
192        let serialized = serde_json::to_string(&b64url).unwrap();
193        assert_eq!(serialized, "\"c2VyaWFsaXphdGlvbiB0ZXN0\"");
194
195        let deserialized: B64Url = serde_json::from_str(&serialized).unwrap();
196        assert_eq!(b64url.as_bytes(), deserialized.as_bytes());
197    }
198
199    #[test]
200    fn test_not_b64url_encoded_error_display() {
201        let error = NotB64UrlEncodedError;
202        assert_eq!(error.to_string(), "Data isn't base64url encoded");
203    }
204
205    #[test]
206    fn test_b64url_from_str() {
207        let encoded = "SGVsbG8sIFdvcmxkIQ==";
208        let b64url: B64Url = encoded.parse().unwrap();
209        assert_eq!(b64url.as_bytes(), b"Hello, World!");
210    }
211
212    #[test]
213    fn test_b64url_eq_and_hash() {
214        let data1 = b"test data";
215        let data2 = b"test data";
216        let data3 = b"different data";
217
218        let b64url_1 = B64Url::from(data1.as_slice());
219        let b64url_2 = B64Url::from(data2.as_slice());
220        let b64url_3 = B64Url::from(data3.as_slice());
221
222        assert_eq!(b64url_1, b64url_2);
223        assert_ne!(b64url_1, b64url_3);
224
225        use std::{
226            collections::hash_map::DefaultHasher,
227            hash::{Hash, Hasher},
228        };
229
230        let mut hasher1 = DefaultHasher::new();
231        let mut hasher2 = DefaultHasher::new();
232
233        b64url_1.hash(&mut hasher1);
234        b64url_2.hash(&mut hasher2);
235
236        assert_eq!(hasher1.finish(), hasher2.finish());
237    }
238}