Skip to main content

bitwarden_core/client/
flags_client.rs

1//! Feature flag retrieval, persistence, and refresh from the server `/config` endpoint.
2
3use std::{collections::HashMap, sync::Arc};
4
5use bitwarden_state::Setting;
6use chrono::{DateTime, Duration, Utc};
7#[cfg(feature = "wasm")]
8use wasm_bindgen::prelude::*;
9
10use crate::{
11    Client,
12    client::{
13        flags::Flags,
14        internal::ApiConfigurations,
15        persisted_state::{FLAGS, FLAGS_FETCHED_AT},
16    },
17};
18
19const FLAGS_TTL: Duration = Duration::hours(1);
20
21/// Errors returned by [`FlagsClient::fetch`].
22#[derive(Debug, thiserror::Error)]
23pub enum FetchFlagsError {
24    /// Network or deserialization error when fetching `/config`.
25    #[error("failed to fetch /config: {0}")]
26    Api(#[from] bitwarden_api_api::apis::Error),
27    /// Error persisting flags or fetched_at timestamp to state registry.
28    #[error("state access error: {0}")]
29    State(#[from] bitwarden_state::SettingsError),
30}
31
32/// A client for inspecting and refreshing feature flags.
33#[cfg_attr(feature = "wasm", wasm_bindgen)]
34pub struct FlagsClient {
35    flags: Setting<Flags>,
36    flags_fetched_at: Setting<DateTime<Utc>>,
37    api_configurations: Arc<ApiConfigurations>,
38}
39
40impl FlagsClient {
41    /// Persist a flag map (e.g. from `/config`) into the state registry.
42    pub async fn load(&self, flags: HashMap<String, bool>) {
43        let flags = Flags::load_from_map(flags);
44        if let Err(e) = self.flags.update(flags).await {
45            tracing::warn!("Failed to persist flags: {e}");
46        }
47    }
48
49    /// Retrieve the active feature flags from the state registry.
50    pub async fn get(&self) -> Flags {
51        match self.flags.get().await {
52            Ok(flags) => flags.unwrap_or_default(),
53            Err(e) => {
54                tracing::warn!("Failed to read flags, using defaults: {e}");
55                Flags::default()
56            }
57        }
58    }
59
60    /// Fetch flags from `/config` and persist both the flag values and a `fetched_at` timestamp.
61    ///
62    /// Pass `force = true` from `from_authenticated_data` (PM-27624) immediately before
63    /// `save_to_state`, so the initial flag fetch is part of the persisted login state.
64    /// [`Client::load_from_state`] calls this with `force = false` to honour the 1-hour TTL.
65    pub async fn fetch(&self, force: bool) -> Result<(), FetchFlagsError> {
66        if !force {
67            let last: Option<DateTime<Utc>> = self.flags_fetched_at.get().await?;
68            if let Some(fetched_at) = last
69                && Utc::now().signed_duration_since(fetched_at) < FLAGS_TTL
70            {
71                return Ok(());
72            }
73        }
74
75        let config = self
76            .api_configurations
77            .api_client
78            .config_api()
79            .get_configs()
80            .await?;
81        let feature_states = config.feature_states.unwrap_or_default();
82        // `/config` returns `serde_json::Value`; coerce to bool. Non-bool values are dropped
83        // because `Flags` only models boolean flags today.
84        let bool_map = feature_states
85            .into_iter()
86            .filter_map(|(k, v)| v.as_bool().map(|b| (k, b)))
87            .collect();
88        self.load(bool_map).await;
89        self.flags_fetched_at.update(Utc::now()).await?;
90        Ok(())
91    }
92}
93
94impl Client {
95    /// Access to feature flag retrieval, persistence, and refresh.
96    pub fn flags(&self) -> FlagsClient {
97        let registry = &self.internal.state_registry;
98        FlagsClient {
99            flags: registry
100                .setting(FLAGS)
101                .expect("Settings repository must be registered on the state registry"),
102            flags_fetched_at: registry
103                .setting(FLAGS_FETCHED_AT)
104                .expect("Settings repository must be registered on the state registry"),
105            api_configurations: self.internal.api_configurations.clone(),
106        }
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use serde_json::json;
113    use wiremock::{
114        Mock, MockServer, ResponseTemplate,
115        matchers::{method, path},
116    };
117
118    use super::*;
119    use crate::{ClientSettings, DeviceType};
120
121    fn settings_for(server: &MockServer) -> ClientSettings {
122        ClientSettings {
123            identity_url: format!("http://{}", server.address()),
124            api_url: format!("http://{}", server.address()),
125            user_agent: "flags-tests".to_string(),
126            device_type: DeviceType::SDK,
127            device_identifier: None,
128            bitwarden_client_version: None,
129            bitwarden_package_type: None,
130        }
131    }
132
133    async fn write_fetched_at(client: &Client, at: DateTime<Utc>) {
134        client
135            .internal
136            .state_registry
137            .setting(FLAGS_FETCHED_AT)
138            .unwrap()
139            .update(at)
140            .await
141            .unwrap();
142    }
143
144    async fn read_fetched_at(client: &Client) -> Option<DateTime<Utc>> {
145        client
146            .internal
147            .state_registry
148            .setting(FLAGS_FETCHED_AT)
149            .unwrap()
150            .get()
151            .await
152            .unwrap()
153    }
154
155    #[tokio::test]
156    async fn load_round_trips_through_setting() {
157        let client = Client::new(None);
158
159        // With no flags loaded yet, get should return defaults.
160        let initial = client.flags().get().await;
161        assert!(!initial.enable_cipher_key_encryption);
162        assert!(!initial.strict_cipher_decryption);
163
164        // Loading flags should persist them via the FLAGS setting.
165        let mut map = HashMap::new();
166        map.insert("enableCipherKeyEncryption".to_string(), true);
167        map.insert("pm-34500-strict-cipher-decryption".to_string(), true);
168        client.flags().load(map).await;
169
170        // get should now return the loaded values.
171        let loaded = client.flags().get().await;
172        assert!(loaded.enable_cipher_key_encryption);
173        assert!(loaded.strict_cipher_decryption);
174
175        // The values should be readable directly from the setting too.
176        let persisted = client
177            .internal
178            .state_registry
179            .setting(FLAGS)
180            .unwrap()
181            .get()
182            .await
183            .unwrap()
184            .expect("flags should be persisted after load");
185        assert!(persisted.enable_cipher_key_encryption);
186        assert!(persisted.strict_cipher_decryption);
187    }
188
189    #[tokio::test]
190    async fn fetch_force_persists_flags_and_timestamp() {
191        let server = MockServer::start().await;
192        Mock::given(method("GET"))
193            .and(path("/config"))
194            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
195                "featureStates": { "enableCipherKeyEncryption": true }
196            })))
197            .expect(1)
198            .mount(&server)
199            .await;
200
201        let client = Client::new(Some(settings_for(&server)));
202        let before = Utc::now();
203        client.flags().fetch(true).await.unwrap();
204
205        assert!(client.flags().get().await.enable_cipher_key_encryption);
206        let fetched_at = read_fetched_at(&client)
207            .await
208            .expect("fetched_at must be set after a successful fetch");
209        assert!(fetched_at >= before);
210    }
211
212    #[tokio::test]
213    async fn fetch_skips_when_fresh() {
214        let server = MockServer::start().await;
215        Mock::given(method("GET"))
216            .and(path("/config"))
217            .respond_with(ResponseTemplate::new(200).set_body_json(json!({})))
218            .expect(0)
219            .mount(&server)
220            .await;
221
222        let client = Client::new(Some(settings_for(&server)));
223        write_fetched_at(&client, Utc::now() - Duration::minutes(5)).await;
224
225        client.flags().fetch(false).await.unwrap();
226    }
227
228    #[tokio::test]
229    async fn fetch_force_ignores_ttl() {
230        let server = MockServer::start().await;
231        Mock::given(method("GET"))
232            .and(path("/config"))
233            .respond_with(ResponseTemplate::new(200).set_body_json(json!({})))
234            .expect(1)
235            .mount(&server)
236            .await;
237
238        let client = Client::new(Some(settings_for(&server)));
239        write_fetched_at(&client, Utc::now() - Duration::minutes(5)).await;
240
241        client.flags().fetch(true).await.unwrap();
242    }
243
244    #[tokio::test]
245    async fn fetch_refreshes_when_stale() {
246        let server = MockServer::start().await;
247        Mock::given(method("GET"))
248            .and(path("/config"))
249            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
250                "featureStates": { "enableCipherKeyEncryption": true }
251            })))
252            .expect(1)
253            .mount(&server)
254            .await;
255
256        let client = Client::new(Some(settings_for(&server)));
257        let stale = Utc::now() - Duration::hours(2);
258        write_fetched_at(&client, stale).await;
259
260        client.flags().fetch(false).await.unwrap();
261
262        assert!(client.flags().get().await.enable_cipher_key_encryption);
263        let fetched_at = read_fetched_at(&client).await.unwrap();
264        assert!(fetched_at > stale);
265    }
266
267    #[tokio::test]
268    async fn fetch_network_error_is_non_fatal_and_preserves_flags() {
269        let server = MockServer::start().await;
270        Mock::given(method("GET"))
271            .and(path("/config"))
272            .respond_with(ResponseTemplate::new(500))
273            .mount(&server)
274            .await;
275
276        let client = Client::new(Some(settings_for(&server)));
277        client
278            .flags()
279            .load(HashMap::from([(
280                "enableCipherKeyEncryption".to_string(),
281                true,
282            )]))
283            .await;
284
285        assert!(client.flags().fetch(true).await.is_err());
286        assert!(
287            client.flags().get().await.enable_cipher_key_encryption,
288            "previously persisted flags must survive a failed fetch"
289        );
290    }
291}