diff --git a/neqo-bin/src/bin/client/http09.rs b/neqo-bin/src/bin/client/http09.rs index 6d9a26fec2..f589dbbe55 100644 --- a/neqo-bin/src/bin/client/http09.rs +++ b/neqo-bin/src/bin/client/http09.rs @@ -143,6 +143,10 @@ impl super::Client for Connection { self.process(dgram, now) } + fn process_input(&mut self, dgram: &Datagram, now: Instant) { + self.process_input(dgram, now); + } + fn close(&mut self, now: Instant, app_error: neqo_transport::AppError, msg: S) where S: AsRef + std::fmt::Display, diff --git a/neqo-bin/src/bin/client/http3.rs b/neqo-bin/src/bin/client/http3.rs index 07cc0e4cde..85fea07578 100644 --- a/neqo-bin/src/bin/client/http3.rs +++ b/neqo-bin/src/bin/client/http3.rs @@ -119,6 +119,10 @@ impl super::Client for Http3Client { self.process(dgram, now) } + fn process_input(&mut self, dgram: &Datagram, now: Instant) { + self.process_input(dgram, now); + } + fn close(&mut self, now: Instant, app_error: AppError, msg: S) where S: AsRef + Display, diff --git a/neqo-bin/src/bin/client/main.rs b/neqo-bin/src/bin/client/main.rs index 7b1a5928a6..4464fe6f02 100644 --- a/neqo-bin/src/bin/client/main.rs +++ b/neqo-bin/src/bin/client/main.rs @@ -322,7 +322,9 @@ trait Handler { /// Network client, e.g. [`neqo_transport::Connection`] or [`neqo_http3::Http3Client`]. trait Client { + // TODO: datagram option needed? fn process(&mut self, dgram: Option<&Datagram>, now: Instant) -> Output; + fn process_input(&mut self, dgram: &Datagram, now: Instant); fn close(&mut self, now: Instant, app_error: AppError, msg: S) where S: AsRef + Display; @@ -367,11 +369,13 @@ impl<'a, H: Handler> Runner<'a, H> { match ready(self.socket, self.timeout.as_mut()).await? { Ready::Socket => loop { let dgrams = self.socket.recv(&self.local_addr)?; - if dgrams.is_empty() { - break; + let mut is_empty = true; + for dgram in dgrams { + is_empty = false; + self.client.process_input(&dgram, Instant::now()); } - for dgram in &dgrams { - self.process(Some(dgram)).await?; + if is_empty { + break; } self.handler.maybe_key_update(&mut self.client)?; }, diff --git a/neqo-bin/src/bin/server/main.rs b/neqo-bin/src/bin/server/main.rs index f694cf98c1..61d9ca69e7 100644 --- a/neqo-bin/src/bin/server/main.rs +++ b/neqo-bin/src/bin/server/main.rs @@ -543,7 +543,8 @@ impl ServersRunner { match self.ready().await? { Ready::Socket(inx) => loop { let (host, socket) = self.sockets.get_mut(inx).unwrap(); - let dgrams = socket.recv(host)?; + // TODO: Remove collect. + let dgrams: Vec<_> = socket.recv(host)?.collect(); if dgrams.is_empty() { break; } diff --git a/neqo-bin/src/udp.rs b/neqo-bin/src/udp.rs index f4ede0b5c2..94b6fed8e0 100644 --- a/neqo-bin/src/udp.rs +++ b/neqo-bin/src/udp.rs @@ -8,15 +8,24 @@ #![allow(clippy::missing_panics_doc)] // Functions simply delegate to tokio and quinn-udp. use std::{ + array, io::{self, IoSliceMut}, + mem::MaybeUninit, net::{SocketAddr, ToSocketAddrs}, slice, }; -use neqo_common::{Datagram, IpTos}; +use neqo_common::{qwarn, Datagram, IpTos}; use quinn_udp::{EcnCodepoint, RecvMeta, Transmit, UdpSocketState}; use tokio::io::Interest; +#[cfg(not(any(target_os = "macos", target_os = "ios")))] +// Chosen somewhat arbitrarily; might benefit from additional tuning. +pub(crate) const BATCH_SIZE: usize = 32; + +#[cfg(any(target_os = "macos", target_os = "ios"))] +pub(crate) const BATCH_SIZE: usize = 1; + /// Socket receive buffer size. /// /// Allows reading multiple datagrams in a single [`Socket::recv`] call. @@ -25,7 +34,7 @@ const RECV_BUF_SIZE: usize = u16::MAX as usize; pub struct Socket { socket: tokio::net::UdpSocket, state: UdpSocketState, - recv_buf: Vec, + recv_bufs: [Vec; BATCH_SIZE], } impl Socket { @@ -36,7 +45,7 @@ impl Socket { Ok(Self { state: quinn_udp::UdpSocketState::new((&socket).into())?, socket: tokio::net::UdpSocket::from_std(socket)?, - recv_buf: vec![0; RECV_BUF_SIZE], + recv_bufs: array::from_fn(|_| vec![0; RECV_BUF_SIZE]), }) } @@ -75,54 +84,58 @@ impl Socket { Ok(()) } - /// Receive a UDP datagram on the specified socket. - pub fn recv(&mut self, local_address: &SocketAddr) -> Result, io::Error> { - let mut meta = RecvMeta::default(); + /// Receive UDP datagrams on the specified socket. + pub fn recv<'a>( + &'a mut self, + local_address: &'a SocketAddr, + ) -> Result + 'a, io::Error> { + let mut metas = [RecvMeta::default(); BATCH_SIZE]; + + // TODO: Safe? Double check. + let mut iovs = MaybeUninit::<[IoSliceMut<'_>; BATCH_SIZE]>::uninit(); + for (i, iov) in self + .recv_bufs + .iter_mut() + .map(|b| IoSliceMut::new(b)) + .enumerate() + { + unsafe { + iovs.as_mut_ptr().cast::().add(i).write(iov); + }; + } + let mut iovs = unsafe { iovs.assume_init() }; - match self.socket.try_io(Interest::READABLE, || { - self.state.recv( - (&self.socket).into(), - &mut [IoSliceMut::new(&mut self.recv_buf)], - slice::from_mut(&mut meta), - ) + let msgs = match self.socket.try_io(Interest::READABLE, || { + self.state + .recv((&self.socket).into(), &mut iovs, &mut metas) }) { - Ok(n) => { - assert_eq!(n, 1, "only passed one slice"); - } - Err(ref err) - if err.kind() == io::ErrorKind::WouldBlock - || err.kind() == io::ErrorKind::Interrupted => - { - return Ok(vec![]) - } - Err(err) => { - return Err(err); - } + Ok(n) => n, + Err(e) if e.kind() == io::ErrorKind::WouldBlock => 0, + Err(e) => return Err(e), }; - if meta.len == 0 { - eprintln!("zero length datagram received?"); - return Ok(vec![]); - } - if meta.len == self.recv_buf.len() { - eprintln!( - "Might have received more than {} bytes", - self.recv_buf.len() - ); - } - - Ok(self.recv_buf[0..meta.len] - .chunks(meta.stride.min(self.recv_buf.len())) - .map(|d| { - Datagram::new( - meta.addr, - *local_address, - meta.ecn.map(|n| IpTos::from(n as u8)).unwrap_or_default(), - None, // TODO: get the real TTL https://github.com/quinn-rs/quinn/issues/1749 - d, - ) - }) - .collect()) + Ok(metas + .into_iter() + .zip(self.recv_bufs.iter()) + .take(msgs) + .flat_map(move |(meta, buf)| { + // TODO: Needed? + if meta.len == buf.len() { + qwarn!("Might have received more than {} bytes", buf.len()); + } + + buf[0..meta.len] + .chunks(meta.stride.min(buf.len())) + .map(move |d| { + Datagram::new( + meta.addr, + *local_address, + meta.ecn.map(|n| IpTos::from(n as u8)).unwrap_or_default(), + None, // TODO: get the real TTL https://github.com/quinn-rs/quinn/issues/1749 + d, + ) + }) + })) } }