bitwarden_core/auth/
auth_tokens.rs1use std::sync::{Arc, RwLock};
5
6use bitwarden_crypto::KeyStore;
7
8use crate::{client::LoginMethod, key_management::KeyIds};
9
10pub trait TokenHandler: 'static + Send + Sync {
12 fn initialize_middleware(
17 &self,
18 login_method: Arc<RwLock<Option<Arc<LoginMethod>>>>,
19 identity_config: bitwarden_api_base::Configuration,
20 key_store: KeyStore<KeyIds>,
21 ) -> Arc<dyn reqwest_middleware::Middleware>;
22
23 fn set_tokens(&self, token: String, refresh_token: Option<String>, expires_in: u64);
29}
30
31#[cfg_attr(feature = "uniffi", uniffi::export(with_foreign))]
33#[async_trait::async_trait]
34pub trait ClientManagedTokens: std::fmt::Debug + Send + Sync {
35 async fn get_access_token(&self) -> Option<String>;
37}
38
39#[derive(Clone)]
41pub struct ClientManagedTokenHandler {
42 tokens: Arc<dyn ClientManagedTokens>,
43}
44
45impl ClientManagedTokenHandler {
46 pub fn new(tokens: Arc<dyn ClientManagedTokens>) -> Arc<Self> {
48 Arc::new(Self { tokens })
49 }
50}
51
52impl TokenHandler for ClientManagedTokenHandler {
53 fn initialize_middleware(
54 &self,
55 _login_method: Arc<RwLock<Option<Arc<LoginMethod>>>>,
56 _identity_config: bitwarden_api_base::Configuration,
57 _key_store: KeyStore<KeyIds>,
58 ) -> Arc<dyn reqwest_middleware::Middleware> {
59 Arc::new(self.clone())
60 }
61
62 fn set_tokens(&self, _token: String, _refresh_token: Option<String>, _expires_on: u64) {
63 panic!("Client-managed tokens cannot be set by the SDK");
64 }
65}
66
67#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
68#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
69impl reqwest_middleware::Middleware for ClientManagedTokenHandler {
70 async fn handle(
71 &self,
72 mut req: reqwest::Request,
73 ext: &mut http::Extensions,
74 next: reqwest_middleware::Next<'_>,
75 ) -> Result<reqwest::Response, reqwest_middleware::Error> {
76 if ext.get::<bitwarden_api_base::AuthRequired>().is_some()
77 && let Some(token) = self.tokens.get_access_token().await
78 {
79 match format!("Bearer {}", token).parse() {
80 Ok(header_value) => {
81 req.headers_mut()
82 .insert(http::header::AUTHORIZATION, header_value);
83 }
84 Err(e) => {
85 tracing::warn!("Failed to parse auth token for header: {e}");
86 }
87 }
88 }
89
90 let resp = next.run(req, ext).await?;
91
92 Ok(resp)
93 }
94}
95
96#[derive(Clone, Copy)]
99pub struct NoopTokenHandler;
100
101impl TokenHandler for NoopTokenHandler {
102 fn initialize_middleware(
103 &self,
104 _login_method: Arc<RwLock<Option<Arc<LoginMethod>>>>,
105 _identity_config: bitwarden_api_base::Configuration,
106 _key_store: KeyStore<KeyIds>,
107 ) -> Arc<dyn reqwest_middleware::Middleware> {
108 Arc::new(*self)
109 }
110
111 fn set_tokens(&self, _token: String, _refresh_token: Option<String>, _expires_on: u64) {
112 panic!("Cannot set tokens on NoopTokenHandler");
113 }
114}
115
116#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
117#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
118impl reqwest_middleware::Middleware for NoopTokenHandler {
119 async fn handle(
120 &self,
121 req: reqwest::Request,
122 ext: &mut http::Extensions,
123 next: reqwest_middleware::Next<'_>,
124 ) -> Result<reqwest::Response, reqwest_middleware::Error> {
125 next.run(req, ext).await
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use wiremock::MockServer;
132
133 use super::*;
134
135 #[derive(Debug)]
136 struct MockTokenProvider {
137 token: Option<String>,
138 }
139
140 #[async_trait::async_trait]
141 impl ClientManagedTokens for MockTokenProvider {
142 async fn get_access_token(&self) -> Option<String> {
143 self.token.clone()
144 }
145 }
146
147 async fn test_setup(
148 token: Option<String>,
149 ) -> (reqwest_middleware::ClientWithMiddleware, MockServer) {
150 let provider = Arc::new(MockTokenProvider { token });
151 let handler = ClientManagedTokenHandler::new(provider);
152
153 let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new())
154 .with((*handler).clone())
155 .build();
156
157 let server = MockServer::start().await;
158 wiremock::Mock::given(wiremock::matchers::any())
159 .respond_with(wiremock::ResponseTemplate::new(200))
160 .mount(&server)
161 .await;
162
163 (client, server)
164 }
165
166 #[tokio::test]
167 async fn attaches_bearer_token_when_auth_required() {
168 let (client, server) = test_setup(Some("test-token".to_string())).await;
169
170 client
171 .get(format!("{}/test", server.uri()))
172 .with_extension(bitwarden_api_base::AuthRequired::Bearer)
173 .send()
174 .await
175 .unwrap();
176
177 let requests = server.received_requests().await.unwrap();
178 assert_eq!(requests.len(), 1);
179 assert_eq!(
180 requests[0]
181 .headers
182 .get("Authorization")
183 .map(|v| v.to_str().unwrap()),
184 Some("Bearer test-token")
185 );
186 }
187
188 #[tokio::test]
189 async fn does_not_attach_token_without_auth_required() {
190 let (client, server) = test_setup(Some("test-token".to_string())).await;
191
192 client
193 .get(format!("{}/test", server.uri()))
194 .send()
195 .await
196 .unwrap();
197
198 let requests = server.received_requests().await.unwrap();
199 assert_eq!(requests.len(), 1);
200 assert_eq!(requests[0].headers.get("Authorization"), None);
201 }
202
203 #[tokio::test]
204 async fn does_not_attach_token_when_provider_returns_none() {
205 let (client, server) = test_setup(None).await;
206
207 client
208 .get(format!("{}/test", server.uri()))
209 .with_extension(bitwarden_api_base::AuthRequired::Bearer)
210 .send()
211 .await
212 .unwrap();
213
214 let requests = server.received_requests().await.unwrap();
215 assert_eq!(requests.len(), 1);
216 assert_eq!(requests[0].headers.get("Authorization"), None);
217 }
218}