1use std::{
14 io,
15 pin::Pin,
16 task::{Context, Poll},
17};
18
19use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
20
21use crate::{
22 CryptoError, KeySlotIds, KeyStoreContext, SymmetricCryptoKey,
23 stream::{
24 ChunkDecryptionResult, ChunkEncryptionResult, StreamDecryptionError, StreamEncryptionError,
25 StreamingDecryptor, StreamingEncryptor,
26 aes256_cbc_hmac_legacy_stream::{
27 StreamingAes256CbcHmacDecryptor, StreamingAes256CbcHmacEncryptor,
28 },
29 },
30};
31
32enum HeaderDiscriminator {
33 Aes256CbcHmacLegacyStream = 0x02,
34}
35
36impl From<HeaderDiscriminator> for u8 {
37 fn from(value: HeaderDiscriminator) -> Self {
38 value as u8
39 }
40}
41
42struct UnknownDiscriminator;
43
44impl TryFrom<u8> for HeaderDiscriminator {
45 type Error = UnknownDiscriminator;
46
47 fn try_from(value: u8) -> Result<Self, Self::Error> {
48 match value {
49 0x02 => Ok(HeaderDiscriminator::Aes256CbcHmacLegacyStream),
50 _ => Err(UnknownDiscriminator),
51 }
52 }
53}
54
55const READ_SCRATCH_SIZE: usize = 8 * 1024;
56
57enum StreamDecryptorState {
61 NeedDiscriminator { key: SymmetricCryptoKey },
63 Aes256CbcHmacLegacyStream {
64 decryptor: Box<StreamingAes256CbcHmacDecryptor>,
65 },
66 Done,
68 Error,
70}
71
72pub struct StreamingAttachmentDecryptor<R> {
76 inner: R,
77 state: StreamDecryptorState,
78 plaintext_buf: Vec<u8>,
80 inner_eof: bool,
82}
83
84impl<R> StreamingAttachmentDecryptor<R> {
85 pub fn new<Ids: KeySlotIds>(
88 key_slot: Ids::Symmetric,
89 ctx: KeyStoreContext<Ids>,
90 inner: R,
91 ) -> Result<Self, CryptoError> {
92 let key = ctx.get_symmetric_key(key_slot)?;
93 match &key {
94 SymmetricCryptoKey::Aes256CbcHmacKey(_) => Ok(Self {
95 inner,
96 state: StreamDecryptorState::NeedDiscriminator {
97 key: key.to_owned(),
98 },
99 plaintext_buf: Vec::new(),
100 inner_eof: false,
101 }),
102 _ => Err(CryptoError::OperationNotSupported(
103 crate::error::UnsupportedOperationError::EncryptionNotImplementedForKey,
104 )),
105 }
106 }
107
108 fn drain_plaintext_into(&mut self, buf: &mut ReadBuf<'_>) -> bool {
111 let bytes_to_copy = std::cmp::min(buf.remaining(), self.plaintext_buf.len());
112 if bytes_to_copy == 0 {
114 return false;
115 }
116
117 buf.put_slice(&self.plaintext_buf[..bytes_to_copy]);
118 self.plaintext_buf.drain(..bytes_to_copy);
119 true
120 }
121
122 fn feed_bytes_to_decryptor(&mut self, mut data: &[u8]) -> io::Result<()> {
123 if data.is_empty() {
124 return Ok(());
125 }
126
127 if let StreamDecryptorState::NeedDiscriminator { key } = &self.state {
128 let discriminator_byte = HeaderDiscriminator::try_from(data[0]).map_err(|_| {
129 io::Error::other("streaming attachment: unknown header discriminator byte")
130 })?;
131 data = &data[1..];
132
133 match discriminator_byte {
134 HeaderDiscriminator::Aes256CbcHmacLegacyStream => {
135 let decryptor =
136 StreamingAes256CbcHmacDecryptor::try_new(key).map_err(|_| {
137 io::Error::other(
138 "streaming attachment: key does not match discriminator 0x02",
139 )
140 })?;
141 self.state = StreamDecryptorState::Aes256CbcHmacLegacyStream {
142 decryptor: Box::new(decryptor),
143 };
144 }
145 }
146 }
147
148 if data.is_empty() {
149 return Ok(());
150 }
151
152 match &mut self.state {
153 StreamDecryptorState::Aes256CbcHmacLegacyStream { decryptor: dec } => {
154 match dec.update(data, false) {
155 ChunkDecryptionResult::NeedMoreData => Ok(()),
156 ChunkDecryptionResult::DecryptedChunk(bytes) => {
157 self.plaintext_buf.extend_from_slice(&bytes);
158 Ok(())
159 }
160 ChunkDecryptionResult::FinalDecryptedChunk(bytes) => {
161 self.plaintext_buf.extend_from_slice(&bytes);
162 self.state = StreamDecryptorState::Done;
163 Ok(())
164 }
165 ChunkDecryptionResult::Error(e) => {
166 self.state = StreamDecryptorState::Error;
167 Err(io::Error::other(e))
168 }
169 }
170 }
171 StreamDecryptorState::Error | StreamDecryptorState::Done => Ok(()),
172 StreamDecryptorState::NeedDiscriminator { .. } => unreachable!("handled above"),
173 }
174 }
175
176 fn finalize_underlying(&mut self) -> io::Result<()> {
177 match std::mem::replace(&mut self.state, StreamDecryptorState::Error) {
178 StreamDecryptorState::NeedDiscriminator { .. } => Err(io::Error::other(
179 "streaming attachment: truncated before discriminator",
180 )),
181 StreamDecryptorState::Aes256CbcHmacLegacyStream { decryptor: mut dec } => {
182 match dec.update(&[], true) {
183 ChunkDecryptionResult::FinalDecryptedChunk(bytes) => {
184 self.plaintext_buf.extend_from_slice(&bytes);
185 self.state = StreamDecryptorState::Done;
186 Ok(())
187 }
188 ChunkDecryptionResult::Error(e) => Err(io::Error::other(e)),
189 ChunkDecryptionResult::NeedMoreData
191 | ChunkDecryptionResult::DecryptedChunk(_) => {
192 Err(io::Error::other(StreamDecryptionError))
193 }
194 }
195 }
196 StreamDecryptorState::Done => {
197 self.state = StreamDecryptorState::Done;
198 Ok(())
199 }
200 StreamDecryptorState::Error => {
201 Err(io::Error::other("streaming attachment: decryption error"))
202 }
203 }
204 }
205}
206
207impl<R: AsyncRead + Unpin> AsyncRead for StreamingAttachmentDecryptor<R> {
208 fn poll_read(
209 self: Pin<&mut Self>,
210 cx: &mut Context<'_>,
211 buf: &mut ReadBuf<'_>,
212 ) -> Poll<io::Result<()>> {
213 let this = self.get_mut();
214
215 loop {
216 if this.drain_plaintext_into(buf) {
218 return Poll::Ready(Ok(()));
219 }
220
221 if matches!(this.state, StreamDecryptorState::Error) {
223 return Poll::Ready(Err(io::Error::other(
224 "streaming attachment: decryption error",
225 )));
226 }
227
228 if matches!(this.state, StreamDecryptorState::Done) {
230 return Poll::Ready(Ok(()));
231 }
232
233 if this.inner_eof {
235 if let Err(e) = this.finalize_underlying() {
236 return Poll::Ready(Err(e));
237 }
238 continue;
239 }
240
241 let mut scratch = [0u8; READ_SCRATCH_SIZE];
243 let mut scratch_buf = ReadBuf::new(&mut scratch);
244 match Pin::new(&mut this.inner).poll_read(cx, &mut scratch_buf) {
245 Poll::Pending => return Poll::Pending,
246 Poll::Ready(Err(e)) => {
247 this.state = StreamDecryptorState::Error;
248 return Poll::Ready(Err(e));
249 }
250 Poll::Ready(Ok(())) => {
251 let filled = scratch_buf.filled();
252 if filled.is_empty() {
253 this.inner_eof = true;
254 } else if let Err(e) = this.feed_bytes_to_decryptor(filled) {
255 return Poll::Ready(Err(e));
256 }
257 }
258 }
259 }
260 }
261}
262
263fn clone_io_error(e: &io::Error) -> io::Error {
266 io::Error::new(e.kind(), e.to_string())
267}
268
269enum StreamEncryptorState {
270 Aes256CbcHmacLegacyStream {
271 encryptor: Box<StreamingAes256CbcHmacEncryptor>,
272 },
273 Finalized,
276 Done,
278 Error(io::Error),
280}
281
282pub struct StreamingAttachmentEncryptor<W> {
286 inner: W,
287 state: StreamEncryptorState,
288 pending_write: Vec<u8>,
292 pending_head: usize,
293}
294
295impl<W> StreamingAttachmentEncryptor<W> {
296 pub fn new<Ids: KeySlotIds>(
299 key_slot: Ids::Symmetric,
300 ctx: KeyStoreContext<Ids>,
301 inner: W,
302 plaintext_size: usize,
305 ) -> Result<Self, CryptoError> {
306 let key = ctx.get_symmetric_key(key_slot)?;
307 let (state, discriminator): (StreamEncryptorState, HeaderDiscriminator) = match &key {
308 SymmetricCryptoKey::Aes256CbcHmacKey(_) => {
309 let encryptor = StreamingAes256CbcHmacEncryptor::try_new(key, plaintext_size)
310 .map_err(|_| {
311 CryptoError::OperationNotSupported(
312 crate::error::UnsupportedOperationError::EncryptionNotImplementedForKey,
313 )
314 })?;
315 (
316 StreamEncryptorState::Aes256CbcHmacLegacyStream {
317 encryptor: Box::new(encryptor),
318 },
319 HeaderDiscriminator::Aes256CbcHmacLegacyStream,
320 )
321 }
322 _ => {
323 return Err(CryptoError::OperationNotSupported(
324 crate::error::UnsupportedOperationError::EncryptionNotImplementedForKey,
325 ));
326 }
327 };
328
329 Ok(Self {
330 inner,
331 state,
332 pending_write: vec![discriminator.into()],
333 pending_head: 0,
334 })
335 }
336}
337
338impl<W: AsyncWrite + Unpin> StreamingAttachmentEncryptor<W> {
339 fn poll_drain_pending(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
341 while self.pending_head < self.pending_write.len() {
342 let to_write = &self.pending_write[self.pending_head..];
343 match Pin::new(&mut self.inner).poll_write(cx, to_write) {
344 Poll::Pending => return Poll::Pending,
345 Poll::Ready(Err(e)) => {
346 self.state = StreamEncryptorState::Error(clone_io_error(&e));
347 return Poll::Ready(Err(e));
348 }
349 Poll::Ready(Ok(0)) => {
350 let e = io::Error::new(
351 io::ErrorKind::WriteZero,
352 "streaming attachment: inner writer accepted 0 bytes",
353 );
354 self.state = StreamEncryptorState::Error(clone_io_error(&e));
355 return Poll::Ready(Err(e));
356 }
357 Poll::Ready(Ok(n)) => {
358 self.pending_head += n;
359 }
360 }
361 }
362 self.pending_write.clear();
363 self.pending_head = 0;
364 Poll::Ready(Ok(()))
365 }
366}
367
368impl<W: AsyncWrite + Unpin> AsyncWrite for StreamingAttachmentEncryptor<W> {
369 fn poll_write(
370 self: Pin<&mut Self>,
371 cx: &mut Context<'_>,
372 buf: &[u8],
373 ) -> Poll<io::Result<usize>> {
374 let this = self.get_mut();
375
376 if let StreamEncryptorState::Error(e) = &this.state {
377 return Poll::Ready(Err(clone_io_error(e)));
378 }
379
380 if matches!(
381 this.state,
382 StreamEncryptorState::Finalized | StreamEncryptorState::Done
383 ) {
384 return Poll::Ready(Err(io::Error::other(
385 "streaming attachment: write after shutdown",
386 )));
387 }
388
389 if this.poll_drain_pending(cx).is_pending() {
391 return Poll::Pending;
392 }
393 if let StreamEncryptorState::Error(e) = &this.state {
394 return Poll::Ready(Err(clone_io_error(e)));
395 }
396
397 if buf.is_empty() {
398 return Poll::Ready(Ok(0));
399 }
400
401 let result = match &mut this.state {
402 StreamEncryptorState::Aes256CbcHmacLegacyStream { encryptor: enc } => {
403 enc.update(buf, false)
404 }
405 _ => unreachable!("state checked above"),
406 };
407
408 match result {
409 ChunkEncryptionResult::NeedMoreData => Poll::Ready(Ok(buf.len())),
410 ChunkEncryptionResult::EncryptedChunk(bytes) => {
411 this.pending_write = bytes;
412 this.pending_head = 0;
413 Poll::Ready(Ok(buf.len()))
414 }
415 ChunkEncryptionResult::FinalEncryptedChunk(bytes) => {
416 this.pending_write = bytes;
417 this.pending_head = 0;
418 this.state = StreamEncryptorState::Finalized;
419 Poll::Ready(Ok(buf.len()))
420 }
421 ChunkEncryptionResult::Error(e) => {
422 let err = io::Error::other(e);
423 this.state = StreamEncryptorState::Error(clone_io_error(&err));
424 Poll::Ready(Err(err))
425 }
426 }
427 }
428
429 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
430 let this = self.get_mut();
431
432 if this.poll_drain_pending(cx).is_pending() {
435 return Poll::Pending;
436 }
437
438 if let StreamEncryptorState::Error(e) = &this.state {
439 return Poll::Ready(Err(clone_io_error(e)));
440 }
441
442 Pin::new(&mut this.inner).poll_flush(cx)
443 }
444
445 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
446 let this = self.get_mut();
447
448 if this.poll_drain_pending(cx).is_pending() {
450 return Poll::Pending;
451 }
452 if let StreamEncryptorState::Error(e) = &this.state {
453 return Poll::Ready(Err(clone_io_error(e)));
454 }
455
456 if matches!(
458 this.state,
459 StreamEncryptorState::Aes256CbcHmacLegacyStream { .. }
460 ) {
461 let old = std::mem::replace(
462 &mut this.state,
463 StreamEncryptorState::Error(io::Error::other(
464 "streaming attachment: encryptor finalizing",
465 )),
466 );
467 let StreamEncryptorState::Aes256CbcHmacLegacyStream { encryptor: mut enc } = old else {
468 unreachable!("matched above");
469 };
470
471 let mut wire = Vec::new();
472 loop {
473 match enc.update(&[], true) {
474 ChunkEncryptionResult::EncryptedChunk(bytes) => wire.extend_from_slice(&bytes),
475 ChunkEncryptionResult::FinalEncryptedChunk(bytes) => {
476 wire.extend_from_slice(&bytes);
477 break;
478 }
479 ChunkEncryptionResult::Error(e) => {
480 let err = io::Error::other(e);
481 this.state = StreamEncryptorState::Error(clone_io_error(&err));
482 return Poll::Ready(Err(err));
483 }
484 ChunkEncryptionResult::NeedMoreData => {
485 let err = io::Error::other(StreamEncryptionError);
486 this.state = StreamEncryptorState::Error(clone_io_error(&err));
487 return Poll::Ready(Err(err));
488 }
489 }
490 }
491
492 this.pending_write = wire;
493 this.pending_head = 0;
494 this.state = StreamEncryptorState::Finalized;
495 }
496
497 if this.poll_drain_pending(cx).is_pending() {
499 return Poll::Pending;
500 }
501 if let StreamEncryptorState::Error(e) = &this.state {
502 return Poll::Ready(Err(clone_io_error(e)));
503 }
504
505 match Pin::new(&mut this.inner).poll_shutdown(cx) {
507 Poll::Pending => Poll::Pending,
508 Poll::Ready(Ok(())) => {
509 this.state = StreamEncryptorState::Done;
510 Poll::Ready(Ok(()))
511 }
512 Poll::Ready(Err(e)) => {
513 this.state = StreamEncryptorState::Error(clone_io_error(&e));
514 Poll::Ready(Err(e))
515 }
516 }
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use std::{
523 pin::Pin,
524 sync::{Arc, Mutex},
525 };
526
527 use tokio::io::{AsyncReadExt, AsyncWriteExt};
528
529 use super::*;
530 use crate::{Aes256CbcHmacKey, KeyStore, traits::tests::TestIds};
531
532 #[derive(Clone)]
535 struct SharedSink(Arc<Mutex<Vec<u8>>>);
536
537 impl AsyncWrite for SharedSink {
538 fn poll_write(
539 self: Pin<&mut Self>,
540 _cx: &mut Context<'_>,
541 buf: &[u8],
542 ) -> Poll<io::Result<usize>> {
543 self.0
544 .lock()
545 .expect("mutex poisoned")
546 .extend_from_slice(buf);
547 Poll::Ready(Ok(buf.len()))
548 }
549 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
550 Poll::Ready(Ok(()))
551 }
552 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
553 Poll::Ready(Ok(()))
554 }
555 }
556
557 fn aes_key() -> SymmetricCryptoKey {
558 SymmetricCryptoKey::Aes256CbcHmacKey(Aes256CbcHmacKey {
559 enc_key: Box::pin([0u8; 32].into()),
560 mac_key: Box::pin([1u8; 32].into()),
561 })
562 }
563
564 async fn encrypt_via_shared(key: SymmetricCryptoKey, plaintext: &[u8]) -> Vec<u8> {
566 let shared = Arc::new(Mutex::new(Vec::<u8>::new()));
567 let sink = SharedSink(shared.clone());
568 let mut enc = {
569 let key_store: KeyStore<TestIds> = KeyStore::default();
570 let mut ctx = key_store.context_mut();
571 let key_slot = ctx.add_local_symmetric_key(key);
572 StreamingAttachmentEncryptor::new(key_slot, ctx, sink, plaintext.len())
573 .expect("encryptor construction")
574 };
575 enc.write_all(plaintext).await.expect("write_all");
576 enc.shutdown().await.expect("shutdown");
577 shared.lock().expect("mutex poisoned").clone()
578 }
579
580 async fn decrypt_wire(key: SymmetricCryptoKey, wire: &[u8]) -> io::Result<Vec<u8>> {
581 let mut dec = {
582 let key_store: KeyStore<TestIds> = KeyStore::default();
583 let mut ctx = key_store.context_mut();
584 let key_slot = ctx.add_local_symmetric_key(key);
585 StreamingAttachmentDecryptor::new(key_slot, ctx, wire).expect("decryptor construction")
586 };
587 let mut out = Vec::new();
588 dec.read_to_end(&mut out).await?;
589 Ok(out)
590 }
591
592 const PLAINTEXT_SHORT: &[u8] =
593 b"streaming attachment cipher: AsyncRead/AsyncWrite roundtrip test plaintext.";
594
595 #[tokio::test]
596 async fn aes_cbc_hmac_roundtrip() {
597 let wire = encrypt_via_shared(aes_key(), PLAINTEXT_SHORT).await;
598 assert_eq!(
599 wire.first().copied(),
600 Some(HeaderDiscriminator::Aes256CbcHmacLegacyStream.into()),
601 "wire should start with the AES-CBC-HMAC discriminator"
602 );
603 let roundtripped = decrypt_wire(aes_key(), &wire).await.expect("decrypt");
604 assert_eq!(roundtripped, PLAINTEXT_SHORT);
605 }
606
607 #[tokio::test]
608 async fn aes_cbc_hmac_roundtrip_1_mib() {
609 let plaintext: Vec<u8> = (0..(1024 * 1024)).map(|i| (i % 251) as u8).collect();
611 let wire = encrypt_via_shared(aes_key(), &plaintext).await;
612 let roundtripped = decrypt_wire(aes_key(), &wire).await.expect("decrypt");
613 assert_eq!(roundtripped, plaintext);
614 }
615
616 #[tokio::test]
617 async fn unknown_discriminator_byte_fails() {
618 let mut wire = vec![0xFFu8];
620 wire.extend_from_slice(&[0u8; 32]);
621 let err = decrypt_wire(aes_key(), &wire)
622 .await
623 .expect_err("expected error for unknown discriminator");
624 assert_eq!(err.kind(), io::ErrorKind::Other);
625 }
626
627 #[tokio::test]
628 async fn truncated_wire_fails_aes() {
629 let wire = encrypt_via_shared(aes_key(), PLAINTEXT_SHORT).await;
630 let truncated = &wire[..wire.len() - 10];
631 let err = decrypt_wire(aes_key(), truncated)
632 .await
633 .expect_err("expected error for truncated wire");
634 assert_eq!(err.kind(), io::ErrorKind::Other);
635 }
636
637 #[tokio::test]
638 async fn small_chunked_writes_roundtrip() {
639 let plaintext = PLAINTEXT_SHORT;
642
643 let shared = Arc::new(Mutex::new(Vec::<u8>::new()));
644 let sink = SharedSink(shared.clone());
645 let mut enc = {
646 let key_store: KeyStore<TestIds> = KeyStore::default();
647 let mut ctx = key_store.context_mut();
648 let key_slot = ctx.add_local_symmetric_key(aes_key());
649 StreamingAttachmentEncryptor::new(key_slot, ctx, sink, plaintext.len())
650 .expect("encryptor construction")
651 };
652 for byte in plaintext {
653 enc.write_all(std::slice::from_ref(byte))
654 .await
655 .expect("byte-wise write");
656 }
657 enc.shutdown().await.expect("shutdown");
658 let wire = shared.lock().expect("mutex poisoned").clone();
659
660 let mut dec = {
662 let key_store: KeyStore<TestIds> = KeyStore::default();
663 let mut ctx = key_store.context_mut();
664 let key_slot = ctx.add_local_symmetric_key(aes_key());
665 StreamingAttachmentDecryptor::new(key_slot, ctx, &wire[..])
666 .expect("decryptor construction")
667 };
668 let mut out = Vec::new();
669 let mut tmp = [0u8; 7];
670 loop {
671 let n = dec.read(&mut tmp).await.expect("read");
672 if n == 0 {
673 break;
674 }
675 out.extend_from_slice(&tmp[..n]);
676 }
677 assert_eq!(out, plaintext);
678 }
679
680 #[tokio::test]
681 async fn empty_plaintext_roundtrip_aes() {
682 let wire = encrypt_via_shared(aes_key(), &[]).await;
683 let roundtripped = decrypt_wire(aes_key(), &wire).await.expect("decrypt");
684 assert!(roundtripped.is_empty());
685 }
686}