Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

agent: support arbitrary STUN request timeouts #28

Merged
merged 1 commit into from
Oct 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 151 additions & 5 deletions stun-proto/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@
fields(stun_id = self.id)
)]
fn validated_peer(&mut self, addr: SocketAddr) {
if self.validated_peers.get(&addr).is_none() {
if !self.validated_peers.contains(&addr) {
debug!("validated peer {:?}", addr);
self.validated_peers.insert(addr);
}
Expand Down Expand Up @@ -502,6 +502,7 @@
from: SocketAddr,
to: SocketAddr,
timeouts_ms: Vec<u64>,
last_retransmit_timeout_ms: u64,
recv_cancelled: bool,
send_cancelled: bool,
timeout_i: usize,
Expand All @@ -516,10 +517,10 @@
to: SocketAddr,
) -> Self {
let data = request.build();
let timeouts_ms = if transport == TransportType::Tcp {
vec![39500]
let (timeouts_ms, last_retransmit_timeout_ms) = if transport == TransportType::Tcp {
(vec![], 39500)
} else {
vec![500, 1000, 2000, 4000, 8000, 16000]
(vec![500, 1000, 2000, 4000, 8000, 16000], 8000)
};
Self {
transaction_id: request.transaction_id(),
Expand All @@ -531,6 +532,7 @@
|| request.has_attribute(MessageIntegritySha256::TYPE),
timeouts_ms,
timeout_i: 0,
last_retransmit_timeout_ms,
recv_cancelled: false,
send_cancelled: false,
last_send_time: None,
Expand All @@ -551,6 +553,10 @@
// TODO: account for TCP connect in timeout
if let Some(last_send) = self.last_send_time {
if self.timeout_i >= self.timeouts_ms.len() {
let next_send = last_send + Duration::from_millis(self.last_retransmit_timeout_ms);
if next_send > now {
return StunRequestPollRet::WaitUntil(next_send);
}
return StunRequestPollRet::TimedOut;
}
let next_send = last_send + Duration::from_millis(self.timeouts_ms[self.timeout_i]);
Expand All @@ -560,7 +566,7 @@
self.timeout_i += 1;
}
if self.send_cancelled {
// this calcelaltion may need a different value
// this cancellation may need a different value
return StunRequestPollRet::Cancelled;
}
self.last_send_time = Some(now);
Expand Down Expand Up @@ -623,6 +629,37 @@
pub fn mut_agent(&mut self) -> &mut StunAgent {
self.agent
}

/// Configure timeouts for the STUN transaction. As specified in RFC 8489, `initial_rto`
/// should be >= 500ms, `retransmits` has a default value of 7, and `last_retransmit_timeout`
/// should be 16 * `initial_rto`.
///
/// STUN transactions over TCP will only send a single request and have a timeout of the sum of
/// the timeouts of a UDP transaction.
pub fn configure_timeout(
&mut self,
initial_rto: Duration,
retransmits: u32,
last_retransmit_timeout: Duration,
) {
if let Some(state) = self.agent.mut_request_state(self.transaction_id) {
match state.transport {
TransportType::Udp => {

Check warning on line 647 in stun-proto/src/agent.rs

View check run for this annotation

Codecov / codecov/patch

stun-proto/src/agent.rs#L647

Added line #L647 was not covered by tests
state.timeouts_ms = (0..retransmits)
.map(|i| (initial_rto * 2u32.pow(i)).as_millis() as u64)
.collect::<Vec<_>>();

Check warning on line 650 in stun-proto/src/agent.rs

View check run for this annotation

Codecov / codecov/patch

stun-proto/src/agent.rs#L650

Added line #L650 was not covered by tests
state.last_retransmit_timeout_ms = last_retransmit_timeout.as_millis() as u64
}
TransportType::Tcp => {
state.timeouts_ms = vec![];
state.last_retransmit_timeout_ms = (last_retransmit_timeout
+ (0..retransmits)
.fold(Duration::ZERO, |acc, i| acc + initial_rto * 2u32.pow(i)))
.as_millis() as u64;

Check warning on line 658 in stun-proto/src/agent.rs

View check run for this annotation

Codecov / codecov/patch

stun-proto/src/agent.rs#L658

Added line #L658 was not covered by tests
}
}
}
}
}

/// Return value when handling possible STUN data
Expand Down Expand Up @@ -843,6 +880,115 @@
assert!(!agent.is_validated_peer(remote_addr));
}

#[test]
fn request_custom_timeout() {
let _log = crate::tests::test_init_log();
let local_addr = "127.0.0.1:2000".parse().unwrap();
let remote_addr = "127.0.0.1:1000".parse().unwrap();
let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
.remote_addr(remote_addr)
.build();
let msg = Message::builder_request(BINDING);
let transaction_id = msg.transaction_id();
let mut now = Instant::now();
agent.send(msg, remote_addr, now).unwrap();
let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
transaction.configure_timeout(Duration::from_secs(1), 2, Duration::from_secs(10));
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
unreachable!();
};
assert_eq!(wait - now, Duration::from_secs(1));
now = wait;
let StunAgentPollRet::SendData(_) = agent.poll(now) else {
unreachable!();
};
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
unreachable!();
};
assert_eq!(wait - now, Duration::from_secs(2));
now = wait;
let StunAgentPollRet::SendData(_) = agent.poll(now) else {
unreachable!();
};
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
unreachable!();
};
assert_eq!(wait - now, Duration::from_secs(10));
now = wait;
let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
unreachable!();
};
assert_eq!(timed_out, transaction_id);

assert!(agent.request_transaction(transaction_id).is_none());
assert!(agent.mut_request_transaction(transaction_id).is_none());

// unvalidated peer data should be dropped
assert!(!agent.is_validated_peer(remote_addr));
}

#[test]
fn request_no_retransmit() {
let _log = crate::tests::test_init_log();
let local_addr = "127.0.0.1:2000".parse().unwrap();
let remote_addr = "127.0.0.1:1000".parse().unwrap();
let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
.remote_addr(remote_addr)
.build();
let msg = Message::builder_request(BINDING);
let transaction_id = msg.transaction_id();
let mut now = Instant::now();
agent.send(msg, remote_addr, now).unwrap();
let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
transaction.configure_timeout(Duration::from_secs(1), 0, Duration::from_secs(10));
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
unreachable!();
};
assert_eq!(wait - now, Duration::from_secs(10));
now = wait;
let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
unreachable!();
};
assert_eq!(timed_out, transaction_id);

assert!(agent.request_transaction(transaction_id).is_none());
assert!(agent.mut_request_transaction(transaction_id).is_none());

// unvalidated peer data should be dropped
assert!(!agent.is_validated_peer(remote_addr));
}

#[test]
fn request_tcp_custom_timeout() {
let _log = crate::tests::test_init_log();
let local_addr = "127.0.0.1:2000".parse().unwrap();
let remote_addr = "127.0.0.1:1000".parse().unwrap();
let mut agent = StunAgent::builder(TransportType::Tcp, local_addr)
.remote_addr(remote_addr)
.build();
let msg = Message::builder_request(BINDING);
let transaction_id = msg.transaction_id();
let mut now = Instant::now();
agent.send(msg, remote_addr, now).unwrap();
let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
transaction.configure_timeout(Duration::from_secs(1), 3, Duration::from_secs(3));
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
unreachable!();
};
assert_eq!(wait - now, Duration::from_secs(1 + 2 + 4 + 3));
now = wait;
let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
unreachable!();
};
assert_eq!(timed_out, transaction_id);

assert!(agent.request_transaction(transaction_id).is_none());
assert!(agent.mut_request_transaction(transaction_id).is_none());

// unvalidated peer data should be dropped
assert!(!agent.is_validated_peer(remote_addr));
}

#[test]
fn request_without_credentials() {
let _log = crate::tests::test_init_log();
Expand Down
Loading