Skip to main content

bitwarden_auth/token_management/
password_manager_token_handler.rs

1//! Token handler implementation for Bitwarden Password Manager authentication.
2
3use std::sync::{Arc, RwLock};
4
5use bitwarden_core::{
6    NotAuthenticatedError,
7    auth::{TokenHandler, login::LoginError},
8    client::{
9        login_method::UserLoginMethod,
10        persisted_state::{AUTHENTICATION_TOKENS, AuthenticationTokens, USER_LOGIN_METHOD},
11    },
12    key_management::KeySlotIds,
13};
14use bitwarden_crypto::KeyStore;
15use bitwarden_state::{registry::StateRegistry, settings::Setting};
16use chrono::Utc;
17
18use super::middleware::{MiddlewareExt, MiddlewareWrapper};
19
20/// Token handler for Bitwarden authentication.
21#[derive(Clone, Default)]
22pub struct PasswordManagerTokenHandler {
23    inner: Arc<RwLock<PasswordManagerTokenHandlerInner>>,
24}
25
26#[derive(Clone, Default)]
27struct PasswordManagerTokenHandlerInner {
28    // Filled in by initialize_middleware.
29    tokens: Option<Setting<AuthenticationTokens>>,
30    login_method: Option<Setting<UserLoginMethod>>,
31    identity_config: Option<bitwarden_api_api::Configuration>,
32}
33
34#[async_trait::async_trait]
35impl TokenHandler for PasswordManagerTokenHandler {
36    fn initialize_middleware(
37        &self,
38        state_registry: &StateRegistry,
39        identity_config: bitwarden_api_api::Configuration,
40        _key_store: KeyStore<KeySlotIds>,
41    ) -> Arc<dyn reqwest_middleware::Middleware> {
42        {
43            let mut inner = self.inner.write().expect("RwLock is not poisoned");
44            inner.tokens = state_registry.setting(AUTHENTICATION_TOKENS).ok();
45            inner.login_method = state_registry.setting(USER_LOGIN_METHOD).ok();
46            inner.identity_config = Some(identity_config);
47        }
48        Arc::new(MiddlewareWrapper::new(self.clone()))
49    }
50
51    async fn set_tokens(
52        &self,
53        access_token: String,
54        refresh_token: Option<String>,
55        expires_in: u64,
56    ) {
57        let tokens = self
58            .inner
59            .read()
60            .expect("RwLock is not poisoned")
61            .tokens
62            .clone();
63
64        if let Some(tokens) = tokens {
65            tokens
66                .update(AuthenticationTokens {
67                    access_token,
68                    refresh_token,
69                    expires_on: Utc::now().timestamp() + expires_in as i64,
70                })
71                .await
72                .ok();
73        }
74    }
75}
76
77#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
78#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
79impl MiddlewareExt for PasswordManagerTokenHandler {
80    async fn current_token(&self) -> Option<(String, i64)> {
81        let inner = self.inner.read().expect("RwLock is not poisoned").clone();
82        let tokens = inner.tokens?.get().await.ok().flatten()?;
83        Some((tokens.access_token, tokens.expires_on))
84    }
85
86    async fn renew_token(&mut self) -> Result<Option<String>, LoginError> {
87        let inner = self.inner.read().expect("RwLock is not poisoned").clone();
88
89        let tokens = inner
90            .tokens
91            .ok_or(NotAuthenticatedError)?
92            .get()
93            .await
94            .ok()
95            .flatten()
96            .ok_or(NotAuthenticatedError)?;
97
98        let login_method = inner.login_method.ok_or(NotAuthenticatedError)?;
99        let identity_config = inner.identity_config.ok_or(NotAuthenticatedError)?;
100
101        let login_method = login_method
102            .get()
103            .await
104            .ok()
105            .flatten()
106            .ok_or(NotAuthenticatedError)?;
107
108        let (access_token, refresh_token, expires_in) =
109            bitwarden_core::auth::renew::renew_pm_token_sdk_managed(
110                tokens.refresh_token,
111                &login_method,
112                identity_config,
113            )
114            .await?;
115
116        self.set_tokens(access_token.clone(), refresh_token, expires_in)
117            .await;
118        Ok(Some(access_token))
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use bitwarden_api_api::apis::AuthRequired;
125    use bitwarden_core::client::{
126        login_method::UserLoginMethod,
127        persisted_state::{AuthenticationTokens, USER_LOGIN_METHOD},
128    };
129    use bitwarden_crypto::Kdf;
130    use bitwarden_state::registry::StateRegistry;
131    use wiremock::MockServer;
132
133    use super::*;
134    use crate::token_management::test_utils::*;
135
136    async fn registry_with_api_key_login() -> StateRegistry {
137        let registry = StateRegistry::new_with_memory_db();
138        registry
139            .setting(USER_LOGIN_METHOD)
140            .unwrap()
141            .update(UserLoginMethod::ApiKey {
142                client_id: "test-client".to_string(),
143                client_secret: "test-secret".to_string(),
144                email: "[email protected]".to_string(),
145                kdf: Kdf::default_pbkdf2(),
146            })
147            .await
148            .unwrap();
149        registry
150    }
151
152    async fn seed_tokens(
153        registry: &StateRegistry,
154        access_token: &str,
155        refresh_token: Option<&str>,
156        expires_in: i64,
157    ) {
158        registry
159            .setting(AUTHENTICATION_TOKENS)
160            .unwrap()
161            .update(AuthenticationTokens {
162                access_token: access_token.to_string(),
163                refresh_token: refresh_token.map(str::to_string),
164                expires_on: Utc::now().timestamp() + expires_in,
165            })
166            .await
167            .unwrap();
168    }
169
170    #[tokio::test]
171    async fn attaches_existing_token_when_not_expired() {
172        let app_server = start_app_server().await;
173        let identity_server = MockServer::start().await;
174
175        let registry = registry_with_api_key_login().await;
176        seed_tokens(&registry, "original-token", Some("refresh"), 5000).await;
177
178        let handler = PasswordManagerTokenHandler::default();
179        let client = build_client(&handler, &registry, &identity_server);
180
181        let auth = send_auth_request(&client, &app_server).await;
182        assert_eq!(auth.as_deref(), Some("Bearer original-token"));
183        assert_eq!(identity_server.received_requests().await.unwrap().len(), 0);
184        assert_eq!(app_server.received_requests().await.unwrap().len(), 1);
185    }
186
187    #[tokio::test]
188    async fn renews_expired_token() {
189        let app_server = start_app_server().await;
190        let identity_server = start_renewal_server("renewed-token").await;
191
192        let registry = registry_with_api_key_login().await;
193        // expires_in=0 puts the token inside the renewal margin.
194        seed_tokens(&registry, "expired-token", Some("old-refresh"), 0).await;
195
196        let handler = PasswordManagerTokenHandler::default();
197        let client = build_client(&handler, &registry, &identity_server);
198
199        let auth = send_auth_request(&client, &app_server).await;
200        assert_eq!(auth.as_deref(), Some("Bearer renewed-token"));
201        assert_eq!(identity_server.received_requests().await.unwrap().len(), 1);
202        assert_eq!(app_server.received_requests().await.unwrap().len(), 1);
203    }
204
205    #[tokio::test]
206    async fn retries_with_renewed_token_on_401() {
207        let app_server = start_app_server_rejecting("stale-token").await;
208        let identity_server = start_renewal_server("renewed-token").await;
209
210        let registry = registry_with_api_key_login().await;
211        // Locally-valid token forces renewal through the 401 retry path.
212        seed_tokens(&registry, "stale-token", Some("refresh"), 5000).await;
213
214        let handler = PasswordManagerTokenHandler::default();
215        let client = build_client(&handler, &registry, &identity_server);
216
217        let response = client
218            .get(format!("{}/test", app_server.uri()))
219            .with_extension(AuthRequired::Bearer)
220            .send()
221            .await
222            .unwrap();
223        assert_eq!(response.status(), 200);
224
225        let requests = app_server.received_requests().await.unwrap();
226        assert_eq!(requests.len(), 2);
227        assert_eq!(
228            requests[0].headers.get("Authorization").unwrap(),
229            "Bearer stale-token"
230        );
231        assert_eq!(
232            requests[1].headers.get("Authorization").unwrap(),
233            "Bearer renewed-token"
234        );
235        assert_eq!(identity_server.received_requests().await.unwrap().len(), 1);
236    }
237
238    #[tokio::test]
239    async fn refreshes_on_retry_when_initial_token_unavailable() {
240        // First identity call fails, so the initial request goes out unauthenticated and the
241        // forced renewal on retry produces a valid token.
242        let app_server = start_app_server_accepting("renewed-token").await;
243        let identity_server = start_renewal_server_failing_then_succeeding("renewed-token").await;
244
245        let registry = registry_with_api_key_login().await;
246        seed_tokens(&registry, "stale-token", Some("refresh"), 0).await;
247
248        let handler = PasswordManagerTokenHandler::default();
249        let client = build_client(&handler, &registry, &identity_server);
250
251        let response = client
252            .get(format!("{}/test", app_server.uri()))
253            .with_extension(AuthRequired::Bearer)
254            .send()
255            .await
256            .unwrap();
257        assert_eq!(response.status(), 200);
258
259        let requests = app_server.received_requests().await.unwrap();
260        assert_eq!(requests.len(), 2);
261        assert!(requests[0].headers.get("Authorization").is_none());
262        assert_eq!(
263            requests[1].headers.get("Authorization").unwrap(),
264            "Bearer renewed-token"
265        );
266        assert_eq!(identity_server.received_requests().await.unwrap().len(), 2);
267    }
268
269    #[tokio::test]
270    async fn concurrent_401s_trigger_a_single_renewal() {
271        // Locally-valid tokens, so renewal only happens via the 401 retry path. Coalescing should
272        // collapse the five retries into a single identity-server call.
273        let app_server = start_app_server_rejecting("stale-token").await;
274        let identity_server =
275            start_renewal_server_with_delay("renewed-token", std::time::Duration::from_millis(100))
276                .await;
277
278        let registry = registry_with_api_key_login().await;
279        seed_tokens(&registry, "stale-token", Some("refresh"), 5000).await;
280
281        let handler = PasswordManagerTokenHandler::default();
282        let client = build_client(&handler, &registry, &identity_server);
283
284        send_concurrent_auth_requests(&client, &app_server, 5).await;
285
286        assert_eq!(identity_server.received_requests().await.unwrap().len(), 1);
287    }
288
289    #[tokio::test]
290    async fn concurrent_requests_trigger_a_single_renewal() {
291        let app_server = start_app_server().await;
292        // Renewal delay so that concurrent renewals would overlap if not serialized.
293        let identity_server =
294            start_renewal_server_with_delay("renewed-token", std::time::Duration::from_millis(100))
295                .await;
296
297        let registry = registry_with_api_key_login().await;
298        seed_tokens(&registry, "expired-token", Some("refresh"), 0).await;
299
300        let handler = PasswordManagerTokenHandler::default();
301        let client = build_client(&handler, &registry, &identity_server);
302
303        send_concurrent_auth_requests(&client, &app_server, 5).await;
304
305        assert_eq!(identity_server.received_requests().await.unwrap().len(), 1);
306        let app_requests = app_server.received_requests().await.unwrap();
307        assert_eq!(app_requests.len(), 5);
308        for req in app_requests {
309            assert_eq!(
310                req.headers.get("Authorization").unwrap(),
311                "Bearer renewed-token"
312            );
313        }
314    }
315}