bitwarden_ipc/rpc/exec/
handler_registry.rs1use erased_serde::Serialize as ErasedSerialize;
2use tokio::sync::RwLock;
3
4use super::handler::{ErasedRpcHandler, RpcHandler};
5use crate::rpc::{error::RpcError, request::RpcRequest, request_message::RpcRequestPayload};
6
7pub struct RpcHandlerRegistry {
8 handlers: RwLock<std::collections::HashMap<String, Box<dyn ErasedRpcHandler>>>,
9}
10
11impl RpcHandlerRegistry {
12 pub fn new() -> Self {
13 Self {
14 handlers: RwLock::new(std::collections::HashMap::new()),
15 }
16 }
17
18 pub async fn register<H>(&self, handler: H)
19 where
20 H: RpcHandler + ErasedRpcHandler + 'static,
21 {
22 let name = H::Request::NAME.to_owned();
23 self.handlers.write().await.insert(name, Box::new(handler));
24 }
25
26 pub async fn handle(
27 &self,
28 request: &RpcRequestPayload,
29 ) -> Result<Box<dyn ErasedSerialize>, RpcError> {
30 match self.handlers.read().await.get(request.request_type()) {
31 Some(handler) => handler.handle(request).await,
32 None => Err(RpcError::NoHandlerFound),
33 }
34 }
35}
36
37#[cfg(test)]
38mod test {
39 use serde::{de::DeserializeOwned, Deserialize, Serialize};
40
41 use super::*;
42 use crate::{
43 rpc::{request::RpcRequest, request_message::RpcRequestMessage},
44 serde_utils,
45 };
46
47 #[derive(Debug, Clone, Serialize, Deserialize)]
48 struct TestRequest {
49 a: i32,
50 b: i32,
51 }
52
53 #[derive(Debug, Clone, Serialize, Deserialize)]
54 struct TestResponse {
55 result: i32,
56 }
57
58 impl RpcRequest for TestRequest {
59 type Response = TestResponse;
60
61 const NAME: &str = "TestRequest";
62 }
63
64 struct TestHandler;
65
66 impl RpcHandler for TestHandler {
67 type Request = TestRequest;
68
69 async fn handle(&self, request: Self::Request) -> TestResponse {
70 TestResponse {
71 result: request.a + request.b,
72 }
73 }
74 }
75
76 #[tokio::test]
77 async fn handle_returns_error_when_no_handler_can_be_found() {
78 let registry = RpcHandlerRegistry::new();
79
80 let request = TestRequest { a: 1, b: 2 };
81 let message = RpcRequestMessage {
82 request,
83 request_id: "test_id".to_string(),
84 request_type: "TestRequest".to_string(),
85 };
86 let serialized_request =
87 RpcRequestPayload::from_slice(serde_utils::to_vec(&message).unwrap()).unwrap();
88
89 let result = registry.handle(&serialized_request).await;
90
91 assert!(matches!(result, Err(RpcError::NoHandlerFound)));
92 }
93
94 #[tokio::test]
95 async fn handle_runs_previously_registered_handler() {
96 let registry = RpcHandlerRegistry::new();
97
98 registry.register(TestHandler).await;
99
100 let request = TestRequest { a: 1, b: 2 };
101 let message = RpcRequestMessage {
102 request,
103 request_id: "test_id".to_string(),
104 request_type: "TestRequest".to_string(),
105 };
106 let serialized_request =
107 RpcRequestPayload::from_slice(serde_utils::to_vec(&message).unwrap()).unwrap();
108
109 let result = registry
110 .handle(&serialized_request)
111 .await
112 .expect("Failed to handle request");
113 let response: TestResponse = deserialize_erased_object(&result);
114
115 assert_eq!(response.result, 3);
116 }
117
118 fn deserialize_erased_object<T, R>(value: &T) -> R
119 where
120 T: Serialize,
121 R: DeserializeOwned,
122 {
123 let serialized = serde_utils::to_vec(value).expect("Failed to serialize erased serialize");
124
125 serde_utils::from_slice(&serialized).expect("Failed to deserialize erased serialize")
126 }
127}