bitwarden_ipc/crypto_provider/noise/
crypto_provider.rs1use std::{sync::LazyLock, time::Duration};
2
3use bitwarden_threading::time::timeout;
4use serde::{Deserialize, Serialize};
5use tracing::{error, info, warn};
6
7use crate::{
8 crypto_provider::noise::{
9 handshake::{
10 CipherSuite, HandshakeFinishMessage, HandshakeInitiator, HandshakeResponder,
11 HandshakeStartMessage,
12 },
13 transport_state::{PersistentTransportState, TransportFrame},
14 },
15 error::IpcErrorKind,
16 message::{IncomingMessage, OutgoingMessage},
17 traits::{
18 CommunicationBackend, CommunicationBackendReceiver, CryptoProvider, SessionRepository,
19 },
20};
21
22pub struct NoiseCryptoProvider;
23
24#[derive(Debug)]
25pub enum NoiseCryptoProviderError {
26 HandshakeProtocol,
28 Timeout,
30 TransportSend { fatal: bool },
33 TransportReceive { fatal: bool },
36 DecryptionFailure,
38}
39
40impl IpcErrorKind for NoiseCryptoProviderError {
41 fn is_fatal(&self) -> bool {
42 match self {
43 NoiseCryptoProviderError::HandshakeProtocol => false,
46 NoiseCryptoProviderError::Timeout => false,
48 NoiseCryptoProviderError::DecryptionFailure => false,
50 NoiseCryptoProviderError::TransportSend { fatal } => *fatal,
52 NoiseCryptoProviderError::TransportReceive { fatal } => *fatal,
53 }
54 }
55}
56
57static CRYPTO_STATE_GUARD: LazyLock<tokio::sync::Mutex<()>> =
60 LazyLock::new(|| tokio::sync::Mutex::new(()));
61
62impl NoiseCryptoProvider {
63 async fn perform_handshake<Com, Ses>(
64 communication: &Com,
65 sessions: &Ses,
66 destination: crate::endpoint::Endpoint,
67 ) -> Result<(), NoiseCryptoProviderError>
68 where
69 Com: CommunicationBackend,
70 Ses: SessionRepository<NoiseCryptoProviderState>,
71 {
72 info!("Starting noise handshake with {:?}", destination);
73
74 let mut initiator = HandshakeInitiator::new(&CipherSuite::default());
75 let message = initiator
76 .write_start_message()
77 .expect("Handshake start message should be buildable");
78 let receiver = communication.subscribe().await;
79
80 let handshake_frame = Frame::HandshakeStart(message);
81 communication
82 .send(OutgoingMessage {
83 payload: handshake_frame.to_cbor(),
84 destination: destination.clone(),
85 topic: None,
86 })
87 .await
88 .map_err(|e| NoiseCryptoProviderError::TransportSend {
89 fatal: e.is_fatal(),
90 })?;
91
92 timeout(Duration::from_secs(HANDSHAKE_TIMEOUT_SECS), async {
94 loop {
95 let incoming = receiver.receive().await.map_err(|e| {
96 NoiseCryptoProviderError::TransportReceive {
97 fatal: e.is_fatal(),
98 }
99 })?;
100
101 if incoming.source.to_endpoint() != destination {
103 continue;
104 }
105
106 let Ok(response_frame) = Frame::from_cbor(&incoming.payload) else {
108 return Err(NoiseCryptoProviderError::HandshakeProtocol);
109 };
110
111 if let Frame::HandshakeFinish(handshake_finish) = response_frame {
113 if initiator.read_response_message(&handshake_finish).is_err() {
114 error!("Failed to read handshake response message");
115 return Err(NoiseCryptoProviderError::HandshakeProtocol);
116 }
117 break;
118 }
119 }
120 Ok(())
121 })
122 .await
123 .map_err(|_| {
124 info!(
125 "Noise handshake with {:?} timed out after {} seconds",
126 destination, HANDSHAKE_TIMEOUT_SECS
127 );
128 NoiseCryptoProviderError::Timeout
129 })??;
132
133 let crypto_state = NoiseCryptoProviderState {
134 state: (&mut initiator).into(),
135 };
136 sessions
137 .save(destination.clone(), crypto_state)
138 .await
139 .expect("Save session should not fail");
140
141 info!(
142 "Noise handshake with {:?} completed, session established",
143 destination
144 );
145
146 Ok(())
147 }
148}
149
150const REHANDSHAKE_INTERVAL_SECS: u64 = 300;
153
154const HANDSHAKE_TIMEOUT_SECS: u64 = 2;
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct NoiseCryptoProviderState {
160 state: PersistentTransportState,
161}
162
163impl<Com, Ses> CryptoProvider<Com, Ses> for NoiseCryptoProvider
164where
165 Com: CommunicationBackend,
166 Ses: SessionRepository<NoiseCryptoProviderState>,
167{
168 type Session = NoiseCryptoProviderState;
169 type SendError = NoiseCryptoProviderError;
170 type ReceiveError = NoiseCryptoProviderError;
171
172 async fn send(
173 &self,
174 communication: &Com,
175 sessions: &Ses,
176 message: OutgoingMessage,
177 ) -> Result<(), Self::SendError> {
178 let _crypto_state_guard = CRYPTO_STATE_GUARD.lock().await;
182
183 let destination = message.destination.clone();
184
185 let crypto_state = sessions
186 .get(destination.clone())
187 .await
188 .expect("Get session should not fail");
189
190 let mut should_handshake = crypto_state.is_none();
191 if let Some(state) = crypto_state.as_ref()
192 && state.state.should_rehandshake(REHANDSHAKE_INTERVAL_SECS)
193 {
194 info!(
195 "Noise session with {:?} is older than {}s, re-handshaking",
196 destination, REHANDSHAKE_INTERVAL_SECS
197 );
198 sessions
199 .remove(destination.clone())
200 .await
201 .expect("Delete session should not fail");
202 should_handshake = true;
203 }
204
205 if should_handshake {
206 if crypto_state.is_none() {
207 info!(
208 "Noise handshake with {:?} initiated for new session establishment",
209 destination
210 );
211 } else {
212 info!(
213 "Noise re-handshake with {:?} due to re-handshake interval",
214 destination
215 );
216 }
217
218 Self::perform_handshake(communication, sessions, destination.clone()).await?;
219 }
220
221 let mut crypto_state = sessions
222 .get(destination.clone())
223 .await
224 .expect("Get session should not fail")
225 .expect("Session should exist after handshake");
226
227 let transport_frame = crypto_state
229 .state
230 .send(message.payload.into())
231 .map_err(|_| NoiseCryptoProviderError::DecryptionFailure)?;
232 communication
233 .send(OutgoingMessage {
234 payload: Frame::TransportFrame(transport_frame).to_cbor(),
235 destination: destination.clone(),
236 topic: message.topic,
237 })
238 .await
239 .map_err(|e| NoiseCryptoProviderError::TransportSend {
240 fatal: e.is_fatal(),
241 })?;
242
243 sessions
244 .save(destination, crypto_state)
245 .await
246 .expect("Save session should not fail");
247
248 Ok(())
249 }
250
251 async fn receive(
252 &self,
253 receiver: &Com::Receiver,
254 communication: &Com,
255 sessions: &Ses,
256 ) -> Result<IncomingMessage, Self::ReceiveError> {
257 loop {
258 let message = receiver.receive().await.map_err(|e| {
259 NoiseCryptoProviderError::TransportReceive {
260 fatal: e.is_fatal(),
261 }
262 })?;
263
264 let source_endpoint: crate::endpoint::Endpoint = message.source.clone().into();
266
267 let Ok(transport_frame) = Frame::from_cbor(&message.payload) else {
269 warn!("Received malformed cbor message, ignoring");
270 continue;
271 };
272
273 match transport_frame {
274 Frame::HandshakeStart(handshake_start) => {
275 let mut responder = HandshakeResponder::new(&handshake_start.ciphersuite);
276 responder
277 .read_start_message(&handshake_start)
278 .map_err(|_| NoiseCryptoProviderError::HandshakeProtocol)?;
279 let response_message = responder
280 .write_response_message()
281 .map_err(|_| NoiseCryptoProviderError::HandshakeProtocol)?;
282 let handshake_frame = Frame::HandshakeFinish(response_message);
283 communication
284 .send(OutgoingMessage {
285 payload: handshake_frame.to_cbor(),
286 destination: source_endpoint.clone(),
287 topic: None,
288 })
289 .await
290 .map_err(|e| NoiseCryptoProviderError::TransportSend {
291 fatal: e.is_fatal(),
292 })?;
293
294 let crypto_state = NoiseCryptoProviderState {
295 state: (&mut responder).into(),
296 };
297 sessions
298 .save(source_endpoint, crypto_state)
299 .await
300 .expect("Save session should not fail");
301 }
302 Frame::TransportFrame(transport_frame) => {
303 let _crypto_state_guard = CRYPTO_STATE_GUARD.lock().await;
304 let crypto_state = sessions
305 .get(source_endpoint.clone())
306 .await
307 .expect("Get session should not fail");
308 let Some(mut state) = crypto_state else {
309 info!("No session for {:?}, waiting for handshake", message.source);
310 let frame = Frame::CryptoInvalidated.to_cbor();
311 communication
312 .send(OutgoingMessage {
313 payload: frame,
314 destination: source_endpoint,
315 topic: None,
316 })
317 .await
318 .map_err(|e| NoiseCryptoProviderError::TransportSend {
319 fatal: e.is_fatal(),
320 })?;
321 continue;
322 };
323
324 let payload = state.state.receive(&transport_frame);
325 let Ok(payload) = payload else {
326 info!("Failed to decrypt message from {:?}", message.source);
327 continue;
328 };
329
330 sessions
331 .save(source_endpoint, state)
332 .await
333 .expect("Save session should not fail");
334
335 return Ok(IncomingMessage {
336 payload: payload.as_ref().to_vec(),
337 destination: message.destination,
338 source: message.source,
339 topic: message.topic,
340 });
341 }
342 Frame::CryptoInvalidated => {
343 info!(
344 "Invalidated session for {:?} due to crypto error, deleting session and waiting for handshake",
345 message.source
346 );
347 sessions
348 .remove(source_endpoint)
349 .await
350 .expect("Delete session should not fail");
351 }
352 _ => continue,
353 }
354 }
355 }
356}
357
358#[derive(Serialize, Deserialize)]
360pub(super) enum Frame {
361 HandshakeStart(HandshakeStartMessage),
363 HandshakeFinish(HandshakeFinishMessage),
364 TransportFrame(TransportFrame),
366 CryptoInvalidated,
369}
370
371impl Frame {
372 pub(crate) fn to_cbor(&self) -> Vec<u8> {
373 let mut buffer = Vec::new();
374 ciborium::into_writer(self, &mut buffer).expect("Ciborium serialization should not fail");
375 buffer
376 }
377
378 pub(crate) fn from_cbor(buffer: &[u8]) -> Result<Self, ()> {
379 ciborium::from_reader(buffer).map_err(|_| ())
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use std::collections::HashMap;
386
387 use crate::{
388 IpcClientImpl,
389 crypto_provider::noise::crypto_provider::NoiseCryptoProvider,
390 endpoint::Endpoint,
391 ipc_client_trait::IpcClient,
392 message::OutgoingMessage,
393 traits::{InMemorySessionRepository, TestTwoWayCommunicationBackend},
394 };
395
396 #[tokio::test]
397 async fn ping_pong() {
398 let (provider_1, provider_2) = TestTwoWayCommunicationBackend::new();
399
400 let session_map_1 = InMemorySessionRepository::new(HashMap::new());
401 let client_1 = IpcClientImpl::new(NoiseCryptoProvider, provider_1, session_map_1);
402 let _ = client_1.start(None).await;
403 let mut recv_1 = client_1.subscribe(None).await.unwrap();
404
405 let session_map_2 = InMemorySessionRepository::new(HashMap::new());
406 let client_2 = IpcClientImpl::new(NoiseCryptoProvider, provider_2, session_map_2);
407 let _ = client_2.start(None).await;
408 let mut recv_2 = client_2.subscribe(None).await.unwrap();
409
410 let handle_1 = tokio::spawn(async move {
411 let mut val: u8 = 0;
412 for _ in 0..255 {
413 let message = OutgoingMessage {
414 payload: vec![val],
415 destination: Endpoint::DesktopMain,
416 topic: None,
417 };
418 client_1.send(message).await.unwrap();
419 let recv_message = recv_1.receive(None).await.unwrap();
420 val = recv_message.payload[0] + 1;
421 }
422 });
423
424 let handle_2 = tokio::spawn(async move {
425 for _ in 0..255 {
426 let recv_message = recv_2.receive(None).await.unwrap();
427 let val = recv_message.payload[0];
428 if val == 255 {
429 break;
430 }
431
432 client_2
433 .send(OutgoingMessage {
434 payload: vec![val],
435 destination: Endpoint::DesktopMain,
436 topic: None,
437 })
438 .await
439 .unwrap();
440 }
441 });
442
443 let _ = tokio::join!(handle_1, handle_2);
444 }
445}