bitwarden_encoding/
b64url.rs

1use std::str::FromStr;
2
3use data_encoding::BASE64URL_NOPAD;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7/// Base64URL encoded data
8///
9/// Is indifferent about padding when decoding, but always produces padding when encoding.
10#[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)]
11#[serde(try_from = "&str", into = "String")]
12pub struct B64Url(Vec<u8>);
13
14impl B64Url {
15    /// Returns a byte slice of the inner vector.
16    pub fn as_bytes(&self) -> &[u8] {
17        &self.0
18    }
19}
20
21impl From<Vec<u8>> for B64Url {
22    fn from(src: Vec<u8>) -> Self {
23        Self(src)
24    }
25}
26impl From<&[u8]> for B64Url {
27    fn from(src: &[u8]) -> Self {
28        Self(src.to_vec())
29    }
30}
31
32impl From<B64Url> for Vec<u8> {
33    fn from(src: B64Url) -> Self {
34        src.0
35    }
36}
37
38impl From<B64Url> for String {
39    fn from(src: B64Url) -> Self {
40        String::from(&src)
41    }
42}
43
44impl From<&B64Url> for String {
45    fn from(src: &B64Url) -> Self {
46        BASE64URL_NOPAD.encode(&src.0)
47    }
48}
49
50impl std::fmt::Display for B64Url {
51    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52        f.write_str(String::from(self).as_str())
53    }
54}
55
56/// An error returned when a string is not base64 decodable.
57#[derive(Debug, Error)]
58#[error("Data isn't base64url encoded")]
59pub struct NotB64UrlEncoded;
60
61const BASE64URL_PERMISSIVE: data_encoding::Encoding = data_encoding_macro::new_encoding! {
62    symbols: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_",
63    padding: None,
64    check_trailing_bits: false,
65};
66const BASE64URL_PADDING: &str = "=";
67
68impl TryFrom<String> for B64Url {
69    type Error = NotB64UrlEncoded;
70
71    fn try_from(value: String) -> Result<Self, Self::Error> {
72        Self::try_from(value.as_str())
73    }
74}
75
76impl TryFrom<&str> for B64Url {
77    type Error = NotB64UrlEncoded;
78
79    fn try_from(value: &str) -> Result<Self, Self::Error> {
80        let sane_string = value.trim_end_matches(BASE64URL_PADDING);
81        BASE64URL_PERMISSIVE
82            .decode(sane_string.as_bytes())
83            .map(Self)
84            .map_err(|_| NotB64UrlEncoded)
85    }
86}
87
88impl FromStr for B64Url {
89    type Err = NotB64UrlEncoded;
90
91    fn from_str(s: &str) -> Result<Self, Self::Err> {
92        Self::try_from(s)
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn test_b64url_from_vec() {
102        let data = vec![72, 101, 108, 108, 111];
103        let b64url = B64Url::from(data.clone());
104        assert_eq!(Vec::<u8>::from(b64url), data);
105    }
106
107    #[test]
108    fn test_b64url_from_slice() {
109        let data = b"Hello";
110        let b64url = B64Url::from(data.as_slice());
111        assert_eq!(b64url.as_bytes(), data);
112    }
113
114    #[test]
115    fn test_b64url_encoding_with_padding() {
116        let data = b"Hello, World!";
117        let b64url = B64Url::from(data.as_slice());
118        let encoded = String::from(&b64url);
119        assert_eq!(encoded, "SGVsbG8sIFdvcmxkIQ");
120    }
121
122    #[test]
123    fn test_b64url_decoding_with_padding() {
124        let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
125        let b64url = B64Url::try_from(encoded_with_padding).unwrap();
126        assert_eq!(b64url.as_bytes(), b"Hello, World!");
127    }
128
129    #[test]
130    fn test_b64url_decoding_without_padding() {
131        let encoded_without_padding = "SGVsbG8sIFdvcmxkIQ";
132        let b64url = B64Url::try_from(encoded_without_padding).unwrap();
133        assert_eq!(b64url.as_bytes(), b"Hello, World!");
134    }
135
136    #[test]
137    fn test_b64url_round_trip_with_padding() {
138        let original = b"Test data that requires padding!";
139        let b64url = B64Url::from(original.as_slice());
140        let encoded = String::from(&b64url);
141        let decoded = B64Url::try_from(encoded.as_str()).unwrap();
142        assert_eq!(decoded.as_bytes(), original);
143    }
144
145    #[test]
146    fn test_b64url_round_trip_without_padding() {
147        let original = b"Test data";
148        let b64url = B64Url::from(original.as_slice());
149        let encoded = String::from(&b64url);
150        let decoded = B64Url::try_from(encoded.as_str()).unwrap();
151        assert_eq!(decoded.as_bytes(), original);
152    }
153
154    #[test]
155    fn test_b64url_display() {
156        let data = b"Hello";
157        let b64url = B64Url::from(data.as_slice());
158        assert_eq!(b64url.to_string(), "SGVsbG8");
159    }
160
161    #[test]
162    fn test_b64url_invalid_encoding() {
163        let invalid_b64url = "This is not base64url!@#$";
164        let result = B64Url::try_from(invalid_b64url);
165        assert!(result.is_err());
166    }
167
168    #[test]
169    fn test_b64url_empty_string() {
170        let empty = "";
171        let b64url = B64Url::try_from(empty).unwrap();
172        assert_eq!(b64url.as_bytes().len(), 0);
173    }
174
175    #[test]
176    fn test_b64url_padding_removal() {
177        let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
178        let b64url = B64Url::try_from(encoded_with_padding).unwrap();
179        assert_eq!(b64url.as_bytes(), b"Hello, World!");
180    }
181
182    #[test]
183    fn test_b64url_serialization() {
184        let data = b"serialization test";
185        let b64url = B64Url::from(data.as_slice());
186
187        let serialized = serde_json::to_string(&b64url).unwrap();
188        assert_eq!(serialized, "\"c2VyaWFsaXphdGlvbiB0ZXN0\"");
189
190        let deserialized: B64Url = serde_json::from_str(&serialized).unwrap();
191        assert_eq!(b64url.as_bytes(), deserialized.as_bytes());
192    }
193
194    #[test]
195    fn test_not_b64url_encoded_error_display() {
196        let error = NotB64UrlEncoded;
197        assert_eq!(error.to_string(), "Data isn't base64url encoded");
198    }
199
200    #[test]
201    fn test_b64url_from_str() {
202        let encoded = "SGVsbG8sIFdvcmxkIQ==";
203        let b64url: B64Url = encoded.parse().unwrap();
204        assert_eq!(b64url.as_bytes(), b"Hello, World!");
205    }
206
207    #[test]
208    fn test_b64url_eq_and_hash() {
209        let data1 = b"test data";
210        let data2 = b"test data";
211        let data3 = b"different data";
212
213        let b64url_1 = B64Url::from(data1.as_slice());
214        let b64url_2 = B64Url::from(data2.as_slice());
215        let b64url_3 = B64Url::from(data3.as_slice());
216
217        assert_eq!(b64url_1, b64url_2);
218        assert_ne!(b64url_1, b64url_3);
219
220        use std::{
221            collections::hash_map::DefaultHasher,
222            hash::{Hash, Hasher},
223        };
224
225        let mut hasher1 = DefaultHasher::new();
226        let mut hasher2 = DefaultHasher::new();
227
228        b64url_1.hash(&mut hasher1);
229        b64url_2.hash(&mut hasher2);
230
231        assert_eq!(hasher1.finish(), hasher2.finish());
232    }
233}