bitwarden_server_communication_config/middleware.rs
1#[cfg(not(target_arch = "wasm32"))]
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use reqwest::header::HeaderValue;
6use reqwest_middleware::{Middleware, Next, Result};
7
8use crate::CookieProvider;
9
10/// Middleware that injects SSO load balancer cookies and re-acquires them on 302/307.
11///
12/// Must be outermost in the middleware chain so it observes raw 3xx responses
13/// before auth middleware. Auto-redirect must be disabled on the underlying
14/// reqwest::Client.
15///
16/// On WASM targets, uses a proactive strategy: checks `needs_bootstrap()` before
17/// each request, acquires cookies if needed, injects them, then sends. This is
18/// required because `reqwest::redirect::Policy` is unavailable on WASM and the
19/// browser auto-follows redirects, making reactive 302/307 detection impossible.
20///
21/// # Security
22///
23/// Cookie values are NEVER logged.
24pub struct ServerCommunicationConfigMiddleware {
25 provider: Arc<dyn CookieProvider>,
26 /// Tracks in-flight cookie acquisitions per hostname to prevent duplicate concurrent
27 /// SSO acquisition flows. When a task is acquiring for a hostname, other tasks wait
28 /// on the Notify rather than starting a redundant acquisition.
29 ///
30 /// Not used on WASM targets (single-threaded, proactive strategy).
31 #[cfg(not(target_arch = "wasm32"))]
32 in_flight: Arc<tokio::sync::Mutex<HashMap<String, Arc<tokio::sync::Notify>>>>,
33}
34
35impl ServerCommunicationConfigMiddleware {
36 /// Creates a new middleware instance wrapping the given cookie provider.
37 pub fn new(provider: Arc<dyn CookieProvider>) -> Self {
38 Self {
39 provider,
40 #[cfg(not(target_arch = "wasm32"))]
41 in_flight: Default::default(),
42 }
43 }
44}
45
46impl Clone for ServerCommunicationConfigMiddleware {
47 fn clone(&self) -> Self {
48 Self {
49 provider: Arc::clone(&self.provider),
50 #[cfg(not(target_arch = "wasm32"))]
51 in_flight: Arc::clone(&self.in_flight),
52 }
53 }
54}
55
56#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
57#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
58impl Middleware for ServerCommunicationConfigMiddleware {
59 async fn handle(
60 &self,
61 mut req: reqwest::Request,
62 extensions: &mut http::Extensions,
63 next: Next<'_>,
64 ) -> Result<reqwest::Response> {
65 // Extract hostname -- pass through unchanged if URL has no host.
66 let hostname = req.url().host_str().map(|h| h.to_string());
67
68 // WASM: proactive strategy -- check bootstrap before sending, inject cookies, then send.
69 // Reactive 302/307 detection is impossible on WASM because reqwest::redirect::Policy is
70 // unavailable and the browser auto-follows redirects.
71 #[cfg(target_arch = "wasm32")]
72 {
73 if let Some(ref hostname) = hostname {
74 if self.provider.needs_bootstrap(hostname).await {
75 let _ = self.provider.acquire_cookie(hostname).await;
76 }
77 inject_cookies(&mut req, self.provider.cookies(hostname).await);
78 }
79 return next.run(req, extensions).await;
80 }
81
82 // Non-WASM: reactive strategy -- send request, detect 302/307, acquire cookie, retry.
83 #[cfg(not(target_arch = "wasm32"))]
84 {
85 let hostname = match hostname {
86 Some(h) => h,
87 None => {
88 return next.run(req, extensions).await;
89 }
90 };
91
92 // Clone the request before forwarding so we can retry on 302/307.
93 // try_clone() returns None for streaming bodies; in that case we
94 // cannot retry and will return the redirect response as-is.
95 let req_clone = req.try_clone();
96
97 // Inject stored cookies into the Cookie header.
98 inject_cookies(&mut req, self.provider.cookies(&hostname).await);
99
100 // Forward the request.
101 let response = next.clone().run(req, extensions).await?;
102
103 // On 302 or 307: check if bootstrap is needed, acquire fresh cookie, and retry.
104 let status = response.status();
105 if status == reqwest::StatusCode::FOUND
106 || status == reqwest::StatusCode::TEMPORARY_REDIRECT
107 {
108 tracing::debug!(
109 %status,
110 hostname = %hostname,
111 "Cookie middleware: intercepting redirect"
112 );
113
114 // Only acquire if bootstrap is required for this hostname.
115 // Mirrors clients needsBootstrap$ check; avoids spurious acquisition
116 // for redirects unrelated to SSO cookie bootstrapping.
117 if !self.provider.needs_bootstrap(&hostname).await {
118 tracing::debug!(
119 hostname = %hostname,
120 "Cookie middleware: bootstrap not required, returning redirect"
121 );
122 return Ok(response);
123 }
124
125 // Deduplicate concurrent acquisition attempts for the same hostname.
126 // If another task is already acquiring, wait for it rather than starting
127 // a redundant SSO flow. Mirrors clients pendingAcquisition deduplication.
128 //
129 // We clone the Arc and call enable() while the lock is held so the
130 // Notified future is registered as a waiter before the acquirer can
131 // call notify_waiters(). Without this, a notify_waiters() call between
132 // lock release and the first poll of .notified() would be missed,
133 // causing the waiter to block forever.
134 let should_acquire = {
135 let mut in_flight = self.in_flight.lock().await;
136 if let Some(notify) = in_flight.get(&hostname) {
137 // Another task is already acquiring. Clone the Arc so Notify
138 // outlives the guard, then enable before releasing the lock.
139 let notify = Arc::clone(notify);
140 let notified = notify.notified();
141 let mut notified = std::pin::pin!(notified);
142 notified.as_mut().enable();
143 drop(in_flight);
144 notified.await;
145 false
146 } else {
147 // We are the first -- register as the acquirer.
148 in_flight.insert(hostname.clone(), Arc::new(tokio::sync::Notify::new()));
149 true
150 }
151 };
152
153 if should_acquire {
154 // Acquire the new cookie (best-effort; log warning on failure).
155 if let Err(e) = self.provider.acquire_cookie(&hostname).await {
156 tracing::warn!(
157 hostname = %hostname,
158 error = ?e,
159 "Cookie middleware: cookie acquisition failed"
160 );
161 }
162
163 // Signal all waiters and remove from in-flight map.
164 let mut in_flight = self.in_flight.lock().await;
165 if let Some(notify) = in_flight.remove(&hostname) {
166 notify.notify_waiters();
167 }
168 }
169
170 // Retry with the cloned request if available (acquisition complete).
171 if let Some(mut retry_req) = req_clone {
172 // Re-inject fresh cookies onto the retry request.
173 inject_cookies(&mut retry_req, self.provider.cookies(&hostname).await);
174 return next.run(retry_req, extensions).await;
175 }
176 // No clone available (streaming body) -- return the redirect response.
177 }
178
179 Ok(response)
180 }
181 }
182}
183
184/// Injects cookie name-value pairs as a Cookie header on the request.
185/// Skips injection if the cookie list is empty.
186/// Cookie values are NOT logged.
187fn inject_cookies(req: &mut reqwest::Request, cookies: Vec<(String, String)>) {
188 if cookies.is_empty() {
189 return;
190 }
191 let cookie_header = cookies
192 .iter()
193 .map(|(name, value)| format!("{name}={value}"))
194 .collect::<Vec<_>>()
195 .join("; ");
196 if let Ok(header_value) = HeaderValue::from_str(&cookie_header) {
197 req.headers_mut()
198 .insert(reqwest::header::COOKIE, header_value);
199 } else {
200 tracing::warn!("Cookie middleware: failed to encode cookie header (invalid characters)");
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use crate::{AcquireCookieError, CookieProvider};
208
209 // Compile-time test: Arc<dyn CookieProvider> wrapping a non-Clone type
210 // must still allow ServerCommunicationConfigMiddleware to be cloned via Arc::clone.
211 struct NoClonemockProvider;
212
213 #[async_trait::async_trait]
214 impl CookieProvider for NoClonemockProvider {
215 async fn cookies(&self, _hostname: &str) -> Vec<(String, String)> {
216 vec![]
217 }
218
219 async fn acquire_cookie(
220 &self,
221 _hostname: &str,
222 ) -> std::result::Result<(), AcquireCookieError> {
223 Ok(())
224 }
225
226 async fn needs_bootstrap(&self, _hostname: &str) -> bool {
227 false
228 }
229 }
230
231 #[test]
232 fn middleware_clone_does_not_require_cookie_provider_clone() {
233 let arc: Arc<dyn CookieProvider> = Arc::new(NoClonemockProvider);
234 let middleware = ServerCommunicationConfigMiddleware::new(arc);
235 // Compilation of this line is the assertion.
236 let _cloned = middleware.clone();
237 }
238
239 #[test]
240 fn inject_cookies_formats_header_correctly() {
241 let mut req = reqwest::Request::new(
242 reqwest::Method::GET,
243 "https://vault.example.com/api".parse().unwrap(),
244 );
245 inject_cookies(
246 &mut req,
247 vec![
248 ("name1".to_string(), "val1".to_string()),
249 ("name2".to_string(), "val2".to_string()),
250 ],
251 );
252 let header = req
253 .headers()
254 .get(reqwest::header::COOKIE)
255 .expect("Cookie header should be set");
256 assert_eq!(header.to_str().unwrap(), "name1=val1; name2=val2");
257 }
258
259 #[test]
260 fn inject_cookies_skips_when_empty() {
261 let mut req = reqwest::Request::new(
262 reqwest::Method::GET,
263 "https://vault.example.com/api".parse().unwrap(),
264 );
265 inject_cookies(&mut req, vec![]);
266 assert!(
267 req.headers().get(reqwest::header::COOKIE).is_none(),
268 "Cookie header should not be set when no cookies"
269 );
270 }
271}