Skip to main content

bitwarden_ipc/crypto_provider/noise/
transport_state.rs

1use serde::{Deserialize, Serialize};
2use serde_bytes::ByteBuf;
3use snow::resolvers::{CryptoResolver, DefaultResolver};
4use tracing::warn;
5
6use crate::crypto_provider::noise::NOISE_MAX_MESSAGE_LEN;
7
8// Ref: http://noiseprotocol.org/noise.html#message-format
9const KEY_SIZE: usize = 32;
10
11/// Supported ciphers for the transport mode of noise.
12#[derive(Default, Debug, Clone, Serialize, Deserialize)]
13pub(crate) enum TransportCipher {
14    ChaCha20Poly1305 = 0,
15    #[default]
16    Aes256Gcm = 1,
17}
18
19/// A newtype for symmetric keys used in noise. A noise key is always 256-bits.
20#[derive(Clone, Serialize, Deserialize, zeroize::ZeroizeOnDrop)]
21pub(super) struct SymmetricKey(pub(crate) [u8; KEY_SIZE]);
22
23/// Implement Debug manually to avoid accidentally logging the key material.
24impl std::fmt::Debug for SymmetricKey {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("SymmetricKey").finish()
27    }
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub(crate) struct PersistentTransportState {
32    // The symmetric algorithm used for transport encryption
33    transport_cipher: TransportCipher,
34
35    // Noise has two keys, the initiator to responder key (i2r) and the responder to initiator key
36    // (r2i).
37    // For the initiator, send_key = i2r and receive_key = r2i.
38    // For the responder, send_key = r2i and receive_key = i2r.
39    send_key: SymmetricKey,
40    receive_key: SymmetricKey,
41
42    // Noise transport messages include a nonce that must be unique for every message encrypted
43    // with the same key. The nonce increases monotonically with every sent/received message
44    // and is never reset to a lower value. Re-using nonces results in catastrophic
45    // cryptographic failure.
46    send_nonce: u64,
47    // For receiving, skipping nonces is allowed, but never going back.
48    receive_nonce: u64,
49
50    last_handshake_time: u64,
51}
52
53impl PersistentTransportState {
54    /// Create a new transport state with the given keys and cipher.
55    pub(crate) fn new(
56        send_key: SymmetricKey,
57        receive_key: SymmetricKey,
58        transport_cipher: TransportCipher,
59    ) -> Self {
60        Self {
61            transport_cipher,
62            send_key,
63            receive_key,
64            send_nonce: 0,
65            receive_nonce: 0,
66            last_handshake_time: current_epoch_secs(),
67        }
68    }
69
70    pub(crate) fn should_rehandshake(&self, rehandshake_interval_secs: u64) -> bool {
71        self.is_older_than(current_epoch_secs(), rehandshake_interval_secs)
72    }
73
74    pub(crate) fn is_older_than(&self, now_epoch_secs: u64, max_age_secs: u64) -> bool {
75        now_epoch_secs.saturating_sub(self.last_handshake_time) > max_age_secs
76    }
77
78    #[cfg(test)]
79    pub(crate) fn set_last_handshake_epoch_secs_for_test(&mut self, epoch_secs: u64) {
80        self.last_handshake_time = epoch_secs;
81    }
82}
83
84/// A newtype for a plaintext payload carried inside the encrypted noise tunnel.
85#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
86pub(crate) struct Payload(pub(crate) Vec<u8>);
87
88impl From<Vec<u8>> for Payload {
89    fn from(v: Vec<u8>) -> Self {
90        Self(v)
91    }
92}
93
94impl AsRef<[u8]> for Payload {
95    fn as_ref(&self) -> &[u8] {
96        &self.0
97    }
98}
99
100impl PersistentTransportState {
101    /// Encrypts the message, mutates the state and returns the transport frame to be sent over
102    /// IPC.
103    pub(crate) fn send(&mut self, payload: Payload) -> Result<TransportFrame, ()> {
104        // Increase nonce. WARNING: Re-used nonces lead to catastrophic
105        // crypto failure. Ensure this increases always. It is impossible to send 2^64 messages
106        // within the lifetime of a session. Nonetheless, the cryptographic guarantees are
107        // not upheld, should a nonce ever be re-used, thus we panic in the event that an
108        // overflow would occur.
109        self.send_nonce = self
110            .send_nonce
111            .checked_add(1)
112            .expect("Nonce should never overflow. It is impossible to send 2^64 messages within the lifetime of a session.");
113
114        let encrypted_message = self.encrypt(&self.send_key, self.send_nonce, &payload);
115
116        Ok(TransportFrame {
117            payload: encrypted_message.into(),
118            nonce: self.send_nonce,
119        })
120    }
121
122    /// Decrypts the transport frame, mutates the state and returns the plaintext message.
123    pub(crate) fn receive(
124        &mut self,
125        transport_frame: &TransportFrame,
126    ) -> Result<Payload, ReceiveError> {
127        if transport_frame.nonce > self.receive_nonce {
128            if let Ok(plaintext) = self.try_decrypt(&self.receive_key, transport_frame) {
129                self.receive_nonce = transport_frame.nonce;
130                Ok(Payload(plaintext))
131            } else {
132                warn!("Failed to decrypt incoming IPC message");
133                Err(ReceiveError::Decryption)
134            }
135        } else {
136            warn!("Ipc message was replayed! Discarding...");
137            Err(ReceiveError::NonceReplay)
138        }
139    }
140
141    fn encrypt(&self, key: &SymmetricKey, nonce: u64, payload: &Payload) -> Vec<u8> {
142        let mut buffer = vec![0u8; NOISE_MAX_MESSAGE_LEN];
143        let cipher = get_cipher_with_key(key, &self.transport_cipher);
144        let len = cipher.encrypt(nonce, &[], payload.as_ref(), &mut buffer);
145        buffer.truncate(len);
146        buffer
147    }
148
149    fn try_decrypt(
150        &self,
151        key: &SymmetricKey,
152        transport_message: &TransportFrame,
153    ) -> Result<Vec<u8>, ()> {
154        let mut buffer = vec![0u8; NOISE_MAX_MESSAGE_LEN];
155        let cipher = get_cipher_with_key(key, &self.transport_cipher);
156        let len = cipher
157            .decrypt(
158                transport_message.nonce,
159                &[],
160                &transport_message.payload,
161                &mut buffer,
162            )
163            .map_err(|_| ())?;
164        Ok(buffer[..len].to_vec())
165    }
166}
167
168#[derive(Debug, Clone)]
169pub(crate) enum ReceiveError {
170    NonceReplay,
171    Decryption,
172}
173
174/// Returns the current time as seconds since the Unix epoch.
175pub(crate) fn current_epoch_secs() -> u64 {
176    #[cfg(target_family = "wasm")]
177    {
178        js_sys::Date::now() as u64 / 1000
179    }
180    #[cfg(not(target_arch = "wasm32"))]
181    {
182        std::time::SystemTime::now()
183            .duration_since(std::time::UNIX_EPOCH)
184            .expect("System clock is before Unix epoch")
185            .as_secs()
186    }
187}
188
189fn get_cipher_with_key(
190    key: &SymmetricKey,
191    cipher: &TransportCipher,
192) -> Box<dyn snow::types::Cipher> {
193    let resolver = DefaultResolver;
194    let snow_cipher = match cipher {
195        TransportCipher::ChaCha20Poly1305 => &snow::params::CipherChoice::ChaChaPoly,
196        TransportCipher::Aes256Gcm => &snow::params::CipherChoice::AESGCM,
197    };
198    let mut cipher = resolver
199        .resolve_cipher(snow_cipher)
200        .expect("Cipher should be supported by the resolver");
201    cipher.set(&key.0);
202    cipher
203}
204
205/// Wire format — always encrypted with current symmetric keys.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub(crate) struct TransportFrame {
208    pub(crate) payload: ByteBuf,
209    pub(crate) nonce: u64,
210}
211
212#[cfg(test)]
213pub(crate) fn assert_matching_pair(
214    state_1: &PersistentTransportState,
215    state_2: &PersistentTransportState,
216) {
217    assert_eq!(state_1.send_key.0, state_2.receive_key.0);
218    assert_eq!(state_1.receive_key.0, state_2.send_key.0);
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    fn test_keys() -> (SymmetricKey, SymmetricKey) {
226        let send_key = SymmetricKey([1u8; KEY_SIZE]);
227        let receive_key = SymmetricKey([2u8; KEY_SIZE]);
228        (send_key, receive_key)
229    }
230
231    fn make_pair() -> (PersistentTransportState, PersistentTransportState) {
232        let (send_key, receive_key) = test_keys();
233        let sender = PersistentTransportState::new(
234            send_key.clone(),
235            receive_key.clone(),
236            TransportCipher::default(),
237        );
238        let receiver =
239            PersistentTransportState::new(receive_key, send_key, TransportCipher::default());
240        (sender, receiver)
241    }
242
243    #[test]
244    fn test_send_and_receive_payload() {
245        let (mut sender, mut receiver) = make_pair();
246
247        let payload: Payload = b"ping".to_vec().into();
248        let frame = sender.send(payload).expect("send should succeed");
249        let received = receiver.receive(&frame).expect("receive should succeed");
250
251        assert_eq!(received.as_ref(), b"ping");
252    }
253
254    #[test]
255    fn test_send_and_receive_multiple_messages() {
256        let (mut sender, mut receiver) = make_pair();
257
258        for i in 0..5 {
259            let payload: Payload = format!("msg-{i}").into_bytes().into();
260            let frame = sender.send(payload).expect("send should succeed");
261            let received = receiver.receive(&frame).expect("receive should succeed");
262            assert_eq!(received.as_ref(), format!("msg-{i}").as_bytes());
263        }
264    }
265
266    #[test]
267    fn test_nonce_replay_is_rejected() {
268        let (mut sender, mut receiver) = make_pair();
269
270        let payload: Payload = b"first".to_vec().into();
271        let frame = sender.send(payload).expect("send should succeed");
272
273        // First receive succeeds
274        let replayed_frame = frame.clone();
275        receiver
276            .receive(&frame)
277            .expect("first receive should succeed");
278
279        // Replaying the same frame (same nonce) should fail
280        let result = receiver.receive(&replayed_frame);
281        assert!(result.is_err(), "replayed frame must be rejected");
282    }
283
284    #[test]
285    fn test_old_nonce_is_rejected() {
286        let (mut sender, mut receiver) = make_pair();
287
288        // Send two messages
289        let msg1: Payload = b"first".to_vec().into();
290        let msg2: Payload = b"second".to_vec().into();
291        let frame1 = sender.send(msg1).expect("send should succeed");
292        let frame2 = sender.send(msg2).expect("send should succeed");
293
294        // Receive the second message first (higher nonce)
295        receiver.receive(&frame2).expect("receive should succeed");
296
297        // Now try to receive the first message (lower nonce) — should be rejected
298        let result = receiver.receive(&frame1);
299        assert!(result.is_err(), "out-of-order lower nonce must be rejected");
300    }
301
302    #[test]
303    fn test_decryption_with_tampered_ciphertext_fails() {
304        let (mut sender, mut receiver) = make_pair();
305
306        let payload: Payload = b"important".to_vec().into();
307        let mut frame = sender.send(payload).expect("send should succeed");
308
309        // Tamper with the ciphertext
310        frame.payload[0] ^= 0xFF;
311
312        let result = receiver.receive(&frame);
313        assert!(result.is_err(), "tampered ciphertext must fail decryption");
314    }
315
316    #[test]
317    fn test_is_older_than_returns_false_when_younger_than_threshold() {
318        let (mut state, _) = make_pair();
319        state.set_last_handshake_epoch_secs_for_test(100);
320
321        let is_expired = state.is_older_than(150, 60);
322        assert!(!is_expired, "session newer than threshold must not expire");
323    }
324
325    #[test]
326    fn test_is_older_than_returns_false_when_equal_to_threshold() {
327        let (mut state, _) = make_pair();
328        state.set_last_handshake_epoch_secs_for_test(100);
329
330        let is_expired = state.is_older_than(160, 60);
331        assert!(!is_expired, "session equal to threshold must not expire");
332    }
333
334    #[test]
335    fn test_is_older_than_returns_true_when_older_than_threshold() {
336        let (mut state, _) = make_pair();
337        state.set_last_handshake_epoch_secs_for_test(100);
338
339        let is_expired = state.is_older_than(161, 60);
340        assert!(is_expired, "session older than threshold must expire");
341    }
342
343    #[test]
344    fn test_is_older_than_handles_clock_rollback_with_saturating_subtraction() {
345        let (mut state, _) = make_pair();
346        state.set_last_handshake_epoch_secs_for_test(200);
347
348        let is_expired = state.is_older_than(100, 60);
349        assert!(
350            !is_expired,
351            "clock rollback should not underflow or force expiry"
352        );
353    }
354}