Skip to main content

bitwarden_state/sdk_managed/
memory.rs

1use std::{
2    any::TypeId,
3    collections::HashMap,
4    sync::{Arc, Mutex},
5};
6
7use serde::{Serialize, de::DeserializeOwned};
8
9use crate::{
10    repository::{RepositoryItem, RepositoryMigrations},
11    sdk_managed::{Database, DatabaseConfiguration, DatabaseError},
12};
13
14/// In-memory database backend implementing the [`Database`] trait.
15///
16/// Stores data in process RAM using a [`TypeId`]-keyed nested HashMap.
17/// Intended for testing, development, and cross-platform use cases where
18/// persistent storage is not required.
19///
20/// All data is lost when the instance is dropped.
21#[derive(Clone)]
22pub struct MemoryDatabase(Arc<Mutex<Option<Store>>>);
23
24type Store = HashMap<TypeId, HashMap<String, String>>;
25
26impl MemoryDatabase {
27    /// Create a new, empty in-memory database.
28    pub fn new() -> Self {
29        MemoryDatabase(Arc::new(Mutex::new(Some(HashMap::new()))))
30    }
31
32    fn with_store<R>(
33        &self,
34        f: impl FnOnce(&Store) -> Result<R, DatabaseError>,
35    ) -> Result<R, DatabaseError> {
36        let guard = self.0.lock().expect("Mutex is not poisoned");
37        let store = guard.as_ref().ok_or(DatabaseError::Closed)?;
38        f(store)
39    }
40
41    fn with_store_mut<R>(
42        &self,
43        f: impl FnOnce(&mut Store) -> Result<R, DatabaseError>,
44    ) -> Result<R, DatabaseError> {
45        let mut guard = self.0.lock().expect("Mutex is not poisoned");
46        let store = guard.as_mut().ok_or(DatabaseError::Closed)?;
47        f(store)
48    }
49}
50
51impl Database for MemoryDatabase {
52    async fn initialize(
53        configuration: DatabaseConfiguration,
54        _migrations: RepositoryMigrations,
55    ) -> Result<Self, DatabaseError> {
56        let DatabaseConfiguration::Memory = configuration else {
57            return Err(DatabaseError::UnsupportedConfiguration(configuration));
58        };
59        Ok(MemoryDatabase::new())
60    }
61
62    async fn get<T: Serialize + DeserializeOwned + RepositoryItem>(
63        &self,
64        key: &str,
65    ) -> Result<Option<T>, DatabaseError> {
66        self.with_store(
67            |store| match store.get(&TypeId::of::<T>()).and_then(|m| m.get(key)) {
68                Some(json) => Ok(Some(serde_json::from_str(json)?)),
69                None => Ok(None),
70            },
71        )
72    }
73
74    async fn list<T: Serialize + DeserializeOwned + RepositoryItem>(
75        &self,
76    ) -> Result<Vec<T>, DatabaseError> {
77        self.with_store(|store| match store.get(&TypeId::of::<T>()) {
78            None => Ok(vec![]),
79            Some(type_map) => {
80                let mut results = Vec::with_capacity(type_map.len());
81                for json in type_map.values() {
82                    results.push(serde_json::from_str(json)?);
83                }
84                Ok(results)
85            }
86        })
87    }
88
89    async fn set<T: Serialize + DeserializeOwned + RepositoryItem>(
90        &self,
91        key: &str,
92        value: T,
93    ) -> Result<(), DatabaseError> {
94        let json = serde_json::to_string(&value)?;
95        self.with_store_mut(|store| {
96            store
97                .entry(TypeId::of::<T>())
98                .or_default()
99                .insert(key.to_string(), json);
100            Ok(())
101        })
102    }
103
104    async fn set_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
105        &self,
106        values: Vec<(String, T)>,
107    ) -> Result<(), DatabaseError> {
108        self.with_store_mut(|store| {
109            let type_map = store.entry(TypeId::of::<T>()).or_default();
110            for (key, value) in values {
111                let json = serde_json::to_string(&value)?;
112                type_map.insert(key, json);
113            }
114            Ok(())
115        })
116    }
117
118    async fn remove<T: Serialize + DeserializeOwned + RepositoryItem>(
119        &self,
120        key: &str,
121    ) -> Result<(), DatabaseError> {
122        self.with_store_mut(|store| {
123            if let Some(type_map) = store.get_mut(&TypeId::of::<T>()) {
124                type_map.remove(key);
125            }
126            Ok(())
127        })
128    }
129
130    async fn remove_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
131        &self,
132        keys: Vec<String>,
133    ) -> Result<(), DatabaseError> {
134        self.with_store_mut(|store| {
135            if let Some(type_map) = store.get_mut(&TypeId::of::<T>()) {
136                for key in keys {
137                    type_map.remove(&key);
138                }
139            }
140            Ok(())
141        })
142    }
143
144    async fn remove_all<T: Serialize + DeserializeOwned + RepositoryItem>(
145        &self,
146    ) -> Result<(), DatabaseError> {
147        self.with_store_mut(|store| {
148            store.remove(&TypeId::of::<T>());
149            Ok(())
150        })
151    }
152
153    async fn wipe(&self) -> Result<(), DatabaseError> {
154        *self.0.lock().expect("Mutex is not poisoned") = None;
155        Ok(())
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use crate::register_repository_item;
163
164    #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
165    struct TypeA(String);
166    register_repository_item!(String => TypeA, "MemTypeA");
167
168    #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
169    struct TypeB(u64);
170    register_repository_item!(String => TypeB, "MemTypeB");
171
172    #[tokio::test]
173    async fn test_memory_database_get_set_remove() {
174        let db = MemoryDatabase::new();
175        assert_eq!(db.get::<TypeA>("key1").await.unwrap(), None);
176        db.set("key1", TypeA("hello".to_string())).await.unwrap();
177        assert_eq!(
178            db.get::<TypeA>("key1").await.unwrap(),
179            Some(TypeA("hello".to_string()))
180        );
181        db.remove::<TypeA>("key1").await.unwrap();
182        assert_eq!(db.get::<TypeA>("key1").await.unwrap(), None);
183    }
184
185    #[tokio::test]
186    async fn test_memory_database_type_isolation() {
187        let db = MemoryDatabase::new();
188        // Same string key for both types — must not interfere
189        db.set("key1", TypeA("value_a".to_string())).await.unwrap();
190        db.set("key1", TypeB(42)).await.unwrap();
191        assert_eq!(
192            db.get::<TypeA>("key1").await.unwrap(),
193            Some(TypeA("value_a".to_string()))
194        );
195        assert_eq!(db.get::<TypeB>("key1").await.unwrap(), Some(TypeB(42)));
196    }
197
198    #[tokio::test]
199    async fn test_memory_database_clone_shares_store() {
200        let db1 = MemoryDatabase::new();
201        let db2 = db1.clone();
202        db1.set("key1", TypeA("shared".to_string())).await.unwrap();
203        assert_eq!(
204            db2.get::<TypeA>("key1").await.unwrap(),
205            Some(TypeA("shared".to_string()))
206        );
207    }
208
209    #[tokio::test]
210    async fn test_memory_database_list() {
211        let db = MemoryDatabase::new();
212        db.set("a", TypeA("1".to_string())).await.unwrap();
213        db.set("b", TypeA("2".to_string())).await.unwrap();
214        db.set("c", TypeB(99)).await.unwrap();
215        let mut list_a = db.list::<TypeA>().await.unwrap();
216        list_a.sort_by_key(|x| x.0.clone());
217        assert_eq!(list_a, vec![TypeA("1".to_string()), TypeA("2".to_string())]);
218        // TypeB must not appear in TypeA list
219        assert_eq!(db.list::<TypeB>().await.unwrap(), vec![TypeB(99)]);
220    }
221
222    #[tokio::test]
223    async fn test_memory_database_set_bulk() {
224        let db = MemoryDatabase::new();
225        db.set_bulk(vec![
226            ("x".to_string(), TypeA("v1".to_string())),
227            ("y".to_string(), TypeA("v2".to_string())),
228        ])
229        .await
230        .unwrap();
231        assert_eq!(
232            db.get::<TypeA>("x").await.unwrap(),
233            Some(TypeA("v1".to_string()))
234        );
235        assert_eq!(
236            db.get::<TypeA>("y").await.unwrap(),
237            Some(TypeA("v2".to_string()))
238        );
239    }
240
241    #[tokio::test]
242    async fn test_memory_database_remove_bulk() {
243        let db = MemoryDatabase::new();
244        db.set("a", TypeA("1".to_string())).await.unwrap();
245        db.set("b", TypeA("2".to_string())).await.unwrap();
246        db.set("c", TypeA("3".to_string())).await.unwrap();
247        db.remove_bulk::<TypeA>(vec!["a".to_string(), "b".to_string()])
248            .await
249            .unwrap();
250        assert_eq!(db.get::<TypeA>("a").await.unwrap(), None);
251        assert_eq!(db.get::<TypeA>("b").await.unwrap(), None);
252        assert_eq!(
253            db.get::<TypeA>("c").await.unwrap(),
254            Some(TypeA("3".to_string()))
255        );
256    }
257
258    #[tokio::test]
259    async fn test_memory_database_remove_all() {
260        let db = MemoryDatabase::new();
261        db.set("a", TypeA("1".to_string())).await.unwrap();
262        db.set("b", TypeA("2".to_string())).await.unwrap();
263        db.set("z", TypeB(5)).await.unwrap();
264        db.remove_all::<TypeA>().await.unwrap();
265        assert_eq!(db.list::<TypeA>().await.unwrap(), vec![]);
266        // TypeB must be unaffected
267        assert_eq!(db.list::<TypeB>().await.unwrap(), vec![TypeB(5)]);
268    }
269
270    #[tokio::test]
271    async fn test_memory_database_initialize_memory_config() {
272        let db = MemoryDatabase::initialize(
273            DatabaseConfiguration::Memory,
274            RepositoryMigrations::new(vec![]),
275        )
276        .await
277        .unwrap();
278        db.set("k", TypeA("v".to_string())).await.unwrap();
279        assert_eq!(
280            db.get::<TypeA>("k").await.unwrap(),
281            Some(TypeA("v".to_string()))
282        );
283    }
284
285    #[tokio::test]
286    async fn test_memory_database_wipe_disconnects_clones() {
287        let db = MemoryDatabase::new();
288        let clone = db.clone();
289        db.set("k", TypeA("v".to_string())).await.unwrap();
290
291        db.wipe().await.unwrap();
292
293        assert!(matches!(
294            db.get::<TypeA>("k").await,
295            Err(DatabaseError::Closed)
296        ));
297        assert!(matches!(
298            clone.get::<TypeA>("k").await,
299            Err(DatabaseError::Closed)
300        ));
301        assert!(matches!(
302            clone.set("k", TypeA("v".to_string())).await,
303            Err(DatabaseError::Closed)
304        ));
305    }
306
307    #[tokio::test]
308    async fn test_memory_database_initialize_rejects_non_memory_config() {
309        let result = MemoryDatabase::initialize(
310            DatabaseConfiguration::Sqlite {
311                db_name: "ignored".to_string(),
312                folder_path: std::path::PathBuf::from("/tmp"),
313            },
314            RepositoryMigrations::new(vec![]),
315        )
316        .await;
317        assert!(matches!(
318            result,
319            Err(DatabaseError::UnsupportedConfiguration(_))
320        ));
321    }
322}