1use std::sync::{Arc, Mutex};
2
3use bitwarden_threading::cancellation_token::CancellationToken;
4use serde::de::DeserializeOwned;
5use thiserror::Error;
6use tokio::select;
7
8use crate::{
9 constants::CHANNEL_BUFFER_CAPACITY,
10 error::{AlreadyRunningError, ReceiveError, SendError, SubscribeError, TypedReceiveError},
11 message::{
12 IncomingMessage, OutgoingMessage, PayloadTypeName, TypedIncomingMessage,
13 TypedOutgoingMessage,
14 },
15 rpc::{
16 exec::{handler::ErasedRpcHandler, handler_registry::RpcHandlerRegistry},
17 request_message::{RPC_REQUEST_PAYLOAD_TYPE_NAME, RpcRequestPayload},
18 response_message::OutgoingRpcResponseMessage,
19 },
20 serde_utils,
21 traits::{CommunicationBackend, CryptoProvider, SessionRepository},
22};
23
24pub struct IpcClientSubscription {
29 pub(crate) receiver: tokio::sync::broadcast::Receiver<IncomingMessage>,
30 pub(crate) topic: Option<String>,
31}
32
33pub struct IpcClientTypedSubscription<Payload: DeserializeOwned + PayloadTypeName>(
38 IpcClientSubscription,
39 std::marker::PhantomData<Payload>,
40);
41
42struct IpcClientInner<Crypto, Com, Ses>
44where
45 Crypto: CryptoProvider<Com, Ses>,
46 Com: CommunicationBackend,
47 Ses: SessionRepository<Crypto::Session>,
48{
49 crypto: Crypto,
50 communication: Com,
51 sessions: Ses,
52
53 handlers: RpcHandlerRegistry,
54 incoming: Mutex<Option<tokio::sync::broadcast::Receiver<IncomingMessage>>>,
55 cancellation_token: Mutex<Option<CancellationToken>>,
56}
57
58pub struct IpcClientImpl<Crypto, Com, Ses>
64where
65 Crypto: CryptoProvider<Com, Ses>,
66 Com: CommunicationBackend,
67 Ses: SessionRepository<Crypto::Session>,
68{
69 inner: Arc<IpcClientInner<Crypto, Com, Ses>>,
70}
71
72impl<Crypto, Com, Ses> Clone for IpcClientImpl<Crypto, Com, Ses>
73where
74 Crypto: CryptoProvider<Com, Ses>,
75 Com: CommunicationBackend,
76 Ses: SessionRepository<Crypto::Session>,
77{
78 fn clone(&self) -> Self {
79 Self {
80 inner: self.inner.clone(),
81 }
82 }
83}
84
85impl<Crypto, Com, Ses> IpcClientImpl<Crypto, Com, Ses>
86where
87 Crypto: CryptoProvider<Com, Ses>,
88 Com: CommunicationBackend,
89 Ses: SessionRepository<Crypto::Session>,
90{
91 pub fn new(crypto: Crypto, communication: Com, sessions: Ses) -> Self {
94 Self {
95 inner: Arc::new(IpcClientInner {
96 crypto,
97 communication,
98 sessions,
99
100 handlers: RpcHandlerRegistry::new(),
101 incoming: Mutex::new(None),
102 cancellation_token: Mutex::new(None),
103 }),
104 }
105 }
106}
107
108#[async_trait::async_trait]
109impl<Crypto, Com, Ses> crate::ipc_client_trait::IpcClient for IpcClientImpl<Crypto, Com, Ses>
110where
111 Crypto: CryptoProvider<Com, Ses>,
112 Com: CommunicationBackend,
113 Ses: SessionRepository<Crypto::Session>,
114{
115 async fn start(
116 &self,
117 cancellation_token: Option<CancellationToken>,
118 ) -> Result<(), AlreadyRunningError> {
119 if self.is_running() {
120 return Err(AlreadyRunningError);
121 }
122
123 let cancellation_token = cancellation_token.unwrap_or_default();
124 self.inner
125 .cancellation_token
126 .lock()
127 .expect("Failed to lock cancellation token mutex")
128 .replace(cancellation_token.clone());
129
130 let com_receiver = self.inner.communication.subscribe().await;
131 let (client_tx, client_rx) = tokio::sync::broadcast::channel(CHANNEL_BUFFER_CAPACITY);
132
133 self.inner
134 .incoming
135 .lock()
136 .expect("Failed to lock incoming mutex")
137 .replace(client_rx);
138
139 let inner = self.inner.clone();
140 let future = async move {
141 loop {
142 let rpc_topic = RPC_REQUEST_PAYLOAD_TYPE_NAME.to_owned();
143 select! {
144 _ = cancellation_token.cancelled() => {
145 tracing::debug!("Cancellation signal received, stopping IPC client");
146 break;
147 }
148 received = inner.crypto.receive(&com_receiver, &inner.communication, &inner.sessions) => {
149 match received {
150 Ok(message) if message.topic == Some(rpc_topic) => {
151 handle_rpc_request(&inner, message)
152 }
153 Ok(message) => {
154 if client_tx.send(message).is_err() {
155 tracing::error!("Failed to save incoming message");
156 break;
157 };
158 }
159 Err(error) => {
160 tracing::error!(?error, "Error receiving message");
161 break;
162 }
163 }
164 }
165 }
166 }
167 tracing::debug!("IPC client shutting down");
168 stop_inner(&inner);
169 };
170
171 #[cfg(not(target_arch = "wasm32"))]
172 tokio::spawn(future);
173
174 #[cfg(target_arch = "wasm32")]
175 wasm_bindgen_futures::spawn_local(future);
176
177 Ok(())
178 }
179
180 fn is_running(&self) -> bool {
181 let has_incoming = self
182 .inner
183 .incoming
184 .lock()
185 .expect("Failed to lock incoming mutex")
186 .as_ref()
187 .map(|receiver| !receiver.is_closed())
188 .unwrap_or(false);
189 let has_cancellation_token = self
190 .inner
191 .cancellation_token
192 .lock()
193 .expect("Failed to lock cancellation token mutex")
194 .is_some();
195 has_incoming && has_cancellation_token
196 }
197
198 async fn send(&self, message: OutgoingMessage) -> Result<(), SendError> {
199 let result = self
200 .inner
201 .crypto
202 .send(&self.inner.communication, &self.inner.sessions, message)
203 .await;
204
205 if let Err(ref error) = result {
206 tracing::error!(?error, "Error sending message");
207 stop_inner(&self.inner);
208 }
209
210 result.map_err(|e| SendError(format!("{e:?}")))
211 }
212
213 async fn subscribe(
214 &self,
215 topic: Option<String>,
216 ) -> Result<IpcClientSubscription, SubscribeError> {
217 Ok(IpcClientSubscription {
218 receiver: self
219 .inner
220 .incoming
221 .lock()
222 .expect("Failed to lock incoming mutex")
223 .as_ref()
224 .ok_or(SubscribeError::NotStarted)?
225 .resubscribe(),
226 topic,
227 })
228 }
229
230 async fn register_rpc_handler_erased(&self, name: &str, handler: Box<dyn ErasedRpcHandler>) {
231 self.inner
232 .handlers
233 .register_erased(name.to_owned(), handler)
234 .await;
235 }
236}
237
238fn stop_inner<Crypto, Com, Ses>(inner: &IpcClientInner<Crypto, Com, Ses>)
239where
240 Crypto: CryptoProvider<Com, Ses>,
241 Com: CommunicationBackend,
242 Ses: SessionRepository<Crypto::Session>,
243{
244 let mut cancellation_token = inner
245 .cancellation_token
246 .lock()
247 .expect("Failed to lock cancellation token mutex");
248
249 if let Some(cancellation_token) = cancellation_token.take() {
250 cancellation_token.cancel();
251 }
252}
253
254fn handle_rpc_request<Crypto, Com, Ses>(
255 inner: &Arc<IpcClientInner<Crypto, Com, Ses>>,
256 incoming_message: IncomingMessage,
257) where
258 Crypto: CryptoProvider<Com, Ses>,
259 Com: CommunicationBackend,
260 Ses: SessionRepository<Crypto::Session>,
261{
262 let inner = inner.clone();
263 let future = async move {
264 #[derive(Debug, Error)]
265 enum HandleError {
266 #[error("Failed to deserialize request message: {0}")]
267 Deserialize(String),
268
269 #[error("Failed to serialize response message: {0}")]
270 Serialize(String),
271 }
272
273 async fn handle(
274 incoming_message: IncomingMessage,
275 handlers: &RpcHandlerRegistry,
276 ) -> Result<OutgoingMessage, HandleError> {
277 let request = RpcRequestPayload::from_slice(incoming_message.payload.clone()).map_err(
278 |e: serde_utils::DeserializeError| HandleError::Deserialize(e.to_string()),
279 )?;
280
281 let response = handlers.handle(&request).await;
282
283 let response_message = OutgoingRpcResponseMessage {
284 request_id: request.request_id(),
285 request_type: request.request_type(),
286 result: response,
287 };
288
289 let outgoing = TypedOutgoingMessage {
290 payload: response_message,
291 destination: incoming_message.source.into(),
292 }
293 .try_into()
294 .map_err(|e: serde_utils::SerializeError| HandleError::Serialize(e.to_string()))?;
295
296 Ok(outgoing)
297 }
298
299 match handle(incoming_message, &inner.handlers).await {
300 Ok(outgoing_message) => {
301 let result = inner
304 .crypto
305 .send(&inner.communication, &inner.sessions, outgoing_message)
306 .await;
307 if result.is_err() {
308 tracing::error!("Failed to send response message");
309 }
310 }
311 Err(error) => {
312 tracing::error!(%error, "Error handling RPC request");
313 }
314 }
315 };
316
317 #[cfg(not(target_arch = "wasm32"))]
318 tokio::spawn(future);
319
320 #[cfg(target_arch = "wasm32")]
321 wasm_bindgen_futures::spawn_local(future);
322}
323
324impl IpcClientSubscription {
325 pub async fn receive(
328 &mut self,
329 cancellation_token: Option<CancellationToken>,
330 ) -> Result<IncomingMessage, ReceiveError> {
331 let cancellation_token = cancellation_token.unwrap_or_default();
332
333 loop {
334 select! {
335 _ = cancellation_token.cancelled() => {
336 return Err(ReceiveError::Cancelled)
337 }
338 result = self.receiver.recv() => {
339 let received = result?;
340 if self.topic.is_none() || received.topic == self.topic {
341 return Ok::<IncomingMessage, ReceiveError>(received);
342 }
343 }
344 }
345 }
346 }
347}
348
349impl<Payload> IpcClientTypedSubscription<Payload>
350where
351 Payload: DeserializeOwned + PayloadTypeName,
352{
353 pub(crate) fn new(subscription: IpcClientSubscription) -> Self {
354 Self(subscription, std::marker::PhantomData)
355 }
356
357 pub async fn receive(
360 &mut self,
361 cancellation_token: Option<CancellationToken>,
362 ) -> Result<TypedIncomingMessage<Payload>, TypedReceiveError> {
363 let received = self.0.receive(cancellation_token).await?;
364 received
365 .try_into()
366 .map_err(|e: serde_utils::DeserializeError| TypedReceiveError::Typing(e.to_string()))
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use std::{collections::HashMap, time::Duration};
373
374 use bitwarden_threading::time::sleep;
375 use serde::{Deserialize, Serialize};
376
377 use super::*;
378 use crate::{
379 IpcClientExt,
380 endpoint::{Endpoint, HostId, Source},
381 ipc_client_trait::IpcClient,
382 message::PayloadTypeName,
383 rpc::{
384 request::RpcRequest,
385 request_message::{RPC_REQUEST_PAYLOAD_TYPE_NAME, RpcRequestMessage},
386 response_message::IncomingRpcResponseMessage,
387 },
388 traits::{InMemorySessionRepository, NoEncryptionCryptoProvider, TestCommunicationBackend},
389 };
390
391 struct TestCryptoProvider {
392 send_result: Option<Result<(), String>>,
394 receive_result: Option<Result<IncomingMessage, String>>,
396 }
397
398 type TestSessionRepository = InMemorySessionRepository<String>;
399 impl CryptoProvider<TestCommunicationBackend, TestSessionRepository> for TestCryptoProvider {
400 type Session = String;
401 type SendError = String;
402 type ReceiveError = String;
403
404 async fn receive(
405 &self,
406 _receiver: &<TestCommunicationBackend as CommunicationBackend>::Receiver,
407 _communication: &TestCommunicationBackend,
408 _sessions: &TestSessionRepository,
409 ) -> Result<IncomingMessage, Self::ReceiveError> {
410 match &self.receive_result {
411 Some(result) => result.clone(),
412 None => {
413 sleep(Duration::from_secs(600)).await;
415 Err("Simulated timeout".to_string())
416 }
417 }
418 }
419
420 async fn send(
421 &self,
422 _communication: &TestCommunicationBackend,
423 _sessions: &TestSessionRepository,
424 _message: OutgoingMessage,
425 ) -> Result<(), Self::SendError> {
426 match &self.send_result {
427 Some(result) => result.clone(),
428 None => {
429 sleep(Duration::from_secs(600)).await;
431 Err("Simulated timeout".to_string())
432 }
433 }
434 }
435 }
436
437 #[tokio::test]
438 async fn returns_send_error_when_crypto_provider_returns_error() {
439 let message = OutgoingMessage {
440 payload: vec![],
441 destination: Endpoint::BrowserBackground { id: HostId::Own },
442 topic: None,
443 };
444 let crypto_provider = TestCryptoProvider {
445 send_result: Some(Err("Crypto error".to_string())),
446 receive_result: Some(Err("Should not have be called".to_string())),
447 };
448 let communication_provider = TestCommunicationBackend::new();
449 let session_map = TestSessionRepository::new(HashMap::new());
450 let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
451 let _ = client.start(None).await;
452
453 let error = client.send(message).await.unwrap_err();
454
455 assert!(error.to_string().contains("Crypto error"));
456 }
457
458 #[tokio::test]
459 async fn communication_provider_has_outgoing_message_when_sending_through_ipc_client() {
460 let message = OutgoingMessage {
461 payload: vec![],
462 destination: Endpoint::BrowserBackground { id: HostId::Own },
463 topic: None,
464 };
465 let crypto_provider = NoEncryptionCryptoProvider;
466 let communication_provider = TestCommunicationBackend::new();
467 let session_map = InMemorySessionRepository::new(HashMap::new());
468 let client =
469 IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
470 let _ = client.start(None).await;
471
472 client.send(message.clone()).await.unwrap();
473
474 let outgoing_messages = communication_provider.outgoing().await;
475 assert_eq!(outgoing_messages, vec![message]);
476 }
477
478 #[tokio::test]
479 async fn returns_received_message_when_received_from_backend() {
480 let message = IncomingMessage {
481 payload: vec![],
482 source: Source::Web {
483 tab_id: 9001,
484 document_id: "doc-1".to_string(),
485 origin: "https://example.com".to_string(),
486 },
487 destination: Endpoint::BrowserBackground { id: HostId::Own },
488 topic: None,
489 };
490 let crypto_provider = NoEncryptionCryptoProvider;
491 let communication_provider = TestCommunicationBackend::new();
492 let session_map = InMemorySessionRepository::new(HashMap::new());
493 let client =
494 IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
495 let _ = client.start(None).await;
496
497 let mut subscription = client
498 .subscribe(None)
499 .await
500 .expect("Subscribing should not fail");
501 communication_provider.push_incoming(message.clone());
502 let received_message = subscription.receive(None).await.unwrap();
503
504 assert_eq!(received_message, message);
505 }
506
507 #[tokio::test]
508 async fn skips_non_matching_topics_and_returns_first_matching_message() {
509 let non_matching_message = IncomingMessage {
510 payload: vec![],
511 source: Source::Web {
512 tab_id: 9001,
513 document_id: "doc-1".to_string(),
514 origin: "https://example.com".to_string(),
515 },
516 destination: Endpoint::BrowserBackground { id: HostId::Own },
517 topic: Some("non_matching_topic".to_owned()),
518 };
519 let matching_message = IncomingMessage {
520 payload: vec![109],
521 source: Source::Web {
522 tab_id: 9001,
523 document_id: "doc-1".to_string(),
524 origin: "https://example.com".to_string(),
525 },
526 destination: Endpoint::BrowserBackground { id: HostId::Own },
527 topic: Some("matching_topic".to_owned()),
528 };
529
530 let crypto_provider = NoEncryptionCryptoProvider;
531 let communication_provider = TestCommunicationBackend::new();
532 let session_map = InMemorySessionRepository::new(HashMap::new());
533 let client =
534 IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
535 let _ = client.start(None).await;
536 let mut subscription = client
537 .subscribe(Some("matching_topic".to_owned()))
538 .await
539 .expect("Subscribing should not fail");
540 communication_provider.push_incoming(non_matching_message.clone());
541 communication_provider.push_incoming(non_matching_message.clone());
542 communication_provider.push_incoming(matching_message.clone());
543
544 let received_message: IncomingMessage = subscription.receive(None).await.unwrap();
545
546 assert_eq!(received_message, matching_message);
547 }
548
549 #[tokio::test]
550 async fn skips_unrelated_messages_and_returns_typed_message() {
551 #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
552 struct TestPayload {
553 some_data: String,
554 }
555
556 impl PayloadTypeName for TestPayload {
557 const PAYLOAD_TYPE_NAME: &str = "TestPayload";
558 }
559
560 let unrelated = IncomingMessage {
561 payload: vec![],
562 source: Source::Web {
563 tab_id: 9001,
564 document_id: "doc-1".to_string(),
565 origin: "https://example.com".to_string(),
566 },
567 destination: Endpoint::BrowserBackground { id: HostId::Own },
568 topic: None,
569 };
570 let typed_message = crate::message::TypedIncomingMessage {
571 payload: TestPayload {
572 some_data: "Hello, world!".to_string(),
573 },
574 source: Source::Web {
575 tab_id: 9001,
576 document_id: "doc-1".to_string(),
577 origin: "https://example.com".to_string(),
578 },
579 destination: Endpoint::BrowserBackground { id: HostId::Own },
580 };
581
582 let crypto_provider = NoEncryptionCryptoProvider;
583 let communication_provider = TestCommunicationBackend::new();
584 let session_map = InMemorySessionRepository::new(HashMap::new());
585 let client =
586 IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
587 let _ = client.start(None).await;
588 let mut subscription = client
589 .subscribe_typed::<TestPayload>()
590 .await
591 .expect("Subscribing should not fail");
592 communication_provider.push_incoming(unrelated.clone());
593 communication_provider.push_incoming(unrelated.clone());
594 communication_provider.push_incoming(
595 typed_message
596 .clone()
597 .try_into()
598 .expect("Serialization should not fail"),
599 );
600
601 let received_message = subscription.receive(None).await.unwrap();
602
603 assert_eq!(received_message, typed_message);
604 }
605
606 #[tokio::test]
607 async fn returns_error_if_related_message_was_not_deserializable() {
608 #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
609 struct TestPayload {
610 some_data: String,
611 }
612
613 impl PayloadTypeName for TestPayload {
614 const PAYLOAD_TYPE_NAME: &str = "TestPayload";
615 }
616
617 let non_deserializable_message = IncomingMessage {
618 payload: vec![],
619 source: Source::Web {
620 tab_id: 9001,
621 document_id: "doc-1".to_string(),
622 origin: "https://example.com".to_string(),
623 },
624 destination: Endpoint::BrowserBackground { id: HostId::Own },
625 topic: Some("TestPayload".to_owned()),
626 };
627
628 let crypto_provider = NoEncryptionCryptoProvider;
629 let communication_provider = TestCommunicationBackend::new();
630 let session_map = InMemorySessionRepository::new(HashMap::new());
631 let client =
632 IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
633 let _ = client.start(None).await;
634 let mut subscription = client
635 .subscribe_typed::<TestPayload>()
636 .await
637 .expect("Subscribing should not fail");
638 communication_provider.push_incoming(non_deserializable_message.clone());
639
640 let result = subscription.receive(None).await;
641 assert!(matches!(result, Err(TypedReceiveError::Typing(_))));
642 }
643
644 #[tokio::test]
645 async fn ipc_client_stops_if_crypto_returns_send_error() {
646 let message = OutgoingMessage {
647 payload: vec![],
648 destination: Endpoint::BrowserBackground { id: HostId::Own },
649 topic: None,
650 };
651 let crypto_provider = TestCryptoProvider {
652 send_result: Some(Err("Crypto error".to_string())),
653 receive_result: None,
654 };
655 let communication_provider = TestCommunicationBackend::new();
656 let session_map = TestSessionRepository::new(HashMap::new());
657 let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
658 let _ = client.start(None).await;
659
660 let error = client.send(message).await.unwrap_err();
661 let is_running = client.is_running();
662
663 assert!(error.to_string().contains("Crypto error"));
664 assert!(!is_running);
665 }
666
667 #[tokio::test]
668 async fn ipc_client_stops_if_crypto_returns_receive_error() {
669 let crypto_provider = TestCryptoProvider {
670 send_result: None,
671 receive_result: Some(Err("Crypto error".to_string())),
672 };
673 let communication_provider = TestCommunicationBackend::new();
674 let session_map = TestSessionRepository::new(HashMap::new());
675 let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
676 let cancellation_token = CancellationToken::new();
677 let _ = client.start(Some(cancellation_token.clone())).await;
678
679 tokio::time::sleep(Duration::from_millis(100)).await;
681 let is_running = client.is_running();
682
683 assert!(!is_running);
684 assert!(cancellation_token.is_cancelled());
685 }
686
687 #[tokio::test]
688 async fn ipc_client_is_not_running_if_cancellation_token_is_cancelled() {
689 let crypto_provider = TestCryptoProvider {
690 send_result: None,
691 receive_result: None,
692 };
693 let communication_provider = TestCommunicationBackend::new();
694 let session_map = TestSessionRepository::new(HashMap::new());
695 let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
696 let cancellation_token = CancellationToken::new();
697 let _ = client.start(Some(cancellation_token.clone())).await;
698
699 tokio::time::sleep(Duration::from_millis(100)).await;
701
702 cancellation_token.cancel();
704 tokio::time::sleep(Duration::from_millis(100)).await;
705 let is_running = client.is_running();
706
707 assert!(!is_running);
708 }
709
710 #[tokio::test]
711 async fn ipc_client_is_running_if_no_errors_are_encountered() {
712 let crypto_provider = TestCryptoProvider {
713 send_result: None,
714 receive_result: None,
715 };
716 let communication_provider = TestCommunicationBackend::new();
717 let session_map = TestSessionRepository::new(HashMap::new());
718 let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
719 let cancellation_token = CancellationToken::new();
720 let _ = client.start(Some(cancellation_token.clone())).await;
721
722 tokio::time::sleep(Duration::from_millis(100)).await;
724 let is_running = client.is_running();
725
726 assert!(is_running);
727 assert!(!cancellation_token.is_cancelled());
728 }
729
730 #[tokio::test]
731 async fn ipc_client_is_not_running_if_not_started() {
732 let crypto_provider = TestCryptoProvider {
733 send_result: None,
734 receive_result: None,
735 };
736 let communication_provider = TestCommunicationBackend::new();
737 let session_map = TestSessionRepository::new(HashMap::new());
738 let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
739
740 tokio::time::sleep(Duration::from_millis(100)).await;
742 let is_running = client.is_running();
743
744 assert!(!is_running);
745 }
746
747 #[tokio::test]
748 async fn ipc_client_start_returns_error_if_already_running() {
749 let crypto_provider = TestCryptoProvider {
750 send_result: None,
751 receive_result: None,
752 };
753 let communication_provider = TestCommunicationBackend::new();
754 let session_map = TestSessionRepository::new(HashMap::new());
755 let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
756 let cancellation_token = CancellationToken::new();
757 let first_result = client.start(Some(cancellation_token.clone())).await;
758 assert_eq!(first_result, Ok(()));
759
760 tokio::time::sleep(Duration::from_millis(100)).await;
762 assert!(client.is_running());
763
764 let second_result = client.start(Some(cancellation_token.clone())).await;
765 assert_eq!(second_result, Err(AlreadyRunningError));
766 }
767
768 mod request {
769 use super::*;
770 use crate::RpcHandler;
771
772 #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
773 struct TestRequest {
774 a: i32,
775 b: i32,
776 }
777
778 #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
779 struct TestResponse {
780 result: i32,
781 }
782
783 impl RpcRequest for TestRequest {
784 type Response = TestResponse;
785
786 const NAME: &str = "TestRequest";
787 }
788
789 struct TestHandler;
790
791 impl RpcHandler for TestHandler {
792 type Request = TestRequest;
793
794 async fn handle(&self, request: Self::Request) -> TestResponse {
795 TestResponse {
796 result: request.a + request.b,
797 }
798 }
799 }
800
801 #[tokio::test]
802 async fn request_sends_message_and_returns_response() {
803 let crypto_provider = NoEncryptionCryptoProvider;
804 let communication_provider = TestCommunicationBackend::new();
805 let session_map = InMemorySessionRepository::new(HashMap::new());
806 let client =
807 IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
808 let _ = client.start(None).await;
809 let request = TestRequest { a: 1, b: 2 };
810 let response = TestResponse { result: 3 };
811
812 let request_clone = request.clone();
814 let client_clone = client.clone();
815 let result_handle = tokio::spawn(async move {
816 client_clone
817 .request::<TestRequest>(
818 request_clone,
819 Endpoint::BrowserBackground { id: HostId::Own },
820 None,
821 )
822 .await
823 });
824 tokio::time::sleep(Duration::from_millis(100)).await;
825
826 let outgoing_messages = communication_provider.outgoing().await;
828 let outgoing_request: RpcRequestMessage<TestRequest> =
829 serde_utils::from_slice(&outgoing_messages[0].payload)
830 .expect("Deserialization should not fail");
831 assert_eq!(outgoing_request.request_type, "TestRequest");
832 assert_eq!(outgoing_request.request, request);
833
834 let simulated_response = IncomingRpcResponseMessage {
836 result: Ok(response),
837 request_id: outgoing_request.request_id.clone(),
838 request_type: outgoing_request.request_type.clone(),
839 };
840 let simulated_response = IncomingMessage {
841 payload: serde_utils::to_vec(&simulated_response)
842 .expect("Serialization should not fail"),
843 source: Source::BrowserBackground { id: HostId::Own },
844 destination: Endpoint::Web {
845 tab_id: 9001,
846 document_id: "doc-1".to_string(),
847 },
848 topic: Some(
849 IncomingRpcResponseMessage::<TestRequest>::PAYLOAD_TYPE_NAME.to_owned(),
850 ),
851 };
852 communication_provider.push_incoming(simulated_response);
853
854 let result = result_handle.await.unwrap();
856 assert_eq!(result.unwrap().result, 3);
857 }
858
859 #[tokio::test]
860 async fn incoming_rpc_message_handles_request_and_returns_response() {
861 let crypto_provider = NoEncryptionCryptoProvider;
862 let communication_provider = TestCommunicationBackend::new();
863 let session_map = InMemorySessionRepository::new(HashMap::new());
864 let client =
865 IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
866 let _ = client.start(None).await;
867 let request_id = uuid::Uuid::new_v4().to_string();
868 let request = TestRequest { a: 1, b: 2 };
869 let response = TestResponse { result: 3 };
870
871 client.register_rpc_handler(TestHandler).await;
873
874 let simulated_request = RpcRequestMessage {
876 request,
877 request_id: request_id.clone(),
878 request_type: "TestRequest".to_string(),
879 };
880 let simulated_request_message = IncomingMessage {
881 payload: serde_utils::to_vec(&simulated_request)
882 .expect("Serialization should not fail"),
883 source: Source::Web {
884 tab_id: 9001,
885 document_id: "doc-1".to_string(),
886 origin: "https://example.com".to_string(),
887 },
888 destination: Endpoint::BrowserBackground { id: HostId::Own },
889 topic: Some(RPC_REQUEST_PAYLOAD_TYPE_NAME.to_owned()),
890 };
891 communication_provider.push_incoming(simulated_request_message);
892
893 tokio::time::sleep(Duration::from_millis(100)).await;
895
896 let outgoing_messages = communication_provider.outgoing().await;
898 let outgoing_response: IncomingRpcResponseMessage<TestResponse> =
899 serde_utils::from_slice(&outgoing_messages[0].payload)
900 .expect("Deserialization should not fail");
901
902 assert_eq!(
903 outgoing_messages[0].topic,
904 Some(IncomingRpcResponseMessage::<TestResponse>::PAYLOAD_TYPE_NAME.to_owned())
905 );
906 assert_eq!(outgoing_response.request_type, "TestRequest");
907 assert_eq!(outgoing_response.result, Ok(response));
908 }
909 }
910}