diff --git a/src/socket.rs b/src/socket.rs index dad22f4..f0d2016 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -1,9 +1,15 @@ +use std::collections::hash_map::DefaultHasher; use std::collections::HashMap; use std::io; use std::net::SocketAddr; use std::sync::{Arc, Mutex, RwLock}; +use std::time::Duration; +use delay_map::HashMapDelay; +use futures::StreamExt; +use std::hash::{Hash, Hasher}; use tokio::net::UdpSocket; +use tokio::sync::mpsc::UnboundedSender; use tokio::sync::{mpsc, oneshot}; use crate::cid::{ConnectionId, ConnectionIdGenerator, ConnectionPeer, StdConnectionIdGenerator}; @@ -13,7 +19,7 @@ use crate::packet::{Packet, PacketType}; use crate::stream::UtpStream; use crate::udp::AsyncUdpSocket; -type ConnChannel = mpsc::UnboundedSender; +type ConnChannel = UnboundedSender; struct Accept

{ stream: oneshot::Sender>>, @@ -21,12 +27,20 @@ struct Accept

{ } const MAX_UDP_PAYLOAD_SIZE: usize = u16::MAX as usize; +/// accept_with_cid() has unique interactions compared to accept() +/// accept() pulls awaiting requests off a queue, but accept_with_cid() only +/// takes a connection off if CID matches. Because of this if we are awaiting a CID +/// eventually we need to timeout the await, or the queue would never stop growing with stale awaits +/// 20 seconds is arbatrary, after the uTP cofig refactor is done that can replace this constant. +/// but thee uTP config refactor is currently very low priority. +const AWAITING_CONNECTION_TIMEOUT: Duration = Duration::from_secs(20); pub struct UtpSocket

{ conns: Arc, ConnChannel>>>, cid_gen: Mutex>, - accepts: mpsc::UnboundedSender<(Accept

, Option>)>, - socket_events: mpsc::UnboundedSender>, + accepts: UnboundedSender>, + accepts_with_cid: UnboundedSender<(Accept

, ConnectionId

)>, + socket_events: UnboundedSender>, } impl UtpSocket { @@ -50,18 +64,21 @@ where let cid_gen = Mutex::new(StdConnectionIdGenerator::new()); - let awaiting: HashMap, Accept

> = HashMap::new(); - let awaiting = Arc::new(RwLock::new(awaiting)); + // if an accept_with_cid awaiting connection isn't connected in AWAITING_CONNECTION_TIMEOUT seconds, cancel and log it + let mut awaiting: HashMapDelay)> = + HashMapDelay::new(AWAITING_CONNECTION_TIMEOUT); let mut incoming_conns = HashMap::new(); let (socket_event_tx, mut socket_event_rx) = mpsc::unbounded_channel(); let (accepts_tx, mut accepts_rx) = mpsc::unbounded_channel(); + let (accepts_with_cid_tx, mut accepts_with_cid_rx) = mpsc::unbounded_channel(); let utp = Self { conns: Arc::clone(&conns), cid_gen, accepts: accepts_tx, + accepts_with_cid: accepts_with_cid_tx, socket_events: socket_event_tx.clone(), }; @@ -94,12 +111,11 @@ where None => { if std::matches!(packet.packet_type(), PacketType::Syn) { let cid = cid_from_packet(&packet, &src, IdType::RecvId); - let mut awaiting = awaiting.write().unwrap(); // If there was an awaiting connection with the CID, then // create a new stream for that connection. Otherwise, add the // connection to the incoming connections. - if let Some(accept) = awaiting.remove(&cid) { + if let Some((_, accept)) = awaiting.remove(&calculate_hash(&cid)) { let (connected_tx, connected_rx) = oneshot::channel(); let (events_tx, events_rx) = mpsc::unbounded_channel(); @@ -135,51 +151,19 @@ where }, } } - Some((accept, cid)) = accepts_rx.recv(), if !incoming_conns.is_empty() => { - let (cid, syn) = match cid { - // If a CID was given, then check for an incoming connection with that - // CID. If one is found, then use that connection. Otherwise, add the - // CID to the awaiting connections. - Some(cid) => { - if let Some(syn) = incoming_conns.remove(&cid) { - (cid, syn) - } else { - awaiting.write().unwrap().insert(cid, accept); - continue; - } - } - // If a CID was not given, then pull an incoming connection, and use - // that connection's CID. An incoming connection is known to exist - // because of the condition in the `select` arm. - None => { - let cid = incoming_conns.keys().next().unwrap().clone(); - let syn = incoming_conns.remove(&cid).unwrap(); - (cid, syn) - } + Some((accept, cid)) = accepts_with_cid_rx.recv() => { + let (cid, syn) = if let Some(syn) = incoming_conns.remove(&cid) { + (cid, syn) + } else { + awaiting.insert(calculate_hash(&cid), ((cid.send, cid.recv), accept)); + continue; }; - - let (connected_tx, connected_rx) = oneshot::channel(); - let (events_tx, events_rx) = mpsc::unbounded_channel(); - - { - conns - .write() - .unwrap() - .insert(cid.clone(), events_tx); - } - - let stream = UtpStream::new( - cid, - accept.config, - Some(syn), - socket_event_tx.clone(), - events_rx, - connected_tx, - ); - - tokio::spawn(async move { - Self::await_connected(stream, accept, connected_rx).await - }); + Self::select_accept_helper(cid, syn, conns.clone(), accept, socket_event_tx.clone()); + } + Some(accept) = accepts_rx.recv(), if !incoming_conns.is_empty() => { + let cid = incoming_conns.keys().next().unwrap().clone(); + let syn = incoming_conns.remove(&cid).unwrap(); + Self::select_accept_helper(cid, syn, conns.clone(), accept, socket_event_tx.clone()); } Some(event) = socket_event_rx.recv() => { match event { @@ -202,6 +186,14 @@ where } } } + Some(Ok((_, ((send, recv), accept)))) = awaiting.next() => { + // accept_with_cid didn't recieve an inbound connection within the timeout period + // log it and return a timeout error + tracing::debug!(%send, %recv, "accept_with_cid timed out"); + let _ = accept + .stream + .send(Err(io::Error::from(io::ErrorKind::TimedOut))); + } } } }); @@ -218,6 +210,8 @@ where self.conns.read().unwrap().len() } + /// WARNING: only accept() or accept_with_cid() can be used in an application. + /// they aren't compatible to use interchangeably in a program pub async fn accept(&self, config: ConnectionConfig) -> io::Result> { let (stream_tx, stream_rx) = oneshot::channel(); let accept = Accept { @@ -225,7 +219,7 @@ where config, }; self.accepts - .send((accept, None)) + .send(accept) .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?; match stream_rx.await { Ok(stream) => Ok(stream?), @@ -233,6 +227,8 @@ where } } + /// WARNING: only accept() or accept_with_cid() can be used in an application. + /// they aren't compatible to use interchangeably in a program pub async fn accept_with_cid( &self, cid: ConnectionId

, @@ -243,8 +239,8 @@ where stream: stream_tx, config, }; - self.accepts - .send((accept, Some(cid))) + self.accepts_with_cid + .send((accept, cid)) .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?; match stream_rx.await { Ok(stream) => Ok(stream?), @@ -337,6 +333,32 @@ where } } } + + fn select_accept_helper( + cid: ConnectionId

, + syn: Packet, + conns: Arc, UnboundedSender>>>, + accept: Accept

, + socket_event_tx: UnboundedSender>, + ) { + let (connected_tx, connected_rx) = oneshot::channel(); + let (events_tx, events_rx) = mpsc::unbounded_channel(); + + { + conns.write().unwrap().insert(cid.clone(), events_tx); + } + + let stream = UtpStream::new( + cid, + accept.config, + Some(syn), + socket_event_tx, + events_rx, + connected_tx, + ); + + tokio::spawn(async move { Self::await_connected(stream, accept, connected_rx).await }); + } } #[derive(Copy, Clone, Debug)] @@ -391,3 +413,9 @@ impl

Drop for UtpSocket

{ } } } + +fn calculate_hash(t: &T) -> u64 { + let mut s = DefaultHasher::new(); + t.hash(&mut s); + s.finish() +}