bitwarden_threading/
thread_bound_runner.rs

1#![allow(dead_code)]
2#![allow(unused_variables)]
3
4use std::{future::Future, pin::Pin, rc::Rc};
5
6use bitwarden_error::bitwarden_error;
7use thiserror::Error;
8#[cfg(not(target_arch = "wasm32"))]
9use tokio::task::spawn_local;
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use wasm_bindgen_futures::spawn_local;
12
13type CallFunction<ThreadState> =
14    Box<dyn FnOnce(Rc<ThreadState>) -> Pin<Box<dyn Future<Output = ()>>> + Send>;
15
16struct CallRequest<ThreadState> {
17    function: CallFunction<ThreadState>,
18}
19
20/// The call failed before it could return a value. This should not happen unless
21/// the thread panics, which can only happen if the function passed to `run_in_thread`
22/// panics.
23#[derive(Debug, Error)]
24#[error("The call failed before it could return a value: {0}")]
25#[bitwarden_error(basic)]
26pub struct CallError(String);
27
28/// A runner that takes a non-`Send` state and makes it `Send` compatible.
29///
30/// `ThreadBoundRunner` is designed to safely encapsulate a `!Send` state object by
31/// pinning it to a single thread using `spawn_local`. It provides a `Send` API that
32/// allows other threads to submit tasks (function pointers or closures) that operate on the
33/// thread-bound state.
34///
35/// Tasks are queued via an internal channel and are executed sequentially on the owning thread.
36///
37/// # Example
38/// ```
39/// # tokio_test::block_on(tokio::task::LocalSet::new().run_until(async {
40/// use bitwarden_threading::ThreadBoundRunner;
41///
42/// struct State;
43///
44/// impl State {
45///     pub async fn do_something(&self, some_input: i32) -> i32 {
46///         return some_input;
47///     }
48/// }
49///
50/// let runner = ThreadBoundRunner::new(State);
51/// let input = 42;
52///
53/// let output = runner.run_in_thread(move |state| async move {
54///   return state.do_something(input).await;
55/// }).await;
56///
57/// assert_eq!(output.unwrap(), 42);
58/// # }));
59/// ```
60///
61/// If you need mutable access to the state, you can wrap the `ThreadState` in a `Mutex` or
62/// `RwLock` and use the `run_in_thread` method to lock it before accessing it.
63///
64/// # Example
65/// ```
66/// # tokio_test::block_on(tokio::task::LocalSet::new().run_until(async {
67/// use bitwarden_threading::ThreadBoundRunner;
68/// use tokio::sync::Mutex;
69///
70/// struct State(i32);
71///
72/// let runner = ThreadBoundRunner::new(Mutex::new(State(0)));
73///
74/// runner.run_in_thread(|state| async move {
75///   state.lock().await.0 += 1;
76/// }).await;
77/// # }));
78/// ```
79///
80/// This pattern is useful for interacting with APIs or data structures that must remain
81/// on the same thread, such as GUI toolkits, WebAssembly contexts, or other thread-bound
82/// environments.
83#[derive(Clone)]
84pub struct ThreadBoundRunner<ThreadState> {
85    call_channel_tx: tokio::sync::mpsc::Sender<CallRequest<ThreadState>>,
86}
87
88impl<ThreadState> ThreadBoundRunner<ThreadState>
89where
90    ThreadState: 'static,
91{
92    #[allow(missing_docs)]
93    pub fn new(state: ThreadState) -> Self {
94        let (call_channel_tx, mut call_channel_rx) =
95            tokio::sync::mpsc::channel::<CallRequest<ThreadState>>(1);
96
97        spawn_local(async move {
98            let state = Rc::new(state);
99            while let Some(request) = call_channel_rx.recv().await {
100                spawn_local((request.function)(state.clone()));
101            }
102        });
103
104        ThreadBoundRunner { call_channel_tx }
105    }
106
107    /// Submit a task to be executed on the thread-bound state.
108    ///
109    /// The provided function is executed on the thread that owns the internal `ThreadState`,
110    /// ensuring safe access to `!Send` data. Tasks are dispatched in the order they are
111    /// received, but because they are asynchronous, multiple tasks may be in-flight and running
112    /// concurrently if their futures yield.
113    ///
114    /// # Returns
115    /// A future that resolves to the result of the function once it has been executed.
116    pub async fn run_in_thread<F, Fut, Output>(&self, function: F) -> Result<Output, CallError>
117    where
118        F: FnOnce(Rc<ThreadState>) -> Fut + Send + 'static,
119        Fut: Future<Output = Output>,
120        Output: Send + Sync + 'static,
121    {
122        let (return_channel_tx, return_channel_rx) = tokio::sync::oneshot::channel();
123        let request = CallRequest {
124            function: Box::new(|state| {
125                Box::pin(async move {
126                    let result = function(state);
127                    return_channel_tx.send(result.await).unwrap_or_else(|_| {
128                        log::warn!(
129                            "ThreadBoundDispatcher failed to send result back to the caller"
130                        );
131                    });
132                })
133            }),
134        };
135
136        self.call_channel_tx
137            .send(request)
138            .await
139            .expect("Call channel should not be able to close while anything still still has a reference to this object");
140        return_channel_rx
141            .await
142            .map_err(|e| CallError(e.to_string()))
143    }
144}
145
146#[cfg(test)]
147mod test {
148    use super::*;
149
150    /// Utility function to run a test in a local context (allows using tokio::..::spawn_local)
151    async fn run_test<F>(test: F) -> F::Output
152    where
153        F: std::future::Future,
154    {
155        #[cfg(not(target_arch = "wasm32"))]
156        {
157            let local_set = tokio::task::LocalSet::new();
158            local_set.run_until(test).await
159        }
160
161        #[cfg(target_arch = "wasm32")]
162        {
163            test.await
164        }
165    }
166
167    async fn run_in_another_thread<F>(test: F)
168    where
169        F: std::future::Future + Send + 'static,
170        F::Output: Send,
171    {
172        #[cfg(not(target_arch = "wasm32"))]
173        {
174            tokio::spawn(test).await.expect("Thread panicked");
175        }
176
177        #[cfg(target_arch = "wasm32")]
178        {
179            test.await;
180        }
181    }
182
183    #[derive(Default)]
184    struct State {
185        /// This is a marker to ensure that the struct is not Send
186        _un_send_marker: std::marker::PhantomData<*const ()>,
187    }
188
189    impl State {
190        pub fn add(&self, input: (i32, i32)) -> i32 {
191            input.0 + input.1
192        }
193
194        #[allow(clippy::unused_async)]
195        pub async fn async_add(&self, input: (i32, i32)) -> i32 {
196            input.0 + input.1
197        }
198    }
199
200    #[tokio::test]
201    async fn calls_function_and_returns_value() {
202        run_test(async {
203            let runner = ThreadBoundRunner::new(State::default());
204
205            let result = runner
206                .run_in_thread(|state| async move {
207                    let input = (1, 2);
208                    state.add(input)
209                })
210                .await
211                .expect("Calling function failed");
212
213            assert_eq!(result, 3);
214        })
215        .await;
216    }
217
218    #[tokio::test]
219    async fn calls_async_function_and_returns_value() {
220        run_test(async {
221            let runner = ThreadBoundRunner::new(State::default());
222
223            let result = runner
224                .run_in_thread(|state| async move {
225                    let input = (1, 2);
226                    state.async_add(input).await
227                })
228                .await
229                .expect("Calling function failed");
230
231            assert_eq!(result, 3);
232        })
233        .await;
234    }
235
236    #[tokio::test]
237    async fn can_continue_running_if_a_call_panics() {
238        run_test(async {
239            let runner = ThreadBoundRunner::new(State::default());
240
241            runner
242                .run_in_thread::<_, _, ()>(|state| async move {
243                    panic!("This is a test panic");
244                })
245                .await
246                .expect_err("Calling function should have panicked");
247
248            let result = runner
249                .run_in_thread(|state| async move {
250                    let input = (1, 2);
251                    state.async_add(input).await
252                })
253                .await
254                .expect("Calling function failed");
255
256            assert_eq!(result, 3);
257        })
258        .await;
259    }
260}