Skip to content

Commit

Permalink
fix: resolve accept_with_await never returning with 2 fixes (#122)
Browse files Browse the repository at this point in the history
* fix: accept logic blocking accept_with_cid logic

* fix: add timeout for accept_with_cid so it doesn't block forever

* fix: resolve pr concerns
  • Loading branch information
KolbyML authored Jan 25, 2024
1 parent 83d256f commit a519426
Showing 1 changed file with 82 additions and 54 deletions.
136 changes: 82 additions & 54 deletions src/socket.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::io;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex, RwLock};
use std::time::Duration;

use delay_map::HashMapDelay;
use futures::StreamExt;
use std::hash::{Hash, Hasher};
use tokio::net::UdpSocket;
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::{mpsc, oneshot};

use crate::cid::{ConnectionId, ConnectionIdGenerator, ConnectionPeer, StdConnectionIdGenerator};
Expand All @@ -13,20 +19,28 @@ use crate::packet::{Packet, PacketType};
use crate::stream::UtpStream;
use crate::udp::AsyncUdpSocket;

type ConnChannel = mpsc::UnboundedSender<StreamEvent>;
type ConnChannel = UnboundedSender<StreamEvent>;

struct Accept<P> {
stream: oneshot::Sender<io::Result<UtpStream<P>>>,
config: ConnectionConfig,
}

const MAX_UDP_PAYLOAD_SIZE: usize = u16::MAX as usize;
/// accept_with_cid() has unique interactions compared to accept()
/// accept() pulls awaiting requests off a queue, but accept_with_cid() only
/// takes a connection off if CID matches. Because of this if we are awaiting a CID
/// eventually we need to timeout the await, or the queue would never stop growing with stale awaits
/// 20 seconds is arbatrary, after the uTP cofig refactor is done that can replace this constant.
/// but thee uTP config refactor is currently very low priority.
const AWAITING_CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);

pub struct UtpSocket<P> {
conns: Arc<RwLock<HashMap<ConnectionId<P>, ConnChannel>>>,
cid_gen: Mutex<StdConnectionIdGenerator<P>>,
accepts: mpsc::UnboundedSender<(Accept<P>, Option<ConnectionId<P>>)>,
socket_events: mpsc::UnboundedSender<SocketEvent<P>>,
accepts: UnboundedSender<Accept<P>>,
accepts_with_cid: UnboundedSender<(Accept<P>, ConnectionId<P>)>,
socket_events: UnboundedSender<SocketEvent<P>>,
}

impl UtpSocket<SocketAddr> {
Expand All @@ -50,18 +64,21 @@ where

let cid_gen = Mutex::new(StdConnectionIdGenerator::new());

let awaiting: HashMap<ConnectionId<P>, Accept<P>> = HashMap::new();
let awaiting = Arc::new(RwLock::new(awaiting));
// if an accept_with_cid awaiting connection isn't connected in AWAITING_CONNECTION_TIMEOUT seconds, cancel and log it
let mut awaiting: HashMapDelay<u64, ((u16, u16), Accept<P>)> =
HashMapDelay::new(AWAITING_CONNECTION_TIMEOUT);

let mut incoming_conns = HashMap::new();

let (socket_event_tx, mut socket_event_rx) = mpsc::unbounded_channel();
let (accepts_tx, mut accepts_rx) = mpsc::unbounded_channel();
let (accepts_with_cid_tx, mut accepts_with_cid_rx) = mpsc::unbounded_channel();

let utp = Self {
conns: Arc::clone(&conns),
cid_gen,
accepts: accepts_tx,
accepts_with_cid: accepts_with_cid_tx,
socket_events: socket_event_tx.clone(),
};

Expand Down Expand Up @@ -94,12 +111,11 @@ where
None => {
if std::matches!(packet.packet_type(), PacketType::Syn) {
let cid = cid_from_packet(&packet, &src, IdType::RecvId);
let mut awaiting = awaiting.write().unwrap();

// If there was an awaiting connection with the CID, then
// create a new stream for that connection. Otherwise, add the
// connection to the incoming connections.
if let Some(accept) = awaiting.remove(&cid) {
if let Some((_, accept)) = awaiting.remove(&calculate_hash(&cid)) {
let (connected_tx, connected_rx) = oneshot::channel();
let (events_tx, events_rx) = mpsc::unbounded_channel();

Expand Down Expand Up @@ -135,51 +151,19 @@ where
},
}
}
Some((accept, cid)) = accepts_rx.recv(), if !incoming_conns.is_empty() => {
let (cid, syn) = match cid {
// If a CID was given, then check for an incoming connection with that
// CID. If one is found, then use that connection. Otherwise, add the
// CID to the awaiting connections.
Some(cid) => {
if let Some(syn) = incoming_conns.remove(&cid) {
(cid, syn)
} else {
awaiting.write().unwrap().insert(cid, accept);
continue;
}
}
// If a CID was not given, then pull an incoming connection, and use
// that connection's CID. An incoming connection is known to exist
// because of the condition in the `select` arm.
None => {
let cid = incoming_conns.keys().next().unwrap().clone();
let syn = incoming_conns.remove(&cid).unwrap();
(cid, syn)
}
Some((accept, cid)) = accepts_with_cid_rx.recv() => {
let (cid, syn) = if let Some(syn) = incoming_conns.remove(&cid) {
(cid, syn)
} else {
awaiting.insert(calculate_hash(&cid), ((cid.send, cid.recv), accept));
continue;
};

let (connected_tx, connected_rx) = oneshot::channel();
let (events_tx, events_rx) = mpsc::unbounded_channel();

{
conns
.write()
.unwrap()
.insert(cid.clone(), events_tx);
}

let stream = UtpStream::new(
cid,
accept.config,
Some(syn),
socket_event_tx.clone(),
events_rx,
connected_tx,
);

tokio::spawn(async move {
Self::await_connected(stream, accept, connected_rx).await
});
Self::select_accept_helper(cid, syn, conns.clone(), accept, socket_event_tx.clone());
}
Some(accept) = accepts_rx.recv(), if !incoming_conns.is_empty() => {
let cid = incoming_conns.keys().next().unwrap().clone();
let syn = incoming_conns.remove(&cid).unwrap();
Self::select_accept_helper(cid, syn, conns.clone(), accept, socket_event_tx.clone());
}
Some(event) = socket_event_rx.recv() => {
match event {
Expand All @@ -202,6 +186,14 @@ where
}
}
}
Some(Ok((_, ((send, recv), accept)))) = awaiting.next() => {
// accept_with_cid didn't recieve an inbound connection within the timeout period
// log it and return a timeout error
tracing::debug!(%send, %recv, "accept_with_cid timed out");
let _ = accept
.stream
.send(Err(io::Error::from(io::ErrorKind::TimedOut)));
}
}
}
});
Expand All @@ -218,21 +210,25 @@ where
self.conns.read().unwrap().len()
}

/// WARNING: only accept() or accept_with_cid() can be used in an application.
/// they aren't compatible to use interchangeably in a program
pub async fn accept(&self, config: ConnectionConfig) -> io::Result<UtpStream<P>> {
let (stream_tx, stream_rx) = oneshot::channel();
let accept = Accept {
stream: stream_tx,
config,
};
self.accepts
.send((accept, None))
.send(accept)
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?;
match stream_rx.await {
Ok(stream) => Ok(stream?),
Err(..) => Err(io::Error::from(io::ErrorKind::TimedOut)),
}
}

/// WARNING: only accept() or accept_with_cid() can be used in an application.
/// they aren't compatible to use interchangeably in a program
pub async fn accept_with_cid(
&self,
cid: ConnectionId<P>,
Expand All @@ -243,8 +239,8 @@ where
stream: stream_tx,
config,
};
self.accepts
.send((accept, Some(cid)))
self.accepts_with_cid
.send((accept, cid))
.map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?;
match stream_rx.await {
Ok(stream) => Ok(stream?),
Expand Down Expand Up @@ -337,6 +333,32 @@ where
}
}
}

fn select_accept_helper(
cid: ConnectionId<P>,
syn: Packet,
conns: Arc<RwLock<HashMap<ConnectionId<P>, UnboundedSender<StreamEvent>>>>,
accept: Accept<P>,
socket_event_tx: UnboundedSender<SocketEvent<P>>,
) {
let (connected_tx, connected_rx) = oneshot::channel();
let (events_tx, events_rx) = mpsc::unbounded_channel();

{
conns.write().unwrap().insert(cid.clone(), events_tx);
}

let stream = UtpStream::new(
cid,
accept.config,
Some(syn),
socket_event_tx,
events_rx,
connected_tx,
);

tokio::spawn(async move { Self::await_connected(stream, accept, connected_rx).await });
}
}

#[derive(Copy, Clone, Debug)]
Expand Down Expand Up @@ -391,3 +413,9 @@ impl<P> Drop for UtpSocket<P> {
}
}
}

fn calculate_hash<T: Hash>(t: &T) -> u64 {
let mut s = DefaultHasher::new();
t.hash(&mut s);
s.finish()
}

0 comments on commit a519426

Please sign in to comment.