diff --git a/muxers/yamux/src/lib.rs b/muxers/yamux/src/lib.rs index e53b65f6917c..fdf211ec341a 100644 --- a/muxers/yamux/src/lib.rs +++ b/muxers/yamux/src/lib.rs @@ -129,25 +129,23 @@ where ) -> Poll> { let this = self.get_mut(); - let inbound_stream = ready!(this.poll_inner(cx))?; - - if this.inbound_stream_buffer.len() >= MAX_BUFFERED_INBOUND_STREAMS { - tracing::warn!( - stream=%inbound_stream.0, - "dropping stream because buffer is full" - ); - drop(inbound_stream); - } else { - this.inbound_stream_buffer.push_back(inbound_stream); - - if let Some(waker) = this.inbound_stream_waker.take() { - waker.wake() + loop { + let inbound_stream = ready!(this.poll_inner(cx))?; + + if this.inbound_stream_buffer.len() >= MAX_BUFFERED_INBOUND_STREAMS { + tracing::warn!( + stream=%inbound_stream.0, + "dropping stream because buffer is full" + ); + drop(inbound_stream); + } else { + this.inbound_stream_buffer.push_back(inbound_stream); + + if let Some(waker) = this.inbound_stream_waker.take() { + waker.wake() + } } } - - // Schedule an immediate wake-up, allowing other code to run. - // cx.waker().wake_by_ref(); - Poll::Pending } } diff --git a/swarm/src/connection.rs b/swarm/src/connection.rs index 18ddacbdb64a..2b06c64b7a7d 100644 --- a/swarm/src/connection.rs +++ b/swarm/src/connection.rs @@ -27,7 +27,6 @@ pub use error::ConnectionError; pub(crate) use error::{ PendingConnectionError, PendingInboundConnectionError, PendingOutboundConnectionError, }; -use futures::task::AtomicWaker; use smallvec::SmallVec; pub use supported_protocols::SupportedProtocols; @@ -40,6 +39,7 @@ use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend}; use crate::{ ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError, SubstreamProtocol, }; +use delegation::DelegatedWaker; use futures::future::BoxFuture; use futures::stream::FuturesUnordered; use futures::StreamExt; @@ -56,9 +56,9 @@ use libp2p_identity::PeerId; use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::future::Future; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -use std::task::{Wake, Waker}; +use std::task::Waker; use std::time::Duration; use std::{fmt, io, mem, pin::Pin, task::Context, task::Poll}; @@ -117,8 +117,6 @@ where /// Node that handles the muxing. muxing: StreamMuxerBox, muxing_waker: Arc, - muxing_inbound_waker: Arc, - muxing_outbound_waker: Arc, /// The underlying handler. handler: THandler, handler_waker: Arc, @@ -210,8 +208,6 @@ where Connection { muxing: muxer, muxing_waker: DelegatedWaker::new(), - muxing_inbound_waker: DelegatedWaker::new(), - muxing_outbound_waker: DelegatedWaker::new(), handler, handler_waker: DelegatedWaker::new(), negotiating_in: Default::default(), @@ -262,9 +258,8 @@ where requested_substreams, muxing, muxing_waker, - muxing_inbound_waker, - muxing_outbound_waker, handler, + handler_waker, negotiating_out, negotiating_in, shutdown, @@ -277,113 +272,120 @@ where .. } = self.get_mut(); - let mut muxing_waker = muxing_waker.scope(); - let mut muxing_inbound_waker = muxing_inbound_waker.scope(); - let mut muxing_outbound_waker = muxing_outbound_waker.scope(); + let mut muxing_waker = muxing_waker.scope(cx); + let mut handler_waker = handler_waker.scope(cx); loop { - match requested_substreams.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(()))) => continue, - Poll::Ready(Some(Err(info))) => { + let mut hcx = handler_waker.guard(); + let mut mcx = muxing_waker.guard(); + let mut handler_mutated = false; + let mut muxer_mutated = false; + let mut new_requests = false; + + while let Poll::Ready(Some(res)) = + hcx.with(|cx| poll_unordered(requested_substreams, cx)) + { + if let Err(info) = res { handler.on_connection_event(ConnectionEvent::DialUpgradeError( DialUpgradeError { info, error: StreamUpgradeError::Timeout, }, )); - continue; + handler_mutated = true; } - Poll::Ready(None) | Poll::Pending => {} } - // Poll the [`ConnectionHandler`]. - match handler.poll(cx) { - Poll::Pending => {} - Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => { - let timeout = *protocol.timeout(); - let (upgrade, user_data) = protocol.into_upgrade(); + while let Poll::Ready(event) = hcx.with(|cx| handler.poll(cx)) { + match event { + ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } => { + let timeout = *protocol.timeout(); + let (upgrade, user_data) = protocol.into_upgrade(); - requested_substreams.push(SubstreamRequested::new(user_data, timeout, upgrade)); - continue; // Poll handler until exhausted. - } - Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)) => { - return Poll::Ready(Ok(Event::Handler(event))); - } - Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols( - ProtocolSupport::Added(protocols), - )) => { - let added = protocols - .into_iter() - .filter(|p| remote_supported_protocols.insert(p.clone())) - .collect::>(); - if !added.is_empty() { - handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange( - ProtocolsChange::Added(added), - )); + requested_substreams + .push(SubstreamRequested::new(user_data, timeout, upgrade)); + new_requests = true; } - continue; - } - Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols( - ProtocolSupport::Removed(protocols), - )) => { - let removed = protocols - .into_iter() - .filter_map(|p| remote_supported_protocols.take(&p)) - .collect::>(); - if !removed.is_empty() { - handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange( - ProtocolsChange::Removed(removed), - )); + ConnectionHandlerEvent::NotifyBehaviour(event) => { + handler_waker.wake(); + return Poll::Ready(Ok(Event::Handler(event))); + } + ConnectionHandlerEvent::ReportRemoteProtocols(ProtocolSupport::Added( + protocols, + )) => { + let added = protocols + .into_iter() + .filter(|p| remote_supported_protocols.insert(p.clone())) + .collect::>(); + if !added.is_empty() { + handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange( + ProtocolsChange::Added(added), + )); + } + } + ConnectionHandlerEvent::ReportRemoteProtocols(ProtocolSupport::Removed( + protocols, + )) => { + let removed = protocols + .into_iter() + .filter_map(|p| remote_supported_protocols.take(&p)) + .collect::>(); + if !removed.is_empty() { + handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange( + ProtocolsChange::Removed(removed), + )); + } } - continue; } + handler_mutated = true; } // In case the [`ConnectionHandler`] can not make any more progress, poll the negotiating outbound streams. - match negotiating_out.poll_next_unpin(cx) { - Poll::Pending | Poll::Ready(None) => {} - Poll::Ready(Some((info, Ok(protocol)))) => { - handler.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound( - FullyNegotiatedOutbound { protocol, info }, - )); - continue; - } - Poll::Ready(Some((info, Err(error)))) => { - handler.on_connection_event(ConnectionEvent::DialUpgradeError( + while let Poll::Ready(Some((info, res))) = + hcx.with(|cx| poll_unordered(negotiating_out, cx)) + { + match res { + Ok(protocol) => { + handler.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound( + FullyNegotiatedOutbound { protocol, info }, + )) + } + Err(error) => handler.on_connection_event(ConnectionEvent::DialUpgradeError( DialUpgradeError { info, error }, - )); - continue; + )), } + handler_mutated = true; } // In case both the [`ConnectionHandler`] and the negotiating outbound streams can not // make any more progress, poll the negotiating inbound streams. - match negotiating_in.poll_next_unpin(cx) { - Poll::Pending | Poll::Ready(None) => {} - Poll::Ready(Some((info, Ok(protocol)))) => { - handler.on_connection_event(ConnectionEvent::FullyNegotiatedInbound( - FullyNegotiatedInbound { protocol, info }, - )); - continue; - } - Poll::Ready(Some((info, Err(StreamUpgradeError::Apply(error))))) => { - handler.on_connection_event(ConnectionEvent::ListenUpgradeError( - ListenUpgradeError { info, error }, - )); - continue; - } - Poll::Ready(Some((_, Err(StreamUpgradeError::Io(e))))) => { - tracing::debug!("failed to upgrade inbound stream: {e}"); - continue; - } - Poll::Ready(Some((_, Err(StreamUpgradeError::NegotiationFailed)))) => { - tracing::debug!("no protocol could be agreed upon for inbound stream"); - continue; - } - Poll::Ready(Some((_, Err(StreamUpgradeError::Timeout)))) => { - tracing::debug!("inbound stream upgrade timed out"); - continue; + while let Poll::Ready(Some((info, res))) = + hcx.with(|cx| poll_unordered(negotiating_in, cx)) + { + match res { + Ok(protocol) => { + handler.on_connection_event(ConnectionEvent::FullyNegotiatedInbound( + FullyNegotiatedInbound { protocol, info }, + )); + handler_mutated = true; + } + Err(StreamUpgradeError::Apply(error)) => { + handler.on_connection_event(ConnectionEvent::ListenUpgradeError( + ListenUpgradeError { info, error }, + )); + handler_mutated = true; + } + Err(StreamUpgradeError::Io(e)) => { + tracing::debug!("failed to upgrade inbound stream: {e}"); + } + Err(StreamUpgradeError::NegotiationFailed) => { + tracing::debug!("no protocol could be agreed upon for inbound stream"); + } + Err(StreamUpgradeError::Timeout) => { + tracing::debug!("inbound stream upgrade timed out"); + } } + // TODO: more this to respective branches } // Check if the connection (and handler) should be shut down. @@ -413,22 +415,22 @@ where *shutdown = Shutdown::None; } - match muxing_waker.guard(cx, |cx| muxing.poll_unpin(cx))? { + match mcx.with(|cx| muxing.poll_unpin(cx))? { Poll::Pending => {} Poll::Ready(StreamMuxerEvent::AddressChange(address)) => { handler.on_connection_event(ConnectionEvent::AddressChange(AddressChange { new_address: &address, })); + muxing_waker.wake(); return Poll::Ready(Ok(Event::AddressChange(address))); } } - if let Some(requested_substream) = requested_substreams.iter_mut().next() { - match muxing_outbound_waker.guard(cx, |cx| muxing.poll_outbound_unpin(cx))? { - Poll::Pending => {} + for requested_substream in requested_substreams.iter_mut().filter(|s| s.can_extract()) { + match mcx.with(|cx| muxing.poll_outbound_unpin(cx))? { + Poll::Pending => break, Poll::Ready(substream) => { let (user_data, timeout, upgrade) = requested_substream.extract(); - negotiating_out.push(StreamUpgrade::new_outbound( substream, user_data, @@ -437,56 +439,57 @@ where *substream_upgrade_protocol_override, stream_counter.clone(), )); - - continue; // Go back to the top, handler can potentially make progress again. + new_requests = true; + muxer_mutated = true; } } } if negotiating_in.len() < *max_negotiating_inbound_streams { - match muxing_inbound_waker.guard(cx, |cx| muxing.poll_inbound_unpin(cx))? { - Poll::Pending => {} - Poll::Ready(substream) => { - let protocol = handler.listen_protocol(); - - negotiating_in.push(StreamUpgrade::new_inbound( - substream, - protocol, - stream_counter.clone(), - )); - - continue; // Go back to the top, handler can potentially make progress again. - } + while let Poll::Ready(substream) = mcx.with(|cx| muxing.poll_inbound_unpin(cx))? { + let protocol = handler.listen_protocol(); + negotiating_in.push(StreamUpgrade::new_inbound( + substream, + protocol, + stream_counter.clone(), + )); + new_requests = true; + muxer_mutated = true; } } - let prev_protocol_count = supported_protocols.len(); - supported_protocols.extend(handler.listen_protocol().upgrade().protocol_info()); + if new_requests { + stream_waker.wake(); + } - let (old, new) = supported_protocols.split_at(prev_protocol_count); - let added = collect_missing_info(new, old); - let removed = collect_missing_info(old, new); + if handler_mutated { + let prev_protocol_count = supported_protocols.len(); + supported_protocols.extend(handler.listen_protocol().upgrade().protocol_info()); - supported_protocols.drain(..prev_protocol_count); + let (old, new) = supported_protocols.split_at(prev_protocol_count); + let added = collect_missing_info(new, old); + let removed = collect_missing_info(old, new); - let [added_any, removed_any] = [!added.is_empty(), !removed.is_empty()]; - if added_any { - handler.on_connection_event(ConnectionEvent::LocalProtocolsChange( - ProtocolsChange::Added(added), - )); - } - if removed_any { - handler.on_connection_event(ConnectionEvent::LocalProtocolsChange( - ProtocolsChange::Removed(removed), - )); - } + supported_protocols.drain(..prev_protocol_count); - if added_any || removed_any { - continue; + if !added.is_empty() { + handler.on_connection_event(ConnectionEvent::LocalProtocolsChange( + ProtocolsChange::Added(added), + )); + } + if !removed.is_empty() { + handler.on_connection_event(ConnectionEvent::LocalProtocolsChange( + ProtocolsChange::Removed(removed), + )); + } } - return Poll::Pending; // Nothing can make progress, return `Pending`. + if !muxer_mutated && !handler_mutated { + break; + } } + + Poll::Pending } #[cfg(test)] @@ -495,6 +498,18 @@ where } } +fn poll_unordered( + futures: &mut FuturesUnordered, + cx: &mut Context<'_>, +) -> Poll> { + // its cheaper to check this then let the poll do it since poll will also register waker + if futures.is_empty() { + return Poll::Pending; + } + + futures.poll_next_unpin(cx) +} + fn info_to_stream_protocol(info: &impl AsRef) -> Option { StreamProtocol::try_from_owned(info.as_ref().to_owned()).ok() } @@ -711,6 +726,10 @@ impl SubstreamRequested { } } + fn can_extract(&self) -> bool { + matches!(self, Self::Waiting { .. }) + } + fn extract(&mut self) -> (UserData, Delay, Upgrade) { match mem::replace(self, Self::Done) { SubstreamRequested::Waiting { @@ -780,6 +799,91 @@ enum Shutdown { Later(Delay, Instant), } +mod delegation { + use futures::task::AtomicWaker; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + use std::task::{Poll, Wake, Waker}; + + pub(crate) struct DelegatedWaker { + inner: AtomicWaker, + woken_up: AtomicBool, + } + + impl DelegatedWaker { + pub(crate) fn new() -> Arc { + Arc::new(Self { + inner: AtomicWaker::new(), + woken_up: AtomicBool::new(true), + }) + } + + pub(crate) fn scope<'a>( + self: &'a Arc, + cx: &mut std::task::Context<'_>, + ) -> DelegatedWakerScope<'a> { + self.inner.register(cx.waker()); + DelegatedWakerScope { + cloned: None, + waker: self, + } + } + } + + pub(crate) struct DelegatedWakerScope<'a> { + waker: &'a Arc, + cloned: Option, + } + + impl DelegatedWakerScope<'_> { + pub(crate) fn guard(&mut self) -> DelegatedContext<'_> { + let woken_up = self.waker.woken_up.swap(false, Ordering::SeqCst); + if !woken_up { + return DelegatedContext { cx: None }; + } + + let wk = self + .cloned + .get_or_insert_with(|| Waker::from(self.waker.clone())); + let cx = std::task::Context::from_waker(&*wk); + DelegatedContext { cx: Some(cx) } + } + + pub(crate) fn wake(&mut self) { + self.waker.woken_up.store(true, Ordering::SeqCst); + } + } + + pub(crate) struct DelegatedContext<'a> { + cx: Option>, + } + + impl<'a> DelegatedContext<'a> { + pub(crate) fn with( + &mut self, + poll: impl FnOnce(&mut std::task::Context<'_>) -> Poll, + ) -> Poll { + if let Some(cx) = &mut self.cx { + poll(cx) + } else { + Poll::Pending + } + } + } + + impl Wake for DelegatedWaker { + fn wake_by_ref(self: &std::sync::Arc) { + if !self.woken_up.swap(true, Ordering::SeqCst) { + self.inner.wake(); + } + } + + fn wake(self: std::sync::Arc) { + self.wake_by_ref() + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -1394,63 +1498,3 @@ impl From for PendingPoint { } } } - -struct DelegatedWaker { - inner: AtomicWaker, - woken_up: AtomicBool, -} - -impl DelegatedWaker { - fn new() -> Arc { - Arc::new(Self { - inner: AtomicWaker::new(), - woken_up: AtomicBool::new(true), - }) - } - - fn scope<'a>(self: &'a Arc) -> DelegatedScope<'a> { - DelegatedScope { - waker: self, - dirty: true, - } - } -} - -struct DelegatedScope<'a> { - waker: &'a Arc, - dirty: bool, -} - -impl DelegatedScope<'_> { - fn guard( - &mut self, - cx: &mut std::task::Context<'_>, - poll: impl FnOnce(&mut std::task::Context<'_>) -> Poll, - ) -> Poll { - let woken_up = self.waker.woken_up.load(Ordering::Relaxed); - if !woken_up { - return Poll::Pending; - } - - self.waker.woken_up.store(false, Ordering::Relaxed); - if std::mem::take(&mut self.dirty) { - self.waker.inner.wake(); - } - self.waker.inner.register(cx.waker()); - let wk = Waker::from(Arc::clone(self.waker)); - let mut cx = std::task::Context::from_waker(&wk); - poll(&mut cx) - } -} - -impl Wake for DelegatedWaker { - fn wake_by_ref(self: &std::sync::Arc) { - if !self.woken_up.swap(true, Ordering::Relaxed) { - self.inner.wake(); - } - } - - fn wake(self: std::sync::Arc) { - self.wake_by_ref() - } -}