From a2e42de2218502b034b6adf6402e00b24482afd8 Mon Sep 17 00:00:00 2001 From: Matt Williams Date: Mon, 13 May 2024 17:50:18 -0400 Subject: [PATCH 1/4] Add transmitter sender data layer for asynchronous client Added a fairly indepth example client. --- .gitignore | 1 + Cargo.toml | 1 + socketio/Cargo.toml | 7 +- socketio/examples/async_transmitter.rs | 109 ++++++++++++++++++++ socketio/src/asynchronous/client/builder.rs | 65 ++++++++++-- socketio/src/asynchronous/client/client.rs | 71 +++++++++++-- socketio/src/error.rs | 2 + 7 files changed, 239 insertions(+), 17 deletions(-) create mode 100644 socketio/examples/async_transmitter.rs diff --git a/.gitignore b/.gitignore index 22e6dfbc..f41165ab 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ target +target_ra .vscode .idea ci/node_modules diff --git a/Cargo.toml b/Cargo.toml index 597aaf51..2ce2ed53 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,2 +1,3 @@ [workspace] members = ["engineio", "socketio"] +resolver = "2" diff --git a/socketio/Cargo.toml b/socketio/Cargo.toml index 92598d71..95ee7dfb 100644 --- a/socketio/Cargo.toml +++ b/socketio/Cargo.toml @@ -40,7 +40,7 @@ version = "1.37.0" features = ["macros", "rt-multi-thread"] [features] -default = [] +default = ["async"] async-callbacks = ["rust_engineio/async-callbacks"] async = ["async-callbacks", "rust_engineio/async", "tokio", "futures-util", "async-stream"] @@ -48,3 +48,8 @@ async = ["async-callbacks", "rust_engineio/async", "tokio", "futures-util", "asy name = "async" path = "examples/async.rs" required-features = ["async"] + +[[example]] +name = "async-transmitter" +path = "examples/async_transmitter.rs" +required-features = ["async"] diff --git a/socketio/examples/async_transmitter.rs b/socketio/examples/async_transmitter.rs new file mode 100644 index 00000000..119bcbb6 --- /dev/null +++ b/socketio/examples/async_transmitter.rs @@ -0,0 +1,109 @@ +use futures_util::FutureExt; +use rust_socketio::{ + asynchronous::{Client as SocketIOClient, ClientBuilder as SocketIOClientBuilder}, + Error as SocketIOError, Payload, +}; +use serde_json::json; +use std::sync::{mpsc, Arc}; +use std::time::Duration; +use tokio::time::sleep; + +struct ComplexData { + /// There should be many more fields below in real life, + /// probaly wrapped in Arc> if you're writing a more serious client. + data: String, +} + +struct TransmitterClient { + receiver: mpsc::Receiver, + complex: ComplexData, + client: SocketIOClient, +} + +impl TransmitterClient { + async fn connect(url: &str) -> Result { + let (sender, receiver) = mpsc::channel::(); + + let client = SocketIOClientBuilder::new(url) + .namespace("/admin") + .on("test", |payload: Payload, socket: SocketIOClient| { + async move { + match payload { + Payload::Text(values) => { + if let Some(value) = values.first() { + if value.is_string() { + let result = socket.try_transitter::>(); + + result + .map(|transmitter| { + transmitter.send(String::from(value.as_str().unwrap())) + }) + .map_err(|err| eprintln!("{}", err)) + .ok(); + } + } + } + Payload::Binary(_bin_data) => println!(), + #[allow(deprecated)] + Payload::String(str) => println!("Received: {}", str), + } + } + .boxed() + }) + .on("error", |err, _| { + async move { eprintln!("Error: {:#?}", err) }.boxed() + }) + .transmitter(Arc::new(sender)) + .connect() + .await?; + + Ok(Self { + client, + receiver, + complex: ComplexData { + data: "".to_string(), + }, + }) + } + + async fn get_test(&mut self) -> Option { + match self.client.emit("test", json!({"got ack": true})).await { + Ok(_) => { + match self.receiver.recv() { + Ok(complex_data) => { + // In the real world the data is probably a serialized json_rpc object + // or some other complex data layer which needs complex business and derserialization logic. + // Best to do that here, and not inside those restrictive callbacks. + self.complex.data = complex_data; + Some(self.complex.data.clone()) + } + Err(err) => { + eprintln!("Transmission buffer is probably full: {}", err); + None + } + } + } + Err(err) => { + eprintln!("Server unreachable: {}", err); + None + } + } + } +} + +#[tokio::main] +async fn main() { + match TransmitterClient::connect("http://localhost:4200/").await { + Ok(mut client) => { + if let Some(test_data) = client.get_test().await { + println!("test event data from internal transmitter: {}", test_data); + } + } + Err(err) => { + eprintln!("{}", err); + } + } + + // Wait so we can see our response + sleep(Duration::from_secs(2)).await; +} diff --git a/socketio/src/asynchronous/client/builder.rs b/socketio/src/asynchronous/client/builder.rs index 44710e19..3d3e5859 100644 --- a/socketio/src/asynchronous/client/builder.rs +++ b/socketio/src/asynchronous/client/builder.rs @@ -1,3 +1,11 @@ +use super::{ + callback::{ + Callback, DynAsyncAnyCallback, DynAsyncCallback, DynAsyncReconnectSettingsCallback, + }, + client::{Client, ReconnectSettings}, +}; +use crate::asynchronous::socket::Socket as InnerSocket; +use crate::{error::Result, Event, Payload, TransportType}; use futures_util::future::BoxFuture; use log::trace; use native_tls::TlsConnector; @@ -6,18 +14,9 @@ use rust_engineio::{ header::{HeaderMap, HeaderValue}, }; use std::collections::HashMap; +use std::sync::Arc; use url::Url; -use crate::{error::Result, Event, Payload, TransportType}; - -use super::{ - callback::{ - Callback, DynAsyncAnyCallback, DynAsyncCallback, DynAsyncReconnectSettingsCallback, - }, - client::{Client, ReconnectSettings}, -}; -use crate::asynchronous::socket::Socket as InnerSocket; - /// A builder class for a `socket.io` socket. This handles setting up the client and /// configuring the callback, the namespace and metadata of the socket. If no /// namespace is specified, the default namespace `/` is taken. The `connect` method @@ -38,6 +37,7 @@ pub struct ClientBuilder { pub(crate) max_reconnect_attempts: Option, pub(crate) reconnect_delay_min: u64, pub(crate) reconnect_delay_max: u64, + pub(crate) transmitter: Option>, } impl ClientBuilder { @@ -97,9 +97,54 @@ impl ClientBuilder { max_reconnect_attempts: None, reconnect_delay_min: 1000, reconnect_delay_max: 5000, + transmitter: None, } } + /// Sets the data transmission object, ideally the standard libraries + /// multi-producer single consumer [`std::sync::mpsc::Sender`] should be used. + /// + /// # Example + /// + /// ```no_run + /// + /// let (sender, receiver) = mpsc::channel::(); + /// let client = ClientBuilder::new(url) + /// .namespace("/admin") + /// .on("test", |payload: Payload, socket: SocketIOClient| { + /// async move { + /// match payload { + /// Payload::Text(values) => { + /// if let Some(value) = values.first() { + /// if value.is_string() { + /// let result = socket.try_transitter::>(); + /// + /// result + /// .map(|transmitter| { + /// transmitter.send(String::from(value.as_str().unwrap())) + /// }) + /// .map_err(|err| eprintln!("{}", err)) + /// .ok(); + /// } + /// } + /// } + /// Payload::Binary(_bin_data) => println!(), + /// #[allow(deprecated)] + /// Payload::String(str) => println!("Received: {}", str), + /// } + /// } + /// .boxed() + /// }) + /// .transmitter(Arc::new(sender)) + /// .connect() + /// .await + /// .expect("Connection failed"); + /// ``` + pub fn transmitter(mut self, data: Arc) -> Self { + self.transmitter = Some(data); + self + } + /// Sets the target namespace of the client. The namespace should start /// with a leading `/`. Valid examples are e.g. `/admin`, `/foo`. /// If the String provided doesn't start with a leading `/`, it is diff --git a/socketio/src/asynchronous/client/client.rs b/socketio/src/asynchronous/client/client.rs index 67feb7db..7822ed7b 100644 --- a/socketio/src/asynchronous/client/client.rs +++ b/socketio/src/asynchronous/client/client.rs @@ -1,10 +1,9 @@ -use std::{ops::DerefMut, pin::Pin, sync::Arc}; - use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; use futures_util::{future::BoxFuture, stream, Stream, StreamExt}; use log::trace; use rand::{thread_rng, Rng}; use serde_json::Value; +use std::{ops::DerefMut, pin::Pin, sync::Arc}; use tokio::{ sync::RwLock, time::{sleep, Duration, Instant}, @@ -68,10 +67,13 @@ pub struct Client { /// The inner socket client to delegate the methods to. socket: Arc>, outstanding_acks: Arc>>, - // namespace, for multiplexing messages + /// namespace, for multiplexing messages nsp: String, - // Data send in the opening packet (commonly used as for auth) + /// Data send in the opening packet (commonly used as for auth) auth: Option, + /// Ideally a Arc> to send data to a receiver that is outside + /// the 'static lifetime restrictions of the callback handlers. + transmitter: Arc, builder: Arc>, disconnect_reason: Arc>, } @@ -87,11 +89,68 @@ impl Client { nsp: builder.namespace.to_owned(), outstanding_acks: Arc::new(RwLock::new(Vec::new())), auth: builder.auth.clone(), + transmitter: builder.transmitter.clone().unwrap_or(Arc::new(())), builder: Arc::new(RwLock::new(builder)), disconnect_reason: Arc::new(RwLock::new(DisconnectReason::default())), }) } + /// Attempts to retrieve the transmitted data of type `D` from the transmitter. + /// + /// This function clones the transmitter and attempts to downcast it to an `Arc`. + /// If the downcast is successful, it returns the cloned data wrapped in a `Result`. + /// If the downcast fails, indicating that the transmitter contains data of an incompatible type, + /// it returns an `Err` with an `Error::TransmitterTypeResolutionFailure`. + /// + /// # Generic Parameters + /// + /// - `D`: The type of data expected to be transmitted. + /// + /// # Returns + /// + /// - `Result>`: A `Result` containing the cloned data if successful, or an error otherwise. + /// + /// # Example + /// + /// ```no_run + /// use std::sync::{Arc, mpsc}; + /// use rust_socketio::{ + /// asynchronous::{Client, ClientBuilder}, + /// Payload, + /// }; + /// + /// let callback = | payload: Payload, socket: Client | { + /// async move { + /// match payload { + /// Payload::Text(values) => { + /// if let Some(value) = values.first() { + /// if value.is_string() { + /// let result = socket.try_transmitter::>(); + /// + /// result + /// .map(|transmitter| { + /// transmitter.send(String::from(value.as_str().unwrap())) + /// }) + /// .map_err(|err| eprintln!("{}", err)) + /// .ok(); + /// } + /// } + /// } + /// Payload::Binary(_bin_data) => println!(), + /// #[allow(deprecated)] + /// Payload::String(str) => println!("Received: {}", str), + /// } + /// } + /// .boxed() + /// }) + /// ``` + pub fn try_transitter(&self) -> Result> { + match Arc::clone(&self.transmitter).downcast() { + Ok(data) => Ok(data), + Err(_) => Err(Error::TransmitterTypeResolutionFailure), + } + } + /// Connects the client to a server. Afterwards the `emit_*` methods can be /// called to interact with the server. pub(crate) async fn connect(&self) -> Result<()> { @@ -415,7 +474,7 @@ impl Client { .await; } if let Some(ref attachments) = socket_packet.attachments { - if let Some(payload) = attachments.get(0) { + if let Some(payload) = attachments.first() { ack.callback.deref_mut()( Payload::Binary(payload.to_owned()), self.clone(), @@ -445,7 +504,7 @@ impl Client { }; if let Some(attachments) = &packet.attachments { - if let Some(binary_payload) = attachments.get(0) { + if let Some(binary_payload) = attachments.first() { self.callback(&event, Payload::Binary(binary_payload.to_owned())) .await?; } diff --git a/socketio/src/error.rs b/socketio/src/error.rs index cc25d897..2b0d2e6b 100644 --- a/socketio/src/error.rs +++ b/socketio/src/error.rs @@ -46,6 +46,8 @@ pub enum Error { InvalidAttachmentPacketType(u8), #[error("Underlying Engine.IO connection has closed")] StoppedEngineIoSocket, + #[error("Client::transmitter does not match the ClientBuilder::transmitter type")] + TransmitterTypeResolutionFailure, } pub(crate) type Result = std::result::Result; From f828ec3c65a0c43989206c5c28af48a17e0b14b3 Mon Sep 17 00:00:00 2001 From: Matt Williams Date: Tue, 14 May 2024 13:38:34 -0400 Subject: [PATCH 2/4] Fixed all clippy warnings. There were many warnings causing the `make clippy` command to fail. Some of the more noticable changes are moving the client.rs code to the corresponding mod.rs and deleting the former. This was to stop the clippy module inception warning. All unit tests, doc tests, and examples are passing. --- .../src/asynchronous/client/async_client.rs | 2 +- engineio/src/asynchronous/mod.rs | 2 +- engineio/src/client/client.rs | 662 ----------------- engineio/src/client/mod.rs | 664 +++++++++++++++++- engineio/src/lib.rs | 2 +- socketio/examples/async_transmitter.rs | 29 +- .../client/{client.rs => async_client.rs} | 50 +- socketio/src/asynchronous/client/builder.rs | 53 +- socketio/src/asynchronous/client/callback.rs | 6 +- socketio/src/asynchronous/client/mod.rs | 2 +- socketio/src/asynchronous/mod.rs | 2 +- socketio/src/client/builder.rs | 12 +- socketio/src/client/client.rs | 478 ------------- socketio/src/client/mod.rs | 481 ++++++++++++- socketio/src/lib.rs | 2 +- socketio/src/packet.rs | 7 +- 16 files changed, 1219 insertions(+), 1235 deletions(-) delete mode 100644 engineio/src/client/client.rs rename socketio/src/asynchronous/client/{client.rs => async_client.rs} (97%) delete mode 100644 socketio/src/client/client.rs diff --git a/engineio/src/asynchronous/client/async_client.rs b/engineio/src/asynchronous/client/async_client.rs index 99b0d0dd..783179eb 100644 --- a/engineio/src/asynchronous/client/async_client.rs +++ b/engineio/src/asynchronous/client/async_client.rs @@ -88,7 +88,7 @@ impl Debug for Client { } } -#[cfg(all(test))] +#[cfg(test)] mod test { use super::*; diff --git a/engineio/src/asynchronous/mod.rs b/engineio/src/asynchronous/mod.rs index 97bd9073..a27ae264 100644 --- a/engineio/src/asynchronous/mod.rs +++ b/engineio/src/asynchronous/mod.rs @@ -1,7 +1,7 @@ pub mod async_transports; pub mod transport; -pub(self) mod async_socket; +mod async_socket; #[cfg(feature = "async-callbacks")] mod callback; #[cfg(feature = "async")] diff --git a/engineio/src/client/client.rs b/engineio/src/client/client.rs deleted file mode 100644 index dc22ff77..00000000 --- a/engineio/src/client/client.rs +++ /dev/null @@ -1,662 +0,0 @@ -use super::super::socket::Socket as InnerSocket; -use crate::callback::OptionalCallback; -use crate::socket::DEFAULT_MAX_POLL_TIMEOUT; -use crate::transport::Transport; - -use crate::error::{Error, Result}; -use crate::header::HeaderMap; -use crate::packet::{HandshakePacket, Packet, PacketId}; -use crate::transports::{PollingTransport, WebsocketSecureTransport, WebsocketTransport}; -use crate::ENGINE_IO_VERSION; -use bytes::Bytes; -use native_tls::TlsConnector; -use std::convert::TryFrom; -use std::convert::TryInto; -use std::fmt::Debug; -use url::Url; - -/// An engine.io client that allows interaction with the connected engine.io -/// server. This client provides means for connecting, disconnecting and sending -/// packets to the server. -/// -/// ## Note: -/// There is no need to put this Client behind an `Arc`, as the type uses `Arc` -/// internally and provides a shared state beyond all cloned instances. -#[derive(Clone, Debug)] -pub struct Client { - socket: InnerSocket, -} - -#[derive(Clone, Debug)] -pub struct ClientBuilder { - url: Url, - tls_config: Option, - headers: Option, - handshake: Option, - on_error: OptionalCallback, - on_open: OptionalCallback<()>, - on_close: OptionalCallback<()>, - on_data: OptionalCallback, - on_packet: OptionalCallback, -} - -impl ClientBuilder { - pub fn new(url: Url) -> Self { - let mut url = url; - url.query_pairs_mut() - .append_pair("EIO", &ENGINE_IO_VERSION.to_string()); - - // No path add engine.io - if url.path() == "/" { - url.set_path("/engine.io/"); - } - ClientBuilder { - url, - headers: None, - tls_config: None, - handshake: None, - on_close: OptionalCallback::default(), - on_data: OptionalCallback::default(), - on_error: OptionalCallback::default(), - on_open: OptionalCallback::default(), - on_packet: OptionalCallback::default(), - } - } - - /// Specify transport's tls config - pub fn tls_config(mut self, tls_config: TlsConnector) -> Self { - self.tls_config = Some(tls_config); - self - } - - /// Specify transport's HTTP headers - pub fn headers(mut self, headers: HeaderMap) -> Self { - self.headers = Some(headers); - self - } - - /// Registers the `on_close` callback. - pub fn on_close(mut self, callback: T) -> Self - where - T: Fn(()) + 'static + Sync + Send, - { - self.on_close = OptionalCallback::new(callback); - self - } - - /// Registers the `on_data` callback. - pub fn on_data(mut self, callback: T) -> Self - where - T: Fn(Bytes) + 'static + Sync + Send, - { - self.on_data = OptionalCallback::new(callback); - self - } - - /// Registers the `on_error` callback. - pub fn on_error(mut self, callback: T) -> Self - where - T: Fn(String) + 'static + Sync + Send, - { - self.on_error = OptionalCallback::new(callback); - self - } - - /// Registers the `on_open` callback. - pub fn on_open(mut self, callback: T) -> Self - where - T: Fn(()) + 'static + Sync + Send, - { - self.on_open = OptionalCallback::new(callback); - self - } - - /// Registers the `on_packet` callback. - pub fn on_packet(mut self, callback: T) -> Self - where - T: Fn(Packet) + 'static + Sync + Send, - { - self.on_packet = OptionalCallback::new(callback); - self - } - - /// Performs the handshake - fn handshake_with_transport(&mut self, transport: &T) -> Result<()> { - // No need to handshake twice - if self.handshake.is_some() { - return Ok(()); - } - - let mut url = self.url.clone(); - - let handshake: HandshakePacket = - Packet::try_from(transport.poll(DEFAULT_MAX_POLL_TIMEOUT)?)?.try_into()?; - - // update the base_url with the new sid - url.query_pairs_mut().append_pair("sid", &handshake.sid[..]); - - self.handshake = Some(handshake); - - self.url = url; - - Ok(()) - } - - fn handshake(&mut self) -> Result<()> { - if self.handshake.is_some() { - return Ok(()); - } - - // Start with polling transport - let transport = PollingTransport::new( - self.url.clone(), - self.tls_config.clone(), - self.headers.clone().map(|v| v.try_into().unwrap()), - ); - - self.handshake_with_transport(&transport) - } - - /// Build websocket if allowed, if not fall back to polling - pub fn build(mut self) -> Result { - self.handshake()?; - - if self.websocket_upgrade()? { - self.build_websocket_with_upgrade() - } else { - self.build_polling() - } - } - - /// Build socket with polling transport - pub fn build_polling(mut self) -> Result { - self.handshake()?; - - // Make a polling transport with new sid - let transport = PollingTransport::new( - self.url, - self.tls_config, - self.headers.map(|v| v.try_into().unwrap()), - ); - - // SAFETY: handshake function called previously. - Ok(Client { - socket: InnerSocket::new( - transport.into(), - self.handshake.unwrap(), - self.on_close, - self.on_data, - self.on_error, - self.on_open, - self.on_packet, - ), - }) - } - - /// Build socket with a polling transport then upgrade to websocket transport - pub fn build_websocket_with_upgrade(mut self) -> Result { - self.handshake()?; - - if self.websocket_upgrade()? { - self.build_websocket() - } else { - Err(Error::IllegalWebsocketUpgrade()) - } - } - - /// Build socket with only a websocket transport - pub fn build_websocket(mut self) -> Result { - // SAFETY: Already a Url - let url = url::Url::parse(self.url.as_ref())?; - - let headers: Option = if let Some(map) = self.headers.clone() { - Some(map.try_into()?) - } else { - None - }; - - match url.scheme() { - "http" | "ws" => { - let transport = WebsocketTransport::new(url, headers)?; - if self.handshake.is_some() { - transport.upgrade()?; - } else { - self.handshake_with_transport(&transport)?; - } - // NOTE: Although self.url contains the sid, it does not propagate to the transport - // SAFETY: handshake function called previously. - Ok(Client { - socket: InnerSocket::new( - transport.into(), - self.handshake.unwrap(), - self.on_close, - self.on_data, - self.on_error, - self.on_open, - self.on_packet, - ), - }) - } - "https" | "wss" => { - let transport = - WebsocketSecureTransport::new(url, self.tls_config.clone(), headers)?; - if self.handshake.is_some() { - transport.upgrade()?; - } else { - self.handshake_with_transport(&transport)?; - } - // NOTE: Although self.url contains the sid, it does not propagate to the transport - // SAFETY: handshake function called previously. - Ok(Client { - socket: InnerSocket::new( - transport.into(), - self.handshake.unwrap(), - self.on_close, - self.on_data, - self.on_error, - self.on_open, - self.on_packet, - ), - }) - } - _ => Err(Error::InvalidUrlScheme(url.scheme().to_string())), - } - } - - /// Build websocket if allowed, if not allowed or errored fall back to polling. - /// WARNING: websocket errors suppressed, no indication of websocket success or failure. - pub fn build_with_fallback(self) -> Result { - let result = self.clone().build(); - if result.is_err() { - self.build_polling() - } else { - result - } - } - - /// Checks the handshake to see if websocket upgrades are allowed - fn websocket_upgrade(&mut self) -> Result { - // SAFETY: handshake set by above function. - Ok(self - .handshake - .as_ref() - .unwrap() - .upgrades - .iter() - .any(|upgrade| upgrade.to_lowercase() == *"websocket")) - } -} - -impl Client { - pub fn close(&self) -> Result<()> { - self.socket.disconnect() - } - - /// Opens the connection to a specified server. The first Pong packet is sent - /// to the server to trigger the Ping-cycle. - pub fn connect(&self) -> Result<()> { - self.socket.connect() - } - - /// Disconnects the connection. - pub fn disconnect(&self) -> Result<()> { - self.socket.disconnect() - } - - /// Sends a packet to the server. - pub fn emit(&self, packet: Packet) -> Result<()> { - self.socket.emit(packet) - } - - /// Polls for next payload - #[doc(hidden)] - pub fn poll(&self) -> Result> { - let packet = self.socket.poll()?; - if let Some(packet) = packet { - // check for the appropriate action or callback - self.socket.handle_packet(packet.clone()); - match packet.packet_id { - PacketId::MessageBinary => { - self.socket.handle_data(packet.data.clone()); - } - PacketId::Message => { - self.socket.handle_data(packet.data.clone()); - } - PacketId::Close => { - self.socket.handle_close(); - } - PacketId::Open => { - unreachable!("Won't happen as we open the connection beforehand"); - } - PacketId::Upgrade => { - // this is already checked during the handshake, so just do nothing here - } - PacketId::Ping => { - self.socket.pinged()?; - self.emit(Packet::new(PacketId::Pong, Bytes::new()))?; - } - PacketId::Pong => { - // this will never happen as the pong packet is - // only sent by the client - unreachable!(); - } - PacketId::Noop => (), - } - Ok(Some(packet)) - } else { - Ok(None) - } - } - - /// Check if the underlying transport client is connected. - pub fn is_connected(&self) -> Result { - self.socket.is_connected() - } - - pub fn iter(&self) -> Iter { - Iter { socket: self } - } -} - -#[derive(Clone)] -pub struct Iter<'a> { - socket: &'a Client, -} - -impl<'a> Iterator for Iter<'a> { - type Item = Result; - fn next(&mut self) -> std::option::Option<::Item> { - match self.socket.poll() { - Ok(Some(packet)) => Some(Ok(packet)), - Ok(None) => None, - Err(err) => Some(Err(err)), - } - } -} - -#[cfg(test)] -mod test { - - use crate::packet::PacketId; - - use super::*; - - /// The purpose of this test is to check whether the Client is properly cloneable or not. - /// As the documentation of the engine.io client states, the object needs to maintain it's internal - /// state when cloned and the cloned object should reflect the same state throughout the lifetime - /// of both objects (initial and cloned). - #[test] - fn test_client_cloneable() -> Result<()> { - let url = crate::test::engine_io_server()?; - let sut = builder(url).build()?; - - let cloned = sut.clone(); - - sut.connect()?; - - // when the underlying socket is connected, the - // state should also change on the cloned one - assert!(sut.is_connected()?); - assert!(cloned.is_connected()?); - - // both clients should reflect the same messages. - let mut iter = sut - .iter() - .map(|packet| packet.unwrap()) - .filter(|packet| packet.packet_id != PacketId::Ping); - - let mut iter_cloned = cloned - .iter() - .map(|packet| packet.unwrap()) - .filter(|packet| packet.packet_id != PacketId::Ping); - - assert_eq!( - iter.next(), - Some(Packet::new(PacketId::Message, "hello client")) - ); - - sut.emit(Packet::new(PacketId::Message, "respond"))?; - - assert_eq!( - iter_cloned.next(), - Some(Packet::new(PacketId::Message, "Roger Roger")) - ); - - cloned.disconnect()?; - - // when the underlying socket is disconnected, the - // state should also change on the cloned one - assert!(!sut.is_connected()?); - assert!(!cloned.is_connected()?); - - Ok(()) - } - - #[test] - fn test_illegal_actions() -> Result<()> { - let url = crate::test::engine_io_server()?; - let sut = builder(url.clone()).build()?; - - assert!(sut - .emit(Packet::new(PacketId::Close, Bytes::new())) - .is_err()); - - sut.connect()?; - - assert!(sut.poll().is_ok()); - - assert!(builder(Url::parse("fake://fake.fake").unwrap()) - .build_websocket() - .is_err()); - - Ok(()) - } - use reqwest::header::HOST; - - use crate::packet::Packet; - - fn builder(url: Url) -> ClientBuilder { - ClientBuilder::new(url) - .on_open(|_| { - println!("Open event!"); - }) - .on_packet(|packet| { - println!("Received packet: {:?}", packet); - }) - .on_data(|data| { - println!("Received data: {:?}", std::str::from_utf8(&data)); - }) - .on_close(|_| { - println!("Close event!"); - }) - .on_error(|error| { - println!("Error {}", error); - }) - } - - fn test_connection(socket: Client) -> Result<()> { - let socket = socket; - - socket.connect().unwrap(); - - // TODO: 0.3.X better tests - - let mut iter = socket - .iter() - .map(|packet| packet.unwrap()) - .filter(|packet| packet.packet_id != PacketId::Ping); - - assert_eq!( - iter.next(), - Some(Packet::new(PacketId::Message, "hello client")) - ); - - socket.emit(Packet::new(PacketId::Message, "respond"))?; - - assert_eq!( - iter.next(), - Some(Packet::new(PacketId::Message, "Roger Roger")) - ); - - socket.close() - } - - #[test] - fn test_connection_long() -> Result<()> { - // Long lived socket to receive pings - let url = crate::test::engine_io_server()?; - let socket = builder(url).build()?; - - socket.connect()?; - - let mut iter = socket.iter(); - // hello client - iter.next(); - // Ping - iter.next(); - - socket.disconnect()?; - - assert!(!socket.is_connected()?); - - Ok(()) - } - - #[test] - fn test_connection_dynamic() -> Result<()> { - let url = crate::test::engine_io_server()?; - let socket = builder(url).build()?; - test_connection(socket)?; - - let url = crate::test::engine_io_polling_server()?; - let socket = builder(url).build()?; - test_connection(socket) - } - - #[test] - fn test_connection_fallback() -> Result<()> { - let url = crate::test::engine_io_server()?; - let socket = builder(url).build_with_fallback()?; - test_connection(socket)?; - - let url = crate::test::engine_io_polling_server()?; - let socket = builder(url).build_with_fallback()?; - test_connection(socket) - } - - #[test] - fn test_connection_dynamic_secure() -> Result<()> { - let url = crate::test::engine_io_server_secure()?; - let mut builder = builder(url); - builder = builder.tls_config(crate::test::tls_connector()?); - let socket = builder.build()?; - test_connection(socket) - } - - #[test] - fn test_connection_polling() -> Result<()> { - let url = crate::test::engine_io_server()?; - let socket = builder(url).build_polling()?; - test_connection(socket) - } - - #[test] - fn test_connection_wss() -> Result<()> { - let url = crate::test::engine_io_polling_server()?; - assert!(builder(url).build_websocket_with_upgrade().is_err()); - - let host = - std::env::var("ENGINE_IO_SECURE_HOST").unwrap_or_else(|_| "localhost".to_owned()); - let mut url = crate::test::engine_io_server_secure()?; - - let mut headers = HeaderMap::default(); - headers.insert(HOST, host); - let mut builder = builder(url.clone()); - - builder = builder.tls_config(crate::test::tls_connector()?); - builder = builder.headers(headers.clone()); - let socket = builder.clone().build_websocket_with_upgrade()?; - - test_connection(socket)?; - - let socket = builder.build_websocket()?; - - test_connection(socket)?; - - url.set_scheme("wss").unwrap(); - - let builder = self::builder(url) - .tls_config(crate::test::tls_connector()?) - .headers(headers); - let socket = builder.clone().build_websocket()?; - - test_connection(socket)?; - - assert!(builder.build_websocket_with_upgrade().is_err()); - - Ok(()) - } - - #[test] - fn test_connection_ws() -> Result<()> { - let url = crate::test::engine_io_polling_server()?; - assert!(builder(url.clone()).build_websocket().is_err()); - assert!(builder(url).build_websocket_with_upgrade().is_err()); - - let mut url = crate::test::engine_io_server()?; - - let builder = builder(url.clone()); - let socket = builder.clone().build_websocket()?; - test_connection(socket)?; - - let socket = builder.build_websocket_with_upgrade()?; - test_connection(socket)?; - - url.set_scheme("ws").unwrap(); - - let builder = self::builder(url); - let socket = builder.clone().build_websocket()?; - - test_connection(socket)?; - - assert!(builder.build_websocket_with_upgrade().is_err()); - - Ok(()) - } - - #[test] - fn test_open_invariants() -> Result<()> { - let url = crate::test::engine_io_server()?; - let illegal_url = "this is illegal"; - - assert!(Url::parse(illegal_url).is_err()); - - let invalid_protocol = "file:///tmp/foo"; - assert!(builder(Url::parse(invalid_protocol).unwrap()) - .build() - .is_err()); - - let sut = builder(url.clone()).build()?; - let _error = sut - .emit(Packet::new(PacketId::Close, Bytes::new())) - .expect_err("error"); - assert!(matches!(Error::IllegalActionBeforeOpen(), _error)); - - // test missing match arm in socket constructor - let mut headers = HeaderMap::default(); - let host = - std::env::var("ENGINE_IO_SECURE_HOST").unwrap_or_else(|_| "localhost".to_owned()); - headers.insert(HOST, host); - - let _ = builder(url.clone()) - .tls_config( - TlsConnector::builder() - .danger_accept_invalid_certs(true) - .build() - .unwrap(), - ) - .build()?; - let _ = builder(url).headers(headers).build()?; - Ok(()) - } -} diff --git a/engineio/src/client/mod.rs b/engineio/src/client/mod.rs index 2feb68ab..894ba60e 100644 --- a/engineio/src/client/mod.rs +++ b/engineio/src/client/mod.rs @@ -1,3 +1,661 @@ -mod client; -pub use client::Iter; -pub use {client::Client, client::ClientBuilder, client::Iter as SocketIter}; +use super::socket::Socket as InnerSocket; +use crate::callback::OptionalCallback; +use crate::error::{Error, Result}; +use crate::header::HeaderMap; +use crate::packet::{HandshakePacket, Packet, PacketId}; +use crate::socket::DEFAULT_MAX_POLL_TIMEOUT; +use crate::transport::Transport; +use crate::transports::{PollingTransport, WebsocketSecureTransport, WebsocketTransport}; +use crate::ENGINE_IO_VERSION; +use bytes::Bytes; +use native_tls::TlsConnector; +use std::convert::TryFrom; +use std::convert::TryInto; +use std::fmt::Debug; +use url::Url; + +/// An engine.io client that allows interaction with the connected engine.io +/// server. This client provides means for connecting, disconnecting and sending +/// packets to the server. +/// +/// ## Note: +/// There is no need to put this Client behind an `Arc`, as the type uses `Arc` +/// internally and provides a shared state beyond all cloned instances. +#[derive(Clone, Debug)] +pub struct Client { + socket: InnerSocket, +} + +#[derive(Clone, Debug)] +pub struct ClientBuilder { + url: Url, + tls_config: Option, + headers: Option, + handshake: Option, + on_error: OptionalCallback, + on_open: OptionalCallback<()>, + on_close: OptionalCallback<()>, + on_data: OptionalCallback, + on_packet: OptionalCallback, +} + +impl ClientBuilder { + pub fn new(url: Url) -> Self { + let mut url = url; + url.query_pairs_mut() + .append_pair("EIO", &ENGINE_IO_VERSION.to_string()); + + // No path add engine.io + if url.path() == "/" { + url.set_path("/engine.io/"); + } + ClientBuilder { + url, + headers: None, + tls_config: None, + handshake: None, + on_close: OptionalCallback::default(), + on_data: OptionalCallback::default(), + on_error: OptionalCallback::default(), + on_open: OptionalCallback::default(), + on_packet: OptionalCallback::default(), + } + } + + /// Specify transport's tls config + pub fn tls_config(mut self, tls_config: TlsConnector) -> Self { + self.tls_config = Some(tls_config); + self + } + + /// Specify transport's HTTP headers + pub fn headers(mut self, headers: HeaderMap) -> Self { + self.headers = Some(headers); + self + } + + /// Registers the `on_close` callback. + pub fn on_close(mut self, callback: T) -> Self + where + T: Fn(()) + 'static + Sync + Send, + { + self.on_close = OptionalCallback::new(callback); + self + } + + /// Registers the `on_data` callback. + pub fn on_data(mut self, callback: T) -> Self + where + T: Fn(Bytes) + 'static + Sync + Send, + { + self.on_data = OptionalCallback::new(callback); + self + } + + /// Registers the `on_error` callback. + pub fn on_error(mut self, callback: T) -> Self + where + T: Fn(String) + 'static + Sync + Send, + { + self.on_error = OptionalCallback::new(callback); + self + } + + /// Registers the `on_open` callback. + pub fn on_open(mut self, callback: T) -> Self + where + T: Fn(()) + 'static + Sync + Send, + { + self.on_open = OptionalCallback::new(callback); + self + } + + /// Registers the `on_packet` callback. + pub fn on_packet(mut self, callback: T) -> Self + where + T: Fn(Packet) + 'static + Sync + Send, + { + self.on_packet = OptionalCallback::new(callback); + self + } + + /// Performs the handshake + fn handshake_with_transport(&mut self, transport: &T) -> Result<()> { + // No need to handshake twice + if self.handshake.is_some() { + return Ok(()); + } + + let mut url = self.url.clone(); + + let handshake: HandshakePacket = + Packet::try_from(transport.poll(DEFAULT_MAX_POLL_TIMEOUT)?)?.try_into()?; + + // update the base_url with the new sid + url.query_pairs_mut().append_pair("sid", &handshake.sid[..]); + + self.handshake = Some(handshake); + + self.url = url; + + Ok(()) + } + + fn handshake(&mut self) -> Result<()> { + if self.handshake.is_some() { + return Ok(()); + } + + // Start with polling transport + let transport = PollingTransport::new( + self.url.clone(), + self.tls_config.clone(), + self.headers.clone().map(|v| v.try_into().unwrap()), + ); + + self.handshake_with_transport(&transport) + } + + /// Build websocket if allowed, if not fall back to polling + pub fn build(mut self) -> Result { + self.handshake()?; + + if self.websocket_upgrade()? { + self.build_websocket_with_upgrade() + } else { + self.build_polling() + } + } + + /// Build socket with polling transport + pub fn build_polling(mut self) -> Result { + self.handshake()?; + + // Make a polling transport with new sid + let transport = PollingTransport::new( + self.url, + self.tls_config, + self.headers.map(|v| v.try_into().unwrap()), + ); + + // SAFETY: handshake function called previously. + Ok(Client { + socket: InnerSocket::new( + transport.into(), + self.handshake.unwrap(), + self.on_close, + self.on_data, + self.on_error, + self.on_open, + self.on_packet, + ), + }) + } + + /// Build socket with a polling transport then upgrade to websocket transport + pub fn build_websocket_with_upgrade(mut self) -> Result { + self.handshake()?; + + if self.websocket_upgrade()? { + self.build_websocket() + } else { + Err(Error::IllegalWebsocketUpgrade()) + } + } + + /// Build socket with only a websocket transport + pub fn build_websocket(mut self) -> Result { + // SAFETY: Already a Url + let url = url::Url::parse(self.url.as_ref())?; + + let headers: Option = if let Some(map) = self.headers.clone() { + Some(map.try_into()?) + } else { + None + }; + + match url.scheme() { + "http" | "ws" => { + let transport = WebsocketTransport::new(url, headers)?; + if self.handshake.is_some() { + transport.upgrade()?; + } else { + self.handshake_with_transport(&transport)?; + } + // NOTE: Although self.url contains the sid, it does not propagate to the transport + // SAFETY: handshake function called previously. + Ok(Client { + socket: InnerSocket::new( + transport.into(), + self.handshake.unwrap(), + self.on_close, + self.on_data, + self.on_error, + self.on_open, + self.on_packet, + ), + }) + } + "https" | "wss" => { + let transport = + WebsocketSecureTransport::new(url, self.tls_config.clone(), headers)?; + if self.handshake.is_some() { + transport.upgrade()?; + } else { + self.handshake_with_transport(&transport)?; + } + // NOTE: Although self.url contains the sid, it does not propagate to the transport + // SAFETY: handshake function called previously. + Ok(Client { + socket: InnerSocket::new( + transport.into(), + self.handshake.unwrap(), + self.on_close, + self.on_data, + self.on_error, + self.on_open, + self.on_packet, + ), + }) + } + _ => Err(Error::InvalidUrlScheme(url.scheme().to_string())), + } + } + + /// Build websocket if allowed, if not allowed or errored fall back to polling. + /// WARNING: websocket errors suppressed, no indication of websocket success or failure. + pub fn build_with_fallback(self) -> Result { + let result = self.clone().build(); + if result.is_err() { + self.build_polling() + } else { + result + } + } + + /// Checks the handshake to see if websocket upgrades are allowed + fn websocket_upgrade(&mut self) -> Result { + // SAFETY: handshake set by above function. + Ok(self + .handshake + .as_ref() + .unwrap() + .upgrades + .iter() + .any(|upgrade| upgrade.to_lowercase() == *"websocket")) + } +} + +impl Client { + pub fn close(&self) -> Result<()> { + self.socket.disconnect() + } + + /// Opens the connection to a specified server. The first Pong packet is sent + /// to the server to trigger the Ping-cycle. + pub fn connect(&self) -> Result<()> { + self.socket.connect() + } + + /// Disconnects the connection. + pub fn disconnect(&self) -> Result<()> { + self.socket.disconnect() + } + + /// Sends a packet to the server. + pub fn emit(&self, packet: Packet) -> Result<()> { + self.socket.emit(packet) + } + + /// Polls for next payload + #[doc(hidden)] + pub fn poll(&self) -> Result> { + let packet = self.socket.poll()?; + if let Some(packet) = packet { + // check for the appropriate action or callback + self.socket.handle_packet(packet.clone()); + match packet.packet_id { + PacketId::MessageBinary => { + self.socket.handle_data(packet.data.clone()); + } + PacketId::Message => { + self.socket.handle_data(packet.data.clone()); + } + PacketId::Close => { + self.socket.handle_close(); + } + PacketId::Open => { + unreachable!("Won't happen as we open the connection beforehand"); + } + PacketId::Upgrade => { + // this is already checked during the handshake, so just do nothing here + } + PacketId::Ping => { + self.socket.pinged()?; + self.emit(Packet::new(PacketId::Pong, Bytes::new()))?; + } + PacketId::Pong => { + // this will never happen as the pong packet is + // only sent by the client + unreachable!(); + } + PacketId::Noop => (), + } + Ok(Some(packet)) + } else { + Ok(None) + } + } + + /// Check if the underlying transport client is connected. + pub fn is_connected(&self) -> Result { + self.socket.is_connected() + } + + pub fn iter(&self) -> Iter { + Iter { socket: self } + } +} + +#[derive(Clone)] +pub struct Iter<'a> { + socket: &'a Client, +} + +impl<'a> Iterator for Iter<'a> { + type Item = Result; + fn next(&mut self) -> std::option::Option<::Item> { + match self.socket.poll() { + Ok(Some(packet)) => Some(Ok(packet)), + Ok(None) => None, + Err(err) => Some(Err(err)), + } + } +} + +#[cfg(test)] +mod test { + + use crate::packet::PacketId; + + use super::*; + + /// The purpose of this test is to check whether the Client is properly cloneable or not. + /// As the documentation of the engine.io client states, the object needs to maintain it's internal + /// state when cloned and the cloned object should reflect the same state throughout the lifetime + /// of both objects (initial and cloned). + #[test] + fn test_client_cloneable() -> Result<()> { + let url = crate::test::engine_io_server()?; + let sut = builder(url).build()?; + + let cloned = sut.clone(); + + sut.connect()?; + + // when the underlying socket is connected, the + // state should also change on the cloned one + assert!(sut.is_connected()?); + assert!(cloned.is_connected()?); + + // both clients should reflect the same messages. + let mut iter = sut + .iter() + .map(|packet| packet.unwrap()) + .filter(|packet| packet.packet_id != PacketId::Ping); + + let mut iter_cloned = cloned + .iter() + .map(|packet| packet.unwrap()) + .filter(|packet| packet.packet_id != PacketId::Ping); + + assert_eq!( + iter.next(), + Some(Packet::new(PacketId::Message, "hello client")) + ); + + sut.emit(Packet::new(PacketId::Message, "respond"))?; + + assert_eq!( + iter_cloned.next(), + Some(Packet::new(PacketId::Message, "Roger Roger")) + ); + + cloned.disconnect()?; + + // when the underlying socket is disconnected, the + // state should also change on the cloned one + assert!(!sut.is_connected()?); + assert!(!cloned.is_connected()?); + + Ok(()) + } + + #[test] + fn test_illegal_actions() -> Result<()> { + let url = crate::test::engine_io_server()?; + let sut = builder(url.clone()).build()?; + + assert!(sut + .emit(Packet::new(PacketId::Close, Bytes::new())) + .is_err()); + + sut.connect()?; + + assert!(sut.poll().is_ok()); + + assert!(builder(Url::parse("fake://fake.fake").unwrap()) + .build_websocket() + .is_err()); + + Ok(()) + } + use reqwest::header::HOST; + + use crate::packet::Packet; + + fn builder(url: Url) -> ClientBuilder { + ClientBuilder::new(url) + .on_open(|_| { + println!("Open event!"); + }) + .on_packet(|packet| { + println!("Received packet: {:?}", packet); + }) + .on_data(|data| { + println!("Received data: {:?}", std::str::from_utf8(&data)); + }) + .on_close(|_| { + println!("Close event!"); + }) + .on_error(|error| { + println!("Error {}", error); + }) + } + + fn test_connection(socket: Client) -> Result<()> { + let socket = socket; + + socket.connect().unwrap(); + + // TODO: 0.3.X better tests + + let mut iter = socket + .iter() + .map(|packet| packet.unwrap()) + .filter(|packet| packet.packet_id != PacketId::Ping); + + assert_eq!( + iter.next(), + Some(Packet::new(PacketId::Message, "hello client")) + ); + + socket.emit(Packet::new(PacketId::Message, "respond"))?; + + assert_eq!( + iter.next(), + Some(Packet::new(PacketId::Message, "Roger Roger")) + ); + + socket.close() + } + + #[test] + fn test_connection_long() -> Result<()> { + // Long lived socket to receive pings + let url = crate::test::engine_io_server()?; + let socket = builder(url).build()?; + + socket.connect()?; + + let mut iter = socket.iter(); + // hello client + iter.next(); + // Ping + iter.next(); + + socket.disconnect()?; + + assert!(!socket.is_connected()?); + + Ok(()) + } + + #[test] + fn test_connection_dynamic() -> Result<()> { + let url = crate::test::engine_io_server()?; + let socket = builder(url).build()?; + test_connection(socket)?; + + let url = crate::test::engine_io_polling_server()?; + let socket = builder(url).build()?; + test_connection(socket) + } + + #[test] + fn test_connection_fallback() -> Result<()> { + let url = crate::test::engine_io_server()?; + let socket = builder(url).build_with_fallback()?; + test_connection(socket)?; + + let url = crate::test::engine_io_polling_server()?; + let socket = builder(url).build_with_fallback()?; + test_connection(socket) + } + + #[test] + fn test_connection_dynamic_secure() -> Result<()> { + let url = crate::test::engine_io_server_secure()?; + let mut builder = builder(url); + builder = builder.tls_config(crate::test::tls_connector()?); + let socket = builder.build()?; + test_connection(socket) + } + + #[test] + fn test_connection_polling() -> Result<()> { + let url = crate::test::engine_io_server()?; + let socket = builder(url).build_polling()?; + test_connection(socket) + } + + #[test] + fn test_connection_wss() -> Result<()> { + let url = crate::test::engine_io_polling_server()?; + assert!(builder(url).build_websocket_with_upgrade().is_err()); + + let host = + std::env::var("ENGINE_IO_SECURE_HOST").unwrap_or_else(|_| "localhost".to_owned()); + let mut url = crate::test::engine_io_server_secure()?; + + let mut headers = HeaderMap::default(); + headers.insert(HOST, host); + let mut builder = builder(url.clone()); + + builder = builder.tls_config(crate::test::tls_connector()?); + builder = builder.headers(headers.clone()); + let socket = builder.clone().build_websocket_with_upgrade()?; + + test_connection(socket)?; + + let socket = builder.build_websocket()?; + + test_connection(socket)?; + + url.set_scheme("wss").unwrap(); + + let builder = self::builder(url) + .tls_config(crate::test::tls_connector()?) + .headers(headers); + let socket = builder.clone().build_websocket()?; + + test_connection(socket)?; + + assert!(builder.build_websocket_with_upgrade().is_err()); + + Ok(()) + } + + #[test] + fn test_connection_ws() -> Result<()> { + let url = crate::test::engine_io_polling_server()?; + assert!(builder(url.clone()).build_websocket().is_err()); + assert!(builder(url).build_websocket_with_upgrade().is_err()); + + let mut url = crate::test::engine_io_server()?; + + let builder = builder(url.clone()); + let socket = builder.clone().build_websocket()?; + test_connection(socket)?; + + let socket = builder.build_websocket_with_upgrade()?; + test_connection(socket)?; + + url.set_scheme("ws").unwrap(); + + let builder = self::builder(url); + let socket = builder.clone().build_websocket()?; + + test_connection(socket)?; + + assert!(builder.build_websocket_with_upgrade().is_err()); + + Ok(()) + } + + #[test] + fn test_open_invariants() -> Result<()> { + let url = crate::test::engine_io_server()?; + let illegal_url = "this is illegal"; + + assert!(Url::parse(illegal_url).is_err()); + + let invalid_protocol = "file:///tmp/foo"; + assert!(builder(Url::parse(invalid_protocol).unwrap()) + .build() + .is_err()); + + let sut = builder(url.clone()).build()?; + let _error = sut + .emit(Packet::new(PacketId::Close, Bytes::new())) + .expect_err("error"); + assert!(matches!(Error::IllegalActionBeforeOpen(), _error)); + + // test missing match arm in socket constructor + let mut headers = HeaderMap::default(); + let host = + std::env::var("ENGINE_IO_SECURE_HOST").unwrap_or_else(|_| "localhost".to_owned()); + headers.insert(HOST, host); + + let _ = builder(url.clone()) + .tls_config( + TlsConnector::builder() + .danger_accept_invalid_certs(true) + .build() + .unwrap(), + ) + .build()?; + let _ = builder(url).headers(headers).build()?; + Ok(()) + } +} diff --git a/engineio/src/lib.rs b/engineio/src/lib.rs index 2ac5787d..2fcd16fb 100644 --- a/engineio/src/lib.rs +++ b/engineio/src/lib.rs @@ -85,7 +85,7 @@ pub mod client; /// Generic header map pub mod header; pub mod packet; -pub(self) mod socket; +mod socket; pub mod transport; pub mod transports; diff --git a/socketio/examples/async_transmitter.rs b/socketio/examples/async_transmitter.rs index 119bcbb6..415a2a3c 100644 --- a/socketio/examples/async_transmitter.rs +++ b/socketio/examples/async_transmitter.rs @@ -32,18 +32,26 @@ impl TransmitterClient { Payload::Text(values) => { if let Some(value) = values.first() { if value.is_string() { - let result = socket.try_transitter::>(); - - result - .map(|transmitter| { - transmitter.send(String::from(value.as_str().unwrap())) - }) - .map_err(|err| eprintln!("{}", err)) - .ok(); + socket + .try_transmitter::>() + .map_or_else( + |err| eprintln!("{}", err), + |tx| { + tx.send(String::from(value.as_str().unwrap())) + .map_or_else( + |err| eprintln!("{}", err), + |_| { + println!( + "Data transmitted successfully" + ) + }, + ); + }, + ); } } } - Payload::Binary(_bin_data) => println!(), + Payload::Binary(bin_data) => println!("Binary data: {:#?}", bin_data), #[allow(deprecated)] Payload::String(str) => println!("Received: {}", str), } @@ -100,10 +108,9 @@ async fn main() { } } Err(err) => { - eprintln!("{}", err); + eprintln!("Failed to connect to server: {}", err); } } - // Wait so we can see our response sleep(Duration::from_secs(2)).await; } diff --git a/socketio/src/asynchronous/client/client.rs b/socketio/src/asynchronous/client/async_client.rs similarity index 97% rename from socketio/src/asynchronous/client/client.rs rename to socketio/src/asynchronous/client/async_client.rs index 7822ed7b..00738eb2 100644 --- a/socketio/src/asynchronous/client/client.rs +++ b/socketio/src/asynchronous/client/async_client.rs @@ -1,14 +1,3 @@ -use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; -use futures_util::{future::BoxFuture, stream, Stream, StreamExt}; -use log::trace; -use rand::{thread_rng, Rng}; -use serde_json::Value; -use std::{ops::DerefMut, pin::Pin, sync::Arc}; -use tokio::{ - sync::RwLock, - time::{sleep, Duration, Instant}, -}; - use super::{ ack::Ack, builder::ClientBuilder, @@ -20,6 +9,16 @@ use crate::{ packet::{Packet, PacketId}, Event, Payload, }; +use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; +use futures_util::{future::BoxFuture, stream, Stream, StreamExt}; +use log::trace; +use rand::{thread_rng, Rng}; +use serde_json::Value; +use std::{ops::DerefMut, pin::Pin, sync::Arc}; +use tokio::{ + sync::RwLock, + time::{sleep, Duration, Instant}, +}; #[derive(Default)] enum DisconnectReason { @@ -113,9 +112,10 @@ impl Client { /// # Example /// /// ```no_run + /// use futures_util::FutureExt; /// use std::sync::{Arc, mpsc}; /// use rust_socketio::{ - /// asynchronous::{Client, ClientBuilder}, + /// asynchronous::Client, /// Payload, /// }; /// @@ -125,26 +125,28 @@ impl Client { /// Payload::Text(values) => { /// if let Some(value) = values.first() { /// if value.is_string() { - /// let result = socket.try_transmitter::>(); - /// - /// result - /// .map(|transmitter| { - /// transmitter.send(String::from(value.as_str().unwrap())) - /// }) - /// .map_err(|err| eprintln!("{}", err)) - /// .ok(); + /// socket.try_transmitter::>().map_or_else( + /// |err| eprintln!("{}", err), + /// |tx| { + /// tx.send(String::from(value.as_str().unwrap())) + /// .map_or_else( + /// |err| eprintln!("{}", err), + /// |_| println!("Data transmitted successfully"), + /// ); + /// }, + /// ); /// } /// } /// } - /// Payload::Binary(_bin_data) => println!(), + /// Payload::Binary(bin_data) => println!("{:#?}", bin_data), /// #[allow(deprecated)] /// Payload::String(str) => println!("Received: {}", str), /// } /// } /// .boxed() - /// }) + /// }; /// ``` - pub fn try_transitter(&self) -> Result> { + pub fn try_transmitter(&self) -> Result> { match Arc::clone(&self.transmitter).downcast() { Ok(data) => Ok(data), Err(_) => Err(Error::TransmitterTypeResolutionFailure), @@ -644,7 +646,7 @@ mod test { use crate::{ asynchronous::{ - client::{builder::ClientBuilder, client::Client}, + client::{async_client::Client, builder::ClientBuilder}, ReconnectSettings, }, error::Result, diff --git a/socketio/src/asynchronous/client/builder.rs b/socketio/src/asynchronous/client/builder.rs index 3d3e5859..63fb9462 100644 --- a/socketio/src/asynchronous/client/builder.rs +++ b/socketio/src/asynchronous/client/builder.rs @@ -1,8 +1,8 @@ use super::{ + async_client::{Client, ReconnectSettings}, callback::{ Callback, DynAsyncAnyCallback, DynAsyncCallback, DynAsyncReconnectSettingsCallback, }, - client::{Client, ReconnectSettings}, }; use crate::asynchronous::socket::Socket as InnerSocket; use crate::{error::Result, Event, Payload, TransportType}; @@ -107,38 +107,27 @@ impl ClientBuilder { /// # Example /// /// ```no_run + /// use futures_util::FutureExt; + /// use std::sync::{Arc, mpsc}; + /// use rust_socketio::{ + /// asynchronous::{Client , ClientBuilder}, + /// Payload, Error, + /// }; + /// + /// async fn connect(url: &str) -> Result<(), Error> { + /// let (sender, receiver) = mpsc::channel::(); /// - /// let (sender, receiver) = mpsc::channel::(); - /// let client = ClientBuilder::new(url) - /// .namespace("/admin") - /// .on("test", |payload: Payload, socket: SocketIOClient| { - /// async move { - /// match payload { - /// Payload::Text(values) => { - /// if let Some(value) = values.first() { - /// if value.is_string() { - /// let result = socket.try_transitter::>(); - /// - /// result - /// .map(|transmitter| { - /// transmitter.send(String::from(value.as_str().unwrap())) - /// }) - /// .map_err(|err| eprintln!("{}", err)) - /// .ok(); - /// } - /// } - /// } - /// Payload::Binary(_bin_data) => println!(), - /// #[allow(deprecated)] - /// Payload::String(str) => println!("Received: {}", str), - /// } - /// } - /// .boxed() - /// }) - /// .transmitter(Arc::new(sender)) - /// .connect() - /// .await - /// .expect("Connection failed"); + /// let client = ClientBuilder::new(url) + /// .namespace("/admin") + /// .on("error", |err, _| { + /// async move { eprintln!("Error: {:#?}", err) }.boxed() + /// }) + /// .transmitter(Arc::new(sender)) + /// .connect() + /// .await?; + /// + /// Ok(()) + /// } /// ``` pub fn transmitter(mut self, data: Arc) -> Self { self.transmitter = Some(data); diff --git a/socketio/src/asynchronous/client/callback.rs b/socketio/src/asynchronous/client/callback.rs index 3188b175..9e49ecc5 100644 --- a/socketio/src/asynchronous/client/callback.rs +++ b/socketio/src/asynchronous/client/callback.rs @@ -1,13 +1,11 @@ +use super::async_client::{Client, ReconnectSettings}; +use crate::{Event, Payload}; use futures_util::future::BoxFuture; use std::{ fmt::Debug, ops::{Deref, DerefMut}, }; -use crate::{Event, Payload}; - -use super::client::{Client, ReconnectSettings}; - /// Internal type, provides a way to store futures and return them in a boxed manner. pub(crate) type DynAsyncCallback = Box FnMut(Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync>; diff --git a/socketio/src/asynchronous/client/mod.rs b/socketio/src/asynchronous/client/mod.rs index bbf7cc92..bba5bd9c 100644 --- a/socketio/src/asynchronous/client/mod.rs +++ b/socketio/src/asynchronous/client/mod.rs @@ -1,5 +1,5 @@ mod ack; +pub(crate) mod async_client; pub(crate) mod builder; #[cfg(feature = "async-callbacks")] mod callback; -pub(crate) mod client; diff --git a/socketio/src/asynchronous/mod.rs b/socketio/src/asynchronous/mod.rs index e57cdf6a..8445c4cd 100644 --- a/socketio/src/asynchronous/mod.rs +++ b/socketio/src/asynchronous/mod.rs @@ -2,9 +2,9 @@ mod client; mod generator; mod socket; +pub use client::async_client::{Client, ReconnectSettings}; #[cfg(feature = "async")] pub use client::builder::ClientBuilder; -pub use client::client::{Client, ReconnectSettings}; // re-export the macro pub use crate::{async_any_callback, async_callback}; diff --git a/socketio/src/client/builder.rs b/socketio/src/client/builder.rs index 724971f0..2f19cedb 100644 --- a/socketio/src/client/builder.rs +++ b/socketio/src/client/builder.rs @@ -1,18 +1,16 @@ use super::super::{event::Event, payload::Payload}; use super::callback::Callback; -use super::client::Client; +use crate::client::callback::{SocketAnyCallback, SocketCallback}; +use crate::client::Client; +use crate::error::Result; +use crate::socket::Socket as InnerSocket; use crate::RawClient; use native_tls::TlsConnector; use rust_engineio::client::ClientBuilder as EngineIoClientBuilder; use rust_engineio::header::{HeaderMap, HeaderValue}; -use url::Url; - -use crate::client::callback::{SocketAnyCallback, SocketCallback}; -use crate::error::Result; use std::collections::HashMap; use std::sync::{Arc, Mutex}; - -use crate::socket::Socket as InnerSocket; +use url::Url; /// Flavor of Engine.IO transport. #[derive(Clone, Eq, PartialEq)] diff --git a/socketio/src/client/client.rs b/socketio/src/client/client.rs deleted file mode 100644 index fe924307..00000000 --- a/socketio/src/client/client.rs +++ /dev/null @@ -1,478 +0,0 @@ -use std::{ - sync::{Arc, Mutex, RwLock}, - time::Duration, -}; - -use super::{ClientBuilder, RawClient}; -use crate::{ - error::Result, - packet::{Packet, PacketId}, - Error, -}; -pub(crate) use crate::{event::Event, payload::Payload}; -use backoff::ExponentialBackoff; -use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; - -#[derive(Clone)] -pub struct Client { - builder: Arc>, - client: Arc>, - backoff: ExponentialBackoff, -} - -impl Client { - pub(crate) fn new(builder: ClientBuilder) -> Result { - let builder_clone = builder.clone(); - let client = builder_clone.connect_raw()?; - let backoff = ExponentialBackoffBuilder::new() - .with_initial_interval(Duration::from_millis(builder.reconnect_delay_min)) - .with_max_interval(Duration::from_millis(builder.reconnect_delay_max)) - .build(); - - let s = Self { - builder: Arc::new(Mutex::new(builder)), - client: Arc::new(RwLock::new(client)), - backoff, - }; - s.poll_callback(); - - Ok(s) - } - - /// Updates the URL the client will connect to when reconnecting. - /// This is especially useful for updating query parameters. - pub fn set_reconnect_url>(&self, address: T) -> Result<()> { - self.builder.lock()?.address = address.into(); - Ok(()) - } - - /// Sends a message to the server using the underlying `engine.io` protocol. - /// This message takes an event, which could either be one of the common - /// events like "message" or "error" or a custom event like "foo". But be - /// careful, the data string needs to be valid JSON. It's recommended to use - /// a library like `serde_json` to serialize the data properly. - /// - /// # Example - /// ``` - /// use rust_socketio::{ClientBuilder, RawClient, Payload}; - /// use serde_json::json; - /// - /// let mut socket = ClientBuilder::new("http://localhost:4200/") - /// .on("test", |payload: Payload, socket: RawClient| { - /// println!("Received: {:#?}", payload); - /// socket.emit("test", json!({"hello": true})).expect("Server unreachable"); - /// }) - /// .connect() - /// .expect("connection failed"); - /// - /// let json_payload = json!({"token": 123}); - /// - /// let result = socket.emit("foo", json_payload); - /// - /// assert!(result.is_ok()); - /// ``` - pub fn emit(&self, event: E, data: D) -> Result<()> - where - E: Into, - D: Into, - { - let client = self.client.read()?; - // TODO(#230): like js client, buffer emit, resend after reconnect - client.emit(event, data) - } - - /// Sends a message to the server but `alloc`s an `ack` to check whether the - /// server responded in a given time span. This message takes an event, which - /// could either be one of the common events like "message" or "error" or a - /// custom event like "foo", as well as a data parameter. But be careful, - /// in case you send a [`Payload::String`], the string needs to be valid JSON. - /// It's even recommended to use a library like serde_json to serialize the data properly. - /// It also requires a timeout `Duration` in which the client needs to answer. - /// If the ack is acked in the correct time span, the specified callback is - /// called. The callback consumes a [`Payload`] which represents the data send - /// by the server. - /// - /// # Example - /// ``` - /// use rust_socketio::{ClientBuilder, Payload, RawClient}; - /// use serde_json::json; - /// use std::time::Duration; - /// use std::thread::sleep; - /// - /// let mut socket = ClientBuilder::new("http://localhost:4200/") - /// .on("foo", |payload: Payload, _| println!("Received: {:#?}", payload)) - /// .connect() - /// .expect("connection failed"); - /// - /// let ack_callback = |message: Payload, socket: RawClient| { - /// match message { - /// Payload::Text(values) => println!("{:#?}", values), - /// Payload::Binary(bytes) => println!("Received bytes: {:#?}", bytes), - /// // This is deprecated, use Payload::Text instead. - /// Payload::String(str) => println!("{}", str), - /// } - /// }; - /// - /// let payload = json!({"token": 123}); - /// socket.emit_with_ack("foo", payload, Duration::from_secs(2), ack_callback).unwrap(); - /// - /// sleep(Duration::from_secs(2)); - /// ``` - pub fn emit_with_ack( - &self, - event: E, - data: D, - timeout: Duration, - callback: F, - ) -> Result<()> - where - F: FnMut(Payload, RawClient) + 'static + Send, - E: Into, - D: Into, - { - let client = self.client.read()?; - // TODO(#230): like js client, buffer emit, resend after reconnect - client.emit_with_ack(event, data, timeout, callback) - } - - /// Disconnects this client from the server by sending a `socket.io` closing - /// packet. - /// # Example - /// ```rust - /// use rust_socketio::{ClientBuilder, Payload, RawClient}; - /// use serde_json::json; - /// - /// fn handle_test(payload: Payload, socket: RawClient) { - /// println!("Received: {:#?}", payload); - /// socket.emit("test", json!({"hello": true})).expect("Server unreachable"); - /// } - /// - /// let mut socket = ClientBuilder::new("http://localhost:4200/") - /// .on("test", handle_test) - /// .connect() - /// .expect("connection failed"); - /// - /// let json_payload = json!({"token": 123}); - /// - /// socket.emit("foo", json_payload); - /// - /// // disconnect from the server - /// socket.disconnect(); - /// - /// ``` - pub fn disconnect(&self) -> Result<()> { - let client = self.client.read()?; - client.disconnect() - } - - fn reconnect(&mut self) -> Result<()> { - let mut reconnect_attempts = 0; - let (reconnect, max_reconnect_attempts) = { - let builder = self.builder.lock()?; - (builder.reconnect, builder.max_reconnect_attempts) - }; - - if reconnect { - loop { - if let Some(max_reconnect_attempts) = max_reconnect_attempts { - reconnect_attempts += 1; - if reconnect_attempts > max_reconnect_attempts { - break; - } - } - - if let Some(backoff) = self.backoff.next_backoff() { - std::thread::sleep(backoff); - } - - if self.do_reconnect().is_ok() { - break; - } - } - } - - Ok(()) - } - - fn do_reconnect(&self) -> Result<()> { - let builder = self.builder.lock()?; - let new_client = builder.clone().connect_raw()?; - let mut client = self.client.write()?; - *client = new_client; - - Ok(()) - } - - pub(crate) fn iter(&self) -> Iter { - Iter { - socket: self.client.clone(), - } - } - - fn poll_callback(&self) { - let mut self_clone = self.clone(); - // Use thread to consume items in iterator in order to call callbacks - std::thread::spawn(move || { - // tries to restart a poll cycle whenever a 'normal' error occurs, - // it just panics on network errors, in case the poll cycle returned - // `Result::Ok`, the server receives a close frame so it's safe to - // terminate - for packet in self_clone.iter() { - let should_reconnect = match packet { - Err(Error::IncompleteResponseFromEngineIo(_)) => { - //TODO: 0.3.X handle errors - //TODO: logging error - true - } - Ok(Packet { - packet_type: PacketId::Disconnect, - .. - }) => match self_clone.builder.lock() { - Ok(builder) => builder.reconnect_on_disconnect, - Err(_) => false, - }, - _ => false, - }; - if should_reconnect { - let _ = self_clone.disconnect(); - let _ = self_clone.reconnect(); - } - } - }); - } -} - -pub(crate) struct Iter { - socket: Arc>, -} - -impl Iterator for Iter { - type Item = Result; - - fn next(&mut self) -> Option { - let socket = self.socket.read(); - match socket { - Ok(socket) => match socket.poll() { - Err(err) => Some(Err(err)), - Ok(Some(packet)) => Some(Ok(packet)), - // If the underlying engineIO connection is closed, - // throw an error so we know to reconnect - Ok(None) => Some(Err(Error::StoppedEngineIoSocket)), - }, - Err(_) => { - // Lock is poisoned, our iterator is useless. - None - } - } - } -} - -#[cfg(test)] -mod test { - use std::{ - sync::atomic::{AtomicUsize, Ordering}, - time::UNIX_EPOCH, - }; - - use super::*; - use crate::error::Result; - use crate::ClientBuilder; - use serde_json::json; - use serial_test::serial; - use std::time::{Duration, SystemTime}; - use url::Url; - - #[test] - #[serial(reconnect)] - fn socket_io_reconnect_integration() -> Result<()> { - static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0); - static CLOSE_NUM: AtomicUsize = AtomicUsize::new(0); - static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0); - - let url = crate::test::socket_io_restart_server(); - - let socket = ClientBuilder::new(url) - .reconnect(true) - .max_reconnect_attempts(100) - .reconnect_delay(100, 100) - .on(Event::Connect, move |_, socket| { - CONNECT_NUM.fetch_add(1, Ordering::Release); - let r = socket.emit_with_ack( - "message", - json!(""), - Duration::from_millis(100), - |_, _| {}, - ); - assert!(r.is_ok(), "should emit message success"); - }) - .on(Event::Close, move |_, _| { - CLOSE_NUM.fetch_add(1, Ordering::Release); - }) - .on("message", move |_, _socket| { - // test the iterator implementation and make sure there is a constant - // stream of packets, even when reconnecting - MESSAGE_NUM.fetch_add(1, Ordering::Release); - }) - .connect(); - - assert!(socket.is_ok(), "should connect success"); - let socket = socket.unwrap(); - - // waiting for server to emit message - std::thread::sleep(std::time::Duration::from_millis(500)); - - assert_eq!(load(&CONNECT_NUM), 1, "should connect once"); - assert_eq!(load(&MESSAGE_NUM), 1, "should receive one"); - assert_eq!(load(&CLOSE_NUM), 0, "should not close"); - - let r = socket.emit("restart_server", json!("")); - assert!(r.is_ok(), "should emit restart success"); - - // waiting for server to restart - for _ in 0..10 { - std::thread::sleep(std::time::Duration::from_millis(400)); - if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 { - break; - } - } - - assert_eq!(load(&CONNECT_NUM), 2, "should connect twice"); - assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages"); - assert_eq!(load(&CLOSE_NUM), 1, "should close once"); - - socket.disconnect()?; - Ok(()) - } - - #[test] - fn socket_io_reconnect_url_auth_integration() -> Result<()> { - static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0); - static CLOSE_NUM: AtomicUsize = AtomicUsize::new(0); - static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0); - - fn get_url() -> Url { - let timestamp = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_millis(); - let mut url = crate::test::socket_io_restart_url_auth_server(); - url.set_query(Some(&format!("timestamp={timestamp}"))); - url - } - - let socket = ClientBuilder::new(get_url()) - .reconnect(true) - .max_reconnect_attempts(100) - .reconnect_delay(100, 100) - .on(Event::Connect, move |_, socket| { - CONNECT_NUM.fetch_add(1, Ordering::Release); - let result = socket.emit_with_ack( - "message", - json!(""), - Duration::from_millis(100), - |_, _| {}, - ); - assert!(result.is_ok(), "should emit message success"); - }) - .on(Event::Close, move |_, _| { - CLOSE_NUM.fetch_add(1, Ordering::Release); - }) - .on("message", move |_, _| { - // test the iterator implementation and make sure there is a constant - // stream of packets, even when reconnecting - MESSAGE_NUM.fetch_add(1, Ordering::Release); - }) - .connect(); - - assert!(socket.is_ok(), "should connect success"); - let socket = socket.unwrap(); - - // waiting for server to emit message - std::thread::sleep(std::time::Duration::from_millis(500)); - - assert_eq!(load(&CONNECT_NUM), 1, "should connect once"); - assert_eq!(load(&MESSAGE_NUM), 1, "should receive one"); - assert_eq!(load(&CLOSE_NUM), 0, "should not close"); - - // waiting for timestamp in url to expire - std::thread::sleep(std::time::Duration::from_secs(1)); - - socket.set_reconnect_url(get_url())?; - - let result = socket.emit("restart_server", json!("")); - assert!(result.is_ok(), "should emit restart success"); - - // waiting for server to restart - for _ in 0..10 { - std::thread::sleep(std::time::Duration::from_millis(400)); - if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 { - break; - } - } - - assert_eq!(load(&CONNECT_NUM), 2, "should connect twice"); - assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages"); - assert_eq!(load(&CLOSE_NUM), 1, "should close once"); - - socket.disconnect()?; - Ok(()) - } - - #[test] - fn socket_io_iterator_integration() -> Result<()> { - let url = crate::test::socket_io_server(); - let builder = ClientBuilder::new(url); - let builder_clone = builder.clone(); - - let client = Arc::new(RwLock::new(builder_clone.connect_raw()?)); - let mut socket = Client { - builder: Arc::new(Mutex::new(builder)), - client, - backoff: Default::default(), - }; - let socket_clone = socket.clone(); - - let packets: Arc>> = Default::default(); - let packets_clone = packets.clone(); - - std::thread::spawn(move || { - for packet in socket_clone.iter() { - { - let mut packets = packets_clone.write().unwrap(); - if let Ok(packet) = packet { - (*packets).push(packet); - } - } - } - }); - - // waiting for client to emit messages - std::thread::sleep(Duration::from_millis(100)); - let lock = packets.read().unwrap(); - let pre_num = lock.len(); - drop(lock); - - let _ = socket.disconnect(); - socket.reconnect()?; - - // waiting for client to emit messages - std::thread::sleep(Duration::from_millis(100)); - - let lock = packets.read().unwrap(); - let post_num = lock.len(); - drop(lock); - - assert!( - pre_num < post_num, - "pre_num {} should less than post_num {}", - pre_num, - post_num - ); - - Ok(()) - } - - fn load(num: &AtomicUsize) -> usize { - num.load(Ordering::Acquire) - } -} diff --git a/socketio/src/client/mod.rs b/socketio/src/client/mod.rs index e3884b64..778db8f0 100644 --- a/socketio/src/client/mod.rs +++ b/socketio/src/client/mod.rs @@ -1,11 +1,484 @@ mod builder; +mod callback; mod raw_client; pub use builder::ClientBuilder; pub use builder::TransportType; -pub use client::Client; pub use raw_client::RawClient; -/// Internal callback type -mod callback; -mod client; +use crate::{ + error::Result, + packet::{Packet, PacketId}, + Error, +}; +pub(crate) use crate::{event::Event, payload::Payload}; +use backoff::ExponentialBackoff; +use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; +use std::{ + sync::{Arc, Mutex, RwLock}, + time::Duration, +}; + +#[derive(Clone)] +pub struct Client { + builder: Arc>, + client: Arc>, + backoff: ExponentialBackoff, +} + +impl Client { + pub(crate) fn new(builder: ClientBuilder) -> Result { + let builder_clone = builder.clone(); + let client = builder_clone.connect_raw()?; + let backoff = ExponentialBackoffBuilder::new() + .with_initial_interval(Duration::from_millis(builder.reconnect_delay_min)) + .with_max_interval(Duration::from_millis(builder.reconnect_delay_max)) + .build(); + + let s = Self { + builder: Arc::new(Mutex::new(builder)), + client: Arc::new(RwLock::new(client)), + backoff, + }; + s.poll_callback(); + + Ok(s) + } + + /// Updates the URL the client will connect to when reconnecting. + /// This is especially useful for updating query parameters. + pub fn set_reconnect_url>(&self, address: T) -> Result<()> { + self.builder.lock()?.address = address.into(); + Ok(()) + } + + /// Sends a message to the server using the underlying `engine.io` protocol. + /// This message takes an event, which could either be one of the common + /// events like "message" or "error" or a custom event like "foo". But be + /// careful, the data string needs to be valid JSON. It's recommended to use + /// a library like `serde_json` to serialize the data properly. + /// + /// # Example + /// ``` + /// use rust_socketio::{ClientBuilder, RawClient, Payload}; + /// use serde_json::json; + /// + /// let mut socket = ClientBuilder::new("http://localhost:4200/") + /// .on("test", |payload: Payload, socket: RawClient| { + /// println!("Received: {:#?}", payload); + /// socket.emit("test", json!({"hello": true})).expect("Server unreachable"); + /// }) + /// .connect() + /// .expect("connection failed"); + /// + /// let json_payload = json!({"token": 123}); + /// + /// let result = socket.emit("foo", json_payload); + /// + /// assert!(result.is_ok()); + /// ``` + pub fn emit(&self, event: E, data: D) -> Result<()> + where + E: Into, + D: Into, + { + let client = self.client.read()?; + // TODO(#230): like js client, buffer emit, resend after reconnect + client.emit(event, data) + } + + /// Sends a message to the server but `alloc`s an `ack` to check whether the + /// server responded in a given time span. This message takes an event, which + /// could either be one of the common events like "message" or "error" or a + /// custom event like "foo", as well as a data parameter. But be careful, + /// in case you send a [`Payload::String`], the string needs to be valid JSON. + /// It's even recommended to use a library like serde_json to serialize the data properly. + /// It also requires a timeout `Duration` in which the client needs to answer. + /// If the ack is acked in the correct time span, the specified callback is + /// called. The callback consumes a [`Payload`] which represents the data send + /// by the server. + /// + /// # Example + /// ``` + /// use rust_socketio::{ClientBuilder, Payload, RawClient}; + /// use serde_json::json; + /// use std::time::Duration; + /// use std::thread::sleep; + /// + /// let mut socket = ClientBuilder::new("http://localhost:4200/") + /// .on("foo", |payload: Payload, _| println!("Received: {:#?}", payload)) + /// .connect() + /// .expect("connection failed"); + /// + /// let ack_callback = |message: Payload, socket: RawClient| { + /// match message { + /// Payload::Text(values) => println!("{:#?}", values), + /// Payload::Binary(bytes) => println!("Received bytes: {:#?}", bytes), + /// // This is deprecated, use Payload::Text instead. + /// Payload::String(str) => println!("{}", str), + /// } + /// }; + /// + /// let payload = json!({"token": 123}); + /// socket.emit_with_ack("foo", payload, Duration::from_secs(2), ack_callback).unwrap(); + /// + /// sleep(Duration::from_secs(2)); + /// ``` + pub fn emit_with_ack( + &self, + event: E, + data: D, + timeout: Duration, + callback: F, + ) -> Result<()> + where + F: FnMut(Payload, RawClient) + 'static + Send, + E: Into, + D: Into, + { + let client = self.client.read()?; + // TODO(#230): like js client, buffer emit, resend after reconnect + client.emit_with_ack(event, data, timeout, callback) + } + + /// Disconnects this client from the server by sending a `socket.io` closing + /// packet. + /// # Example + /// ```rust + /// use rust_socketio::{ClientBuilder, Payload, RawClient}; + /// use serde_json::json; + /// + /// fn handle_test(payload: Payload, socket: RawClient) { + /// println!("Received: {:#?}", payload); + /// socket.emit("test", json!({"hello": true})).expect("Server unreachable"); + /// } + /// + /// let mut socket = ClientBuilder::new("http://localhost:4200/") + /// .on("test", handle_test) + /// .connect() + /// .expect("connection failed"); + /// + /// let json_payload = json!({"token": 123}); + /// + /// socket.emit("foo", json_payload); + /// + /// // disconnect from the server + /// socket.disconnect(); + /// + /// ``` + pub fn disconnect(&self) -> Result<()> { + let client = self.client.read()?; + client.disconnect() + } + + fn reconnect(&mut self) -> Result<()> { + let mut reconnect_attempts = 0; + let (reconnect, max_reconnect_attempts) = { + let builder = self.builder.lock()?; + (builder.reconnect, builder.max_reconnect_attempts) + }; + + if reconnect { + loop { + if let Some(max_reconnect_attempts) = max_reconnect_attempts { + reconnect_attempts += 1; + if reconnect_attempts > max_reconnect_attempts { + break; + } + } + + if let Some(backoff) = self.backoff.next_backoff() { + std::thread::sleep(backoff); + } + + if self.do_reconnect().is_ok() { + break; + } + } + } + + Ok(()) + } + + fn do_reconnect(&self) -> Result<()> { + let builder = self.builder.lock()?; + let new_client = builder.clone().connect_raw()?; + let mut client = self.client.write()?; + *client = new_client; + + Ok(()) + } + + pub(crate) fn iter(&self) -> Iter { + Iter { + socket: self.client.clone(), + } + } + + fn poll_callback(&self) { + let mut self_clone = self.clone(); + // Use thread to consume items in iterator in order to call callbacks + std::thread::spawn(move || { + // tries to restart a poll cycle whenever a 'normal' error occurs, + // it just panics on network errors, in case the poll cycle returned + // `Result::Ok`, the server receives a close frame so it's safe to + // terminate + for packet in self_clone.iter() { + let should_reconnect = match packet { + Err(Error::IncompleteResponseFromEngineIo(_)) => { + //TODO: 0.3.X handle errors + //TODO: logging error + true + } + Ok(Packet { + packet_type: PacketId::Disconnect, + .. + }) => match self_clone.builder.lock() { + Ok(builder) => builder.reconnect_on_disconnect, + Err(_) => false, + }, + _ => false, + }; + if should_reconnect { + let _ = self_clone.disconnect(); + let _ = self_clone.reconnect(); + } + } + }); + } +} + +pub(crate) struct Iter { + socket: Arc>, +} + +impl Iterator for Iter { + type Item = Result; + + fn next(&mut self) -> Option { + let socket = self.socket.read(); + match socket { + Ok(socket) => match socket.poll() { + Err(err) => Some(Err(err)), + Ok(Some(packet)) => Some(Ok(packet)), + // If the underlying engineIO connection is closed, + // throw an error so we know to reconnect + Ok(None) => Some(Err(Error::StoppedEngineIoSocket)), + }, + Err(_) => { + // Lock is poisoned, our iterator is useless. + None + } + } + } +} + +#[cfg(test)] +mod test { + use std::{ + sync::atomic::{AtomicUsize, Ordering}, + time::UNIX_EPOCH, + }; + + use super::*; + use crate::error::Result; + use crate::ClientBuilder; + use serde_json::json; + use serial_test::serial; + use std::time::{Duration, SystemTime}; + use url::Url; + + #[test] + #[serial(reconnect)] + fn socket_io_reconnect_integration() -> Result<()> { + static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0); + static CLOSE_NUM: AtomicUsize = AtomicUsize::new(0); + static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0); + + let url = crate::test::socket_io_restart_server(); + + let socket = ClientBuilder::new(url) + .reconnect(true) + .max_reconnect_attempts(100) + .reconnect_delay(100, 100) + .on(Event::Connect, move |_, socket| { + CONNECT_NUM.fetch_add(1, Ordering::Release); + let r = socket.emit_with_ack( + "message", + json!(""), + Duration::from_millis(100), + |_, _| {}, + ); + assert!(r.is_ok(), "should emit message success"); + }) + .on(Event::Close, move |_, _| { + CLOSE_NUM.fetch_add(1, Ordering::Release); + }) + .on("message", move |_, _socket| { + // test the iterator implementation and make sure there is a constant + // stream of packets, even when reconnecting + MESSAGE_NUM.fetch_add(1, Ordering::Release); + }) + .connect(); + + assert!(socket.is_ok(), "should connect success"); + let socket = socket.unwrap(); + + // waiting for server to emit message + std::thread::sleep(std::time::Duration::from_millis(500)); + + assert_eq!(load(&CONNECT_NUM), 1, "should connect once"); + assert_eq!(load(&MESSAGE_NUM), 1, "should receive one"); + assert_eq!(load(&CLOSE_NUM), 0, "should not close"); + + let r = socket.emit("restart_server", json!("")); + assert!(r.is_ok(), "should emit restart success"); + + // waiting for server to restart + for _ in 0..10 { + std::thread::sleep(std::time::Duration::from_millis(400)); + if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 { + break; + } + } + + assert_eq!(load(&CONNECT_NUM), 2, "should connect twice"); + assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages"); + assert_eq!(load(&CLOSE_NUM), 1, "should close once"); + + socket.disconnect()?; + Ok(()) + } + + #[test] + fn socket_io_reconnect_url_auth_integration() -> Result<()> { + static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0); + static CLOSE_NUM: AtomicUsize = AtomicUsize::new(0); + static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0); + + fn get_url() -> Url { + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis(); + let mut url = crate::test::socket_io_restart_url_auth_server(); + url.set_query(Some(&format!("timestamp={timestamp}"))); + url + } + + let socket = ClientBuilder::new(get_url()) + .reconnect(true) + .max_reconnect_attempts(100) + .reconnect_delay(100, 100) + .on(Event::Connect, move |_, socket| { + CONNECT_NUM.fetch_add(1, Ordering::Release); + let result = socket.emit_with_ack( + "message", + json!(""), + Duration::from_millis(100), + |_, _| {}, + ); + assert!(result.is_ok(), "should emit message success"); + }) + .on(Event::Close, move |_, _| { + CLOSE_NUM.fetch_add(1, Ordering::Release); + }) + .on("message", move |_, _| { + // test the iterator implementation and make sure there is a constant + // stream of packets, even when reconnecting + MESSAGE_NUM.fetch_add(1, Ordering::Release); + }) + .connect(); + + assert!(socket.is_ok(), "should connect success"); + let socket = socket.unwrap(); + + // waiting for server to emit message + std::thread::sleep(std::time::Duration::from_millis(500)); + + assert_eq!(load(&CONNECT_NUM), 1, "should connect once"); + assert_eq!(load(&MESSAGE_NUM), 1, "should receive one"); + assert_eq!(load(&CLOSE_NUM), 0, "should not close"); + + // waiting for timestamp in url to expire + std::thread::sleep(std::time::Duration::from_secs(1)); + + socket.set_reconnect_url(get_url())?; + + let result = socket.emit("restart_server", json!("")); + assert!(result.is_ok(), "should emit restart success"); + + // waiting for server to restart + for _ in 0..10 { + std::thread::sleep(std::time::Duration::from_millis(400)); + if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 { + break; + } + } + + assert_eq!(load(&CONNECT_NUM), 2, "should connect twice"); + assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages"); + assert_eq!(load(&CLOSE_NUM), 1, "should close once"); + + socket.disconnect()?; + Ok(()) + } + + #[test] + fn socket_io_iterator_integration() -> Result<()> { + let url = crate::test::socket_io_server(); + let builder = ClientBuilder::new(url); + let builder_clone = builder.clone(); + + let client = Arc::new(RwLock::new(builder_clone.connect_raw()?)); + let mut socket = Client { + builder: Arc::new(Mutex::new(builder)), + client, + backoff: Default::default(), + }; + let socket_clone = socket.clone(); + + let packets: Arc>> = Default::default(); + let packets_clone = packets.clone(); + + std::thread::spawn(move || { + for packet in socket_clone.iter() { + { + let mut packets = packets_clone.write().unwrap(); + if let Ok(packet) = packet { + (*packets).push(packet); + } + } + } + }); + + // waiting for client to emit messages + std::thread::sleep(Duration::from_millis(100)); + let lock = packets.read().unwrap(); + let pre_num = lock.len(); + drop(lock); + + let _ = socket.disconnect(); + socket.reconnect()?; + + // waiting for client to emit messages + std::thread::sleep(Duration::from_millis(100)); + + let lock = packets.read().unwrap(); + let post_num = lock.len(); + drop(lock); + + assert!( + pre_num < post_num, + "pre_num {} should less than post_num {}", + pre_num, + post_num + ); + + Ok(()) + } + + fn load(num: &AtomicUsize) -> usize { + num.load(Ordering::Acquire) + } +} diff --git a/socketio/src/lib.rs b/socketio/src/lib.rs index b913eb4d..b0558e96 100644 --- a/socketio/src/lib.rs +++ b/socketio/src/lib.rs @@ -176,7 +176,7 @@ pub(crate) mod packet; /// Defines the types of payload (binary or string), that /// could be sent or received. pub mod payload; -pub(self) mod socket; +mod socket; /// Deprecated import since 0.3.0-alpha-2, use Error in the crate root instead. /// Contains the error type which will be returned with every result in this diff --git a/socketio/src/packet.rs b/socketio/src/packet.rs index e74dedb5..6b2629ed 100644 --- a/socketio/src/packet.rs +++ b/socketio/src/packet.rs @@ -2,7 +2,6 @@ use crate::error::{Error, Result}; use crate::{Event, Payload}; use bytes::Bytes; use serde::de::IgnoredAny; - use std::convert::TryFrom; use std::fmt::Write; use std::str::from_utf8 as str_from_utf8; @@ -34,10 +33,10 @@ impl Packet { /// Returns a packet for a payload, could be used for both binary and non binary /// events and acks. Convenience method. #[inline] - pub(crate) fn new_from_payload<'a>( + pub(crate) fn new_from_payload( payload: Payload, event: Event, - nsp: &'a str, + nsp: &str, id: Option, ) -> Result { match payload { @@ -216,7 +215,7 @@ impl TryFrom<&Bytes> for Packet { /// this member. This is done because the attachment is usually /// send in another packet. fn try_from(payload: &Bytes) -> Result { - let mut payload = str_from_utf8(&payload).map_err(Error::InvalidUtf8)?; + let mut payload = str_from_utf8(payload).map_err(Error::InvalidUtf8)?; let mut packet = Packet::default(); // packet_type From b386b0c4490e2d8274741edf6c1c891947385bfc Mon Sep 17 00:00:00 2001 From: Matt Williams Date: Tue, 14 May 2024 15:18:29 -0400 Subject: [PATCH 3/4] Add transmitter sender data layer for raw client --- socketio/Cargo.toml | 2 +- socketio/examples/readme.rs | 3 + socketio/examples/sync_transmitter.rs | 105 ++++++++++++++++++++ socketio/src/asynchronous/client/builder.rs | 4 +- socketio/src/client/builder.rs | 32 ++++++ socketio/src/client/raw_client.rs | 51 ++++++++++ 6 files changed, 194 insertions(+), 3 deletions(-) create mode 100644 socketio/examples/sync_transmitter.rs diff --git a/socketio/Cargo.toml b/socketio/Cargo.toml index 95ee7dfb..7c1c18e3 100644 --- a/socketio/Cargo.toml +++ b/socketio/Cargo.toml @@ -40,7 +40,7 @@ version = "1.37.0" features = ["macros", "rt-multi-thread"] [features] -default = ["async"] +default = [] async-callbacks = ["rust_engineio/async-callbacks"] async = ["async-callbacks", "rust_engineio/async", "tokio", "futures-util", "async-stream"] diff --git a/socketio/examples/readme.rs b/socketio/examples/readme.rs index 00878192..c8610211 100644 --- a/socketio/examples/readme.rs +++ b/socketio/examples/readme.rs @@ -1,5 +1,6 @@ use rust_socketio::{ClientBuilder, Payload, RawClient}; use serde_json::json; +use std::thread::sleep; use std::time::Duration; fn main() { @@ -45,5 +46,7 @@ fn main() { .emit_with_ack("test", json_payload, Duration::from_secs(2), ack_callback) .expect("Server unreachable"); + sleep(Duration::from_secs(2)); + socket.disconnect().expect("Disconnect failed") } diff --git a/socketio/examples/sync_transmitter.rs b/socketio/examples/sync_transmitter.rs new file mode 100644 index 00000000..7d5df54e --- /dev/null +++ b/socketio/examples/sync_transmitter.rs @@ -0,0 +1,105 @@ +use rust_socketio::{ + client::Client as SocketIOClient, ClientBuilder as SocketIOClientBuilder, + Error as SocketIOError, Payload, RawClient, +}; +use serde_json::json; +use std::sync::{mpsc, Arc}; +use std::thread::sleep; +use std::time::Duration; + +struct ComplexData { + /// There should be many more fields below in real life, + /// probaly wrapped in Arc> if you're writing a more serious client. + data: String, +} + +struct TransmitterClient { + client: SocketIOClient, + receiver: mpsc::Receiver, + complex: ComplexData, +} + +impl TransmitterClient { + fn connect(url: &str) -> Result { + let (sender, receiver) = mpsc::channel::(); + + let client = SocketIOClientBuilder::new(url) + .namespace("/admin") + .on( + "test", + |payload: Payload, socket: RawClient| match payload { + Payload::Text(values) => { + if let Some(value) = values.first() { + if value.is_string() { + socket + .try_transmitter::>() + .map_or_else( + |err| eprintln!("{}", err), + |tx| { + tx.send(String::from(value.as_str().unwrap())) + .map_or_else( + |err| eprintln!("{}", err), + |_| println!("Data transmitted successfully"), + ); + }, + ); + } + } + } + Payload::Binary(bin_data) => println!("Received bytes: {:#?}", bin_data), + #[allow(deprecated)] + Payload::String(str) => println!("Received: {}", str), + }, + ) + .on("error", |err, _| eprintln!("Error: {:#?}", err)) + .transmitter(Arc::new(sender)) + .connect()?; + + Ok(Self { + client, + receiver, + complex: ComplexData { + data: "".to_string(), + }, + }) + } + + fn get_test(&mut self) -> Option { + match self.client.emit("test", json!({"got ack": true})) { + Ok(_) => { + match self.receiver.recv() { + Ok(complex_data) => { + // In the real world the data is probably a serialized json_rpc object + // or some other complex data layer which needs complex business and derserialization logic. + // Best to do that here, and not inside those restrictive callbacks. + self.complex.data = complex_data; + Some(self.complex.data.clone()) + } + Err(err) => { + eprintln!("Transmission buffer is probably full: {}", err); + None + } + } + } + Err(err) => { + eprintln!("Server unreachable: {}", err); + None + } + } + } +} + +fn main() { + match TransmitterClient::connect("http://localhost:4200/") { + Ok(mut client) => { + if let Some(test_data) = client.get_test() { + println!("test event data from internal transmitter: {}", test_data); + } + } + Err(err) => { + eprintln!("Failed to connect to server: {}", err); + } + } + + sleep(Duration::from_secs(2)); +} diff --git a/socketio/src/asynchronous/client/builder.rs b/socketio/src/asynchronous/client/builder.rs index 63fb9462..553cbf62 100644 --- a/socketio/src/asynchronous/client/builder.rs +++ b/socketio/src/asynchronous/client/builder.rs @@ -114,7 +114,7 @@ impl ClientBuilder { /// Payload, Error, /// }; /// - /// async fn connect(url: &str) -> Result<(), Error> { + /// async fn connect(url: &str) -> Result { /// let (sender, receiver) = mpsc::channel::(); /// /// let client = ClientBuilder::new(url) @@ -126,7 +126,7 @@ impl ClientBuilder { /// .connect() /// .await?; /// - /// Ok(()) + /// Ok(client) /// } /// ``` pub fn transmitter(mut self, data: Arc) -> Self { diff --git a/socketio/src/client/builder.rs b/socketio/src/client/builder.rs index 2f19cedb..16c20b21 100644 --- a/socketio/src/client/builder.rs +++ b/socketio/src/client/builder.rs @@ -39,6 +39,7 @@ pub struct ClientBuilder { opening_headers: Option, transport_type: TransportType, auth: Option, + transmitter: Option>, pub(crate) reconnect: bool, pub(crate) reconnect_on_disconnect: bool, // None reconnect attempts represent infinity. @@ -90,6 +91,7 @@ impl ClientBuilder { opening_headers: None, transport_type: TransportType::Any, auth: None, + transmitter: None, reconnect: true, reconnect_on_disconnect: false, // None means infinity @@ -99,6 +101,35 @@ impl ClientBuilder { } } + /// Sets the data transmission object, ideally the standard libraries + /// multi-producer single consumer [`std::sync::mpsc::Sender`] should be used. + /// + /// ```no_run + /// use rust_socketio::{ + /// client::Client, ClientBuilder, + /// Error , Payload, RawClient, + /// }; + /// use std::sync::{Arc, mpsc}; + /// + /// fn connect(url: &str) -> Result { + /// let (sender, receiver) = mpsc::channel::(); + /// + /// let client = ClientBuilder::new(url) + /// .namespace("/admin") + /// .on("error", |err, _| { + /// eprintln!("Error: {:#?}", err); + /// }) + /// .transmitter(Arc::new(sender)) + /// .connect()?; + /// + /// Ok(client) + /// } + /// ``` + pub fn transmitter(mut self, data: Arc) -> Self { + self.transmitter = Some(data); + self + } + /// Sets the target namespace of the client. The namespace should start /// with a leading `/`. Valid examples are e.g. `/admin`, `/foo`. pub fn namespace>(mut self, namespace: T) -> Self { @@ -363,6 +394,7 @@ impl ClientBuilder { self.on, self.on_any, self.auth, + self.transmitter.unwrap_or(Arc::new(())), )?; socket.connect()?; diff --git a/socketio/src/client/raw_client.rs b/socketio/src/client/raw_client.rs index 0686683f..f1f30d0f 100644 --- a/socketio/src/client/raw_client.rs +++ b/socketio/src/client/raw_client.rs @@ -41,6 +41,7 @@ pub struct RawClient { nsp: String, // Data send in the opening packet (commonly used as for auth) auth: Option, + transmitter: Arc, } impl RawClient { @@ -54,6 +55,7 @@ impl RawClient { on: Arc>>>, on_any: Arc>>>, auth: Option, + transmitter: Arc, ) -> Result { Ok(RawClient { socket, @@ -62,9 +64,58 @@ impl RawClient { on_any, outstanding_acks: Arc::new(Mutex::new(Vec::new())), auth, + transmitter, }) } + /// Attempts to retrieve the transmitted data of type `D` from the transmitter. + /// + /// This function clones the transmitter and attempts to downcast it to an `Arc`. + /// If the downcast is successful, it returns the cloned data wrapped in a `Result`. + /// If the downcast fails, indicating that the transmitter contains data of an incompatible type, + /// it returns an `Err` with an `Error::TransmitterTypeResolutionFailure`. + /// + /// # Example + /// + /// ```no_run + /// use rust_socketio::{ + /// client::Client, ClientBuilder, + /// Error , Payload, RawClient, + /// }; + /// use std::sync::mpsc; + /// + /// + /// let callback = |payload: Payload, socket: RawClient| { + /// match payload { + /// Payload::Text(values) => { + /// if let Some(value) = values.first() { + /// if value.is_string() { + /// socket.try_transmitter::>().map_or_else( + /// |err| eprintln!("{}", err), + /// |tx| { + /// tx.send(String::from(value.as_str().unwrap())) + /// .map_or_else( + /// |err| eprintln!("{}", err), + /// |_| println!("Data transmitted successfully"), + /// ); + /// }, + /// ); + /// } + /// } + /// } + /// Payload::Binary(bin_data) => println!("{:#?}", bin_data), + /// #[allow(deprecated)] + /// Payload::String(str) => println!("Received: {}", str), + /// } + /// }; + /// ``` + pub fn try_transmitter(&self) -> Result> { + match Arc::clone(&self.transmitter).downcast() { + Ok(data) => Ok(data), + Err(_) => Err(Error::TransmitterTypeResolutionFailure), + } + } + /// Connects the client to a server. Afterwards the `emit_*` methods can be /// called to interact with the server. Attention: it's not allowed to add a /// callback after a call to this method. From 82d5703d9288862d9ba62284bb0c986335367f16 Mon Sep 17 00:00:00 2001 From: Matt Williams Date: Thu, 16 May 2024 11:20:34 -0400 Subject: [PATCH 4/4] Refactor example code and documentation. Put client.rs modules back, and explicitly allow the clippy warnings. --- .gitignore | 1 - engineio/src/client/client.rs | 661 +++++++++++++++++ engineio/src/client/mod.rs | 665 +----------------- socketio/examples/async_transmitter.rs | 88 ++- socketio/examples/sync_transmitter.rs | 73 +- .../src/asynchronous/client/async_client.rs | 36 +- socketio/src/asynchronous/client/builder.rs | 7 +- socketio/src/client/builder.rs | 8 +- socketio/src/client/client.rs | 478 +++++++++++++ socketio/src/client/mod.rs | 482 +------------ socketio/src/client/raw_client.rs | 41 +- 11 files changed, 1264 insertions(+), 1276 deletions(-) create mode 100644 engineio/src/client/client.rs create mode 100644 socketio/src/client/client.rs diff --git a/.gitignore b/.gitignore index f41165ab..22e6dfbc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ target -target_ra .vscode .idea ci/node_modules diff --git a/engineio/src/client/client.rs b/engineio/src/client/client.rs new file mode 100644 index 00000000..69fdd8df --- /dev/null +++ b/engineio/src/client/client.rs @@ -0,0 +1,661 @@ +use super::super::socket::Socket as InnerSocket; +use crate::callback::OptionalCallback; +use crate::error::{Error, Result}; +use crate::header::HeaderMap; +use crate::packet::{HandshakePacket, Packet, PacketId}; +use crate::socket::DEFAULT_MAX_POLL_TIMEOUT; +use crate::transport::Transport; +use crate::transports::{PollingTransport, WebsocketSecureTransport, WebsocketTransport}; +use crate::ENGINE_IO_VERSION; +use bytes::Bytes; +use native_tls::TlsConnector; +use std::convert::TryFrom; +use std::convert::TryInto; +use std::fmt::Debug; +use url::Url; + +/// An engine.io client that allows interaction with the connected engine.io +/// server. This client provides means for connecting, disconnecting and sending +/// packets to the server. +/// +/// ## Note: +/// There is no need to put this Client behind an `Arc`, as the type uses `Arc` +/// internally and provides a shared state beyond all cloned instances. +#[derive(Clone, Debug)] +pub struct Client { + socket: InnerSocket, +} + +#[derive(Clone, Debug)] +pub struct ClientBuilder { + url: Url, + tls_config: Option, + headers: Option, + handshake: Option, + on_error: OptionalCallback, + on_open: OptionalCallback<()>, + on_close: OptionalCallback<()>, + on_data: OptionalCallback, + on_packet: OptionalCallback, +} + +impl ClientBuilder { + pub fn new(url: Url) -> Self { + let mut url = url; + url.query_pairs_mut() + .append_pair("EIO", &ENGINE_IO_VERSION.to_string()); + + // No path add engine.io + if url.path() == "/" { + url.set_path("/engine.io/"); + } + ClientBuilder { + url, + headers: None, + tls_config: None, + handshake: None, + on_close: OptionalCallback::default(), + on_data: OptionalCallback::default(), + on_error: OptionalCallback::default(), + on_open: OptionalCallback::default(), + on_packet: OptionalCallback::default(), + } + } + + /// Specify transport's tls config + pub fn tls_config(mut self, tls_config: TlsConnector) -> Self { + self.tls_config = Some(tls_config); + self + } + + /// Specify transport's HTTP headers + pub fn headers(mut self, headers: HeaderMap) -> Self { + self.headers = Some(headers); + self + } + + /// Registers the `on_close` callback. + pub fn on_close(mut self, callback: T) -> Self + where + T: Fn(()) + 'static + Sync + Send, + { + self.on_close = OptionalCallback::new(callback); + self + } + + /// Registers the `on_data` callback. + pub fn on_data(mut self, callback: T) -> Self + where + T: Fn(Bytes) + 'static + Sync + Send, + { + self.on_data = OptionalCallback::new(callback); + self + } + + /// Registers the `on_error` callback. + pub fn on_error(mut self, callback: T) -> Self + where + T: Fn(String) + 'static + Sync + Send, + { + self.on_error = OptionalCallback::new(callback); + self + } + + /// Registers the `on_open` callback. + pub fn on_open(mut self, callback: T) -> Self + where + T: Fn(()) + 'static + Sync + Send, + { + self.on_open = OptionalCallback::new(callback); + self + } + + /// Registers the `on_packet` callback. + pub fn on_packet(mut self, callback: T) -> Self + where + T: Fn(Packet) + 'static + Sync + Send, + { + self.on_packet = OptionalCallback::new(callback); + self + } + + /// Performs the handshake + fn handshake_with_transport(&mut self, transport: &T) -> Result<()> { + // No need to handshake twice + if self.handshake.is_some() { + return Ok(()); + } + + let mut url = self.url.clone(); + + let handshake: HandshakePacket = + Packet::try_from(transport.poll(DEFAULT_MAX_POLL_TIMEOUT)?)?.try_into()?; + + // update the base_url with the new sid + url.query_pairs_mut().append_pair("sid", &handshake.sid[..]); + + self.handshake = Some(handshake); + + self.url = url; + + Ok(()) + } + + fn handshake(&mut self) -> Result<()> { + if self.handshake.is_some() { + return Ok(()); + } + + // Start with polling transport + let transport = PollingTransport::new( + self.url.clone(), + self.tls_config.clone(), + self.headers.clone().map(|v| v.try_into().unwrap()), + ); + + self.handshake_with_transport(&transport) + } + + /// Build websocket if allowed, if not fall back to polling + pub fn build(mut self) -> Result { + self.handshake()?; + + if self.websocket_upgrade()? { + self.build_websocket_with_upgrade() + } else { + self.build_polling() + } + } + + /// Build socket with polling transport + pub fn build_polling(mut self) -> Result { + self.handshake()?; + + // Make a polling transport with new sid + let transport = PollingTransport::new( + self.url, + self.tls_config, + self.headers.map(|v| v.try_into().unwrap()), + ); + + // SAFETY: handshake function called previously. + Ok(Client { + socket: InnerSocket::new( + transport.into(), + self.handshake.unwrap(), + self.on_close, + self.on_data, + self.on_error, + self.on_open, + self.on_packet, + ), + }) + } + + /// Build socket with a polling transport then upgrade to websocket transport + pub fn build_websocket_with_upgrade(mut self) -> Result { + self.handshake()?; + + if self.websocket_upgrade()? { + self.build_websocket() + } else { + Err(Error::IllegalWebsocketUpgrade()) + } + } + + /// Build socket with only a websocket transport + pub fn build_websocket(mut self) -> Result { + // SAFETY: Already a Url + let url = url::Url::parse(self.url.as_ref())?; + + let headers: Option = if let Some(map) = self.headers.clone() { + Some(map.try_into()?) + } else { + None + }; + + match url.scheme() { + "http" | "ws" => { + let transport = WebsocketTransport::new(url, headers)?; + if self.handshake.is_some() { + transport.upgrade()?; + } else { + self.handshake_with_transport(&transport)?; + } + // NOTE: Although self.url contains the sid, it does not propagate to the transport + // SAFETY: handshake function called previously. + Ok(Client { + socket: InnerSocket::new( + transport.into(), + self.handshake.unwrap(), + self.on_close, + self.on_data, + self.on_error, + self.on_open, + self.on_packet, + ), + }) + } + "https" | "wss" => { + let transport = + WebsocketSecureTransport::new(url, self.tls_config.clone(), headers)?; + if self.handshake.is_some() { + transport.upgrade()?; + } else { + self.handshake_with_transport(&transport)?; + } + // NOTE: Although self.url contains the sid, it does not propagate to the transport + // SAFETY: handshake function called previously. + Ok(Client { + socket: InnerSocket::new( + transport.into(), + self.handshake.unwrap(), + self.on_close, + self.on_data, + self.on_error, + self.on_open, + self.on_packet, + ), + }) + } + _ => Err(Error::InvalidUrlScheme(url.scheme().to_string())), + } + } + + /// Build websocket if allowed, if not allowed or errored fall back to polling. + /// WARNING: websocket errors suppressed, no indication of websocket success or failure. + pub fn build_with_fallback(self) -> Result { + let result = self.clone().build(); + if result.is_err() { + self.build_polling() + } else { + result + } + } + + /// Checks the handshake to see if websocket upgrades are allowed + fn websocket_upgrade(&mut self) -> Result { + // SAFETY: handshake set by above function. + Ok(self + .handshake + .as_ref() + .unwrap() + .upgrades + .iter() + .any(|upgrade| upgrade.to_lowercase() == *"websocket")) + } +} + +impl Client { + pub fn close(&self) -> Result<()> { + self.socket.disconnect() + } + + /// Opens the connection to a specified server. The first Pong packet is sent + /// to the server to trigger the Ping-cycle. + pub fn connect(&self) -> Result<()> { + self.socket.connect() + } + + /// Disconnects the connection. + pub fn disconnect(&self) -> Result<()> { + self.socket.disconnect() + } + + /// Sends a packet to the server. + pub fn emit(&self, packet: Packet) -> Result<()> { + self.socket.emit(packet) + } + + /// Polls for next payload + #[doc(hidden)] + pub fn poll(&self) -> Result> { + let packet = self.socket.poll()?; + if let Some(packet) = packet { + // check for the appropriate action or callback + self.socket.handle_packet(packet.clone()); + match packet.packet_id { + PacketId::MessageBinary => { + self.socket.handle_data(packet.data.clone()); + } + PacketId::Message => { + self.socket.handle_data(packet.data.clone()); + } + PacketId::Close => { + self.socket.handle_close(); + } + PacketId::Open => { + unreachable!("Won't happen as we open the connection beforehand"); + } + PacketId::Upgrade => { + // this is already checked during the handshake, so just do nothing here + } + PacketId::Ping => { + self.socket.pinged()?; + self.emit(Packet::new(PacketId::Pong, Bytes::new()))?; + } + PacketId::Pong => { + // this will never happen as the pong packet is + // only sent by the client + unreachable!(); + } + PacketId::Noop => (), + } + Ok(Some(packet)) + } else { + Ok(None) + } + } + + /// Check if the underlying transport client is connected. + pub fn is_connected(&self) -> Result { + self.socket.is_connected() + } + + pub fn iter(&self) -> Iter { + Iter { socket: self } + } +} + +#[derive(Clone)] +pub struct Iter<'a> { + socket: &'a Client, +} + +impl<'a> Iterator for Iter<'a> { + type Item = Result; + fn next(&mut self) -> std::option::Option<::Item> { + match self.socket.poll() { + Ok(Some(packet)) => Some(Ok(packet)), + Ok(None) => None, + Err(err) => Some(Err(err)), + } + } +} + +#[cfg(test)] +mod test { + + use crate::packet::PacketId; + + use super::*; + + /// The purpose of this test is to check whether the Client is properly cloneable or not. + /// As the documentation of the engine.io client states, the object needs to maintain it's internal + /// state when cloned and the cloned object should reflect the same state throughout the lifetime + /// of both objects (initial and cloned). + #[test] + fn test_client_cloneable() -> Result<()> { + let url = crate::test::engine_io_server()?; + let sut = builder(url).build()?; + + let cloned = sut.clone(); + + sut.connect()?; + + // when the underlying socket is connected, the + // state should also change on the cloned one + assert!(sut.is_connected()?); + assert!(cloned.is_connected()?); + + // both clients should reflect the same messages. + let mut iter = sut + .iter() + .map(|packet| packet.unwrap()) + .filter(|packet| packet.packet_id != PacketId::Ping); + + let mut iter_cloned = cloned + .iter() + .map(|packet| packet.unwrap()) + .filter(|packet| packet.packet_id != PacketId::Ping); + + assert_eq!( + iter.next(), + Some(Packet::new(PacketId::Message, "hello client")) + ); + + sut.emit(Packet::new(PacketId::Message, "respond"))?; + + assert_eq!( + iter_cloned.next(), + Some(Packet::new(PacketId::Message, "Roger Roger")) + ); + + cloned.disconnect()?; + + // when the underlying socket is disconnected, the + // state should also change on the cloned one + assert!(!sut.is_connected()?); + assert!(!cloned.is_connected()?); + + Ok(()) + } + + #[test] + fn test_illegal_actions() -> Result<()> { + let url = crate::test::engine_io_server()?; + let sut = builder(url.clone()).build()?; + + assert!(sut + .emit(Packet::new(PacketId::Close, Bytes::new())) + .is_err()); + + sut.connect()?; + + assert!(sut.poll().is_ok()); + + assert!(builder(Url::parse("fake://fake.fake").unwrap()) + .build_websocket() + .is_err()); + + Ok(()) + } + use reqwest::header::HOST; + + use crate::packet::Packet; + + fn builder(url: Url) -> ClientBuilder { + ClientBuilder::new(url) + .on_open(|_| { + println!("Open event!"); + }) + .on_packet(|packet| { + println!("Received packet: {:?}", packet); + }) + .on_data(|data| { + println!("Received data: {:?}", std::str::from_utf8(&data)); + }) + .on_close(|_| { + println!("Close event!"); + }) + .on_error(|error| { + println!("Error {}", error); + }) + } + + fn test_connection(socket: Client) -> Result<()> { + let socket = socket; + + socket.connect().unwrap(); + + // TODO: 0.3.X better tests + + let mut iter = socket + .iter() + .map(|packet| packet.unwrap()) + .filter(|packet| packet.packet_id != PacketId::Ping); + + assert_eq!( + iter.next(), + Some(Packet::new(PacketId::Message, "hello client")) + ); + + socket.emit(Packet::new(PacketId::Message, "respond"))?; + + assert_eq!( + iter.next(), + Some(Packet::new(PacketId::Message, "Roger Roger")) + ); + + socket.close() + } + + #[test] + fn test_connection_long() -> Result<()> { + // Long lived socket to receive pings + let url = crate::test::engine_io_server()?; + let socket = builder(url).build()?; + + socket.connect()?; + + let mut iter = socket.iter(); + // hello client + iter.next(); + // Ping + iter.next(); + + socket.disconnect()?; + + assert!(!socket.is_connected()?); + + Ok(()) + } + + #[test] + fn test_connection_dynamic() -> Result<()> { + let url = crate::test::engine_io_server()?; + let socket = builder(url).build()?; + test_connection(socket)?; + + let url = crate::test::engine_io_polling_server()?; + let socket = builder(url).build()?; + test_connection(socket) + } + + #[test] + fn test_connection_fallback() -> Result<()> { + let url = crate::test::engine_io_server()?; + let socket = builder(url).build_with_fallback()?; + test_connection(socket)?; + + let url = crate::test::engine_io_polling_server()?; + let socket = builder(url).build_with_fallback()?; + test_connection(socket) + } + + #[test] + fn test_connection_dynamic_secure() -> Result<()> { + let url = crate::test::engine_io_server_secure()?; + let mut builder = builder(url); + builder = builder.tls_config(crate::test::tls_connector()?); + let socket = builder.build()?; + test_connection(socket) + } + + #[test] + fn test_connection_polling() -> Result<()> { + let url = crate::test::engine_io_server()?; + let socket = builder(url).build_polling()?; + test_connection(socket) + } + + #[test] + fn test_connection_wss() -> Result<()> { + let url = crate::test::engine_io_polling_server()?; + assert!(builder(url).build_websocket_with_upgrade().is_err()); + + let host = + std::env::var("ENGINE_IO_SECURE_HOST").unwrap_or_else(|_| "localhost".to_owned()); + let mut url = crate::test::engine_io_server_secure()?; + + let mut headers = HeaderMap::default(); + headers.insert(HOST, host); + let mut builder = builder(url.clone()); + + builder = builder.tls_config(crate::test::tls_connector()?); + builder = builder.headers(headers.clone()); + let socket = builder.clone().build_websocket_with_upgrade()?; + + test_connection(socket)?; + + let socket = builder.build_websocket()?; + + test_connection(socket)?; + + url.set_scheme("wss").unwrap(); + + let builder = self::builder(url) + .tls_config(crate::test::tls_connector()?) + .headers(headers); + let socket = builder.clone().build_websocket()?; + + test_connection(socket)?; + + assert!(builder.build_websocket_with_upgrade().is_err()); + + Ok(()) + } + + #[test] + fn test_connection_ws() -> Result<()> { + let url = crate::test::engine_io_polling_server()?; + assert!(builder(url.clone()).build_websocket().is_err()); + assert!(builder(url).build_websocket_with_upgrade().is_err()); + + let mut url = crate::test::engine_io_server()?; + + let builder = builder(url.clone()); + let socket = builder.clone().build_websocket()?; + test_connection(socket)?; + + let socket = builder.build_websocket_with_upgrade()?; + test_connection(socket)?; + + url.set_scheme("ws").unwrap(); + + let builder = self::builder(url); + let socket = builder.clone().build_websocket()?; + + test_connection(socket)?; + + assert!(builder.build_websocket_with_upgrade().is_err()); + + Ok(()) + } + + #[test] + fn test_open_invariants() -> Result<()> { + let url = crate::test::engine_io_server()?; + let illegal_url = "this is illegal"; + + assert!(Url::parse(illegal_url).is_err()); + + let invalid_protocol = "file:///tmp/foo"; + assert!(builder(Url::parse(invalid_protocol).unwrap()) + .build() + .is_err()); + + let sut = builder(url.clone()).build()?; + let _error = sut + .emit(Packet::new(PacketId::Close, Bytes::new())) + .expect_err("error"); + assert!(matches!(Error::IllegalActionBeforeOpen(), _error)); + + // test missing match arm in socket constructor + let mut headers = HeaderMap::default(); + let host = + std::env::var("ENGINE_IO_SECURE_HOST").unwrap_or_else(|_| "localhost".to_owned()); + headers.insert(HOST, host); + + let _ = builder(url.clone()) + .tls_config( + TlsConnector::builder() + .danger_accept_invalid_certs(true) + .build() + .unwrap(), + ) + .build()?; + let _ = builder(url).headers(headers).build()?; + Ok(()) + } +} diff --git a/engineio/src/client/mod.rs b/engineio/src/client/mod.rs index 894ba60e..04bc91bc 100644 --- a/engineio/src/client/mod.rs +++ b/engineio/src/client/mod.rs @@ -1,661 +1,4 @@ -use super::socket::Socket as InnerSocket; -use crate::callback::OptionalCallback; -use crate::error::{Error, Result}; -use crate::header::HeaderMap; -use crate::packet::{HandshakePacket, Packet, PacketId}; -use crate::socket::DEFAULT_MAX_POLL_TIMEOUT; -use crate::transport::Transport; -use crate::transports::{PollingTransport, WebsocketSecureTransport, WebsocketTransport}; -use crate::ENGINE_IO_VERSION; -use bytes::Bytes; -use native_tls::TlsConnector; -use std::convert::TryFrom; -use std::convert::TryInto; -use std::fmt::Debug; -use url::Url; - -/// An engine.io client that allows interaction with the connected engine.io -/// server. This client provides means for connecting, disconnecting and sending -/// packets to the server. -/// -/// ## Note: -/// There is no need to put this Client behind an `Arc`, as the type uses `Arc` -/// internally and provides a shared state beyond all cloned instances. -#[derive(Clone, Debug)] -pub struct Client { - socket: InnerSocket, -} - -#[derive(Clone, Debug)] -pub struct ClientBuilder { - url: Url, - tls_config: Option, - headers: Option, - handshake: Option, - on_error: OptionalCallback, - on_open: OptionalCallback<()>, - on_close: OptionalCallback<()>, - on_data: OptionalCallback, - on_packet: OptionalCallback, -} - -impl ClientBuilder { - pub fn new(url: Url) -> Self { - let mut url = url; - url.query_pairs_mut() - .append_pair("EIO", &ENGINE_IO_VERSION.to_string()); - - // No path add engine.io - if url.path() == "/" { - url.set_path("/engine.io/"); - } - ClientBuilder { - url, - headers: None, - tls_config: None, - handshake: None, - on_close: OptionalCallback::default(), - on_data: OptionalCallback::default(), - on_error: OptionalCallback::default(), - on_open: OptionalCallback::default(), - on_packet: OptionalCallback::default(), - } - } - - /// Specify transport's tls config - pub fn tls_config(mut self, tls_config: TlsConnector) -> Self { - self.tls_config = Some(tls_config); - self - } - - /// Specify transport's HTTP headers - pub fn headers(mut self, headers: HeaderMap) -> Self { - self.headers = Some(headers); - self - } - - /// Registers the `on_close` callback. - pub fn on_close(mut self, callback: T) -> Self - where - T: Fn(()) + 'static + Sync + Send, - { - self.on_close = OptionalCallback::new(callback); - self - } - - /// Registers the `on_data` callback. - pub fn on_data(mut self, callback: T) -> Self - where - T: Fn(Bytes) + 'static + Sync + Send, - { - self.on_data = OptionalCallback::new(callback); - self - } - - /// Registers the `on_error` callback. - pub fn on_error(mut self, callback: T) -> Self - where - T: Fn(String) + 'static + Sync + Send, - { - self.on_error = OptionalCallback::new(callback); - self - } - - /// Registers the `on_open` callback. - pub fn on_open(mut self, callback: T) -> Self - where - T: Fn(()) + 'static + Sync + Send, - { - self.on_open = OptionalCallback::new(callback); - self - } - - /// Registers the `on_packet` callback. - pub fn on_packet(mut self, callback: T) -> Self - where - T: Fn(Packet) + 'static + Sync + Send, - { - self.on_packet = OptionalCallback::new(callback); - self - } - - /// Performs the handshake - fn handshake_with_transport(&mut self, transport: &T) -> Result<()> { - // No need to handshake twice - if self.handshake.is_some() { - return Ok(()); - } - - let mut url = self.url.clone(); - - let handshake: HandshakePacket = - Packet::try_from(transport.poll(DEFAULT_MAX_POLL_TIMEOUT)?)?.try_into()?; - - // update the base_url with the new sid - url.query_pairs_mut().append_pair("sid", &handshake.sid[..]); - - self.handshake = Some(handshake); - - self.url = url; - - Ok(()) - } - - fn handshake(&mut self) -> Result<()> { - if self.handshake.is_some() { - return Ok(()); - } - - // Start with polling transport - let transport = PollingTransport::new( - self.url.clone(), - self.tls_config.clone(), - self.headers.clone().map(|v| v.try_into().unwrap()), - ); - - self.handshake_with_transport(&transport) - } - - /// Build websocket if allowed, if not fall back to polling - pub fn build(mut self) -> Result { - self.handshake()?; - - if self.websocket_upgrade()? { - self.build_websocket_with_upgrade() - } else { - self.build_polling() - } - } - - /// Build socket with polling transport - pub fn build_polling(mut self) -> Result { - self.handshake()?; - - // Make a polling transport with new sid - let transport = PollingTransport::new( - self.url, - self.tls_config, - self.headers.map(|v| v.try_into().unwrap()), - ); - - // SAFETY: handshake function called previously. - Ok(Client { - socket: InnerSocket::new( - transport.into(), - self.handshake.unwrap(), - self.on_close, - self.on_data, - self.on_error, - self.on_open, - self.on_packet, - ), - }) - } - - /// Build socket with a polling transport then upgrade to websocket transport - pub fn build_websocket_with_upgrade(mut self) -> Result { - self.handshake()?; - - if self.websocket_upgrade()? { - self.build_websocket() - } else { - Err(Error::IllegalWebsocketUpgrade()) - } - } - - /// Build socket with only a websocket transport - pub fn build_websocket(mut self) -> Result { - // SAFETY: Already a Url - let url = url::Url::parse(self.url.as_ref())?; - - let headers: Option = if let Some(map) = self.headers.clone() { - Some(map.try_into()?) - } else { - None - }; - - match url.scheme() { - "http" | "ws" => { - let transport = WebsocketTransport::new(url, headers)?; - if self.handshake.is_some() { - transport.upgrade()?; - } else { - self.handshake_with_transport(&transport)?; - } - // NOTE: Although self.url contains the sid, it does not propagate to the transport - // SAFETY: handshake function called previously. - Ok(Client { - socket: InnerSocket::new( - transport.into(), - self.handshake.unwrap(), - self.on_close, - self.on_data, - self.on_error, - self.on_open, - self.on_packet, - ), - }) - } - "https" | "wss" => { - let transport = - WebsocketSecureTransport::new(url, self.tls_config.clone(), headers)?; - if self.handshake.is_some() { - transport.upgrade()?; - } else { - self.handshake_with_transport(&transport)?; - } - // NOTE: Although self.url contains the sid, it does not propagate to the transport - // SAFETY: handshake function called previously. - Ok(Client { - socket: InnerSocket::new( - transport.into(), - self.handshake.unwrap(), - self.on_close, - self.on_data, - self.on_error, - self.on_open, - self.on_packet, - ), - }) - } - _ => Err(Error::InvalidUrlScheme(url.scheme().to_string())), - } - } - - /// Build websocket if allowed, if not allowed or errored fall back to polling. - /// WARNING: websocket errors suppressed, no indication of websocket success or failure. - pub fn build_with_fallback(self) -> Result { - let result = self.clone().build(); - if result.is_err() { - self.build_polling() - } else { - result - } - } - - /// Checks the handshake to see if websocket upgrades are allowed - fn websocket_upgrade(&mut self) -> Result { - // SAFETY: handshake set by above function. - Ok(self - .handshake - .as_ref() - .unwrap() - .upgrades - .iter() - .any(|upgrade| upgrade.to_lowercase() == *"websocket")) - } -} - -impl Client { - pub fn close(&self) -> Result<()> { - self.socket.disconnect() - } - - /// Opens the connection to a specified server. The first Pong packet is sent - /// to the server to trigger the Ping-cycle. - pub fn connect(&self) -> Result<()> { - self.socket.connect() - } - - /// Disconnects the connection. - pub fn disconnect(&self) -> Result<()> { - self.socket.disconnect() - } - - /// Sends a packet to the server. - pub fn emit(&self, packet: Packet) -> Result<()> { - self.socket.emit(packet) - } - - /// Polls for next payload - #[doc(hidden)] - pub fn poll(&self) -> Result> { - let packet = self.socket.poll()?; - if let Some(packet) = packet { - // check for the appropriate action or callback - self.socket.handle_packet(packet.clone()); - match packet.packet_id { - PacketId::MessageBinary => { - self.socket.handle_data(packet.data.clone()); - } - PacketId::Message => { - self.socket.handle_data(packet.data.clone()); - } - PacketId::Close => { - self.socket.handle_close(); - } - PacketId::Open => { - unreachable!("Won't happen as we open the connection beforehand"); - } - PacketId::Upgrade => { - // this is already checked during the handshake, so just do nothing here - } - PacketId::Ping => { - self.socket.pinged()?; - self.emit(Packet::new(PacketId::Pong, Bytes::new()))?; - } - PacketId::Pong => { - // this will never happen as the pong packet is - // only sent by the client - unreachable!(); - } - PacketId::Noop => (), - } - Ok(Some(packet)) - } else { - Ok(None) - } - } - - /// Check if the underlying transport client is connected. - pub fn is_connected(&self) -> Result { - self.socket.is_connected() - } - - pub fn iter(&self) -> Iter { - Iter { socket: self } - } -} - -#[derive(Clone)] -pub struct Iter<'a> { - socket: &'a Client, -} - -impl<'a> Iterator for Iter<'a> { - type Item = Result; - fn next(&mut self) -> std::option::Option<::Item> { - match self.socket.poll() { - Ok(Some(packet)) => Some(Ok(packet)), - Ok(None) => None, - Err(err) => Some(Err(err)), - } - } -} - -#[cfg(test)] -mod test { - - use crate::packet::PacketId; - - use super::*; - - /// The purpose of this test is to check whether the Client is properly cloneable or not. - /// As the documentation of the engine.io client states, the object needs to maintain it's internal - /// state when cloned and the cloned object should reflect the same state throughout the lifetime - /// of both objects (initial and cloned). - #[test] - fn test_client_cloneable() -> Result<()> { - let url = crate::test::engine_io_server()?; - let sut = builder(url).build()?; - - let cloned = sut.clone(); - - sut.connect()?; - - // when the underlying socket is connected, the - // state should also change on the cloned one - assert!(sut.is_connected()?); - assert!(cloned.is_connected()?); - - // both clients should reflect the same messages. - let mut iter = sut - .iter() - .map(|packet| packet.unwrap()) - .filter(|packet| packet.packet_id != PacketId::Ping); - - let mut iter_cloned = cloned - .iter() - .map(|packet| packet.unwrap()) - .filter(|packet| packet.packet_id != PacketId::Ping); - - assert_eq!( - iter.next(), - Some(Packet::new(PacketId::Message, "hello client")) - ); - - sut.emit(Packet::new(PacketId::Message, "respond"))?; - - assert_eq!( - iter_cloned.next(), - Some(Packet::new(PacketId::Message, "Roger Roger")) - ); - - cloned.disconnect()?; - - // when the underlying socket is disconnected, the - // state should also change on the cloned one - assert!(!sut.is_connected()?); - assert!(!cloned.is_connected()?); - - Ok(()) - } - - #[test] - fn test_illegal_actions() -> Result<()> { - let url = crate::test::engine_io_server()?; - let sut = builder(url.clone()).build()?; - - assert!(sut - .emit(Packet::new(PacketId::Close, Bytes::new())) - .is_err()); - - sut.connect()?; - - assert!(sut.poll().is_ok()); - - assert!(builder(Url::parse("fake://fake.fake").unwrap()) - .build_websocket() - .is_err()); - - Ok(()) - } - use reqwest::header::HOST; - - use crate::packet::Packet; - - fn builder(url: Url) -> ClientBuilder { - ClientBuilder::new(url) - .on_open(|_| { - println!("Open event!"); - }) - .on_packet(|packet| { - println!("Received packet: {:?}", packet); - }) - .on_data(|data| { - println!("Received data: {:?}", std::str::from_utf8(&data)); - }) - .on_close(|_| { - println!("Close event!"); - }) - .on_error(|error| { - println!("Error {}", error); - }) - } - - fn test_connection(socket: Client) -> Result<()> { - let socket = socket; - - socket.connect().unwrap(); - - // TODO: 0.3.X better tests - - let mut iter = socket - .iter() - .map(|packet| packet.unwrap()) - .filter(|packet| packet.packet_id != PacketId::Ping); - - assert_eq!( - iter.next(), - Some(Packet::new(PacketId::Message, "hello client")) - ); - - socket.emit(Packet::new(PacketId::Message, "respond"))?; - - assert_eq!( - iter.next(), - Some(Packet::new(PacketId::Message, "Roger Roger")) - ); - - socket.close() - } - - #[test] - fn test_connection_long() -> Result<()> { - // Long lived socket to receive pings - let url = crate::test::engine_io_server()?; - let socket = builder(url).build()?; - - socket.connect()?; - - let mut iter = socket.iter(); - // hello client - iter.next(); - // Ping - iter.next(); - - socket.disconnect()?; - - assert!(!socket.is_connected()?); - - Ok(()) - } - - #[test] - fn test_connection_dynamic() -> Result<()> { - let url = crate::test::engine_io_server()?; - let socket = builder(url).build()?; - test_connection(socket)?; - - let url = crate::test::engine_io_polling_server()?; - let socket = builder(url).build()?; - test_connection(socket) - } - - #[test] - fn test_connection_fallback() -> Result<()> { - let url = crate::test::engine_io_server()?; - let socket = builder(url).build_with_fallback()?; - test_connection(socket)?; - - let url = crate::test::engine_io_polling_server()?; - let socket = builder(url).build_with_fallback()?; - test_connection(socket) - } - - #[test] - fn test_connection_dynamic_secure() -> Result<()> { - let url = crate::test::engine_io_server_secure()?; - let mut builder = builder(url); - builder = builder.tls_config(crate::test::tls_connector()?); - let socket = builder.build()?; - test_connection(socket) - } - - #[test] - fn test_connection_polling() -> Result<()> { - let url = crate::test::engine_io_server()?; - let socket = builder(url).build_polling()?; - test_connection(socket) - } - - #[test] - fn test_connection_wss() -> Result<()> { - let url = crate::test::engine_io_polling_server()?; - assert!(builder(url).build_websocket_with_upgrade().is_err()); - - let host = - std::env::var("ENGINE_IO_SECURE_HOST").unwrap_or_else(|_| "localhost".to_owned()); - let mut url = crate::test::engine_io_server_secure()?; - - let mut headers = HeaderMap::default(); - headers.insert(HOST, host); - let mut builder = builder(url.clone()); - - builder = builder.tls_config(crate::test::tls_connector()?); - builder = builder.headers(headers.clone()); - let socket = builder.clone().build_websocket_with_upgrade()?; - - test_connection(socket)?; - - let socket = builder.build_websocket()?; - - test_connection(socket)?; - - url.set_scheme("wss").unwrap(); - - let builder = self::builder(url) - .tls_config(crate::test::tls_connector()?) - .headers(headers); - let socket = builder.clone().build_websocket()?; - - test_connection(socket)?; - - assert!(builder.build_websocket_with_upgrade().is_err()); - - Ok(()) - } - - #[test] - fn test_connection_ws() -> Result<()> { - let url = crate::test::engine_io_polling_server()?; - assert!(builder(url.clone()).build_websocket().is_err()); - assert!(builder(url).build_websocket_with_upgrade().is_err()); - - let mut url = crate::test::engine_io_server()?; - - let builder = builder(url.clone()); - let socket = builder.clone().build_websocket()?; - test_connection(socket)?; - - let socket = builder.build_websocket_with_upgrade()?; - test_connection(socket)?; - - url.set_scheme("ws").unwrap(); - - let builder = self::builder(url); - let socket = builder.clone().build_websocket()?; - - test_connection(socket)?; - - assert!(builder.build_websocket_with_upgrade().is_err()); - - Ok(()) - } - - #[test] - fn test_open_invariants() -> Result<()> { - let url = crate::test::engine_io_server()?; - let illegal_url = "this is illegal"; - - assert!(Url::parse(illegal_url).is_err()); - - let invalid_protocol = "file:///tmp/foo"; - assert!(builder(Url::parse(invalid_protocol).unwrap()) - .build() - .is_err()); - - let sut = builder(url.clone()).build()?; - let _error = sut - .emit(Packet::new(PacketId::Close, Bytes::new())) - .expect_err("error"); - assert!(matches!(Error::IllegalActionBeforeOpen(), _error)); - - // test missing match arm in socket constructor - let mut headers = HeaderMap::default(); - let host = - std::env::var("ENGINE_IO_SECURE_HOST").unwrap_or_else(|_| "localhost".to_owned()); - headers.insert(HOST, host); - - let _ = builder(url.clone()) - .tls_config( - TlsConnector::builder() - .danger_accept_invalid_certs(true) - .build() - .unwrap(), - ) - .build()?; - let _ = builder(url).headers(headers).build()?; - Ok(()) - } -} +#![allow(clippy::module_inception)] +mod client; +pub use client::Iter; +pub use {client::Client, client::ClientBuilder, client::Iter as SocketIter}; diff --git a/socketio/examples/async_transmitter.rs b/socketio/examples/async_transmitter.rs index 415a2a3c..25d4ccfe 100644 --- a/socketio/examples/async_transmitter.rs +++ b/socketio/examples/async_transmitter.rs @@ -1,13 +1,38 @@ -use futures_util::FutureExt; +use futures_util::future::{BoxFuture, FutureExt}; use rust_socketio::{ asynchronous::{Client as SocketIOClient, ClientBuilder as SocketIOClientBuilder}, Error as SocketIOError, Payload, }; -use serde_json::json; +use serde_json::{json, Value}; use std::sync::{mpsc, Arc}; use std::time::Duration; use tokio::time::sleep; +type JsonValues = Vec; + +fn test_event_handler<'event>(payload: Payload, socket: SocketIOClient) -> BoxFuture<'event, ()> { + async move { + if let Payload::Text(values) = payload { + match socket.try_transmitter::>() { + Ok(tx) => { + tx.send(values.to_owned()).map_or_else( + |err| eprintln!("{}", err), + |_| println!("Data transmitted successfully"), + ); + } + Err(err) => { + eprintln!("{}", err); + } + } + } + } + .boxed() +} + +fn error_event_handler<'event>(payload: Payload, _: SocketIOClient) -> BoxFuture<'event, ()> { + async move { eprintln!("Error: {:#?}", payload) }.boxed() +} + struct ComplexData { /// There should be many more fields below in real life, /// probaly wrapped in Arc> if you're writing a more serious client. @@ -15,52 +40,19 @@ struct ComplexData { } struct TransmitterClient { - receiver: mpsc::Receiver, + receiver: mpsc::Receiver, complex: ComplexData, client: SocketIOClient, } impl TransmitterClient { async fn connect(url: &str) -> Result { - let (sender, receiver) = mpsc::channel::(); + let (sender, receiver) = mpsc::channel::(); let client = SocketIOClientBuilder::new(url) .namespace("/admin") - .on("test", |payload: Payload, socket: SocketIOClient| { - async move { - match payload { - Payload::Text(values) => { - if let Some(value) = values.first() { - if value.is_string() { - socket - .try_transmitter::>() - .map_or_else( - |err| eprintln!("{}", err), - |tx| { - tx.send(String::from(value.as_str().unwrap())) - .map_or_else( - |err| eprintln!("{}", err), - |_| { - println!( - "Data transmitted successfully" - ) - }, - ); - }, - ); - } - } - } - Payload::Binary(bin_data) => println!("Binary data: {:#?}", bin_data), - #[allow(deprecated)] - Payload::String(str) => println!("Received: {}", str), - } - } - .boxed() - }) - .on("error", |err, _| { - async move { eprintln!("Error: {:#?}", err) }.boxed() - }) + .on("test", test_event_handler) + .on("error", error_event_handler) .transmitter(Arc::new(sender)) .connect() .await?; @@ -69,7 +61,7 @@ impl TransmitterClient { client, receiver, complex: ComplexData { - data: "".to_string(), + data: String::from(""), }, }) } @@ -78,12 +70,16 @@ impl TransmitterClient { match self.client.emit("test", json!({"got ack": true})).await { Ok(_) => { match self.receiver.recv() { - Ok(complex_data) => { - // In the real world the data is probably a serialized json_rpc object - // or some other complex data layer which needs complex business and derserialization logic. - // Best to do that here, and not inside those restrictive callbacks. - self.complex.data = complex_data; - Some(self.complex.data.clone()) + Ok(values) => { + // Json deserialization and parsing business logic should be implemented + // here to avoid over-complicating the handler callbacks. + if let Some(value) = values.first() { + if value.is_string() { + self.complex.data = String::from(value.as_str().unwrap()); + return Some(self.complex.data.clone()); + } + } + None } Err(err) => { eprintln!("Transmission buffer is probably full: {}", err); diff --git a/socketio/examples/sync_transmitter.rs b/socketio/examples/sync_transmitter.rs index 7d5df54e..8e49ddea 100644 --- a/socketio/examples/sync_transmitter.rs +++ b/socketio/examples/sync_transmitter.rs @@ -2,11 +2,33 @@ use rust_socketio::{ client::Client as SocketIOClient, ClientBuilder as SocketIOClientBuilder, Error as SocketIOError, Payload, RawClient, }; -use serde_json::json; +use serde_json::{json, Value}; use std::sync::{mpsc, Arc}; use std::thread::sleep; use std::time::Duration; +type JsonValues = Vec; + +fn test_event_handler(payload: Payload, socket: RawClient) { + if let Payload::Text(values) = payload { + match socket.try_transmitter::>() { + Ok(tx) => { + tx.send(values.to_owned()).map_or_else( + |err| eprintln!("{}", err), + |_| println!("Data transmitted successfully"), + ); + } + Err(err) => { + eprintln!("{}", err); + } + } + } +} + +fn error_event_handler(payload: Payload, _: RawClient) { + eprintln!("Error: {:#?}", payload); +} + struct ComplexData { /// There should be many more fields below in real life, /// probaly wrapped in Arc> if you're writing a more serious client. @@ -15,43 +37,18 @@ struct ComplexData { struct TransmitterClient { client: SocketIOClient, - receiver: mpsc::Receiver, + receiver: mpsc::Receiver, complex: ComplexData, } impl TransmitterClient { fn connect(url: &str) -> Result { - let (sender, receiver) = mpsc::channel::(); + let (sender, receiver) = mpsc::channel::(); let client = SocketIOClientBuilder::new(url) .namespace("/admin") - .on( - "test", - |payload: Payload, socket: RawClient| match payload { - Payload::Text(values) => { - if let Some(value) = values.first() { - if value.is_string() { - socket - .try_transmitter::>() - .map_or_else( - |err| eprintln!("{}", err), - |tx| { - tx.send(String::from(value.as_str().unwrap())) - .map_or_else( - |err| eprintln!("{}", err), - |_| println!("Data transmitted successfully"), - ); - }, - ); - } - } - } - Payload::Binary(bin_data) => println!("Received bytes: {:#?}", bin_data), - #[allow(deprecated)] - Payload::String(str) => println!("Received: {}", str), - }, - ) - .on("error", |err, _| eprintln!("Error: {:#?}", err)) + .on("test", test_event_handler) + .on("error", error_event_handler) .transmitter(Arc::new(sender)) .connect()?; @@ -68,12 +65,16 @@ impl TransmitterClient { match self.client.emit("test", json!({"got ack": true})) { Ok(_) => { match self.receiver.recv() { - Ok(complex_data) => { - // In the real world the data is probably a serialized json_rpc object - // or some other complex data layer which needs complex business and derserialization logic. - // Best to do that here, and not inside those restrictive callbacks. - self.complex.data = complex_data; - Some(self.complex.data.clone()) + Ok(values) => { + // Json deserialization and parsing business logic should be implemented + // here to avoid over-complicating the handler callbacks. + if let Some(value) = values.first() { + if value.is_string() { + self.complex.data = String::from(value.as_str().unwrap()); + return Some(self.complex.data.clone()); + } + } + None } Err(err) => { eprintln!("Transmission buffer is probably full: {}", err); diff --git a/socketio/src/asynchronous/client/async_client.rs b/socketio/src/asynchronous/client/async_client.rs index 00738eb2..288d7b71 100644 --- a/socketio/src/asynchronous/client/async_client.rs +++ b/socketio/src/asynchronous/client/async_client.rs @@ -111,40 +111,32 @@ impl Client { /// /// # Example /// - /// ```no_run - /// use futures_util::FutureExt; + /// ```rust + /// use futures_util::future::{BoxFuture, FutureExt}; /// use std::sync::{Arc, mpsc}; /// use rust_socketio::{ /// asynchronous::Client, /// Payload, /// }; /// - /// let callback = | payload: Payload, socket: Client | { + /// fn event_handler<'event>(payload: Payload, socket: Client) -> BoxFuture<'event, ()> { /// async move { - /// match payload { - /// Payload::Text(values) => { - /// if let Some(value) = values.first() { - /// if value.is_string() { - /// socket.try_transmitter::>().map_or_else( - /// |err| eprintln!("{}", err), - /// |tx| { - /// tx.send(String::from(value.as_str().unwrap())) - /// .map_or_else( - /// |err| eprintln!("{}", err), - /// |_| println!("Data transmitted successfully"), - /// ); - /// }, - /// ); - /// } + /// if let Payload::Text(values) = payload { + /// match socket.try_transmitter::>>() { + /// Ok(tx) => { + /// tx.send(values.to_owned()).map_or_else( + /// |err| eprintln!("{}", err), + /// |_| println!("Data transmitted successfully"), + /// ); + /// } + /// Err(err) => { + /// eprintln!("{}", err); /// } /// } - /// Payload::Binary(bin_data) => println!("{:#?}", bin_data), - /// #[allow(deprecated)] - /// Payload::String(str) => println!("Received: {}", str), /// } /// } /// .boxed() - /// }; + /// } /// ``` pub fn try_transmitter(&self) -> Result> { match Arc::clone(&self.transmitter).downcast() { diff --git a/socketio/src/asynchronous/client/builder.rs b/socketio/src/asynchronous/client/builder.rs index 553cbf62..ed2d9e2a 100644 --- a/socketio/src/asynchronous/client/builder.rs +++ b/socketio/src/asynchronous/client/builder.rs @@ -106,16 +106,15 @@ impl ClientBuilder { /// /// # Example /// - /// ```no_run + /// ```rust /// use futures_util::FutureExt; /// use std::sync::{Arc, mpsc}; /// use rust_socketio::{ - /// asynchronous::{Client , ClientBuilder}, - /// Payload, Error, + /// asynchronous::{Client , ClientBuilder}, Error, /// }; /// /// async fn connect(url: &str) -> Result { - /// let (sender, receiver) = mpsc::channel::(); + /// let (sender, receiver) = mpsc::channel::>(); /// /// let client = ClientBuilder::new(url) /// .namespace("/admin") diff --git a/socketio/src/client/builder.rs b/socketio/src/client/builder.rs index 16c20b21..35c0bd10 100644 --- a/socketio/src/client/builder.rs +++ b/socketio/src/client/builder.rs @@ -104,7 +104,7 @@ impl ClientBuilder { /// Sets the data transmission object, ideally the standard libraries /// multi-producer single consumer [`std::sync::mpsc::Sender`] should be used. /// - /// ```no_run + /// ```rust /// use rust_socketio::{ /// client::Client, ClientBuilder, /// Error , Payload, RawClient, @@ -112,7 +112,7 @@ impl ClientBuilder { /// use std::sync::{Arc, mpsc}; /// /// fn connect(url: &str) -> Result { - /// let (sender, receiver) = mpsc::channel::(); + /// let (sender, receiver) = mpsc::channel::(); /// /// let client = ClientBuilder::new(url) /// .namespace("/admin") @@ -125,8 +125,8 @@ impl ClientBuilder { /// Ok(client) /// } /// ``` - pub fn transmitter(mut self, data: Arc) -> Self { - self.transmitter = Some(data); + pub fn transmitter(mut self, transmitter: Arc) -> Self { + self.transmitter = Some(transmitter); self } diff --git a/socketio/src/client/client.rs b/socketio/src/client/client.rs new file mode 100644 index 00000000..fe924307 --- /dev/null +++ b/socketio/src/client/client.rs @@ -0,0 +1,478 @@ +use std::{ + sync::{Arc, Mutex, RwLock}, + time::Duration, +}; + +use super::{ClientBuilder, RawClient}; +use crate::{ + error::Result, + packet::{Packet, PacketId}, + Error, +}; +pub(crate) use crate::{event::Event, payload::Payload}; +use backoff::ExponentialBackoff; +use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; + +#[derive(Clone)] +pub struct Client { + builder: Arc>, + client: Arc>, + backoff: ExponentialBackoff, +} + +impl Client { + pub(crate) fn new(builder: ClientBuilder) -> Result { + let builder_clone = builder.clone(); + let client = builder_clone.connect_raw()?; + let backoff = ExponentialBackoffBuilder::new() + .with_initial_interval(Duration::from_millis(builder.reconnect_delay_min)) + .with_max_interval(Duration::from_millis(builder.reconnect_delay_max)) + .build(); + + let s = Self { + builder: Arc::new(Mutex::new(builder)), + client: Arc::new(RwLock::new(client)), + backoff, + }; + s.poll_callback(); + + Ok(s) + } + + /// Updates the URL the client will connect to when reconnecting. + /// This is especially useful for updating query parameters. + pub fn set_reconnect_url>(&self, address: T) -> Result<()> { + self.builder.lock()?.address = address.into(); + Ok(()) + } + + /// Sends a message to the server using the underlying `engine.io` protocol. + /// This message takes an event, which could either be one of the common + /// events like "message" or "error" or a custom event like "foo". But be + /// careful, the data string needs to be valid JSON. It's recommended to use + /// a library like `serde_json` to serialize the data properly. + /// + /// # Example + /// ``` + /// use rust_socketio::{ClientBuilder, RawClient, Payload}; + /// use serde_json::json; + /// + /// let mut socket = ClientBuilder::new("http://localhost:4200/") + /// .on("test", |payload: Payload, socket: RawClient| { + /// println!("Received: {:#?}", payload); + /// socket.emit("test", json!({"hello": true})).expect("Server unreachable"); + /// }) + /// .connect() + /// .expect("connection failed"); + /// + /// let json_payload = json!({"token": 123}); + /// + /// let result = socket.emit("foo", json_payload); + /// + /// assert!(result.is_ok()); + /// ``` + pub fn emit(&self, event: E, data: D) -> Result<()> + where + E: Into, + D: Into, + { + let client = self.client.read()?; + // TODO(#230): like js client, buffer emit, resend after reconnect + client.emit(event, data) + } + + /// Sends a message to the server but `alloc`s an `ack` to check whether the + /// server responded in a given time span. This message takes an event, which + /// could either be one of the common events like "message" or "error" or a + /// custom event like "foo", as well as a data parameter. But be careful, + /// in case you send a [`Payload::String`], the string needs to be valid JSON. + /// It's even recommended to use a library like serde_json to serialize the data properly. + /// It also requires a timeout `Duration` in which the client needs to answer. + /// If the ack is acked in the correct time span, the specified callback is + /// called. The callback consumes a [`Payload`] which represents the data send + /// by the server. + /// + /// # Example + /// ``` + /// use rust_socketio::{ClientBuilder, Payload, RawClient}; + /// use serde_json::json; + /// use std::time::Duration; + /// use std::thread::sleep; + /// + /// let mut socket = ClientBuilder::new("http://localhost:4200/") + /// .on("foo", |payload: Payload, _| println!("Received: {:#?}", payload)) + /// .connect() + /// .expect("connection failed"); + /// + /// let ack_callback = |message: Payload, socket: RawClient| { + /// match message { + /// Payload::Text(values) => println!("{:#?}", values), + /// Payload::Binary(bytes) => println!("Received bytes: {:#?}", bytes), + /// // This is deprecated, use Payload::Text instead. + /// Payload::String(str) => println!("{}", str), + /// } + /// }; + /// + /// let payload = json!({"token": 123}); + /// socket.emit_with_ack("foo", payload, Duration::from_secs(2), ack_callback).unwrap(); + /// + /// sleep(Duration::from_secs(2)); + /// ``` + pub fn emit_with_ack( + &self, + event: E, + data: D, + timeout: Duration, + callback: F, + ) -> Result<()> + where + F: FnMut(Payload, RawClient) + 'static + Send, + E: Into, + D: Into, + { + let client = self.client.read()?; + // TODO(#230): like js client, buffer emit, resend after reconnect + client.emit_with_ack(event, data, timeout, callback) + } + + /// Disconnects this client from the server by sending a `socket.io` closing + /// packet. + /// # Example + /// ```rust + /// use rust_socketio::{ClientBuilder, Payload, RawClient}; + /// use serde_json::json; + /// + /// fn handle_test(payload: Payload, socket: RawClient) { + /// println!("Received: {:#?}", payload); + /// socket.emit("test", json!({"hello": true})).expect("Server unreachable"); + /// } + /// + /// let mut socket = ClientBuilder::new("http://localhost:4200/") + /// .on("test", handle_test) + /// .connect() + /// .expect("connection failed"); + /// + /// let json_payload = json!({"token": 123}); + /// + /// socket.emit("foo", json_payload); + /// + /// // disconnect from the server + /// socket.disconnect(); + /// + /// ``` + pub fn disconnect(&self) -> Result<()> { + let client = self.client.read()?; + client.disconnect() + } + + fn reconnect(&mut self) -> Result<()> { + let mut reconnect_attempts = 0; + let (reconnect, max_reconnect_attempts) = { + let builder = self.builder.lock()?; + (builder.reconnect, builder.max_reconnect_attempts) + }; + + if reconnect { + loop { + if let Some(max_reconnect_attempts) = max_reconnect_attempts { + reconnect_attempts += 1; + if reconnect_attempts > max_reconnect_attempts { + break; + } + } + + if let Some(backoff) = self.backoff.next_backoff() { + std::thread::sleep(backoff); + } + + if self.do_reconnect().is_ok() { + break; + } + } + } + + Ok(()) + } + + fn do_reconnect(&self) -> Result<()> { + let builder = self.builder.lock()?; + let new_client = builder.clone().connect_raw()?; + let mut client = self.client.write()?; + *client = new_client; + + Ok(()) + } + + pub(crate) fn iter(&self) -> Iter { + Iter { + socket: self.client.clone(), + } + } + + fn poll_callback(&self) { + let mut self_clone = self.clone(); + // Use thread to consume items in iterator in order to call callbacks + std::thread::spawn(move || { + // tries to restart a poll cycle whenever a 'normal' error occurs, + // it just panics on network errors, in case the poll cycle returned + // `Result::Ok`, the server receives a close frame so it's safe to + // terminate + for packet in self_clone.iter() { + let should_reconnect = match packet { + Err(Error::IncompleteResponseFromEngineIo(_)) => { + //TODO: 0.3.X handle errors + //TODO: logging error + true + } + Ok(Packet { + packet_type: PacketId::Disconnect, + .. + }) => match self_clone.builder.lock() { + Ok(builder) => builder.reconnect_on_disconnect, + Err(_) => false, + }, + _ => false, + }; + if should_reconnect { + let _ = self_clone.disconnect(); + let _ = self_clone.reconnect(); + } + } + }); + } +} + +pub(crate) struct Iter { + socket: Arc>, +} + +impl Iterator for Iter { + type Item = Result; + + fn next(&mut self) -> Option { + let socket = self.socket.read(); + match socket { + Ok(socket) => match socket.poll() { + Err(err) => Some(Err(err)), + Ok(Some(packet)) => Some(Ok(packet)), + // If the underlying engineIO connection is closed, + // throw an error so we know to reconnect + Ok(None) => Some(Err(Error::StoppedEngineIoSocket)), + }, + Err(_) => { + // Lock is poisoned, our iterator is useless. + None + } + } + } +} + +#[cfg(test)] +mod test { + use std::{ + sync::atomic::{AtomicUsize, Ordering}, + time::UNIX_EPOCH, + }; + + use super::*; + use crate::error::Result; + use crate::ClientBuilder; + use serde_json::json; + use serial_test::serial; + use std::time::{Duration, SystemTime}; + use url::Url; + + #[test] + #[serial(reconnect)] + fn socket_io_reconnect_integration() -> Result<()> { + static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0); + static CLOSE_NUM: AtomicUsize = AtomicUsize::new(0); + static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0); + + let url = crate::test::socket_io_restart_server(); + + let socket = ClientBuilder::new(url) + .reconnect(true) + .max_reconnect_attempts(100) + .reconnect_delay(100, 100) + .on(Event::Connect, move |_, socket| { + CONNECT_NUM.fetch_add(1, Ordering::Release); + let r = socket.emit_with_ack( + "message", + json!(""), + Duration::from_millis(100), + |_, _| {}, + ); + assert!(r.is_ok(), "should emit message success"); + }) + .on(Event::Close, move |_, _| { + CLOSE_NUM.fetch_add(1, Ordering::Release); + }) + .on("message", move |_, _socket| { + // test the iterator implementation and make sure there is a constant + // stream of packets, even when reconnecting + MESSAGE_NUM.fetch_add(1, Ordering::Release); + }) + .connect(); + + assert!(socket.is_ok(), "should connect success"); + let socket = socket.unwrap(); + + // waiting for server to emit message + std::thread::sleep(std::time::Duration::from_millis(500)); + + assert_eq!(load(&CONNECT_NUM), 1, "should connect once"); + assert_eq!(load(&MESSAGE_NUM), 1, "should receive one"); + assert_eq!(load(&CLOSE_NUM), 0, "should not close"); + + let r = socket.emit("restart_server", json!("")); + assert!(r.is_ok(), "should emit restart success"); + + // waiting for server to restart + for _ in 0..10 { + std::thread::sleep(std::time::Duration::from_millis(400)); + if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 { + break; + } + } + + assert_eq!(load(&CONNECT_NUM), 2, "should connect twice"); + assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages"); + assert_eq!(load(&CLOSE_NUM), 1, "should close once"); + + socket.disconnect()?; + Ok(()) + } + + #[test] + fn socket_io_reconnect_url_auth_integration() -> Result<()> { + static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0); + static CLOSE_NUM: AtomicUsize = AtomicUsize::new(0); + static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0); + + fn get_url() -> Url { + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis(); + let mut url = crate::test::socket_io_restart_url_auth_server(); + url.set_query(Some(&format!("timestamp={timestamp}"))); + url + } + + let socket = ClientBuilder::new(get_url()) + .reconnect(true) + .max_reconnect_attempts(100) + .reconnect_delay(100, 100) + .on(Event::Connect, move |_, socket| { + CONNECT_NUM.fetch_add(1, Ordering::Release); + let result = socket.emit_with_ack( + "message", + json!(""), + Duration::from_millis(100), + |_, _| {}, + ); + assert!(result.is_ok(), "should emit message success"); + }) + .on(Event::Close, move |_, _| { + CLOSE_NUM.fetch_add(1, Ordering::Release); + }) + .on("message", move |_, _| { + // test the iterator implementation and make sure there is a constant + // stream of packets, even when reconnecting + MESSAGE_NUM.fetch_add(1, Ordering::Release); + }) + .connect(); + + assert!(socket.is_ok(), "should connect success"); + let socket = socket.unwrap(); + + // waiting for server to emit message + std::thread::sleep(std::time::Duration::from_millis(500)); + + assert_eq!(load(&CONNECT_NUM), 1, "should connect once"); + assert_eq!(load(&MESSAGE_NUM), 1, "should receive one"); + assert_eq!(load(&CLOSE_NUM), 0, "should not close"); + + // waiting for timestamp in url to expire + std::thread::sleep(std::time::Duration::from_secs(1)); + + socket.set_reconnect_url(get_url())?; + + let result = socket.emit("restart_server", json!("")); + assert!(result.is_ok(), "should emit restart success"); + + // waiting for server to restart + for _ in 0..10 { + std::thread::sleep(std::time::Duration::from_millis(400)); + if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 { + break; + } + } + + assert_eq!(load(&CONNECT_NUM), 2, "should connect twice"); + assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages"); + assert_eq!(load(&CLOSE_NUM), 1, "should close once"); + + socket.disconnect()?; + Ok(()) + } + + #[test] + fn socket_io_iterator_integration() -> Result<()> { + let url = crate::test::socket_io_server(); + let builder = ClientBuilder::new(url); + let builder_clone = builder.clone(); + + let client = Arc::new(RwLock::new(builder_clone.connect_raw()?)); + let mut socket = Client { + builder: Arc::new(Mutex::new(builder)), + client, + backoff: Default::default(), + }; + let socket_clone = socket.clone(); + + let packets: Arc>> = Default::default(); + let packets_clone = packets.clone(); + + std::thread::spawn(move || { + for packet in socket_clone.iter() { + { + let mut packets = packets_clone.write().unwrap(); + if let Ok(packet) = packet { + (*packets).push(packet); + } + } + } + }); + + // waiting for client to emit messages + std::thread::sleep(Duration::from_millis(100)); + let lock = packets.read().unwrap(); + let pre_num = lock.len(); + drop(lock); + + let _ = socket.disconnect(); + socket.reconnect()?; + + // waiting for client to emit messages + std::thread::sleep(Duration::from_millis(100)); + + let lock = packets.read().unwrap(); + let post_num = lock.len(); + drop(lock); + + assert!( + pre_num < post_num, + "pre_num {} should less than post_num {}", + pre_num, + post_num + ); + + Ok(()) + } + + fn load(num: &AtomicUsize) -> usize { + num.load(Ordering::Acquire) + } +} diff --git a/socketio/src/client/mod.rs b/socketio/src/client/mod.rs index 778db8f0..924d45c7 100644 --- a/socketio/src/client/mod.rs +++ b/socketio/src/client/mod.rs @@ -1,484 +1,12 @@ +#![allow(clippy::module_inception)] mod builder; -mod callback; mod raw_client; pub use builder::ClientBuilder; pub use builder::TransportType; +pub use client::Client; pub use raw_client::RawClient; -use crate::{ - error::Result, - packet::{Packet, PacketId}, - Error, -}; -pub(crate) use crate::{event::Event, payload::Payload}; -use backoff::ExponentialBackoff; -use backoff::{backoff::Backoff, ExponentialBackoffBuilder}; -use std::{ - sync::{Arc, Mutex, RwLock}, - time::Duration, -}; - -#[derive(Clone)] -pub struct Client { - builder: Arc>, - client: Arc>, - backoff: ExponentialBackoff, -} - -impl Client { - pub(crate) fn new(builder: ClientBuilder) -> Result { - let builder_clone = builder.clone(); - let client = builder_clone.connect_raw()?; - let backoff = ExponentialBackoffBuilder::new() - .with_initial_interval(Duration::from_millis(builder.reconnect_delay_min)) - .with_max_interval(Duration::from_millis(builder.reconnect_delay_max)) - .build(); - - let s = Self { - builder: Arc::new(Mutex::new(builder)), - client: Arc::new(RwLock::new(client)), - backoff, - }; - s.poll_callback(); - - Ok(s) - } - - /// Updates the URL the client will connect to when reconnecting. - /// This is especially useful for updating query parameters. - pub fn set_reconnect_url>(&self, address: T) -> Result<()> { - self.builder.lock()?.address = address.into(); - Ok(()) - } - - /// Sends a message to the server using the underlying `engine.io` protocol. - /// This message takes an event, which could either be one of the common - /// events like "message" or "error" or a custom event like "foo". But be - /// careful, the data string needs to be valid JSON. It's recommended to use - /// a library like `serde_json` to serialize the data properly. - /// - /// # Example - /// ``` - /// use rust_socketio::{ClientBuilder, RawClient, Payload}; - /// use serde_json::json; - /// - /// let mut socket = ClientBuilder::new("http://localhost:4200/") - /// .on("test", |payload: Payload, socket: RawClient| { - /// println!("Received: {:#?}", payload); - /// socket.emit("test", json!({"hello": true})).expect("Server unreachable"); - /// }) - /// .connect() - /// .expect("connection failed"); - /// - /// let json_payload = json!({"token": 123}); - /// - /// let result = socket.emit("foo", json_payload); - /// - /// assert!(result.is_ok()); - /// ``` - pub fn emit(&self, event: E, data: D) -> Result<()> - where - E: Into, - D: Into, - { - let client = self.client.read()?; - // TODO(#230): like js client, buffer emit, resend after reconnect - client.emit(event, data) - } - - /// Sends a message to the server but `alloc`s an `ack` to check whether the - /// server responded in a given time span. This message takes an event, which - /// could either be one of the common events like "message" or "error" or a - /// custom event like "foo", as well as a data parameter. But be careful, - /// in case you send a [`Payload::String`], the string needs to be valid JSON. - /// It's even recommended to use a library like serde_json to serialize the data properly. - /// It also requires a timeout `Duration` in which the client needs to answer. - /// If the ack is acked in the correct time span, the specified callback is - /// called. The callback consumes a [`Payload`] which represents the data send - /// by the server. - /// - /// # Example - /// ``` - /// use rust_socketio::{ClientBuilder, Payload, RawClient}; - /// use serde_json::json; - /// use std::time::Duration; - /// use std::thread::sleep; - /// - /// let mut socket = ClientBuilder::new("http://localhost:4200/") - /// .on("foo", |payload: Payload, _| println!("Received: {:#?}", payload)) - /// .connect() - /// .expect("connection failed"); - /// - /// let ack_callback = |message: Payload, socket: RawClient| { - /// match message { - /// Payload::Text(values) => println!("{:#?}", values), - /// Payload::Binary(bytes) => println!("Received bytes: {:#?}", bytes), - /// // This is deprecated, use Payload::Text instead. - /// Payload::String(str) => println!("{}", str), - /// } - /// }; - /// - /// let payload = json!({"token": 123}); - /// socket.emit_with_ack("foo", payload, Duration::from_secs(2), ack_callback).unwrap(); - /// - /// sleep(Duration::from_secs(2)); - /// ``` - pub fn emit_with_ack( - &self, - event: E, - data: D, - timeout: Duration, - callback: F, - ) -> Result<()> - where - F: FnMut(Payload, RawClient) + 'static + Send, - E: Into, - D: Into, - { - let client = self.client.read()?; - // TODO(#230): like js client, buffer emit, resend after reconnect - client.emit_with_ack(event, data, timeout, callback) - } - - /// Disconnects this client from the server by sending a `socket.io` closing - /// packet. - /// # Example - /// ```rust - /// use rust_socketio::{ClientBuilder, Payload, RawClient}; - /// use serde_json::json; - /// - /// fn handle_test(payload: Payload, socket: RawClient) { - /// println!("Received: {:#?}", payload); - /// socket.emit("test", json!({"hello": true})).expect("Server unreachable"); - /// } - /// - /// let mut socket = ClientBuilder::new("http://localhost:4200/") - /// .on("test", handle_test) - /// .connect() - /// .expect("connection failed"); - /// - /// let json_payload = json!({"token": 123}); - /// - /// socket.emit("foo", json_payload); - /// - /// // disconnect from the server - /// socket.disconnect(); - /// - /// ``` - pub fn disconnect(&self) -> Result<()> { - let client = self.client.read()?; - client.disconnect() - } - - fn reconnect(&mut self) -> Result<()> { - let mut reconnect_attempts = 0; - let (reconnect, max_reconnect_attempts) = { - let builder = self.builder.lock()?; - (builder.reconnect, builder.max_reconnect_attempts) - }; - - if reconnect { - loop { - if let Some(max_reconnect_attempts) = max_reconnect_attempts { - reconnect_attempts += 1; - if reconnect_attempts > max_reconnect_attempts { - break; - } - } - - if let Some(backoff) = self.backoff.next_backoff() { - std::thread::sleep(backoff); - } - - if self.do_reconnect().is_ok() { - break; - } - } - } - - Ok(()) - } - - fn do_reconnect(&self) -> Result<()> { - let builder = self.builder.lock()?; - let new_client = builder.clone().connect_raw()?; - let mut client = self.client.write()?; - *client = new_client; - - Ok(()) - } - - pub(crate) fn iter(&self) -> Iter { - Iter { - socket: self.client.clone(), - } - } - - fn poll_callback(&self) { - let mut self_clone = self.clone(); - // Use thread to consume items in iterator in order to call callbacks - std::thread::spawn(move || { - // tries to restart a poll cycle whenever a 'normal' error occurs, - // it just panics on network errors, in case the poll cycle returned - // `Result::Ok`, the server receives a close frame so it's safe to - // terminate - for packet in self_clone.iter() { - let should_reconnect = match packet { - Err(Error::IncompleteResponseFromEngineIo(_)) => { - //TODO: 0.3.X handle errors - //TODO: logging error - true - } - Ok(Packet { - packet_type: PacketId::Disconnect, - .. - }) => match self_clone.builder.lock() { - Ok(builder) => builder.reconnect_on_disconnect, - Err(_) => false, - }, - _ => false, - }; - if should_reconnect { - let _ = self_clone.disconnect(); - let _ = self_clone.reconnect(); - } - } - }); - } -} - -pub(crate) struct Iter { - socket: Arc>, -} - -impl Iterator for Iter { - type Item = Result; - - fn next(&mut self) -> Option { - let socket = self.socket.read(); - match socket { - Ok(socket) => match socket.poll() { - Err(err) => Some(Err(err)), - Ok(Some(packet)) => Some(Ok(packet)), - // If the underlying engineIO connection is closed, - // throw an error so we know to reconnect - Ok(None) => Some(Err(Error::StoppedEngineIoSocket)), - }, - Err(_) => { - // Lock is poisoned, our iterator is useless. - None - } - } - } -} - -#[cfg(test)] -mod test { - use std::{ - sync::atomic::{AtomicUsize, Ordering}, - time::UNIX_EPOCH, - }; - - use super::*; - use crate::error::Result; - use crate::ClientBuilder; - use serde_json::json; - use serial_test::serial; - use std::time::{Duration, SystemTime}; - use url::Url; - - #[test] - #[serial(reconnect)] - fn socket_io_reconnect_integration() -> Result<()> { - static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0); - static CLOSE_NUM: AtomicUsize = AtomicUsize::new(0); - static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0); - - let url = crate::test::socket_io_restart_server(); - - let socket = ClientBuilder::new(url) - .reconnect(true) - .max_reconnect_attempts(100) - .reconnect_delay(100, 100) - .on(Event::Connect, move |_, socket| { - CONNECT_NUM.fetch_add(1, Ordering::Release); - let r = socket.emit_with_ack( - "message", - json!(""), - Duration::from_millis(100), - |_, _| {}, - ); - assert!(r.is_ok(), "should emit message success"); - }) - .on(Event::Close, move |_, _| { - CLOSE_NUM.fetch_add(1, Ordering::Release); - }) - .on("message", move |_, _socket| { - // test the iterator implementation and make sure there is a constant - // stream of packets, even when reconnecting - MESSAGE_NUM.fetch_add(1, Ordering::Release); - }) - .connect(); - - assert!(socket.is_ok(), "should connect success"); - let socket = socket.unwrap(); - - // waiting for server to emit message - std::thread::sleep(std::time::Duration::from_millis(500)); - - assert_eq!(load(&CONNECT_NUM), 1, "should connect once"); - assert_eq!(load(&MESSAGE_NUM), 1, "should receive one"); - assert_eq!(load(&CLOSE_NUM), 0, "should not close"); - - let r = socket.emit("restart_server", json!("")); - assert!(r.is_ok(), "should emit restart success"); - - // waiting for server to restart - for _ in 0..10 { - std::thread::sleep(std::time::Duration::from_millis(400)); - if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 { - break; - } - } - - assert_eq!(load(&CONNECT_NUM), 2, "should connect twice"); - assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages"); - assert_eq!(load(&CLOSE_NUM), 1, "should close once"); - - socket.disconnect()?; - Ok(()) - } - - #[test] - fn socket_io_reconnect_url_auth_integration() -> Result<()> { - static CONNECT_NUM: AtomicUsize = AtomicUsize::new(0); - static CLOSE_NUM: AtomicUsize = AtomicUsize::new(0); - static MESSAGE_NUM: AtomicUsize = AtomicUsize::new(0); - - fn get_url() -> Url { - let timestamp = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_millis(); - let mut url = crate::test::socket_io_restart_url_auth_server(); - url.set_query(Some(&format!("timestamp={timestamp}"))); - url - } - - let socket = ClientBuilder::new(get_url()) - .reconnect(true) - .max_reconnect_attempts(100) - .reconnect_delay(100, 100) - .on(Event::Connect, move |_, socket| { - CONNECT_NUM.fetch_add(1, Ordering::Release); - let result = socket.emit_with_ack( - "message", - json!(""), - Duration::from_millis(100), - |_, _| {}, - ); - assert!(result.is_ok(), "should emit message success"); - }) - .on(Event::Close, move |_, _| { - CLOSE_NUM.fetch_add(1, Ordering::Release); - }) - .on("message", move |_, _| { - // test the iterator implementation and make sure there is a constant - // stream of packets, even when reconnecting - MESSAGE_NUM.fetch_add(1, Ordering::Release); - }) - .connect(); - - assert!(socket.is_ok(), "should connect success"); - let socket = socket.unwrap(); - - // waiting for server to emit message - std::thread::sleep(std::time::Duration::from_millis(500)); - - assert_eq!(load(&CONNECT_NUM), 1, "should connect once"); - assert_eq!(load(&MESSAGE_NUM), 1, "should receive one"); - assert_eq!(load(&CLOSE_NUM), 0, "should not close"); - - // waiting for timestamp in url to expire - std::thread::sleep(std::time::Duration::from_secs(1)); - - socket.set_reconnect_url(get_url())?; - - let result = socket.emit("restart_server", json!("")); - assert!(result.is_ok(), "should emit restart success"); - - // waiting for server to restart - for _ in 0..10 { - std::thread::sleep(std::time::Duration::from_millis(400)); - if load(&CONNECT_NUM) == 2 && load(&MESSAGE_NUM) == 2 { - break; - } - } - - assert_eq!(load(&CONNECT_NUM), 2, "should connect twice"); - assert_eq!(load(&MESSAGE_NUM), 2, "should receive two messages"); - assert_eq!(load(&CLOSE_NUM), 1, "should close once"); - - socket.disconnect()?; - Ok(()) - } - - #[test] - fn socket_io_iterator_integration() -> Result<()> { - let url = crate::test::socket_io_server(); - let builder = ClientBuilder::new(url); - let builder_clone = builder.clone(); - - let client = Arc::new(RwLock::new(builder_clone.connect_raw()?)); - let mut socket = Client { - builder: Arc::new(Mutex::new(builder)), - client, - backoff: Default::default(), - }; - let socket_clone = socket.clone(); - - let packets: Arc>> = Default::default(); - let packets_clone = packets.clone(); - - std::thread::spawn(move || { - for packet in socket_clone.iter() { - { - let mut packets = packets_clone.write().unwrap(); - if let Ok(packet) = packet { - (*packets).push(packet); - } - } - } - }); - - // waiting for client to emit messages - std::thread::sleep(Duration::from_millis(100)); - let lock = packets.read().unwrap(); - let pre_num = lock.len(); - drop(lock); - - let _ = socket.disconnect(); - socket.reconnect()?; - - // waiting for client to emit messages - std::thread::sleep(Duration::from_millis(100)); - - let lock = packets.read().unwrap(); - let post_num = lock.len(); - drop(lock); - - assert!( - pre_num < post_num, - "pre_num {} should less than post_num {}", - pre_num, - post_num - ); - - Ok(()) - } - - fn load(num: &AtomicUsize) -> usize { - num.load(Ordering::Acquire) - } -} +/// Internal callback type +mod callback; +mod client; diff --git a/socketio/src/client/raw_client.rs b/socketio/src/client/raw_client.rs index f1f30d0f..b06c84e7 100644 --- a/socketio/src/client/raw_client.rs +++ b/socketio/src/client/raw_client.rs @@ -77,37 +77,28 @@ impl RawClient { /// /// # Example /// - /// ```no_run + /// ```rust /// use rust_socketio::{ /// client::Client, ClientBuilder, /// Error , Payload, RawClient, /// }; /// use std::sync::mpsc; /// - /// - /// let callback = |payload: Payload, socket: RawClient| { - /// match payload { - /// Payload::Text(values) => { - /// if let Some(value) = values.first() { - /// if value.is_string() { - /// socket.try_transmitter::>().map_or_else( - /// |err| eprintln!("{}", err), - /// |tx| { - /// tx.send(String::from(value.as_str().unwrap())) - /// .map_or_else( - /// |err| eprintln!("{}", err), - /// |_| println!("Data transmitted successfully"), - /// ); - /// }, - /// ); - /// } - /// } - /// } - /// Payload::Binary(bin_data) => println!("{:#?}", bin_data), - /// #[allow(deprecated)] - /// Payload::String(str) => println!("Received: {}", str), - /// } - /// }; + /// fn event_handler(payload: Payload, socket: RawClient) { + /// if let Payload::Text(values) = payload { + /// match socket.try_transmitter::>>() { + /// Ok(tx) => { + /// tx.send(values.to_owned()).map_or_else( + /// |err| eprintln!("{}", err), + /// |_| println!("Data transmitted successfully"), + /// ); + /// } + /// Err(err) => { + /// eprintln!("{}", err); + /// } + /// } + /// } + /// } /// ``` pub fn try_transmitter(&self) -> Result> { match Arc::clone(&self.transmitter).downcast() {