Skip to content

Commit

Permalink
packet: properly implement packet number encoding/decoding
Browse files Browse the repository at this point in the history
The current naive packet number encoding implementation seems to break after about 2 billions packets sent.
  • Loading branch information
frochet authored Jun 26, 2024
1 parent 23e194f commit 6eb0850
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 23 deletions.
10 changes: 6 additions & 4 deletions quiche/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3613,7 +3613,9 @@ impl Connection {
};

let pn = pkt_space.next_pkt_num;
let pn_len = packet::pkt_num_len(pn)?;
let largest_acked_pkt =
path.recovery.get_largest_acked_on_epoch(epoch).unwrap_or(0);
let pn_len = packet::pkt_num_len(pn, largest_acked_pkt);

// The AEAD overhead at the current encryption level.
let crypto_overhead = pkt_space.crypto_overhead().ok_or(Error::Done)?;
Expand Down Expand Up @@ -3730,7 +3732,7 @@ impl Connection {
b.skip(PAYLOAD_LENGTH_LEN)?;
}

packet::encode_pkt_num(pn, &mut b)?;
packet::encode_pkt_num(pn, pn_len, &mut b)?;

let payload_offset = b.off();

Expand Down Expand Up @@ -12054,7 +12056,7 @@ mod tests {
let epoch = packet::Type::Initial.to_epoch().unwrap();

let pn = 0;
let pn_len = packet::pkt_num_len(pn).unwrap();
let pn_len = packet::pkt_num_len(pn, 0);

let dcid = pipe.client.destination_id();
let scid = pipe.client.source_id();
Expand All @@ -12079,7 +12081,7 @@ mod tests {
let len = pn_len + payload_len;
b.put_varint(len as u64).unwrap();

packet::encode_pkt_num(pn, &mut b).unwrap();
packet::encode_pkt_num(pn, pn_len, &mut b).unwrap();

let payload_offset = b.off();

Expand Down
52 changes: 33 additions & 19 deletions quiche/src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,20 +555,12 @@ impl<'a> std::fmt::Debug for Header<'a> {
}
}

pub fn pkt_num_len(pn: u64) -> Result<usize> {
let len = if pn < u64::from(u8::MAX) {
1
} else if pn < u64::from(u16::MAX) {
2
} else if pn < 16_777_215u64 {
3
} else if pn < u64::from(u32::MAX) {
4
} else {
return Err(Error::InvalidPacket);
};

Ok(len)
pub fn pkt_num_len(pn: u64, largest_acked: u64) -> usize {
let num_unacked: u64 = pn.saturating_sub(largest_acked) + 1;
// computes ceil of num_unacked.log2()
let min_bits = u64::BITS - num_unacked.leading_zeros();
// get the num len in bytes
((min_bits + 7) / 8) as usize
}

pub fn decrypt_hdr(
Expand Down Expand Up @@ -713,10 +705,10 @@ pub fn encrypt_pkt(
Ok(payload_offset + ciphertext_len)
}

pub fn encode_pkt_num(pn: u64, b: &mut octets::OctetsMut) -> Result<()> {
let len = pkt_num_len(pn)?;

match len {
pub fn encode_pkt_num(
pn: u64, pn_len: usize, b: &mut octets::OctetsMut,
) -> Result<()> {
match pn_len {
1 => b.put_u8(pn as u8)?,

2 => b.put_u16(pn as u16)?,
Expand Down Expand Up @@ -1173,9 +1165,31 @@ mod tests {
}

#[test]
fn pkt_num_decode() {
fn pkt_num_encode_decode() {
let num_len = pkt_num_len(0, 0);
assert_eq!(num_len, 1);
let pn = decode_pkt_num(0xa82f30ea, 0x9b32, 2);
assert_eq!(pn, 0xa82f9b32);
let mut d = [0; 10];
let mut b = octets::OctetsMut::with_slice(&mut d);
let num_len = pkt_num_len(0xac5c02, 0xabe8b3);
assert_eq!(num_len, 2);
encode_pkt_num(0xac5c02, num_len, &mut b).unwrap();
// reading
let mut b = octets::OctetsMut::with_slice(&mut d);
let hdr_num = u64::from(b.get_u16().unwrap());
let pn = decode_pkt_num(0xac5c01, hdr_num, num_len);
assert_eq!(pn, 0xac5c02);
// sending 0xace8fe while having 0xabe8b3 acked
let num_len = pkt_num_len(0xace9fe, 0xabe8b3);
assert_eq!(num_len, 3);
let mut b = octets::OctetsMut::with_slice(&mut d);
encode_pkt_num(0xace9fe, num_len, &mut b).unwrap();
// reading
let mut b = octets::OctetsMut::with_slice(&mut d);
let hdr_num = u64::from(b.get_u24().unwrap());
let pn = decode_pkt_num(0xace9fa, hdr_num, num_len);
assert_eq!(pn, 0xace9fe);
}

#[test]
Expand Down
6 changes: 6 additions & 0 deletions quiche/src/recovery/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,12 @@ impl Recovery {
self.epochs[epoch].lost_frames.drain(..)
}

pub fn get_largest_acked_on_epoch(
&self, epoch: packet::Epoch,
) -> Option<u64> {
self.epochs[epoch].largest_acked_packet
}

pub fn has_lost_frames(&self, epoch: packet::Epoch) -> bool {
!self.epochs[epoch].lost_frames.is_empty()
}
Expand Down

0 comments on commit 6eb0850

Please sign in to comment.