From 330baeeaa3657fe9cc0e8529dceb4e25c62f2f5a Mon Sep 17 00:00:00 2001 From: Matthew Waters Date: Tue, 15 Oct 2024 07:39:47 +1100 Subject: [PATCH] agent: support arbitrary STUN request timeouts --- stun-proto/src/agent.rs | 156 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 151 insertions(+), 5 deletions(-) diff --git a/stun-proto/src/agent.rs b/stun-proto/src/agent.rs index b8c5021..00b7942 100644 --- a/stun-proto/src/agent.rs +++ b/stun-proto/src/agent.rs @@ -174,7 +174,7 @@ impl StunAgent { fields(stun_id = self.id) )] fn validated_peer(&mut self, addr: SocketAddr) { - if self.validated_peers.get(&addr).is_none() { + if !self.validated_peers.contains(&addr) { debug!("validated peer {:?}", addr); self.validated_peers.insert(addr); } @@ -502,6 +502,7 @@ struct StunRequestState { from: SocketAddr, to: SocketAddr, timeouts_ms: Vec, + last_retransmit_timeout_ms: u64, recv_cancelled: bool, send_cancelled: bool, timeout_i: usize, @@ -516,10 +517,10 @@ impl StunRequestState { to: SocketAddr, ) -> Self { let data = request.build(); - let timeouts_ms = if transport == TransportType::Tcp { - vec![39500] + let (timeouts_ms, last_retransmit_timeout_ms) = if transport == TransportType::Tcp { + (vec![], 39500) } else { - vec![500, 1000, 2000, 4000, 8000, 16000] + (vec![500, 1000, 2000, 4000, 8000, 16000], 8000) }; Self { transaction_id: request.transaction_id(), @@ -531,6 +532,7 @@ impl StunRequestState { || request.has_attribute(MessageIntegritySha256::TYPE), timeouts_ms, timeout_i: 0, + last_retransmit_timeout_ms, recv_cancelled: false, send_cancelled: false, last_send_time: None, @@ -551,6 +553,10 @@ impl StunRequestState { // TODO: account for TCP connect in timeout if let Some(last_send) = self.last_send_time { if self.timeout_i >= self.timeouts_ms.len() { + let next_send = last_send + Duration::from_millis(self.last_retransmit_timeout_ms); + if next_send > now { + return StunRequestPollRet::WaitUntil(next_send); + } return StunRequestPollRet::TimedOut; } let next_send = last_send + Duration::from_millis(self.timeouts_ms[self.timeout_i]); @@ -560,7 +566,7 @@ impl StunRequestState { self.timeout_i += 1; } if self.send_cancelled { - // this calcelaltion may need a different value + // this cancellation may need a different value return StunRequestPollRet::Cancelled; } self.last_send_time = Some(now); @@ -623,6 +629,37 @@ impl<'a> StunRequestMut<'a> { pub fn mut_agent(&mut self) -> &mut StunAgent { self.agent } + + /// Configure timeouts for the STUN transaction. As specified in RFC 8489, `initial_rto` + /// should be >= 500ms, `retransmits` has a default value of 7, and `last_retransmit_timeout` + /// should be 16 * `initial_rto`. + /// + /// STUN transactions over TCP will only send a single request and have a timeout of the sum of + /// the timeouts of a UDP transaction. + pub fn configure_timeout( + &mut self, + initial_rto: Duration, + retransmits: u32, + last_retransmit_timeout: Duration, + ) { + if let Some(state) = self.agent.mut_request_state(self.transaction_id) { + match state.transport { + TransportType::Udp => { + state.timeouts_ms = (0..retransmits) + .map(|i| (initial_rto * 2u32.pow(i)).as_millis() as u64) + .collect::>(); + state.last_retransmit_timeout_ms = last_retransmit_timeout.as_millis() as u64 + } + TransportType::Tcp => { + state.timeouts_ms = vec![]; + state.last_retransmit_timeout_ms = (last_retransmit_timeout + + (0..retransmits) + .fold(Duration::ZERO, |acc, i| acc + initial_rto * 2u32.pow(i))) + .as_millis() as u64; + } + } + } + } } /// Return value when handling possible STUN data @@ -843,6 +880,115 @@ pub(crate) mod tests { assert!(!agent.is_validated_peer(remote_addr)); } + #[test] + fn request_custom_timeout() { + let _log = crate::tests::test_init_log(); + let local_addr = "127.0.0.1:2000".parse().unwrap(); + let remote_addr = "127.0.0.1:1000".parse().unwrap(); + let mut agent = StunAgent::builder(TransportType::Udp, local_addr) + .remote_addr(remote_addr) + .build(); + let msg = Message::builder_request(BINDING); + let transaction_id = msg.transaction_id(); + let mut now = Instant::now(); + agent.send(msg, remote_addr, now).unwrap(); + let mut transaction = agent.mut_request_transaction(transaction_id).unwrap(); + transaction.configure_timeout(Duration::from_secs(1), 2, Duration::from_secs(10)); + let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else { + unreachable!(); + }; + assert_eq!(wait - now, Duration::from_secs(1)); + now = wait; + let StunAgentPollRet::SendData(_) = agent.poll(now) else { + unreachable!(); + }; + let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else { + unreachable!(); + }; + assert_eq!(wait - now, Duration::from_secs(2)); + now = wait; + let StunAgentPollRet::SendData(_) = agent.poll(now) else { + unreachable!(); + }; + let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else { + unreachable!(); + }; + assert_eq!(wait - now, Duration::from_secs(10)); + now = wait; + let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else { + unreachable!(); + }; + assert_eq!(timed_out, transaction_id); + + assert!(agent.request_transaction(transaction_id).is_none()); + assert!(agent.mut_request_transaction(transaction_id).is_none()); + + // unvalidated peer data should be dropped + assert!(!agent.is_validated_peer(remote_addr)); + } + + #[test] + fn request_no_retransmit() { + let _log = crate::tests::test_init_log(); + let local_addr = "127.0.0.1:2000".parse().unwrap(); + let remote_addr = "127.0.0.1:1000".parse().unwrap(); + let mut agent = StunAgent::builder(TransportType::Udp, local_addr) + .remote_addr(remote_addr) + .build(); + let msg = Message::builder_request(BINDING); + let transaction_id = msg.transaction_id(); + let mut now = Instant::now(); + agent.send(msg, remote_addr, now).unwrap(); + let mut transaction = agent.mut_request_transaction(transaction_id).unwrap(); + transaction.configure_timeout(Duration::from_secs(1), 0, Duration::from_secs(10)); + let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else { + unreachable!(); + }; + assert_eq!(wait - now, Duration::from_secs(10)); + now = wait; + let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else { + unreachable!(); + }; + assert_eq!(timed_out, transaction_id); + + assert!(agent.request_transaction(transaction_id).is_none()); + assert!(agent.mut_request_transaction(transaction_id).is_none()); + + // unvalidated peer data should be dropped + assert!(!agent.is_validated_peer(remote_addr)); + } + + #[test] + fn request_tcp_custom_timeout() { + let _log = crate::tests::test_init_log(); + let local_addr = "127.0.0.1:2000".parse().unwrap(); + let remote_addr = "127.0.0.1:1000".parse().unwrap(); + let mut agent = StunAgent::builder(TransportType::Tcp, local_addr) + .remote_addr(remote_addr) + .build(); + let msg = Message::builder_request(BINDING); + let transaction_id = msg.transaction_id(); + let mut now = Instant::now(); + agent.send(msg, remote_addr, now).unwrap(); + let mut transaction = agent.mut_request_transaction(transaction_id).unwrap(); + transaction.configure_timeout(Duration::from_secs(1), 3, Duration::from_secs(3)); + let StunAgentPollRet::WaitUntil(wait) = agent.poll(now) else { + unreachable!(); + }; + assert_eq!(wait - now, Duration::from_secs(1 + 2 + 4 + 3)); + now = wait; + let StunAgentPollRet::TransactionTimedOut(timed_out) = agent.poll(now) else { + unreachable!(); + }; + assert_eq!(timed_out, transaction_id); + + assert!(agent.request_transaction(transaction_id).is_none()); + assert!(agent.mut_request_transaction(transaction_id).is_none()); + + // unvalidated peer data should be dropped + assert!(!agent.is_validated_peer(remote_addr)); + } + #[test] fn request_without_credentials() { let _log = crate::tests::test_init_log();