bitwarden_state/
registry.rs1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4 sync::{Arc, RwLock},
5};
6
7use crate::repository::{Repository, RepositoryItem};
8
9pub struct StateRegistry {
12 client_managed: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
13}
14
15impl std::fmt::Debug for StateRegistry {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 f.debug_struct("StateRegistry").finish()
18 }
19}
20
21impl StateRegistry {
22 #[allow(clippy::new_without_default)]
24 pub fn new() -> Self {
25 StateRegistry {
26 client_managed: RwLock::new(HashMap::new()),
27 }
28 }
29
30 pub fn register_client_managed<T: RepositoryItem>(&self, value: Arc<dyn Repository<T>>) {
32 self.client_managed
33 .write()
34 .expect("RwLock should not be poisoned")
35 .insert(TypeId::of::<T>(), Box::new(value));
36 }
37
38 pub fn get_client_managed<T: RepositoryItem>(&self) -> Option<Arc<dyn Repository<T>>> {
40 self.client_managed
41 .read()
42 .expect("RwLock should not be poisoned")
43 .get(&TypeId::of::<T>())
44 .and_then(|boxed| boxed.downcast_ref::<Arc<dyn Repository<T>>>())
45 .map(Arc::clone)
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use super::*;
52 use crate::{
53 register_repository_item,
54 repository::{RepositoryError, RepositoryItem},
55 };
56
57 macro_rules! impl_repository {
58 ($name:ident, $ty:ty) => {
59 #[async_trait::async_trait]
60 impl Repository<$ty> for $name {
61 async fn get(&self, _key: String) -> Result<Option<$ty>, RepositoryError> {
62 Ok(Some(TestItem(self.0.clone())))
63 }
64 async fn list(&self) -> Result<Vec<$ty>, RepositoryError> {
65 unimplemented!()
66 }
67 async fn set(&self, _key: String, _value: $ty) -> Result<(), RepositoryError> {
68 unimplemented!()
69 }
70 async fn remove(&self, _key: String) -> Result<(), RepositoryError> {
71 unimplemented!()
72 }
73 }
74 };
75 }
76
77 #[derive(PartialEq, Eq, Debug)]
78 struct TestA(usize);
79 #[derive(PartialEq, Eq, Debug)]
80 struct TestB(String);
81 #[derive(PartialEq, Eq, Debug)]
82 struct TestC(Vec<u8>);
83 #[derive(PartialEq, Eq, Debug)]
84 struct TestItem<T>(T);
85
86 register_repository_item!(TestItem<usize>, "TestItem<usize>");
87 register_repository_item!(TestItem<String>, "TestItem<String>");
88 register_repository_item!(TestItem<Vec<u8>>, "TestItem<Vec<u8>>");
89
90 impl_repository!(TestA, TestItem<usize>);
91 impl_repository!(TestB, TestItem<String>);
92 impl_repository!(TestC, TestItem<Vec<u8>>);
93
94 #[tokio::test]
95 async fn test_repository_map() {
96 let a = Arc::new(TestA(145832));
97 let b = Arc::new(TestB("test".to_string()));
98 let c = Arc::new(TestC(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]));
99
100 let map = StateRegistry::new();
101
102 async fn get<T: RepositoryItem>(map: &StateRegistry) -> Option<T> {
103 map.get_client_managed::<T>()
104 .unwrap()
105 .get(String::new())
106 .await
107 .unwrap()
108 }
109
110 assert!(map.get_client_managed::<TestItem<usize>>().is_none());
111 assert!(map.get_client_managed::<TestItem<String>>().is_none());
112 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
113
114 map.register_client_managed(a.clone());
115 assert_eq!(get(&map).await, Some(TestItem(a.0)));
116 assert!(map.get_client_managed::<TestItem<String>>().is_none());
117 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
118
119 map.register_client_managed(b.clone());
120 assert_eq!(get(&map).await, Some(TestItem(a.0)));
121 assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
122 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
123
124 map.register_client_managed(c.clone());
125 assert_eq!(get(&map).await, Some(TestItem(a.0)));
126 assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
127 assert_eq!(get(&map).await, Some(TestItem(c.0.clone())));
128 }
129}