Skip to main content

bitwarden_state/sdk_managed/
sqlite.rs

1use std::{path::PathBuf, sync::Arc};
2
3use serde::{de::DeserializeOwned, ser::Serialize};
4use tokio::sync::Mutex;
5
6use crate::{
7    repository::{
8        RepositoryItem, RepositoryMigrationStep, RepositoryMigrations, validate_registry_name,
9    },
10    sdk_managed::{Database, DatabaseConfiguration, DatabaseError},
11};
12
13// TODO: Use connection pooling with r2d2 and r2d2_sqlite?
14#[derive(Clone)]
15pub struct SqliteDatabase {
16    conn: Arc<Mutex<Option<rusqlite::Connection>>>,
17    path: PathBuf,
18}
19
20fn validate_identifier(name: &'static str) -> Result<&'static str, DatabaseError> {
21    if validate_registry_name(name) {
22        Ok(name)
23    } else {
24        Err(DatabaseError::Internal(
25            rusqlite::Error::InvalidParameterName(name.to_string()).to_string(),
26        ))
27    }
28}
29
30impl SqliteDatabase {
31    async fn with_conn<R>(
32        &self,
33        f: impl FnOnce(&rusqlite::Connection) -> Result<R, DatabaseError>,
34    ) -> Result<R, DatabaseError> {
35        let guard = self.conn.lock().await;
36        let conn = guard.as_ref().ok_or(DatabaseError::Closed)?;
37        f(conn)
38    }
39
40    async fn with_tx<R>(
41        &self,
42        f: impl FnOnce(&rusqlite::Transaction) -> Result<R, DatabaseError>,
43    ) -> Result<R, DatabaseError> {
44        let mut guard = self.conn.lock().await;
45        let conn = guard.as_mut().ok_or(DatabaseError::Closed)?;
46        let tx = conn.transaction()?;
47        let result = f(&tx)?;
48        tx.commit()?;
49        Ok(result)
50    }
51
52    fn initialize_internal(
53        mut db: rusqlite::Connection,
54        path: PathBuf,
55        migrations: RepositoryMigrations,
56    ) -> Result<Self, DatabaseError> {
57        // Set WAL mode for better concurrency
58        db.pragma_update(None, "journal_mode", "WAL")?;
59
60        let transaction = db.transaction()?;
61
62        for step in &migrations.steps {
63            match step {
64                RepositoryMigrationStep::Add(data) => {
65                    // SAFETY: SQLite tables cannot use ?, but `reg.name()` is not user controlled
66                    // and is validated to only contain valid characters, so
67                    // it's safe to interpolate here.
68                    transaction.execute(
69                        &format!(
70                            "CREATE TABLE IF NOT EXISTS \"{}\" (key TEXT PRIMARY KEY, value TEXT NOT NULL);",
71                            validate_identifier(data.name())?,
72                        ),
73                        [],
74                    )?;
75                }
76                RepositoryMigrationStep::Remove(data) => {
77                    // SAFETY: SQLite tables cannot use ?, but `reg.name()` is not user controlled
78                    // and is validated to only contain valid characters, so
79                    // it's safe to interpolate here.
80                    transaction.execute(
81                        &format!(
82                            "DROP TABLE IF EXISTS \"{}\";",
83                            validate_identifier(data.name())?,
84                        ),
85                        [],
86                    )?;
87                }
88            }
89        }
90
91        transaction.commit()?;
92        Ok(SqliteDatabase {
93            conn: Arc::new(Mutex::new(Some(db))),
94            path,
95        })
96    }
97}
98
99impl Database for SqliteDatabase {
100    async fn initialize(
101        configuration: DatabaseConfiguration,
102        migrations: RepositoryMigrations,
103    ) -> Result<Self, DatabaseError> {
104        let DatabaseConfiguration::Sqlite {
105            db_name,
106            folder_path: mut path,
107        } = configuration
108        else {
109            return Err(DatabaseError::UnsupportedConfiguration(configuration));
110        };
111        path.push(format!("{db_name}.sqlite"));
112
113        let db = rusqlite::Connection::open(&path)?;
114        Self::initialize_internal(db, path, migrations)
115    }
116
117    async fn get<T: Serialize + DeserializeOwned + RepositoryItem>(
118        &self,
119        key: &str,
120    ) -> Result<Option<T>, DatabaseError> {
121        self.with_conn(|conn| {
122            // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
123            // validated to only contain valid characters, so it's safe to interpolate here.
124            let mut stmt = conn.prepare(&format!(
125                "SELECT value FROM \"{}\" WHERE key = ?1",
126                validate_identifier(T::NAME)?
127            ))?;
128            let mut rows = stmt.query([key])?;
129
130            if let Some(row) = rows.next()? {
131                let value = row.get::<_, String>(0)?;
132                Ok(Some(serde_json::from_str(&value)?))
133            } else {
134                Ok(None)
135            }
136        })
137        .await
138    }
139
140    async fn list<T: Serialize + DeserializeOwned + RepositoryItem>(
141        &self,
142    ) -> Result<Vec<T>, DatabaseError> {
143        self.with_conn(|conn| {
144            // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
145            // validated to only contain valid characters, so it's safe to interpolate here.
146            let mut stmt = conn.prepare(&format!(
147                "SELECT key, value FROM \"{}\"",
148                validate_identifier(T::NAME)?
149            ))?;
150            let rows = stmt.query_map([], |row| row.get(1))?;
151
152            let mut results = Vec::new();
153            for row in rows {
154                let value: String = row?;
155                results.push(serde_json::from_str(&value)?);
156            }
157            Ok(results)
158        })
159        .await
160    }
161
162    async fn set<T: Serialize + DeserializeOwned + RepositoryItem>(
163        &self,
164        key: &str,
165        value: T,
166    ) -> Result<(), DatabaseError> {
167        let value = serde_json::to_string(&value)?;
168        self.with_tx(|tx| {
169            // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
170            // validated to only contain valid characters, so it's safe to interpolate here.
171            tx.execute(
172                &format!(
173                    "INSERT OR REPLACE INTO \"{}\" (key, value) VALUES (?1, ?2)",
174                    validate_identifier(T::NAME)?,
175                ),
176                [key, &value],
177            )?;
178            Ok(())
179        })
180        .await
181    }
182
183    async fn set_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
184        &self,
185        values: Vec<(String, T)>,
186    ) -> Result<(), DatabaseError> {
187        self.with_tx(|tx| {
188            // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
189            // validated to only contain valid characters, so it's safe to interpolate here.
190            let sql = format!(
191                "INSERT OR REPLACE INTO \"{}\" (key, value) VALUES (?1, ?2)",
192                validate_identifier(T::NAME)?,
193            );
194            for (key, value) in values {
195                let value = serde_json::to_string(&value)?;
196                tx.execute(&sql, [&key, &value])?;
197            }
198            Ok(())
199        })
200        .await
201    }
202
203    async fn remove<T: Serialize + DeserializeOwned + RepositoryItem>(
204        &self,
205        key: &str,
206    ) -> Result<(), DatabaseError> {
207        self.with_tx(|tx| {
208            // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
209            // validated to only contain valid characters, so it's safe to interpolate here.
210            tx.execute(
211                &format!(
212                    "DELETE FROM \"{}\" WHERE key = ?1",
213                    validate_identifier(T::NAME)?
214                ),
215                [key],
216            )?;
217            Ok(())
218        })
219        .await
220    }
221
222    async fn remove_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
223        &self,
224        keys: Vec<String>,
225    ) -> Result<(), DatabaseError> {
226        self.with_tx(|tx| {
227            // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
228            // validated to only contain valid characters, so it's safe to interpolate here.
229            let sql = format!(
230                "DELETE FROM \"{}\" WHERE key = ?1",
231                validate_identifier(T::NAME)?
232            );
233            for key in keys {
234                tx.execute(&sql, [&key])?;
235            }
236            Ok(())
237        })
238        .await
239    }
240
241    async fn remove_all<T: Serialize + DeserializeOwned + RepositoryItem>(
242        &self,
243    ) -> Result<(), DatabaseError> {
244        self.with_tx(|tx| {
245            // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
246            // validated to only contain valid characters, so it's safe to interpolate here.
247            tx.execute(
248                &format!("DELETE FROM \"{}\"", validate_identifier(T::NAME)?),
249                [],
250            )?;
251            Ok(())
252        })
253        .await
254    }
255
256    async fn wipe(&self) -> Result<(), DatabaseError> {
257        // Take and drop the connection under the lock so OS file handles close
258        // before we attempt to remove the file (matters on Windows).
259        drop(self.conn.lock().await.take());
260
261        // Attempt to remove every file before returning, so one failure doesn't
262        // skip the rest. `NotFound` is ignored and considered success.
263        let mut result = Ok(());
264        for p in [
265            self.path.clone(),
266            self.path.with_extension("sqlite-wal"),
267            self.path.with_extension("sqlite-shm"),
268        ] {
269            match std::fs::remove_file(&p) {
270                Ok(()) => {}
271                Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
272                Err(e) => {
273                    tracing::warn!("Failed to delete {p:?} during wipe: {e}");
274                    result = Err(DatabaseError::Internal(format!(
275                        "Failed to delete {}: {e}",
276                        p.display()
277                    )));
278                }
279            }
280        }
281        result
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use crate::register_repository_item;
289
290    fn open_in_memory(steps: Vec<RepositoryMigrationStep>) -> SqliteDatabase {
291        SqliteDatabase::initialize_internal(
292            rusqlite::Connection::open_in_memory().unwrap(),
293            PathBuf::new(),
294            RepositoryMigrations::new(steps),
295        )
296        .unwrap()
297    }
298
299    #[tokio::test]
300    async fn test_sqlite_integration() {
301        #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
302        struct TestA(usize);
303        register_repository_item!(String => TestA, "TestItem_A");
304
305        #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
306        struct TestB(usize);
307        register_repository_item!(String => TestB, "TestItem_B");
308
309        let db = open_in_memory(vec![
310            // Test that deleting a table that doesn't exist is fine
311            RepositoryMigrationStep::Remove(TestB::data()),
312            RepositoryMigrationStep::Add(TestA::data()),
313            RepositoryMigrationStep::Add(TestB::data()),
314            // Test that deleting a table that does exist is also fine
315            RepositoryMigrationStep::Remove(TestB::data()),
316        ]);
317
318        assert_eq!(db.list::<TestA>().await.unwrap(), Vec::<TestA>::new());
319
320        db.set("key1", TestA(42)).await.unwrap();
321        assert_eq!(db.get::<TestA>("key1").await.unwrap(), Some(TestA(42)));
322
323        db.remove::<TestA>("key1").await.unwrap();
324
325        assert_eq!(db.get::<TestA>("key1").await.unwrap(), None);
326    }
327
328    #[tokio::test]
329    async fn test_sqlite_database_path_construction() {
330        let temp_dir = std::env::temp_dir().join("bitwarden_state_test");
331        std::fs::create_dir_all(&temp_dir).unwrap();
332
333        let config = DatabaseConfiguration::Sqlite {
334            db_name: "test_db".to_string(),
335            folder_path: temp_dir.clone(),
336        };
337
338        SqliteDatabase::initialize(config, RepositoryMigrations::new(vec![]))
339            .await
340            .unwrap();
341
342        assert!(temp_dir.join("test_db.sqlite").exists());
343
344        std::fs::remove_dir_all(&temp_dir).ok();
345    }
346
347    #[tokio::test]
348    async fn test_sqlite_wipe_deletes_files_and_closes_handles() {
349        #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
350        struct WipeItem(usize);
351        register_repository_item!(String => WipeItem, "WipeItem_sqlite_wipe");
352
353        let temp_dir = std::env::temp_dir().join("bitwarden_state_wipe_test");
354        std::fs::create_dir_all(&temp_dir).unwrap();
355        let config = DatabaseConfiguration::Sqlite {
356            db_name: "wipe_db".to_string(),
357            folder_path: temp_dir.clone(),
358        };
359
360        let db = SqliteDatabase::initialize(
361            config,
362            RepositoryMigrations::new(vec![RepositoryMigrationStep::Add(WipeItem::data())]),
363        )
364        .await
365        .unwrap();
366        let clone = db.clone();
367        db.set("key1", WipeItem(7)).await.unwrap();
368        let sqlite_path = temp_dir.join("wipe_db.sqlite");
369        assert!(sqlite_path.exists());
370
371        db.wipe().await.unwrap();
372
373        assert!(!sqlite_path.exists());
374        assert!(!temp_dir.join("wipe_db.sqlite-wal").exists());
375        assert!(!temp_dir.join("wipe_db.sqlite-shm").exists());
376        assert!(matches!(
377            clone.get::<WipeItem>("key1").await,
378            Err(DatabaseError::Closed)
379        ));
380
381        std::fs::remove_dir_all(&temp_dir).ok();
382    }
383
384    #[tokio::test]
385    async fn test_sqlite_persistence_across_reopens() {
386        #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
387        struct PersistedItem(String);
388        register_repository_item!(String => PersistedItem, "PersistedItem_reopen");
389
390        let temp_dir = std::env::temp_dir().join("bitwarden_state_persistence_test");
391        std::fs::create_dir_all(&temp_dir).unwrap();
392        let config = || DatabaseConfiguration::Sqlite {
393            db_name: "persist_db".to_string(),
394            folder_path: temp_dir.clone(),
395        };
396        let migrations =
397            || RepositoryMigrations::new(vec![RepositoryMigrationStep::Add(PersistedItem::data())]);
398
399        let db = SqliteDatabase::initialize(config(), migrations())
400            .await
401            .unwrap();
402        db.set("k", PersistedItem("hello".to_string()))
403            .await
404            .unwrap();
405        drop(db);
406
407        let db = SqliteDatabase::initialize(config(), migrations())
408            .await
409            .unwrap();
410        assert_eq!(
411            db.get::<PersistedItem>("k").await.unwrap(),
412            Some(PersistedItem("hello".to_string()))
413        );
414
415        db.wipe().await.unwrap();
416        std::fs::remove_dir_all(&temp_dir).ok();
417    }
418
419    #[tokio::test]
420    async fn test_sqlite_bulk_operations() {
421        #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
422        struct BulkItem(u32);
423        register_repository_item!(String => BulkItem, "BulkItem_sqlite");
424
425        let db = open_in_memory(vec![RepositoryMigrationStep::Add(BulkItem::data())]);
426
427        db.set_bulk(vec![
428            ("a".to_string(), BulkItem(1)),
429            ("b".to_string(), BulkItem(2)),
430            ("c".to_string(), BulkItem(3)),
431        ])
432        .await
433        .unwrap();
434
435        let mut list = db.list::<BulkItem>().await.unwrap();
436        list.sort_by_key(|item| item.0);
437        assert_eq!(list, vec![BulkItem(1), BulkItem(2), BulkItem(3)]);
438
439        db.remove_bulk::<BulkItem>(vec!["a".to_string(), "b".to_string()])
440            .await
441            .unwrap();
442        assert_eq!(db.get::<BulkItem>("a").await.unwrap(), None);
443        assert_eq!(db.get::<BulkItem>("b").await.unwrap(), None);
444        assert_eq!(db.get::<BulkItem>("c").await.unwrap(), Some(BulkItem(3)));
445
446        db.remove_all::<BulkItem>().await.unwrap();
447        assert_eq!(db.list::<BulkItem>().await.unwrap(), Vec::<BulkItem>::new());
448    }
449
450    #[tokio::test]
451    async fn test_sqlite_cross_type_isolation() {
452        #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
453        struct AlphaItem(String);
454        register_repository_item!(String => AlphaItem, "AlphaItem_sqlite");
455
456        #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
457        struct BetaItem(u64);
458        register_repository_item!(String => BetaItem, "BetaItem_sqlite");
459
460        let db = open_in_memory(vec![
461            RepositoryMigrationStep::Add(AlphaItem::data()),
462            RepositoryMigrationStep::Add(BetaItem::data()),
463        ]);
464
465        db.set("key", AlphaItem("alpha".to_string())).await.unwrap();
466        db.set("key", BetaItem(42)).await.unwrap();
467
468        assert_eq!(
469            db.get::<AlphaItem>("key").await.unwrap(),
470            Some(AlphaItem("alpha".to_string()))
471        );
472        assert_eq!(db.get::<BetaItem>("key").await.unwrap(), Some(BetaItem(42)));
473
474        db.remove_all::<AlphaItem>().await.unwrap();
475        assert_eq!(db.get::<AlphaItem>("key").await.unwrap(), None);
476        // BetaItem must be unaffected.
477        assert_eq!(db.get::<BetaItem>("key").await.unwrap(), Some(BetaItem(42)));
478    }
479
480    #[tokio::test]
481    async fn test_sqlite_set_overwrites_existing_key() {
482        #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
483        struct OverwriteItem(u32);
484        register_repository_item!(String => OverwriteItem, "OverwriteItem_sqlite");
485
486        let db = open_in_memory(vec![RepositoryMigrationStep::Add(OverwriteItem::data())]);
487
488        db.set("k", OverwriteItem(1)).await.unwrap();
489        db.set("k", OverwriteItem(2)).await.unwrap();
490
491        assert_eq!(
492            db.get::<OverwriteItem>("k").await.unwrap(),
493            Some(OverwriteItem(2))
494        );
495        assert_eq!(db.list::<OverwriteItem>().await.unwrap().len(), 1);
496    }
497
498    #[tokio::test]
499    async fn test_sqlite_wipe_with_missing_files_is_ok() {
500        let temp_dir = std::env::temp_dir().join("bitwarden_state_wipe_missing_test");
501        std::fs::create_dir_all(&temp_dir).unwrap();
502        let config = DatabaseConfiguration::Sqlite {
503            db_name: "missing_db".to_string(),
504            folder_path: temp_dir.clone(),
505        };
506
507        let db = SqliteDatabase::initialize(config, RepositoryMigrations::new(vec![]))
508            .await
509            .unwrap();
510
511        // Wiping twice exercises the missing-file path on the second call.
512        db.wipe().await.unwrap();
513        db.wipe().await.unwrap();
514
515        std::fs::remove_dir_all(&temp_dir).ok();
516    }
517}