bitwarden_state/sdk_managed/
memory.rs1use std::{
2 any::TypeId,
3 collections::HashMap,
4 sync::{Arc, Mutex},
5};
6
7use serde::{Serialize, de::DeserializeOwned};
8
9use crate::{
10 repository::{RepositoryItem, RepositoryMigrations},
11 sdk_managed::{Database, DatabaseConfiguration, DatabaseError},
12};
13
14#[derive(Clone)]
22pub struct MemoryDatabase(Arc<Mutex<Option<Store>>>);
23
24type Store = HashMap<TypeId, HashMap<String, String>>;
25
26impl MemoryDatabase {
27 pub fn new() -> Self {
29 MemoryDatabase(Arc::new(Mutex::new(Some(HashMap::new()))))
30 }
31
32 fn with_store<R>(
33 &self,
34 f: impl FnOnce(&Store) -> Result<R, DatabaseError>,
35 ) -> Result<R, DatabaseError> {
36 let guard = self.0.lock().expect("Mutex is not poisoned");
37 let store = guard.as_ref().ok_or(DatabaseError::Closed)?;
38 f(store)
39 }
40
41 fn with_store_mut<R>(
42 &self,
43 f: impl FnOnce(&mut Store) -> Result<R, DatabaseError>,
44 ) -> Result<R, DatabaseError> {
45 let mut guard = self.0.lock().expect("Mutex is not poisoned");
46 let store = guard.as_mut().ok_or(DatabaseError::Closed)?;
47 f(store)
48 }
49}
50
51impl Database for MemoryDatabase {
52 async fn initialize(
53 configuration: DatabaseConfiguration,
54 _migrations: RepositoryMigrations,
55 ) -> Result<Self, DatabaseError> {
56 let DatabaseConfiguration::Memory = configuration else {
57 return Err(DatabaseError::UnsupportedConfiguration(configuration));
58 };
59 Ok(MemoryDatabase::new())
60 }
61
62 async fn get<T: Serialize + DeserializeOwned + RepositoryItem>(
63 &self,
64 key: &str,
65 ) -> Result<Option<T>, DatabaseError> {
66 self.with_store(
67 |store| match store.get(&TypeId::of::<T>()).and_then(|m| m.get(key)) {
68 Some(json) => Ok(Some(serde_json::from_str(json)?)),
69 None => Ok(None),
70 },
71 )
72 }
73
74 async fn list<T: Serialize + DeserializeOwned + RepositoryItem>(
75 &self,
76 ) -> Result<Vec<T>, DatabaseError> {
77 self.with_store(|store| match store.get(&TypeId::of::<T>()) {
78 None => Ok(vec![]),
79 Some(type_map) => {
80 let mut results = Vec::with_capacity(type_map.len());
81 for json in type_map.values() {
82 results.push(serde_json::from_str(json)?);
83 }
84 Ok(results)
85 }
86 })
87 }
88
89 async fn set<T: Serialize + DeserializeOwned + RepositoryItem>(
90 &self,
91 key: &str,
92 value: T,
93 ) -> Result<(), DatabaseError> {
94 let json = serde_json::to_string(&value)?;
95 self.with_store_mut(|store| {
96 store
97 .entry(TypeId::of::<T>())
98 .or_default()
99 .insert(key.to_string(), json);
100 Ok(())
101 })
102 }
103
104 async fn set_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
105 &self,
106 values: Vec<(String, T)>,
107 ) -> Result<(), DatabaseError> {
108 self.with_store_mut(|store| {
109 let type_map = store.entry(TypeId::of::<T>()).or_default();
110 for (key, value) in values {
111 let json = serde_json::to_string(&value)?;
112 type_map.insert(key, json);
113 }
114 Ok(())
115 })
116 }
117
118 async fn remove<T: Serialize + DeserializeOwned + RepositoryItem>(
119 &self,
120 key: &str,
121 ) -> Result<(), DatabaseError> {
122 self.with_store_mut(|store| {
123 if let Some(type_map) = store.get_mut(&TypeId::of::<T>()) {
124 type_map.remove(key);
125 }
126 Ok(())
127 })
128 }
129
130 async fn remove_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
131 &self,
132 keys: Vec<String>,
133 ) -> Result<(), DatabaseError> {
134 self.with_store_mut(|store| {
135 if let Some(type_map) = store.get_mut(&TypeId::of::<T>()) {
136 for key in keys {
137 type_map.remove(&key);
138 }
139 }
140 Ok(())
141 })
142 }
143
144 async fn remove_all<T: Serialize + DeserializeOwned + RepositoryItem>(
145 &self,
146 ) -> Result<(), DatabaseError> {
147 self.with_store_mut(|store| {
148 store.remove(&TypeId::of::<T>());
149 Ok(())
150 })
151 }
152
153 async fn wipe(&self) -> Result<(), DatabaseError> {
154 *self.0.lock().expect("Mutex is not poisoned") = None;
155 Ok(())
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use crate::register_repository_item;
163
164 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
165 struct TypeA(String);
166 register_repository_item!(String => TypeA, "MemTypeA");
167
168 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
169 struct TypeB(u64);
170 register_repository_item!(String => TypeB, "MemTypeB");
171
172 #[tokio::test]
173 async fn test_memory_database_get_set_remove() {
174 let db = MemoryDatabase::new();
175 assert_eq!(db.get::<TypeA>("key1").await.unwrap(), None);
176 db.set("key1", TypeA("hello".to_string())).await.unwrap();
177 assert_eq!(
178 db.get::<TypeA>("key1").await.unwrap(),
179 Some(TypeA("hello".to_string()))
180 );
181 db.remove::<TypeA>("key1").await.unwrap();
182 assert_eq!(db.get::<TypeA>("key1").await.unwrap(), None);
183 }
184
185 #[tokio::test]
186 async fn test_memory_database_type_isolation() {
187 let db = MemoryDatabase::new();
188 db.set("key1", TypeA("value_a".to_string())).await.unwrap();
190 db.set("key1", TypeB(42)).await.unwrap();
191 assert_eq!(
192 db.get::<TypeA>("key1").await.unwrap(),
193 Some(TypeA("value_a".to_string()))
194 );
195 assert_eq!(db.get::<TypeB>("key1").await.unwrap(), Some(TypeB(42)));
196 }
197
198 #[tokio::test]
199 async fn test_memory_database_clone_shares_store() {
200 let db1 = MemoryDatabase::new();
201 let db2 = db1.clone();
202 db1.set("key1", TypeA("shared".to_string())).await.unwrap();
203 assert_eq!(
204 db2.get::<TypeA>("key1").await.unwrap(),
205 Some(TypeA("shared".to_string()))
206 );
207 }
208
209 #[tokio::test]
210 async fn test_memory_database_list() {
211 let db = MemoryDatabase::new();
212 db.set("a", TypeA("1".to_string())).await.unwrap();
213 db.set("b", TypeA("2".to_string())).await.unwrap();
214 db.set("c", TypeB(99)).await.unwrap();
215 let mut list_a = db.list::<TypeA>().await.unwrap();
216 list_a.sort_by_key(|x| x.0.clone());
217 assert_eq!(list_a, vec![TypeA("1".to_string()), TypeA("2".to_string())]);
218 assert_eq!(db.list::<TypeB>().await.unwrap(), vec![TypeB(99)]);
220 }
221
222 #[tokio::test]
223 async fn test_memory_database_set_bulk() {
224 let db = MemoryDatabase::new();
225 db.set_bulk(vec![
226 ("x".to_string(), TypeA("v1".to_string())),
227 ("y".to_string(), TypeA("v2".to_string())),
228 ])
229 .await
230 .unwrap();
231 assert_eq!(
232 db.get::<TypeA>("x").await.unwrap(),
233 Some(TypeA("v1".to_string()))
234 );
235 assert_eq!(
236 db.get::<TypeA>("y").await.unwrap(),
237 Some(TypeA("v2".to_string()))
238 );
239 }
240
241 #[tokio::test]
242 async fn test_memory_database_remove_bulk() {
243 let db = MemoryDatabase::new();
244 db.set("a", TypeA("1".to_string())).await.unwrap();
245 db.set("b", TypeA("2".to_string())).await.unwrap();
246 db.set("c", TypeA("3".to_string())).await.unwrap();
247 db.remove_bulk::<TypeA>(vec!["a".to_string(), "b".to_string()])
248 .await
249 .unwrap();
250 assert_eq!(db.get::<TypeA>("a").await.unwrap(), None);
251 assert_eq!(db.get::<TypeA>("b").await.unwrap(), None);
252 assert_eq!(
253 db.get::<TypeA>("c").await.unwrap(),
254 Some(TypeA("3".to_string()))
255 );
256 }
257
258 #[tokio::test]
259 async fn test_memory_database_remove_all() {
260 let db = MemoryDatabase::new();
261 db.set("a", TypeA("1".to_string())).await.unwrap();
262 db.set("b", TypeA("2".to_string())).await.unwrap();
263 db.set("z", TypeB(5)).await.unwrap();
264 db.remove_all::<TypeA>().await.unwrap();
265 assert_eq!(db.list::<TypeA>().await.unwrap(), vec![]);
266 assert_eq!(db.list::<TypeB>().await.unwrap(), vec![TypeB(5)]);
268 }
269
270 #[tokio::test]
271 async fn test_memory_database_initialize_memory_config() {
272 let db = MemoryDatabase::initialize(
273 DatabaseConfiguration::Memory,
274 RepositoryMigrations::new(vec![]),
275 )
276 .await
277 .unwrap();
278 db.set("k", TypeA("v".to_string())).await.unwrap();
279 assert_eq!(
280 db.get::<TypeA>("k").await.unwrap(),
281 Some(TypeA("v".to_string()))
282 );
283 }
284
285 #[tokio::test]
286 async fn test_memory_database_wipe_disconnects_clones() {
287 let db = MemoryDatabase::new();
288 let clone = db.clone();
289 db.set("k", TypeA("v".to_string())).await.unwrap();
290
291 db.wipe().await.unwrap();
292
293 assert!(matches!(
294 db.get::<TypeA>("k").await,
295 Err(DatabaseError::Closed)
296 ));
297 assert!(matches!(
298 clone.get::<TypeA>("k").await,
299 Err(DatabaseError::Closed)
300 ));
301 assert!(matches!(
302 clone.set("k", TypeA("v".to_string())).await,
303 Err(DatabaseError::Closed)
304 ));
305 }
306
307 #[tokio::test]
308 async fn test_memory_database_initialize_rejects_non_memory_config() {
309 let result = MemoryDatabase::initialize(
310 DatabaseConfiguration::Sqlite {
311 db_name: "ignored".to_string(),
312 folder_path: std::path::PathBuf::from("/tmp"),
313 },
314 RepositoryMigrations::new(vec![]),
315 )
316 .await;
317 assert!(matches!(
318 result,
319 Err(DatabaseError::UnsupportedConfiguration(_))
320 ));
321 }
322}