Skip to main content

bitwarden_auth/token_management/
password_manager_token_handler.rs

1//! Token handler implementation for Bitwarden Password Manager authentication.
2
3use std::sync::{Arc, RwLock};
4
5use bitwarden_core::{
6    NotAuthenticatedError,
7    auth::{TokenHandler, login::LoginError},
8    client::login_method::LoginMethod,
9    key_management::KeySlotIds,
10};
11use bitwarden_crypto::KeyStore;
12use chrono::Utc;
13
14use super::middleware::{MiddlewareExt, MiddlewareWrapper};
15use crate::token_management::middleware::TOKEN_RENEW_MARGIN_SECONDS;
16
17/// Token handler for Bitwarden authentication.
18#[derive(Clone, Default)]
19pub struct PasswordManagerTokenHandler {
20    inner: Arc<RwLock<PasswordManagerTokenHandlerInner>>,
21}
22
23#[derive(Clone, Default)]
24struct PasswordManagerTokenHandlerInner {
25    access_token: Option<String>,
26    expires_on: Option<i64>,
27
28    refresh_token: Option<String>,
29
30    // The following are passed as optional as they are filled in when instantiating the
31    // middleware.
32    login_method: Option<Arc<RwLock<Option<Arc<LoginMethod>>>>>,
33    identity_config: Option<bitwarden_api_api::Configuration>,
34}
35
36#[async_trait::async_trait]
37impl TokenHandler for PasswordManagerTokenHandler {
38    fn initialize_middleware(
39        &self,
40        login_method: Arc<RwLock<Option<Arc<LoginMethod>>>>,
41        identity_config: bitwarden_api_api::Configuration,
42        _key_store: KeyStore<KeySlotIds>,
43    ) -> Arc<dyn reqwest_middleware::Middleware> {
44        {
45            let mut inner = self.inner.write().expect("RwLock is not poisoned");
46            inner.login_method = Some(login_method);
47            inner.identity_config = Some(identity_config);
48        }
49        Arc::new(MiddlewareWrapper(self.clone()))
50    }
51
52    async fn set_tokens(
53        &self,
54        access_token: String,
55        refresh_token: Option<String>,
56        expires_in: u64,
57    ) {
58        let mut inner = self.inner.write().expect("RwLock is not poisoned");
59        inner.access_token = Some(access_token);
60        inner.refresh_token = refresh_token;
61        inner.expires_on = Some(Utc::now().timestamp() + expires_in as i64);
62    }
63}
64
65#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
66#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
67impl MiddlewareExt for PasswordManagerTokenHandler {
68    async fn get_token(&self) -> Result<Option<String>, LoginError> {
69        // We're not holding on to a lock for the duration of the token renewal, so if multiple
70        // requests come in at the same time when the token is expired, we may end up renewing the
71        // token multiple times. This is not ideal, but it's the behavior of the previous
72        // implementation. We should be able to introduce an async semaphore or something
73        // similar to prevent this if it becomes an issue in practice.
74        let inner = self.inner.read().expect("RwLock is not poisoned").clone();
75
76        // Validate the token, returning early if it's still valid.
77        if let Some(expires) = inner.expires_on
78            && Utc::now().timestamp() < expires - TOKEN_RENEW_MARGIN_SECONDS
79        {
80            return Ok(inner.access_token.clone());
81        }
82
83        // These should always be set by initialize_middleware before we get here, but we return an
84        // error if not.
85        let login_method = inner.login_method.ok_or(NotAuthenticatedError)?;
86        let identity_config = inner.identity_config.ok_or(NotAuthenticatedError)?;
87
88        let login_method = login_method
89            .read()
90            .expect("RwLock is not poisoned")
91            .clone()
92            .ok_or(NotAuthenticatedError)?;
93
94        #[allow(irrefutable_let_patterns)]
95        let LoginMethod::User(user_login_method) = login_method.as_ref() else {
96            return Err(NotAuthenticatedError.into());
97        };
98
99        let (access_token, refresh_token, expires_in) =
100            bitwarden_core::auth::renew::renew_pm_token_sdk_managed(
101                inner.refresh_token,
102                user_login_method,
103                identity_config,
104            )
105            .await?;
106
107        self.set_tokens(access_token.clone(), refresh_token, expires_in)
108            .await;
109        Ok(Some(access_token))
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use std::sync::{Arc, RwLock};
116
117    use bitwarden_core::{
118        auth::TokenHandler,
119        client::login_method::{LoginMethod, UserLoginMethod},
120        key_management::KeySlotIds,
121    };
122    use bitwarden_crypto::{Kdf, KeyStore};
123    use wiremock::MockServer;
124
125    use super::*;
126    use crate::token_management::test_utils::*;
127
128    fn api_key_login_method() -> Arc<RwLock<Option<Arc<LoginMethod>>>> {
129        Arc::new(RwLock::new(Some(Arc::new(LoginMethod::User(
130            UserLoginMethod::ApiKey {
131                client_id: "test-client".to_string(),
132                client_secret: "test-secret".to_string(),
133                email: "[email protected]".to_string(),
134                kdf: Kdf::default_pbkdf2(),
135            },
136        )))))
137    }
138
139    #[tokio::test]
140    async fn attaches_existing_token_when_not_expired() {
141        let app_server = start_app_server().await;
142        let identity_server = MockServer::start().await;
143
144        let handler = PasswordManagerTokenHandler::default();
145        handler
146            .set_tokens(
147                "original-token".to_string(),
148                Some("refresh".to_string()),
149                5000,
150            )
151            .await;
152        let client = build_client(handler.initialize_middleware(
153            api_key_login_method(),
154            identity_config(&identity_server.uri()),
155            KeyStore::<KeySlotIds>::default(),
156        ));
157
158        let auth = send_auth_request(&client, &app_server).await;
159        assert_eq!(auth.as_deref(), Some("Bearer original-token"));
160        assert_eq!(identity_server.received_requests().await.unwrap().len(), 0);
161        assert_eq!(app_server.received_requests().await.unwrap().len(), 1);
162    }
163
164    #[tokio::test]
165    async fn renews_expired_token() {
166        let app_server = start_app_server().await;
167        let identity_server = start_renewal_server("renewed-token").await;
168
169        let handler = PasswordManagerTokenHandler::default();
170        // expires_in=0 means the token is considered expired as it's less than the margin
171        handler
172            .set_tokens(
173                "expired-token".to_string(),
174                Some("old-refresh".to_string()),
175                0,
176            )
177            .await;
178
179        let client = build_client(handler.initialize_middleware(
180            api_key_login_method(),
181            identity_config(&identity_server.uri()),
182            KeyStore::<KeySlotIds>::default(),
183        ));
184
185        let auth = send_auth_request(&client, &app_server).await;
186        assert_eq!(auth.as_deref(), Some("Bearer renewed-token"));
187        assert_eq!(identity_server.received_requests().await.unwrap().len(), 1);
188        assert_eq!(app_server.received_requests().await.unwrap().len(), 1);
189    }
190}