1use std::sync::Arc;
2
3use bitwarden_state::registry::StateRegistry;
4
5use super::Client;
6use crate::{
7 UserId,
8 auth::auth_tokens::TokenHandler,
9 client::{
10 ClientBuilder, get_host_platform_info,
11 persisted_state::{ACCOUNT_CRYPTO_STATE, BASE_URLS, BaseUrls, USER_ID},
12 },
13 key_management::account_cryptographic_state::WrappedAccountCryptographicState,
14};
15
16#[derive(Debug, thiserror::Error)]
18pub enum RehydrationError {
19 #[error("Required state value not found in registry: {0}")]
21 MissingState(String),
22 #[error("State access error: {0}")]
24 State(#[from] bitwarden_state::SettingsError),
25}
26
27pub struct SaveStateData {
32 pub user_id: UserId,
34 pub urls: BaseUrls,
36 pub crypto_state: WrappedAccountCryptographicState,
38}
39
40impl Client {
41 pub async fn save_to_state(
46 data: SaveStateData,
47 reg: &StateRegistry,
48 ) -> Result<(), RehydrationError> {
49 reg.setting(BASE_URLS)
50 .map_err(|e| RehydrationError::State(e.into()))?
51 .update(data.urls)
52 .await
53 .map_err(RehydrationError::State)?;
54 reg.setting(USER_ID)
55 .map_err(|e| RehydrationError::State(e.into()))?
56 .update(data.user_id)
57 .await
58 .map_err(RehydrationError::State)?;
59 reg.setting(ACCOUNT_CRYPTO_STATE)
60 .map_err(|e| RehydrationError::State(e.into()))?
61 .update(data.crypto_state)
62 .await
63 .map_err(RehydrationError::State)?;
64 Ok(())
65 }
66
67 pub async fn load_from_state(
71 token_handler: Arc<dyn TokenHandler>,
72 registry: StateRegistry,
73 ) -> Result<Self, RehydrationError> {
74 let base_urls: BaseUrls = registry
75 .setting(BASE_URLS)
76 .map_err(|e| RehydrationError::State(e.into()))?
77 .get()
78 .await
79 .map_err(RehydrationError::State)?
80 .ok_or_else(|| RehydrationError::MissingState("BASE_URLS".to_string()))?;
81
82 let user_id: UserId = registry
83 .setting(USER_ID)
84 .map_err(|e| RehydrationError::State(e.into()))?
85 .get()
86 .await
87 .map_err(RehydrationError::State)?
88 .ok_or_else(|| RehydrationError::MissingState("USER_ID".to_string()))?;
89
90 let platform = get_host_platform_info();
91 let settings = crate::ClientSettings {
92 identity_url: base_urls.identity_url,
93 api_url: base_urls.api_url,
94 user_agent: platform.user_agent.clone(),
95 device_type: platform.device_type,
96 device_identifier: platform.device_identifier.clone(),
97 bitwarden_client_version: platform.bitwarden_client_version.clone(),
98 bitwarden_package_type: platform.bitwarden_package_type.clone(),
99 };
100
101 let client = ClientBuilder::new()
102 .with_settings(settings)
103 .with_token_handler(token_handler)
104 .with_state(registry)
105 .build();
106
107 client
108 .internal
109 .init_user_id(user_id)
110 .await
111 .expect("user ID cannot already be set on a freshly built client");
112
113 Ok(client)
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use std::sync::{Arc, Once};
120
121 use bitwarden_crypto::{
122 KeyStore, PublicKeyEncryptionAlgorithm, SignatureAlgorithm, SymmetricKeyAlgorithm,
123 };
124 use bitwarden_state::registry::StateRegistry;
125
126 use super::*;
127 use crate::{
128 DeviceType, HostPlatformInfo, UserId,
129 auth::auth_tokens::NoopTokenHandler,
130 client::persisted_state::{ACCOUNT_CRYPTO_STATE, BASE_URLS, BaseUrls, USER_ID},
131 key_management::{
132 KeySlotIds, SecurityState,
133 account_cryptographic_state::WrappedAccountCryptographicState,
134 },
135 };
136
137 static INIT: Once = Once::new();
138
139 fn ensure_platform_info() {
140 INIT.call_once(|| {
141 crate::init_host_platform_info(HostPlatformInfo {
142 user_agent: "rehydration-tests".to_string(),
143 device_type: DeviceType::SDK,
144 device_identifier: None,
145 bitwarden_client_version: None,
146 bitwarden_package_type: None,
147 });
148 });
149 }
150
151 fn test_user_id() -> UserId {
152 "d5b1fde2-a1e3-4c5b-9e0f-1a2b3c4d5e6f".parse().unwrap()
153 }
154
155 fn test_base_urls() -> BaseUrls {
156 BaseUrls {
157 identity_url: "https://identity.example.com".to_string(),
158 api_url: "https://api.example.com".to_string(),
159 }
160 }
161
162 fn test_crypto_state() -> WrappedAccountCryptographicState {
163 let store: KeyStore<KeySlotIds> = KeyStore::default();
164 let mut ctx = store.context_mut();
165 let user_key = ctx.make_symmetric_key(SymmetricKeyAlgorithm::XChaCha20Poly1305);
166 let private_key_id = ctx.make_private_key(PublicKeyEncryptionAlgorithm::RsaOaepSha1);
167 let signing_key_id = ctx.make_signing_key(SignatureAlgorithm::Ed25519);
168 let signed_public_key = ctx
169 .make_signed_public_key(private_key_id, signing_key_id)
170 .unwrap();
171 let security_state = SecurityState::new();
172 let signed_security_state = security_state.sign(signing_key_id, &mut ctx).unwrap();
173 let wrapped_private = ctx.wrap_private_key(user_key, private_key_id).unwrap();
174 let wrapped_signing = ctx.wrap_signing_key(user_key, signing_key_id).unwrap();
175 WrappedAccountCryptographicState::V2 {
176 private_key: wrapped_private,
177 signed_public_key: Some(signed_public_key),
178 signing_key: wrapped_signing,
179 security_state: signed_security_state,
180 }
181 }
182
183 fn test_save_data() -> SaveStateData {
184 SaveStateData {
185 user_id: test_user_id(),
186 urls: test_base_urls(),
187 crypto_state: test_crypto_state(),
188 }
189 }
190
191 #[tokio::test]
192 async fn save_to_state_writes_all_settings() {
193 let reg = StateRegistry::new_with_memory_db();
194 let data = test_save_data();
195 let expected_user_id = data.user_id;
196 let expected_urls_identity = data.urls.identity_url.clone();
197 let expected_urls_api = data.urls.api_url.clone();
198
199 Client::save_to_state(data, ®).await.unwrap();
200
201 let base_urls: BaseUrls = reg
203 .setting(BASE_URLS)
204 .unwrap()
205 .get()
206 .await
207 .unwrap()
208 .expect("BASE_URLS should be present");
209 assert_eq!(base_urls.identity_url, expected_urls_identity);
210 assert_eq!(base_urls.api_url, expected_urls_api);
211
212 let user_id: UserId = reg
213 .setting(USER_ID)
214 .unwrap()
215 .get()
216 .await
217 .unwrap()
218 .expect("USER_ID should be present");
219 assert_eq!(user_id, expected_user_id);
220
221 let crypto_state: WrappedAccountCryptographicState = reg
222 .setting(ACCOUNT_CRYPTO_STATE)
223 .unwrap()
224 .get()
225 .await
226 .unwrap()
227 .expect("ACCOUNT_CRYPTO_STATE should be present");
228 assert!(
229 matches!(crypto_state, WrappedAccountCryptographicState::V2 { .. }),
230 "Expected V2 crypto state"
231 );
232 }
233
234 #[tokio::test]
235 async fn load_from_state_restores_user_id() {
236 ensure_platform_info();
237
238 let reg = StateRegistry::new_with_memory_db();
239 let data = test_save_data();
240 let expected_user_id = data.user_id;
241
242 Client::save_to_state(data, ®).await.unwrap();
243
244 let token_handler: Arc<dyn TokenHandler> = Arc::new(NoopTokenHandler);
245 let client = Client::load_from_state(token_handler, reg).await.unwrap();
246
247 assert_eq!(
248 client.internal.get_user_id(),
249 Some(expected_user_id),
250 "Restored client should have the saved user ID"
251 );
252 }
253
254 #[tokio::test]
255 async fn load_from_state_missing_base_urls_returns_error() {
256 ensure_platform_info();
257
258 let reg = StateRegistry::new_with_memory_db();
259 let token_handler: Arc<dyn TokenHandler> = Arc::new(NoopTokenHandler);
262 let result = Client::load_from_state(token_handler, reg).await;
263
264 match result {
265 Err(RehydrationError::MissingState(s)) => {
266 assert!(
267 s.contains("BASE_URLS"),
268 "Error message should mention BASE_URLS, got: {s}"
269 );
270 }
271 Err(e) => panic!("Expected MissingState error for BASE_URLS, got: {e:?}"),
272 Ok(_) => panic!("Expected MissingState error for BASE_URLS, got Ok"),
273 }
274 }
275
276 #[tokio::test]
277 async fn load_from_state_missing_user_id_returns_error() {
278 ensure_platform_info();
279
280 let reg = StateRegistry::new_with_memory_db();
281 reg.setting(BASE_URLS)
283 .unwrap()
284 .update(test_base_urls())
285 .await
286 .unwrap();
287
288 let token_handler: Arc<dyn TokenHandler> = Arc::new(NoopTokenHandler);
289 let result = Client::load_from_state(token_handler, reg).await;
290
291 match result {
292 Err(RehydrationError::MissingState(s)) => {
293 assert!(
294 s.contains("USER_ID"),
295 "Error message should mention USER_ID, got: {s}"
296 );
297 }
298 Err(e) => panic!("Expected MissingState error for USER_ID, got: {e:?}"),
299 Ok(_) => panic!("Expected MissingState error for USER_ID, got Ok"),
300 }
301 }
302}