bitwarden_ipc/crypto_provider/noise/
crypto_provider.rs1use std::{sync::LazyLock, time::Duration};
2
3use serde::{Deserialize, Serialize};
4use tokio::time::timeout;
5use tracing::{error, info, warn};
6
7use crate::{
8 crypto_provider::noise::{
9 handshake::{
10 CipherSuite, HandshakeFinishMessage, HandshakeInitiator, HandshakeResponder,
11 HandshakeStartMessage,
12 },
13 transport_state::{PersistentTransportState, TransportFrame},
14 },
15 message::{IncomingMessage, OutgoingMessage},
16 traits::{
17 CommunicationBackend, CommunicationBackendReceiver, CryptoProvider, SessionRepository,
18 },
19};
20
21pub struct NoiseCryptoProvider;
22
23#[derive(Debug)]
24pub enum NoiseCryptoProviderError {
25 HandshakeProtocol,
27 Timeout,
29 TransportSend,
31 TransportReceive,
33 DecryptionFailure,
35}
36
37static CRYPTO_STATE_GUARD: LazyLock<tokio::sync::Mutex<()>> =
40 LazyLock::new(|| tokio::sync::Mutex::new(()));
41
42impl NoiseCryptoProvider {
43 async fn perform_handshake<Com, Ses>(
44 communication: &Com,
45 sessions: &Ses,
46 destination: crate::endpoint::Endpoint,
47 ) -> Result<(), NoiseCryptoProviderError>
48 where
49 Com: CommunicationBackend,
50 Ses: SessionRepository<NoiseCryptoProviderState>,
51 {
52 info!("Starting noise handshake with {:?}", destination);
53
54 let mut initiator = HandshakeInitiator::new(&CipherSuite::default());
55 let message = initiator
56 .write_start_message()
57 .expect("Handshake start message should be buildable");
58 let receiver = communication.subscribe().await;
59
60 let handshake_frame = Frame::HandshakeStart(message);
61 communication
62 .send(OutgoingMessage {
63 payload: handshake_frame.to_cbor(),
64 destination: destination.clone(),
65 topic: None,
66 })
67 .await
68 .map_err(|_| NoiseCryptoProviderError::TransportSend)?;
69
70 timeout(Duration::from_secs(HANDSHAKE_TIMEOUT_SECS), async {
72 loop {
73 let incoming = receiver
74 .receive()
75 .await
76 .map_err(|_| NoiseCryptoProviderError::TransportReceive)?;
77
78 if incoming.source.to_endpoint() != destination {
80 continue;
81 }
82
83 let Ok(response_frame) = Frame::from_cbor(&incoming.payload) else {
85 return Err(NoiseCryptoProviderError::HandshakeProtocol);
86 };
87
88 if let Frame::HandshakeFinish(handshake_finish) = response_frame {
90 if initiator.read_response_message(&handshake_finish).is_err() {
91 error!("Failed to read handshake response message");
92 return Err(NoiseCryptoProviderError::HandshakeProtocol);
93 }
94 break;
95 }
96 }
97 Ok(())
98 })
99 .await
100 .map_err(|_| {
101 info!(
102 "Noise handshake with {:?} timed out after {} seconds",
103 destination, HANDSHAKE_TIMEOUT_SECS
104 );
105 NoiseCryptoProviderError::Timeout
106 })??;
109
110 let crypto_state = NoiseCryptoProviderState {
111 state: (&mut initiator).into(),
112 };
113 sessions
114 .save(destination.clone(), crypto_state)
115 .await
116 .expect("Save session should not fail");
117
118 info!(
119 "Noise handshake with {:?} completed, session established",
120 destination
121 );
122
123 Ok(())
124 }
125}
126
127const REHANDSHAKE_INTERVAL_SECS: u64 = 300;
130
131const HANDSHAKE_TIMEOUT_SECS: u64 = 2;
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct NoiseCryptoProviderState {
137 state: PersistentTransportState,
138}
139
140impl<Com, Ses> CryptoProvider<Com, Ses> for NoiseCryptoProvider
141where
142 Com: CommunicationBackend,
143 Ses: SessionRepository<NoiseCryptoProviderState>,
144{
145 type Session = NoiseCryptoProviderState;
146 type SendError = NoiseCryptoProviderError;
147 type ReceiveError = NoiseCryptoProviderError;
148
149 async fn send(
150 &self,
151 communication: &Com,
152 sessions: &Ses,
153 message: OutgoingMessage,
154 ) -> Result<(), Self::SendError> {
155 let _crypto_state_guard = CRYPTO_STATE_GUARD.lock().await;
159
160 let destination = message.destination.clone();
161
162 let crypto_state = sessions
163 .get(destination.clone())
164 .await
165 .expect("Get session should not fail");
166
167 let mut should_handshake = crypto_state.is_none();
168 if let Some(state) = crypto_state.as_ref()
169 && state.state.should_rehandshake(REHANDSHAKE_INTERVAL_SECS)
170 {
171 info!(
172 "Noise session with {:?} is older than {}s, re-handshaking",
173 destination, REHANDSHAKE_INTERVAL_SECS
174 );
175 sessions
176 .remove(destination.clone())
177 .await
178 .expect("Delete session should not fail");
179 should_handshake = true;
180 }
181
182 if should_handshake {
183 if crypto_state.is_none() {
184 info!(
185 "Noise handshake with {:?} initiated for new session establishment",
186 destination
187 );
188 } else {
189 info!(
190 "Noise re-handshake with {:?} due to re-handshake interval",
191 destination
192 );
193 }
194
195 Self::perform_handshake(communication, sessions, destination.clone()).await?;
196 }
197
198 let mut crypto_state = sessions
199 .get(destination.clone())
200 .await
201 .expect("Get session should not fail")
202 .expect("Session should exist after handshake");
203
204 let transport_frame = crypto_state
206 .state
207 .send(message.payload.into())
208 .map_err(|_| NoiseCryptoProviderError::DecryptionFailure)?;
209 communication
210 .send(OutgoingMessage {
211 payload: Frame::TransportFrame(transport_frame).to_cbor(),
212 destination: destination.clone(),
213 topic: message.topic,
214 })
215 .await
216 .map_err(|_| NoiseCryptoProviderError::TransportSend)?;
217
218 sessions
219 .save(destination, crypto_state)
220 .await
221 .expect("Save session should not fail");
222
223 Ok(())
224 }
225
226 async fn receive(
227 &self,
228 receiver: &Com::Receiver,
229 communication: &Com,
230 sessions: &Ses,
231 ) -> Result<IncomingMessage, Self::ReceiveError> {
232 loop {
233 let message = receiver
234 .receive()
235 .await
236 .map_err(|_| NoiseCryptoProviderError::TransportReceive)?;
237
238 let source_endpoint: crate::endpoint::Endpoint = message.source.clone().into();
240
241 let Ok(transport_frame) = Frame::from_cbor(&message.payload) else {
243 warn!("Received malformed cbor message, ignoring");
244 continue;
245 };
246
247 match transport_frame {
248 Frame::HandshakeStart(handshake_start) => {
249 let mut responder = HandshakeResponder::new(&handshake_start.ciphersuite);
250 responder
251 .read_start_message(&handshake_start)
252 .map_err(|_| NoiseCryptoProviderError::HandshakeProtocol)?;
253 let response_message = responder
254 .write_response_message()
255 .map_err(|_| NoiseCryptoProviderError::HandshakeProtocol)?;
256 let handshake_frame = Frame::HandshakeFinish(response_message);
257 communication
258 .send(OutgoingMessage {
259 payload: handshake_frame.to_cbor(),
260 destination: source_endpoint.clone(),
261 topic: None,
262 })
263 .await
264 .map_err(|_| NoiseCryptoProviderError::TransportSend)?;
265
266 let crypto_state = NoiseCryptoProviderState {
267 state: (&mut responder).into(),
268 };
269 sessions
270 .save(source_endpoint, crypto_state)
271 .await
272 .expect("Save session should not fail");
273 }
274 Frame::TransportFrame(transport_frame) => {
275 let _crypto_state_guard = CRYPTO_STATE_GUARD.lock().await;
276 let crypto_state = sessions
277 .get(source_endpoint.clone())
278 .await
279 .expect("Get session should not fail");
280 let Some(mut state) = crypto_state else {
281 info!("No session for {:?}, waiting for handshake", message.source);
282 let frame = Frame::CryptoInvalidated.to_cbor();
283 communication
284 .send(OutgoingMessage {
285 payload: frame,
286 destination: source_endpoint,
287 topic: None,
288 })
289 .await
290 .map_err(|_| NoiseCryptoProviderError::TransportSend)?;
291 continue;
292 };
293
294 let payload = state
295 .state
296 .receive(&transport_frame)
297 .map_err(|_| NoiseCryptoProviderError::DecryptionFailure)?;
298 sessions
299 .save(source_endpoint, state)
300 .await
301 .expect("Save session should not fail");
302
303 return Ok(IncomingMessage {
304 payload: payload.as_ref().to_vec(),
305 destination: message.destination,
306 source: message.source,
307 topic: message.topic,
308 });
309 }
310 Frame::CryptoInvalidated => {
311 info!(
312 "Invalidated session for {:?} due to crypto error, deleting session and waiting for handshake",
313 message.source
314 );
315 sessions
316 .remove(source_endpoint)
317 .await
318 .expect("Delete session should not fail");
319 }
320 _ => continue,
321 }
322 }
323 }
324}
325
326#[derive(Serialize, Deserialize)]
328pub(super) enum Frame {
329 HandshakeStart(HandshakeStartMessage),
331 HandshakeFinish(HandshakeFinishMessage),
332 TransportFrame(TransportFrame),
334 CryptoInvalidated,
337}
338
339impl Frame {
340 pub(crate) fn to_cbor(&self) -> Vec<u8> {
341 let mut buffer = Vec::new();
342 ciborium::into_writer(self, &mut buffer).expect("Ciborium serialization should not fail");
343 buffer
344 }
345
346 pub(crate) fn from_cbor(buffer: &[u8]) -> Result<Self, ()> {
347 ciborium::from_reader(buffer).map_err(|_| ())
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use std::collections::HashMap;
354
355 use crate::{
356 IpcClientImpl,
357 crypto_provider::noise::crypto_provider::NoiseCryptoProvider,
358 endpoint::Endpoint,
359 ipc_client_trait::IpcClient,
360 message::OutgoingMessage,
361 traits::{InMemorySessionRepository, TestTwoWayCommunicationBackend},
362 };
363
364 #[tokio::test]
365 async fn ping_pong() {
366 let (provider_1, provider_2) = TestTwoWayCommunicationBackend::new();
367
368 let session_map_1 = InMemorySessionRepository::new(HashMap::new());
369 let client_1 = IpcClientImpl::new(NoiseCryptoProvider, provider_1, session_map_1);
370 let _ = client_1.start(None).await;
371 let mut recv_1 = client_1.subscribe(None).await.unwrap();
372
373 let session_map_2 = InMemorySessionRepository::new(HashMap::new());
374 let client_2 = IpcClientImpl::new(NoiseCryptoProvider, provider_2, session_map_2);
375 let _ = client_2.start(None).await;
376 let mut recv_2 = client_2.subscribe(None).await.unwrap();
377
378 let handle_1 = tokio::spawn(async move {
379 let mut val: u8 = 0;
380 for _ in 0..255 {
381 let message = OutgoingMessage {
382 payload: vec![val],
383 destination: Endpoint::DesktopMain,
384 topic: None,
385 };
386 client_1.send(message).await.unwrap();
387 let recv_message = recv_1.receive(None).await.unwrap();
388 val = recv_message.payload[0] + 1;
389 }
390 });
391
392 let handle_2 = tokio::spawn(async move {
393 for _ in 0..255 {
394 let recv_message = recv_2.receive(None).await.unwrap();
395 let val = recv_message.payload[0];
396 if val == 255 {
397 break;
398 }
399
400 client_2
401 .send(OutgoingMessage {
402 payload: vec![val],
403 destination: Endpoint::DesktopMain,
404 topic: None,
405 })
406 .await
407 .unwrap();
408 }
409 });
410
411 let _ = tokio::join!(handle_1, handle_2);
412 }
413}