bitwarden_threading/
thread_bound_runner.rs1#![allow(dead_code)]
2#![allow(unused_variables)]
3
4use std::{future::Future, pin::Pin, rc::Rc};
5
6use bitwarden_error::bitwarden_error;
7use thiserror::Error;
8#[cfg(not(target_arch = "wasm32"))]
9use tokio::task::spawn_local;
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use wasm_bindgen_futures::spawn_local;
12
13type CallFunction<ThreadState> =
14 Box<dyn FnOnce(Rc<ThreadState>) -> Pin<Box<dyn Future<Output = ()>>> + Send>;
15
16struct CallRequest<ThreadState> {
17 function: CallFunction<ThreadState>,
18}
19
20#[derive(Debug, Error)]
24#[error("The call failed before it could return a value: {0}")]
25#[bitwarden_error(basic)]
26pub struct CallError(String);
27
28#[derive(Clone)]
84pub struct ThreadBoundRunner<ThreadState> {
85 call_channel_tx: tokio::sync::mpsc::Sender<CallRequest<ThreadState>>,
86}
87
88impl<ThreadState> ThreadBoundRunner<ThreadState>
89where
90 ThreadState: 'static,
91{
92 #[allow(missing_docs)]
93 pub fn new(state: ThreadState) -> Self {
94 let (call_channel_tx, mut call_channel_rx) =
95 tokio::sync::mpsc::channel::<CallRequest<ThreadState>>(1);
96
97 spawn_local(async move {
98 let state = Rc::new(state);
99 while let Some(request) = call_channel_rx.recv().await {
100 spawn_local((request.function)(state.clone()));
101 }
102 });
103
104 ThreadBoundRunner { call_channel_tx }
105 }
106
107 pub async fn run_in_thread<F, Fut, Output>(&self, function: F) -> Result<Output, CallError>
117 where
118 F: FnOnce(Rc<ThreadState>) -> Fut + Send + 'static,
119 Fut: Future<Output = Output>,
120 Output: Send + Sync + 'static,
121 {
122 let (return_channel_tx, return_channel_rx) = tokio::sync::oneshot::channel();
123 let request = CallRequest {
124 function: Box::new(|state| {
125 Box::pin(async move {
126 let result = function(state);
127 return_channel_tx.send(result.await).unwrap_or_else(|_| {
128 log::warn!(
129 "ThreadBoundDispatcher failed to send result back to the caller"
130 );
131 });
132 })
133 }),
134 };
135
136 self.call_channel_tx
137 .send(request)
138 .await
139 .expect("Call channel should not be able to close while anything still still has a reference to this object");
140 return_channel_rx
141 .await
142 .map_err(|e| CallError(e.to_string()))
143 }
144}
145
146#[cfg(test)]
147mod test {
148 use super::*;
149
150 async fn run_test<F>(test: F) -> F::Output
152 where
153 F: std::future::Future,
154 {
155 #[cfg(not(target_arch = "wasm32"))]
156 {
157 let local_set = tokio::task::LocalSet::new();
158 local_set.run_until(test).await
159 }
160
161 #[cfg(target_arch = "wasm32")]
162 {
163 test.await
164 }
165 }
166
167 async fn run_in_another_thread<F>(test: F)
168 where
169 F: std::future::Future + Send + 'static,
170 F::Output: Send,
171 {
172 #[cfg(not(target_arch = "wasm32"))]
173 {
174 tokio::spawn(test).await.expect("Thread panicked");
175 }
176
177 #[cfg(target_arch = "wasm32")]
178 {
179 test.await;
180 }
181 }
182
183 #[derive(Default)]
184 struct State {
185 _un_send_marker: std::marker::PhantomData<*const ()>,
187 }
188
189 impl State {
190 pub fn add(&self, input: (i32, i32)) -> i32 {
191 input.0 + input.1
192 }
193
194 #[allow(clippy::unused_async)]
195 pub async fn async_add(&self, input: (i32, i32)) -> i32 {
196 input.0 + input.1
197 }
198 }
199
200 #[tokio::test]
201 async fn calls_function_and_returns_value() {
202 run_test(async {
203 let runner = ThreadBoundRunner::new(State::default());
204
205 let result = runner
206 .run_in_thread(|state| async move {
207 let input = (1, 2);
208 state.add(input)
209 })
210 .await
211 .expect("Calling function failed");
212
213 assert_eq!(result, 3);
214 })
215 .await;
216 }
217
218 #[tokio::test]
219 async fn calls_async_function_and_returns_value() {
220 run_test(async {
221 let runner = ThreadBoundRunner::new(State::default());
222
223 let result = runner
224 .run_in_thread(|state| async move {
225 let input = (1, 2);
226 state.async_add(input).await
227 })
228 .await
229 .expect("Calling function failed");
230
231 assert_eq!(result, 3);
232 })
233 .await;
234 }
235
236 #[tokio::test]
237 async fn can_continue_running_if_a_call_panics() {
238 run_test(async {
239 let runner = ThreadBoundRunner::new(State::default());
240
241 runner
242 .run_in_thread::<_, _, ()>(|state| async move {
243 panic!("This is a test panic");
244 })
245 .await
246 .expect_err("Calling function should have panicked");
247
248 let result = runner
249 .run_in_thread(|state| async move {
250 let input = (1, 2);
251 state.async_add(input).await
252 })
253 .await
254 .expect("Calling function failed");
255
256 assert_eq!(result, 3);
257 })
258 .await;
259 }
260}