bitwarden_threading/
thread_bound_runner.rs1#![allow(dead_code)]
2#![allow(unused_variables)]
3
4use std::{future::Future, pin::Pin, rc::Rc};
5
6use bitwarden_error::bitwarden_error;
7use thiserror::Error;
8#[cfg(not(target_arch = "wasm32"))]
9use tokio::task::spawn_local;
10#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
11use wasm_bindgen_futures::spawn_local;
12
13type CallFunction<ThreadState> =
14 Box<dyn FnOnce(Rc<ThreadState>) -> Pin<Box<dyn Future<Output = ()>>> + Send>;
15
16struct CallRequest<ThreadState> {
17 function: CallFunction<ThreadState>,
18}
19
20#[derive(Debug, Error)]
24#[error("The call failed before it could return a value: {0}")]
25#[bitwarden_error(basic)]
26pub struct CallError(String);
27
28pub struct ThreadBoundRunner<ThreadState> {
84 call_channel_tx: tokio::sync::mpsc::Sender<CallRequest<ThreadState>>,
85}
86
87impl<ThreadState> Clone for ThreadBoundRunner<ThreadState> {
94 fn clone(&self) -> Self {
95 ThreadBoundRunner {
96 call_channel_tx: self.call_channel_tx.clone(),
97 }
98 }
99}
100
101impl<ThreadState> ThreadBoundRunner<ThreadState>
102where
103 ThreadState: 'static,
104{
105 #[allow(missing_docs)]
106 pub fn new(state: ThreadState) -> Self {
107 let (call_channel_tx, mut call_channel_rx) =
108 tokio::sync::mpsc::channel::<CallRequest<ThreadState>>(1);
109
110 spawn_local(async move {
111 let state = Rc::new(state);
112 while let Some(request) = call_channel_rx.recv().await {
113 spawn_local((request.function)(state.clone()));
114 }
115 });
116
117 ThreadBoundRunner { call_channel_tx }
118 }
119
120 pub async fn run_in_thread<F, Fut, Output>(&self, function: F) -> Result<Output, CallError>
130 where
131 F: FnOnce(Rc<ThreadState>) -> Fut + Send + 'static,
132 Fut: Future<Output = Output>,
133 Output: Send + Sync + 'static,
134 {
135 let (return_channel_tx, return_channel_rx) = tokio::sync::oneshot::channel();
136 let request = CallRequest {
137 function: Box::new(|state| {
138 Box::pin(async move {
139 let result = function(state);
140 return_channel_tx.send(result.await).unwrap_or_else(|_| {
141 log::warn!(
142 "ThreadBoundDispatcher failed to send result back to the caller"
143 );
144 });
145 })
146 }),
147 };
148
149 self.call_channel_tx
150 .send(request)
151 .await
152 .expect("Call channel should not be able to close while anything still still has a reference to this object");
153 return_channel_rx
154 .await
155 .map_err(|e| CallError(e.to_string()))
156 }
157}
158
159#[cfg(test)]
160mod test {
161 use super::*;
162
163 async fn run_test<F>(test: F) -> F::Output
165 where
166 F: std::future::Future,
167 {
168 #[cfg(not(target_arch = "wasm32"))]
169 {
170 let local_set = tokio::task::LocalSet::new();
171 local_set.run_until(test).await
172 }
173
174 #[cfg(target_arch = "wasm32")]
175 {
176 test.await
177 }
178 }
179
180 async fn run_in_another_thread<F>(test: F)
181 where
182 F: std::future::Future + Send + 'static,
183 F::Output: Send,
184 {
185 #[cfg(not(target_arch = "wasm32"))]
186 {
187 tokio::spawn(test).await.expect("Thread panicked");
188 }
189
190 #[cfg(target_arch = "wasm32")]
191 {
192 test.await;
193 }
194 }
195
196 #[derive(Default)]
197 struct State {
198 _un_send_marker: std::marker::PhantomData<*const ()>,
200 }
201
202 impl State {
203 pub fn add(&self, input: (i32, i32)) -> i32 {
204 input.0 + input.1
205 }
206
207 #[allow(clippy::unused_async)]
208 pub async fn async_add(&self, input: (i32, i32)) -> i32 {
209 input.0 + input.1
210 }
211 }
212
213 #[tokio::test]
214 async fn calls_function_and_returns_value() {
215 run_test(async {
216 let runner = ThreadBoundRunner::new(State::default());
217
218 let result = runner
219 .run_in_thread(|state| async move {
220 let input = (1, 2);
221 state.add(input)
222 })
223 .await
224 .expect("Calling function failed");
225
226 assert_eq!(result, 3);
227 })
228 .await;
229 }
230
231 #[tokio::test]
232 async fn calls_async_function_and_returns_value() {
233 run_test(async {
234 let runner = ThreadBoundRunner::new(State::default());
235
236 let result = runner
237 .run_in_thread(|state| async move {
238 let input = (1, 2);
239 state.async_add(input).await
240 })
241 .await
242 .expect("Calling function failed");
243
244 assert_eq!(result, 3);
245 })
246 .await;
247 }
248
249 #[tokio::test]
250 async fn can_continue_running_if_a_call_panics() {
251 run_test(async {
252 let runner = ThreadBoundRunner::new(State::default());
253
254 runner
255 .run_in_thread::<_, _, ()>(|state| async move {
256 panic!("This is a test panic");
257 })
258 .await
259 .expect_err("Calling function should have panicked");
260
261 let result = runner
262 .run_in_thread(|state| async move {
263 let input = (1, 2);
264 state.async_add(input).await
265 })
266 .await
267 .expect("Calling function failed");
268
269 assert_eq!(result, 3);
270 })
271 .await;
272 }
273}