bitwarden_ipc/crypto_provider/noise/
crypto_provider.rs1use std::{sync::LazyLock, time::Duration};
2
3use bitwarden_threading::time::timeout;
4use serde::{Deserialize, Serialize};
5use tracing::{error, info, warn};
6
7use crate::{
8 crypto_provider::noise::{
9 handshake::{
10 CipherSuite, HandshakeFinishMessage, HandshakeInitiator, HandshakeResponder,
11 HandshakeStartMessage,
12 },
13 transport_state::{PersistentTransportState, TransportFrame},
14 },
15 message::{IncomingMessage, OutgoingMessage},
16 traits::{
17 CommunicationBackend, CommunicationBackendReceiver, CryptoProvider, SessionRepository,
18 },
19};
20
21pub struct NoiseCryptoProvider;
22
23#[derive(Debug)]
24pub enum NoiseCryptoProviderError {
25 HandshakeProtocol,
27 Timeout,
29 TransportSend,
31 TransportReceive,
33 DecryptionFailure,
35}
36
37static CRYPTO_STATE_GUARD: LazyLock<tokio::sync::Mutex<()>> =
40 LazyLock::new(|| tokio::sync::Mutex::new(()));
41
42impl NoiseCryptoProvider {
43 async fn perform_handshake<Com, Ses>(
44 communication: &Com,
45 sessions: &Ses,
46 destination: crate::endpoint::Endpoint,
47 ) -> Result<(), NoiseCryptoProviderError>
48 where
49 Com: CommunicationBackend,
50 Ses: SessionRepository<NoiseCryptoProviderState>,
51 {
52 info!("Starting noise handshake with {:?}", destination);
53
54 let mut initiator = HandshakeInitiator::new(&CipherSuite::default());
55 let message = initiator
56 .write_start_message()
57 .expect("Handshake start message should be buildable");
58 let receiver = communication.subscribe().await;
59
60 let handshake_frame = Frame::HandshakeStart(message);
61 communication
62 .send(OutgoingMessage {
63 payload: handshake_frame.to_cbor(),
64 destination: destination.clone(),
65 topic: None,
66 })
67 .await
68 .map_err(|_| NoiseCryptoProviderError::TransportSend)?;
69
70 timeout(Duration::from_secs(HANDSHAKE_TIMEOUT_SECS), async {
72 loop {
73 let incoming = receiver
74 .receive()
75 .await
76 .map_err(|_| NoiseCryptoProviderError::TransportReceive)?;
77
78 if incoming.source.to_endpoint() != destination {
80 continue;
81 }
82
83 let Ok(response_frame) = Frame::from_cbor(&incoming.payload) else {
85 return Err(NoiseCryptoProviderError::HandshakeProtocol);
86 };
87
88 if let Frame::HandshakeFinish(handshake_finish) = response_frame {
90 if initiator.read_response_message(&handshake_finish).is_err() {
91 error!("Failed to read handshake response message");
92 return Err(NoiseCryptoProviderError::HandshakeProtocol);
93 }
94 break;
95 }
96 }
97 Ok(())
98 })
99 .await
100 .map_err(|_| {
101 info!(
102 "Noise handshake with {:?} timed out after {} seconds",
103 destination, HANDSHAKE_TIMEOUT_SECS
104 );
105 NoiseCryptoProviderError::Timeout
106 })??;
109
110 let crypto_state = NoiseCryptoProviderState {
111 state: (&mut initiator).into(),
112 };
113 sessions
114 .save(destination.clone(), crypto_state)
115 .await
116 .expect("Save session should not fail");
117
118 info!(
119 "Noise handshake with {:?} completed, session established",
120 destination
121 );
122
123 Ok(())
124 }
125}
126
127const REHANDSHAKE_INTERVAL_SECS: u64 = 300;
130
131const HANDSHAKE_TIMEOUT_SECS: u64 = 2;
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct NoiseCryptoProviderState {
137 state: PersistentTransportState,
138}
139
140impl<Com, Ses> CryptoProvider<Com, Ses> for NoiseCryptoProvider
141where
142 Com: CommunicationBackend,
143 Ses: SessionRepository<NoiseCryptoProviderState>,
144{
145 type Session = NoiseCryptoProviderState;
146 type SendError = NoiseCryptoProviderError;
147 type ReceiveError = NoiseCryptoProviderError;
148
149 async fn send(
150 &self,
151 communication: &Com,
152 sessions: &Ses,
153 message: OutgoingMessage,
154 ) -> Result<(), Self::SendError> {
155 let _crypto_state_guard = CRYPTO_STATE_GUARD.lock().await;
159
160 let destination = message.destination.clone();
161
162 let crypto_state = sessions
163 .get(destination.clone())
164 .await
165 .expect("Get session should not fail");
166
167 let mut should_handshake = crypto_state.is_none();
168 if let Some(state) = crypto_state.as_ref()
169 && state.state.should_rehandshake(REHANDSHAKE_INTERVAL_SECS)
170 {
171 info!(
172 "Noise session with {:?} is older than {}s, re-handshaking",
173 destination, REHANDSHAKE_INTERVAL_SECS
174 );
175 sessions
176 .remove(destination.clone())
177 .await
178 .expect("Delete session should not fail");
179 should_handshake = true;
180 }
181
182 if should_handshake {
183 if crypto_state.is_none() {
184 info!(
185 "Noise handshake with {:?} initiated for new session establishment",
186 destination
187 );
188 } else {
189 info!(
190 "Noise re-handshake with {:?} due to re-handshake interval",
191 destination
192 );
193 }
194
195 Self::perform_handshake(communication, sessions, destination.clone()).await?;
196 }
197
198 let mut crypto_state = sessions
199 .get(destination.clone())
200 .await
201 .expect("Get session should not fail")
202 .expect("Session should exist after handshake");
203
204 let transport_frame = crypto_state
206 .state
207 .send(message.payload.into())
208 .map_err(|_| NoiseCryptoProviderError::DecryptionFailure)?;
209 communication
210 .send(OutgoingMessage {
211 payload: Frame::TransportFrame(transport_frame).to_cbor(),
212 destination: destination.clone(),
213 topic: message.topic,
214 })
215 .await
216 .map_err(|_| NoiseCryptoProviderError::TransportSend)?;
217
218 sessions
219 .save(destination, crypto_state)
220 .await
221 .expect("Save session should not fail");
222
223 Ok(())
224 }
225
226 async fn receive(
227 &self,
228 receiver: &Com::Receiver,
229 communication: &Com,
230 sessions: &Ses,
231 ) -> Result<IncomingMessage, Self::ReceiveError> {
232 loop {
233 let message = receiver
234 .receive()
235 .await
236 .map_err(|_| NoiseCryptoProviderError::TransportReceive)?;
237
238 let source_endpoint: crate::endpoint::Endpoint = message.source.clone().into();
240
241 let Ok(transport_frame) = Frame::from_cbor(&message.payload) else {
243 warn!("Received malformed cbor message, ignoring");
244 continue;
245 };
246
247 match transport_frame {
248 Frame::HandshakeStart(handshake_start) => {
249 let mut responder = HandshakeResponder::new(&handshake_start.ciphersuite);
250 responder
251 .read_start_message(&handshake_start)
252 .map_err(|_| NoiseCryptoProviderError::HandshakeProtocol)?;
253 let response_message = responder
254 .write_response_message()
255 .map_err(|_| NoiseCryptoProviderError::HandshakeProtocol)?;
256 let handshake_frame = Frame::HandshakeFinish(response_message);
257 communication
258 .send(OutgoingMessage {
259 payload: handshake_frame.to_cbor(),
260 destination: source_endpoint.clone(),
261 topic: None,
262 })
263 .await
264 .map_err(|_| NoiseCryptoProviderError::TransportSend)?;
265
266 let crypto_state = NoiseCryptoProviderState {
267 state: (&mut responder).into(),
268 };
269 sessions
270 .save(source_endpoint, crypto_state)
271 .await
272 .expect("Save session should not fail");
273 }
274 Frame::TransportFrame(transport_frame) => {
275 let _crypto_state_guard = CRYPTO_STATE_GUARD.lock().await;
276 let crypto_state = sessions
277 .get(source_endpoint.clone())
278 .await
279 .expect("Get session should not fail");
280 let Some(mut state) = crypto_state else {
281 info!("No session for {:?}, waiting for handshake", message.source);
282 let frame = Frame::CryptoInvalidated.to_cbor();
283 communication
284 .send(OutgoingMessage {
285 payload: frame,
286 destination: source_endpoint,
287 topic: None,
288 })
289 .await
290 .map_err(|_| NoiseCryptoProviderError::TransportSend)?;
291 continue;
292 };
293
294 let payload = state.state.receive(&transport_frame);
295 let Ok(payload) = payload else {
296 info!("Failed to decrypt message from {:?}", message.source);
297 continue;
298 };
299
300 sessions
301 .save(source_endpoint, state)
302 .await
303 .expect("Save session should not fail");
304
305 return Ok(IncomingMessage {
306 payload: payload.as_ref().to_vec(),
307 destination: message.destination,
308 source: message.source,
309 topic: message.topic,
310 });
311 }
312 Frame::CryptoInvalidated => {
313 info!(
314 "Invalidated session for {:?} due to crypto error, deleting session and waiting for handshake",
315 message.source
316 );
317 sessions
318 .remove(source_endpoint)
319 .await
320 .expect("Delete session should not fail");
321 }
322 _ => continue,
323 }
324 }
325 }
326}
327
328#[derive(Serialize, Deserialize)]
330pub(super) enum Frame {
331 HandshakeStart(HandshakeStartMessage),
333 HandshakeFinish(HandshakeFinishMessage),
334 TransportFrame(TransportFrame),
336 CryptoInvalidated,
339}
340
341impl Frame {
342 pub(crate) fn to_cbor(&self) -> Vec<u8> {
343 let mut buffer = Vec::new();
344 ciborium::into_writer(self, &mut buffer).expect("Ciborium serialization should not fail");
345 buffer
346 }
347
348 pub(crate) fn from_cbor(buffer: &[u8]) -> Result<Self, ()> {
349 ciborium::from_reader(buffer).map_err(|_| ())
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use std::collections::HashMap;
356
357 use crate::{
358 IpcClientImpl,
359 crypto_provider::noise::crypto_provider::NoiseCryptoProvider,
360 endpoint::Endpoint,
361 ipc_client_trait::IpcClient,
362 message::OutgoingMessage,
363 traits::{InMemorySessionRepository, TestTwoWayCommunicationBackend},
364 };
365
366 #[tokio::test]
367 async fn ping_pong() {
368 let (provider_1, provider_2) = TestTwoWayCommunicationBackend::new();
369
370 let session_map_1 = InMemorySessionRepository::new(HashMap::new());
371 let client_1 = IpcClientImpl::new(NoiseCryptoProvider, provider_1, session_map_1);
372 let _ = client_1.start(None).await;
373 let mut recv_1 = client_1.subscribe(None).await.unwrap();
374
375 let session_map_2 = InMemorySessionRepository::new(HashMap::new());
376 let client_2 = IpcClientImpl::new(NoiseCryptoProvider, provider_2, session_map_2);
377 let _ = client_2.start(None).await;
378 let mut recv_2 = client_2.subscribe(None).await.unwrap();
379
380 let handle_1 = tokio::spawn(async move {
381 let mut val: u8 = 0;
382 for _ in 0..255 {
383 let message = OutgoingMessage {
384 payload: vec![val],
385 destination: Endpoint::DesktopMain,
386 topic: None,
387 };
388 client_1.send(message).await.unwrap();
389 let recv_message = recv_1.receive(None).await.unwrap();
390 val = recv_message.payload[0] + 1;
391 }
392 });
393
394 let handle_2 = tokio::spawn(async move {
395 for _ in 0..255 {
396 let recv_message = recv_2.receive(None).await.unwrap();
397 let val = recv_message.payload[0];
398 if val == 255 {
399 break;
400 }
401
402 client_2
403 .send(OutgoingMessage {
404 payload: vec![val],
405 destination: Endpoint::DesktopMain,
406 topic: None,
407 })
408 .await
409 .unwrap();
410 }
411 });
412
413 let _ = tokio::join!(handle_1, handle_2);
414 }
415}