1use std::{
2 cell::Cell,
3 sync::{RwLockReadGuard, RwLockWriteGuard},
4};
5
6use zeroize::Zeroizing;
7
8use super::KeyStoreInner;
9use crate::{
10 derive_shareable_key, error::UnsupportedOperation, store::backend::StoreBackend,
11 AsymmetricCryptoKey, CryptoError, EncString, KeyId, KeyIds, Result, SymmetricCryptoKey,
12 UnsignedSharedKey,
13};
14
15#[must_use]
64pub struct KeyStoreContext<'a, Ids: KeyIds> {
65 pub(super) global_keys: GlobalKeys<'a, Ids>,
66
67 pub(super) local_symmetric_keys: Box<dyn StoreBackend<Ids::Symmetric>>,
68 pub(super) local_asymmetric_keys: Box<dyn StoreBackend<Ids::Asymmetric>>,
69
70 pub(super) _phantom: std::marker::PhantomData<(Cell<()>, RwLockReadGuard<'static, ()>)>,
72}
73
74pub(crate) enum GlobalKeys<'a, Ids: KeyIds> {
80 ReadOnly(RwLockReadGuard<'a, KeyStoreInner<Ids>>),
81 ReadWrite(RwLockWriteGuard<'a, KeyStoreInner<Ids>>),
82}
83
84impl<Ids: KeyIds> GlobalKeys<'_, Ids> {
85 pub fn get(&self) -> &KeyStoreInner<Ids> {
86 match self {
87 GlobalKeys::ReadOnly(keys) => keys,
88 GlobalKeys::ReadWrite(keys) => keys,
89 }
90 }
91
92 pub fn get_mut(&mut self) -> Result<&mut KeyStoreInner<Ids>> {
93 match self {
94 GlobalKeys::ReadOnly(_) => Err(CryptoError::ReadOnlyKeyStore),
95 GlobalKeys::ReadWrite(keys) => Ok(keys),
96 }
97 }
98}
99
100impl<Ids: KeyIds> KeyStoreContext<'_, Ids> {
101 pub fn clear_local(&mut self) {
105 self.local_symmetric_keys.clear();
106 self.local_asymmetric_keys.clear();
107 }
108
109 pub fn retain_symmetric_keys(&mut self, f: fn(Ids::Symmetric) -> bool) {
112 if let Ok(keys) = self.global_keys.get_mut() {
113 keys.symmetric_keys.retain(f);
114 }
115 self.local_symmetric_keys.retain(f);
116 }
117
118 pub fn retain_asymmetric_keys(&mut self, f: fn(Ids::Asymmetric) -> bool) {
121 if let Ok(keys) = self.global_keys.get_mut() {
122 keys.asymmetric_keys.retain(f);
123 }
124 self.local_asymmetric_keys.retain(f);
125 }
126
127 pub fn unwrap_symmetric_key(
140 &mut self,
141 encryption_key: Ids::Symmetric,
142 new_key_id: Ids::Symmetric,
143 encrypted_key: &EncString,
144 ) -> Result<Ids::Symmetric> {
145 let mut new_key_material =
146 self.decrypt_data_with_symmetric_key(encryption_key, encrypted_key)?;
147
148 #[allow(deprecated)]
149 self.set_symmetric_key(
150 new_key_id,
151 SymmetricCryptoKey::try_from(new_key_material.as_mut_slice())?,
152 )?;
153
154 Ok(new_key_id)
156 }
157
158 pub fn wrap_symmetric_key(
167 &self,
168 wrapping_key: Ids::Symmetric,
169 key_to_wrap: Ids::Symmetric,
170 ) -> Result<EncString> {
171 use SymmetricCryptoKey::*;
172
173 let wrapping_key_instance = self.get_symmetric_key(wrapping_key)?;
174 let key_to_wrap_instance = self.get_symmetric_key(key_to_wrap)?;
175 match (wrapping_key_instance, key_to_wrap_instance) {
181 (Aes256CbcHmacKey(_), Aes256CbcHmacKey(_) | Aes256CbcKey(_)) => self
182 .encrypt_data_with_symmetric_key(
183 wrapping_key,
184 key_to_wrap_instance.to_encoded().as_slice(),
185 ),
186 _ => Err(CryptoError::OperationNotSupported(
187 UnsupportedOperation::EncryptionNotImplementedForKey,
188 )),
189 }
190 }
191
192 pub fn decapsulate_key_unsigned(
202 &mut self,
203 decapsulation_key: Ids::Asymmetric,
204 new_key_id: Ids::Symmetric,
205 encapsulated_shared_key: &UnsignedSharedKey,
206 ) -> Result<Ids::Symmetric> {
207 let decapsulation_key = self.get_asymmetric_key(decapsulation_key)?;
208 let decapsulated_key =
209 encapsulated_shared_key.decapsulate_key_unsigned(decapsulation_key)?;
210
211 #[allow(deprecated)]
212 self.set_symmetric_key(new_key_id, decapsulated_key)?;
213
214 Ok(new_key_id)
216 }
217
218 pub fn encapsulate_key_unsigned(
227 &self,
228 encapsulation_key: Ids::Asymmetric,
229 shared_key: Ids::Symmetric,
230 ) -> Result<UnsignedSharedKey> {
231 UnsignedSharedKey::encapsulate_key_unsigned(
232 self.get_symmetric_key(shared_key)?,
233 self.get_asymmetric_key(encapsulation_key)?,
234 )
235 }
236
237 pub fn has_symmetric_key(&self, key_id: Ids::Symmetric) -> bool {
239 self.get_symmetric_key(key_id).is_ok()
240 }
241
242 pub fn has_asymmetric_key(&self, key_id: Ids::Asymmetric) -> bool {
244 self.get_asymmetric_key(key_id).is_ok()
245 }
246
247 pub fn generate_symmetric_key(&mut self, key_id: Ids::Symmetric) -> Result<Ids::Symmetric> {
249 let key = SymmetricCryptoKey::make_aes256_cbc_hmac_key();
250 #[allow(deprecated)]
251 self.set_symmetric_key(key_id, key)?;
252 Ok(key_id)
253 }
254
255 pub fn derive_shareable_key(
260 &mut self,
261 key_id: Ids::Symmetric,
262 secret: Zeroizing<[u8; 16]>,
263 name: &str,
264 info: Option<&str>,
265 ) -> Result<Ids::Symmetric> {
266 #[allow(deprecated)]
267 self.set_symmetric_key(
268 key_id,
269 SymmetricCryptoKey::Aes256CbcHmacKey(derive_shareable_key(secret, name, info)),
270 )?;
271 Ok(key_id)
272 }
273
274 #[deprecated(note = "This function should ideally never be used outside this crate")]
275 pub fn dangerous_get_symmetric_key(
276 &self,
277 key_id: Ids::Symmetric,
278 ) -> Result<&SymmetricCryptoKey> {
279 self.get_symmetric_key(key_id)
280 }
281
282 #[deprecated(note = "This function should ideally never be used outside this crate")]
283 pub fn dangerous_get_asymmetric_key(
284 &self,
285 key_id: Ids::Asymmetric,
286 ) -> Result<&AsymmetricCryptoKey> {
287 self.get_asymmetric_key(key_id)
288 }
289
290 fn get_symmetric_key(&self, key_id: Ids::Symmetric) -> Result<&SymmetricCryptoKey> {
291 if key_id.is_local() {
292 self.local_symmetric_keys.get(key_id)
293 } else {
294 self.global_keys.get().symmetric_keys.get(key_id)
295 }
296 .ok_or_else(|| crate::CryptoError::MissingKeyId(format!("{key_id:?}")))
297 }
298
299 fn get_asymmetric_key(&self, key_id: Ids::Asymmetric) -> Result<&AsymmetricCryptoKey> {
300 if key_id.is_local() {
301 self.local_asymmetric_keys.get(key_id)
302 } else {
303 self.global_keys.get().asymmetric_keys.get(key_id)
304 }
305 .ok_or_else(|| crate::CryptoError::MissingKeyId(format!("{key_id:?}")))
306 }
307
308 #[deprecated(note = "This function should ideally never be used outside this crate")]
309 pub fn set_symmetric_key(
310 &mut self,
311 key_id: Ids::Symmetric,
312 key: SymmetricCryptoKey,
313 ) -> Result<()> {
314 if key_id.is_local() {
315 self.local_symmetric_keys.upsert(key_id, key);
316 } else {
317 self.global_keys
318 .get_mut()?
319 .symmetric_keys
320 .upsert(key_id, key);
321 }
322 Ok(())
323 }
324
325 #[deprecated(note = "This function should ideally never be used outside this crate")]
326 pub fn set_asymmetric_key(
327 &mut self,
328 key_id: Ids::Asymmetric,
329 key: AsymmetricCryptoKey,
330 ) -> Result<()> {
331 if key_id.is_local() {
332 self.local_asymmetric_keys.upsert(key_id, key);
333 } else {
334 self.global_keys
335 .get_mut()?
336 .asymmetric_keys
337 .upsert(key_id, key);
338 }
339 Ok(())
340 }
341
342 pub(crate) fn decrypt_data_with_symmetric_key(
343 &self,
344 key: Ids::Symmetric,
345 data: &EncString,
346 ) -> Result<Vec<u8>> {
347 let key = self.get_symmetric_key(key)?;
348
349 match (data, key) {
350 (EncString::Aes256Cbc_B64 { iv, data }, SymmetricCryptoKey::Aes256CbcKey(key)) => {
351 crate::aes::decrypt_aes256(iv, data.clone(), &key.enc_key)
352 }
353 (
354 EncString::Aes256Cbc_HmacSha256_B64 { iv, mac, data },
355 SymmetricCryptoKey::Aes256CbcHmacKey(key),
356 ) => crate::aes::decrypt_aes256_hmac(iv, mac, data.clone(), &key.mac_key, &key.enc_key),
357 _ => Err(CryptoError::InvalidKey),
358 }
359 }
360
361 pub(crate) fn encrypt_data_with_symmetric_key(
362 &self,
363 key: Ids::Symmetric,
364 data: &[u8],
365 ) -> Result<EncString> {
366 let key = self.get_symmetric_key(key)?;
367 match key {
368 SymmetricCryptoKey::Aes256CbcKey(_) => Err(CryptoError::OperationNotSupported(
369 UnsupportedOperation::EncryptionNotImplementedForKey,
370 )),
371 SymmetricCryptoKey::Aes256CbcHmacKey(key) => EncString::encrypt_aes256_hmac(data, key),
372 SymmetricCryptoKey::XChaCha20Poly1305Key(key) => {
373 EncString::encrypt_xchacha20_poly1305(data, key)
374 }
375 }
376 }
377}
378
379#[cfg(test)]
380#[allow(deprecated)]
381mod tests {
382 use crate::{
383 store::{tests::DataView, KeyStore},
384 traits::tests::{TestIds, TestSymmKey},
385 Decryptable, Encryptable, SymmetricCryptoKey,
386 };
387
388 #[test]
389 fn test_set_keys_for_encryption() {
390 let store: KeyStore<TestIds> = KeyStore::default();
391
392 let key_a0_id = TestSymmKey::A(0);
394 let key_a0 = SymmetricCryptoKey::make_aes256_cbc_hmac_key();
395
396 store
397 .context_mut()
398 .set_symmetric_key(TestSymmKey::A(0), key_a0.clone())
399 .unwrap();
400
401 assert!(store.context().has_symmetric_key(key_a0_id));
402
403 let data = DataView("Hello, World!".to_string(), key_a0_id);
405 let _encrypted = data.encrypt(&mut store.context(), key_a0_id).unwrap();
406 }
407
408 #[test]
409 fn test_key_encryption() {
410 let store: KeyStore<TestIds> = KeyStore::default();
411
412 let mut ctx = store.context();
413
414 let key_1_id = TestSymmKey::C(1);
416 let key_1 = SymmetricCryptoKey::make_aes256_cbc_hmac_key();
417
418 ctx.set_symmetric_key(key_1_id, key_1.clone()).unwrap();
419
420 assert!(ctx.has_symmetric_key(key_1_id));
421
422 let key_2_id = TestSymmKey::C(2);
424 let key_2 = SymmetricCryptoKey::make_aes256_cbc_hmac_key();
425
426 ctx.set_symmetric_key(key_2_id, key_2.clone()).unwrap();
427
428 assert!(ctx.has_symmetric_key(key_2_id));
429
430 let key_2_enc = ctx.wrap_symmetric_key(key_1_id, key_2_id).unwrap();
432
433 let new_key_id = TestSymmKey::C(3);
435
436 ctx.unwrap_symmetric_key(key_1_id, new_key_id, &key_2_enc)
437 .unwrap();
438
439 let data = DataView("Hello, World!".to_string(), key_2_id);
443 let encrypted = data.encrypt(&mut ctx, key_2_id).unwrap();
444
445 let decrypted1 = encrypted.decrypt(&mut ctx, key_2_id).unwrap();
446 let decrypted2 = encrypted.decrypt(&mut ctx, new_key_id).unwrap();
447
448 assert_eq!(decrypted1.0, decrypted2.0);
450 }
451}