1use std::{cmp::max, pin::Pin};
2
3use generic_array::GenericArray;
4use typenum::U32;
5
6use super::Aes256CbcHmacKey;
7use crate::{CryptoError, Result, util::hkdf_expand};
8
9pub(super) fn stretch_key(key: &Pin<Box<GenericArray<u8, U32>>>) -> Aes256CbcHmacKey {
13 Aes256CbcHmacKey {
14 enc_key: hkdf_expand(key, Some("enc")).expect("HKDF expand to succeed"),
16 mac_key: hkdf_expand(key, Some("mac")).expect("HKDF expand to succeed"),
17 }
18}
19
20pub(crate) fn pad_bytes(bytes: &mut Vec<u8>, min_length: usize) -> Result<(), CryptoError> {
25 let pad_bytes = min_length.saturating_sub(bytes.len()).max(1);
27 if pad_bytes > 255 {
30 return Err(CryptoError::InvalidPadding);
31 }
32 let padded_length = max(min_length, bytes.len() + 1);
33 bytes.resize(padded_length, pad_bytes as u8);
34 Ok(())
35}
36
37pub(crate) fn unpad_bytes(padded_bytes: &[u8]) -> Result<&[u8], CryptoError> {
41 let pad_len = *padded_bytes.last().ok_or(CryptoError::InvalidPadding)? as usize;
42 if pad_len == 0 || pad_len > padded_bytes.len() {
44 return Err(CryptoError::InvalidPadding);
45 }
46 Ok(padded_bytes[..(padded_bytes.len() - pad_len)].as_ref())
47}
48
49#[cfg(test)]
50mod tests {
51 use super::*;
52
53 #[test]
54 fn test_stretch_kdf_key() {
55 let key = Box::pin(
56 [
57 31, 79, 104, 226, 150, 71, 177, 90, 194, 80, 172, 209, 17, 129, 132, 81, 138, 167,
58 69, 167, 254, 149, 2, 27, 39, 197, 64, 42, 22, 195, 86, 75,
59 ]
60 .into(),
61 );
62 let stretched = stretch_key(&key);
63
64 assert_eq!(
65 [
66 111, 31, 178, 45, 238, 152, 37, 114, 143, 215, 124, 83, 135, 173, 195, 23, 142,
67 134, 120, 249, 61, 132, 163, 182, 113, 197, 189, 204, 188, 21, 237, 96
68 ],
69 stretched.enc_key.as_slice()
70 );
71 assert_eq!(
72 [
73 221, 127, 206, 234, 101, 27, 202, 38, 86, 52, 34, 28, 78, 28, 185, 16, 48, 61, 127,
74 166, 209, 247, 194, 87, 232, 26, 48, 85, 193, 249, 179, 155
75 ],
76 stretched.mac_key.as_slice()
77 );
78 }
79
80 #[test]
81 fn test_pad_bytes_256_error() {
82 let mut bytes = vec![1u8; 0];
83 let result = pad_bytes(&mut bytes, 256);
84 assert!(matches!(result, Err(CryptoError::InvalidPadding)));
85 }
86
87 #[test]
88 fn test_pad_bytes_roundtrip() {
89 let original_bytes = vec![1u8; 10];
90 let mut cloned_bytes = original_bytes.clone();
91 let mut encoded_bytes = vec![1u8; 12];
92 encoded_bytes[10] = 2;
93 encoded_bytes[11] = 2;
94 pad_bytes(&mut cloned_bytes, 12).expect("Padding failed");
95 assert_eq!(encoded_bytes, cloned_bytes);
96 let unpadded_bytes = unpad_bytes(&cloned_bytes).unwrap();
97 assert_eq!(original_bytes, unpadded_bytes);
98 }
99
100 #[test]
101 fn test_pad_bytes_roundtrip_empty() {
102 let original_bytes = Vec::new();
103 let mut cloned_bytes = original_bytes.clone();
104 pad_bytes(&mut cloned_bytes, 32).expect("Padding failed");
105 let unpadded = unpad_bytes(&cloned_bytes).unwrap();
106 assert_eq!(Vec::<u8>::new(), unpadded);
107 }
108
109 #[test]
110 fn test_unpad_bytes_invalid_empty() {
111 let data: Vec<u8> = vec![];
112 let result = unpad_bytes(&data);
113 assert!(matches!(result, Err(CryptoError::InvalidPadding)));
114 }
115
116 #[test]
117 fn test_unpad_bytes_invalid_too_large() {
118 let data = vec![1, 2, 3, 5];
120 let result = unpad_bytes(&data);
121 assert!(matches!(result, Err(CryptoError::InvalidPadding)));
122 }
123
124 #[test]
125 fn test_unpad_bytes_invalid_0_padding() {
126 let data = vec![1, 2, 3, 0];
128 let result = unpad_bytes(&data);
129 assert!(matches!(result, Err(CryptoError::InvalidPadding)));
130 }
131
132 #[test]
133 fn test_pad_and_unpad_bytes_range_0_to_1024() {
134 let cases: Vec<_> = (0..=1024)
135 .flat_map(|data_size| (2..=1024).map(move |padding_size| (data_size, padding_size)))
136 .collect();
137
138 let data_larger_than_padding_cases: Vec<_> = cases
139 .clone()
140 .into_iter()
141 .filter(|(data_size, padding_size)| data_size > padding_size)
142 .collect();
143 for (data_size, padding_size) in data_larger_than_padding_cases {
144 let mut data: Vec<u8> = vec![0x12; data_size];
145 let original = data.clone();
146 pad_bytes(&mut data, padding_size).expect("Padding failed");
147 let unpadded = unpad_bytes(&data).expect("Unpadding failed");
148 assert_eq!(
149 unpadded, original,
150 "Failed at size {} and padding {}",
151 data_size, padding_size
152 );
153 }
154
155 let padding_larger_than_data_cases: Vec<_> = cases
156 .clone()
157 .into_iter()
158 .filter(|(data_size, padding_size)| {
159 data_size <= padding_size && (padding_size - data_size) <= 255
160 })
161 .collect();
162 for (data_size, padding_size) in padding_larger_than_data_cases {
163 println!(
164 "Testing data_size: {}, padding_size: {}",
165 data_size, padding_size
166 );
167 let data_original: Vec<u8> = vec![0x12; data_size];
168 let mut data = data_original.clone();
169
170 pad_bytes(&mut data, padding_size).expect("Padding failed");
171 let unpadded = unpad_bytes(&data).expect("Unpadding failed");
172 assert_eq!(
173 unpadded, data_original,
174 "Failed at size {} and padding {}",
175 data_size, padding_size
176 );
177 }
178
179 let padding_massively_larger_than_data_cases: Vec<_> = cases
180 .into_iter()
181 .filter(|(data_size, padding_size)| {
182 data_size <= padding_size && (padding_size - data_size) > 255
183 })
184 .collect();
185 for (data_size, padding_size) in padding_massively_larger_than_data_cases {
186 let mut data: Vec<u8> = vec![0x12; data_size];
187 let error = pad_bytes(&mut data, padding_size);
188 assert!(
189 matches!(error, Err(CryptoError::InvalidPadding)),
190 "Expected InvalidPadding error at size {} and padding {}, but got {:?}",
191 data_size,
192 padding_size,
193 error
194 );
195 }
196 }
197}