bitwarden_state/
registry.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    sync::{Arc, RwLock},
5};
6
7use crate::repository::{Repository, RepositoryItem};
8
9/// A registry that contains repositories for different types of items.
10/// These repositories can be either managed by the client or by the SDK itself.
11pub struct StateRegistry {
12    client_managed: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
13}
14
15impl std::fmt::Debug for StateRegistry {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        f.debug_struct("StateRegistry").finish()
18    }
19}
20
21impl StateRegistry {
22    /// Creates a new empty `StateRegistry`.
23    #[allow(clippy::new_without_default)]
24    pub fn new() -> Self {
25        StateRegistry {
26            client_managed: RwLock::new(HashMap::new()),
27        }
28    }
29
30    /// Registers a client-managed repository into the map, associating it with its type.
31    pub fn register_client_managed<T: RepositoryItem>(&self, value: Arc<dyn Repository<T>>) {
32        self.client_managed
33            .write()
34            .expect("RwLock should not be poisoned")
35            .insert(TypeId::of::<T>(), Box::new(value));
36    }
37
38    /// Retrieves a client-managed repository from the map given its type.
39    pub fn get_client_managed<T: RepositoryItem>(&self) -> Option<Arc<dyn Repository<T>>> {
40        self.client_managed
41            .read()
42            .expect("RwLock should not be poisoned")
43            .get(&TypeId::of::<T>())
44            .and_then(|boxed| boxed.downcast_ref::<Arc<dyn Repository<T>>>())
45            .map(Arc::clone)
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52    use crate::{
53        register_repository_item,
54        repository::{RepositoryError, RepositoryItem},
55    };
56
57    macro_rules! impl_repository {
58        ($name:ident, $ty:ty) => {
59            #[async_trait::async_trait]
60            impl Repository<$ty> for $name {
61                async fn get(&self, _key: String) -> Result<Option<$ty>, RepositoryError> {
62                    Ok(Some(TestItem(self.0.clone())))
63                }
64                async fn list(&self) -> Result<Vec<$ty>, RepositoryError> {
65                    unimplemented!()
66                }
67                async fn set(&self, _key: String, _value: $ty) -> Result<(), RepositoryError> {
68                    unimplemented!()
69                }
70                async fn remove(&self, _key: String) -> Result<(), RepositoryError> {
71                    unimplemented!()
72                }
73            }
74        };
75    }
76
77    #[derive(PartialEq, Eq, Debug)]
78    struct TestA(usize);
79    #[derive(PartialEq, Eq, Debug)]
80    struct TestB(String);
81    #[derive(PartialEq, Eq, Debug)]
82    struct TestC(Vec<u8>);
83    #[derive(PartialEq, Eq, Debug)]
84    struct TestItem<T>(T);
85
86    register_repository_item!(TestItem<usize>, "TestItem<usize>");
87    register_repository_item!(TestItem<String>, "TestItem<String>");
88    register_repository_item!(TestItem<Vec<u8>>, "TestItem<Vec<u8>>");
89
90    impl_repository!(TestA, TestItem<usize>);
91    impl_repository!(TestB, TestItem<String>);
92    impl_repository!(TestC, TestItem<Vec<u8>>);
93
94    #[tokio::test]
95    async fn test_repository_map() {
96        let a = Arc::new(TestA(145832));
97        let b = Arc::new(TestB("test".to_string()));
98        let c = Arc::new(TestC(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]));
99
100        let map = StateRegistry::new();
101
102        async fn get<T: RepositoryItem>(map: &StateRegistry) -> Option<T> {
103            map.get_client_managed::<T>()
104                .unwrap()
105                .get(String::new())
106                .await
107                .unwrap()
108        }
109
110        assert!(map.get_client_managed::<TestItem<usize>>().is_none());
111        assert!(map.get_client_managed::<TestItem<String>>().is_none());
112        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
113
114        map.register_client_managed(a.clone());
115        assert_eq!(get(&map).await, Some(TestItem(a.0)));
116        assert!(map.get_client_managed::<TestItem<String>>().is_none());
117        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
118
119        map.register_client_managed(b.clone());
120        assert_eq!(get(&map).await, Some(TestItem(a.0)));
121        assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
122        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
123
124        map.register_client_managed(c.clone());
125        assert_eq!(get(&map).await, Some(TestItem(a.0)));
126        assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
127        assert_eq!(get(&map).await, Some(TestItem(c.0.clone())));
128    }
129}