bitwarden_state/sdk_managed/
sqlite.rs1use std::sync::Arc;
2
3use serde::{de::DeserializeOwned, ser::Serialize};
4use tokio::sync::Mutex;
5
6use crate::{
7 repository::{
8 RepositoryItem, RepositoryMigrationStep, RepositoryMigrations, validate_registry_name,
9 },
10 sdk_managed::{Database, DatabaseConfiguration, DatabaseError},
11};
12
13#[derive(Clone)]
15pub struct SqliteDatabase(Arc<Mutex<rusqlite::Connection>>);
16
17fn validate_identifier(name: &'static str) -> Result<&'static str, DatabaseError> {
18 if validate_registry_name(name) {
19 Ok(name)
20 } else {
21 Err(DatabaseError::Internal(
22 rusqlite::Error::InvalidParameterName(name.to_string()),
23 ))
24 }
25}
26
27impl SqliteDatabase {
28 fn initialize_internal(
29 mut db: rusqlite::Connection,
30 migrations: RepositoryMigrations,
31 ) -> Result<Self, DatabaseError> {
32 db.pragma_update(None, "journal_mode", "WAL")?;
34
35 let transaction = db.transaction()?;
36
37 for step in &migrations.steps {
38 match step {
39 RepositoryMigrationStep::Add(data) => {
40 transaction.execute(
44 &format!(
45 "CREATE TABLE IF NOT EXISTS \"{}\" (key TEXT PRIMARY KEY, value TEXT NOT NULL);",
46 validate_identifier(data.name())?,
47 ),
48 [],
49 )?;
50 }
51 RepositoryMigrationStep::Remove(data) => {
52 transaction.execute(
56 &format!(
57 "DROP TABLE IF EXISTS \"{}\";",
58 validate_identifier(data.name())?,
59 ),
60 [],
61 )?;
62 }
63 }
64 }
65
66 transaction.commit()?;
67 Ok(SqliteDatabase(Arc::new(Mutex::new(db))))
68 }
69}
70
71impl Database for SqliteDatabase {
72 async fn initialize(
73 configuration: DatabaseConfiguration,
74 migrations: RepositoryMigrations,
75 ) -> Result<Self, DatabaseError> {
76 let DatabaseConfiguration::Sqlite {
77 db_name,
78 folder_path: mut path,
79 } = configuration
80 else {
81 return Err(DatabaseError::UnsupportedConfiguration(configuration));
82 };
83 path.push(format!("{db_name}.sqlite"));
84
85 let db = rusqlite::Connection::open(path)?;
86 Self::initialize_internal(db, migrations)
87 }
88
89 async fn get<T: Serialize + DeserializeOwned + RepositoryItem>(
90 &self,
91 key: &str,
92 ) -> Result<Option<T>, DatabaseError> {
93 let conn = self.0.lock().await;
94
95 let mut stmt = conn.prepare(&format!(
98 "SELECT value FROM \"{}\" WHERE key = ?1",
99 validate_identifier(T::NAME)?
100 ))?;
101 let mut rows = stmt.query([key])?;
102
103 if let Some(row) = rows.next()? {
104 let value = row.get::<_, String>(0)?;
105
106 Ok(Some(serde_json::from_str(&value)?))
107 } else {
108 Ok(None)
109 }
110 }
111
112 async fn list<T: Serialize + DeserializeOwned + RepositoryItem>(
113 &self,
114 ) -> Result<Vec<T>, DatabaseError> {
115 let conn = self.0.lock().await;
116
117 let mut stmt = conn.prepare(&format!(
120 "SELECT key, value FROM \"{}\"",
121 validate_identifier(T::NAME)?
122 ))?;
123 let rows = stmt.query_map([], |row| row.get(1))?;
124
125 let mut results = Vec::new();
126 for row in rows {
127 let value: String = row?;
128 let value: T = serde_json::from_str(&value)?;
129 results.push(value);
130 }
131
132 Ok(results)
133 }
134
135 async fn set<T: Serialize + DeserializeOwned + RepositoryItem>(
136 &self,
137 key: &str,
138 value: T,
139 ) -> Result<(), DatabaseError> {
140 let mut conn = self.0.lock().await;
141 let transaction = conn.transaction()?;
142
143 let value = serde_json::to_string(&value)?;
144
145 transaction.execute(
148 &format!(
149 "INSERT OR REPLACE INTO \"{}\" (key, value) VALUES (?1, ?2)",
150 validate_identifier(T::NAME)?,
151 ),
152 [key, &value],
153 )?;
154
155 transaction.commit()?;
156 Ok(())
157 }
158
159 async fn set_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
160 &self,
161 values: Vec<(String, T)>,
162 ) -> Result<(), DatabaseError> {
163 let mut conn = self.0.lock().await;
164 let transaction = conn.transaction()?;
165
166 let sql = format!(
169 "INSERT OR REPLACE INTO \"{}\" (key, value) VALUES (?1, ?2)",
170 validate_identifier(T::NAME)?,
171 );
172 for (key, value) in values {
173 let value = serde_json::to_string(&value)?;
174 transaction.execute(&sql, [&key, &value])?;
175 }
176
177 transaction.commit()?;
178 Ok(())
179 }
180
181 async fn remove<T: Serialize + DeserializeOwned + RepositoryItem>(
182 &self,
183 key: &str,
184 ) -> Result<(), DatabaseError> {
185 let mut conn = self.0.lock().await;
186 let transaction = conn.transaction()?;
187
188 transaction.execute(
191 &format!(
192 "DELETE FROM \"{}\" WHERE key = ?1",
193 validate_identifier(T::NAME)?
194 ),
195 [key],
196 )?;
197
198 transaction.commit()?;
199 Ok(())
200 }
201
202 async fn remove_bulk<T: Serialize + DeserializeOwned + RepositoryItem>(
203 &self,
204 keys: Vec<String>,
205 ) -> Result<(), DatabaseError> {
206 let mut conn = self.0.lock().await;
207 let transaction = conn.transaction()?;
208
209 let sql = format!(
212 "DELETE FROM \"{}\" WHERE key = ?1",
213 validate_identifier(T::NAME)?
214 );
215 for key in keys {
216 transaction.execute(&sql, [&key])?;
217 }
218
219 transaction.commit()?;
220 Ok(())
221 }
222
223 async fn remove_all<T: Serialize + DeserializeOwned + RepositoryItem>(
224 &self,
225 ) -> Result<(), DatabaseError> {
226 let mut conn = self.0.lock().await;
227 let transaction = conn.transaction()?;
228
229 transaction.execute(
232 &format!("DELETE FROM \"{}\"", validate_identifier(T::NAME)?),
233 [],
234 )?;
235
236 transaction.commit()?;
237 Ok(())
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use crate::register_repository_item;
245
246 #[tokio::test]
247 async fn test_sqlite_integration() {
248 let db = rusqlite::Connection::open_in_memory().unwrap();
249
250 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
251 struct TestA(usize);
252 register_repository_item!(String => TestA, "TestItem_A");
253
254 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
255 struct TestB(usize);
256 register_repository_item!(String => TestB, "TestItem_B");
257
258 let steps = vec![
259 RepositoryMigrationStep::Remove(TestB::data()),
261 RepositoryMigrationStep::Add(TestA::data()),
262 RepositoryMigrationStep::Add(TestB::data()),
263 RepositoryMigrationStep::Remove(TestB::data()),
265 ];
266 let migrations = RepositoryMigrations::new(steps);
267
268 let db = SqliteDatabase::initialize_internal(db, migrations).unwrap();
269
270 assert_eq!(db.list::<TestA>().await.unwrap(), Vec::<TestA>::new());
271
272 db.set("key1", TestA(42)).await.unwrap();
273 assert_eq!(db.get::<TestA>("key1").await.unwrap(), Some(TestA(42)));
274
275 db.remove::<TestA>("key1").await.unwrap();
276
277 assert_eq!(db.get::<TestA>("key1").await.unwrap(), None);
278 }
279
280 #[tokio::test]
281 async fn test_sqlite_database_path_construction() {
282 let temp_dir = std::env::temp_dir().join("bitwarden_state_test");
283 std::fs::create_dir_all(&temp_dir).unwrap();
284
285 let config = DatabaseConfiguration::Sqlite {
286 db_name: "test_db".to_string(),
287 folder_path: temp_dir.clone(),
288 };
289
290 SqliteDatabase::initialize(config, RepositoryMigrations::new(vec![]))
291 .await
292 .unwrap();
293
294 assert!(temp_dir.join("test_db.sqlite").exists());
295
296 std::fs::remove_dir_all(&temp_dir).ok();
297 }
298}