bitwarden_state/
registry.rs1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4 sync::{Arc, OnceLock, RwLock},
5};
6
7use bitwarden_error::bitwarden_error;
8use thiserror::Error;
9
10use crate::{
11 repository::{Repository, RepositoryItem, RepositoryItemData, RepositoryMigrations},
12 sdk_managed::{Database, DatabaseConfiguration, SystemDatabase},
13};
14
15pub struct StateRegistry {
18 sdk_managed: RwLock<Vec<RepositoryItemData>>,
19 client_managed: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
20
21 database: OnceLock<SystemDatabase>,
22}
23
24impl std::fmt::Debug for StateRegistry {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("StateRegistry").finish()
27 }
28}
29
30#[allow(missing_docs)]
31#[bitwarden_error(flat)]
32#[derive(Debug, Error)]
33pub enum StateRegistryError {
34 #[error("Database is already initialized")]
35 DatabaseAlreadyInitialized,
36 #[error("Database is not initialized")]
37 DatabaseNotInitialized,
38
39 #[error(transparent)]
40 Database(#[from] crate::sdk_managed::DatabaseError),
41}
42
43impl StateRegistry {
44 #[allow(clippy::new_without_default)]
46 pub fn new() -> Self {
47 StateRegistry {
48 client_managed: RwLock::new(HashMap::new()),
49 database: OnceLock::new(),
50 sdk_managed: RwLock::new(Vec::new()),
51 }
52 }
53
54 pub async fn initialize_database(
66 &self,
67 configuration: DatabaseConfiguration,
68 migrations: RepositoryMigrations,
69 ) -> Result<(), StateRegistryError> {
70 if self.database.get().is_some() {
71 return Err(StateRegistryError::DatabaseAlreadyInitialized);
72 }
73 let _ = self
74 .database
75 .set(SystemDatabase::initialize(configuration, migrations.clone()).await?);
76
77 *self
78 .sdk_managed
79 .write()
80 .expect("RwLock should not be poisoned") = migrations.into_repository_items();
81
82 Ok(())
83 }
84
85 pub fn register_client_managed<T: RepositoryItem>(&self, value: Arc<dyn Repository<T>>) {
87 self.client_managed
88 .write()
89 .expect("RwLock should not be poisoned")
90 .insert(TypeId::of::<T>(), Box::new(value));
91 }
92
93 fn get_client_managed<T: RepositoryItem>(&self) -> Option<Arc<dyn Repository<T>>> {
95 self.client_managed
96 .read()
97 .expect("RwLock should not be poisoned")
98 .get(&TypeId::of::<T>())
99 .and_then(|boxed| boxed.downcast_ref::<Arc<dyn Repository<T>>>())
100 .map(Arc::clone)
101 }
102
103 fn get_sdk_managed<T: RepositoryItem>(
105 &self,
106 ) -> Result<Arc<dyn Repository<T>>, StateRegistryError> {
107 self.database
108 .get()
109 .map(|db| db.get_repository::<T>())
110 .ok_or(StateRegistryError::DatabaseNotInitialized)
111 }
112
113 pub fn get<T>(&self) -> Result<Arc<dyn Repository<T>>, StateRegistryError>
126 where
127 T: RepositoryItem,
128 {
129 if let Some(repo) = self.get_client_managed::<T>() {
130 return Ok(repo);
131 }
132
133 self.get_sdk_managed::<T>()
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::{
141 register_repository_item,
142 repository::{RepositoryError, RepositoryItem},
143 };
144
145 macro_rules! impl_repository {
146 ($name:ident, $ty:ty) => {
147 #[async_trait::async_trait]
148 impl Repository<$ty> for $name {
149 async fn get(&self, _key: String) -> Result<Option<$ty>, RepositoryError> {
150 Ok(Some(TestItem(self.0.clone())))
151 }
152 async fn list(&self) -> Result<Vec<$ty>, RepositoryError> {
153 unimplemented!()
154 }
155 async fn set(&self, _key: String, _value: $ty) -> Result<(), RepositoryError> {
156 unimplemented!()
157 }
158 async fn remove(&self, _key: String) -> Result<(), RepositoryError> {
159 unimplemented!()
160 }
161 }
162 };
163 }
164
165 use serde::{Deserialize, Serialize};
166
167 #[derive(PartialEq, Eq, Debug)]
168 struct TestA(usize);
169 #[derive(PartialEq, Eq, Debug)]
170 struct TestB(String);
171 #[derive(PartialEq, Eq, Debug)]
172 struct TestC(Vec<u8>);
173 #[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
174 struct TestItem<T>(T);
175
176 register_repository_item!(TestItem<usize>, "TestItem_usize");
177 register_repository_item!(TestItem<String>, "TestItem_String");
178 register_repository_item!(TestItem<Vec<u8>>, "TestItem_Vec");
179
180 impl_repository!(TestA, TestItem<usize>);
181 impl_repository!(TestB, TestItem<String>);
182 impl_repository!(TestC, TestItem<Vec<u8>>);
183
184 #[tokio::test]
185 async fn test_repository_map() {
186 let a = Arc::new(TestA(145832));
187 let b = Arc::new(TestB("test".to_string()));
188 let c = Arc::new(TestC(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]));
189
190 let map = StateRegistry::new();
191
192 async fn get<T: RepositoryItem>(map: &StateRegistry) -> Option<T> {
193 map.get_client_managed::<T>()
194 .unwrap()
195 .get(String::new())
196 .await
197 .unwrap()
198 }
199
200 assert!(map.get_client_managed::<TestItem<usize>>().is_none());
201 assert!(map.get_client_managed::<TestItem<String>>().is_none());
202 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
203
204 map.register_client_managed(a.clone());
205 assert_eq!(get(&map).await, Some(TestItem(a.0)));
206 assert!(map.get_client_managed::<TestItem<String>>().is_none());
207 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
208
209 map.register_client_managed(b.clone());
210 assert_eq!(get(&map).await, Some(TestItem(a.0)));
211 assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
212 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
213
214 map.register_client_managed(c.clone());
215 assert_eq!(get(&map).await, Some(TestItem(a.0)));
216 assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
217 assert_eq!(get(&map).await, Some(TestItem(c.0.clone())));
218 }
219
220 #[tokio::test]
221 async fn test_fallback_client_managed_found() {
222 let registry = StateRegistry::new();
223 let test_repo = Arc::new(TestA(12345));
224
225 registry.register_client_managed(test_repo.clone());
226
227 let repo = registry.get::<TestItem<usize>>().unwrap();
228 let result = repo.get(String::new()).await.unwrap();
229
230 assert_eq!(result, Some(TestItem(12345)));
231 }
232
233 #[tokio::test]
234 async fn test_fallback_neither_available() {
235 let registry = StateRegistry::new();
236 let result = registry.get::<TestItem<usize>>();
239 assert!(matches!(
240 result,
241 Err(StateRegistryError::DatabaseNotInitialized)
242 ));
243 }
244}