bitwarden_state/
registry.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    sync::{Arc, OnceLock, RwLock},
5};
6
7use bitwarden_error::bitwarden_error;
8use thiserror::Error;
9
10use crate::{
11    repository::{Repository, RepositoryItem, RepositoryItemData, RepositoryMigrations},
12    sdk_managed::{Database, DatabaseConfiguration, SystemDatabase},
13};
14
15/// A registry that contains repositories for different types of items.
16/// These repositories can be either managed by the client or by the SDK itself.
17pub struct StateRegistry {
18    sdk_managed: RwLock<Vec<RepositoryItemData>>,
19    client_managed: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
20
21    database: OnceLock<SystemDatabase>,
22}
23
24impl std::fmt::Debug for StateRegistry {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("StateRegistry").finish()
27    }
28}
29
30#[allow(missing_docs)]
31#[bitwarden_error(flat)]
32#[derive(Debug, Error)]
33pub enum StateRegistryError {
34    #[error("Database is already initialized")]
35    DatabaseAlreadyInitialized,
36    #[error("Database is not initialized")]
37    DatabaseNotInitialized,
38
39    #[error(transparent)]
40    Database(#[from] crate::sdk_managed::DatabaseError),
41}
42
43impl StateRegistry {
44    /// Creates a new empty `StateRegistry`.
45    #[allow(clippy::new_without_default)]
46    pub fn new() -> Self {
47        StateRegistry {
48            client_managed: RwLock::new(HashMap::new()),
49            database: OnceLock::new(),
50            sdk_managed: RwLock::new(Vec::new()),
51        }
52    }
53
54    // TODO: Ideally we'd do this in new, but that would mean making the client initialization
55    // async.
56    // TODO: This function needs to be provided some configuration to know where to open the
57    // database. For Sqlite:
58    // - A folder path where the files will be stored.
59    // - A user ID to create a unique database file per user?
60    //
61    // For WASM indexedDB:
62    // - A database name to use for the indexedDB (Some prefix to avoid conflicts + user ID?)
63
64    /// Initializes the database used for sdk-managed repositories.
65    pub async fn initialize_database(
66        &self,
67        configuration: DatabaseConfiguration,
68        migrations: RepositoryMigrations,
69    ) -> Result<(), StateRegistryError> {
70        if self.database.get().is_some() {
71            return Err(StateRegistryError::DatabaseAlreadyInitialized);
72        }
73        let _ = self
74            .database
75            .set(SystemDatabase::initialize(configuration, migrations.clone()).await?);
76
77        *self
78            .sdk_managed
79            .write()
80            .expect("RwLock should not be poisoned") = migrations.into_repository_items();
81
82        Ok(())
83    }
84
85    /// Registers a client-managed repository into the map, associating it with its type.
86    pub fn register_client_managed<T: RepositoryItem>(&self, value: Arc<dyn Repository<T>>) {
87        self.client_managed
88            .write()
89            .expect("RwLock should not be poisoned")
90            .insert(TypeId::of::<T>(), Box::new(value));
91    }
92
93    /// Retrieves a client-managed repository from the map given its type.
94    fn get_client_managed<T: RepositoryItem>(&self) -> Option<Arc<dyn Repository<T>>> {
95        self.client_managed
96            .read()
97            .expect("RwLock should not be poisoned")
98            .get(&TypeId::of::<T>())
99            .and_then(|boxed| boxed.downcast_ref::<Arc<dyn Repository<T>>>())
100            .map(Arc::clone)
101    }
102
103    /// Retrieves a SDK-managed repository from the database.
104    fn get_sdk_managed<T: RepositoryItem>(
105        &self,
106    ) -> Result<Arc<dyn Repository<T>>, StateRegistryError> {
107        self.database
108            .get()
109            .map(|db| db.get_repository::<T>())
110            .ok_or(StateRegistryError::DatabaseNotInitialized)
111    }
112
113    /// Get a repository with fallback: prefer client-managed, fall back to SDK-managed.
114    ///
115    /// This method first attempts to retrieve a client-managed repository. If not found,
116    /// it falls back to an SDK-managed repository. Both are returned as `Arc<dyn Repository<T>>`.
117    ///
118    /// # Type Requirements
119    /// - `T` must implement `RepositoryItem` (for both types)
120    ///
121    /// # Errors
122    /// Returns `StateRegistryError` when:
123    /// - Client-managed repository is not registered, AND
124    /// - SDK-managed repository cannot be retrieved (e.g., database not initialized)
125    pub fn get<T>(&self) -> Result<Arc<dyn Repository<T>>, StateRegistryError>
126    where
127        T: RepositoryItem,
128    {
129        if let Some(repo) = self.get_client_managed::<T>() {
130            return Ok(repo);
131        }
132
133        self.get_sdk_managed::<T>()
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use crate::{
141        register_repository_item,
142        repository::{RepositoryError, RepositoryItem},
143    };
144
145    macro_rules! impl_repository {
146        ($name:ident, $ty:ty) => {
147            #[async_trait::async_trait]
148            impl Repository<$ty> for $name {
149                async fn get(&self, _key: String) -> Result<Option<$ty>, RepositoryError> {
150                    Ok(Some(TestItem(self.0.clone())))
151                }
152                async fn list(&self) -> Result<Vec<$ty>, RepositoryError> {
153                    unimplemented!()
154                }
155                async fn set(&self, _key: String, _value: $ty) -> Result<(), RepositoryError> {
156                    unimplemented!()
157                }
158                async fn remove(&self, _key: String) -> Result<(), RepositoryError> {
159                    unimplemented!()
160                }
161            }
162        };
163    }
164
165    use serde::{Deserialize, Serialize};
166
167    #[derive(PartialEq, Eq, Debug)]
168    struct TestA(usize);
169    #[derive(PartialEq, Eq, Debug)]
170    struct TestB(String);
171    #[derive(PartialEq, Eq, Debug)]
172    struct TestC(Vec<u8>);
173    #[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
174    struct TestItem<T>(T);
175
176    register_repository_item!(TestItem<usize>, "TestItem_usize");
177    register_repository_item!(TestItem<String>, "TestItem_String");
178    register_repository_item!(TestItem<Vec<u8>>, "TestItem_Vec");
179
180    impl_repository!(TestA, TestItem<usize>);
181    impl_repository!(TestB, TestItem<String>);
182    impl_repository!(TestC, TestItem<Vec<u8>>);
183
184    #[tokio::test]
185    async fn test_repository_map() {
186        let a = Arc::new(TestA(145832));
187        let b = Arc::new(TestB("test".to_string()));
188        let c = Arc::new(TestC(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]));
189
190        let map = StateRegistry::new();
191
192        async fn get<T: RepositoryItem>(map: &StateRegistry) -> Option<T> {
193            map.get_client_managed::<T>()
194                .unwrap()
195                .get(String::new())
196                .await
197                .unwrap()
198        }
199
200        assert!(map.get_client_managed::<TestItem<usize>>().is_none());
201        assert!(map.get_client_managed::<TestItem<String>>().is_none());
202        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
203
204        map.register_client_managed(a.clone());
205        assert_eq!(get(&map).await, Some(TestItem(a.0)));
206        assert!(map.get_client_managed::<TestItem<String>>().is_none());
207        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
208
209        map.register_client_managed(b.clone());
210        assert_eq!(get(&map).await, Some(TestItem(a.0)));
211        assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
212        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
213
214        map.register_client_managed(c.clone());
215        assert_eq!(get(&map).await, Some(TestItem(a.0)));
216        assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
217        assert_eq!(get(&map).await, Some(TestItem(c.0.clone())));
218    }
219
220    #[tokio::test]
221    async fn test_fallback_client_managed_found() {
222        let registry = StateRegistry::new();
223        let test_repo = Arc::new(TestA(12345));
224
225        registry.register_client_managed(test_repo.clone());
226
227        let repo = registry.get::<TestItem<usize>>().unwrap();
228        let result = repo.get(String::new()).await.unwrap();
229
230        assert_eq!(result, Some(TestItem(12345)));
231    }
232
233    #[tokio::test]
234    async fn test_fallback_neither_available() {
235        let registry = StateRegistry::new();
236        // Don't register client-managed or initialize database
237
238        let result = registry.get::<TestItem<usize>>();
239        assert!(matches!(
240            result,
241            Err(StateRegistryError::DatabaseNotInitialized)
242        ));
243    }
244}