bitwarden_core/auth/
auth_tokens.rs1use std::sync::{Arc, RwLock};
5
6use bitwarden_crypto::KeyStore;
7
8use crate::{client::LoginMethod, key_management::KeyIds};
9
10#[async_trait::async_trait]
12pub trait TokenHandler: 'static + Send + Sync {
13 fn initialize_middleware(
18 &self,
19 login_method: Arc<RwLock<Option<Arc<LoginMethod>>>>,
20 identity_config: bitwarden_api_base::Configuration,
21 key_store: KeyStore<KeyIds>,
22 ) -> Arc<dyn reqwest_middleware::Middleware>;
23
24 async fn set_tokens(&self, token: String, refresh_token: Option<String>, expires_in: u64);
30}
31
32#[cfg_attr(feature = "uniffi", uniffi::export(with_foreign))]
34#[async_trait::async_trait]
35pub trait ClientManagedTokens: std::fmt::Debug + Send + Sync {
36 async fn get_access_token(&self) -> Option<String>;
38}
39
40#[derive(Clone)]
42pub struct ClientManagedTokenHandler {
43 tokens: Arc<dyn ClientManagedTokens>,
44}
45
46impl ClientManagedTokenHandler {
47 pub fn new(tokens: Arc<dyn ClientManagedTokens>) -> Arc<Self> {
49 Arc::new(Self { tokens })
50 }
51}
52
53#[async_trait::async_trait]
54impl TokenHandler for ClientManagedTokenHandler {
55 fn initialize_middleware(
56 &self,
57 _login_method: Arc<RwLock<Option<Arc<LoginMethod>>>>,
58 _identity_config: bitwarden_api_base::Configuration,
59 _key_store: KeyStore<KeyIds>,
60 ) -> Arc<dyn reqwest_middleware::Middleware> {
61 Arc::new(self.clone())
62 }
63
64 async fn set_tokens(&self, _token: String, _refresh_token: Option<String>, _expires_on: u64) {
65 panic!("Client-managed tokens cannot be set by the SDK");
66 }
67}
68
69#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
70#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
71impl reqwest_middleware::Middleware for ClientManagedTokenHandler {
72 async fn handle(
73 &self,
74 mut req: reqwest::Request,
75 ext: &mut http::Extensions,
76 next: reqwest_middleware::Next<'_>,
77 ) -> Result<reqwest::Response, reqwest_middleware::Error> {
78 if ext.get::<bitwarden_api_base::AuthRequired>().is_some()
79 && let Some(token) = self.tokens.get_access_token().await
80 {
81 match format!("Bearer {}", token).parse() {
82 Ok(header_value) => {
83 req.headers_mut()
84 .insert(http::header::AUTHORIZATION, header_value);
85 }
86 Err(e) => {
87 tracing::warn!("Failed to parse auth token for header: {e}");
88 }
89 }
90 }
91
92 let resp = next.run(req, ext).await?;
93
94 Ok(resp)
95 }
96}
97
98#[derive(Clone, Copy)]
101pub struct NoopTokenHandler;
102
103#[async_trait::async_trait]
104impl TokenHandler for NoopTokenHandler {
105 fn initialize_middleware(
106 &self,
107 _login_method: Arc<RwLock<Option<Arc<LoginMethod>>>>,
108 _identity_config: bitwarden_api_base::Configuration,
109 _key_store: KeyStore<KeyIds>,
110 ) -> Arc<dyn reqwest_middleware::Middleware> {
111 Arc::new(*self)
112 }
113
114 async fn set_tokens(&self, _token: String, _refresh_token: Option<String>, _expires_on: u64) {
115 panic!("Cannot set tokens on NoopTokenHandler");
116 }
117}
118
119#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
120#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
121impl reqwest_middleware::Middleware for NoopTokenHandler {
122 async fn handle(
123 &self,
124 req: reqwest::Request,
125 ext: &mut http::Extensions,
126 next: reqwest_middleware::Next<'_>,
127 ) -> Result<reqwest::Response, reqwest_middleware::Error> {
128 next.run(req, ext).await
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use wiremock::MockServer;
135
136 use super::*;
137
138 #[derive(Debug)]
139 struct MockTokenProvider {
140 token: Option<String>,
141 }
142
143 #[async_trait::async_trait]
144 impl ClientManagedTokens for MockTokenProvider {
145 async fn get_access_token(&self) -> Option<String> {
146 self.token.clone()
147 }
148 }
149
150 async fn test_setup(
151 token: Option<String>,
152 ) -> (reqwest_middleware::ClientWithMiddleware, MockServer) {
153 let provider = Arc::new(MockTokenProvider { token });
154 let handler = ClientManagedTokenHandler::new(provider);
155
156 let client = reqwest_middleware::ClientBuilder::new(reqwest::Client::new())
157 .with((*handler).clone())
158 .build();
159
160 let server = MockServer::start().await;
161 wiremock::Mock::given(wiremock::matchers::any())
162 .respond_with(wiremock::ResponseTemplate::new(200))
163 .mount(&server)
164 .await;
165
166 (client, server)
167 }
168
169 #[tokio::test]
170 async fn attaches_bearer_token_when_auth_required() {
171 let (client, server) = test_setup(Some("test-token".to_string())).await;
172
173 client
174 .get(format!("{}/test", server.uri()))
175 .with_extension(bitwarden_api_base::AuthRequired::Bearer)
176 .send()
177 .await
178 .unwrap();
179
180 let requests = server.received_requests().await.unwrap();
181 assert_eq!(requests.len(), 1);
182 assert_eq!(
183 requests[0]
184 .headers
185 .get("Authorization")
186 .map(|v| v.to_str().unwrap()),
187 Some("Bearer test-token")
188 );
189 }
190
191 #[tokio::test]
192 async fn does_not_attach_token_without_auth_required() {
193 let (client, server) = test_setup(Some("test-token".to_string())).await;
194
195 client
196 .get(format!("{}/test", server.uri()))
197 .send()
198 .await
199 .unwrap();
200
201 let requests = server.received_requests().await.unwrap();
202 assert_eq!(requests.len(), 1);
203 assert_eq!(requests[0].headers.get("Authorization"), None);
204 }
205
206 #[tokio::test]
207 async fn does_not_attach_token_when_provider_returns_none() {
208 let (client, server) = test_setup(None).await;
209
210 client
211 .get(format!("{}/test", server.uri()))
212 .with_extension(bitwarden_api_base::AuthRequired::Bearer)
213 .send()
214 .await
215 .unwrap();
216
217 let requests = server.received_requests().await.unwrap();
218 assert_eq!(requests.len(), 1);
219 assert_eq!(requests[0].headers.get("Authorization"), None);
220 }
221}