Skip to main content

bitwarden_server_communication_config/
cookie_provider.rs

1use crate::{
2    AcquireCookieError, ServerCommunicationConfigClient, ServerCommunicationConfigPlatformApi,
3    ServerCommunicationConfigRepository,
4};
5
6/// Abstraction for acquiring and retrieving SSO load balancer cookies.
7///
8/// Allows bitwarden-core to request cookies without depending on
9/// bitwarden-server-communication-config. Middleware holds `Arc<dyn CookieProvider>`.
10///
11/// # Security
12///
13/// Implementors MUST NOT log cookie values. Cookie values are sensitive SSO tokens.
14#[async_trait::async_trait]
15pub trait CookieProvider: 'static + Send + Sync {
16    /// Returns stored cookies for the given hostname as name-value pairs.
17    ///
18    /// Returns an empty vec when no cookies are stored for the hostname.
19    async fn cookies(&self, hostname: &str) -> Vec<(String, String)>;
20
21    /// Acquires a fresh cookie from the platform and stores it for the given hostname.
22    ///
23    /// Triggers the platform-specific SSO acquisition flow (e.g., WebView redirect).
24    async fn acquire_cookie(&self, hostname: &str) -> Result<(), AcquireCookieError>;
25
26    /// Returns true if the hostname requires SSO cookie bootstrapping.
27    ///
28    /// Returns false when no configuration exists for the hostname, or when
29    /// the configuration is `Direct` (no cookie required). Middleware uses this
30    /// to skip acquisition on redirects unrelated to SSO bootstrapping.
31    async fn needs_bootstrap(&self, hostname: &str) -> bool;
32}
33
34#[async_trait::async_trait]
35impl<R, P> CookieProvider for ServerCommunicationConfigClient<R, P>
36where
37    R: ServerCommunicationConfigRepository + Send + 'static,
38    P: ServerCommunicationConfigPlatformApi + Send + 'static,
39{
40    async fn cookies(&self, hostname: &str) -> Vec<(String, String)> {
41        #[allow(deprecated)]
42        self.cookies(hostname.to_string()).await
43    }
44
45    async fn acquire_cookie(&self, hostname: &str) -> Result<(), AcquireCookieError> {
46        self.acquire_cookie(hostname).await.map(|_| ())
47    }
48
49    async fn needs_bootstrap(&self, hostname: &str) -> bool {
50        self.needs_bootstrap(hostname.to_string()).await
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use std::{collections::HashMap, sync::Arc};
57
58    use tokio::sync::RwLock;
59
60    use super::*;
61    use crate::{
62        AcquiredCookie, BootstrapConfig, ServerCommunicationConfig,
63        ServerCommunicationConfigPlatformApi, ServerCommunicationConfigRepository,
64        SsoCookieVendorConfig,
65    };
66
67    /// Mock in-memory repository for testing
68    #[derive(Default, Clone)]
69    struct MockRepository {
70        storage: Arc<RwLock<HashMap<String, ServerCommunicationConfig>>>,
71    }
72
73    impl ServerCommunicationConfigRepository for MockRepository {
74        type GetError = ();
75        type SaveError = ();
76
77        async fn get(&self, hostname: String) -> Result<Option<ServerCommunicationConfig>, ()> {
78            Ok(self.storage.read().await.get(&hostname).cloned())
79        }
80
81        async fn save(
82            &self,
83            hostname: String,
84            config: ServerCommunicationConfig,
85        ) -> Result<(), ()> {
86            self.storage.write().await.insert(hostname, config);
87            Ok(())
88        }
89    }
90
91    /// Mock platform API for testing
92    #[derive(Clone)]
93    struct MockPlatformApi {
94        cookies_to_return: Arc<RwLock<Option<Vec<AcquiredCookie>>>>,
95    }
96
97    impl MockPlatformApi {
98        fn new() -> Self {
99            Self {
100                cookies_to_return: Arc::new(RwLock::new(None)),
101            }
102        }
103
104        async fn set_cookies(&self, cookies: Option<Vec<AcquiredCookie>>) {
105            *self.cookies_to_return.write().await = cookies;
106        }
107    }
108
109    #[async_trait::async_trait]
110    impl ServerCommunicationConfigPlatformApi for MockPlatformApi {
111        async fn acquire_cookies(&self, _vault_url: String) -> Option<Vec<AcquiredCookie>> {
112            self.cookies_to_return.read().await.clone()
113        }
114    }
115
116    #[test]
117    fn dyn_cookie_provider_is_send_sync() {
118        fn assert_send_sync<T: Send + Sync>() {}
119        fn check() {
120            assert_send_sync::<Arc<dyn CookieProvider>>();
121        }
122        check();
123    }
124
125    #[tokio::test]
126    async fn cookie_provider_cookies_delegates() {
127        let repo = MockRepository::default();
128        let config = ServerCommunicationConfig {
129            bootstrap: BootstrapConfig::SsoCookieVendor(SsoCookieVendorConfig {
130                idp_login_url: Some("https://idp.example.com".to_string()),
131                cookie_name: "TestCookie".to_string(),
132                cookie_domain: "example.com".to_string(),
133                vault_url: "https://vault.example.com".to_string(),
134                cookie_value: Some(vec![AcquiredCookie {
135                    name: "TestCookie".to_string(),
136                    value: "test-value".to_string(),
137                }]),
138            }),
139        };
140
141        repo.save("vault.example.com".to_string(), config)
142            .await
143            .unwrap();
144
145        let platform_api = MockPlatformApi::new();
146        let client = ServerCommunicationConfigClient::new(repo, platform_api);
147        let provider: &dyn CookieProvider = &client;
148
149        let cookies = provider.cookies("vault.example.com").await;
150        assert_eq!(cookies.len(), 1);
151        assert_eq!(cookies[0].0, "TestCookie");
152        assert_eq!(cookies[0].1, "test-value");
153    }
154
155    #[tokio::test]
156    async fn cookie_provider_acquire_cookie_delegates() {
157        let repo = MockRepository::default();
158        let config = ServerCommunicationConfig {
159            bootstrap: BootstrapConfig::SsoCookieVendor(SsoCookieVendorConfig {
160                idp_login_url: Some("https://idp.example.com".to_string()),
161                cookie_name: "TestCookie".to_string(),
162                cookie_domain: "example.com".to_string(),
163                vault_url: "https://vault.example.com".to_string(),
164                cookie_value: None,
165            }),
166        };
167
168        repo.save("vault.example.com".to_string(), config)
169            .await
170            .unwrap();
171
172        let platform_api = MockPlatformApi::new();
173        platform_api
174            .set_cookies(Some(vec![AcquiredCookie {
175                name: "TestCookie".to_string(),
176                value: "acquired-value".to_string(),
177            }]))
178            .await;
179
180        let client = ServerCommunicationConfigClient::new(repo, platform_api);
181        let provider: &dyn CookieProvider = &client;
182
183        let result = provider.acquire_cookie("vault.example.com").await;
184        assert_eq!(result, Ok(()));
185    }
186}