Skip to main content

bitwarden_state_bridge_macro/
lib.rs

1//! Proc macro for generating the key-management state bridge surface.
2//!
3//! Provides the `state_bridge!` function-like macro that, from a list of typed fields, expands
4//! into the `StateBridgeImpl` trait, wrapper methods on `StateBridge` and `StateBridgeClient`,
5//! WASM extern bindings, the `WasmStateBridge` trait impl, and the matching TypeScript interface.
6
7use proc_macro::TokenStream;
8use quote::{format_ident, quote};
9use syn::{
10    Ident, LitStr, Token, Type,
11    parse::{Parse, ParseStream},
12    parse_macro_input,
13    punctuated::Punctuated,
14};
15
16mod state_bridge_kw {
17    syn::custom_keyword!(ts);
18}
19
20struct StateBridgeField {
21    name: Ident,
22    ty: Type,
23    ts: LitStr,
24}
25
26impl Parse for StateBridgeField {
27    fn parse(input: ParseStream) -> syn::Result<Self> {
28        let name: Ident = input.parse()?;
29        input.parse::<Token![:]>()?;
30        let ty: Type = input.parse()?;
31        input.parse::<Token![as]>()?;
32        input.parse::<state_bridge_kw::ts>()?;
33        let ts: LitStr = input.parse()?;
34        Ok(Self { name, ty, ts })
35    }
36}
37
38struct StateBridgeInput {
39    fields: Punctuated<StateBridgeField, Token![,]>,
40}
41
42impl Parse for StateBridgeInput {
43    fn parse(input: ParseStream) -> syn::Result<Self> {
44        Ok(Self {
45            fields: Punctuated::parse_terminated(input)?,
46        })
47    }
48}
49
50/// Generates the full state bridge surface for a fixed list of fields.
51///
52/// Each field expands to:
53/// 1. Three methods on the `StateBridgeImpl` trait (`set_$name`, `get_$name`, `clear_$name`).
54/// 2. Three corresponding wrapper methods on `StateBridge`.
55/// 3. Three corresponding methods on `StateBridgeClient`.
56/// 4. Three method declarations on the WASM `RawWasmStateBridge` extern type and three forwarders
57///    in the `StateBridgeImpl` impl for `WasmStateBridge`.
58/// 5. Three lines in the `WasmStateBridge` TypeScript interface.
59/// 6. One field on `test_support::InMemoryStateBridge` and three forwarders in its
60///    `StateBridgeImpl` impl, gated on `#[cfg(test)]`.
61///
62/// All fields share the same shape: `set_$name(value: $ty)`, `get_$name() -> Option<$ty>`,
63/// `clear_$name()`.
64#[proc_macro]
65pub fn state_bridge(input: TokenStream) -> TokenStream {
66    let StateBridgeInput { fields } = parse_macro_input!(input as StateBridgeInput);
67
68    let trait_methods = fields.iter().map(|f| {
69        let ty = &f.ty;
70        let n = f.name.to_string();
71        let set = format_ident!("set_{}", f.name);
72        let get = format_ident!("get_{}", f.name);
73        let clear = format_ident!("clear_{}", f.name);
74        let set_doc = format!("Stores the `{n}` value.");
75        let get_doc = format!("Returns the `{n}` value, if available.");
76        let clear_doc = format!("Clears the `{n}` value.");
77        quote! {
78            #[doc = #set_doc]
79            async fn #set(&self, value: #ty);
80            #[doc = #get_doc]
81            async fn #get(&self) -> Option<#ty>;
82            #[doc = #clear_doc]
83            async fn #clear(&self);
84        }
85    });
86
87    let bridge_wrappers = fields.iter().map(|f| {
88        let ty = &f.ty;
89        let n = f.name.to_string();
90        let set = format_ident!("set_{}", f.name);
91        let get = format_ident!("get_{}", f.name);
92        let clear = format_ident!("clear_{}", f.name);
93        let set_doc = format!("Stores the `{n}` value.");
94        let get_doc = format!("Returns the `{n}` value, if available.");
95        let clear_doc = format!("Clears the `{n}` value.");
96        quote! {
97            #[doc = #set_doc]
98            pub async fn #set(&self, value: &#ty) {
99                let implementation = self
100                    .implementation
101                    .lock()
102                    .expect("Mutex is not poisoned")
103                    .as_ref()
104                    .expect("StateBridge not registered")
105                    .clone();
106                implementation.#set(value.to_owned()).await
107            }
108
109            #[doc = #get_doc]
110            pub async fn #get(&self) -> Option<#ty> {
111                let implementation = self
112                    .implementation
113                    .lock()
114                    .expect("Mutex is not poisoned")
115                    .as_ref()
116                    .expect("StateBridge not registered")
117                    .clone();
118                implementation.#get().await
119            }
120
121            #[doc = #clear_doc]
122            pub async fn #clear(&self) {
123                let implementation = self
124                    .implementation
125                    .lock()
126                    .expect("Mutex is not poisoned")
127                    .as_ref()
128                    .expect("StateBridge not registered")
129                    .clone();
130                implementation.#clear().await
131            }
132        }
133    });
134
135    let client_forwarders = fields.iter().map(|f| {
136        let ty = &f.ty;
137        let n = f.name.to_string();
138        let set = format_ident!("set_{}", f.name);
139        let get = format_ident!("get_{}", f.name);
140        let clear = format_ident!("clear_{}", f.name);
141        let set_doc = format!("Sets the `{n}` value in client-managed state.");
142        let get_doc = format!("Gets the `{n}` value from client-managed state, if available.");
143        let clear_doc = format!("Clears the `{n}` value from client-managed state.");
144        quote! {
145            #[doc = #set_doc]
146            pub async fn #set(&self, value: &#ty) {
147                self.client.internal.state_bridge.#set(value).await;
148            }
149
150            #[doc = #get_doc]
151            pub async fn #get(&self) -> Option<#ty> {
152                self.client.internal.state_bridge.#get().await
153            }
154
155            #[doc = #clear_doc]
156            pub async fn #clear(&self) {
157                self.client.internal.state_bridge.#clear().await;
158            }
159        }
160    });
161
162    let mut ts_iface = String::from(
163        "/**\n * Typescript interface that the state bridge needs to implement. The state bridge\n * is a temporary layer that allows quickly transitioning non-repository shaped\n * state to be accessible from within the SDK.\n */\nexport interface WasmStateBridge {\n",
164    );
165    for f in &fields {
166        let n = f.name.to_string();
167        let t = f.ts.value();
168        ts_iface.push_str(&format!("    set_{n}(value: {t}): Promise<void>;\n"));
169        ts_iface.push_str(&format!("    get_{n}(): Promise<{t} | null>;\n"));
170        ts_iface.push_str(&format!("    clear_{n}(): Promise<void>;\n"));
171    }
172    ts_iface.push_str("}\n");
173
174    let extern_methods = fields.iter().map(|f| {
175        let ty = &f.ty;
176        let n = f.name.to_string();
177        let set = format_ident!("set_{}", f.name);
178        let get = format_ident!("get_{}", f.name);
179        let clear = format_ident!("clear_{}", f.name);
180        let set_doc = format!("JS-side `set_{n}` method on `WasmStateBridge`.");
181        let get_doc = format!("JS-side `get_{n}` method on `WasmStateBridge`.");
182        let clear_doc = format!("JS-side `clear_{n}` method on `WasmStateBridge`.");
183        quote! {
184            #[doc = #set_doc]
185            #[wasm_bindgen(method)]
186            pub async fn #set(
187                this: &crate::key_management::state_bridge::RawWasmStateBridge,
188                value: #ty,
189            );
190            #[doc = #get_doc]
191            #[wasm_bindgen(method)]
192            pub async fn #get(
193                this: &crate::key_management::state_bridge::RawWasmStateBridge,
194            ) -> Option<#ty>;
195            #[doc = #clear_doc]
196            #[wasm_bindgen(method)]
197            pub async fn #clear(
198                this: &crate::key_management::state_bridge::RawWasmStateBridge,
199            );
200        }
201    });
202
203    let test_support_struct_fields = fields.iter().map(|f| {
204        let name = &f.name;
205        let ty = &f.ty;
206        quote! { #name: ::std::sync::Mutex<Option<#ty>>, }
207    });
208
209    let test_support_impl_methods = fields.iter().map(|f| {
210        let name = &f.name;
211        let ty = &f.ty;
212        let set = format_ident!("set_{}", f.name);
213        let get = format_ident!("get_{}", f.name);
214        let clear = format_ident!("clear_{}", f.name);
215        quote! {
216            async fn #set(&self, value: #ty) {
217                *self.#name.lock().expect("not poisoned") = Some(value);
218            }
219            async fn #get(&self) -> Option<#ty> {
220                self.#name.lock().expect("not poisoned").clone()
221            }
222            async fn #clear(&self) {
223                *self.#name.lock().expect("not poisoned") = None;
224            }
225        }
226    });
227
228    let wasm_impls = fields.iter().map(|f| {
229        let ty = &f.ty;
230        let set = format_ident!("set_{}", f.name);
231        let get = format_ident!("get_{}", f.name);
232        let clear = format_ident!("clear_{}", f.name);
233        quote! {
234            async fn #set(&self, value: #ty) {
235                self.0
236                    .run_in_thread(move |state| async move {
237                        state.#set(value).await
238                    })
239                    .await
240                    .expect("State bridge call panicked");
241            }
242            async fn #get(&self) -> Option<#ty> {
243                self.0
244                    .run_in_thread(|state| async move {
245                        state.#get().await
246                    })
247                    .await
248                    .expect("State bridge call panicked")
249            }
250            async fn #clear(&self) {
251                self.0
252                    .run_in_thread(|state| async move {
253                        state.#clear().await
254                    })
255                    .await
256                    .expect("State bridge call panicked");
257            }
258        }
259    });
260
261    let expanded = quote! {
262        /// Host-provided storage bridge for key-management state.
263        ///
264        /// SDK consumers register an implementation that persists or caches sensitive
265        /// account state across unlock flows.
266        #[cfg_attr(target_arch = "wasm32", ::async_trait::async_trait(?Send))]
267        #[cfg_attr(not(target_arch = "wasm32"), ::async_trait::async_trait)]
268        pub trait StateBridgeImpl: Send + Sync {
269            #(#trait_methods)*
270        }
271
272        impl StateBridge {
273            #(#bridge_wrappers)*
274        }
275
276        impl crate::key_management::state_bridge::StateBridgeClient {
277            #(#client_forwarders)*
278        }
279
280        #[cfg(target_arch = "wasm32")]
281        #[::wasm_bindgen::prelude::wasm_bindgen(typescript_custom_section)]
282        const TS_CUSTOM_TYPES_STATE_BRIDGE: &'static str = #ts_iface;
283
284        #[cfg(target_arch = "wasm32")]
285        #[::wasm_bindgen::prelude::wasm_bindgen]
286        extern "C" {
287            #(#extern_methods)*
288        }
289
290        #[cfg(target_arch = "wasm32")]
291        #[::async_trait::async_trait(?Send)]
292        impl StateBridgeImpl for crate::key_management::state_bridge::WasmStateBridge {
293            #(#wasm_impls)*
294        }
295
296        #[cfg(test)]
297        pub(crate) mod test_support {
298            use super::*;
299
300            /// In-memory `StateBridgeImpl` for use in tests.
301            #[derive(Default)]
302            pub(crate) struct InMemoryStateBridge {
303                #(#test_support_struct_fields)*
304            }
305
306            #[cfg_attr(target_arch = "wasm32", ::async_trait::async_trait(?Send))]
307            #[cfg_attr(not(target_arch = "wasm32"), ::async_trait::async_trait)]
308            impl super::StateBridgeImpl for InMemoryStateBridge {
309                #(#test_support_impl_methods)*
310            }
311        }
312    };
313
314    TokenStream::from(expanded)
315}