bitwarden_core/auth/
auth_tokens.rs1use std::sync::Arc;
5
6use bitwarden_crypto::KeyStore;
7use bitwarden_state::registry::StateRegistry;
8
9#[cfg(feature = "secrets")]
10use crate::client::login_method::ServiceAccountLoginMethod;
11use crate::key_management::KeySlotIds;
12
13#[async_trait::async_trait]
15pub trait TokenHandler: 'static + Send + Sync {
16 fn initialize_middleware(
21 &self,
22 state_registry: &StateRegistry,
23 identity_config: bitwarden_api_base::Configuration,
24 key_store: KeyStore<KeySlotIds>,
25 ) -> Arc<dyn reqwest_middleware::Middleware>;
26
27 async fn set_tokens(&self, token: String, refresh_token: Option<String>, expires_in: u64);
33
34 #[cfg(feature = "secrets")]
41 async fn set_sm_login_method(&self, _login_method: ServiceAccountLoginMethod) {}
42}
43
44#[cfg_attr(feature = "uniffi", uniffi::export(with_foreign))]
46#[async_trait::async_trait]
47pub trait ClientManagedTokens: std::fmt::Debug + Send + Sync {
48 async fn get_access_token(&self) -> Option<String>;
50}
51
52#[derive(Clone)]
54pub struct ClientManagedTokenHandler {
55 tokens: Arc<dyn ClientManagedTokens>,
56}
57
58impl ClientManagedTokenHandler {
59 pub fn new(tokens: Arc<dyn ClientManagedTokens>) -> Arc<Self> {
61 Arc::new(Self { tokens })
62 }
63}
64
65#[async_trait::async_trait]
66impl TokenHandler for ClientManagedTokenHandler {
67 fn initialize_middleware(
68 &self,
69 _state_registry: &StateRegistry,
70 _identity_config: bitwarden_api_base::Configuration,
71 _key_store: KeyStore<KeySlotIds>,
72 ) -> Arc<dyn reqwest_middleware::Middleware> {
73 Arc::new(self.clone())
74 }
75
76 async fn set_tokens(&self, _token: String, _refresh_token: Option<String>, _expires_on: u64) {
77 panic!("Client-managed tokens cannot be set by the SDK");
78 }
79}
80
81#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
82#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
83impl reqwest_middleware::Middleware for ClientManagedTokenHandler {
84 async fn handle(
85 &self,
86 mut req: reqwest::Request,
87 ext: &mut http::Extensions,
88 next: reqwest_middleware::Next<'_>,
89 ) -> Result<reqwest::Response, reqwest_middleware::Error> {
90 if ext.get::<bitwarden_api_base::AuthRequired>().is_some()
91 && let Some(token) = self.tokens.get_access_token().await
92 {
93 match format!("Bearer {}", token).parse() {
94 Ok(header_value) => {
95 req.headers_mut()
96 .insert(http::header::AUTHORIZATION, header_value);
97 }
98 Err(e) => {
99 tracing::warn!("Failed to parse auth token for header: {e}");
100 }
101 }
102 }
103
104 let resp = next.run(req, ext).await?;
105
106 Ok(resp)
107 }
108}
109
110#[derive(Clone, Copy)]
113pub struct NoopTokenHandler;
114
115#[async_trait::async_trait]
116impl TokenHandler for NoopTokenHandler {
117 fn initialize_middleware(
118 &self,
119 _state_registry: &StateRegistry,
120 _identity_config: bitwarden_api_base::Configuration,
121 _key_store: KeyStore<KeySlotIds>,
122 ) -> Arc<dyn reqwest_middleware::Middleware> {
123 Arc::new(*self)
124 }
125
126 async fn set_tokens(&self, _token: String, _refresh_token: Option<String>, _expires_on: u64) {
127 panic!("Cannot set tokens on NoopTokenHandler");
128 }
129}
130
131#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
132#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
133impl reqwest_middleware::Middleware for NoopTokenHandler {
134 async fn handle(
135 &self,
136 req: reqwest::Request,
137 ext: &mut http::Extensions,
138 next: reqwest_middleware::Next<'_>,
139 ) -> Result<reqwest::Response, reqwest_middleware::Error> {
140 next.run(req, ext).await
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use wiremock::MockServer;
147
148 use super::*;
149
150 #[derive(Debug)]
151 struct MockTokenProvider {
152 token: Option<String>,
153 }
154
155 #[async_trait::async_trait]
156 impl ClientManagedTokens for MockTokenProvider {
157 async fn get_access_token(&self) -> Option<String> {
158 self.token.clone()
159 }
160 }
161
162 async fn test_setup(
163 token: Option<String>,
164 ) -> (reqwest_middleware::ClientWithMiddleware, MockServer) {
165 let provider = Arc::new(MockTokenProvider { token });
166 let handler = ClientManagedTokenHandler::new(provider);
167
168 let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new())
169 .with((*handler).clone())
170 .build();
171
172 let server = MockServer::start().await;
173 wiremock::Mock::given(wiremock::matchers::any())
174 .respond_with(wiremock::ResponseTemplate::new(200))
175 .mount(&server)
176 .await;
177
178 (client, server)
179 }
180
181 #[tokio::test]
182 async fn attaches_bearer_token_when_auth_required() {
183 let (client, server) = test_setup(Some("test-token".to_string())).await;
184
185 client
186 .get(format!("{}/test", server.uri()))
187 .with_extension(bitwarden_api_base::AuthRequired::Bearer)
188 .send()
189 .await
190 .unwrap();
191
192 let requests = server.received_requests().await.unwrap();
193 assert_eq!(requests.len(), 1);
194 assert_eq!(
195 requests[0]
196 .headers
197 .get("Authorization")
198 .map(|v| v.to_str().unwrap()),
199 Some("Bearer test-token")
200 );
201 }
202
203 #[tokio::test]
204 async fn does_not_attach_token_without_auth_required() {
205 let (client, server) = test_setup(Some("test-token".to_string())).await;
206
207 client
208 .get(format!("{}/test", server.uri()))
209 .send()
210 .await
211 .unwrap();
212
213 let requests = server.received_requests().await.unwrap();
214 assert_eq!(requests.len(), 1);
215 assert_eq!(requests[0].headers.get("Authorization"), None);
216 }
217
218 #[tokio::test]
219 async fn does_not_attach_token_when_provider_returns_none() {
220 let (client, server) = test_setup(None).await;
221
222 client
223 .get(format!("{}/test", server.uri()))
224 .with_extension(bitwarden_api_base::AuthRequired::Bearer)
225 .send()
226 .await
227 .unwrap();
228
229 let requests = server.received_requests().await.unwrap();
230 assert_eq!(requests.len(), 1);
231 assert_eq!(requests[0].headers.get("Authorization"), None);
232 }
233}