bitwarden_state/sdk_managed/
memory.rs1use std::{
2 any::TypeId,
3 collections::HashMap,
4 sync::{Arc, Mutex},
5};
6
7use serde::{Serialize, de::DeserializeOwned};
8
9use crate::{
10 repository::{RepositoryItem, RepositoryMigrations},
11 sdk_managed::{Database, DatabaseConfiguration, DatabaseError},
12};
13
14#[derive(Clone)]
22pub struct MemoryDatabase(Arc<Mutex<HashMap<TypeId, HashMap<String, String>>>>);
23
24impl MemoryDatabase {
25 pub fn new() -> Self {
27 MemoryDatabase(Arc::new(Mutex::new(HashMap::new())))
28 }
29}
30
31impl Database for MemoryDatabase {
32 async fn initialize(
33 configuration: DatabaseConfiguration,
34 _migrations: RepositoryMigrations,
35 ) -> Result<Self, DatabaseError> {
36 let DatabaseConfiguration::Memory = configuration else {
37 return Err(DatabaseError::UnsupportedConfiguration(configuration));
38 };
39 Ok(MemoryDatabase::new())
40 }
41
42 async fn get<T: Serialize + DeserializeOwned + RepositoryItem>(
43 &self,
44 key: &str,
45 ) -> Result<Option<T>, DatabaseError> {
46 let store = self.0.lock().expect("Mutex is not poisoned");
47 let type_map = store.get(&TypeId::of::<T>());
48 match type_map.and_then(|m| m.get(key)) {
49 Some(json) => Ok(Some(serde_json::from_str(json)?)),
50 None => Ok(None),
51 }
52 }
53
54 async fn list<T: Serialize + DeserializeOwned + RepositoryItem>(
55 &self,
56 ) -> Result<Vec<T>, DatabaseError> {
57 let store = self.0.lock().expect("Mutex is not poisoned");
58 match store.get(&TypeId::of::<T>()) {
59 None => Ok(vec![]),
60 Some(type_map) => {
61 let mut results = Vec::with_capacity(type_map.len());
62 for json in type_map.values() {
63 results.push(serde_json::from_str(json)?);
64 }
65 Ok(results)
66 }
67 }
68 }
69
70 async fn set<T: Serialize + DeserializeOwned + RepositoryItem>(
71 &self,
72 key: &str,
73 value: T,
74 ) -> Result<(), DatabaseError> {
75 let json = serde_json::to_string(&value)?;
76 let mut store = self.0.lock().expect("Mutex is not poisoned");
77 store
78 .entry(TypeId::of::<T>())
79 .or_default()
80 .insert(key.to_string(), json);
81 Ok(())
82 }
83
84 async fn set_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
85 &self,
86 values: Vec<(String, T)>,
87 ) -> Result<(), DatabaseError> {
88 let mut store = self.0.lock().expect("Mutex is not poisoned");
89 let type_map = store.entry(TypeId::of::<T>()).or_default();
90 for (key, value) in values {
91 let json = serde_json::to_string(&value)?;
92 type_map.insert(key, json);
93 }
94 Ok(())
95 }
96
97 async fn remove<T: Serialize + DeserializeOwned + RepositoryItem>(
98 &self,
99 key: &str,
100 ) -> Result<(), DatabaseError> {
101 let mut store = self.0.lock().expect("Mutex is not poisoned");
102 if let Some(type_map) = store.get_mut(&TypeId::of::<T>()) {
103 type_map.remove(key);
104 }
105 Ok(())
106 }
107
108 async fn remove_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
109 &self,
110 keys: Vec<String>,
111 ) -> Result<(), DatabaseError> {
112 let mut store = self.0.lock().expect("Mutex is not poisoned");
113 if let Some(type_map) = store.get_mut(&TypeId::of::<T>()) {
114 for key in keys {
115 type_map.remove(&key);
116 }
117 }
118 Ok(())
119 }
120
121 async fn remove_all<T: Serialize + DeserializeOwned + RepositoryItem>(
122 &self,
123 ) -> Result<(), DatabaseError> {
124 let mut store = self.0.lock().expect("Mutex is not poisoned");
125 store.remove(&TypeId::of::<T>());
126 Ok(())
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use crate::register_repository_item;
134
135 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
136 struct TypeA(String);
137 register_repository_item!(String => TypeA, "MemTypeA");
138
139 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
140 struct TypeB(u64);
141 register_repository_item!(String => TypeB, "MemTypeB");
142
143 #[tokio::test]
144 async fn test_memory_database_get_set_remove() {
145 let db = MemoryDatabase::new();
146 assert_eq!(db.get::<TypeA>("key1").await.unwrap(), None);
147 db.set("key1", TypeA("hello".to_string())).await.unwrap();
148 assert_eq!(
149 db.get::<TypeA>("key1").await.unwrap(),
150 Some(TypeA("hello".to_string()))
151 );
152 db.remove::<TypeA>("key1").await.unwrap();
153 assert_eq!(db.get::<TypeA>("key1").await.unwrap(), None);
154 }
155
156 #[tokio::test]
157 async fn test_memory_database_type_isolation() {
158 let db = MemoryDatabase::new();
159 db.set("key1", TypeA("value_a".to_string())).await.unwrap();
161 db.set("key1", TypeB(42)).await.unwrap();
162 assert_eq!(
163 db.get::<TypeA>("key1").await.unwrap(),
164 Some(TypeA("value_a".to_string()))
165 );
166 assert_eq!(db.get::<TypeB>("key1").await.unwrap(), Some(TypeB(42)));
167 }
168
169 #[tokio::test]
170 async fn test_memory_database_clone_shares_store() {
171 let db1 = MemoryDatabase::new();
172 let db2 = db1.clone();
173 db1.set("key1", TypeA("shared".to_string())).await.unwrap();
174 assert_eq!(
175 db2.get::<TypeA>("key1").await.unwrap(),
176 Some(TypeA("shared".to_string()))
177 );
178 }
179
180 #[tokio::test]
181 async fn test_memory_database_list() {
182 let db = MemoryDatabase::new();
183 db.set("a", TypeA("1".to_string())).await.unwrap();
184 db.set("b", TypeA("2".to_string())).await.unwrap();
185 db.set("c", TypeB(99)).await.unwrap();
186 let mut list_a = db.list::<TypeA>().await.unwrap();
187 list_a.sort_by_key(|x| x.0.clone());
188 assert_eq!(list_a, vec![TypeA("1".to_string()), TypeA("2".to_string())]);
189 assert_eq!(db.list::<TypeB>().await.unwrap(), vec![TypeB(99)]);
191 }
192
193 #[tokio::test]
194 async fn test_memory_database_set_bulk() {
195 let db = MemoryDatabase::new();
196 db.set_bulk(vec![
197 ("x".to_string(), TypeA("v1".to_string())),
198 ("y".to_string(), TypeA("v2".to_string())),
199 ])
200 .await
201 .unwrap();
202 assert_eq!(
203 db.get::<TypeA>("x").await.unwrap(),
204 Some(TypeA("v1".to_string()))
205 );
206 assert_eq!(
207 db.get::<TypeA>("y").await.unwrap(),
208 Some(TypeA("v2".to_string()))
209 );
210 }
211
212 #[tokio::test]
213 async fn test_memory_database_remove_bulk() {
214 let db = MemoryDatabase::new();
215 db.set("a", TypeA("1".to_string())).await.unwrap();
216 db.set("b", TypeA("2".to_string())).await.unwrap();
217 db.set("c", TypeA("3".to_string())).await.unwrap();
218 db.remove_bulk::<TypeA>(vec!["a".to_string(), "b".to_string()])
219 .await
220 .unwrap();
221 assert_eq!(db.get::<TypeA>("a").await.unwrap(), None);
222 assert_eq!(db.get::<TypeA>("b").await.unwrap(), None);
223 assert_eq!(
224 db.get::<TypeA>("c").await.unwrap(),
225 Some(TypeA("3".to_string()))
226 );
227 }
228
229 #[tokio::test]
230 async fn test_memory_database_remove_all() {
231 let db = MemoryDatabase::new();
232 db.set("a", TypeA("1".to_string())).await.unwrap();
233 db.set("b", TypeA("2".to_string())).await.unwrap();
234 db.set("z", TypeB(5)).await.unwrap();
235 db.remove_all::<TypeA>().await.unwrap();
236 assert_eq!(db.list::<TypeA>().await.unwrap(), vec![]);
237 assert_eq!(db.list::<TypeB>().await.unwrap(), vec![TypeB(5)]);
239 }
240
241 #[tokio::test]
242 async fn test_memory_database_initialize_memory_config() {
243 let db = MemoryDatabase::initialize(
244 DatabaseConfiguration::Memory,
245 RepositoryMigrations::new(vec![]),
246 )
247 .await
248 .unwrap();
249 db.set("k", TypeA("v".to_string())).await.unwrap();
250 assert_eq!(
251 db.get::<TypeA>("k").await.unwrap(),
252 Some(TypeA("v".to_string()))
253 );
254 }
255
256 #[tokio::test]
257 async fn test_memory_database_initialize_rejects_non_memory_config() {
258 let result = MemoryDatabase::initialize(
259 DatabaseConfiguration::Sqlite {
260 db_name: "ignored".to_string(),
261 folder_path: std::path::PathBuf::from("/tmp"),
262 },
263 RepositoryMigrations::new(vec![]),
264 )
265 .await;
266 assert!(matches!(
267 result,
268 Err(DatabaseError::UnsupportedConfiguration(_))
269 ));
270 }
271}