Skip to main content

bitwarden_server_communication_config/
middleware.rs

1#[cfg(not(target_arch = "wasm32"))]
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use reqwest::header::HeaderValue;
6use reqwest_middleware::{Middleware, Next, Result};
7
8use crate::CookieProvider;
9
10/// Middleware that injects SSO load balancer cookies and re-acquires them on 302/307.
11///
12/// Must be outermost in the middleware chain so it observes raw 3xx responses
13/// before auth middleware. Auto-redirect must be disabled on the underlying
14/// reqwest::Client.
15///
16/// On WASM targets, uses a proactive strategy: checks `needs_bootstrap()` before
17/// each request, acquires cookies if needed, injects them, then sends. This is
18/// required because `reqwest::redirect::Policy` is unavailable on WASM and the
19/// browser auto-follows redirects, making reactive 302/307 detection impossible.
20///
21/// # Security
22///
23/// Cookie values are NEVER logged.
24pub struct ServerCommunicationConfigMiddleware {
25    provider: Arc<dyn CookieProvider>,
26    /// Tracks in-flight cookie acquisitions per hostname to prevent duplicate concurrent
27    /// SSO acquisition flows. When a task is acquiring for a hostname, other tasks wait
28    /// on the Notify rather than starting a redundant acquisition.
29    ///
30    /// Not used on WASM targets (single-threaded, proactive strategy).
31    #[cfg(not(target_arch = "wasm32"))]
32    in_flight: Arc<tokio::sync::Mutex<HashMap<String, Arc<tokio::sync::Notify>>>>,
33}
34
35impl ServerCommunicationConfigMiddleware {
36    /// Creates a new middleware instance wrapping the given cookie provider.
37    pub fn new(provider: Arc<dyn CookieProvider>) -> Self {
38        Self {
39            provider,
40            #[cfg(not(target_arch = "wasm32"))]
41            in_flight: Default::default(),
42        }
43    }
44}
45
46impl Clone for ServerCommunicationConfigMiddleware {
47    fn clone(&self) -> Self {
48        Self {
49            provider: Arc::clone(&self.provider),
50            #[cfg(not(target_arch = "wasm32"))]
51            in_flight: Arc::clone(&self.in_flight),
52        }
53    }
54}
55
56#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
57#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
58impl Middleware for ServerCommunicationConfigMiddleware {
59    async fn handle(
60        &self,
61        mut req: reqwest::Request,
62        extensions: &mut http::Extensions,
63        next: Next<'_>,
64    ) -> Result<reqwest::Response> {
65        // Extract hostname -- pass through unchanged if URL has no host.
66        let hostname = req.url().host_str().map(|h| h.to_string());
67
68        // WASM: proactive strategy -- check bootstrap before sending, inject cookies, then send.
69        // Reactive 302/307 detection is impossible on WASM because reqwest::redirect::Policy is
70        // unavailable and the browser auto-follows redirects.
71        #[cfg(target_arch = "wasm32")]
72        {
73            if let Some(ref hostname) = hostname {
74                if self.provider.needs_bootstrap(hostname).await {
75                    let _ = self.provider.acquire_cookie(hostname).await;
76                }
77                inject_cookies(&mut req, self.provider.cookies(hostname).await);
78            }
79            return next.run(req, extensions).await;
80        }
81
82        // Non-WASM: reactive strategy -- send request, detect 302/307, acquire cookie, retry.
83        #[cfg(not(target_arch = "wasm32"))]
84        {
85            let hostname = match hostname {
86                Some(h) => h,
87                None => {
88                    return next.run(req, extensions).await;
89                }
90            };
91
92            // Clone the request before forwarding so we can retry on 302/307.
93            // try_clone() returns None for streaming bodies; in that case we
94            // cannot retry and will return the redirect response as-is.
95            let req_clone = req.try_clone();
96
97            // Inject stored cookies into the Cookie header.
98            inject_cookies(&mut req, self.provider.cookies(&hostname).await);
99
100            // Forward the request.
101            let response = next.clone().run(req, extensions).await?;
102
103            // On 302 or 307: check if bootstrap is needed, acquire fresh cookie, and retry.
104            let status = response.status();
105            if status == reqwest::StatusCode::FOUND
106                || status == reqwest::StatusCode::TEMPORARY_REDIRECT
107            {
108                tracing::debug!(
109                    %status,
110                    hostname = %hostname,
111                    "Cookie middleware: intercepting redirect"
112                );
113
114                // Only acquire if bootstrap is required for this hostname.
115                // Mirrors clients needsBootstrap$ check; avoids spurious acquisition
116                // for redirects unrelated to SSO cookie bootstrapping.
117                if !self.provider.needs_bootstrap(&hostname).await {
118                    tracing::debug!(
119                        hostname = %hostname,
120                        "Cookie middleware: bootstrap not required, returning redirect"
121                    );
122                    return Ok(response);
123                }
124
125                // Deduplicate concurrent acquisition attempts for the same hostname.
126                // If another task is already acquiring, wait for it rather than starting
127                // a redundant SSO flow. Mirrors clients pendingAcquisition deduplication.
128                //
129                // We clone the Arc and call enable() while the lock is held so the
130                // Notified future is registered as a waiter before the acquirer can
131                // call notify_waiters(). Without this, a notify_waiters() call between
132                // lock release and the first poll of .notified() would be missed,
133                // causing the waiter to block forever.
134                let should_acquire = {
135                    let mut in_flight = self.in_flight.lock().await;
136                    if let Some(notify) = in_flight.get(&hostname) {
137                        // Another task is already acquiring. Clone the Arc so Notify
138                        // outlives the guard, then enable before releasing the lock.
139                        let notify = Arc::clone(notify);
140                        let notified = notify.notified();
141                        let mut notified = std::pin::pin!(notified);
142                        notified.as_mut().enable();
143                        drop(in_flight);
144                        notified.await;
145                        false
146                    } else {
147                        // We are the first -- register as the acquirer.
148                        in_flight.insert(hostname.clone(), Arc::new(tokio::sync::Notify::new()));
149                        true
150                    }
151                };
152
153                if should_acquire {
154                    // Acquire the new cookie (best-effort; log warning on failure).
155                    if let Err(e) = self.provider.acquire_cookie(&hostname).await {
156                        tracing::warn!(
157                            hostname = %hostname,
158                            error = ?e,
159                            "Cookie middleware: cookie acquisition failed"
160                        );
161                    }
162
163                    // Signal all waiters and remove from in-flight map.
164                    let mut in_flight = self.in_flight.lock().await;
165                    if let Some(notify) = in_flight.remove(&hostname) {
166                        notify.notify_waiters();
167                    }
168                }
169
170                // Retry with the cloned request if available (acquisition complete).
171                if let Some(mut retry_req) = req_clone {
172                    // Re-inject fresh cookies onto the retry request.
173                    inject_cookies(&mut retry_req, self.provider.cookies(&hostname).await);
174                    return next.run(retry_req, extensions).await;
175                }
176                // No clone available (streaming body) -- return the redirect response.
177            }
178
179            Ok(response)
180        }
181    }
182}
183
184/// Injects cookie name-value pairs as a Cookie header on the request.
185/// Skips injection if the cookie list is empty.
186/// Cookie values are NOT logged.
187fn inject_cookies(req: &mut reqwest::Request, cookies: Vec<(String, String)>) {
188    if cookies.is_empty() {
189        return;
190    }
191    let cookie_header = cookies
192        .iter()
193        .map(|(name, value)| format!("{name}={value}"))
194        .collect::<Vec<_>>()
195        .join("; ");
196    if let Ok(header_value) = HeaderValue::from_str(&cookie_header) {
197        req.headers_mut()
198            .insert(reqwest::header::COOKIE, header_value);
199    } else {
200        tracing::warn!("Cookie middleware: failed to encode cookie header (invalid characters)");
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use crate::{AcquireCookieError, CookieProvider};
208
209    // Compile-time test: Arc<dyn CookieProvider> wrapping a non-Clone type
210    // must still allow ServerCommunicationConfigMiddleware to be cloned via Arc::clone.
211    struct NoClonemockProvider;
212
213    #[async_trait::async_trait]
214    impl CookieProvider for NoClonemockProvider {
215        async fn cookies(&self, _hostname: &str) -> Vec<(String, String)> {
216            vec![]
217        }
218
219        async fn acquire_cookie(
220            &self,
221            _hostname: &str,
222        ) -> std::result::Result<(), AcquireCookieError> {
223            Ok(())
224        }
225
226        async fn needs_bootstrap(&self, _hostname: &str) -> bool {
227            false
228        }
229    }
230
231    #[test]
232    fn middleware_clone_does_not_require_cookie_provider_clone() {
233        let arc: Arc<dyn CookieProvider> = Arc::new(NoClonemockProvider);
234        let middleware = ServerCommunicationConfigMiddleware::new(arc);
235        // Compilation of this line is the assertion.
236        let _cloned = middleware.clone();
237    }
238
239    #[test]
240    fn inject_cookies_formats_header_correctly() {
241        let mut req = reqwest::Request::new(
242            reqwest::Method::GET,
243            "https://vault.example.com/api".parse().unwrap(),
244        );
245        inject_cookies(
246            &mut req,
247            vec![
248                ("name1".to_string(), "val1".to_string()),
249                ("name2".to_string(), "val2".to_string()),
250            ],
251        );
252        let header = req
253            .headers()
254            .get(reqwest::header::COOKIE)
255            .expect("Cookie header should be set");
256        assert_eq!(header.to_str().unwrap(), "name1=val1; name2=val2");
257    }
258
259    #[test]
260    fn inject_cookies_skips_when_empty() {
261        let mut req = reqwest::Request::new(
262            reqwest::Method::GET,
263            "https://vault.example.com/api".parse().unwrap(),
264        );
265        inject_cookies(&mut req, vec![]);
266        assert!(
267            req.headers().get(reqwest::header::COOKIE).is_none(),
268            "Cookie header should not be set when no cookies"
269        );
270    }
271}