Skip to main content

bitwarden_auth/token_management/
secrets_manager_token_handler.rs

1//! Token handler implementation for Bitwarden Secrets Manager authentication.
2
3use std::sync::{Arc, RwLock};
4
5use bitwarden_core::{
6    NotAuthenticatedError, OrganizationId,
7    auth::{TokenHandler, login::LoginError},
8    client::login_method::ServiceAccountLoginMethod,
9    key_management::KeySlotIds,
10};
11use bitwarden_crypto::KeyStore;
12use bitwarden_state::registry::StateRegistry;
13use chrono::Utc;
14
15use super::middleware::{MiddlewareExt, MiddlewareWrapper};
16use crate::token_management::middleware::TOKEN_RENEW_MARGIN_SECONDS;
17
18/// Token handler for Bitwarden authentication.
19#[derive(Clone, Default)]
20pub struct SecretsManagerTokenHandler {
21    inner: Arc<RwLock<SecretsManagerTokenHandlerInner>>,
22}
23
24#[derive(Clone, Default)]
25struct SecretsManagerTokenHandlerInner {
26    access_token: Option<String>,
27    expires_on: Option<i64>,
28
29    // The following are passed as optional as they are filled in when instantiating the
30    // middleware.
31    login_method: Option<Arc<ServiceAccountLoginMethod>>,
32    identity_config: Option<bitwarden_api_api::Configuration>,
33    key_store: Option<KeyStore<KeySlotIds>>,
34}
35
36#[async_trait::async_trait]
37impl TokenHandler for SecretsManagerTokenHandler {
38    fn initialize_middleware(
39        &self,
40        _state_registry: &StateRegistry,
41        identity_config: bitwarden_api_api::Configuration,
42        key_store: KeyStore<KeySlotIds>,
43    ) -> Arc<dyn reqwest_middleware::Middleware> {
44        {
45            let mut inner = self.inner.write().expect("RwLock is not poisoned");
46            inner.identity_config = Some(identity_config);
47            inner.key_store = Some(key_store);
48        }
49        Arc::new(MiddlewareWrapper(self.clone()))
50    }
51
52    async fn set_tokens(
53        &self,
54        access_token: String,
55        _refresh_token: Option<String>,
56        expires_in: u64,
57    ) {
58        let mut inner = self.inner.write().expect("RwLock is not poisoned");
59        inner.access_token = Some(access_token);
60        inner.expires_on = Some(Utc::now().timestamp() + expires_in as i64);
61    }
62
63    async fn set_sm_login_method(&self, login_method: ServiceAccountLoginMethod) {
64        let mut inner = self.inner.write().expect("RwLock is not poisoned");
65        inner.login_method = Some(Arc::new(login_method));
66    }
67}
68
69impl SecretsManagerTokenHandler {
70    /// Get the organization ID associated with the current access token, if available.
71    pub fn get_access_token_organization(&self) -> Option<OrganizationId> {
72        let inner = self.inner.read().ok()?;
73        match inner.login_method.as_deref()? {
74            ServiceAccountLoginMethod::AccessToken {
75                organization_id, ..
76            } => Some(*organization_id),
77        }
78    }
79}
80
81#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
82#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
83impl MiddlewareExt for SecretsManagerTokenHandler {
84    async fn get_token(&self) -> Result<Option<String>, LoginError> {
85        // We're not holding on to a lock for the duration of the token renewal, so if multiple
86        // requests come in at the same time when the token is expired, we may end up renewing the
87        // token multiple times. This is not ideal, but it's the behavior of the previous
88        // implementation. We should be able to introduce an async semaphore or something
89        // similar to prevent this if it becomes an issue in practice.
90        let inner = self.inner.read().expect("RwLock is not poisoned").clone();
91
92        // Validate the token, returning early if it's still valid.
93        if let Some(expires) = inner.expires_on
94            && Utc::now().timestamp() < expires - TOKEN_RENEW_MARGIN_SECONDS
95        {
96            return Ok(inner.access_token.clone());
97        }
98
99        // These should always be set by initialize_middleware / set_sm_login_method before we get
100        // here, but we return an error if not.
101        let login_method = inner.login_method.ok_or(NotAuthenticatedError)?;
102        let identity_config = inner.identity_config.ok_or(NotAuthenticatedError)?;
103        let key_store = inner.key_store.ok_or(NotAuthenticatedError)?;
104
105        let (access_token, refresh_token, expires_in) =
106            bitwarden_core::auth::renew::renew_sm_token_sdk_managed(
107                login_method.as_ref(),
108                identity_config,
109                key_store,
110            )
111            .await?;
112
113        self.set_tokens(access_token.clone(), refresh_token, expires_in)
114            .await;
115        Ok(Some(access_token))
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use std::str::FromStr;
122
123    use bitwarden_core::{
124        auth::{AccessToken, TokenHandler},
125        client::login_method::ServiceAccountLoginMethod,
126        key_management::KeySlotIds,
127    };
128    use bitwarden_crypto::KeyStore;
129    use bitwarden_state::registry::StateRegistry;
130    use wiremock::MockServer;
131
132    use super::*;
133    use crate::token_management::test_utils::*;
134
135    fn service_account_login_method() -> ServiceAccountLoginMethod {
136        let access_token = AccessToken::from_str(
137            "0.ec2c1d46-6a4b-4751-a310-af9601317f2d.C2IgxjjLF7qSshsbwe8JGcbM075YXw:X8vbvA0bduihIDe/qrzIQQ==",
138        )
139        .unwrap();
140
141        ServiceAccountLoginMethod::AccessToken {
142            access_token,
143            organization_id: "00000000-0000-0000-0000-000000000001".parse().unwrap(),
144            state_file: None,
145        }
146    }
147
148    #[tokio::test]
149    async fn attaches_existing_token_when_not_expired() {
150        let app_server = start_app_server().await;
151        let identity_server = MockServer::start().await;
152
153        let handler = SecretsManagerTokenHandler::default();
154        handler
155            .set_sm_login_method(service_account_login_method())
156            .await;
157        handler
158            .set_tokens("original-token".to_string(), None, 3600)
159            .await;
160
161        let registry = StateRegistry::new_with_memory_db();
162        let client = build_client(handler.initialize_middleware(
163            &registry,
164            identity_config(&identity_server.uri()),
165            KeyStore::<KeySlotIds>::default(),
166        ));
167
168        let auth = send_auth_request(&client, &app_server).await;
169        assert_eq!(auth.as_deref(), Some("Bearer original-token"));
170        assert_eq!(identity_server.received_requests().await.unwrap().len(), 0);
171        assert_eq!(app_server.received_requests().await.unwrap().len(), 1);
172    }
173
174    #[tokio::test]
175    async fn renews_expired_token() {
176        let app_server = start_app_server().await;
177        let identity_server = start_renewal_server("renewed-token").await;
178
179        let handler = SecretsManagerTokenHandler::default();
180        handler
181            .set_sm_login_method(service_account_login_method())
182            .await;
183        // expires_in=0 means the token is immediately considered expired
184        handler
185            .set_tokens("expired-token".to_string(), None, 0)
186            .await;
187
188        let registry = StateRegistry::new_with_memory_db();
189        let client = build_client(handler.initialize_middleware(
190            &registry,
191            identity_config(&identity_server.uri()),
192            KeyStore::<KeySlotIds>::default(),
193        ));
194
195        let auth = send_auth_request(&client, &app_server).await;
196        assert_eq!(auth.as_deref(), Some("Bearer renewed-token"));
197        assert_eq!(identity_server.received_requests().await.unwrap().len(), 1);
198        assert_eq!(app_server.received_requests().await.unwrap().len(), 1);
199    }
200}