bitwarden_state/
registry.rs1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4 sync::{Arc, OnceLock, RwLock},
5};
6
7use bitwarden_error::bitwarden_error;
8use serde::{de::DeserializeOwned, Serialize};
9use thiserror::Error;
10
11use crate::{
12 repository::{Repository, RepositoryItem, RepositoryItemData},
13 sdk_managed::{Database, DatabaseConfiguration, SystemDatabase},
14};
15
16pub struct StateRegistry {
19 sdk_managed: RwLock<Vec<RepositoryItemData>>,
20 client_managed: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
21
22 database: OnceLock<SystemDatabase>,
23}
24
25impl std::fmt::Debug for StateRegistry {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 f.debug_struct("StateRegistry").finish()
28 }
29}
30
31#[allow(missing_docs)]
32#[bitwarden_error(flat)]
33#[derive(Debug, Error)]
34pub enum StateRegistryError {
35 #[error("Database is already initialized")]
36 DatabaseAlreadyInitialized,
37 #[error("Database is not initialized")]
38 DatabaseNotInitialized,
39
40 #[error(transparent)]
41 Database(#[from] crate::sdk_managed::DatabaseError),
42}
43
44#[derive(Debug, Error)]
46#[error("Repository not found for the requested type")]
47pub struct RepositoryNotFoundError;
48
49impl StateRegistry {
50 #[allow(clippy::new_without_default)]
52 pub fn new() -> Self {
53 StateRegistry {
54 client_managed: RwLock::new(HashMap::new()),
55 database: OnceLock::new(),
56 sdk_managed: RwLock::new(Vec::new()),
57 }
58 }
59
60 pub async fn initialize_database(
72 &self,
73 configuration: DatabaseConfiguration,
74 repositories: Vec<RepositoryItemData>,
75 ) -> Result<(), StateRegistryError> {
76 if self.database.get().is_some() {
77 return Err(StateRegistryError::DatabaseAlreadyInitialized);
78 }
79 let _ = self
80 .database
81 .set(SystemDatabase::initialize(configuration, &repositories).await?);
82
83 *self
84 .sdk_managed
85 .write()
86 .expect("RwLock should not be poisoned") = repositories.clone();
87
88 Ok(())
89 }
90
91 pub fn register_client_managed<T: RepositoryItem>(&self, value: Arc<dyn Repository<T>>) {
93 self.client_managed
94 .write()
95 .expect("RwLock should not be poisoned")
96 .insert(TypeId::of::<T>(), Box::new(value));
97 }
98
99 pub fn get_client_managed<T: RepositoryItem>(
101 &self,
102 ) -> Result<Arc<dyn Repository<T>>, RepositoryNotFoundError> {
103 self.client_managed
104 .read()
105 .expect("RwLock should not be poisoned")
106 .get(&TypeId::of::<T>())
107 .and_then(|boxed| boxed.downcast_ref::<Arc<dyn Repository<T>>>())
108 .map(Arc::clone)
109 .ok_or(RepositoryNotFoundError)
110 }
111
112 pub fn get_sdk_managed<T: RepositoryItem + Serialize + DeserializeOwned>(
114 &self,
115 ) -> Result<impl Repository<T>, StateRegistryError> {
116 if let Some(db) = self.database.get() {
117 Ok(db.get_repository::<T>()?)
118 } else {
119 Err(StateRegistryError::DatabaseNotInitialized)
120 }
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use crate::{
128 register_repository_item,
129 repository::{RepositoryError, RepositoryItem},
130 };
131
132 macro_rules! impl_repository {
133 ($name:ident, $ty:ty) => {
134 #[async_trait::async_trait]
135 impl Repository<$ty> for $name {
136 async fn get(&self, _key: String) -> Result<Option<$ty>, RepositoryError> {
137 Ok(Some(TestItem(self.0.clone())))
138 }
139 async fn list(&self) -> Result<Vec<$ty>, RepositoryError> {
140 unimplemented!()
141 }
142 async fn set(&self, _key: String, _value: $ty) -> Result<(), RepositoryError> {
143 unimplemented!()
144 }
145 async fn remove(&self, _key: String) -> Result<(), RepositoryError> {
146 unimplemented!()
147 }
148 }
149 };
150 }
151
152 #[derive(PartialEq, Eq, Debug)]
153 struct TestA(usize);
154 #[derive(PartialEq, Eq, Debug)]
155 struct TestB(String);
156 #[derive(PartialEq, Eq, Debug)]
157 struct TestC(Vec<u8>);
158 #[derive(PartialEq, Eq, Debug)]
159 struct TestItem<T>(T);
160
161 register_repository_item!(TestItem<usize>, "TestItem_usize");
162 register_repository_item!(TestItem<String>, "TestItem_String");
163 register_repository_item!(TestItem<Vec<u8>>, "TestItem_Vec");
164
165 impl_repository!(TestA, TestItem<usize>);
166 impl_repository!(TestB, TestItem<String>);
167 impl_repository!(TestC, TestItem<Vec<u8>>);
168
169 #[tokio::test]
170 async fn test_repository_map() {
171 let a = Arc::new(TestA(145832));
172 let b = Arc::new(TestB("test".to_string()));
173 let c = Arc::new(TestC(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]));
174
175 let map = StateRegistry::new();
176
177 async fn get<T: RepositoryItem>(map: &StateRegistry) -> Option<T> {
178 map.get_client_managed::<T>()
179 .unwrap()
180 .get(String::new())
181 .await
182 .unwrap()
183 }
184
185 assert!(map.get_client_managed::<TestItem<usize>>().is_err());
186 assert!(map.get_client_managed::<TestItem<String>>().is_err());
187 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_err());
188
189 map.register_client_managed(a.clone());
190 assert_eq!(get(&map).await, Some(TestItem(a.0)));
191 assert!(map.get_client_managed::<TestItem<String>>().is_err());
192 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_err());
193
194 map.register_client_managed(b.clone());
195 assert_eq!(get(&map).await, Some(TestItem(a.0)));
196 assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
197 assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_err());
198
199 map.register_client_managed(c.clone());
200 assert_eq!(get(&map).await, Some(TestItem(a.0)));
201 assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
202 assert_eq!(get(&map).await, Some(TestItem(c.0.clone())));
203 }
204}