Skip to main content

bitwarden_state/
registry.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    sync::{Arc, OnceLock, RwLock},
5};
6
7use bitwarden_error::bitwarden_error;
8use thiserror::Error;
9
10use crate::{
11    repository::{Repository, RepositoryItem, RepositoryItemData, RepositoryMigrations},
12    sdk_managed::{Database, DatabaseConfiguration, MemoryDatabase, SystemDatabase},
13    settings::{Key, Setting, SettingItem},
14};
15
16/// A registry that contains repositories for different types of items.
17/// These repositories can be either managed by the client or by the SDK itself.
18pub struct StateRegistry {
19    sdk_managed: RwLock<Vec<RepositoryItemData>>,
20    client_managed: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
21
22    database: OnceLock<SystemDatabase>,
23}
24
25impl std::fmt::Debug for StateRegistry {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        f.debug_struct("StateRegistry").finish()
28    }
29}
30
31#[allow(missing_docs)]
32#[bitwarden_error(flat)]
33#[derive(Debug, Error)]
34pub enum StateRegistryError {
35    #[error("Database is already initialized")]
36    DatabaseAlreadyInitialized,
37    #[error("Database is not initialized")]
38    DatabaseNotInitialized,
39
40    #[error(transparent)]
41    Database(#[from] crate::sdk_managed::DatabaseError),
42}
43
44impl StateRegistry {
45    /// Creates a new empty `StateRegistry`.
46    #[allow(clippy::new_without_default)]
47    pub fn new() -> Self {
48        StateRegistry {
49            client_managed: RwLock::new(HashMap::new()),
50            database: OnceLock::new(),
51            sdk_managed: RwLock::new(Vec::new()),
52        }
53    }
54
55    /// Creates a new `StateRegistry` backed by an in-memory database.
56    pub fn new_with_memory_db() -> Self {
57        let registry = Self::new();
58        // OnceLock::set returns Err only if already set.
59        // new() guarantees the OnceLock is unset. We ignore the result
60        // because there is no failure scenario here.
61        let _ = registry
62            .database
63            .set(SystemDatabase::Memory(MemoryDatabase::new()));
64        registry
65    }
66
67    // TODO: Ideally we'd do this in new, but that would mean making the client initialization
68    // async.
69    // TODO: This function needs to be provided some configuration to know where to open the
70    // database. For Sqlite:
71    // - A folder path where the files will be stored.
72    // - A user ID to create a unique database file per user?
73    //
74    // For WASM indexedDB:
75    // - A database name to use for the indexedDB (Some prefix to avoid conflicts + user ID?)
76
77    /// Initializes the database used for sdk-managed repositories.
78    pub async fn initialize_database(
79        &self,
80        configuration: DatabaseConfiguration,
81        migrations: RepositoryMigrations,
82    ) -> Result<(), StateRegistryError> {
83        if self.database.get().is_some() {
84            return Err(StateRegistryError::DatabaseAlreadyInitialized);
85        }
86        let _ = self
87            .database
88            .set(SystemDatabase::initialize(configuration, migrations.clone()).await?);
89
90        *self
91            .sdk_managed
92            .write()
93            .expect("RwLock should not be poisoned") = migrations.into_repository_items();
94
95        Ok(())
96    }
97
98    /// Get a handle to a setting by its type-safe key.
99    pub fn setting<T>(&self, key: Key<T>) -> Result<Setting<T>, StateRegistryError> {
100        let repo = self.get::<SettingItem>()?;
101        Ok(Setting::new(repo, key))
102    }
103
104    /// Registers a client-managed repository into the map, associating it with its type.
105    pub fn register_client_managed<T: RepositoryItem>(&self, value: Arc<dyn Repository<T>>) {
106        self.client_managed
107            .write()
108            .expect("RwLock should not be poisoned")
109            .insert(TypeId::of::<T>(), Box::new(value));
110    }
111
112    /// Retrieves a client-managed repository from the map given its type.
113    fn get_client_managed<T: RepositoryItem>(&self) -> Option<Arc<dyn Repository<T>>> {
114        self.client_managed
115            .read()
116            .expect("RwLock should not be poisoned")
117            .get(&TypeId::of::<T>())
118            .and_then(|boxed| boxed.downcast_ref::<Arc<dyn Repository<T>>>())
119            .map(Arc::clone)
120    }
121
122    /// Retrieves a SDK-managed repository from the database.
123    fn get_sdk_managed<T: RepositoryItem>(
124        &self,
125    ) -> Result<Arc<dyn Repository<T>>, StateRegistryError> {
126        self.database
127            .get()
128            .map(|db| db.get_repository::<T>())
129            .ok_or(StateRegistryError::DatabaseNotInitialized)
130    }
131
132    /// Get a repository with fallback: prefer client-managed, fall back to SDK-managed.
133    ///
134    /// This method first attempts to retrieve a client-managed repository. If not found,
135    /// it falls back to an SDK-managed repository. Both are returned as `Arc<dyn Repository<T>>`.
136    ///
137    /// # Type Requirements
138    /// - `T` must implement `RepositoryItem` (for both types)
139    ///
140    /// # Errors
141    /// Returns `StateRegistryError` when:
142    /// - Client-managed repository is not registered, AND
143    /// - SDK-managed repository cannot be retrieved (e.g., database not initialized)
144    pub fn get<T>(&self) -> Result<Arc<dyn Repository<T>>, StateRegistryError>
145    where
146        T: RepositoryItem,
147    {
148        if let Some(repo) = self.get_client_managed::<T>() {
149            return Ok(repo);
150        }
151
152        self.get_sdk_managed::<T>()
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::{
160        register_repository_item,
161        repository::{RepositoryError, RepositoryItem},
162    };
163
164    macro_rules! impl_repository {
165        ($name:ident, $ty:ty) => {
166            #[async_trait::async_trait]
167            impl Repository<$ty> for $name {
168                async fn get(&self, _key: String) -> Result<Option<$ty>, RepositoryError> {
169                    Ok(Some(TestItem(self.0.clone())))
170                }
171                async fn list(&self) -> Result<Vec<$ty>, RepositoryError> {
172                    unimplemented!()
173                }
174                async fn set(&self, _key: String, _value: $ty) -> Result<(), RepositoryError> {
175                    unimplemented!()
176                }
177                async fn set_bulk(
178                    &self,
179                    _values: Vec<(String, $ty)>,
180                ) -> Result<(), RepositoryError> {
181                    unimplemented!()
182                }
183                async fn remove(&self, _key: String) -> Result<(), RepositoryError> {
184                    unimplemented!()
185                }
186                async fn remove_bulk(&self, _keys: Vec<String>) -> Result<(), RepositoryError> {
187                    unimplemented!()
188                }
189                async fn remove_all(&self) -> Result<(), RepositoryError> {
190                    unimplemented!()
191                }
192            }
193        };
194    }
195
196    use serde::{Deserialize, Serialize};
197
198    #[derive(PartialEq, Eq, Debug)]
199    struct TestA(usize);
200    #[derive(PartialEq, Eq, Debug)]
201    struct TestB(String);
202    #[derive(PartialEq, Eq, Debug)]
203    struct TestC(Vec<u8>);
204    #[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
205    struct TestItem<T>(T);
206
207    register_repository_item!(String => TestItem<usize>, "TestItem_usize");
208    register_repository_item!(String => TestItem<String>, "TestItem_String");
209    register_repository_item!(String => TestItem<Vec<u8>>, "TestItem_Vec");
210
211    impl_repository!(TestA, TestItem<usize>);
212    impl_repository!(TestB, TestItem<String>);
213    impl_repository!(TestC, TestItem<Vec<u8>>);
214
215    #[tokio::test]
216    async fn test_state_registry() {
217        let a = Arc::new(TestA(145832));
218        let b = Arc::new(TestB("test".to_string()));
219        let c = Arc::new(TestC(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]));
220
221        let map = StateRegistry::new();
222
223        async fn get<T: RepositoryItem>(map: &StateRegistry) -> Option<T>
224        where
225            T::Key: Default,
226        {
227            map.get_client_managed::<T>()
228                .unwrap()
229                .get(Default::default())
230                .await
231                .unwrap()
232        }
233
234        assert!(map.get_client_managed::<TestItem<usize>>().is_none());
235        assert!(map.get_client_managed::<TestItem<String>>().is_none());
236        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
237
238        map.register_client_managed(a.clone());
239        assert_eq!(get(&map).await, Some(TestItem(a.0)));
240        assert!(map.get_client_managed::<TestItem<String>>().is_none());
241        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
242
243        map.register_client_managed(b.clone());
244        assert_eq!(get(&map).await, Some(TestItem(a.0)));
245        assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
246        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
247
248        map.register_client_managed(c.clone());
249        assert_eq!(get(&map).await, Some(TestItem(a.0)));
250        assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
251        assert_eq!(get(&map).await, Some(TestItem(c.0.clone())));
252    }
253
254    #[tokio::test]
255    async fn test_fallback_client_managed_found() {
256        let registry = StateRegistry::new();
257        let test_repo = Arc::new(TestA(12345));
258
259        registry.register_client_managed(test_repo.clone());
260
261        let repo = registry.get::<TestItem<usize>>().unwrap();
262        let result = repo.get(String::new()).await.unwrap();
263
264        assert_eq!(result, Some(TestItem(12345)));
265    }
266
267    #[tokio::test]
268    async fn test_fallback_neither_available() {
269        let registry = StateRegistry::new();
270        // Don't register client-managed or initialize database
271
272        let result = registry.get::<TestItem<usize>>();
273        assert!(matches!(
274            result,
275            Err(StateRegistryError::DatabaseNotInitialized)
276        ));
277    }
278
279    #[tokio::test]
280    async fn test_new_with_memory_db_sync() {
281        // Construct in sync context (no .await on the constructor itself)
282        let registry = StateRegistry::new_with_memory_db();
283        // Database must be accessible via async get after sync construction
284        let repo = registry.get::<TestItem<usize>>().unwrap();
285        let result = repo.get(String::new()).await;
286        // Should return Ok(None) — key not found, not an error
287        // (Note: TestItem<usize> is registered in this test module already)
288        assert!(result.is_ok());
289    }
290
291    #[tokio::test]
292    async fn test_setting_on_memory_db() {
293        use crate::register_setting_key;
294        register_setting_key!(const TEST_SETTING: String = "test_registry_setting_key");
295
296        let registry = StateRegistry::new_with_memory_db();
297        let setting = registry.setting(TEST_SETTING).unwrap();
298
299        // Value must not exist initially
300        assert_eq!(setting.get().await.unwrap(), None::<String>);
301
302        // Update and read back
303        setting.update("hello".to_string()).await.unwrap();
304        assert_eq!(setting.get().await.unwrap(), Some("hello".to_string()));
305
306        // Delete and confirm gone
307        setting.delete().await.unwrap();
308        assert_eq!(setting.get().await.unwrap(), None::<String>);
309    }
310}