bitwarden_state/
repository.rs1use std::{any::TypeId, sync::Arc};
2
3use serde::{Serialize, de::DeserializeOwned};
4
5use crate::registry::StateRegistryError;
6
7#[derive(thiserror::Error, Debug)]
9pub enum RepositoryError {
10 #[error("Internal error: {0}")]
12 Internal(String),
13
14 #[error(transparent)]
16 Serde(#[from] serde_json::Error),
17
18 #[error(transparent)]
20 Database(#[from] crate::sdk_managed::DatabaseError),
21
22 #[error(transparent)]
24 StateRegistry(#[from] StateRegistryError),
25}
26
27pub trait RepositoryOption<V: RepositoryItem> {
30 fn require(&self) -> Result<&Arc<dyn Repository<V>>, RepositoryError>;
33}
34
35impl<V: RepositoryItem> RepositoryOption<V> for Option<Arc<dyn Repository<V>>> {
36 fn require(&self) -> Result<&Arc<dyn Repository<V>>, RepositoryError> {
37 self.as_ref()
38 .ok_or(StateRegistryError::DatabaseNotInitialized.into())
39 }
40}
41
42#[async_trait::async_trait]
45pub trait Repository<V: RepositoryItem>: Send + Sync {
46 async fn get(&self, key: V::Key) -> Result<Option<V>, RepositoryError>;
48 async fn list(&self) -> Result<Vec<V>, RepositoryError>;
50 async fn set(&self, key: V::Key, value: V) -> Result<(), RepositoryError>;
52 async fn set_bulk(&self, values: Vec<(V::Key, V)>) -> Result<(), RepositoryError>;
54 async fn remove(&self, key: V::Key) -> Result<(), RepositoryError>;
56 async fn remove_bulk(&self, keys: Vec<V::Key>) -> Result<(), RepositoryError>;
58 async fn remove_all(&self) -> Result<(), RepositoryError>;
60
61 async fn replace_all(&self, values: Vec<(V::Key, V)>) -> Result<(), RepositoryError> {
66 self.remove_all().await?;
67 self.set_bulk(values).await
68 }
69}
70
71pub trait RepositoryItem: Internal + Serialize + DeserializeOwned + Send + Sync + 'static {
78 const NAME: &'static str;
80
81 type Key: ToString + Send + Sync + 'static;
83
84 fn type_id() -> TypeId {
86 TypeId::of::<Self>()
87 }
88
89 fn data() -> RepositoryItemData {
91 RepositoryItemData::new::<Self>()
92 }
93}
94
95#[allow(dead_code)]
97#[derive(Debug, Clone, Copy)]
98pub struct RepositoryItemData {
99 type_id: TypeId,
100 name: &'static str,
101}
102
103impl RepositoryItemData {
104 pub fn new<T: RepositoryItem>() -> Self {
106 Self {
107 type_id: TypeId::of::<T>(),
108 name: T::NAME,
109 }
110 }
111
112 pub fn type_id(&self) -> TypeId {
114 self.type_id
115 }
116 pub fn name(&self) -> &'static str {
119 self.name
120 }
121}
122
123pub const fn validate_registry_name(name: &str) -> bool {
128 let bytes = name.as_bytes();
129 let mut i = 0;
130 while i < bytes.len() {
131 let byte = bytes[i];
132 if !((byte >= b'a' && byte <= b'z') || (byte >= b'A' && byte <= b'Z') || byte == b'_') {
134 return false;
135 }
136 i += 1;
137 }
138 true
139}
140
141#[derive(Debug, Clone)]
143pub struct RepositoryMigrations {
144 pub(crate) steps: Vec<RepositoryMigrationStep>,
145 #[allow(dead_code)]
147 pub(crate) version: u32,
148}
149
150#[derive(Debug, Clone, Copy)]
152pub enum RepositoryMigrationStep {
153 Add(RepositoryItemData),
155 Remove(RepositoryItemData),
157}
158
159impl RepositoryMigrations {
160 pub fn new(steps: Vec<RepositoryMigrationStep>) -> Self {
163 Self {
164 version: steps.len() as u32,
165 steps,
166 }
167 }
168
169 pub fn into_repository_items(self) -> Vec<RepositoryItemData> {
171 let mut map = std::collections::HashMap::new();
172 for step in self.steps {
173 match step {
174 RepositoryMigrationStep::Add(data) => {
175 map.insert(data.type_id, data);
176 }
177 RepositoryMigrationStep::Remove(data) => {
178 map.remove(&data.type_id);
179 }
180 }
181 }
182 map.into_values().collect()
183 }
184}
185
186#[macro_export]
189macro_rules! register_repository_item {
190 ($keyty:ty => $ty:ty, $name:literal) => {
191 const _: () = {
192 impl $crate::repository::___internal::Internal for $ty {}
193 impl $crate::repository::RepositoryItem for $ty {
194 const NAME: &'static str = $name;
195 type Key = $keyty;
196 }
197 assert!(
198 $crate::repository::validate_registry_name($name),
199 concat!(
200 "Repository name '",
201 $name,
202 "' must contain only alphabetic characters and underscores"
203 )
204 )
205 };
206 };
207}
208
209#[doc(hidden)]
212pub mod ___internal {
213
214 pub trait Internal {}
217}
218pub(crate) use ___internal::Internal;
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_validate_name() {
226 assert!(validate_registry_name("valid"));
227 assert!(validate_registry_name("Valid_Name"));
228 assert!(!validate_registry_name("Invalid-Name"));
229 assert!(!validate_registry_name("Invalid Name"));
230 assert!(!validate_registry_name("Invalid.Name"));
231 assert!(!validate_registry_name("Invalid123"));
232 }
233}