Skip to content

Commit a772f7b

Browse files
committed
agent: support arbitrary STUN request timeouts
1 parent c106a15 commit a772f7b

File tree

1 file changed

+151
-5
lines changed

1 file changed

+151
-5
lines changed

stun-proto/src/agent.rs

+151-5
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ impl StunAgent {
174174
fields(stun_id = self.id)
175175
)]
176176
fn validated_peer(&mut self, addr: SocketAddr) {
177-
if self.validated_peers.get(&addr).is_none() {
177+
if !self.validated_peers.contains(&addr) {
178178
debug!("validated peer {:?}", addr);
179179
self.validated_peers.insert(addr);
180180
}
@@ -502,6 +502,7 @@ struct StunRequestState {
502502
from: SocketAddr,
503503
to: SocketAddr,
504504
timeouts_ms: Vec<u64>,
505+
last_retransmit_timeout_ms: u64,
505506
recv_cancelled: bool,
506507
send_cancelled: bool,
507508
timeout_i: usize,
@@ -516,10 +517,10 @@ impl StunRequestState {
516517
to: SocketAddr,
517518
) -> Self {
518519
let data = request.build();
519-
let timeouts_ms = if transport == TransportType::Tcp {
520-
vec![39500]
520+
let (timeouts_ms, last_retransmit_timeout_ms) = if transport == TransportType::Tcp {
521+
(vec![], 39500)
521522
} else {
522-
vec![500, 1000, 2000, 4000, 8000, 16000]
523+
(vec![500, 1000, 2000, 4000, 8000, 16000], 8000)
523524
};
524525
Self {
525526
transaction_id: request.transaction_id(),
@@ -531,6 +532,7 @@ impl StunRequestState {
531532
|| request.has_attribute(MessageIntegritySha256::TYPE),
532533
timeouts_ms,
533534
timeout_i: 0,
535+
last_retransmit_timeout_ms,
534536
recv_cancelled: false,
535537
send_cancelled: false,
536538
last_send_time: None,
@@ -551,6 +553,10 @@ impl StunRequestState {
551553
// TODO: account for TCP connect in timeout
552554
if let Some(last_send) = self.last_send_time {
553555
if self.timeout_i >= self.timeouts_ms.len() {
556+
let next_send = last_send + Duration::from_millis(self.last_retransmit_timeout_ms);
557+
if next_send > now {
558+
return StunRequestPollRet::WaitUntil(next_send);
559+
}
554560
return StunRequestPollRet::TimedOut;
555561
}
556562
let next_send = last_send + Duration::from_millis(self.timeouts_ms[self.timeout_i]);
@@ -560,7 +566,7 @@ impl StunRequestState {
560566
self.timeout_i += 1;
561567
}
562568
if self.send_cancelled {
563-
// this calcelaltion may need a different value
569+
// this cancellation may need a different value
564570
return StunRequestPollRet::Cancelled;
565571
}
566572
self.last_send_time = Some(now);
@@ -623,6 +629,37 @@ impl<'a> StunRequestMut<'a> {
623629
pub fn mut_agent(&mut self) -> &mut StunAgent {
624630
self.agent
625631
}
632+
633+
/// Configure timeouts for the STUN transaction. As specified in RFC 8489, `initial_rto`
634+
/// should be >= 500ms, `retransmits` has a default value of 7, and `last_retransmit_timeout`
635+
/// should be 16 * `initial_rto`.
636+
///
637+
/// STUN transactions over TCP will only send a single request and have a timeout of the sum of
638+
/// the timeouts of a UDP transaction.
639+
pub fn configure_timeout(
640+
&mut self,
641+
initial_rto: Duration,
642+
retransmits: u32,
643+
last_retransmit_timeout: Duration,
644+
) {
645+
if let Some(state) = self.agent.mut_request_state(self.transaction_id) {
646+
match state.transport {
647+
TransportType::Udp => {
648+
state.timeouts_ms = (0..retransmits)
649+
.map(|i| (initial_rto * 2u32.pow(i)).as_millis() as u64)
650+
.collect::<Vec<_>>();
651+
state.last_retransmit_timeout_ms = last_retransmit_timeout.as_millis() as u64
652+
}
653+
TransportType::Tcp => {
654+
state.timeouts_ms = vec![];
655+
state.last_retransmit_timeout_ms = (last_retransmit_timeout
656+
+ (0..retransmits)
657+
.fold(Duration::ZERO, |acc, i| acc + initial_rto * 2u32.pow(i)))
658+
.as_millis() as u64;
659+
}
660+
}
661+
}
662+
}
626663
}
627664

628665
/// Return value when handling possible STUN data
@@ -843,6 +880,115 @@ pub(crate) mod tests {
843880
assert!(!agent.is_validated_peer(remote_addr));
844881
}
845882

883+
#[test]
884+
fn request_custom_timeout() {
885+
let _log = crate::tests::test_init_log();
886+
let local_addr = "127.0.0.1:2000".parse().unwrap();
887+
let remote_addr = "127.0.0.1:1000".parse().unwrap();
888+
let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
889+
.remote_addr(remote_addr)
890+
.build();
891+
let msg = Message::builder_request(BINDING);
892+
let transaction_id = msg.transaction_id();
893+
let mut now = Instant::now();
894+
agent.send(msg, remote_addr, now).unwrap();
895+
let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
896+
transaction.configure_timeout(Duration::from_secs(1), 2, Duration::from_secs(10));
897+
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
898+
unreachable!();
899+
};
900+
assert_eq!(wait - now, Duration::from_secs(1));
901+
now = wait;
902+
let StunAgentPollRet::SendData(_) = agent.poll(now) else {
903+
unreachable!();
904+
};
905+
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
906+
unreachable!();
907+
};
908+
assert_eq!(wait - now, Duration::from_secs(2));
909+
now = wait;
910+
let StunAgentPollRet::SendData(_) = agent.poll(now) else {
911+
unreachable!();
912+
};
913+
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
914+
unreachable!();
915+
};
916+
assert_eq!(wait - now, Duration::from_secs(10));
917+
now = wait;
918+
let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
919+
unreachable!();
920+
};
921+
assert_eq!(timed_out, transaction_id);
922+
923+
assert!(agent.request_transaction(transaction_id).is_none());
924+
assert!(agent.mut_request_transaction(transaction_id).is_none());
925+
926+
// unvalidated peer data should be dropped
927+
assert!(!agent.is_validated_peer(remote_addr));
928+
}
929+
930+
#[test]
931+
fn request_no_retransmit() {
932+
let _log = crate::tests::test_init_log();
933+
let local_addr = "127.0.0.1:2000".parse().unwrap();
934+
let remote_addr = "127.0.0.1:1000".parse().unwrap();
935+
let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
936+
.remote_addr(remote_addr)
937+
.build();
938+
let msg = Message::builder_request(BINDING);
939+
let transaction_id = msg.transaction_id();
940+
let mut now = Instant::now();
941+
agent.send(msg, remote_addr, now).unwrap();
942+
let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
943+
transaction.configure_timeout(Duration::from_secs(1), 0, Duration::from_secs(10));
944+
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
945+
unreachable!();
946+
};
947+
assert_eq!(wait - now, Duration::from_secs(10));
948+
now = wait;
949+
let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
950+
unreachable!();
951+
};
952+
assert_eq!(timed_out, transaction_id);
953+
954+
assert!(agent.request_transaction(transaction_id).is_none());
955+
assert!(agent.mut_request_transaction(transaction_id).is_none());
956+
957+
// unvalidated peer data should be dropped
958+
assert!(!agent.is_validated_peer(remote_addr));
959+
}
960+
961+
#[test]
962+
fn request_tcp_custom_timeout() {
963+
let _log = crate::tests::test_init_log();
964+
let local_addr = "127.0.0.1:2000".parse().unwrap();
965+
let remote_addr = "127.0.0.1:1000".parse().unwrap();
966+
let mut agent = StunAgent::builder(TransportType::Tcp, local_addr)
967+
.remote_addr(remote_addr)
968+
.build();
969+
let msg = Message::builder_request(BINDING);
970+
let transaction_id = msg.transaction_id();
971+
let mut now = Instant::now();
972+
agent.send(msg, remote_addr, now).unwrap();
973+
let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
974+
transaction.configure_timeout(Duration::from_secs(1), 3, Duration::from_secs(3));
975+
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
976+
unreachable!();
977+
};
978+
assert_eq!(wait - now, Duration::from_secs(1 + 2 + 4 + 3));
979+
now = wait;
980+
let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
981+
unreachable!();
982+
};
983+
assert_eq!(timed_out, transaction_id);
984+
985+
assert!(agent.request_transaction(transaction_id).is_none());
986+
assert!(agent.mut_request_transaction(transaction_id).is_none());
987+
988+
// unvalidated peer data should be dropped
989+
assert!(!agent.is_validated_peer(remote_addr));
990+
}
991+
846992
#[test]
847993
fn request_without_credentials() {
848994
let _log = crate::tests::test_init_log();

0 commit comments

Comments
 (0)