From 67a9f4e27dfcd178bb8c2255f47aede043831b78 Mon Sep 17 00:00:00 2001 From: Matthew Waters Date: Sun, 23 Jun 2024 22:24:00 +1000 Subject: [PATCH] message: implement parsing without ever needing to copy any bytes --- fuzz/fuzz_targets/stun_msg_from_bytes.rs | 5 +- stun-proto/src/agent.rs | 634 +++------- stun-proto/src/lib.rs | 13 +- stun-types/src/attribute/address.rs | 8 +- stun-types/src/attribute/alternate.rs | 20 +- stun-types/src/attribute/error.rs | 28 +- stun-types/src/attribute/fingerprint.rs | 14 +- stun-types/src/attribute/ice.rs | 42 +- stun-types/src/attribute/integrity.rs | 22 +- stun-types/src/attribute/mod.rs | 74 +- stun-types/src/attribute/nonce.rs | 10 +- .../src/attribute/password_algorithm.rs | 24 +- stun-types/src/attribute/realm.rs | 10 +- stun-types/src/attribute/software.rs | 10 +- stun-types/src/attribute/user.rs | 18 +- stun-types/src/attribute/xor_addr.rs | 10 +- stun-types/src/data.rs | 141 +++ stun-types/src/lib.rs | 1 + stun-types/src/message.rs | 1024 ++++++++++------- 19 files changed, 1080 insertions(+), 1028 deletions(-) create mode 100644 stun-types/src/data.rs diff --git a/fuzz/fuzz_targets/stun_msg_from_bytes.rs b/fuzz/fuzz_targets/stun_msg_from_bytes.rs index 07330a7..ab44d9b 100644 --- a/fuzz/fuzz_targets/stun_msg_from_bytes.rs +++ b/fuzz/fuzz_targets/stun_msg_from_bytes.rs @@ -29,8 +29,7 @@ fuzz_target!(|data_and_credentials: DataAndCredentials| { debug_init(); let msg = Message::from_bytes(data_and_credentials.data); debug!("generated {:?}", msg); - let integrity_result = msg.and_then(|msg| { - msg.validate_integrity(data_and_credentials.data, &data_and_credentials.credentials) - }); + let integrity_result = + msg.and_then(|msg| msg.validate_integrity(&data_and_credentials.credentials)); debug!("integrity result {:?}", integrity_result); }); diff --git a/stun-proto/src/agent.rs b/stun-proto/src/agent.rs index 3606875..af3e80a 100644 --- a/stun-proto/src/agent.rs +++ b/stun-proto/src/agent.rs @@ -25,6 +25,7 @@ use std::collections::{HashMap, HashSet}; use byteorder::{BigEndian, ByteOrder}; use stun_types::attribute::*; +use stun_types::data::Data; use stun_types::message::*; use crate::DebugWrapper; @@ -46,7 +47,6 @@ pub struct StunAgent { outstanding_requests: HashMap, local_credentials: Option, remote_credentials: Option, - tcp_buffer: Option, } /// Builder struct for a [`StunAgent`] @@ -66,10 +66,6 @@ impl StunAgentBuilder { /// Build the [`StunAgent`] pub fn build(self) -> StunAgent { let id = STUN_AGENT_COUNT.fetch_add(1, Ordering::SeqCst); - let tcp_buffer = match self.transport { - TransportType::Udp => None, - TransportType::Tcp => Some(TcpBuffer::new()), - }; StunAgent { id, transport: self.transport, @@ -79,7 +75,6 @@ impl StunAgentBuilder { outstanding_requests: Default::default(), local_credentials: None, remote_credentials: None, - tcp_buffer, } } } @@ -137,8 +132,11 @@ impl StunAgent { /// Perform any operations needed to be able to send a [`Message`] to a peer. /// /// If a request message is successfully sent, then [`StunAgent::poll`] needs to be called. - pub fn send(&mut self, msg: Message, to: SocketAddr) -> Result, StunError> { - let data = msg.to_bytes(); + pub fn send( + &mut self, + msg: MessageBuilder<'_>, + to: SocketAddr, + ) -> Result, StunError> { if msg.has_class(MessageClass::Request) { if self .outstanding_requests @@ -155,78 +153,18 @@ impl StunAgent { self.outstanding_requests.insert(transaction_id, state); return Ok(transmit); } + let data = msg.build(); Ok(self.send_data(&data, to).into_owned()) } - fn parse_chunk( - &mut self, - data: &[u8], - from: SocketAddr, - ) -> Result, StunError> { - match Message::from_bytes(data) { - Ok(stun_msg) => { - debug!("received stun {}", stun_msg); - self.handle_stun(stun_msg, data, from) - } - Err(_) => { - let peer_validated = { self.validated_peers.contains(&from) }; - if peer_validated { - Ok(Some(HandleStunReply::Data(data.to_vec()))) - } else if self.transport == TransportType::Tcp { - // close the tcp channel - warn!("stun message not the first message sent over TCP channel, closing"); - Err(StunError::ProtocolViolation) - } else { - trace!("dropping unvalidated data from peer"); - Ok(None) - } - } - } - } - - /// Provide data received on a socket from a peer for handling by the [`StunAgent`]. - /// The returned value indicates what the caller must do with the data. + /// Returns whether this agent has received or send a STUN message to this peer. Failure may + /// be the result of an attacker and the caller must drop any non-STUN data received before this + /// functions returns `true`. /// - /// If this function returns [`HandleStunReply::StunResponse`], then this agent needs to be - /// `poll()`ed again. - #[tracing::instrument( - name = "stun_incoming_data" - level = "info", - skip(self, data), - fields( - stun_id = self.id, - to = ?self.local_addr() - ) - )] - pub fn handle_incoming_data( - &mut self, - data: &[u8], - from: SocketAddr, - ) -> Result, StunError> { - match self.transport { - TransportType::Udp => { - if let Some(reply) = self.parse_chunk(data, from)? { - Ok(vec![reply]) - } else { - Ok(vec![]) - } - } - TransportType::Tcp => { - let mut ret = vec![]; - let tcp = self.tcp_buffer.as_mut().unwrap(); - tcp.push_data(data); - let mut datas = vec![]; - while let Some(data) = tcp.pull_data() { - datas.push(data); - } - for data in datas { - if let Some(reply) = self.parse_chunk(&data, from)? { - ret.push(reply); - } - } - Ok(ret) - } - } + /// If non-STUN data is received over a TCP connection from an unvalidated peer, the caller + /// must immediately close the TCP connection. + pub fn is_validated_peer(&self, remote_addr: SocketAddr) -> bool { + self.validated_peers.contains(&remote_addr) } #[tracing::instrument( @@ -241,53 +179,53 @@ impl StunAgent { } } + /// Provide data received on a socket from a peer for handling by the [`StunAgent`]. + /// The returned value indicates what the caller must do with the data. + /// + /// If this function returns [`HandleStunReply::StunResponse`], then this agent needs to be + /// `poll()`ed again. #[tracing::instrument( name = "stun_handle_message" - skip(self, msg, orig_data, from), + skip(self, msg, from), fields( transaction_id = %msg.transaction_id(), ) )] - fn handle_stun( - &mut self, - msg: Message, - orig_data: &[u8], - from: SocketAddr, - ) -> Result, StunError> { + pub fn handle_stun<'a>(&mut self, msg: Message<'a>, from: SocketAddr) -> HandleStunReply<'a> { if msg.is_response() { let Some(request) = self.take_outstanding_request(&msg.transaction_id()) else { trace!("original request disappeared -> ignoring response"); - return Ok(None); + return HandleStunReply::Drop; }; // only validate response if the original request had credentials - if request.msg.has_attribute(MessageIntegrity::TYPE) { + if request.request_had_credentials { if let Some(remote_creds) = &self.remote_credentials { - match msg.validate_integrity(orig_data, remote_creds) { + match msg.validate_integrity(remote_creds) { Ok(_) => { self.validated_peer(from); - Ok(Some(HandleStunReply::StunResponse(request.msg, msg))) + HandleStunReply::StunResponse(msg) } Err(e) => { debug!("message failed integrity check: {:?}", e); self.outstanding_requests .insert(msg.transaction_id(), request); - Ok(None) + HandleStunReply::Drop } } } else { debug!("no remote credentials, ignoring"); self.outstanding_requests .insert(msg.transaction_id(), request); - Ok(None) + HandleStunReply::Drop } } else { // original message didn't have integrity, reply doesn't need to either self.validated_peer(from); - Ok(Some(HandleStunReply::StunResponse(request.msg, msg))) + HandleStunReply::StunResponse(msg) } } else { self.validated_peer(from); - Ok(Some(HandleStunReply::IncomingStun(msg))) + HandleStunReply::IncomingStun(msg) } } @@ -363,7 +301,7 @@ impl StunAgent { let mut timeout = None; let mut cancelled = None; for request in self.outstanding_requests.values_mut() { - let transaction_id = request.msg.transaction_id(); + let transaction_id = request.transaction_id; match request.poll(now) { StunRequestPollRet::Cancelled => { cancelled = Some(transaction_id); @@ -384,13 +322,13 @@ impl StunAgent { } } if let Some(transaction) = timeout { - if let Some(state) = self.outstanding_requests.remove(&transaction) { - return StunAgentPollRet::TransactionTimedOut(state.msg); + if let Some(_state) = self.outstanding_requests.remove(&transaction) { + return StunAgentPollRet::TransactionTimedOut(transaction); } } if let Some(transaction) = cancelled { - if let Some(state) = self.outstanding_requests.remove(&transaction) { - return StunAgentPollRet::TransactionCancelled(state.msg); + if let Some(_state) = self.outstanding_requests.remove(&transaction) { + return StunAgentPollRet::TransactionCancelled(transaction); } } StunAgentPollRet::WaitUntil(lowest_wait) @@ -401,9 +339,9 @@ impl StunAgent { #[derive(Debug)] pub enum StunAgentPollRet<'a> { /// An oustanding transaction timed out and has been removed from the agent. - TransactionTimedOut(Message), + TransactionTimedOut(TransactionId), /// An oustanding transaction was cancelled and has been removed from the agent. - TransactionCancelled(Message), + TransactionCancelled(TransactionId), /// Send data using the specified 5-tuple SendData(Transmit<'a>), /// Wait until the specified time has passed @@ -423,23 +361,29 @@ fn send_data(transport: TransportType, bytes: &[u8], from: SocketAddr, to: Socke } } +/// A buffer object for handling STUN data received over a TCP connection that requires framing as +/// specified in RFC 4571. This framing is required for ICE usage of TCP candidates. #[derive(Debug)] -struct TcpBuffer { +pub struct TcpBuffer { buf: DebugWrapper>, } impl TcpBuffer { - fn new() -> Self { + /// Construct a new [`TcpBuffer`] + pub fn new() -> Self { Self { buf: DebugWrapper::wrap(vec![], "..."), } } - fn push_data(&mut self, data: &[u8]) { + /// Push a chunk of received data into the buffer. + pub fn push_data(&mut self, data: &[u8]) { self.buf.extend(data); } - fn pull_data(&mut self) -> Option> { + /// Pull the next chunk of data from the buffer. If no buffer is available, then None is + /// returned. + pub fn pull_data(&mut self) -> Option> { if self.buf.len() < 2 { trace!( "running buffer is currently too small ({} bytes) to provide data", @@ -474,105 +418,9 @@ impl TcpBuffer { } } -/// A slice of data -#[derive(Debug)] -#[repr(transparent)] -pub struct DataSlice<'a>(&'a [u8]); - -impl<'a> DataSlice<'a> { - pub fn take(self) -> &'a [u8] { - self.0 - } - - pub fn to_owned(&self) -> DataOwned { - DataOwned(self.0.into()) - } -} - -impl<'a> std::ops::Deref for DataSlice<'a> { - type Target = [u8]; - fn deref(&self) -> &Self::Target { - self.0 - } -} - -impl<'a> From> for &'a [u8] { - fn from(value: DataSlice<'a>) -> Self { - value.0 - } -} - -impl<'a> From<&'a [u8]> for DataSlice<'a> { - fn from(value: &'a [u8]) -> Self { - Self(value) - } -} - -/// An owned piece of data -#[derive(Debug)] -#[repr(transparent)] -pub struct DataOwned(Box<[u8]>); - -impl DataOwned { - pub fn take(self) -> Box<[u8]> { - self.0 - } -} - -impl std::ops::Deref for DataOwned { - type Target = [u8]; - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl From for Box<[u8]> { - fn from(value: DataOwned) -> Self { - value.0 - } -} - -impl From> for DataOwned { - fn from(value: Box<[u8]>) -> Self { - Self(value) - } -} - -/// An owned or borrowed piece of data -#[derive(Debug)] -pub enum Data<'a> { - Borrowed(DataSlice<'a>), - Owned(DataOwned), -} - -impl<'a> std::ops::Deref for Data<'a> { - type Target = [u8]; - fn deref(&self) -> &Self::Target { - match self { - Self::Borrowed(data) => data.0, - Self::Owned(data) => &data.0, - } - } -} - -impl<'a> Data<'a> { - fn into_owned<'b>(self) -> Data<'b> { - match self { - Self::Borrowed(data) => Data::Owned(data.to_owned()), - Self::Owned(data) => Data::Owned(data), - } - } -} - -impl<'a> From<&'a [u8]> for Data<'a> { - fn from(value: &'a [u8]) -> Self { - Self::Borrowed(value.into()) - } -} - -impl<'a> From> for Data<'a> { - fn from(value: Box<[u8]>) -> Self { - Self::Owned(value.into()) +impl Default for TcpBuffer { + fn default() -> Self { + Self::new() } } @@ -590,6 +438,7 @@ pub struct Transmit<'a> { } impl<'a> Transmit<'a> { + /// Construct a new [`Transmit`] with the specifid data and 5-tuple. pub fn new( data: impl Into>, transport: TransportType, @@ -604,6 +453,7 @@ impl<'a> Transmit<'a> { } } + /// Construct a new [`Transmit`] with the specifid 5-tuple and data converted to owned. pub fn new_owned( data: impl Into>, transport: TransportType, @@ -618,6 +468,7 @@ impl<'a> Transmit<'a> { } } + /// Consume this [`Transmit`] and produce and owned version. pub fn into_owned<'b>(self) -> Transmit<'b> { Transmit { data: self.data.into_owned(), @@ -651,7 +502,8 @@ enum StunRequestPollRet<'a> { #[derive(Debug)] struct StunRequestState { - msg: Message, + transaction_id: TransactionId, + request_had_credentials: bool, bytes: Vec, transport: TransportType, from: SocketAddr, @@ -664,19 +516,26 @@ struct StunRequestState { } impl StunRequestState { - fn new(request: Message, transport: TransportType, from: SocketAddr, to: SocketAddr) -> Self { - let data = request.to_bytes(); + fn new( + request: MessageBuilder<'_>, + transport: TransportType, + from: SocketAddr, + to: SocketAddr, + ) -> Self { + let data = request.build(); let timeouts_ms = if transport == TransportType::Tcp { vec![39500] } else { vec![500, 1000, 2000, 4000, 8000, 16000] }; Self { - msg: request, + transaction_id: request.transaction_id(), bytes: data, transport, from, to, + request_had_credentials: request.has_attribute(MessageIntegrity::TYPE) + || request.has_attribute(MessageIntegritySha256::TYPE), timeouts_ms, timeout_i: 0, recv_cancelled: false, @@ -690,7 +549,7 @@ impl StunRequestState { level = "info", ret, skip(self), - fields(transaction_id = %self.msg.transaction_id()), + fields(transaction_id = %self.transaction_id), )] fn poll(&mut self, now: Instant) -> StunRequestPollRet { if self.recv_cancelled { @@ -726,12 +585,6 @@ pub struct StunRequest<'a> { } impl<'a> StunRequest<'a> { - /// The request [`Message`] - pub fn request(&self) -> &Message { - let state = self.agent.request_state(self.transaction_id).unwrap(); - &state.msg - } - /// The remote address the request is sent to pub fn peer_address(&self) -> SocketAddr { let state = self.agent.request_state(self.transaction_id).unwrap(); @@ -747,12 +600,6 @@ pub struct StunRequestMut<'a> { } impl<'a> StunRequestMut<'a> { - /// The request [`Message`] - pub fn request(&self) -> &Message { - let state = self.agent.request_state(self.transaction_id).unwrap(); - &state.msg - } - /// The remote address the request is sent to pub fn peer_address(&self) -> SocketAddr { let state = self.agent.request_state(self.transaction_id).unwrap(); @@ -787,13 +634,13 @@ impl<'a> StunRequestMut<'a> { /// Return value when handling possible STUN data #[derive(Debug)] -pub enum HandleStunReply { +pub enum HandleStunReply<'a> { /// The provided data could be parsed as a response to an outstanding request - StunResponse(Message, Message), + StunResponse(Message<'a>), /// The provided data could be parsed as a STUN message - IncomingStun(Message), - /// The provided data could not be parsed as a STUN message - Data(Vec), + IncomingStun(Message<'a>), + /// Drop this message. + Drop, } /// STUN errors @@ -875,7 +722,7 @@ pub(crate) mod tests { let mut agent = StunAgent::builder(TransportType::Udp, local_addr) .remote_addr(remote_addr) .build(); - let msg = Message::new_request(BINDING); + let msg = Message::builder_request(BINDING); let transaction_id = msg.transaction_id(); let transmit = agent.send(msg, remote_addr).unwrap(); let now = Instant::now(); @@ -883,10 +730,11 @@ pub(crate) mod tests { assert_eq!(transmit.from, local_addr); assert_eq!(transmit.to, remote_addr); let request = Message::from_bytes(&transmit.data).unwrap(); - let response = Message::new_error(&request); - let resp_data = response.to_bytes(); - let ret = agent.handle_incoming_data(&resp_data, remote_addr).unwrap(); - assert!(matches!(ret[0], HandleStunReply::StunResponse(_, _))); + let response = Message::builder_error(&request); + let resp_data = response.build(); + let response = Message::from_bytes(&resp_data).unwrap(); + let ret = agent.handle_stun(response, remote_addr); + assert!(matches!(ret, HandleStunReply::StunResponse(_))); assert!(agent.request_transaction(transaction_id).is_none()); assert!(agent.mut_request_transaction(transaction_id).is_none()); @@ -903,7 +751,7 @@ pub(crate) mod tests { .remote_addr(remote_addr) .build(); let transaction_id = TransactionId::generate(); - let msg = Message::new( + let msg = Message::builder( MessageType::from_class_method(MessageClass::Indication, BINDING), transaction_id, ); @@ -915,14 +763,15 @@ pub(crate) mod tests { assert!(agent.request_transaction(transaction_id).is_none()); assert!(agent.mut_request_transaction(transaction_id).is_none()); // you should definitely never do this ;). Indications should never get replies. - let response = Message::new( + let response = Message::builder( MessageType::from_class_method(MessageClass::Error, BINDING), transaction_id, ); - let resp_data = response.to_bytes(); + let resp_data = response.build(); + let response = Message::from_bytes(&resp_data).unwrap(); // response without a request is dropped. - let ret = agent.handle_incoming_data(&resp_data, remote_addr).unwrap(); - assert!(ret.is_empty()); + let ret = agent.handle_stun(response, remote_addr); + assert!(matches!(ret, HandleStunReply::Drop)); } #[test] @@ -938,21 +787,22 @@ pub(crate) mod tests { agent.set_remote_credentials(remote_credentials.clone().into()); // unvalidated peer data should be dropped - let data = vec![20; 4]; - let replies = agent.handle_incoming_data(&data, remote_addr).unwrap(); - assert!(replies.is_empty()); + assert!(!agent.is_validated_peer(remote_addr)); - let mut msg = Message::new_request(BINDING); + let mut msg = Message::builder_request(BINDING); let transaction_id = msg.transaction_id(); msg.add_message_integrity(&local_credentials.clone().into(), IntegrityAlgorithm::Sha1) .unwrap(); + println!("send"); let transmit = agent.send(msg, remote_addr).unwrap(); + println!("sent"); let request = Message::from_bytes(&transmit.data).unwrap(); - let mut response = Message::new_success(&request); + println!("generate response"); + let mut response = Message::builder_success(&request); response - .add_attribute(XorMappedAddress::new( + .add_attribute(&XorMappedAddress::new( transmit.from, request.transaction_id(), )) @@ -960,25 +810,22 @@ pub(crate) mod tests { response .add_message_integrity(&remote_credentials.into(), IntegrityAlgorithm::Sha1) .unwrap(); + println!("{response:?}"); - let data = response.to_bytes(); + let data = response.build(); + println!("{data:?}"); let to = transmit.to; - let mut reply = agent.handle_incoming_data(&data, to).unwrap(); - let HandleStunReply::StunResponse(request, response) = reply.remove(0) else { + let response = Message::from_bytes(&data).unwrap(); + println!("{response}"); + let reply = agent.handle_stun(response, to); + let HandleStunReply::StunResponse(response) = reply else { unreachable!(); }; - assert_eq!(request.transaction_id(), transaction_id); assert_eq!(response.transaction_id(), transaction_id); assert!(agent.request_transaction(transaction_id).is_none()); assert!(agent.mut_request_transaction(transaction_id).is_none()); - - let data = vec![20; 4]; - let mut replies = agent.handle_incoming_data(&data, remote_addr).unwrap(); - let HandleStunReply::Data(received) = replies.remove(0) else { - unreachable!(); - }; - assert_eq!(data, received); + assert!(agent.is_validated_peer(remote_addr)); } #[test] @@ -989,7 +836,7 @@ pub(crate) mod tests { let mut agent = StunAgent::builder(TransportType::Udp, local_addr) .remote_addr(remote_addr) .build(); - let msg = Message::new_request(BINDING); + let msg = Message::builder_request(BINDING); let transaction_id = msg.transaction_id(); agent.send(msg, remote_addr).unwrap(); let mut now = Instant::now(); @@ -1007,9 +854,7 @@ pub(crate) mod tests { assert!(agent.mut_request_transaction(transaction_id).is_none()); // unvalidated peer data should be dropped - let data = vec![20; 4]; - let replies = agent.handle_incoming_data(&data, remote_addr).unwrap(); - assert!(replies.is_empty()); + assert!(!agent.is_validated_peer(remote_addr)); } #[test] @@ -1021,44 +866,32 @@ pub(crate) mod tests { let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build(); // unvalidated peer data should be dropped - let data = vec![20; 4]; - let replies = agent.handle_incoming_data(&data, remote_addr).unwrap(); - assert!(replies.is_empty()); + assert!(!agent.is_validated_peer(remote_addr)); - let msg = Message::new_request(BINDING); + let msg = Message::builder_request(BINDING); let transaction_id = msg.transaction_id(); let transmit = agent.send(msg, remote_addr).unwrap(); let request = Message::from_bytes(&transmit.data).unwrap(); - let mut response = Message::new_success(&request); + let mut response = Message::builder_success(&request); response - .add_attribute(XorMappedAddress::new( + .add_attribute(&XorMappedAddress::new( transmit.from, request.transaction_id(), )) .unwrap(); - let data = response.to_bytes(); + let data = response.build(); let to = transmit.to; - let reply = agent.handle_incoming_data(&data, to).unwrap(); + trace!("data: {data:?}"); + let response = Message::from_bytes(&data).unwrap(); + let reply = agent.handle_stun(response, to); - assert!(matches!(reply[0], HandleStunReply::StunResponse(_, _))); + assert!(matches!(reply, HandleStunReply::StunResponse(_))); assert!(agent.request_transaction(transaction_id).is_none()); assert!(agent.mut_request_transaction(transaction_id).is_none()); - - let data = vec![42; 8]; - let transmit = agent.send_data(&data, remote_addr); - assert_eq!(transmit.data(), &data); - assert_eq!(transmit.from, local_addr); - assert_eq!(transmit.to, remote_addr); - - let data = vec![20; 4]; - let mut replies = agent.handle_incoming_data(&data, remote_addr).unwrap(); - let HandleStunReply::Data(received) = replies.remove(0) else { - unreachable!(); - }; - assert_eq!(data, received); + assert!(agent.is_validated_peer(remote_addr)); } #[test] @@ -1073,7 +906,7 @@ pub(crate) mod tests { agent.set_local_credentials(local_credentials.clone().into()); agent.set_remote_credentials(remote_credentials.clone().into()); - let mut msg = Message::new_request(BINDING); + let mut msg = Message::builder_request(BINDING); let transaction_id = msg.transaction_id(); msg.add_message_integrity(&local_credentials.into(), IntegrityAlgorithm::Sha1) .unwrap(); @@ -1081,26 +914,25 @@ pub(crate) mod tests { let request = Message::from_bytes(&transmit.data).unwrap(); - let mut response = Message::new_success(&request); + let mut response = Message::builder_success(&request); response - .add_attribute(XorMappedAddress::new( + .add_attribute(&XorMappedAddress::new( transmit.from, request.transaction_id(), )) .unwrap(); - let data = response.to_bytes(); + let data = response.build(); let to = transmit.to; - let reply = agent.handle_incoming_data(&data, to).unwrap(); + let response = Message::from_bytes(&data).unwrap(); + let reply = agent.handle_stun(response, to); // reply is ignored as it does not have credentials - assert!(reply.is_empty()); + assert!(matches!(reply, HandleStunReply::Drop)); assert!(agent.request_transaction(transaction_id).is_some()); assert!(agent.mut_request_transaction(transaction_id).is_some()); // unvalidated peer data should be dropped - let data = vec![20; 4]; - let replies = agent.handle_incoming_data(&data, remote_addr).unwrap(); - assert!(replies.is_empty()); + assert!(!agent.is_validated_peer(remote_addr)); } #[test] @@ -1115,16 +947,16 @@ pub(crate) mod tests { agent.set_local_credentials(local_credentials.clone().into()); agent.set_remote_credentials(remote_credentials.into()); - let mut msg = Message::new_request(BINDING); + let mut msg = Message::builder_request(BINDING); msg.add_message_integrity(&local_credentials.clone().into(), IntegrityAlgorithm::Sha1) .unwrap(); let transmit = agent.send(msg, remote_addr).unwrap(); let request = Message::from_bytes(&transmit.data).unwrap(); - let mut response = Message::new_success(&request); + let mut response = Message::builder_success(&request); response - .add_attribute(XorMappedAddress::new( + .add_attribute(&XorMappedAddress::new( transmit.from, request.transaction_id(), )) @@ -1134,16 +966,15 @@ pub(crate) mod tests { .add_message_integrity(&local_credentials.into(), IntegrityAlgorithm::Sha1) .unwrap(); - let data = response.to_bytes(); + let data = response.build(); let to = transmit.to; - let reply = agent.handle_incoming_data(&data, to).unwrap(); + let response = Message::from_bytes(&data).unwrap(); + let reply = agent.handle_stun(response, to); // reply is ignored as it does not have credentials - assert!(reply.is_empty()); + assert!(matches!(reply, HandleStunReply::Drop)); // unvalidated peer data should be dropped - let data = vec![20; 4]; - let replies = agent.handle_incoming_data(&data, remote_addr).unwrap(); - assert!(replies.is_empty()); + assert!(!agent.is_validated_peer(remote_addr)); } #[test] @@ -1153,126 +984,30 @@ pub(crate) mod tests { let remote_addr = "10.0.0.2:3478".parse().unwrap(); let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build(); + assert!(!agent.is_validated_peer(remote_addr)); - // unvalidated peer data should be dropped - let data = vec![20; 4]; - let replies = agent.handle_incoming_data(&data, remote_addr).unwrap(); - assert!(replies.is_empty()); - - let msg = Message::new_request(BINDING); + let msg = Message::builder_request(BINDING); let transmit = agent.send(msg, remote_addr).unwrap(); let request = Message::from_bytes(&transmit.data).unwrap(); - let mut response = Message::new_success(&request); + let mut response = Message::builder_success(&request); response - .add_attribute(XorMappedAddress::new( + .add_attribute(&XorMappedAddress::new( transmit.from, request.transaction_id(), )) .unwrap(); - let data = response.to_bytes(); + let data = response.build(); let to = transmit.to; - let reply = agent.handle_incoming_data(&data, to).unwrap(); - - assert!(matches!(reply[0], HandleStunReply::StunResponse(_, _))); - - let data = vec![42; 8]; - let transmit = agent.send_data(&data, remote_addr); - assert_eq!(transmit.data(), &data); - assert_eq!(transmit.from, local_addr); - assert_eq!(transmit.to, remote_addr); - - let data = vec![20; 4]; - let mut replies = agent.handle_incoming_data(&data, remote_addr).unwrap(); - let HandleStunReply::Data(received) = replies.remove(0) else { - unreachable!(); - }; - assert_eq!(data, received); - - let data = response.to_bytes(); - let reply = agent.handle_incoming_data(&data, to).unwrap(); - assert!(reply.is_empty()); - } - - #[test] - fn tcp_request() { - init(); - let local_addr = "127.0.0.1:2000".parse().unwrap(); - let remote_addr = "127.0.0.1:1000".parse().unwrap(); - let mut agent = StunAgent::builder(TransportType::Tcp, local_addr) - .remote_addr(remote_addr) - .build(); - let msg = Message::new_request(BINDING); - let transmit = agent.send(msg, remote_addr).unwrap(); - let now = Instant::now(); - assert_eq!(transmit.transport, TransportType::Tcp); - assert_eq!(transmit.from, local_addr); - assert_eq!(transmit.to, remote_addr); - let request = Message::from_bytes(&transmit.data[2..]).unwrap(); - let response = Message::new_error(&request); - let resp_data = response.to_bytes(); - let mut data = Vec::with_capacity(resp_data.len() + 2); - data.resize(2, 0); - BigEndian::write_u16(&mut data[..2], resp_data.len() as u16); - data.extend(resp_data); - let ret = agent.handle_incoming_data(&data, remote_addr).unwrap(); - assert!(matches!(ret[0], HandleStunReply::StunResponse(_, _))); - - let ret = agent.poll(now); - assert!(matches!(ret, StunAgentPollRet::WaitUntil(_))); - - let data = vec![42; 8]; - let transmit = agent.send_data(&data, remote_addr); - assert_eq!(&transmit.data()[2..], &data); - assert_eq!(transmit.from, local_addr); - assert_eq!(transmit.to, remote_addr); - - let data = vec![0, 2, 4, 8]; - let mut replies = agent.handle_incoming_data(&data, remote_addr).unwrap(); - let HandleStunReply::Data(received) = replies.remove(0) else { - unreachable!(); - }; - assert_eq!(&data[2..], received); - } - - #[test] - fn tcp_data_before_request() { - init(); - let local_addr = "127.0.0.1:2000".parse().unwrap(); - let remote_addr = "127.0.0.1:1000".parse().unwrap(); - let mut agent = StunAgent::builder(TransportType::Tcp, local_addr) - .remote_addr(remote_addr) - .build(); - let data = [0, 2, 42, 42]; + let response = Message::from_bytes(&data).unwrap(); + let reply = agent.handle_stun(response, to); + assert!(matches!(reply, HandleStunReply::StunResponse(_))); - assert!(matches!( - agent.handle_incoming_data(&data, remote_addr), - Err(StunError::ProtocolViolation) - )); - } - - #[test] - fn tcp_split_recv() { - init(); - let local_addr = "127.0.0.1:2000".parse().unwrap(); - let remote_addr = "127.0.0.1:1000".parse().unwrap(); - let mut agent = StunAgent::builder(TransportType::Tcp, local_addr) - .remote_addr(remote_addr) - .build(); - let msg = Message::new_request(BINDING); - - let msg_data = msg.to_bytes(); - let mut data = Vec::with_capacity(msg_data.len() + 2); - data.resize(2, 0); - BigEndian::write_u16(&mut data[..2], msg_data.len() as u16); - data.extend(msg_data); - - let ret = agent.handle_incoming_data(&data[..8], remote_addr).unwrap(); - assert!(ret.is_empty()); - let ret = agent.handle_incoming_data(&data[8..], remote_addr).unwrap(); - assert!(matches!(ret[0], HandleStunReply::IncomingStun(_))); + let response = Message::from_bytes(&data).unwrap(); + let reply = agent.handle_stun(response, to); + assert!(matches!(reply, HandleStunReply::Drop)); } #[test] @@ -1282,24 +1017,24 @@ pub(crate) mod tests { let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build(); - let msg = Message::new_request(BINDING); + let msg = Message::builder_request(BINDING); let transaction_id = msg.transaction_id(); let _transmit = agent.send(msg, remote_addr).unwrap(); let mut request = agent.mut_request_transaction(transaction_id).unwrap(); - assert_eq!(request.request().transaction_id(), transaction_id); assert_eq!(request.agent().local_addr(), local_addr); assert_eq!(request.mut_agent().local_addr(), local_addr); assert_eq!(request.peer_address(), remote_addr); request.cancel(); let ret = agent.poll(Instant::now()); - let StunAgentPollRet::TransactionCancelled(request) = ret else { + let StunAgentPollRet::TransactionCancelled(_request) = ret else { unreachable!(); }; - assert_eq!(request.transaction_id(), transaction_id); + assert_eq!(transaction_id, transaction_id); assert!(agent.request_transaction(transaction_id).is_none()); assert!(agent.mut_request_transaction(transaction_id).is_none()); + assert!(!agent.is_validated_peer(remote_addr)); } #[test] @@ -1309,12 +1044,11 @@ pub(crate) mod tests { let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build(); - let msg = Message::new_request(BINDING); + let msg = Message::builder_request(BINDING); let transaction_id = msg.transaction_id(); let _transmit = agent.send(msg, remote_addr).unwrap(); let mut request = agent.mut_request_transaction(transaction_id).unwrap(); - assert_eq!(request.request().transaction_id(), transaction_id); assert_eq!(request.agent().local_addr(), local_addr); assert_eq!(request.mut_agent().local_addr(), local_addr); assert_eq!(request.peer_address(), remote_addr); @@ -1332,6 +1066,7 @@ pub(crate) mod tests { } assert!(agent.request_transaction(transaction_id).is_none()); assert!(agent.mut_request_transaction(transaction_id).is_none()); + assert!(!agent.is_validated_peer(remote_addr)); } #[test] @@ -1341,14 +1076,15 @@ pub(crate) mod tests { let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build(); - let msg = Message::new_request(BINDING); + let msg = Message::builder_request(BINDING); let transaction_id = msg.transaction_id(); let transmit = agent.send(msg.clone(), remote_addr).unwrap(); let to = transmit.to; + let request = Message::from_bytes(transmit.data()).unwrap(); - let mut response = Message::new_success(&msg); + let mut response = Message::builder_success(&request); response - .add_attribute(XorMappedAddress::new(transmit.from, transaction_id)) + .add_attribute(&XorMappedAddress::new(transmit.from, transaction_id)) .unwrap(); assert!(matches!( @@ -1358,24 +1094,17 @@ pub(crate) mod tests { // the original transaction should still exist let request = agent.request_transaction(transaction_id).unwrap(); - assert_eq!(request.request().transaction_id(), transaction_id); assert_eq!(request.peer_address(), remote_addr); - let data = response.to_bytes(); - let mut reply = agent.handle_incoming_data(&data, to).unwrap(); + let data = response.build(); + let response = Message::from_bytes(&data).unwrap(); + let reply = agent.handle_stun(response, to); - let HandleStunReply::StunResponse(request, response) = reply.remove(0) else { + let HandleStunReply::StunResponse(response) = reply else { unreachable!(); }; - assert_eq!(request.transaction_id(), transaction_id); assert_eq!(response.transaction_id(), transaction_id); - - let data = vec![20; 4]; - let mut replies = agent.handle_incoming_data(&data, remote_addr).unwrap(); - let HandleStunReply::Data(received) = replies.remove(0) else { - unreachable!(); - }; - assert_eq!(data, received); + assert!(agent.is_validated_peer(to)); } #[test] @@ -1385,33 +1114,50 @@ pub(crate) mod tests { let mut agent = StunAgent::builder(TransportType::Udp, local_addr).build(); - let msg = Message::new_request(BINDING); - let data = msg.to_bytes(); - let HandleStunReply::IncomingStun(request) = agent - .handle_incoming_data(&data, remote_addr) - .unwrap() - .remove(0) - else { + let msg = Message::builder_request(BINDING); + let data = msg.build(); + let stun = Message::from_bytes(&data).unwrap(); + println!("{stun:?}"); + let HandleStunReply::IncomingStun(request) = agent.handle_stun(stun, remote_addr) else { unreachable!() }; assert_eq!(msg.transaction_id(), request.transaction_id()); + assert!(agent.is_validated_peer(remote_addr)); } #[test] - fn data_access() { - let array = [0, 1, 2, 3]; - let borrowed_data = Data::from(array.as_slice()); - assert_eq!(array.as_slice(), &*borrowed_data); - let owned_data = borrowed_data.into_owned(); - assert_eq!(array.as_slice(), &*owned_data); - let Data::Owned(owned) = owned_data else { - unreachable!(); - }; - let owned = DataOwned::take(owned); - assert_eq!(array.as_slice(), &*owned); - let data = Data::from(owned); - assert_eq!(array.as_slice(), &*data); - let borrowed = DataSlice::from(&*data); - assert_eq!(array.as_slice(), &*borrowed); + fn tcp_request() { + init(); + let local_addr = "127.0.0.1:2000".parse().unwrap(); + let remote_addr = "127.0.0.1:1000".parse().unwrap(); + let mut agent = StunAgent::builder(TransportType::Tcp, local_addr) + .remote_addr(remote_addr) + .build(); + + let msg = Message::builder_request(BINDING); + let transaction_id = msg.transaction_id(); + let transmit = agent.send(msg, remote_addr).unwrap(); + assert_eq!(transmit.transport, TransportType::Tcp); + assert_eq!(transmit.from, local_addr); + assert_eq!(transmit.to, remote_addr); + + let request = Message::from_bytes(&transmit.data[2..]).unwrap(); + assert_eq!(request.transaction_id(), transaction_id); + } + + #[test] + fn tcp_buffer_split_recv() { + init(); + + let mut tcp_buffer = TcpBuffer::default(); + + let mut len = [0; 2]; + let data = [0, 1, 2, 4, 3]; + BigEndian::write_u16(&mut len, data.len() as u16); + + tcp_buffer.push_data(&len); + assert!(tcp_buffer.pull_data().is_none()); + tcp_buffer.push_data(&data); + assert_eq!(tcp_buffer.pull_data().unwrap(), &data); } } diff --git a/stun-proto/src/lib.rs b/stun-proto/src/lib.rs index 0fff188..a5c6bcb 100644 --- a/stun-proto/src/lib.rs +++ b/stun-proto/src/lib.rs @@ -35,27 +35,28 @@ //! agent.set_remote_credentials(remote_credentials.clone().into()); //! //! // and we can send a Message -//! let mut msg = Message::new_request(BINDING); +//! let mut msg = Message::builder_request(BINDING); //! msg.add_message_integrity(&local_credentials.clone().into(), IntegrityAlgorithm::Sha1).unwrap(); //! let transmit = agent.send(msg, remote_addr).unwrap(); //! //! // The transmit struct indicates what data and where to send it. //! let request = Message::from_bytes(&transmit.data).unwrap(); //! -//! let mut response = Message::new_success(&request); +//! let mut response = Message::builder_success(&request); //! let xor_addr = XorMappedAddress::new(transmit.from, request.transaction_id()); -//! response.add_attribute(xor_addr).unwrap(); +//! response.add_attribute(&xor_addr).unwrap(); //! response.add_message_integrity(&remote_credentials.clone().into(), IntegrityAlgorithm::Sha1).unwrap(); //! //! // when receiving data on the associated socket, we should pass it through the Agent so it can //! // parse and handle any STUN messages. -//! let data = response.to_bytes(); +//! let data = response.build(); //! let to = transmit.to; -//! let reply = agent.handle_incoming_data(&data, to).unwrap(); +//! let response = Message::from_bytes(&data).unwrap(); +//! let reply = agent.handle_stun(response, to); //! //! // If running over TCP then there may be multiple messages parsed. However UDP will only ever //! // have a single message per datagram. -//! assert!(matches!(reply[0], HandleStunReply::StunResponse(_, _))); +//! assert!(matches!(reply, HandleStunReply::StunResponse(_))); //! //! // Once valid STUN data has been sent and received, then data can be sent and received from the //! // peer. diff --git a/stun-types/src/attribute/address.rs b/stun-types/src/attribute/address.rs index 94611c3..6b9865b 100644 --- a/stun-types/src/attribute/address.rs +++ b/stun-types/src/attribute/address.rs @@ -68,7 +68,7 @@ impl MappedSocketAddr { } /// Convert this [`MappedSocketAddr`] into a [`RawAttribute`] - pub fn to_raw(&self, atype: AttributeType) -> RawAttribute { + pub fn to_raw<'a>(&self, atype: AttributeType) -> RawAttribute<'a> { match self.addr { SocketAddr::V4(addr) => { let mut buf = [0; 8]; @@ -76,7 +76,7 @@ impl MappedSocketAddr { BigEndian::write_u16(&mut buf[2..4], addr.port()); let octets = u32::from(*addr.ip()); BigEndian::write_u32(&mut buf[4..8], octets); - RawAttribute::new(atype, &buf) + RawAttribute::new(atype, &buf).into_owned() } SocketAddr::V6(addr) => { let mut buf = [0; 20]; @@ -84,7 +84,7 @@ impl MappedSocketAddr { BigEndian::write_u16(&mut buf[2..4], addr.port()); let octets = u128::from(*addr.ip()); BigEndian::write_u128(&mut buf[4..20], octets); - RawAttribute::new(atype, &buf) + RawAttribute::new(atype, &buf).into_owned() } } } @@ -155,7 +155,7 @@ impl XorSocketAddr { } /// Convert this [`XorSocketAddr`] into a [`RawAttribute`] - pub fn to_raw(&self, atype: AttributeType) -> RawAttribute { + pub fn to_raw<'a>(&self, atype: AttributeType) -> RawAttribute<'a> { self.addr.to_raw(atype) } diff --git a/stun-types/src/attribute/alternate.rs b/stun-types/src/attribute/alternate.rs index 898d013..d237446 100644 --- a/stun-types/src/attribute/alternate.rs +++ b/stun-types/src/attribute/alternate.rs @@ -27,13 +27,13 @@ impl Attribute for AlternateServer { } } -impl From for RawAttribute { - fn from(value: AlternateServer) -> RawAttribute { +impl<'a> From<&AlternateServer> for RawAttribute<'a> { + fn from(value: &AlternateServer) -> RawAttribute<'a> { value.addr.to_raw(AlternateServer::TYPE) } } -impl TryFrom<&RawAttribute> for AlternateServer { +impl<'a> TryFrom<&RawAttribute<'a>> for AlternateServer { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -94,7 +94,7 @@ impl Attribute for AlternateDomain { self.domain.len() as u16 } } -impl TryFrom<&RawAttribute> for AlternateDomain { +impl<'a> TryFrom<&RawAttribute<'a>> for AlternateDomain { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -107,9 +107,9 @@ impl TryFrom<&RawAttribute> for AlternateDomain { }) } } -impl From for RawAttribute { - fn from(value: AlternateDomain) -> RawAttribute { - RawAttribute::new(AlternateDomain::TYPE, value.domain.as_bytes()) +impl<'a> From<&AlternateDomain> for RawAttribute<'a> { + fn from(value: &AlternateDomain) -> RawAttribute<'a> { + RawAttribute::new(AlternateDomain::TYPE, value.domain.as_bytes()).into_owned() } } @@ -174,7 +174,7 @@ mod tests { SocketAddr::V4(_) => assert_eq!(mapped.length(), 8), SocketAddr::V6(_) => assert_eq!(mapped.length(), 20), } - let raw: RawAttribute = mapped.into(); + let raw = RawAttribute::from(&mapped); assert_eq!(raw.get_type(), AlternateServer::TYPE); let mapped2 = AlternateServer::try_from(&raw).unwrap(); assert_eq!(mapped2.server(), *addr); @@ -184,7 +184,7 @@ mod tests { BigEndian::write_u16(&mut data[2..4], len as u16 - 4 - 1); assert!(matches!( AlternateServer::try_from( - &RawAttribute::try_from(data[..len - 1].as_ref()).unwrap() + &RawAttribute::from_bytes(data[..len - 1].as_ref()).unwrap() ), Err(StunParseError::Truncated { expected: _, @@ -208,7 +208,7 @@ mod tests { let attr = AlternateDomain::new(dns); assert_eq!(attr.domain(), dns); assert_eq!(attr.length() as usize, dns.len()); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); assert_eq!(raw.get_type(), AlternateDomain::TYPE); let mapped2 = AlternateDomain::try_from(&raw).unwrap(); assert_eq!(mapped2.domain(), dns); diff --git a/stun-types/src/attribute/error.rs b/stun-types/src/attribute/error.rs index a453a23..2d65004 100644 --- a/stun-types/src/attribute/error.rs +++ b/stun-types/src/attribute/error.rs @@ -27,18 +27,18 @@ impl Attribute for ErrorCode { self.reason.len() as u16 + 4 } } -impl From for RawAttribute { - fn from(value: ErrorCode) -> RawAttribute { +impl<'a> From<&ErrorCode> for RawAttribute<'a> { + fn from(value: &ErrorCode) -> RawAttribute<'a> { let mut data = Vec::with_capacity(value.length() as usize); data.push(0u8); data.push(0u8); data.push((value.code / 100) as u8); data.push((value.code % 100) as u8); data.extend(value.reason.as_bytes()); - RawAttribute::new(ErrorCode::TYPE, &data) + RawAttribute::new(ErrorCode::TYPE, &data).into_owned() } } -impl TryFrom<&RawAttribute> for ErrorCode { +impl<'a> TryFrom<&RawAttribute<'a>> for ErrorCode { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -243,18 +243,18 @@ impl Attribute for UnknownAttributes { (self.attributes.len() as u16) * 2 } } -impl From for RawAttribute { - fn from(value: UnknownAttributes) -> RawAttribute { +impl<'a> From<&UnknownAttributes> for RawAttribute<'a> { + fn from(value: &UnknownAttributes) -> RawAttribute<'a> { let mut data = Vec::with_capacity(value.length() as usize); for attr in &value.attributes { let mut encoded = vec![0; 2]; BigEndian::write_u16(&mut encoded, (*attr).into()); data.extend(encoded); } - RawAttribute::new(UnknownAttributes::TYPE, &data) + RawAttribute::new(UnknownAttributes::TYPE, &data).into_owned() } } -impl TryFrom<&RawAttribute> for UnknownAttributes { +impl<'a> TryFrom<&RawAttribute<'a>> for UnknownAttributes { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -345,7 +345,7 @@ mod tests { let err = ErrorCode::new(code, reason).unwrap(); assert_eq!(err.code(), code); assert_eq!(err.reason(), reason); - let raw: RawAttribute = err.into(); + let raw = RawAttribute::from(&err); assert_eq!(raw.get_type(), ErrorCode::TYPE); let err2 = ErrorCode::try_from(&raw).unwrap(); assert_eq!(err2.code(), code); @@ -361,7 +361,7 @@ mod tests { #[test] fn error_code_parse_short() { let err = error_code_new(420); - let raw: RawAttribute = err.into(); + let raw = RawAttribute::from(&err); // no data let mut data: Vec<_> = raw.into(); let len = 0; @@ -378,7 +378,7 @@ mod tests { #[test] fn error_code_parse_wrong_implementation() { let err = error_code_new(420); - let raw: RawAttribute = err.into(); + let raw = RawAttribute::from(&err); // provide incorrectly typed data let mut data: Vec<_> = raw.into(); BigEndian::write_u16(&mut data[0..2], 0); @@ -391,7 +391,7 @@ mod tests { #[test] fn error_code_parse_out_of_range_code() { let err = error_code_new(420); - let raw: RawAttribute = err.into(); + let raw = RawAttribute::from(&err); let mut data: Vec<_> = raw.into(); // write an invalid error code @@ -405,7 +405,7 @@ mod tests { #[test] fn error_code_parse_invalid_reason() { let err = error_code_new(420); - let raw: RawAttribute = err.into(); + let raw = RawAttribute::from(&err); let mut data: Vec<_> = raw.into(); // write an invalid utf8 bytes @@ -457,7 +457,7 @@ mod tests { assert!(unknown.has_attribute(Realm::TYPE)); assert!(unknown.has_attribute(AlternateServer::TYPE)); assert!(!unknown.has_attribute(Nonce::TYPE)); - let raw: RawAttribute = unknown.into(); + let raw = RawAttribute::from(&unknown); assert_eq!(raw.get_type(), UnknownAttributes::TYPE); let unknown2 = UnknownAttributes::try_from(&raw).unwrap(); assert!(unknown2.has_attribute(Realm::TYPE)); diff --git a/stun-types/src/attribute/fingerprint.rs b/stun-types/src/attribute/fingerprint.rs index 91b2f75..87b3c7b 100644 --- a/stun-types/src/attribute/fingerprint.rs +++ b/stun-types/src/attribute/fingerprint.rs @@ -25,20 +25,20 @@ impl Attribute for Fingerprint { 4 } } -impl From for RawAttribute { - fn from(value: Fingerprint) -> RawAttribute { +impl<'a> From<&Fingerprint> for RawAttribute<'a> { + fn from(value: &Fingerprint) -> RawAttribute<'a> { let buf = bytewise_xor!(4, value.fingerprint, Fingerprint::XOR_CONSTANT, 0); - RawAttribute::new(Fingerprint::TYPE, &buf) + RawAttribute::new(Fingerprint::TYPE, &buf).into_owned() } } -impl TryFrom<&RawAttribute> for Fingerprint { +impl<'a> TryFrom<&RawAttribute<'a>> for Fingerprint { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { raw.check_type_and_len(Self::TYPE, 4..=4)?; // sized checked earlier - let boxed: Box<[u8; 4]> = raw.value.clone().into_boxed_slice().try_into().unwrap(); - let fingerprint = bytewise_xor!(4, *boxed, Fingerprint::XOR_CONSTANT, 0); + let boxed: [u8; 4] = (&*raw.value).try_into().unwrap(); + let fingerprint = bytewise_xor!(4, boxed, Fingerprint::XOR_CONSTANT, 0); Ok(Self { fingerprint }) } } @@ -115,7 +115,7 @@ mod tests { let attr = Fingerprint::new(val); assert_eq!(attr.fingerprint(), &val); assert_eq!(attr.length(), 4); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); assert_eq!(raw.get_type(), Fingerprint::TYPE); let mapped2 = Fingerprint::try_from(&raw).unwrap(); assert_eq!(mapped2.fingerprint(), &val); diff --git a/stun-types/src/attribute/ice.rs b/stun-types/src/attribute/ice.rs index 99126fa..5c955e4 100644 --- a/stun-types/src/attribute/ice.rs +++ b/stun-types/src/attribute/ice.rs @@ -27,14 +27,14 @@ impl Attribute for Priority { 4 } } -impl From for RawAttribute { - fn from(value: Priority) -> RawAttribute { +impl<'a> From<&Priority> for RawAttribute<'a> { + fn from(value: &Priority) -> RawAttribute<'a> { let mut buf = [0; 4]; BigEndian::write_u32(&mut buf[0..4], value.priority); - RawAttribute::new(Priority::TYPE, &buf) + RawAttribute::new(Priority::TYPE, &buf).into_owned() } } -impl TryFrom<&RawAttribute> for Priority { +impl<'a> TryFrom<&RawAttribute<'a>> for Priority { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -90,13 +90,13 @@ impl Attribute for UseCandidate { 0 } } -impl From for RawAttribute { - fn from(_value: UseCandidate) -> RawAttribute { - let buf = [0; 0]; - RawAttribute::new(UseCandidate::TYPE, &buf) +impl<'a> From<&UseCandidate> for RawAttribute<'a> { + fn from(_value: &UseCandidate) -> RawAttribute<'a> { + static BUF: [u8; 0] = [0; 0]; + RawAttribute::new(UseCandidate::TYPE, &BUF) } } -impl TryFrom<&RawAttribute> for UseCandidate { +impl<'a> TryFrom<&RawAttribute<'a>> for UseCandidate { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -144,14 +144,14 @@ impl Attribute for IceControlled { 8 } } -impl From for RawAttribute { - fn from(value: IceControlled) -> RawAttribute { +impl<'a> From<&IceControlled> for RawAttribute<'a> { + fn from(value: &IceControlled) -> RawAttribute<'a> { let mut buf = [0; 8]; BigEndian::write_u64(&mut buf[..8], value.tie_breaker); - RawAttribute::new(IceControlled::TYPE, &buf) + RawAttribute::new(IceControlled::TYPE, &buf).into_owned() } } -impl TryFrom<&RawAttribute> for IceControlled { +impl<'a> TryFrom<&RawAttribute<'a>> for IceControlled { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -209,15 +209,15 @@ impl Attribute for IceControlling { 8 } } -impl From for RawAttribute { - fn from(value: IceControlling) -> RawAttribute { +impl<'a> From<&IceControlling> for RawAttribute<'a> { + fn from(value: &IceControlling) -> RawAttribute<'a> { let mut buf = [0; 8]; BigEndian::write_u64(&mut buf[..8], value.tie_breaker); - RawAttribute::new(IceControlling::TYPE, &buf) + RawAttribute::new(IceControlling::TYPE, &buf).into_owned() } } -impl TryFrom<&RawAttribute> for IceControlling { +impl<'a> TryFrom<&RawAttribute<'a>> for IceControlling { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -277,7 +277,7 @@ mod tests { let priority = Priority::new(val); assert_eq!(priority.priority(), val); assert_eq!(priority.length(), 4); - let raw: RawAttribute = priority.into(); + let raw = RawAttribute::from(&priority); assert_eq!(raw.get_type(), Priority::TYPE); let mapped2 = Priority::try_from(&raw).unwrap(); assert_eq!(mapped2.priority(), val); @@ -306,7 +306,7 @@ mod tests { init(); let use_candidate = UseCandidate::default(); assert_eq!(use_candidate.length(), 0); - let raw: RawAttribute = use_candidate.into(); + let raw = RawAttribute::from(&use_candidate); assert_eq!(raw.get_type(), UseCandidate::TYPE); let _mapped2 = UseCandidate::try_from(&raw).unwrap(); // provide incorrectly typed data @@ -325,7 +325,7 @@ mod tests { let attr = IceControlling::new(tb); assert_eq!(attr.tie_breaker(), tb); assert_eq!(attr.length(), 8); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); assert_eq!(raw.get_type(), IceControlling::TYPE); let mapped2 = IceControlling::try_from(&raw).unwrap(); assert_eq!(mapped2.tie_breaker(), tb); @@ -356,7 +356,7 @@ mod tests { let attr = IceControlled::new(tb); assert_eq!(attr.tie_breaker(), tb); assert_eq!(attr.length(), 8); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); assert_eq!(raw.get_type(), IceControlled::TYPE); let mapped2 = IceControlled::try_from(&raw).unwrap(); assert_eq!(mapped2.tie_breaker(), tb); diff --git a/stun-types/src/attribute/integrity.rs b/stun-types/src/attribute/integrity.rs index 35d410e..b9f7353 100644 --- a/stun-types/src/attribute/integrity.rs +++ b/stun-types/src/attribute/integrity.rs @@ -27,19 +27,19 @@ impl Attribute for MessageIntegrity { 20 } } -impl From for RawAttribute { - fn from(value: MessageIntegrity) -> RawAttribute { +impl<'a> From<&'a MessageIntegrity> for RawAttribute<'a> { + fn from(value: &'a MessageIntegrity) -> RawAttribute<'a> { RawAttribute::new(MessageIntegrity::TYPE, &value.hmac) } } -impl TryFrom<&RawAttribute> for MessageIntegrity { +impl<'a> TryFrom<&RawAttribute<'a>> for MessageIntegrity { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { raw.check_type_and_len(Self::TYPE, 20..=20)?; // sized checked earlier - let boxed: Box<[u8; 20]> = raw.value.clone().into_boxed_slice().try_into().unwrap(); - Ok(Self { hmac: *boxed }) + let hmac: [u8; 20] = (&*raw.value).try_into().unwrap(); + Ok(Self { hmac }) } } @@ -153,12 +153,12 @@ impl Attribute for MessageIntegritySha256 { self.hmac.len() as u16 } } -impl From for RawAttribute { - fn from(value: MessageIntegritySha256) -> RawAttribute { +impl<'a> From<&'a MessageIntegritySha256> for RawAttribute<'a> { + fn from(value: &'a MessageIntegritySha256) -> RawAttribute<'a> { RawAttribute::new(MessageIntegritySha256::TYPE, &value.hmac) } } -impl TryFrom<&RawAttribute> for MessageIntegritySha256 { +impl<'a> TryFrom<&RawAttribute<'a>> for MessageIntegritySha256 { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -303,7 +303,7 @@ mod tests { let attr = MessageIntegrity::new(val); assert_eq!(attr.hmac(), &val); assert_eq!(attr.length(), 20); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); assert_eq!(raw.get_type(), MessageIntegrity::TYPE); let mapped2 = MessageIntegrity::try_from(&raw).unwrap(); assert_eq!(mapped2.hmac(), &val); @@ -336,7 +336,7 @@ mod tests { let attr = MessageIntegritySha256::new(&val).unwrap(); assert_eq!(attr.hmac(), &val); assert_eq!(attr.length(), 32); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); assert_eq!(raw.get_type(), MessageIntegritySha256::TYPE); let mapped2 = MessageIntegritySha256::try_from(&raw).unwrap(); assert_eq!(mapped2.hmac(), &val); @@ -346,7 +346,7 @@ mod tests { BigEndian::write_u16(&mut data[2..4], len as u16 - 4 - 1); assert!(matches!( MessageIntegritySha256::try_from( - &RawAttribute::try_from(data[..len - 1].as_ref()).unwrap() + &RawAttribute::from_bytes(data[..len - 1].as_ref()).unwrap() ), Err(StunParseError::InvalidAttributeData) )); diff --git a/stun-types/src/attribute/mod.rs b/stun-types/src/attribute/mod.rs index c6caaaf..0ac5b92 100644 --- a/stun-types/src/attribute/mod.rs +++ b/stun-types/src/attribute/mod.rs @@ -33,7 +33,7 @@ //! 0x65, 0x73, 0x00, 0x00 // e s //! ]; //! -//! let raw: RawAttribute = software.into(); +//! let raw = RawAttribute::from(&software); //! assert_eq!(raw.to_bytes(), attribute_data); //! //! // Can also parse data into a typed attribute as needed @@ -59,14 +59,14 @@ //! 4 //! } //! } -//! impl From for RawAttribute { -//! fn from(value: MyAttribute) -> RawAttribute { +//! impl<'a> From<&MyAttribute> for RawAttribute<'a> { +//! fn from(value: &MyAttribute) -> RawAttribute<'a> { //! let mut ret = vec![0; 4]; //! BigEndian::write_u32(&mut ret, value.value); -//! RawAttribute::new(MyAttribute::TYPE, &ret) +//! RawAttribute::new(MyAttribute::TYPE, &ret).into_owned() //! } //! } -//! impl TryFrom<&RawAttribute> for MyAttribute { +//! impl<'a> TryFrom<&RawAttribute<'a>> for MyAttribute { //! type Error = StunParseError; //! fn try_from(raw: &RawAttribute) -> Result { //! raw.check_type_and_len(Self::TYPE, 4..=4)?; @@ -78,7 +78,7 @@ //! } //! //! let my_attr = MyAttribute { value: 0x4729 }; -//! let raw: RawAttribute = my_attr.into(); +//! let raw = RawAttribute::from(&my_attr); //! //! let attribute_data = [ //! 0x88, 0x51, 0x00, 0x04, @@ -125,7 +125,7 @@ pub use software::Software; mod xor_addr; pub use xor_addr::XorMappedAddress; -use crate::message::StunParseError; +use crate::{data::Data, message::StunParseError}; use byteorder::{BigEndian, ByteOrder}; @@ -287,34 +287,36 @@ pub trait Attribute: std::fmt::Debug { /// Automatically implemented trait for converting from a concrete [`Attribute`] to a /// [`RawAttribute`] -pub trait AttributeToRaw: Attribute + Into +pub trait AttributeToRaw<'b>: Attribute + Into> where - RawAttribute: for<'a> From<&'a Self>, + RawAttribute<'b>: for<'a> From<&'a Self>, { /// Convert an `Attribute` to a `RawAttribute` - fn to_raw(&self) -> RawAttribute; + fn to_raw(&self) -> RawAttribute<'b>; } -impl> AttributeToRaw for T +impl<'b, T: Attribute + Into>> AttributeToRaw<'b> for T where - RawAttribute: for<'a> From<&'a Self>, + RawAttribute<'b>: for<'a> From<&'a Self>, { - fn to_raw(&self) -> RawAttribute + fn to_raw(&self) -> RawAttribute<'b> where - RawAttribute: for<'a> From<&'a Self>, + RawAttribute<'b>: for<'a> From<&'a Self>, { self.into() } } /// Automatically implemented trait for converting to a concrete [`Attribute`] from a /// [`RawAttribute`] -pub trait AttributeFromRaw: Attribute + for<'a> TryFrom<&'a RawAttribute, Error = E> { +pub trait AttributeFromRaw: + Attribute + for<'a> TryFrom<&'a RawAttribute<'a>, Error = E> +{ /// Convert an `Attribute` from a `RawAttribute` fn from_raw(raw: &RawAttribute) -> Result where Self: Sized; } -impl TryFrom<&'a RawAttribute, Error = E>> AttributeFromRaw for T { +impl TryFrom<&'a RawAttribute<'a>, Error = E>> AttributeFromRaw for T { fn from_raw(raw: &RawAttribute) -> Result { Self::try_from(raw) } @@ -334,11 +336,11 @@ pub(crate) fn padded_attr_size(attr: &RawAttribute) -> usize { /// The header and raw bytes of an unparsed [`Attribute`] #[derive(Debug, Clone, PartialEq, Eq)] -pub struct RawAttribute { +pub struct RawAttribute<'a> { /// The [`AttributeHeader`] of this [`RawAttribute`] pub header: AttributeHeader, /// The raw bytes of this [`RawAttribute`] - pub value: Vec, + pub value: Data<'a>, } macro_rules! display_attr { @@ -351,7 +353,7 @@ macro_rules! display_attr { }}; } -impl std::fmt::Display for RawAttribute { +impl<'a> std::fmt::Display for RawAttribute<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // try to get a more specialised display let malformed_str = format!( @@ -391,14 +393,15 @@ impl std::fmt::Display for RawAttribute { } } -impl RawAttribute { - pub fn new(atype: AttributeType, data: &[u8]) -> Self { +impl<'a> RawAttribute<'a> { + /// Create a new [`RawAttribute`] + pub fn new(atype: AttributeType, data: &'a [u8]) -> Self { Self { header: AttributeHeader { atype, length: data.len() as u16, }, - value: data.to_vec(), + value: data.into(), } } @@ -413,7 +416,7 @@ impl RawAttribute { /// assert_eq!(attr.get_type(), AttributeType::new(1)); /// assert_eq!(attr.length(), 2); /// ``` - pub fn from_bytes(data: &[u8]) -> Result { + pub fn from_bytes(data: &'a [u8]) -> Result { let header = AttributeHeader::parse(data)?; // the advertised length is larger than actual data -> error if header.length > (data.len() - 4) as u16 { @@ -422,12 +425,9 @@ impl RawAttribute { actual: data.len() - 4, }); } - let mut data = data[4..].to_vec(); - data.truncate(header.length as usize); - //trace!("parsed into {:?} {:?}", header, data); Ok(Self { header, - value: data, + value: Data::Borrowed(data[4..header.length as usize + 4].into()), }) } @@ -442,7 +442,7 @@ impl RawAttribute { /// ``` pub fn to_bytes(&self) -> Vec { let mut ret: Vec = self.header.into(); - ret.extend(&self.value); + ret.extend(&*self.value); let len = ret.len(); if len % 4 != 0 { // pad to 4 bytes @@ -472,6 +472,14 @@ impl RawAttribute { } check_len(self.value.len(), allowed_range) } + + /// Consume this [`RawAttribute`] and return a new owned [`RawAttribute`] + pub fn into_owned<'b>(self) -> RawAttribute<'b> { + RawAttribute { + header: self.header, + value: self.value.into_owned(), + } + } } fn check_len( @@ -519,20 +527,12 @@ fn check_len( Ok(()) } -impl From for Vec { +impl<'a> From> for Vec { fn from(f: RawAttribute) -> Self { f.to_bytes() } } -impl TryFrom<&[u8]> for RawAttribute { - type Error = StunParseError; - - fn try_from(value: &[u8]) -> Result { - RawAttribute::from_bytes(value) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/stun-types/src/attribute/nonce.rs b/stun-types/src/attribute/nonce.rs index 9be09a9..0c2a7cf 100644 --- a/stun-types/src/attribute/nonce.rs +++ b/stun-types/src/attribute/nonce.rs @@ -25,12 +25,12 @@ impl Attribute for Nonce { self.nonce.len() as u16 } } -impl From for RawAttribute { - fn from(value: Nonce) -> RawAttribute { +impl<'a> From<&'a Nonce> for RawAttribute<'a> { + fn from(value: &'a Nonce) -> RawAttribute<'a> { RawAttribute::new(Nonce::TYPE, value.nonce.as_bytes()) } } -impl TryFrom<&RawAttribute> for Nonce { +impl<'a> TryFrom<&RawAttribute<'a>> for Nonce { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -100,7 +100,7 @@ mod tests { let attr = Nonce::new("nonce").unwrap(); assert_eq!(attr.nonce(), "nonce"); assert_eq!(attr.length() as usize, "nonce".len()); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); assert_eq!(raw.get_type(), Nonce::TYPE); let mapped2 = Nonce::try_from(&raw).unwrap(); assert_eq!(mapped2.nonce(), "nonce"); @@ -117,7 +117,7 @@ mod tests { fn nonce_not_utf8() { init(); let attr = Nonce::new("nonce").unwrap(); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); let mut data = raw.to_bytes(); data[6] = 0x88; assert!(matches!( diff --git a/stun-types/src/attribute/password_algorithm.rs b/stun-types/src/attribute/password_algorithm.rs index c2919b9..f48035e 100644 --- a/stun-types/src/attribute/password_algorithm.rs +++ b/stun-types/src/attribute/password_algorithm.rs @@ -81,8 +81,8 @@ impl Attribute for PasswordAlgorithms { len as u16 } } -impl From for RawAttribute { - fn from(value: PasswordAlgorithms) -> RawAttribute { +impl<'a> From<&PasswordAlgorithms> for RawAttribute<'a> { + fn from(value: &PasswordAlgorithms) -> RawAttribute<'a> { let len = value.length() as usize; let mut data = vec![0; len]; let mut i = 0; @@ -90,10 +90,10 @@ impl From for RawAttribute { algo.write(&mut data[i..]); i += 4 + padded_attr_len(algo.len() as usize); } - RawAttribute::new(PasswordAlgorithms::TYPE, &data) + RawAttribute::new(PasswordAlgorithms::TYPE, &data).into_owned() } } -impl TryFrom<&RawAttribute> for PasswordAlgorithms { +impl<'a> TryFrom<&RawAttribute<'a>> for PasswordAlgorithms { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -170,15 +170,15 @@ impl Attribute for PasswordAlgorithm { } } -impl From for RawAttribute { - fn from(value: PasswordAlgorithm) -> RawAttribute { +impl<'a> From<&PasswordAlgorithm> for RawAttribute<'a> { + fn from(value: &PasswordAlgorithm) -> RawAttribute<'a> { let len = value.length() as usize; let mut data = vec![0; len]; value.algorithm.write(&mut data); - RawAttribute::new(PasswordAlgorithm::TYPE, &data) + RawAttribute::new(PasswordAlgorithm::TYPE, &data).into_owned() } } -impl TryFrom<&RawAttribute> for PasswordAlgorithm { +impl<'a> TryFrom<&RawAttribute<'a>> for PasswordAlgorithm { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -239,7 +239,7 @@ mod tests { let vals = [PasswordAlgorithmValue::MD5, PasswordAlgorithmValue::SHA256]; let attr = PasswordAlgorithms::new(&vals); assert_eq!(attr.algorithms(), &vals); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); assert_eq!(raw.get_type(), PasswordAlgorithms::TYPE); let mapped2 = PasswordAlgorithms::try_from(&raw).unwrap(); assert_eq!(mapped2.algorithms(), &vals); @@ -258,7 +258,7 @@ mod tests { let val = PasswordAlgorithmValue::SHA256; let attr = PasswordAlgorithm::new(val); assert_eq!(attr.algorithm(), val); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); assert_eq!(raw.get_type(), PasswordAlgorithm::TYPE); let mapped2 = PasswordAlgorithm::try_from(&raw).unwrap(); assert_eq!(mapped2.algorithm(), val); @@ -276,7 +276,7 @@ mod tests { init(); let val = PasswordAlgorithmValue::SHA256; let attr = PasswordAlgorithm::new(val); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); let mut data = raw.to_bytes(); data[7] = 100; assert!(matches!( @@ -293,7 +293,7 @@ mod tests { init(); let val = PasswordAlgorithmValue::SHA256; let attr = PasswordAlgorithm::new(val); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); let mut data = raw.to_bytes(); data[5] = 0x80; assert!(matches!( diff --git a/stun-types/src/attribute/realm.rs b/stun-types/src/attribute/realm.rs index ff8dfea..eb1e5c3 100644 --- a/stun-types/src/attribute/realm.rs +++ b/stun-types/src/attribute/realm.rs @@ -25,12 +25,12 @@ impl Attribute for Realm { self.realm.len() as u16 } } -impl From for RawAttribute { - fn from(value: Realm) -> RawAttribute { +impl<'a> From<&'a Realm> for RawAttribute<'a> { + fn from(value: &'a Realm) -> RawAttribute<'a> { RawAttribute::new(Realm::TYPE, value.realm.as_bytes()) } } -impl TryFrom<&RawAttribute> for Realm { +impl<'a> TryFrom<&RawAttribute<'a>> for Realm { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -99,7 +99,7 @@ mod tests { let attr = Realm::new("realm").unwrap(); assert_eq!(attr.realm(), "realm"); assert_eq!(attr.length() as usize, "realm".len()); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); assert_eq!(raw.get_type(), Realm::TYPE); let mapped2 = Realm::try_from(&raw).unwrap(); assert_eq!(mapped2.realm(), "realm"); @@ -116,7 +116,7 @@ mod tests { fn realm_not_utf8() { init(); let attr = Realm::new("realm").unwrap(); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); let mut data = raw.to_bytes(); data[6] = 0x88; assert!(matches!( diff --git a/stun-types/src/attribute/software.rs b/stun-types/src/attribute/software.rs index a4f953c..809c455 100644 --- a/stun-types/src/attribute/software.rs +++ b/stun-types/src/attribute/software.rs @@ -24,12 +24,12 @@ impl Attribute for Software { self.software.len() as u16 } } -impl From for RawAttribute { - fn from(value: Software) -> RawAttribute { +impl<'a> From<&'a Software> for RawAttribute<'a> { + fn from(value: &'a Software) -> RawAttribute<'a> { RawAttribute::new(Software::TYPE, value.software.as_bytes()) } } -impl TryFrom<&RawAttribute> for Software { +impl<'a> TryFrom<&RawAttribute<'a>> for Software { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -104,7 +104,7 @@ mod tests { let software = Software::new("software").unwrap(); assert_eq!(software.software(), "software"); assert_eq!(software.length() as usize, "software".len()); - let raw: RawAttribute = software.into(); + let raw = RawAttribute::from(&software); assert_eq!(raw.get_type(), Software::TYPE); let software2 = Software::try_from(&raw).unwrap(); assert_eq!(software2.software(), "software"); @@ -121,7 +121,7 @@ mod tests { fn software_not_utf8() { init(); let attr = Software::new("software").unwrap(); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); let mut data = raw.to_bytes(); data[6] = 0x88; assert!(matches!( diff --git a/stun-types/src/attribute/user.rs b/stun-types/src/attribute/user.rs index a38e0ac..9b470ee 100644 --- a/stun-types/src/attribute/user.rs +++ b/stun-types/src/attribute/user.rs @@ -24,12 +24,12 @@ impl Attribute for Username { self.user.len() as u16 } } -impl From for RawAttribute { - fn from(value: Username) -> RawAttribute { +impl<'a> From<&'a Username> for RawAttribute<'a> { + fn from(value: &'a Username) -> RawAttribute<'a> { RawAttribute::new(Username::TYPE, value.user.as_bytes()) } } -impl TryFrom<&RawAttribute> for Username { +impl<'a> TryFrom<&RawAttribute<'a>> for Username { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -104,13 +104,13 @@ impl Attribute for Userhash { 32 } } -impl From for RawAttribute { - fn from(value: Userhash) -> RawAttribute { +impl<'a> From<&'a Userhash> for RawAttribute<'a> { + fn from(value: &'a Userhash) -> RawAttribute<'a> { RawAttribute::new(Userhash::TYPE, &value.hash) } } -impl TryFrom<&RawAttribute> for Userhash { +impl<'a> TryFrom<&RawAttribute<'a>> for Userhash { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -191,7 +191,7 @@ mod tests { let user = Username::new(s).unwrap(); assert_eq!(user.username(), s); assert_eq!(user.length() as usize, s.len()); - let raw: RawAttribute = user.into(); + let raw = RawAttribute::from(&user); assert_eq!(raw.get_type(), Username::TYPE); let user2 = Username::try_from(&raw).unwrap(); assert_eq!(user2.username(), s); @@ -208,7 +208,7 @@ mod tests { fn username_not_utf8() { init(); let attr = Username::new("user").unwrap(); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); let mut data = raw.to_bytes(); data[6] = 0x88; assert!(matches!( @@ -241,7 +241,7 @@ mod tests { let attr = Userhash::new(hash); assert_eq!(attr.hash(), &hash); assert_eq!(attr.length(), 32); - let raw: RawAttribute = attr.into(); + let raw = RawAttribute::from(&attr); assert_eq!(raw.get_type(), Userhash::TYPE); let mapped2 = Userhash::try_from(&raw).unwrap(); assert_eq!(mapped2.hash(), &hash); diff --git a/stun-types/src/attribute/xor_addr.rs b/stun-types/src/attribute/xor_addr.rs index cb54eb4..fb5af03 100644 --- a/stun-types/src/attribute/xor_addr.rs +++ b/stun-types/src/attribute/xor_addr.rs @@ -26,12 +26,12 @@ impl Attribute for XorMappedAddress { self.addr.length() } } -impl From for RawAttribute { - fn from(value: XorMappedAddress) -> RawAttribute { +impl<'a> From<&XorMappedAddress> for RawAttribute<'a> { + fn from(value: &XorMappedAddress) -> RawAttribute<'a> { value.addr.to_raw(XorMappedAddress::TYPE) } } -impl TryFrom<&RawAttribute> for XorMappedAddress { +impl<'a> TryFrom<&RawAttribute<'a>> for XorMappedAddress { type Error = StunParseError; fn try_from(raw: &RawAttribute) -> Result { @@ -102,7 +102,7 @@ mod tests { for addr in addrs { let mapped = XorMappedAddress::new(*addr, transaction_id); assert_eq!(mapped.addr(transaction_id), *addr); - let raw: RawAttribute = mapped.into(); + let raw = RawAttribute::from(&mapped); assert_eq!(raw.get_type(), XorMappedAddress::TYPE); let mapped2 = XorMappedAddress::try_from(&raw).unwrap(); assert_eq!(mapped2.addr(transaction_id), *addr); @@ -112,7 +112,7 @@ mod tests { BigEndian::write_u16(&mut data[2..4], len as u16 - 4 - 1); assert!(matches!( XorMappedAddress::try_from( - &RawAttribute::try_from(data[..len - 1].as_ref()).unwrap() + &RawAttribute::from_bytes(data[..len - 1].as_ref()).unwrap() ), Err(StunParseError::Truncated { expected: _, diff --git a/stun-types/src/data.rs b/stun-types/src/data.rs new file mode 100644 index 0000000..f39bff5 --- /dev/null +++ b/stun-types/src/data.rs @@ -0,0 +1,141 @@ +// Copyright (C) 2020 Matthew Waters +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +/// A slice of data +#[derive(Debug, Clone, PartialEq, Eq)] +#[repr(transparent)] +pub struct DataSlice<'a>(&'a [u8]); + +impl<'a> DataSlice<'a> { + /// Consume this slice and return the underlying data. + pub fn take(self) -> &'a [u8] { + self.0 + } + + /// Copy this borrowed slice into a new owned allocation. + pub fn to_owned(&self) -> DataOwned { + DataOwned(self.0.into()) + } +} + +impl<'a> std::ops::Deref for DataSlice<'a> { + type Target = [u8]; + fn deref(&self) -> &Self::Target { + self.0 + } +} + +impl<'a> From> for &'a [u8] { + fn from(value: DataSlice<'a>) -> Self { + value.0 + } +} + +impl<'a> From<&'a [u8]> for DataSlice<'a> { + fn from(value: &'a [u8]) -> Self { + Self(value) + } +} + +/// An owned piece of data +#[derive(Debug, Clone, PartialEq, Eq)] +#[repr(transparent)] +pub struct DataOwned(Box<[u8]>); + +impl DataOwned { + /// Consume this slice and return the underlying data. + pub fn take(self) -> Box<[u8]> { + self.0 + } +} + +impl std::ops::Deref for DataOwned { + type Target = [u8]; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for Box<[u8]> { + fn from(value: DataOwned) -> Self { + value.0 + } +} + +impl From> for DataOwned { + fn from(value: Box<[u8]>) -> Self { + Self(value) + } +} + +/// An owned or borrowed piece of data +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Data<'a> { + Borrowed(DataSlice<'a>), + Owned(DataOwned), +} + +impl<'a> std::ops::Deref for Data<'a> { + type Target = [u8]; + fn deref(&self) -> &Self::Target { + match self { + Self::Borrowed(data) => data.0, + Self::Owned(data) => &data.0, + } + } +} + +impl<'a> Data<'a> { + /// Create a new owned version of this data + pub fn into_owned<'b>(self) -> Data<'b> { + match self { + Self::Borrowed(data) => Data::Owned(data.to_owned()), + Self::Owned(data) => Data::Owned(data), + } + } +} + +impl<'a> From<&'a [u8]> for Data<'a> { + fn from(value: &'a [u8]) -> Self { + Self::Borrowed(value.into()) + } +} + +impl<'a> From> for Data<'a> { + fn from(value: Box<[u8]>) -> Self { + Self::Owned(value.into()) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + + fn init() { + crate::tests::test_init_log(); + } + + #[test] + fn data_access() { + init(); + let array = [0, 1, 2, 3]; + let borrowed_data = Data::from(array.as_slice()); + assert_eq!(array.as_slice(), &*borrowed_data); + let owned_data = borrowed_data.into_owned(); + assert_eq!(array.as_slice(), &*owned_data); + let Data::Owned(owned) = owned_data else { + unreachable!(); + }; + let owned = DataOwned::take(owned); + assert_eq!(array.as_slice(), &*owned); + let data = Data::from(owned); + assert_eq!(array.as_slice(), &*data); + let borrowed = DataSlice::from(&*data); + assert_eq!(array.as_slice(), &*borrowed); + } +} diff --git a/stun-types/src/lib.rs b/stun-types/src/lib.rs index 3fd5cba..33adcbc 100644 --- a/stun-types/src/lib.rs +++ b/stun-types/src/lib.rs @@ -29,6 +29,7 @@ use std::error::Error; use std::str::FromStr; pub mod attribute; +pub mod data; pub mod message; /// The transport family diff --git a/stun-types/src/message.rs b/stun-types/src/message.rs index 411adf7..7da70d0 100644 --- a/stun-types/src/message.rs +++ b/stun-types/src/message.rs @@ -42,7 +42,7 @@ //! // Attributes can be retrieved as raw values. //! let msg_attr = msg.raw_attribute(0x1D.into()).unwrap(); //! let attr = RawAttribute::new(0x1D.into(), &[0, 2, 0, 0]); -//! assert_eq!(msg_attr, &attr); +//! assert_eq!(msg_attr, attr); //! //! // Or as typed values //! let attr = msg.attribute::().unwrap(); @@ -57,12 +57,12 @@ //! use stun_types::message::{Message, BINDING}; //! //! // Automatically generates a transaction ID. -//! let mut msg = Message::new_request(BINDING); +//! let mut msg = Message::builder_request(BINDING); //! //! let software_name = "stun-types"; //! let software = Software::new(software_name).unwrap(); //! assert_eq!(software.software(), software_name); -//! msg.add_attribute(software).unwrap(); +//! msg.add_attribute(&software).unwrap(); //! //! let attribute_data = [ //! 0x80, 0x22, 0x00, 0x0a, // attribute type (0x8022) and length (0x000a) @@ -71,7 +71,7 @@ //! 0x65, 0x73, 0x00, 0x00 // e s //! ]; //! -//! let msg_data = msg.to_bytes(); +//! let msg_data = msg.build(); //! // ignores the randomly generated transaction id //! assert_eq!(msg_data[20..], attribute_data); //! ``` @@ -515,13 +515,11 @@ impl std::fmt::Display for TransactionId { /// Contains the [`MessageType`], a transaction ID, and a list of STUN /// [`Attribute`] #[derive(Debug, Clone)] -pub struct Message { - msg_type: MessageType, - transaction: TransactionId, - attributes: Vec, +pub struct Message<'a> { + data: &'a [u8], } -impl std::fmt::Display for Message { +impl<'a> std::fmt::Display for Message<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, @@ -531,18 +529,15 @@ impl std::fmt::Display for Message { self.get_type().method(), self.transaction_id() )?; - if self.attributes.is_empty() { - write!(f, "[]")?; - } else { - write!(f, "[")?; - for (i, a) in self.attributes.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{}", a)?; + let iter = self.iter_attributes(); + write!(f, "[")?; + for (i, a) in iter.enumerate() { + if i > 0 { + write!(f, ", ")?; } - write!(f, "]")?; + write!(f, "{}", a)?; } + write!(f, "]")?; write!(f, ")") } } @@ -554,7 +549,7 @@ pub enum IntegrityAlgorithm { Sha256, } -impl Message { +impl<'a> Message<'a> { /// Create a new [`Message`] with the provided [`MessageType`] and transaction ID /// /// Note you probably want to use one of the other helper constructors instead. @@ -564,14 +559,15 @@ impl Message { /// ``` /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; /// let mtype = MessageType::from_class_method(MessageClass::Indication, BINDING); - /// let message = Message::new(mtype, 0.into()); + /// let message = Message::builder(mtype, 0.into()).build(); + /// let message = Message::from_bytes(&message).unwrap(); /// assert!(message.has_class(MessageClass::Indication)); /// assert!(message.has_method(BINDING)); /// ``` - pub fn new(mtype: MessageType, transaction: TransactionId) -> Self { - Self { + pub fn builder<'b>(mtype: MessageType, transaction_id: TransactionId) -> MessageBuilder<'b> { + MessageBuilder { msg_type: mtype, - transaction, + transaction_id, attributes: vec![], } } @@ -582,12 +578,14 @@ impl Message { /// /// ``` /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; - /// let message = Message::new_request(BINDING); + /// let message = Message::builder_request(BINDING); + /// let data = message.build(); + /// let message = Message::from_bytes(&data).unwrap(); /// assert!(message.has_class(MessageClass::Request)); /// assert!(message.has_method(BINDING)); /// ``` - pub fn new_request(method: u16) -> Self { - Message::new( + pub fn builder_request<'b>(method: u16) -> MessageBuilder<'b> { + Message::builder( MessageType::from_class_method(MessageClass::Request, method), TransactionId::generate(), ) @@ -603,18 +601,21 @@ impl Message { /// /// ``` /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; - /// let message = Message::new_request(BINDING); - /// let success = Message::new_success(&message); + /// let message = Message::builder_request(BINDING); + /// let data = message.build(); + /// let message = Message::from_bytes(&data).unwrap(); + /// let success = Message::builder_success(&message).build(); + /// let success = Message::from_bytes(&success).unwrap(); /// assert!(success.has_class(MessageClass::Success)); /// assert!(success.has_method(BINDING)); /// ``` - pub fn new_success(orig: &Message) -> Self { + pub fn builder_success<'b>(orig: &Message) -> MessageBuilder<'b> { if !orig.has_class(MessageClass::Request) { panic!( "A success response message was attempted to be created from a non-request message" ); } - Message::new( + Message::builder( MessageType::from_class_method(MessageClass::Success, orig.method()), orig.transaction_id(), ) @@ -630,13 +631,16 @@ impl Message { /// /// ``` /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; - /// let message = Message::new_request(BINDING); - /// let success = Message::new_error(&message); - /// assert!(success.has_class(MessageClass::Error)); - /// assert!(success.has_method(BINDING)); + /// let message = Message::builder_request(BINDING); + /// let data = message.build(); + /// let message = Message::from_bytes(&data).unwrap(); + /// let error = Message::builder_error(&message).build(); + /// let error = Message::from_bytes(&error).unwrap(); + /// assert!(error.has_class(MessageClass::Error)); + /// assert!(error.has_method(BINDING)); /// ``` - pub fn new_error(orig: &Message) -> Self { - Message::new( + pub fn builder_error(orig: &Message) -> MessageBuilder<'a> { + Message::builder( MessageType::from_class_method(MessageClass::Error, orig.method()), orig.transaction_id(), ) @@ -648,12 +652,14 @@ impl Message { /// /// ``` /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; - /// let message = Message::new_request(BINDING); + /// let message = Message::builder_request(BINDING); + /// let data = message.build(); + /// let message = Message::from_bytes(&data).unwrap(); /// assert!(message.get_type().has_class(MessageClass::Request)); /// assert!(message.get_type().has_method(BINDING)); /// ``` pub fn get_type(&self) -> MessageType { - self.msg_type + MessageType::from_bytes(&self.data[..2]).unwrap() } /// Retrieve the [`MessageClass`] of a [`Message`] @@ -662,7 +668,8 @@ impl Message { /// /// ``` /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; - /// let message = Message::new_request(BINDING); + /// let message = Message::builder_request(BINDING).build(); + /// let message = Message::from_bytes(&message).unwrap(); /// assert_eq!(message.class(), MessageClass::Request); /// ``` pub fn class(&self) -> MessageClass { @@ -675,7 +682,8 @@ impl Message { /// /// ``` /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; - /// let message = Message::new_request(BINDING); + /// let message = Message::builder_request(BINDING).build(); + /// let message = Message::from_bytes(&message).unwrap(); /// assert!(message.has_class(MessageClass::Request)); /// ``` pub fn has_class(&self, cls: MessageClass) -> bool { @@ -690,13 +698,16 @@ impl Message { /// /// ``` /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; - /// let message = Message::new_request(BINDING); + /// let message = Message::builder_request(BINDING).build(); + /// let message = Message::from_bytes(&message).unwrap(); /// assert_eq!(message.is_response(), false); /// - /// let error = Message::new_error(&message); + /// let error = Message::builder_error(&message).build(); + /// let error = Message::from_bytes(&error).unwrap(); /// assert_eq!(error.is_response(), true); /// - /// let success = Message::new_success(&message); + /// let success = Message::builder_success(&message).build(); + /// let success = Message::from_bytes(&success).unwrap(); /// assert_eq!(success.is_response(), true); /// ``` pub fn is_response(&self) -> bool { @@ -709,7 +720,8 @@ impl Message { /// /// ``` /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; - /// let message = Message::new_request(BINDING); + /// let message = Message::builder_request(BINDING).build(); + /// let message = Message::from_bytes(&message).unwrap(); /// assert_eq!(message.method(), BINDING); /// ``` pub fn method(&self) -> u16 { @@ -722,7 +734,8 @@ impl Message { /// /// ``` /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; - /// let message = Message::new_request(BINDING); + /// let message = Message::builder_request(BINDING).build(); + /// let message = Message::from_bytes(&message).unwrap(); /// assert_eq!(message.has_method(BINDING), true); /// assert_eq!(message.has_method(0), false); /// ``` @@ -738,50 +751,12 @@ impl Message { /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING, TransactionId}; /// let mtype = MessageType::from_class_method(MessageClass::Request, BINDING); /// let transaction_id = TransactionId::generate(); - /// let message = Message::new(mtype, transaction_id); + /// let message = Message::builder(mtype, transaction_id).build(); + /// let message = Message::from_bytes(&message).unwrap(); /// assert_eq!(message.transaction_id(), transaction_id); /// ``` pub fn transaction_id(&self) -> TransactionId { - self.transaction - } - - /// Serialize a `Message` to network bytes - /// - /// # Examples - /// - /// ``` - /// # use stun_types::attribute::{RawAttribute, Attribute}; - /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; - /// let mut message = Message::new(MessageType::from_class_method(MessageClass::Request, BINDING), 1000.into()); - /// let attr = RawAttribute::new(1.into(), &[3]); - /// assert!(message.add_attribute(attr).is_ok()); - /// assert_eq!(message.to_bytes(), vec![0, 1, 0, 8, 33, 18, 164, 66, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 232, 0, 1, 0, 1, 3, 0, 0, 0]); - /// ``` - #[tracing::instrument( - name = "message_to_bytes", - level = "trace", - skip(self), - fields( - msg.transaction_id = %self.transaction_id() - ) - )] - pub fn to_bytes(&self) -> Vec { - let mut attr_size = 0; - for attr in &self.attributes { - attr_size += padded_attr_size(attr); - } - let mut ret = Vec::with_capacity(20 + attr_size); - ret.extend(self.msg_type.to_bytes()); - ret.resize(20, 0); - let transaction: u128 = self.transaction.into(); - let tid = (MAGIC_COOKIE as u128) << 96 | transaction & 0xffff_ffff_ffff_ffff_ffff_ffff; - BigEndian::write_u128(&mut ret[4..20], tid); - BigEndian::write_u16(&mut ret[2..4], attr_size as u16); - for attr in &self.attributes { - let bytes = attr.to_bytes(); - ret.extend(bytes); - } - ret + BigEndian::read_u128(&self.data[4..]).into() } /// Deserialize a `Message` @@ -795,7 +770,7 @@ impl Message { /// let message = Message::from_bytes(&msg_data).unwrap(); /// let attr = RawAttribute::new(1.into(), &[3]); /// let msg_attr = message.raw_attribute(1.into()).unwrap(); - /// assert_eq!(msg_attr, &attr); + /// assert_eq!(msg_attr, attr); /// assert_eq!(message.get_type(), MessageType::from_class_method(MessageClass::Request, BINDING)); /// assert_eq!(message.transaction_id(), 1000.into()); /// ``` @@ -807,7 +782,7 @@ impl Message { data.len = data.len() ) )] - pub fn from_bytes(data: &[u8]) -> Result { + pub fn from_bytes(data: &'a [u8]) -> Result { let orig_data = data; if data.len() < 20 { @@ -818,7 +793,7 @@ impl Message { actual: data.len(), }); } - let mtype = MessageType::from_bytes(data)?; + let _mtype = MessageType::from_bytes(data)?; let mlength = BigEndian::read_u16(&data[2..]) as usize; if mlength + 20 > data.len() { // mlength + header @@ -841,7 +816,6 @@ impl Message { ); return Err(StunParseError::NotStun); } - let mut ret = Self::new(mtype, tid.into()); let mut data_offset = 20; let mut data = &data[20..]; @@ -893,11 +867,10 @@ impl Message { return Err(StunParseError::FingerprintMismatch); } } - ret.attributes.push(attr); data = &data[padded_len..]; data_offset += padded_len; } - Ok(ret) + Ok(Message { data: orig_data }) } /// Validates the MESSAGE_INTEGRITY attribute with the provided credentials @@ -910,27 +883,27 @@ impl Message { /// ``` /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING, /// MessageIntegrityCredentials, LongTermCredentials, IntegrityAlgorithm}; - /// let mut message = Message::new_request(BINDING); + /// let mut message = Message::builder_request(BINDING); /// let credentials = LongTermCredentials::new( /// "user".to_owned(), /// "pass".to_owned(), /// "realm".to_owned() /// ).into(); /// assert!(message.add_message_integrity(&credentials, IntegrityAlgorithm::Sha256).is_ok()); - /// let data = message.to_bytes(); - /// assert!(message.validate_integrity(&data, &credentials).is_ok()); + /// let data = message.build(); + /// let message = Message::from_bytes(&data).unwrap(); + /// assert!(message.validate_integrity(&credentials).is_ok()); /// ``` #[tracing::instrument( name = "message_validate_integrity", level = "trace", - skip(self, orig_data, credentials), + skip(self, credentials), fields( msg.transaction = %self.transaction_id(), ) )] pub fn validate_integrity( &self, - orig_data: &[u8], credentials: &MessageIntegrityCredentials, ) -> Result<(), StunParseError> { debug!("using credentials {credentials:?}"); @@ -939,11 +912,11 @@ impl Message { let (algo, msg_hmac) = match (raw_sha1, raw_sha256) { (Some(_), Some(_)) => return Err(StunParseError::DuplicateIntegrity), (Some(sha1), None) => { - let integrity = MessageIntegrity::try_from(sha1)?; + let integrity = MessageIntegrity::try_from(&sha1)?; (IntegrityAlgorithm::Sha1, integrity.hmac().to_vec()) } (None, Some(sha256)) => { - let integrity = MessageIntegritySha256::try_from(sha256)?; + let integrity = MessageIntegritySha256::try_from(&sha256)?; (IntegrityAlgorithm::Sha256, integrity.hmac().to_vec()) } (None, None) => return Err(StunParseError::MissingAttribute(MessageIntegrity::TYPE)), @@ -951,7 +924,7 @@ impl Message { // find the location of the original MessageIntegrity attribute: XXX: maybe encode this into // the attribute instead? - let data = orig_data; + let data = self.data; if data.len() < 20 { // always at least 20 bytes long debug!("not enough data in message"); @@ -975,7 +948,7 @@ impl Message { // HMAC is computed using all the data up to (exclusive of) the MESSAGE_INTEGRITY // but with a length field including the MESSAGE_INTEGRITY attribute... let key = credentials.make_hmac_key(); - let mut hmac_data = orig_data[..data_offset].to_vec(); + let mut hmac_data = self.data[..data_offset].to_vec(); BigEndian::write_u16(&mut hmac_data[2..4], data_offset as u16 + 24 - 20); return MessageIntegrity::verify( &hmac_data, @@ -995,7 +968,7 @@ impl Message { // HMAC is computed using all the data up to (exclusive of) the MESSAGE_INTEGRITY // but with a length field including the MESSAGE_INTEGRITY attribute... let key = credentials.make_hmac_key(); - let mut hmac_data = orig_data[..data_offset].to_vec(); + let mut hmac_data = self.data[..data_offset].to_vec(); BigEndian::write_u16( &mut hmac_data[2..4], data_offset as u16 + attr.length() + 4 - 20, @@ -1021,194 +994,6 @@ impl Message { Err(StunParseError::MissingAttribute(MessageIntegrity::TYPE)) } - // message-integrity is computed using all the data up to (exclusive of) the - // MESSAGE-INTEGRITY but with a length field including the MESSAGE-INTEGRITY attribute... - fn integrity_bytes_from_message(&self, extra_len: u16) -> Vec { - let mut bytes = self.to_bytes(); - // rewrite the length to include the message-integrity attribute - let existing_len = BigEndian::read_u16(&bytes[2..4]); - BigEndian::write_u16(&mut bytes[2..4], existing_len + extra_len); - bytes - } - - /// Adds MESSAGE_INTEGRITY attribute to a [`Message`] using the provided credentials - /// - /// # Errors - /// - /// - If a [`MessageIntegrity`] attribute is already present - /// - If a [`MessageIntegritySha256`] attribute is already present - /// - If a [`Fingerprint`] attribute is already present - /// - /// # Examples - /// - /// ``` - /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING, - /// MessageIntegrityCredentials, ShortTermCredentials, IntegrityAlgorithm, StunWriteError}; - /// # use stun_types::attribute::{Attribute, MessageIntegrity, MessageIntegritySha256}; - /// let mut message = Message::new_request(BINDING); - /// let credentials = ShortTermCredentials::new("pass".to_owned()).into(); - /// assert!(message.add_message_integrity(&credentials, IntegrityAlgorithm::Sha1).is_ok()); - /// let data = message.to_bytes(); - /// assert!(message.validate_integrity(&data, &credentials).is_ok()); - /// - /// // duplicate MessageIntegrity is an error - /// assert!(matches!( - /// message.add_message_integrity(&credentials, IntegrityAlgorithm::Sha1), - /// Err(StunWriteError::AttributeExists(MessageIntegrity::TYPE)), - /// )); - /// - /// // only one of MessageIntegrity, and MessageIntegritySha256 is allowed - /// assert!(matches!( - /// message.add_message_integrity(&credentials, IntegrityAlgorithm::Sha256), - /// Err(StunWriteError::AttributeExists(MessageIntegrity::TYPE)), - /// )); - /// ``` - #[tracing::instrument( - name = "message_add_integrity", - level = "trace", - err, - skip(self), - fields( - msg.transaction = %self.transaction_id(), - ) - )] - pub fn add_message_integrity( - &mut self, - credentials: &MessageIntegrityCredentials, - algorithm: IntegrityAlgorithm, - ) -> Result<(), StunWriteError> { - if self.has_attribute(MessageIntegrity::TYPE) { - return Err(StunWriteError::AttributeExists(MessageIntegrity::TYPE)); - } - if self.has_attribute(MessageIntegritySha256::TYPE) { - return Err(StunWriteError::AttributeExists( - MessageIntegritySha256::TYPE, - )); - } - if self.has_attribute(Fingerprint::TYPE) { - return Err(StunWriteError::FingerprintExists); - } - - let key = credentials.make_hmac_key(); - match algorithm { - IntegrityAlgorithm::Sha1 => { - let bytes = self.integrity_bytes_from_message(24); - let integrity = MessageIntegrity::compute(&bytes, &key).unwrap(); - self.attributes - .push(MessageIntegrity::new(integrity).into()); - } - IntegrityAlgorithm::Sha256 => { - let bytes = self.integrity_bytes_from_message(36); - let integrity = MessageIntegritySha256::compute(&bytes, &key).unwrap(); - self.attributes - .push(MessageIntegritySha256::new(integrity.as_slice())?.into()); - } - } - Ok(()) - } - - /// Adds [`Fingerprint`] attribute to a [`Message`] - /// - /// # Errors - /// - /// - If a [`Fingerprint`] attribute is already present - /// - /// # Examples - /// - /// ``` - /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; - /// let mut message = Message::new_request(BINDING); - /// assert!(message.add_fingerprint().is_ok()); - /// - /// // duplicate FINGERPRINT is an error - /// assert!(message.add_fingerprint().is_err()); - /// ``` - #[tracing::instrument( - name = "message_add_fingerprint", - level = "trace", - skip(self), - fields( - msg.transaction = %self.transaction_id(), - ) - )] - pub fn add_fingerprint(&mut self) -> Result<(), StunWriteError> { - if self.has_attribute(Fingerprint::TYPE) { - return Err(StunWriteError::AttributeExists(Fingerprint::TYPE)); - } - // fingerprint is computed using all the data up to (exclusive of) the FINGERPRINT - // but with a length field including the FINGERPRINT attribute... - let mut bytes = self.to_bytes(); - // rewrite the length to include the fingerprint attribute - let existing_len = BigEndian::read_u16(&bytes[2..4]); - BigEndian::write_u16(&mut bytes[2..4], existing_len + 8); - let fingerprint = Fingerprint::compute(&bytes); - self.attributes.push(Fingerprint::new(fingerprint).into()); - Ok(()) - } - - /// Add a `Attribute` to this `Message`. Only one `AttributeType` can be added for each - /// `Attribute. Attempting to add multiple `Atribute`s of the same `AttributeType` will fail. - /// - /// # Errors - /// - /// - if a [`MessageIntegrity`] or [`MessageIntegritySha256`] attribute is attempted to be added. Use - /// `Message::add_message_integrity` instead. - /// - if a [`Fingerprint`] attribute is attempted to be added. Use - /// `Message::add_fingerprint` instead. - /// - If the attribute already exists within the message - /// - If attempting to add attributes when [`MessageIntegrity`], [`MessageIntegritySha256`] or - /// [`Fingerprint`] atributes already exist. - /// - /// # Examples - /// - /// Add an `Attribute` - /// - /// ``` - /// # use stun_types::attribute::RawAttribute; - /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; - /// let mut message = Message::new_request(BINDING); - /// let attr = RawAttribute::new(1.into(), &[3]); - /// assert!(message.add_attribute(attr.clone()).is_ok()); - /// assert!(message.add_attribute(attr).is_err()); - /// ``` - #[tracing::instrument( - name = "message_add_attribute", - level = "trace", - err, - skip(self, attr), - fields( - msg.transaction = %self.transaction_id(), - ) - )] - pub fn add_attribute(&mut self, attr: impl Into) -> Result<(), StunWriteError> { - let raw = attr.into(); - //trace!("adding attribute {:?}", attr); - if raw.get_type() == MessageIntegrity::TYPE { - panic!("Cannot write MessageIntegrity with `add_attribute`. Use add_message_integrity() instead"); - } - if raw.get_type() == MessageIntegritySha256::TYPE { - panic!("Cannot write MessageIntegritySha256 with `add_attribute`. Use add_message_integrity() instead"); - } - if raw.get_type() == Fingerprint::TYPE { - panic!("Cannot write Fingerprint with `add_attribute`. Use add_fingerprint() instead"); - } - if self.has_attribute(raw.get_type()) { - return Err(StunWriteError::AttributeExists(raw.get_type())); - } - // can't validly add generic attributes after message integrity or fingerprint - if self.has_attribute(MessageIntegrity::TYPE) { - return Err(StunWriteError::MessageIntegrityExists); - } - if self.has_attribute(MessageIntegritySha256::TYPE) { - return Err(StunWriteError::MessageIntegrityExists); - } - if self.has_attribute(Fingerprint::TYPE) { - return Err(StunWriteError::FingerprintExists); - } - self.attributes.push(raw); - Ok(()) - } - /// Retrieve a `RawAttribute` from this `Message`. /// /// # Examples @@ -1218,10 +1003,12 @@ impl Message { /// ``` /// # use stun_types::attribute::{RawAttribute, Attribute}; /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; - /// let mut message = Message::new_request(BINDING); + /// let mut message = Message::builder_request(BINDING); /// let attr = RawAttribute::new(1.into(), &[3]); /// assert!(message.add_attribute(attr.clone()).is_ok()); - /// assert_eq!(message.raw_attribute(1.into()).unwrap(), &attr); + /// let message = message.build(); + /// let message = Message::from_bytes(&message).unwrap(); + /// assert_eq!(message.raw_attribute(1.into()).unwrap(), attr); /// ``` #[tracing::instrument( name = "message_get_raw_attribute", @@ -1233,8 +1020,8 @@ impl Message { attribute_type = %atype, ) )] - pub fn raw_attribute(&self, atype: AttributeType) -> Option<&RawAttribute> { - self.attributes.iter().find(|attr| attr.get_type() == atype) + pub fn raw_attribute(&self, atype: AttributeType) -> Option { + self.iter_attributes().find(|attr| attr.get_type() == atype) } /// Retrieve a concrete `Attribute` from this `Message`. @@ -1250,9 +1037,11 @@ impl Message { /// ``` /// # use stun_types::attribute::{Software, Attribute}; /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; - /// let mut message = Message::new_request(BINDING); + /// let mut message = Message::builder_request(BINDING); /// let attr = Software::new("stun-types").unwrap(); - /// assert!(message.add_attribute(attr.clone()).is_ok()); + /// assert!(message.add_attribute(&attr).is_ok()); + /// let message = message.build(); + /// let message = Message::from_bytes(&message).unwrap(); /// assert_eq!(message.attribute::().unwrap(), attr); /// ``` #[tracing::instrument( @@ -1266,16 +1055,19 @@ impl Message { ) )] pub fn attribute>(&self) -> Result { - self.attributes - .iter() + self.iter_attributes() .find(|attr| attr.get_type() == A::TYPE) .ok_or(StunParseError::MissingAttribute(A::TYPE)) - .and_then(|raw| A::from_raw(raw)) + .and_then(|raw| A::from_raw(&raw)) } /// Returns an iterator over the attributes in the [`Message`]. - pub fn iter_attributes(&self) -> impl Iterator { - self.attributes.iter() + pub fn iter_attributes(&self) -> impl Iterator { + MessageAttributesIter { + data: self.data, + data_i: 20, + seen_message_integrity: false, + } } /// Check that a message [`Message`] only contains required attributes that are supported and @@ -1288,7 +1080,9 @@ impl Message { /// # use stun_types::attribute::*; /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; /// # use std::convert::TryInto; - /// let mut message = Message::new_request(BINDING); + /// let mut builder = Message::builder_request(BINDING); + /// let message = builder.build(); + /// let message = Message::from_bytes(&message).unwrap(); /// // If nothing is required, no error response is returned /// assert!(matches!(Message::check_attribute_types(&message, &[], &[]), None)); /// @@ -1300,13 +1094,20 @@ impl Message { /// &[Software::TYPE] /// ).unwrap(); /// assert!(error_msg.has_attribute(ErrorCode::TYPE)); + /// let error_msg = error_msg.build(); + /// let error_msg = Message::from_bytes(&error_msg).unwrap(); /// let error_code = error_msg.attribute::().unwrap(); /// assert_eq!(error_code.code(), 400); /// - /// message.add_attribute(Username::new("user").unwrap()); + /// let username = Username::new("user").unwrap(); + /// builder.add_attribute(&username).unwrap(); + /// let message = builder.build(); + /// let message = Message::from_bytes(&message).unwrap(); /// // If a Username is in the message but is not advertised as supported then an /// // 'UNKNOWN-ATTRIBUTES' error response is returned /// let error_msg = Message::check_attribute_types(&message, &[], &[]).unwrap(); + /// let error_msg = error_msg.build(); + /// let error_msg = Message::from_bytes(&error_msg).unwrap(); /// assert!(error_msg.is_response()); /// assert!(error_msg.has_attribute(ErrorCode::TYPE)); /// let error_code : ErrorCode = error_msg.attribute::().unwrap(); @@ -1320,11 +1121,11 @@ impl Message { msg.transaction = %msg.transaction_id(), ) )] - pub fn check_attribute_types( + pub fn check_attribute_types<'b>( msg: &Message, supported: &[AttributeType], required_in_msg: &[AttributeType], - ) -> Option { + ) -> Option> { // Attribute -> AttributeType let unsupported: Vec = msg .iter_attributes() @@ -1359,8 +1160,10 @@ impl Message { /// # use stun_types::message::{Message, BINDING}; /// # use stun_types::attribute::*; /// # use std::convert::TryInto; - /// let msg = Message::new_request(BINDING); - /// let error_msg = Message::unknown_attributes(&msg, &[Username::TYPE]); + /// let msg = Message::builder_request(BINDING).build(); + /// let msg = Message::from_bytes(&msg).unwrap(); + /// let error_msg = Message::unknown_attributes(&msg, &[Username::TYPE]).build(); + /// let error_msg = Message::from_bytes(&error_msg).unwrap(); /// assert!(error_msg.is_response()); /// assert!(error_msg.has_attribute(ErrorCode::TYPE)); /// let error_code = error_msg.attribute::().unwrap(); @@ -1368,17 +1171,20 @@ impl Message { /// let unknown = error_msg.attribute::().unwrap(); /// assert!(unknown.has_attribute(Username::TYPE)); /// ``` - pub fn unknown_attributes(src: &Message, attributes: &[AttributeType]) -> Message { - let mut out = Message::new_error(src); - out.add_attribute(Software::new("stun-types").unwrap()) - .unwrap(); - out.add_attribute(ErrorCode::new(420, "Unknown Attributes").unwrap()) + pub fn unknown_attributes<'b>( + src: &Message, + attributes: &[AttributeType], + ) -> MessageBuilder<'b> { + let mut out = Message::builder_error(src); + let software = Software::new("stun-types").unwrap(); + out.add_attribute(&software).unwrap(); + out.add_attribute(&ErrorCode::new(420, "Unknown Attributes").unwrap()) .unwrap(); if !attributes.is_empty() { - out.add_attribute(UnknownAttributes::new(attributes)) + out.add_attribute(&UnknownAttributes::new(attributes)) .unwrap(); } - out + out.into_owned() } /// Generate an error message with an [`ErrorCode`] attribute signalling a 'Bad Request' @@ -1389,19 +1195,21 @@ impl Message { /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; /// # use stun_types::attribute::*; /// # use std::convert::TryInto; - /// let msg = Message::new_request(BINDING); - /// let error_msg = Message::bad_request(&msg); + /// let msg = Message::builder_request(BINDING).build(); + /// let msg = Message::from_bytes(&msg).unwrap(); + /// let error_msg = Message::bad_request(&msg).build(); + /// let error_msg = Message::from_bytes(&error_msg).unwrap(); /// assert!(error_msg.has_attribute(ErrorCode::TYPE)); /// let error_code = error_msg.attribute::().unwrap(); /// assert_eq!(error_code.code(), 400); /// ``` - pub fn bad_request(src: &Message) -> Message { - let mut out = Message::new_error(src); - out.add_attribute(Software::new("stun-types").unwrap()) + pub fn bad_request<'b>(src: &Message) -> MessageBuilder<'b> { + let mut out = Message::builder_error(src); + let software = Software::new("stun-types").unwrap(); + out.add_attribute(&software).unwrap(); + out.add_attribute(&ErrorCode::new(400, "Bad Request").unwrap()) .unwrap(); - out.add_attribute(ErrorCode::new(400, "Bad Request").unwrap()) - .unwrap(); - out + out.into_owned() } /// Whether this message contains an attribute of the specified type. @@ -1411,28 +1219,345 @@ impl Message { /// ``` /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; /// # use stun_types::attribute::{Software, Attribute}; - /// let mut msg = Message::new_request(BINDING); + /// let mut msg = Message::builder_request(BINDING); /// let attr = Software::new("stun-types").unwrap(); - /// assert!(msg.add_attribute(attr.clone()).is_ok()); + /// assert!(msg.add_attribute(&attr).is_ok()); + /// let msg = msg.build(); + /// let msg = Message::from_bytes(&msg).unwrap(); /// assert!(msg.has_attribute(Software::TYPE)); /// ``` pub fn has_attribute(&self, atype: AttributeType) -> bool { - self.attributes.iter().any(|attr| attr.get_type() == atype) - } -} -impl From for Vec { - fn from(f: Message) -> Self { - f.to_bytes() + self.iter_attributes().any(|attr| attr.get_type() == atype) } } -impl TryFrom<&[u8]> for Message { +impl<'a> TryFrom<&'a [u8]> for Message<'a> { type Error = StunParseError; - fn try_from(value: &[u8]) -> Result { + fn try_from(value: &'a [u8]) -> Result { Message::from_bytes(value) } } +#[doc(hidden)] +pub struct MessageAttributesIter<'a> { + data: &'a [u8], + data_i: usize, + seen_message_integrity: bool, +} + +impl<'a> Iterator for MessageAttributesIter<'a> { + type Item = RawAttribute<'a>; + + fn next(&mut self) -> Option { + if self.data_i >= self.data.len() { + return None; + } + + let Ok(attr) = RawAttribute::from_bytes(&self.data[self.data_i..]) else { + self.data_i = self.data.len(); + return None; + }; + let padded_len = padded_attr_size(&attr); + self.data_i += padded_len; + if self.seen_message_integrity { + if attr.get_type() == Fingerprint::TYPE { + return Some(attr); + } + return None; + } + if attr.get_type() == MessageIntegrity::TYPE + || attr.get_type() == MessageIntegritySha256::TYPE + { + self.seen_message_integrity = true; + } + + Some(attr) + } +} + +#[derive(Clone, Debug)] +pub struct MessageBuilder<'a> { + msg_type: MessageType, + transaction_id: TransactionId, + attributes: Vec>, +} + +impl<'a> MessageBuilder<'a> { + /// Consume this builder and produce a new owned version. + pub fn into_owned<'b>(self) -> MessageBuilder<'b> { + MessageBuilder { + msg_type: self.msg_type, + transaction_id: self.transaction_id, + attributes: self + .attributes + .into_iter() + .map(|attr| attr.into_owned()) + .collect(), + } + } + + /// Retrieves the 96-bit transaction ID of the [`Message`] + /// + /// # Examples + /// + /// ``` + /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING, TransactionId}; + /// let mtype = MessageType::from_class_method(MessageClass::Request, BINDING); + /// let transaction_id = TransactionId::generate(); + /// let message = Message::builder(mtype, transaction_id).build(); + /// let message = Message::from_bytes(&message).unwrap(); + /// assert_eq!(message.transaction_id(), transaction_id); + /// ``` + pub fn transaction_id(&self) -> TransactionId { + self.transaction_id + } + + /// Whether this [`MessageBuilder`] is for a particular [`MessageClass`] + pub fn has_class(&self, cls: MessageClass) -> bool { + self.msg_type.class() == cls + } + + /// Serialize a `MessageBuilder` to network bytes + /// + /// # Examples + /// + /// ``` + /// # use stun_types::attribute::{RawAttribute, Attribute}; + /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; + /// let mut message = Message::builder(MessageType::from_class_method(MessageClass::Request, BINDING), 1000.into()); + /// let attr = RawAttribute::new(1.into(), &[3]); + /// assert!(message.add_attribute(attr).is_ok()); + /// assert_eq!(message.build(), vec![0, 1, 0, 8, 33, 18, 164, 66, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 232, 0, 1, 0, 1, 3, 0, 0, 0]); + /// ``` + #[tracing::instrument( + name = "message_build", + level = "trace", + skip(self), + fields( + msg.transaction_id = %self.transaction_id() + ) + )] + pub fn build(&self) -> Vec { + let mut attr_size = 0; + for attr in &self.attributes { + attr_size += padded_attr_size(attr); + } + let mut ret = Vec::with_capacity(20 + attr_size); + ret.extend(self.msg_type.to_bytes()); + ret.resize(20, 0); + let transaction: u128 = self.transaction_id.into(); + let tid = (MAGIC_COOKIE as u128) << 96 | transaction & 0xffff_ffff_ffff_ffff_ffff_ffff; + BigEndian::write_u128(&mut ret[4..20], tid); + BigEndian::write_u16(&mut ret[2..4], attr_size as u16); + for attr in &self.attributes { + let bytes = attr.to_bytes(); + ret.extend(bytes); + } + ret + } + + // message-integrity is computed using all the data up to (exclusive of) the + // MESSAGE-INTEGRITY but with a length field including the MESSAGE-INTEGRITY attribute... + fn integrity_bytes_from_message(&self, extra_len: u16) -> Vec { + let mut bytes = self.build(); + // rewrite the length to include the message-integrity attribute + let existing_len = BigEndian::read_u16(&bytes[2..4]); + BigEndian::write_u16(&mut bytes[2..4], existing_len + extra_len); + bytes + } + + /// Adds MESSAGE_INTEGRITY attribute to a [`Message`] using the provided credentials + /// + /// # Errors + /// + /// - If a [`MessageIntegrity`] attribute is already present + /// - If a [`MessageIntegritySha256`] attribute is already present + /// - If a [`Fingerprint`] attribute is already present + /// + /// # Examples + /// + /// ``` + /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING, + /// MessageIntegrityCredentials, ShortTermCredentials, IntegrityAlgorithm, StunWriteError}; + /// # use stun_types::attribute::{Attribute, MessageIntegrity, MessageIntegritySha256}; + /// let mut message = Message::builder_request(BINDING); + /// let credentials = ShortTermCredentials::new("pass".to_owned()).into(); + /// assert!(message.add_message_integrity(&credentials, IntegrityAlgorithm::Sha1).is_ok()); + /// + /// // duplicate MessageIntegrity is an error + /// assert!(matches!( + /// message.add_message_integrity(&credentials, IntegrityAlgorithm::Sha1), + /// Err(StunWriteError::AttributeExists(MessageIntegrity::TYPE)), + /// )); + /// + /// // only one of MessageIntegrity, and MessageIntegritySha256 is allowed + /// assert!(matches!( + /// message.add_message_integrity(&credentials, IntegrityAlgorithm::Sha256), + /// Err(StunWriteError::AttributeExists(MessageIntegrity::TYPE)), + /// )); + /// + /// let data = message.build(); + /// let message = Message::from_bytes(&data).unwrap(); + /// assert!(message.validate_integrity(&credentials).is_ok()); + /// ``` + #[tracing::instrument( + name = "message_add_integrity", + level = "trace", + err, + skip(self), + fields( + msg.transaction = %self.transaction_id(), + ) + )] + pub fn add_message_integrity( + &mut self, + credentials: &MessageIntegrityCredentials, + algorithm: IntegrityAlgorithm, + ) -> Result<(), StunWriteError> { + if self.has_attribute(MessageIntegrity::TYPE) { + return Err(StunWriteError::AttributeExists(MessageIntegrity::TYPE)); + } + if self.has_attribute(MessageIntegritySha256::TYPE) { + return Err(StunWriteError::AttributeExists( + MessageIntegritySha256::TYPE, + )); + } + if self.has_attribute(Fingerprint::TYPE) { + return Err(StunWriteError::FingerprintExists); + } + + let key = credentials.make_hmac_key(); + match algorithm { + IntegrityAlgorithm::Sha1 => { + let bytes = self.integrity_bytes_from_message(24); + let integrity = MessageIntegrity::compute(&bytes, &key).unwrap(); + self.attributes + .push(RawAttribute::from(&MessageIntegrity::new(integrity)).into_owned()); + } + IntegrityAlgorithm::Sha256 => { + let bytes = self.integrity_bytes_from_message(36); + let integrity = MessageIntegritySha256::compute(&bytes, &key).unwrap(); + self.attributes.push( + RawAttribute::from(&MessageIntegritySha256::new(integrity.as_slice())?) + .into_owned(), + ); + } + } + Ok(()) + } + + /// Adds [`Fingerprint`] attribute to a [`Message`] + /// + /// # Errors + /// + /// - If a [`Fingerprint`] attribute is already present + /// + /// # Examples + /// + /// ``` + /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; + /// let mut message = Message::builder_request(BINDING); + /// assert!(message.add_fingerprint().is_ok()); + /// + /// // duplicate FINGERPRINT is an error + /// assert!(message.add_fingerprint().is_err()); + /// ``` + #[tracing::instrument( + name = "message_add_fingerprint", + level = "trace", + skip(self), + fields( + msg.transaction = %self.transaction_id(), + ) + )] + pub fn add_fingerprint(&mut self) -> Result<(), StunWriteError> { + if self.has_attribute(Fingerprint::TYPE) { + return Err(StunWriteError::AttributeExists(Fingerprint::TYPE)); + } + // fingerprint is computed using all the data up to (exclusive of) the FINGERPRINT + // but with a length field including the FINGERPRINT attribute... + let mut bytes = self.build(); + // rewrite the length to include the fingerprint attribute + let existing_len = BigEndian::read_u16(&bytes[2..4]); + BigEndian::write_u16(&mut bytes[2..4], existing_len + 8); + let fingerprint = Fingerprint::compute(&bytes); + self.attributes + .push(RawAttribute::from(&Fingerprint::new(fingerprint))); + Ok(()) + } + + /// Add a `Attribute` to this `Message`. Only one `AttributeType` can be added for each + /// `Attribute. Attempting to add multiple `Atribute`s of the same `AttributeType` will fail. + /// + /// # Errors + /// + /// - if a [`MessageIntegrity`] or [`MessageIntegritySha256`] attribute is attempted to be added. Use + /// `Message::add_message_integrity` instead. + /// - if a [`Fingerprint`] attribute is attempted to be added. Use + /// `Message::add_fingerprint` instead. + /// - If the attribute already exists within the message + /// - If attempting to add attributes when [`MessageIntegrity`], [`MessageIntegritySha256`] or + /// [`Fingerprint`] atributes already exist. + /// + /// # Examples + /// + /// Add an `Attribute` + /// + /// ``` + /// # use stun_types::attribute::RawAttribute; + /// # use stun_types::message::{Message, MessageType, MessageClass, BINDING}; + /// let mut message = Message::builder_request(BINDING); + /// let attr = RawAttribute::new(1.into(), &[3]); + /// assert!(message.add_attribute(attr.clone()).is_ok()); + /// assert!(message.add_attribute(attr).is_err()); + /// ``` + #[tracing::instrument( + name = "message_add_attribute", + level = "trace", + err, + skip(self, attr), + fields( + msg.transaction = %self.transaction_id(), + ) + )] + pub fn add_attribute( + &mut self, + attr: impl Into>, + ) -> Result<(), StunWriteError> { + let raw = attr.into(); + //trace!("adding attribute {:?}", attr); + if raw.get_type() == MessageIntegrity::TYPE { + panic!("Cannot write MessageIntegrity with `add_attribute`. Use add_message_integrity() instead"); + } + if raw.get_type() == MessageIntegritySha256::TYPE { + panic!("Cannot write MessageIntegritySha256 with `add_attribute`. Use add_message_integrity() instead"); + } + if raw.get_type() == Fingerprint::TYPE { + panic!("Cannot write Fingerprint with `add_attribute`. Use add_fingerprint() instead"); + } + if self.has_attribute(raw.get_type()) { + return Err(StunWriteError::AttributeExists(raw.get_type())); + } + // can't validly add generic attributes after message integrity or fingerprint + if self.has_attribute(MessageIntegrity::TYPE) { + return Err(StunWriteError::MessageIntegrityExists); + } + if self.has_attribute(MessageIntegritySha256::TYPE) { + return Err(StunWriteError::MessageIntegrityExists); + } + if self.has_attribute(Fingerprint::TYPE) { + return Err(StunWriteError::FingerprintExists); + } + self.attributes.push(raw); + Ok(()) + } + + /// Return whether this [`MessageBuilder`] contains a particular attribute. + pub fn has_attribute(&self, atype: AttributeType) -> bool { + self.attributes.iter().any(|attr| attr.get_type() == atype) + } +} + #[cfg(test)] mod tests { use super::*; @@ -1474,14 +1599,14 @@ mod tests { for c in classes { let mtype = MessageType::from_class_method(c, m); for tid in (0x18..0xff_ffff_ffff_ffff_ffff).step_by(0xfedc_ba98_7654_3210) { - let mut msg = Message::new(mtype, tid.into()); + let mut msg = Message::builder(mtype, tid.into()); let attr = RawAttribute::new(1.into(), &[3]); assert!(msg.add_attribute(attr.clone()).is_ok()); - let data = msg.to_bytes(); + let data = msg.build(); let msg = Message::from_bytes(&data).unwrap(); let msg_attr = msg.raw_attribute(1.into()).unwrap(); - assert_eq!(msg_attr, &attr); + assert_eq!(msg_attr, attr); assert_eq!(msg.get_type(), mtype); assert_eq!(msg.transaction_id(), tid.into()); } @@ -1491,8 +1616,10 @@ mod tests { #[test] fn unknown_attributes() { - let src = Message::new_request(BINDING); - let msg = Message::unknown_attributes(&src, &[Software::TYPE]); + let src = Message::builder_request(BINDING).build(); + let src = Message::from_bytes(&src).unwrap(); + let msg = Message::unknown_attributes(&src, &[Software::TYPE]).build(); + let msg = Message::from_bytes(&msg).unwrap(); assert_eq!(msg.transaction_id(), src.transaction_id()); assert_eq!(msg.class(), MessageClass::Error); assert_eq!(msg.method(), src.method()); @@ -1504,8 +1631,10 @@ mod tests { #[test] fn bad_request() { - let src = Message::new_request(BINDING); - let msg = Message::bad_request(&src); + let src = Message::builder_request(BINDING).build(); + let src = Message::from_bytes(&src).unwrap(); + let msg = Message::bad_request(&src).build(); + let msg = Message::from_bytes(&msg).unwrap(); assert_eq!(msg.transaction_id(), src.transaction_id()); assert_eq!(msg.class(), MessageClass::Error); assert_eq!(msg.method(), src.method()); @@ -1516,73 +1645,45 @@ mod tests { #[test] fn fingerprint() { init(); - let mut msg = Message::new_request(BINDING); - let software_str = "s"; - msg.add_attribute(Software::new(software_str).unwrap()) - .unwrap(); + let mut msg = Message::builder_request(BINDING); + let software = Software::new("s").unwrap(); + msg.add_attribute(&software).unwrap(); msg.add_fingerprint().unwrap(); - let orig_fingerprint = msg.attribute::().unwrap(); - let bytes: Vec<_> = msg.into(); + let bytes = msg.build(); // validates the fingerprint of the data when available let new_msg = Message::from_bytes(&bytes).unwrap(); let software = new_msg.attribute::().unwrap(); - assert_eq!(software.software(), software_str); - let new_fingerprint = new_msg.attribute::().unwrap(); - assert_eq!( - orig_fingerprint.fingerprint(), - new_fingerprint.fingerprint() - ); + assert_eq!(software.software(), "s"); + let _new_fingerprint = new_msg.attribute::().unwrap(); } #[test] fn integrity() { init(); for algorithm in [IntegrityAlgorithm::Sha1, IntegrityAlgorithm::Sha256] { - let mut msg = Message::new_request(BINDING); - let software_str = "s"; let credentials = ShortTermCredentials::new("secret".to_owned()).into(); - msg.add_attribute(Software::new(software_str).unwrap()) - .unwrap(); + let mut msg = Message::builder_request(BINDING); + let software = Software::new("s").unwrap(); + msg.add_attribute(&software).unwrap(); msg.add_message_integrity(&credentials, algorithm).unwrap(); - let bytes: Vec<_> = msg.clone().into(); - msg.validate_integrity(&bytes, &credentials).unwrap(); - let orig_hmac = match algorithm { - IntegrityAlgorithm::Sha1 => { - msg.attribute::().unwrap().hmac().to_vec() - } - IntegrityAlgorithm::Sha256 => msg - .attribute::() - .unwrap() - .hmac() - .to_vec(), - }; + let bytes = msg.build(); // validates the fingerprint of the data when available let new_msg = Message::from_bytes(&bytes).unwrap(); + new_msg.validate_integrity(&credentials).unwrap(); let software = new_msg.attribute::().unwrap(); - assert_eq!(software.software(), software_str); - let new_hmac = match algorithm { - IntegrityAlgorithm::Sha1 => new_msg - .attribute::() - .unwrap() - .hmac() - .to_vec(), - IntegrityAlgorithm::Sha256 => new_msg - .attribute::() - .unwrap() - .hmac() - .to_vec(), - }; - assert_eq!(orig_hmac, new_hmac); - new_msg.validate_integrity(&bytes, &credentials).unwrap(); + assert_eq!(software.software(), "s"); } } #[test] fn valid_attributes() { init(); - let mut src = Message::new_request(BINDING); - src.add_attribute(Username::new("123").unwrap()).unwrap(); - src.add_attribute(Priority::new(123)).unwrap(); + let mut src = Message::builder_request(BINDING); + let username = Username::new("123").unwrap(); + src.add_attribute(&username).unwrap(); + src.add_attribute(&Priority::new(123)).unwrap(); + let src = src.build(); + let src = Message::from_bytes(&src).unwrap(); // success case let res = Message::check_attribute_types( @@ -1600,6 +1701,8 @@ mod tests { ); assert!(res.is_some()); let res = res.unwrap(); + let res = res.build(); + let res = Message::from_bytes(&res).unwrap(); assert!(res.has_class(MessageClass::Error)); assert!(res.has_method(src.method())); let err = res.attribute::().unwrap(); @@ -1609,6 +1712,8 @@ mod tests { let res = Message::check_attribute_types(&src, &[Username::TYPE], &[]); assert!(res.is_some()); let res = res.unwrap(); + let data = res.build(); + let res = Message::from_bytes(&data).unwrap(); assert!(res.has_class(MessageClass::Error)); assert!(res.has_method(src.method())); let err = res.attribute::().unwrap(); @@ -1650,52 +1755,70 @@ mod tests { 0x80, 0x28, 0x00, 0x04, // FINGERPRINT header 0xe5, 0x7a, 0x3b, 0xcf, // CRC32 fingerprint ]; - let msg = Message::from_bytes(&data).unwrap(); assert!(msg.has_class(MessageClass::Request)); assert!(msg.has_method(BINDING)); assert_eq!(msg.transaction_id(), 0xb7e7_a701_bc34_d686_fa87_dfae.into()); + let mut builder = Message::builder( + MessageType::from_class_method(MessageClass::Request, BINDING), + msg.transaction_id(), + ); + // SOFTWARE assert!(msg.has_attribute(Software::TYPE)); let raw = msg.raw_attribute(Software::TYPE).unwrap(); - assert!(matches!(Software::try_from(raw), Ok(_))); - let software = Software::try_from(raw).unwrap(); + assert!(matches!(Software::try_from(&raw), Ok(_))); + let software = Software::try_from(&raw).unwrap(); assert_eq!(software.software(), "STUN test client"); + builder.add_attribute(&software).unwrap(); // PRIORITY assert!(msg.has_attribute(Priority::TYPE)); let raw = msg.raw_attribute(Priority::TYPE).unwrap(); - assert!(matches!(Priority::try_from(raw), Ok(_))); - let priority = Priority::try_from(raw).unwrap(); + assert!(matches!(Priority::try_from(&raw), Ok(_))); + let priority = Priority::try_from(&raw).unwrap(); assert_eq!(priority.priority(), 0x6e0001ff); + builder.add_attribute(&priority).unwrap(); + + // ICE-CONTROLLED + assert!(msg.has_attribute(IceControlled::TYPE)); + let raw = msg.raw_attribute(IceControlled::TYPE).unwrap(); + assert!(matches!(IceControlled::try_from(&raw), Ok(_))); + let ice = IceControlled::try_from(&raw).unwrap(); + assert_eq!(ice.tie_breaker(), 0x932f_f9b1_5126_3b36); + builder.add_attribute(&ice).unwrap(); // USERNAME assert!(msg.has_attribute(Username::TYPE)); let raw = msg.raw_attribute(Username::TYPE).unwrap(); - assert!(matches!(Username::try_from(raw), Ok(_))); - let username = Username::try_from(raw).unwrap(); + assert!(matches!(Username::try_from(&raw), Ok(_))); + let username = Username::try_from(&raw).unwrap(); assert_eq!(username.username(), "evtj:h6vY"); + builder.add_attribute(&username).unwrap(); // MESSAGE_INTEGRITY let credentials = MessageIntegrityCredentials::ShortTerm(ShortTermCredentials { password: "VOkJxbRl1RmTxUk/WvJxBt".to_owned(), }); - assert!(matches!( - msg.validate_integrity(&data, &credentials), - Ok(()) - )); + assert!(matches!(msg.validate_integrity(&credentials), Ok(()))); + builder + .add_message_integrity(&credentials, IntegrityAlgorithm::Sha1) + .unwrap(); // FINGERPRINT is checked by Message::from_bytes() when present assert!(msg.has_attribute(Fingerprint::TYPE)); + builder.add_fingerprint().unwrap(); // assert that we produce the same output as we parsed in this case - let mut msg_data = msg.to_bytes(); + let mut msg_data = builder.build(); // match the padding bytes with the original msg_data[73] = 0x20; msg_data[74] = 0x20; msg_data[75] = 0x20; - assert_eq!(msg_data, data); + // as a result of the padding difference, the message integrity and fingerpinrt values will + // be different + assert_eq!(msg_data[..80], data[..80]); } #[test] @@ -1729,40 +1852,50 @@ mod tests { assert!(msg.has_class(MessageClass::Success)); assert!(msg.has_method(BINDING)); assert_eq!(msg.transaction_id(), 0xb7e7_a701_bc34_d686_fa87_dfae.into()); + let mut builder = Message::builder( + MessageType::from_class_method(MessageClass::Success, BINDING), + msg.transaction_id(), + ); // SOFTWARE assert!(msg.has_attribute(Software::TYPE)); let raw = msg.raw_attribute(Software::TYPE).unwrap(); - assert!(matches!(Software::try_from(raw), Ok(_))); - let software = Software::try_from(raw).unwrap(); + assert!(matches!(Software::try_from(&raw), Ok(_))); + let software = Software::try_from(&raw).unwrap(); assert_eq!(software.software(), "test vector"); + builder.add_attribute(&software).unwrap(); // XOR_MAPPED_ADDRESS assert!(msg.has_attribute(XorMappedAddress::TYPE)); let raw = msg.raw_attribute(XorMappedAddress::TYPE).unwrap(); - assert!(matches!(XorMappedAddress::try_from(raw), Ok(_))); - let xor_mapped_addres = XorMappedAddress::try_from(raw).unwrap(); + assert!(matches!(XorMappedAddress::try_from(&raw), Ok(_))); + let xor_mapped_addres = XorMappedAddress::try_from(&raw).unwrap(); assert_eq!( xor_mapped_addres.addr(msg.transaction_id()), "192.0.2.1:32853".parse().unwrap() ); + builder.add_attribute(&xor_mapped_addres).unwrap(); // MESSAGE_INTEGRITY let credentials = MessageIntegrityCredentials::ShortTerm(ShortTermCredentials { password: "VOkJxbRl1RmTxUk/WvJxBt".to_owned(), }); - let ret = msg.validate_integrity(&data, &credentials); + let ret = msg.validate_integrity(&credentials); debug!("{:?}", ret); assert!(matches!(ret, Ok(()))); + builder + .add_message_integrity(&credentials, IntegrityAlgorithm::Sha1) + .unwrap(); // FINGERPRINT is checked by Message::from_bytes() when present assert!(msg.has_attribute(Fingerprint::TYPE)); + builder.add_fingerprint().unwrap(); // assert that we produce the same output as we parsed in this case - let mut msg_data = msg.to_bytes(); + let mut msg_data = builder.build(); // match the padding bytes with the original msg_data[35] = 0x20; - assert_eq!(msg_data, data); + assert_eq!(msg_data[..52], data[..52]); } #[test] @@ -1799,43 +1932,50 @@ mod tests { assert!(msg.has_class(MessageClass::Success)); assert!(msg.has_method(BINDING)); assert_eq!(msg.transaction_id(), 0xb7e7_a701_bc34_d686_fa87_dfae.into()); + let mut builder = Message::builder( + MessageType::from_class_method(MessageClass::Success, BINDING), + msg.transaction_id(), + ); // SOFTWARE assert!(msg.has_attribute(Software::TYPE)); let raw = msg.raw_attribute(Software::TYPE).unwrap(); - assert!(matches!(Software::try_from(raw), Ok(_))); - let software = Software::try_from(raw).unwrap(); + assert!(matches!(Software::try_from(&raw), Ok(_))); + let software = Software::try_from(&raw).unwrap(); assert_eq!(software.software(), "test vector"); + builder.add_attribute(&software).unwrap(); // XOR_MAPPED_ADDRESS assert!(msg.has_attribute(XorMappedAddress::TYPE)); let raw = msg.raw_attribute(XorMappedAddress::TYPE).unwrap(); - assert!(matches!(XorMappedAddress::try_from(raw), Ok(_))); - let xor_mapped_addres = XorMappedAddress::try_from(raw).unwrap(); + assert!(matches!(XorMappedAddress::try_from(&raw), Ok(_))); + let xor_mapped_addres = XorMappedAddress::try_from(&raw).unwrap(); assert_eq!( xor_mapped_addres.addr(msg.transaction_id()), "[2001:db8:1234:5678:11:2233:4455:6677]:32853" .parse() .unwrap() ); + builder.add_attribute(&xor_mapped_addres).unwrap(); // MESSAGE_INTEGRITY let credentials = MessageIntegrityCredentials::ShortTerm(ShortTermCredentials { password: "VOkJxbRl1RmTxUk/WvJxBt".to_owned(), }); - assert!(matches!( - msg.validate_integrity(&data, &credentials), - Ok(()) - )); + assert!(matches!(msg.validate_integrity(&credentials), Ok(()))); + builder + .add_message_integrity(&credentials, IntegrityAlgorithm::Sha1) + .unwrap(); // FINGERPRINT is checked by Message::from_bytes() when present assert!(msg.has_attribute(Fingerprint::TYPE)); + builder.add_fingerprint().unwrap(); // assert that we produce the same output as we parsed in this case - let mut msg_data = msg.to_bytes(); + let mut msg_data = builder.build(); // match the padding bytes with the original msg_data[35] = 0x20; - assert_eq!(msg_data, data); + assert_eq!(msg_data[..64], data[..64]); } #[test] @@ -1878,6 +2018,10 @@ mod tests { assert!(msg.has_class(MessageClass::Request)); assert!(msg.has_method(BINDING)); assert_eq!(msg.transaction_id(), 0x78ad_3433_c6ad_72c0_29da_412e.into()); + let mut builder = Message::builder( + MessageType::from_class_method(MessageClass::Request, BINDING), + msg.transaction_id(), + ); let long_term = LongTermCredentials { username: "\u{30DE}\u{30C8}\u{30EA}\u{30C3}\u{30AF}\u{30B9}".to_owned(), @@ -1887,25 +2031,36 @@ mod tests { // USERNAME assert!(msg.has_attribute(Username::TYPE)); let raw = msg.raw_attribute(Username::TYPE).unwrap(); - assert!(matches!(Username::try_from(raw), Ok(_))); - let username = Username::try_from(raw).unwrap(); + assert!(matches!(Username::try_from(&raw), Ok(_))); + let username = Username::try_from(&raw).unwrap(); assert_eq!(username.username(), &long_term.username); + builder.add_attribute(&username).unwrap(); // NONCE let expected_nonce = "f//499k954d6OL34oL9FSTvy64sA"; assert!(msg.has_attribute(Nonce::TYPE)); let raw = msg.raw_attribute(Nonce::TYPE).unwrap(); - assert!(matches!(Nonce::try_from(raw), Ok(_))); - let nonce = Nonce::try_from(raw).unwrap(); + assert!(matches!(Nonce::try_from(&raw), Ok(_))); + let nonce = Nonce::try_from(&raw).unwrap(); assert_eq!(nonce.nonce(), expected_nonce); + builder.add_attribute(&nonce).unwrap(); + + // REALM + assert!(msg.has_attribute(Realm::TYPE)); + let raw = msg.raw_attribute(Realm::TYPE).unwrap(); + assert!(matches!(Realm::try_from(&raw), Ok(_))); + let realm = Realm::try_from(&raw).unwrap(); + assert_eq!(realm.realm(), long_term.realm()); + builder.add_attribute(&realm).unwrap(); // MESSAGE_INTEGRITY /* XXX: the password needs SASLPrep-ing to be useful here let credentials = MessageIntegrityCredentials::LongTerm(long_term); assert!(matches!(msg.validate_integrity(&data, &credentials), Ok(()))); */ + //builder.add_attribute(msg.raw_attribute(MessageIntegrity::TYPE).unwrap()).unwrap(); - assert_eq!(msg.to_bytes(), data); + assert_eq!(builder.build()[4..], data[4..92]); } #[test] @@ -1961,6 +2116,10 @@ mod tests { assert!(msg.has_class(MessageClass::Request)); assert!(msg.has_method(BINDING)); assert_eq!(msg.transaction_id(), 0x78ad_3433_c6ad_72c0_29da_412e.into()); + let mut builder = Message::builder( + MessageType::from_class_method(MessageClass::Success, BINDING), + msg.transaction_id(), + ); let long_term = LongTermCredentials { username: "\u{30DE}\u{30C8}\u{30EA}\u{30C3}\u{30AF}\u{30B9}".to_owned(), @@ -1970,37 +2129,42 @@ mod tests { // USERHASH assert!(msg.has_attribute(Userhash::TYPE)); let raw = msg.raw_attribute(Userhash::TYPE).unwrap(); - assert!(matches!(Userhash::try_from(raw), Ok(_))); - let _userhash = Userhash::try_from(raw).unwrap(); + assert!(matches!(Userhash::try_from(&raw), Ok(_))); + let userhash = Userhash::try_from(&raw).unwrap(); + builder.add_attribute(&userhash).unwrap(); // NONCE let expected_nonce = "obMatJos2AAACf//499k954d6OL34oL9FSTvy64sA"; assert!(msg.has_attribute(Nonce::TYPE)); let raw = msg.raw_attribute(Nonce::TYPE).unwrap(); - assert!(matches!(Nonce::try_from(raw), Ok(_))); - let nonce = Nonce::try_from(raw).unwrap(); + assert!(matches!(Nonce::try_from(&raw), Ok(_))); + let nonce = Nonce::try_from(&raw).unwrap(); assert_eq!(nonce.nonce(), expected_nonce); + builder.add_attribute(&nonce).unwrap(); // REALM assert!(msg.has_attribute(Realm::TYPE)); let raw = msg.raw_attribute(Realm::TYPE).unwrap(); - assert!(matches!(Realm::try_from(raw), Ok(_))); - let realm = Realm::try_from(raw).unwrap(); + assert!(matches!(Realm::try_from(&raw), Ok(_))); + let realm = Realm::try_from(&raw).unwrap(); assert_eq!(realm.realm(), long_term.realm); + builder.add_attribute(&realm).unwrap(); // PASSWORD_ALGORITHM assert!(msg.has_attribute(PasswordAlgorithm::TYPE)); let raw = msg.raw_attribute(PasswordAlgorithm::TYPE).unwrap(); - assert!(matches!(PasswordAlgorithm::try_from(raw), Ok(_))); - let algo = PasswordAlgorithm::try_from(raw).unwrap(); + assert!(matches!(PasswordAlgorithm::try_from(&raw), Ok(_))); + let algo = PasswordAlgorithm::try_from(&raw).unwrap(); assert_eq!(algo.algorithm(), PasswordAlgorithmValue::SHA256); + builder.add_attribute(&algo).unwrap(); // MESSAGE_INTEGRITY_SHA256 /* XXX: the password needs SASLPrep-ing to be useful here let credentials = MessageIntegrityCredentials::LongTerm(long_term); assert!(matches!(msg.validate_integrity(&data, &credentials), Ok(()))); */ + //builder.add_attribute(msg.raw_attribute(MessageIntegritySha256::TYPE).unwrap()).unwrap(); - assert_eq!(msg.to_bytes(), data); + assert_eq!(builder.build()[4..], data[4..128]); } }