Skip to content

Commit

Permalink
add shutdown signal to inbound messaging worker
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbondi committed May 3, 2024
1 parent bb00135 commit 6cd1301
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
15 changes: 10 additions & 5 deletions comms/core/src/protocol/messaging/inbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@

use std::io;

use futures::StreamExt;
use futures::{future::Either, SinkExt, StreamExt};
use log::*;
use tari_shutdown::ShutdownSignal;
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::{broadcast, mpsc},
Expand All @@ -32,7 +33,7 @@ use tokio::{
#[cfg(feature = "metrics")]
use super::metrics;
use super::{MessagingEvent, MessagingProtocol};
use crate::{message::InboundMessage, peer_manager::NodeId};
use crate::{message::InboundMessage, peer_manager::NodeId, protocol::rpc::__macro_reexports::future};

const LOG_TARGET: &str = "comms::protocol::messaging::inbound";

Expand All @@ -42,6 +43,7 @@ pub struct InboundMessaging {
inbound_message_tx: mpsc::Sender<InboundMessage>,
messaging_events_tx: broadcast::Sender<MessagingEvent>,
enable_message_received_event: bool,
shutdown_signal: ShutdownSignal,
}

impl InboundMessaging {
Expand All @@ -50,16 +52,18 @@ impl InboundMessaging {
inbound_message_tx: mpsc::Sender<InboundMessage>,
messaging_events_tx: broadcast::Sender<MessagingEvent>,
enable_message_received_event: bool,
shutdown_signal: ShutdownSignal,
) -> Self {
Self {
peer,
inbound_message_tx,
messaging_events_tx,
enable_message_received_event,
shutdown_signal,
}
}

pub async fn run<S>(self, socket: S)
pub async fn run<S>(mut self, socket: S)
where S: AsyncRead + AsyncWrite + Unpin {
let peer = &self.peer;
#[cfg(feature = "metrics")]
Expand All @@ -71,10 +75,9 @@ impl InboundMessaging {
);

let stream = MessagingProtocol::framed(socket);

tokio::pin!(stream);

while let Some(result) = stream.next().await {
while let Either::Right((Some(result), _)) = future::select(self.shutdown_signal.wait(), stream.next()).await {
match result {
Ok(raw_msg) => {
#[cfg(feature = "metrics")]
Expand Down Expand Up @@ -138,6 +141,8 @@ impl InboundMessaging {
}
}

let _ignore = stream.close().await;

let _ignore = self
.messaging_events_tx
.send(MessagingEvent::InboundProtocolExited(peer.clone()));
Expand Down
1 change: 1 addition & 0 deletions comms/core/src/protocol/messaging/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ impl MessagingProtocol {
inbound_message_tx,
messaging_events_tx,
self.enable_message_received_event,
self.shutdown_signal.clone(),
);
let handle = tokio::spawn(inbound_messaging.run(substream));
self.active_inbound.insert(peer, handle);
Expand Down
3 changes: 2 additions & 1 deletion comms/dht/tests/dht.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ async fn test_dht_propagate_dedup() {
#[allow(non_snake_case)]
#[allow(clippy::too_many_lines)]
async fn test_dht_do_not_store_invalid_message_in_dedup() {
env_logger::builder().filter_level(log::LevelFilter::Debug).init();
let mut config = dht_config();
config.dedup_allowed_message_occurrences = 1;

Expand Down Expand Up @@ -586,7 +587,7 @@ async fn test_dht_do_not_store_invalid_message_in_dedup() {
node_C.shutdown().await;

// Check the message flow BEFORE deduping
let received = filter_received(collect_try_recv!(node_C_messaging, timeout = Duration::from_secs(20)));
let received = filter_received(collect_try_recv!(node_C_messaging, timeout = Duration::from_secs(40)));

let received_from_a = count_messages_received(&received, &[&node_A_id]);
let received_from_b = count_messages_received(&received, &[&node_B_id]);
Expand Down

0 comments on commit 6cd1301

Please sign in to comment.