bitwarden_state/sdk_managed/
sqlite.rs

1use std::sync::Arc;
2
3use serde::{de::DeserializeOwned, ser::Serialize};
4use tokio::sync::Mutex;
5
6use crate::{
7    repository::{
8        RepositoryItem, RepositoryMigrationStep, RepositoryMigrations, validate_registry_name,
9    },
10    sdk_managed::{Database, DatabaseConfiguration, DatabaseError},
11};
12
13// TODO: Use connection pooling with r2d2 and r2d2_sqlite?
14#[derive(Clone)]
15pub struct SqliteDatabase(Arc<Mutex<rusqlite::Connection>>);
16
17fn validate_identifier(name: &'static str) -> Result<&'static str, DatabaseError> {
18    if validate_registry_name(name) {
19        Ok(name)
20    } else {
21        Err(DatabaseError::Internal(
22            rusqlite::Error::InvalidParameterName(name.to_string()),
23        ))
24    }
25}
26
27impl SqliteDatabase {
28    fn initialize_internal(
29        mut db: rusqlite::Connection,
30        migrations: RepositoryMigrations,
31    ) -> Result<Self, DatabaseError> {
32        // Set WAL mode for better concurrency
33        db.pragma_update(None, "journal_mode", "WAL")?;
34
35        let transaction = db.transaction()?;
36
37        for step in &migrations.steps {
38            match step {
39                RepositoryMigrationStep::Add(data) => {
40                    // SAFETY: SQLite tables cannot use ?, but `reg.name()` is not user controlled
41                    // and is validated to only contain valid characters, so
42                    // it's safe to interpolate here.
43                    transaction.execute(
44                        &format!(
45                            "CREATE TABLE IF NOT EXISTS \"{}\" (key TEXT PRIMARY KEY, value TEXT NOT NULL);",
46                            validate_identifier(data.name())?,
47                        ),
48                        [],
49                    )?;
50                }
51                RepositoryMigrationStep::Remove(data) => {
52                    // SAFETY: SQLite tables cannot use ?, but `reg.name()` is not user controlled
53                    // and is validated to only contain valid characters, so
54                    // it's safe to interpolate here.
55                    transaction.execute(
56                        &format!(
57                            "DROP TABLE IF EXISTS \"{}\";",
58                            validate_identifier(data.name())?,
59                        ),
60                        [],
61                    )?;
62                }
63            }
64        }
65
66        transaction.commit()?;
67        Ok(SqliteDatabase(Arc::new(Mutex::new(db))))
68    }
69}
70
71impl Database for SqliteDatabase {
72    async fn initialize(
73        configuration: DatabaseConfiguration,
74        migrations: RepositoryMigrations,
75    ) -> Result<Self, DatabaseError> {
76        let DatabaseConfiguration::Sqlite {
77            db_name,
78            folder_path: mut path,
79        } = configuration
80        else {
81            return Err(DatabaseError::UnsupportedConfiguration(configuration));
82        };
83        path.set_file_name(format!("{db_name}.sqlite"));
84
85        let db = rusqlite::Connection::open(path)?;
86        Self::initialize_internal(db, migrations)
87    }
88
89    async fn get<T: Serialize + DeserializeOwned + RepositoryItem>(
90        &self,
91        key: &str,
92    ) -> Result<Option<T>, DatabaseError> {
93        let conn = self.0.lock().await;
94
95        // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
96        // validated to only contain valid characters, so it's safe to interpolate here.
97        let mut stmt = conn.prepare(&format!(
98            "SELECT value FROM \"{}\" WHERE key = ?1",
99            validate_identifier(T::NAME)?
100        ))?;
101        let mut rows = stmt.query([key])?;
102
103        if let Some(row) = rows.next()? {
104            let value = row.get::<_, String>(0)?;
105
106            Ok(Some(serde_json::from_str(&value)?))
107        } else {
108            Ok(None)
109        }
110    }
111
112    async fn list<T: Serialize + DeserializeOwned + RepositoryItem>(
113        &self,
114    ) -> Result<Vec<T>, DatabaseError> {
115        let conn = self.0.lock().await;
116
117        // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
118        // validated to only contain valid characters, so it's safe to interpolate here.
119        let mut stmt = conn.prepare(&format!(
120            "SELECT key, value FROM \"{}\"",
121            validate_identifier(T::NAME)?
122        ))?;
123        let rows = stmt.query_map([], |row| row.get(1))?;
124
125        let mut results = Vec::new();
126        for row in rows {
127            let value: String = row?;
128            let value: T = serde_json::from_str(&value)?;
129            results.push(value);
130        }
131
132        Ok(results)
133    }
134
135    async fn set<T: Serialize + DeserializeOwned + RepositoryItem>(
136        &self,
137        key: &str,
138        value: T,
139    ) -> Result<(), DatabaseError> {
140        let mut conn = self.0.lock().await;
141        let transaction = conn.transaction()?;
142
143        let value = serde_json::to_string(&value)?;
144
145        // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
146        // validated to only contain valid characters, so it's safe to interpolate here.
147        transaction.execute(
148            &format!(
149                "INSERT OR REPLACE INTO \"{}\" (key, value) VALUES (?1, ?2)",
150                validate_identifier(T::NAME)?,
151            ),
152            [key, &value],
153        )?;
154
155        transaction.commit()?;
156        Ok(())
157    }
158
159    async fn remove<T: Serialize + DeserializeOwned + RepositoryItem>(
160        &self,
161        key: &str,
162    ) -> Result<(), DatabaseError> {
163        let mut conn = self.0.lock().await;
164        let transaction = conn.transaction()?;
165
166        // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
167        // validated to only contain valid characters, so it's safe to interpolate here.
168        transaction.execute(
169            &format!(
170                "DELETE FROM \"{}\" WHERE key = ?1",
171                validate_identifier(T::NAME)?
172            ),
173            [key],
174        )?;
175
176        transaction.commit()?;
177        Ok(())
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use crate::register_repository_item;
185
186    #[tokio::test]
187    async fn test_sqlite_integration() {
188        let db = rusqlite::Connection::open_in_memory().unwrap();
189
190        #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
191        struct TestA(usize);
192        register_repository_item!(TestA, "TestItem_A");
193
194        #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
195        struct TestB(usize);
196        register_repository_item!(TestB, "TestItem_B");
197
198        let steps = vec![
199            // Test that deleting a table that doesn't exist is fine
200            RepositoryMigrationStep::Remove(TestB::data()),
201            RepositoryMigrationStep::Add(TestA::data()),
202            RepositoryMigrationStep::Add(TestB::data()),
203            // Test that deleting a table that does exist is also fine
204            RepositoryMigrationStep::Remove(TestB::data()),
205        ];
206        let migrations = RepositoryMigrations::new(steps);
207
208        let db = SqliteDatabase::initialize_internal(db, migrations).unwrap();
209
210        assert_eq!(db.list::<TestA>().await.unwrap(), Vec::<TestA>::new());
211
212        db.set("key1", TestA(42)).await.unwrap();
213        assert_eq!(db.get::<TestA>("key1").await.unwrap(), Some(TestA(42)));
214
215        db.remove::<TestA>("key1").await.unwrap();
216
217        assert_eq!(db.get::<TestA>("key1").await.unwrap(), None);
218    }
219}