diff --git a/rumqttc/CHANGELOG.md b/rumqttc/CHANGELOG.md index 03eb79556..5d7902176 100644 --- a/rumqttc/CHANGELOG.md +++ b/rumqttc/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `set_session_expiry_interval` and `session_expiry_interval` methods on `MqttOptions`. * `Auth` packet as per MQTT5 standards * Allow configuring the `nodelay` property of underlying TCP client with the `tcp_nodelay` field in `NetworkOptions` +* `publish` / `subscribe` / `unsubscribe` methods on `AsyncClient` and `Client` now return an `AckPromise` which resolves when the packet(except for QoS 0 publishes, which resolve as soon as handled) is acknowledged by the broker. ### Changed diff --git a/rumqttc/examples/ack_promise.rs b/rumqttc/examples/ack_promise.rs new file mode 100644 index 000000000..f4c837b29 --- /dev/null +++ b/rumqttc/examples/ack_promise.rs @@ -0,0 +1,81 @@ +use tokio::task::{self, JoinSet}; + +use rumqttc::{AsyncClient, MqttOptions, QoS}; +use std::error::Error; +use std::time::Duration; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + task::spawn(async move { + loop { + let event = eventloop.poll().await; + match &event { + Ok(v) => { + println!("Event = {v:?}"); + } + Err(e) => { + println!("Error = {e:?}"); + } + } + } + }); + + // Subscribe and wait for broker acknowledgement + match client + .subscribe("hello/world", QoS::AtMostOnce) + .await + .unwrap() + .await + { + Ok(pkid) => println!("Acknowledged Sub({pkid:?})"), + Err(e) => println!("Subscription failed: {e:?}"), + } + + // Publish at all QoS levels and wait for broker acknowledgement + for (i, qos) in [QoS::AtMostOnce, QoS::AtLeastOnce, QoS::ExactlyOnce] + .into_iter() + .enumerate() + { + match client + .publish("hello/world", qos, false, vec![1; i]) + .await + .unwrap() + .await + { + Ok(ack) => println!("Acknowledged Pub({ack:?})"), + Err(e) => println!("Publish failed: {e:?}"), + } + } + + // Publish with different QoS levels and spawn wait for notification + let mut set = JoinSet::new(); + for (i, qos) in [QoS::AtMostOnce, QoS::AtLeastOnce, QoS::ExactlyOnce] + .into_iter() + .enumerate() + { + let token = client + .publish("hello/world", qos, false, vec![1; i]) + .await + .unwrap(); + set.spawn(token); + } + + while let Some(Ok(res)) = set.join_next().await { + match res { + Ok(ack) => println!("Acknowledged Pub({ack:?})"), + Err(e) => println!("Publish failed: {e:?}"), + } + } + + // Unsubscribe and wait for broker acknowledgement + match client.unsubscribe("hello/world").await.unwrap().await { + Ok(ack) => println!("Acknowledged Unsub({ack:?})"), + Err(e) => println!("Unsubscription failed: {e:?}"), + } + + Ok(()) +} diff --git a/rumqttc/examples/ack_promise_sync.rs b/rumqttc/examples/ack_promise_sync.rs new file mode 100644 index 000000000..1eaa20e0d --- /dev/null +++ b/rumqttc/examples/ack_promise_sync.rs @@ -0,0 +1,98 @@ +use flume::bounded; +use rumqttc::{Client, MqttOptions, QoS, TokenError}; +use std::error::Error; +use std::thread::{self, sleep}; +use std::time::Duration; + +fn main() -> Result<(), Box> { + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut conn) = Client::new(mqttoptions, 10); + thread::spawn(move || { + for event in conn.iter() { + match &event { + Ok(v) => { + println!("Event = {v:?}"); + } + Err(e) => { + println!("Error = {e:?}"); + } + } + } + }); + + // Subscribe and wait for broker acknowledgement + match client + .subscribe("hello/world", QoS::AtMostOnce) + .unwrap() + .wait() + { + Ok(pkid) => println!("Acknowledged Sub({pkid:?})"), + Err(e) => println!("Subscription failed: {e:?}"), + } + + // Publish at all QoS levels and wait for broker acknowledgement + for (i, qos) in [QoS::AtMostOnce, QoS::AtLeastOnce, QoS::ExactlyOnce] + .into_iter() + .enumerate() + { + match client + .publish("hello/world", qos, false, vec![1; i]) + .unwrap() + .wait() + { + Ok(ack) => println!("Acknowledged Pub({ack:?})"), + Err(e) => println!("Publish failed: {e:?}"), + } + } + + // Spawn threads for each publish, use channel to notify result + let (tx, rx) = bounded(1); + + for (i, qos) in [QoS::AtMostOnce, QoS::AtLeastOnce, QoS::ExactlyOnce] + .into_iter() + .enumerate() + { + let token = client + .publish("hello/world", qos, false, vec![1; i]) + .unwrap(); + let tx = tx.clone(); + thread::spawn(move || { + let res = token.wait(); + tx.send(res).unwrap() + }); + } + + // Try resolving a promise, if it is waiting to resolve, try again after a sleep of 1s + let mut token = client + .publish("hello/world", QoS::AtMostOnce, false, vec![1; 4]) + .unwrap(); + thread::spawn(move || loop { + match token.check() { + Err(TokenError::Waiting) => { + println!("Promise yet to resolve, retrying"); + sleep(Duration::from_secs(1)); + } + res => { + tx.send(res).unwrap(); + break; + } + } + }); + + while let Ok(res) = rx.recv() { + match res { + Ok(ack) => println!("Acknowledged Pub({ack:?})"), + Err(e) => println!("Publish failed: {e:?}"), + } + } + + // Unsubscribe and wait for broker acknowledgement + match client.unsubscribe("hello/world").unwrap().wait() { + Ok(ack) => println!("Acknowledged Unsub({ack:?})"), + Err(e) => println!("Unsubscription failed: {e:?}"), + } + + Ok(()) +} diff --git a/rumqttc/examples/ack_promise_v5.rs b/rumqttc/examples/ack_promise_v5.rs new file mode 100644 index 000000000..de2fdf566 --- /dev/null +++ b/rumqttc/examples/ack_promise_v5.rs @@ -0,0 +1,81 @@ +use tokio::task::{self, JoinSet}; + +use rumqttc::v5::{mqttbytes::QoS, AsyncClient, MqttOptions}; +use std::error::Error; +use std::time::Duration; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883); + mqttoptions.set_keep_alive(Duration::from_secs(5)); + + let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10); + task::spawn(async move { + loop { + let event = eventloop.poll().await; + match &event { + Ok(v) => { + println!("Event = {v:?}"); + } + Err(e) => { + println!("Error = {e:?}"); + } + } + } + }); + + // Subscribe and wait for broker acknowledgement + match client + .subscribe("hello/world", QoS::AtMostOnce) + .await + .unwrap() + .await + { + Ok(pkid) => println!("Acknowledged Sub({pkid:?})"), + Err(e) => println!("Subscription failed: {e:?}"), + } + + // Publish at all QoS levels and wait for broker acknowledgement + for (i, qos) in [QoS::AtMostOnce, QoS::AtLeastOnce, QoS::ExactlyOnce] + .into_iter() + .enumerate() + { + match client + .publish("hello/world", qos, false, vec![1; i]) + .await + .unwrap() + .await + { + Ok(pkid) => println!("Acknowledged Pub({pkid:?})"), + Err(e) => println!("Publish failed: {e:?}"), + } + } + + // Publish with different QoS levels and spawn wait for notification + let mut set = JoinSet::new(); + for (i, qos) in [QoS::AtMostOnce, QoS::AtLeastOnce, QoS::ExactlyOnce] + .into_iter() + .enumerate() + { + let token = client + .publish("hello/world", qos, false, vec![1; i]) + .await + .unwrap(); + set.spawn(token); + } + + while let Some(Ok(res)) = set.join_next().await { + match res { + Ok(pkid) => println!("Acknowledged Pub({pkid:?})"), + Err(e) => println!("Publish failed: {e:?}"), + } + } + + // Unsubscribe and wait for broker acknowledgement + match client.unsubscribe("hello/world").await.unwrap().await { + Ok(pkid) => println!("Acknowledged Unsub({pkid:?})"), + Err(e) => println!("Unsubscription failed: {e:?}"), + } + + Ok(()) +} diff --git a/rumqttc/src/client.rs b/rumqttc/src/client.rs index cb58cf82d..4ed010266 100644 --- a/rumqttc/src/client.rs +++ b/rumqttc/src/client.rs @@ -3,7 +3,11 @@ use std::time::Duration; use crate::mqttbytes::{v4::*, QoS}; -use crate::{valid_filter, valid_topic, ConnectionError, Event, EventLoop, MqttOptions, Request}; +use crate::tokens::{NoResponse, Resolver, Token}; +use crate::{ + valid_filter, valid_topic, AckOfAck, AckOfPub, ConnectionError, Event, EventLoop, MqttOptions, + Request, +}; use bytes::Bytes; use flume::{SendError, Sender, TrySendError}; @@ -72,20 +76,22 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, V: Into>, { + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); + let request = Request::Publish(publish, resolver); if !valid_topic(&topic) { - return Err(ClientError::Request(publish)); + return Err(ClientError::Request(request)); } - self.request_tx.send_async(publish).await?; - Ok(()) + self.request_tx.send_async(request).await?; + + Ok(token) } /// Attempts to send a MQTT Publish to the `EventLoop`. @@ -95,39 +101,44 @@ impl AsyncClient { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, V: Into>, { + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); + let request = Request::Publish(publish, resolver); if !valid_topic(&topic) { - return Err(ClientError::TryRequest(publish)); + return Err(ClientError::TryRequest(request)); } - self.request_tx.try_send(publish)?; - Ok(()) + self.request_tx.try_send(request)?; + + Ok(token) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); - + pub async fn ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { self.request_tx.send_async(ack).await?; } - Ok(()) + + Ok(token) } /// Attempts to send a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); + pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { self.request_tx.try_send(ack)?; } - Ok(()) + + Ok(token) } /// Sends a MQTT Publish to the `EventLoop` @@ -137,102 +148,144 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, { + let (resolver, token) = Resolver::new(); let mut publish = Publish::from_bytes(topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); - self.request_tx.send_async(publish).await?; - Ok(()) + let request = Request::Publish(publish, resolver); + self.request_tx.send_async(request).await?; + + Ok(token) } /// Sends a MQTT Subscribe to the `EventLoop` - pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub async fn subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new(topic, qos); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } + self.request_tx.send_async(request).await?; - self.request_tx.send_async(subscribe.into()).await?; - Ok(()) + Ok(token) } /// Attempts to send a MQTT Subscribe to the `EventLoop` - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn try_subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new(topic, qos); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::TryRequest(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::TryRequest(request)); } + self.request_tx.try_send(request)?; - self.request_tx.try_send(subscribe.into())?; - Ok(()) + Ok(token) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` - pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub async fn subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } + self.request_tx.send_async(request).await?; - self.request_tx.send_async(subscribe.into()).await?; - Ok(()) + Ok(token) } /// Attempts to send a MQTT Subscribe for multiple topics to the `EventLoop` - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::TryRequest(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::TryRequest(request)); } - self.request_tx.try_send(subscribe.into())?; - Ok(()) + self.request_tx.try_send(request)?; + + Ok(token) } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub async fn unsubscribe>( + &self, + topic: S, + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic.into()); - let request = Request::Unsubscribe(unsubscribe); + let request = Request::Unsubscribe(unsubscribe, resolver); self.request_tx.send_async(request).await?; - Ok(()) + + Ok(token) } /// Attempts to send a MQTT Unsubscribe to the `EventLoop` - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn try_unsubscribe>( + &self, + topic: S, + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic.into()); - let request = Request::Unsubscribe(unsubscribe); + let request = Request::Unsubscribe(unsubscribe, resolver); self.request_tx.try_send(request)?; - Ok(()) + + Ok(token) } /// Sends a MQTT disconnect to the `EventLoop` - pub async fn disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect(Disconnect); + pub async fn disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); self.request_tx.send_async(request).await?; - Ok(()) + + Ok(token) } /// Attempts to send a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect(Disconnect); + pub fn try_disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); self.request_tx.try_send(request)?; - Ok(()) + + Ok(token) } } -fn get_ack_req(publish: &Publish) -> Option { +fn get_ack_req(publish: &Publish, resolver: Resolver) -> Option { let ack = match publish.qos { - QoS::AtMostOnce => return None, - QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid)), - QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid)), + QoS::AtMostOnce => { + resolver.resolve(AckOfAck::None); + return None; + } + QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid), resolver), + QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid), resolver), }; Some(ack) } @@ -285,20 +338,22 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, V: Into>, { + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload); publish.retain = retain; - let publish = Request::Publish(publish); + let publish = Request::Publish(publish, resolver); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } self.client.request_tx.send(publish)?; - Ok(()) + + Ok(token) } pub fn try_publish( @@ -307,63 +362,76 @@ impl Client { qos: QoS, retain: bool, payload: V, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, V: Into>, { - self.client.try_publish(topic, qos, retain, payload)?; - Ok(()) + self.client.try_publish(topic, qos, retain, payload) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); - + pub fn ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { self.client.request_tx.send(ack)?; } - Ok(()) + + Ok(token) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { - self.client.try_ack(publish)?; - Ok(()) + pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { + self.client.try_ack(publish) } /// Sends a MQTT Subscribe to the `EventLoop` - pub fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new(topic, qos); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } + self.client.request_tx.send(request)?; - self.client.request_tx.send(subscribe.into())?; - Ok(()) + Ok(token) } /// Sends a MQTT Subscribe to the `EventLoop` - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { - self.client.try_subscribe(topic, qos)?; - Ok(()) + pub fn try_subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result, ClientError> { + self.client.try_subscribe(topic, qos) } /// Sends a MQTT Subscribe for multiple topics to the `EventLoop` - pub fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } + self.client.request_tx.send(request)?; - self.client.request_tx.send(subscribe.into())?; - Ok(()) + Ok(token) } - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -371,30 +439,35 @@ impl Client { } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn unsubscribe>(&self, topic: S) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic.into()); - let request = Request::Unsubscribe(unsubscribe); + let request = Request::Unsubscribe(unsubscribe, resolver); self.client.request_tx.send(request)?; - Ok(()) + + Ok(token) } /// Sends a MQTT Unsubscribe to the `EventLoop` - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { - self.client.try_unsubscribe(topic)?; - Ok(()) + pub fn try_unsubscribe>( + &self, + topic: S, + ) -> Result, ClientError> { + self.client.try_unsubscribe(topic) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect(Disconnect); + pub fn disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); self.client.request_tx.send(request)?; - Ok(()) + + Ok(token) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { - self.client.try_disconnect()?; - Ok(()) + pub fn try_disconnect(&self) -> Result, ClientError> { + self.client.try_disconnect() } } diff --git a/rumqttc/src/eventloop.rs b/rumqttc/src/eventloop.rs index d31690d99..00063bd0f 100644 --- a/rumqttc/src/eventloop.rs +++ b/rumqttc/src/eventloop.rs @@ -134,7 +134,7 @@ impl EventLoop { requests_in_channel.retain(|request| { match request { - Request::PubAck(_) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack + Request::PubAck(..) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack _ => true, } }); @@ -260,7 +260,7 @@ impl EventLoop { let timeout = self.keepalive_timeout.as_mut().unwrap(); timeout.as_mut().reset(Instant::now() + self.mqtt_options.keep_alive); - if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq(PingReq))? { + if let Some(outgoing) = self.state.handle_outgoing_packet(Request::PingReq)? { network.write(outgoing).await?; } match time::timeout(network_timeout, network.flush()).await { diff --git a/rumqttc/src/lib.rs b/rumqttc/src/lib.rs index 07694ffaf..8cf870c80 100644 --- a/rumqttc/src/lib.rs +++ b/rumqttc/src/lib.rs @@ -133,10 +133,9 @@ type RequestModifierFn = Arc< #[cfg(feature = "proxy")] mod proxy; +mod tokens; -pub use client::{ - AsyncClient, Client, ClientError, Connection, Iter, RecvError, RecvTimeoutError, TryRecvError, -}; +pub use client::{AsyncClient, Client, ClientError, Connection, Iter, RecvError, RecvTimeoutError}; pub use eventloop::{ConnectionError, Event, EventLoop}; pub use mqttbytes::v4::*; pub use mqttbytes::*; @@ -145,6 +144,8 @@ use rustls_native_certs::load_native_certs; pub use state::{MqttState, StateError}; #[cfg(any(feature = "use-rustls", feature = "use-native-tls"))] pub use tls::Error as TlsError; +use tokens::Resolver; +pub use tokens::{Token, TokenError}; #[cfg(feature = "use-native-tls")] pub use tokio_native_tls; #[cfg(feature = "use-native-tls")] @@ -159,6 +160,21 @@ pub use proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; +/// Used to encapsulate all publish/pubrec acknowledgements in v4 +#[derive(Debug, PartialEq)] +pub enum AckOfPub { + PubAck(PubAck), + PubComp(PubComp), + None, +} + +/// Used to encapsulate all ack/pubrel acknowledgements in v4 +#[derive(Debug)] +pub enum AckOfAck { + None, + PubRel(PubRel), +} + /// Current outgoing activity on the eventloop #[derive(Debug, Clone, PartialEq, Eq)] pub enum Outgoing { @@ -188,39 +204,20 @@ pub enum Outgoing { /// Requests by the client to mqtt event loop. Request are /// handled one by one. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Debug)] pub enum Request { - Publish(Publish), - PubAck(PubAck), - PubRec(PubRec), - PubComp(PubComp), - PubRel(PubRel), - PingReq(PingReq), - PingResp(PingResp), - Subscribe(Subscribe), - SubAck(SubAck), - Unsubscribe(Unsubscribe), - UnsubAck(UnsubAck), - Disconnect(Disconnect), -} - -impl From for Request { - fn from(publish: Publish) -> Request { - Request::Publish(publish) - } -} - -impl From for Request { - fn from(subscribe: Subscribe) -> Request { - Request::Subscribe(subscribe) - } + Publish(Publish, Resolver), + PubAck(PubAck, Resolver), + PubRec(PubRec, Resolver), + PubRel(PubRel, Resolver), + Subscribe(Subscribe, Resolver), + Unsubscribe(Unsubscribe, Resolver), + Disconnect(Resolver<()>), + PingReq, } -impl From for Request { - fn from(unsubscribe: Unsubscribe) -> Request { - Request::Unsubscribe(unsubscribe) - } -} +/// Packet Identifier with which Publish/Subscribe/Unsubscribe packets are identified while inflight. +pub type Pkid = u16; /// Transport methods. Defaults to TCP. #[derive(Clone)] diff --git a/rumqttc/src/state.rs b/rumqttc/src/state.rs index f7cb34841..f52bcdf32 100644 --- a/rumqttc/src/state.rs +++ b/rumqttc/src/state.rs @@ -1,9 +1,10 @@ -use crate::{Event, Incoming, Outgoing, Request}; +use crate::{tokens::Resolver, Event, Incoming, Outgoing, Request}; +use crate::{AckOfAck, AckOfPub, Pkid}; use crate::mqttbytes::v4::*; use crate::mqttbytes::{self, *}; use fixedbitset::FixedBitSet; -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::{io, time::Instant}; /// Errors during state handling @@ -40,7 +41,7 @@ pub enum StateError { // This is done for 2 reasons // Bad acks or out of order acks aren't O(n) causing cpu spikes // Any missing acks from the broker are detected during the next recycled use of packet ids -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct MqttState { /// Status of last ping pub await_pingresp: bool, @@ -67,11 +68,19 @@ pub struct MqttState { /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: FixedBitSet, /// Last collision due to broker not acking in order - pub collision: Option, + pub collision: Option<(Publish, Resolver)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately pub manual_acks: bool, + /// Waiters for publish acknowledgements + pub_ack_waiter: HashMap>, + /// Waiters for PubRel, qos 2 + pub_rel_waiter: HashMap>, + /// Waiters for subscribe acknowledgements + sub_ack_waiter: HashMap>, + /// Waiters for unsubscribe acknowledgements + unsub_ack_waiter: HashMap>, } impl MqttState { @@ -96,6 +105,10 @@ impl MqttState { // TODO: Optimize these sizes later events: VecDeque::with_capacity(100), manual_acks, + pub_ack_waiter: HashMap::with_capacity(max_inflight as usize), + pub_rel_waiter: HashMap::with_capacity(max_inflight as usize), + sub_ack_waiter: HashMap::with_capacity(max_inflight as usize), + unsub_ack_waiter: HashMap::with_capacity(max_inflight as usize), } } @@ -108,16 +121,25 @@ impl MqttState { for publish in second_half.iter_mut().chain(first_half) { if let Some(publish) = publish.take() { - let request = Request::Publish(publish); + let resolver = self.pub_ack_waiter.remove(&publish.pkid).unwrap(); + let request = Request::Publish(publish, resolver); pending.push(request); } } // remove and collect pending releases for pkid in self.outgoing_rel.ones() { - let request = Request::PubRel(PubRel::new(pkid as u16)); + let pkid = pkid as u16; + let resolver = self.pub_ack_waiter.remove(&pkid).unwrap(); + let request = Request::PubRel(PubRel::new(pkid), resolver); pending.push(request); } + + // we don't retransmit subscribe and unsubscribe packet + // so we can clear their state + self.sub_ack_waiter.clear(); + self.unsub_ack_waiter.clear(); + self.outgoing_rel.clear(); // remove packet ids of incoming qos2 publishes @@ -140,15 +162,24 @@ impl MqttState { request: Request, ) -> Result, StateError> { let packet = match request { - Request::Publish(publish) => self.outgoing_publish(publish)?, - Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, - Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, - Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, - Request::PingReq(_) => self.outgoing_ping()?, - Request::Disconnect(_) => self.outgoing_disconnect()?, - Request::PubAck(puback) => self.outgoing_puback(puback)?, - Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, - _ => unimplemented!(), + Request::Publish(publish, resolver) => self.outgoing_publish(publish, resolver)?, + Request::PubRel(pubrel, resolver) => self.outgoing_pubrel(pubrel, resolver)?, + Request::Subscribe(subscribe, resolver) => { + self.outgoing_subscribe(subscribe, resolver)? + } + Request::Unsubscribe(unsubscribe, resolver) => { + self.outgoing_unsubscribe(unsubscribe, resolver)? + } + Request::PingReq => self.outgoing_ping()?, + Request::Disconnect(resolver) => { + resolver.resolve(()); + self.outgoing_disconnect()? + } + Request::PubAck(puback, resolver) => { + resolver.resolve(AckOfAck::None); + self.outgoing_puback(puback)? + } + Request::PubRec(pubrec, resolver) => self.outgoing_pubrec(pubrec, resolver)?, }; self.last_outgoing = Instant::now(); @@ -165,11 +196,11 @@ impl MqttState { ) -> Result, StateError> { self.events.push_back(Event::Incoming(packet.clone())); - let outgoing = match &packet { + let outgoing = match packet { Incoming::PingResp => self.handle_incoming_pingresp()?, Incoming::Publish(publish) => self.handle_incoming_publish(publish)?, - Incoming::SubAck(_suback) => self.handle_incoming_suback()?, - Incoming::UnsubAck(_unsuback) => self.handle_incoming_unsuback()?, + Incoming::SubAck(suback) => self.handle_incoming_suback(suback)?, + Incoming::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback)?, Incoming::PubAck(puback) => self.handle_incoming_puback(puback)?, Incoming::PubRec(pubrec) => self.handle_incoming_pubrec(pubrec)?, Incoming::PubRel(pubrel) => self.handle_incoming_pubrel(pubrel)?, @@ -184,17 +215,32 @@ impl MqttState { Ok(outgoing) } - fn handle_incoming_suback(&mut self) -> Result, StateError> { + fn handle_incoming_suback(&mut self, suback: SubAck) -> Result, StateError> { + let Some(resolver) = self.sub_ack_waiter.remove(&suback.pkid) else { + return Err(StateError::Unsolicited(suback.pkid)); + }; + + resolver.resolve(suback); + Ok(None) } - fn handle_incoming_unsuback(&mut self) -> Result, StateError> { + fn handle_incoming_unsuback( + &mut self, + unsuback: UnsubAck, + ) -> Result, StateError> { + let Some(resolver) = self.unsub_ack_waiter.remove(&unsuback.pkid) else { + return Err(StateError::Unsolicited(unsuback.pkid)); + }; + + resolver.resolve(unsuback); + Ok(None) } /// Results in a publish notification in all the QoS cases. Replys with an ack /// in case of QoS1 and Replys rec in case of QoS while also storing the message - fn handle_incoming_publish(&mut self, publish: &Publish) -> Result, StateError> { + fn handle_incoming_publish(&mut self, publish: Publish) -> Result, StateError> { let qos = publish.qos; match qos { @@ -212,34 +258,44 @@ impl MqttState { if !self.manual_acks { let pubrec = PubRec::new(pkid); - return self.outgoing_pubrec(pubrec); + let (resolver, _) = Resolver::new(); + return self.outgoing_pubrec(pubrec, resolver); } Ok(None) } } } - fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { - let publish = self + fn handle_incoming_puback(&mut self, puback: PubAck) -> Result, StateError> { + let pkid = puback.pkid; + let p = self .outgoing_pub - .get_mut(puback.pkid as usize) - .ok_or(StateError::Unsolicited(puback.pkid))?; + .get_mut(pkid as usize) + .ok_or(StateError::Unsolicited(pkid))?; - self.last_puback = puback.pkid; + self.last_puback = pkid; - if publish.take().is_none() { - error!("Unsolicited puback packet: {:?}", puback.pkid); - return Err(StateError::Unsolicited(puback.pkid)); + if p.take().is_none() { + error!("Unsolicited puback packet: {pkid:?}"); + return Err(StateError::Unsolicited(pkid)); } + let Some(resolver) = self.pub_ack_waiter.remove(&pkid) else { + return Err(StateError::Unsolicited(pkid)); + }; + + // Resolve promise for QoS 1 + resolver.resolve(AckOfPub::PubAck(puback)); + self.inflight -= 1; - let packet = self.check_collision(puback.pkid).map(|publish| { + let packet = self.check_collision(pkid).map(|(publish, resolver)| { self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); self.inflight += 1; let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); self.collision_ping_count = 0; + self.pub_ack_waiter.insert(publish.pkid, resolver); Packet::Publish(publish) }); @@ -247,13 +303,14 @@ impl MqttState { Ok(packet) } - fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { - let publish = self + fn handle_incoming_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { + if self .outgoing_pub .get_mut(pubrec.pkid as usize) - .ok_or(StateError::Unsolicited(pubrec.pkid))?; - - if publish.take().is_none() { + .ok_or(StateError::Unsolicited(pubrec.pkid))? + .take() + .is_none() + { error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); return Err(StateError::Unsolicited(pubrec.pkid)); } @@ -267,32 +324,45 @@ impl MqttState { Ok(Some(Packet::PubRel(pubrel))) } - fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result, StateError> { - if !self.incoming_pub.contains(pubrel.pkid as usize) { - error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); - return Err(StateError::Unsolicited(pubrel.pkid)); + fn handle_incoming_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { + let pkid = pubrel.pkid; + if !self.incoming_pub.contains(pkid as usize) { + error!("Unsolicited pubrel packet: {:?}", pkid); + return Err(StateError::Unsolicited(pkid)); } - self.incoming_pub.set(pubrel.pkid as usize, false); - let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid)); - let pubcomp = PubComp { pkid: pubrel.pkid }; + let resolver = self.pub_rel_waiter.remove(&pkid).unwrap(); + resolver.resolve(AckOfAck::PubRel(pubrel)); + + self.incoming_pub.set(pkid as usize, false); + let event = Event::Outgoing(Outgoing::PubComp(pkid)); + let pubcomp = PubComp { pkid }; self.events.push_back(event); Ok(Some(Packet::PubComp(pubcomp))) } - fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { - if !self.outgoing_rel.contains(pubcomp.pkid as usize) { - error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); - return Err(StateError::Unsolicited(pubcomp.pkid)); + fn handle_incoming_pubcomp(&mut self, pubcomp: PubComp) -> Result, StateError> { + let pkid = pubcomp.pkid; + if !self.outgoing_rel.contains(pkid as usize) { + error!("Unsolicited pubcomp packet: {pkid:?}"); + return Err(StateError::Unsolicited(pkid)); } - self.outgoing_rel.set(pubcomp.pkid as usize, false); + let Some(resolver) = self.pub_ack_waiter.remove(&pkid) else { + return Err(StateError::Unsolicited(pkid)); + }; + + // Resolve promise for QoS 2 + resolver.resolve(AckOfPub::PubComp(pubcomp)); + + self.outgoing_rel.set(pkid as usize, false); self.inflight -= 1; - let packet = self.check_collision(pubcomp.pkid).map(|publish| { + let packet = self.check_collision(pkid).map(|(publish, resolver)| { let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); self.collision_ping_count = 0; + self.pub_ack_waiter.insert(publish.pkid, resolver); Packet::Publish(publish) }); @@ -308,7 +378,11 @@ impl MqttState { /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { + fn outgoing_publish( + &mut self, + mut publish: Publish, + resolver: Resolver, + ) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); @@ -322,7 +396,7 @@ impl MqttState { .is_some() { info!("Collision on packet id = {:?}", publish.pkid); - self.collision = Some(publish); + self.collision = Some((publish, resolver)); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); return Ok(None); @@ -343,16 +417,26 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); self.events.push_back(event); + if publish.qos == QoS::AtMostOnce { + resolver.resolve(AckOfPub::None); + } else { + self.pub_ack_waiter.insert(publish.pkid, resolver); + } Ok(Some(Packet::Publish(publish))) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { + fn outgoing_pubrel( + &mut self, + pubrel: PubRel, + resolver: Resolver, + ) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); self.events.push_back(event); + self.pub_ack_waiter.insert(pubrel.pkid, resolver); Ok(Some(Packet::PubRel(pubrel))) } @@ -364,9 +448,14 @@ impl MqttState { Ok(Some(Packet::PubAck(puback))) } - fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { + fn outgoing_pubrec( + &mut self, + pubrec: PubRec, + resolver: Resolver, + ) -> Result, StateError> { let event = Event::Outgoing(Outgoing::PubRec(pubrec.pkid)); self.events.push_back(event); + self.pub_rel_waiter.insert(pubrec.pkid, resolver); Ok(Some(Packet::PubRec(pubrec))) } @@ -409,6 +498,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, + resolver: Resolver, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -424,6 +514,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Subscribe(subscription.pkid)); self.events.push_back(event); + self.sub_ack_waiter.insert(subscription.pkid, resolver); Ok(Some(Packet::Subscribe(subscription))) } @@ -431,6 +522,7 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, + resolver: Resolver, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -442,6 +534,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Unsubscribe(unsub.pkid)); self.events.push_back(event); + self.unsub_ack_waiter.insert(unsub.pkid, resolver); Ok(Some(Packet::Unsubscribe(unsub))) } @@ -455,8 +548,8 @@ impl MqttState { Ok(Some(Packet::Disconnect)) } - fn check_collision(&mut self, pkid: u16) -> Option { - if let Some(publish) = &self.collision { + fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Resolver)> { + if let Some((publish, _)) = &self.collision { if publish.pkid == pkid { return self.collision.take(); } @@ -504,7 +597,8 @@ impl MqttState { mod test { use super::{MqttState, StateError}; use crate::mqttbytes::v4::*; - use crate::mqttbytes::*; + use crate::tokens::Resolver; + use crate::{mqttbytes::*, Pkid}; use crate::{Event, Incoming, Outgoing, Request}; fn build_outgoing_publish(qos: QoS) -> Publish { @@ -555,7 +649,8 @@ mod test { let publish = build_outgoing_publish(QoS::AtMostOnce); // QoS 0 publish shouldn't be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 0); @@ -563,12 +658,14 @@ mod test { let publish = build_outgoing_publish(QoS::AtLeastOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 2); assert_eq!(mqtt.inflight, 2); @@ -576,12 +673,14 @@ mod test { let publish = build_outgoing_publish(QoS::ExactlyOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 3); assert_eq!(mqtt.inflight, 3); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 4); assert_eq!(mqtt.inflight, 4); } @@ -595,9 +694,9 @@ mod test { let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&publish1).unwrap(); - mqtt.handle_incoming_publish(&publish2).unwrap(); - mqtt.handle_incoming_publish(&publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); // only qos2 publish should be add to queue assert!(mqtt.incoming_pub.contains(3)); @@ -612,9 +711,9 @@ mod test { let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&publish1).unwrap(); - mqtt.handle_incoming_publish(&publish2).unwrap(); - mqtt.handle_incoming_publish(&publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); if let Event::Outgoing(Outgoing::PubAck(pkid)) = mqtt.events[0] { assert_eq!(pkid, 2); @@ -639,9 +738,9 @@ mod test { let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&publish1).unwrap(); - mqtt.handle_incoming_publish(&publish2).unwrap(); - mqtt.handle_incoming_publish(&publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); assert!(mqtt.incoming_pub.contains(3)); @@ -653,7 +752,7 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - let packet = mqtt.handle_incoming_publish(&publish).unwrap().unwrap(); + let packet = mqtt.handle_incoming_publish(publish).unwrap().unwrap(); match packet { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), _ => panic!("Invalid network request: {:?}", packet), @@ -667,14 +766,16 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish1).unwrap(); - mqtt.outgoing_publish(publish2).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish1, resolver).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish2, resolver).unwrap(); assert_eq!(mqtt.inflight, 2); - mqtt.handle_incoming_puback(&PubAck::new(1)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(1)).unwrap(); assert_eq!(mqtt.inflight, 1); - mqtt.handle_incoming_puback(&PubAck::new(2)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(2)).unwrap(); assert_eq!(mqtt.inflight, 0); assert!(mqtt.outgoing_pub[1].is_none()); @@ -685,7 +786,7 @@ mod test { fn incoming_puback_with_pkid_greater_than_max_inflight_should_be_handled_gracefully() { let mut mqtt = build_mqttstate(); - let got = mqtt.handle_incoming_puback(&PubAck::new(101)).unwrap_err(); + let got = mqtt.handle_incoming_puback(PubAck::new(101)).unwrap_err(); match got { StateError::Unsolicited(pkid) => assert_eq!(pkid, 101), @@ -700,10 +801,12 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - let _publish_out = mqtt.outgoing_publish(publish1); - let _publish_out = mqtt.outgoing_publish(publish2); + let resolver = Resolver::mock(); + let _publish_out = mqtt.outgoing_publish(publish1, resolver); + let resolver = Resolver::mock(); + let _publish_out = mqtt.outgoing_publish(publish2, resolver); - mqtt.handle_incoming_pubrec(&PubRec::new(2)).unwrap(); + mqtt.handle_incoming_pubrec(PubRec::new(2)).unwrap(); assert_eq!(mqtt.inflight, 2); // check if the remaining element's pkid is 1 @@ -719,14 +822,15 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - let packet = mqtt.outgoing_publish(publish).unwrap().unwrap(); + let resolver = Resolver::mock(); + let packet = mqtt.outgoing_publish(publish, resolver).unwrap().unwrap(); match packet { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } let packet = mqtt - .handle_incoming_pubrec(&PubRec::new(1)) + .handle_incoming_pubrec(PubRec::new(1)) .unwrap() .unwrap(); match packet { @@ -740,14 +844,14 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - let packet = mqtt.handle_incoming_publish(&publish).unwrap().unwrap(); + let packet = mqtt.handle_incoming_publish(publish).unwrap().unwrap(); match packet { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } let packet = mqtt - .handle_incoming_pubrel(&PubRel::new(1)) + .handle_incoming_pubrel(PubRel::new(1)) .unwrap() .unwrap(); match packet { @@ -761,10 +865,11 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish).unwrap(); - mqtt.handle_incoming_pubrec(&PubRec::new(1)).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); + mqtt.handle_incoming_pubrec(PubRec::new(1)).unwrap(); - mqtt.handle_incoming_pubcomp(&PubComp::new(1)).unwrap(); + mqtt.handle_incoming_pubcomp(PubComp::new(1)).unwrap(); assert_eq!(mqtt.inflight, 0); } @@ -775,7 +880,8 @@ mod test { // network activity other than pingresp let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(Request::Publish(publish)) + let resolver = Resolver::mock(); + mqtt.handle_outgoing_packet(Request::Publish(publish, resolver)) .unwrap(); mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1))) .unwrap(); @@ -804,8 +910,8 @@ mod test { fn clean_is_calculating_pending_correctly() { let mut mqtt = build_mqttstate(); - fn build_outgoing_pub() -> Vec> { - vec![ + fn build_outgoing_pub(state: &mut MqttState) { + state.outgoing_pub = vec![ None, Some(Publish { dup: false, @@ -841,39 +947,47 @@ mod test { pkid: 6, payload: "".into(), }), - ] + ]; + for (i, _) in state + .outgoing_pub + .iter() + .enumerate() + .filter(|(_, p)| p.is_some()) + { + state.pub_ack_waiter.insert(i as Pkid, Resolver::mock()); + } } - mqtt.outgoing_pub = build_outgoing_pub(); + build_outgoing_pub(&mut mqtt); mqtt.last_puback = 3; let requests = mqtt.clean(); let res = vec![6, 1, 2, 3]; for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(publish) = req { + if let Request::Publish(publish, _) = &req { assert_eq!(publish.pkid, idx); } else { unreachable!() } } - mqtt.outgoing_pub = build_outgoing_pub(); + build_outgoing_pub(&mut mqtt); mqtt.last_puback = 0; let requests = mqtt.clean(); let res = vec![1, 2, 3, 6]; for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(publish) = req { + if let Request::Publish(publish, _) = &req { assert_eq!(publish.pkid, idx); } else { unreachable!() } } - mqtt.outgoing_pub = build_outgoing_pub(); + build_outgoing_pub(&mut mqtt); mqtt.last_puback = 6; let requests = mqtt.clean(); let res = vec![1, 2, 3, 6]; for (req, idx) in requests.iter().zip(res) { - if let Request::Publish(publish) = req { + if let Request::Publish(publish, _) = &req { assert_eq!(publish.pkid, idx); } else { unreachable!() diff --git a/rumqttc/src/tokens.rs b/rumqttc/src/tokens.rs new file mode 100644 index 000000000..040fedb4e --- /dev/null +++ b/rumqttc/src/tokens.rs @@ -0,0 +1,94 @@ +use std::{ + fmt::Debug, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::sync::oneshot::{self, error::TryRecvError}; + +#[derive(Debug, thiserror::Error)] +pub enum TokenError { + #[error("Sender has nothing to send instantly")] + Waiting, + #[error("Sender side of channel was dropped")] + Disconnected, +} + +pub type NoResponse = (); + +/// Resolves with [`Pkid`] used against packet when: +/// 1. Packet is acknowldged by the broker, e.g. QoS 1/2 Publish, Subscribe and Unsubscribe +/// 2. QoS 0 packet finishes processing in the [`EventLoop`] +pub struct Token { + rx: oneshot::Receiver, +} + +impl Future for Token { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let polled = unsafe { self.map_unchecked_mut(|s| &mut s.rx) }.poll(cx); + + match polled { + Poll::Ready(Ok(p)) => Poll::Ready(Ok(p)), + Poll::Ready(Err(_)) => Poll::Ready(Err(TokenError::Disconnected)), + Poll::Pending => Poll::Pending, + } + } +} + +/// There is a type of token returned for each type of [`Request`] when it is created and +/// sent to the [`EventLoop`] for further processing from the [`Client`]/[`AsyncClient`]. +/// Some tokens such as those associated with the resolve with the `pkid` value used in the packet sent to the broker while other +/// tokens don't return such a value. +impl Token { + /// Blocks on the current thread and waits till the packet completes being handled. + /// + /// ## Errors + /// Returns [`TokenError::Disconnected`] if the [`EventLoop`] was dropped(usually), + /// [`TokenError::Rejection`] if the packet acknowledged but not accepted. + pub fn wait(self) -> Result { + self.rx + .blocking_recv() + .map_err(|_| TokenError::Disconnected) + } + + /// Attempts to check if the packet handling has been completed, without blocking the current thread. + /// + /// ## Errors + /// Returns [`TokenError::Waiting`] if the packet wasn't acknowledged yet. + /// Multiple calls to this functions can fail with [`TokenError::Disconnected`] + /// if the promise has already been resolved. + pub fn check(&mut self) -> Result { + self.rx.try_recv().map_err(|e| match e { + TryRecvError::Empty => TokenError::Waiting, + TryRecvError::Closed => TokenError::Disconnected, + }) + } +} + +#[derive(Debug)] +pub struct Resolver { + tx: oneshot::Sender, +} + +impl Resolver { + pub fn new() -> (Self, Token) { + let (tx, rx) = oneshot::channel(); + + (Self { tx }, Token { rx }) + } + + #[cfg(test)] + pub fn mock() -> Self { + let (tx, _) = oneshot::channel(); + + Self { tx } + } + + pub fn resolve(self, resolved: T) { + if self.tx.send(resolved).is_err() { + trace!("Promise was dropped") + } + } +} diff --git a/rumqttc/src/v5/client.rs b/rumqttc/src/v5/client.rs index 7a86333f2..e3f0f04c5 100644 --- a/rumqttc/src/v5/client.rs +++ b/rumqttc/src/v5/client.rs @@ -3,11 +3,12 @@ use std::time::Duration; use super::mqttbytes::v5::{ - Filter, PubAck, PubRec, Publish, PublishProperties, Subscribe, SubscribeProperties, - Unsubscribe, UnsubscribeProperties, + Filter, PubAck, PubRec, Publish, PublishProperties, SubAck, Subscribe, SubscribeProperties, + UnsubAck, Unsubscribe, UnsubscribeProperties, }; use super::mqttbytes::QoS; -use super::{ConnectionError, Event, EventLoop, MqttOptions, Request}; +use super::{AckOfAck, AckOfPub, ConnectionError, Event, EventLoop, MqttOptions, Request}; +use crate::tokens::{NoResponse, Resolver, Token}; use crate::{valid_filter, valid_topic}; use bytes::Bytes; @@ -37,7 +38,7 @@ impl From> for ClientError { } } -/// An asynchronous client, communicates with MQTT `EventLoop`. +// An asynchronous client, communicates with MQTT `EventLoop`. /// /// This is cloneable and can be used to asynchronously [`publish`](`AsyncClient::publish`), /// [`subscribe`](`AsyncClient::subscribe`) through the `EventLoop`, which is to be polled parallelly. @@ -78,20 +79,22 @@ impl AsyncClient { retain: bool, payload: P, properties: Option, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, P: Into, { + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let publish = Request::Publish(publish); + let publish = Request::Publish(publish, resolver); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } self.request_tx.send_async(publish).await?; - Ok(()) + + Ok(token) } pub async fn publish_with_properties( @@ -101,7 +104,7 @@ impl AsyncClient { retain: bool, payload: P, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -116,7 +119,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: P, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -132,20 +135,22 @@ impl AsyncClient { retain: bool, payload: P, properties: Option, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, P: Into, { + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let publish = Request::Publish(publish); + let publish = Request::Publish(publish, resolver); if !valid_topic(&topic) { return Err(ClientError::TryRequest(publish)); } self.request_tx.try_send(publish)?; - Ok(()) + + Ok(token) } pub fn try_publish_with_properties( @@ -155,7 +160,7 @@ impl AsyncClient { retain: bool, payload: P, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -169,7 +174,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: P, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -178,22 +183,26 @@ impl AsyncClient { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub async fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); + pub async fn ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { self.request_tx.send_async(ack).await?; } - Ok(()) + + Ok(token) } /// Attempts to send a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); + pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { self.request_tx.try_send(ack)?; } - Ok(()) + + Ok(token) } /// Sends a MQTT Publish to the `EventLoop` @@ -204,19 +213,18 @@ impl AsyncClient { retain: bool, payload: Bytes, properties: Option, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, { + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let publish = Request::Publish(publish); - if !valid_topic(&topic) { - return Err(ClientError::TryRequest(publish)); - } + let publish = Request::Publish(publish, resolver); self.request_tx.send_async(publish).await?; - Ok(()) + + Ok(token) } pub async fn publish_bytes_with_properties( @@ -226,7 +234,7 @@ impl AsyncClient { retain: bool, payload: Bytes, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, { @@ -240,7 +248,7 @@ impl AsyncClient { qos: QoS, retain: bool, payload: Bytes, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, { @@ -254,15 +262,18 @@ impl AsyncClient { topic: S, qos: QoS, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } + self.request_tx.send_async(request).await?; - self.request_tx.send_async(subscribe.into()).await?; - Ok(()) + Ok(token) } pub async fn subscribe_with_properties>( @@ -270,11 +281,15 @@ impl AsyncClient { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result, ClientError> { self.handle_subscribe(topic, qos, Some(properties)).await } - pub async fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub async fn subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result, ClientError> { self.handle_subscribe(topic, qos, None).await } @@ -284,15 +299,18 @@ impl AsyncClient { topic: S, qos: QoS, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::TryRequest(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::TryRequest(request)); } + self.request_tx.try_send(request)?; - self.request_tx.try_send(subscribe.into())?; - Ok(()) + Ok(token) } pub fn try_subscribe_with_properties>( @@ -300,11 +318,15 @@ impl AsyncClient { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result, ClientError> { self.handle_try_subscribe(topic, qos, Some(properties)) } - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn try_subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result, ClientError> { self.handle_try_subscribe(topic, qos, None) } @@ -313,32 +335,34 @@ impl AsyncClient { &self, topics: T, properties: Option, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where T: IntoIterator, { + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } + self.request_tx.send_async(request).await?; - self.request_tx.send_async(subscribe.into()).await?; - - Ok(()) + Ok(token) } pub async fn subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)).await } - pub async fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub async fn subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -350,31 +374,34 @@ impl AsyncClient { &self, topics: T, properties: Option, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where T: IntoIterator, { + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::TryRequest(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::TryRequest(request)); } + self.request_tx.try_send(request)?; - self.request_tx.try_send(subscribe.into())?; - Ok(()) + Ok(token) } pub fn try_subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where T: IntoIterator, { self.handle_try_subscribe_many(topics, Some(properties)) } - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -386,22 +413,27 @@ impl AsyncClient { &self, topic: S, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic, properties); - let request = Request::Unsubscribe(unsubscribe); + let request = Request::Unsubscribe(unsubscribe, resolver); self.request_tx.send_async(request).await?; - Ok(()) + + Ok(token) } pub async fn unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result, ClientError> { self.handle_unsubscribe(topic, Some(properties)).await } - pub async fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub async fn unsubscribe>( + &self, + topic: S, + ) -> Result, ClientError> { self.handle_unsubscribe(topic, None).await } @@ -410,45 +442,57 @@ impl AsyncClient { &self, topic: S, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic, properties); - let request = Request::Unsubscribe(unsubscribe); + let request = Request::Unsubscribe(unsubscribe, resolver); self.request_tx.try_send(request)?; - Ok(()) + + Ok(token) } pub fn try_unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result, ClientError> { self.handle_try_unsubscribe(topic, Some(properties)) } - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn try_unsubscribe>( + &self, + topic: S, + ) -> Result, ClientError> { self.handle_try_unsubscribe(topic, None) } /// Sends a MQTT disconnect to the `EventLoop` - pub async fn disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect; + pub async fn disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); self.request_tx.send_async(request).await?; - Ok(()) + + Ok(token) } /// Attempts to send a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect; + pub fn try_disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); self.request_tx.try_send(request)?; - Ok(()) + + Ok(token) } } -fn get_ack_req(publish: &Publish) -> Option { +fn get_ack_req(publish: &Publish, resolver: Resolver) -> Option { let ack = match publish.qos { - QoS::AtMostOnce => return None, - QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid, None)), - QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid, None)), + QoS::AtMostOnce => { + resolver.resolve(AckOfAck::None); + return None; + } + QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid, None), resolver), + QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid, None), resolver), }; Some(ack) } @@ -503,20 +547,22 @@ impl Client { retain: bool, payload: P, properties: Option, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, P: Into, { + let (resolver, token) = Resolver::new(); let topic = topic.into(); let mut publish = Publish::new(&topic, qos, payload, properties); publish.retain = retain; - let publish = Request::Publish(publish); + let publish = Request::Publish(publish, resolver); if !valid_topic(&topic) { return Err(ClientError::Request(publish)); } self.client.request_tx.send(publish)?; - Ok(()) + + Ok(token) } pub fn publish_with_properties( @@ -526,7 +572,7 @@ impl Client { retain: bool, payload: P, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -540,7 +586,7 @@ impl Client { qos: QoS, retain: bool, payload: P, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -555,7 +601,7 @@ impl Client { retain: bool, payload: P, properties: PublishProperties, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -570,7 +616,7 @@ impl Client { qos: QoS, retain: bool, payload: P, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where S: Into, P: Into, @@ -579,19 +625,20 @@ impl Client { } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn ack(&self, publish: &Publish) -> Result<(), ClientError> { - let ack = get_ack_req(publish); + pub fn ack(&self, publish: &Publish) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let ack = get_ack_req(publish, resolver); if let Some(ack) = ack { self.client.request_tx.send(ack)?; } - Ok(()) + + Ok(token) } /// Sends a MQTT PubAck to the `EventLoop`. Only needed in if `manual_acks` flag is set. - pub fn try_ack(&self, publish: &Publish) -> Result<(), ClientError> { - self.client.try_ack(publish)?; - Ok(()) + pub fn try_ack(&self, publish: &Publish) -> Result, ClientError> { + self.client.try_ack(publish) } /// Sends a MQTT Subscribe to the `EventLoop` @@ -600,15 +647,18 @@ impl Client { topic: S, qos: QoS, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let filter = Filter::new(topic, qos); let subscribe = Subscribe::new(filter, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } + self.client.request_tx.send(request)?; - self.client.request_tx.send(subscribe.into())?; - Ok(()) + Ok(token) } pub fn subscribe_with_properties>( @@ -616,11 +666,15 @@ impl Client { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result, ClientError> { self.handle_subscribe(topic, qos, Some(properties)) } - pub fn subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result, ClientError> { self.handle_subscribe(topic, qos, None) } @@ -630,12 +684,16 @@ impl Client { topic: S, qos: QoS, properties: SubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result, ClientError> { self.client .try_subscribe_with_properties(topic, qos, properties) } - pub fn try_subscribe>(&self, topic: S, qos: QoS) -> Result<(), ClientError> { + pub fn try_subscribe>( + &self, + topic: S, + qos: QoS, + ) -> Result, ClientError> { self.client.try_subscribe(topic, qos) } @@ -644,31 +702,34 @@ impl Client { &self, topics: T, properties: Option, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where T: IntoIterator, { + let (resolver, token) = Resolver::new(); let subscribe = Subscribe::new_many(topics, properties); - if !subscribe_has_valid_filters(&subscribe) { - return Err(ClientError::Request(subscribe.into())); + let is_valid = subscribe_has_valid_filters(&subscribe); + let request = Request::Subscribe(subscribe, resolver); + if !is_valid { + return Err(ClientError::Request(request)); } + self.client.request_tx.send(request)?; - self.client.request_tx.send(subscribe.into())?; - Ok(()) + Ok(token) } pub fn subscribe_many_with_properties( &self, topics: T, properties: SubscribeProperties, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where T: IntoIterator, { self.handle_subscribe_many(topics, Some(properties)) } - pub fn subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -679,7 +740,7 @@ impl Client { &self, topics: T, properties: SubscribeProperties, - ) -> Result<(), ClientError> + ) -> Result, ClientError> where T: IntoIterator, { @@ -687,7 +748,7 @@ impl Client { .try_subscribe_many_with_properties(topics, properties) } - pub fn try_subscribe_many(&self, topics: T) -> Result<(), ClientError> + pub fn try_subscribe_many(&self, topics: T) -> Result, ClientError> where T: IntoIterator, { @@ -699,22 +760,24 @@ impl Client { &self, topic: S, properties: Option, - ) -> Result<(), ClientError> { + ) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); let unsubscribe = Unsubscribe::new(topic, properties); - let request = Request::Unsubscribe(unsubscribe); + let request = Request::Unsubscribe(unsubscribe, resolver); self.client.request_tx.send(request)?; - Ok(()) + + Ok(token) } pub fn unsubscribe_with_properties>( &self, topic: S, properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result, ClientError> { self.handle_unsubscribe(topic, Some(properties)) } - pub fn unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn unsubscribe>(&self, topic: S) -> Result, ClientError> { self.handle_unsubscribe(topic, None) } @@ -723,26 +786,30 @@ impl Client { &self, topic: S, properties: UnsubscribeProperties, - ) -> Result<(), ClientError> { + ) -> Result, ClientError> { self.client .try_unsubscribe_with_properties(topic, properties) } - pub fn try_unsubscribe>(&self, topic: S) -> Result<(), ClientError> { + pub fn try_unsubscribe>( + &self, + topic: S, + ) -> Result, ClientError> { self.client.try_unsubscribe(topic) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn disconnect(&self) -> Result<(), ClientError> { - let request = Request::Disconnect; + pub fn disconnect(&self) -> Result, ClientError> { + let (resolver, token) = Resolver::new(); + let request = Request::Disconnect(resolver); self.client.request_tx.send(request)?; - Ok(()) + + Ok(token) } /// Sends a MQTT disconnect to the `EventLoop` - pub fn try_disconnect(&self) -> Result<(), ClientError> { - self.client.try_disconnect()?; - Ok(()) + pub fn try_disconnect(&self) -> Result, ClientError> { + self.client.try_disconnect() } } diff --git a/rumqttc/src/v5/eventloop.rs b/rumqttc/src/v5/eventloop.rs index cd0568ada..ea361b4eb 100644 --- a/rumqttc/src/v5/eventloop.rs +++ b/rumqttc/src/v5/eventloop.rs @@ -130,7 +130,7 @@ impl EventLoop { requests_in_channel.retain(|request| { match request { - Request::PubAck(_) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack + Request::PubAck(..) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack _ => true, } }); diff --git a/rumqttc/src/v5/mod.rs b/rumqttc/src/v5/mod.rs index 6e0e43931..d1b044b31 100644 --- a/rumqttc/src/v5/mod.rs +++ b/rumqttc/src/v5/mod.rs @@ -14,8 +14,8 @@ mod framed; pub mod mqttbytes; mod state; -use crate::Outgoing; -use crate::{NetworkOptions, Transport}; +use crate::tokens::Resolver; +use crate::{NetworkOptions, Outgoing, Transport}; use mqttbytes::v5::*; @@ -31,28 +31,33 @@ pub use crate::proxy::{Proxy, ProxyAuth, ProxyType}; pub type Incoming = Packet; -/// Requests by the client to mqtt event loop. Request are -/// handled one by one. -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum Request { - Publish(Publish), +/// Used to encapsulate all publish acknowledgents in v5 +#[derive(Debug)] +pub enum AckOfPub { PubAck(PubAck), - PubRec(PubRec), PubComp(PubComp), + None, +} + +/// Used to encapsulate all ack/pubrel acknowledgements in v5 +#[derive(Debug)] +pub enum AckOfAck { + None, PubRel(PubRel), - PingReq, - PingResp, - Subscribe(Subscribe), - SubAck(SubAck), - Unsubscribe(Unsubscribe), - UnsubAck(UnsubAck), - Disconnect, } -impl From for Request { - fn from(subscribe: Subscribe) -> Self { - Self::Subscribe(subscribe) - } +/// Requests by the client to mqtt event loop. Request are +/// handled one by one. +#[derive(Debug)] +pub enum Request { + Publish(Publish, Resolver), + PubAck(PubAck, Resolver), + PubRec(PubRec, Resolver), + PubRel(PubRel, Resolver), + Subscribe(Subscribe, Resolver), + Unsubscribe(Unsubscribe, Resolver), + Disconnect(Resolver<()>), + PingReq, } #[cfg(feature = "websocket")] diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 9a7485f3b..b49a2c2cd 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -1,11 +1,17 @@ -use super::mqttbytes::v5::{ - ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck, - PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, Publish, - SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason, Unsubscribe, +use crate::{tokens::Resolver, Pkid}; + +use super::{ + mqttbytes::{ + self, + v5::{ + ConnAck, ConnectReturnCode, Disconnect, DisconnectReasonCode, Packet, PingReq, PubAck, + PubAckReason, PubComp, PubCompReason, PubRec, PubRecReason, PubRel, PubRelReason, + Publish, SubAck, Subscribe, SubscribeReasonCode, UnsubAck, UnsubAckReason, Unsubscribe, + }, + Error as MqttError, QoS, + }, + AckOfAck, AckOfPub, Event, Incoming, Outgoing, Request, }; -use super::mqttbytes::{self, Error as MqttError, QoS}; - -use super::{Event, Incoming, Outgoing, Request}; use bytes::Bytes; use fixedbitset::FixedBitSet; @@ -74,7 +80,7 @@ impl From for StateError { // This is done for 2 reasons // Bad acks or out of order acks aren't O(n) causing cpu spikes // Any missing acks from the broker are detected during the next recycled use of packet ids -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct MqttState { /// Status of last ping pub await_pingresp: bool, @@ -97,7 +103,7 @@ pub struct MqttState { /// Packet ids on incoming QoS 2 publishes pub(crate) incoming_pub: FixedBitSet, /// Last collision due to broker not acking in order - pub collision: Option, + pub collision: Option<(Publish, Resolver)>, /// Buffered incoming packets pub events: VecDeque, /// Indicates if acknowledgements should be send immediately @@ -110,6 +116,14 @@ pub struct MqttState { pub(crate) max_outgoing_inflight: u16, /// Upper limit on the maximum number of allowed inflight QoS1 & QoS2 requests max_outgoing_inflight_upper_limit: u16, + /// Waiters for publish acknowledgements, qos 1/2 + pub_ack_waiter: HashMap>, + /// Waiters for PubRel, qos 2 + pub_rel_waiter: HashMap>, + /// Waiters for subscribe acknowledgements + sub_ack_waiter: HashMap>, + /// Waiters for unsubscribe acknowledgements + unsub_ack_waiter: HashMap>, } impl MqttState { @@ -137,6 +151,10 @@ impl MqttState { broker_topic_alias_max: 0, max_outgoing_inflight: max_inflight, max_outgoing_inflight_upper_limit: max_inflight, + pub_ack_waiter: HashMap::with_capacity(max_inflight as usize), + pub_rel_waiter: HashMap::with_capacity(max_inflight as usize), + sub_ack_waiter: HashMap::with_capacity(max_inflight as usize), + unsub_ack_waiter: HashMap::with_capacity(max_inflight as usize), } } @@ -146,16 +164,24 @@ impl MqttState { // remove and collect pending publishes for publish in self.outgoing_pub.iter_mut() { if let Some(publish) = publish.take() { - let request = Request::Publish(publish); + let resolver = self.pub_ack_waiter.remove(&publish.pkid).unwrap(); + let request = Request::Publish(publish, resolver); pending.push(request); } } // remove and collect pending releases for pkid in self.outgoing_rel.ones() { - let request = Request::PubRel(PubRel::new(pkid as u16, None)); + let resolver = self.pub_ack_waiter.remove(&(pkid as u16)).unwrap(); + let request = Request::PubRel(PubRel::new(pkid as u16, None), resolver); pending.push(request); } + + // we don't retransmit subscribe and unsubscribe packet + // so we can clear their state + self.sub_ack_waiter.clear(); + self.unsub_ack_waiter.clear(); + self.outgoing_rel.clear(); // remove packed ids of incoming qos2 publishes @@ -178,17 +204,24 @@ impl MqttState { request: Request, ) -> Result, StateError> { let packet = match request { - Request::Publish(publish) => self.outgoing_publish(publish)?, - Request::PubRel(pubrel) => self.outgoing_pubrel(pubrel)?, - Request::Subscribe(subscribe) => self.outgoing_subscribe(subscribe)?, - Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?, + Request::Publish(publish, resolver) => self.outgoing_publish(publish, resolver)?, + Request::PubRel(pubrel, resolver) => self.outgoing_pubrel(pubrel, resolver)?, + Request::Subscribe(subscribe, resolver) => { + self.outgoing_subscribe(subscribe, resolver)? + } + Request::Unsubscribe(unsubscribe, resolver) => { + self.outgoing_unsubscribe(unsubscribe, resolver)? + } Request::PingReq => self.outgoing_ping()?, - Request::Disconnect => { + Request::Disconnect(resolver) => { + resolver.resolve(()); self.outgoing_disconnect(DisconnectReasonCode::NormalDisconnection)? } - Request::PubAck(puback) => self.outgoing_puback(puback)?, - Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?, - _ => unimplemented!(), + Request::PubAck(puback, resolver) => { + resolver.resolve(super::AckOfAck::None); + self.outgoing_puback(puback)? + } + Request::PubRec(pubrec, resolver) => self.outgoing_pubrec(pubrec, resolver)?, }; self.last_outgoing = Instant::now(); @@ -201,11 +234,11 @@ impl MqttState { /// be forwarded to user and Pubck packet will be written to network pub fn handle_incoming_packet( &mut self, - mut packet: Incoming, + packet: Incoming, ) -> Result, StateError> { self.events.push_back(Event::Incoming(packet.to_owned())); - let outgoing = match &mut packet { + let outgoing = match packet { Incoming::PingResp(_) => self.handle_incoming_pingresp()?, Incoming::Publish(publish) => self.handle_incoming_publish(publish)?, Incoming::SubAck(suback) => self.handle_incoming_suback(suback)?, @@ -231,10 +264,12 @@ impl MqttState { self.outgoing_disconnect(DisconnectReasonCode::ProtocolError) } - fn handle_incoming_suback( - &mut self, - suback: &mut SubAck, - ) -> Result, StateError> { + fn handle_incoming_suback(&mut self, suback: SubAck) -> Result, StateError> { + // Expected ack for a subscribe packet, not a publish packet + let Some(resolver) = self.sub_ack_waiter.remove(&suback.pkid) else { + return Err(StateError::Unsolicited(suback.pkid)); + }; + for reason in suback.return_codes.iter() { match reason { SubscribeReasonCode::Success(qos) => { @@ -245,25 +280,32 @@ impl MqttState { } } } + + resolver.resolve(suback); + Ok(None) } fn handle_incoming_unsuback( &mut self, - unsuback: &mut UnsubAck, + unsuback: UnsubAck, ) -> Result, StateError> { + let Some(resolver) = self.unsub_ack_waiter.remove(&unsuback.pkid) else { + return Err(StateError::Unsolicited(unsuback.pkid)); + }; + for reason in unsuback.reasons.iter() { if reason != &UnsubAckReason::Success { warn!("UnsubAck Pkid = {:?}, Reason = {:?}", unsuback.pkid, reason); } } + + resolver.resolve(unsuback); + Ok(None) } - fn handle_incoming_connack( - &mut self, - connack: &mut ConnAck, - ) -> Result, StateError> { + fn handle_incoming_connack(&mut self, connack: ConnAck) -> Result, StateError> { if connack.code != ConnectReturnCode::Success { return Err(StateError::ConnFail { reason: connack.code, @@ -287,7 +329,7 @@ impl MqttState { fn handle_incoming_disconn( &mut self, - disconn: &mut Disconnect, + disconn: Disconnect, ) -> Result, StateError> { let reason_code = disconn.reason_code; let reason_string = if let Some(props) = &disconn.properties { @@ -305,7 +347,7 @@ impl MqttState { /// in case of QoS1 and Replys rec in case of QoS while also storing the message fn handle_incoming_publish( &mut self, - publish: &mut Publish, + mut publish: Publish, ) -> Result, StateError> { let qos = publish.qos; @@ -341,23 +383,27 @@ impl MqttState { if !self.manual_acks { let pubrec = PubRec::new(pkid, None); - return self.outgoing_pubrec(pubrec); + let (resolver, _) = Resolver::new(); + return self.outgoing_pubrec(pubrec, resolver); } Ok(None) } } } - fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { - let publish = self - .outgoing_pub - .get_mut(puback.pkid as usize) - .ok_or(StateError::Unsolicited(puback.pkid))?; - - if publish.take().is_none() { + fn handle_incoming_puback(&mut self, puback: PubAck) -> Result, StateError> { + let Some(resolver) = self.pub_ack_waiter.remove(&puback.pkid) else { error!("Unsolicited puback packet: {:?}", puback.pkid); return Err(StateError::Unsolicited(puback.pkid)); - } + }; + + self.outgoing_pub + .get_mut(puback.pkid as usize) + .ok_or(StateError::Unsolicited(puback.pkid))? + .take(); + + // Resolve promise for QoS 1 + resolver.resolve(AckOfPub::PubAck(puback.clone())); self.inflight -= 1; @@ -371,7 +417,7 @@ impl MqttState { return Ok(None); } - if let Some(publish) = self.check_collision(puback.pkid) { + if let Some((publish, resolver)) = self.check_collision(puback.pkid) { self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); self.inflight += 1; @@ -379,6 +425,7 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); self.collision_ping_count = 0; + self.pub_ack_waiter.insert(puback.pkid, resolver); return Ok(Some(Packet::Publish(publish))); } @@ -386,7 +433,7 @@ impl MqttState { Ok(None) } - fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { + fn handle_incoming_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { let publish = self .outgoing_pub .get_mut(pubrec.pkid as usize) @@ -415,13 +462,16 @@ impl MqttState { Ok(Some(Packet::PubRel(PubRel::new(pubrec.pkid, None)))) } - fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result, StateError> { + fn handle_incoming_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { if !self.incoming_pub.contains(pubrel.pkid as usize) { error!("Unsolicited pubrel packet: {:?}", pubrel.pkid); return Err(StateError::Unsolicited(pubrel.pkid)); } self.incoming_pub.set(pubrel.pkid as usize, false); + let resolver = self.pub_rel_waiter.remove(&pubrel.pkid).unwrap(); + resolver.resolve(AckOfAck::PubRel(pubrel.clone())); + if pubrel.reason != PubRelReason::Success { warn!( "PubRel Pkid = {:?}, reason: {:?}", @@ -436,21 +486,26 @@ impl MqttState { Ok(Some(Packet::PubComp(PubComp::new(pubrel.pkid, None)))) } - fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result, StateError> { - let outgoing = self.check_collision(pubcomp.pkid).map(|publish| { - let pkid = publish.pkid; - let event = Event::Outgoing(Outgoing::Publish(pkid)); - self.events.push_back(event); - self.collision_ping_count = 0; - - Packet::Publish(publish) - }); - - if !self.outgoing_rel.contains(pubcomp.pkid as usize) { + fn handle_incoming_pubcomp(&mut self, pubcomp: PubComp) -> Result, StateError> { + let Some(resolver) = self.pub_ack_waiter.remove(&pubcomp.pkid) else { error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid); return Err(StateError::Unsolicited(pubcomp.pkid)); - } + }; + + // Resolve promise for QoS 2 + resolver.resolve(AckOfPub::PubComp(pubcomp.clone())); + self.outgoing_rel.set(pubcomp.pkid as usize, false); + let outgoing = self + .check_collision(pubcomp.pkid) + .map(|(publish, resolver)| { + let event = Event::Outgoing(Outgoing::Publish(publish.pkid)); + self.events.push_back(event); + self.collision_ping_count = 0; + self.pub_ack_waiter.insert(pubcomp.pkid, resolver); + + Packet::Publish(publish) + }); if pubcomp.reason != PubCompReason::Success { warn!( @@ -471,7 +526,11 @@ impl MqttState { /// Adds next packet identifier to QoS 1 and 2 publish packets and returns /// it buy wrapping publish in packet - fn outgoing_publish(&mut self, mut publish: Publish) -> Result, StateError> { + fn outgoing_publish( + &mut self, + mut publish: Publish, + resolver: Resolver, + ) -> Result, StateError> { if publish.qos != QoS::AtMostOnce { if publish.pkid == 0 { publish.pkid = self.next_pkid(); @@ -485,7 +544,7 @@ impl MqttState { .is_some() { info!("Collision on packet id = {:?}", publish.pkid); - self.collision = Some(publish); + self.collision = Some((publish, resolver)); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); self.events.push_back(event); return Ok(None); @@ -521,17 +580,27 @@ impl MqttState { let event = Event::Outgoing(Outgoing::Publish(pkid)); self.events.push_back(event); + if publish.qos == QoS::AtMostOnce { + resolver.resolve(AckOfPub::None) + } else { + self.pub_ack_waiter.insert(publish.pkid, resolver); + } Ok(Some(Packet::Publish(publish))) } - fn outgoing_pubrel(&mut self, pubrel: PubRel) -> Result, StateError> { + fn outgoing_pubrel( + &mut self, + pubrel: PubRel, + resolver: Resolver, + ) -> Result, StateError> { let pubrel = self.save_pubrel(pubrel)?; debug!("Pubrel. Pkid = {}", pubrel.pkid); let event = Event::Outgoing(Outgoing::PubRel(pubrel.pkid)); self.events.push_back(event); + self.pub_ack_waiter.insert(pubrel.pkid, resolver); Ok(Some(Packet::PubRel(PubRel::new(pubrel.pkid, None)))) } @@ -544,10 +613,15 @@ impl MqttState { Ok(Some(Packet::PubAck(puback))) } - fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result, StateError> { + fn outgoing_pubrec( + &mut self, + pubrec: PubRec, + resolver: Resolver, + ) -> Result, StateError> { let pkid = pubrec.pkid; let event = Event::Outgoing(Outgoing::PubRec(pkid)); self.events.push_back(event); + self.pub_rel_waiter.insert(pubrec.pkid, resolver); Ok(Some(Packet::PubRec(pubrec))) } @@ -587,6 +661,7 @@ impl MqttState { fn outgoing_subscribe( &mut self, mut subscription: Subscribe, + resolver: Resolver, ) -> Result, StateError> { if subscription.filters.is_empty() { return Err(StateError::EmptySubscription); @@ -603,6 +678,7 @@ impl MqttState { let pkid = subscription.pkid; let event = Event::Outgoing(Outgoing::Subscribe(pkid)); self.events.push_back(event); + self.sub_ack_waiter.insert(subscription.pkid, resolver); Ok(Some(Packet::Subscribe(subscription))) } @@ -610,6 +686,7 @@ impl MqttState { fn outgoing_unsubscribe( &mut self, mut unsub: Unsubscribe, + resolver: Resolver, ) -> Result, StateError> { let pkid = self.next_pkid(); unsub.pkid = pkid; @@ -622,6 +699,7 @@ impl MqttState { let pkid = unsub.pkid; let event = Event::Outgoing(Outgoing::Unsubscribe(pkid)); self.events.push_back(event); + self.unsub_ack_waiter.insert(unsub.pkid, resolver); Ok(Some(Packet::Unsubscribe(unsub))) } @@ -637,8 +715,8 @@ impl MqttState { Ok(Some(Packet::Disconnect(Disconnect::new(reason)))) } - fn check_collision(&mut self, pkid: u16) -> Option { - if let Some(publish) = &self.collision { + fn check_collision(&mut self, pkid: u16) -> Option<(Publish, Resolver)> { + if let Some((publish, _)) = &self.collision { if publish.pkid == pkid { return self.collision.take(); } @@ -684,6 +762,8 @@ impl MqttState { #[cfg(test)] mod test { + use crate::tokens::Resolver; + use super::mqttbytes::v5::*; use super::mqttbytes::*; use super::{Event, Incoming, Outgoing, Request}; @@ -737,7 +817,9 @@ mod test { let publish = build_outgoing_publish(QoS::AtMostOnce); // QoS 0 publish shouldn't be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 0); @@ -745,12 +827,15 @@ mod test { let publish = build_outgoing_publish(QoS::AtLeastOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 2); assert_eq!(mqtt.inflight, 2); @@ -758,12 +843,15 @@ mod test { let publish = build_outgoing_publish(QoS::ExactlyOnce); // Packet id should be set and publish should be saved in queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 3); assert_eq!(mqtt.inflight, 3); // Packet id should be incremented and publish should be saved in queue - mqtt.outgoing_publish(publish).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); assert_eq!(mqtt.last_pkid, 4); assert_eq!(mqtt.inflight, 4); } @@ -775,27 +863,31 @@ mod test { // QoS2 publish let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish.clone()).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 1); // Packet id should be set back down to 0, since we hit the limit - mqtt.outgoing_publish(publish.clone()).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 2); // This should cause a collition - mqtt.outgoing_publish(publish.clone()).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 1); assert_eq!(mqtt.inflight, 2); assert!(mqtt.collision.is_some()); - mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap(); - mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(1, None)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(2, None)).unwrap(); assert_eq!(mqtt.inflight, 1); // Now there should be space in the outgoing queue - mqtt.outgoing_publish(publish.clone()).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish.clone(), resolver).unwrap(); assert_eq!(mqtt.last_pkid, 0); assert_eq!(mqtt.inflight, 2); } @@ -805,13 +897,13 @@ mod test { let mut mqtt = build_mqttstate(); // QoS0, 1, 2 Publishes - let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&mut publish1).unwrap(); - mqtt.handle_incoming_publish(&mut publish2).unwrap(); - mqtt.handle_incoming_publish(&mut publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); // only qos2 publish should be add to queue assert!(mqtt.incoming_pub.contains(3)); @@ -822,13 +914,13 @@ mod test { let mut mqtt = build_mqttstate(); // QoS0, 1, 2 Publishes - let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&mut publish1).unwrap(); - mqtt.handle_incoming_publish(&mut publish2).unwrap(); - mqtt.handle_incoming_publish(&mut publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); if let Event::Outgoing(Outgoing::PubAck(pkid)) = mqtt.events[0] { assert_eq!(pkid, 2); @@ -849,13 +941,13 @@ mod test { mqtt.manual_acks = true; // QoS0, 1, 2 Publishes - let mut publish1 = build_incoming_publish(QoS::AtMostOnce, 1); - let mut publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); - let mut publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); + let publish1 = build_incoming_publish(QoS::AtMostOnce, 1); + let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2); + let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3); - mqtt.handle_incoming_publish(&mut publish1).unwrap(); - mqtt.handle_incoming_publish(&mut publish2).unwrap(); - mqtt.handle_incoming_publish(&mut publish3).unwrap(); + mqtt.handle_incoming_publish(publish1).unwrap(); + mqtt.handle_incoming_publish(publish2).unwrap(); + mqtt.handle_incoming_publish(publish3).unwrap(); assert!(mqtt.incoming_pub.contains(3)); assert!(mqtt.events.is_empty()); @@ -864,9 +956,9 @@ mod test { #[test] fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() { let mut mqtt = build_mqttstate(); - let mut publish = build_incoming_publish(QoS::ExactlyOnce, 1); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - match mqtt.handle_incoming_publish(&mut publish).unwrap().unwrap() { + match mqtt.handle_incoming_publish(publish).unwrap().unwrap() { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } @@ -879,14 +971,16 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish1).unwrap(); - mqtt.outgoing_publish(publish2).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish1, resolver).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish2, resolver).unwrap(); assert_eq!(mqtt.inflight, 2); - mqtt.handle_incoming_puback(&PubAck::new(1, None)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(1, None)).unwrap(); assert_eq!(mqtt.inflight, 1); - mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap(); + mqtt.handle_incoming_puback(PubAck::new(2, None)).unwrap(); assert_eq!(mqtt.inflight, 0); assert!(mqtt.outgoing_pub[1].is_none()); @@ -898,7 +992,7 @@ mod test { let mut mqtt = build_mqttstate(); let got = mqtt - .handle_incoming_puback(&PubAck::new(101, None)) + .handle_incoming_puback(PubAck::new(101, None)) .unwrap_err(); match got { @@ -914,10 +1008,12 @@ mod test { let publish1 = build_outgoing_publish(QoS::AtLeastOnce); let publish2 = build_outgoing_publish(QoS::ExactlyOnce); - let _publish_out = mqtt.outgoing_publish(publish1); - let _publish_out = mqtt.outgoing_publish(publish2); + let resolver = Resolver::mock(); + let _publish_out = mqtt.outgoing_publish(publish1, resolver); + let resolver = Resolver::mock(); + let _publish_out = mqtt.outgoing_publish(publish2, resolver); - mqtt.handle_incoming_pubrec(&PubRec::new(2, None)).unwrap(); + mqtt.handle_incoming_pubrec(PubRec::new(2, None)).unwrap(); assert_eq!(mqtt.inflight, 2); // check if the remaining element's pkid is 1 @@ -933,13 +1029,14 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - match mqtt.outgoing_publish(publish).unwrap().unwrap() { + let resolver = Resolver::mock(); + match mqtt.outgoing_publish(publish, resolver).unwrap().unwrap() { Packet::Publish(publish) => assert_eq!(publish.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } match mqtt - .handle_incoming_pubrec(&PubRec::new(1, None)) + .handle_incoming_pubrec(PubRec::new(1, None)) .unwrap() .unwrap() { @@ -951,15 +1048,15 @@ mod test { #[test] fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() { let mut mqtt = build_mqttstate(); - let mut publish = build_incoming_publish(QoS::ExactlyOnce, 1); + let publish = build_incoming_publish(QoS::ExactlyOnce, 1); - match mqtt.handle_incoming_publish(&mut publish).unwrap().unwrap() { + match mqtt.handle_incoming_publish(publish).unwrap().unwrap() { Packet::PubRec(pubrec) => assert_eq!(pubrec.pkid, 1), packet => panic!("Invalid network request: {:?}", packet), } match mqtt - .handle_incoming_pubrel(&PubRel::new(1, None)) + .handle_incoming_pubrel(PubRel::new(1, None)) .unwrap() .unwrap() { @@ -973,11 +1070,11 @@ mod test { let mut mqtt = build_mqttstate(); let publish = build_outgoing_publish(QoS::ExactlyOnce); - mqtt.outgoing_publish(publish).unwrap(); - mqtt.handle_incoming_pubrec(&PubRec::new(1, None)).unwrap(); + let resolver = Resolver::mock(); + mqtt.outgoing_publish(publish, resolver).unwrap(); + mqtt.handle_incoming_pubrec(PubRec::new(1, None)).unwrap(); - mqtt.handle_incoming_pubcomp(&PubComp::new(1, None)) - .unwrap(); + mqtt.handle_incoming_pubcomp(PubComp::new(1, None)).unwrap(); assert_eq!(mqtt.inflight, 0); } @@ -988,7 +1085,8 @@ mod test { // network activity other than pingresp let publish = build_outgoing_publish(QoS::AtLeastOnce); - mqtt.handle_outgoing_packet(Request::Publish(publish)) + let resolver = Resolver::mock(); + mqtt.handle_outgoing_packet(Request::Publish(publish, resolver)) .unwrap(); mqtt.handle_incoming_packet(Incoming::PubAck(PubAck::new(1, None))) .unwrap(); diff --git a/rumqttc/tests/broker.rs b/rumqttc/tests/broker.rs index 760a2ab37..609b381ed 100644 --- a/rumqttc/tests/broker.rs +++ b/rumqttc/tests/broker.rs @@ -9,9 +9,39 @@ use tokio::{task, time}; use bytes::BytesMut; use flume::{bounded, Receiver, Sender}; -use rumqttc::{Event, Incoming, Outgoing, Packet}; +use rumqttc::{Incoming, Packet}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +#[derive(Debug, PartialEq)] +pub enum Event { + Incoming(Packet), + Outgoing(Outgoing), +} + +#[derive(Debug, PartialEq)] +pub enum Outgoing { + /// Publish packet with packet identifier. 0 implies QoS 0 + Publish(u16), + /// SubAck packet with packet identifier + SubAck(u16), + /// UnsubAck packet with packet identifier + UnsubAck(u16), + /// PubAck packet + PubAck(u16), + /// PubRec packet + PubRec(u16), + /// PubRel packet + PubRel(u16), + /// PubComp packet + PubComp(u16), + /// Ping request packet + PingReq, + /// Ping response packet + PingResp, + /// Disconnect packet + Disconnect, +} + pub struct Broker { pub(crate) framed: Network, pub(crate) incoming: VecDeque, @@ -116,12 +146,36 @@ impl Broker { } } - /// Sends an acknowledgement - pub async fn ack(&mut self, pkid: u16) { + /// Sends a publish acknowledgement + pub async fn puback(&mut self, pkid: u16) { let packet = Packet::PubAck(PubAck::new(pkid)); self.framed.write(packet).await.unwrap(); } + /// Sends a publish record + pub async fn pubrec(&mut self, pkid: u16) { + let packet = Packet::PubRec(PubRec::new(pkid)); + self.framed.write(packet).await.unwrap(); + } + + /// Sends a publish complete + pub async fn pubcomp(&mut self, pkid: u16) { + let packet = Packet::PubComp(PubComp::new(pkid)); + self.framed.write(packet).await.unwrap(); + } + + /// Sends a subscribe acknowledgement + pub async fn suback(&mut self, pkid: u16, qos: QoS) { + let packet = Packet::SubAck(SubAck::new(pkid, vec![SubscribeReasonCode::Success(qos)])); + self.framed.write(packet).await.unwrap(); + } + + /// Sends an unsubscribe acknowledgement + pub async fn unsuback(&mut self, pkid: u16) { + let packet = Packet::UnsubAck(UnsubAck::new(pkid)); + self.framed.write(packet).await.unwrap(); + } + /// Sends an acknowledgement pub async fn pingresp(&mut self) { let packet = Packet::PingResp; @@ -296,8 +350,8 @@ fn outgoing(packet: &Packet) -> Outgoing { Packet::PubRec(pubrec) => Outgoing::PubRec(pubrec.pkid), Packet::PubRel(pubrel) => Outgoing::PubRel(pubrel.pkid), Packet::PubComp(pubcomp) => Outgoing::PubComp(pubcomp.pkid), - Packet::Subscribe(subscribe) => Outgoing::Subscribe(subscribe.pkid), - Packet::Unsubscribe(unsubscribe) => Outgoing::Unsubscribe(unsubscribe.pkid), + Packet::SubAck(suback) => Outgoing::SubAck(suback.pkid), + Packet::UnsubAck(unsuback) => Outgoing::UnsubAck(unsuback.pkid), Packet::PingReq => Outgoing::PingReq, Packet::PingResp => Outgoing::PingResp, Packet::Disconnect => Outgoing::Disconnect, diff --git a/rumqttc/tests/reliability.rs b/rumqttc/tests/reliability.rs index 633ca4706..098f697ef 100644 --- a/rumqttc/tests/reliability.rs +++ b/rumqttc/tests/reliability.rs @@ -1,6 +1,9 @@ use matches::assert_matches; use std::time::{Duration, Instant}; -use tokio::{task, time}; +use tokio::{ + task, + time::{self, timeout}, +}; mod broker; @@ -176,7 +179,7 @@ async fn some_outgoing_and_no_incoming_should_trigger_pings_on_time() { loop { let event = broker.tick().await; - if event == Event::Incoming(Incoming::PingReq) { + if event == broker::Event::Incoming(Incoming::PingReq) { // wait for 3 pings count += 1; if count == 3 { @@ -215,7 +218,7 @@ async fn some_incoming_and_no_outgoing_should_trigger_pings_on_time() { loop { let event = broker.tick().await; - if event == Event::Incoming(Incoming::PingReq) { + if event == broker::Event::Incoming(Incoming::PingReq) { // wait for 3 pings count += 1; if count == 3 { @@ -317,12 +320,12 @@ async fn requests_are_recovered_after_inflight_queue_size_falls_below_max() { assert!(broker.read_publish().await.is_none()); // ack packet 1 and client would produce packet 4 - broker.ack(1).await; + broker.puback(1).await; assert!(broker.read_publish().await.is_some()); assert!(broker.read_publish().await.is_none()); // ack packet 2 and client would produce packet 5 - broker.ack(2).await; + broker.puback(2).await; assert!(broker.read_publish().await.is_some()); assert!(broker.read_publish().await.is_none()); } @@ -350,18 +353,18 @@ async fn packet_id_collisions_are_detected_and_flow_control_is_applied() { } // out of order ack - broker.ack(3).await; - broker.ack(4).await; + broker.puback(3).await; + broker.puback(4).await; time::sleep(Duration::from_secs(5)).await; - broker.ack(1).await; - broker.ack(2).await; + broker.puback(1).await; + broker.puback(2).await; // read and ack remaining packets in order for i in 5..=15 { let packet = broker.read_publish().await; let packet = packet.unwrap(); assert_eq!(packet.payload[0], i); - broker.ack(packet.pkid).await; + broker.puback(packet.pkid).await; } time::sleep(Duration::from_secs(10)).await; @@ -373,7 +376,7 @@ async fn packet_id_collisions_are_detected_and_flow_control_is_applied() { // Poll until there is collision. loop { match eventloop.poll().await.unwrap() { - Event::Outgoing(Outgoing::AwaitAck(1)) => break, + rumqttc::Event::Outgoing(rumqttc::Outgoing::AwaitAck(1)) => break, v => { println!("Poll = {v:?}"); continue; @@ -387,7 +390,7 @@ async fn packet_id_collisions_are_detected_and_flow_control_is_applied() { println!("Poll = {event:?}"); match event { - Event::Outgoing(Outgoing::Publish(ack)) => { + rumqttc::Event::Outgoing(rumqttc::Outgoing::Publish(ack)) => { if ack == 1 { let elapsed = start.elapsed().as_millis() as i64; let deviation_millis: i64 = (5000 - elapsed).abs(); @@ -463,7 +466,7 @@ async fn next_poll_after_connect_failure_reconnects() { } match eventloop.poll().await { - Ok(Event::Incoming(Packet::ConnAck(ConnAck { + Ok(rumqttc::Event::Incoming(Packet::ConnAck(ConnAck { code: ConnectReturnCode::Success, session_present: false, }))) => (), @@ -495,7 +498,7 @@ async fn reconnection_resumes_from_the_previous_state() { for i in 1..=2 { let packet = broker.read_publish().await.unwrap(); assert_eq!(i, packet.payload[0]); - broker.ack(packet.pkid).await; + broker.puback(packet.pkid).await; } // NOTE: An interesting thing to notice here is that reassigning a new broker @@ -509,7 +512,7 @@ async fn reconnection_resumes_from_the_previous_state() { for i in 3..=4 { let packet = broker.read_publish().await.unwrap(); assert_eq!(i, packet.payload[0]); - broker.ack(packet.pkid).await; + broker.puback(packet.pkid).await; } } @@ -585,3 +588,313 @@ async fn state_is_being_cleaned_properly_and_pending_request_calculated_properly }); handle.await.unwrap(); } + +#[tokio::test] +async fn resolve_on_qos0_before_write_to_tcp_buffer() { + let options = MqttOptions::new("dummy", "127.0.0.1", 3005); + let (client, mut eventloop) = AsyncClient::new(options, 5); + + task::spawn(async move { + let res = run(&mut eventloop, false).await; + if let Err(e) = res { + match e { + ConnectionError::FlushTimeout => { + assert!(eventloop.network.is_none()); + println!("State is being clean properly"); + } + _ => { + println!("Couldn't fill the TCP send buffer to run this test properly. Try reducing the size of buffer."); + } + } + } + }); + + let mut broker = Broker::new(3005, 0, false).await; + + let token = client + .publish("hello/world", QoS::AtMostOnce, false, [1; 1]) + .await + .unwrap(); + + // Token can resolve as soon as it was processed by eventloop + assert_eq!( + timeout(Duration::from_secs(1), token) + .await + .unwrap() + .unwrap(), + AckOfPub::None + ); + + // Verify the packet still reached broker + // NOTE: this can't always be guaranteed + let Packet::Publish(Publish { + qos, + topic, + pkid, + payload, + .. + }) = broker.read_packet().await.unwrap() + else { + unreachable!() + }; + assert_eq!(topic, "hello/world"); + assert_eq!(qos, QoS::AtMostOnce); + assert_eq!(payload.to_vec(), [1; 1]); + assert_eq!(pkid, 0); +} + +#[tokio::test] +async fn resolve_on_qos1_ack_from_broker() { + let options = MqttOptions::new("dummy", "127.0.0.1", 3006); + let (client, mut eventloop) = AsyncClient::new(options, 5); + + task::spawn(async move { + let res = run(&mut eventloop, false).await; + if let Err(e) = res { + match e { + ConnectionError::FlushTimeout => { + assert!(eventloop.network.is_none()); + println!("State is being clean properly"); + } + _ => { + println!("Couldn't fill the TCP send buffer to run this test properly. Try reducing the size of buffer."); + } + } + } + }); + + let mut broker = Broker::new(3006, 0, false).await; + + let mut token = client + .publish("hello/world", QoS::AtLeastOnce, false, [1; 1]) + .await + .unwrap(); + + // Token shouldn't resolve before reaching broker + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + let Packet::Publish(Publish { + qos, + topic, + pkid, + payload, + .. + }) = broker.read_packet().await.unwrap() + else { + unreachable!() + }; + assert_eq!(topic, "hello/world"); + assert_eq!(qos, QoS::AtLeastOnce); + assert_eq!(payload.to_vec(), [1; 1]); + assert_eq!(pkid, 1); + + // Token shouldn't resolve until packet is acked + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + // Finally ack the packet + broker.puback(1).await; + + // Token shouldn't resolve until packet is acked + assert_eq!( + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap() + .unwrap(), + AckOfPub::PubAck(PubAck { pkid: 1 }) + ); +} + +#[tokio::test] +async fn resolve_on_qos2_ack_from_broker() { + let options = MqttOptions::new("dummy", "127.0.0.1", 3007); + let (client, mut eventloop) = AsyncClient::new(options, 5); + + task::spawn(async move { + let res = run(&mut eventloop, false).await; + if let Err(e) = res { + match e { + ConnectionError::FlushTimeout => { + assert!(eventloop.network.is_none()); + println!("State is being clean properly"); + } + _ => { + println!("Couldn't fill the TCP send buffer to run this test properly. Try reducing the size of buffer."); + } + } + } + }); + + let mut broker = Broker::new(3007, 0, false).await; + + let mut token = client + .publish("hello/world", QoS::ExactlyOnce, false, [1; 1]) + .await + .unwrap(); + + // Token shouldn't resolve before reaching broker + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + let Packet::Publish(Publish { + qos, + topic, + pkid, + payload, + .. + }) = broker.read_packet().await.unwrap() + else { + unreachable!() + }; + assert_eq!(topic, "hello/world"); + assert_eq!(qos, QoS::ExactlyOnce); + assert_eq!(payload.to_vec(), [1; 1]); + assert_eq!(pkid, 1); + + // Token shouldn't resolve till publish recorded + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + // Record the publish message + broker.pubrec(1).await; + + // Token shouldn't resolve till publish complete + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + // Complete the publish message ack + broker.pubcomp(1).await; + + // Finally the publish is QoS2 acked + assert_eq!( + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap() + .unwrap(), + AckOfPub::PubComp(PubComp { pkid: 1 }) + ); +} + +#[tokio::test] +async fn resolve_on_sub_ack_from_broker() { + let options = MqttOptions::new("dummy", "127.0.0.1", 3006); + let (client, mut eventloop) = AsyncClient::new(options, 5); + + task::spawn(async move { + let res = run(&mut eventloop, false).await; + if let Err(e) = res { + match e { + ConnectionError::FlushTimeout => { + assert!(eventloop.network.is_none()); + println!("State is being clean properly"); + } + _ => { + println!("Couldn't fill the TCP send buffer to run this test properly. Try reducing the size of buffer."); + } + } + } + }); + + let mut broker = Broker::new(3006, 0, false).await; + + let mut token = client + .subscribe("hello/world", QoS::AtLeastOnce) + .await + .unwrap(); + + // Token shouldn't resolve before reaching broker + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + let Packet::Subscribe(Subscribe { pkid, filters, .. }) = broker.read_packet().await.unwrap() + else { + unreachable!() + }; + assert_eq!( + filters, + [SubscribeFilter { + path: "hello/world".to_owned(), + qos: QoS::AtLeastOnce + }] + ); + assert_eq!(pkid, 1); + + // Token shouldn't resolve until packet is acked + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + // Finally ack the packet + broker.suback(1, QoS::AtLeastOnce).await; + + // Token shouldn't resolve until packet is acked + assert_eq!( + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap() + .unwrap() + .pkid, + 1 + ); +} + +#[tokio::test] +async fn resolve_on_unsub_ack_from_broker() { + let options = MqttOptions::new("dummy", "127.0.0.1", 3006); + let (client, mut eventloop) = AsyncClient::new(options, 5); + + task::spawn(async move { + let res = run(&mut eventloop, false).await; + if let Err(e) = res { + match e { + ConnectionError::FlushTimeout => { + assert!(eventloop.network.is_none()); + println!("State is being clean properly"); + } + _ => { + println!("Couldn't fill the TCP send buffer to run this test properly. Try reducing the size of buffer."); + } + } + } + }); + + let mut broker = Broker::new(3006, 0, false).await; + + let mut token = client.unsubscribe("hello/world").await.unwrap(); + + // Token shouldn't resolve before reaching broker + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + let Packet::Unsubscribe(Unsubscribe { topics, pkid, .. }) = broker.read_packet().await.unwrap() + else { + unreachable!() + }; + assert_eq!(topics, vec!["hello/world"]); + assert_eq!(pkid, 1); + + // Token shouldn't resolve until packet is acked + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap_err(); + + // Finally ack the packet + broker.unsuback(1).await; + + // Token shouldn't resolve until packet is acked + assert_eq!( + timeout(Duration::from_secs(1), &mut token) + .await + .unwrap() + .unwrap(), + UnsubAck { pkid: 1 } + ); +}