bitwarden_state/
registry.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    sync::{Arc, RwLock},
5};
6
7use thiserror::Error;
8
9use crate::repository::{Repository, RepositoryItem};
10
11/// A registry that contains repositories for different types of items.
12/// These repositories can be either managed by the client or by the SDK itself.
13pub struct StateRegistry {
14    client_managed: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
15}
16
17impl std::fmt::Debug for StateRegistry {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_struct("StateRegistry").finish()
20    }
21}
22
23/// Repository not found.
24#[derive(Debug, Error)]
25#[error("Repository not found for the requested type")]
26pub struct RepositoryNotFoundError;
27
28impl StateRegistry {
29    /// Creates a new empty `StateRegistry`.
30    #[allow(clippy::new_without_default)]
31    pub fn new() -> Self {
32        StateRegistry {
33            client_managed: RwLock::new(HashMap::new()),
34        }
35    }
36
37    /// Registers a client-managed repository into the map, associating it with its type.
38    pub fn register_client_managed<T: RepositoryItem>(&self, value: Arc<dyn Repository<T>>) {
39        self.client_managed
40            .write()
41            .expect("RwLock should not be poisoned")
42            .insert(TypeId::of::<T>(), Box::new(value));
43    }
44
45    /// Retrieves a client-managed repository from the map given its type.
46    pub fn get_client_managed<T: RepositoryItem>(
47        &self,
48    ) -> Result<Arc<dyn Repository<T>>, RepositoryNotFoundError> {
49        self.client_managed
50            .read()
51            .expect("RwLock should not be poisoned")
52            .get(&TypeId::of::<T>())
53            .and_then(|boxed| boxed.downcast_ref::<Arc<dyn Repository<T>>>())
54            .map(Arc::clone)
55            .ok_or(RepositoryNotFoundError)
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62    use crate::{
63        register_repository_item,
64        repository::{RepositoryError, RepositoryItem},
65    };
66
67    macro_rules! impl_repository {
68        ($name:ident, $ty:ty) => {
69            #[async_trait::async_trait]
70            impl Repository<$ty> for $name {
71                async fn get(&self, _key: String) -> Result<Option<$ty>, RepositoryError> {
72                    Ok(Some(TestItem(self.0.clone())))
73                }
74                async fn list(&self) -> Result<Vec<$ty>, RepositoryError> {
75                    unimplemented!()
76                }
77                async fn set(&self, _key: String, _value: $ty) -> Result<(), RepositoryError> {
78                    unimplemented!()
79                }
80                async fn remove(&self, _key: String) -> Result<(), RepositoryError> {
81                    unimplemented!()
82                }
83            }
84        };
85    }
86
87    #[derive(PartialEq, Eq, Debug)]
88    struct TestA(usize);
89    #[derive(PartialEq, Eq, Debug)]
90    struct TestB(String);
91    #[derive(PartialEq, Eq, Debug)]
92    struct TestC(Vec<u8>);
93    #[derive(PartialEq, Eq, Debug)]
94    struct TestItem<T>(T);
95
96    register_repository_item!(TestItem<usize>, "TestItem<usize>");
97    register_repository_item!(TestItem<String>, "TestItem<String>");
98    register_repository_item!(TestItem<Vec<u8>>, "TestItem<Vec<u8>>");
99
100    impl_repository!(TestA, TestItem<usize>);
101    impl_repository!(TestB, TestItem<String>);
102    impl_repository!(TestC, TestItem<Vec<u8>>);
103
104    #[tokio::test]
105    async fn test_repository_map() {
106        let a = Arc::new(TestA(145832));
107        let b = Arc::new(TestB("test".to_string()));
108        let c = Arc::new(TestC(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]));
109
110        let map = StateRegistry::new();
111
112        async fn get<T: RepositoryItem>(map: &StateRegistry) -> Option<T> {
113            map.get_client_managed::<T>()
114                .unwrap()
115                .get(String::new())
116                .await
117                .unwrap()
118        }
119
120        assert!(map.get_client_managed::<TestItem<usize>>().is_err());
121        assert!(map.get_client_managed::<TestItem<String>>().is_err());
122        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_err());
123
124        map.register_client_managed(a.clone());
125        assert_eq!(get(&map).await, Some(TestItem(a.0)));
126        assert!(map.get_client_managed::<TestItem<String>>().is_err());
127        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_err());
128
129        map.register_client_managed(b.clone());
130        assert_eq!(get(&map).await, Some(TestItem(a.0)));
131        assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
132        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_err());
133
134        map.register_client_managed(c.clone());
135        assert_eq!(get(&map).await, Some(TestItem(a.0)));
136        assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
137        assert_eq!(get(&map).await, Some(TestItem(c.0.clone())));
138    }
139}