Skip to content

Commit

Permalink
more triing, some bugs in swarm discovered
Browse files Browse the repository at this point in the history
  • Loading branch information
jakubDoka committed Dec 22, 2023
1 parent 7ec9854 commit e54bb99
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 86 deletions.
18 changes: 15 additions & 3 deletions protocols/kad/src/behaviour/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ impl Arbitrary for Seed {

#[test]
fn bootstrap() {
let _ = tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init();
fn prop(seed: Seed) {
let mut rng = StdRng::from_seed(seed.0);

Expand Down Expand Up @@ -222,7 +225,13 @@ fn bootstrap() {
known.insert(*e.node.key.preimage());
}
}
assert_eq!(expected_known, known);
assert_eq!(
expected_known,
known,
"{} {}",
expected_known.len(),
known.len()
);
return Poll::Ready(());
}
}
Expand Down Expand Up @@ -1090,6 +1099,9 @@ fn exp_decr_expiration_overflow() {

#[test]
fn disjoint_query_does_not_finish_before_all_paths_did() {
let _ = tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init();
let mut config = Config::default();
config.disjoint_query_paths(true);
// I.e. setting the amount disjoint paths to be explored to 2.
Expand Down Expand Up @@ -1129,7 +1141,7 @@ fn disjoint_query_does_not_finish_before_all_paths_did() {

// The default peer timeout is 10 seconds. Choosing 1 seconds here should
// give enough head room to prevent connections to `bob` to time out.
let mut before_timeout = Delay::new(Duration::from_secs(1));
let mut before_timeout = Delay::new(Duration::from_secs(3));

// Poll only `alice` and `trudy` expecting `alice` not yet to return a query
// result as it is not able to connect to `bob` just yet.
Expand Down Expand Up @@ -1161,7 +1173,7 @@ fn disjoint_query_does_not_finish_before_all_paths_did() {
}
}
// Ignore any other event.
Poll::Ready(Some(_)) => (),
Poll::Ready(Some(_)) => {}
Poll::Ready(None) => panic!("Expected Kademlia behaviour not to finish."),
Poll::Pending => break,
}
Expand Down
35 changes: 16 additions & 19 deletions protocols/kad/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,16 +492,7 @@ impl Handler {
if let Some(sender) = self.pending_streams.pop_front() {
let _ = sender.send(Ok(stream));
}

if self.protocol_status.is_none() {
// Upon the first successfully negotiated substream, we know that the
// remote is configured with the same protocol name and we want
// the behaviour to add this peer to the routing table, if possible.
self.protocol_status = Some(ProtocolStatus {
supported: true,
reported: false,
});
}
self.lazy_init_protocol_status();
}

fn on_fully_negotiated_inbound(
Expand All @@ -518,15 +509,7 @@ impl Handler {
future::Either::Right(p) => void::unreachable(p),
};

if self.protocol_status.is_none() {
// Upon the first successfully negotiated substream, we know that the
// remote is configured with the same protocol name and we want
// the behaviour to add this peer to the routing table, if possible.
self.protocol_status = Some(ProtocolStatus {
supported: true,
reported: false,
});
}
self.lazy_init_protocol_status();

if self.inbound_substreams.len() == MAX_NUM_STREAMS {
if let Some(s) = self.inbound_substreams.iter_mut().find(|s| {
Expand Down Expand Up @@ -564,6 +547,19 @@ impl Handler {
});
}

fn lazy_init_protocol_status(&mut self) {
if self.protocol_status.is_none() {
// Upon the first successfully negotiated substream, we know that the
// remote is configured with the same protocol name and we want
// the behaviour to add this peer to the routing table, if possible.
self.protocol_status = Some(ProtocolStatus {
supported: false,
reported: false,
});
self.wake();
}
}

/// Takes the given [`KadRequestMsg`] and composes it into an outbound request-response protocol handshake using a [`oneshot::channel`].
fn queue_new_stream(&mut self, id: QueryId, msg: KadRequestMsg) {
let (sender, receiver) = oneshot::channel();
Expand Down Expand Up @@ -823,6 +819,7 @@ impl ConnectionHandler for Handler {
.iter()
.any(|p| self.protocol_config.protocol_names().contains(p));

self.wake();
self.protocol_status = Some(compute_new_protocol_status(
remote_supports_our_kademlia_protocols,
self.protocol_status,
Expand Down
151 changes: 88 additions & 63 deletions swarm/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ where
SubstreamRequested<THandler::OutboundOpenInfo, THandler::OutboundProtocol>,
>,

stream_waker: Arc<DelegatedWaker>,
local_supported_protocols: Vec<<THandler::InboundProtocol as UpgradeInfoSend>::Info>,
remote_supported_protocols: HashSet<StreamProtocol>,
idle_timeout: Duration,
Expand Down Expand Up @@ -217,6 +218,7 @@ where
max_negotiating_inbound_streams,
requested_substreams: Default::default(),
local_supported_protocols: initial_protocols,
stream_waker: DelegatedWaker::new(),
remote_supported_protocols: Default::default(),
idle_timeout,
stream_counter: ActiveStreamCounter::default(),
Expand Down Expand Up @@ -269,20 +271,68 @@ where
remote_supported_protocols,
idle_timeout,
stream_counter,
stream_waker,
..
} = self.get_mut();

let mut muxing_waker = muxing_waker.scope(cx);
let mut handler_waker = handler_waker.scope(cx);
let mut stream_waker = stream_waker.scope(cx);

loop {
let mut hcx = handler_waker.guard();
let mut mcx = muxing_waker.guard();
let mut handler_mutated = false;
let mut muxer_mutated = false;

// we do polling first so that handler can update its waker
let mut hcx = handler_waker.guard();
while let Poll::Ready(event) = hcx.with(|cx| handler.poll(cx)) {
match event {
ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } => {
let timeout = *protocol.timeout();
let (upgrade, user_data) = protocol.into_upgrade();

requested_substreams
.push(SubstreamRequested::new(user_data, timeout, upgrade));
muxing_waker.wake();
}
ConnectionHandlerEvent::NotifyBehaviour(event) => {
handler_waker.wake();
muxing_waker.wake();
return Poll::Ready(Ok(Event::Handler(event)));
}
ConnectionHandlerEvent::ReportRemoteProtocols(ProtocolSupport::Added(
protocols,
)) => {
let added = protocols
.into_iter()
.filter(|p| remote_supported_protocols.insert(p.clone()))
.collect::<SmallVec<_>>();
if !added.is_empty() {
handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange(
ProtocolsChange::Added(added),
));
}
}
ConnectionHandlerEvent::ReportRemoteProtocols(ProtocolSupport::Removed(
protocols,
)) => {
let removed = protocols
.into_iter()
.filter_map(|p| remote_supported_protocols.take(&p))
.collect::<SmallVec<_>>();
if !removed.is_empty() {
handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange(
ProtocolsChange::Removed(removed),
));
}
}
}
handler_mutated = true;
}

let mut scx = stream_waker.guard();
while let Poll::Ready(Some(res)) =
hcx.with(|cx| poll_unordered(requested_substreams, cx))
scx.with(|cx| poll_unordered(requested_substreams, cx))
{
if let Err(info) = res {
handler.on_connection_event(ConnectionEvent::DialUpgradeError(
Expand All @@ -297,7 +347,7 @@ where

// In case the [`ConnectionHandler`] can not make any more progress, poll the negotiating outbound streams.
while let Poll::Ready(Some((info, res))) =
hcx.with(|cx| poll_unordered(negotiating_out, cx))
scx.with(|cx| poll_unordered(negotiating_out, cx))
{
match res {
Ok(protocol) => {
Expand All @@ -315,7 +365,7 @@ where
// In case both the [`ConnectionHandler`] and the negotiating outbound streams can not
// make any more progress, poll the negotiating inbound streams.
while let Poll::Ready(Some((info, res))) =
hcx.with(|cx| poll_unordered(negotiating_in, cx))
scx.with(|cx| poll_unordered(negotiating_in, cx))
{
match res {
Ok(protocol) => {
Expand Down Expand Up @@ -343,49 +393,6 @@ where
// TODO: more this to respective branches
}

while let Poll::Ready(event) = hcx.with(|cx| handler.poll(cx)) {
match event {
ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } => {
let timeout = *protocol.timeout();
let (upgrade, user_data) = protocol.into_upgrade();

requested_substreams
.push(SubstreamRequested::new(user_data, timeout, upgrade));
}
ConnectionHandlerEvent::NotifyBehaviour(event) => {
handler_waker.wake();
return Poll::Ready(Ok(Event::Handler(event)));
}
ConnectionHandlerEvent::ReportRemoteProtocols(ProtocolSupport::Added(
protocols,
)) => {
let added = protocols
.into_iter()
.filter(|p| remote_supported_protocols.insert(p.clone()))
.collect::<SmallVec<_>>();
if !added.is_empty() {
handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange(
ProtocolsChange::Added(added),
));
}
}
ConnectionHandlerEvent::ReportRemoteProtocols(ProtocolSupport::Removed(
protocols,
)) => {
let removed = protocols
.into_iter()
.filter_map(|p| remote_supported_protocols.take(&p))
.collect::<SmallVec<_>>();
if !removed.is_empty() {
handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange(
ProtocolsChange::Removed(removed),
));
}
}
}
handler_mutated = true;
}

// Check if the connection (and handler) should be shut down.
// As long as we're still negotiating substreams or have any active streams shutdown is always postponed.
if negotiating_in.is_empty()
Expand Down Expand Up @@ -413,17 +420,21 @@ where
*shutdown = Shutdown::None;
}

let mut mcx = muxing_waker.guard();
match mcx.with(|cx| muxing.poll_unpin(cx))? {
Poll::Pending => {}
Poll::Ready(StreamMuxerEvent::AddressChange(address)) => {
handler.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
new_address: &address,
}));
muxing_waker.wake();
handler_waker.wake();
stream_waker.wake();
return Poll::Ready(Ok(Event::AddressChange(address)));
}
}

let mut trigger_streams = false;
for requested_substream in requested_substreams.iter_mut().filter(|s| s.can_extract()) {
match mcx.with(|cx| muxing.poll_outbound_unpin(cx))? {
Poll::Pending => break,
Expand All @@ -437,6 +448,7 @@ where
*substream_upgrade_protocol_override,
stream_counter.clone(),
));
trigger_streams = true;
muxer_mutated = true;
}
}
Expand All @@ -450,10 +462,15 @@ where
protocol,
stream_counter.clone(),
));
trigger_streams = true;
muxer_mutated = true;
}
}

if trigger_streams {
stream_waker.wake();
}

if handler_mutated {
let prev_protocol_count = supported_protocols.len();
supported_protocols.extend(handler.listen_protocol().upgrade().protocol_info());
Expand Down Expand Up @@ -827,40 +844,48 @@ mod delegation {
cloned: Option<std::task::Waker>,
}

impl DelegatedWakerScope<'_> {
pub(crate) fn guard(&mut self) -> DelegatedContext<'_> {
impl<'a> DelegatedWakerScope<'a> {
pub(crate) fn guard(&mut self) -> DelegatedContext<'_, 'a> {
let woken_up = self.waker.woken_up.swap(false, Ordering::SeqCst);
if !woken_up {
return DelegatedContext { cx: None };
DelegatedContext {
active: woken_up,
original: self,
}

let wk = self
.cloned
.get_or_insert_with(|| Waker::from(self.waker.clone()));
let cx = std::task::Context::from_waker(&*wk);
DelegatedContext { cx: Some(cx) }
}

pub(crate) fn wake(&mut self) {
self.waker.woken_up.store(true, Ordering::SeqCst);
}
}

pub(crate) struct DelegatedContext<'a> {
cx: Option<std::task::Context<'a>>,
pub(crate) struct DelegatedContext<'a, 'b> {
original: &'a mut DelegatedWakerScope<'b>,
active: bool,
}

impl<'a> DelegatedContext<'a> {
impl<'a, 'b> DelegatedContext<'a, 'b> {
pub(crate) fn with<R>(
&mut self,
poll: impl FnOnce(&mut std::task::Context<'_>) -> Poll<R>,
) -> Poll<R> {
if let Some(cx) = &mut self.cx {
poll(cx)
if self.active {
let wk = self
.original
.cloned
.get_or_insert_with(|| Waker::from(self.original.waker.clone()));
let mut cx = std::task::Context::from_waker(wk);
poll(&mut cx)
} else {
Poll::Pending
}
}

pub(crate) fn wake(&mut self) {
if !self.active {
self.original.waker.woken_up.store(true, Ordering::SeqCst);
self.active = true;
}
}
}

impl Wake for DelegatedWaker {
Expand Down
3 changes: 2 additions & 1 deletion swarm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ use libp2p_core::{
use libp2p_identity::PeerId;
use smallvec::SmallVec;
use std::collections::{HashMap, HashSet, VecDeque};
use std::mem;
use std::num::{NonZeroU32, NonZeroU8, NonZeroUsize};
use std::time::Duration;
use std::{
Expand Down Expand Up @@ -1330,7 +1331,7 @@ where
Poll::Pending => pending.push(id),
Poll::Ready(Err(())) => {} // connection is closing
Poll::Ready(Ok(())) => {
let e = event.take().expect("by (1),(2)");
let e = mem::take(&mut event).expect("by (1),(2)");
if let Err(e) = conn.notify_handler(e) {
event = Some(e) // (2)
} else {
Expand Down

0 comments on commit e54bb99

Please sign in to comment.