bitwarden_ipc/rpc/exec/
handler_registry.rs1use erased_serde::Serialize as ErasedSerialize;
2use tokio::sync::RwLock;
3
4use super::handler::{ErasedRpcHandler, RpcHandler};
5use crate::rpc::{error::RpcError, request::RpcRequest, request_message::RpcRequestPayload};
6
7pub struct RpcHandlerRegistry {
8 handlers: RwLock<std::collections::HashMap<String, Box<dyn ErasedRpcHandler>>>,
9}
10
11impl RpcHandlerRegistry {
12 pub fn new() -> Self {
13 Self {
14 handlers: RwLock::new(std::collections::HashMap::new()),
15 }
16 }
17
18 pub async fn register<H>(&self, handler: H)
19 where
20 H: RpcHandler + ErasedRpcHandler + 'static,
21 {
22 let name = H::Request::NAME.to_owned();
23 self.register_erased(name, Box::new(handler)).await;
24 }
25
26 pub async fn register_erased(&self, name: String, handler: Box<dyn ErasedRpcHandler>) {
27 self.handlers.write().await.insert(name, handler);
28 }
29
30 pub async fn handle(
31 &self,
32 request: &RpcRequestPayload,
33 ) -> Result<Box<dyn ErasedSerialize>, RpcError> {
34 match self.handlers.read().await.get(request.request_type()) {
35 Some(handler) => handler.handle(request).await,
36 None => Err(RpcError::NoHandlerFound),
37 }
38 }
39}
40
41#[cfg(test)]
42mod test {
43 use serde::{Deserialize, Serialize, de::DeserializeOwned};
44
45 use super::*;
46 use crate::{
47 rpc::{request::RpcRequest, request_message::RpcRequestMessage},
48 serde_utils,
49 };
50
51 #[derive(Debug, Clone, Serialize, Deserialize)]
52 struct TestRequest {
53 a: i32,
54 b: i32,
55 }
56
57 #[derive(Debug, Clone, Serialize, Deserialize)]
58 struct TestResponse {
59 result: i32,
60 }
61
62 impl RpcRequest for TestRequest {
63 type Response = TestResponse;
64
65 const NAME: &str = "TestRequest";
66 }
67
68 struct TestHandler;
69
70 impl RpcHandler for TestHandler {
71 type Request = TestRequest;
72
73 async fn handle(&self, request: Self::Request) -> TestResponse {
74 TestResponse {
75 result: request.a + request.b,
76 }
77 }
78 }
79
80 #[tokio::test]
81 async fn handle_returns_error_when_no_handler_can_be_found() {
82 let registry = RpcHandlerRegistry::new();
83
84 let request = TestRequest { a: 1, b: 2 };
85 let message = RpcRequestMessage {
86 request,
87 request_id: "test_id".to_string(),
88 request_type: "TestRequest".to_string(),
89 };
90 let serialized_request =
91 RpcRequestPayload::from_slice(serde_utils::to_vec(&message).unwrap()).unwrap();
92
93 let result = registry.handle(&serialized_request).await;
94
95 assert!(matches!(result, Err(RpcError::NoHandlerFound)));
96 }
97
98 #[tokio::test]
99 async fn handle_runs_previously_registered_handler() {
100 let registry = RpcHandlerRegistry::new();
101
102 registry.register(TestHandler).await;
103
104 let request = TestRequest { a: 1, b: 2 };
105 let message = RpcRequestMessage {
106 request,
107 request_id: "test_id".to_string(),
108 request_type: "TestRequest".to_string(),
109 };
110 let serialized_request =
111 RpcRequestPayload::from_slice(serde_utils::to_vec(&message).unwrap()).unwrap();
112
113 let result = registry
114 .handle(&serialized_request)
115 .await
116 .expect("Failed to handle request");
117 let response: TestResponse = deserialize_erased_object(&result);
118
119 assert_eq!(response.result, 3);
120 }
121
122 fn deserialize_erased_object<T, R>(value: &T) -> R
123 where
124 T: Serialize,
125 R: DeserializeOwned,
126 {
127 let serialized = serde_utils::to_vec(value).expect("Failed to serialize erased serialize");
128
129 serde_utils::from_slice(&serialized).expect("Failed to deserialize erased serialize")
130 }
131}