diff --git a/quiche/src/crypto/boringssl.rs b/quiche/src/crypto/boringssl.rs index 8eaed89d8a..447538fdef 100644 --- a/quiche/src/crypto/boringssl.rs +++ b/quiche/src/crypto/boringssl.rs @@ -1,5 +1,7 @@ use super::*; +use std::convert::TryFrom; + use std::mem::MaybeUninit; use libc::c_int; @@ -195,38 +197,54 @@ impl HeaderProtectionKey { } } - pub fn new_mask(&self, sample: &[u8]) -> Result<[u8; 5]> { - let mut new_mask = [0_u8; 5]; - + pub fn new_mask(&self, sample: &[u8]) -> Result { match self { - Self::Aes(aes_key) => unsafe { - AES_ecb_encrypt( - sample.as_ptr(), - new_mask.as_mut_ptr(), - aes_key as _, - 1, - ); + Self::Aes(aes_key) => { + let mut block = [0_u8; 16]; + + unsafe { + AES_ecb_encrypt( + sample.as_ptr(), + block.as_mut_ptr(), + aes_key as _, + 1, + ) + }; + + // Downsize the encrypted block to the size of the header + // protection mask. + // + // The length of the slice will always match the size of + // `HeaderProtectionMask` so the `unwrap()` is safe. + let new_mask = + HeaderProtectionMask::try_from(&block[..HP_MASK_LEN]) + .unwrap(); + Ok(new_mask) }, - Self::ChaCha(key) => unsafe { - const PLAINTEXT: &[u8; 5] = &[0_u8; 5]; + Self::ChaCha(key) => { + const PLAINTEXT: &[u8; HP_MASK_LEN] = &[0_u8; HP_MASK_LEN]; + + let mut new_mask = HeaderProtectionMask::default(); let counter = u32::from_le_bytes([ sample[0], sample[1], sample[2], sample[3], ]); - CRYPTO_chacha_20( - new_mask.as_mut_ptr(), - PLAINTEXT.as_ptr(), - PLAINTEXT.len(), - key.as_ptr(), - sample[std::mem::size_of::()..].as_ptr(), - counter, - ); + unsafe { + CRYPTO_chacha_20( + new_mask.as_mut_ptr(), + PLAINTEXT.as_ptr(), + PLAINTEXT.len(), + key.as_ptr(), + sample[std::mem::size_of::()..].as_ptr(), + counter, + ); + }; + + Ok(new_mask) }, } - - Ok(new_mask) } } diff --git a/quiche/src/crypto/mod.rs b/quiche/src/crypto/mod.rs index 367215ae60..cbe56714e1 100644 --- a/quiche/src/crypto/mod.rs +++ b/quiche/src/crypto/mod.rs @@ -35,6 +35,9 @@ use crate::packet; // All the AEAD algorithms we support use 96-bit nonces. pub const MAX_NONCE_LEN: usize = 12; +// Length of header protection mask. +pub const HP_MASK_LEN: usize = 5; + #[repr(C)] #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Level { @@ -120,6 +123,8 @@ struct EVP_MD { _unused: c_void, } +type HeaderProtectionMask = [u8; HP_MASK_LEN]; + pub struct Open { alg: Algorithm, diff --git a/quiche/src/crypto/openssl_quictls.rs b/quiche/src/crypto/openssl_quictls.rs index 2888ecbbc6..c4bc5cb9c3 100644 --- a/quiche/src/crypto/openssl_quictls.rs +++ b/quiche/src/crypto/openssl_quictls.rs @@ -316,10 +316,10 @@ impl HeaderProtectionKey { }) } - pub fn new_mask(&self, sample: &[u8]) -> Result<[u8; 5]> { + pub fn new_mask(&self, sample: &[u8]) -> Result { const PLAINTEXT: &[u8; 5] = &[0_u8; 5]; - let mut new_mask = [0_u8; 5]; + let mut new_mask = HeaderProtectionMask::default(); // Set IV (i.e. the sample). let rc = unsafe {