1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4 sync::{Arc, OnceLock, RwLock},
5};
6
7use bitwarden_error::bitwarden_error;
8use thiserror::Error;
9
10use crate::{
11 repository::{Repository, RepositoryItem, RepositoryItemData, RepositoryMigrations},
12 sdk_managed::{Database, DatabaseConfiguration, MemoryDatabase, SystemDatabase},
13 settings::{Key, Setting, SettingItem},
14};
15
16pub struct StateRegistry {
19 sdk_managed: RwLock<Vec<RepositoryItemData>>,
20 client_managed: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
21
22 database: OnceLock<SystemDatabase>,
23}
24
25impl std::fmt::Debug for StateRegistry {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 f.debug_struct("StateRegistry").finish()
28 }
29}
30
31#[allow(missing_docs)]
32#[bitwarden_error(flat)]
33#[derive(Debug, Error)]
34pub enum StateRegistryError {
35 #[error("Database is already initialized")]
36 DatabaseAlreadyInitialized,
37 #[error("Database is not initialized")]
38 DatabaseNotInitialized,
39
40 #[error(transparent)]
41 Database(#[from] crate::sdk_managed::DatabaseError),
42}
43
44impl StateRegistry {
45 #[allow(clippy::new_without_default)]
47 pub fn new() -> Self {
48 StateRegistry {
49 client_managed: RwLock::new(HashMap::new()),
50 database: OnceLock::new(),
51 sdk_managed: RwLock::new(Vec::new()),
52 }
53 }
54
55 pub fn new_with_memory_db() -> Self {
57 let registry = Self::new();
58 let _ = registry
62 .database
63 .set(SystemDatabase::Memory(MemoryDatabase::new()));
64 registry
65 }
66
67 pub async fn initialize_database(
79 &self,
80 configuration: DatabaseConfiguration,
81 migrations: RepositoryMigrations,
82 ) -> Result<(), StateRegistryError> {
83 if self.database.get().is_some() {
84 return Err(StateRegistryError::DatabaseAlreadyInitialized);
85 }
86 let _ = self
87 .database
88 .set(SystemDatabase::initialize(configuration, migrations.clone()).await?);
89
90 *self
91 .sdk_managed
92 .write()
93 .expect("RwLock should not be poisoned") = migrations.into_repository_items();
94
95 Ok(())
96 }
97
98 pub fn setting<T>(&self, key: Key<T>) -> Result<Setting<T>, StateRegistryError> {
100 let repo = self.get::<SettingItem>()?;
101 Ok(Setting::new(repo, key))
102 }
103
104 pub fn register_client_managed<T: RepositoryItem>(&self, value: Arc<dyn Repository<T>>) {
106 self.client_managed
107 .write()
108 .expect("RwLock should not be poisoned")
109 .insert(TypeId::of::<T>(), Box::new(value));
110 }
111
112 fn get_client_managed<T: RepositoryItem>(&self) -> Option<Arc<dyn Repository<T>>> {
114 self.client_managed
115 .read()
116 .expect("RwLock should not be poisoned")
117 .get(&TypeId::of::<T>())
118 .and_then(|boxed| boxed.downcast_ref::<Arc<dyn Repository<T>>>())
119 .map(Arc::clone)
120 }
121
122 fn get_sdk_managed<T: RepositoryItem>(
124 &self,
125 ) -> Result<Arc<dyn Repository<T>>, StateRegistryError> {
126 self.database
127 .get()
128 .map(|db| db.get_repository::<T>())
129 .ok_or(StateRegistryError::DatabaseNotInitialized)
130 }
131
132 pub fn get<T>(&self) -> Result<Arc<dyn Repository<T>>, StateRegistryError>
145 where
146 T: RepositoryItem,
147 {
148 if let Some(repo) = self.get_client_managed::<T>() {
149 return Ok(repo);
150 }
151
152 self.get_sdk_managed::<T>()
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::{
160 register_repository_item,
161 repository::{RepositoryError, RepositoryItem},
162 };
163
164 macro_rules! impl_repository {
165 ($name:ident, $ty:ty) => {
166 #[async_trait::async_trait]
167 impl Repository<$ty> for $name {
168 async fn get(&self, _key: String) -> Result<Option<$ty>, RepositoryError> {
169 Ok(Some(TestItem(self.0.clone())))
170 }
171 async fn list(&self) -> Result<Vec<$ty>, RepositoryError> {
172 unimplemented!()
173 }
174 async fn set(&self, _key: String, _value: $ty) -> Result<(), RepositoryError> {
175 unimplemented!()
176 }
177 async fn set_bulk(
178 &self,
179 _values: Vec<(String, $ty)>,
180 ) -> Result<(), RepositoryError> {
181 unimplemented!()
182 }
183 async fn remove(&self, _key: String) -> Result<(), RepositoryError> {
184 unimplemented!()
185 }
186 async fn remove_bulk(&self, _keys: Vec<String>) -> Result<(), RepositoryError> {
187 unimplemented!()
188 }
189 async fn remove_all(&self) -> Result<(), RepositoryError> {
190 unimplemented!()
191 }
192 }
193 };
194 }
195
196 use serde::{Deserialize, Serialize};
197
198 #[derive(PartialEq, Eq, Debug)]
199 struct TestA(usize);
200 #[derive(PartialEq, Eq, Debug)]
201 struct TestB(String);
202 #[derive(PartialEq, Eq, Debug)]
203 struct TestC(Vec<u8>);
204 #[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
205 struct TestItem<T>(T);
206
207 register_repository_item!(String => TestItem<usize>, "TestItem_usize");
208 register_repository_item!(String => TestItem<String>, "TestItem_String");
209 register_repository_item!(String => TestItem<Vec<u8>>, "TestItem_Vec");
210
211 impl_repository!(TestA, TestItem<usize>);
212 impl_repository!(TestB, TestItem<String>);
213 impl_repository!(TestC, TestItem<Vec<u8>>);
214
215 #[tokio::test]
216 async fn test_state_registry() {
217 let a = Arc::new(TestA(145832));
218 let b = Arc::new(TestB("test".to_string()));
219 let c = Arc::new(TestC(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]));
220
221 let map = StateRegistry::new();
222
223 async fn get<T: RepositoryItem>(map: &StateRegistry) -> Option<T>
224 where
225 T::Key: Default,
226 {
227 map.get_client_managed::<T>()
228 .unwrap()
229 .get(Default::default())
230 .await
231 .unwrap()
232 }
233
234 assert!(map.get_client_managed::<TestItem<usize>>().is_none());
235 assert!(map.get_client_managed::<TestItem<String>>().is_none());
236 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
237
238 map.register_client_managed(a.clone());
239 assert_eq!(get(&map).await, Some(TestItem(a.0)));
240 assert!(map.get_client_managed::<TestItem<String>>().is_none());
241 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
242
243 map.register_client_managed(b.clone());
244 assert_eq!(get(&map).await, Some(TestItem(a.0)));
245 assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
246 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
247
248 map.register_client_managed(c.clone());
249 assert_eq!(get(&map).await, Some(TestItem(a.0)));
250 assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
251 assert_eq!(get(&map).await, Some(TestItem(c.0.clone())));
252 }
253
254 #[tokio::test]
255 async fn test_fallback_client_managed_found() {
256 let registry = StateRegistry::new();
257 let test_repo = Arc::new(TestA(12345));
258
259 registry.register_client_managed(test_repo.clone());
260
261 let repo = registry.get::<TestItem<usize>>().unwrap();
262 let result = repo.get(String::new()).await.unwrap();
263
264 assert_eq!(result, Some(TestItem(12345)));
265 }
266
267 #[tokio::test]
268 async fn test_fallback_neither_available() {
269 let registry = StateRegistry::new();
270 let result = registry.get::<TestItem<usize>>();
273 assert!(matches!(
274 result,
275 Err(StateRegistryError::DatabaseNotInitialized)
276 ));
277 }
278
279 #[tokio::test]
280 async fn test_new_with_memory_db_sync() {
281 let registry = StateRegistry::new_with_memory_db();
283 let repo = registry.get::<TestItem<usize>>().unwrap();
285 let result = repo.get(String::new()).await;
286 assert!(result.is_ok());
289 }
290
291 #[tokio::test]
292 async fn test_setting_on_memory_db() {
293 use crate::register_setting_key;
294 register_setting_key!(const TEST_SETTING: String = "test_registry_setting_key");
295
296 let registry = StateRegistry::new_with_memory_db();
297 let setting = registry.setting(TEST_SETTING).unwrap();
298
299 assert_eq!(setting.get().await.unwrap(), None::<String>);
301
302 setting.update("hello".to_string()).await.unwrap();
304 assert_eq!(setting.get().await.unwrap(), Some("hello".to_string()));
305
306 setting.delete().await.unwrap();
308 assert_eq!(setting.get().await.unwrap(), None::<String>);
309 }
310}