Skip to content

Commit ee1f1ca

Browse files
committed
agent: support arbitrary STUN request timeouts
1 parent c106a15 commit ee1f1ca

File tree

1 file changed

+156
-5
lines changed

1 file changed

+156
-5
lines changed

stun-proto/src/agent.rs

+156-5
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ impl StunAgent {
174174
fields(stun_id = self.id)
175175
)]
176176
fn validated_peer(&mut self, addr: SocketAddr) {
177-
if self.validated_peers.get(&addr).is_none() {
177+
if !self.validated_peers.contains(&addr) {
178178
debug!("validated peer {:?}", addr);
179179
self.validated_peers.insert(addr);
180180
}
@@ -502,6 +502,7 @@ struct StunRequestState {
502502
from: SocketAddr,
503503
to: SocketAddr,
504504
timeouts_ms: Vec<u64>,
505+
last_retransmit_timeout_ms: u64,
505506
recv_cancelled: bool,
506507
send_cancelled: bool,
507508
timeout_i: usize,
@@ -516,10 +517,10 @@ impl StunRequestState {
516517
to: SocketAddr,
517518
) -> Self {
518519
let data = request.build();
519-
let timeouts_ms = if transport == TransportType::Tcp {
520-
vec![39500]
520+
let (timeouts_ms, last_retransmit_timeout_ms) = if transport == TransportType::Tcp {
521+
(vec![], 39500)
521522
} else {
522-
vec![500, 1000, 2000, 4000, 8000, 16000]
523+
(vec![500, 1000, 2000, 4000, 8000, 16000], 8000)
523524
};
524525
Self {
525526
transaction_id: request.transaction_id(),
@@ -531,6 +532,7 @@ impl StunRequestState {
531532
|| request.has_attribute(MessageIntegritySha256::TYPE),
532533
timeouts_ms,
533534
timeout_i: 0,
535+
last_retransmit_timeout_ms,
534536
recv_cancelled: false,
535537
send_cancelled: false,
536538
last_send_time: None,
@@ -551,6 +553,10 @@ impl StunRequestState {
551553
// TODO: account for TCP connect in timeout
552554
if let Some(last_send) = self.last_send_time {
553555
if self.timeout_i >= self.timeouts_ms.len() {
556+
let next_send = last_send + Duration::from_millis(self.last_retransmit_timeout_ms);
557+
if next_send > now {
558+
return StunRequestPollRet::WaitUntil(next_send);
559+
}
554560
return StunRequestPollRet::TimedOut;
555561
}
556562
let next_send = last_send + Duration::from_millis(self.timeouts_ms[self.timeout_i]);
@@ -560,7 +566,7 @@ impl StunRequestState {
560566
self.timeout_i += 1;
561567
}
562568
if self.send_cancelled {
563-
// this calcelaltion may need a different value
569+
// this cancellation may need a different value
564570
return StunRequestPollRet::Cancelled;
565571
}
566572
self.last_send_time = Some(now);
@@ -623,6 +629,42 @@ impl<'a> StunRequestMut<'a> {
623629
pub fn mut_agent(&mut self) -> &mut StunAgent {
624630
self.agent
625631
}
632+
633+
/// Configure timeouts for the STUN transaction. As specified in RFC 8489, `initial_rto`
634+
/// should be >= 500ms, `retransmits` has a default value of 7, and `last_retransmit_timeout`
635+
/// should be 16 * `initial_rto`.
636+
///
637+
/// STUN transactions over TCP will only send a single request and have a timeout of the sum of
638+
/// the timeouts of a UDP transaction.
639+
pub fn configure_timeout(
640+
&mut self,
641+
initial_rto: Duration,
642+
retransmits: u32,
643+
last_retransmit_timeout: Duration,
644+
) {
645+
if let Some(state) = self.agent.mut_request_state(self.transaction_id) {
646+
match state.transport {
647+
TransportType::Udp => {
648+
state.timeouts_ms = (0..retransmits)
649+
.map(|i| (initial_rto * 2u32.pow(i)).as_millis() as u64)
650+
.collect::<Vec<_>>();
651+
state.last_retransmit_timeout_ms = last_retransmit_timeout.as_millis() as u64
652+
}
653+
TransportType::Tcp => {
654+
state.timeouts_ms = vec![];
655+
state.last_retransmit_timeout_ms = (last_retransmit_timeout
656+
+ (0..retransmits)
657+
.fold(Duration::ZERO, |acc, i| acc + initial_rto * 2u32.pow(i)))
658+
.as_millis() as u64;
659+
}
660+
}
661+
tracing::error!("timeouts: {:?}", state.timeouts_ms);
662+
tracing::error!(
663+
"last retransmit timeout: {:?}",
664+
state.last_retransmit_timeout_ms
665+
);
666+
}
667+
}
626668
}
627669

628670
/// Return value when handling possible STUN data
@@ -843,6 +885,115 @@ pub(crate) mod tests {
843885
assert!(!agent.is_validated_peer(remote_addr));
844886
}
845887

888+
#[test]
889+
fn request_custom_timeout() {
890+
let _log = crate::tests::test_init_log();
891+
let local_addr = "127.0.0.1:2000".parse().unwrap();
892+
let remote_addr = "127.0.0.1:1000".parse().unwrap();
893+
let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
894+
.remote_addr(remote_addr)
895+
.build();
896+
let msg = Message::builder_request(BINDING);
897+
let transaction_id = msg.transaction_id();
898+
let mut now = Instant::now();
899+
agent.send(msg, remote_addr, now).unwrap();
900+
let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
901+
transaction.configure_timeout(Duration::from_secs(1), 2, Duration::from_secs(10));
902+
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
903+
unreachable!();
904+
};
905+
assert_eq!(wait - now, Duration::from_secs(1));
906+
now = wait;
907+
let StunAgentPollRet::SendData(_) = agent.poll(now) else {
908+
unreachable!();
909+
};
910+
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
911+
unreachable!();
912+
};
913+
assert_eq!(wait - now, Duration::from_secs(2));
914+
now = wait;
915+
let StunAgentPollRet::SendData(_) = agent.poll(now) else {
916+
unreachable!();
917+
};
918+
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
919+
unreachable!();
920+
};
921+
assert_eq!(wait - now, Duration::from_secs(10));
922+
now = wait;
923+
let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
924+
unreachable!();
925+
};
926+
assert_eq!(timed_out, transaction_id);
927+
928+
assert!(agent.request_transaction(transaction_id).is_none());
929+
assert!(agent.mut_request_transaction(transaction_id).is_none());
930+
931+
// unvalidated peer data should be dropped
932+
assert!(!agent.is_validated_peer(remote_addr));
933+
}
934+
935+
#[test]
936+
fn request_no_retransmit() {
937+
let _log = crate::tests::test_init_log();
938+
let local_addr = "127.0.0.1:2000".parse().unwrap();
939+
let remote_addr = "127.0.0.1:1000".parse().unwrap();
940+
let mut agent = StunAgent::builder(TransportType::Udp, local_addr)
941+
.remote_addr(remote_addr)
942+
.build();
943+
let msg = Message::builder_request(BINDING);
944+
let transaction_id = msg.transaction_id();
945+
let mut now = Instant::now();
946+
agent.send(msg, remote_addr, now).unwrap();
947+
let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
948+
transaction.configure_timeout(Duration::from_secs(1), 0, Duration::from_secs(10));
949+
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
950+
unreachable!();
951+
};
952+
assert_eq!(wait - now, Duration::from_secs(10));
953+
now = wait;
954+
let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
955+
unreachable!();
956+
};
957+
assert_eq!(timed_out, transaction_id);
958+
959+
assert!(agent.request_transaction(transaction_id).is_none());
960+
assert!(agent.mut_request_transaction(transaction_id).is_none());
961+
962+
// unvalidated peer data should be dropped
963+
assert!(!agent.is_validated_peer(remote_addr));
964+
}
965+
966+
#[test]
967+
fn request_tcp_custom_timeout() {
968+
let _log = crate::tests::test_init_log();
969+
let local_addr = "127.0.0.1:2000".parse().unwrap();
970+
let remote_addr = "127.0.0.1:1000".parse().unwrap();
971+
let mut agent = StunAgent::builder(TransportType::Tcp, local_addr)
972+
.remote_addr(remote_addr)
973+
.build();
974+
let msg = Message::builder_request(BINDING);
975+
let transaction_id = msg.transaction_id();
976+
let mut now = Instant::now();
977+
agent.send(msg, remote_addr, now).unwrap();
978+
let mut transaction = agent.mut_request_transaction(transaction_id).unwrap();
979+
transaction.configure_timeout(Duration::from_secs(1), 3, Duration::from_secs(3));
980+
let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else {
981+
unreachable!();
982+
};
983+
assert_eq!(wait - now, Duration::from_secs(1 + 2 + 4 + 3));
984+
now = wait;
985+
let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else {
986+
unreachable!();
987+
};
988+
assert_eq!(timed_out, transaction_id);
989+
990+
assert!(agent.request_transaction(transaction_id).is_none());
991+
assert!(agent.mut_request_transaction(transaction_id).is_none());
992+
993+
// unvalidated peer data should be dropped
994+
assert!(!agent.is_validated_peer(remote_addr));
995+
}
996+
846997
#[test]
847998
fn request_without_credentials() {
848999
let _log = crate::tests::test_init_log();

0 commit comments

Comments
 (0)