1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4 sync::{Arc, RwLock},
5};
6
7use bitwarden_error::bitwarden_error;
8use thiserror::Error;
9
10use crate::{
11 repository::{Repository, RepositoryItem, RepositoryMigrations},
12 sdk_managed::{Database, DatabaseConfiguration, DatabaseError, MemoryDatabase, SystemDatabase},
13 settings::{Key, Setting, SettingItem},
14};
15
16pub struct StateRegistry {
19 database: SystemDatabase,
20 client_managed: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
21}
22
23impl std::fmt::Debug for StateRegistry {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 f.debug_struct("StateRegistry").finish()
26 }
27}
28
29#[allow(missing_docs)]
30#[bitwarden_error(flat)]
31#[derive(Debug, Error)]
32pub enum StateRegistryError {
33 #[error("Database is not initialized")]
34 DatabaseNotInitialized,
35
36 #[error(transparent)]
37 Database(#[from] DatabaseError),
38}
39
40impl StateRegistry {
41 pub fn new_with_memory_db() -> Self {
43 StateRegistry {
44 database: SystemDatabase::Memory(MemoryDatabase::new()),
45 client_managed: RwLock::new(HashMap::new()),
46 }
47 }
48
49 pub async fn new_with_db(
51 configuration: DatabaseConfiguration,
52 migrations: RepositoryMigrations,
53 ) -> Result<Self, DatabaseError> {
54 let database = SystemDatabase::initialize(configuration, migrations.clone()).await?;
55 Ok(StateRegistry {
56 database,
57 client_managed: RwLock::new(HashMap::new()),
58 })
59 }
60
61 pub fn setting<T>(&self, key: Key<T>) -> Result<Setting<T>, StateRegistryError> {
63 let repo = self.get::<SettingItem>()?;
64 Ok(Setting::new(repo, key))
65 }
66
67 pub fn register_client_managed<T: RepositoryItem>(&self, value: Arc<dyn Repository<T>>) {
69 self.client_managed
70 .write()
71 .expect("RwLock should not be poisoned")
72 .insert(TypeId::of::<T>(), Box::new(value));
73 }
74
75 fn get_client_managed<T: RepositoryItem>(&self) -> Option<Arc<dyn Repository<T>>> {
77 self.client_managed
78 .read()
79 .expect("RwLock should not be poisoned")
80 .get(&TypeId::of::<T>())
81 .and_then(|boxed| boxed.downcast_ref::<Arc<dyn Repository<T>>>())
82 .map(Arc::clone)
83 }
84
85 fn get_sdk_managed<T: RepositoryItem>(
87 &self,
88 ) -> Result<Arc<dyn Repository<T>>, StateRegistryError> {
89 Ok(self.database.get_repository::<T>())
90 }
91
92 pub fn get<T>(&self) -> Result<Arc<dyn Repository<T>>, StateRegistryError>
100 where
101 T: RepositoryItem,
102 {
103 if let Some(repo) = self.get_client_managed::<T>() {
104 return Ok(repo);
105 }
106
107 self.get_sdk_managed::<T>()
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::{
115 register_repository_item,
116 repository::{RepositoryError, RepositoryItem},
117 };
118
119 macro_rules! impl_repository {
120 ($name:ident, $ty:ty) => {
121 #[async_trait::async_trait]
122 impl Repository<$ty> for $name {
123 async fn get(&self, _key: String) -> Result<Option<$ty>, RepositoryError> {
124 Ok(Some(TestItem(self.0.clone())))
125 }
126 async fn list(&self) -> Result<Vec<$ty>, RepositoryError> {
127 unimplemented!()
128 }
129 async fn set(&self, _key: String, _value: $ty) -> Result<(), RepositoryError> {
130 unimplemented!()
131 }
132 async fn set_bulk(
133 &self,
134 _values: Vec<(String, $ty)>,
135 ) -> Result<(), RepositoryError> {
136 unimplemented!()
137 }
138 async fn remove(&self, _key: String) -> Result<(), RepositoryError> {
139 unimplemented!()
140 }
141 async fn remove_bulk(&self, _keys: Vec<String>) -> Result<(), RepositoryError> {
142 unimplemented!()
143 }
144 async fn remove_all(&self) -> Result<(), RepositoryError> {
145 unimplemented!()
146 }
147 }
148 };
149 }
150
151 use serde::{Deserialize, Serialize};
152
153 #[derive(PartialEq, Eq, Debug)]
154 struct TestA(usize);
155 #[derive(PartialEq, Eq, Debug)]
156 struct TestB(String);
157 #[derive(PartialEq, Eq, Debug)]
158 struct TestC(Vec<u8>);
159 #[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
160 struct TestItem<T>(T);
161
162 register_repository_item!(String => TestItem<usize>, "TestItem_usize");
163 register_repository_item!(String => TestItem<String>, "TestItem_String");
164 register_repository_item!(String => TestItem<Vec<u8>>, "TestItem_Vec");
165
166 impl_repository!(TestA, TestItem<usize>);
167 impl_repository!(TestB, TestItem<String>);
168 impl_repository!(TestC, TestItem<Vec<u8>>);
169
170 #[tokio::test]
171 async fn test_state_registry() {
172 let a = Arc::new(TestA(145832));
173 let b = Arc::new(TestB("test".to_string()));
174 let c = Arc::new(TestC(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]));
175
176 let map = StateRegistry::new_with_memory_db();
177
178 async fn get<T: RepositoryItem>(map: &StateRegistry) -> Option<T>
179 where
180 T::Key: Default,
181 {
182 map.get_client_managed::<T>()
183 .unwrap()
184 .get(Default::default())
185 .await
186 .unwrap()
187 }
188
189 assert!(map.get_client_managed::<TestItem<usize>>().is_none());
190 assert!(map.get_client_managed::<TestItem<String>>().is_none());
191 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
192
193 map.register_client_managed(a.clone());
194 assert_eq!(get(&map).await, Some(TestItem(a.0)));
195 assert!(map.get_client_managed::<TestItem<String>>().is_none());
196 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
197
198 map.register_client_managed(b.clone());
199 assert_eq!(get(&map).await, Some(TestItem(a.0)));
200 assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
201 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
202
203 map.register_client_managed(c.clone());
204 assert_eq!(get(&map).await, Some(TestItem(a.0)));
205 assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
206 assert_eq!(get(&map).await, Some(TestItem(c.0.clone())));
207 }
208
209 #[tokio::test]
210 async fn test_fallback_client_managed_found() {
211 let registry = StateRegistry::new_with_memory_db();
212 let test_repo = Arc::new(TestA(12345));
213
214 registry.register_client_managed(test_repo.clone());
215
216 let repo = registry.get::<TestItem<usize>>().unwrap();
217 let result = repo.get(String::new()).await.unwrap();
218
219 assert_eq!(result, Some(TestItem(12345)));
220 }
221
222 #[tokio::test]
223 async fn test_new_with_memory_db_sync() {
224 let registry = StateRegistry::new_with_memory_db();
226 let repo = registry.get::<TestItem<usize>>().unwrap();
228 let result = repo.get(String::new()).await;
229 assert!(result.is_ok());
232 }
233
234 #[tokio::test]
235 async fn test_setting_on_memory_db() {
236 use crate::register_setting_key;
237 register_setting_key!(const TEST_SETTING: String = "test_registry_setting_key");
238
239 let registry = StateRegistry::new_with_memory_db();
240 let setting = registry.setting(TEST_SETTING).unwrap();
241
242 assert_eq!(setting.get().await.unwrap(), None::<String>);
244
245 setting.update("hello".to_string()).await.unwrap();
247 assert_eq!(setting.get().await.unwrap(), Some("hello".to_string()));
248
249 setting.delete().await.unwrap();
251 assert_eq!(setting.get().await.unwrap(), None::<String>);
252 }
253}