Skip to main content

bitwarden_auth/token_management/
secrets_manager_token_handler.rs

1//! Token handler implementation for Bitwarden Secrets Manager authentication.
2
3use std::sync::{Arc, RwLock};
4
5use bitwarden_core::{
6    NotAuthenticatedError, OrganizationId,
7    auth::{TokenHandler, login::LoginError},
8    client::login_method::ServiceAccountLoginMethod,
9    key_management::KeySlotIds,
10};
11use bitwarden_crypto::KeyStore;
12use bitwarden_state::registry::StateRegistry;
13use chrono::Utc;
14
15use super::middleware::{MiddlewareExt, MiddlewareWrapper};
16
17/// Token handler for Bitwarden authentication.
18#[derive(Clone, Default)]
19pub struct SecretsManagerTokenHandler {
20    inner: Arc<RwLock<SecretsManagerTokenHandlerInner>>,
21}
22
23#[derive(Clone, Default)]
24struct SecretsManagerTokenHandlerInner {
25    access_token: Option<String>,
26    expires_on: Option<i64>,
27
28    // Filled in by initialize_middleware / set_sm_login_method.
29    login_method: Option<Arc<ServiceAccountLoginMethod>>,
30    identity_config: Option<bitwarden_api_api::Configuration>,
31    key_store: Option<KeyStore<KeySlotIds>>,
32}
33
34#[async_trait::async_trait]
35impl TokenHandler for SecretsManagerTokenHandler {
36    fn initialize_middleware(
37        &self,
38        _state_registry: &StateRegistry,
39        identity_config: bitwarden_api_api::Configuration,
40        key_store: KeyStore<KeySlotIds>,
41    ) -> Arc<dyn reqwest_middleware::Middleware> {
42        {
43            let mut inner = self.inner.write().expect("RwLock is not poisoned");
44            inner.identity_config = Some(identity_config);
45            inner.key_store = Some(key_store);
46        }
47        Arc::new(MiddlewareWrapper::new(self.clone()))
48    }
49
50    async fn set_tokens(
51        &self,
52        access_token: String,
53        _refresh_token: Option<String>,
54        expires_in: u64,
55    ) {
56        let mut inner = self.inner.write().expect("RwLock is not poisoned");
57        inner.access_token = Some(access_token);
58        inner.expires_on = Some(Utc::now().timestamp() + expires_in as i64);
59    }
60
61    async fn set_sm_login_method(&self, login_method: ServiceAccountLoginMethod) {
62        let mut inner = self.inner.write().expect("RwLock is not poisoned");
63        inner.login_method = Some(Arc::new(login_method));
64    }
65}
66
67impl SecretsManagerTokenHandler {
68    /// Get the organization ID associated with the current access token, if available.
69    pub fn get_access_token_organization(&self) -> Option<OrganizationId> {
70        let inner = self.inner.read().ok()?;
71        match inner.login_method.as_deref()? {
72            ServiceAccountLoginMethod::AccessToken {
73                organization_id, ..
74            } => Some(*organization_id),
75        }
76    }
77}
78
79#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
80#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
81impl MiddlewareExt for SecretsManagerTokenHandler {
82    async fn current_token(&self) -> Option<(String, i64)> {
83        let inner = self.inner.read().expect("RwLock is not poisoned").clone();
84        Some((inner.access_token?, inner.expires_on?))
85    }
86
87    async fn renew_token(&mut self) -> Result<Option<String>, LoginError> {
88        let inner = self.inner.read().expect("RwLock is not poisoned").clone();
89
90        let login_method = inner.login_method.ok_or(NotAuthenticatedError)?;
91        let identity_config = inner.identity_config.ok_or(NotAuthenticatedError)?;
92        let key_store = inner.key_store.ok_or(NotAuthenticatedError)?;
93
94        let (access_token, refresh_token, expires_in) =
95            bitwarden_core::auth::renew::renew_sm_token_sdk_managed(
96                login_method.as_ref(),
97                identity_config,
98                key_store,
99            )
100            .await?;
101
102        self.set_tokens(access_token.clone(), refresh_token, expires_in)
103            .await;
104        Ok(Some(access_token))
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use std::str::FromStr;
111
112    use bitwarden_api_api::apis::AuthRequired;
113    use bitwarden_core::{auth::AccessToken, client::login_method::ServiceAccountLoginMethod};
114    use bitwarden_state::registry::StateRegistry;
115    use wiremock::MockServer;
116
117    use super::*;
118    use crate::token_management::test_utils::*;
119
120    fn service_account_login_method() -> ServiceAccountLoginMethod {
121        let access_token = AccessToken::from_str(
122            "0.ec2c1d46-6a4b-4751-a310-af9601317f2d.C2IgxjjLF7qSshsbwe8JGcbM075YXw:X8vbvA0bduihIDe/qrzIQQ==",
123        )
124        .unwrap();
125
126        ServiceAccountLoginMethod::AccessToken {
127            access_token,
128            organization_id: "00000000-0000-0000-0000-000000000001".parse().unwrap(),
129            state_file: None,
130        }
131    }
132
133    #[tokio::test]
134    async fn attaches_existing_token_when_not_expired() {
135        let app_server = start_app_server().await;
136        let identity_server = MockServer::start().await;
137
138        let handler = SecretsManagerTokenHandler::default();
139        handler
140            .set_sm_login_method(service_account_login_method())
141            .await;
142        handler
143            .set_tokens("original-token".to_string(), None, 3600)
144            .await;
145
146        let registry = StateRegistry::new_with_memory_db();
147        let client = build_client(&handler, &registry, &identity_server);
148
149        let auth = send_auth_request(&client, &app_server).await;
150        assert_eq!(auth.as_deref(), Some("Bearer original-token"));
151        assert_eq!(identity_server.received_requests().await.unwrap().len(), 0);
152        assert_eq!(app_server.received_requests().await.unwrap().len(), 1);
153    }
154
155    #[tokio::test]
156    async fn renews_expired_token() {
157        let app_server = start_app_server().await;
158        let identity_server = start_renewal_server("renewed-token").await;
159
160        let handler = SecretsManagerTokenHandler::default();
161        handler
162            .set_sm_login_method(service_account_login_method())
163            .await;
164        // expires_in=0 puts the token inside the renewal margin.
165        handler
166            .set_tokens("expired-token".to_string(), None, 0)
167            .await;
168
169        let registry = StateRegistry::new_with_memory_db();
170        let client = build_client(&handler, &registry, &identity_server);
171
172        let auth = send_auth_request(&client, &app_server).await;
173        assert_eq!(auth.as_deref(), Some("Bearer renewed-token"));
174        assert_eq!(identity_server.received_requests().await.unwrap().len(), 1);
175        assert_eq!(app_server.received_requests().await.unwrap().len(), 1);
176    }
177
178    #[tokio::test]
179    async fn retries_with_renewed_token_on_401() {
180        let app_server = start_app_server_rejecting("stale-token").await;
181        let identity_server = start_renewal_server("renewed-token").await;
182
183        let handler = SecretsManagerTokenHandler::default();
184        handler
185            .set_sm_login_method(service_account_login_method())
186            .await;
187        // Locally-valid token forces renewal through the 401 retry path.
188        handler
189            .set_tokens("stale-token".to_string(), None, 3600)
190            .await;
191
192        let registry = StateRegistry::new_with_memory_db();
193        let client = build_client(&handler, &registry, &identity_server);
194
195        let response = client
196            .get(format!("{}/test", app_server.uri()))
197            .with_extension(AuthRequired::Bearer)
198            .send()
199            .await
200            .unwrap();
201        assert_eq!(response.status(), 200);
202
203        let requests = app_server.received_requests().await.unwrap();
204        assert_eq!(requests.len(), 2);
205        assert_eq!(
206            requests[0].headers.get("Authorization").unwrap(),
207            "Bearer stale-token"
208        );
209        assert_eq!(
210            requests[1].headers.get("Authorization").unwrap(),
211            "Bearer renewed-token"
212        );
213        assert_eq!(identity_server.received_requests().await.unwrap().len(), 1);
214    }
215
216    #[tokio::test]
217    async fn refreshes_on_retry_when_initial_token_unavailable() {
218        // First identity call fails, so the initial request goes out unauthenticated and the
219        // forced renewal on retry produces a valid token.
220        let app_server = start_app_server_accepting("renewed-token").await;
221        let identity_server = start_renewal_server_failing_then_succeeding("renewed-token").await;
222
223        let handler = SecretsManagerTokenHandler::default();
224        handler
225            .set_sm_login_method(service_account_login_method())
226            .await;
227
228        let registry = StateRegistry::new_with_memory_db();
229        let client = build_client(&handler, &registry, &identity_server);
230
231        let response = client
232            .get(format!("{}/test", app_server.uri()))
233            .with_extension(AuthRequired::Bearer)
234            .send()
235            .await
236            .unwrap();
237        assert_eq!(response.status(), 200);
238
239        let requests = app_server.received_requests().await.unwrap();
240        assert_eq!(requests.len(), 2);
241        assert!(requests[0].headers.get("Authorization").is_none());
242        assert_eq!(
243            requests[1].headers.get("Authorization").unwrap(),
244            "Bearer renewed-token"
245        );
246        assert_eq!(identity_server.received_requests().await.unwrap().len(), 2);
247    }
248
249    #[tokio::test]
250    async fn concurrent_401s_trigger_a_single_renewal() {
251        // Locally-valid tokens, so renewal only happens via the 401 retry path. Coalescing should
252        // collapse the five retries into a single identity-server call.
253        let app_server = start_app_server_rejecting("stale-token").await;
254        let identity_server =
255            start_renewal_server_with_delay("renewed-token", std::time::Duration::from_millis(100))
256                .await;
257
258        let handler = SecretsManagerTokenHandler::default();
259        handler
260            .set_sm_login_method(service_account_login_method())
261            .await;
262        handler
263            .set_tokens("stale-token".to_string(), None, 3600)
264            .await;
265
266        let registry = StateRegistry::new_with_memory_db();
267        let client = build_client(&handler, &registry, &identity_server);
268
269        send_concurrent_auth_requests(&client, &app_server, 5).await;
270
271        assert_eq!(identity_server.received_requests().await.unwrap().len(), 1);
272    }
273
274    #[tokio::test]
275    async fn concurrent_requests_trigger_a_single_renewal() {
276        let app_server = start_app_server().await;
277        // Renewal delay so that concurrent renewals would overlap if not serialized.
278        let identity_server =
279            start_renewal_server_with_delay("renewed-token", std::time::Duration::from_millis(100))
280                .await;
281
282        let handler = SecretsManagerTokenHandler::default();
283        handler
284            .set_sm_login_method(service_account_login_method())
285            .await;
286        handler
287            .set_tokens("expired-token".to_string(), None, 0)
288            .await;
289
290        let registry = StateRegistry::new_with_memory_db();
291        let client = build_client(&handler, &registry, &identity_server);
292
293        send_concurrent_auth_requests(&client, &app_server, 5).await;
294
295        assert_eq!(identity_server.received_requests().await.unwrap().len(), 1);
296        let app_requests = app_server.received_requests().await.unwrap();
297        assert_eq!(app_requests.len(), 5);
298        for req in app_requests {
299            assert_eq!(
300                req.headers.get("Authorization").unwrap(),
301                "Bearer renewed-token"
302            );
303        }
304    }
305}