Skip to main content

bitwarden_state_bridge_macro/
lib.rs

1//! Proc macro for generating the key-management state bridge surface.
2//!
3//! Provides the `state_bridge!` function-like macro that, from a list of typed fields, expands
4//! into the `StateBridgeImpl` trait, wrapper methods on `StateBridge` and `StateBridgeClient`,
5//! WASM extern bindings, the `WasmStateBridge` trait impl, and the matching TypeScript interface.
6
7use proc_macro::TokenStream;
8use quote::{format_ident, quote};
9use syn::{
10    Ident, LitStr, Token, Type,
11    parse::{Parse, ParseStream},
12    parse_macro_input,
13    punctuated::Punctuated,
14};
15
16mod state_bridge_kw {
17    syn::custom_keyword!(ts);
18}
19
20struct StateBridgeField {
21    name: Ident,
22    ty: Type,
23    ts: LitStr,
24}
25
26impl Parse for StateBridgeField {
27    fn parse(input: ParseStream) -> syn::Result<Self> {
28        let name: Ident = input.parse()?;
29        input.parse::<Token![:]>()?;
30        let ty: Type = input.parse()?;
31        input.parse::<Token![as]>()?;
32        input.parse::<state_bridge_kw::ts>()?;
33        let ts: LitStr = input.parse()?;
34        Ok(Self { name, ty, ts })
35    }
36}
37
38struct StateBridgeInput {
39    fields: Punctuated<StateBridgeField, Token![,]>,
40}
41
42impl Parse for StateBridgeInput {
43    fn parse(input: ParseStream) -> syn::Result<Self> {
44        Ok(Self {
45            fields: Punctuated::parse_terminated(input)?,
46        })
47    }
48}
49
50/// Generates the full state bridge surface for a fixed list of fields.
51///
52/// Each field expands to:
53/// 1. Three methods on the `StateBridgeImpl` trait (`set_$name`, `get_$name`, `clear_$name`).
54/// 2. Three corresponding wrapper methods on `StateBridge`.
55/// 3. Three corresponding methods on `StateBridgeClient`.
56/// 4. Three method declarations on the WASM `RawWasmStateBridge` extern type and three forwarders
57///    in the `StateBridgeImpl` impl for `WasmStateBridge`.
58/// 5. Three lines in the `WasmStateBridge` TypeScript interface.
59/// 6. One field on `test_support::InMemoryStateBridge` and three forwarders in its
60///    `StateBridgeImpl` impl, gated on `#[cfg(test)]`.
61///
62/// All fields share the same shape: `set_$name(value: $ty)`, `get_$name() -> Option<$ty>`,
63/// `clear_$name()`.
64#[proc_macro]
65pub fn state_bridge(input: TokenStream) -> TokenStream {
66    let StateBridgeInput { fields } = parse_macro_input!(input as StateBridgeInput);
67
68    let trait_methods = fields.iter().map(|f| {
69        let ty = &f.ty;
70        let n = f.name.to_string();
71        let set = format_ident!("set_{}", f.name);
72        let get = format_ident!("get_{}", f.name);
73        let clear = format_ident!("clear_{}", f.name);
74        let set_doc = format!("Stores the `{n}` value.");
75        let get_doc = format!("Returns the `{n}` value, if available.");
76        let clear_doc = format!("Clears the `{n}` value.");
77        quote! {
78            #[doc = #set_doc]
79            async fn #set(&self, value: #ty);
80            #[doc = #get_doc]
81            async fn #get(&self) -> Option<#ty>;
82            #[doc = #clear_doc]
83            async fn #clear(&self);
84        }
85    });
86
87    let bridge_wrappers = fields.iter().map(|f| {
88        let ty = &f.ty;
89        let n = f.name.to_string();
90        let set = format_ident!("set_{}", f.name);
91        let get = format_ident!("get_{}", f.name);
92        let clear = format_ident!("clear_{}", f.name);
93        let set_doc = format!("Stores the `{n}` value.");
94        let get_doc = format!("Returns the `{n}` value, if available.");
95        let clear_doc = format!("Clears the `{n}` value.");
96        quote! {
97            #[doc = #set_doc]
98            pub async fn #set(&self, value: &#ty) {
99                let implementation = self
100                    .implementation
101                    .lock()
102                    .expect("Mutex is not poisoned")
103                    .as_ref()
104                    .expect("StateBridge not registered")
105                    .clone();
106                implementation.#set(value.to_owned()).await
107            }
108
109            #[doc = #get_doc]
110            pub async fn #get(&self) -> Option<#ty> {
111                let implementation = self
112                    .implementation
113                    .lock()
114                    .expect("Mutex is not poisoned")
115                    .as_ref()
116                    .expect("StateBridge not registered")
117                    .clone();
118                implementation.#get().await
119            }
120
121            #[doc = #clear_doc]
122            pub async fn #clear(&self) {
123                let implementation = self
124                    .implementation
125                    .lock()
126                    .expect("Mutex is not poisoned")
127                    .as_ref()
128                    .expect("StateBridge not registered")
129                    .clone();
130                implementation.#clear().await
131            }
132        }
133    });
134
135    let client_forwarders = fields.iter().map(|f| {
136        let ty = &f.ty;
137        let n = f.name.to_string();
138        let set = format_ident!("set_{}", f.name);
139        let get = format_ident!("get_{}", f.name);
140        let clear = format_ident!("clear_{}", f.name);
141        let set_doc = format!("Sets the `{n}` value in client-managed state.");
142        let get_doc = format!("Gets the `{n}` value from client-managed state, if available.");
143        let clear_doc = format!("Clears the `{n}` value from client-managed state.");
144        quote! {
145            #[doc = #set_doc]
146            pub async fn #set(&self, value: &#ty) {
147                self.client.internal.state_bridge.#set(value).await;
148            }
149
150            #[doc = #get_doc]
151            pub async fn #get(&self) -> Option<#ty> {
152                self.client.internal.state_bridge.#get().await
153            }
154
155            #[doc = #clear_doc]
156            pub async fn #clear(&self) {
157                self.client.internal.state_bridge.#clear().await;
158            }
159        }
160    });
161
162    let mut ts_iface = String::from(
163        "/**\n * Typescript interface that the state bridge needs to implement. The state bridge\n * is a temporary layer that allows quickly transitioning non-repository shaped\n * state to be accessible from within the SDK.\n */\nexport interface WasmStateBridge {\n",
164    );
165    for f in &fields {
166        let n = f.name.to_string();
167        let t = f.ts.value();
168        ts_iface.push_str(&format!("    set_{n}(value: {t}): Promise<void>;\n"));
169        ts_iface.push_str(&format!("    get_{n}(): Promise<{t} | null>;\n"));
170        ts_iface.push_str(&format!("    clear_{n}(): Promise<void>;\n"));
171    }
172    ts_iface.push_str("}\n");
173
174    let extern_methods = fields.iter().map(|f| {
175        let ty = &f.ty;
176        let n = f.name.to_string();
177        let set = format_ident!("set_{}", f.name);
178        let get = format_ident!("get_{}", f.name);
179        let clear = format_ident!("clear_{}", f.name);
180        let set_doc = format!("JS-side `set_{n}` method on `WasmStateBridge`.");
181        let get_doc = format!("JS-side `get_{n}` method on `WasmStateBridge`.");
182        let clear_doc = format!("JS-side `clear_{n}` method on `WasmStateBridge`.");
183        quote! {
184            #[doc = #set_doc]
185            #[wasm_bindgen(method)]
186            pub async fn #set(
187                this: &crate::key_management::state_bridge::RawWasmStateBridge,
188                value: #ty,
189            );
190            #[doc = #get_doc]
191            #[wasm_bindgen(method)]
192            pub async fn #get(
193                this: &crate::key_management::state_bridge::RawWasmStateBridge,
194            ) -> ::wasm_bindgen::JsValue;
195            #[doc = #clear_doc]
196            #[wasm_bindgen(method)]
197            pub async fn #clear(
198                this: &crate::key_management::state_bridge::RawWasmStateBridge,
199            );
200        }
201    });
202
203    let test_support_struct_fields = fields.iter().map(|f| {
204        let name = &f.name;
205        let ty = &f.ty;
206        quote! { #name: ::std::sync::Mutex<Option<#ty>>, }
207    });
208
209    let test_support_impl_methods = fields.iter().map(|f| {
210        let name = &f.name;
211        let ty = &f.ty;
212        let set = format_ident!("set_{}", f.name);
213        let get = format_ident!("get_{}", f.name);
214        let clear = format_ident!("clear_{}", f.name);
215        quote! {
216            async fn #set(&self, value: #ty) {
217                *self.#name.lock().expect("not poisoned") = Some(value);
218            }
219            async fn #get(&self) -> Option<#ty> {
220                self.#name.lock().expect("not poisoned").clone()
221            }
222            async fn #clear(&self) {
223                *self.#name.lock().expect("not poisoned") = None;
224            }
225        }
226    });
227
228    let uniffi_trait_methods = fields.iter().map(|f| {
229        let ty = &f.ty;
230        let n = f.name.to_string();
231        let set = format_ident!("set_{}", f.name);
232        let get = format_ident!("get_{}", f.name);
233        let clear = format_ident!("clear_{}", f.name);
234        let set_doc = format!("Stores the `{n}` value.");
235        let get_doc = format!("Returns the `{n}` value, if available.");
236        let clear_doc = format!("Clears the `{n}` value.");
237        quote! {
238            #[doc = #set_doc]
239            async fn #set(&self, value: #ty);
240            #[doc = #get_doc]
241            async fn #get(&self) -> Option<#ty>;
242            #[doc = #clear_doc]
243            async fn #clear(&self);
244        }
245    });
246
247    let uniffi_impls = fields.iter().map(|f| {
248        let ty = &f.ty;
249        let set = format_ident!("set_{}", f.name);
250        let get = format_ident!("get_{}", f.name);
251        let clear = format_ident!("clear_{}", f.name);
252        quote! {
253            async fn #set(&self, value: #ty) {
254                self.0.#set(value).await
255            }
256            async fn #get(&self) -> Option<#ty> {
257                self.0.#get().await
258            }
259            async fn #clear(&self) {
260                self.0.#clear().await
261            }
262        }
263    });
264
265    let wasm_impls = fields.iter().map(|f| {
266        let ty = &f.ty;
267        let n = f.name.to_string();
268        let set = format_ident!("set_{}", f.name);
269        let get = format_ident!("get_{}", f.name);
270        let clear = format_ident!("clear_{}", f.name);
271        let get_err = format!("State bridge `get_{n}` failed to deserialize value from JsValue");
272        quote! {
273            async fn #set(&self, value: #ty) {
274                self.0
275                    .run_in_thread(move |state| async move {
276                        state.#set(value).await
277                    })
278                    .await
279                    .expect("State bridge call panicked");
280            }
281            async fn #get(&self) -> Option<#ty> {
282                let js: ::wasm_bindgen::JsValue = self.0
283                    .run_in_thread(|state| async move {
284                        state.#get().await
285                    })
286                    .await
287                    .expect("State bridge call panicked");
288                if js.is_null() || js.is_undefined() {
289                    None
290                } else {
291                    Some(
292                        <#ty as ::core::convert::TryFrom<::wasm_bindgen::JsValue>>::try_from(js)
293                            .expect(#get_err),
294                    )
295                }
296            }
297            async fn #clear(&self) {
298                self.0
299                    .run_in_thread(|state| async move {
300                        state.#clear().await
301                    })
302                    .await
303                    .expect("State bridge call panicked");
304            }
305        }
306    });
307
308    let expanded = quote! {
309        /// Host-provided storage bridge for key-management state.
310        ///
311        /// SDK consumers register an implementation that persists or caches sensitive
312        /// account state across unlock flows.
313        #[cfg_attr(target_arch = "wasm32", ::async_trait::async_trait(?Send))]
314        #[cfg_attr(not(target_arch = "wasm32"), ::async_trait::async_trait)]
315        pub trait StateBridgeImpl: Send + Sync {
316            #(#trait_methods)*
317        }
318
319        impl StateBridge {
320            #(#bridge_wrappers)*
321        }
322
323        impl crate::key_management::state_bridge::StateBridgeClient {
324            #(#client_forwarders)*
325        }
326
327        #[cfg(target_arch = "wasm32")]
328        #[::wasm_bindgen::prelude::wasm_bindgen(typescript_custom_section)]
329        const TS_CUSTOM_TYPES_STATE_BRIDGE: &'static str = #ts_iface;
330
331        #[cfg(target_arch = "wasm32")]
332        #[::wasm_bindgen::prelude::wasm_bindgen]
333        extern "C" {
334            #(#extern_methods)*
335        }
336
337        #[cfg(target_arch = "wasm32")]
338        #[::async_trait::async_trait(?Send)]
339        impl StateBridgeImpl for crate::key_management::state_bridge::WasmStateBridge {
340            #(#wasm_impls)*
341        }
342
343        /// Foreign trait that Swift/Kotlin hosts implement to provide the state bridge.
344        ///
345        /// `StateBridgeImpl` is automatically implemented for the
346        /// `UniffiStateBridge` adapter that wraps an
347        /// `Arc<dyn StateBridgeForeignImpl>`.
348        #[cfg(feature = "uniffi")]
349        #[::uniffi::export(with_foreign)]
350        #[::async_trait::async_trait]
351        pub trait StateBridgeForeignImpl: Send + Sync {
352            #(#uniffi_trait_methods)*
353        }
354
355        #[cfg(feature = "uniffi")]
356        #[::async_trait::async_trait]
357        impl StateBridgeImpl for crate::key_management::state_bridge::UniffiStateBridge {
358            #(#uniffi_impls)*
359        }
360
361        #[cfg(any(test, feature = "internal-test-utils"))]
362        #[allow(missing_docs)]
363        pub mod test_support {
364            //! In-memory test fixtures for the state bridge.
365            use super::*;
366
367            /// In-memory `StateBridgeImpl` for use in tests.
368            #[derive(Default)]
369            pub struct InMemoryStateBridge {
370                #(#test_support_struct_fields)*
371            }
372
373            #[cfg_attr(target_arch = "wasm32", ::async_trait::async_trait(?Send))]
374            #[cfg_attr(not(target_arch = "wasm32"), ::async_trait::async_trait)]
375            impl super::StateBridgeImpl for InMemoryStateBridge {
376                #(#test_support_impl_methods)*
377            }
378        }
379    };
380
381    TokenStream::from(expanded)
382}