@@ -1468,15 +1468,7 @@ impl<'a> MessageBuilder<'a> {
1468
1468
. map ( |attr| attr. padded_len ( ) )
1469
1469
. sum :: < usize > ( ) ;
1470
1470
let mut ret = vec ! [ 0 ; MessageHeader :: LENGTH + attr_size] ;
1471
- self . msg_type . write_into ( & mut ret[ ..2 ] ) ;
1472
- let transaction: u128 = self . transaction_id . into ( ) ;
1473
- let tid = ( MAGIC_COOKIE as u128 ) << 96 | transaction & 0xffff_ffff_ffff_ffff_ffff_ffff ;
1474
- BigEndian :: write_u128 ( & mut ret[ 4 ..20 ] , tid) ;
1475
- BigEndian :: write_u16 ( & mut ret[ 2 ..4 ] , attr_size as u16 ) ;
1476
- let mut offset = MessageHeader :: LENGTH ;
1477
- for attr in & self . attributes {
1478
- offset += attr. write_into ( & mut ret[ offset..] ) . unwrap ( ) ;
1479
- }
1471
+ let _ = self . write_into ( & mut ret) ;
1480
1472
ret
1481
1473
}
1482
1474
@@ -1574,16 +1566,32 @@ impl<'a> MessageBuilder<'a> {
1574
1566
credentials : & MessageIntegrityCredentials ,
1575
1567
algorithm : IntegrityAlgorithm ,
1576
1568
) -> Result < ( ) , StunWriteError > {
1577
- if self . has_attribute ( MessageIntegrity :: TYPE ) && algorithm == IntegrityAlgorithm :: Sha1 {
1578
- return Err ( StunWriteError :: AttributeExists ( MessageIntegrity :: TYPE ) ) ;
1579
- }
1580
- if self . has_attribute ( MessageIntegritySha256 :: TYPE ) {
1581
- return Err ( StunWriteError :: AttributeExists (
1582
- MessageIntegritySha256 :: TYPE ,
1583
- ) ) ;
1569
+ let mut atypes = [ AttributeType :: new ( 0 ) ; 3 ] ;
1570
+ let mut i = 0 ;
1571
+ atypes[ i] = match algorithm {
1572
+ IntegrityAlgorithm :: Sha1 => MessageIntegrity :: TYPE ,
1573
+ IntegrityAlgorithm :: Sha256 => MessageIntegritySha256 :: TYPE ,
1574
+ } ;
1575
+ i += 1 ;
1576
+ if algorithm == IntegrityAlgorithm :: Sha1 {
1577
+ atypes[ i] = MessageIntegritySha256 :: TYPE ;
1578
+ i += 1 ;
1584
1579
}
1585
- if self . has_attribute ( Fingerprint :: TYPE ) {
1586
- return Err ( StunWriteError :: FingerprintExists ) ;
1580
+ atypes[ i] = Fingerprint :: TYPE ;
1581
+ i += 1 ;
1582
+
1583
+ match self . has_any_attribute ( & atypes[ ..i] ) {
1584
+ // can't validly add generic attributes after message integrity or fingerprint
1585
+ Some ( MessageIntegrity :: TYPE ) => {
1586
+ return Err ( StunWriteError :: AttributeExists ( MessageIntegrity :: TYPE ) )
1587
+ }
1588
+ Some ( MessageIntegritySha256 :: TYPE ) => {
1589
+ return Err ( StunWriteError :: AttributeExists (
1590
+ MessageIntegritySha256 :: TYPE ,
1591
+ ) ) ;
1592
+ }
1593
+ Some ( Fingerprint :: TYPE ) => return Err ( StunWriteError :: FingerprintExists ) ,
1594
+ _ => ( ) ,
1587
1595
}
1588
1596
1589
1597
self . add_message_integrity_unchecked ( credentials, algorithm) ;
@@ -1720,18 +1728,20 @@ impl<'a> MessageBuilder<'a> {
1720
1728
}
1721
1729
_ => ( ) ,
1722
1730
}
1723
- if self . has_attribute ( ty) {
1724
- return Err ( StunWriteError :: AttributeExists ( ty) ) ;
1725
- }
1726
- // can't validly add generic attributes after message integrity or fingerprint
1727
- if self . has_attribute ( MessageIntegrity :: TYPE ) {
1728
- return Err ( StunWriteError :: MessageIntegrityExists ) ;
1729
- }
1730
- if self . has_attribute ( MessageIntegritySha256 :: TYPE ) {
1731
- return Err ( StunWriteError :: MessageIntegrityExists ) ;
1732
- }
1733
- if self . has_attribute ( Fingerprint :: TYPE ) {
1734
- return Err ( StunWriteError :: FingerprintExists ) ;
1731
+ match self . has_any_attribute ( & [
1732
+ ty,
1733
+ MessageIntegrity :: TYPE ,
1734
+ MessageIntegritySha256 :: TYPE ,
1735
+ Fingerprint :: TYPE ,
1736
+ ] ) {
1737
+ // can't validly add generic attributes after message integrity or fingerprint
1738
+ Some ( MessageIntegrity :: TYPE ) => return Err ( StunWriteError :: MessageIntegrityExists ) ,
1739
+ Some ( MessageIntegritySha256 :: TYPE ) => {
1740
+ return Err ( StunWriteError :: MessageIntegrityExists )
1741
+ }
1742
+ Some ( Fingerprint :: TYPE ) => return Err ( StunWriteError :: FingerprintExists ) ,
1743
+ Some ( typ) if typ == ty => return Err ( StunWriteError :: AttributeExists ( ty) ) ,
1744
+ _ => ( ) ,
1735
1745
}
1736
1746
self . attributes . push ( AttrOrRaw :: Attr ( attr) ) ;
1737
1747
self . attribute_types . push ( ty) ;
@@ -1777,18 +1787,20 @@ impl<'a> MessageBuilder<'a> {
1777
1787
}
1778
1788
_ => ( ) ,
1779
1789
}
1780
- if self . has_attribute ( ty) {
1781
- return Err ( StunWriteError :: AttributeExists ( ty) ) ;
1782
- }
1783
- // can't validly add generic attributes after message integrity or fingerprint
1784
- if self . has_attribute ( MessageIntegrity :: TYPE ) {
1785
- return Err ( StunWriteError :: MessageIntegrityExists ) ;
1786
- }
1787
- if self . has_attribute ( MessageIntegritySha256 :: TYPE ) {
1788
- return Err ( StunWriteError :: MessageIntegrityExists ) ;
1789
- }
1790
- if self . has_attribute ( Fingerprint :: TYPE ) {
1791
- return Err ( StunWriteError :: FingerprintExists ) ;
1790
+ match self . has_any_attribute ( & [
1791
+ ty,
1792
+ MessageIntegrity :: TYPE ,
1793
+ MessageIntegritySha256 :: TYPE ,
1794
+ Fingerprint :: TYPE ,
1795
+ ] ) {
1796
+ // can't validly add generic attributes after message integrity or fingerprint
1797
+ Some ( MessageIntegrity :: TYPE ) => return Err ( StunWriteError :: MessageIntegrityExists ) ,
1798
+ Some ( MessageIntegritySha256 :: TYPE ) => {
1799
+ return Err ( StunWriteError :: MessageIntegrityExists )
1800
+ }
1801
+ Some ( Fingerprint :: TYPE ) => return Err ( StunWriteError :: FingerprintExists ) ,
1802
+ Some ( typ) if typ == ty => return Err ( StunWriteError :: AttributeExists ( ty) ) ,
1803
+ _ => ( ) ,
1792
1804
}
1793
1805
self . attributes . push ( AttrOrRaw :: Raw ( attr) ) ;
1794
1806
self . attribute_types . push ( ty) ;
@@ -1799,6 +1811,15 @@ impl<'a> MessageBuilder<'a> {
1799
1811
pub fn has_attribute ( & self , atype : AttributeType ) -> bool {
1800
1812
self . attribute_types . iter ( ) . any ( |& ty| ty == atype)
1801
1813
}
1814
+
1815
+ /// Return whether this [`MessageBuilder`] contains any of the provided attributes and
1816
+ /// returns the attribute found.
1817
+ pub fn has_any_attribute ( & self , atypes : & [ AttributeType ] ) -> Option < AttributeType > {
1818
+ self . attribute_types
1819
+ . iter ( )
1820
+ . find ( |& ty| atypes. contains ( ty) )
1821
+ . cloned ( )
1822
+ }
1802
1823
}
1803
1824
1804
1825
#[ cfg( test) ]
0 commit comments