bitwarden_state/
registry.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    sync::{Arc, OnceLock, RwLock},
5};
6
7use bitwarden_error::bitwarden_error;
8use serde::{de::DeserializeOwned, Serialize};
9use thiserror::Error;
10
11use crate::{
12    repository::{Repository, RepositoryItem, RepositoryItemData},
13    sdk_managed::{Database, DatabaseConfiguration, SystemDatabase},
14};
15
16/// A registry that contains repositories for different types of items.
17/// These repositories can be either managed by the client or by the SDK itself.
18pub struct StateRegistry {
19    sdk_managed: RwLock<Vec<RepositoryItemData>>,
20    client_managed: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
21
22    database: OnceLock<SystemDatabase>,
23}
24
25impl std::fmt::Debug for StateRegistry {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        f.debug_struct("StateRegistry").finish()
28    }
29}
30
31#[allow(missing_docs)]
32#[bitwarden_error(flat)]
33#[derive(Debug, Error)]
34pub enum StateRegistryError {
35    #[error("Database is already initialized")]
36    DatabaseAlreadyInitialized,
37    #[error("Database is not initialized")]
38    DatabaseNotInitialized,
39
40    #[error(transparent)]
41    Database(#[from] crate::sdk_managed::DatabaseError),
42}
43
44/// Repository not found.
45#[derive(Debug, Error)]
46#[error("Repository not found for the requested type")]
47pub struct RepositoryNotFoundError;
48
49impl StateRegistry {
50    /// Creates a new empty `StateRegistry`.
51    #[allow(clippy::new_without_default)]
52    pub fn new() -> Self {
53        StateRegistry {
54            client_managed: RwLock::new(HashMap::new()),
55            database: OnceLock::new(),
56            sdk_managed: RwLock::new(Vec::new()),
57        }
58    }
59
60    // TODO: Ideally we'd do this in new, but that would mean making the client initialization
61    // async.
62    // TODO: This function needs to be provided some configuration to know where to open the
63    // database. For Sqlite:
64    // - A folder path where the files will be stored.
65    // - A user ID to create a unique database file per user?
66    //
67    // For WASM indexedDB:
68    // - A database name to use for the indexedDB (Some prefix to avoid conflicts + user ID?)
69
70    /// Initializes the database used for sdk-managed repositories.
71    pub async fn initialize_database(
72        &self,
73        configuration: DatabaseConfiguration,
74        repositories: Vec<RepositoryItemData>,
75    ) -> Result<(), StateRegistryError> {
76        if self.database.get().is_some() {
77            return Err(StateRegistryError::DatabaseAlreadyInitialized);
78        }
79        let _ = self
80            .database
81            .set(SystemDatabase::initialize(configuration, &repositories).await?);
82
83        *self
84            .sdk_managed
85            .write()
86            .expect("RwLock should not be poisoned") = repositories.clone();
87
88        Ok(())
89    }
90
91    /// Registers a client-managed repository into the map, associating it with its type.
92    pub fn register_client_managed<T: RepositoryItem>(&self, value: Arc<dyn Repository<T>>) {
93        self.client_managed
94            .write()
95            .expect("RwLock should not be poisoned")
96            .insert(TypeId::of::<T>(), Box::new(value));
97    }
98
99    /// Retrieves a client-managed repository from the map given its type.
100    pub fn get_client_managed<T: RepositoryItem>(
101        &self,
102    ) -> Result<Arc<dyn Repository<T>>, RepositoryNotFoundError> {
103        self.client_managed
104            .read()
105            .expect("RwLock should not be poisoned")
106            .get(&TypeId::of::<T>())
107            .and_then(|boxed| boxed.downcast_ref::<Arc<dyn Repository<T>>>())
108            .map(Arc::clone)
109            .ok_or(RepositoryNotFoundError)
110    }
111
112    /// Retrieves a SDK-managed repository from the database.
113    pub fn get_sdk_managed<T: RepositoryItem + Serialize + DeserializeOwned>(
114        &self,
115    ) -> Result<impl Repository<T>, StateRegistryError> {
116        if let Some(db) = self.database.get() {
117            Ok(db.get_repository::<T>()?)
118        } else {
119            Err(StateRegistryError::DatabaseNotInitialized)
120        }
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use crate::{
128        register_repository_item,
129        repository::{RepositoryError, RepositoryItem},
130    };
131
132    macro_rules! impl_repository {
133        ($name:ident, $ty:ty) => {
134            #[async_trait::async_trait]
135            impl Repository<$ty> for $name {
136                async fn get(&self, _key: String) -> Result<Option<$ty>, RepositoryError> {
137                    Ok(Some(TestItem(self.0.clone())))
138                }
139                async fn list(&self) -> Result<Vec<$ty>, RepositoryError> {
140                    unimplemented!()
141                }
142                async fn set(&self, _key: String, _value: $ty) -> Result<(), RepositoryError> {
143                    unimplemented!()
144                }
145                async fn remove(&self, _key: String) -> Result<(), RepositoryError> {
146                    unimplemented!()
147                }
148            }
149        };
150    }
151
152    #[derive(PartialEq, Eq, Debug)]
153    struct TestA(usize);
154    #[derive(PartialEq, Eq, Debug)]
155    struct TestB(String);
156    #[derive(PartialEq, Eq, Debug)]
157    struct TestC(Vec<u8>);
158    #[derive(PartialEq, Eq, Debug)]
159    struct TestItem<T>(T);
160
161    register_repository_item!(TestItem<usize>, "TestItem_usize");
162    register_repository_item!(TestItem<String>, "TestItem_String");
163    register_repository_item!(TestItem<Vec<u8>>, "TestItem_Vec");
164
165    impl_repository!(TestA, TestItem<usize>);
166    impl_repository!(TestB, TestItem<String>);
167    impl_repository!(TestC, TestItem<Vec<u8>>);
168
169    #[tokio::test]
170    async fn test_repository_map() {
171        let a = Arc::new(TestA(145832));
172        let b = Arc::new(TestB("test".to_string()));
173        let c = Arc::new(TestC(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]));
174
175        let map = StateRegistry::new();
176
177        async fn get<T: RepositoryItem>(map: &StateRegistry) -> Option<T> {
178            map.get_client_managed::<T>()
179                .unwrap()
180                .get(String::new())
181                .await
182                .unwrap()
183        }
184
185        assert!(map.get_client_managed::<TestItem<usize>>().is_err());
186        assert!(map.get_client_managed::<TestItem<String>>().is_err());
187        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_err());
188
189        map.register_client_managed(a.clone());
190        assert_eq!(get(&map).await, Some(TestItem(a.0)));
191        assert!(map.get_client_managed::<TestItem<String>>().is_err());
192        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_err());
193
194        map.register_client_managed(b.clone());
195        assert_eq!(get(&map).await, Some(TestItem(a.0)));
196        assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
197        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_err());
198
199        map.register_client_managed(c.clone());
200        assert_eq!(get(&map).await, Some(TestItem(a.0)));
201        assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
202        assert_eq!(get(&map).await, Some(TestItem(c.0.clone())));
203    }
204}