diff --git a/stun-types/src/message.rs b/stun-types/src/message.rs index 77fd27e..9cd0ab5 100644 --- a/stun-types/src/message.rs +++ b/stun-types/src/message.rs @@ -1468,15 +1468,7 @@ impl<'a> MessageBuilder<'a> { .map(|attr| attr.padded_len()) .sum::(); let mut ret = vec![0; MessageHeader::LENGTH + attr_size]; - self.msg_type.write_into(&mut ret[..2]); - let transaction: u128 = self.transaction_id.into(); - let tid = (MAGIC_COOKIE as u128) << 96 | transaction & 0xffff_ffff_ffff_ffff_ffff_ffff; - BigEndian::write_u128(&mut ret[4..20], tid); - BigEndian::write_u16(&mut ret[2..4], attr_size as u16); - let mut offset = MessageHeader::LENGTH; - for attr in &self.attributes { - offset += attr.write_into(&mut ret[offset..]).unwrap(); - } + let _ = self.write_into(&mut ret); ret } @@ -1574,16 +1566,32 @@ impl<'a> MessageBuilder<'a> { credentials: &MessageIntegrityCredentials, algorithm: IntegrityAlgorithm, ) -> Result<(), StunWriteError> { - if self.has_attribute(MessageIntegrity::TYPE) && algorithm == IntegrityAlgorithm::Sha1 { - return Err(StunWriteError::AttributeExists(MessageIntegrity::TYPE)); - } - if self.has_attribute(MessageIntegritySha256::TYPE) { - return Err(StunWriteError::AttributeExists( - MessageIntegritySha256::TYPE, - )); + let mut atypes = [AttributeType::new(0); 3]; + let mut i = 0; + atypes[i] = match algorithm { + IntegrityAlgorithm::Sha1 => MessageIntegrity::TYPE, + IntegrityAlgorithm::Sha256 => MessageIntegritySha256::TYPE, + }; + i += 1; + if algorithm == IntegrityAlgorithm::Sha1 { + atypes[i] = MessageIntegritySha256::TYPE; + i += 1; } - if self.has_attribute(Fingerprint::TYPE) { - return Err(StunWriteError::FingerprintExists); + atypes[i] = Fingerprint::TYPE; + i += 1; + + match self.has_any_attribute(&atypes[..i]) { + // can't validly add generic attributes after message integrity or fingerprint + Some(MessageIntegrity::TYPE) => { + return Err(StunWriteError::AttributeExists(MessageIntegrity::TYPE)) + } + Some(MessageIntegritySha256::TYPE) => { + return Err(StunWriteError::AttributeExists( + MessageIntegritySha256::TYPE, + )); + } + Some(Fingerprint::TYPE) => return Err(StunWriteError::FingerprintExists), + _ => (), } self.add_message_integrity_unchecked(credentials, algorithm); @@ -1720,18 +1728,20 @@ impl<'a> MessageBuilder<'a> { } _ => (), } - if self.has_attribute(ty) { - return Err(StunWriteError::AttributeExists(ty)); - } - // can't validly add generic attributes after message integrity or fingerprint - if self.has_attribute(MessageIntegrity::TYPE) { - return Err(StunWriteError::MessageIntegrityExists); - } - if self.has_attribute(MessageIntegritySha256::TYPE) { - return Err(StunWriteError::MessageIntegrityExists); - } - if self.has_attribute(Fingerprint::TYPE) { - return Err(StunWriteError::FingerprintExists); + match self.has_any_attribute(&[ + ty, + MessageIntegrity::TYPE, + MessageIntegritySha256::TYPE, + Fingerprint::TYPE, + ]) { + // can't validly add generic attributes after message integrity or fingerprint + Some(MessageIntegrity::TYPE) => return Err(StunWriteError::MessageIntegrityExists), + Some(MessageIntegritySha256::TYPE) => { + return Err(StunWriteError::MessageIntegrityExists) + } + Some(Fingerprint::TYPE) => return Err(StunWriteError::FingerprintExists), + Some(typ) if typ == ty => return Err(StunWriteError::AttributeExists(ty)), + _ => (), } self.attributes.push(AttrOrRaw::Attr(attr)); self.attribute_types.push(ty); @@ -1777,18 +1787,20 @@ impl<'a> MessageBuilder<'a> { } _ => (), } - if self.has_attribute(ty) { - return Err(StunWriteError::AttributeExists(ty)); - } - // can't validly add generic attributes after message integrity or fingerprint - if self.has_attribute(MessageIntegrity::TYPE) { - return Err(StunWriteError::MessageIntegrityExists); - } - if self.has_attribute(MessageIntegritySha256::TYPE) { - return Err(StunWriteError::MessageIntegrityExists); - } - if self.has_attribute(Fingerprint::TYPE) { - return Err(StunWriteError::FingerprintExists); + match self.has_any_attribute(&[ + ty, + MessageIntegrity::TYPE, + MessageIntegritySha256::TYPE, + Fingerprint::TYPE, + ]) { + // can't validly add generic attributes after message integrity or fingerprint + Some(MessageIntegrity::TYPE) => return Err(StunWriteError::MessageIntegrityExists), + Some(MessageIntegritySha256::TYPE) => { + return Err(StunWriteError::MessageIntegrityExists) + } + Some(Fingerprint::TYPE) => return Err(StunWriteError::FingerprintExists), + Some(typ) if typ == ty => return Err(StunWriteError::AttributeExists(ty)), + _ => (), } self.attributes.push(AttrOrRaw::Raw(attr)); self.attribute_types.push(ty); @@ -1799,6 +1811,15 @@ impl<'a> MessageBuilder<'a> { pub fn has_attribute(&self, atype: AttributeType) -> bool { self.attribute_types.iter().any(|&ty| ty == atype) } + + /// Return whether this [`MessageBuilder`] contains any of the provided attributes and + /// returns the attribute found. + pub fn has_any_attribute(&self, atypes: &[AttributeType]) -> Option { + self.attribute_types + .iter() + .find(|&ty| atypes.contains(ty)) + .cloned() + } } #[cfg(test)]