1use std::{cmp::max, pin::Pin};
2
3use generic_array::GenericArray;
4use typenum::U32;
5
6use super::Aes256CbcHmacKey;
7use crate::{util::hkdf_expand, CryptoError, Result};
8
9pub(super) fn stretch_key(key: &Pin<Box<GenericArray<u8, U32>>>) -> Result<Aes256CbcHmacKey> {
13 Ok(Aes256CbcHmacKey {
14 enc_key: hkdf_expand(key, Some("enc"))?,
15 mac_key: hkdf_expand(key, Some("mac"))?,
16 })
17}
18
19pub(crate) fn pad_bytes(bytes: &mut Vec<u8>, min_length: usize) -> Result<(), CryptoError> {
24 let pad_bytes = min_length.saturating_sub(bytes.len()).max(1);
26 if pad_bytes > 255 {
29 return Err(CryptoError::InvalidPadding);
30 }
31 let padded_length = max(min_length, bytes.len() + 1);
32 bytes.resize(padded_length, pad_bytes as u8);
33 Ok(())
34}
35
36pub(crate) fn unpad_bytes(padded_bytes: &[u8]) -> Result<&[u8], CryptoError> {
40 let pad_len = *padded_bytes.last().ok_or(CryptoError::InvalidPadding)? as usize;
41 if pad_len == 0 || pad_len > padded_bytes.len() {
43 return Err(CryptoError::InvalidPadding);
44 }
45 Ok(padded_bytes[..(padded_bytes.len() - pad_len)].as_ref())
46}
47
48#[cfg(test)]
49mod tests {
50 use super::*;
51
52 #[test]
53 fn test_stretch_kdf_key() {
54 let key = Box::pin(
55 [
56 31, 79, 104, 226, 150, 71, 177, 90, 194, 80, 172, 209, 17, 129, 132, 81, 138, 167,
57 69, 167, 254, 149, 2, 27, 39, 197, 64, 42, 22, 195, 86, 75,
58 ]
59 .into(),
60 );
61 let stretched = stretch_key(&key).unwrap();
62
63 assert_eq!(
64 [
65 111, 31, 178, 45, 238, 152, 37, 114, 143, 215, 124, 83, 135, 173, 195, 23, 142,
66 134, 120, 249, 61, 132, 163, 182, 113, 197, 189, 204, 188, 21, 237, 96
67 ],
68 stretched.enc_key.as_slice()
69 );
70 assert_eq!(
71 [
72 221, 127, 206, 234, 101, 27, 202, 38, 86, 52, 34, 28, 78, 28, 185, 16, 48, 61, 127,
73 166, 209, 247, 194, 87, 232, 26, 48, 85, 193, 249, 179, 155
74 ],
75 stretched.mac_key.as_slice()
76 );
77 }
78
79 #[test]
80 fn test_pad_bytes_256_error() {
81 let mut bytes = vec![1u8; 0];
82 let result = pad_bytes(&mut bytes, 256);
83 assert!(matches!(result, Err(CryptoError::InvalidPadding)));
84 }
85
86 #[test]
87 fn test_pad_bytes_roundtrip() {
88 let original_bytes = vec![1u8; 10];
89 let mut cloned_bytes = original_bytes.clone();
90 let mut encoded_bytes = vec![1u8; 12];
91 encoded_bytes[10] = 2;
92 encoded_bytes[11] = 2;
93 pad_bytes(&mut cloned_bytes, 12).expect("Padding failed");
94 assert_eq!(encoded_bytes, cloned_bytes);
95 let unpadded_bytes = unpad_bytes(&cloned_bytes).unwrap();
96 assert_eq!(original_bytes, unpadded_bytes);
97 }
98
99 #[test]
100 fn test_pad_bytes_roundtrip_empty() {
101 let original_bytes = Vec::new();
102 let mut cloned_bytes = original_bytes.clone();
103 pad_bytes(&mut cloned_bytes, 32).expect("Padding failed");
104 let unpadded = unpad_bytes(&cloned_bytes).unwrap();
105 assert_eq!(Vec::<u8>::new(), unpadded);
106 }
107
108 #[test]
109 fn test_unpad_bytes_invalid_empty() {
110 let data: Vec<u8> = vec![];
111 let result = unpad_bytes(&data);
112 assert!(matches!(result, Err(CryptoError::InvalidPadding)));
113 }
114
115 #[test]
116 fn test_unpad_bytes_invalid_too_large() {
117 let data = vec![1, 2, 3, 5];
119 let result = unpad_bytes(&data);
120 assert!(matches!(result, Err(CryptoError::InvalidPadding)));
121 }
122
123 #[test]
124 fn test_unpad_bytes_invalid_0_padding() {
125 let data = vec![1, 2, 3, 0];
127 let result = unpad_bytes(&data);
128 assert!(matches!(result, Err(CryptoError::InvalidPadding)));
129 }
130
131 #[test]
132 fn test_pad_and_unpad_bytes_range_0_to_1024() {
133 let cases: Vec<_> = (0..=1024)
134 .flat_map(|data_size| (2..=1024).map(move |padding_size| (data_size, padding_size)))
135 .collect();
136
137 let data_larger_than_padding_cases: Vec<_> = cases
138 .clone()
139 .into_iter()
140 .filter(|(data_size, padding_size)| data_size > padding_size)
141 .collect();
142 for (data_size, padding_size) in data_larger_than_padding_cases {
143 let mut data: Vec<u8> = vec![0x12; data_size];
144 let original = data.clone();
145 pad_bytes(&mut data, padding_size).expect("Padding failed");
146 let unpadded = unpad_bytes(&data).expect("Unpadding failed");
147 assert_eq!(
148 unpadded, original,
149 "Failed at size {} and padding {}",
150 data_size, padding_size
151 );
152 }
153
154 let padding_larger_than_data_cases: Vec<_> = cases
155 .clone()
156 .into_iter()
157 .filter(|(data_size, padding_size)| {
158 data_size <= padding_size && (padding_size - data_size) <= 255
159 })
160 .collect();
161 for (data_size, padding_size) in padding_larger_than_data_cases {
162 println!(
163 "Testing data_size: {}, padding_size: {}",
164 data_size, padding_size
165 );
166 let data_original: Vec<u8> = vec![0x12; data_size];
167 let mut data = data_original.clone();
168
169 pad_bytes(&mut data, padding_size).expect("Padding failed");
170 let unpadded = unpad_bytes(&data).expect("Unpadding failed");
171 assert_eq!(
172 unpadded, data_original,
173 "Failed at size {} and padding {}",
174 data_size, padding_size
175 );
176 }
177
178 let padding_massively_larger_than_data_cases: Vec<_> = cases
179 .into_iter()
180 .filter(|(data_size, padding_size)| {
181 data_size <= padding_size && (padding_size - data_size) > 255
182 })
183 .collect();
184 for (data_size, padding_size) in padding_massively_larger_than_data_cases {
185 let mut data: Vec<u8> = vec![0x12; data_size];
186 let error = pad_bytes(&mut data, padding_size);
187 assert!(
188 matches!(error, Err(CryptoError::InvalidPadding)),
189 "Expected InvalidPadding error at size {} and padding {}, but got {:?}",
190 data_size,
191 padding_size,
192 error
193 );
194 }
195 }
196}