bitwarden_ipc/crypto_provider/noise/
transport_state.rs1use serde::{Deserialize, Serialize};
2use serde_bytes::ByteBuf;
3use snow::resolvers::{CryptoResolver, DefaultResolver};
4use tracing::warn;
5
6use crate::crypto_provider::noise::NOISE_MAX_MESSAGE_LEN;
7
8const KEY_SIZE: usize = 32;
10
11#[derive(Default, Debug, Clone, Serialize, Deserialize)]
13pub(crate) enum TransportCipher {
14 ChaCha20Poly1305 = 0,
15 #[default]
16 Aes256Gcm = 1,
17}
18
19#[derive(Clone, Serialize, Deserialize, zeroize::ZeroizeOnDrop)]
21pub(super) struct SymmetricKey(pub(crate) [u8; KEY_SIZE]);
22
23impl std::fmt::Debug for SymmetricKey {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("SymmetricKey").finish()
27 }
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub(crate) struct PersistentTransportState {
32 transport_cipher: TransportCipher,
34
35 send_key: SymmetricKey,
40 receive_key: SymmetricKey,
41
42 send_nonce: u64,
47 receive_nonce: u64,
49
50 last_handshake_time: u64,
51}
52
53impl PersistentTransportState {
54 pub(crate) fn new(
56 send_key: SymmetricKey,
57 receive_key: SymmetricKey,
58 transport_cipher: TransportCipher,
59 ) -> Self {
60 Self {
61 transport_cipher,
62 send_key,
63 receive_key,
64 send_nonce: 0,
65 receive_nonce: 0,
66 last_handshake_time: current_epoch_secs(),
67 }
68 }
69
70 pub(crate) fn should_rehandshake(&self, rehandshake_interval_secs: u64) -> bool {
71 self.is_older_than(current_epoch_secs(), rehandshake_interval_secs)
72 }
73
74 pub(crate) fn is_older_than(&self, now_epoch_secs: u64, max_age_secs: u64) -> bool {
75 now_epoch_secs.saturating_sub(self.last_handshake_time) > max_age_secs
76 }
77
78 #[cfg(test)]
79 pub(crate) fn set_last_handshake_epoch_secs_for_test(&mut self, epoch_secs: u64) {
80 self.last_handshake_time = epoch_secs;
81 }
82}
83
84#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
86pub(crate) struct Payload(pub(crate) Vec<u8>);
87
88impl From<Vec<u8>> for Payload {
89 fn from(v: Vec<u8>) -> Self {
90 Self(v)
91 }
92}
93
94impl AsRef<[u8]> for Payload {
95 fn as_ref(&self) -> &[u8] {
96 &self.0
97 }
98}
99
100impl PersistentTransportState {
101 pub(crate) fn send(&mut self, payload: Payload) -> Result<TransportFrame, ()> {
104 self.send_nonce = self
110 .send_nonce
111 .checked_add(1)
112 .expect("Nonce should never overflow. It is impossible to send 2^64 messages within the lifetime of a session.");
113
114 let encrypted_message = self.encrypt(&self.send_key, self.send_nonce, &payload);
115
116 Ok(TransportFrame {
117 payload: encrypted_message.into(),
118 nonce: self.send_nonce,
119 })
120 }
121
122 pub(crate) fn receive(
124 &mut self,
125 transport_frame: &TransportFrame,
126 ) -> Result<Payload, ReceiveError> {
127 if transport_frame.nonce > self.receive_nonce {
128 if let Ok(plaintext) = self.try_decrypt(&self.receive_key, transport_frame) {
129 self.receive_nonce = transport_frame.nonce;
130 Ok(Payload(plaintext))
131 } else {
132 warn!("Failed to decrypt incoming IPC message");
133 Err(ReceiveError::Decryption)
134 }
135 } else {
136 warn!("Ipc message was replayed! Discarding...");
137 Err(ReceiveError::NonceReplay)
138 }
139 }
140
141 fn encrypt(&self, key: &SymmetricKey, nonce: u64, payload: &Payload) -> Vec<u8> {
142 let mut buffer = vec![0u8; NOISE_MAX_MESSAGE_LEN];
143 let cipher = get_cipher_with_key(key, &self.transport_cipher);
144 let len = cipher.encrypt(nonce, &[], payload.as_ref(), &mut buffer);
145 buffer.truncate(len);
146 buffer
147 }
148
149 fn try_decrypt(
150 &self,
151 key: &SymmetricKey,
152 transport_message: &TransportFrame,
153 ) -> Result<Vec<u8>, ()> {
154 let mut buffer = vec![0u8; NOISE_MAX_MESSAGE_LEN];
155 let cipher = get_cipher_with_key(key, &self.transport_cipher);
156 let len = cipher
157 .decrypt(
158 transport_message.nonce,
159 &[],
160 &transport_message.payload,
161 &mut buffer,
162 )
163 .map_err(|_| ())?;
164 Ok(buffer[..len].to_vec())
165 }
166}
167
168#[derive(Debug, Clone)]
169pub(crate) enum ReceiveError {
170 NonceReplay,
171 Decryption,
172}
173
174pub(crate) fn current_epoch_secs() -> u64 {
176 #[cfg(target_family = "wasm")]
177 {
178 js_sys::Date::now() as u64 / 1000
179 }
180 #[cfg(not(target_arch = "wasm32"))]
181 {
182 std::time::SystemTime::now()
183 .duration_since(std::time::UNIX_EPOCH)
184 .expect("System clock is before Unix epoch")
185 .as_secs()
186 }
187}
188
189fn get_cipher_with_key(
190 key: &SymmetricKey,
191 cipher: &TransportCipher,
192) -> Box<dyn snow::types::Cipher> {
193 let resolver = DefaultResolver;
194 let snow_cipher = match cipher {
195 TransportCipher::ChaCha20Poly1305 => &snow::params::CipherChoice::ChaChaPoly,
196 TransportCipher::Aes256Gcm => &snow::params::CipherChoice::AESGCM,
197 };
198 let mut cipher = resolver
199 .resolve_cipher(snow_cipher)
200 .expect("Cipher should be supported by the resolver");
201 cipher.set(&key.0);
202 cipher
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub(crate) struct TransportFrame {
208 pub(crate) payload: ByteBuf,
209 pub(crate) nonce: u64,
210}
211
212#[cfg(test)]
213pub(crate) fn assert_matching_pair(
214 state_1: &PersistentTransportState,
215 state_2: &PersistentTransportState,
216) {
217 assert_eq!(state_1.send_key.0, state_2.receive_key.0);
218 assert_eq!(state_1.receive_key.0, state_2.send_key.0);
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 fn test_keys() -> (SymmetricKey, SymmetricKey) {
226 let send_key = SymmetricKey([1u8; KEY_SIZE]);
227 let receive_key = SymmetricKey([2u8; KEY_SIZE]);
228 (send_key, receive_key)
229 }
230
231 fn make_pair() -> (PersistentTransportState, PersistentTransportState) {
232 let (send_key, receive_key) = test_keys();
233 let sender = PersistentTransportState::new(
234 send_key.clone(),
235 receive_key.clone(),
236 TransportCipher::default(),
237 );
238 let receiver =
239 PersistentTransportState::new(receive_key, send_key, TransportCipher::default());
240 (sender, receiver)
241 }
242
243 #[test]
244 fn test_send_and_receive_payload() {
245 let (mut sender, mut receiver) = make_pair();
246
247 let payload: Payload = b"ping".to_vec().into();
248 let frame = sender.send(payload).expect("send should succeed");
249 let received = receiver.receive(&frame).expect("receive should succeed");
250
251 assert_eq!(received.as_ref(), b"ping");
252 }
253
254 #[test]
255 fn test_send_and_receive_multiple_messages() {
256 let (mut sender, mut receiver) = make_pair();
257
258 for i in 0..5 {
259 let payload: Payload = format!("msg-{i}").into_bytes().into();
260 let frame = sender.send(payload).expect("send should succeed");
261 let received = receiver.receive(&frame).expect("receive should succeed");
262 assert_eq!(received.as_ref(), format!("msg-{i}").as_bytes());
263 }
264 }
265
266 #[test]
267 fn test_nonce_replay_is_rejected() {
268 let (mut sender, mut receiver) = make_pair();
269
270 let payload: Payload = b"first".to_vec().into();
271 let frame = sender.send(payload).expect("send should succeed");
272
273 let replayed_frame = frame.clone();
275 receiver
276 .receive(&frame)
277 .expect("first receive should succeed");
278
279 let result = receiver.receive(&replayed_frame);
281 assert!(result.is_err(), "replayed frame must be rejected");
282 }
283
284 #[test]
285 fn test_old_nonce_is_rejected() {
286 let (mut sender, mut receiver) = make_pair();
287
288 let msg1: Payload = b"first".to_vec().into();
290 let msg2: Payload = b"second".to_vec().into();
291 let frame1 = sender.send(msg1).expect("send should succeed");
292 let frame2 = sender.send(msg2).expect("send should succeed");
293
294 receiver.receive(&frame2).expect("receive should succeed");
296
297 let result = receiver.receive(&frame1);
299 assert!(result.is_err(), "out-of-order lower nonce must be rejected");
300 }
301
302 #[test]
303 fn test_decryption_with_tampered_ciphertext_fails() {
304 let (mut sender, mut receiver) = make_pair();
305
306 let payload: Payload = b"important".to_vec().into();
307 let mut frame = sender.send(payload).expect("send should succeed");
308
309 frame.payload[0] ^= 0xFF;
311
312 let result = receiver.receive(&frame);
313 assert!(result.is_err(), "tampered ciphertext must fail decryption");
314 }
315
316 #[test]
317 fn test_is_older_than_returns_false_when_younger_than_threshold() {
318 let (mut state, _) = make_pair();
319 state.set_last_handshake_epoch_secs_for_test(100);
320
321 let is_expired = state.is_older_than(150, 60);
322 assert!(!is_expired, "session newer than threshold must not expire");
323 }
324
325 #[test]
326 fn test_is_older_than_returns_false_when_equal_to_threshold() {
327 let (mut state, _) = make_pair();
328 state.set_last_handshake_epoch_secs_for_test(100);
329
330 let is_expired = state.is_older_than(160, 60);
331 assert!(!is_expired, "session equal to threshold must not expire");
332 }
333
334 #[test]
335 fn test_is_older_than_returns_true_when_older_than_threshold() {
336 let (mut state, _) = make_pair();
337 state.set_last_handshake_epoch_secs_for_test(100);
338
339 let is_expired = state.is_older_than(161, 60);
340 assert!(is_expired, "session older than threshold must expire");
341 }
342
343 #[test]
344 fn test_is_older_than_handles_clock_rollback_with_saturating_subtraction() {
345 let (mut state, _) = make_pair();
346 state.set_last_handshake_epoch_secs_for_test(200);
347
348 let is_expired = state.is_older_than(100, 60);
349 assert!(
350 !is_expired,
351 "clock rollback should not underflow or force expiry"
352 );
353 }
354}