1use std::pin::Pin;
2
3use ::aes::cipher::{ArrayLength, Unsigned};
4use generic_array::GenericArray;
5use hmac::digest::OutputSizeUser;
6use rand::{
7 Rng,
8 distributions::{Alphanumeric, DistString, Distribution, Standard},
9};
10use zeroize::{Zeroize, Zeroizing};
11
12use crate::Result;
13
14pub(crate) type PbkdfSha256Hmac = hmac::Hmac<sha2::Sha256>;
15pub(crate) const PBKDF_SHA256_HMAC_OUT_SIZE: usize =
16 <<PbkdfSha256Hmac as OutputSizeUser>::OutputSize as Unsigned>::USIZE;
17
18#[derive(Debug)]
19pub(crate) enum HkdfExpandError {
20 InvalidInputLegth,
21 InvalidOutputLength,
22}
23
24pub(crate) fn hkdf_expand<T: ArrayLength<u8>>(
26 prk: &[u8],
27 info: Option<&str>,
28) -> Result<Pin<Box<GenericArray<u8, T>>>, HkdfExpandError> {
29 let hkdf = hkdf::Hkdf::<sha2::Sha256>::from_prk(prk)
30 .map_err(|_| HkdfExpandError::InvalidInputLegth)?;
31 let mut key = Box::<GenericArray<u8, T>>::default();
32
33 let i = info.map(|i| i.as_bytes()).unwrap_or(&[]);
34 hkdf.expand(i, &mut key)
35 .map_err(|_| HkdfExpandError::InvalidOutputLength)?;
36
37 Ok(Box::into_pin(key))
38}
39
40pub fn generate_random_bytes<T>() -> Zeroizing<T>
42where
43 Standard: Distribution<T>,
44 T: Zeroize,
45{
46 Zeroizing::new(rand::thread_rng().r#gen::<T>())
47}
48
49pub fn generate_random_alphanumeric(len: usize) -> String {
54 Alphanumeric.sample_string(&mut rand::thread_rng(), len)
55}
56
57pub fn pbkdf2(password: &[u8], salt: &[u8], rounds: u32) -> [u8; PBKDF_SHA256_HMAC_OUT_SIZE] {
59 pbkdf2::pbkdf2_array::<PbkdfSha256Hmac, PBKDF_SHA256_HMAC_OUT_SIZE>(password, salt, rounds)
60 .expect("hash is a valid fixed size")
61}
62
63#[cfg(test)]
64mod tests {
65 use typenum::U64;
66
67 use super::*;
68
69 #[test]
70 fn test_hkdf_expand() {
71 let prk = &[
72 23, 152, 120, 41, 214, 16, 156, 133, 71, 226, 178, 135, 208, 255, 66, 101, 189, 70,
73 173, 30, 39, 215, 175, 236, 38, 180, 180, 62, 196, 4, 159, 70,
74 ];
75 let info = Some("info");
76
77 let result: Pin<Box<GenericArray<u8, U64>>> = hkdf_expand(prk, info).unwrap();
78
79 let expected_output: [u8; 64] = [
80 6, 114, 42, 38, 87, 231, 30, 109, 30, 255, 104, 129, 255, 94, 92, 108, 124, 145, 215,
81 208, 17, 60, 135, 22, 70, 158, 40, 53, 45, 182, 8, 63, 65, 87, 239, 234, 185, 227, 153,
82 122, 115, 205, 144, 56, 102, 149, 92, 139, 217, 102, 119, 57, 37, 57, 251, 178, 18, 52,
83 94, 77, 132, 215, 239, 100,
84 ];
85
86 assert_eq!(result.as_slice(), expected_output);
87 }
88
89 #[test]
90 fn test_hkdf_expand_invalid_input_length() {
91 let prk = &[1, 2, 3, 4, 5];
93 let info = Some("info");
94
95 let result: Result<Pin<Box<GenericArray<u8, U64>>>, HkdfExpandError> =
96 hkdf_expand(prk, info);
97
98 assert!(matches!(result, Err(HkdfExpandError::InvalidInputLegth)));
99 }
100
101 #[test]
102 fn test_hkdf_expand_invalid_output_length() {
103 let prk = &[
104 23, 152, 120, 41, 214, 16, 156, 133, 71, 226, 178, 135, 208, 255, 66, 101, 189, 70,
105 173, 30, 39, 215, 175, 236, 38, 180, 180, 62, 196, 4, 159, 70,
106 ];
107 let info = Some("info");
108
109 type TooLarge = typenum::U8192;
111 let result: Result<Pin<Box<GenericArray<u8, TooLarge>>>, HkdfExpandError> =
112 hkdf_expand(prk, info);
113
114 assert!(matches!(result, Err(HkdfExpandError::InvalidOutputLength)));
115 }
116}