@@ -174,7 +174,7 @@ impl StunAgent {
174
174
fields( stun_id = self . id)
175
175
) ]
176
176
fn validated_peer ( & mut self , addr : SocketAddr ) {
177
- if self . validated_peers . get ( & addr) . is_none ( ) {
177
+ if ! self . validated_peers . contains ( & addr) {
178
178
debug ! ( "validated peer {:?}" , addr) ;
179
179
self . validated_peers . insert ( addr) ;
180
180
}
@@ -502,6 +502,7 @@ struct StunRequestState {
502
502
from : SocketAddr ,
503
503
to : SocketAddr ,
504
504
timeouts_ms : Vec < u64 > ,
505
+ last_retransmit_timeout_ms : u64 ,
505
506
recv_cancelled : bool ,
506
507
send_cancelled : bool ,
507
508
timeout_i : usize ,
@@ -516,10 +517,10 @@ impl StunRequestState {
516
517
to : SocketAddr ,
517
518
) -> Self {
518
519
let data = request. build ( ) ;
519
- let timeouts_ms = if transport == TransportType :: Tcp {
520
- vec ! [ 39500 ]
520
+ let ( timeouts_ms, last_retransmit_timeout_ms ) = if transport == TransportType :: Tcp {
521
+ ( vec ! [ ] , 39500 )
521
522
} else {
522
- vec ! [ 500 , 1000 , 2000 , 4000 , 8000 , 16000 ]
523
+ ( vec ! [ 500 , 1000 , 2000 , 4000 , 8000 , 16000 ] , 8000 )
523
524
} ;
524
525
Self {
525
526
transaction_id : request. transaction_id ( ) ,
@@ -531,6 +532,7 @@ impl StunRequestState {
531
532
|| request. has_attribute ( MessageIntegritySha256 :: TYPE ) ,
532
533
timeouts_ms,
533
534
timeout_i : 0 ,
535
+ last_retransmit_timeout_ms,
534
536
recv_cancelled : false ,
535
537
send_cancelled : false ,
536
538
last_send_time : None ,
@@ -551,6 +553,10 @@ impl StunRequestState {
551
553
// TODO: account for TCP connect in timeout
552
554
if let Some ( last_send) = self . last_send_time {
553
555
if self . timeout_i >= self . timeouts_ms . len ( ) {
556
+ let next_send = last_send + Duration :: from_millis ( self . last_retransmit_timeout_ms ) ;
557
+ if next_send > now {
558
+ return StunRequestPollRet :: WaitUntil ( next_send) ;
559
+ }
554
560
return StunRequestPollRet :: TimedOut ;
555
561
}
556
562
let next_send = last_send + Duration :: from_millis ( self . timeouts_ms [ self . timeout_i ] ) ;
@@ -560,7 +566,7 @@ impl StunRequestState {
560
566
self . timeout_i += 1 ;
561
567
}
562
568
if self . send_cancelled {
563
- // this calcelaltion may need a different value
569
+ // this cancellation may need a different value
564
570
return StunRequestPollRet :: Cancelled ;
565
571
}
566
572
self . last_send_time = Some ( now) ;
@@ -623,6 +629,37 @@ impl<'a> StunRequestMut<'a> {
623
629
pub fn mut_agent ( & mut self ) -> & mut StunAgent {
624
630
self . agent
625
631
}
632
+
633
+ /// Configure timeouts for the STUN transaction. As specified in RFC 8489, `initial_rto`
634
+ /// should be >= 500ms, `retransmits` has a default value of 7, and `last_retransmit_timeout`
635
+ /// should be 16 * `initial_rto`.
636
+ ///
637
+ /// STUN transactions over TCP will only send a single request and have a timeout of the sum of
638
+ /// the timeouts of a UDP transaction.
639
+ pub fn configure_timeout (
640
+ & mut self ,
641
+ initial_rto : Duration ,
642
+ retransmits : u32 ,
643
+ last_retransmit_timeout : Duration ,
644
+ ) {
645
+ if let Some ( state) = self . agent . mut_request_state ( self . transaction_id ) {
646
+ match state. transport {
647
+ TransportType :: Udp => {
648
+ state. timeouts_ms = ( 0 ..retransmits)
649
+ . map ( |i| ( initial_rto * 2u32 . pow ( i) ) . as_millis ( ) as u64 )
650
+ . collect :: < Vec < _ > > ( ) ;
651
+ state. last_retransmit_timeout_ms = last_retransmit_timeout. as_millis ( ) as u64
652
+ }
653
+ TransportType :: Tcp => {
654
+ state. timeouts_ms = vec ! [ ] ;
655
+ state. last_retransmit_timeout_ms = ( last_retransmit_timeout
656
+ + ( 0 ..retransmits)
657
+ . fold ( Duration :: ZERO , |acc, i| acc + initial_rto * 2u32 . pow ( i) ) )
658
+ . as_millis ( ) as u64 ;
659
+ }
660
+ }
661
+ }
662
+ }
626
663
}
627
664
628
665
/// Return value when handling possible STUN data
@@ -843,6 +880,115 @@ pub(crate) mod tests {
843
880
assert ! ( !agent. is_validated_peer( remote_addr) ) ;
844
881
}
845
882
883
+ #[ test]
884
+ fn request_custom_timeout ( ) {
885
+ let _log = crate :: tests:: test_init_log ( ) ;
886
+ let local_addr = "127.0.0.1:2000" . parse ( ) . unwrap ( ) ;
887
+ let remote_addr = "127.0.0.1:1000" . parse ( ) . unwrap ( ) ;
888
+ let mut agent = StunAgent :: builder ( TransportType :: Udp , local_addr)
889
+ . remote_addr ( remote_addr)
890
+ . build ( ) ;
891
+ let msg = Message :: builder_request ( BINDING ) ;
892
+ let transaction_id = msg. transaction_id ( ) ;
893
+ let mut now = Instant :: now ( ) ;
894
+ agent. send ( msg, remote_addr, now) . unwrap ( ) ;
895
+ let mut transaction = agent. mut_request_transaction ( transaction_id) . unwrap ( ) ;
896
+ transaction. configure_timeout ( Duration :: from_secs ( 1 ) , 2 , Duration :: from_secs ( 10 ) ) ;
897
+ let StunAgentPollRet :: WaitUntil ( wait) = agent. poll ( now) else {
898
+ unreachable ! ( ) ;
899
+ } ;
900
+ assert_eq ! ( wait - now, Duration :: from_secs( 1 ) ) ;
901
+ now = wait;
902
+ let StunAgentPollRet :: SendData ( _) = agent. poll ( now) else {
903
+ unreachable ! ( ) ;
904
+ } ;
905
+ let StunAgentPollRet :: WaitUntil ( wait) = agent. poll ( now) else {
906
+ unreachable ! ( ) ;
907
+ } ;
908
+ assert_eq ! ( wait - now, Duration :: from_secs( 2 ) ) ;
909
+ now = wait;
910
+ let StunAgentPollRet :: SendData ( _) = agent. poll ( now) else {
911
+ unreachable ! ( ) ;
912
+ } ;
913
+ let StunAgentPollRet :: WaitUntil ( wait) = agent. poll ( now) else {
914
+ unreachable ! ( ) ;
915
+ } ;
916
+ assert_eq ! ( wait - now, Duration :: from_secs( 10 ) ) ;
917
+ now = wait;
918
+ let StunAgentPollRet :: TransactionTimedOut ( timed_out) = agent. poll ( now) else {
919
+ unreachable ! ( ) ;
920
+ } ;
921
+ assert_eq ! ( timed_out, transaction_id) ;
922
+
923
+ assert ! ( agent. request_transaction( transaction_id) . is_none( ) ) ;
924
+ assert ! ( agent. mut_request_transaction( transaction_id) . is_none( ) ) ;
925
+
926
+ // unvalidated peer data should be dropped
927
+ assert ! ( !agent. is_validated_peer( remote_addr) ) ;
928
+ }
929
+
930
+ #[ test]
931
+ fn request_no_retransmit ( ) {
932
+ let _log = crate :: tests:: test_init_log ( ) ;
933
+ let local_addr = "127.0.0.1:2000" . parse ( ) . unwrap ( ) ;
934
+ let remote_addr = "127.0.0.1:1000" . parse ( ) . unwrap ( ) ;
935
+ let mut agent = StunAgent :: builder ( TransportType :: Udp , local_addr)
936
+ . remote_addr ( remote_addr)
937
+ . build ( ) ;
938
+ let msg = Message :: builder_request ( BINDING ) ;
939
+ let transaction_id = msg. transaction_id ( ) ;
940
+ let mut now = Instant :: now ( ) ;
941
+ agent. send ( msg, remote_addr, now) . unwrap ( ) ;
942
+ let mut transaction = agent. mut_request_transaction ( transaction_id) . unwrap ( ) ;
943
+ transaction. configure_timeout ( Duration :: from_secs ( 1 ) , 0 , Duration :: from_secs ( 10 ) ) ;
944
+ let StunAgentPollRet :: WaitUntil ( wait) = agent. poll ( now) else {
945
+ unreachable ! ( ) ;
946
+ } ;
947
+ assert_eq ! ( wait - now, Duration :: from_secs( 10 ) ) ;
948
+ now = wait;
949
+ let StunAgentPollRet :: TransactionTimedOut ( timed_out) = agent. poll ( now) else {
950
+ unreachable ! ( ) ;
951
+ } ;
952
+ assert_eq ! ( timed_out, transaction_id) ;
953
+
954
+ assert ! ( agent. request_transaction( transaction_id) . is_none( ) ) ;
955
+ assert ! ( agent. mut_request_transaction( transaction_id) . is_none( ) ) ;
956
+
957
+ // unvalidated peer data should be dropped
958
+ assert ! ( !agent. is_validated_peer( remote_addr) ) ;
959
+ }
960
+
961
+ #[ test]
962
+ fn request_tcp_custom_timeout ( ) {
963
+ let _log = crate :: tests:: test_init_log ( ) ;
964
+ let local_addr = "127.0.0.1:2000" . parse ( ) . unwrap ( ) ;
965
+ let remote_addr = "127.0.0.1:1000" . parse ( ) . unwrap ( ) ;
966
+ let mut agent = StunAgent :: builder ( TransportType :: Tcp , local_addr)
967
+ . remote_addr ( remote_addr)
968
+ . build ( ) ;
969
+ let msg = Message :: builder_request ( BINDING ) ;
970
+ let transaction_id = msg. transaction_id ( ) ;
971
+ let mut now = Instant :: now ( ) ;
972
+ agent. send ( msg, remote_addr, now) . unwrap ( ) ;
973
+ let mut transaction = agent. mut_request_transaction ( transaction_id) . unwrap ( ) ;
974
+ transaction. configure_timeout ( Duration :: from_secs ( 1 ) , 3 , Duration :: from_secs ( 3 ) ) ;
975
+ let StunAgentPollRet :: WaitUntil ( wait) = agent. poll ( now) else {
976
+ unreachable ! ( ) ;
977
+ } ;
978
+ assert_eq ! ( wait - now, Duration :: from_secs( 1 + 2 + 4 + 3 ) ) ;
979
+ now = wait;
980
+ let StunAgentPollRet :: TransactionTimedOut ( timed_out) = agent. poll ( now) else {
981
+ unreachable ! ( ) ;
982
+ } ;
983
+ assert_eq ! ( timed_out, transaction_id) ;
984
+
985
+ assert ! ( agent. request_transaction( transaction_id) . is_none( ) ) ;
986
+ assert ! ( agent. mut_request_transaction( transaction_id) . is_none( ) ) ;
987
+
988
+ // unvalidated peer data should be dropped
989
+ assert ! ( !agent. is_validated_peer( remote_addr) ) ;
990
+ }
991
+
846
992
#[ test]
847
993
fn request_without_credentials ( ) {
848
994
let _log = crate :: tests:: test_init_log ( ) ;
0 commit comments