1use std::sync::Arc;
2
3use bitwarden_api_api::models::SyncResponseModel;
4use bitwarden_core::{
5 Client,
6 client::{ApiConfigurations, FromClientPart},
7};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10use tokio::sync::Mutex;
11
12use crate::{SyncErrorHandler, SyncHandler, SyncHandlerError, registry::HandlerRegistry};
13
14#[allow(missing_docs)]
15#[derive(Debug, Error)]
16pub enum SyncError {
17 #[error(transparent)]
18 Api(#[from] bitwarden_core::ApiError),
19
20 #[error("Sync event handler failed: {0}")]
21 HandlerFailed(#[source] SyncHandlerError),
22}
23
24#[allow(missing_docs)]
25#[derive(Serialize, Deserialize, Debug, Clone)]
26#[serde(rename_all = "camelCase", deny_unknown_fields)]
27pub struct SyncRequest {
28 pub exclude_subdomains: Option<bool>,
30}
31
32pub struct SyncClient {
37 api_configurations: Arc<ApiConfigurations>,
38 sync_handlers: HandlerRegistry<dyn SyncHandler>,
39 error_handlers: HandlerRegistry<dyn SyncErrorHandler>,
40 sync_lock: Mutex<()>,
41}
42
43impl SyncClient {
44 pub fn new(client: Client) -> Self {
46 Self {
47 api_configurations: client.get_part(),
48 sync_handlers: HandlerRegistry::new(),
49 error_handlers: HandlerRegistry::new(),
50 sync_lock: Mutex::new(()),
51 }
52 }
53
54 pub fn register_sync_handler(&self, handler: Arc<dyn SyncHandler>) {
59 self.sync_handlers.register(handler);
60 }
61
62 pub fn register_error_handler(&self, handler: Arc<dyn SyncErrorHandler>) {
68 self.error_handlers.register(handler);
69 }
70
71 pub async fn sync(&self, request: SyncRequest) -> Result<SyncResponseModel, SyncError> {
85 let _guard = self.sync_lock.lock().await;
87
88 let result = async {
89 let response = self.perform_sync(&request).await?;
90 self.run_handlers(&response).await?;
91 Ok(response)
92 }
93 .await;
94
95 if let Err(ref error) = result {
96 self.run_error_handlers(error).await;
97 }
98
99 result
100 }
101
102 async fn run_handlers(&self, response: &SyncResponseModel) -> Result<(), SyncError> {
110 let handlers = self.sync_handlers.handlers();
111
112 for handler in &handlers {
113 handler
114 .on_sync(response)
115 .await
116 .map_err(SyncError::HandlerFailed)?;
117 }
118
119 for handler in &handlers {
120 handler.on_sync_complete().await;
121 }
122
123 Ok(())
124 }
125
126 async fn run_error_handlers(&self, error: &SyncError) {
130 for handler in &self.error_handlers.handlers() {
131 handler.on_error(error).await;
132 }
133 }
134
135 async fn perform_sync(&self, input: &SyncRequest) -> Result<SyncResponseModel, SyncError> {
137 let sync = self
138 .api_configurations
139 .api_client
140 .sync_api()
141 .get(input.exclude_subdomains)
142 .await
143 .map_err(|e| SyncError::Api(e.into()))?;
144
145 Ok(sync)
146 }
147}
148
149pub trait SyncClientExt {
154 fn sync(&self) -> SyncClient;
156}
157
158impl SyncClientExt for Client {
159 fn sync(&self) -> SyncClient {
160 SyncClient::new(self.clone())
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use std::sync::{Arc, Mutex};
167
168 use super::*;
169
170 struct TestHandler {
171 name: String,
172 execution_log: Arc<Mutex<Vec<String>>>,
173 should_fail: bool,
174 }
175
176 #[async_trait::async_trait]
177 impl SyncHandler for TestHandler {
178 async fn on_sync(&self, _response: &SyncResponseModel) -> Result<(), SyncHandlerError> {
179 self.execution_log.lock().unwrap().push(self.name.clone());
180 if self.should_fail {
181 Err("Handler failed".into())
182 } else {
183 Ok(())
184 }
185 }
186 }
187
188 struct TestErrorHandler {
189 name: String,
190 error_log: Arc<Mutex<Vec<String>>>,
191 }
192
193 #[async_trait::async_trait]
194 impl SyncErrorHandler for TestErrorHandler {
195 async fn on_error(&self, _error: &SyncError) {
196 self.error_log.lock().unwrap().push(self.name.clone());
197 }
198 }
199
200 fn test_client(api_client: bitwarden_api_api::apis::ApiClient) -> SyncClient {
202 let dummy_config = bitwarden_api_api::Configuration {
203 base_path: String::new(),
204 client: reqwest_middleware::ClientBuilder::new(reqwest::Client::new()).build(),
205 };
206 SyncClient {
207 api_configurations: Arc::new(ApiConfigurations {
208 api_client,
209 identity_client: bitwarden_api_identity::apis::ApiClient::new_mocked(|_| {}),
210 api_config: dummy_config.clone(),
211 identity_config: dummy_config,
212 device_type: bitwarden_core::client::DeviceType::SDK,
213 }),
214 sync_handlers: HandlerRegistry::new(),
215 error_handlers: HandlerRegistry::new(),
216 sync_lock: tokio::sync::Mutex::new(()),
217 }
218 }
219
220 #[tokio::test]
221 async fn test_handlers_execute_in_registration_order() {
222 let client = test_client(bitwarden_api_api::apis::ApiClient::new_mocked(|_| {}));
223 let log = Arc::new(Mutex::new(Vec::new()));
224
225 client.register_sync_handler(Arc::new(TestHandler {
226 name: "first".to_string(),
227 execution_log: log.clone(),
228 should_fail: false,
229 }));
230 client.register_sync_handler(Arc::new(TestHandler {
231 name: "second".to_string(),
232 execution_log: log.clone(),
233 should_fail: false,
234 }));
235 client.register_sync_handler(Arc::new(TestHandler {
236 name: "third".to_string(),
237 execution_log: log.clone(),
238 should_fail: false,
239 }));
240
241 let response = SyncResponseModel::default();
242 client.run_handlers(&response).await.unwrap();
243
244 assert_eq!(
245 *log.lock().unwrap(),
246 vec!["first", "second", "third"],
247 "Handlers should execute in registration order"
248 );
249 }
250
251 #[tokio::test]
252 async fn test_handler_error_stops_subsequent_handlers() {
253 let client = test_client(bitwarden_api_api::apis::ApiClient::new_mocked(|_| {}));
254 let log = Arc::new(Mutex::new(Vec::new()));
255
256 client.register_sync_handler(Arc::new(TestHandler {
257 name: "first".to_string(),
258 execution_log: log.clone(),
259 should_fail: false,
260 }));
261 client.register_sync_handler(Arc::new(TestHandler {
262 name: "second".to_string(),
263 execution_log: log.clone(),
264 should_fail: true,
265 }));
266 client.register_sync_handler(Arc::new(TestHandler {
267 name: "third".to_string(),
268 execution_log: log.clone(),
269 should_fail: false,
270 }));
271
272 let response = SyncResponseModel::default();
273 let result = client.run_handlers(&response).await;
274
275 assert!(result.is_err(), "Should return error when handler fails");
276 assert_eq!(
277 *log.lock().unwrap(),
278 vec!["first", "second"],
279 "Third handler should not execute after second handler fails"
280 );
281 }
282
283 #[tokio::test]
284 async fn test_sync_success_calls_handlers_and_returns_response() {
285 let client = test_client(bitwarden_api_api::apis::ApiClient::new_mocked(|mock| {
286 mock.sync_api
287 .expect_get()
288 .returning(|_| Ok(SyncResponseModel::default()));
289 }));
290 let sync_log = Arc::new(Mutex::new(Vec::new()));
291 let error_log = Arc::new(Mutex::new(Vec::new()));
292
293 client.register_sync_handler(Arc::new(TestHandler {
294 name: "handler".to_string(),
295 execution_log: sync_log.clone(),
296 should_fail: false,
297 }));
298 client.register_error_handler(Arc::new(TestErrorHandler {
299 name: "error_handler".to_string(),
300 error_log: error_log.clone(),
301 }));
302
303 let result = client
304 .sync(SyncRequest {
305 exclude_subdomains: None,
306 })
307 .await;
308
309 assert!(result.is_ok(), "Sync should succeed");
310 assert_eq!(
311 *sync_log.lock().unwrap(),
312 vec!["handler"],
313 "Sync handler should be called on success"
314 );
315 assert!(
316 error_log.lock().unwrap().is_empty(),
317 "Error handlers should not be called on success"
318 );
319 }
320
321 #[tokio::test]
322 async fn test_sync_error_notifies_error_handlers() {
323 let client = test_client(bitwarden_api_api::apis::ApiClient::new_mocked(|mock| {
324 mock.sync_api
325 .expect_get()
326 .returning(|_| Err(std::io::Error::other("test error").into()));
327 }));
328 let error_log = Arc::new(Mutex::new(Vec::new()));
329
330 client.register_error_handler(Arc::new(TestErrorHandler {
331 name: "first".to_string(),
332 error_log: error_log.clone(),
333 }));
334 client.register_error_handler(Arc::new(TestErrorHandler {
335 name: "second".to_string(),
336 error_log: error_log.clone(),
337 }));
338
339 let result = client
341 .sync(SyncRequest {
342 exclude_subdomains: None,
343 })
344 .await;
345
346 assert!(result.is_err());
347 assert_eq!(
348 *error_log.lock().unwrap(),
349 vec!["first", "second"],
350 "All error handlers should be called on sync failure"
351 );
352 }
353}