1use std::sync::{Arc, Mutex};
2
3use bitwarden_threading::cancellation_token::CancellationToken;
4use serde::de::DeserializeOwned;
5use thiserror::Error;
6use tokio::select;
7
8use crate::{
9 constants::CHANNEL_BUFFER_CAPACITY,
10 error::{
11 AlreadyRunningError, IpcErrorKind, ReceiveError, SendError, SubscribeError,
12 TypedReceiveError,
13 },
14 message::{
15 IncomingMessage, OutgoingMessage, PayloadTypeName, TypedIncomingMessage,
16 TypedOutgoingMessage,
17 },
18 rpc::{
19 exec::{handler::ErasedRpcHandler, handler_registry::RpcHandlerRegistry},
20 request_message::{RPC_REQUEST_PAYLOAD_TYPE_NAME, RpcRequestPayload},
21 response_message::OutgoingRpcResponseMessage,
22 },
23 serde_utils,
24 traits::{CommunicationBackend, CryptoProvider, SessionRepository},
25};
26
27pub struct IpcClientSubscription {
32 pub(crate) receiver: tokio::sync::broadcast::Receiver<IncomingMessage>,
33 pub(crate) topic: Option<String>,
34}
35
36pub struct IpcClientTypedSubscription<Payload: DeserializeOwned + PayloadTypeName>(
41 IpcClientSubscription,
42 std::marker::PhantomData<Payload>,
43);
44
45struct IpcClientInner<Crypto, Com, Ses>
47where
48 Crypto: CryptoProvider<Com, Ses>,
49 Com: CommunicationBackend,
50 Ses: SessionRepository<Crypto::Session>,
51{
52 crypto: Crypto,
53 communication: Com,
54 sessions: Ses,
55
56 handlers: RpcHandlerRegistry,
57 incoming: Mutex<Option<tokio::sync::broadcast::Receiver<IncomingMessage>>>,
58 cancellation_token: Mutex<Option<CancellationToken>>,
59}
60
61pub struct IpcClientImpl<Crypto, Com, Ses>
67where
68 Crypto: CryptoProvider<Com, Ses>,
69 Com: CommunicationBackend,
70 Ses: SessionRepository<Crypto::Session>,
71{
72 inner: Arc<IpcClientInner<Crypto, Com, Ses>>,
73}
74
75impl<Crypto, Com, Ses> Clone for IpcClientImpl<Crypto, Com, Ses>
76where
77 Crypto: CryptoProvider<Com, Ses>,
78 Com: CommunicationBackend,
79 Ses: SessionRepository<Crypto::Session>,
80{
81 fn clone(&self) -> Self {
82 Self {
83 inner: self.inner.clone(),
84 }
85 }
86}
87
88impl<Crypto, Com, Ses> IpcClientImpl<Crypto, Com, Ses>
89where
90 Crypto: CryptoProvider<Com, Ses>,
91 Com: CommunicationBackend,
92 Ses: SessionRepository<Crypto::Session>,
93{
94 pub fn new(crypto: Crypto, communication: Com, sessions: Ses) -> Self {
97 Self {
98 inner: Arc::new(IpcClientInner {
99 crypto,
100 communication,
101 sessions,
102
103 handlers: RpcHandlerRegistry::new(),
104 incoming: Mutex::new(None),
105 cancellation_token: Mutex::new(None),
106 }),
107 }
108 }
109}
110
111#[async_trait::async_trait]
112impl<Crypto, Com, Ses> crate::ipc_client_trait::IpcClient for IpcClientImpl<Crypto, Com, Ses>
113where
114 Crypto: CryptoProvider<Com, Ses>,
115 Com: CommunicationBackend,
116 Ses: SessionRepository<Crypto::Session>,
117{
118 async fn start(
119 &self,
120 cancellation_token: Option<CancellationToken>,
121 ) -> Result<(), AlreadyRunningError> {
122 if self.is_running() {
123 return Err(AlreadyRunningError);
124 }
125
126 let cancellation_token = cancellation_token.unwrap_or_default();
127 self.inner
128 .cancellation_token
129 .lock()
130 .expect("Failed to lock cancellation token mutex")
131 .replace(cancellation_token.clone());
132
133 let com_receiver = self.inner.communication.subscribe().await;
134 let (client_tx, client_rx) = tokio::sync::broadcast::channel(CHANNEL_BUFFER_CAPACITY);
135
136 self.inner
137 .incoming
138 .lock()
139 .expect("Failed to lock incoming mutex")
140 .replace(client_rx);
141
142 let inner = self.inner.clone();
143 let future = async move {
144 loop {
145 let rpc_topic = RPC_REQUEST_PAYLOAD_TYPE_NAME.to_owned();
146 select! {
147 _ = cancellation_token.cancelled() => {
148 tracing::debug!("Cancellation signal received, stopping IPC client");
149 break;
150 }
151 received = inner.crypto.receive(&com_receiver, &inner.communication, &inner.sessions) => {
152 match received {
153 Ok(message) if message.topic == Some(rpc_topic) => {
154 handle_rpc_request(&inner, message)
155 }
156 Ok(message) => {
157 if client_tx.send(message).is_err() {
158 tracing::error!("Failed to save incoming message");
159 break;
160 };
161 }
162 Err(error) if error.is_fatal() => {
163 tracing::error!(?error, "Fatal error receiving message, stopping IPC client");
164 break;
165 }
166 Err(error) => {
167 tracing::warn!(?error, "Recoverable error receiving message, continuing");
168 }
169 }
170 }
171 }
172 }
173 tracing::debug!("IPC client shutting down");
174 stop_inner(&inner);
175 };
176
177 #[cfg(not(target_arch = "wasm32"))]
178 tokio::spawn(future);
179
180 #[cfg(target_arch = "wasm32")]
181 wasm_bindgen_futures::spawn_local(future);
182
183 Ok(())
184 }
185
186 fn is_running(&self) -> bool {
187 let has_incoming = self
188 .inner
189 .incoming
190 .lock()
191 .expect("Failed to lock incoming mutex")
192 .as_ref()
193 .map(|receiver| !receiver.is_closed())
194 .unwrap_or(false);
195 let has_cancellation_token = self
196 .inner
197 .cancellation_token
198 .lock()
199 .expect("Failed to lock cancellation token mutex")
200 .is_some();
201 has_incoming && has_cancellation_token
202 }
203
204 async fn send(&self, message: OutgoingMessage) -> Result<(), SendError> {
205 let result = self
206 .inner
207 .crypto
208 .send(&self.inner.communication, &self.inner.sessions, message)
209 .await;
210
211 if let Err(ref error) = result {
212 if error.is_fatal() {
213 tracing::error!(?error, "Fatal error sending message, stopping IPC client");
214 stop_inner(&self.inner);
215 } else {
216 tracing::warn!(
217 ?error,
218 "Recoverable error sending message, IPC client will continue running"
219 );
220 }
221 }
222
223 result.map_err(|e| SendError(format!("{e:?}")))
224 }
225
226 async fn subscribe(
227 &self,
228 topic: Option<String>,
229 ) -> Result<IpcClientSubscription, SubscribeError> {
230 Ok(IpcClientSubscription {
231 receiver: self
232 .inner
233 .incoming
234 .lock()
235 .expect("Failed to lock incoming mutex")
236 .as_ref()
237 .ok_or(SubscribeError::NotStarted)?
238 .resubscribe(),
239 topic,
240 })
241 }
242
243 async fn register_rpc_handler_erased(&self, name: &str, handler: Box<dyn ErasedRpcHandler>) {
244 self.inner
245 .handlers
246 .register_erased(name.to_owned(), handler)
247 .await;
248 }
249}
250
251fn stop_inner<Crypto, Com, Ses>(inner: &IpcClientInner<Crypto, Com, Ses>)
252where
253 Crypto: CryptoProvider<Com, Ses>,
254 Com: CommunicationBackend,
255 Ses: SessionRepository<Crypto::Session>,
256{
257 let mut cancellation_token = inner
258 .cancellation_token
259 .lock()
260 .expect("Failed to lock cancellation token mutex");
261
262 if let Some(cancellation_token) = cancellation_token.take() {
263 cancellation_token.cancel();
264 }
265}
266
267fn handle_rpc_request<Crypto, Com, Ses>(
268 inner: &Arc<IpcClientInner<Crypto, Com, Ses>>,
269 incoming_message: IncomingMessage,
270) where
271 Crypto: CryptoProvider<Com, Ses>,
272 Com: CommunicationBackend,
273 Ses: SessionRepository<Crypto::Session>,
274{
275 let inner = inner.clone();
276 let future = async move {
277 #[derive(Debug, Error)]
278 enum HandleError {
279 #[error("Failed to deserialize request message: {0}")]
280 Deserialize(String),
281
282 #[error("Failed to serialize response message: {0}")]
283 Serialize(String),
284 }
285
286 async fn handle(
287 incoming_message: IncomingMessage,
288 handlers: &RpcHandlerRegistry,
289 ) -> Result<OutgoingMessage, HandleError> {
290 let request = RpcRequestPayload::from_slice(incoming_message.payload.clone()).map_err(
291 |e: serde_utils::DeserializeError| HandleError::Deserialize(e.to_string()),
292 )?;
293
294 let response = handlers.handle(&request).await;
295
296 let response_message = OutgoingRpcResponseMessage {
297 request_id: request.request_id(),
298 request_type: request.request_type(),
299 result: response,
300 };
301
302 let outgoing = TypedOutgoingMessage {
303 payload: response_message,
304 destination: incoming_message.source.into(),
305 }
306 .try_into()
307 .map_err(|e: serde_utils::SerializeError| HandleError::Serialize(e.to_string()))?;
308
309 Ok(outgoing)
310 }
311
312 match handle(incoming_message, &inner.handlers).await {
313 Ok(outgoing_message) => {
314 let result = inner
317 .crypto
318 .send(&inner.communication, &inner.sessions, outgoing_message)
319 .await;
320 if result.is_err() {
321 tracing::error!("Failed to send response message");
322 }
323 }
324 Err(error) => {
325 tracing::error!(%error, "Error handling RPC request");
326 }
327 }
328 };
329
330 #[cfg(not(target_arch = "wasm32"))]
331 tokio::spawn(future);
332
333 #[cfg(target_arch = "wasm32")]
334 wasm_bindgen_futures::spawn_local(future);
335}
336
337impl IpcClientSubscription {
338 pub async fn receive(
341 &mut self,
342 cancellation_token: Option<CancellationToken>,
343 ) -> Result<IncomingMessage, ReceiveError> {
344 let cancellation_token = cancellation_token.unwrap_or_default();
345
346 loop {
347 select! {
348 _ = cancellation_token.cancelled() => {
349 return Err(ReceiveError::Cancelled)
350 }
351 result = self.receiver.recv() => {
352 let received = result?;
353 if self.topic.is_none() || received.topic == self.topic {
354 return Ok::<IncomingMessage, ReceiveError>(received);
355 }
356 }
357 }
358 }
359 }
360}
361
362impl<Payload> IpcClientTypedSubscription<Payload>
363where
364 Payload: DeserializeOwned + PayloadTypeName,
365{
366 pub(crate) fn new(subscription: IpcClientSubscription) -> Self {
367 Self(subscription, std::marker::PhantomData)
368 }
369
370 pub async fn receive(
373 &mut self,
374 cancellation_token: Option<CancellationToken>,
375 ) -> Result<TypedIncomingMessage<Payload>, TypedReceiveError> {
376 let received = self.0.receive(cancellation_token).await?;
377 received
378 .try_into()
379 .map_err(|e: serde_utils::DeserializeError| TypedReceiveError::Typing(e.to_string()))
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use std::{collections::HashMap, time::Duration};
386
387 use bitwarden_threading::time::sleep;
388 use serde::{Deserialize, Serialize};
389
390 use super::*;
391 use crate::{
392 IpcClientExt,
393 endpoint::{Endpoint, HostId, Source},
394 ipc_client_trait::IpcClient,
395 message::PayloadTypeName,
396 rpc::{
397 request::RpcRequest,
398 request_message::{RPC_REQUEST_PAYLOAD_TYPE_NAME, RpcRequestMessage},
399 response_message::IncomingRpcResponseMessage,
400 },
401 traits::{InMemorySessionRepository, NoEncryptionCryptoProvider, TestCommunicationBackend},
402 };
403
404 #[derive(Debug, Clone)]
407 struct TestCryptoError {
408 #[allow(dead_code)]
411 message: String,
412 fatal: bool,
413 }
414
415 impl IpcErrorKind for TestCryptoError {
416 fn is_fatal(&self) -> bool {
417 self.fatal
418 }
419 }
420
421 struct TestCryptoProvider {
422 send_result: Option<Result<(), TestCryptoError>>,
424 receive_result: Option<Result<IncomingMessage, TestCryptoError>>,
426 }
427
428 type TestSessionRepository = InMemorySessionRepository<String>;
429 impl CryptoProvider<TestCommunicationBackend, TestSessionRepository> for TestCryptoProvider {
430 type Session = String;
431 type SendError = TestCryptoError;
432 type ReceiveError = TestCryptoError;
433
434 async fn receive(
435 &self,
436 _receiver: &<TestCommunicationBackend as CommunicationBackend>::Receiver,
437 _communication: &TestCommunicationBackend,
438 _sessions: &TestSessionRepository,
439 ) -> Result<IncomingMessage, Self::ReceiveError> {
440 match &self.receive_result {
441 Some(result) => {
442 sleep(Duration::from_millis(5)).await;
447 result.clone()
448 }
449 None => {
450 sleep(Duration::from_secs(600)).await;
452 Err(TestCryptoError {
453 message: "Simulated timeout".to_string(),
454 fatal: true,
455 })
456 }
457 }
458 }
459
460 async fn send(
461 &self,
462 _communication: &TestCommunicationBackend,
463 _sessions: &TestSessionRepository,
464 _message: OutgoingMessage,
465 ) -> Result<(), Self::SendError> {
466 match &self.send_result {
467 Some(result) => result.clone(),
468 None => {
469 sleep(Duration::from_secs(600)).await;
471 Err(TestCryptoError {
472 message: "Simulated timeout".to_string(),
473 fatal: true,
474 })
475 }
476 }
477 }
478 }
479
480 #[tokio::test]
481 async fn returns_send_error_when_crypto_provider_returns_error() {
482 let message = OutgoingMessage {
483 payload: vec![],
484 destination: Endpoint::BrowserBackground { id: HostId::Own },
485 topic: None,
486 };
487 let crypto_provider = TestCryptoProvider {
488 send_result: Some(Err(TestCryptoError {
489 message: "Crypto error".to_string(),
490 fatal: false,
491 })),
492 receive_result: Some(Err(TestCryptoError {
493 message: "Should not have be called".to_string(),
494 fatal: false,
495 })),
496 };
497 let communication_provider = TestCommunicationBackend::new();
498 let session_map = TestSessionRepository::new(HashMap::new());
499 let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
500 let _ = client.start(None).await;
501
502 let error = client.send(message).await.unwrap_err();
503
504 assert!(error.to_string().contains("Crypto error"));
505 }
506
507 #[tokio::test]
508 async fn communication_provider_has_outgoing_message_when_sending_through_ipc_client() {
509 let message = OutgoingMessage {
510 payload: vec![],
511 destination: Endpoint::BrowserBackground { id: HostId::Own },
512 topic: None,
513 };
514 let crypto_provider = NoEncryptionCryptoProvider;
515 let communication_provider = TestCommunicationBackend::new();
516 let session_map = InMemorySessionRepository::new(HashMap::new());
517 let client =
518 IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
519 let _ = client.start(None).await;
520
521 client.send(message.clone()).await.unwrap();
522
523 let outgoing_messages = communication_provider.outgoing().await;
524 assert_eq!(outgoing_messages, vec![message]);
525 }
526
527 #[tokio::test]
528 async fn returns_received_message_when_received_from_backend() {
529 let message = IncomingMessage {
530 payload: vec![],
531 source: Source::Web {
532 tab_id: 9001,
533 document_id: "doc-1".to_string(),
534 origin: "https://example.com".to_string(),
535 },
536 destination: Endpoint::BrowserBackground { id: HostId::Own },
537 topic: None,
538 };
539 let crypto_provider = NoEncryptionCryptoProvider;
540 let communication_provider = TestCommunicationBackend::new();
541 let session_map = InMemorySessionRepository::new(HashMap::new());
542 let client =
543 IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
544 let _ = client.start(None).await;
545
546 let mut subscription = client
547 .subscribe(None)
548 .await
549 .expect("Subscribing should not fail");
550 communication_provider.push_incoming(message.clone());
551 let received_message = subscription.receive(None).await.unwrap();
552
553 assert_eq!(received_message, message);
554 }
555
556 #[tokio::test]
557 async fn skips_non_matching_topics_and_returns_first_matching_message() {
558 let non_matching_message = IncomingMessage {
559 payload: vec![],
560 source: Source::Web {
561 tab_id: 9001,
562 document_id: "doc-1".to_string(),
563 origin: "https://example.com".to_string(),
564 },
565 destination: Endpoint::BrowserBackground { id: HostId::Own },
566 topic: Some("non_matching_topic".to_owned()),
567 };
568 let matching_message = IncomingMessage {
569 payload: vec![109],
570 source: Source::Web {
571 tab_id: 9001,
572 document_id: "doc-1".to_string(),
573 origin: "https://example.com".to_string(),
574 },
575 destination: Endpoint::BrowserBackground { id: HostId::Own },
576 topic: Some("matching_topic".to_owned()),
577 };
578
579 let crypto_provider = NoEncryptionCryptoProvider;
580 let communication_provider = TestCommunicationBackend::new();
581 let session_map = InMemorySessionRepository::new(HashMap::new());
582 let client =
583 IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
584 let _ = client.start(None).await;
585 let mut subscription = client
586 .subscribe(Some("matching_topic".to_owned()))
587 .await
588 .expect("Subscribing should not fail");
589 communication_provider.push_incoming(non_matching_message.clone());
590 communication_provider.push_incoming(non_matching_message.clone());
591 communication_provider.push_incoming(matching_message.clone());
592
593 let received_message: IncomingMessage = subscription.receive(None).await.unwrap();
594
595 assert_eq!(received_message, matching_message);
596 }
597
598 #[tokio::test]
599 async fn skips_unrelated_messages_and_returns_typed_message() {
600 #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
601 struct TestPayload {
602 some_data: String,
603 }
604
605 impl PayloadTypeName for TestPayload {
606 const PAYLOAD_TYPE_NAME: &str = "TestPayload";
607 }
608
609 let unrelated = IncomingMessage {
610 payload: vec![],
611 source: Source::Web {
612 tab_id: 9001,
613 document_id: "doc-1".to_string(),
614 origin: "https://example.com".to_string(),
615 },
616 destination: Endpoint::BrowserBackground { id: HostId::Own },
617 topic: None,
618 };
619 let typed_message = crate::message::TypedIncomingMessage {
620 payload: TestPayload {
621 some_data: "Hello, world!".to_string(),
622 },
623 source: Source::Web {
624 tab_id: 9001,
625 document_id: "doc-1".to_string(),
626 origin: "https://example.com".to_string(),
627 },
628 destination: Endpoint::BrowserBackground { id: HostId::Own },
629 };
630
631 let crypto_provider = NoEncryptionCryptoProvider;
632 let communication_provider = TestCommunicationBackend::new();
633 let session_map = InMemorySessionRepository::new(HashMap::new());
634 let client =
635 IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
636 let _ = client.start(None).await;
637 let mut subscription = client
638 .subscribe_typed::<TestPayload>()
639 .await
640 .expect("Subscribing should not fail");
641 communication_provider.push_incoming(unrelated.clone());
642 communication_provider.push_incoming(unrelated.clone());
643 communication_provider.push_incoming(
644 typed_message
645 .clone()
646 .try_into()
647 .expect("Serialization should not fail"),
648 );
649
650 let received_message = subscription.receive(None).await.unwrap();
651
652 assert_eq!(received_message, typed_message);
653 }
654
655 #[tokio::test]
656 async fn returns_error_if_related_message_was_not_deserializable() {
657 #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
658 struct TestPayload {
659 some_data: String,
660 }
661
662 impl PayloadTypeName for TestPayload {
663 const PAYLOAD_TYPE_NAME: &str = "TestPayload";
664 }
665
666 let non_deserializable_message = IncomingMessage {
667 payload: vec![],
668 source: Source::Web {
669 tab_id: 9001,
670 document_id: "doc-1".to_string(),
671 origin: "https://example.com".to_string(),
672 },
673 destination: Endpoint::BrowserBackground { id: HostId::Own },
674 topic: Some("TestPayload".to_owned()),
675 };
676
677 let crypto_provider = NoEncryptionCryptoProvider;
678 let communication_provider = TestCommunicationBackend::new();
679 let session_map = InMemorySessionRepository::new(HashMap::new());
680 let client =
681 IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
682 let _ = client.start(None).await;
683 let mut subscription = client
684 .subscribe_typed::<TestPayload>()
685 .await
686 .expect("Subscribing should not fail");
687 communication_provider.push_incoming(non_deserializable_message.clone());
688
689 let result = subscription.receive(None).await;
690 assert!(matches!(result, Err(TypedReceiveError::Typing(_))));
691 }
692
693 #[tokio::test]
694 async fn ipc_client_stops_if_crypto_returns_fatal_send_error() {
695 let message = OutgoingMessage {
696 payload: vec![],
697 destination: Endpoint::BrowserBackground { id: HostId::Own },
698 topic: None,
699 };
700 let crypto_provider = TestCryptoProvider {
701 send_result: Some(Err(TestCryptoError {
702 message: "Crypto error".to_string(),
703 fatal: true,
704 })),
705 receive_result: None,
706 };
707 let communication_provider = TestCommunicationBackend::new();
708 let session_map = TestSessionRepository::new(HashMap::new());
709 let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
710 let _ = client.start(None).await;
711
712 let error = client.send(message).await.unwrap_err();
713 let is_running = client.is_running();
714
715 assert!(error.to_string().contains("Crypto error"));
716 assert!(!is_running);
717 }
718
719 #[tokio::test]
720 async fn ipc_client_keeps_running_if_crypto_returns_recoverable_send_error() {
721 let message = OutgoingMessage {
722 payload: vec![],
723 destination: Endpoint::BrowserBackground { id: HostId::Own },
724 topic: None,
725 };
726 let crypto_provider = TestCryptoProvider {
727 send_result: Some(Err(TestCryptoError {
730 message: "Crypto error".to_string(),
731 fatal: false,
732 })),
733 receive_result: None,
735 };
736 let communication_provider = TestCommunicationBackend::new();
737 let session_map = TestSessionRepository::new(HashMap::new());
738 let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
739 let _ = client.start(None).await;
740
741 let error = client.send(message).await.unwrap_err();
742 let is_running = client.is_running();
743
744 assert!(error.to_string().contains("Crypto error"));
746 assert!(is_running);
748 }
749
750 #[tokio::test]
751 async fn ipc_client_stops_if_crypto_returns_fatal_receive_error() {
752 let crypto_provider = TestCryptoProvider {
753 send_result: None,
754 receive_result: Some(Err(TestCryptoError {
755 message: "Crypto error".to_string(),
756 fatal: true,
757 })),
758 };
759 let communication_provider = TestCommunicationBackend::new();
760 let session_map = TestSessionRepository::new(HashMap::new());
761 let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
762 let cancellation_token = CancellationToken::new();
763 let _ = client.start(Some(cancellation_token.clone())).await;
764
765 tokio::time::sleep(Duration::from_millis(100)).await;
767 let is_running = client.is_running();
768
769 assert!(!is_running);
770 assert!(cancellation_token.is_cancelled());
771 }
772
773 #[tokio::test]
774 async fn ipc_client_keeps_running_if_crypto_returns_recoverable_receive_error() {
775 let crypto_provider = TestCryptoProvider {
776 send_result: None,
777 receive_result: Some(Err(TestCryptoError {
779 message: "Crypto error".to_string(),
780 fatal: false,
781 })),
782 };
783 let communication_provider = TestCommunicationBackend::new();
784 let session_map = TestSessionRepository::new(HashMap::new());
785 let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
786 let cancellation_token = CancellationToken::new();
787 let _ = client.start(Some(cancellation_token.clone())).await;
788
789 tokio::time::sleep(Duration::from_millis(100)).await;
791 let is_running = client.is_running();
792
793 assert!(is_running);
794 assert!(!cancellation_token.is_cancelled());
795 }
796
797 #[tokio::test]
798 async fn ipc_client_is_not_running_if_cancellation_token_is_cancelled() {
799 let crypto_provider = TestCryptoProvider {
800 send_result: None,
801 receive_result: None,
802 };
803 let communication_provider = TestCommunicationBackend::new();
804 let session_map = TestSessionRepository::new(HashMap::new());
805 let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
806 let cancellation_token = CancellationToken::new();
807 let _ = client.start(Some(cancellation_token.clone())).await;
808
809 tokio::time::sleep(Duration::from_millis(100)).await;
811
812 cancellation_token.cancel();
814 tokio::time::sleep(Duration::from_millis(100)).await;
815 let is_running = client.is_running();
816
817 assert!(!is_running);
818 }
819
820 #[tokio::test]
821 async fn ipc_client_is_running_if_no_errors_are_encountered() {
822 let crypto_provider = TestCryptoProvider {
823 send_result: None,
824 receive_result: None,
825 };
826 let communication_provider = TestCommunicationBackend::new();
827 let session_map = TestSessionRepository::new(HashMap::new());
828 let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
829 let cancellation_token = CancellationToken::new();
830 let _ = client.start(Some(cancellation_token.clone())).await;
831
832 tokio::time::sleep(Duration::from_millis(100)).await;
834 let is_running = client.is_running();
835
836 assert!(is_running);
837 assert!(!cancellation_token.is_cancelled());
838 }
839
840 #[tokio::test]
841 async fn ipc_client_is_not_running_if_not_started() {
842 let crypto_provider = TestCryptoProvider {
843 send_result: None,
844 receive_result: None,
845 };
846 let communication_provider = TestCommunicationBackend::new();
847 let session_map = TestSessionRepository::new(HashMap::new());
848 let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
849
850 tokio::time::sleep(Duration::from_millis(100)).await;
852 let is_running = client.is_running();
853
854 assert!(!is_running);
855 }
856
857 #[tokio::test]
858 async fn ipc_client_start_returns_error_if_already_running() {
859 let crypto_provider = TestCryptoProvider {
860 send_result: None,
861 receive_result: None,
862 };
863 let communication_provider = TestCommunicationBackend::new();
864 let session_map = TestSessionRepository::new(HashMap::new());
865 let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
866 let cancellation_token = CancellationToken::new();
867 let first_result = client.start(Some(cancellation_token.clone())).await;
868 assert_eq!(first_result, Ok(()));
869
870 tokio::time::sleep(Duration::from_millis(100)).await;
872 assert!(client.is_running());
873
874 let second_result = client.start(Some(cancellation_token.clone())).await;
875 assert_eq!(second_result, Err(AlreadyRunningError));
876 }
877
878 mod request {
879 use super::*;
880 use crate::RpcHandler;
881
882 #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
883 struct TestRequest {
884 a: i32,
885 b: i32,
886 }
887
888 #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
889 struct TestResponse {
890 result: i32,
891 }
892
893 impl RpcRequest for TestRequest {
894 type Response = TestResponse;
895
896 const NAME: &str = "TestRequest";
897 }
898
899 struct TestHandler;
900
901 impl RpcHandler for TestHandler {
902 type Request = TestRequest;
903
904 async fn handle(&self, request: Self::Request) -> TestResponse {
905 TestResponse {
906 result: request.a + request.b,
907 }
908 }
909 }
910
911 #[tokio::test]
912 async fn request_sends_message_and_returns_response() {
913 let crypto_provider = NoEncryptionCryptoProvider;
914 let communication_provider = TestCommunicationBackend::new();
915 let session_map = InMemorySessionRepository::default();
916 let client =
917 IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
918 let _ = client.start(None).await;
919 let request = TestRequest { a: 1, b: 2 };
920 let response = TestResponse { result: 3 };
921
922 let request_clone = request.clone();
924 let client_clone = client.clone();
925 let result_handle = tokio::spawn(async move {
926 client_clone
927 .request::<TestRequest>(
928 request_clone,
929 Endpoint::BrowserBackground { id: HostId::Own },
930 None,
931 )
932 .await
933 });
934 tokio::time::sleep(Duration::from_millis(100)).await;
935
936 let outgoing_messages = communication_provider.outgoing().await;
938 let outgoing_request: RpcRequestMessage<TestRequest> =
939 serde_utils::from_slice(&outgoing_messages[0].payload)
940 .expect("Deserialization should not fail");
941 assert_eq!(outgoing_request.request_type, "TestRequest");
942 assert_eq!(outgoing_request.request, request);
943
944 let simulated_response = IncomingRpcResponseMessage {
946 result: Ok(response),
947 request_id: outgoing_request.request_id.clone(),
948 request_type: outgoing_request.request_type.clone(),
949 };
950 let simulated_response = IncomingMessage {
951 payload: serde_utils::to_vec(&simulated_response)
952 .expect("Serialization should not fail"),
953 source: Source::BrowserBackground { id: HostId::Own },
954 destination: Endpoint::Web {
955 tab_id: 9001,
956 document_id: "doc-1".to_string(),
957 },
958 topic: Some(
959 IncomingRpcResponseMessage::<TestRequest>::PAYLOAD_TYPE_NAME.to_owned(),
960 ),
961 };
962 communication_provider.push_incoming(simulated_response);
963
964 let result = result_handle.await.unwrap();
966 assert_eq!(result.unwrap().result, 3);
967 }
968
969 #[tokio::test]
970 async fn incoming_rpc_message_handles_request_and_returns_response() {
971 let crypto_provider = NoEncryptionCryptoProvider;
972 let communication_provider = TestCommunicationBackend::new();
973 let session_map = InMemorySessionRepository::default();
974 let client =
975 IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
976 let _ = client.start(None).await;
977 let request_id = uuid::Uuid::new_v4().to_string();
978 let request = TestRequest { a: 1, b: 2 };
979 let response = TestResponse { result: 3 };
980
981 client.register_rpc_handler(TestHandler).await;
983
984 let simulated_request = RpcRequestMessage {
986 request,
987 request_id: request_id.clone(),
988 request_type: "TestRequest".to_string(),
989 };
990 let simulated_request_message = IncomingMessage {
991 payload: serde_utils::to_vec(&simulated_request)
992 .expect("Serialization should not fail"),
993 source: Source::Web {
994 tab_id: 9001,
995 document_id: "doc-1".to_string(),
996 origin: "https://example.com".to_string(),
997 },
998 destination: Endpoint::BrowserBackground { id: HostId::Own },
999 topic: Some(RPC_REQUEST_PAYLOAD_TYPE_NAME.to_owned()),
1000 };
1001 communication_provider.push_incoming(simulated_request_message);
1002
1003 tokio::time::sleep(Duration::from_millis(100)).await;
1005
1006 let outgoing_messages = communication_provider.outgoing().await;
1008 let outgoing_response: IncomingRpcResponseMessage<TestResponse> =
1009 serde_utils::from_slice(&outgoing_messages[0].payload)
1010 .expect("Deserialization should not fail");
1011
1012 assert_eq!(
1013 outgoing_messages[0].topic,
1014 Some(IncomingRpcResponseMessage::<TestResponse>::PAYLOAD_TYPE_NAME.to_owned())
1015 );
1016 assert_eq!(outgoing_response.request_type, "TestRequest");
1017 assert_eq!(outgoing_response.result, Ok(response));
1018 }
1019 }
1020}