Skip to main content

bitwarden_state/
registry.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    sync::{Arc, RwLock},
5};
6
7use bitwarden_error::bitwarden_error;
8use thiserror::Error;
9
10use crate::{
11    repository::{Repository, RepositoryItem, RepositoryMigrations},
12    sdk_managed::{Database, DatabaseConfiguration, DatabaseError, MemoryDatabase, SystemDatabase},
13    settings::{Key, Setting, SettingItem},
14};
15
16/// A registry that contains repositories for different types of items.
17/// These repositories can be either managed by the client or by the SDK itself.
18pub struct StateRegistry {
19    database: SystemDatabase,
20    client_managed: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
21}
22
23impl std::fmt::Debug for StateRegistry {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        f.debug_struct("StateRegistry").finish()
26    }
27}
28
29#[allow(missing_docs)]
30#[bitwarden_error(flat)]
31#[derive(Debug, Error)]
32pub enum StateRegistryError {
33    #[error("Database is not initialized")]
34    DatabaseNotInitialized,
35
36    #[error(transparent)]
37    Database(#[from] DatabaseError),
38}
39
40impl StateRegistry {
41    /// Creates a new `StateRegistry` backed by an in-memory database.
42    pub fn new_with_memory_db() -> Self {
43        StateRegistry {
44            database: SystemDatabase::Memory(MemoryDatabase::new()),
45            client_managed: RwLock::new(HashMap::new()),
46        }
47    }
48
49    /// Creates a new `StateRegistry` backed by a database.
50    pub async fn new_with_db(
51        configuration: DatabaseConfiguration,
52        migrations: RepositoryMigrations,
53    ) -> Result<Self, DatabaseError> {
54        let database = SystemDatabase::initialize(configuration, migrations.clone()).await?;
55        Ok(StateRegistry {
56            database,
57            client_managed: RwLock::new(HashMap::new()),
58        })
59    }
60
61    /// Get a handle to a setting by its type-safe key.
62    pub fn setting<T>(&self, key: Key<T>) -> Result<Setting<T>, StateRegistryError> {
63        let repo = self.get::<SettingItem>()?;
64        Ok(Setting::new(repo, key))
65    }
66
67    /// Registers a client-managed repository into the map, associating it with its type.
68    pub fn register_client_managed<T: RepositoryItem>(&self, value: Arc<dyn Repository<T>>) {
69        self.client_managed
70            .write()
71            .expect("RwLock should not be poisoned")
72            .insert(TypeId::of::<T>(), Box::new(value));
73    }
74
75    /// Retrieves a client-managed repository from the map given its type.
76    fn get_client_managed<T: RepositoryItem>(&self) -> Option<Arc<dyn Repository<T>>> {
77        self.client_managed
78            .read()
79            .expect("RwLock should not be poisoned")
80            .get(&TypeId::of::<T>())
81            .and_then(|boxed| boxed.downcast_ref::<Arc<dyn Repository<T>>>())
82            .map(Arc::clone)
83    }
84
85    /// Retrieves a SDK-managed repository from the database.
86    fn get_sdk_managed<T: RepositoryItem>(
87        &self,
88    ) -> Result<Arc<dyn Repository<T>>, StateRegistryError> {
89        Ok(self.database.get_repository::<T>())
90    }
91
92    /// Get a repository with fallback: prefer client-managed, fall back to SDK-managed.
93    ///
94    /// This method first attempts to retrieve a client-managed repository. If not found,
95    /// it falls back to an SDK-managed repository. Both are returned as `Arc<dyn Repository<T>>`.
96    ///
97    /// # Errors
98    /// This method never fails, but returns a Result for backwards compatibility.
99    pub fn get<T>(&self) -> Result<Arc<dyn Repository<T>>, StateRegistryError>
100    where
101        T: RepositoryItem,
102    {
103        if let Some(repo) = self.get_client_managed::<T>() {
104            return Ok(repo);
105        }
106
107        self.get_sdk_managed::<T>()
108    }
109
110    /// Wipes all state from this registry, and deletes any files or databases associated with it.
111    /// Intended to be used during logout, where the Client will be dropped right after.
112    ///
113    /// # Warning
114    ///
115    /// This closes the SDK-managed database and deletes persistent storage (SQLite file + WAL/SHM,
116    /// IndexedDB database). Outstanding [`Repository`] handles will return
117    /// [`DatabaseError::Closed`] on subsequent operations. Client-managed repositories are also
118    /// cleared.
119    pub async fn wipe(&self) -> Result<(), DatabaseError> {
120        // Clear client-managed first so a failure in the persistent-store wipe
121        // still releases the in-memory Arc references.
122        self.client_managed
123            .write()
124            .expect("RwLock should not be poisoned")
125            .clear();
126        self.database.wipe().await
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use crate::{
134        register_repository_item,
135        repository::{RepositoryError, RepositoryItem},
136        sdk_managed::DatabaseError,
137    };
138
139    macro_rules! impl_repository {
140        ($name:ident, $ty:ty) => {
141            #[async_trait::async_trait]
142            impl Repository<$ty> for $name {
143                async fn get(&self, _key: String) -> Result<Option<$ty>, RepositoryError> {
144                    Ok(Some(TestItem(self.0.clone())))
145                }
146                async fn list(&self) -> Result<Vec<$ty>, RepositoryError> {
147                    unimplemented!()
148                }
149                async fn set(&self, _key: String, _value: $ty) -> Result<(), RepositoryError> {
150                    unimplemented!()
151                }
152                async fn set_bulk(
153                    &self,
154                    _values: Vec<(String, $ty)>,
155                ) -> Result<(), RepositoryError> {
156                    unimplemented!()
157                }
158                async fn remove(&self, _key: String) -> Result<(), RepositoryError> {
159                    unimplemented!()
160                }
161                async fn remove_bulk(&self, _keys: Vec<String>) -> Result<(), RepositoryError> {
162                    unimplemented!()
163                }
164                async fn remove_all(&self) -> Result<(), RepositoryError> {
165                    unimplemented!()
166                }
167            }
168        };
169    }
170
171    use serde::{Deserialize, Serialize};
172
173    #[derive(PartialEq, Eq, Debug)]
174    struct TestA(usize);
175    #[derive(PartialEq, Eq, Debug)]
176    struct TestB(String);
177    #[derive(PartialEq, Eq, Debug)]
178    struct TestC(Vec<u8>);
179    #[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
180    struct TestItem<T>(T);
181
182    register_repository_item!(String => TestItem<usize>, "TestItem_usize");
183    register_repository_item!(String => TestItem<String>, "TestItem_String");
184    register_repository_item!(String => TestItem<Vec<u8>>, "TestItem_Vec");
185
186    impl_repository!(TestA, TestItem<usize>);
187    impl_repository!(TestB, TestItem<String>);
188    impl_repository!(TestC, TestItem<Vec<u8>>);
189
190    #[tokio::test]
191    async fn test_state_registry() {
192        let a = Arc::new(TestA(145832));
193        let b = Arc::new(TestB("test".to_string()));
194        let c = Arc::new(TestC(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]));
195
196        let map = StateRegistry::new_with_memory_db();
197
198        async fn get<T: RepositoryItem>(map: &StateRegistry) -> Option<T>
199        where
200            T::Key: Default,
201        {
202            map.get_client_managed::<T>()
203                .unwrap()
204                .get(Default::default())
205                .await
206                .unwrap()
207        }
208
209        assert!(map.get_client_managed::<TestItem<usize>>().is_none());
210        assert!(map.get_client_managed::<TestItem<String>>().is_none());
211        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
212
213        map.register_client_managed(a.clone());
214        assert_eq!(get(&map).await, Some(TestItem(a.0)));
215        assert!(map.get_client_managed::<TestItem<String>>().is_none());
216        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
217
218        map.register_client_managed(b.clone());
219        assert_eq!(get(&map).await, Some(TestItem(a.0)));
220        assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
221        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
222
223        map.register_client_managed(c.clone());
224        assert_eq!(get(&map).await, Some(TestItem(a.0)));
225        assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
226        assert_eq!(get(&map).await, Some(TestItem(c.0.clone())));
227    }
228
229    #[tokio::test]
230    async fn test_fallback_client_managed_found() {
231        let registry = StateRegistry::new_with_memory_db();
232        let test_repo = Arc::new(TestA(12345));
233
234        registry.register_client_managed(test_repo.clone());
235
236        let repo = registry.get::<TestItem<usize>>().unwrap();
237        let result = repo.get(String::new()).await.unwrap();
238
239        assert_eq!(result, Some(TestItem(12345)));
240    }
241
242    #[tokio::test]
243    async fn test_new_with_memory_db_sync() {
244        // Construct in sync context (no .await on the constructor itself)
245        let registry = StateRegistry::new_with_memory_db();
246        // Database must be accessible via async get after sync construction
247        let repo = registry.get::<TestItem<usize>>().unwrap();
248        let result = repo.get(String::new()).await;
249        // Should return Ok(None) — key not found, not an error
250        // (Note: TestItem<usize> is registered in this test module already)
251        assert!(result.is_ok());
252    }
253
254    #[tokio::test]
255    async fn test_wipe_disconnects_outstanding_repository_handles() {
256        let registry = StateRegistry::new_with_memory_db();
257        let repo = registry.get::<TestItem<usize>>().unwrap();
258        repo.set(String::new(), TestItem(42usize)).await.unwrap();
259
260        registry.wipe().await.unwrap();
261
262        assert!(matches!(
263            repo.get(String::new()).await,
264            Err(RepositoryError::Database(DatabaseError::Closed))
265        ));
266        assert!(matches!(
267            repo.list().await,
268            Err(RepositoryError::Database(DatabaseError::Closed))
269        ));
270    }
271
272    #[tokio::test]
273    async fn test_wipe_clears_client_managed() {
274        let registry = StateRegistry::new_with_memory_db();
275        registry.register_client_managed(Arc::new(TestA(99)));
276
277        registry.wipe().await.unwrap();
278
279        // Client-managed is gone; falls through to SDK-managed (now closed).
280        let repo = registry.get::<TestItem<usize>>().unwrap();
281        assert!(matches!(
282            repo.get(String::new()).await,
283            Err(RepositoryError::Database(DatabaseError::Closed))
284        ));
285    }
286
287    #[tokio::test]
288    async fn test_wipe_is_idempotent() {
289        let registry = StateRegistry::new_with_memory_db();
290        registry.wipe().await.unwrap();
291        registry.wipe().await.unwrap();
292    }
293
294    #[tokio::test]
295    async fn test_setting_on_memory_db() {
296        use crate::register_setting_key;
297        register_setting_key!(const TEST_SETTING: String = "test_registry_setting_key");
298
299        let registry = StateRegistry::new_with_memory_db();
300        let setting = registry.setting(TEST_SETTING).unwrap();
301
302        // Value must not exist initially
303        assert_eq!(setting.get().await.unwrap(), None::<String>);
304
305        // Update and read back
306        setting.update("hello".to_string()).await.unwrap();
307        assert_eq!(setting.get().await.unwrap(), Some("hello".to_string()));
308
309        // Delete and confirm gone
310        setting.delete().await.unwrap();
311        assert_eq!(setting.get().await.unwrap(), None::<String>);
312    }
313}