Skip to main content

bitwarden_state/sdk_managed/
sqlite.rs

1use std::sync::Arc;
2
3use serde::{de::DeserializeOwned, ser::Serialize};
4use tokio::sync::Mutex;
5
6use crate::{
7    repository::{
8        RepositoryItem, RepositoryMigrationStep, RepositoryMigrations, validate_registry_name,
9    },
10    sdk_managed::{Database, DatabaseConfiguration, DatabaseError},
11};
12
13// TODO: Use connection pooling with r2d2 and r2d2_sqlite?
14#[derive(Clone)]
15pub struct SqliteDatabase(Arc<Mutex<rusqlite::Connection>>);
16
17fn validate_identifier(name: &'static str) -> Result<&'static str, DatabaseError> {
18    if validate_registry_name(name) {
19        Ok(name)
20    } else {
21        Err(DatabaseError::Internal(
22            rusqlite::Error::InvalidParameterName(name.to_string()),
23        ))
24    }
25}
26
27impl SqliteDatabase {
28    fn initialize_internal(
29        mut db: rusqlite::Connection,
30        migrations: RepositoryMigrations,
31    ) -> Result<Self, DatabaseError> {
32        // Set WAL mode for better concurrency
33        db.pragma_update(None, "journal_mode", "WAL")?;
34
35        let transaction = db.transaction()?;
36
37        for step in &migrations.steps {
38            match step {
39                RepositoryMigrationStep::Add(data) => {
40                    // SAFETY: SQLite tables cannot use ?, but `reg.name()` is not user controlled
41                    // and is validated to only contain valid characters, so
42                    // it's safe to interpolate here.
43                    transaction.execute(
44                        &format!(
45                            "CREATE TABLE IF NOT EXISTS \"{}\" (key TEXT PRIMARY KEY, value TEXT NOT NULL);",
46                            validate_identifier(data.name())?,
47                        ),
48                        [],
49                    )?;
50                }
51                RepositoryMigrationStep::Remove(data) => {
52                    // SAFETY: SQLite tables cannot use ?, but `reg.name()` is not user controlled
53                    // and is validated to only contain valid characters, so
54                    // it's safe to interpolate here.
55                    transaction.execute(
56                        &format!(
57                            "DROP TABLE IF EXISTS \"{}\";",
58                            validate_identifier(data.name())?,
59                        ),
60                        [],
61                    )?;
62                }
63            }
64        }
65
66        transaction.commit()?;
67        Ok(SqliteDatabase(Arc::new(Mutex::new(db))))
68    }
69}
70
71impl Database for SqliteDatabase {
72    async fn initialize(
73        configuration: DatabaseConfiguration,
74        migrations: RepositoryMigrations,
75    ) -> Result<Self, DatabaseError> {
76        let DatabaseConfiguration::Sqlite {
77            db_name,
78            folder_path: mut path,
79        } = configuration
80        else {
81            return Err(DatabaseError::UnsupportedConfiguration(configuration));
82        };
83        path.push(format!("{db_name}.sqlite"));
84
85        let db = rusqlite::Connection::open(path)?;
86        Self::initialize_internal(db, migrations)
87    }
88
89    async fn get<T: Serialize + DeserializeOwned + RepositoryItem>(
90        &self,
91        key: &str,
92    ) -> Result<Option<T>, DatabaseError> {
93        let conn = self.0.lock().await;
94
95        // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
96        // validated to only contain valid characters, so it's safe to interpolate here.
97        let mut stmt = conn.prepare(&format!(
98            "SELECT value FROM \"{}\" WHERE key = ?1",
99            validate_identifier(T::NAME)?
100        ))?;
101        let mut rows = stmt.query([key])?;
102
103        if let Some(row) = rows.next()? {
104            let value = row.get::<_, String>(0)?;
105
106            Ok(Some(serde_json::from_str(&value)?))
107        } else {
108            Ok(None)
109        }
110    }
111
112    async fn list<T: Serialize + DeserializeOwned + RepositoryItem>(
113        &self,
114    ) -> Result<Vec<T>, DatabaseError> {
115        let conn = self.0.lock().await;
116
117        // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
118        // validated to only contain valid characters, so it's safe to interpolate here.
119        let mut stmt = conn.prepare(&format!(
120            "SELECT key, value FROM \"{}\"",
121            validate_identifier(T::NAME)?
122        ))?;
123        let rows = stmt.query_map([], |row| row.get(1))?;
124
125        let mut results = Vec::new();
126        for row in rows {
127            let value: String = row?;
128            let value: T = serde_json::from_str(&value)?;
129            results.push(value);
130        }
131
132        Ok(results)
133    }
134
135    async fn set<T: Serialize + DeserializeOwned + RepositoryItem>(
136        &self,
137        key: &str,
138        value: T,
139    ) -> Result<(), DatabaseError> {
140        let mut conn = self.0.lock().await;
141        let transaction = conn.transaction()?;
142
143        let value = serde_json::to_string(&value)?;
144
145        // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
146        // validated to only contain valid characters, so it's safe to interpolate here.
147        transaction.execute(
148            &format!(
149                "INSERT OR REPLACE INTO \"{}\" (key, value) VALUES (?1, ?2)",
150                validate_identifier(T::NAME)?,
151            ),
152            [key, &value],
153        )?;
154
155        transaction.commit()?;
156        Ok(())
157    }
158
159    async fn set_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
160        &self,
161        values: Vec<(String, T)>,
162    ) -> Result<(), DatabaseError> {
163        let mut conn = self.0.lock().await;
164        let transaction = conn.transaction()?;
165
166        // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
167        // validated to only contain valid characters, so it's safe to interpolate here.
168        let sql = format!(
169            "INSERT OR REPLACE INTO \"{}\" (key, value) VALUES (?1, ?2)",
170            validate_identifier(T::NAME)?,
171        );
172        for (key, value) in values {
173            let value = serde_json::to_string(&value)?;
174            transaction.execute(&sql, [&key, &value])?;
175        }
176
177        transaction.commit()?;
178        Ok(())
179    }
180
181    async fn remove<T: Serialize + DeserializeOwned + RepositoryItem>(
182        &self,
183        key: &str,
184    ) -> Result<(), DatabaseError> {
185        let mut conn = self.0.lock().await;
186        let transaction = conn.transaction()?;
187
188        // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
189        // validated to only contain valid characters, so it's safe to interpolate here.
190        transaction.execute(
191            &format!(
192                "DELETE FROM \"{}\" WHERE key = ?1",
193                validate_identifier(T::NAME)?
194            ),
195            [key],
196        )?;
197
198        transaction.commit()?;
199        Ok(())
200    }
201
202    async fn remove_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
203        &self,
204        keys: Vec<String>,
205    ) -> Result<(), DatabaseError> {
206        let mut conn = self.0.lock().await;
207        let transaction = conn.transaction()?;
208
209        // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
210        // validated to only contain valid characters, so it's safe to interpolate here.
211        let sql = format!(
212            "DELETE FROM \"{}\" WHERE key = ?1",
213            validate_identifier(T::NAME)?
214        );
215        for key in keys {
216            transaction.execute(&sql, [&key])?;
217        }
218
219        transaction.commit()?;
220        Ok(())
221    }
222
223    async fn remove_all<T: Serialize + DeserializeOwned + RepositoryItem>(
224        &self,
225    ) -> Result<(), DatabaseError> {
226        let mut conn = self.0.lock().await;
227        let transaction = conn.transaction()?;
228
229        // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
230        // validated to only contain valid characters, so it's safe to interpolate here.
231        transaction.execute(
232            &format!("DELETE FROM \"{}\"", validate_identifier(T::NAME)?),
233            [],
234        )?;
235
236        transaction.commit()?;
237        Ok(())
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use crate::register_repository_item;
245
246    #[tokio::test]
247    async fn test_sqlite_integration() {
248        let db = rusqlite::Connection::open_in_memory().unwrap();
249
250        #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
251        struct TestA(usize);
252        register_repository_item!(String => TestA, "TestItem_A");
253
254        #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
255        struct TestB(usize);
256        register_repository_item!(String => TestB, "TestItem_B");
257
258        let steps = vec![
259            // Test that deleting a table that doesn't exist is fine
260            RepositoryMigrationStep::Remove(TestB::data()),
261            RepositoryMigrationStep::Add(TestA::data()),
262            RepositoryMigrationStep::Add(TestB::data()),
263            // Test that deleting a table that does exist is also fine
264            RepositoryMigrationStep::Remove(TestB::data()),
265        ];
266        let migrations = RepositoryMigrations::new(steps);
267
268        let db = SqliteDatabase::initialize_internal(db, migrations).unwrap();
269
270        assert_eq!(db.list::<TestA>().await.unwrap(), Vec::<TestA>::new());
271
272        db.set("key1", TestA(42)).await.unwrap();
273        assert_eq!(db.get::<TestA>("key1").await.unwrap(), Some(TestA(42)));
274
275        db.remove::<TestA>("key1").await.unwrap();
276
277        assert_eq!(db.get::<TestA>("key1").await.unwrap(), None);
278    }
279
280    #[tokio::test]
281    async fn test_sqlite_database_path_construction() {
282        let temp_dir = std::env::temp_dir().join("bitwarden_state_test");
283        std::fs::create_dir_all(&temp_dir).unwrap();
284
285        let config = DatabaseConfiguration::Sqlite {
286            db_name: "test_db".to_string(),
287            folder_path: temp_dir.clone(),
288        };
289
290        SqliteDatabase::initialize(config, RepositoryMigrations::new(vec![]))
291            .await
292            .unwrap();
293
294        assert!(temp_dir.join("test_db.sqlite").exists());
295
296        std::fs::remove_dir_all(&temp_dir).ok();
297    }
298}