bitwarden_core/client/
flags_client.rs1use std::{collections::HashMap, sync::Arc};
4
5use bitwarden_state::Setting;
6use chrono::{DateTime, Duration, Utc};
7#[cfg(feature = "wasm")]
8use wasm_bindgen::prelude::*;
9
10use crate::{
11 Client,
12 client::{
13 flags::Flags,
14 internal::ApiConfigurations,
15 persisted_state::{FLAGS, FLAGS_FETCHED_AT},
16 },
17};
18
19const FLAGS_TTL: Duration = Duration::hours(1);
20
21#[derive(Debug, thiserror::Error)]
23pub enum FetchFlagsError {
24 #[error("failed to fetch /config: {0}")]
26 Api(#[from] bitwarden_api_api::apis::Error),
27 #[error("state access error: {0}")]
29 State(#[from] bitwarden_state::SettingsError),
30}
31
32#[cfg_attr(feature = "wasm", wasm_bindgen)]
34pub struct FlagsClient {
35 flags: Setting<Flags>,
36 flags_fetched_at: Setting<DateTime<Utc>>,
37 api_configurations: Arc<ApiConfigurations>,
38}
39
40impl FlagsClient {
41 pub async fn load(&self, flags: HashMap<String, bool>) {
43 let flags = Flags::load_from_map(flags);
44 if let Err(e) = self.flags.update(flags).await {
45 tracing::warn!("Failed to persist flags: {e}");
46 }
47 }
48
49 pub async fn get(&self) -> Flags {
51 match self.flags.get().await {
52 Ok(flags) => flags.unwrap_or_default(),
53 Err(e) => {
54 tracing::warn!("Failed to read flags, using defaults: {e}");
55 Flags::default()
56 }
57 }
58 }
59
60 pub async fn fetch(&self, force: bool) -> Result<(), FetchFlagsError> {
66 if !force {
67 let last: Option<DateTime<Utc>> = self.flags_fetched_at.get().await?;
68 if let Some(fetched_at) = last
69 && Utc::now().signed_duration_since(fetched_at) < FLAGS_TTL
70 {
71 return Ok(());
72 }
73 }
74
75 let config = self
76 .api_configurations
77 .api_client
78 .config_api()
79 .get_configs()
80 .await?;
81 let feature_states = config.feature_states.unwrap_or_default();
82 let bool_map = feature_states
85 .into_iter()
86 .filter_map(|(k, v)| v.as_bool().map(|b| (k, b)))
87 .collect();
88 self.load(bool_map).await;
89 self.flags_fetched_at.update(Utc::now()).await?;
90 Ok(())
91 }
92}
93
94impl Client {
95 pub fn flags(&self) -> FlagsClient {
97 let registry = &self.internal.state_registry;
98 FlagsClient {
99 flags: registry
100 .setting(FLAGS)
101 .expect("Settings repository must be registered on the state registry"),
102 flags_fetched_at: registry
103 .setting(FLAGS_FETCHED_AT)
104 .expect("Settings repository must be registered on the state registry"),
105 api_configurations: self.internal.api_configurations.clone(),
106 }
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use serde_json::json;
113 use wiremock::{
114 Mock, MockServer, ResponseTemplate,
115 matchers::{method, path},
116 };
117
118 use super::*;
119 use crate::{ClientSettings, DeviceType};
120
121 fn settings_for(server: &MockServer) -> ClientSettings {
122 ClientSettings {
123 identity_url: format!("http://{}", server.address()),
124 api_url: format!("http://{}", server.address()),
125 user_agent: "flags-tests".to_string(),
126 device_type: DeviceType::SDK,
127 device_identifier: None,
128 bitwarden_client_version: None,
129 bitwarden_package_type: None,
130 }
131 }
132
133 async fn write_fetched_at(client: &Client, at: DateTime<Utc>) {
134 client
135 .internal
136 .state_registry
137 .setting(FLAGS_FETCHED_AT)
138 .unwrap()
139 .update(at)
140 .await
141 .unwrap();
142 }
143
144 async fn read_fetched_at(client: &Client) -> Option<DateTime<Utc>> {
145 client
146 .internal
147 .state_registry
148 .setting(FLAGS_FETCHED_AT)
149 .unwrap()
150 .get()
151 .await
152 .unwrap()
153 }
154
155 #[tokio::test]
156 async fn load_round_trips_through_setting() {
157 let client = Client::new(None);
158
159 let initial = client.flags().get().await;
161 assert!(!initial.enable_cipher_key_encryption);
162 assert!(!initial.strict_cipher_decryption);
163
164 let mut map = HashMap::new();
166 map.insert("enableCipherKeyEncryption".to_string(), true);
167 map.insert("pm-34500-strict-cipher-decryption".to_string(), true);
168 client.flags().load(map).await;
169
170 let loaded = client.flags().get().await;
172 assert!(loaded.enable_cipher_key_encryption);
173 assert!(loaded.strict_cipher_decryption);
174
175 let persisted = client
177 .internal
178 .state_registry
179 .setting(FLAGS)
180 .unwrap()
181 .get()
182 .await
183 .unwrap()
184 .expect("flags should be persisted after load");
185 assert!(persisted.enable_cipher_key_encryption);
186 assert!(persisted.strict_cipher_decryption);
187 }
188
189 #[tokio::test]
190 async fn fetch_force_persists_flags_and_timestamp() {
191 let server = MockServer::start().await;
192 Mock::given(method("GET"))
193 .and(path("/config"))
194 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
195 "featureStates": { "enableCipherKeyEncryption": true }
196 })))
197 .expect(1)
198 .mount(&server)
199 .await;
200
201 let client = Client::new(Some(settings_for(&server)));
202 let before = Utc::now();
203 client.flags().fetch(true).await.unwrap();
204
205 assert!(client.flags().get().await.enable_cipher_key_encryption);
206 let fetched_at = read_fetched_at(&client)
207 .await
208 .expect("fetched_at must be set after a successful fetch");
209 assert!(fetched_at >= before);
210 }
211
212 #[tokio::test]
213 async fn fetch_skips_when_fresh() {
214 let server = MockServer::start().await;
215 Mock::given(method("GET"))
216 .and(path("/config"))
217 .respond_with(ResponseTemplate::new(200).set_body_json(json!({})))
218 .expect(0)
219 .mount(&server)
220 .await;
221
222 let client = Client::new(Some(settings_for(&server)));
223 write_fetched_at(&client, Utc::now() - Duration::minutes(5)).await;
224
225 client.flags().fetch(false).await.unwrap();
226 }
227
228 #[tokio::test]
229 async fn fetch_force_ignores_ttl() {
230 let server = MockServer::start().await;
231 Mock::given(method("GET"))
232 .and(path("/config"))
233 .respond_with(ResponseTemplate::new(200).set_body_json(json!({})))
234 .expect(1)
235 .mount(&server)
236 .await;
237
238 let client = Client::new(Some(settings_for(&server)));
239 write_fetched_at(&client, Utc::now() - Duration::minutes(5)).await;
240
241 client.flags().fetch(true).await.unwrap();
242 }
243
244 #[tokio::test]
245 async fn fetch_refreshes_when_stale() {
246 let server = MockServer::start().await;
247 Mock::given(method("GET"))
248 .and(path("/config"))
249 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
250 "featureStates": { "enableCipherKeyEncryption": true }
251 })))
252 .expect(1)
253 .mount(&server)
254 .await;
255
256 let client = Client::new(Some(settings_for(&server)));
257 let stale = Utc::now() - Duration::hours(2);
258 write_fetched_at(&client, stale).await;
259
260 client.flags().fetch(false).await.unwrap();
261
262 assert!(client.flags().get().await.enable_cipher_key_encryption);
263 let fetched_at = read_fetched_at(&client).await.unwrap();
264 assert!(fetched_at > stale);
265 }
266
267 #[tokio::test]
268 async fn fetch_network_error_is_non_fatal_and_preserves_flags() {
269 let server = MockServer::start().await;
270 Mock::given(method("GET"))
271 .and(path("/config"))
272 .respond_with(ResponseTemplate::new(500))
273 .mount(&server)
274 .await;
275
276 let client = Client::new(Some(settings_for(&server)));
277 client
278 .flags()
279 .load(HashMap::from([(
280 "enableCipherKeyEncryption".to_string(),
281 true,
282 )]))
283 .await;
284
285 assert!(client.flags().fetch(true).await.is_err());
286 assert!(
287 client.flags().get().await.enable_cipher_key_encryption,
288 "previously persisted flags must survive a failed fetch"
289 );
290 }
291}