bitwarden_auth/token_management/
middleware.rs1use bitwarden_api_api::apis::AuthRequired;
4use bitwarden_core::auth::login::LoginError;
5use chrono::Utc;
6use reqwest_middleware::Middleware;
7
8const TOKEN_RENEW_MARGIN_SECONDS: i64 = 5 * 60;
9
10pub(crate) struct MiddlewareWrapper<T>(tokio::sync::Mutex<T>);
17
18impl<T> MiddlewareWrapper<T> {
19 pub(crate) fn new(inner: T) -> Self {
20 Self(tokio::sync::Mutex::new(inner))
21 }
22}
23
24#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
27#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
28pub(crate) trait MiddlewareExt: 'static + Send + Sync {
29 async fn current_token(&self) -> Option<(String, i64)>;
32
33 async fn renew_token(&mut self) -> Result<Option<String>, LoginError>;
35}
36
37#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
40#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
41impl<T: MiddlewareExt> Middleware for MiddlewareWrapper<T> {
42 async fn handle(
43 &self,
44 mut req: reqwest::Request,
45 ext: &mut http::Extensions,
46 next: reqwest_middleware::Next<'_>,
47 ) -> Result<reqwest::Response, reqwest_middleware::Error> {
48 let auth_required = match ext.get::<AuthRequired>() {
49 Some(AuthRequired::Bearer) => true,
50 Some(other) => {
51 tracing::warn!(?other, "Unsupported authentication method in request");
52 false
53 }
54 None => false,
55 };
56
57 let attached = if auth_required {
58 attach_header(&mut req, self.resolve_initial().await)
59 } else {
60 None
61 };
62
63 let req_clone = req.try_clone();
64 let result = next.clone().run(req, ext).await?;
65
66 if auth_required
67 && let Some(mut req_clone) = req_clone
68 && result.status() == http::StatusCode::UNAUTHORIZED
69 {
70 tracing::info!("Received 401 response, attempting token refresh and retrying");
71 attach_header(&mut req_clone, self.resolve_retry(attached).await);
72 return next.run(req_clone, ext).await;
73 }
74
75 Ok(result)
76 }
77}
78
79impl<T: MiddlewareExt> MiddlewareWrapper<T> {
80 async fn resolve_initial(&self) -> Result<Option<String>, LoginError> {
82 let mut handler = self.0.lock().await;
83 if let Some((access_token, expires_on)) = handler.current_token().await
84 && Utc::now().timestamp() < expires_on - TOKEN_RENEW_MARGIN_SECONDS
85 {
86 return Ok(Some(access_token));
87 }
88 handler.renew_token().await
89 }
90
91 async fn resolve_retry(&self, previous: Option<String>) -> Result<Option<String>, LoginError> {
94 let mut handler = self.0.lock().await;
95 if let Some((access_token, _)) = handler.current_token().await
96 && let Some(prev) = &previous
97 && access_token != *prev
98 {
99 return Ok(Some(access_token));
100 }
101 handler.renew_token().await
102 }
103}
104
105fn attach_header(
106 req: &mut reqwest::Request,
107 token: Result<Option<String>, LoginError>,
108) -> Option<String> {
109 let token = match token {
110 Ok(Some(t)) => t,
111 Ok(None) => {
112 tracing::warn!("No token available for request requiring authentication");
113 return None;
114 }
115 Err(e) => {
116 tracing::warn!("Failed to get auth token: {e}");
117 return None;
118 }
119 };
120 match format!("Bearer {}", token).parse() {
121 Ok(header_value) => {
122 req.headers_mut()
123 .insert(http::header::AUTHORIZATION, header_value);
124 Some(token)
125 }
126 Err(e) => {
127 tracing::warn!("Failed to parse auth token for header: {e}");
128 None
129 }
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use std::{
136 collections::VecDeque,
137 sync::{Arc, Mutex},
138 };
139
140 use bitwarden_api_api::{apis::AuthRequired, new_http_client};
141 use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
142 use wiremock::MockServer;
143
144 use super::*;
145
146 #[derive(Default)]
147 struct MockState {
148 current: Option<(String, i64)>,
150 renewals: VecDeque<Result<Option<String>, LoginError>>,
152 renew_count: usize,
154 }
155
156 struct MockMiddleware {
157 state: Arc<Mutex<MockState>>,
158 }
159
160 #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
161 #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
162 impl MiddlewareExt for MockMiddleware {
163 async fn current_token(&self) -> Option<(String, i64)> {
164 self.state.lock().unwrap().current.clone()
165 }
166
167 async fn renew_token(&mut self) -> Result<Option<String>, LoginError> {
168 let mut state = self.state.lock().unwrap();
169 state.renew_count += 1;
170 let result = state
171 .renewals
172 .pop_front()
173 .expect("Not enough mock renewals provided for test");
174 if let Ok(Some(ref token)) = result {
175 state.current = Some((token.clone(), Utc::now().timestamp() + 3600));
176 }
177 result
178 }
179 }
180
181 fn build_mock_client(
182 initial: Option<&str>,
183 renewals: Vec<Result<Option<String>, LoginError>>,
184 ) -> (ClientWithMiddleware, Arc<Mutex<MockState>>) {
185 let state = Arc::new(Mutex::new(MockState {
186 current: initial.map(|t| (t.to_string(), Utc::now().timestamp() + 3600)),
187 renewals: renewals.into_iter().collect(),
188 renew_count: 0,
189 }));
190 let ext = MockMiddleware {
191 state: state.clone(),
192 };
193
194 let client = ClientBuilder::new(new_http_client())
195 .with_arc(Arc::new(MiddlewareWrapper::new(ext)))
196 .build();
197 (client, state)
198 }
199
200 async fn start_server_returning(status: u16) -> MockServer {
201 let server = MockServer::start().await;
202 wiremock::Mock::given(wiremock::matchers::any())
203 .respond_with(wiremock::ResponseTemplate::new(status))
204 .mount(&server)
205 .await;
206 server
207 }
208
209 #[tokio::test]
210 async fn does_not_renew_when_no_auth_extension() {
211 let server = start_server_returning(401).await;
212 let (client, state) = build_mock_client(None, vec![]);
213
214 let response = client
215 .get(format!("{}/test", server.uri()))
216 .send()
217 .await
218 .unwrap();
219 assert_eq!(response.status(), 401);
220
221 let requests = server.received_requests().await.unwrap();
222 assert_eq!(requests.len(), 1);
223 assert!(requests[0].headers.get("Authorization").is_none());
224 assert_eq!(state.lock().unwrap().renew_count, 0);
225 }
226
227 #[tokio::test]
228 async fn retries_with_renewed_token_on_401() {
229 let server = MockServer::start().await;
230 wiremock::Mock::given(wiremock::matchers::header("Authorization", "Bearer stale"))
231 .respond_with(wiremock::ResponseTemplate::new(401))
232 .mount(&server)
233 .await;
234 wiremock::Mock::given(wiremock::matchers::any())
235 .respond_with(wiremock::ResponseTemplate::new(200))
236 .mount(&server)
237 .await;
238
239 let (client, state) = build_mock_client(Some("stale"), vec![Ok(Some("fresh".into()))]);
240
241 let response = client
242 .get(format!("{}/test", server.uri()))
243 .with_extension(AuthRequired::Bearer)
244 .send()
245 .await
246 .unwrap();
247 assert_eq!(response.status(), 200);
248
249 assert_eq!(state.lock().unwrap().renew_count, 1);
250 let requests = server.received_requests().await.unwrap();
251 assert_eq!(requests.len(), 2);
252 assert_eq!(
253 requests[0].headers.get("Authorization").unwrap(),
254 "Bearer stale"
255 );
256 assert_eq!(
257 requests[1].headers.get("Authorization").unwrap(),
258 "Bearer fresh"
259 );
260 }
261
262 #[tokio::test]
263 async fn surfaces_second_401_without_third_attempt() {
264 let server = start_server_returning(401).await;
265 let (client, state) = build_mock_client(Some("token-a"), vec![Ok(Some("token-b".into()))]);
266
267 let response = client
268 .get(format!("{}/test", server.uri()))
269 .with_extension(AuthRequired::Bearer)
270 .send()
271 .await
272 .unwrap();
273 assert_eq!(response.status(), 401);
274
275 assert_eq!(server.received_requests().await.unwrap().len(), 2);
276 assert_eq!(state.lock().unwrap().renew_count, 1);
277 }
278
279 #[tokio::test]
280 async fn does_not_retry_on_non_401_status() {
281 let server = start_server_returning(500).await;
282 let (client, state) = build_mock_client(Some("token"), vec![]);
283
284 let response = client
285 .get(format!("{}/test", server.uri()))
286 .with_extension(AuthRequired::Bearer)
287 .send()
288 .await
289 .unwrap();
290 assert_eq!(response.status(), 500);
291
292 assert_eq!(server.received_requests().await.unwrap().len(), 1);
293 assert_eq!(state.lock().unwrap().renew_count, 0);
294 }
295
296 #[tokio::test]
297 async fn does_not_retry_when_body_cannot_be_cloned() {
298 let server = start_server_returning(401).await;
299 let (client, state) = build_mock_client(Some("token"), vec![]);
300
301 let streaming_body = reqwest::Body::wrap(reqwest::Body::from("payload"));
303
304 let response = client
305 .post(format!("{}/test", server.uri()))
306 .with_extension(AuthRequired::Bearer)
307 .body(streaming_body)
308 .send()
309 .await
310 .unwrap();
311 assert_eq!(response.status(), 401);
312
313 assert_eq!(server.received_requests().await.unwrap().len(), 1);
314 assert_eq!(state.lock().unwrap().renew_count, 0);
315 }
316}