bitwarden_auth/token_management/
password_manager_token_handler.rs1use std::sync::{Arc, RwLock};
4
5use bitwarden_core::{
6 NotAuthenticatedError,
7 auth::{TokenHandler, login::LoginError},
8 client::{
9 login_method::UserLoginMethod,
10 persisted_state::{AUTHENTICATION_TOKENS, AuthenticationTokens, USER_LOGIN_METHOD},
11 },
12 key_management::KeySlotIds,
13};
14use bitwarden_crypto::KeyStore;
15use bitwarden_state::{registry::StateRegistry, settings::Setting};
16use chrono::Utc;
17
18use super::middleware::{MiddlewareExt, MiddlewareWrapper};
19use crate::token_management::middleware::TOKEN_RENEW_MARGIN_SECONDS;
20
21#[derive(Clone, Default)]
23pub struct PasswordManagerTokenHandler {
24 inner: Arc<RwLock<PasswordManagerTokenHandlerInner>>,
25}
26
27#[derive(Clone, Default)]
28struct PasswordManagerTokenHandlerInner {
29 tokens: Option<Setting<AuthenticationTokens>>,
31 login_method: Option<Setting<UserLoginMethod>>,
32 identity_config: Option<bitwarden_api_api::Configuration>,
33}
34
35#[async_trait::async_trait]
36impl TokenHandler for PasswordManagerTokenHandler {
37 fn initialize_middleware(
38 &self,
39 state_registry: &StateRegistry,
40 identity_config: bitwarden_api_api::Configuration,
41 _key_store: KeyStore<KeySlotIds>,
42 ) -> Arc<dyn reqwest_middleware::Middleware> {
43 {
44 let mut inner = self.inner.write().expect("RwLock is not poisoned");
45 inner.tokens = state_registry.setting(AUTHENTICATION_TOKENS).ok();
46 inner.login_method = state_registry.setting(USER_LOGIN_METHOD).ok();
47 inner.identity_config = Some(identity_config);
48 }
49 Arc::new(MiddlewareWrapper(self.clone()))
50 }
51
52 async fn set_tokens(
53 &self,
54 access_token: String,
55 refresh_token: Option<String>,
56 expires_in: u64,
57 ) {
58 let tokens = self
59 .inner
60 .read()
61 .expect("RwLock is not poisoned")
62 .tokens
63 .clone();
64
65 if let Some(tokens) = tokens {
66 tokens
67 .update(AuthenticationTokens {
68 access_token,
69 refresh_token,
70 expires_on: Utc::now().timestamp() + expires_in as i64,
71 })
72 .await
73 .ok();
74 }
75 }
76}
77
78#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
79#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
80impl MiddlewareExt for PasswordManagerTokenHandler {
81 async fn get_token(&self) -> Result<Option<String>, LoginError> {
82 let inner = self.inner.read().expect("RwLock is not poisoned").clone();
88
89 let tokens = inner
90 .tokens
91 .ok_or(NotAuthenticatedError)?
92 .get()
93 .await
94 .ok()
95 .flatten()
96 .ok_or(NotAuthenticatedError)?;
97
98 if Utc::now().timestamp() < tokens.expires_on - TOKEN_RENEW_MARGIN_SECONDS {
100 return Ok(Some(tokens.access_token.clone()));
101 }
102
103 let login_method = inner.login_method.ok_or(NotAuthenticatedError)?;
106 let identity_config = inner.identity_config.ok_or(NotAuthenticatedError)?;
107
108 let login_method = login_method
109 .get()
110 .await
111 .ok()
112 .flatten()
113 .ok_or(NotAuthenticatedError)?;
114
115 let (access_token, refresh_token, expires_in) =
116 bitwarden_core::auth::renew::renew_pm_token_sdk_managed(
117 tokens.refresh_token,
118 &login_method,
119 identity_config,
120 )
121 .await?;
122
123 self.set_tokens(access_token.clone(), refresh_token, expires_in)
124 .await;
125 Ok(Some(access_token))
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use bitwarden_core::{
132 auth::TokenHandler,
133 client::{
134 login_method::UserLoginMethod,
135 persisted_state::{AuthenticationTokens, USER_LOGIN_METHOD},
136 },
137 key_management::KeySlotIds,
138 };
139 use bitwarden_crypto::{Kdf, KeyStore};
140 use bitwarden_state::registry::StateRegistry;
141 use wiremock::MockServer;
142
143 use super::*;
144 use crate::token_management::test_utils::*;
145
146 async fn registry_with_api_key_login() -> StateRegistry {
147 let registry = StateRegistry::new_with_memory_db();
148 registry
149 .setting(USER_LOGIN_METHOD)
150 .unwrap()
151 .update(UserLoginMethod::ApiKey {
152 client_id: "test-client".to_string(),
153 client_secret: "test-secret".to_string(),
154 email: "[email protected]".to_string(),
155 kdf: Kdf::default_pbkdf2(),
156 })
157 .await
158 .unwrap();
159 registry
160 }
161
162 async fn seed_tokens(
163 registry: &StateRegistry,
164 access_token: &str,
165 refresh_token: Option<&str>,
166 expires_in: i64,
167 ) {
168 registry
169 .setting(AUTHENTICATION_TOKENS)
170 .unwrap()
171 .update(AuthenticationTokens {
172 access_token: access_token.to_string(),
173 refresh_token: refresh_token.map(str::to_string),
174 expires_on: Utc::now().timestamp() + expires_in,
175 })
176 .await
177 .unwrap();
178 }
179
180 #[tokio::test]
181 async fn attaches_existing_token_when_not_expired() {
182 let app_server = start_app_server().await;
183 let identity_server = MockServer::start().await;
184
185 let registry = registry_with_api_key_login().await;
186 seed_tokens(®istry, "original-token", Some("refresh"), 5000).await;
187
188 let handler = PasswordManagerTokenHandler::default();
189 let client = build_client(handler.initialize_middleware(
190 ®istry,
191 identity_config(&identity_server.uri()),
192 KeyStore::<KeySlotIds>::default(),
193 ));
194
195 let auth = send_auth_request(&client, &app_server).await;
196 assert_eq!(auth.as_deref(), Some("Bearer original-token"));
197 assert_eq!(identity_server.received_requests().await.unwrap().len(), 0);
198 assert_eq!(app_server.received_requests().await.unwrap().len(), 1);
199 }
200
201 #[tokio::test]
202 async fn renews_expired_token() {
203 let app_server = start_app_server().await;
204 let identity_server = start_renewal_server("renewed-token").await;
205
206 let registry = registry_with_api_key_login().await;
207 seed_tokens(®istry, "expired-token", Some("old-refresh"), 0).await;
209
210 let handler = PasswordManagerTokenHandler::default();
211 let client = build_client(handler.initialize_middleware(
212 ®istry,
213 identity_config(&identity_server.uri()),
214 KeyStore::<KeySlotIds>::default(),
215 ));
216
217 let auth = send_auth_request(&client, &app_server).await;
218 assert_eq!(auth.as_deref(), Some("Bearer renewed-token"));
219 assert_eq!(identity_server.received_requests().await.unwrap().len(), 1);
220 assert_eq!(app_server.received_requests().await.unwrap().len(), 1);
221 }
222}