bitwarden_state/
repository.rs1use std::any::TypeId;
2
3use serde::{Serialize, de::DeserializeOwned};
4
5use crate::registry::StateRegistryError;
6
7#[derive(thiserror::Error, Debug)]
9pub enum RepositoryError {
10 #[error("Internal error: {0}")]
12 Internal(String),
13
14 #[error(transparent)]
16 Serde(#[from] serde_json::Error),
17
18 #[error(transparent)]
20 Database(#[from] crate::sdk_managed::DatabaseError),
21
22 #[error(transparent)]
24 StateRegistry(#[from] StateRegistryError),
25}
26
27#[async_trait::async_trait]
30pub trait Repository<V: RepositoryItem>: Send + Sync {
31 async fn get(&self, key: V::Key) -> Result<Option<V>, RepositoryError>;
33 async fn list(&self) -> Result<Vec<V>, RepositoryError>;
35 async fn set(&self, key: V::Key, value: V) -> Result<(), RepositoryError>;
37 async fn set_bulk(&self, values: Vec<(V::Key, V)>) -> Result<(), RepositoryError>;
39 async fn remove(&self, key: V::Key) -> Result<(), RepositoryError>;
41 async fn remove_bulk(&self, keys: Vec<V::Key>) -> Result<(), RepositoryError>;
43 async fn remove_all(&self) -> Result<(), RepositoryError>;
45
46 async fn replace_all(&self, values: Vec<(V::Key, V)>) -> Result<(), RepositoryError> {
51 self.remove_all().await?;
52 self.set_bulk(values).await
53 }
54}
55
56pub trait RepositoryItem: Internal + Serialize + DeserializeOwned + Send + Sync + 'static {
63 const NAME: &'static str;
65
66 type Key: ToString + Send + Sync + 'static;
68
69 fn type_id() -> TypeId {
71 TypeId::of::<Self>()
72 }
73
74 fn data() -> RepositoryItemData {
76 RepositoryItemData::new::<Self>()
77 }
78}
79
80#[allow(dead_code)]
82#[derive(Debug, Clone, Copy)]
83pub struct RepositoryItemData {
84 type_id: TypeId,
85 name: &'static str,
86}
87
88impl RepositoryItemData {
89 pub fn new<T: RepositoryItem>() -> Self {
91 Self {
92 type_id: TypeId::of::<T>(),
93 name: T::NAME,
94 }
95 }
96
97 pub fn type_id(&self) -> TypeId {
99 self.type_id
100 }
101 pub fn name(&self) -> &'static str {
104 self.name
105 }
106}
107
108pub const fn validate_registry_name(name: &str) -> bool {
113 let bytes = name.as_bytes();
114 let mut i = 0;
115 while i < bytes.len() {
116 let byte = bytes[i];
117 if !((byte >= b'a' && byte <= b'z') || (byte >= b'A' && byte <= b'Z') || byte == b'_') {
119 return false;
120 }
121 i += 1;
122 }
123 true
124}
125
126#[derive(Debug, Clone)]
128pub struct RepositoryMigrations {
129 pub(crate) steps: Vec<RepositoryMigrationStep>,
130 #[allow(dead_code)]
132 pub(crate) version: u32,
133}
134
135#[derive(Debug, Clone, Copy)]
137pub enum RepositoryMigrationStep {
138 Add(RepositoryItemData),
140 Remove(RepositoryItemData),
142}
143
144impl RepositoryMigrations {
145 pub fn new(steps: Vec<RepositoryMigrationStep>) -> Self {
148 Self {
149 version: steps.len() as u32,
150 steps,
151 }
152 }
153
154 pub fn into_repository_items(self) -> Vec<RepositoryItemData> {
156 let mut map = std::collections::HashMap::new();
157 for step in self.steps {
158 match step {
159 RepositoryMigrationStep::Add(data) => {
160 map.insert(data.type_id, data);
161 }
162 RepositoryMigrationStep::Remove(data) => {
163 map.remove(&data.type_id);
164 }
165 }
166 }
167 map.into_values().collect()
168 }
169}
170
171#[macro_export]
174macro_rules! register_repository_item {
175 ($keyty:ty => $ty:ty, $name:literal) => {
176 const _: () = {
177 impl $crate::repository::___internal::Internal for $ty {}
178 impl $crate::repository::RepositoryItem for $ty {
179 const NAME: &'static str = $name;
180 type Key = $keyty;
181 }
182 assert!(
183 $crate::repository::validate_registry_name($name),
184 concat!(
185 "Repository name '",
186 $name,
187 "' must contain only alphabetic characters and underscores"
188 )
189 )
190 };
191 };
192}
193
194#[doc(hidden)]
197pub mod ___internal {
198
199 pub trait Internal {}
202}
203pub(crate) use ___internal::Internal;
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn test_validate_name() {
211 assert!(validate_registry_name("valid"));
212 assert!(validate_registry_name("Valid_Name"));
213 assert!(!validate_registry_name("Invalid-Name"));
214 assert!(!validate_registry_name("Invalid Name"));
215 assert!(!validate_registry_name("Invalid.Name"));
216 assert!(!validate_registry_name("Invalid123"));
217 }
218}