Skip to main content

bitwarden_encoding/
b64url.rs

1use std::str::FromStr;
2
3use data_encoding::BASE64URL_NOPAD;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7use crate::chunked::chunked_encode;
8
9/// Base64URL encoded data
10///
11/// Is indifferent about padding when decoding, but always produces padding when encoding.
12#[derive(Debug, Serialize, Deserialize, Clone, Hash, PartialEq, Eq)]
13#[serde(try_from = "&str", into = "String")]
14pub struct B64Url(Vec<u8>);
15
16impl B64Url {
17    /// Returns a byte slice of the inner vector.
18    pub fn as_bytes(&self) -> &[u8] {
19        &self.0
20    }
21
22    /// Returns the inner byte vector.
23    pub fn into_bytes(self) -> Vec<u8> {
24        self.0
25    }
26}
27
28impl From<Vec<u8>> for B64Url {
29    fn from(src: Vec<u8>) -> Self {
30        Self(src)
31    }
32}
33impl From<&[u8]> for B64Url {
34    fn from(src: &[u8]) -> Self {
35        Self(src.to_vec())
36    }
37}
38
39impl From<B64Url> for Vec<u8> {
40    fn from(src: B64Url) -> Self {
41        src.0
42    }
43}
44
45impl From<B64Url> for String {
46    fn from(src: B64Url) -> Self {
47        String::from(&src)
48    }
49}
50
51impl From<&B64Url> for String {
52    fn from(src: &B64Url) -> Self {
53        chunked_encode(&BASE64URL_NOPAD, &src.0)
54    }
55}
56
57impl std::fmt::Display for B64Url {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.write_str(String::from(self).as_str())
60    }
61}
62
63/// An error returned when a string is not base64 decodable.
64#[derive(Debug, Error)]
65#[error("Data isn't base64url encoded")]
66pub struct NotB64UrlEncodedError;
67
68const BASE64URL_PERMISSIVE: data_encoding::Encoding = data_encoding_macro::new_encoding! {
69    symbols: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_",
70    padding: None,
71    check_trailing_bits: false,
72};
73const BASE64URL_PADDING: &str = "=";
74
75impl TryFrom<String> for B64Url {
76    type Error = NotB64UrlEncodedError;
77
78    fn try_from(value: String) -> Result<Self, Self::Error> {
79        Self::try_from(value.as_str())
80    }
81}
82
83impl TryFrom<&str> for B64Url {
84    type Error = NotB64UrlEncodedError;
85
86    fn try_from(value: &str) -> Result<Self, Self::Error> {
87        let sane_string = value.trim_end_matches(BASE64URL_PADDING);
88        BASE64URL_PERMISSIVE
89            .decode(sane_string.as_bytes())
90            .map(Self)
91            .map_err(|_| NotB64UrlEncodedError)
92    }
93}
94
95impl FromStr for B64Url {
96    type Err = NotB64UrlEncodedError;
97
98    fn from_str(s: &str) -> Result<Self, Self::Err> {
99        Self::try_from(s)
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn test_b64url_from_vec() {
109        let data = vec![72, 101, 108, 108, 111];
110        let b64url = B64Url::from(data.clone());
111        assert_eq!(Vec::<u8>::from(b64url), data);
112    }
113
114    #[test]
115    fn test_b64url_from_slice() {
116        let data = b"Hello";
117        let b64url = B64Url::from(data.as_slice());
118        assert_eq!(b64url.as_bytes(), data);
119    }
120
121    #[test]
122    fn test_b64url_encoding_with_padding() {
123        let data = b"Hello, World!";
124        let b64url = B64Url::from(data.as_slice());
125        let encoded = String::from(&b64url);
126        assert_eq!(encoded, "SGVsbG8sIFdvcmxkIQ");
127    }
128
129    #[test]
130    fn test_b64url_decoding_with_padding() {
131        let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
132        let b64url = B64Url::try_from(encoded_with_padding).unwrap();
133        assert_eq!(b64url.as_bytes(), b"Hello, World!");
134    }
135
136    #[test]
137    fn test_b64url_decoding_without_padding() {
138        let encoded_without_padding = "SGVsbG8sIFdvcmxkIQ";
139        let b64url = B64Url::try_from(encoded_without_padding).unwrap();
140        assert_eq!(b64url.as_bytes(), b"Hello, World!");
141    }
142
143    #[test]
144    fn test_b64url_round_trip_with_padding() {
145        let original = b"Test data that requires padding!";
146        let b64url = B64Url::from(original.as_slice());
147        let encoded = String::from(&b64url);
148        let decoded = B64Url::try_from(encoded.as_str()).unwrap();
149        assert_eq!(decoded.as_bytes(), original);
150    }
151
152    #[test]
153    fn test_b64url_round_trip_without_padding() {
154        let original = b"Test data";
155        let b64url = B64Url::from(original.as_slice());
156        let encoded = String::from(&b64url);
157        let decoded = B64Url::try_from(encoded.as_str()).unwrap();
158        assert_eq!(decoded.as_bytes(), original);
159    }
160
161    #[test]
162    fn test_b64url_display() {
163        let data = b"Hello";
164        let b64url = B64Url::from(data.as_slice());
165        assert_eq!(b64url.to_string(), "SGVsbG8");
166    }
167
168    #[test]
169    fn test_b64url_invalid_encoding() {
170        let invalid_b64url = "This is not base64url!@#$";
171        let result = B64Url::try_from(invalid_b64url);
172        assert!(result.is_err());
173    }
174
175    #[test]
176    fn test_b64url_empty_string() {
177        let empty = "";
178        let b64url = B64Url::try_from(empty).unwrap();
179        assert_eq!(b64url.as_bytes().len(), 0);
180    }
181
182    #[test]
183    fn test_b64url_padding_removal() {
184        let encoded_with_padding = "SGVsbG8sIFdvcmxkIQ==";
185        let b64url = B64Url::try_from(encoded_with_padding).unwrap();
186        assert_eq!(b64url.as_bytes(), b"Hello, World!");
187    }
188
189    #[test]
190    fn test_b64url_serialization() {
191        let data = b"serialization test";
192        let b64url = B64Url::from(data.as_slice());
193
194        let serialized = serde_json::to_string(&b64url).unwrap();
195        assert_eq!(serialized, "\"c2VyaWFsaXphdGlvbiB0ZXN0\"");
196
197        let deserialized: B64Url = serde_json::from_str(&serialized).unwrap();
198        assert_eq!(b64url.as_bytes(), deserialized.as_bytes());
199    }
200
201    #[test]
202    fn test_not_b64url_encoded_error_display() {
203        let error = NotB64UrlEncodedError;
204        assert_eq!(error.to_string(), "Data isn't base64url encoded");
205    }
206
207    #[test]
208    fn test_b64url_from_str() {
209        let encoded = "SGVsbG8sIFdvcmxkIQ==";
210        let b64url: B64Url = encoded.parse().unwrap();
211        assert_eq!(b64url.as_bytes(), b"Hello, World!");
212    }
213
214    #[test]
215    fn test_b64url_eq_and_hash() {
216        let data1 = b"test data";
217        let data2 = b"test data";
218        let data3 = b"different data";
219
220        let b64url_1 = B64Url::from(data1.as_slice());
221        let b64url_2 = B64Url::from(data2.as_slice());
222        let b64url_3 = B64Url::from(data3.as_slice());
223
224        assert_eq!(b64url_1, b64url_2);
225        assert_ne!(b64url_1, b64url_3);
226
227        use std::{
228            collections::hash_map::DefaultHasher,
229            hash::{Hash, Hasher},
230        };
231
232        let mut hasher1 = DefaultHasher::new();
233        let mut hasher2 = DefaultHasher::new();
234
235        b64url_1.hash(&mut hasher1);
236        b64url_2.hash(&mut hasher2);
237
238        assert_eq!(hasher1.finish(), hasher2.finish());
239    }
240}