use std::collections::BTreeSet;
use rand::{distributions::Distribution, seq::SliceRandom, RngCore};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum PasswordError {
#[error("No character set enabled")]
NoCharacterSetEnabled,
#[error("Invalid password length")]
InvalidLength,
}
#[derive(Serialize, Deserialize, Debug, JsonSchema)]
#[serde(rename_all = "camelCase", deny_unknown_fields)]
#[cfg_attr(feature = "uniffi", derive(uniffi::Record))]
pub struct PasswordGeneratorRequest {
pub lowercase: bool,
pub uppercase: bool,
pub numbers: bool,
pub special: bool,
pub length: u8,
pub avoid_ambiguous: bool, pub min_lowercase: Option<u8>,
pub min_uppercase: Option<u8>,
pub min_number: Option<u8>,
pub min_special: Option<u8>,
}
const DEFAULT_PASSWORD_LENGTH: u8 = 16;
impl Default for PasswordGeneratorRequest {
fn default() -> Self {
Self {
lowercase: true,
uppercase: true,
numbers: true,
special: false,
length: DEFAULT_PASSWORD_LENGTH,
avoid_ambiguous: false,
min_lowercase: None,
min_uppercase: None,
min_number: None,
min_special: None,
}
}
}
const UPPER_CHARS_AMBIGUOUS: &[char] = &['I', 'O'];
const LOWER_CHARS_AMBIGUOUS: &[char] = &['l'];
const NUMBER_CHARS_AMBIGUOUS: &[char] = &['0', '1'];
const SPECIAL_CHARS: &[char] = &['!', '@', '#', '$', '%', '^', '&', '*'];
#[derive(Clone, Default)]
struct CharSet(BTreeSet<char>);
impl CharSet {
pub fn include(self, other: impl IntoIterator<Item = char>) -> Self {
self.include_if(true, other)
}
pub fn include_if(mut self, predicate: bool, other: impl IntoIterator<Item = char>) -> Self {
if predicate {
self.0.extend(other);
}
self
}
pub fn exclude_if<'a>(
self,
predicate: bool,
other: impl IntoIterator<Item = &'a char>,
) -> Self {
if predicate {
let other: BTreeSet<_> = other.into_iter().copied().collect();
Self(self.0.difference(&other).copied().collect())
} else {
self
}
}
}
impl<'a> IntoIterator for &'a CharSet {
type Item = char;
type IntoIter = std::iter::Copied<std::collections::btree_set::Iter<'a, char>>;
fn into_iter(self) -> Self::IntoIter {
self.0.iter().copied()
}
}
impl Distribution<char> for CharSet {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> char {
let idx = rng.gen_range(0..self.0.len());
*self.0.iter().nth(idx).expect("Valid index")
}
}
struct PasswordGeneratorOptions {
pub(super) lower: (CharSet, usize),
pub(super) upper: (CharSet, usize),
pub(super) number: (CharSet, usize),
pub(super) special: (CharSet, usize),
pub(super) all: (CharSet, usize),
pub(super) length: usize,
}
impl PasswordGeneratorRequest {
fn validate_options(self) -> Result<PasswordGeneratorOptions, PasswordError> {
if !self.lowercase && !self.uppercase && !self.numbers && !self.special {
return Err(PasswordError::NoCharacterSetEnabled);
}
if self.length < 4 {
return Err(PasswordError::InvalidLength);
}
fn get_minimum(min: Option<u8>, enabled: bool) -> usize {
if enabled {
usize::max(min.unwrap_or(1) as usize, 1)
} else {
0
}
}
let length = self.length as usize;
let min_lowercase = get_minimum(self.min_lowercase, self.lowercase);
let min_uppercase = get_minimum(self.min_uppercase, self.uppercase);
let min_number = get_minimum(self.min_number, self.numbers);
let min_special = get_minimum(self.min_special, self.special);
let minimum_length = min_lowercase + min_uppercase + min_number + min_special;
if minimum_length > length {
return Err(PasswordError::InvalidLength);
}
let lower = (
CharSet::default()
.include_if(self.lowercase, 'a'..='z')
.exclude_if(self.avoid_ambiguous, LOWER_CHARS_AMBIGUOUS),
min_lowercase,
);
let upper = (
CharSet::default()
.include_if(self.uppercase, 'A'..='Z')
.exclude_if(self.avoid_ambiguous, UPPER_CHARS_AMBIGUOUS),
min_uppercase,
);
let number = (
CharSet::default()
.include_if(self.numbers, '0'..='9')
.exclude_if(self.avoid_ambiguous, NUMBER_CHARS_AMBIGUOUS),
min_number,
);
let special = (
CharSet::default().include_if(self.special, SPECIAL_CHARS.iter().copied()),
min_special,
);
let all = (
CharSet::default()
.include(&lower.0)
.include(&upper.0)
.include(&number.0)
.include(&special.0),
length - minimum_length,
);
Ok(PasswordGeneratorOptions {
lower,
upper,
number,
special,
all,
length,
})
}
}
pub(crate) fn password(input: PasswordGeneratorRequest) -> Result<String, PasswordError> {
let options = input.validate_options()?;
Ok(password_with_rng(rand::thread_rng(), options))
}
fn password_with_rng(mut rng: impl RngCore, options: PasswordGeneratorOptions) -> String {
let mut buf: Vec<char> = Vec::with_capacity(options.length);
let opts = [
&options.all,
&options.upper,
&options.lower,
&options.number,
&options.special,
];
for (set, qty) in opts {
buf.extend(set.sample_iter(&mut rng).take(*qty));
}
buf.shuffle(&mut rng);
buf.iter().collect()
}
#[cfg(test)]
mod test {
use std::collections::BTreeSet;
use rand::SeedableRng;
use super::*;
fn ref_to_set<'a>(chars: impl IntoIterator<Item = &'a char>) -> BTreeSet<char> {
chars.into_iter().copied().collect()
}
fn to_set(chars: impl IntoIterator<Item = char>) -> BTreeSet<char> {
chars.into_iter().collect()
}
#[test]
fn test_password_gen_all_charsets_enabled() {
let mut rng = rand_chacha::ChaCha8Rng::from_seed([0u8; 32]);
let options = PasswordGeneratorRequest {
lowercase: true,
uppercase: true,
numbers: true,
special: true,
avoid_ambiguous: false,
..Default::default()
}
.validate_options()
.unwrap();
assert_eq!(to_set(&options.lower.0), to_set('a'..='z'));
assert_eq!(to_set(&options.upper.0), to_set('A'..='Z'));
assert_eq!(to_set(&options.number.0), to_set('0'..='9'));
assert_eq!(to_set(&options.special.0), ref_to_set(SPECIAL_CHARS));
let pass = password_with_rng(&mut rng, options);
assert_eq!(pass, "Z!^B5r%hUa23dFM@");
}
#[test]
fn test_password_gen_only_letters_enabled() {
let mut rng = rand_chacha::ChaCha8Rng::from_seed([0u8; 32]);
let options = PasswordGeneratorRequest {
lowercase: true,
uppercase: true,
numbers: false,
special: false,
avoid_ambiguous: false,
..Default::default()
}
.validate_options()
.unwrap();
assert_eq!(to_set(&options.lower.0), to_set('a'..='z'));
assert_eq!(to_set(&options.upper.0), to_set('A'..='Z'));
assert_eq!(to_set(&options.number.0), to_set([]));
assert_eq!(to_set(&options.special.0), to_set([]));
let pass = password_with_rng(&mut rng, options);
assert_eq!(pass, "NQiFrGufQMiNUAmj");
}
#[test]
fn test_password_gen_only_numbers_and_lower_enabled_no_ambiguous() {
let mut rng = rand_chacha::ChaCha8Rng::from_seed([0u8; 32]);
let options = PasswordGeneratorRequest {
lowercase: true,
uppercase: false,
numbers: true,
special: false,
avoid_ambiguous: true,
..Default::default()
}
.validate_options()
.unwrap();
assert!(to_set(&options.lower.0).is_subset(&to_set('a'..='z')));
assert!(to_set(&options.lower.0).is_disjoint(&ref_to_set(LOWER_CHARS_AMBIGUOUS)));
assert!(to_set(&options.number.0).is_subset(&to_set('0'..='9')));
assert!(to_set(&options.number.0).is_disjoint(&ref_to_set(NUMBER_CHARS_AMBIGUOUS)));
assert_eq!(to_set(&options.upper.0), to_set([]));
assert_eq!(to_set(&options.special.0), to_set([]));
let pass = password_with_rng(&mut rng, options);
assert_eq!(pass, "mnjabfz5ct272prf");
}
#[test]
fn test_password_gen_only_upper_and_special_enabled_no_ambiguous() {
let mut rng = rand_chacha::ChaCha8Rng::from_seed([0u8; 32]);
let options = PasswordGeneratorRequest {
lowercase: false,
uppercase: true,
numbers: false,
special: true,
avoid_ambiguous: true,
..Default::default()
}
.validate_options()
.unwrap();
assert!(to_set(&options.upper.0).is_subset(&to_set('A'..='Z')));
assert!(to_set(&options.upper.0).is_disjoint(&ref_to_set(UPPER_CHARS_AMBIGUOUS)));
assert_eq!(to_set(&options.special.0), ref_to_set(SPECIAL_CHARS));
assert_eq!(to_set(&options.lower.0), to_set([]));
assert_eq!(to_set(&options.number.0), to_set([]));
let pass = password_with_rng(&mut rng, options);
assert_eq!(pass, "B*GBQANS%UZPQD!K");
}
#[test]
fn test_password_gen_minimum_limits() {
let mut rng = rand_chacha::ChaCha8Rng::from_seed([0u8; 32]);
let options = PasswordGeneratorRequest {
lowercase: true,
uppercase: true,
numbers: true,
special: true,
avoid_ambiguous: false,
length: 24,
min_lowercase: Some(5),
min_uppercase: Some(5),
min_number: Some(5),
min_special: Some(5),
}
.validate_options()
.unwrap();
assert_eq!(to_set(&options.lower.0), to_set('a'..='z'));
assert_eq!(to_set(&options.upper.0), to_set('A'..='Z'));
assert_eq!(to_set(&options.number.0), to_set('0'..='9'));
assert_eq!(to_set(&options.special.0), ref_to_set(SPECIAL_CHARS));
assert_eq!(options.lower.1, 5);
assert_eq!(options.upper.1, 5);
assert_eq!(options.number.1, 5);
assert_eq!(options.special.1, 5);
let pass = password_with_rng(&mut rng, options);
assert_eq!(pass, "236q5!a#R%PG5rI%k1!*@uRt");
}
}