1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4 sync::{Arc, OnceLock, RwLock},
5};
6
7use bitwarden_error::bitwarden_error;
8use thiserror::Error;
9
10use crate::{
11 repository::{Repository, RepositoryItem, RepositoryItemData, RepositoryMigrations},
12 sdk_managed::{Database, DatabaseConfiguration, SystemDatabase},
13};
14
15pub struct StateRegistry {
18 sdk_managed: RwLock<Vec<RepositoryItemData>>,
19 client_managed: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
20
21 database: OnceLock<SystemDatabase>,
22}
23
24impl std::fmt::Debug for StateRegistry {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("StateRegistry").finish()
27 }
28}
29
30#[allow(missing_docs)]
31#[bitwarden_error(flat)]
32#[derive(Debug, Error)]
33pub enum StateRegistryError {
34 #[error("Database is already initialized")]
35 DatabaseAlreadyInitialized,
36 #[error("Database is not initialized")]
37 DatabaseNotInitialized,
38
39 #[error(transparent)]
40 Database(#[from] crate::sdk_managed::DatabaseError),
41}
42
43impl StateRegistry {
44 #[allow(clippy::new_without_default)]
46 pub fn new() -> Self {
47 StateRegistry {
48 client_managed: RwLock::new(HashMap::new()),
49 database: OnceLock::new(),
50 sdk_managed: RwLock::new(Vec::new()),
51 }
52 }
53
54 pub async fn initialize_database(
66 &self,
67 configuration: DatabaseConfiguration,
68 migrations: RepositoryMigrations,
69 ) -> Result<(), StateRegistryError> {
70 if self.database.get().is_some() {
71 return Err(StateRegistryError::DatabaseAlreadyInitialized);
72 }
73 let _ = self
74 .database
75 .set(SystemDatabase::initialize(configuration, migrations.clone()).await?);
76
77 *self
78 .sdk_managed
79 .write()
80 .expect("RwLock should not be poisoned") = migrations.into_repository_items();
81
82 Ok(())
83 }
84
85 pub fn register_client_managed<T: RepositoryItem>(&self, value: Arc<dyn Repository<T>>) {
87 self.client_managed
88 .write()
89 .expect("RwLock should not be poisoned")
90 .insert(TypeId::of::<T>(), Box::new(value));
91 }
92
93 fn get_client_managed<T: RepositoryItem>(&self) -> Option<Arc<dyn Repository<T>>> {
95 self.client_managed
96 .read()
97 .expect("RwLock should not be poisoned")
98 .get(&TypeId::of::<T>())
99 .and_then(|boxed| boxed.downcast_ref::<Arc<dyn Repository<T>>>())
100 .map(Arc::clone)
101 }
102
103 fn get_sdk_managed<T: RepositoryItem>(
105 &self,
106 ) -> Result<Arc<dyn Repository<T>>, StateRegistryError> {
107 self.database
108 .get()
109 .map(|db| db.get_repository::<T>())
110 .ok_or(StateRegistryError::DatabaseNotInitialized)
111 }
112
113 pub fn get<T>(&self) -> Result<Arc<dyn Repository<T>>, StateRegistryError>
126 where
127 T: RepositoryItem,
128 {
129 if let Some(repo) = self.get_client_managed::<T>() {
130 return Ok(repo);
131 }
132
133 self.get_sdk_managed::<T>()
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::{
141 register_repository_item,
142 repository::{RepositoryError, RepositoryItem},
143 };
144
145 macro_rules! impl_repository {
146 ($name:ident, $ty:ty) => {
147 #[async_trait::async_trait]
148 impl Repository<$ty> for $name {
149 async fn get(&self, _key: String) -> Result<Option<$ty>, RepositoryError> {
150 Ok(Some(TestItem(self.0.clone())))
151 }
152 async fn list(&self) -> Result<Vec<$ty>, RepositoryError> {
153 unimplemented!()
154 }
155 async fn set(&self, _key: String, _value: $ty) -> Result<(), RepositoryError> {
156 unimplemented!()
157 }
158 async fn set_bulk(
159 &self,
160 _values: Vec<(String, $ty)>,
161 ) -> Result<(), RepositoryError> {
162 unimplemented!()
163 }
164 async fn remove(&self, _key: String) -> Result<(), RepositoryError> {
165 unimplemented!()
166 }
167 async fn remove_bulk(&self, _keys: Vec<String>) -> Result<(), RepositoryError> {
168 unimplemented!()
169 }
170 async fn remove_all(&self) -> Result<(), RepositoryError> {
171 unimplemented!()
172 }
173 }
174 };
175 }
176
177 use serde::{Deserialize, Serialize};
178
179 #[derive(PartialEq, Eq, Debug)]
180 struct TestA(usize);
181 #[derive(PartialEq, Eq, Debug)]
182 struct TestB(String);
183 #[derive(PartialEq, Eq, Debug)]
184 struct TestC(Vec<u8>);
185 #[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
186 struct TestItem<T>(T);
187
188 register_repository_item!(String => TestItem<usize>, "TestItem_usize");
189 register_repository_item!(String => TestItem<String>, "TestItem_String");
190 register_repository_item!(String => TestItem<Vec<u8>>, "TestItem_Vec");
191
192 impl_repository!(TestA, TestItem<usize>);
193 impl_repository!(TestB, TestItem<String>);
194 impl_repository!(TestC, TestItem<Vec<u8>>);
195
196 #[tokio::test]
197 async fn test_repository_map() {
198 let a = Arc::new(TestA(145832));
199 let b = Arc::new(TestB("test".to_string()));
200 let c = Arc::new(TestC(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]));
201
202 let map = StateRegistry::new();
203
204 async fn get<T: RepositoryItem>(map: &StateRegistry) -> Option<T>
205 where
206 T::Key: Default,
207 {
208 map.get_client_managed::<T>()
209 .unwrap()
210 .get(Default::default())
211 .await
212 .unwrap()
213 }
214
215 assert!(map.get_client_managed::<TestItem<usize>>().is_none());
216 assert!(map.get_client_managed::<TestItem<String>>().is_none());
217 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
218
219 map.register_client_managed(a.clone());
220 assert_eq!(get(&map).await, Some(TestItem(a.0)));
221 assert!(map.get_client_managed::<TestItem<String>>().is_none());
222 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
223
224 map.register_client_managed(b.clone());
225 assert_eq!(get(&map).await, Some(TestItem(a.0)));
226 assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
227 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
228
229 map.register_client_managed(c.clone());
230 assert_eq!(get(&map).await, Some(TestItem(a.0)));
231 assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
232 assert_eq!(get(&map).await, Some(TestItem(c.0.clone())));
233 }
234
235 #[tokio::test]
236 async fn test_fallback_client_managed_found() {
237 let registry = StateRegistry::new();
238 let test_repo = Arc::new(TestA(12345));
239
240 registry.register_client_managed(test_repo.clone());
241
242 let repo = registry.get::<TestItem<usize>>().unwrap();
243 let result = repo.get(String::new()).await.unwrap();
244
245 assert_eq!(result, Some(TestItem(12345)));
246 }
247
248 #[tokio::test]
249 async fn test_fallback_neither_available() {
250 let registry = StateRegistry::new();
251 let result = registry.get::<TestItem<usize>>();
254 assert!(matches!(
255 result,
256 Err(StateRegistryError::DatabaseNotInitialized)
257 ));
258 }
259}