1use std::sync::Arc;
2
3use bitwarden_state::registry::StateRegistry;
4
5use super::Client;
6use crate::{
7 UserId,
8 auth::auth_tokens::TokenHandler,
9 client::{
10 ClientBuilder, get_host_platform_info,
11 persisted_state::{ACCOUNT_CRYPTO_STATE, BASE_URLS, BaseUrls, USER_EMAIL, USER_ID},
12 },
13 key_management::account_cryptographic_state::WrappedAccountCryptographicState,
14};
15
16#[derive(Debug, thiserror::Error)]
18pub enum RehydrationError {
19 #[error("Required state value not found in registry: {0}")]
21 MissingState(String),
22 #[error("State access error: {0}")]
24 State(#[from] bitwarden_state::SettingsError),
25}
26
27pub struct SaveStateData {
32 pub user_id: UserId,
34 pub email: String,
36 pub urls: BaseUrls,
38 pub crypto_state: WrappedAccountCryptographicState,
40}
41
42impl Client {
43 pub async fn save_to_state(
48 data: SaveStateData,
49 reg: &StateRegistry,
50 ) -> Result<(), RehydrationError> {
51 reg.setting(BASE_URLS)
52 .map_err(|e| RehydrationError::State(e.into()))?
53 .update(data.urls)
54 .await
55 .map_err(RehydrationError::State)?;
56 reg.setting(USER_ID)
57 .map_err(|e| RehydrationError::State(e.into()))?
58 .update(data.user_id)
59 .await
60 .map_err(RehydrationError::State)?;
61 reg.setting(USER_EMAIL)
62 .map_err(|e| RehydrationError::State(e.into()))?
63 .update(data.email)
64 .await
65 .map_err(RehydrationError::State)?;
66 reg.setting(ACCOUNT_CRYPTO_STATE)
67 .map_err(|e| RehydrationError::State(e.into()))?
68 .update(data.crypto_state)
69 .await
70 .map_err(RehydrationError::State)?;
71 Ok(())
72 }
73
74 pub async fn load_from_state(
78 token_handler: Arc<dyn TokenHandler>,
79 registry: StateRegistry,
80 ) -> Result<Self, RehydrationError> {
81 let base_urls: BaseUrls = registry
82 .setting(BASE_URLS)
83 .map_err(|e| RehydrationError::State(e.into()))?
84 .get()
85 .await
86 .map_err(RehydrationError::State)?
87 .ok_or_else(|| RehydrationError::MissingState("BASE_URLS".to_string()))?;
88
89 let user_id: UserId = registry
90 .setting(USER_ID)
91 .map_err(|e| RehydrationError::State(e.into()))?
92 .get()
93 .await
94 .map_err(RehydrationError::State)?
95 .ok_or_else(|| RehydrationError::MissingState("USER_ID".to_string()))?;
96
97 let platform = get_host_platform_info();
98 let settings = crate::ClientSettings {
99 identity_url: base_urls.identity_url,
100 api_url: base_urls.api_url,
101 user_agent: platform.user_agent.clone(),
102 device_type: platform.device_type,
103 device_identifier: platform.device_identifier.clone(),
104 bitwarden_client_version: platform.bitwarden_client_version.clone(),
105 bitwarden_package_type: platform.bitwarden_package_type.clone(),
106 };
107
108 let client = ClientBuilder::new()
109 .with_settings(settings)
110 .with_token_handler(token_handler)
111 .with_state(registry)
112 .build();
113
114 client
115 .internal
116 .init_user_id(user_id)
117 .await
118 .expect("user ID cannot already be set on a freshly built client");
119
120 if let Err(e) = client.flags().fetch(false).await {
123 tracing::warn!(
124 "Failed to refresh feature flags on startup; using previously stored flags: {e}"
125 );
126 }
127
128 Ok(client)
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use std::sync::{Arc, Once};
135
136 use bitwarden_crypto::{
137 KeyStore, PublicKeyEncryptionAlgorithm, SignatureAlgorithm, SymmetricKeyAlgorithm,
138 };
139 use bitwarden_state::registry::StateRegistry;
140
141 use super::*;
142 use crate::{
143 DeviceType, HostPlatformInfo, UserId,
144 auth::auth_tokens::NoopTokenHandler,
145 client::persisted_state::{ACCOUNT_CRYPTO_STATE, BASE_URLS, BaseUrls, USER_EMAIL, USER_ID},
146 key_management::{
147 KeySlotIds, SecurityState,
148 account_cryptographic_state::WrappedAccountCryptographicState,
149 },
150 };
151
152 static INIT: Once = Once::new();
153
154 fn ensure_platform_info() {
155 INIT.call_once(|| {
156 crate::init_host_platform_info(HostPlatformInfo {
157 user_agent: "rehydration-tests".to_string(),
158 device_type: DeviceType::SDK,
159 device_identifier: None,
160 bitwarden_client_version: None,
161 bitwarden_package_type: None,
162 });
163 });
164 }
165
166 fn test_user_id() -> UserId {
167 "d5b1fde2-a1e3-4c5b-9e0f-1a2b3c4d5e6f".parse().unwrap()
168 }
169
170 fn test_base_urls() -> BaseUrls {
171 BaseUrls {
173 identity_url: "http://127.0.0.1:1".to_string(),
174 api_url: "http://127.0.0.1:1".to_string(),
175 }
176 }
177
178 fn test_crypto_state() -> WrappedAccountCryptographicState {
179 let store: KeyStore<KeySlotIds> = KeyStore::default();
180 let mut ctx = store.context_mut();
181 let user_key = ctx.make_symmetric_key(SymmetricKeyAlgorithm::XChaCha20Poly1305);
182 let private_key_id = ctx.make_private_key(PublicKeyEncryptionAlgorithm::RsaOaepSha1);
183 let signing_key_id = ctx.make_signing_key(SignatureAlgorithm::Ed25519);
184 let signed_public_key = ctx
185 .make_signed_public_key(private_key_id, signing_key_id)
186 .unwrap();
187 let security_state = SecurityState::new();
188 let signed_security_state = security_state.sign(signing_key_id, &mut ctx).unwrap();
189 let wrapped_private = ctx.wrap_private_key(user_key, private_key_id).unwrap();
190 let wrapped_signing = ctx.wrap_signing_key(user_key, signing_key_id).unwrap();
191 WrappedAccountCryptographicState::V2 {
192 private_key: wrapped_private,
193 signed_public_key: Some(signed_public_key),
194 signing_key: wrapped_signing,
195 security_state: signed_security_state,
196 }
197 }
198
199 fn test_email() -> String {
200 "[email protected]".to_string()
201 }
202
203 fn test_save_data() -> SaveStateData {
204 SaveStateData {
205 user_id: test_user_id(),
206 email: test_email(),
207 urls: test_base_urls(),
208 crypto_state: test_crypto_state(),
209 }
210 }
211
212 #[tokio::test]
213 async fn save_to_state_writes_all_settings() {
214 let reg = StateRegistry::new_with_memory_db();
215 let data = test_save_data();
216 let expected_user_id = data.user_id;
217 let expected_email = data.email.clone();
218 let expected_urls_identity = data.urls.identity_url.clone();
219 let expected_urls_api = data.urls.api_url.clone();
220
221 Client::save_to_state(data, ®).await.unwrap();
222
223 let base_urls: BaseUrls = reg
225 .setting(BASE_URLS)
226 .unwrap()
227 .get()
228 .await
229 .unwrap()
230 .expect("BASE_URLS should be present");
231 assert_eq!(base_urls.identity_url, expected_urls_identity);
232 assert_eq!(base_urls.api_url, expected_urls_api);
233
234 let user_id: UserId = reg
235 .setting(USER_ID)
236 .unwrap()
237 .get()
238 .await
239 .unwrap()
240 .expect("USER_ID should be present");
241 assert_eq!(user_id, expected_user_id);
242
243 let email: String = reg
244 .setting(USER_EMAIL)
245 .unwrap()
246 .get()
247 .await
248 .unwrap()
249 .expect("USER_EMAIL should be present");
250 assert_eq!(email, expected_email);
251
252 let crypto_state: WrappedAccountCryptographicState = reg
253 .setting(ACCOUNT_CRYPTO_STATE)
254 .unwrap()
255 .get()
256 .await
257 .unwrap()
258 .expect("ACCOUNT_CRYPTO_STATE should be present");
259 assert!(
260 matches!(crypto_state, WrappedAccountCryptographicState::V2 { .. }),
261 "Expected V2 crypto state"
262 );
263 }
264
265 #[tokio::test]
266 async fn load_from_state_restores_user_id() {
267 ensure_platform_info();
268
269 let reg = StateRegistry::new_with_memory_db();
270 let data = test_save_data();
271 let expected_user_id = data.user_id;
272
273 Client::save_to_state(data, ®).await.unwrap();
274
275 let token_handler: Arc<dyn TokenHandler> = Arc::new(NoopTokenHandler);
276 let client = Client::load_from_state(token_handler, reg).await.unwrap();
277
278 assert_eq!(
279 client.internal.get_user_id(),
280 Some(expected_user_id),
281 "Restored client should have the saved user ID"
282 );
283 }
284
285 #[tokio::test]
286 async fn load_from_state_missing_base_urls_returns_error() {
287 ensure_platform_info();
288
289 let reg = StateRegistry::new_with_memory_db();
290 let token_handler: Arc<dyn TokenHandler> = Arc::new(NoopTokenHandler);
293 let result = Client::load_from_state(token_handler, reg).await;
294
295 match result {
296 Err(RehydrationError::MissingState(s)) => {
297 assert!(
298 s.contains("BASE_URLS"),
299 "Error message should mention BASE_URLS, got: {s}"
300 );
301 }
302 Err(e) => panic!("Expected MissingState error for BASE_URLS, got: {e:?}"),
303 Ok(_) => panic!("Expected MissingState error for BASE_URLS, got Ok"),
304 }
305 }
306
307 #[tokio::test]
308 async fn load_from_state_missing_user_id_returns_error() {
309 ensure_platform_info();
310
311 let reg = StateRegistry::new_with_memory_db();
312 reg.setting(BASE_URLS)
314 .unwrap()
315 .update(test_base_urls())
316 .await
317 .unwrap();
318
319 let token_handler: Arc<dyn TokenHandler> = Arc::new(NoopTokenHandler);
320 let result = Client::load_from_state(token_handler, reg).await;
321
322 match result {
323 Err(RehydrationError::MissingState(s)) => {
324 assert!(
325 s.contains("USER_ID"),
326 "Error message should mention USER_ID, got: {s}"
327 );
328 }
329 Err(e) => panic!("Expected MissingState error for USER_ID, got: {e:?}"),
330 Ok(_) => panic!("Expected MissingState error for USER_ID, got Ok"),
331 }
332 }
333}