Skip to main content

bitwarden_core/client/
rehydration.rs

1use std::sync::Arc;
2
3use bitwarden_state::registry::StateRegistry;
4
5use super::Client;
6use crate::{
7    UserId,
8    auth::auth_tokens::TokenHandler,
9    client::{
10        ClientBuilder, get_host_platform_info,
11        persisted_state::{ACCOUNT_CRYPTO_STATE, BASE_URLS, BaseUrls, USER_EMAIL, USER_ID},
12    },
13    key_management::account_cryptographic_state::WrappedAccountCryptographicState,
14};
15
16/// Errors that can occur during client rehydration.
17#[derive(Debug, thiserror::Error)]
18pub enum RehydrationError {
19    /// A required value was not found in the state registry.
20    #[error("Required state value not found in registry: {0}")]
21    MissingState(String),
22    /// An error occurred accessing or updating a setting in the state registry.
23    #[error("State access error: {0}")]
24    State(#[from] bitwarden_state::SettingsError),
25}
26
27/// Data required to populate a [`StateRegistry`] via [`Client::save_to_state`].
28///
29/// Contains the values the auth flow does not yet persist automatically. Once the auth crate
30/// handles persistence directly, this type will be removed.
31pub struct SaveStateData {
32    /// The authenticated user's ID.
33    pub user_id: UserId,
34    /// The authenticated user's email.
35    pub email: String,
36    /// The base API URLs for the user's server.
37    pub urls: BaseUrls,
38    /// The user's wrapped account cryptographic state.
39    pub crypto_state: WrappedAccountCryptographicState,
40}
41
42impl Client {
43    /// Populates a [`StateRegistry`] with the state required for [`Client::load_from_state`].
44    ///
45    /// Call this after a successful login to persist the values that the auth flow does not yet
46    /// write automatically. Once the auth crate handles persistence directly, this will be removed.
47    pub async fn save_to_state(
48        data: SaveStateData,
49        reg: &StateRegistry,
50    ) -> Result<(), RehydrationError> {
51        reg.setting(BASE_URLS)
52            .map_err(|e| RehydrationError::State(e.into()))?
53            .update(data.urls)
54            .await
55            .map_err(RehydrationError::State)?;
56        reg.setting(USER_ID)
57            .map_err(|e| RehydrationError::State(e.into()))?
58            .update(data.user_id)
59            .await
60            .map_err(RehydrationError::State)?;
61        reg.setting(USER_EMAIL)
62            .map_err(|e| RehydrationError::State(e.into()))?
63            .update(data.email)
64            .await
65            .map_err(RehydrationError::State)?;
66        reg.setting(ACCOUNT_CRYPTO_STATE)
67            .map_err(|e| RehydrationError::State(e.into()))?
68            .update(data.crypto_state)
69            .await
70            .map_err(RehydrationError::State)?;
71        Ok(())
72    }
73
74    /// Reconstruct a locked Client from a populated StateRegistry.
75    ///
76    /// Does NOT unlock the vault.
77    pub async fn load_from_state(
78        token_handler: Arc<dyn TokenHandler>,
79        registry: StateRegistry,
80    ) -> Result<Self, RehydrationError> {
81        let base_urls: BaseUrls = registry
82            .setting(BASE_URLS)
83            .map_err(|e| RehydrationError::State(e.into()))?
84            .get()
85            .await
86            .map_err(RehydrationError::State)?
87            .ok_or_else(|| RehydrationError::MissingState("BASE_URLS".to_string()))?;
88
89        let user_id: UserId = registry
90            .setting(USER_ID)
91            .map_err(|e| RehydrationError::State(e.into()))?
92            .get()
93            .await
94            .map_err(RehydrationError::State)?
95            .ok_or_else(|| RehydrationError::MissingState("USER_ID".to_string()))?;
96
97        let platform = get_host_platform_info();
98        let settings = crate::ClientSettings {
99            identity_url: base_urls.identity_url,
100            api_url: base_urls.api_url,
101            user_agent: platform.user_agent.clone(),
102            device_type: platform.device_type,
103            device_identifier: platform.device_identifier.clone(),
104            bitwarden_client_version: platform.bitwarden_client_version.clone(),
105            bitwarden_package_type: platform.bitwarden_package_type.clone(),
106        };
107
108        let client = ClientBuilder::new()
109            .with_settings(settings)
110            .with_token_handler(token_handler)
111            .with_state(registry)
112            .build();
113
114        client
115            .internal
116            .init_user_id(user_id)
117            .await
118            .expect("user ID cannot already be set on a freshly built client");
119
120        // Refresh feature flags from /config if the cached set is missing or stale (TTL 1h).
121        // Failure is non-fatal — an offline client continues with the previously persisted flags.
122        if let Err(e) = client.flags().fetch(false).await {
123            tracing::warn!(
124                "Failed to refresh feature flags on startup; using previously stored flags: {e}"
125            );
126        }
127
128        Ok(client)
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use std::sync::{Arc, Once};
135
136    use bitwarden_crypto::{
137        KeyStore, PublicKeyEncryptionAlgorithm, SignatureAlgorithm, SymmetricKeyAlgorithm,
138    };
139    use bitwarden_state::registry::StateRegistry;
140
141    use super::*;
142    use crate::{
143        DeviceType, HostPlatformInfo, UserId,
144        auth::auth_tokens::NoopTokenHandler,
145        client::persisted_state::{ACCOUNT_CRYPTO_STATE, BASE_URLS, BaseUrls, USER_EMAIL, USER_ID},
146        key_management::{
147            KeySlotIds, SecurityState,
148            account_cryptographic_state::WrappedAccountCryptographicState,
149        },
150    };
151
152    static INIT: Once = Once::new();
153
154    fn ensure_platform_info() {
155        INIT.call_once(|| {
156            crate::init_host_platform_info(HostPlatformInfo {
157                user_agent: "rehydration-tests".to_string(),
158                device_type: DeviceType::SDK,
159                device_identifier: None,
160                bitwarden_client_version: None,
161                bitwarden_package_type: None,
162            });
163        });
164    }
165
166    fn test_user_id() -> UserId {
167        "d5b1fde2-a1e3-4c5b-9e0f-1a2b3c4d5e6f".parse().unwrap()
168    }
169
170    fn test_base_urls() -> BaseUrls {
171        // Use an invalid port so flag loading fails fast
172        BaseUrls {
173            identity_url: "http://127.0.0.1:1".to_string(),
174            api_url: "http://127.0.0.1:1".to_string(),
175        }
176    }
177
178    fn test_crypto_state() -> WrappedAccountCryptographicState {
179        let store: KeyStore<KeySlotIds> = KeyStore::default();
180        let mut ctx = store.context_mut();
181        let user_key = ctx.make_symmetric_key(SymmetricKeyAlgorithm::XChaCha20Poly1305);
182        let private_key_id = ctx.make_private_key(PublicKeyEncryptionAlgorithm::RsaOaepSha1);
183        let signing_key_id = ctx.make_signing_key(SignatureAlgorithm::Ed25519);
184        let signed_public_key = ctx
185            .make_signed_public_key(private_key_id, signing_key_id)
186            .unwrap();
187        let security_state = SecurityState::new();
188        let signed_security_state = security_state.sign(signing_key_id, &mut ctx).unwrap();
189        let wrapped_private = ctx.wrap_private_key(user_key, private_key_id).unwrap();
190        let wrapped_signing = ctx.wrap_signing_key(user_key, signing_key_id).unwrap();
191        WrappedAccountCryptographicState::V2 {
192            private_key: wrapped_private,
193            signed_public_key: Some(signed_public_key),
194            signing_key: wrapped_signing,
195            security_state: signed_security_state,
196        }
197    }
198
199    fn test_email() -> String {
200        "[email protected]".to_string()
201    }
202
203    fn test_save_data() -> SaveStateData {
204        SaveStateData {
205            user_id: test_user_id(),
206            email: test_email(),
207            urls: test_base_urls(),
208            crypto_state: test_crypto_state(),
209        }
210    }
211
212    #[tokio::test]
213    async fn save_to_state_writes_all_settings() {
214        let reg = StateRegistry::new_with_memory_db();
215        let data = test_save_data();
216        let expected_user_id = data.user_id;
217        let expected_email = data.email.clone();
218        let expected_urls_identity = data.urls.identity_url.clone();
219        let expected_urls_api = data.urls.api_url.clone();
220
221        Client::save_to_state(data, &reg).await.unwrap();
222
223        // Read back each setting directly from the registry.
224        let base_urls: BaseUrls = reg
225            .setting(BASE_URLS)
226            .unwrap()
227            .get()
228            .await
229            .unwrap()
230            .expect("BASE_URLS should be present");
231        assert_eq!(base_urls.identity_url, expected_urls_identity);
232        assert_eq!(base_urls.api_url, expected_urls_api);
233
234        let user_id: UserId = reg
235            .setting(USER_ID)
236            .unwrap()
237            .get()
238            .await
239            .unwrap()
240            .expect("USER_ID should be present");
241        assert_eq!(user_id, expected_user_id);
242
243        let email: String = reg
244            .setting(USER_EMAIL)
245            .unwrap()
246            .get()
247            .await
248            .unwrap()
249            .expect("USER_EMAIL should be present");
250        assert_eq!(email, expected_email);
251
252        let crypto_state: WrappedAccountCryptographicState = reg
253            .setting(ACCOUNT_CRYPTO_STATE)
254            .unwrap()
255            .get()
256            .await
257            .unwrap()
258            .expect("ACCOUNT_CRYPTO_STATE should be present");
259        assert!(
260            matches!(crypto_state, WrappedAccountCryptographicState::V2 { .. }),
261            "Expected V2 crypto state"
262        );
263    }
264
265    #[tokio::test]
266    async fn load_from_state_restores_user_id() {
267        ensure_platform_info();
268
269        let reg = StateRegistry::new_with_memory_db();
270        let data = test_save_data();
271        let expected_user_id = data.user_id;
272
273        Client::save_to_state(data, &reg).await.unwrap();
274
275        let token_handler: Arc<dyn TokenHandler> = Arc::new(NoopTokenHandler);
276        let client = Client::load_from_state(token_handler, reg).await.unwrap();
277
278        assert_eq!(
279            client.internal.get_user_id(),
280            Some(expected_user_id),
281            "Restored client should have the saved user ID"
282        );
283    }
284
285    #[tokio::test]
286    async fn load_from_state_missing_base_urls_returns_error() {
287        ensure_platform_info();
288
289        let reg = StateRegistry::new_with_memory_db();
290        // Registry is empty no settings written.
291
292        let token_handler: Arc<dyn TokenHandler> = Arc::new(NoopTokenHandler);
293        let result = Client::load_from_state(token_handler, reg).await;
294
295        match result {
296            Err(RehydrationError::MissingState(s)) => {
297                assert!(
298                    s.contains("BASE_URLS"),
299                    "Error message should mention BASE_URLS, got: {s}"
300                );
301            }
302            Err(e) => panic!("Expected MissingState error for BASE_URLS, got: {e:?}"),
303            Ok(_) => panic!("Expected MissingState error for BASE_URLS, got Ok"),
304        }
305    }
306
307    #[tokio::test]
308    async fn load_from_state_missing_user_id_returns_error() {
309        ensure_platform_info();
310
311        let reg = StateRegistry::new_with_memory_db();
312        // Write only BASE_URLS, omit USER_ID.
313        reg.setting(BASE_URLS)
314            .unwrap()
315            .update(test_base_urls())
316            .await
317            .unwrap();
318
319        let token_handler: Arc<dyn TokenHandler> = Arc::new(NoopTokenHandler);
320        let result = Client::load_from_state(token_handler, reg).await;
321
322        match result {
323            Err(RehydrationError::MissingState(s)) => {
324                assert!(
325                    s.contains("USER_ID"),
326                    "Error message should mention USER_ID, got: {s}"
327                );
328            }
329            Err(e) => panic!("Expected MissingState error for USER_ID, got: {e:?}"),
330            Ok(_) => panic!("Expected MissingState error for USER_ID, got Ok"),
331        }
332    }
333}