Skip to main content

bitwarden_state/sdk_managed/
memory.rs

1use std::{
2    any::TypeId,
3    collections::HashMap,
4    sync::{Arc, Mutex},
5};
6
7use serde::{Serialize, de::DeserializeOwned};
8
9use crate::{
10    repository::{RepositoryItem, RepositoryMigrations},
11    sdk_managed::{Database, DatabaseConfiguration, DatabaseError},
12};
13
14/// In-memory database backend implementing the [`Database`] trait.
15///
16/// Stores data in process RAM using a [`TypeId`]-keyed nested HashMap.
17/// Intended for testing, development, and cross-platform use cases where
18/// persistent storage is not required.
19///
20/// All data is lost when the instance is dropped.
21#[derive(Clone)]
22pub struct MemoryDatabase(Arc<Mutex<HashMap<TypeId, HashMap<String, String>>>>);
23
24impl MemoryDatabase {
25    /// Create a new, empty in-memory database.
26    pub fn new() -> Self {
27        MemoryDatabase(Arc::new(Mutex::new(HashMap::new())))
28    }
29}
30
31impl Database for MemoryDatabase {
32    async fn initialize(
33        configuration: DatabaseConfiguration,
34        _migrations: RepositoryMigrations,
35    ) -> Result<Self, DatabaseError> {
36        let DatabaseConfiguration::Memory = configuration else {
37            return Err(DatabaseError::UnsupportedConfiguration(configuration));
38        };
39        Ok(MemoryDatabase::new())
40    }
41
42    async fn get<T: Serialize + DeserializeOwned + RepositoryItem>(
43        &self,
44        key: &str,
45    ) -> Result<Option<T>, DatabaseError> {
46        let store = self.0.lock().expect("Mutex is not poisoned");
47        let type_map = store.get(&TypeId::of::<T>());
48        match type_map.and_then(|m| m.get(key)) {
49            Some(json) => Ok(Some(serde_json::from_str(json)?)),
50            None => Ok(None),
51        }
52    }
53
54    async fn list<T: Serialize + DeserializeOwned + RepositoryItem>(
55        &self,
56    ) -> Result<Vec<T>, DatabaseError> {
57        let store = self.0.lock().expect("Mutex is not poisoned");
58        match store.get(&TypeId::of::<T>()) {
59            None => Ok(vec![]),
60            Some(type_map) => {
61                let mut results = Vec::with_capacity(type_map.len());
62                for json in type_map.values() {
63                    results.push(serde_json::from_str(json)?);
64                }
65                Ok(results)
66            }
67        }
68    }
69
70    async fn set<T: Serialize + DeserializeOwned + RepositoryItem>(
71        &self,
72        key: &str,
73        value: T,
74    ) -> Result<(), DatabaseError> {
75        let json = serde_json::to_string(&value)?;
76        let mut store = self.0.lock().expect("Mutex is not poisoned");
77        store
78            .entry(TypeId::of::<T>())
79            .or_default()
80            .insert(key.to_string(), json);
81        Ok(())
82    }
83
84    async fn set_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
85        &self,
86        values: Vec<(String, T)>,
87    ) -> Result<(), DatabaseError> {
88        let mut store = self.0.lock().expect("Mutex is not poisoned");
89        let type_map = store.entry(TypeId::of::<T>()).or_default();
90        for (key, value) in values {
91            let json = serde_json::to_string(&value)?;
92            type_map.insert(key, json);
93        }
94        Ok(())
95    }
96
97    async fn remove<T: Serialize + DeserializeOwned + RepositoryItem>(
98        &self,
99        key: &str,
100    ) -> Result<(), DatabaseError> {
101        let mut store = self.0.lock().expect("Mutex is not poisoned");
102        if let Some(type_map) = store.get_mut(&TypeId::of::<T>()) {
103            type_map.remove(key);
104        }
105        Ok(())
106    }
107
108    async fn remove_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
109        &self,
110        keys: Vec<String>,
111    ) -> Result<(), DatabaseError> {
112        let mut store = self.0.lock().expect("Mutex is not poisoned");
113        if let Some(type_map) = store.get_mut(&TypeId::of::<T>()) {
114            for key in keys {
115                type_map.remove(&key);
116            }
117        }
118        Ok(())
119    }
120
121    async fn remove_all<T: Serialize + DeserializeOwned + RepositoryItem>(
122        &self,
123    ) -> Result<(), DatabaseError> {
124        let mut store = self.0.lock().expect("Mutex is not poisoned");
125        store.remove(&TypeId::of::<T>());
126        Ok(())
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use crate::register_repository_item;
134
135    #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
136    struct TypeA(String);
137    register_repository_item!(String => TypeA, "MemTypeA");
138
139    #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
140    struct TypeB(u64);
141    register_repository_item!(String => TypeB, "MemTypeB");
142
143    #[tokio::test]
144    async fn test_memory_database_get_set_remove() {
145        let db = MemoryDatabase::new();
146        assert_eq!(db.get::<TypeA>("key1").await.unwrap(), None);
147        db.set("key1", TypeA("hello".to_string())).await.unwrap();
148        assert_eq!(
149            db.get::<TypeA>("key1").await.unwrap(),
150            Some(TypeA("hello".to_string()))
151        );
152        db.remove::<TypeA>("key1").await.unwrap();
153        assert_eq!(db.get::<TypeA>("key1").await.unwrap(), None);
154    }
155
156    #[tokio::test]
157    async fn test_memory_database_type_isolation() {
158        let db = MemoryDatabase::new();
159        // Same string key for both types — must not interfere
160        db.set("key1", TypeA("value_a".to_string())).await.unwrap();
161        db.set("key1", TypeB(42)).await.unwrap();
162        assert_eq!(
163            db.get::<TypeA>("key1").await.unwrap(),
164            Some(TypeA("value_a".to_string()))
165        );
166        assert_eq!(db.get::<TypeB>("key1").await.unwrap(), Some(TypeB(42)));
167    }
168
169    #[tokio::test]
170    async fn test_memory_database_clone_shares_store() {
171        let db1 = MemoryDatabase::new();
172        let db2 = db1.clone();
173        db1.set("key1", TypeA("shared".to_string())).await.unwrap();
174        assert_eq!(
175            db2.get::<TypeA>("key1").await.unwrap(),
176            Some(TypeA("shared".to_string()))
177        );
178    }
179
180    #[tokio::test]
181    async fn test_memory_database_list() {
182        let db = MemoryDatabase::new();
183        db.set("a", TypeA("1".to_string())).await.unwrap();
184        db.set("b", TypeA("2".to_string())).await.unwrap();
185        db.set("c", TypeB(99)).await.unwrap();
186        let mut list_a = db.list::<TypeA>().await.unwrap();
187        list_a.sort_by_key(|x| x.0.clone());
188        assert_eq!(list_a, vec![TypeA("1".to_string()), TypeA("2".to_string())]);
189        // TypeB must not appear in TypeA list
190        assert_eq!(db.list::<TypeB>().await.unwrap(), vec![TypeB(99)]);
191    }
192
193    #[tokio::test]
194    async fn test_memory_database_set_bulk() {
195        let db = MemoryDatabase::new();
196        db.set_bulk(vec![
197            ("x".to_string(), TypeA("v1".to_string())),
198            ("y".to_string(), TypeA("v2".to_string())),
199        ])
200        .await
201        .unwrap();
202        assert_eq!(
203            db.get::<TypeA>("x").await.unwrap(),
204            Some(TypeA("v1".to_string()))
205        );
206        assert_eq!(
207            db.get::<TypeA>("y").await.unwrap(),
208            Some(TypeA("v2".to_string()))
209        );
210    }
211
212    #[tokio::test]
213    async fn test_memory_database_remove_bulk() {
214        let db = MemoryDatabase::new();
215        db.set("a", TypeA("1".to_string())).await.unwrap();
216        db.set("b", TypeA("2".to_string())).await.unwrap();
217        db.set("c", TypeA("3".to_string())).await.unwrap();
218        db.remove_bulk::<TypeA>(vec!["a".to_string(), "b".to_string()])
219            .await
220            .unwrap();
221        assert_eq!(db.get::<TypeA>("a").await.unwrap(), None);
222        assert_eq!(db.get::<TypeA>("b").await.unwrap(), None);
223        assert_eq!(
224            db.get::<TypeA>("c").await.unwrap(),
225            Some(TypeA("3".to_string()))
226        );
227    }
228
229    #[tokio::test]
230    async fn test_memory_database_remove_all() {
231        let db = MemoryDatabase::new();
232        db.set("a", TypeA("1".to_string())).await.unwrap();
233        db.set("b", TypeA("2".to_string())).await.unwrap();
234        db.set("z", TypeB(5)).await.unwrap();
235        db.remove_all::<TypeA>().await.unwrap();
236        assert_eq!(db.list::<TypeA>().await.unwrap(), vec![]);
237        // TypeB must be unaffected
238        assert_eq!(db.list::<TypeB>().await.unwrap(), vec![TypeB(5)]);
239    }
240
241    #[tokio::test]
242    async fn test_memory_database_initialize_memory_config() {
243        let db = MemoryDatabase::initialize(
244            DatabaseConfiguration::Memory,
245            RepositoryMigrations::new(vec![]),
246        )
247        .await
248        .unwrap();
249        db.set("k", TypeA("v".to_string())).await.unwrap();
250        assert_eq!(
251            db.get::<TypeA>("k").await.unwrap(),
252            Some(TypeA("v".to_string()))
253        );
254    }
255
256    #[tokio::test]
257    async fn test_memory_database_initialize_rejects_non_memory_config() {
258        let result = MemoryDatabase::initialize(
259            DatabaseConfiguration::Sqlite {
260                db_name: "ignored".to_string(),
261                folder_path: std::path::PathBuf::from("/tmp"),
262            },
263            RepositoryMigrations::new(vec![]),
264        )
265        .await;
266        assert!(matches!(
267            result,
268            Err(DatabaseError::UnsupportedConfiguration(_))
269        ));
270    }
271}