bitwarden_threading/
thread_bound_runner.rs

1#![allow(dead_code)]
2#![allow(unused_variables)]
3
4use std::{future::Future, pin::Pin, rc::Rc};
5
6use bitwarden_error::bitwarden_error;
7use thiserror::Error;
8#[cfg(not(target_arch = "wasm32"))]
9use tokio::task::spawn_local;
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use wasm_bindgen_futures::spawn_local;
12
13type CallFunction<ThreadState> =
14    Box<dyn FnOnce(Rc<ThreadState>) -> Pin<Box<dyn Future<Output = ()>>> + Send>;
15
16struct CallRequest<ThreadState> {
17    function: CallFunction<ThreadState>,
18}
19
20/// The call failed before it could return a value. This should not happen unless
21/// the thread panics, which can only happen if the function passed to `run_in_thread`
22/// panics.
23#[derive(Debug, Error)]
24#[error("The call failed before it could return a value: {0}")]
25#[bitwarden_error(basic)]
26pub struct CallError(String);
27
28/// A runner that takes a non-`Send` state and makes it `Send` compatible.
29///
30/// `ThreadBoundRunner` is designed to safely encapsulate a `!Send` state object by
31/// pinning it to a single thread using `spawn_local`. It provides a `Send` API that
32/// allows other threads to submit tasks (function pointers or closures) that operate on the
33/// thread-bound state.
34///
35/// Tasks are queued via an internal channel and are executed sequentially on the owning thread.
36///
37/// # Example
38/// ```
39/// # tokio_test::block_on(tokio::task::LocalSet::new().run_until(async {
40/// use bitwarden_threading::ThreadBoundRunner;
41///
42/// struct State;
43///
44/// impl State {
45///     pub async fn do_something(&self, some_input: i32) -> i32 {
46///         return some_input;
47///     }
48/// }
49///
50/// let runner = ThreadBoundRunner::new(State);
51/// let input = 42;
52///
53/// let output = runner.run_in_thread(move |state| async move {
54///   return state.do_something(input).await;
55/// }).await;
56///
57/// assert_eq!(output.unwrap(), 42);
58/// # }));
59/// ```
60///
61/// If you need mutable access to the state, you can wrap the `ThreadState` in a `Mutex` or
62/// `RwLock` and use the `run_in_thread` method to lock it before accessing it.
63///
64/// # Example
65/// ```
66/// # tokio_test::block_on(tokio::task::LocalSet::new().run_until(async {
67/// use bitwarden_threading::ThreadBoundRunner;
68/// use tokio::sync::Mutex;
69///
70/// struct State(i32);
71///
72/// let runner = ThreadBoundRunner::new(Mutex::new(State(0)));
73///
74/// runner.run_in_thread(|state| async move {
75///   state.lock().await.0 += 1;
76/// }).await;
77/// # }));
78/// ```
79///
80/// This pattern is useful for interacting with APIs or data structures that must remain
81/// on the same thread, such as GUI toolkits, WebAssembly contexts, or other thread-bound
82/// environments.
83pub struct ThreadBoundRunner<ThreadState> {
84    call_channel_tx: tokio::sync::mpsc::Sender<CallRequest<ThreadState>>,
85}
86
87/// Makes a clone of the runner handle.
88///
89/// This creates another handle to the same underlying runner object.
90/// The underlying state is not duplicated; all clones refer to the same
91/// instance.
92// This is not implemented using derive to remove the implicit bound on `ThreadState: Clone`
93impl<ThreadState> Clone for ThreadBoundRunner<ThreadState> {
94    fn clone(&self) -> Self {
95        ThreadBoundRunner {
96            call_channel_tx: self.call_channel_tx.clone(),
97        }
98    }
99}
100
101impl<ThreadState> ThreadBoundRunner<ThreadState>
102where
103    ThreadState: 'static,
104{
105    #[allow(missing_docs)]
106    pub fn new(state: ThreadState) -> Self {
107        let (call_channel_tx, mut call_channel_rx) =
108            tokio::sync::mpsc::channel::<CallRequest<ThreadState>>(1);
109
110        spawn_local(async move {
111            let state = Rc::new(state);
112            while let Some(request) = call_channel_rx.recv().await {
113                spawn_local((request.function)(state.clone()));
114            }
115        });
116
117        ThreadBoundRunner { call_channel_tx }
118    }
119
120    /// Submit a task to be executed on the thread-bound state.
121    ///
122    /// The provided function is executed on the thread that owns the internal `ThreadState`,
123    /// ensuring safe access to `!Send` data. Tasks are dispatched in the order they are
124    /// received, but because they are asynchronous, multiple tasks may be in-flight and running
125    /// concurrently if their futures yield.
126    ///
127    /// # Returns
128    /// A future that resolves to the result of the function once it has been executed.
129    pub async fn run_in_thread<F, Fut, Output>(&self, function: F) -> Result<Output, CallError>
130    where
131        F: FnOnce(Rc<ThreadState>) -> Fut + Send + 'static,
132        Fut: Future<Output = Output>,
133        Output: Send + Sync + 'static,
134    {
135        let (return_channel_tx, return_channel_rx) = tokio::sync::oneshot::channel();
136        let request = CallRequest {
137            function: Box::new(|state| {
138                Box::pin(async move {
139                    let result = function(state);
140                    return_channel_tx.send(result.await).unwrap_or_else(|_| {
141                        log::warn!(
142                            "ThreadBoundDispatcher failed to send result back to the caller"
143                        );
144                    });
145                })
146            }),
147        };
148
149        self.call_channel_tx
150            .send(request)
151            .await
152            .expect("Call channel should not be able to close while anything still still has a reference to this object");
153        return_channel_rx
154            .await
155            .map_err(|e| CallError(e.to_string()))
156    }
157}
158
159#[cfg(test)]
160mod test {
161    use super::*;
162
163    /// Utility function to run a test in a local context (allows using tokio::..::spawn_local)
164    async fn run_test<F>(test: F) -> F::Output
165    where
166        F: std::future::Future,
167    {
168        #[cfg(not(target_arch = "wasm32"))]
169        {
170            let local_set = tokio::task::LocalSet::new();
171            local_set.run_until(test).await
172        }
173
174        #[cfg(target_arch = "wasm32")]
175        {
176            test.await
177        }
178    }
179
180    async fn run_in_another_thread<F>(test: F)
181    where
182        F: std::future::Future + Send + 'static,
183        F::Output: Send,
184    {
185        #[cfg(not(target_arch = "wasm32"))]
186        {
187            tokio::spawn(test).await.expect("Thread panicked");
188        }
189
190        #[cfg(target_arch = "wasm32")]
191        {
192            test.await;
193        }
194    }
195
196    #[derive(Default)]
197    struct State {
198        /// This is a marker to ensure that the struct is not Send
199        _un_send_marker: std::marker::PhantomData<*const ()>,
200    }
201
202    impl State {
203        pub fn add(&self, input: (i32, i32)) -> i32 {
204            input.0 + input.1
205        }
206
207        #[allow(clippy::unused_async)]
208        pub async fn async_add(&self, input: (i32, i32)) -> i32 {
209            input.0 + input.1
210        }
211    }
212
213    #[tokio::test]
214    async fn calls_function_and_returns_value() {
215        run_test(async {
216            let runner = ThreadBoundRunner::new(State::default());
217
218            let result = runner
219                .run_in_thread(|state| async move {
220                    let input = (1, 2);
221                    state.add(input)
222                })
223                .await
224                .expect("Calling function failed");
225
226            assert_eq!(result, 3);
227        })
228        .await;
229    }
230
231    #[tokio::test]
232    async fn calls_async_function_and_returns_value() {
233        run_test(async {
234            let runner = ThreadBoundRunner::new(State::default());
235
236            let result = runner
237                .run_in_thread(|state| async move {
238                    let input = (1, 2);
239                    state.async_add(input).await
240                })
241                .await
242                .expect("Calling function failed");
243
244            assert_eq!(result, 3);
245        })
246        .await;
247    }
248
249    #[tokio::test]
250    async fn can_continue_running_if_a_call_panics() {
251        run_test(async {
252            let runner = ThreadBoundRunner::new(State::default());
253
254            runner
255                .run_in_thread::<_, _, ()>(|state| async move {
256                    panic!("This is a test panic");
257                })
258                .await
259                .expect_err("Calling function should have panicked");
260
261            let result = runner
262                .run_in_thread(|state| async move {
263                    let input = (1, 2);
264                    state.async_add(input).await
265                })
266                .await
267                .expect("Calling function failed");
268
269            assert_eq!(result, 3);
270        })
271        .await;
272    }
273}