Skip to main content

bitwarden_auth/token_management/
middleware.rs

1//! Shared utilities for token renewal.
2
3use bitwarden_api_api::apis::AuthRequired;
4use bitwarden_core::auth::login::LoginError;
5use chrono::Utc;
6use reqwest_middleware::Middleware;
7
8const TOKEN_RENEW_MARGIN_SECONDS: i64 = 5 * 60;
9
10/// Bridges a [MiddlewareExt] implementation to [reqwest_middleware::Middleware], which can't be
11/// implemented directly because the trait is defined in an external crate.
12///
13/// The inner [tokio::sync::Mutex] serializes token reads and renewals, ensuring at most one
14/// in-flight renewal and that no request goes out with a potentially-invalidated token while a
15/// renewal is in progress.
16pub(crate) struct MiddlewareWrapper<T>(tokio::sync::Mutex<T>);
17
18impl<T> MiddlewareWrapper<T> {
19    pub(crate) fn new(inner: T) -> Self {
20        Self(tokio::sync::Mutex::new(inner))
21    }
22}
23
24/// Implemented by token handlers to expose stored token state and a renewal hook. The middleware
25/// owns the coalescing decision under the [MiddlewareWrapper] mutex.
26#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
27#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
28pub(crate) trait MiddlewareExt: 'static + Send + Sync {
29    /// Returns the stored access token and its expiration timestamp (Unix seconds), or `None` if
30    /// no token state is available.
31    async fn current_token(&self) -> Option<(String, i64)>;
32
33    /// Renew the access token from the upstream identity service and persist the result.
34    async fn renew_token(&mut self) -> Result<Option<String>, LoginError>;
35}
36
37/// Attaches an auth token (when [AuthRequired] is present) and retries once on 401 with a forced
38/// renewal, in case the server invalidated the token out from under us.
39#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
40#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
41impl<T: MiddlewareExt> Middleware for MiddlewareWrapper<T> {
42    async fn handle(
43        &self,
44        mut req: reqwest::Request,
45        ext: &mut http::Extensions,
46        next: reqwest_middleware::Next<'_>,
47    ) -> Result<reqwest::Response, reqwest_middleware::Error> {
48        let auth_required = match ext.get::<AuthRequired>() {
49            Some(AuthRequired::Bearer) => true,
50            Some(other) => {
51                tracing::warn!(?other, "Unsupported authentication method in request");
52                false
53            }
54            None => false,
55        };
56
57        let attached = if auth_required {
58            attach_header(&mut req, self.resolve_initial().await)
59        } else {
60            None
61        };
62
63        let req_clone = req.try_clone();
64        let result = next.clone().run(req, ext).await?;
65
66        if auth_required
67            && let Some(mut req_clone) = req_clone
68            && result.status() == http::StatusCode::UNAUTHORIZED
69        {
70            tracing::info!("Received 401 response, attempting token refresh and retrying");
71            attach_header(&mut req_clone, self.resolve_retry(attached).await);
72            return next.run(req_clone, ext).await;
73        }
74
75        Ok(result)
76    }
77}
78
79impl<T: MiddlewareExt> MiddlewareWrapper<T> {
80    /// First-attempt resolution: reuse the stored token if it's locally valid, otherwise renew.
81    async fn resolve_initial(&self) -> Result<Option<String>, LoginError> {
82        let mut handler = self.0.lock().await;
83        if let Some((access_token, expires_on)) = handler.current_token().await
84            && Utc::now().timestamp() < expires_on - TOKEN_RENEW_MARGIN_SECONDS
85        {
86            return Ok(Some(access_token));
87        }
88        handler.renew_token().await
89    }
90
91    /// Retry resolution after a 401: if a concurrent retry already renewed (the stored token
92    /// differs from `previous`), reuse that result. Otherwise renew.
93    async fn resolve_retry(&self, previous: Option<String>) -> Result<Option<String>, LoginError> {
94        let mut handler = self.0.lock().await;
95        if let Some((access_token, _)) = handler.current_token().await
96            && let Some(prev) = &previous
97            && access_token != *prev
98        {
99            return Ok(Some(access_token));
100        }
101        handler.renew_token().await
102    }
103}
104
105fn attach_header(
106    req: &mut reqwest::Request,
107    token: Result<Option<String>, LoginError>,
108) -> Option<String> {
109    let token = match token {
110        Ok(Some(t)) => t,
111        Ok(None) => {
112            tracing::warn!("No token available for request requiring authentication");
113            return None;
114        }
115        Err(e) => {
116            tracing::warn!("Failed to get auth token: {e}");
117            return None;
118        }
119    };
120    match format!("Bearer {}", token).parse() {
121        Ok(header_value) => {
122            req.headers_mut()
123                .insert(http::header::AUTHORIZATION, header_value);
124            Some(token)
125        }
126        Err(e) => {
127            tracing::warn!("Failed to parse auth token for header: {e}");
128            None
129        }
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use std::{
136        collections::VecDeque,
137        sync::{Arc, Mutex},
138    };
139
140    use bitwarden_api_api::{apis::AuthRequired, new_http_client};
141    use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
142    use wiremock::MockServer;
143
144    use super::*;
145
146    #[derive(Default)]
147    struct MockState {
148        /// Current stored token and its expiry.
149        current: Option<(String, i64)>,
150        /// Queue of renewal results.
151        renewals: VecDeque<Result<Option<String>, LoginError>>,
152        /// Number of `renew_token` calls.
153        renew_count: usize,
154    }
155
156    struct MockMiddleware {
157        state: Arc<Mutex<MockState>>,
158    }
159
160    #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
161    #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
162    impl MiddlewareExt for MockMiddleware {
163        async fn current_token(&self) -> Option<(String, i64)> {
164            self.state.lock().unwrap().current.clone()
165        }
166
167        async fn renew_token(&mut self) -> Result<Option<String>, LoginError> {
168            let mut state = self.state.lock().unwrap();
169            state.renew_count += 1;
170            let result = state
171                .renewals
172                .pop_front()
173                .expect("Not enough mock renewals provided for test");
174            if let Ok(Some(ref token)) = result {
175                state.current = Some((token.clone(), Utc::now().timestamp() + 3600));
176            }
177            result
178        }
179    }
180
181    fn build_mock_client(
182        initial: Option<&str>,
183        renewals: Vec<Result<Option<String>, LoginError>>,
184    ) -> (ClientWithMiddleware, Arc<Mutex<MockState>>) {
185        let state = Arc::new(Mutex::new(MockState {
186            current: initial.map(|t| (t.to_string(), Utc::now().timestamp() + 3600)),
187            renewals: renewals.into_iter().collect(),
188            renew_count: 0,
189        }));
190        let ext = MockMiddleware {
191            state: state.clone(),
192        };
193
194        let client = ClientBuilder::new(new_http_client())
195            .with_arc(Arc::new(MiddlewareWrapper::new(ext)))
196            .build();
197        (client, state)
198    }
199
200    async fn start_server_returning(status: u16) -> MockServer {
201        let server = MockServer::start().await;
202        wiremock::Mock::given(wiremock::matchers::any())
203            .respond_with(wiremock::ResponseTemplate::new(status))
204            .mount(&server)
205            .await;
206        server
207    }
208
209    #[tokio::test]
210    async fn does_not_renew_when_no_auth_extension() {
211        let server = start_server_returning(401).await;
212        let (client, state) = build_mock_client(None, vec![]);
213
214        let response = client
215            .get(format!("{}/test", server.uri()))
216            .send()
217            .await
218            .unwrap();
219        assert_eq!(response.status(), 401);
220
221        let requests = server.received_requests().await.unwrap();
222        assert_eq!(requests.len(), 1);
223        assert!(requests[0].headers.get("Authorization").is_none());
224        assert_eq!(state.lock().unwrap().renew_count, 0);
225    }
226
227    #[tokio::test]
228    async fn retries_with_renewed_token_on_401() {
229        let server = MockServer::start().await;
230        wiremock::Mock::given(wiremock::matchers::header("Authorization", "Bearer stale"))
231            .respond_with(wiremock::ResponseTemplate::new(401))
232            .mount(&server)
233            .await;
234        wiremock::Mock::given(wiremock::matchers::any())
235            .respond_with(wiremock::ResponseTemplate::new(200))
236            .mount(&server)
237            .await;
238
239        let (client, state) = build_mock_client(Some("stale"), vec![Ok(Some("fresh".into()))]);
240
241        let response = client
242            .get(format!("{}/test", server.uri()))
243            .with_extension(AuthRequired::Bearer)
244            .send()
245            .await
246            .unwrap();
247        assert_eq!(response.status(), 200);
248
249        assert_eq!(state.lock().unwrap().renew_count, 1);
250        let requests = server.received_requests().await.unwrap();
251        assert_eq!(requests.len(), 2);
252        assert_eq!(
253            requests[0].headers.get("Authorization").unwrap(),
254            "Bearer stale"
255        );
256        assert_eq!(
257            requests[1].headers.get("Authorization").unwrap(),
258            "Bearer fresh"
259        );
260    }
261
262    #[tokio::test]
263    async fn surfaces_second_401_without_third_attempt() {
264        let server = start_server_returning(401).await;
265        let (client, state) = build_mock_client(Some("token-a"), vec![Ok(Some("token-b".into()))]);
266
267        let response = client
268            .get(format!("{}/test", server.uri()))
269            .with_extension(AuthRequired::Bearer)
270            .send()
271            .await
272            .unwrap();
273        assert_eq!(response.status(), 401);
274
275        assert_eq!(server.received_requests().await.unwrap().len(), 2);
276        assert_eq!(state.lock().unwrap().renew_count, 1);
277    }
278
279    #[tokio::test]
280    async fn does_not_retry_on_non_401_status() {
281        let server = start_server_returning(500).await;
282        let (client, state) = build_mock_client(Some("token"), vec![]);
283
284        let response = client
285            .get(format!("{}/test", server.uri()))
286            .with_extension(AuthRequired::Bearer)
287            .send()
288            .await
289            .unwrap();
290        assert_eq!(response.status(), 500);
291
292        assert_eq!(server.received_requests().await.unwrap().len(), 1);
293        assert_eq!(state.lock().unwrap().renew_count, 0);
294    }
295
296    #[tokio::test]
297    async fn does_not_retry_when_body_cannot_be_cloned() {
298        let server = start_server_returning(401).await;
299        let (client, state) = build_mock_client(Some("token"), vec![]);
300
301        // Body::wrap forces the Streaming variant, for which try_clone returns None.
302        let streaming_body = reqwest::Body::wrap(reqwest::Body::from("payload"));
303
304        let response = client
305            .post(format!("{}/test", server.uri()))
306            .with_extension(AuthRequired::Bearer)
307            .body(streaming_body)
308            .send()
309            .await
310            .unwrap();
311        assert_eq!(response.status(), 401);
312
313        assert_eq!(server.received_requests().await.unwrap().len(), 1);
314        assert_eq!(state.lock().unwrap().renew_count, 0);
315    }
316}