Skip to main content

bitwarden_ipc/
ipc_client.rs

1use std::sync::{Arc, Mutex};
2
3use bitwarden_threading::cancellation_token::CancellationToken;
4use serde::de::DeserializeOwned;
5use thiserror::Error;
6use tokio::select;
7
8use crate::{
9    constants::CHANNEL_BUFFER_CAPACITY,
10    error::{
11        AlreadyRunningError, IpcErrorKind, ReceiveError, SendError, SubscribeError,
12        TypedReceiveError,
13    },
14    message::{
15        IncomingMessage, OutgoingMessage, PayloadTypeName, TypedIncomingMessage,
16        TypedOutgoingMessage,
17    },
18    rpc::{
19        exec::{handler::ErasedRpcHandler, handler_registry::RpcHandlerRegistry},
20        request_message::{RPC_REQUEST_PAYLOAD_TYPE_NAME, RpcRequestPayload},
21        response_message::OutgoingRpcResponseMessage,
22    },
23    serde_utils,
24    traits::{CommunicationBackend, CryptoProvider, SessionRepository},
25};
26
27/// A subscription to receive messages over IPC.
28/// The subcription will start buffering messages after its creation and return them
29/// when receive() is called. Messages received before the subscription was created will not be
30/// returned.
31pub struct IpcClientSubscription {
32    pub(crate) receiver: tokio::sync::broadcast::Receiver<IncomingMessage>,
33    pub(crate) topic: Option<String>,
34}
35
36/// A subscription to receive messages over IPC.
37/// The subcription will start buffering messages after its creation and return them
38/// when receive() is called. Messages received before the subscription was created will not be
39/// returned.
40pub struct IpcClientTypedSubscription<Payload: DeserializeOwned + PayloadTypeName>(
41    IpcClientSubscription,
42    std::marker::PhantomData<Payload>,
43);
44
45/// Internal shared state for the IPC client.
46struct IpcClientInner<Crypto, Com, Ses>
47where
48    Crypto: CryptoProvider<Com, Ses>,
49    Com: CommunicationBackend,
50    Ses: SessionRepository<Crypto::Session>,
51{
52    crypto: Crypto,
53    communication: Com,
54    sessions: Ses,
55
56    handlers: RpcHandlerRegistry,
57    incoming: Mutex<Option<tokio::sync::broadcast::Receiver<IncomingMessage>>>,
58    cancellation_token: Mutex<Option<CancellationToken>>,
59}
60
61/// An IPC client that handles communication between different components and clients.
62/// It uses a crypto provider to encrypt and decrypt messages, a communication backend to send and
63/// receive messages, and a session repository to persist sessions.
64///
65/// This is the concrete implementation of the [`IpcClient`](crate::IpcClient) trait.
66pub struct IpcClientImpl<Crypto, Com, Ses>
67where
68    Crypto: CryptoProvider<Com, Ses>,
69    Com: CommunicationBackend,
70    Ses: SessionRepository<Crypto::Session>,
71{
72    inner: Arc<IpcClientInner<Crypto, Com, Ses>>,
73}
74
75impl<Crypto, Com, Ses> Clone for IpcClientImpl<Crypto, Com, Ses>
76where
77    Crypto: CryptoProvider<Com, Ses>,
78    Com: CommunicationBackend,
79    Ses: SessionRepository<Crypto::Session>,
80{
81    fn clone(&self) -> Self {
82        Self {
83            inner: self.inner.clone(),
84        }
85    }
86}
87
88impl<Crypto, Com, Ses> IpcClientImpl<Crypto, Com, Ses>
89where
90    Crypto: CryptoProvider<Com, Ses>,
91    Com: CommunicationBackend,
92    Ses: SessionRepository<Crypto::Session>,
93{
94    /// Create a new IPC client with the provided crypto provider, communication backend, and
95    /// session repository.
96    pub fn new(crypto: Crypto, communication: Com, sessions: Ses) -> Self {
97        Self {
98            inner: Arc::new(IpcClientInner {
99                crypto,
100                communication,
101                sessions,
102
103                handlers: RpcHandlerRegistry::new(),
104                incoming: Mutex::new(None),
105                cancellation_token: Mutex::new(None),
106            }),
107        }
108    }
109}
110
111#[async_trait::async_trait]
112impl<Crypto, Com, Ses> crate::ipc_client_trait::IpcClient for IpcClientImpl<Crypto, Com, Ses>
113where
114    Crypto: CryptoProvider<Com, Ses>,
115    Com: CommunicationBackend,
116    Ses: SessionRepository<Crypto::Session>,
117{
118    async fn start(
119        &self,
120        cancellation_token: Option<CancellationToken>,
121    ) -> Result<(), AlreadyRunningError> {
122        if self.is_running() {
123            return Err(AlreadyRunningError);
124        }
125
126        let cancellation_token = cancellation_token.unwrap_or_default();
127        self.inner
128            .cancellation_token
129            .lock()
130            .expect("Failed to lock cancellation token mutex")
131            .replace(cancellation_token.clone());
132
133        let com_receiver = self.inner.communication.subscribe().await;
134        let (client_tx, client_rx) = tokio::sync::broadcast::channel(CHANNEL_BUFFER_CAPACITY);
135
136        self.inner
137            .incoming
138            .lock()
139            .expect("Failed to lock incoming mutex")
140            .replace(client_rx);
141
142        let inner = self.inner.clone();
143        let future = async move {
144            loop {
145                let rpc_topic = RPC_REQUEST_PAYLOAD_TYPE_NAME.to_owned();
146                select! {
147                    _ = cancellation_token.cancelled() => {
148                        tracing::debug!("Cancellation signal received, stopping IPC client");
149                        break;
150                    }
151                    received = inner.crypto.receive(&com_receiver, &inner.communication, &inner.sessions) => {
152                        match received {
153                            Ok(message) if message.topic == Some(rpc_topic) => {
154                                handle_rpc_request(&inner, message)
155                            }
156                            Ok(message) => {
157                                if client_tx.send(message).is_err() {
158                                    tracing::error!("Failed to save incoming message");
159                                    break;
160                                };
161                            }
162                            Err(error) if error.is_fatal() => {
163                                tracing::error!(?error, "Fatal error receiving message, stopping IPC client");
164                                break;
165                            }
166                            Err(error) => {
167                                tracing::warn!(?error, "Recoverable error receiving message, continuing");
168                            }
169                        }
170                    }
171                }
172            }
173            tracing::debug!("IPC client shutting down");
174            stop_inner(&inner);
175        };
176
177        #[cfg(not(target_arch = "wasm32"))]
178        tokio::spawn(future);
179
180        #[cfg(target_arch = "wasm32")]
181        wasm_bindgen_futures::spawn_local(future);
182
183        Ok(())
184    }
185
186    fn is_running(&self) -> bool {
187        let has_incoming = self
188            .inner
189            .incoming
190            .lock()
191            .expect("Failed to lock incoming mutex")
192            .as_ref()
193            .map(|receiver| !receiver.is_closed())
194            .unwrap_or(false);
195        let has_cancellation_token = self
196            .inner
197            .cancellation_token
198            .lock()
199            .expect("Failed to lock cancellation token mutex")
200            .is_some();
201        has_incoming && has_cancellation_token
202    }
203
204    async fn send(&self, message: OutgoingMessage) -> Result<(), SendError> {
205        let result = self
206            .inner
207            .crypto
208            .send(&self.inner.communication, &self.inner.sessions, message)
209            .await;
210
211        if let Err(ref error) = result {
212            if error.is_fatal() {
213                tracing::error!(?error, "Fatal error sending message, stopping IPC client");
214                stop_inner(&self.inner);
215            } else {
216                tracing::warn!(
217                    ?error,
218                    "Recoverable error sending message, IPC client will continue running"
219                );
220            }
221        }
222
223        result.map_err(|e| SendError(format!("{e:?}")))
224    }
225
226    async fn subscribe(
227        &self,
228        topic: Option<String>,
229    ) -> Result<IpcClientSubscription, SubscribeError> {
230        Ok(IpcClientSubscription {
231            receiver: self
232                .inner
233                .incoming
234                .lock()
235                .expect("Failed to lock incoming mutex")
236                .as_ref()
237                .ok_or(SubscribeError::NotStarted)?
238                .resubscribe(),
239            topic,
240        })
241    }
242
243    async fn register_rpc_handler_erased(&self, name: &str, handler: Box<dyn ErasedRpcHandler>) {
244        self.inner
245            .handlers
246            .register_erased(name.to_owned(), handler)
247            .await;
248    }
249}
250
251fn stop_inner<Crypto, Com, Ses>(inner: &IpcClientInner<Crypto, Com, Ses>)
252where
253    Crypto: CryptoProvider<Com, Ses>,
254    Com: CommunicationBackend,
255    Ses: SessionRepository<Crypto::Session>,
256{
257    let mut cancellation_token = inner
258        .cancellation_token
259        .lock()
260        .expect("Failed to lock cancellation token mutex");
261
262    if let Some(cancellation_token) = cancellation_token.take() {
263        cancellation_token.cancel();
264    }
265}
266
267fn handle_rpc_request<Crypto, Com, Ses>(
268    inner: &Arc<IpcClientInner<Crypto, Com, Ses>>,
269    incoming_message: IncomingMessage,
270) where
271    Crypto: CryptoProvider<Com, Ses>,
272    Com: CommunicationBackend,
273    Ses: SessionRepository<Crypto::Session>,
274{
275    let inner = inner.clone();
276    let future = async move {
277        #[derive(Debug, Error)]
278        enum HandleError {
279            #[error("Failed to deserialize request message: {0}")]
280            Deserialize(String),
281
282            #[error("Failed to serialize response message: {0}")]
283            Serialize(String),
284        }
285
286        async fn handle(
287            incoming_message: IncomingMessage,
288            handlers: &RpcHandlerRegistry,
289        ) -> Result<OutgoingMessage, HandleError> {
290            let request = RpcRequestPayload::from_slice(incoming_message.payload.clone()).map_err(
291                |e: serde_utils::DeserializeError| HandleError::Deserialize(e.to_string()),
292            )?;
293
294            let response = handlers.handle(&request).await;
295
296            let response_message = OutgoingRpcResponseMessage {
297                request_id: request.request_id(),
298                request_type: request.request_type(),
299                result: response,
300            };
301
302            let outgoing = TypedOutgoingMessage {
303                payload: response_message,
304                destination: incoming_message.source.into(),
305            }
306            .try_into()
307            .map_err(|e: serde_utils::SerializeError| HandleError::Serialize(e.to_string()))?;
308
309            Ok(outgoing)
310        }
311
312        match handle(incoming_message, &inner.handlers).await {
313            Ok(outgoing_message) => {
314                // Send response directly through the crypto provider (not through the trait)
315                // since we're inside the background task and don't have a trait object.
316                let result = inner
317                    .crypto
318                    .send(&inner.communication, &inner.sessions, outgoing_message)
319                    .await;
320                if result.is_err() {
321                    tracing::error!("Failed to send response message");
322                }
323            }
324            Err(error) => {
325                tracing::error!(%error, "Error handling RPC request");
326            }
327        }
328    };
329
330    #[cfg(not(target_arch = "wasm32"))]
331    tokio::spawn(future);
332
333    #[cfg(target_arch = "wasm32")]
334    wasm_bindgen_futures::spawn_local(future);
335}
336
337impl IpcClientSubscription {
338    /// Receive a message, optionally filtering by topic.
339    /// Setting the cancellation_token to `None` will wait indefinitely.
340    pub async fn receive(
341        &mut self,
342        cancellation_token: Option<CancellationToken>,
343    ) -> Result<IncomingMessage, ReceiveError> {
344        let cancellation_token = cancellation_token.unwrap_or_default();
345
346        loop {
347            select! {
348                _ = cancellation_token.cancelled() => {
349                    return Err(ReceiveError::Cancelled)
350                }
351                result = self.receiver.recv() => {
352                    let received = result?;
353                    if self.topic.is_none() || received.topic == self.topic {
354                        return Ok::<IncomingMessage, ReceiveError>(received);
355                    }
356                }
357            }
358        }
359    }
360}
361
362impl<Payload> IpcClientTypedSubscription<Payload>
363where
364    Payload: DeserializeOwned + PayloadTypeName,
365{
366    pub(crate) fn new(subscription: IpcClientSubscription) -> Self {
367        Self(subscription, std::marker::PhantomData)
368    }
369
370    /// Receive a message.
371    /// Setting the cancellation_token to `None` will wait indefinitely.
372    pub async fn receive(
373        &mut self,
374        cancellation_token: Option<CancellationToken>,
375    ) -> Result<TypedIncomingMessage<Payload>, TypedReceiveError> {
376        let received = self.0.receive(cancellation_token).await?;
377        received
378            .try_into()
379            .map_err(|e: serde_utils::DeserializeError| TypedReceiveError::Typing(e.to_string()))
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use std::{collections::HashMap, time::Duration};
386
387    use bitwarden_threading::time::sleep;
388    use serde::{Deserialize, Serialize};
389
390    use super::*;
391    use crate::{
392        IpcClientExt,
393        endpoint::{Endpoint, HostId, Source},
394        ipc_client_trait::IpcClient,
395        message::PayloadTypeName,
396        rpc::{
397            request::RpcRequest,
398            request_message::{RPC_REQUEST_PAYLOAD_TYPE_NAME, RpcRequestMessage},
399            response_message::IncomingRpcResponseMessage,
400        },
401        traits::{InMemorySessionRepository, NoEncryptionCryptoProvider, TestCommunicationBackend},
402    };
403
404    /// Error type for [`TestCryptoProvider`] that carries an explicit fatal/recoverable
405    /// classification so tests can exercise both control-flow paths.
406    #[derive(Debug, Clone)]
407    struct TestCryptoError {
408        // Read only through the derived `Debug` impl (the outer `SendError` wraps it via
409        // `{e:?}`), which dead-code analysis does not count as a use.
410        #[allow(dead_code)]
411        message: String,
412        fatal: bool,
413    }
414
415    impl IpcErrorKind for TestCryptoError {
416        fn is_fatal(&self) -> bool {
417            self.fatal
418        }
419    }
420
421    struct TestCryptoProvider {
422        /// Simulate a send result. Set to `None` wait indefinitely
423        send_result: Option<Result<(), TestCryptoError>>,
424        /// Simulate a receive result. Set to `None` wait indefinitely
425        receive_result: Option<Result<IncomingMessage, TestCryptoError>>,
426    }
427
428    type TestSessionRepository = InMemorySessionRepository<String>;
429    impl CryptoProvider<TestCommunicationBackend, TestSessionRepository> for TestCryptoProvider {
430        type Session = String;
431        type SendError = TestCryptoError;
432        type ReceiveError = TestCryptoError;
433
434        async fn receive(
435            &self,
436            _receiver: &<TestCommunicationBackend as CommunicationBackend>::Receiver,
437            _communication: &TestCommunicationBackend,
438            _sessions: &TestSessionRepository,
439        ) -> Result<IncomingMessage, Self::ReceiveError> {
440            match &self.receive_result {
441                Some(result) => {
442                    // Yield (and throttle) so a recoverable error that makes the processing loop
443                    // `continue` doesn't busy-spin and starve the single-threaded test runtime.
444                    // Real backends await their underlying transport here, which has the same
445                    // effect.
446                    sleep(Duration::from_millis(5)).await;
447                    result.clone()
448                }
449                None => {
450                    // Simulate waiting for a message but never returning
451                    sleep(Duration::from_secs(600)).await;
452                    Err(TestCryptoError {
453                        message: "Simulated timeout".to_string(),
454                        fatal: true,
455                    })
456                }
457            }
458        }
459
460        async fn send(
461            &self,
462            _communication: &TestCommunicationBackend,
463            _sessions: &TestSessionRepository,
464            _message: OutgoingMessage,
465        ) -> Result<(), Self::SendError> {
466            match &self.send_result {
467                Some(result) => result.clone(),
468                None => {
469                    // Simulate waiting for a message to be send but never returning
470                    sleep(Duration::from_secs(600)).await;
471                    Err(TestCryptoError {
472                        message: "Simulated timeout".to_string(),
473                        fatal: true,
474                    })
475                }
476            }
477        }
478    }
479
480    #[tokio::test]
481    async fn returns_send_error_when_crypto_provider_returns_error() {
482        let message = OutgoingMessage {
483            payload: vec![],
484            destination: Endpoint::BrowserBackground { id: HostId::Own },
485            topic: None,
486        };
487        let crypto_provider = TestCryptoProvider {
488            send_result: Some(Err(TestCryptoError {
489                message: "Crypto error".to_string(),
490                fatal: false,
491            })),
492            receive_result: Some(Err(TestCryptoError {
493                message: "Should not have be called".to_string(),
494                fatal: false,
495            })),
496        };
497        let communication_provider = TestCommunicationBackend::new();
498        let session_map = TestSessionRepository::new(HashMap::new());
499        let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
500        let _ = client.start(None).await;
501
502        let error = client.send(message).await.unwrap_err();
503
504        assert!(error.to_string().contains("Crypto error"));
505    }
506
507    #[tokio::test]
508    async fn communication_provider_has_outgoing_message_when_sending_through_ipc_client() {
509        let message = OutgoingMessage {
510            payload: vec![],
511            destination: Endpoint::BrowserBackground { id: HostId::Own },
512            topic: None,
513        };
514        let crypto_provider = NoEncryptionCryptoProvider;
515        let communication_provider = TestCommunicationBackend::new();
516        let session_map = InMemorySessionRepository::new(HashMap::new());
517        let client =
518            IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
519        let _ = client.start(None).await;
520
521        client.send(message.clone()).await.unwrap();
522
523        let outgoing_messages = communication_provider.outgoing().await;
524        assert_eq!(outgoing_messages, vec![message]);
525    }
526
527    #[tokio::test]
528    async fn returns_received_message_when_received_from_backend() {
529        let message = IncomingMessage {
530            payload: vec![],
531            source: Source::Web {
532                tab_id: 9001,
533                document_id: "doc-1".to_string(),
534                origin: "https://example.com".to_string(),
535            },
536            destination: Endpoint::BrowserBackground { id: HostId::Own },
537            topic: None,
538        };
539        let crypto_provider = NoEncryptionCryptoProvider;
540        let communication_provider = TestCommunicationBackend::new();
541        let session_map = InMemorySessionRepository::new(HashMap::new());
542        let client =
543            IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
544        let _ = client.start(None).await;
545
546        let mut subscription = client
547            .subscribe(None)
548            .await
549            .expect("Subscribing should not fail");
550        communication_provider.push_incoming(message.clone());
551        let received_message = subscription.receive(None).await.unwrap();
552
553        assert_eq!(received_message, message);
554    }
555
556    #[tokio::test]
557    async fn skips_non_matching_topics_and_returns_first_matching_message() {
558        let non_matching_message = IncomingMessage {
559            payload: vec![],
560            source: Source::Web {
561                tab_id: 9001,
562                document_id: "doc-1".to_string(),
563                origin: "https://example.com".to_string(),
564            },
565            destination: Endpoint::BrowserBackground { id: HostId::Own },
566            topic: Some("non_matching_topic".to_owned()),
567        };
568        let matching_message = IncomingMessage {
569            payload: vec![109],
570            source: Source::Web {
571                tab_id: 9001,
572                document_id: "doc-1".to_string(),
573                origin: "https://example.com".to_string(),
574            },
575            destination: Endpoint::BrowserBackground { id: HostId::Own },
576            topic: Some("matching_topic".to_owned()),
577        };
578
579        let crypto_provider = NoEncryptionCryptoProvider;
580        let communication_provider = TestCommunicationBackend::new();
581        let session_map = InMemorySessionRepository::new(HashMap::new());
582        let client =
583            IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
584        let _ = client.start(None).await;
585        let mut subscription = client
586            .subscribe(Some("matching_topic".to_owned()))
587            .await
588            .expect("Subscribing should not fail");
589        communication_provider.push_incoming(non_matching_message.clone());
590        communication_provider.push_incoming(non_matching_message.clone());
591        communication_provider.push_incoming(matching_message.clone());
592
593        let received_message: IncomingMessage = subscription.receive(None).await.unwrap();
594
595        assert_eq!(received_message, matching_message);
596    }
597
598    #[tokio::test]
599    async fn skips_unrelated_messages_and_returns_typed_message() {
600        #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
601        struct TestPayload {
602            some_data: String,
603        }
604
605        impl PayloadTypeName for TestPayload {
606            const PAYLOAD_TYPE_NAME: &str = "TestPayload";
607        }
608
609        let unrelated = IncomingMessage {
610            payload: vec![],
611            source: Source::Web {
612                tab_id: 9001,
613                document_id: "doc-1".to_string(),
614                origin: "https://example.com".to_string(),
615            },
616            destination: Endpoint::BrowserBackground { id: HostId::Own },
617            topic: None,
618        };
619        let typed_message = crate::message::TypedIncomingMessage {
620            payload: TestPayload {
621                some_data: "Hello, world!".to_string(),
622            },
623            source: Source::Web {
624                tab_id: 9001,
625                document_id: "doc-1".to_string(),
626                origin: "https://example.com".to_string(),
627            },
628            destination: Endpoint::BrowserBackground { id: HostId::Own },
629        };
630
631        let crypto_provider = NoEncryptionCryptoProvider;
632        let communication_provider = TestCommunicationBackend::new();
633        let session_map = InMemorySessionRepository::new(HashMap::new());
634        let client =
635            IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
636        let _ = client.start(None).await;
637        let mut subscription = client
638            .subscribe_typed::<TestPayload>()
639            .await
640            .expect("Subscribing should not fail");
641        communication_provider.push_incoming(unrelated.clone());
642        communication_provider.push_incoming(unrelated.clone());
643        communication_provider.push_incoming(
644            typed_message
645                .clone()
646                .try_into()
647                .expect("Serialization should not fail"),
648        );
649
650        let received_message = subscription.receive(None).await.unwrap();
651
652        assert_eq!(received_message, typed_message);
653    }
654
655    #[tokio::test]
656    async fn returns_error_if_related_message_was_not_deserializable() {
657        #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
658        struct TestPayload {
659            some_data: String,
660        }
661
662        impl PayloadTypeName for TestPayload {
663            const PAYLOAD_TYPE_NAME: &str = "TestPayload";
664        }
665
666        let non_deserializable_message = IncomingMessage {
667            payload: vec![],
668            source: Source::Web {
669                tab_id: 9001,
670                document_id: "doc-1".to_string(),
671                origin: "https://example.com".to_string(),
672            },
673            destination: Endpoint::BrowserBackground { id: HostId::Own },
674            topic: Some("TestPayload".to_owned()),
675        };
676
677        let crypto_provider = NoEncryptionCryptoProvider;
678        let communication_provider = TestCommunicationBackend::new();
679        let session_map = InMemorySessionRepository::new(HashMap::new());
680        let client =
681            IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
682        let _ = client.start(None).await;
683        let mut subscription = client
684            .subscribe_typed::<TestPayload>()
685            .await
686            .expect("Subscribing should not fail");
687        communication_provider.push_incoming(non_deserializable_message.clone());
688
689        let result = subscription.receive(None).await;
690        assert!(matches!(result, Err(TypedReceiveError::Typing(_))));
691    }
692
693    #[tokio::test]
694    async fn ipc_client_stops_if_crypto_returns_fatal_send_error() {
695        let message = OutgoingMessage {
696            payload: vec![],
697            destination: Endpoint::BrowserBackground { id: HostId::Own },
698            topic: None,
699        };
700        let crypto_provider = TestCryptoProvider {
701            send_result: Some(Err(TestCryptoError {
702                message: "Crypto error".to_string(),
703                fatal: true,
704            })),
705            receive_result: None,
706        };
707        let communication_provider = TestCommunicationBackend::new();
708        let session_map = TestSessionRepository::new(HashMap::new());
709        let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
710        let _ = client.start(None).await;
711
712        let error = client.send(message).await.unwrap_err();
713        let is_running = client.is_running();
714
715        assert!(error.to_string().contains("Crypto error"));
716        assert!(!is_running);
717    }
718
719    #[tokio::test]
720    async fn ipc_client_keeps_running_if_crypto_returns_recoverable_send_error() {
721        let message = OutgoingMessage {
722            payload: vec![],
723            destination: Endpoint::BrowserBackground { id: HostId::Own },
724            topic: None,
725        };
726        let crypto_provider = TestCryptoProvider {
727            // A recoverable send error (e.g. a handshake timeout because the peer is down) must
728            // not tear down the shared client.
729            send_result: Some(Err(TestCryptoError {
730                message: "Crypto error".to_string(),
731                fatal: false,
732            })),
733            // Block forever on receive so the loop stays alive while we inspect it.
734            receive_result: None,
735        };
736        let communication_provider = TestCommunicationBackend::new();
737        let session_map = TestSessionRepository::new(HashMap::new());
738        let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
739        let _ = client.start(None).await;
740
741        let error = client.send(message).await.unwrap_err();
742        let is_running = client.is_running();
743
744        // The error is still surfaced to the caller...
745        assert!(error.to_string().contains("Crypto error"));
746        // ...but the client keeps running so future sends/requests can succeed.
747        assert!(is_running);
748    }
749
750    #[tokio::test]
751    async fn ipc_client_stops_if_crypto_returns_fatal_receive_error() {
752        let crypto_provider = TestCryptoProvider {
753            send_result: None,
754            receive_result: Some(Err(TestCryptoError {
755                message: "Crypto error".to_string(),
756                fatal: true,
757            })),
758        };
759        let communication_provider = TestCommunicationBackend::new();
760        let session_map = TestSessionRepository::new(HashMap::new());
761        let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
762        let cancellation_token = CancellationToken::new();
763        let _ = client.start(Some(cancellation_token.clone())).await;
764
765        // Give the client some time to process the error
766        tokio::time::sleep(Duration::from_millis(100)).await;
767        let is_running = client.is_running();
768
769        assert!(!is_running);
770        assert!(cancellation_token.is_cancelled());
771    }
772
773    #[tokio::test]
774    async fn ipc_client_keeps_running_if_crypto_returns_recoverable_receive_error() {
775        let crypto_provider = TestCryptoProvider {
776            send_result: None,
777            // A recoverable receive error must not stop the processing loop.
778            receive_result: Some(Err(TestCryptoError {
779                message: "Crypto error".to_string(),
780                fatal: false,
781            })),
782        };
783        let communication_provider = TestCommunicationBackend::new();
784        let session_map = TestSessionRepository::new(HashMap::new());
785        let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
786        let cancellation_token = CancellationToken::new();
787        let _ = client.start(Some(cancellation_token.clone())).await;
788
789        // Give the client time to hit the recoverable receive error (repeatedly).
790        tokio::time::sleep(Duration::from_millis(100)).await;
791        let is_running = client.is_running();
792
793        assert!(is_running);
794        assert!(!cancellation_token.is_cancelled());
795    }
796
797    #[tokio::test]
798    async fn ipc_client_is_not_running_if_cancellation_token_is_cancelled() {
799        let crypto_provider = TestCryptoProvider {
800            send_result: None,
801            receive_result: None,
802        };
803        let communication_provider = TestCommunicationBackend::new();
804        let session_map = TestSessionRepository::new(HashMap::new());
805        let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
806        let cancellation_token = CancellationToken::new();
807        let _ = client.start(Some(cancellation_token.clone())).await;
808
809        // Give the client some time to process
810        tokio::time::sleep(Duration::from_millis(100)).await;
811
812        // Cancel the token and give the client some time to process the cancellation
813        cancellation_token.cancel();
814        tokio::time::sleep(Duration::from_millis(100)).await;
815        let is_running = client.is_running();
816
817        assert!(!is_running);
818    }
819
820    #[tokio::test]
821    async fn ipc_client_is_running_if_no_errors_are_encountered() {
822        let crypto_provider = TestCryptoProvider {
823            send_result: None,
824            receive_result: None,
825        };
826        let communication_provider = TestCommunicationBackend::new();
827        let session_map = TestSessionRepository::new(HashMap::new());
828        let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
829        let cancellation_token = CancellationToken::new();
830        let _ = client.start(Some(cancellation_token.clone())).await;
831
832        // Give the client some time to process
833        tokio::time::sleep(Duration::from_millis(100)).await;
834        let is_running = client.is_running();
835
836        assert!(is_running);
837        assert!(!cancellation_token.is_cancelled());
838    }
839
840    #[tokio::test]
841    async fn ipc_client_is_not_running_if_not_started() {
842        let crypto_provider = TestCryptoProvider {
843            send_result: None,
844            receive_result: None,
845        };
846        let communication_provider = TestCommunicationBackend::new();
847        let session_map = TestSessionRepository::new(HashMap::new());
848        let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
849
850        // Give the client some time to process
851        tokio::time::sleep(Duration::from_millis(100)).await;
852        let is_running = client.is_running();
853
854        assert!(!is_running);
855    }
856
857    #[tokio::test]
858    async fn ipc_client_start_returns_error_if_already_running() {
859        let crypto_provider = TestCryptoProvider {
860            send_result: None,
861            receive_result: None,
862        };
863        let communication_provider = TestCommunicationBackend::new();
864        let session_map = TestSessionRepository::new(HashMap::new());
865        let client = IpcClientImpl::new(crypto_provider, communication_provider, session_map);
866        let cancellation_token = CancellationToken::new();
867        let first_result = client.start(Some(cancellation_token.clone())).await;
868        assert_eq!(first_result, Ok(()));
869
870        // Give the client some time to process
871        tokio::time::sleep(Duration::from_millis(100)).await;
872        assert!(client.is_running());
873
874        let second_result = client.start(Some(cancellation_token.clone())).await;
875        assert_eq!(second_result, Err(AlreadyRunningError));
876    }
877
878    mod request {
879        use super::*;
880        use crate::RpcHandler;
881
882        #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
883        struct TestRequest {
884            a: i32,
885            b: i32,
886        }
887
888        #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
889        struct TestResponse {
890            result: i32,
891        }
892
893        impl RpcRequest for TestRequest {
894            type Response = TestResponse;
895
896            const NAME: &str = "TestRequest";
897        }
898
899        struct TestHandler;
900
901        impl RpcHandler for TestHandler {
902            type Request = TestRequest;
903
904            async fn handle(&self, request: Self::Request) -> TestResponse {
905                TestResponse {
906                    result: request.a + request.b,
907                }
908            }
909        }
910
911        #[tokio::test]
912        async fn request_sends_message_and_returns_response() {
913            let crypto_provider = NoEncryptionCryptoProvider;
914            let communication_provider = TestCommunicationBackend::new();
915            let session_map = InMemorySessionRepository::default();
916            let client =
917                IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
918            let _ = client.start(None).await;
919            let request = TestRequest { a: 1, b: 2 };
920            let response = TestResponse { result: 3 };
921
922            // Send the request
923            let request_clone = request.clone();
924            let client_clone = client.clone();
925            let result_handle = tokio::spawn(async move {
926                client_clone
927                    .request::<TestRequest>(
928                        request_clone,
929                        Endpoint::BrowserBackground { id: HostId::Own },
930                        None,
931                    )
932                    .await
933            });
934            tokio::time::sleep(Duration::from_millis(100)).await;
935
936            // Read and verify the outgoing message
937            let outgoing_messages = communication_provider.outgoing().await;
938            let outgoing_request: RpcRequestMessage<TestRequest> =
939                serde_utils::from_slice(&outgoing_messages[0].payload)
940                    .expect("Deserialization should not fail");
941            assert_eq!(outgoing_request.request_type, "TestRequest");
942            assert_eq!(outgoing_request.request, request);
943
944            // Simulate receiving a response
945            let simulated_response = IncomingRpcResponseMessage {
946                result: Ok(response),
947                request_id: outgoing_request.request_id.clone(),
948                request_type: outgoing_request.request_type.clone(),
949            };
950            let simulated_response = IncomingMessage {
951                payload: serde_utils::to_vec(&simulated_response)
952                    .expect("Serialization should not fail"),
953                source: Source::BrowserBackground { id: HostId::Own },
954                destination: Endpoint::Web {
955                    tab_id: 9001,
956                    document_id: "doc-1".to_string(),
957                },
958                topic: Some(
959                    IncomingRpcResponseMessage::<TestRequest>::PAYLOAD_TYPE_NAME.to_owned(),
960                ),
961            };
962            communication_provider.push_incoming(simulated_response);
963
964            // Wait for the response
965            let result = result_handle.await.unwrap();
966            assert_eq!(result.unwrap().result, 3);
967        }
968
969        #[tokio::test]
970        async fn incoming_rpc_message_handles_request_and_returns_response() {
971            let crypto_provider = NoEncryptionCryptoProvider;
972            let communication_provider = TestCommunicationBackend::new();
973            let session_map = InMemorySessionRepository::default();
974            let client =
975                IpcClientImpl::new(crypto_provider, communication_provider.clone(), session_map);
976            let _ = client.start(None).await;
977            let request_id = uuid::Uuid::new_v4().to_string();
978            let request = TestRequest { a: 1, b: 2 };
979            let response = TestResponse { result: 3 };
980
981            // Register the handler
982            client.register_rpc_handler(TestHandler).await;
983
984            // Simulate receiving a request
985            let simulated_request = RpcRequestMessage {
986                request,
987                request_id: request_id.clone(),
988                request_type: "TestRequest".to_string(),
989            };
990            let simulated_request_message = IncomingMessage {
991                payload: serde_utils::to_vec(&simulated_request)
992                    .expect("Serialization should not fail"),
993                source: Source::Web {
994                    tab_id: 9001,
995                    document_id: "doc-1".to_string(),
996                    origin: "https://example.com".to_string(),
997                },
998                destination: Endpoint::BrowserBackground { id: HostId::Own },
999                topic: Some(RPC_REQUEST_PAYLOAD_TYPE_NAME.to_owned()),
1000            };
1001            communication_provider.push_incoming(simulated_request_message);
1002
1003            // Give the client some time to process the request
1004            tokio::time::sleep(Duration::from_millis(100)).await;
1005
1006            // Read and verify the outgoing message
1007            let outgoing_messages = communication_provider.outgoing().await;
1008            let outgoing_response: IncomingRpcResponseMessage<TestResponse> =
1009                serde_utils::from_slice(&outgoing_messages[0].payload)
1010                    .expect("Deserialization should not fail");
1011
1012            assert_eq!(
1013                outgoing_messages[0].topic,
1014                Some(IncomingRpcResponseMessage::<TestResponse>::PAYLOAD_TYPE_NAME.to_owned())
1015            );
1016            assert_eq!(outgoing_response.request_type, "TestRequest");
1017            assert_eq!(outgoing_response.result, Ok(response));
1018        }
1019    }
1020}