bitwarden_state/sdk_managed/
sqlite.rs1use std::sync::Arc;
2
3use serde::{de::DeserializeOwned, ser::Serialize};
4use tokio::sync::Mutex;
5
6use crate::{
7 repository::{
8 RepositoryItem, RepositoryMigrationStep, RepositoryMigrations, validate_registry_name,
9 },
10 sdk_managed::{Database, DatabaseConfiguration, DatabaseError},
11};
12
13#[derive(Clone)]
15pub struct SqliteDatabase(Arc<Mutex<rusqlite::Connection>>);
16
17fn validate_identifier(name: &'static str) -> Result<&'static str, DatabaseError> {
18 if validate_registry_name(name) {
19 Ok(name)
20 } else {
21 Err(DatabaseError::Internal(
22 rusqlite::Error::InvalidParameterName(name.to_string()),
23 ))
24 }
25}
26
27impl SqliteDatabase {
28 fn initialize_internal(
29 mut db: rusqlite::Connection,
30 migrations: RepositoryMigrations,
31 ) -> Result<Self, DatabaseError> {
32 db.pragma_update(None, "journal_mode", "WAL")?;
34
35 let transaction = db.transaction()?;
36
37 for step in &migrations.steps {
38 match step {
39 RepositoryMigrationStep::Add(data) => {
40 transaction.execute(
44 &format!(
45 "CREATE TABLE IF NOT EXISTS \"{}\" (key TEXT PRIMARY KEY, value TEXT NOT NULL);",
46 validate_identifier(data.name())?,
47 ),
48 [],
49 )?;
50 }
51 RepositoryMigrationStep::Remove(data) => {
52 transaction.execute(
56 &format!(
57 "DROP TABLE IF EXISTS \"{}\";",
58 validate_identifier(data.name())?,
59 ),
60 [],
61 )?;
62 }
63 }
64 }
65
66 transaction.commit()?;
67 Ok(SqliteDatabase(Arc::new(Mutex::new(db))))
68 }
69}
70
71impl Database for SqliteDatabase {
72 async fn initialize(
73 configuration: DatabaseConfiguration,
74 migrations: RepositoryMigrations,
75 ) -> Result<Self, DatabaseError> {
76 let DatabaseConfiguration::Sqlite {
77 db_name,
78 folder_path: mut path,
79 } = configuration
80 else {
81 return Err(DatabaseError::UnsupportedConfiguration(configuration));
82 };
83 path.set_file_name(format!("{db_name}.sqlite"));
84
85 let db = rusqlite::Connection::open(path)?;
86 Self::initialize_internal(db, migrations)
87 }
88
89 async fn get<T: Serialize + DeserializeOwned + RepositoryItem>(
90 &self,
91 key: &str,
92 ) -> Result<Option<T>, DatabaseError> {
93 let conn = self.0.lock().await;
94
95 let mut stmt = conn.prepare(&format!(
98 "SELECT value FROM \"{}\" WHERE key = ?1",
99 validate_identifier(T::NAME)?
100 ))?;
101 let mut rows = stmt.query([key])?;
102
103 if let Some(row) = rows.next()? {
104 let value = row.get::<_, String>(0)?;
105
106 Ok(Some(serde_json::from_str(&value)?))
107 } else {
108 Ok(None)
109 }
110 }
111
112 async fn list<T: Serialize + DeserializeOwned + RepositoryItem>(
113 &self,
114 ) -> Result<Vec<T>, DatabaseError> {
115 let conn = self.0.lock().await;
116
117 let mut stmt = conn.prepare(&format!(
120 "SELECT key, value FROM \"{}\"",
121 validate_identifier(T::NAME)?
122 ))?;
123 let rows = stmt.query_map([], |row| row.get(1))?;
124
125 let mut results = Vec::new();
126 for row in rows {
127 let value: String = row?;
128 let value: T = serde_json::from_str(&value)?;
129 results.push(value);
130 }
131
132 Ok(results)
133 }
134
135 async fn set<T: Serialize + DeserializeOwned + RepositoryItem>(
136 &self,
137 key: &str,
138 value: T,
139 ) -> Result<(), DatabaseError> {
140 let mut conn = self.0.lock().await;
141 let transaction = conn.transaction()?;
142
143 let value = serde_json::to_string(&value)?;
144
145 transaction.execute(
148 &format!(
149 "INSERT OR REPLACE INTO \"{}\" (key, value) VALUES (?1, ?2)",
150 validate_identifier(T::NAME)?,
151 ),
152 [key, &value],
153 )?;
154
155 transaction.commit()?;
156 Ok(())
157 }
158
159 async fn remove<T: Serialize + DeserializeOwned + RepositoryItem>(
160 &self,
161 key: &str,
162 ) -> Result<(), DatabaseError> {
163 let mut conn = self.0.lock().await;
164 let transaction = conn.transaction()?;
165
166 transaction.execute(
169 &format!(
170 "DELETE FROM \"{}\" WHERE key = ?1",
171 validate_identifier(T::NAME)?
172 ),
173 [key],
174 )?;
175
176 transaction.commit()?;
177 Ok(())
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use crate::register_repository_item;
185
186 #[tokio::test]
187 async fn test_sqlite_integration() {
188 let db = rusqlite::Connection::open_in_memory().unwrap();
189
190 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
191 struct TestA(usize);
192 register_repository_item!(TestA, "TestItem_A");
193
194 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
195 struct TestB(usize);
196 register_repository_item!(TestB, "TestItem_B");
197
198 let steps = vec![
199 RepositoryMigrationStep::Remove(TestB::data()),
201 RepositoryMigrationStep::Add(TestA::data()),
202 RepositoryMigrationStep::Add(TestB::data()),
203 RepositoryMigrationStep::Remove(TestB::data()),
205 ];
206 let migrations = RepositoryMigrations::new(steps);
207
208 let db = SqliteDatabase::initialize_internal(db, migrations).unwrap();
209
210 assert_eq!(db.list::<TestA>().await.unwrap(), Vec::<TestA>::new());
211
212 db.set("key1", TestA(42)).await.unwrap();
213 assert_eq!(db.get::<TestA>("key1").await.unwrap(), Some(TestA(42)));
214
215 db.remove::<TestA>("key1").await.unwrap();
216
217 assert_eq!(db.get::<TestA>("key1").await.unwrap(), None);
218 }
219}