bitwarden_ipc/rpc/exec/
handler_registry.rs

1use erased_serde::Serialize as ErasedSerialize;
2use tokio::sync::RwLock;
3
4use super::handler::{ErasedRpcHandler, RpcHandler};
5use crate::rpc::{error::RpcError, request::RpcRequest, request_message::RpcRequestPayload};
6
7pub struct RpcHandlerRegistry {
8    handlers: RwLock<std::collections::HashMap<String, Box<dyn ErasedRpcHandler>>>,
9}
10
11impl RpcHandlerRegistry {
12    pub fn new() -> Self {
13        Self {
14            handlers: RwLock::new(std::collections::HashMap::new()),
15        }
16    }
17
18    pub async fn register<H>(&self, handler: H)
19    where
20        H: RpcHandler + ErasedRpcHandler + 'static,
21    {
22        let name = H::Request::NAME.to_owned();
23        self.handlers.write().await.insert(name, Box::new(handler));
24    }
25
26    pub async fn handle(
27        &self,
28        request: &RpcRequestPayload,
29    ) -> Result<Box<dyn ErasedSerialize>, RpcError> {
30        match self.handlers.read().await.get(request.request_type()) {
31            Some(handler) => handler.handle(request).await,
32            None => Err(RpcError::NoHandlerFound),
33        }
34    }
35}
36
37#[cfg(test)]
38mod test {
39    use serde::{de::DeserializeOwned, Deserialize, Serialize};
40
41    use super::*;
42    use crate::{
43        rpc::{request::RpcRequest, request_message::RpcRequestMessage},
44        serde_utils,
45    };
46
47    #[derive(Debug, Clone, Serialize, Deserialize)]
48    struct TestRequest {
49        a: i32,
50        b: i32,
51    }
52
53    #[derive(Debug, Clone, Serialize, Deserialize)]
54    struct TestResponse {
55        result: i32,
56    }
57
58    impl RpcRequest for TestRequest {
59        type Response = TestResponse;
60
61        const NAME: &str = "TestRequest";
62    }
63
64    struct TestHandler;
65
66    impl RpcHandler for TestHandler {
67        type Request = TestRequest;
68
69        async fn handle(&self, request: Self::Request) -> TestResponse {
70            TestResponse {
71                result: request.a + request.b,
72            }
73        }
74    }
75
76    #[tokio::test]
77    async fn handle_returns_error_when_no_handler_can_be_found() {
78        let registry = RpcHandlerRegistry::new();
79
80        let request = TestRequest { a: 1, b: 2 };
81        let message = RpcRequestMessage {
82            request,
83            request_id: "test_id".to_string(),
84            request_type: "TestRequest".to_string(),
85        };
86        let serialized_request =
87            RpcRequestPayload::from_slice(serde_utils::to_vec(&message).unwrap()).unwrap();
88
89        let result = registry.handle(&serialized_request).await;
90
91        assert!(matches!(result, Err(RpcError::NoHandlerFound)));
92    }
93
94    #[tokio::test]
95    async fn handle_runs_previously_registered_handler() {
96        let registry = RpcHandlerRegistry::new();
97
98        registry.register(TestHandler).await;
99
100        let request = TestRequest { a: 1, b: 2 };
101        let message = RpcRequestMessage {
102            request,
103            request_id: "test_id".to_string(),
104            request_type: "TestRequest".to_string(),
105        };
106        let serialized_request =
107            RpcRequestPayload::from_slice(serde_utils::to_vec(&message).unwrap()).unwrap();
108
109        let result = registry
110            .handle(&serialized_request)
111            .await
112            .expect("Failed to handle request");
113        let response: TestResponse = deserialize_erased_object(&result);
114
115        assert_eq!(response.result, 3);
116    }
117
118    fn deserialize_erased_object<T, R>(value: &T) -> R
119    where
120        T: Serialize,
121        R: DeserializeOwned,
122    {
123        let serialized = serde_utils::to_vec(value).expect("Failed to serialize erased serialize");
124
125        serde_utils::from_slice(&serialized).expect("Failed to deserialize erased serialize")
126    }
127}