bitwarden_state/sdk_managed/
sqlite.rs1use std::sync::Arc;
2
3use serde::{de::DeserializeOwned, ser::Serialize};
4use tokio::sync::Mutex;
5
6use crate::{
7 repository::{validate_registry_name, RepositoryItem, RepositoryItemData},
8 sdk_managed::{Database, DatabaseConfiguration, DatabaseError},
9};
10
11#[derive(Clone)]
13pub struct SqliteDatabase(Arc<Mutex<rusqlite::Connection>>);
14
15fn validate_identifier(name: &'static str) -> Result<&'static str, DatabaseError> {
16 if validate_registry_name(name) {
17 Ok(name)
18 } else {
19 Err(DatabaseError::Internal(
20 rusqlite::Error::InvalidParameterName(name.to_string()),
21 ))
22 }
23}
24
25impl SqliteDatabase {
26 fn initialize_internal(
27 mut db: rusqlite::Connection,
28 registrations: &[RepositoryItemData],
29 ) -> Result<Self, DatabaseError> {
30 db.pragma_update(None, "journal_mode", "WAL")?;
32
33 let transaction = db.transaction()?;
34
35 for reg in registrations {
36 transaction.execute(
40 &format!(
41 "CREATE TABLE IF NOT EXISTS \"{}\" (key TEXT PRIMARY KEY, value TEXT NOT NULL);",
42 validate_identifier(reg.name())?,
43 ),
44 [],
45 )?;
46 }
47
48 transaction.commit()?;
49 Ok(SqliteDatabase(Arc::new(Mutex::new(db))))
50 }
51}
52
53impl Database for SqliteDatabase {
54 async fn initialize(
55 configuration: DatabaseConfiguration,
56 registrations: &[RepositoryItemData],
57 ) -> Result<Self, DatabaseError> {
58 let DatabaseConfiguration::Sqlite {
59 db_name,
60 folder_path: mut path,
61 } = configuration
62 else {
63 return Err(DatabaseError::UnsupportedConfiguration(configuration));
64 };
65 path.set_file_name(format!("{db_name}.sqlite"));
66
67 let db = rusqlite::Connection::open(path)?;
68 Self::initialize_internal(db, registrations)
69 }
70
71 async fn get<T: Serialize + DeserializeOwned + RepositoryItem>(
72 &self,
73 key: &str,
74 ) -> Result<Option<T>, DatabaseError> {
75 let conn = self.0.lock().await;
76
77 let mut stmt = conn.prepare(&format!(
80 "SELECT value FROM \"{}\" WHERE key = ?1",
81 validate_identifier(T::NAME)?
82 ))?;
83 let mut rows = stmt.query([key])?;
84
85 if let Some(row) = rows.next()? {
86 let value = row.get::<_, String>(0)?;
87
88 Ok(Some(serde_json::from_str(&value)?))
89 } else {
90 Ok(None)
91 }
92 }
93
94 async fn list<T: Serialize + DeserializeOwned + RepositoryItem>(
95 &self,
96 ) -> Result<Vec<T>, DatabaseError> {
97 let conn = self.0.lock().await;
98
99 let mut stmt = conn.prepare(&format!(
102 "SELECT key, value FROM \"{}\"",
103 validate_identifier(T::NAME)?
104 ))?;
105 let rows = stmt.query_map([], |row| row.get(1))?;
106
107 let mut results = Vec::new();
108 for row in rows {
109 let value: String = row?;
110 let value: T = serde_json::from_str(&value)?;
111 results.push(value);
112 }
113
114 Ok(results)
115 }
116
117 async fn set<T: Serialize + DeserializeOwned + RepositoryItem>(
118 &self,
119 key: &str,
120 value: T,
121 ) -> Result<(), DatabaseError> {
122 let mut conn = self.0.lock().await;
123 let transaction = conn.transaction()?;
124
125 let value = serde_json::to_string(&value)?;
126
127 transaction.execute(
130 &format!(
131 "INSERT OR REPLACE INTO \"{}\" (key, value) VALUES (?1, ?2)",
132 validate_identifier(T::NAME)?,
133 ),
134 [key, &value],
135 )?;
136
137 transaction.commit()?;
138 Ok(())
139 }
140
141 async fn remove<T: Serialize + DeserializeOwned + RepositoryItem>(
142 &self,
143 key: &str,
144 ) -> Result<(), DatabaseError> {
145 let mut conn = self.0.lock().await;
146 let transaction = conn.transaction()?;
147
148 transaction.execute(
151 &format!(
152 "DELETE FROM \"{}\" WHERE key = ?1",
153 validate_identifier(T::NAME)?
154 ),
155 [key],
156 )?;
157
158 transaction.commit()?;
159 Ok(())
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::register_repository_item;
167
168 #[tokio::test]
169 async fn test_sqlite_integration() {
170 let db = rusqlite::Connection::open_in_memory().unwrap();
171
172 #[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
173 struct TestA(usize);
174 register_repository_item!(TestA, "TestItem_A");
175
176 let registrations = vec![TestA::data()];
177
178 let db = SqliteDatabase::initialize_internal(db, ®istrations).unwrap();
179
180 assert_eq!(db.list::<TestA>().await.unwrap(), Vec::<TestA>::new());
181
182 db.set("key1", TestA(42)).await.unwrap();
183 assert_eq!(db.get::<TestA>("key1").await.unwrap(), Some(TestA(42)));
184
185 db.remove::<TestA>("key1").await.unwrap();
186
187 assert_eq!(db.get::<TestA>("key1").await.unwrap(), None);
188 }
189}