1use std::sync::Arc;
2
3use bitwarden_api_api::models::SyncResponseModel;
4use bitwarden_core::{
5 Client,
6 client::{ApiConfigurations, FromClientPart},
7};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10use tokio::sync::Mutex;
11
12use crate::{SyncErrorHandler, SyncHandler, SyncHandlerError, registry::HandlerRegistry};
13
14#[allow(missing_docs)]
15#[derive(Debug, Error)]
16pub enum SyncError {
17 #[error(transparent)]
18 Api(#[from] bitwarden_core::ApiError),
19
20 #[error("Sync event handler failed: {0}")]
21 HandlerFailed(#[source] SyncHandlerError),
22}
23
24#[allow(missing_docs)]
25#[derive(Serialize, Deserialize, Debug, Clone)]
26#[serde(rename_all = "camelCase", deny_unknown_fields)]
27pub struct SyncRequest {
28 pub exclude_subdomains: Option<bool>,
30}
31
32pub struct SyncClient {
37 api_configurations: Arc<ApiConfigurations>,
38 sync_handlers: HandlerRegistry<dyn SyncHandler>,
39 error_handlers: HandlerRegistry<dyn SyncErrorHandler>,
40 sync_lock: Mutex<()>,
41}
42
43impl SyncClient {
44 pub fn new(client: Client) -> Self {
46 Self {
47 api_configurations: client
48 .get_part()
49 .expect("ApiConfigurations should never fail"),
50 sync_handlers: HandlerRegistry::new(),
51 error_handlers: HandlerRegistry::new(),
52 sync_lock: Mutex::new(()),
53 }
54 }
55
56 pub fn register_sync_handler(&self, handler: Arc<dyn SyncHandler>) {
61 self.sync_handlers.register(handler);
62 }
63
64 pub fn register_error_handler(&self, handler: Arc<dyn SyncErrorHandler>) {
70 self.error_handlers.register(handler);
71 }
72
73 pub async fn sync(&self, request: SyncRequest) -> Result<SyncResponseModel, SyncError> {
87 let _guard = self.sync_lock.lock().await;
89
90 let result = async {
91 let response = self.perform_sync(&request).await?;
92 self.run_handlers(&response).await?;
93 Ok(response)
94 }
95 .await;
96
97 if let Err(ref error) = result {
98 self.run_error_handlers(error).await;
99 }
100
101 result
102 }
103
104 async fn run_handlers(&self, response: &SyncResponseModel) -> Result<(), SyncError> {
112 let handlers = self.sync_handlers.handlers();
113
114 for handler in &handlers {
115 handler
116 .on_sync(response)
117 .await
118 .map_err(SyncError::HandlerFailed)?;
119 }
120
121 for handler in &handlers {
122 handler.on_sync_complete().await;
123 }
124
125 Ok(())
126 }
127
128 async fn run_error_handlers(&self, error: &SyncError) {
132 for handler in &self.error_handlers.handlers() {
133 handler.on_error(error).await;
134 }
135 }
136
137 async fn perform_sync(&self, input: &SyncRequest) -> Result<SyncResponseModel, SyncError> {
139 let sync = self
140 .api_configurations
141 .api_client
142 .sync_api()
143 .get(input.exclude_subdomains)
144 .await
145 .map_err(|e| SyncError::Api(e.into()))?;
146
147 Ok(sync)
148 }
149}
150
151pub trait SyncClientExt {
156 fn sync(&self) -> SyncClient;
158}
159
160impl SyncClientExt for Client {
161 fn sync(&self) -> SyncClient {
162 SyncClient::new(self.clone())
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use std::sync::{Arc, Mutex};
169
170 use super::*;
171
172 struct TestHandler {
173 name: String,
174 execution_log: Arc<Mutex<Vec<String>>>,
175 should_fail: bool,
176 }
177
178 #[async_trait::async_trait]
179 impl SyncHandler for TestHandler {
180 async fn on_sync(&self, _response: &SyncResponseModel) -> Result<(), SyncHandlerError> {
181 self.execution_log.lock().unwrap().push(self.name.clone());
182 if self.should_fail {
183 Err("Handler failed".into())
184 } else {
185 Ok(())
186 }
187 }
188 }
189
190 struct TestErrorHandler {
191 name: String,
192 error_log: Arc<Mutex<Vec<String>>>,
193 }
194
195 #[async_trait::async_trait]
196 impl SyncErrorHandler for TestErrorHandler {
197 async fn on_error(&self, _error: &SyncError) {
198 self.error_log.lock().unwrap().push(self.name.clone());
199 }
200 }
201
202 fn test_client(api_client: bitwarden_api_api::apis::ApiClient) -> SyncClient {
204 let dummy_config = bitwarden_api_api::Configuration {
205 base_path: String::new(),
206 client: reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build(),
207 };
208 SyncClient {
209 api_configurations: Arc::new(ApiConfigurations {
210 api_client,
211 identity_client: bitwarden_api_identity::apis::ApiClient::new_mocked(|_| {}),
212 api_config: dummy_config.clone(),
213 identity_config: dummy_config,
214 device_type: bitwarden_core::client::DeviceType::SDK,
215 }),
216 sync_handlers: HandlerRegistry::new(),
217 error_handlers: HandlerRegistry::new(),
218 sync_lock: tokio::sync::Mutex::new(()),
219 }
220 }
221
222 #[tokio::test]
223 async fn test_handlers_execute_in_registration_order() {
224 let client = test_client(bitwarden_api_api::apis::ApiClient::new_mocked(|_| {}));
225 let log = Arc::new(Mutex::new(Vec::new()));
226
227 client.register_sync_handler(Arc::new(TestHandler {
228 name: "first".to_string(),
229 execution_log: log.clone(),
230 should_fail: false,
231 }));
232 client.register_sync_handler(Arc::new(TestHandler {
233 name: "second".to_string(),
234 execution_log: log.clone(),
235 should_fail: false,
236 }));
237 client.register_sync_handler(Arc::new(TestHandler {
238 name: "third".to_string(),
239 execution_log: log.clone(),
240 should_fail: false,
241 }));
242
243 let response = SyncResponseModel::default();
244 client.run_handlers(&response).await.unwrap();
245
246 assert_eq!(
247 *log.lock().unwrap(),
248 vec!["first", "second", "third"],
249 "Handlers should execute in registration order"
250 );
251 }
252
253 #[tokio::test]
254 async fn test_handler_error_stops_subsequent_handlers() {
255 let client = test_client(bitwarden_api_api::apis::ApiClient::new_mocked(|_| {}));
256 let log = Arc::new(Mutex::new(Vec::new()));
257
258 client.register_sync_handler(Arc::new(TestHandler {
259 name: "first".to_string(),
260 execution_log: log.clone(),
261 should_fail: false,
262 }));
263 client.register_sync_handler(Arc::new(TestHandler {
264 name: "second".to_string(),
265 execution_log: log.clone(),
266 should_fail: true,
267 }));
268 client.register_sync_handler(Arc::new(TestHandler {
269 name: "third".to_string(),
270 execution_log: log.clone(),
271 should_fail: false,
272 }));
273
274 let response = SyncResponseModel::default();
275 let result = client.run_handlers(&response).await;
276
277 assert!(result.is_err(), "Should return error when handler fails");
278 assert_eq!(
279 *log.lock().unwrap(),
280 vec!["first", "second"],
281 "Third handler should not execute after second handler fails"
282 );
283 }
284
285 #[tokio::test]
286 async fn test_sync_success_calls_handlers_and_returns_response() {
287 let client = test_client(bitwarden_api_api::apis::ApiClient::new_mocked(|mock| {
288 mock.sync_api
289 .expect_get()
290 .returning(|_| Ok(SyncResponseModel::default()));
291 }));
292 let sync_log = Arc::new(Mutex::new(Vec::new()));
293 let error_log = Arc::new(Mutex::new(Vec::new()));
294
295 client.register_sync_handler(Arc::new(TestHandler {
296 name: "handler".to_string(),
297 execution_log: sync_log.clone(),
298 should_fail: false,
299 }));
300 client.register_error_handler(Arc::new(TestErrorHandler {
301 name: "error_handler".to_string(),
302 error_log: error_log.clone(),
303 }));
304
305 let result = client
306 .sync(SyncRequest {
307 exclude_subdomains: None,
308 })
309 .await;
310
311 assert!(result.is_ok(), "Sync should succeed");
312 assert_eq!(
313 *sync_log.lock().unwrap(),
314 vec!["handler"],
315 "Sync handler should be called on success"
316 );
317 assert!(
318 error_log.lock().unwrap().is_empty(),
319 "Error handlers should not be called on success"
320 );
321 }
322
323 #[tokio::test]
324 async fn test_sync_error_notifies_error_handlers() {
325 let client = test_client(bitwarden_api_api::apis::ApiClient::new_mocked(|mock| {
326 mock.sync_api.expect_get().returning(|_| {
327 Err(bitwarden_api_api::Error::Io(std::io::Error::other(
328 "test error",
329 )))
330 });
331 }));
332 let error_log = Arc::new(Mutex::new(Vec::new()));
333
334 client.register_error_handler(Arc::new(TestErrorHandler {
335 name: "first".to_string(),
336 error_log: error_log.clone(),
337 }));
338 client.register_error_handler(Arc::new(TestErrorHandler {
339 name: "second".to_string(),
340 error_log: error_log.clone(),
341 }));
342
343 let result = client
345 .sync(SyncRequest {
346 exclude_subdomains: None,
347 })
348 .await;
349
350 assert!(result.is_err());
351 assert_eq!(
352 *error_log.lock().unwrap(),
353 vec!["first", "second"],
354 "All error handlers should be called on sync failure"
355 );
356 }
357}