bitwarden_state/sdk_managed/
sqlite.rs

1use std::sync::Arc;
2
3use serde::{de::DeserializeOwned, ser::Serialize};
4use tokio::sync::Mutex;
5
6use crate::{
7    repository::{validate_registry_name, RepositoryItem, RepositoryItemData},
8    sdk_managed::{Database, DatabaseConfiguration, DatabaseError},
9};
10
11// TODO: Use connection pooling with r2d2 and r2d2_sqlite?
12#[derive(Clone)]
13pub struct SqliteDatabase(Arc<Mutex<rusqlite::Connection>>);
14
15fn validate_identifier(name: &'static str) -> Result<&'static str, DatabaseError> {
16    if validate_registry_name(name) {
17        Ok(name)
18    } else {
19        Err(DatabaseError::Internal(
20            rusqlite::Error::InvalidParameterName(name.to_string()),
21        ))
22    }
23}
24
25impl SqliteDatabase {
26    fn initialize_internal(
27        mut db: rusqlite::Connection,
28        registrations: &[RepositoryItemData],
29    ) -> Result<Self, DatabaseError> {
30        // Set WAL mode for better concurrency
31        db.pragma_update(None, "journal_mode", "WAL")?;
32
33        let transaction = db.transaction()?;
34
35        for reg in registrations {
36            // SAFETY: SQLite tables cannot use ?, but `reg.name()` is not user controlled and
37            // is validated to only contain valid characters, so it's safe to
38            // interpolate here.
39            transaction.execute(
40                &format!(
41                    "CREATE TABLE IF NOT EXISTS \"{}\" (key TEXT PRIMARY KEY, value TEXT NOT NULL);",
42                    validate_identifier(reg.name())?,
43                ),
44                [],
45            )?;
46        }
47
48        transaction.commit()?;
49        Ok(SqliteDatabase(Arc::new(Mutex::new(db))))
50    }
51}
52
53impl Database for SqliteDatabase {
54    async fn initialize(
55        configuration: DatabaseConfiguration,
56        registrations: &[RepositoryItemData],
57    ) -> Result<Self, DatabaseError> {
58        let DatabaseConfiguration::Sqlite {
59            db_name,
60            folder_path: mut path,
61        } = configuration
62        else {
63            return Err(DatabaseError::UnsupportedConfiguration(configuration));
64        };
65        path.set_file_name(format!("{db_name}.sqlite"));
66
67        let db = rusqlite::Connection::open(path)?;
68        Self::initialize_internal(db, registrations)
69    }
70
71    async fn get<T: Serialize + DeserializeOwned + RepositoryItem>(
72        &self,
73        key: &str,
74    ) -> Result<Option<T>, DatabaseError> {
75        let conn = self.0.lock().await;
76
77        // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
78        // validated to only contain valid characters, so it's safe to interpolate here.
79        let mut stmt = conn.prepare(&format!(
80            "SELECT value FROM \"{}\" WHERE key = ?1",
81            validate_identifier(T::NAME)?
82        ))?;
83        let mut rows = stmt.query([key])?;
84
85        if let Some(row) = rows.next()? {
86            let value = row.get::<_, String>(0)?;
87
88            Ok(Some(serde_json::from_str(&value)?))
89        } else {
90            Ok(None)
91        }
92    }
93
94    async fn list<T: Serialize + DeserializeOwned + RepositoryItem>(
95        &self,
96    ) -> Result<Vec<T>, DatabaseError> {
97        let conn = self.0.lock().await;
98
99        // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
100        // validated to only contain valid characters, so it's safe to interpolate here.
101        let mut stmt = conn.prepare(&format!(
102            "SELECT key, value FROM \"{}\"",
103            validate_identifier(T::NAME)?
104        ))?;
105        let rows = stmt.query_map([], |row| row.get(1))?;
106
107        let mut results = Vec::new();
108        for row in rows {
109            let value: String = row?;
110            let value: T = serde_json::from_str(&value)?;
111            results.push(value);
112        }
113
114        Ok(results)
115    }
116
117    async fn set<T: Serialize + DeserializeOwned + RepositoryItem>(
118        &self,
119        key: &str,
120        value: T,
121    ) -> Result<(), DatabaseError> {
122        let mut conn = self.0.lock().await;
123        let transaction = conn.transaction()?;
124
125        let value = serde_json::to_string(&value)?;
126
127        // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
128        // validated to only contain valid characters, so it's safe to interpolate here.
129        transaction.execute(
130            &format!(
131                "INSERT OR REPLACE INTO \"{}\" (key, value) VALUES (?1, ?2)",
132                validate_identifier(T::NAME)?,
133            ),
134            [key, &value],
135        )?;
136
137        transaction.commit()?;
138        Ok(())
139    }
140
141    async fn remove<T: Serialize + DeserializeOwned + RepositoryItem>(
142        &self,
143        key: &str,
144    ) -> Result<(), DatabaseError> {
145        let mut conn = self.0.lock().await;
146        let transaction = conn.transaction()?;
147
148        // SAFETY: SQLite tables cannot use ?, but `T::NAME` is not user controlled and is
149        // validated to only contain valid characters, so it's safe to interpolate here.
150        transaction.execute(
151            &format!(
152                "DELETE FROM \"{}\" WHERE key = ?1",
153                validate_identifier(T::NAME)?
154            ),
155            [key],
156        )?;
157
158        transaction.commit()?;
159        Ok(())
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::register_repository_item;
167
168    #[tokio::test]
169    async fn test_sqlite_integration() {
170        let db = rusqlite::Connection::open_in_memory().unwrap();
171
172        #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
173        struct TestA(usize);
174        register_repository_item!(TestA, "TestItem_A");
175
176        let registrations = vec![TestA::data()];
177
178        let db = SqliteDatabase::initialize_internal(db, &registrations).unwrap();
179
180        assert_eq!(db.list::<TestA>().await.unwrap(), Vec::<TestA>::new());
181
182        db.set("key1", TestA(42)).await.unwrap();
183        assert_eq!(db.get::<TestA>("key1").await.unwrap(), Some(TestA(42)));
184
185        db.remove::<TestA>("key1").await.unwrap();
186
187        assert_eq!(db.get::<TestA>("key1").await.unwrap(), None);
188    }
189}