Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add builder for additional header values #400

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
//! Methods to connect to a WebSocket as a client.

use std::{
convert::TryFrom,
io::{Read, Write},
net::{SocketAddr, TcpStream, ToSocketAddrs},
result::Result as StdResult,
};

use http::{request::Parts, Uri};
use http::{request::Parts, HeaderName, Uri};
use log::*;

use url::Url;
Expand Down Expand Up @@ -265,3 +266,73 @@ impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
Request::from_httparse(self)
}
}

/// Builder for a custom [`IntoClientRequest`] with options to add
/// custom additional headers and sub protocols.
///
/// # Example
///
/// ```rust no_run
/// # use crate::*;
/// use http::Uri;
/// use tungstenite::{connect, ClientRequestBuilder};
///
/// let uri: Uri = "ws://localhost:3012/socket".parse().unwrap();
/// let token = "my_jwt_token";
/// let builder = ClientRequestBuilder::new(uri)
/// .with_header("Authorization", format!("Bearer {token}"))
/// .with_sub_protocol("my_sub_protocol");
/// let socket = connect(builder).unwrap();
/// ```
#[derive(Debug, Clone)]
pub struct ClientRequestBuilder {
uri: Uri,
/// Additional [`Request`] handshake headers
additional_headers: Vec<(String, String)>,
/// Handsake subprotocols
subprotocols: Vec<String>,
}

impl ClientRequestBuilder {
/// Initializes an empty request builder
#[must_use]
pub const fn new(uri: Uri) -> Self {
Self { uri, additional_headers: Vec::new(), subprotocols: Vec::new() }
}

/// Adds (`key`, `value`) as an additional header to the handshake request
pub fn with_header<K, V>(mut self, key: K, value: V) -> Self
where
K: Into<String>,
V: Into<String>,
{
self.additional_headers.push((key.into(), value.into()));
self
}

/// Adds `protocol` to the handshake request subprotocols (`Sec-WebSocket-Protocol`)
pub fn with_sub_protocol<P>(mut self, protocol: P) -> Self
where
P: Into<String>,
{
self.subprotocols.push(protocol.into());
self
}
}

impl IntoClientRequest for ClientRequestBuilder {
fn into_client_request(self) -> Result<Request> {
let mut request = self.uri.into_client_request()?;
let headers = request.headers_mut();
for (k, v) in self.additional_headers {
let key = HeaderName::try_from(k)?;
let value = v.parse()?;
headers.append(key, value);
}
if !self.subprotocols.is_empty() {
let protocols = self.subprotocols.join(", ").parse()?;
headers.append("Sec-WebSocket-Protocol", protocols);
}
Ok(request)
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub use crate::{

#[cfg(feature = "handshake")]
pub use crate::{
client::{client, connect},
client::{client, connect, ClientRequestBuilder},
handshake::{client::ClientHandshake, server::ServerHandshake, HandshakeError},
server::{accept, accept_hdr, accept_hdr_with_config, accept_with_config},
};
Expand Down
92 changes: 92 additions & 0 deletions tests/client_headers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#![cfg(feature = "handshake")]

use http::Uri;
use std::{
net::TcpListener,
process::exit,
thread::{sleep, spawn},
time::Duration,
};
use tungstenite::{
accept_hdr, connect,
handshake::server::{Request, Response},
ClientRequestBuilder, Error, Message,
};

/// Test for write buffering and flushing behaviour.
#[test]
fn test_headers() {
env_logger::init();
let uri: Uri = "ws://127.0.0.1:3013/socket".parse().unwrap();
let token = "my_jwt_token";
let full_token = format!("Bearer {token}");
let sub_protocol = "my_sub_protocol";
let builder = ClientRequestBuilder::new(uri)
.with_header("Authorization", full_token.to_owned())
.with_sub_protocol(sub_protocol.to_owned());

spawn(|| {
sleep(Duration::from_secs(5));
println!("Unit test executed too long, perhaps stuck on WOULDBLOCK...");
exit(1);
});

let server = TcpListener::bind("127.0.0.1:3013").unwrap();

let client_thread = spawn(move || {
let (mut client, _) = connect(builder).unwrap();
client.send(Message::Text("Hello WebSocket".into())).unwrap();

let message = client.read().unwrap(); // receive close from server
assert!(message.is_close());

let err = client.read().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}
});

let callback = |req: &Request, response: Response| {
println!("Received a new ws handshake");
println!("The request's path is: {}", req.uri().path());
println!("The request's headers are:");
let authorization_header: String = "authorization".to_ascii_lowercase();
let web_socket_proto: String = "sec-websocket-protocol".to_ascii_lowercase();

for (ref header, value) in req.headers() {
println!("* {}: {}", header, value.to_str().unwrap());
if header.to_string() == authorization_header {
println!("Matching authorization header");
assert_eq!(header.to_string(), authorization_header);
assert_eq!(value.to_str().unwrap(), full_token);
} else if header.to_string() == web_socket_proto {
println!("Matching sec-websocket-protocol header");
assert_eq!(header.to_string(), web_socket_proto);
assert_eq!(value.to_str().unwrap(), sub_protocol);
}
}
Ok(response)
};

let client_handler = server.incoming().next().unwrap();
let mut client_handler = accept_hdr(client_handler.unwrap(), callback).unwrap();

client_handler.close(None).unwrap(); // send close to client

// This read should succeed even though we already initiated a close
let message = client_handler.read().unwrap();
assert_eq!(message.into_data(), b"Hello WebSocket");

assert!(client_handler.read().unwrap().is_close()); // receive acknowledgement

let err = client_handler.read().unwrap_err(); // now we should get ConnectionClosed
match err {
Error::ConnectionClosed => {}
_ => panic!("unexpected error: {:?}", err),
}

drop(client_handler);

client_thread.join().unwrap();
}
Loading