bitwarden_state/
registry.rs1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4 sync::{Arc, RwLock},
5};
6
7use thiserror::Error;
8
9use crate::repository::{Repository, RepositoryItem};
10
11pub struct StateRegistry {
14 client_managed: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
15}
16
17impl std::fmt::Debug for StateRegistry {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 f.debug_struct("StateRegistry").finish()
20 }
21}
22
23#[derive(Debug, Error)]
25#[error("Repository not found for the requested type")]
26pub struct RepositoryNotFoundError;
27
28impl StateRegistry {
29 #[allow(clippy::new_without_default)]
31 pub fn new() -> Self {
32 StateRegistry {
33 client_managed: RwLock::new(HashMap::new()),
34 }
35 }
36
37 pub fn register_client_managed<T: RepositoryItem>(&self, value: Arc<dyn Repository<T>>) {
39 self.client_managed
40 .write()
41 .expect("RwLock should not be poisoned")
42 .insert(TypeId::of::<T>(), Box::new(value));
43 }
44
45 pub fn get_client_managed<T: RepositoryItem>(
47 &self,
48 ) -> Result<Arc<dyn Repository<T>>, RepositoryNotFoundError> {
49 self.client_managed
50 .read()
51 .expect("RwLock should not be poisoned")
52 .get(&TypeId::of::<T>())
53 .and_then(|boxed| boxed.downcast_ref::<Arc<dyn Repository<T>>>())
54 .map(Arc::clone)
55 .ok_or(RepositoryNotFoundError)
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62 use crate::{
63 register_repository_item,
64 repository::{RepositoryError, RepositoryItem},
65 };
66
67 macro_rules! impl_repository {
68 ($name:ident, $ty:ty) => {
69 #[async_trait::async_trait]
70 impl Repository<$ty> for $name {
71 async fn get(&self, _key: String) -> Result<Option<$ty>, RepositoryError> {
72 Ok(Some(TestItem(self.0.clone())))
73 }
74 async fn list(&self) -> Result<Vec<$ty>, RepositoryError> {
75 unimplemented!()
76 }
77 async fn set(&self, _key: String, _value: $ty) -> Result<(), RepositoryError> {
78 unimplemented!()
79 }
80 async fn remove(&self, _key: String) -> Result<(), RepositoryError> {
81 unimplemented!()
82 }
83 }
84 };
85 }
86
87 #[derive(PartialEq, Eq, Debug)]
88 struct TestA(usize);
89 #[derive(PartialEq, Eq, Debug)]
90 struct TestB(String);
91 #[derive(PartialEq, Eq, Debug)]
92 struct TestC(Vec<u8>);
93 #[derive(PartialEq, Eq, Debug)]
94 struct TestItem<T>(T);
95
96 register_repository_item!(TestItem<usize>, "TestItem<usize>");
97 register_repository_item!(TestItem<String>, "TestItem<String>");
98 register_repository_item!(TestItem<Vec<u8>>, "TestItem<Vec<u8>>");
99
100 impl_repository!(TestA, TestItem<usize>);
101 impl_repository!(TestB, TestItem<String>);
102 impl_repository!(TestC, TestItem<Vec<u8>>);
103
104 #[tokio::test]
105 async fn test_repository_map() {
106 let a = Arc::new(TestA(145832));
107 let b = Arc::new(TestB("test".to_string()));
108 let c = Arc::new(TestC(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]));
109
110 let map = StateRegistry::new();
111
112 async fn get<T: RepositoryItem>(map: &StateRegistry) -> Option<T> {
113 map.get_client_managed::<T>()
114 .unwrap()
115 .get(String::new())
116 .await
117 .unwrap()
118 }
119
120 assert!(map.get_client_managed::<TestItem<usize>>().is_err());
121 assert!(map.get_client_managed::<TestItem<String>>().is_err());
122 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_err());
123
124 map.register_client_managed(a.clone());
125 assert_eq!(get(&map).await, Some(TestItem(a.0)));
126 assert!(map.get_client_managed::<TestItem<String>>().is_err());
127 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_err());
128
129 map.register_client_managed(b.clone());
130 assert_eq!(get(&map).await, Some(TestItem(a.0)));
131 assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
132 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_err());
133
134 map.register_client_managed(c.clone());
135 assert_eq!(get(&map).await, Some(TestItem(a.0)));
136 assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
137 assert_eq!(get(&map).await, Some(TestItem(c.0.clone())));
138 }
139}