1use std::{path::PathBuf, sync::Arc};
2
3use serde::{de::DeserializeOwned, ser::Serialize};
4use tokio::sync::Mutex;
5
6use crate::{
7 repository::{
8 RepositoryItem, RepositoryMigrationStep, RepositoryMigrations, validate_registry_name,
9 },
10 sdk_managed::{Database, DatabaseConfiguration, DatabaseError},
11};
12
13#[derive(Clone)]
15pub struct SqliteDatabase {
16 conn: Arc<Mutex<Option<rusqlite::Connection>>>,
17 path: PathBuf,
18}
19
20fn validate_identifier(name: &'static str) -> Result<&'static str, DatabaseError> {
21 if validate_registry_name(name) {
22 Ok(name)
23 } else {
24 Err(DatabaseError::Internal(
25 rusqlite::Error::InvalidParameterName(name.to_string()).to_string(),
26 ))
27 }
28}
29
30impl SqliteDatabase {
31 async fn with_conn<R>(
32 &self,
33 f: impl FnOnce(&rusqlite::Connection) -> Result<R, DatabaseError>,
34 ) -> Result<R, DatabaseError> {
35 let guard = self.conn.lock().await;
36 let conn = guard.as_ref().ok_or(DatabaseError::Closed)?;
37 f(conn)
38 }
39
40 async fn with_tx<R>(
41 &self,
42 f: impl FnOnce(&rusqlite::Transaction) -> Result<R, DatabaseError>,
43 ) -> Result<R, DatabaseError> {
44 let mut guard = self.conn.lock().await;
45 let conn = guard.as_mut().ok_or(DatabaseError::Closed)?;
46 let tx = conn.transaction()?;
47 let result = f(&tx)?;
48 tx.commit()?;
49 Ok(result)
50 }
51
52 fn initialize_internal(
53 mut db: rusqlite::Connection,
54 path: PathBuf,
55 migrations: RepositoryMigrations,
56 ) -> Result<Self, DatabaseError> {
57 db.pragma_update(None, "journal_mode", "WAL")?;
59
60 let transaction = db.transaction()?;
61
62 for step in &migrations.steps {
63 match step {
64 RepositoryMigrationStep::Add(data) => {
65 transaction.execute(
69 &format!(
70 "CREATE TABLE IF NOT EXISTS \"{}\" (key TEXT PRIMARY KEY, value TEXT NOT NULL);",
71 validate_identifier(data.name())?,
72 ),
73 [],
74 )?;
75 }
76 RepositoryMigrationStep::Remove(data) => {
77 transaction.execute(
81 &format!(
82 "DROP TABLE IF EXISTS \"{}\";",
83 validate_identifier(data.name())?,
84 ),
85 [],
86 )?;
87 }
88 }
89 }
90
91 transaction.commit()?;
92 Ok(SqliteDatabase {
93 conn: Arc::new(Mutex::new(Some(db))),
94 path,
95 })
96 }
97}
98
99impl Database for SqliteDatabase {
100 async fn initialize(
101 configuration: DatabaseConfiguration,
102 migrations: RepositoryMigrations,
103 ) -> Result<Self, DatabaseError> {
104 let DatabaseConfiguration::Sqlite {
105 db_name,
106 folder_path: mut path,
107 } = configuration
108 else {
109 return Err(DatabaseError::UnsupportedConfiguration(configuration));
110 };
111 path.push(format!("{db_name}.sqlite"));
112
113 let db = rusqlite::Connection::open(&path)?;
114 Self::initialize_internal(db, path, migrations)
115 }
116
117 async fn get<T: Serialize + DeserializeOwned + RepositoryItem>(
118 &self,
119 key: &str,
120 ) -> Result<Option<T>, DatabaseError> {
121 self.with_conn(|conn| {
122 let mut stmt = conn.prepare(&format!(
125 "SELECT value FROM \"{}\" WHERE key = ?1",
126 validate_identifier(T::NAME)?
127 ))?;
128 let mut rows = stmt.query([key])?;
129
130 if let Some(row) = rows.next()? {
131 let value = row.get::<_, String>(0)?;
132 Ok(Some(serde_json::from_str(&value)?))
133 } else {
134 Ok(None)
135 }
136 })
137 .await
138 }
139
140 async fn list<T: Serialize + DeserializeOwned + RepositoryItem>(
141 &self,
142 ) -> Result<Vec<T>, DatabaseError> {
143 self.with_conn(|conn| {
144 let mut stmt = conn.prepare(&format!(
147 "SELECT key, value FROM \"{}\"",
148 validate_identifier(T::NAME)?
149 ))?;
150 let rows = stmt.query_map([], |row| row.get(1))?;
151
152 let mut results = Vec::new();
153 for row in rows {
154 let value: String = row?;
155 results.push(serde_json::from_str(&value)?);
156 }
157 Ok(results)
158 })
159 .await
160 }
161
162 async fn set<T: Serialize + DeserializeOwned + RepositoryItem>(
163 &self,
164 key: &str,
165 value: T,
166 ) -> Result<(), DatabaseError> {
167 let value = serde_json::to_string(&value)?;
168 self.with_tx(|tx| {
169 tx.execute(
172 &format!(
173 "INSERT OR REPLACE INTO \"{}\" (key, value) VALUES (?1, ?2)",
174 validate_identifier(T::NAME)?,
175 ),
176 [key, &value],
177 )?;
178 Ok(())
179 })
180 .await
181 }
182
183 async fn set_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
184 &self,
185 values: Vec<(String, T)>,
186 ) -> Result<(), DatabaseError> {
187 self.with_tx(|tx| {
188 let sql = format!(
191 "INSERT OR REPLACE INTO \"{}\" (key, value) VALUES (?1, ?2)",
192 validate_identifier(T::NAME)?,
193 );
194 for (key, value) in values {
195 let value = serde_json::to_string(&value)?;
196 tx.execute(&sql, [&key, &value])?;
197 }
198 Ok(())
199 })
200 .await
201 }
202
203 async fn remove<T: Serialize + DeserializeOwned + RepositoryItem>(
204 &self,
205 key: &str,
206 ) -> Result<(), DatabaseError> {
207 self.with_tx(|tx| {
208 tx.execute(
211 &format!(
212 "DELETE FROM \"{}\" WHERE key = ?1",
213 validate_identifier(T::NAME)?
214 ),
215 [key],
216 )?;
217 Ok(())
218 })
219 .await
220 }
221
222 async fn remove_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
223 &self,
224 keys: Vec<String>,
225 ) -> Result<(), DatabaseError> {
226 self.with_tx(|tx| {
227 let sql = format!(
230 "DELETE FROM \"{}\" WHERE key = ?1",
231 validate_identifier(T::NAME)?
232 );
233 for key in keys {
234 tx.execute(&sql, [&key])?;
235 }
236 Ok(())
237 })
238 .await
239 }
240
241 async fn remove_all<T: Serialize + DeserializeOwned + RepositoryItem>(
242 &self,
243 ) -> Result<(), DatabaseError> {
244 self.with_tx(|tx| {
245 tx.execute(
248 &format!("DELETE FROM \"{}\"", validate_identifier(T::NAME)?),
249 [],
250 )?;
251 Ok(())
252 })
253 .await
254 }
255
256 async fn wipe(&self) -> Result<(), DatabaseError> {
257 drop(self.conn.lock().await.take());
260
261 let mut result = Ok(());
264 for p in [
265 self.path.clone(),
266 self.path.with_extension("sqlite-wal"),
267 self.path.with_extension("sqlite-shm"),
268 ] {
269 match std::fs::remove_file(&p) {
270 Ok(()) => {}
271 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
272 Err(e) => {
273 tracing::warn!("Failed to delete {p:?} during wipe: {e}");
274 result = Err(DatabaseError::Internal(format!(
275 "Failed to delete {}: {e}",
276 p.display()
277 )));
278 }
279 }
280 }
281 result
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use crate::register_repository_item;
289
290 fn open_in_memory(steps: Vec<RepositoryMigrationStep>) -> SqliteDatabase {
291 SqliteDatabase::initialize_internal(
292 rusqlite::Connection::open_in_memory().unwrap(),
293 PathBuf::new(),
294 RepositoryMigrations::new(steps),
295 )
296 .unwrap()
297 }
298
299 #[tokio::test]
300 async fn test_sqlite_integration() {
301 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
302 struct TestA(usize);
303 register_repository_item!(String => TestA, "TestItem_A");
304
305 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
306 struct TestB(usize);
307 register_repository_item!(String => TestB, "TestItem_B");
308
309 let db = open_in_memory(vec![
310 RepositoryMigrationStep::Remove(TestB::data()),
312 RepositoryMigrationStep::Add(TestA::data()),
313 RepositoryMigrationStep::Add(TestB::data()),
314 RepositoryMigrationStep::Remove(TestB::data()),
316 ]);
317
318 assert_eq!(db.list::<TestA>().await.unwrap(), Vec::<TestA>::new());
319
320 db.set("key1", TestA(42)).await.unwrap();
321 assert_eq!(db.get::<TestA>("key1").await.unwrap(), Some(TestA(42)));
322
323 db.remove::<TestA>("key1").await.unwrap();
324
325 assert_eq!(db.get::<TestA>("key1").await.unwrap(), None);
326 }
327
328 #[tokio::test]
329 async fn test_sqlite_database_path_construction() {
330 let temp_dir = std::env::temp_dir().join("bitwarden_state_test");
331 std::fs::create_dir_all(&temp_dir).unwrap();
332
333 let config = DatabaseConfiguration::Sqlite {
334 db_name: "test_db".to_string(),
335 folder_path: temp_dir.clone(),
336 };
337
338 SqliteDatabase::initialize(config, RepositoryMigrations::new(vec![]))
339 .await
340 .unwrap();
341
342 assert!(temp_dir.join("test_db.sqlite").exists());
343
344 std::fs::remove_dir_all(&temp_dir).ok();
345 }
346
347 #[tokio::test]
348 async fn test_sqlite_wipe_deletes_files_and_closes_handles() {
349 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
350 struct WipeItem(usize);
351 register_repository_item!(String => WipeItem, "WipeItem_sqlite_wipe");
352
353 let temp_dir = std::env::temp_dir().join("bitwarden_state_wipe_test");
354 std::fs::create_dir_all(&temp_dir).unwrap();
355 let config = DatabaseConfiguration::Sqlite {
356 db_name: "wipe_db".to_string(),
357 folder_path: temp_dir.clone(),
358 };
359
360 let db = SqliteDatabase::initialize(
361 config,
362 RepositoryMigrations::new(vec![RepositoryMigrationStep::Add(WipeItem::data())]),
363 )
364 .await
365 .unwrap();
366 let clone = db.clone();
367 db.set("key1", WipeItem(7)).await.unwrap();
368 let sqlite_path = temp_dir.join("wipe_db.sqlite");
369 assert!(sqlite_path.exists());
370
371 db.wipe().await.unwrap();
372
373 assert!(!sqlite_path.exists());
374 assert!(!temp_dir.join("wipe_db.sqlite-wal").exists());
375 assert!(!temp_dir.join("wipe_db.sqlite-shm").exists());
376 assert!(matches!(
377 clone.get::<WipeItem>("key1").await,
378 Err(DatabaseError::Closed)
379 ));
380
381 std::fs::remove_dir_all(&temp_dir).ok();
382 }
383
384 #[tokio::test]
385 async fn test_sqlite_persistence_across_reopens() {
386 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
387 struct PersistedItem(String);
388 register_repository_item!(String => PersistedItem, "PersistedItem_reopen");
389
390 let temp_dir = std::env::temp_dir().join("bitwarden_state_persistence_test");
391 std::fs::create_dir_all(&temp_dir).unwrap();
392 let config = || DatabaseConfiguration::Sqlite {
393 db_name: "persist_db".to_string(),
394 folder_path: temp_dir.clone(),
395 };
396 let migrations =
397 || RepositoryMigrations::new(vec![RepositoryMigrationStep::Add(PersistedItem::data())]);
398
399 let db = SqliteDatabase::initialize(config(), migrations())
400 .await
401 .unwrap();
402 db.set("k", PersistedItem("hello".to_string()))
403 .await
404 .unwrap();
405 drop(db);
406
407 let db = SqliteDatabase::initialize(config(), migrations())
408 .await
409 .unwrap();
410 assert_eq!(
411 db.get::<PersistedItem>("k").await.unwrap(),
412 Some(PersistedItem("hello".to_string()))
413 );
414
415 db.wipe().await.unwrap();
416 std::fs::remove_dir_all(&temp_dir).ok();
417 }
418
419 #[tokio::test]
420 async fn test_sqlite_bulk_operations() {
421 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
422 struct BulkItem(u32);
423 register_repository_item!(String => BulkItem, "BulkItem_sqlite");
424
425 let db = open_in_memory(vec![RepositoryMigrationStep::Add(BulkItem::data())]);
426
427 db.set_bulk(vec![
428 ("a".to_string(), BulkItem(1)),
429 ("b".to_string(), BulkItem(2)),
430 ("c".to_string(), BulkItem(3)),
431 ])
432 .await
433 .unwrap();
434
435 let mut list = db.list::<BulkItem>().await.unwrap();
436 list.sort_by_key(|item| item.0);
437 assert_eq!(list, vec![BulkItem(1), BulkItem(2), BulkItem(3)]);
438
439 db.remove_bulk::<BulkItem>(vec!["a".to_string(), "b".to_string()])
440 .await
441 .unwrap();
442 assert_eq!(db.get::<BulkItem>("a").await.unwrap(), None);
443 assert_eq!(db.get::<BulkItem>("b").await.unwrap(), None);
444 assert_eq!(db.get::<BulkItem>("c").await.unwrap(), Some(BulkItem(3)));
445
446 db.remove_all::<BulkItem>().await.unwrap();
447 assert_eq!(db.list::<BulkItem>().await.unwrap(), Vec::<BulkItem>::new());
448 }
449
450 #[tokio::test]
451 async fn test_sqlite_cross_type_isolation() {
452 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
453 struct AlphaItem(String);
454 register_repository_item!(String => AlphaItem, "AlphaItem_sqlite");
455
456 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
457 struct BetaItem(u64);
458 register_repository_item!(String => BetaItem, "BetaItem_sqlite");
459
460 let db = open_in_memory(vec![
461 RepositoryMigrationStep::Add(AlphaItem::data()),
462 RepositoryMigrationStep::Add(BetaItem::data()),
463 ]);
464
465 db.set("key", AlphaItem("alpha".to_string())).await.unwrap();
466 db.set("key", BetaItem(42)).await.unwrap();
467
468 assert_eq!(
469 db.get::<AlphaItem>("key").await.unwrap(),
470 Some(AlphaItem("alpha".to_string()))
471 );
472 assert_eq!(db.get::<BetaItem>("key").await.unwrap(), Some(BetaItem(42)));
473
474 db.remove_all::<AlphaItem>().await.unwrap();
475 assert_eq!(db.get::<AlphaItem>("key").await.unwrap(), None);
476 assert_eq!(db.get::<BetaItem>("key").await.unwrap(), Some(BetaItem(42)));
478 }
479
480 #[tokio::test]
481 async fn test_sqlite_set_overwrites_existing_key() {
482 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
483 struct OverwriteItem(u32);
484 register_repository_item!(String => OverwriteItem, "OverwriteItem_sqlite");
485
486 let db = open_in_memory(vec![RepositoryMigrationStep::Add(OverwriteItem::data())]);
487
488 db.set("k", OverwriteItem(1)).await.unwrap();
489 db.set("k", OverwriteItem(2)).await.unwrap();
490
491 assert_eq!(
492 db.get::<OverwriteItem>("k").await.unwrap(),
493 Some(OverwriteItem(2))
494 );
495 assert_eq!(db.list::<OverwriteItem>().await.unwrap().len(), 1);
496 }
497
498 #[tokio::test]
499 async fn test_sqlite_wipe_with_missing_files_is_ok() {
500 let temp_dir = std::env::temp_dir().join("bitwarden_state_wipe_missing_test");
501 std::fs::create_dir_all(&temp_dir).unwrap();
502 let config = DatabaseConfiguration::Sqlite {
503 db_name: "missing_db".to_string(),
504 folder_path: temp_dir.clone(),
505 };
506
507 let db = SqliteDatabase::initialize(config, RepositoryMigrations::new(vec![]))
508 .await
509 .unwrap();
510
511 db.wipe().await.unwrap();
513 db.wipe().await.unwrap();
514
515 std::fs::remove_dir_all(&temp_dir).ok();
516 }
517}