1use std::sync::Arc;
2
3use bitwarden_state::registry::StateRegistry;
4
5use super::Client;
6use crate::{
7 UserId,
8 auth::auth_tokens::TokenHandler,
9 client::{
10 ClientBuilder, get_host_platform_info,
11 persisted_state::{ACCOUNT_CRYPTO_STATE, BASE_URLS, BaseUrls, USER_EMAIL, USER_ID},
12 },
13 key_management::account_cryptographic_state::WrappedAccountCryptographicState,
14};
15
16#[derive(Debug, thiserror::Error)]
18pub enum RehydrationError {
19 #[error("Required state value not found in registry: {0}")]
21 MissingState(String),
22 #[error("State access error: {0}")]
24 State(#[from] bitwarden_state::SettingsError),
25}
26
27pub struct SaveStateData {
32 pub user_id: UserId,
34 pub email: String,
36 pub urls: BaseUrls,
38 pub crypto_state: WrappedAccountCryptographicState,
40}
41
42impl Client {
43 pub async fn save_to_state(
48 data: SaveStateData,
49 reg: &StateRegistry,
50 ) -> Result<(), RehydrationError> {
51 reg.setting(BASE_URLS)
52 .map_err(|e| RehydrationError::State(e.into()))?
53 .update(data.urls)
54 .await
55 .map_err(RehydrationError::State)?;
56 reg.setting(USER_ID)
57 .map_err(|e| RehydrationError::State(e.into()))?
58 .update(data.user_id)
59 .await
60 .map_err(RehydrationError::State)?;
61 reg.setting(USER_EMAIL)
62 .map_err(|e| RehydrationError::State(e.into()))?
63 .update(data.email)
64 .await
65 .map_err(RehydrationError::State)?;
66 reg.setting(ACCOUNT_CRYPTO_STATE)
67 .map_err(|e| RehydrationError::State(e.into()))?
68 .update(data.crypto_state)
69 .await
70 .map_err(RehydrationError::State)?;
71 Ok(())
72 }
73
74 pub async fn load_from_state(
78 token_handler: Arc<dyn TokenHandler>,
79 registry: StateRegistry,
80 ) -> Result<Self, RehydrationError> {
81 let base_urls: BaseUrls = registry
82 .setting(BASE_URLS)
83 .map_err(|e| RehydrationError::State(e.into()))?
84 .get()
85 .await
86 .map_err(RehydrationError::State)?
87 .ok_or_else(|| RehydrationError::MissingState("BASE_URLS".to_string()))?;
88
89 let user_id: UserId = registry
90 .setting(USER_ID)
91 .map_err(|e| RehydrationError::State(e.into()))?
92 .get()
93 .await
94 .map_err(RehydrationError::State)?
95 .ok_or_else(|| RehydrationError::MissingState("USER_ID".to_string()))?;
96
97 let settings =
98 get_host_platform_info().to_client_settings(base_urls.api_url, base_urls.identity_url);
99
100 let client = ClientBuilder::new()
101 .with_settings(settings)
102 .with_token_handler(token_handler)
103 .with_state(registry)
104 .build();
105
106 client
107 .internal
108 .init_user_id(user_id)
109 .await
110 .expect("user ID cannot already be set on a freshly built client");
111
112 if let Err(e) = client.flags().fetch(false).await {
115 tracing::warn!(
116 "Failed to refresh feature flags on startup; using previously stored flags: {e}"
117 );
118 }
119
120 Ok(client)
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use std::sync::{Arc, Once};
127
128 use bitwarden_crypto::{
129 KeyStore, PublicKeyEncryptionAlgorithm, SignatureAlgorithm, SymmetricKeyAlgorithm,
130 };
131 use bitwarden_state::registry::StateRegistry;
132
133 use super::*;
134 use crate::{
135 DeviceType, HostPlatformInfo, UserId,
136 auth::auth_tokens::NoopTokenHandler,
137 client::persisted_state::{ACCOUNT_CRYPTO_STATE, BASE_URLS, BaseUrls, USER_EMAIL, USER_ID},
138 key_management::{
139 KeySlotIds, SecurityState,
140 account_cryptographic_state::WrappedAccountCryptographicState,
141 },
142 };
143
144 static INIT: Once = Once::new();
145
146 fn ensure_platform_info() {
147 INIT.call_once(|| {
148 crate::init_host_platform_info(HostPlatformInfo {
149 user_agent: "rehydration-tests".to_string(),
150 device_type: DeviceType::SDK,
151 device_identifier: None,
152 bitwarden_client_version: None,
153 bitwarden_package_type: None,
154 });
155 });
156 }
157
158 fn test_user_id() -> UserId {
159 "d5b1fde2-a1e3-4c5b-9e0f-1a2b3c4d5e6f".parse().unwrap()
160 }
161
162 fn test_base_urls() -> BaseUrls {
163 BaseUrls {
165 identity_url: "http://127.0.0.1:1".to_string(),
166 api_url: "http://127.0.0.1:1".to_string(),
167 }
168 }
169
170 fn test_crypto_state() -> WrappedAccountCryptographicState {
171 let store: KeyStore<KeySlotIds> = KeyStore::default();
172 let mut ctx = store.context_mut();
173 let user_key = ctx.make_symmetric_key(SymmetricKeyAlgorithm::XChaCha20Poly1305);
174 let private_key_id = ctx.make_private_key(PublicKeyEncryptionAlgorithm::RsaOaepSha1);
175 let signing_key_id = ctx.make_signing_key(SignatureAlgorithm::Ed25519);
176 let signed_public_key = ctx
177 .make_signed_public_key(private_key_id, signing_key_id)
178 .unwrap();
179 let security_state = SecurityState::new();
180 let signed_security_state = security_state.sign(signing_key_id, &mut ctx).unwrap();
181 let wrapped_private = ctx.wrap_private_key(user_key, private_key_id).unwrap();
182 let wrapped_signing = ctx.wrap_signing_key(user_key, signing_key_id).unwrap();
183 WrappedAccountCryptographicState::V2 {
184 private_key: wrapped_private,
185 signed_public_key: Some(signed_public_key),
186 signing_key: wrapped_signing,
187 security_state: signed_security_state,
188 }
189 }
190
191 fn test_email() -> String {
192 "[email protected]".to_string()
193 }
194
195 fn test_save_data() -> SaveStateData {
196 SaveStateData {
197 user_id: test_user_id(),
198 email: test_email(),
199 urls: test_base_urls(),
200 crypto_state: test_crypto_state(),
201 }
202 }
203
204 #[tokio::test]
205 async fn save_to_state_writes_all_settings() {
206 let reg = StateRegistry::new_with_memory_db();
207 let data = test_save_data();
208 let expected_user_id = data.user_id;
209 let expected_email = data.email.clone();
210 let expected_urls_identity = data.urls.identity_url.clone();
211 let expected_urls_api = data.urls.api_url.clone();
212
213 Client::save_to_state(data, ®).await.unwrap();
214
215 let base_urls: BaseUrls = reg
217 .setting(BASE_URLS)
218 .unwrap()
219 .get()
220 .await
221 .unwrap()
222 .expect("BASE_URLS should be present");
223 assert_eq!(base_urls.identity_url, expected_urls_identity);
224 assert_eq!(base_urls.api_url, expected_urls_api);
225
226 let user_id: UserId = reg
227 .setting(USER_ID)
228 .unwrap()
229 .get()
230 .await
231 .unwrap()
232 .expect("USER_ID should be present");
233 assert_eq!(user_id, expected_user_id);
234
235 let email: String = reg
236 .setting(USER_EMAIL)
237 .unwrap()
238 .get()
239 .await
240 .unwrap()
241 .expect("USER_EMAIL should be present");
242 assert_eq!(email, expected_email);
243
244 let crypto_state: WrappedAccountCryptographicState = reg
245 .setting(ACCOUNT_CRYPTO_STATE)
246 .unwrap()
247 .get()
248 .await
249 .unwrap()
250 .expect("ACCOUNT_CRYPTO_STATE should be present");
251 assert!(
252 matches!(crypto_state, WrappedAccountCryptographicState::V2 { .. }),
253 "Expected V2 crypto state"
254 );
255 }
256
257 #[tokio::test]
258 async fn load_from_state_restores_user_id() {
259 ensure_platform_info();
260
261 let reg = StateRegistry::new_with_memory_db();
262 let data = test_save_data();
263 let expected_user_id = data.user_id;
264
265 Client::save_to_state(data, ®).await.unwrap();
266
267 let token_handler: Arc<dyn TokenHandler> = Arc::new(NoopTokenHandler);
268 let client = Client::load_from_state(token_handler, reg).await.unwrap();
269
270 assert_eq!(
271 client.internal.get_user_id(),
272 Some(expected_user_id),
273 "Restored client should have the saved user ID"
274 );
275 }
276
277 #[tokio::test]
278 async fn load_from_state_missing_base_urls_returns_error() {
279 ensure_platform_info();
280
281 let reg = StateRegistry::new_with_memory_db();
282 let token_handler: Arc<dyn TokenHandler> = Arc::new(NoopTokenHandler);
285 let result = Client::load_from_state(token_handler, reg).await;
286
287 match result {
288 Err(RehydrationError::MissingState(s)) => {
289 assert!(
290 s.contains("BASE_URLS"),
291 "Error message should mention BASE_URLS, got: {s}"
292 );
293 }
294 Err(e) => panic!("Expected MissingState error for BASE_URLS, got: {e:?}"),
295 Ok(_) => panic!("Expected MissingState error for BASE_URLS, got Ok"),
296 }
297 }
298
299 #[tokio::test]
300 async fn load_from_state_missing_user_id_returns_error() {
301 ensure_platform_info();
302
303 let reg = StateRegistry::new_with_memory_db();
304 reg.setting(BASE_URLS)
306 .unwrap()
307 .update(test_base_urls())
308 .await
309 .unwrap();
310
311 let token_handler: Arc<dyn TokenHandler> = Arc::new(NoopTokenHandler);
312 let result = Client::load_from_state(token_handler, reg).await;
313
314 match result {
315 Err(RehydrationError::MissingState(s)) => {
316 assert!(
317 s.contains("USER_ID"),
318 "Error message should mention USER_ID, got: {s}"
319 );
320 }
321 Err(e) => panic!("Expected MissingState error for USER_ID, got: {e:?}"),
322 Ok(_) => panic!("Expected MissingState error for USER_ID, got Ok"),
323 }
324 }
325}