@@ -174,7 +174,7 @@ impl StunAgent {
174
174
fields( stun_id = self . id)
175
175
) ]
176
176
fn validated_peer ( & mut self , addr : SocketAddr ) {
177
- if self . validated_peers . get ( & addr) . is_none ( ) {
177
+ if ! self . validated_peers . contains ( & addr) {
178
178
debug ! ( "validated peer {:?}" , addr) ;
179
179
self . validated_peers . insert ( addr) ;
180
180
}
@@ -502,6 +502,7 @@ struct StunRequestState {
502
502
from : SocketAddr ,
503
503
to : SocketAddr ,
504
504
timeouts_ms : Vec < u64 > ,
505
+ last_retransmit_timeout_ms : u64 ,
505
506
recv_cancelled : bool ,
506
507
send_cancelled : bool ,
507
508
timeout_i : usize ,
@@ -516,10 +517,10 @@ impl StunRequestState {
516
517
to : SocketAddr ,
517
518
) -> Self {
518
519
let data = request. build ( ) ;
519
- let timeouts_ms = if transport == TransportType :: Tcp {
520
- vec ! [ 39500 ]
520
+ let ( timeouts_ms, last_retransmit_timeout_ms ) = if transport == TransportType :: Tcp {
521
+ ( vec ! [ ] , 39500 )
521
522
} else {
522
- vec ! [ 500 , 1000 , 2000 , 4000 , 8000 , 16000 ]
523
+ ( vec ! [ 500 , 1000 , 2000 , 4000 , 8000 , 16000 ] , 8000 )
523
524
} ;
524
525
Self {
525
526
transaction_id : request. transaction_id ( ) ,
@@ -531,6 +532,7 @@ impl StunRequestState {
531
532
|| request. has_attribute ( MessageIntegritySha256 :: TYPE ) ,
532
533
timeouts_ms,
533
534
timeout_i : 0 ,
535
+ last_retransmit_timeout_ms,
534
536
recv_cancelled : false ,
535
537
send_cancelled : false ,
536
538
last_send_time : None ,
@@ -551,6 +553,10 @@ impl StunRequestState {
551
553
// TODO: account for TCP connect in timeout
552
554
if let Some ( last_send) = self . last_send_time {
553
555
if self . timeout_i >= self . timeouts_ms . len ( ) {
556
+ let next_send = last_send + Duration :: from_millis ( self . last_retransmit_timeout_ms ) ;
557
+ if next_send > now {
558
+ return StunRequestPollRet :: WaitUntil ( next_send) ;
559
+ }
554
560
return StunRequestPollRet :: TimedOut ;
555
561
}
556
562
let next_send = last_send + Duration :: from_millis ( self . timeouts_ms [ self . timeout_i ] ) ;
@@ -560,7 +566,7 @@ impl StunRequestState {
560
566
self . timeout_i += 1 ;
561
567
}
562
568
if self . send_cancelled {
563
- // this calcelaltion may need a different value
569
+ // this cancellation may need a different value
564
570
return StunRequestPollRet :: Cancelled ;
565
571
}
566
572
self . last_send_time = Some ( now) ;
@@ -623,6 +629,42 @@ impl<'a> StunRequestMut<'a> {
623
629
pub fn mut_agent ( & mut self ) -> & mut StunAgent {
624
630
self . agent
625
631
}
632
+
633
+ /// Configure timeouts for the STUN transaction. As specified in RFC 8489, `initial_rto`
634
+ /// should be >= 500ms, `retransmits` has a default value of 7, and `last_retransmit_timeout`
635
+ /// should be 16 * `initial_rto`.
636
+ ///
637
+ /// STUN transactions over TCP will only send a single request and have a timeout of the sum of
638
+ /// the timeouts of a UDP transaction.
639
+ pub fn configure_timeout (
640
+ & mut self ,
641
+ initial_rto : Duration ,
642
+ retransmits : u32 ,
643
+ last_retransmit_timeout : Duration ,
644
+ ) {
645
+ if let Some ( state) = self . agent . mut_request_state ( self . transaction_id ) {
646
+ match state. transport {
647
+ TransportType :: Udp => {
648
+ state. timeouts_ms = ( 0 ..retransmits)
649
+ . map ( |i| ( initial_rto * 2u32 . pow ( i) ) . as_millis ( ) as u64 )
650
+ . collect :: < Vec < _ > > ( ) ;
651
+ state. last_retransmit_timeout_ms = last_retransmit_timeout. as_millis ( ) as u64
652
+ }
653
+ TransportType :: Tcp => {
654
+ state. timeouts_ms = vec ! [ ] ;
655
+ state. last_retransmit_timeout_ms = ( last_retransmit_timeout
656
+ + ( 0 ..retransmits)
657
+ . fold ( Duration :: ZERO , |acc, i| acc + initial_rto * 2u32 . pow ( i) ) )
658
+ . as_millis ( ) as u64 ;
659
+ }
660
+ }
661
+ tracing:: error!( "timeouts: {:?}" , state. timeouts_ms) ;
662
+ tracing:: error!(
663
+ "last retransmit timeout: {:?}" ,
664
+ state. last_retransmit_timeout_ms
665
+ ) ;
666
+ }
667
+ }
626
668
}
627
669
628
670
/// Return value when handling possible STUN data
@@ -843,6 +885,115 @@ pub(crate) mod tests {
843
885
assert ! ( !agent. is_validated_peer( remote_addr) ) ;
844
886
}
845
887
888
+ #[ test]
889
+ fn request_custom_timeout ( ) {
890
+ let _log = crate :: tests:: test_init_log ( ) ;
891
+ let local_addr = "127.0.0.1:2000" . parse ( ) . unwrap ( ) ;
892
+ let remote_addr = "127.0.0.1:1000" . parse ( ) . unwrap ( ) ;
893
+ let mut agent = StunAgent :: builder ( TransportType :: Udp , local_addr)
894
+ . remote_addr ( remote_addr)
895
+ . build ( ) ;
896
+ let msg = Message :: builder_request ( BINDING ) ;
897
+ let transaction_id = msg. transaction_id ( ) ;
898
+ let mut now = Instant :: now ( ) ;
899
+ agent. send ( msg, remote_addr, now) . unwrap ( ) ;
900
+ let mut transaction = agent. mut_request_transaction ( transaction_id) . unwrap ( ) ;
901
+ transaction. configure_timeout ( Duration :: from_secs ( 1 ) , 2 , Duration :: from_secs ( 10 ) ) ;
902
+ let StunAgentPollRet :: WaitUntil ( wait) = agent. poll ( now) else {
903
+ unreachable ! ( ) ;
904
+ } ;
905
+ assert_eq ! ( wait - now, Duration :: from_secs( 1 ) ) ;
906
+ now = wait;
907
+ let StunAgentPollRet :: SendData ( _) = agent. poll ( now) else {
908
+ unreachable ! ( ) ;
909
+ } ;
910
+ let StunAgentPollRet :: WaitUntil ( wait) = agent. poll ( now) else {
911
+ unreachable ! ( ) ;
912
+ } ;
913
+ assert_eq ! ( wait - now, Duration :: from_secs( 2 ) ) ;
914
+ now = wait;
915
+ let StunAgentPollRet :: SendData ( _) = agent. poll ( now) else {
916
+ unreachable ! ( ) ;
917
+ } ;
918
+ let StunAgentPollRet :: WaitUntil ( wait) = agent. poll ( now) else {
919
+ unreachable ! ( ) ;
920
+ } ;
921
+ assert_eq ! ( wait - now, Duration :: from_secs( 10 ) ) ;
922
+ now = wait;
923
+ let StunAgentPollRet :: TransactionTimedOut ( timed_out) = agent. poll ( now) else {
924
+ unreachable ! ( ) ;
925
+ } ;
926
+ assert_eq ! ( timed_out, transaction_id) ;
927
+
928
+ assert ! ( agent. request_transaction( transaction_id) . is_none( ) ) ;
929
+ assert ! ( agent. mut_request_transaction( transaction_id) . is_none( ) ) ;
930
+
931
+ // unvalidated peer data should be dropped
932
+ assert ! ( !agent. is_validated_peer( remote_addr) ) ;
933
+ }
934
+
935
+ #[ test]
936
+ fn request_no_retransmit ( ) {
937
+ let _log = crate :: tests:: test_init_log ( ) ;
938
+ let local_addr = "127.0.0.1:2000" . parse ( ) . unwrap ( ) ;
939
+ let remote_addr = "127.0.0.1:1000" . parse ( ) . unwrap ( ) ;
940
+ let mut agent = StunAgent :: builder ( TransportType :: Udp , local_addr)
941
+ . remote_addr ( remote_addr)
942
+ . build ( ) ;
943
+ let msg = Message :: builder_request ( BINDING ) ;
944
+ let transaction_id = msg. transaction_id ( ) ;
945
+ let mut now = Instant :: now ( ) ;
946
+ agent. send ( msg, remote_addr, now) . unwrap ( ) ;
947
+ let mut transaction = agent. mut_request_transaction ( transaction_id) . unwrap ( ) ;
948
+ transaction. configure_timeout ( Duration :: from_secs ( 1 ) , 0 , Duration :: from_secs ( 10 ) ) ;
949
+ let StunAgentPollRet :: WaitUntil ( wait) = agent. poll ( now) else {
950
+ unreachable ! ( ) ;
951
+ } ;
952
+ assert_eq ! ( wait - now, Duration :: from_secs( 10 ) ) ;
953
+ now = wait;
954
+ let StunAgentPollRet :: TransactionTimedOut ( timed_out) = agent. poll ( now) else {
955
+ unreachable ! ( ) ;
956
+ } ;
957
+ assert_eq ! ( timed_out, transaction_id) ;
958
+
959
+ assert ! ( agent. request_transaction( transaction_id) . is_none( ) ) ;
960
+ assert ! ( agent. mut_request_transaction( transaction_id) . is_none( ) ) ;
961
+
962
+ // unvalidated peer data should be dropped
963
+ assert ! ( !agent. is_validated_peer( remote_addr) ) ;
964
+ }
965
+
966
+ #[ test]
967
+ fn request_tcp_custom_timeout ( ) {
968
+ let _log = crate :: tests:: test_init_log ( ) ;
969
+ let local_addr = "127.0.0.1:2000" . parse ( ) . unwrap ( ) ;
970
+ let remote_addr = "127.0.0.1:1000" . parse ( ) . unwrap ( ) ;
971
+ let mut agent = StunAgent :: builder ( TransportType :: Tcp , local_addr)
972
+ . remote_addr ( remote_addr)
973
+ . build ( ) ;
974
+ let msg = Message :: builder_request ( BINDING ) ;
975
+ let transaction_id = msg. transaction_id ( ) ;
976
+ let mut now = Instant :: now ( ) ;
977
+ agent. send ( msg, remote_addr, now) . unwrap ( ) ;
978
+ let mut transaction = agent. mut_request_transaction ( transaction_id) . unwrap ( ) ;
979
+ transaction. configure_timeout ( Duration :: from_secs ( 1 ) , 3 , Duration :: from_secs ( 3 ) ) ;
980
+ let StunAgentPollRet :: WaitUntil ( wait) = agent. poll ( now) else {
981
+ unreachable ! ( ) ;
982
+ } ;
983
+ assert_eq ! ( wait - now, Duration :: from_secs( 1 + 2 + 4 + 3 ) ) ;
984
+ now = wait;
985
+ let StunAgentPollRet :: TransactionTimedOut ( timed_out) = agent. poll ( now) else {
986
+ unreachable ! ( ) ;
987
+ } ;
988
+ assert_eq ! ( timed_out, transaction_id) ;
989
+
990
+ assert ! ( agent. request_transaction( transaction_id) . is_none( ) ) ;
991
+ assert ! ( agent. mut_request_transaction( transaction_id) . is_none( ) ) ;
992
+
993
+ // unvalidated peer data should be dropped
994
+ assert ! ( !agent. is_validated_peer( remote_addr) ) ;
995
+ }
996
+
846
997
#[ test]
847
998
fn request_without_credentials ( ) {
848
999
let _log = crate :: tests:: test_init_log ( ) ;
0 commit comments