Skip to main content

bitwarden_ipc/crypto_provider/noise/
crypto_provider.rs

1use std::{sync::LazyLock, time::Duration};
2
3use serde::{Deserialize, Serialize};
4use tokio::time::timeout;
5use tracing::{error, info, warn};
6
7use crate::{
8    crypto_provider::noise::{
9        handshake::{
10            CipherSuite, HandshakeFinishMessage, HandshakeInitiator, HandshakeResponder,
11            HandshakeStartMessage,
12        },
13        transport_state::{PersistentTransportState, TransportFrame},
14    },
15    message::{IncomingMessage, OutgoingMessage},
16    traits::{
17        CommunicationBackend, CommunicationBackendReceiver, CryptoProvider, SessionRepository,
18    },
19};
20
21pub struct NoiseCryptoProvider;
22
23#[derive(Debug)]
24pub enum NoiseCryptoProviderError {
25    /// A protocol error (missing message, malformed message)
26    HandshakeProtocol,
27    /// A timeout waiting for a message
28    Timeout,
29    /// Could not send via the underlying transport
30    TransportSend,
31    /// Could not receive via the underlying transport
32    TransportReceive,
33    /// A cryptographic error. In most cases, such messages are just dropped.
34    DecryptionFailure,
35}
36
37// Serialize send operations to prevent concurrent reads of the same persisted
38// transport state, which can cause nonce reuse.
39static CRYPTO_STATE_GUARD: LazyLock<tokio::sync::Mutex<()>> =
40    LazyLock::new(|| tokio::sync::Mutex::new(()));
41
42impl NoiseCryptoProvider {
43    async fn perform_handshake<Com, Ses>(
44        communication: &Com,
45        sessions: &Ses,
46        destination: crate::endpoint::Endpoint,
47    ) -> Result<(), NoiseCryptoProviderError>
48    where
49        Com: CommunicationBackend,
50        Ses: SessionRepository<NoiseCryptoProviderState>,
51    {
52        info!("Starting noise handshake with {:?}", destination);
53
54        let mut initiator = HandshakeInitiator::new(&CipherSuite::default());
55        let message = initiator
56            .write_start_message()
57            .expect("Handshake start message should be buildable");
58        let receiver = communication.subscribe().await;
59
60        let handshake_frame = Frame::HandshakeStart(message);
61        communication
62            .send(OutgoingMessage {
63                payload: handshake_frame.to_cbor(),
64                destination: destination.clone(),
65                topic: None,
66            })
67            .await
68            .map_err(|_| NoiseCryptoProviderError::TransportSend)?;
69
70        // Wait for the handshake response (with timeout)
71        timeout(Duration::from_secs(HANDSHAKE_TIMEOUT_SECS), async {
72            loop {
73                let incoming = receiver
74                    .receive()
75                    .await
76                    .map_err(|_| NoiseCryptoProviderError::TransportReceive)?;
77
78                // For concurrent handshakes, ignore messages
79                if incoming.source.to_endpoint() != destination {
80                    continue;
81                }
82
83                // Malformed messages will cancel the handshake
84                let Ok(response_frame) = Frame::from_cbor(&incoming.payload) else {
85                    return Err(NoiseCryptoProviderError::HandshakeProtocol);
86                };
87
88                // Only accept handshake finish messages until the handshake is complete
89                if let Frame::HandshakeFinish(handshake_finish) = response_frame {
90                    if initiator.read_response_message(&handshake_finish).is_err() {
91                        error!("Failed to read handshake response message");
92                        return Err(NoiseCryptoProviderError::HandshakeProtocol);
93                    }
94                    break;
95                }
96            }
97            Ok(())
98        })
99        .await
100        .map_err(|_| {
101            info!(
102                "Noise handshake with {:?} timed out after {} seconds",
103                destination, HANDSHAKE_TIMEOUT_SECS
104            );
105            NoiseCryptoProviderError::Timeout
106            // Both the timeout error, and errors from within the handshake loop are propagated
107            // here, hence the double question mark.
108        })??;
109
110        let crypto_state = NoiseCryptoProviderState {
111            state: (&mut initiator).into(),
112        };
113        sessions
114            .save(destination.clone(), crypto_state)
115            .await
116            .expect("Save session should not fail");
117
118        info!(
119            "Noise handshake with {:?} completed, session established",
120            destination
121        );
122
123        Ok(())
124    }
125}
126
127/// Re-handshake interval in seconds. Sessions older than this will automatically
128/// re-key on the next send operation.
129const REHANDSHAKE_INTERVAL_SECS: u64 = 300;
130
131/// Timeout for waiting for a handshake response from the remote peer.
132const HANDSHAKE_TIMEOUT_SECS: u64 = 2;
133
134/// Session state for the Noise crypto provider.
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct NoiseCryptoProviderState {
137    state: PersistentTransportState,
138}
139
140impl<Com, Ses> CryptoProvider<Com, Ses> for NoiseCryptoProvider
141where
142    Com: CommunicationBackend,
143    Ses: SessionRepository<NoiseCryptoProviderState>,
144{
145    type Session = NoiseCryptoProviderState;
146    type SendError = NoiseCryptoProviderError;
147    type ReceiveError = NoiseCryptoProviderError;
148
149    async fn send(
150        &self,
151        communication: &Com,
152        sessions: &Ses,
153        message: OutgoingMessage,
154    ) -> Result<(), Self::SendError> {
155        // Send operations *MUST* be serialized, otherwise nonce re-use may happen since
156        // concurrent sends may acquire the same copy of the transport state before nonce
157        // updating.
158        let _crypto_state_guard = CRYPTO_STATE_GUARD.lock().await;
159
160        let destination = message.destination.clone();
161
162        let crypto_state = sessions
163            .get(destination.clone())
164            .await
165            .expect("Get session should not fail");
166
167        let mut should_handshake = crypto_state.is_none();
168        if let Some(state) = crypto_state.as_ref()
169            && state.state.should_rehandshake(REHANDSHAKE_INTERVAL_SECS)
170        {
171            info!(
172                "Noise session with {:?} is older than {}s, re-handshaking",
173                destination, REHANDSHAKE_INTERVAL_SECS
174            );
175            sessions
176                .remove(destination.clone())
177                .await
178                .expect("Delete session should not fail");
179            should_handshake = true;
180        }
181
182        if should_handshake {
183            if crypto_state.is_none() {
184                info!(
185                    "Noise handshake with {:?} initiated for new session establishment",
186                    destination
187                );
188            } else {
189                info!(
190                    "Noise re-handshake with {:?} due to re-handshake interval",
191                    destination
192                );
193            }
194
195            Self::perform_handshake(communication, sessions, destination.clone()).await?;
196        }
197
198        let mut crypto_state = sessions
199            .get(destination.clone())
200            .await
201            .expect("Get session should not fail")
202            .expect("Session should exist after handshake");
203
204        // Encrypt and send the payload
205        let transport_frame = crypto_state
206            .state
207            .send(message.payload.into())
208            .map_err(|_| NoiseCryptoProviderError::DecryptionFailure)?;
209        communication
210            .send(OutgoingMessage {
211                payload: Frame::TransportFrame(transport_frame).to_cbor(),
212                destination: destination.clone(),
213                topic: message.topic,
214            })
215            .await
216            .map_err(|_| NoiseCryptoProviderError::TransportSend)?;
217
218        sessions
219            .save(destination, crypto_state)
220            .await
221            .expect("Save session should not fail");
222
223        Ok(())
224    }
225
226    async fn receive(
227        &self,
228        receiver: &Com::Receiver,
229        communication: &Com,
230        sessions: &Ses,
231    ) -> Result<IncomingMessage, Self::ReceiveError> {
232        loop {
233            let message = receiver
234                .receive()
235                .await
236                .map_err(|_| NoiseCryptoProviderError::TransportReceive)?;
237
238            // Ensure session exists
239            let source_endpoint: crate::endpoint::Endpoint = message.source.clone().into();
240
241            // Decode outer transport frame from wire
242            let Ok(transport_frame) = Frame::from_cbor(&message.payload) else {
243                warn!("Received malformed cbor message, ignoring");
244                continue;
245            };
246
247            match transport_frame {
248                Frame::HandshakeStart(handshake_start) => {
249                    let mut responder = HandshakeResponder::new(&handshake_start.ciphersuite);
250                    responder
251                        .read_start_message(&handshake_start)
252                        .map_err(|_| NoiseCryptoProviderError::HandshakeProtocol)?;
253                    let response_message = responder
254                        .write_response_message()
255                        .map_err(|_| NoiseCryptoProviderError::HandshakeProtocol)?;
256                    let handshake_frame = Frame::HandshakeFinish(response_message);
257                    communication
258                        .send(OutgoingMessage {
259                            payload: handshake_frame.to_cbor(),
260                            destination: source_endpoint.clone(),
261                            topic: None,
262                        })
263                        .await
264                        .map_err(|_| NoiseCryptoProviderError::TransportSend)?;
265
266                    let crypto_state = NoiseCryptoProviderState {
267                        state: (&mut responder).into(),
268                    };
269                    sessions
270                        .save(source_endpoint, crypto_state)
271                        .await
272                        .expect("Save session should not fail");
273                }
274                Frame::TransportFrame(transport_frame) => {
275                    let _crypto_state_guard = CRYPTO_STATE_GUARD.lock().await;
276                    let crypto_state = sessions
277                        .get(source_endpoint.clone())
278                        .await
279                        .expect("Get session should not fail");
280                    let Some(mut state) = crypto_state else {
281                        info!("No session for {:?}, waiting for handshake", message.source);
282                        let frame = Frame::CryptoInvalidated.to_cbor();
283                        communication
284                            .send(OutgoingMessage {
285                                payload: frame,
286                                destination: source_endpoint,
287                                topic: None,
288                            })
289                            .await
290                            .map_err(|_| NoiseCryptoProviderError::TransportSend)?;
291                        continue;
292                    };
293
294                    let payload = state
295                        .state
296                        .receive(&transport_frame)
297                        .map_err(|_| NoiseCryptoProviderError::DecryptionFailure)?;
298                    sessions
299                        .save(source_endpoint, state)
300                        .await
301                        .expect("Save session should not fail");
302
303                    return Ok(IncomingMessage {
304                        payload: payload.as_ref().to_vec(),
305                        destination: message.destination,
306                        source: message.source,
307                        topic: message.topic,
308                    });
309                }
310                Frame::CryptoInvalidated => {
311                    info!(
312                        "Invalidated session for {:?} due to crypto error, deleting session and waiting for handshake",
313                        message.source
314                    );
315                    sessions
316                        .remove(source_endpoint)
317                        .await
318                        .expect("Delete session should not fail");
319                }
320                _ => continue,
321            }
322        }
323    }
324}
325
326/// The raw frame that is sent via IPC.
327#[derive(Serialize, Deserialize)]
328pub(super) enum Frame {
329    // Handshake Frames
330    HandshakeStart(HandshakeStartMessage),
331    HandshakeFinish(HandshakeFinishMessage),
332    // After the handshake is done, transport frames are used to wrap ciphertexts
333    TransportFrame(TransportFrame),
334    // If crypto is invalidated, this message is sent by the device noticing
335    // the invalidation so that both sides reset the crypto.
336    CryptoInvalidated,
337}
338
339impl Frame {
340    pub(crate) fn to_cbor(&self) -> Vec<u8> {
341        let mut buffer = Vec::new();
342        ciborium::into_writer(self, &mut buffer).expect("Ciborium serialization should not fail");
343        buffer
344    }
345
346    pub(crate) fn from_cbor(buffer: &[u8]) -> Result<Self, ()> {
347        ciborium::from_reader(buffer).map_err(|_| ())
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use std::collections::HashMap;
354
355    use crate::{
356        IpcClientImpl,
357        crypto_provider::noise::crypto_provider::NoiseCryptoProvider,
358        endpoint::Endpoint,
359        ipc_client_trait::IpcClient,
360        message::OutgoingMessage,
361        traits::{InMemorySessionRepository, TestTwoWayCommunicationBackend},
362    };
363
364    #[tokio::test]
365    async fn ping_pong() {
366        let (provider_1, provider_2) = TestTwoWayCommunicationBackend::new();
367
368        let session_map_1 = InMemorySessionRepository::new(HashMap::new());
369        let client_1 = IpcClientImpl::new(NoiseCryptoProvider, provider_1, session_map_1);
370        let _ = client_1.start(None).await;
371        let mut recv_1 = client_1.subscribe(None).await.unwrap();
372
373        let session_map_2 = InMemorySessionRepository::new(HashMap::new());
374        let client_2 = IpcClientImpl::new(NoiseCryptoProvider, provider_2, session_map_2);
375        let _ = client_2.start(None).await;
376        let mut recv_2 = client_2.subscribe(None).await.unwrap();
377
378        let handle_1 = tokio::spawn(async move {
379            let mut val: u8 = 0;
380            for _ in 0..255 {
381                let message = OutgoingMessage {
382                    payload: vec![val],
383                    destination: Endpoint::DesktopMain,
384                    topic: None,
385                };
386                client_1.send(message).await.unwrap();
387                let recv_message = recv_1.receive(None).await.unwrap();
388                val = recv_message.payload[0] + 1;
389            }
390        });
391
392        let handle_2 = tokio::spawn(async move {
393            for _ in 0..255 {
394                let recv_message = recv_2.receive(None).await.unwrap();
395                let val = recv_message.payload[0];
396                if val == 255 {
397                    break;
398                }
399
400                client_2
401                    .send(OutgoingMessage {
402                        payload: vec![val],
403                        destination: Endpoint::DesktopMain,
404                        topic: None,
405                    })
406                    .await
407                    .unwrap();
408            }
409        });
410
411        let _ = tokio::join!(handle_1, handle_2);
412    }
413}