Skip to main content

bitwarden_ipc/
ipc_client_ext.rs

1use bitwarden_threading::cancellation_token::CancellationToken;
2use serde::{Serialize, de::DeserializeOwned};
3
4use crate::{
5    RpcHandler,
6    endpoint::Endpoint,
7    error::{RequestError, SubscribeError},
8    ipc_client::IpcClientTypedSubscription,
9    ipc_client_trait::IpcClient,
10    message::{PayloadTypeName, TypedOutgoingMessage},
11    rpc::{
12        error::RpcError, request::RpcRequest, request_message::RpcRequestMessage,
13        response_message::IncomingRpcResponseMessage,
14    },
15    serde_utils,
16};
17
18/// Extension trait providing generic convenience methods on any [`IpcClient`].
19///
20/// This trait is automatically implemented for all types that implement [`IpcClient`],
21/// including `dyn IpcClient`. It provides typed subscriptions, handler registration,
22/// and RPC request functionality with full static type safety.
23pub trait IpcClientExt: IpcClient {
24    /// Register a new RPC handler for processing incoming RPC requests.
25    /// The handler will be executed by the IPC client when an RPC request is received and
26    /// the response will be sent back over IPC.
27    fn register_rpc_handler<H>(&self, handler: H) -> impl std::future::Future<Output = ()> + Send
28    where
29        H: RpcHandler + Send + Sync + 'static,
30    {
31        async move {
32            self.register_rpc_handler_erased(H::Request::NAME, Box::new(handler))
33                .await;
34        }
35    }
36
37    /// Send a message with a payload of any serializable type to the specified destination.
38    fn send_typed<Payload>(
39        &self,
40        payload: Payload,
41        destination: Endpoint,
42    ) -> impl std::future::Future<Output = Result<(), RequestError>> + Send
43    where
44        Payload: Serialize + PayloadTypeName + Send,
45    {
46        async move {
47            let message = TypedOutgoingMessage {
48                payload,
49                destination,
50            }
51            .try_into()
52            .map_err(|e: serde_utils::DeserializeError| {
53                RequestError::Rpc(RpcError::RequestSerialization(e.to_string()))
54            })?;
55
56            self.send(message)
57                .await
58                .map_err(|e| RequestError::Send(format!("{e:?}")))
59        }
60    }
61
62    /// Create a subscription to receive messages that can be deserialized into the provided
63    /// payload type.
64    fn subscribe_typed<Payload>(
65        &self,
66    ) -> impl std::future::Future<
67        Output = Result<IpcClientTypedSubscription<Payload>, SubscribeError>,
68    > + Send
69    where
70        Payload: DeserializeOwned + PayloadTypeName,
71    {
72        async move {
73            Ok(IpcClientTypedSubscription::new(
74                self.subscribe(Some(Payload::PAYLOAD_TYPE_NAME.to_owned()))
75                    .await?,
76            ))
77        }
78    }
79
80    /// Send a request to the specified destination and wait for a response.
81    /// The destination must have a registered RPC handler for the request type, otherwise
82    /// an error will be returned by the remote endpoint.
83    fn request<Request>(
84        &self,
85        request: Request,
86        destination: Endpoint,
87        cancellation_token: Option<CancellationToken>,
88    ) -> impl std::future::Future<Output = Result<Request::Response, RequestError>> + Send
89    where
90        Request: RpcRequest + Send,
91        Request::Response: Send,
92    {
93        async move {
94            let request_id = uuid::Uuid::new_v4().to_string();
95            let mut response_subscription = self
96                .subscribe_typed::<IncomingRpcResponseMessage<Request::Response>>()
97                .await?;
98
99            let request_payload = RpcRequestMessage {
100                request,
101                request_id: request_id.clone(),
102                request_type: Request::NAME.to_owned(),
103            };
104
105            let message = TypedOutgoingMessage {
106                payload: request_payload,
107                destination,
108            }
109            .try_into()
110            .map_err(|e: serde_utils::DeserializeError| {
111                RequestError::Rpc(RpcError::RequestSerialization(e.to_string()))
112            })?;
113
114            self.send(message)
115                .await
116                .map_err(|e| RequestError::Send(format!("{e:?}")))?;
117
118            let response = loop {
119                let received = response_subscription
120                    .receive(cancellation_token.clone())
121                    .await
122                    .map_err(RequestError::Receive)?;
123
124                if received.payload.request_id == request_id {
125                    break received;
126                }
127            };
128
129            Ok(response.payload.result?)
130        }
131    }
132}
133
134/// Blanket implementation: every [`IpcClient`] gets the extension methods for free.
135impl<T: IpcClient + ?Sized> IpcClientExt for T {}