Skip to main content

bitwarden_ipc/rpc/exec/
handler_registry.rs

1use erased_serde::Serialize as ErasedSerialize;
2use tokio::sync::RwLock;
3
4use super::handler::{ErasedRpcHandler, RpcHandler};
5use crate::rpc::{error::RpcError, request::RpcRequest, request_message::RpcRequestPayload};
6
7pub struct RpcHandlerRegistry {
8    handlers: RwLock<std::collections::HashMap<String, Box<dyn ErasedRpcHandler>>>,
9}
10
11impl RpcHandlerRegistry {
12    pub fn new() -> Self {
13        Self {
14            handlers: RwLock::new(std::collections::HashMap::new()),
15        }
16    }
17
18    pub async fn register<H>(&self, handler: H)
19    where
20        H: RpcHandler + ErasedRpcHandler + 'static,
21    {
22        let name = H::Request::NAME.to_owned();
23        self.register_erased(name, Box::new(handler)).await;
24    }
25
26    pub async fn register_erased(&self, name: String, handler: Box<dyn ErasedRpcHandler>) {
27        self.handlers.write().await.insert(name, handler);
28    }
29
30    pub async fn handle(
31        &self,
32        request: &RpcRequestPayload,
33    ) -> Result<Box<dyn ErasedSerialize>, RpcError> {
34        match self.handlers.read().await.get(request.request_type()) {
35            Some(handler) => handler.handle(request).await,
36            None => Err(RpcError::NoHandlerFound),
37        }
38    }
39}
40
41#[cfg(test)]
42mod test {
43    use serde::{Deserialize, Serialize, de::DeserializeOwned};
44
45    use super::*;
46    use crate::{
47        rpc::{request::RpcRequest, request_message::RpcRequestMessage},
48        serde_utils,
49    };
50
51    #[derive(Debug, Clone, Serialize, Deserialize)]
52    struct TestRequest {
53        a: i32,
54        b: i32,
55    }
56
57    #[derive(Debug, Clone, Serialize, Deserialize)]
58    struct TestResponse {
59        result: i32,
60    }
61
62    impl RpcRequest for TestRequest {
63        type Response = TestResponse;
64
65        const NAME: &str = "TestRequest";
66    }
67
68    struct TestHandler;
69
70    impl RpcHandler for TestHandler {
71        type Request = TestRequest;
72
73        async fn handle(&self, request: Self::Request) -> TestResponse {
74            TestResponse {
75                result: request.a + request.b,
76            }
77        }
78    }
79
80    #[tokio::test]
81    async fn handle_returns_error_when_no_handler_can_be_found() {
82        let registry = RpcHandlerRegistry::new();
83
84        let request = TestRequest { a: 1, b: 2 };
85        let message = RpcRequestMessage {
86            request,
87            request_id: "test_id".to_string(),
88            request_type: "TestRequest".to_string(),
89        };
90        let serialized_request =
91            RpcRequestPayload::from_slice(serde_utils::to_vec(&message).unwrap()).unwrap();
92
93        let result = registry.handle(&serialized_request).await;
94
95        assert!(matches!(result, Err(RpcError::NoHandlerFound)));
96    }
97
98    #[tokio::test]
99    async fn handle_runs_previously_registered_handler() {
100        let registry = RpcHandlerRegistry::new();
101
102        registry.register(TestHandler).await;
103
104        let request = TestRequest { a: 1, b: 2 };
105        let message = RpcRequestMessage {
106            request,
107            request_id: "test_id".to_string(),
108            request_type: "TestRequest".to_string(),
109        };
110        let serialized_request =
111            RpcRequestPayload::from_slice(serde_utils::to_vec(&message).unwrap()).unwrap();
112
113        let result = registry
114            .handle(&serialized_request)
115            .await
116            .expect("Failed to handle request");
117        let response: TestResponse = deserialize_erased_object(&result);
118
119        assert_eq!(response.result, 3);
120    }
121
122    fn deserialize_erased_object<T, R>(value: &T) -> R
123    where
124        T: Serialize,
125        R: DeserializeOwned,
126    {
127        let serialized = serde_utils::to_vec(value).expect("Failed to serialize erased serialize");
128
129        serde_utils::from_slice(&serialized).expect("Failed to deserialize erased serialize")
130    }
131}