bitwarden_auth/token_management/
secrets_manager_token_handler.rs1use std::sync::{Arc, RwLock};
4
5use bitwarden_core::{
6 NotAuthenticatedError, OrganizationId,
7 auth::{TokenHandler, login::LoginError},
8 client::login_method::ServiceAccountLoginMethod,
9 key_management::KeySlotIds,
10};
11use bitwarden_crypto::KeyStore;
12use bitwarden_state::registry::StateRegistry;
13use chrono::Utc;
14
15use super::middleware::{MiddlewareExt, MiddlewareWrapper};
16
17#[derive(Clone, Default)]
19pub struct SecretsManagerTokenHandler {
20 inner: Arc<RwLock<SecretsManagerTokenHandlerInner>>,
21}
22
23#[derive(Clone, Default)]
24struct SecretsManagerTokenHandlerInner {
25 access_token: Option<String>,
26 expires_on: Option<i64>,
27
28 login_method: Option<Arc<ServiceAccountLoginMethod>>,
30 identity_config: Option<bitwarden_api_api::Configuration>,
31 key_store: Option<KeyStore<KeySlotIds>>,
32}
33
34#[async_trait::async_trait]
35impl TokenHandler for SecretsManagerTokenHandler {
36 fn initialize_middleware(
37 &self,
38 _state_registry: &StateRegistry,
39 identity_config: bitwarden_api_api::Configuration,
40 key_store: KeyStore<KeySlotIds>,
41 ) -> Arc<dyn reqwest_middleware::Middleware> {
42 {
43 let mut inner = self.inner.write().expect("RwLock is not poisoned");
44 inner.identity_config = Some(identity_config);
45 inner.key_store = Some(key_store);
46 }
47 Arc::new(MiddlewareWrapper::new(self.clone()))
48 }
49
50 async fn set_tokens(
51 &self,
52 access_token: String,
53 _refresh_token: Option<String>,
54 expires_in: u64,
55 ) {
56 let mut inner = self.inner.write().expect("RwLock is not poisoned");
57 inner.access_token = Some(access_token);
58 inner.expires_on = Some(Utc::now().timestamp() + expires_in as i64);
59 }
60
61 async fn set_sm_login_method(&self, login_method: ServiceAccountLoginMethod) {
62 let mut inner = self.inner.write().expect("RwLock is not poisoned");
63 inner.login_method = Some(Arc::new(login_method));
64 }
65}
66
67impl SecretsManagerTokenHandler {
68 pub fn get_access_token_organization(&self) -> Option<OrganizationId> {
70 let inner = self.inner.read().ok()?;
71 match inner.login_method.as_deref()? {
72 ServiceAccountLoginMethod::AccessToken {
73 organization_id, ..
74 } => Some(*organization_id),
75 }
76 }
77}
78
79#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
80#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
81impl MiddlewareExt for SecretsManagerTokenHandler {
82 async fn current_token(&self) -> Option<(String, i64)> {
83 let inner = self.inner.read().expect("RwLock is not poisoned").clone();
84 Some((inner.access_token?, inner.expires_on?))
85 }
86
87 async fn renew_token(&mut self) -> Result<Option<String>, LoginError> {
88 let inner = self.inner.read().expect("RwLock is not poisoned").clone();
89
90 let login_method = inner.login_method.ok_or(NotAuthenticatedError)?;
91 let identity_config = inner.identity_config.ok_or(NotAuthenticatedError)?;
92 let key_store = inner.key_store.ok_or(NotAuthenticatedError)?;
93
94 let (access_token, refresh_token, expires_in) =
95 bitwarden_core::auth::renew::renew_sm_token_sdk_managed(
96 login_method.as_ref(),
97 identity_config,
98 key_store,
99 )
100 .await?;
101
102 self.set_tokens(access_token.clone(), refresh_token, expires_in)
103 .await;
104 Ok(Some(access_token))
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use std::str::FromStr;
111
112 use bitwarden_api_api::apis::AuthRequired;
113 use bitwarden_core::{auth::AccessToken, client::login_method::ServiceAccountLoginMethod};
114 use bitwarden_state::registry::StateRegistry;
115 use wiremock::MockServer;
116
117 use super::*;
118 use crate::token_management::test_utils::*;
119
120 fn service_account_login_method() -> ServiceAccountLoginMethod {
121 let access_token = AccessToken::from_str(
122 "0.ec2c1d46-6a4b-4751-a310-af9601317f2d.C2IgxjjLF7qSshsbwe8JGcbM075YXw:X8vbvA0bduihIDe/qrzIQQ==",
123 )
124 .unwrap();
125
126 ServiceAccountLoginMethod::AccessToken {
127 access_token,
128 organization_id: "00000000-0000-0000-0000-000000000001".parse().unwrap(),
129 state_file: None,
130 }
131 }
132
133 #[tokio::test]
134 async fn attaches_existing_token_when_not_expired() {
135 let app_server = start_app_server().await;
136 let identity_server = MockServer::start().await;
137
138 let handler = SecretsManagerTokenHandler::default();
139 handler
140 .set_sm_login_method(service_account_login_method())
141 .await;
142 handler
143 .set_tokens("original-token".to_string(), None, 3600)
144 .await;
145
146 let registry = StateRegistry::new_with_memory_db();
147 let client = build_client(&handler, ®istry, &identity_server);
148
149 let auth = send_auth_request(&client, &app_server).await;
150 assert_eq!(auth.as_deref(), Some("Bearer original-token"));
151 assert_eq!(identity_server.received_requests().await.unwrap().len(), 0);
152 assert_eq!(app_server.received_requests().await.unwrap().len(), 1);
153 }
154
155 #[tokio::test]
156 async fn renews_expired_token() {
157 let app_server = start_app_server().await;
158 let identity_server = start_renewal_server("renewed-token").await;
159
160 let handler = SecretsManagerTokenHandler::default();
161 handler
162 .set_sm_login_method(service_account_login_method())
163 .await;
164 handler
166 .set_tokens("expired-token".to_string(), None, 0)
167 .await;
168
169 let registry = StateRegistry::new_with_memory_db();
170 let client = build_client(&handler, ®istry, &identity_server);
171
172 let auth = send_auth_request(&client, &app_server).await;
173 assert_eq!(auth.as_deref(), Some("Bearer renewed-token"));
174 assert_eq!(identity_server.received_requests().await.unwrap().len(), 1);
175 assert_eq!(app_server.received_requests().await.unwrap().len(), 1);
176 }
177
178 #[tokio::test]
179 async fn retries_with_renewed_token_on_401() {
180 let app_server = start_app_server_rejecting("stale-token").await;
181 let identity_server = start_renewal_server("renewed-token").await;
182
183 let handler = SecretsManagerTokenHandler::default();
184 handler
185 .set_sm_login_method(service_account_login_method())
186 .await;
187 handler
189 .set_tokens("stale-token".to_string(), None, 3600)
190 .await;
191
192 let registry = StateRegistry::new_with_memory_db();
193 let client = build_client(&handler, ®istry, &identity_server);
194
195 let response = client
196 .get(format!("{}/test", app_server.uri()))
197 .with_extension(AuthRequired::Bearer)
198 .send()
199 .await
200 .unwrap();
201 assert_eq!(response.status(), 200);
202
203 let requests = app_server.received_requests().await.unwrap();
204 assert_eq!(requests.len(), 2);
205 assert_eq!(
206 requests[0].headers.get("Authorization").unwrap(),
207 "Bearer stale-token"
208 );
209 assert_eq!(
210 requests[1].headers.get("Authorization").unwrap(),
211 "Bearer renewed-token"
212 );
213 assert_eq!(identity_server.received_requests().await.unwrap().len(), 1);
214 }
215
216 #[tokio::test]
217 async fn refreshes_on_retry_when_initial_token_unavailable() {
218 let app_server = start_app_server_accepting("renewed-token").await;
221 let identity_server = start_renewal_server_failing_then_succeeding("renewed-token").await;
222
223 let handler = SecretsManagerTokenHandler::default();
224 handler
225 .set_sm_login_method(service_account_login_method())
226 .await;
227
228 let registry = StateRegistry::new_with_memory_db();
229 let client = build_client(&handler, ®istry, &identity_server);
230
231 let response = client
232 .get(format!("{}/test", app_server.uri()))
233 .with_extension(AuthRequired::Bearer)
234 .send()
235 .await
236 .unwrap();
237 assert_eq!(response.status(), 200);
238
239 let requests = app_server.received_requests().await.unwrap();
240 assert_eq!(requests.len(), 2);
241 assert!(requests[0].headers.get("Authorization").is_none());
242 assert_eq!(
243 requests[1].headers.get("Authorization").unwrap(),
244 "Bearer renewed-token"
245 );
246 assert_eq!(identity_server.received_requests().await.unwrap().len(), 2);
247 }
248
249 #[tokio::test]
250 async fn concurrent_401s_trigger_a_single_renewal() {
251 let app_server = start_app_server_rejecting("stale-token").await;
254 let identity_server =
255 start_renewal_server_with_delay("renewed-token", std::time::Duration::from_millis(100))
256 .await;
257
258 let handler = SecretsManagerTokenHandler::default();
259 handler
260 .set_sm_login_method(service_account_login_method())
261 .await;
262 handler
263 .set_tokens("stale-token".to_string(), None, 3600)
264 .await;
265
266 let registry = StateRegistry::new_with_memory_db();
267 let client = build_client(&handler, ®istry, &identity_server);
268
269 send_concurrent_auth_requests(&client, &app_server, 5).await;
270
271 assert_eq!(identity_server.received_requests().await.unwrap().len(), 1);
272 }
273
274 #[tokio::test]
275 async fn concurrent_requests_trigger_a_single_renewal() {
276 let app_server = start_app_server().await;
277 let identity_server =
279 start_renewal_server_with_delay("renewed-token", std::time::Duration::from_millis(100))
280 .await;
281
282 let handler = SecretsManagerTokenHandler::default();
283 handler
284 .set_sm_login_method(service_account_login_method())
285 .await;
286 handler
287 .set_tokens("expired-token".to_string(), None, 0)
288 .await;
289
290 let registry = StateRegistry::new_with_memory_db();
291 let client = build_client(&handler, ®istry, &identity_server);
292
293 send_concurrent_auth_requests(&client, &app_server, 5).await;
294
295 assert_eq!(identity_server.received_requests().await.unwrap().len(), 1);
296 let app_requests = app_server.received_requests().await.unwrap();
297 assert_eq!(app_requests.len(), 5);
298 for req in app_requests {
299 assert_eq!(
300 req.headers.get("Authorization").unwrap(),
301 "Bearer renewed-token"
302 );
303 }
304 }
305}