From 20243af200e89f07c3258631655a7418e2d4df88 Mon Sep 17 00:00:00 2001 From: Ragesh Krishna Date: Thu, 21 Mar 2024 11:17:31 +0530 Subject: [PATCH] Introduce `DisconnectReason` enum The enum replaces the need for multiple `AtomicBool`'s to maintain the disconnection reason. This makes the code easier to read and more ergonomic to maintain the state. --- socketio/src/asynchronous/client/client.rs | 58 ++++++++++++---------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/socketio/src/asynchronous/client/client.rs b/socketio/src/asynchronous/client/client.rs index 2043f97f..525ab39a 100644 --- a/socketio/src/asynchronous/client/client.rs +++ b/socketio/src/asynchronous/client/client.rs @@ -29,6 +29,17 @@ use crate::{ Event, Payload, }; +#[derive(Default)] +enum DisconnectReason { + /// There is no known reason for the disconnect; likely a network error + #[default] + Unknown, + /// The user disconnected manually + Manual, + /// The server disconnected + Server, +} + /// A socket which handles communication with the server. It's initialized with /// a specific address as well as an optional namespace to connect to. If `None` /// is given the client will connect to the default namespace `"/"`. @@ -42,8 +53,7 @@ pub struct Client { // Data send in the opening packet (commonly used as for auth) auth: Option, builder: Arc>, - manually_disconnected: Arc, - server_disconnected: Arc, + disconnect_reason: Arc>, } impl Client { @@ -58,8 +68,7 @@ impl Client { outstanding_acks: Arc::new(RwLock::new(Vec::new())), auth: builder.auth.clone(), builder: Arc::new(RwLock::new(builder)), - manually_disconnected: Arc::new(AtomicBool::new(false)), - server_disconnected: Arc::new(AtomicBool::new(false)), + disconnect_reason: Arc::new(RwLock::new(DisconnectReason::default())), }) } @@ -85,6 +94,9 @@ impl Client { // New inner socket that can be connected let mut client_socket = self.socket.write().await; *client_socket = socket; + + // Now that we have replaced `self.socket`, we drop the write lock + // because the `connect` method we call below will need to use it drop(client_socket); self.connect().await?; @@ -98,6 +110,8 @@ impl Client { let reconnect_delay_min = builder.reconnect_delay_min; let reconnect_delay_max = builder.reconnect_delay_max; let max_reconnect_attempts = builder.max_reconnect_attempts; + let reconnect = builder.reconnect; + let reconnect_on_disconnect = builder.reconnect_on_disconnect; drop(builder); let mut client_clone = self.clone(); @@ -115,7 +129,13 @@ impl Client { // Drop the stream so we can once again use `socket_clone` as mutable drop(stream); - if client_clone.should_reconnect().await { + let should_reconnect = match *(client_clone.disconnect_reason.read().await) { + DisconnectReason::Unknown => reconnect, + DisconnectReason::Manual => false, + DisconnectReason::Server => reconnect_on_disconnect, + }; + + if should_reconnect { let mut reconnect_attempts = 0; let mut backoff = ExponentialBackoffBuilder::new() .with_initial_interval(Duration::from_millis(reconnect_delay_min)) @@ -233,7 +253,7 @@ impl Client { /// } /// ``` pub async fn disconnect(&self) -> Result<()> { - self.manually_disconnected.store(true, Ordering::Release); + *(self.disconnect_reason.write().await) = DisconnectReason::Manual; let disconnect_packet = Packet::new(PacketId::Disconnect, self.nsp.clone(), None, None, 0, None); @@ -456,11 +476,11 @@ impl Client { } } PacketId::Connect => { - self.server_disconnected.store(false, Ordering::Release); + *(self.disconnect_reason.write().await) = DisconnectReason::default(); self.callback(&Event::Connect, "").await?; } PacketId::Disconnect => { - self.server_disconnected.store(true, Ordering::Release); + *(self.disconnect_reason.write().await) = DisconnectReason::Server; self.callback(&Event::Close, "").await?; } PacketId::ConnectError => { @@ -484,31 +504,15 @@ impl Client { Ok(()) } - /// Indicates whether the client should try to reconnect - pub(crate) async fn should_reconnect(&self) -> bool { - let manually_disconnected = self.manually_disconnected.load(Ordering::Acquire); - let server_disconnected = self.server_disconnected.load(Ordering::Acquire); - - if server_disconnected { - self.builder.read().await.reconnect_on_disconnect - } else { - !manually_disconnected - } - } - /// Returns the packet stream for the client. pub(crate) fn as_stream<'a>( &'a self, ) -> Pin> + Send + 'a>> { - let socket_clone = self.socket.clone(); + let socket_clone = (*self.socket.blocking_read()).clone(); - stream::unfold(socket_clone, |socket| async { - let mut socket_read = { - let s = socket.read().await; - (*s).clone() - }; + stream::unfold(socket_clone, |mut socket| async { // wait for the next payload - let packet: Option> = socket_read.next().await; + let packet: Option> = socket.next().await; match packet { // end the stream if the underlying one is closed None => None,