1use std::{pin::Pin, str::FromStr};
2
3use ::aes::cipher::{ArrayLength, Unsigned};
4use generic_array::GenericArray;
5use hmac::digest::OutputSizeUser;
6use rand::{
7 distributions::{Alphanumeric, DistString, Distribution, Standard},
8 Rng,
9};
10use zeroize::{Zeroize, Zeroizing};
11
12use crate::{CryptoError, Result};
13
14pub(crate) type PbkdfSha256Hmac = hmac::Hmac<sha2::Sha256>;
15pub(crate) const PBKDF_SHA256_HMAC_OUT_SIZE: usize =
16 <<PbkdfSha256Hmac as OutputSizeUser>::OutputSize as Unsigned>::USIZE;
17
18pub(crate) fn hkdf_expand<T: ArrayLength<u8>>(
20 prk: &[u8],
21 info: Option<&str>,
22) -> Result<Pin<Box<GenericArray<u8, T>>>> {
23 let hkdf = hkdf::Hkdf::<sha2::Sha256>::from_prk(prk).map_err(|_| CryptoError::InvalidKeyLen)?;
24 let mut key = Box::<GenericArray<u8, T>>::default();
25
26 let i = info.map(|i| i.as_bytes()).unwrap_or(&[]);
27 hkdf.expand(i, &mut key)
28 .map_err(|_| CryptoError::InvalidKeyLen)?;
29
30 Ok(Box::into_pin(key))
31}
32
33pub fn generate_random_bytes<T>() -> Zeroizing<T>
35where
36 Standard: Distribution<T>,
37 T: Zeroize,
38{
39 Zeroizing::new(rand::thread_rng().r#gen::<T>())
40}
41
42pub fn generate_random_alphanumeric(len: usize) -> String {
47 Alphanumeric.sample_string(&mut rand::thread_rng(), len)
48}
49
50pub fn pbkdf2(password: &[u8], salt: &[u8], rounds: u32) -> [u8; PBKDF_SHA256_HMAC_OUT_SIZE] {
52 pbkdf2::pbkdf2_array::<PbkdfSha256Hmac, PBKDF_SHA256_HMAC_OUT_SIZE>(password, salt, rounds)
53 .expect("hash is a valid fixed size")
54}
55
56pub struct FromStrVisitor<T>(std::marker::PhantomData<T>);
58impl<T> FromStrVisitor<T> {
59 pub fn new() -> Self {
61 Self::default()
62 }
63}
64impl<T> Default for FromStrVisitor<T> {
65 fn default() -> Self {
66 Self(Default::default())
67 }
68}
69impl<T: FromStr> serde::de::Visitor<'_> for FromStrVisitor<T>
70where
71 T::Err: std::fmt::Debug,
72{
73 type Value = T;
74
75 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
76 write!(f, "a valid string")
77 }
78
79 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
80 where
81 E: serde::de::Error,
82 {
83 T::from_str(v).map_err(|e| E::custom(format!("{e:?}")))
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use typenum::U64;
90
91 use super::*;
92
93 #[test]
94 fn test_hkdf_expand() {
95 let prk = &[
96 23, 152, 120, 41, 214, 16, 156, 133, 71, 226, 178, 135, 208, 255, 66, 101, 189, 70,
97 173, 30, 39, 215, 175, 236, 38, 180, 180, 62, 196, 4, 159, 70,
98 ];
99 let info = Some("info");
100
101 let result: Pin<Box<GenericArray<u8, U64>>> = hkdf_expand(prk, info).unwrap();
102
103 let expected_output: [u8; 64] = [
104 6, 114, 42, 38, 87, 231, 30, 109, 30, 255, 104, 129, 255, 94, 92, 108, 124, 145, 215,
105 208, 17, 60, 135, 22, 70, 158, 40, 53, 45, 182, 8, 63, 65, 87, 239, 234, 185, 227, 153,
106 122, 115, 205, 144, 56, 102, 149, 92, 139, 217, 102, 119, 57, 37, 57, 251, 178, 18, 52,
107 94, 77, 132, 215, 239, 100,
108 ];
109
110 assert_eq!(result.as_slice(), expected_output);
111 }
112}