From a2ec84b8223f0c0f3ece7a08803cf92aa2420065 Mon Sep 17 00:00:00 2001 From: conorbros Date: Tue, 27 Jun 2023 16:40:52 +1000 Subject: [PATCH 1/3] subprotocol --- src/error.rs | 20 ++++++ src/handshake/client.rs | 43 ++++++++++++- tests/handshake.rs | 134 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 194 insertions(+), 3 deletions(-) create mode 100644 tests/handshake.rs diff --git a/src/error.rs b/src/error.rs index a7b33545..be69ea0a 100644 --- a/src/error.rs +++ b/src/error.rs @@ -146,6 +146,23 @@ pub enum CapacityError { }, } +/// Indicates the specific type/cause of a subprotocol header error. +#[derive(Error, Clone, PartialEq, Eq, Debug, Copy)] +pub enum SubProtocolError { + /// The server sent a subprotocol to a client handshake request but none was requested + #[error("Server sent a subprotocol but none was requested")] + ServerSentSubProtocolNoneRequested, + + /// The server sent an invalid subprotocol to a client handhshake request + #[error("Server sent an invalid subprotocol")] + InvalidSubProtocol, + + /// The server sent no subprotocol to a client handshake request that requested one or more + /// subprotocols + #[error("Server sent no subprotocol")] + NoSubProtocol, +} + /// Indicates the specific type/cause of a protocol error. #[allow(missing_copy_implementations)] #[derive(Error, Debug, PartialEq, Eq, Clone)] @@ -171,6 +188,9 @@ pub enum ProtocolError { /// The `Sec-WebSocket-Accept` header is either not present or does not specify the correct key value. #[error("Key mismatch in \"Sec-WebSocket-Accept\" header")] SecWebSocketAcceptKeyMismatch, + /// The `Sec-WebSocket-Protocol` header was invalid + #[error("SubProtocol error: {0}")] + SecWebSocketSubProtocolError(SubProtocolError), /// Garbage data encountered after client request. #[error("Junk after client request")] JunkAfterRequest, diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 08cc7b29..51c244d4 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -18,7 +18,7 @@ use super::{ HandshakeRole, MidHandshake, ProcessingResult, }; use crate::{ - error::{Error, ProtocolError, Result, UrlError}, + error::{Error, ProtocolError, Result, SubProtocolError, UrlError}, protocol::{Role, WebSocket, WebSocketConfig}, }; @@ -54,6 +54,8 @@ impl ClientHandshake { // Check the URI scheme: only ws or wss are supported let _ = crate::client::uri_mode(request.uri())?; + let subprotocols = extract_subprotocols_from_request(&request)?; + // Convert and verify the `http::Request` and turn it into the request as per RFC. // Also extract the key from it (it must be present in a correct request). let (request, key) = generate_request(request)?; @@ -62,7 +64,11 @@ impl ClientHandshake { let client = { let accept_key = derive_accept_key(key.as_ref()); - ClientHandshake { verify_data: VerifyData { accept_key }, config, _marker: PhantomData } + ClientHandshake { + verify_data: VerifyData { accept_key, subprotocols }, + config, + _marker: PhantomData, + } }; trace!("Client handshake initiated."); @@ -178,11 +184,22 @@ pub fn generate_request(mut request: Request) -> Result<(Vec, String)> { Ok((req, key)) } +fn extract_subprotocols_from_request(request: &Request) -> Result>> { + if let Some(subprotocols) = request.headers().get("Sec-WebSocket-Protocol") { + Ok(Some(subprotocols.to_str()?.split(",").map(|s| s.to_string()).collect())) + } else { + Ok(None) + } +} + /// Information for handshake verification. #[derive(Debug)] struct VerifyData { /// Accepted server key. accept_key: String, + + /// Accepted subprotocols + subprotocols: Option>, } impl VerifyData { @@ -238,7 +255,27 @@ impl VerifyData { // not present in the client's handshake (the server has indicated a // subprotocol not requested by the client), the client MUST _Fail // the WebSocket Connection_. (RFC 6455) - // TODO + if headers.get("Sec-WebSocket-Protocol").is_none() && self.subprotocols.is_some() { + return Err(Error::Protocol(ProtocolError::SecWebSocketSubProtocolError( + SubProtocolError::NoSubProtocol, + ))); + } + + if headers.get("Sec-WebSocket-Protocol").is_some() && self.subprotocols.is_none() { + return Err(Error::Protocol(ProtocolError::SecWebSocketSubProtocolError( + SubProtocolError::ServerSentSubProtocolNoneRequested, + ))); + } + + if let Some(returned_subprotocol) = headers.get("Sec-WebSocket-Protocol") { + if let Some(accepted_subprotocols) = &self.subprotocols { + if !accepted_subprotocols.contains(&returned_subprotocol.to_str()?.to_string()) { + return Err(Error::Protocol(ProtocolError::SecWebSocketSubProtocolError( + SubProtocolError::InvalidSubProtocol, + ))); + } + } + } Ok(response) } diff --git a/tests/handshake.rs b/tests/handshake.rs new file mode 100644 index 00000000..81e788f2 --- /dev/null +++ b/tests/handshake.rs @@ -0,0 +1,134 @@ +use std::net::TcpListener; +use std::thread::spawn; +use tungstenite::error::{Error, ProtocolError, SubProtocolError}; +use tungstenite::handshake::client::generate_key; +use tungstenite::handshake::server::{Request, Response}; +use tungstenite::{accept_hdr, connect}; + +fn create_http_request(uri: &str, subprotocols: Option>) -> http::Request<()> { + let uri = uri.parse::().unwrap(); + + let authority = uri.authority().unwrap().as_str(); + let host = + authority.find('@').map(|idx| authority.split_at(idx + 1).1).unwrap_or_else(|| authority); + + if host.is_empty() { + panic!("Empty host name"); + } + + let mut builder = http::Request::builder() + .method("GET") + .header("Host", host) + .header("Connection", "Upgrade") + .header("Upgrade", "websocket") + .header("Sec-WebSocket-Version", "13") + .header("Sec-WebSocket-Key", generate_key()); + + if let Some(subprotocols) = subprotocols { + builder = builder.header("Sec-WebSocket-Protocol", subprotocols.join(",")); + } + + builder.uri(uri).body(()).unwrap() +} + +fn server_thread(port: u16, server_subprotocols: Option>) { + spawn(move || { + let server = TcpListener::bind(("127.0.0.1", port)) + .expect("Can't listen, is this port already in use?"); + let client_handler = server.incoming().next().unwrap(); + + let callback = |_request: &Request, mut response: Response| { + if let Some(subprotocols) = server_subprotocols { + let headers = response.headers_mut(); + headers.append("Sec-WebSocket-Protocol", subprotocols.join(",").parse().unwrap()); + } + Ok(response) + }; + + let _client_handler = accept_hdr(client_handler.unwrap(), callback).unwrap(); + }); +} + +#[test] +fn test_server_send_no_subprotocol() { + server_thread(3012, None); + + let err = + connect(create_http_request("ws://127.0.0.1:3012", Some(vec!["my-sub-protocol".into()]))) + .unwrap_err(); + + assert!(matches!( + err, + Error::Protocol(ProtocolError::SecWebSocketSubProtocolError( + SubProtocolError::NoSubProtocol + )) + )); +} + +#[test] +fn test_server_sent_subprotocol_none_requested() { + server_thread(3013, Some(vec!["my-sub-protocol".to_string()])); + + let err = connect(create_http_request("ws://127.0.0.1:3013", None)).unwrap_err(); + + assert!(matches!( + err, + Error::Protocol(ProtocolError::SecWebSocketSubProtocolError( + SubProtocolError::ServerSentSubProtocolNoneRequested + )) + )); +} + +#[test] +fn test_invalid_subprotocol() { + server_thread(3014, Some(vec!["invalid-sub-protocol".to_string()])); + + let err = connect(create_http_request( + "ws://127.0.0.1:3014", + Some(vec!["my-sub-protocol".to_string()]), + )) + .unwrap_err(); + + assert!(matches!( + err, + Error::Protocol(ProtocolError::SecWebSocketSubProtocolError( + SubProtocolError::InvalidSubProtocol + )) + )); +} + +#[test] +fn test_request_multiple_subprotocols() { + server_thread(3015, Some(vec!["my-sub-protocol".to_string()])); + + let (_, response) = connect(create_http_request( + "ws://127.0.0.1:3015", + Some(vec![ + "my-sub-protocol".to_string(), + "my-sub-protocol-1".to_string(), + "my-sub-protocol-2".to_string(), + ]), + )) + .unwrap(); + + assert_eq!( + response.headers().get("Sec-WebSocket-Protocol").unwrap(), + "my-sub-protocol".parse::().unwrap() + ); +} + +#[test] +fn test_request_single_subprotocol() { + server_thread(3016, Some(vec!["my-sub-protocol".to_string()])); + + let (_, response) = connect(create_http_request( + "ws://127.0.0.1:3016", + Some(vec!["my-sub-protocol".to_string()]), + )) + .unwrap(); + + assert_eq!( + response.headers().get("Sec-WebSocket-Protocol").unwrap(), + "my-sub-protocol".parse::().unwrap() + ); +} From adbc70a6b45c4d8e58e2bd98a33138e623a37085 Mon Sep 17 00:00:00 2001 From: n4n5 Date: Fri, 8 Mar 2024 15:37:30 +0100 Subject: [PATCH 2/3] fix test for merging --- tests/client_headers.rs | 6 +++++- tests/handshake.rs | 15 +++++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/client_headers.rs b/tests/client_headers.rs index 1a037af7..f943f8ae 100644 --- a/tests/client_headers.rs +++ b/tests/client_headers.rs @@ -47,7 +47,7 @@ fn test_headers() { } }); - let callback = |req: &Request, response: Response| { + let callback = |req: &Request, mut response: Response| { println!("Received a new ws handshake"); println!("The request's path is: {}", req.uri().path()); println!("The request's headers are:"); @@ -64,6 +64,10 @@ fn test_headers() { println!("Matching sec-websocket-protocol header"); assert_eq!(header.to_string(), web_socket_proto); assert_eq!(value.to_str().unwrap(), sub_protocol); + // the server needs to respond with the same sub-protocol + response + .headers_mut() + .append("sec-websocket-protocol", sub_protocol.parse().unwrap()); } } Ok(response) diff --git a/tests/handshake.rs b/tests/handshake.rs index 81e788f2..d5953237 100644 --- a/tests/handshake.rs +++ b/tests/handshake.rs @@ -1,5 +1,7 @@ +#![cfg(feature = "handshake")] use std::net::TcpListener; -use std::thread::spawn; +use std::thread::{sleep, spawn}; +use std::time::Duration; use tungstenite::error::{Error, ProtocolError, SubProtocolError}; use tungstenite::handshake::client::generate_key; use tungstenite::handshake::server::{Request, Response}; @@ -35,7 +37,6 @@ fn server_thread(port: u16, server_subprotocols: Option>) { spawn(move || { let server = TcpListener::bind(("127.0.0.1", port)) .expect("Can't listen, is this port already in use?"); - let client_handler = server.incoming().next().unwrap(); let callback = |_request: &Request, mut response: Response| { if let Some(subprotocols) = server_subprotocols { @@ -45,13 +46,16 @@ fn server_thread(port: u16, server_subprotocols: Option>) { Ok(response) }; - let _client_handler = accept_hdr(client_handler.unwrap(), callback).unwrap(); + let client_handler = server.incoming().next().unwrap(); + let mut client_handler = accept_hdr(client_handler.unwrap(), callback).unwrap(); + client_handler.close(None).unwrap(); }); } #[test] fn test_server_send_no_subprotocol() { server_thread(3012, None); + sleep(Duration::from_secs(1)); let err = connect(create_http_request("ws://127.0.0.1:3012", Some(vec!["my-sub-protocol".into()]))) @@ -68,6 +72,7 @@ fn test_server_send_no_subprotocol() { #[test] fn test_server_sent_subprotocol_none_requested() { server_thread(3013, Some(vec!["my-sub-protocol".to_string()])); + sleep(Duration::from_secs(1)); let err = connect(create_http_request("ws://127.0.0.1:3013", None)).unwrap_err(); @@ -82,6 +87,7 @@ fn test_server_sent_subprotocol_none_requested() { #[test] fn test_invalid_subprotocol() { server_thread(3014, Some(vec!["invalid-sub-protocol".to_string()])); + sleep(Duration::from_secs(1)); let err = connect(create_http_request( "ws://127.0.0.1:3014", @@ -100,7 +106,7 @@ fn test_invalid_subprotocol() { #[test] fn test_request_multiple_subprotocols() { server_thread(3015, Some(vec!["my-sub-protocol".to_string()])); - + sleep(Duration::from_secs(1)); let (_, response) = connect(create_http_request( "ws://127.0.0.1:3015", Some(vec![ @@ -120,6 +126,7 @@ fn test_request_multiple_subprotocols() { #[test] fn test_request_single_subprotocol() { server_thread(3016, Some(vec!["my-sub-protocol".to_string()])); + sleep(Duration::from_secs(1)); let (_, response) = connect(create_http_request( "ws://127.0.0.1:3016", From 734234a5e2b97fc2ad37a89fe3f7285b1fb7fb06 Mon Sep 17 00:00:00 2001 From: Lucas Kent Date: Fri, 10 May 2024 08:35:58 +1000 Subject: [PATCH 3/3] Update tests/handshake.rs Co-authored-by: n4n5 <56606507+Its-Just-Nans@users.noreply.github.com> --- tests/handshake.rs | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/handshake.rs b/tests/handshake.rs index d5953237..2d07e51b 100644 --- a/tests/handshake.rs +++ b/tests/handshake.rs @@ -1,11 +1,17 @@ #![cfg(feature = "handshake")] -use std::net::TcpListener; -use std::thread::{sleep, spawn}; -use std::time::Duration; -use tungstenite::error::{Error, ProtocolError, SubProtocolError}; -use tungstenite::handshake::client::generate_key; -use tungstenite::handshake::server::{Request, Response}; -use tungstenite::{accept_hdr, connect}; +use std::{ + net::TcpListener, + thread::{sleep, spawn}, + time::Duration, +}; +use tungstenite::{ + accept_hdr, connect, + error::{Error, ProtocolError, SubProtocolError}, + handshake::{ + client::generate_key, + server::{Request, Response}, + }, +}; fn create_http_request(uri: &str, subprotocols: Option>) -> http::Request<()> { let uri = uri.parse::().unwrap();