Skip to content

Commit

Permalink
switching to hashsets
Browse files Browse the repository at this point in the history
  • Loading branch information
jakubDoka committed Dec 25, 2023
1 parent d8417ea commit b9ffdae
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 24 deletions.
46 changes: 38 additions & 8 deletions swarm/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,10 @@ where
SubstreamRequested<THandler::OutboundOpenInfo, THandler::OutboundProtocol>,
>,

local_supported_protocols: Vec<<THandler::InboundProtocol as UpgradeInfoSend>::Info>,
local_supported_protocols:
HashSet<AsStrHashEq<<THandler::InboundProtocol as UpgradeInfoSend>::Info>>,
remote_supported_protocols: HashSet<StreamProtocol>,
temp_protocols_set: HashSet<AsStrHashEq<<THandler::InboundProtocol as UpgradeInfoSend>::Info>>,
temp_protocols: Vec<StreamProtocol>,

idle_timeout: Duration,
Expand Down Expand Up @@ -192,12 +194,13 @@ where
.listen_protocol()
.upgrade()
.protocol_info()
.collect::<Vec<_>>();
.map(AsStrHashEq)
.collect::<HashSet<_>>();

if !local_supported_protocols.is_empty() {
let temp = local_supported_protocols
.iter()
.filter_map(|i| StreamProtocol::try_from_owned(i.as_ref().to_owned()).ok())
.filter_map(|i| StreamProtocol::try_from_owned(i.0.as_ref().to_owned()).ok())
.collect::<Vec<_>>();
handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(
ProtocolsChange::Added(ProtocolsAdded {
Expand All @@ -216,6 +219,7 @@ where
requested_substreams: Default::default(),
local_supported_protocols,
remote_supported_protocols: Default::default(),
temp_protocols_set: Default::default(),
temp_protocols: Default::default(),
idle_timeout,
stream_counter: ActiveStreamCounter::default(),
Expand Down Expand Up @@ -264,6 +268,7 @@ where
substream_upgrade_protocol_override,
local_supported_protocols: supported_protocols,
remote_supported_protocols,
temp_protocols_set,
temp_protocols,
idle_timeout,
stream_counter,
Expand Down Expand Up @@ -457,17 +462,26 @@ where
}
}

let current_proocols = supported_protocols.len();
supported_protocols.extend(handler.listen_protocol().upgrade().protocol_info());
let (old, new) = supported_protocols.split_at(current_proocols);
let changes = ProtocolsChange::from_full_sets(old, new, temp_protocols);
supported_protocols.drain(..current_proocols);
temp_protocols_set.clear();
temp_protocols_set.extend(
handler
.listen_protocol()
.upgrade()
.protocol_info()
.map(AsStrHashEq),
);
let changes = ProtocolsChange::from_full_sets(
supported_protocols,
temp_protocols_set,
temp_protocols,
);
let mut has_changes = false;
for change in changes {
handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(change));
has_changes = true;
}
if has_changes {
std::mem::swap(supported_protocols, temp_protocols_set);
continue;
}

Expand Down Expand Up @@ -1369,3 +1383,19 @@ impl From<ConnectedPoint> for PendingPoint {
}
}
}

pub(crate) struct AsStrHashEq<T>(pub(crate) T);

impl<T: AsRef<str>> Eq for AsStrHashEq<T> {}

impl<T: AsRef<str>> PartialEq for AsStrHashEq<T> {
fn eq(&self, other: &Self) -> bool {
self.0.as_ref() == other.0.as_ref()
}
}

impl<T: AsRef<str>> std::hash::Hash for AsStrHashEq<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.as_ref().hash(state)
}
}
29 changes: 13 additions & 16 deletions swarm/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ mod one_shot;
mod pending;
mod select;

use crate::connection::AsStrHashEq;
pub use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, SendWrapper, UpgradeInfoSend};
pub use map_in::MapInEvent;
pub use map_out::MapOutEvent;
Expand Down Expand Up @@ -332,26 +333,22 @@ pub enum ProtocolsChange<'a> {
impl<'a> ProtocolsChange<'a> {
/// Compute the [`ProtocolsChange`]s required to go from `existing_protocols` to `new_protocols`.
pub(crate) fn from_full_sets<T: AsRef<str>>(
existing_protocols: &[T],
new_protocols: &[T],
existing_protocols: &HashSet<AsStrHashEq<T>>,
new_protocols: &HashSet<AsStrHashEq<T>>,
temp_owner: &'a mut Vec<StreamProtocol>,
) -> impl Iterator<Item = Self> {
fn push_difference<'a, T: AsRef<str>>(
a: &'a [T],
b: &'a [T],
buffer: &mut Vec<StreamProtocol>,
) {
let iter = a
.iter()
.filter(|a| b.iter().all(|b| b.as_ref() != a.as_ref()))
.filter_map(|p| StreamProtocol::try_from_owned(p.as_ref().to_string()).ok());
buffer.extend(iter);
}

temp_owner.clear();
push_difference(new_protocols, existing_protocols, temp_owner);
temp_owner.extend(
new_protocols
.difference(existing_protocols)
.find_map(|p| StreamProtocol::try_from_owned(p.0.as_ref().to_owned()).ok()),
);
let added_count = temp_owner.len();
push_difference(existing_protocols, new_protocols, temp_owner);
temp_owner.extend(
existing_protocols
.difference(new_protocols)
.find_map(|p| StreamProtocol::try_from_owned(p.0.as_ref().to_owned()).ok()),
);

let (added, removed) = temp_owner.split_at(added_count);
added
Expand Down

0 comments on commit b9ffdae

Please sign in to comment.