Skip to main content

bitwarden_ipc/crypto_provider/noise/
crypto_provider.rs

1use std::{sync::LazyLock, time::Duration};
2
3use bitwarden_threading::time::timeout;
4use serde::{Deserialize, Serialize};
5use tracing::{error, info, warn};
6
7use crate::{
8    crypto_provider::noise::{
9        handshake::{
10            CipherSuite, HandshakeFinishMessage, HandshakeInitiator, HandshakeResponder,
11            HandshakeStartMessage,
12        },
13        transport_state::{PersistentTransportState, TransportFrame},
14    },
15    error::IpcErrorKind,
16    message::{IncomingMessage, OutgoingMessage},
17    traits::{
18        CommunicationBackend, CommunicationBackendReceiver, CryptoProvider, SessionRepository,
19    },
20};
21
22pub struct NoiseCryptoProvider;
23
24#[derive(Debug)]
25pub enum NoiseCryptoProviderError {
26    /// A protocol error (missing message, malformed message)
27    HandshakeProtocol,
28    /// A timeout waiting for a message
29    Timeout,
30    /// Could not send via the underlying transport. `fatal` is derived from the underlying
31    /// backend error's [`IpcErrorKind`] classification.
32    TransportSend { fatal: bool },
33    /// Could not receive via the underlying transport. `fatal` is derived from the underlying
34    /// backend error's [`IpcErrorKind`] classification.
35    TransportReceive { fatal: bool },
36    /// A cryptographic error. In most cases, such messages are just dropped.
37    DecryptionFailure,
38}
39
40impl IpcErrorKind for NoiseCryptoProviderError {
41    fn is_fatal(&self) -> bool {
42        match self {
43            // A bad/missing handshake frame from one peer does not affect the shared client; the
44            // peer can retry the handshake.
45            NoiseCryptoProviderError::HandshakeProtocol => false,
46            // The handshake is retryable on a subsequent send.
47            NoiseCryptoProviderError::Timeout => false,
48            // A decryption failure only affects the offending message, which is dropped.
49            NoiseCryptoProviderError::DecryptionFailure => false,
50            // Defer to the underlying backend's classification, captured at construction.
51            NoiseCryptoProviderError::TransportSend { fatal } => *fatal,
52            NoiseCryptoProviderError::TransportReceive { fatal } => *fatal,
53        }
54    }
55}
56
57// Serialize send operations to prevent concurrent reads of the same persisted
58// transport state, which can cause nonce reuse.
59static CRYPTO_STATE_GUARD: LazyLock<tokio::sync::Mutex<()>> =
60    LazyLock::new(|| tokio::sync::Mutex::new(()));
61
62impl NoiseCryptoProvider {
63    async fn perform_handshake<Com, Ses>(
64        communication: &Com,
65        sessions: &Ses,
66        destination: crate::endpoint::Endpoint,
67    ) -> Result<(), NoiseCryptoProviderError>
68    where
69        Com: CommunicationBackend,
70        Ses: SessionRepository<NoiseCryptoProviderState>,
71    {
72        info!("Starting noise handshake with {:?}", destination);
73
74        let mut initiator = HandshakeInitiator::new(&CipherSuite::default());
75        let message = initiator
76            .write_start_message()
77            .expect("Handshake start message should be buildable");
78        let receiver = communication.subscribe().await;
79
80        let handshake_frame = Frame::HandshakeStart(message);
81        communication
82            .send(OutgoingMessage {
83                payload: handshake_frame.to_cbor(),
84                destination: destination.clone(),
85                topic: None,
86            })
87            .await
88            .map_err(|e| NoiseCryptoProviderError::TransportSend {
89                fatal: e.is_fatal(),
90            })?;
91
92        // Wait for the handshake response (with timeout)
93        timeout(Duration::from_secs(HANDSHAKE_TIMEOUT_SECS), async {
94            loop {
95                let incoming = receiver.receive().await.map_err(|e| {
96                    NoiseCryptoProviderError::TransportReceive {
97                        fatal: e.is_fatal(),
98                    }
99                })?;
100
101                // For concurrent handshakes, ignore messages
102                if incoming.source.to_endpoint() != destination {
103                    continue;
104                }
105
106                // Malformed messages will cancel the handshake
107                let Ok(response_frame) = Frame::from_cbor(&incoming.payload) else {
108                    return Err(NoiseCryptoProviderError::HandshakeProtocol);
109                };
110
111                // Only accept handshake finish messages until the handshake is complete
112                if let Frame::HandshakeFinish(handshake_finish) = response_frame {
113                    if initiator.read_response_message(&handshake_finish).is_err() {
114                        error!("Failed to read handshake response message");
115                        return Err(NoiseCryptoProviderError::HandshakeProtocol);
116                    }
117                    break;
118                }
119            }
120            Ok(())
121        })
122        .await
123        .map_err(|_| {
124            info!(
125                "Noise handshake with {:?} timed out after {} seconds",
126                destination, HANDSHAKE_TIMEOUT_SECS
127            );
128            NoiseCryptoProviderError::Timeout
129            // Both the timeout error, and errors from within the handshake loop are propagated
130            // here, hence the double question mark.
131        })??;
132
133        let crypto_state = NoiseCryptoProviderState {
134            state: (&mut initiator).into(),
135        };
136        sessions
137            .save(destination.clone(), crypto_state)
138            .await
139            .expect("Save session should not fail");
140
141        info!(
142            "Noise handshake with {:?} completed, session established",
143            destination
144        );
145
146        Ok(())
147    }
148}
149
150/// Re-handshake interval in seconds. Sessions older than this will automatically
151/// re-key on the next send operation.
152const REHANDSHAKE_INTERVAL_SECS: u64 = 300;
153
154/// Timeout for waiting for a handshake response from the remote peer.
155const HANDSHAKE_TIMEOUT_SECS: u64 = 2;
156
157/// Session state for the Noise crypto provider.
158#[derive(Debug, Clone, Serialize, Deserialize)]
159pub struct NoiseCryptoProviderState {
160    state: PersistentTransportState,
161}
162
163impl<Com, Ses> CryptoProvider<Com, Ses> for NoiseCryptoProvider
164where
165    Com: CommunicationBackend,
166    Ses: SessionRepository<NoiseCryptoProviderState>,
167{
168    type Session = NoiseCryptoProviderState;
169    type SendError = NoiseCryptoProviderError;
170    type ReceiveError = NoiseCryptoProviderError;
171
172    async fn send(
173        &self,
174        communication: &Com,
175        sessions: &Ses,
176        message: OutgoingMessage,
177    ) -> Result<(), Self::SendError> {
178        // Send operations *MUST* be serialized, otherwise nonce re-use may happen since
179        // concurrent sends may acquire the same copy of the transport state before nonce
180        // updating.
181        let _crypto_state_guard = CRYPTO_STATE_GUARD.lock().await;
182
183        let destination = message.destination.clone();
184
185        let crypto_state = sessions
186            .get(destination.clone())
187            .await
188            .expect("Get session should not fail");
189
190        let mut should_handshake = crypto_state.is_none();
191        if let Some(state) = crypto_state.as_ref()
192            && state.state.should_rehandshake(REHANDSHAKE_INTERVAL_SECS)
193        {
194            info!(
195                "Noise session with {:?} is older than {}s, re-handshaking",
196                destination, REHANDSHAKE_INTERVAL_SECS
197            );
198            sessions
199                .remove(destination.clone())
200                .await
201                .expect("Delete session should not fail");
202            should_handshake = true;
203        }
204
205        if should_handshake {
206            if crypto_state.is_none() {
207                info!(
208                    "Noise handshake with {:?} initiated for new session establishment",
209                    destination
210                );
211            } else {
212                info!(
213                    "Noise re-handshake with {:?} due to re-handshake interval",
214                    destination
215                );
216            }
217
218            Self::perform_handshake(communication, sessions, destination.clone()).await?;
219        }
220
221        let mut crypto_state = sessions
222            .get(destination.clone())
223            .await
224            .expect("Get session should not fail")
225            .expect("Session should exist after handshake");
226
227        // Encrypt and send the payload
228        let transport_frame = crypto_state
229            .state
230            .send(message.payload.into())
231            .map_err(|_| NoiseCryptoProviderError::DecryptionFailure)?;
232        communication
233            .send(OutgoingMessage {
234                payload: Frame::TransportFrame(transport_frame).to_cbor(),
235                destination: destination.clone(),
236                topic: message.topic,
237            })
238            .await
239            .map_err(|e| NoiseCryptoProviderError::TransportSend {
240                fatal: e.is_fatal(),
241            })?;
242
243        sessions
244            .save(destination, crypto_state)
245            .await
246            .expect("Save session should not fail");
247
248        Ok(())
249    }
250
251    async fn receive(
252        &self,
253        receiver: &Com::Receiver,
254        communication: &Com,
255        sessions: &Ses,
256    ) -> Result<IncomingMessage, Self::ReceiveError> {
257        loop {
258            let message = receiver.receive().await.map_err(|e| {
259                NoiseCryptoProviderError::TransportReceive {
260                    fatal: e.is_fatal(),
261                }
262            })?;
263
264            // Ensure session exists
265            let source_endpoint: crate::endpoint::Endpoint = message.source.clone().into();
266
267            // Decode outer transport frame from wire
268            let Ok(transport_frame) = Frame::from_cbor(&message.payload) else {
269                warn!("Received malformed cbor message, ignoring");
270                continue;
271            };
272
273            match transport_frame {
274                Frame::HandshakeStart(handshake_start) => {
275                    let mut responder = HandshakeResponder::new(&handshake_start.ciphersuite);
276                    responder
277                        .read_start_message(&handshake_start)
278                        .map_err(|_| NoiseCryptoProviderError::HandshakeProtocol)?;
279                    let response_message = responder
280                        .write_response_message()
281                        .map_err(|_| NoiseCryptoProviderError::HandshakeProtocol)?;
282                    let handshake_frame = Frame::HandshakeFinish(response_message);
283                    communication
284                        .send(OutgoingMessage {
285                            payload: handshake_frame.to_cbor(),
286                            destination: source_endpoint.clone(),
287                            topic: None,
288                        })
289                        .await
290                        .map_err(|e| NoiseCryptoProviderError::TransportSend {
291                            fatal: e.is_fatal(),
292                        })?;
293
294                    let crypto_state = NoiseCryptoProviderState {
295                        state: (&mut responder).into(),
296                    };
297                    sessions
298                        .save(source_endpoint, crypto_state)
299                        .await
300                        .expect("Save session should not fail");
301                }
302                Frame::TransportFrame(transport_frame) => {
303                    let _crypto_state_guard = CRYPTO_STATE_GUARD.lock().await;
304                    let crypto_state = sessions
305                        .get(source_endpoint.clone())
306                        .await
307                        .expect("Get session should not fail");
308                    let Some(mut state) = crypto_state else {
309                        info!("No session for {:?}, waiting for handshake", message.source);
310                        let frame = Frame::CryptoInvalidated.to_cbor();
311                        communication
312                            .send(OutgoingMessage {
313                                payload: frame,
314                                destination: source_endpoint,
315                                topic: None,
316                            })
317                            .await
318                            .map_err(|e| NoiseCryptoProviderError::TransportSend {
319                                fatal: e.is_fatal(),
320                            })?;
321                        continue;
322                    };
323
324                    let payload = state.state.receive(&transport_frame);
325                    let Ok(payload) = payload else {
326                        info!("Failed to decrypt message from {:?}", message.source);
327                        continue;
328                    };
329
330                    sessions
331                        .save(source_endpoint, state)
332                        .await
333                        .expect("Save session should not fail");
334
335                    return Ok(IncomingMessage {
336                        payload: payload.as_ref().to_vec(),
337                        destination: message.destination,
338                        source: message.source,
339                        topic: message.topic,
340                    });
341                }
342                Frame::CryptoInvalidated => {
343                    info!(
344                        "Invalidated session for {:?} due to crypto error, deleting session and waiting for handshake",
345                        message.source
346                    );
347                    sessions
348                        .remove(source_endpoint)
349                        .await
350                        .expect("Delete session should not fail");
351                }
352                _ => continue,
353            }
354        }
355    }
356}
357
358/// The raw frame that is sent via IPC.
359#[derive(Serialize, Deserialize)]
360pub(super) enum Frame {
361    // Handshake Frames
362    HandshakeStart(HandshakeStartMessage),
363    HandshakeFinish(HandshakeFinishMessage),
364    // After the handshake is done, transport frames are used to wrap ciphertexts
365    TransportFrame(TransportFrame),
366    // If crypto is invalidated, this message is sent by the device noticing
367    // the invalidation so that both sides reset the crypto.
368    CryptoInvalidated,
369}
370
371impl Frame {
372    pub(crate) fn to_cbor(&self) -> Vec<u8> {
373        let mut buffer = Vec::new();
374        ciborium::into_writer(self, &mut buffer).expect("Ciborium serialization should not fail");
375        buffer
376    }
377
378    pub(crate) fn from_cbor(buffer: &[u8]) -> Result<Self, ()> {
379        ciborium::from_reader(buffer).map_err(|_| ())
380    }
381}
382
383#[cfg(test)]
384mod tests {
385    use std::collections::HashMap;
386
387    use crate::{
388        IpcClientImpl,
389        crypto_provider::noise::crypto_provider::NoiseCryptoProvider,
390        endpoint::Endpoint,
391        ipc_client_trait::IpcClient,
392        message::OutgoingMessage,
393        traits::{InMemorySessionRepository, TestTwoWayCommunicationBackend},
394    };
395
396    #[tokio::test]
397    async fn ping_pong() {
398        let (provider_1, provider_2) = TestTwoWayCommunicationBackend::new();
399
400        let session_map_1 = InMemorySessionRepository::new(HashMap::new());
401        let client_1 = IpcClientImpl::new(NoiseCryptoProvider, provider_1, session_map_1);
402        let _ = client_1.start(None).await;
403        let mut recv_1 = client_1.subscribe(None).await.unwrap();
404
405        let session_map_2 = InMemorySessionRepository::new(HashMap::new());
406        let client_2 = IpcClientImpl::new(NoiseCryptoProvider, provider_2, session_map_2);
407        let _ = client_2.start(None).await;
408        let mut recv_2 = client_2.subscribe(None).await.unwrap();
409
410        let handle_1 = tokio::spawn(async move {
411            let mut val: u8 = 0;
412            for _ in 0..255 {
413                let message = OutgoingMessage {
414                    payload: vec![val],
415                    destination: Endpoint::DesktopMain,
416                    topic: None,
417                };
418                client_1.send(message).await.unwrap();
419                let recv_message = recv_1.receive(None).await.unwrap();
420                val = recv_message.payload[0] + 1;
421            }
422        });
423
424        let handle_2 = tokio::spawn(async move {
425            for _ in 0..255 {
426                let recv_message = recv_2.receive(None).await.unwrap();
427                let val = recv_message.payload[0];
428                if val == 255 {
429                    break;
430                }
431
432                client_2
433                    .send(OutgoingMessage {
434                        payload: vec![val],
435                        destination: Endpoint::DesktopMain,
436                        topic: None,
437                    })
438                    .await
439                    .unwrap();
440            }
441        });
442
443        let _ = tokio::join!(handle_1, handle_2);
444    }
445}