From 99b129afc46d06121cccd1bc585bc75206f79027 Mon Sep 17 00:00:00 2001 From: nicolas <48695862+merklefruit@users.noreply.github.com> Date: Sun, 25 Aug 2024 12:12:15 +0200 Subject: [PATCH 1/7] feat: ipc socket --- Cargo.lock | 18 ++-- Cargo.toml | 9 +- README.md | 36 ++++--- msg-common/src/lib.rs | 7 +- msg-sim/src/lib.rs | 7 +- msg-socket/Cargo.toml | 4 +- msg-socket/src/connection/state.rs | 8 +- msg-socket/src/lib.rs | 10 +- msg-socket/src/pub/driver.rs | 14 +-- msg-socket/src/pub/mod.rs | 38 +++---- msg-socket/src/pub/socket.rs | 61 ++++++++--- msg-socket/src/rep/driver.rs | 48 ++++----- msg-socket/src/rep/mod.rs | 36 +++---- msg-socket/src/rep/socket.rs | 61 ++++++++--- msg-socket/src/req/driver.rs | 24 ++--- msg-socket/src/req/socket.rs | 44 ++++++-- msg-socket/src/sub/driver.rs | 87 ++++++++-------- msg-socket/src/sub/mod.rs | 37 ++++--- msg-socket/src/sub/session.rs | 31 +++--- msg-socket/src/sub/socket.rs | 161 ++++++++++++++++++++--------- msg-socket/src/sub/stats.rs | 24 +++-- msg-socket/tests/it/pubsub.rs | 46 +++++---- msg-transport/Cargo.toml | 4 - msg-transport/src/ipc/mod.rs | 153 +++++++++++++++++++++++++++ msg-transport/src/lib.rs | 35 +++++-- msg-transport/src/quic/config.rs | 4 +- msg-transport/src/quic/mod.rs | 16 +-- msg-transport/src/quic/stream.rs | 2 +- msg-transport/src/tcp/mod.rs | 3 +- msg-wire/Cargo.toml | 1 - msg-wire/src/lib.rs | 4 + msg/benches/pubsub.rs | 112 ++++++++++++++++++-- msg/benches/reqrep.rs | 85 +++++++++++++-- msg/examples/durable.rs | 4 +- msg/examples/ipc.rs | 44 ++++++++ msg/examples/pubsub.rs | 6 +- msg/examples/pubsub_auth.rs | 6 +- msg/examples/pubsub_compression.rs | 6 +- msg/examples/quic_vs_tcp.rs | 8 +- msg/examples/reqrep.rs | 4 +- msg/examples/reqrep_auth.rs | 4 +- msg/examples/reqrep_compression.rs | 4 +- msg/src/lib.rs | 3 + 43 files changed, 948 insertions(+), 371 deletions(-) create mode 100644 msg-transport/src/ipc/mod.rs create mode 100644 msg/examples/ipc.rs diff --git a/Cargo.lock b/Cargo.lock index d949ef4..f190edd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -713,7 +713,7 @@ dependencies = [ [[package]] name = "msg" -version = "0.1.1" +version = "0.1.2" dependencies = [ "bytes", "criterion", @@ -733,7 +733,7 @@ dependencies = [ [[package]] name = "msg-common" -version = "0.1.1" +version = "0.1.2" dependencies = [ "futures", "tokio", @@ -742,16 +742,15 @@ dependencies = [ [[package]] name = "msg-sim" -version = "0.1.1" +version = "0.1.2" dependencies = [ "pnet", ] [[package]] name = "msg-socket" -version = "0.1.1" +version = "0.1.2" dependencies = [ - "async-trait", "bytes", "futures", "msg-common", @@ -771,27 +770,23 @@ dependencies = [ [[package]] name = "msg-transport" -version = "0.1.1" +version = "0.1.2" dependencies = [ "async-trait", - "bytes", "futures", "msg-common", - "msg-wire", "quinn", - "rand", "rcgen", "rustls", "thiserror", "tokio", - "tokio-util", "tracing", "tracing-subscriber", ] [[package]] name = "msg-wire" -version = "0.1.1" +version = "0.1.2" dependencies = [ "bytes", "flate2", @@ -800,7 +795,6 @@ dependencies = [ "snap", "thiserror", "tokio-util", - "tracing", "zstd", ] diff --git a/Cargo.toml b/Cargo.toml index 54913e1..35960fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,9 @@ members = [ resolver = "2" [workspace.package] -version = "0.1.1" +version = "0.1.2" edition = "2021" -rust-version = "1.70" # Remember to update .clippy.toml and README.md +rust-version = "1.70" license = "MIT" description = "A flexible and lightweight messaging library for distributed systems" authors = ["Jonas Bostoen", "Nicolas Racchi"] @@ -35,7 +35,6 @@ futures = "0.3" tokio-stream = { version = "0.1", features = ["sync"] } parking_lot = "0.12" - # general bytes = "1" thiserror = "1" @@ -43,9 +42,9 @@ tracing = "0.1" rustc-hash = "1" rand = "0.8" -# NETWORKING +# networking quinn = "0.10" -# rustls needs to be the same version as the one used by quinn +# (rustls needs to be the same version as the one used by quinn) rustls = { version = "0.21", features = ["quic", "dangerous_configuration"] } rcgen = "0.12" diff --git a/README.md b/README.md index 030122d..34e98e7 100644 --- a/README.md +++ b/README.md @@ -18,31 +18,37 @@ `msg-rs` is a messaging library that was inspired by projects like [ZeroMQ](https://zeromq.org/) and [Nanomsg](https://nanomsg.org/). It was built because we needed a Rust-native messaging library like those above. -> MSG is still in ALPHA and is not ready for production use. - ## Documentation -The [MSG-RS Book][book] contains detailed information on how to use the library. +The 📖 [MSG-RS Book][book] contains detailed information on how to use the library. ## Features -- [ ] Multiple socket types +- Multiple socket types - [x] Request/Reply - [x] Publish/Subscribe +- Pluggable transport layers + - [x] TCP + - [x] QUIC + - [x] IPC +- Useful stats: latency, throughput, packet drops +- Durable IO abstraction (built-in retries and reconnections) +- Custom wire protocol with support for authentication and compression +- Network simulation mode with dummynet & pfctl +- Extensive benchmarks +- Integration tests + + ## MSRV @@ -59,6 +65,12 @@ Additionally, you can reach out to us on [Discord][discord] if you have any ques This project is licensed under the Open Source [MIT license][mit-license]. +## Disclaimer + + +This software is provided "as is", without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose and noninfringement. In no event shall the authors or copyright holders be liable for any claim, damages or other liability, whether in an action of contract, tort or otherwise, arising from, out of or in connection with the software or the use or other dealings in the software. + + [book]: https://chainbound.github.io/msg-rs/ diff --git a/msg-common/src/lib.rs b/msg-common/src/lib.rs index 1f2aedf..f7e4b94 100644 --- a/msg-common/src/lib.rs +++ b/msg-common/src/lib.rs @@ -1,11 +1,14 @@ -use futures::future::BoxFuture; +#![doc(issue_tracker_base_url = "https://github.com/chainbound/msg-rs/issues/")] +#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] +#![cfg_attr(not(test), warn(unused_crate_dependencies))] + use std::{ pin::Pin, task::{Context, Poll}, time::SystemTime, }; -use futures::{Sink, SinkExt, Stream}; +use futures::{future::BoxFuture, Sink, SinkExt, Stream}; use tokio::sync::mpsc::{ self, error::{TryRecvError, TrySendError}, diff --git a/msg-sim/src/lib.rs b/msg-sim/src/lib.rs index 84c66d6..d95450f 100644 --- a/msg-sim/src/lib.rs +++ b/msg-sim/src/lib.rs @@ -1,8 +1,11 @@ -use std::{collections::HashMap, io, net::IpAddr, time::Duration}; +#![doc(issue_tracker_base_url = "https://github.com/chainbound/msg-rs/issues/")] +#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] +#![cfg_attr(not(test), warn(unused_crate_dependencies))] -pub use protocol::Protocol; +use std::{collections::HashMap, io, net::IpAddr, time::Duration}; mod protocol; +pub use protocol::Protocol; #[cfg(target_os = "macos")] pub mod dummynet; diff --git a/msg-socket/Cargo.toml b/msg-socket/Cargo.toml index 22a7a3c..2378ac6 100644 --- a/msg-socket/Cargo.toml +++ b/msg-socket/Cargo.toml @@ -23,12 +23,12 @@ tokio-util.workspace = true thiserror.workspace = true rustc-hash.workspace = true tracing.workspace = true -async-trait.workspace = true tokio-stream.workspace = true -rand.workspace = true parking_lot.workspace = true [dev-dependencies] +rand.workspace = true + msg-sim.workspace = true tracing-subscriber = "0.3" diff --git a/msg-socket/src/connection/state.rs b/msg-socket/src/connection/state.rs index de5a5f6..c3f7814 100644 --- a/msg-socket/src/connection/state.rs +++ b/msg-socket/src/connection/state.rs @@ -1,4 +1,4 @@ -use std::net::SocketAddr; +use msg_transport::Address; use super::Backoff; @@ -6,20 +6,20 @@ use super::Backoff; /// /// * `C` is the channel type, which is used to send and receive generic messages. /// * `B` is the backoff type, used to control the backoff state for inactive connections. -pub enum ConnectionState { +pub enum ConnectionState { Active { /// Channel to control the underlying connection. This is used to send /// and receive any kind of message in any direction. channel: C, }, Inactive { - addr: SocketAddr, + addr: A, /// The current backoff state for inactive connections. backoff: B, }, } -impl ConnectionState { +impl ConnectionState { /// Returns `true` if the connection is active. #[allow(unused)] pub fn is_active(&self) -> bool { diff --git a/msg-socket/src/lib.rs b/msg-socket/src/lib.rs index 901e3d8..1705d11 100644 --- a/msg-socket/src/lib.rs +++ b/msg-socket/src/lib.rs @@ -1,4 +1,8 @@ -use std::net::SocketAddr; +#![doc(issue_tracker_base_url = "https://github.com/chainbound/msg-rs/issues/")] +#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] +#![cfg_attr(not(test), warn(unused_crate_dependencies))] + +use msg_transport::Address; use tokio::io::{AsyncRead, AsyncWrite}; #[path = "pub/mod.rs"] @@ -36,8 +40,8 @@ pub trait Authenticator: Send + Sync + Unpin + 'static { fn authenticate(&self, id: &Bytes) -> bool; } -pub(crate) struct AuthResult { +pub(crate) struct AuthResult { id: Bytes, - addr: SocketAddr, + addr: A, stream: S, } diff --git a/msg-socket/src/pub/driver.rs b/msg-socket/src/pub/driver.rs index 5789828..e27f0d2 100644 --- a/msg-socket/src/pub/driver.rs +++ b/msg-socket/src/pub/driver.rs @@ -1,10 +1,11 @@ -use futures::{stream::FuturesUnordered, Future, SinkExt, StreamExt}; use std::{ io, pin::Pin, sync::Arc, task::{Context, Poll}, }; + +use futures::{stream::FuturesUnordered, Future, SinkExt, StreamExt}; use tokio::{sync::broadcast, task::JoinSet}; use tokio_util::codec::Framed; use tracing::{debug, error, info, warn}; @@ -16,6 +17,7 @@ use crate::{AuthResult, Authenticator}; use msg_transport::{PeerAddress, Transport}; use msg_wire::{auth, pubsub}; +#[allow(clippy::type_complexity)] pub(crate) struct PubDriver { /// Session ID counter. pub(super) id_counter: u32, @@ -30,7 +32,7 @@ pub(crate) struct PubDriver { /// A set of pending incoming connections, represented by [`Transport::Accept`]. pub(super) conn_tasks: FuturesUnordered, /// A joinset of authentication tasks. - pub(super) auth_tasks: JoinSet, PubError>>, + pub(super) auth_tasks: JoinSet, PubError>>, /// The receiver end of the message broadcast channel. The sender half is stored by [`PubSocket`](super::PubSocket). pub(super) from_socket_bcast: broadcast::Receiver, } @@ -50,7 +52,7 @@ where match auth { Ok(auth) => { // Run custom authenticator - debug!("Authentication passed for {:?} ({})", auth.id, auth.addr); + debug!("Authentication passed for {:?} ({:?})", auth.id, auth.addr); let mut framed = Framed::new(auth.stream, pubsub::Codec::new()); framed.set_backpressure_boundary(this.options.backpressure_boundary); @@ -137,12 +139,12 @@ where fn on_incoming(&mut self, io: T::Io) -> Result<(), io::Error> { let addr = io.peer_addr()?; - info!("New connection from {}", addr); + info!("New connection from {:?}", addr); // If authentication is enabled, start the authentication process if let Some(ref auth) = self.auth { let authenticator = Arc::clone(auth); - debug!("New connection from {}, authenticating", addr); + debug!("New connection from {:?}, authenticating", addr); self.auth_tasks.spawn(async move { let mut conn = Framed::new(io, auth::Codec::new_server()); @@ -201,7 +203,7 @@ where self.id_counter = self.id_counter.wrapping_add(1); debug!( - "New connection from {}, session ID {}", + "New connection from {:?}, session ID {}", addr, self.id_counter ); } diff --git a/msg-socket/src/pub/mod.rs b/msg-socket/src/pub/mod.rs index ef291d2..37f56d5 100644 --- a/msg-socket/src/pub/mod.rs +++ b/msg-socket/src/pub/mod.rs @@ -192,10 +192,10 @@ mod tests { let mut sub_socket = SubSocket::with_options(Tcp::default(), SubOptions::default()); - pub_socket.bind("0.0.0.0:0").await.unwrap(); + pub_socket.bind_socket("0.0.0.0:0").await.unwrap(); let addr = pub_socket.local_addr().unwrap(); - sub_socket.connect(addr).await.unwrap(); + sub_socket.connect_socket(addr).await.unwrap(); sub_socket.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; @@ -221,10 +221,10 @@ mod tests { SubOptions::default().auth_token(Bytes::from("client1")), ); - pub_socket.bind("0.0.0.0:0").await.unwrap(); + pub_socket.bind_socket("0.0.0.0:0").await.unwrap(); let addr = pub_socket.local_addr().unwrap(); - sub_socket.connect(addr).await.unwrap(); + sub_socket.connect_socket(addr).await.unwrap(); sub_socket.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; @@ -250,10 +250,10 @@ mod tests { SubOptions::default().auth_token(Bytes::from("client1")), ); - pub_socket.bind("0.0.0.0:0").await.unwrap(); + pub_socket.bind_socket("0.0.0.0:0").await.unwrap(); let addr = pub_socket.local_addr().unwrap(); - sub_socket.connect(addr).await.unwrap(); + sub_socket.connect_socket(addr).await.unwrap(); sub_socket.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; @@ -278,11 +278,11 @@ mod tests { let mut sub2 = SubSocket::new(Tcp::default()); - pub_socket.bind("0.0.0.0:0").await.unwrap(); + pub_socket.bind_socket("0.0.0.0:0").await.unwrap(); let addr = pub_socket.local_addr().unwrap(); - sub1.connect(addr).await.unwrap(); - sub2.connect(addr).await.unwrap(); + sub1.connect_socket(addr).await.unwrap(); + sub2.connect_socket(addr).await.unwrap(); sub1.subscribe("HELLO".to_string()).await.unwrap(); sub2.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; @@ -313,11 +313,11 @@ mod tests { let mut sub2 = SubSocket::new(Tcp::default()); - pub_socket.bind("0.0.0.0:0").await.unwrap(); + pub_socket.bind_socket("0.0.0.0:0").await.unwrap(); let addr = pub_socket.local_addr().unwrap(); - sub1.connect(addr).await.unwrap(); - sub2.connect(addr).await.unwrap(); + sub1.connect_socket(addr).await.unwrap(); + sub2.connect_socket(addr).await.unwrap(); sub1.subscribe("HELLO".to_string()).await.unwrap(); sub2.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; @@ -349,11 +349,11 @@ mod tests { let mut sub_socket = SubSocket::new(Tcp::default()); // Try to connect and subscribe before the publisher is up - sub_socket.connect("0.0.0.0:6662").await.unwrap(); + sub_socket.connect_socket("0.0.0.0:6662").await.unwrap(); sub_socket.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(500)).await; - pub_socket.bind("0.0.0.0:6662").await.unwrap(); + pub_socket.bind_socket("0.0.0.0:6662").await.unwrap(); tokio::time::sleep(Duration::from_millis(2000)).await; pub_socket @@ -376,11 +376,11 @@ mod tests { let mut sub_socket = SubSocket::new(Quic::default()); // Try to connect and subscribe before the publisher is up - sub_socket.connect("0.0.0.0:6662").await.unwrap(); + sub_socket.connect_socket("0.0.0.0:6662").await.unwrap(); sub_socket.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(1000)).await; - pub_socket.bind("0.0.0.0:6662").await.unwrap(); + pub_socket.bind_socket("0.0.0.0:6662").await.unwrap(); tokio::time::sleep(Duration::from_millis(2000)).await; pub_socket @@ -401,7 +401,7 @@ mod tests { let mut pub_socket = PubSocket::with_options(Tcp::default(), PubOptions::default().max_clients(1)); - pub_socket.bind("0.0.0.0:0").await.unwrap(); + pub_socket.bind_socket("0.0.0.0:0").await.unwrap(); let mut sub1 = SubSocket::::with_options(Tcp::default(), SubOptions::default()); @@ -409,10 +409,10 @@ mod tests { let addr = pub_socket.local_addr().unwrap(); - sub1.connect(addr).await.unwrap(); + sub1.connect_socket(addr).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; assert_eq!(pub_socket.stats().active_clients(), 1); - sub2.connect(addr).await.unwrap(); + sub2.connect_socket(addr).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; assert_eq!(pub_socket.stats().active_clients(), 1); } diff --git a/msg-socket/src/pub/socket.rs b/msg-socket/src/pub/socket.rs index 6a0bbb8..991da56 100644 --- a/msg-socket/src/pub/socket.rs +++ b/msg-socket/src/pub/socket.rs @@ -1,12 +1,12 @@ use bytes::Bytes; use futures::stream::FuturesUnordered; -use std::{io, net::SocketAddr, sync::Arc}; +use std::{io, net::SocketAddr, path::PathBuf, sync::Arc}; use tokio::{ net::{lookup_host, ToSocketAddrs}, sync::broadcast, task::JoinSet, }; -use tracing::{debug, trace}; +use tracing::{debug, trace, warn}; use super::{driver::PubDriver, stats::SocketStats, PubError, PubMessage, PubOptions, SocketState}; use crate::Authenticator; @@ -32,7 +32,42 @@ pub struct PubSocket { // complicates the API a lot. We can always change this later for perf reasons. compressor: Option>, /// The local address this socket is bound to. - local_addr: Option, + local_addr: Option, +} + +impl PubSocket +where + T: Transport + Send + Unpin + 'static, + T::Addr: ToSocketAddrs, +{ +} + +impl PubSocket +where + T: Transport + Send + Unpin + 'static, +{ + /// Binds the socket to the given socket addres + /// + /// This method is only available for transports that support [`SocketAddr`] as address type, + /// like [`Tcp`](msg_transport::tcp::Tcp) and [`Quic`](msg_transport::quic::Quic). + pub async fn bind_socket(&mut self, addr: impl ToSocketAddrs) -> Result<(), PubError> { + let addrs = lookup_host(addr).await?; + self.try_bind(addrs.collect()).await + } +} + +impl PubSocket +where + T: Transport + Send + Unpin + 'static, +{ + /// Binds the socket to the given path. + /// + /// This method is only available for transports that support [`PathBuf`] as address type, + /// like [`Ipc`](msg_transport::ipc::Ipc). + pub async fn bind_path(&mut self, path: impl AsRef) -> Result<(), PubError> { + let addr = path.as_ref().clone(); + self.try_bind(vec![addr]).await + } } impl PubSocket @@ -69,8 +104,10 @@ where self } - /// Binds the socket to the given address. This spawns the socket driver task. - pub async fn bind(&mut self, addr: A) -> Result<(), PubError> { + /// Binds the socket to the given addresses in order until one succeeds. + /// + /// This also spawns the socket driver task. + pub async fn try_bind(&mut self, addresses: Vec) -> Result<(), PubError> { let (to_sessions_bcast, from_socket_bcast) = broadcast::channel(self.options.session_buffer_size); @@ -79,13 +116,11 @@ where .take() .expect("Transport has been moved already"); - let addrs = lookup_host(addr).await?; - - for addr in addrs { - match transport.bind(addr).await { + for addr in addresses { + match transport.bind(addr.clone()).await { Ok(_) => break, Err(e) => { - tracing::warn!("Failed to bind to {}, trying next address: {}", addr, e); + warn!("Failed to bind to {:?}, trying next address: {}", addr, e); continue; } } @@ -98,7 +133,7 @@ where ))); }; - tracing::debug!("Listening on {}", local_addr); + debug!("Listening on {:?}", local_addr); let backend = PubDriver { id_counter: 0, @@ -192,7 +227,7 @@ where } /// Returns the local address this socket is bound to. `None` if the socket is not bound. - pub fn local_addr(&self) -> Option { - self.local_addr + pub fn local_addr(&self) -> Option<&T::Addr> { + self.local_addr.as_ref() } } diff --git a/msg-socket/src/rep/driver.rs b/msg-socket/src/rep/driver.rs index c0d06c4..9899f16 100644 --- a/msg-socket/src/rep/driver.rs +++ b/msg-socket/src/rep/driver.rs @@ -1,13 +1,13 @@ -use bytes::Bytes; -use futures::{stream::FuturesUnordered, Future, FutureExt, SinkExt, Stream, StreamExt}; use std::{ collections::VecDeque, io, - net::SocketAddr, pin::Pin, sync::Arc, task::{Context, Poll}, }; + +use bytes::Bytes; +use futures::{stream::FuturesUnordered, Future, FutureExt, SinkExt, Stream, StreamExt}; use tokio::{ io::{AsyncRead, AsyncWrite}, sync::{mpsc, oneshot}, @@ -15,26 +15,28 @@ use tokio::{ }; use tokio_stream::{StreamMap, StreamNotifyClose}; use tokio_util::codec::Framed; -use tracing::{debug, error, info, warn}; +use tracing::{debug, error, info, trace, warn}; use crate::{rep::SocketState, AuthResult, Authenticator, PubError, RepOptions, Request}; -use msg_transport::{PeerAddress, Transport}; + +use msg_transport::{Address, PeerAddress, Transport}; use msg_wire::{ auth, compression::{try_decompress_payload, Compressor}, reqrep, }; -pub(crate) struct PeerState { +pub(crate) struct PeerState { pending_requests: FuturesUnordered, conn: Framed, - addr: SocketAddr, + addr: A, egress_queue: VecDeque, state: Arc, should_flush: bool, compressor: Option>, } +#[allow(clippy::type_complexity)] pub(crate) struct RepDriver { /// The server transport used to accept incoming connections. pub(crate) transport: T, @@ -44,9 +46,9 @@ pub(crate) struct RepDriver { /// Options shared with socket. pub(crate) options: Arc, /// [`StreamMap`] of connected peers. The key is the peer's address. - pub(crate) peer_states: StreamMap>>, + pub(crate) peer_states: StreamMap>>, /// Sender to the socket front-end. Used to notify the socket of incoming requests. - pub(crate) to_socket: mpsc::Sender, + pub(crate) to_socket: mpsc::Sender>, /// Optional connection authenticator. pub(crate) auth: Option>, /// Optional message compressor. This is shared with the socket to keep @@ -55,7 +57,7 @@ pub(crate) struct RepDriver { /// A set of pending incoming connections, represented by [`Transport::Accept`]. pub(super) conn_tasks: FuturesUnordered, /// A joinset of authentication tasks. - pub(crate) auth_tasks: JoinSet, PubError>>, + pub(crate) auth_tasks: JoinSet, PubError>>, } impl Future for RepDriver @@ -71,7 +73,7 @@ where if let Poll::Ready(Some((peer, msg))) = this.peer_states.poll_next_unpin(cx) { match msg { Some(Ok(mut request)) => { - debug!("Received request from peer {}", peer); + debug!("Received request from peer {:?}", peer); let size = request.msg().len(); @@ -88,10 +90,10 @@ where let _ = this.to_socket.try_send(request); } Some(Err(e)) => { - error!("Error receiving message from peer {}: {:?}", peer, e); + error!("Error receiving message from peer {:?}: {:?}", peer, e); } None => { - warn!("Peer {} disconnected", peer); + warn!("Peer {:?} disconnected", peer); this.state.stats.decrement_active_clients(); } } @@ -103,10 +105,10 @@ where match auth { Ok(auth) => { // Run custom authenticator - tracing::info!("Authentication passed for {:?} ({})", auth.id, auth.addr); + info!("Authentication passed for {:?} ({:?})", auth.id, auth.addr); this.peer_states.insert( - auth.addr, + auth.addr.clone(), StreamNotifyClose::new(PeerState { pending_requests: FuturesUnordered::new(), conn: Framed::new(auth.stream, reqrep::Codec::new()), @@ -183,12 +185,12 @@ where fn on_incoming(&mut self, io: T::Io) -> Result<(), io::Error> { let addr = io.peer_addr()?; - info!("New connection from {}", addr); + info!("New connection from {:?}", addr); // If authentication is enabled, start the authentication process if let Some(ref auth) = self.auth { let authenticator = Arc::clone(auth); - debug!("New connection from {}, authenticating", addr); + debug!("New connection from {:?}, authenticating", addr); self.auth_tasks.spawn(async move { let mut conn = Framed::new(io, auth::Codec::new_server()); @@ -229,7 +231,7 @@ where }); } else { self.peer_states.insert( - addr, + addr.clone(), StreamNotifyClose::new(PeerState { pending_requests: FuturesUnordered::new(), conn: Framed::new(io, reqrep::Codec::new()), @@ -246,8 +248,8 @@ where } } -impl Stream for PeerState { - type Item = Result; +impl Stream for PeerState { + type Item = Result, PubError>; /// Advances the state of the peer. fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -318,7 +320,7 @@ impl Stream for PeerState { // Finally we accept incoming requests from the peer. match this.conn.poll_next_unpin(cx) { Poll::Ready(Some(result)) => { - tracing::trace!("Received message from peer {}: {:?}", this.addr, result); + trace!("Received message from peer {:?}: {:?}", this.addr, result); let msg = result?; let (tx, rx) = oneshot::channel(); @@ -330,7 +332,7 @@ impl Stream for PeerState { }); let request = Request { - source: this.addr, + source: this.addr.clone(), response: tx, compression_type: msg.header().compression_type(), msg: msg.into_payload(), @@ -339,7 +341,7 @@ impl Stream for PeerState { return Poll::Ready(Some(Ok(request))); } Poll::Ready(None) => { - tracing::error!("Framed closed unexpectedly (peer {})", this.addr); + error!("Framed closed unexpectedly (peer {:?})", this.addr); return Poll::Ready(None); } Poll::Pending => {} diff --git a/msg-socket/src/rep/mod.rs b/msg-socket/src/rep/mod.rs index 282c1d8..d0b0e4a 100644 --- a/msg-socket/src/rep/mod.rs +++ b/msg-socket/src/rep/mod.rs @@ -1,5 +1,5 @@ use bytes::Bytes; -use std::net::SocketAddr; +use msg_transport::Address; use thiserror::Error; use tokio::sync::oneshot; @@ -62,9 +62,9 @@ pub(crate) struct SocketState { } /// A request received by the socket. -pub struct Request { +pub struct Request { /// The source address of the request. - source: SocketAddr, + source: A, /// The compression type used for the request payload compression_type: u8, /// The oneshot channel to respond to the request. @@ -73,10 +73,10 @@ pub struct Request { msg: Bytes, } -impl Request { +impl Request { /// Returns the source address of the request. - pub fn source(&self) -> SocketAddr { - self.source + pub fn source(&self) -> &A { + &self.source } /// Returns a reference to the message. @@ -94,7 +94,7 @@ impl Request { #[cfg(test)] mod tests { - use std::time::Duration; + use std::{net::SocketAddr, time::Duration}; use futures::StreamExt; use msg_transport::tcp::Tcp; @@ -113,10 +113,10 @@ mod tests { async fn reqrep_simple() { let _ = tracing_subscriber::fmt::try_init(); let mut rep = RepSocket::new(Tcp::default()); - rep.bind(localhost()).await.unwrap(); + rep.bind_socket(localhost()).await.unwrap(); let mut req = ReqSocket::new(Tcp::default()); - req.connect(rep.local_addr().unwrap()).await.unwrap(); + req.connect_socket(rep.local_addr().unwrap()).await.unwrap(); tokio::spawn(async move { loop { @@ -156,7 +156,7 @@ mod tests { // Try to connect even through the server isn't up yet let endpoint = addr.clone(); let connection_attempt = tokio::spawn(async move { - req.connect(endpoint).await.unwrap(); + req.connect_socket(endpoint).await.unwrap(); req }); @@ -164,7 +164,7 @@ mod tests { // Wait a moment to start the server tokio::time::sleep(Duration::from_millis(500)).await; let mut rep = RepSocket::new(Tcp::default()); - rep.bind(addr).await.unwrap(); + rep.bind_socket(addr).await.unwrap(); let req = connection_attempt.await.unwrap(); @@ -193,7 +193,7 @@ mod tests { let _ = tracing_subscriber::fmt::try_init(); let mut rep = RepSocket::new(Tcp::default()).with_auth(Auth); - rep.bind(localhost()).await.unwrap(); + rep.bind_socket(localhost()).await.unwrap(); // Initialize socket with a client ID. This will implicitly enable authentication. let mut req = ReqSocket::with_options( @@ -201,7 +201,7 @@ mod tests { ReqOptions::default().auth_token(Bytes::from("REQ")), ); - req.connect(rep.local_addr().unwrap()).await.unwrap(); + req.connect_socket(rep.local_addr().unwrap()).await.unwrap(); tracing::info!("Connected to rep"); @@ -236,16 +236,16 @@ mod tests { async fn rep_max_connections() { let _ = tracing_subscriber::fmt::try_init(); let mut rep = RepSocket::with_options(Tcp::default(), RepOptions::default().max_clients(1)); - rep.bind("127.0.0.1:0").await.unwrap(); + rep.bind_socket("127.0.0.1:0").await.unwrap(); let addr = rep.local_addr().unwrap(); let mut req1 = ReqSocket::new(Tcp::default()); - req1.connect(addr).await.unwrap(); + req1.connect_socket(addr).await.unwrap(); tokio::time::sleep(Duration::from_secs(1)).await; assert_eq!(rep.stats().active_clients(), 1); let mut req2 = ReqSocket::new(Tcp::default()); - req2.connect(addr).await.unwrap(); + req2.connect_socket(addr).await.unwrap(); tokio::time::sleep(Duration::from_secs(1)).await; assert_eq!(rep.stats().active_clients(), 1); } @@ -256,13 +256,13 @@ mod tests { RepSocket::with_options(Tcp::default(), RepOptions::default().min_compress_size(0)) .with_compressor(SnappyCompressor); - rep.bind("0.0.0.0:4445").await.unwrap(); + rep.bind_socket("0.0.0.0:4445").await.unwrap(); let mut req = ReqSocket::with_options(Tcp::default(), ReqOptions::default().min_compress_size(0)) .with_compressor(GzipCompressor::new(6)); - req.connect("0.0.0.0:4445").await.unwrap(); + req.connect_socket("0.0.0.0:4445").await.unwrap(); tokio::spawn(async move { let req = rep.next().await.unwrap(); diff --git a/msg-socket/src/rep/socket.rs b/msg-socket/src/rep/socket.rs index 09ce53b..6a905a5 100644 --- a/msg-socket/src/rep/socket.rs +++ b/msg-socket/src/rep/socket.rs @@ -1,25 +1,29 @@ -use futures::{stream::FuturesUnordered, Stream}; -use msg_wire::compression::Compressor; use std::{ io, net::SocketAddr, + path::PathBuf, pin::Pin, sync::Arc, task::{Context, Poll}, }; + +use futures::{stream::FuturesUnordered, Stream}; use tokio::{ net::{lookup_host, ToSocketAddrs}, sync::mpsc, task::JoinSet, }; use tokio_stream::StreamMap; +use tracing::{debug, warn}; use crate::{ rep::{driver::RepDriver, DEFAULT_BUFFER_SIZE}, rep::{SocketState, SocketStats}, Authenticator, PubError, RepOptions, Request, }; -use msg_transport::{Transport, TransportExt}; + +use msg_transport::Transport; +use msg_wire::compression::Compressor; /// A reply socket. This socket implements [`Stream`] and yields incoming [`Request`]s. #[derive(Default)] @@ -29,21 +33,49 @@ pub struct RepSocket { /// The reply socket state, shared with the driver. state: Arc, /// Receiver from the socket driver. - from_driver: Option>, + from_driver: Option>>, /// The transport used by this socket. This value is temporary and will be moved /// to the driver task once the socket is bound. transport: Option, /// Optional connection authenticator. auth: Option>, /// The local address this socket is bound to. - local_addr: Option, + local_addr: Option, /// Optional message compressor. compressor: Option>, } impl RepSocket where - T: TransportExt + Send + Unpin + 'static, + T: Transport + Send + Unpin + 'static, +{ + /// Binds the socket to the given socket addres + /// + /// This method is only available for transports that support [`SocketAddr`] as address type, + /// like [`Tcp`](msg_transport::tcp::Tcp) and [`Quic`](msg_transport::quic::Quic). + pub async fn bind_socket(&mut self, addr: impl ToSocketAddrs) -> Result<(), PubError> { + let addrs = lookup_host(addr).await?; + self.try_bind(addrs.collect()).await + } +} + +impl RepSocket +where + T: Transport + Send + Unpin + 'static, +{ + /// Binds the socket to the given path. + /// + /// This method is only available for transports that support [`PathBuf`] as address type, + /// like [`Ipc`](msg_transport::ipc::Ipc). + pub async fn bind_path(&mut self, path: impl Into) -> Result<(), PubError> { + let addr = path.into().clone(); + self.try_bind(vec![addr]).await + } +} + +impl RepSocket +where + T: Transport + Send + Unpin + 'static, { /// Creates a new reply socket with the default [`RepOptions`]. pub fn new(transport: T) -> Self { @@ -76,7 +108,7 @@ where } /// Binds the socket to the given address. This spawns the socket driver task. - pub async fn bind(&mut self, addr: A) -> Result<(), PubError> { + pub async fn try_bind(&mut self, addresses: Vec) -> Result<(), PubError> { let (to_socket, from_backend) = mpsc::channel(DEFAULT_BUFFER_SIZE); let mut transport = self @@ -84,12 +116,11 @@ where .take() .expect("Transport has been moved already"); - let addrs = lookup_host(addr).await?; - for addr in addrs { - match transport.bind(addr).await { + for addr in addresses { + match transport.bind(addr.clone()).await { Ok(_) => break, Err(e) => { - tracing::warn!("Failed to bind to {}, trying next address: {}", addr, e); + warn!("Failed to bind to {:?}, trying next address: {}", addr, e); continue; } } @@ -102,7 +133,7 @@ where ))); }; - tracing::debug!("Listening on {}", local_addr); + debug!("Listening on {:?}", local_addr); let backend = RepDriver { transport, @@ -129,13 +160,13 @@ where } /// Returns the local address this socket is bound to. `None` if the socket is not bound. - pub fn local_addr(&self) -> Option { - self.local_addr + pub fn local_addr(&self) -> Option<&T::Addr> { + self.local_addr.as_ref() } } impl Stream for RepSocket { - type Item = Request; + type Item = Request; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.get_mut() diff --git a/msg-socket/src/req/driver.rs b/msg-socket/src/req/driver.rs index f63f365..5fb1a9f 100644 --- a/msg-socket/src/req/driver.rs +++ b/msg-socket/src/req/driver.rs @@ -4,7 +4,6 @@ use rustc_hash::FxHashMap; use std::{ collections::VecDeque, io, - net::SocketAddr, pin::Pin, sync::Arc, task::{ready, Context, Poll}, @@ -44,12 +43,13 @@ pub(crate) struct ReqDriver { /// The transport for this socket. pub(crate) transport: T, /// The address of the server. - pub(crate) addr: SocketAddr, + pub(crate) addr: T::Addr, /// The connection task which handles the connection to the server. pub(crate) conn_task: Option>, /// The transport controller, wrapped in a [`ConnectionState`] for backoff. /// The [`Framed`] object can send and receive messages from the socket. - pub(crate) conn_state: ConnectionState, ExponentialBackoff>, + pub(crate) conn_state: + ConnectionState, ExponentialBackoff, T::Addr>, /// The outgoing message queue. pub(crate) egress_queue: VecDeque, /// The currently pending requests, if any. Uses [`FxHashMap`] for performance. @@ -79,17 +79,17 @@ where { /// Start the connection task to the server, handling authentication if necessary. /// The result will be polled by the driver and re-tried according to the backoff policy. - fn try_connect(&mut self, addr: SocketAddr) { - trace!("Trying to connect to {}", addr); + fn try_connect(&mut self, addr: T::Addr) { + trace!("Trying to connect to {:?}", addr); - let connect = self.transport.connect(addr); + let connect = self.transport.connect(addr.clone()); let token = self.options.auth_token.clone(); self.conn_task = Some(Box::pin(async move { let mut io = match connect.await { Ok(io) => io, Err(e) => { - error!("Failed to connect to {}: {:?}", addr, e); + error!("Failed to connect to {:?}: {:?}", addr, e); return Err(e); } }; @@ -108,7 +108,7 @@ where match conn.next().await { Some(res) => match res { Ok(auth::Message::Ack) => { - debug!("Connected to {}", addr); + debug!("Connected to {:?}", addr); Ok(io) } Ok(msg) => { @@ -126,7 +126,7 @@ where } } } else { - debug!("Connected to {}", addr); + debug!("Connected to {:?}", addr); Ok(io) } })); @@ -242,7 +242,7 @@ where #[inline] fn reset_connection(&mut self) { self.conn_state = ConnectionState::Inactive { - addr: self.addr, + addr: self.addr.clone(), backoff: ExponentialBackoff::new(Duration::from_millis(20), 16), }; } @@ -287,14 +287,14 @@ where // or poll the backoff timer if we're already trying to connect. if let ConnectionState::Inactive { ref mut backoff, - addr, + ref addr, } = this.conn_state { if let Poll::Ready(item) = backoff.poll_next_unpin(cx) { if let Some(duration) = item { if this.conn_task.is_none() { debug!(backoff = ?duration, "Retrying connection to {:?}", addr); - this.try_connect(addr); + this.try_connect(addr.clone()); } else { debug!(backoff = ?duration, "Not retrying connection to {:?} as there is already a connection task", addr); } diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index 5f47a52..7edb1fa 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -1,6 +1,8 @@ use bytes::Bytes; use msg_wire::compression::Compressor; use rustc_hash::FxHashMap; +use std::net::SocketAddr; +use std::path::PathBuf; use std::{io, sync::Arc, time::Duration}; use tokio::net::{lookup_host, ToSocketAddrs}; use tokio::sync::{mpsc, oneshot}; @@ -29,6 +31,34 @@ pub struct ReqSocket { compressor: Option>, } +impl ReqSocket +where + T: Transport + Send + Sync + Unpin + 'static, +{ + /// Connects to the target address with the default options. + pub async fn connect_socket(&mut self, addr: impl ToSocketAddrs) -> Result<(), ReqError> { + let mut addrs = lookup_host(addr).await?; + let endpoint = addrs.next().ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "could not find any valid address", + ) + })?; + + self.try_connect(endpoint).await + } +} + +impl ReqSocket +where + T: Transport + Send + Sync + Unpin + 'static, +{ + /// Connects to the target path with the default options. + pub async fn connect_path(&mut self, addr: impl Into) -> Result<(), ReqError> { + self.try_connect(addr.into().clone()).await + } +} + impl ReqSocket where T: Transport + Send + Sync + Unpin + 'static, @@ -75,17 +105,9 @@ where response_rx.await.map_err(|_| ReqError::SocketClosed)? } - /// Connects to the target address with the default options. + /// Tries to connect to the target endpoint with the default options. /// A ReqSocket can only be connected to a single address. - pub async fn connect(&mut self, addr: A) -> Result<(), ReqError> { - let mut addrs = lookup_host(addr).await?; - let endpoint = addrs.next().ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "could not find any valid address", - ) - })?; - + pub async fn try_connect(&mut self, endpoint: T::Addr) -> Result<(), ReqError> { // Initialize communication channels let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE); @@ -97,7 +119,7 @@ where // We initialize the connection as inactive, and let it be activated // by the backend task as soon as the driver is spawned. let conn_state = ConnectionState::Inactive { - addr: endpoint, + addr: endpoint.clone(), backoff: ExponentialBackoff::new(Duration::from_millis(20), 16), }; diff --git a/msg-socket/src/sub/driver.rs b/msg-socket/src/sub/driver.rs index 4f6b4c0..a157ecb 100644 --- a/msg-socket/src/sub/driver.rs +++ b/msg-socket/src/sub/driver.rs @@ -1,13 +1,13 @@ -use futures::{Future, SinkExt, StreamExt}; -use rustc_hash::FxHashMap; use std::{ collections::HashSet, io, - net::{IpAddr, Ipv4Addr, SocketAddr}, pin::Pin, sync::Arc, task::{Context, Poll}, }; + +use futures::{Future, SinkExt, StreamExt}; +use rustc_hash::FxHashMap; use tokio::sync::mpsc::{self, error::TrySendError}; use tokio_util::codec::Framed; use tracing::{debug, error, info, warn}; @@ -34,17 +34,18 @@ pub(crate) struct SubDriver { /// The transport for this socket. pub(super) transport: T, /// Commands from the socket. - pub(super) from_socket: mpsc::Receiver, + pub(super) from_socket: mpsc::Receiver>, /// Messages to the socket. - pub(super) to_socket: mpsc::Sender, + pub(super) to_socket: mpsc::Sender>, /// A joinset of authentication tasks. - pub(super) connection_tasks: JoinMap>, + pub(super) connection_tasks: JoinMap>, /// The set of subscribed topics. pub(super) subscribed_topics: HashSet, /// All publisher sessions for this subscriber socket, keyed by address. - pub(super) publishers: FxHashMap>, + pub(super) publishers: + FxHashMap>, /// Socket state. This is shared with the backend task. - pub(super) state: Arc, + pub(super) state: Arc>, } impl Future for SubDriver @@ -78,7 +79,7 @@ where this.on_connection(addr, io); } Err(e) => { - error!(%addr, "Error connecting to publisher: {:?}", e); + error!(?addr, "Error connecting to publisher: {:?}", e); } } @@ -96,10 +97,10 @@ where { /// De-activates a publisher by setting it to [`ConnectionState::Inactive`]. /// This will initialize the backoff stream. - fn reset_publisher(&mut self, addr: SocketAddr) { + fn reset_publisher(&mut self, addr: T::Addr) { debug!("Resetting publisher at {addr:?}"); self.publishers.insert( - addr, + addr.clone(), ConnectionState::Inactive { addr, backoff: ExponentialBackoff::new(self.options.initial_backoff, 16), @@ -108,7 +109,7 @@ where } /// Returns true if we're already connected to the given publisher address. - fn is_connected(&self, addr: &SocketAddr) -> bool { + fn is_connected(&self, addr: &T::Addr) -> bool { if self.publishers.get(addr).is_some_and(|s| s.is_active()) { return true; } @@ -116,7 +117,7 @@ where false } - fn is_known(&self, addr: &SocketAddr) -> bool { + fn is_known(&self, addr: &T::Addr) -> bool { self.publishers.contains_key(addr) } @@ -133,8 +134,8 @@ where if let Err(TrySendError::Closed(_)) = channel.try_send(SessionCommand::Subscribe(topic.clone())) { - warn!(publisher = %addr, "Error trying to subscribe to topic {topic}: publisher channel closed"); - inactive.push(*addr); + warn!(publisher = ?addr, "Error trying to subscribe to topic {topic}: publisher channel closed"); + inactive.push(addr.clone()); } } } @@ -167,8 +168,8 @@ where if let Err(TrySendError::Closed(_)) = channel.try_send(SessionCommand::Unsubscribe(topic.clone())) { - warn!(publisher = %addr, "Error trying to unsubscribe from topic {topic}: publisher channel closed"); - inactive.push(*addr); + warn!(publisher = ?addr, "Error trying to unsubscribe from topic {topic}: publisher channel closed"); + inactive.push(addr.clone()); } } } @@ -189,7 +190,7 @@ where } } - fn on_command(&mut self, cmd: Command) { + fn on_command(&mut self, cmd: Command) { debug!("Received command: {:?}", cmd); match cmd { Command::Subscribe { topic } => { @@ -198,20 +199,16 @@ where Command::Unsubscribe { topic } => { self.unsubscribe(topic); } - Command::Connect { mut endpoint } => { - // Some transport implementations (e.g. Quinn) can't dial an unspecified IP address, so replace - // it with localhost. - if endpoint.ip().is_unspecified() { - // TODO: support IPv6 - endpoint.set_ip(IpAddr::V4(Ipv4Addr::LOCALHOST)); - } - + Command::Connect { endpoint } => { if self.is_known(&endpoint) { - debug!(%endpoint, "Publisher already known, ignoring connect command"); + debug!( + ?endpoint, + "Publisher already known, ignoring connect command" + ); return; } - self.connect(endpoint); + self.connect(endpoint.clone()); // Also set the publisher to the disconnected state. This will make sure that if the // initial connection attempt fails, it will be retried in `poll_publishers`. @@ -219,10 +216,10 @@ where } Command::Disconnect { endpoint } => { if self.publishers.remove(&endpoint).is_some() { - debug!(%endpoint, "Disconnected from publisher"); + debug!(?endpoint, "Disconnected from publisher"); self.state.stats.remove(&endpoint); } else { - debug!(%endpoint, "Not connected to publisher"); + debug!(?endpoint, "Not connected to publisher"); }; } Command::Shutdown => { @@ -232,11 +229,11 @@ where } } - fn connect(&mut self, addr: SocketAddr) { - let connect = self.transport.connect(addr); + fn connect(&mut self, addr: T::Addr) { + let connect = self.transport.connect(addr.clone()); let token = self.options.auth_token.clone(); - self.connection_tasks.spawn(addr, async move { + self.connection_tasks.spawn(addr.clone(), async move { let io = match connect.await { Ok(io) => io, Err(e) => { @@ -297,21 +294,21 @@ where }); } - fn on_connection(&mut self, addr: SocketAddr, io: T::Io) { + fn on_connection(&mut self, addr: T::Addr, io: T::Io) { if self.is_connected(&addr) { // We're already connected to this publisher - warn!(%addr, "Already connected to publisher"); + warn!(?addr, "Already connected to publisher"); return; } - debug!("Connection to {} established, spawning session", addr); + debug!("Connection to {:?} established, spawning session", addr); let framed = Framed::with_capacity(io, pubsub::Codec::new(), self.options.read_buffer_size); let (driver_channel, mut publisher_channel) = channel(1024, 64); let publisher_session = - PublisherSession::new(addr, PublisherStream::from(framed), driver_channel); + PublisherSession::new(addr.clone(), PublisherStream::from(framed), driver_channel); // Get the shared session stats. let session_stats = publisher_session.stats(); @@ -324,12 +321,12 @@ where .try_send(SessionCommand::Subscribe(topic.clone())) .is_err() { - error!(publisher = %addr, "Error trying to subscribe to topic {topic} on startup: publisher channel closed / full"); + error!(publisher = ?addr, "Error trying to subscribe to topic {topic} on startup: publisher channel closed / full"); } } self.publishers.insert( - addr, + addr.clone(), ConnectionState::Active { channel: publisher_channel, }, @@ -364,9 +361,9 @@ where } }; - let msg = PubMessage::new(*addr, msg.topic, msg.payload); + let msg = PubMessage::new(addr.clone(), msg.topic, msg.payload); - debug!(source = %msg.source, "New message: {:?}", msg); + debug!(source = ?msg.source, "New message: {:?}", msg); // TODO: queuing if let Err(TrySendError::Full(msg)) = self.to_socket.try_send(msg) { error!( @@ -378,8 +375,8 @@ where progress = true; } Poll::Ready(None) => { - error!(source = %addr, "Publisher stream closed, removing channel"); - inactive.push(*addr); + error!(source = ?addr, "Publisher stream closed, removing channel"); + inactive.push(addr.clone()); progress = true; } @@ -395,13 +392,13 @@ where // Only retry if there are no active connection tasks if !self.connection_tasks.contains_key(addr) { debug!(backoff = ?duration, "Retrying connection to {:?}", addr); - to_retry.push(*addr); + to_retry.push(addr.clone()); } else { debug!(backoff = ?duration, "Not retrying connection to {:?} as there is already a connection task", addr); } } else { error!("Exceeded maximum number of retries for {:?}, terminating connection", addr); - to_terminate.push(*addr); + to_terminate.push(addr.clone()); } } } diff --git a/msg-socket/src/sub/mod.rs b/msg-socket/src/sub/mod.rs index 01b0774..b6d7a7d 100644 --- a/msg-socket/src/sub/mod.rs +++ b/msg-socket/src/sub/mod.rs @@ -1,7 +1,8 @@ use bytes::Bytes; use core::fmt; +use msg_transport::Address; use msg_wire::pubsub; -use std::{net::SocketAddr, time::Duration}; +use std::time::Duration; use thiserror::Error; mod driver; @@ -33,15 +34,15 @@ pub enum SubError { } #[derive(Debug)] -enum Command { +enum Command { /// Subscribe to a topic. Subscribe { topic: String }, /// Unsubscribe from a topic. Unsubscribe { topic: String }, /// Connect to a publisher socket. - Connect { endpoint: SocketAddr }, + Connect { endpoint: A }, /// Disconnect from a publisher socket. - Disconnect { endpoint: SocketAddr }, + Disconnect { endpoint: A }, /// Shut down the driver. Shutdown, } @@ -101,17 +102,17 @@ impl Default for SubOptions { /// A message received from a publisher. /// Includes the source, topic, and payload. #[derive(Clone)] -pub struct PubMessage { +pub struct PubMessage { /// The source address of the publisher. We need this because /// a subscriber can connect to multiple publishers. - source: SocketAddr, + source: A, /// The topic of the message. topic: String, /// The message payload. payload: Bytes, } -impl fmt::Debug for PubMessage { +impl fmt::Debug for PubMessage { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PubMessage") .field("source", &self.source) @@ -121,8 +122,8 @@ impl fmt::Debug for PubMessage { } } -impl PubMessage { - pub fn new(source: SocketAddr, topic: String, payload: Bytes) -> Self { +impl PubMessage { + pub fn new(source: A, topic: String, payload: Bytes) -> Self { Self { source, topic, @@ -131,8 +132,8 @@ impl PubMessage { } #[inline] - pub fn source(&self) -> SocketAddr { - self.source + pub fn source(&self) -> &A { + &self.source } #[inline] @@ -153,12 +154,22 @@ impl PubMessage { /// The request socket state, shared between the backend task and the socket. #[derive(Debug, Default)] -pub(crate) struct SocketState { - pub(crate) stats: SocketStats, +pub(crate) struct SocketState { + pub(crate) stats: SocketStats, +} + +impl SocketState { + pub fn new() -> Self { + Self { + stats: SocketStats::new(), + } + } } #[cfg(test)] mod tests { + use std::net::SocketAddr; + use msg_transport::tcp::Tcp; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, diff --git a/msg-socket/src/sub/session.rs b/msg-socket/src/sub/session.rs index 0560743..e645338 100644 --- a/msg-socket/src/sub/session.rs +++ b/msg-socket/src/sub/session.rs @@ -1,14 +1,17 @@ +use std::{ + collections::VecDeque, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + use bytes::Bytes; use futures::{Future, StreamExt}; -use std::collections::VecDeque; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, error, warn}; use msg_common::{unix_micros, Channel}; +use msg_transport::Address; use msg_wire::pubsub; use super::{ @@ -23,9 +26,9 @@ pub(super) enum SessionCommand { /// Manages the state of a single publisher, represented as a [`Future`]. #[must_use = "This future must be spawned"] -pub(super) struct PublisherSession { +pub(super) struct PublisherSession { /// The addr of the publisher - addr: SocketAddr, + addr: A, /// The egress queue (for subscribe / unsubscribe messages) egress: VecDeque, /// The inner stream @@ -37,9 +40,9 @@ pub(super) struct PublisherSession { driver_channel: Channel, } -impl PublisherSession { +impl PublisherSession { pub(super) fn new( - addr: SocketAddr, + addr: A, stream: PublisherStream, channel: Channel, ) -> Self { @@ -82,11 +85,11 @@ impl PublisherSession { self.stats.update_latency(now.saturating_sub(msg.timestamp)); if self.driver_channel.try_send(msg).is_err() { - warn!(addr = %self.addr, "Failed to send message to driver"); + warn!(addr = ?self.addr, "Failed to send message to driver"); } } Err(e) => { - error!(addr = %self.addr, "Error receiving message: {:?}", e); + error!(addr = ?self.addr, "Error receiving message: {:?}", e); } } } @@ -99,7 +102,7 @@ impl PublisherSession { } } -impl Future for PublisherSession { +impl Future for PublisherSession { type Output = (); /// This poll implementation prioritizes incoming messages over outgoing messages. @@ -116,7 +119,7 @@ impl Future for PublisherSession { continue; } Poll::Ready(None) => { - error!(addr = %this.addr, "Publisher stream closed"); + error!(addr = ?this.addr, "Publisher stream closed"); return Poll::Ready(()); } Poll::Pending => {} @@ -145,7 +148,7 @@ impl Future for PublisherSession { continue; } None => { - warn!(addr = %this.addr, "Driver channel closed, shutting down session"); + warn!(addr = ?this.addr, "Driver channel closed, shutting down session"); return Poll::Ready(()); } } diff --git a/msg-socket/src/sub/socket.rs b/msg-socket/src/sub/socket.rs index 4fcd2e2..588cf14 100644 --- a/msg-socket/src/sub/socket.rs +++ b/msg-socket/src/sub/socket.rs @@ -3,7 +3,8 @@ use rustc_hash::FxHashMap; use std::{ collections::HashSet, io, - net::SocketAddr, + net::{IpAddr, Ipv4Addr, SocketAddr}, + path::PathBuf, pin::Pin, sync::Arc, task::{Context, Poll}, @@ -23,19 +24,117 @@ use super::{ pub struct SubSocket { /// Command channel to the socket driver. - to_driver: mpsc::Sender, + to_driver: mpsc::Sender>, /// Receiver channel from the socket driver. - from_driver: mpsc::Receiver, + from_driver: mpsc::Receiver>, /// Options for the socket. These are shared with the backend task. #[allow(unused)] options: Arc, /// The pending driver. driver: Option>, /// Socket state. This is shared with the socket frontend. - state: Arc, + state: Arc>, + /// Marker for the transport type. _marker: std::marker::PhantomData, } +impl SubSocket +where + T: Transport + Send + Sync + Unpin + 'static, +{ + /// Connects to the given endpoint asynchronously. + pub async fn connect_socket(&mut self, endpoint: impl ToSocketAddrs) -> Result<(), SubError> { + let mut addrs = lookup_host(endpoint).await?; + let mut endpoint = addrs.next().ok_or(SubError::Io(io::Error::new( + io::ErrorKind::InvalidInput, + "could not find any valid address", + )))?; + + // Some transport implementations (e.g. Quinn) can't dial an unspecified + // IP address, so replace it with localhost. + if endpoint.ip().is_unspecified() { + // TODO: support IPv6 + endpoint.set_ip(IpAddr::V4(Ipv4Addr::LOCALHOST)); + } + + self.connect(endpoint).await + } + + /// Attempts to connect to the given endpoint immediately. + pub fn try_connect_socket(&mut self, endpoint: impl Into) -> Result<(), SubError> { + let addr = endpoint.into(); + let mut endpoint: SocketAddr = addr.parse().map_err(|_| { + SubError::Io(io::Error::new( + io::ErrorKind::InvalidInput, + "could not find any valid address", + )) + })?; + + // Some transport implementations (e.g. Quinn) can't dial an unspecified + // IP address, so replace it with localhost. + if endpoint.ip().is_unspecified() { + // TODO: support IPv6 + endpoint.set_ip(IpAddr::V4(Ipv4Addr::LOCALHOST)); + } + + self.try_connect(endpoint) + } + + pub async fn disconnect_socket( + &mut self, + endpoint: impl ToSocketAddrs, + ) -> Result<(), SubError> { + let mut addrs = lookup_host(endpoint).await?; + let endpoint = addrs.next().ok_or(SubError::Io(io::Error::new( + io::ErrorKind::InvalidInput, + "could not find any valid address", + )))?; + + self.disconnect(endpoint).await + } + + pub fn try_disconnect_socket(&mut self, endpoint: impl Into) -> Result<(), SubError> { + let endpoint = endpoint.into(); + let endpoint: SocketAddr = endpoint.parse().map_err(|_| { + SubError::Io(io::Error::new( + io::ErrorKind::InvalidInput, + "could not find any valid address", + )) + })?; + + self.try_disconnect(endpoint) + } +} + +impl SubSocket +where + T: Transport + Send + Sync + Unpin + 'static, +{ + /// Connects to the given path asynchronously. + pub async fn connect_path(&mut self, path: impl AsRef) -> Result<(), SubError> { + let path = path.as_ref().clone(); + self.connect(path).await + } + + /// Attempts to connect to the given path immediately. + pub fn try_connect_path(&mut self, path: impl AsRef) -> Result<(), SubError> { + let path = path.as_ref().clone(); + self.try_connect(path) + } + + /// Disconnects from the given path asynchronously. + pub async fn disconnect_path(&mut self, path: impl AsRef) -> Result<(), SubError> { + let path = path.as_ref().clone(); + self.disconnect(path).await + } + + /// Attempts to disconnect from the given path immediately. + pub fn try_disconnect_path(&mut self, path: impl AsRef) -> Result<(), SubError> { + let path = path.as_ref().clone(); + self.try_disconnect(path) + } +} + impl SubSocket where T: Transport + Send + Sync + Unpin + 'static, @@ -51,7 +150,7 @@ where let options = Arc::new(options); - let state = Arc::new(SocketState::default()); + let state = Arc::new(SocketState::new()); let mut publishers = FxHashMap::default(); publishers.reserve(32); @@ -78,66 +177,30 @@ where } /// Asynchronously connects to the endpoint. - pub async fn connect(&mut self, endpoint: A) -> Result<(), SubError> { + pub async fn connect(&mut self, endpoint: T::Addr) -> Result<(), SubError> { self.ensure_active_driver(); - - let mut addrs = lookup_host(endpoint).await?; - let endpoint = addrs.next().ok_or(SubError::Io(io::Error::new( - io::ErrorKind::InvalidInput, - "could not find any valid address", - )))?; - self.send_command(Command::Connect { endpoint }).await?; - Ok(()) } /// Immediately send a connect command to the driver. - pub fn try_connect(&mut self, endpoint: impl Into) -> Result<(), SubError> { + pub fn try_connect(&mut self, endpoint: T::Addr) -> Result<(), SubError> { self.ensure_active_driver(); - - let endpoint = endpoint.into(); - let endpoint: SocketAddr = endpoint.parse().map_err(|_| { - SubError::Io(io::Error::new( - io::ErrorKind::InvalidInput, - "could not find any valid address", - )) - })?; - self.try_send_command(Command::Connect { endpoint })?; - Ok(()) } /// Asynchronously disconnects from the endpoint. - pub async fn disconnect(&mut self, endpoint: A) -> Result<(), SubError> { + pub async fn disconnect(&mut self, endpoint: T::Addr) -> Result<(), SubError> { self.ensure_active_driver(); - - let mut addrs = lookup_host(endpoint).await?; - let endpoint = addrs.next().ok_or(SubError::Io(io::Error::new( - io::ErrorKind::InvalidInput, - "could not find any valid address", - )))?; - self.send_command(Command::Disconnect { endpoint }).await?; - Ok(()) } /// Immediately send a disconnect command to the driver. - pub fn try_disconnect(&mut self, endpoint: impl Into) -> Result<(), SubError> { + pub fn try_disconnect(&mut self, endpoint: T::Addr) -> Result<(), SubError> { self.ensure_active_driver(); - - let endpoint = endpoint.into(); - let endpoint: SocketAddr = endpoint.parse().map_err(|_| { - SubError::Io(io::Error::new( - io::ErrorKind::InvalidInput, - "could not find any valid address", - )) - })?; - self.try_send_command(Command::Disconnect { endpoint })?; - Ok(()) } @@ -193,7 +256,7 @@ where /// Sends a command to the driver, returning [`SubError::SocketClosed`] if the /// driver has been dropped. - async fn send_command(&self, command: Command) -> Result<(), SubError> { + async fn send_command(&self, command: Command) -> Result<(), SubError> { self.to_driver .send(command) .await @@ -202,7 +265,7 @@ where Ok(()) } - fn try_send_command(&self, command: Command) -> Result<(), SubError> { + fn try_send_command(&self, command: Command) -> Result<(), SubError> { use mpsc::error::TrySendError::*; self.to_driver.try_send(command).map_err(|e| match e { Full(_) => SubError::ChannelFull, @@ -219,7 +282,7 @@ where } } - pub fn stats(&self) -> &SocketStats { + pub fn stats(&self) -> &SocketStats { &self.state.stats } } @@ -232,7 +295,7 @@ impl Drop for SubSocket { } impl Stream for SubSocket { - type Item = PubMessage; + type Item = PubMessage; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.from_driver.poll_recv(cx) diff --git a/msg-socket/src/sub/stats.rs b/msg-socket/src/sub/stats.rs index 8efeaf6..8f76bf5 100644 --- a/msg-socket/src/sub/stats.rs +++ b/msg-socket/src/sub/stats.rs @@ -1,35 +1,43 @@ use std::{ collections::HashMap, - net::SocketAddr, sync::{ atomic::{AtomicU64, AtomicUsize, Ordering}, Arc, }, }; +use msg_transport::Address; use parking_lot::RwLock; /// Statistics for a reply socket. These are shared between the driver task /// and the socket. #[derive(Debug, Default)] -pub struct SocketStats { +pub struct SocketStats { /// Individual session stats for each publisher - session_stats: RwLock>>, + session_stats: RwLock>>, } -impl SocketStats { +impl SocketStats { + pub fn new() -> Self { + Self { + session_stats: RwLock::new(HashMap::new()), + } + } +} + +impl SocketStats { #[inline] - pub(crate) fn insert(&self, addr: SocketAddr, stats: Arc) { + pub(crate) fn insert(&self, addr: A, stats: Arc) { self.session_stats.write().insert(addr, stats); } #[inline] - pub(crate) fn remove(&self, addr: &SocketAddr) { + pub(crate) fn remove(&self, addr: &A) { self.session_stats.write().remove(addr); } #[inline] - pub fn bytes_rx(&self, session_addr: &SocketAddr) -> Option { + pub fn bytes_rx(&self, session_addr: &A) -> Option { self.session_stats .read() .get(session_addr) @@ -38,7 +46,7 @@ impl SocketStats { /// Returns the average latency in microseconds for the given session. #[inline] - pub fn avg_latency(&self, session_addr: &SocketAddr) -> Option { + pub fn avg_latency(&self, session_addr: &A) -> Option { self.session_stats .read() .get(session_addr) diff --git a/msg-socket/tests/it/pubsub.rs b/msg-socket/tests/it/pubsub.rs index 1786bda..8645a54 100644 --- a/msg-socket/tests/it/pubsub.rs +++ b/msg-socket/tests/it/pubsub.rs @@ -35,29 +35,34 @@ async fn pubsub_channel() { }, ); - let result = pubsub_channel_transport(build_tcp).await; + let result = pubsub_channel_transport(build_tcp, "127.0.0.1:9879".parse().unwrap()).await; assert!(result.is_ok()); - let result = pubsub_channel_transport(build_quic).await; + let result = pubsub_channel_transport(build_quic, "127.0.0.1:9879".parse().unwrap()).await; assert!(result.is_ok()); simulator.stop(addr); } -async fn pubsub_channel_transport T, T: Transport + Send + Sync + Unpin + 'static>( +async fn pubsub_channel_transport( new_transport: F, -) -> Result<(), Box> { + addr: T::Addr, +) -> Result<(), Box> +where + F: Fn() -> T, + T: Transport + Send + Sync + Unpin + 'static, +{ let mut publisher = PubSocket::new(new_transport()); let mut subscriber = SubSocket::new(new_transport()); - subscriber.connect("127.0.0.1:9879").await?; + subscriber.connect(addr.clone()).await?; subscriber.subscribe(TOPIC).await?; inject_delay(400).await; - publisher.bind("127.0.0.1:9879").await?; + publisher.try_bind(vec![addr]).await?; // Spawn a task to keep sending messages until the subscriber receives one (after connection process) tokio::spawn(async move { @@ -96,11 +101,11 @@ async fn pubsub_fan_out() { }, ); - let result = pubsub_fan_out_transport(build_tcp, 10).await; + let result = pubsub_fan_out_transport(build_tcp, 10, "127.0.0.1:9880".parse().unwrap()).await; assert!(result.is_ok()); - let result = pubsub_fan_out_transport(build_quic, 10).await; + let result = pubsub_fan_out_transport(build_quic, 10, "127.0.0.1:9880".parse().unwrap()).await; assert!(result.is_ok()); @@ -113,15 +118,14 @@ async fn pubsub_fan_out_transport< >( new_transport: F, subscibers: usize, + addr: T::Addr, ) -> Result<(), Box> { let mut publisher = PubSocket::new(new_transport()); let mut sub_tasks = JoinSet::new(); - let addr = "127.0.0.1:9880"; - for i in 0..subscibers { - let cloned = addr.to_string(); + let cloned = addr.clone(); sub_tasks.spawn(async move { let mut subscriber = SubSocket::new(new_transport()); inject_delay((100 * (i + 1)) as u64).await; @@ -139,7 +143,7 @@ async fn pubsub_fan_out_transport< inject_delay(400).await; - publisher.bind(addr).await?; + publisher.try_bind(vec![addr]).await?; // Spawn a task to keep sending messages until the subscriber receives one (after connection process) tokio::spawn(async move { @@ -177,11 +181,11 @@ async fn pubsub_fan_in() { }, ); - let result = pubsub_fan_in_transport(build_tcp, 20).await; + let result = pubsub_fan_in_transport(build_tcp, 20, "127.0.0.1:9881".parse().unwrap()).await; assert!(result.is_ok()); - let result = pubsub_fan_in_transport(build_quic, 20).await; + let result = pubsub_fan_in_transport(build_quic, 20, "127.0.0.1:9881".parse().unwrap()).await; assert!(result.is_ok()); @@ -194,6 +198,7 @@ async fn pubsub_fan_in_transport< >( new_transport: F, publishers: usize, + addr: T::Addr, ) -> Result<(), Box> { let mut sub_tasks = JoinSet::new(); @@ -201,14 +206,15 @@ async fn pubsub_fan_in_transport< for i in 0..publishers { let tx = tx.clone(); + let addr = addr.clone(); sub_tasks.spawn(async move { let mut publisher = PubSocket::new(new_transport()); inject_delay((100 * (i + 1)) as u64).await; - publisher.bind("127.0.0.1:0").await.unwrap(); + publisher.try_bind(vec![addr]).await.unwrap(); - let addr = publisher.local_addr().unwrap(); - tx.send(addr).await.unwrap(); + let local_addr = publisher.local_addr().unwrap().clone(); + tx.send(local_addr).await.unwrap(); // Spawn a task to keep sending messages until the subscriber receives one (after connection process) tokio::spawn(async move { @@ -233,9 +239,9 @@ async fn pubsub_fan_in_transport< addrs.insert(addr); } - for addr in &addrs { + for addr in addrs.clone() { inject_delay(500).await; - subscriber.connect(addr).await.unwrap(); + subscriber.connect(addr.clone()).await.unwrap(); subscriber.subscribe(TOPIC).await.unwrap(); } @@ -249,7 +255,7 @@ async fn pubsub_fan_in_transport< assert_eq!(TOPIC, msg.topic()); assert_eq!("WORLD", msg.payload()); - addrs.remove(&msg.source()); + addrs.remove(msg.source()); } for _ in 0..publishers { diff --git a/msg-transport/Cargo.toml b/msg-transport/Cargo.toml index 15f1aa7..0484b8c 100644 --- a/msg-transport/Cargo.toml +++ b/msg-transport/Cargo.toml @@ -11,16 +11,12 @@ homepage.workspace = true repository.workspace = true [dependencies] -msg-wire.workspace = true msg-common.workspace = true async-trait.workspace = true -bytes.workspace = true futures.workspace = true tokio.workspace = true -tokio-util.workspace = true tracing.workspace = true -rand.workspace = true thiserror.workspace = true quinn.workspace = true diff --git a/msg-transport/src/ipc/mod.rs b/msg-transport/src/ipc/mod.rs new file mode 100644 index 0000000..91cfc39 --- /dev/null +++ b/msg-transport/src/ipc/mod.rs @@ -0,0 +1,153 @@ +use std::{ + io, + path::PathBuf, + pin::Pin, + task::{Context, Poll}, +}; + +use async_trait::async_trait; +use futures::future::BoxFuture; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::{UnixListener, UnixStream}, +}; +use tracing::debug; + +use crate::{Acceptor, PeerAddress, Transport, TransportExt}; + +use msg_common::async_error; + +#[derive(Debug, Default)] +pub struct Config; + +#[derive(Debug, Default)] +pub struct Ipc { + #[allow(unused)] + config: Config, + listener: Option, + path: Option, +} + +impl Ipc { + pub fn new(config: Config) -> Self { + Self { + config, + listener: None, + path: None, + } + } +} + +pub struct IpcStream { + peer: PathBuf, + stream: UnixStream, +} + +impl IpcStream { + pub async fn connect(peer: PathBuf) -> io::Result { + let stream = UnixStream::connect(&peer).await?; + Ok(Self { peer, stream }) + } +} + +impl AsyncRead for IpcStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().stream).poll_read(cx, buf) + } +} + +impl AsyncWrite for IpcStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.get_mut().stream).poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().stream).poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().stream).poll_shutdown(cx) + } +} + +impl PeerAddress for IpcStream { + fn peer_addr(&self) -> Result { + Ok(self.peer.clone()) + } +} + +#[async_trait] +impl Transport for Ipc { + type Addr = PathBuf; + type Io = IpcStream; + + type Error = io::Error; + + type Connect = BoxFuture<'static, Result>; + type Accept = BoxFuture<'static, Result>; + + fn local_addr(&self) -> Option { + self.path.clone() + } + + async fn bind(&mut self, addr: Self::Addr) -> Result<(), Self::Error> { + if addr.exists() { + debug!("Socket file already exists. Attempting to remove."); + if let Err(e) = std::fs::remove_file(&addr) { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("Failed to remove existing socket file: {}", e), + )); + } + } + + let listener = UnixListener::bind(&addr)?; + self.listener = Some(listener); + self.path = Some(addr); + Ok(()) + } + + fn connect(&mut self, addr: Self::Addr) -> Self::Connect { + Box::pin(async move { IpcStream::connect(addr).await }) + } + + fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + let Some(ref listener) = this.listener else { + return Poll::Ready(async_error(io::ErrorKind::NotConnected.into())); + }; + + match listener.poll_accept(cx) { + Poll::Ready(Ok((io, _addr))) => { + debug!("accepted IPC connection"); + let stream = IpcStream { + // We expect the path to be the same socket as the listener + peer: this.path.clone().expect("listener not bound"), + stream: io, + }; + Poll::Ready(Box::pin(async move { Ok(stream) })) + } + Poll::Ready(Err(e)) => Poll::Ready(async_error(e)), + Poll::Pending => Poll::Pending, + } + } +} + +#[async_trait::async_trait] +impl TransportExt for Ipc { + fn accept(&mut self) -> Acceptor<'_, Self> + where + Self: Sized + Unpin, + { + Acceptor::new(self) + } +} diff --git a/msg-transport/src/lib.rs b/msg-transport/src/lib.rs index 52e6fae..b241da9 100644 --- a/msg-transport/src/lib.rs +++ b/msg-transport/src/lib.rs @@ -1,28 +1,49 @@ +#![doc(issue_tracker_base_url = "https://github.com/chainbound/msg-rs/issues/")] +#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] +#![cfg_attr(not(test), warn(unused_crate_dependencies))] + use futures::{Future, FutureExt}; use std::{ + fmt::Debug, + hash::Hash, + io, net::SocketAddr, + path::PathBuf, pin::Pin, task::{Context, Poll}, }; use tokio::io::{AsyncRead, AsyncWrite}; +pub mod ipc; pub mod quic; pub mod tcp; +/// A trait for address types that can be used by any transport. +pub trait Address: Clone + Debug + Send + Sync + Unpin + Hash + Eq + 'static {} + +/// IP address types, used for TCP and QUIC transports. +impl Address for SocketAddr {} + +/// File system path, used for IPC transport. +impl Address for PathBuf {} + /// A transport provides connection-oriented communication between two peers through /// ordered and reliable streams of bytes. /// /// It provides an interface to manage both inbound and outbound connections. #[async_trait::async_trait] pub trait Transport { + /// The generic address type used by this transport + type Addr: Address; + /// The result of a successful connection. /// /// The output type is transport-specific, and can be a handle to directly write to the /// connection, or it can be a substream multiplexer in the case of stream protocols. - type Io: AsyncRead + AsyncWrite + PeerAddress + Send + Unpin; + type Io: AsyncRead + AsyncWrite + PeerAddress + Send + Unpin; /// An error that occurred when setting up the connection. - type Error: std::error::Error + From + Send + Sync; + type Error: std::error::Error + From + Send + Sync; /// A pending [`Transport::Output`] for an outbound connection, /// obtained when calling [`Transport::connect`]. @@ -33,13 +54,13 @@ pub trait Transport { type Accept: Future> + Send + Unpin; /// Returns the local address this transport is bound to (if it is bound). - fn local_addr(&self) -> Option; + fn local_addr(&self) -> Option; /// Binds to the given address. - async fn bind(&mut self, addr: SocketAddr) -> Result<(), Self::Error>; + async fn bind(&mut self, addr: Self::Addr) -> Result<(), Self::Error>; /// Connects to the given address, returning a future representing a pending outbound connection. - fn connect(&mut self, addr: SocketAddr) -> Self::Connect; + fn connect(&mut self, addr: Self::Addr) -> Self::Connect; /// Poll for incoming connections. If an inbound connection is received, a future representing /// a pending inbound connection is returned. The future will resolve to [`Transport::Output`]. @@ -85,6 +106,6 @@ where } /// Trait for connection types that can return their peer address. -pub trait PeerAddress { - fn peer_addr(&self) -> Result; +pub trait PeerAddress { + fn peer_addr(&self) -> Result; } diff --git a/msg-transport/src/quic/config.rs b/msg-transport/src/quic/config.rs index 2f80771..0cce0ff 100644 --- a/msg-transport/src/quic/config.rs +++ b/msg-transport/src/quic/config.rs @@ -1,7 +1,9 @@ -use quinn::{congestion::ControllerFactory, IdleTimeout}; use std::{sync::Arc, time::Duration}; +use quinn::{congestion::ControllerFactory, IdleTimeout}; + use super::tls::{self_signed_certificate, unsafe_client_config}; + use msg_common::constants::MiB; #[derive(Debug, Clone)] diff --git a/msg-transport/src/quic/mod.rs b/msg-transport/src/quic/mod.rs index c9c8495..ebff405 100644 --- a/msg-transport/src/quic/mod.rs +++ b/msg-transport/src/quic/mod.rs @@ -1,4 +1,3 @@ -use futures::future::BoxFuture; use std::{ io, net::{SocketAddr, UdpSocket}, @@ -6,22 +5,24 @@ use std::{ sync::Arc, task::{ready, Poll}, }; + +use futures::future::BoxFuture; use thiserror::Error; use tokio::sync::mpsc::{self, Receiver}; use tracing::error; -use msg_common::async_error; - use crate::{Acceptor, Transport, TransportExt}; -mod config; -mod stream; -mod tls; +use msg_common::async_error; +mod config; pub use config::{Config, ConfigBuilder}; -pub use quinn::congestion; + +mod stream; use stream::QuicStream; +mod tls; + /// A QUIC error. #[derive(Debug, Error)] pub enum Error { @@ -84,6 +85,7 @@ impl Quic { #[async_trait::async_trait] impl Transport for Quic { + type Addr = SocketAddr; type Io = QuicStream; type Error = Error; diff --git a/msg-transport/src/quic/stream.rs b/msg-transport/src/quic/stream.rs index 76ba2f2..6575853 100644 --- a/msg-transport/src/quic/stream.rs +++ b/msg-transport/src/quic/stream.rs @@ -45,7 +45,7 @@ impl AsyncWrite for QuicStream { } } -impl PeerAddress for QuicStream { +impl PeerAddress for QuicStream { fn peer_addr(&self) -> Result { Ok(self.peer) } diff --git a/msg-transport/src/tcp/mod.rs b/msg-transport/src/tcp/mod.rs index fd16b11..0120263 100644 --- a/msg-transport/src/tcp/mod.rs +++ b/msg-transport/src/tcp/mod.rs @@ -29,7 +29,7 @@ impl Tcp { } } -impl PeerAddress for TcpStream { +impl PeerAddress for TcpStream { fn peer_addr(&self) -> io::Result { self.peer_addr() } @@ -37,6 +37,7 @@ impl PeerAddress for TcpStream { #[async_trait::async_trait] impl Transport for Tcp { + type Addr = SocketAddr; type Io = TcpStream; type Error = io::Error; diff --git a/msg-wire/Cargo.toml b/msg-wire/Cargo.toml index 2a33a2a..4e6144a 100644 --- a/msg-wire/Cargo.toml +++ b/msg-wire/Cargo.toml @@ -16,7 +16,6 @@ msg-common.workspace = true bytes.workspace = true thiserror.workspace = true tokio-util.workspace = true -tracing.workspace = true flate2 = "1" zstd = "0.13" diff --git a/msg-wire/src/lib.rs b/msg-wire/src/lib.rs index b78c693..41703d1 100644 --- a/msg-wire/src/lib.rs +++ b/msg-wire/src/lib.rs @@ -1,3 +1,7 @@ +#![doc(issue_tracker_base_url = "https://github.com/chainbound/msg-rs/issues/")] +#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] +#![cfg_attr(not(test), warn(unused_crate_dependencies))] + pub mod auth; pub mod pubsub; pub mod reqrep; diff --git a/msg/benches/pubsub.rs b/msg/benches/pubsub.rs index 4a70dc3..d1eaf53 100644 --- a/msg/benches/pubsub.rs +++ b/msg/benches/pubsub.rs @@ -4,9 +4,13 @@ use criterion::{ Throughput, }; use futures::StreamExt; +use msg::ipc::Ipc; use pprof::criterion::{Output, PProfProfiler}; use rand::Rng; -use std::time::{Duration, Instant}; +use std::{ + env::temp_dir, + time::{Duration, Instant}, +}; use tokio::runtime::Runtime; use msg_socket::{PubOptions, PubSocket, SubOptions, SubSocket}; @@ -31,13 +35,13 @@ struct PairBenchmark { impl PairBenchmark { /// Sets up the publisher and subscriber sockets. - fn init(&mut self) { + fn init(&mut self, addr: T::Addr) { // Set up the socket connections self.rt.block_on(async { - self.publisher.bind("127.0.0.1:0").await.unwrap(); + self.publisher.try_bind(vec![addr]).await.unwrap(); let addr = self.publisher.local_addr().unwrap(); - self.subscriber.connect(addr).await.unwrap(); + self.subscriber.connect(addr.clone()).await.unwrap(); self.subscriber .subscribe("HELLO".to_string()) @@ -172,7 +176,7 @@ fn pubsub_single_thread_tcp(c: &mut Criterion) { msg_sizes: vec![MSG_SIZE, MSG_SIZE * 8, MSG_SIZE * 64, MSG_SIZE * 128], }; - bench.init(); + bench.init("127.0.0.1:0".parse().unwrap()); let mut group = c.benchmark_group("pubsub_single_thread_tcp_bytes"); group.sample_size(10); @@ -217,7 +221,7 @@ fn pubsub_multi_thread_tcp(c: &mut Criterion) { msg_sizes: vec![MSG_SIZE, MSG_SIZE * 8, MSG_SIZE * 64, MSG_SIZE * 128], }; - bench.init(); + bench.init("127.0.0.1:0".parse().unwrap()); let mut group = c.benchmark_group("pubsub_multi_thread_tcp_bytes"); group.sample_size(10); @@ -261,7 +265,7 @@ fn pubsub_single_thread_quic(c: &mut Criterion) { msg_sizes: vec![MSG_SIZE, MSG_SIZE * 8, MSG_SIZE * 64, MSG_SIZE * 128], }; - bench.init(); + bench.init("127.0.0.1:0".parse().unwrap()); let mut group = c.benchmark_group("pubsub_single_thread_quic_bytes"); group.sample_size(10); @@ -306,7 +310,7 @@ fn pubsub_multi_thread_quic(c: &mut Criterion) { msg_sizes: vec![MSG_SIZE, MSG_SIZE * 8, MSG_SIZE * 64, MSG_SIZE * 128], }; - bench.init(); + bench.init("127.0.0.1:0".parse().unwrap()); let mut group = c.benchmark_group("pubsub_multi_thread_quic_bytes"); group.sample_size(10); @@ -317,10 +321,100 @@ fn pubsub_multi_thread_quic(c: &mut Criterion) { bench.bench_message_throughput(group); } +fn pubsub_single_thread_ipc(c: &mut Criterion) { + let _ = tracing_subscriber::fmt::try_init(); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let buffer_size = 1024 * 64; + + let publisher = PubSocket::with_options( + Ipc::default(), + PubOptions::default() + .flush_interval(Duration::from_micros(100)) + .backpressure_boundary(buffer_size) + .session_buffer_size(N_REQS * 2), + ); + + let subscriber = SubSocket::with_options( + Ipc::default(), + SubOptions::default() + .read_buffer_size(buffer_size) + .ingress_buffer_size(N_REQS * 2), + ); + + let mut bench = PairBenchmark { + rt, + publisher, + subscriber, + n_reqs: N_REQS, + msg_sizes: vec![MSG_SIZE, MSG_SIZE * 8, MSG_SIZE * 64, MSG_SIZE * 128], + }; + + bench.init(temp_dir().join("msg-bench-pubsub-ipc.sock")); + + let mut group = c.benchmark_group("pubsub_single_thread_ipc_bytes"); + group.sample_size(10); + bench.bench_bytes_throughput(group); + + let mut group = c.benchmark_group("pubsub_single_thread_ipc_msgs"); + group.sample_size(10); + bench.bench_message_throughput(group); +} + +fn pubsub_multi_thread_ipc(c: &mut Criterion) { + let _ = tracing_subscriber::fmt::try_init(); + + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build() + .unwrap(); + + let buffer_size = 1024 * 64; + + let publisher = PubSocket::with_options( + Ipc::default(), + PubOptions::default() + .flush_interval(Duration::from_micros(100)) + .backpressure_boundary(buffer_size) + .session_buffer_size(N_REQS * 2), + ); + + let subscriber = SubSocket::with_options( + Ipc::default(), + SubOptions::default() + .read_buffer_size(buffer_size) + .ingress_buffer_size(N_REQS * 2), + ); + + let mut bench = PairBenchmark { + rt, + publisher, + subscriber, + n_reqs: N_REQS, + msg_sizes: vec![MSG_SIZE, MSG_SIZE * 8, MSG_SIZE * 64, MSG_SIZE * 128], + }; + + bench.init(temp_dir().join("msg-bench-pubsub-ipc-multi.sock")); + + let mut group = c.benchmark_group("pubsub_multi_thread_ipc_bytes"); + group.sample_size(10); + bench.bench_bytes_throughput(group); + + let mut group = c.benchmark_group("pubsub_multi_thread_ipc_msgs"); + group.sample_size(10); + bench.bench_message_throughput(group); +} + criterion_group! { name = benches; config = Criterion::default().warm_up_time(Duration::from_secs(1)).with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); - targets = pubsub_single_thread_tcp, pubsub_multi_thread_tcp, pubsub_single_thread_quic, pubsub_multi_thread_quic + targets = pubsub_single_thread_tcp, pubsub_multi_thread_tcp, pubsub_single_thread_quic, + pubsub_multi_thread_quic, pubsub_single_thread_ipc, pubsub_multi_thread_ipc } // Runs various benchmarks for the `PubSocket` and `SubSocket`. diff --git a/msg/benches/reqrep.rs b/msg/benches/reqrep.rs index 1a0b624..bc84e19 100644 --- a/msg/benches/reqrep.rs +++ b/msg/benches/reqrep.rs @@ -1,4 +1,4 @@ -use std::time::Duration; +use std::{env::temp_dir, time::Duration}; use bytes::Bytes; use criterion::{ @@ -6,6 +6,7 @@ use criterion::{ Throughput, }; use futures::StreamExt; +use msg::{ipc::Ipc, Transport}; use pprof::criterion::Output; use rand::Rng; @@ -22,23 +23,26 @@ const MSG_SIZE: usize = 512; #[global_allocator] static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc; -struct PairBenchmark { +struct PairBenchmark { rt: Runtime, - req: ReqSocket, - rep: Option>, + req: ReqSocket, + rep: Option>, n_reqs: usize, msg_sizes: Vec, } -impl PairBenchmark { - fn init(&mut self) { +impl PairBenchmark { + fn init(&mut self, addr: T::Addr) { let mut rep = self.rep.take().unwrap(); // setup the socket connections self.rt.block_on(async { - rep.bind("127.0.0.1:0").await.unwrap(); + rep.try_bind(vec![addr]).await.unwrap(); - self.req.connect(rep.local_addr().unwrap()).await.unwrap(); + self.req + .try_connect(rep.local_addr().unwrap().clone()) + .await + .unwrap(); tokio::spawn(async move { rep.map(|req| async move { @@ -133,7 +137,7 @@ fn reqrep_single_thread_tcp(c: &mut Criterion) { msg_sizes: vec![MSG_SIZE, MSG_SIZE * 8, MSG_SIZE * 64, MSG_SIZE * 128], }; - bench.init(); + bench.init("127.0.0.1:0".parse().unwrap()); let mut group = c.benchmark_group("reqrep_single_thread_tcp_bytes"); group.sample_size(10); bench.bench_request_throughput(group); @@ -167,7 +171,7 @@ fn reqrep_multi_thread_tcp(c: &mut Criterion) { msg_sizes: vec![MSG_SIZE, MSG_SIZE * 8, MSG_SIZE * 64, MSG_SIZE * 128], }; - bench.init(); + bench.init("127.0.0.1:0".parse().unwrap()); let mut group = c.benchmark_group("reqrep_multi_thread_tcp_bytes"); group.sample_size(10); bench.bench_request_throughput(group); @@ -177,10 +181,69 @@ fn reqrep_multi_thread_tcp(c: &mut Criterion) { bench.bench_rps(group); } +fn reqrep_single_thread_ipc(c: &mut Criterion) { + let _ = tracing_subscriber::fmt::try_init(); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + let req = ReqSocket::new(Ipc::default()); + let rep = RepSocket::new(Ipc::default()); + + let mut bench = PairBenchmark { + rt, + req, + rep: Some(rep), + n_reqs: N_REQS, + msg_sizes: vec![MSG_SIZE, MSG_SIZE * 8, MSG_SIZE * 64, MSG_SIZE * 128], + }; + + bench.init(temp_dir().join("msg-bench-reqrep-ipc.sock")); + let mut group = c.benchmark_group("reqrep_single_thread_ipc_bytes"); + group.sample_size(10); + bench.bench_request_throughput(group); + + let mut group = c.benchmark_group("reqrep_single_thread_ipc_rps"); + group.sample_size(10); + bench.bench_rps(group); +} + +fn reqrep_multi_thread_ipc(c: &mut Criterion) { + let _ = tracing_subscriber::fmt::try_init(); + + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build() + .unwrap(); + + let req = ReqSocket::new(Ipc::default()); + let rep = RepSocket::new(Ipc::default()); + + let mut bench = PairBenchmark { + rt, + req, + rep: Some(rep), + n_reqs: N_REQS, + msg_sizes: vec![MSG_SIZE, MSG_SIZE * 8, MSG_SIZE * 64, MSG_SIZE * 128], + }; + + bench.init(temp_dir().join("msg-bench-reqrep-ipc-multi.sock")); + let mut group = c.benchmark_group("reqrep_multi_thread_ipc_bytes"); + group.sample_size(10); + bench.bench_request_throughput(group); + + let mut group = c.benchmark_group("reqrep_multi_thread_ipc_rps"); + group.sample_size(10); + bench.bench_rps(group); +} + criterion_group! { name = benches; config = Criterion::default().warm_up_time(Duration::from_secs(1)).with_profiler(pprof::criterion::PProfProfiler::new(100, Output::Flamegraph(None))); - targets = reqrep_single_thread_tcp, reqrep_multi_thread_tcp + targets = reqrep_single_thread_tcp, reqrep_multi_thread_tcp, reqrep_single_thread_ipc, reqrep_multi_thread_ipc } // Runs various benchmarks for the `ReqSocket` and `RepSocket`. diff --git a/msg/examples/durable.rs b/msg/examples/durable.rs index 42cff7e..f6a3662 100644 --- a/msg/examples/durable.rs +++ b/msg/examples/durable.rs @@ -23,7 +23,7 @@ async fn start_rep() { // Initialize the reply socket (server side) with a transport // and an authenticator. let mut rep = RepSocket::new(Tcp::default()).with_auth(Auth); - while rep.bind("0.0.0.0:4444").await.is_err() { + while rep.bind_socket("0.0.0.0:4444").await.is_err() { rep = RepSocket::new(Tcp::default()).with_auth(Auth); tracing::warn!("Failed to bind rep socket, retrying..."); tokio::time::sleep(Duration::from_secs(1)).await; @@ -76,7 +76,7 @@ async fn main() { tokio::spawn( async move { tracing::info!("Trying to connect to rep socket... This will start the connection process in the background, it won't immediately connect."); - req.connect("0.0.0.0:4444").await.unwrap(); + req.connect_socket("0.0.0.0:4444").await.unwrap(); for i in 0..10 { tracing::info!("Sending request {i}..."); diff --git a/msg/examples/ipc.rs b/msg/examples/ipc.rs new file mode 100644 index 0000000..29e32e4 --- /dev/null +++ b/msg/examples/ipc.rs @@ -0,0 +1,44 @@ +use std::env::temp_dir; + +use bytes::Bytes; +use tokio_stream::StreamExt; + +use msg::{ipc::Ipc, RepSocket, ReqSocket}; + +#[tokio::main] +async fn main() { + let _ = tracing_subscriber::fmt::try_init(); + + // Initialize the reply socket (server side) with a transport + let mut rep = RepSocket::new(Ipc::default()); + + // use a temporary file as the socket path + let path = temp_dir().join("test.sock"); + rep.bind_path(path.clone()).await.unwrap(); + println!("Listening on {:?}", rep.local_addr().unwrap()); + + // Initialize the request socket (client side) with a transport + let mut req = ReqSocket::new(Ipc::default()); + req.connect_path(path).await.unwrap(); + + tokio::spawn(async move { + // Receive the request and respond with "world" + // RepSocket implements `Stream` + let req = rep.next().await.unwrap(); + println!("Message: {:?}", req.msg()); + + req.respond(Bytes::from("world")).unwrap(); + }); + + let res: Bytes = req.request(Bytes::from("helloooo!")).await.unwrap(); + println!("Response: {:?}", res); + + // Access the socket statistics + let stats = req.stats(); + println!( + "Sent: {}B, Received: {}B | time: {}μs", + stats.bytes_tx(), + stats.bytes_rx(), + stats.rtt() + ); +} diff --git a/msg/examples/pubsub.rs b/msg/examples/pubsub.rs index 6b8c54c..8752870 100644 --- a/msg/examples/pubsub.rs +++ b/msg/examples/pubsub.rs @@ -31,17 +31,17 @@ async fn main() { ); tracing::info!("Setting up the sockets..."); - pub_socket.bind("127.0.0.1:0").await.unwrap(); + pub_socket.bind_socket("127.0.0.1:0").await.unwrap(); let pub_addr = pub_socket.local_addr().unwrap(); tracing::info!("Publisher listening on: {}", pub_addr); - sub1.connect(pub_addr).await.unwrap(); + sub1.connect_socket(pub_addr).await.unwrap(); sub1.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); tracing::info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); - sub2.connect(pub_addr).await.unwrap(); + sub2.connect_socket(pub_addr).await.unwrap(); sub2.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); tracing::info!("Subscriber 2 connected and subscribed to HELLO_TOPIC"); diff --git a/msg/examples/pubsub_auth.rs b/msg/examples/pubsub_auth.rs index 59d7208..5d4d796 100644 --- a/msg/examples/pubsub_auth.rs +++ b/msg/examples/pubsub_auth.rs @@ -45,17 +45,17 @@ async fn main() { ); tracing::info!("Setting up the sockets..."); - pub_socket.bind("127.0.0.1:0").await.unwrap(); + pub_socket.bind_socket("127.0.0.1:0").await.unwrap(); let pub_addr = pub_socket.local_addr().unwrap(); tracing::info!("Publisher listening on: {}", pub_addr); - sub1.connect(pub_addr).await.unwrap(); + sub1.connect_socket(pub_addr).await.unwrap(); sub1.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); tracing::info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); - sub2.connect(pub_addr).await.unwrap(); + sub2.connect_socket(pub_addr).await.unwrap(); sub2.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); tracing::info!("Subscriber 2 connected and subscribed to HELLO_TOPIC"); diff --git a/msg/examples/pubsub_compression.rs b/msg/examples/pubsub_compression.rs index 37cfaa9..c887805 100644 --- a/msg/examples/pubsub_compression.rs +++ b/msg/examples/pubsub_compression.rs @@ -21,17 +21,17 @@ async fn main() { let mut sub2 = SubSocket::new(Tcp::default()); tracing::info!("Setting up the sockets..."); - pub_socket.bind("127.0.0.1:0").await.unwrap(); + pub_socket.bind_socket("127.0.0.1:0").await.unwrap(); let pub_addr = pub_socket.local_addr().unwrap(); tracing::info!("Publisher listening on: {}", pub_addr); - sub1.connect(pub_addr).await.unwrap(); + sub1.connect_socket(pub_addr).await.unwrap(); sub1.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); tracing::info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); - sub2.connect(pub_addr).await.unwrap(); + sub2.connect_socket(pub_addr).await.unwrap(); sub2.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); tracing::info!("Subscriber 2 connected and subscribed to HELLO_TOPIC"); diff --git a/msg/examples/quic_vs_tcp.rs b/msg/examples/quic_vs_tcp.rs index 26809f7..c0583cc 100644 --- a/msg/examples/quic_vs_tcp.rs +++ b/msg/examples/quic_vs_tcp.rs @@ -30,12 +30,12 @@ async fn run_tcp() { ); tracing::info!("Setting up the sockets..."); - pub_socket.bind("127.0.0.1:0").await.unwrap(); + pub_socket.bind_socket("127.0.0.1:0").await.unwrap(); let pub_addr = pub_socket.local_addr().unwrap(); tracing::info!("Publisher listening on: {}", pub_addr); - sub1.connect(pub_addr).await.unwrap(); + sub1.connect_socket(pub_addr).await.unwrap(); sub1.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); tracing::info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); @@ -63,12 +63,12 @@ async fn run_quic() { ); tracing::info!("Setting up the sockets..."); - pub_socket.bind("127.0.0.1:0").await.unwrap(); + pub_socket.bind_socket("127.0.0.1:0").await.unwrap(); let pub_addr = pub_socket.local_addr().unwrap(); tracing::info!("Publisher listening on: {}", pub_addr); - sub1.connect(pub_addr).await.unwrap(); + sub1.connect_socket(pub_addr).await.unwrap(); sub1.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); tracing::info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); diff --git a/msg/examples/reqrep.rs b/msg/examples/reqrep.rs index 286facd..4b153a6 100644 --- a/msg/examples/reqrep.rs +++ b/msg/examples/reqrep.rs @@ -7,11 +7,11 @@ use msg::{tcp::Tcp, RepSocket, ReqSocket}; async fn main() { // Initialize the reply socket (server side) with a transport let mut rep = RepSocket::new(Tcp::default()); - rep.bind("0.0.0.0:4444").await.unwrap(); + rep.bind_socket("0.0.0.0:4444").await.unwrap(); // Initialize the request socket (client side) with a transport let mut req = ReqSocket::new(Tcp::default()); - req.connect("0.0.0.0:4444").await.unwrap(); + req.connect_socket("0.0.0.0:4444").await.unwrap(); tokio::spawn(async move { // Receive the request and respond with "world" diff --git a/msg/examples/reqrep_auth.rs b/msg/examples/reqrep_auth.rs index 2b4b54e..8c2e17e 100644 --- a/msg/examples/reqrep_auth.rs +++ b/msg/examples/reqrep_auth.rs @@ -20,7 +20,7 @@ async fn main() { // Initialize the reply socket (server side) with a transport // and an authenticator. let mut rep = RepSocket::new(Tcp::default()).with_auth(Auth); - rep.bind("0.0.0.0:4444").await.unwrap(); + rep.bind_socket("0.0.0.0:4444").await.unwrap(); // Initialize the request socket (client side) with a transport // and an identifier. This will implicitly turn on client authentication. @@ -29,7 +29,7 @@ async fn main() { ReqOptions::default().auth_token(Bytes::from("REQ")), ); - req.connect("0.0.0.0:4444").await.unwrap(); + req.connect_socket("0.0.0.0:4444").await.unwrap(); tokio::spawn(async move { // Receive the request and respond with "world" diff --git a/msg/examples/reqrep_compression.rs b/msg/examples/reqrep_compression.rs index bb85132..e5ac2d5 100644 --- a/msg/examples/reqrep_compression.rs +++ b/msg/examples/reqrep_compression.rs @@ -13,7 +13,7 @@ async fn main() { RepSocket::with_options(Tcp::default(), RepOptions::default().min_compress_size(0)) // Enable Gzip compression (compression level 6) .with_compressor(GzipCompressor::new(6)); - rep.bind("0.0.0.0:4444").await.unwrap(); + rep.bind_socket("0.0.0.0:4444").await.unwrap(); // Initialize the request socket (client side) with a transport // and a minimum compresion size of 0 bytes to compress all requests @@ -24,7 +24,7 @@ async fn main() { // use the same compression algorithm or level. .with_compressor(GzipCompressor::new(6)); - req.connect("0.0.0.0:4444").await.unwrap(); + req.connect_socket("0.0.0.0:4444").await.unwrap(); tokio::spawn(async move { // Receive the request and respond with "world" diff --git a/msg/src/lib.rs b/msg/src/lib.rs index 417d4b2..681e14f 100644 --- a/msg/src/lib.rs +++ b/msg/src/lib.rs @@ -1,3 +1,6 @@ +#![doc(issue_tracker_base_url = "https://github.com/chainbound/msg-rs/issues/")] +#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] + pub use msg_socket::*; pub use msg_transport::*; pub use msg_wire::compression; From 0595fdd97f131c25f40fce39d703a69ba230625f Mon Sep 17 00:00:00 2001 From: nicolas <48695862+merklefruit@users.noreply.github.com> Date: Sun, 25 Aug 2024 13:31:02 +0200 Subject: [PATCH 2/7] chore: lint --- msg-sim/src/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/msg-sim/src/lib.rs b/msg-sim/src/lib.rs index d95450f..0e9a6e8 100644 --- a/msg-sim/src/lib.rs +++ b/msg-sim/src/lib.rs @@ -1,6 +1,5 @@ #![doc(issue_tracker_base_url = "https://github.com/chainbound/msg-rs/issues/")] #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] -#![cfg_attr(not(test), warn(unused_crate_dependencies))] use std::{collections::HashMap, io, net::IpAddr, time::Duration}; From f711e0ac52db816ea2d3927eff8cf91aee9ac111 Mon Sep 17 00:00:00 2001 From: nicolas <48695862+merklefruit@users.noreply.github.com> Date: Sun, 25 Aug 2024 13:43:50 +0200 Subject: [PATCH 3/7] chore: small nit --- msg-socket/src/pub/socket.rs | 12 ++---------- msg-socket/src/sub/socket.rs | 20 ++++++++------------ 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/msg-socket/src/pub/socket.rs b/msg-socket/src/pub/socket.rs index 991da56..0e331d4 100644 --- a/msg-socket/src/pub/socket.rs +++ b/msg-socket/src/pub/socket.rs @@ -35,13 +35,6 @@ pub struct PubSocket { local_addr: Option, } -impl PubSocket -where - T: Transport + Send + Unpin + 'static, - T::Addr: ToSocketAddrs, -{ -} - impl PubSocket where T: Transport + Send + Unpin + 'static, @@ -64,9 +57,8 @@ where /// /// This method is only available for transports that support [`PathBuf`] as address type, /// like [`Ipc`](msg_transport::ipc::Ipc). - pub async fn bind_path(&mut self, path: impl AsRef) -> Result<(), PubError> { - let addr = path.as_ref().clone(); - self.try_bind(vec![addr]).await + pub async fn bind_path(&mut self, path: impl Into) -> Result<(), PubError> { + self.try_bind(vec![path.into()]).await } } diff --git a/msg-socket/src/sub/socket.rs b/msg-socket/src/sub/socket.rs index 588cf14..69e9cd8 100644 --- a/msg-socket/src/sub/socket.rs +++ b/msg-socket/src/sub/socket.rs @@ -111,27 +111,23 @@ where T: Transport + Send + Sync + Unpin + 'static, { /// Connects to the given path asynchronously. - pub async fn connect_path(&mut self, path: impl AsRef) -> Result<(), SubError> { - let path = path.as_ref().clone(); - self.connect(path).await + pub async fn connect_path(&mut self, path: impl Into) -> Result<(), SubError> { + self.connect(path.into()).await } /// Attempts to connect to the given path immediately. - pub fn try_connect_path(&mut self, path: impl AsRef) -> Result<(), SubError> { - let path = path.as_ref().clone(); - self.try_connect(path) + pub fn try_connect_path(&mut self, path: impl Into) -> Result<(), SubError> { + self.try_connect(path.into()) } /// Disconnects from the given path asynchronously. - pub async fn disconnect_path(&mut self, path: impl AsRef) -> Result<(), SubError> { - let path = path.as_ref().clone(); - self.disconnect(path).await + pub async fn disconnect_path(&mut self, path: impl Into) -> Result<(), SubError> { + self.disconnect(path.into()).await } /// Attempts to disconnect from the given path immediately. - pub fn try_disconnect_path(&mut self, path: impl AsRef) -> Result<(), SubError> { - let path = path.as_ref().clone(); - self.try_disconnect(path) + pub fn try_disconnect_path(&mut self, path: impl Into) -> Result<(), SubError> { + self.try_disconnect(path.into()) } } From a79265daa66e7753dc88e2495fc570d17a4af2b3 Mon Sep 17 00:00:00 2001 From: nicolas <48695862+merklefruit@users.noreply.github.com> Date: Sun, 25 Aug 2024 14:01:06 +0200 Subject: [PATCH 4/7] chore: docs --- msg-transport/src/ipc/mod.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/msg-transport/src/ipc/mod.rs b/msg-transport/src/ipc/mod.rs index 91cfc39..cd5490a 100644 --- a/msg-transport/src/ipc/mod.rs +++ b/msg-transport/src/ipc/mod.rs @@ -20,6 +20,18 @@ use msg_common::async_error; #[derive(Debug, Default)] pub struct Config; +/// An IPC (Inter-Process Communication) implementation using Unix domain sockets. +/// +/// This struct represents the IPC transport, which allows communication between processes +/// on the same machine using Unix domain sockets. +/// +/// # Features +/// - Asynchronous communication using Tokio's runtime +/// - Supports both connection-oriented (stream) and connectionless (datagram) sockets +/// - Implements standard transport traits for easy integration with other components +/// +/// Note: This implementation is specific to Unix-like operating systems and is not tested +/// on Windows or other non-Unix platforms. #[derive(Debug, Default)] pub struct Ipc { #[allow(unused)] @@ -142,7 +154,7 @@ impl Transport for Ipc { } } -#[async_trait::async_trait] +#[async_trait] impl TransportExt for Ipc { fn accept(&mut self) -> Acceptor<'_, Self> where From 30675dffb9a739088488aa0c9070e7ce20974f71 Mon Sep 17 00:00:00 2001 From: nicolas <48695862+merklefruit@users.noreply.github.com> Date: Mon, 26 Aug 2024 18:16:49 +0200 Subject: [PATCH 5/7] fix(ipc): remove api breaking changes by using generic in Transport --- msg-common/src/lib.rs | 6 ++ msg-socket/src/pub/driver.rs | 16 ++--- msg-socket/src/pub/mod.rs | 42 ++++++------- msg-socket/src/pub/socket.rs | 29 ++++----- msg-socket/src/rep/driver.rs | 18 +++--- msg-socket/src/rep/mod.rs | 22 +++---- msg-socket/src/rep/socket.rs | 43 ++++++-------- msg-socket/src/req/driver.rs | 21 +++---- msg-socket/src/req/socket.rs | 31 ++++++---- msg-socket/src/sub/driver.rs | 37 ++++++------ msg-socket/src/sub/socket.rs | 74 +++++++++++------------ msg-socket/tests/it/pubsub.rs | 25 ++++---- msg-transport/src/ipc/mod.rs | 13 ++-- msg-transport/src/lib.rs | 95 +++++++++++++++++++++++++----- msg-transport/src/quic/mod.rs | 7 +-- msg-transport/src/tcp/mod.rs | 7 +-- msg/benches/pubsub.rs | 14 ++--- msg/benches/reqrep.rs | 12 ++-- msg/examples/durable.rs | 4 +- msg/examples/ipc.rs | 4 +- msg/examples/pubsub.rs | 6 +- msg/examples/pubsub_auth.rs | 6 +- msg/examples/pubsub_compression.rs | 6 +- msg/examples/quic_vs_tcp.rs | 16 ++--- msg/examples/reqrep.rs | 4 +- msg/examples/reqrep_auth.rs | 4 +- msg/examples/reqrep_compression.rs | 4 +- 27 files changed, 321 insertions(+), 245 deletions(-) diff --git a/msg-common/src/lib.rs b/msg-common/src/lib.rs index f7e4b94..1225f5e 100644 --- a/msg-common/src/lib.rs +++ b/msg-common/src/lib.rs @@ -3,6 +3,7 @@ #![cfg_attr(not(test), warn(unused_crate_dependencies))] use std::{ + io, pin::Pin, task::{Context, Poll}, time::SystemTime, @@ -34,6 +35,11 @@ pub fn async_error( Box::pin(async move { Err(e) }) } +/// Creates a new [`io::Error`] with the given message. +pub fn io_error(msg: impl Into) -> io::Error { + io::Error::new(io::ErrorKind::Other, msg.into()) +} + #[allow(non_upper_case_globals)] pub mod constants { pub const KiB: u32 = 1024; diff --git a/msg-socket/src/pub/driver.rs b/msg-socket/src/pub/driver.rs index e27f0d2..e6a5c10 100644 --- a/msg-socket/src/pub/driver.rs +++ b/msg-socket/src/pub/driver.rs @@ -14,11 +14,11 @@ use super::{ session::SubscriberSession, trie::PrefixTrie, PubError, PubMessage, PubOptions, SocketState, }; use crate::{AuthResult, Authenticator}; -use msg_transport::{PeerAddress, Transport}; +use msg_transport::{Address, PeerAddress, Transport}; use msg_wire::{auth, pubsub}; #[allow(clippy::type_complexity)] -pub(crate) struct PubDriver { +pub(crate) struct PubDriver, A: Address> { /// Session ID counter. pub(super) id_counter: u32, /// The server transport used to accept incoming connections. @@ -32,14 +32,15 @@ pub(crate) struct PubDriver { /// A set of pending incoming connections, represented by [`Transport::Accept`]. pub(super) conn_tasks: FuturesUnordered, /// A joinset of authentication tasks. - pub(super) auth_tasks: JoinSet, PubError>>, + pub(super) auth_tasks: JoinSet, PubError>>, /// The receiver end of the message broadcast channel. The sender half is stored by [`PubSocket`](super::PubSocket). pub(super) from_socket_bcast: broadcast::Receiver, } -impl Future for PubDriver +impl Future for PubDriver where - T: Transport + Unpin + 'static, + T: Transport + Unpin + 'static, + A: Address, { type Output = Result<(), PubError>; @@ -130,9 +131,10 @@ where } } -impl PubDriver +impl PubDriver where - T: Transport + Unpin + 'static, + T: Transport + Unpin + 'static, + A: Address, { /// Handles an incoming connection. If this returns an error, the active connections counter /// should be decremented. diff --git a/msg-socket/src/pub/mod.rs b/msg-socket/src/pub/mod.rs index 37f56d5..3360bf9 100644 --- a/msg-socket/src/pub/mod.rs +++ b/msg-socket/src/pub/mod.rs @@ -192,10 +192,10 @@ mod tests { let mut sub_socket = SubSocket::with_options(Tcp::default(), SubOptions::default()); - pub_socket.bind_socket("0.0.0.0:0").await.unwrap(); + pub_socket.bind("0.0.0.0:0").await.unwrap(); let addr = pub_socket.local_addr().unwrap(); - sub_socket.connect_socket(addr).await.unwrap(); + sub_socket.connect(addr).await.unwrap(); sub_socket.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; @@ -221,10 +221,10 @@ mod tests { SubOptions::default().auth_token(Bytes::from("client1")), ); - pub_socket.bind_socket("0.0.0.0:0").await.unwrap(); + pub_socket.bind("0.0.0.0:0").await.unwrap(); let addr = pub_socket.local_addr().unwrap(); - sub_socket.connect_socket(addr).await.unwrap(); + sub_socket.connect(addr).await.unwrap(); sub_socket.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; @@ -250,10 +250,10 @@ mod tests { SubOptions::default().auth_token(Bytes::from("client1")), ); - pub_socket.bind_socket("0.0.0.0:0").await.unwrap(); + pub_socket.bind("0.0.0.0:0").await.unwrap(); let addr = pub_socket.local_addr().unwrap(); - sub_socket.connect_socket(addr).await.unwrap(); + sub_socket.connect(addr).await.unwrap(); sub_socket.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; @@ -278,11 +278,11 @@ mod tests { let mut sub2 = SubSocket::new(Tcp::default()); - pub_socket.bind_socket("0.0.0.0:0").await.unwrap(); + pub_socket.bind("0.0.0.0:0").await.unwrap(); let addr = pub_socket.local_addr().unwrap(); - sub1.connect_socket(addr).await.unwrap(); - sub2.connect_socket(addr).await.unwrap(); + sub1.connect(addr).await.unwrap(); + sub2.connect(addr).await.unwrap(); sub1.subscribe("HELLO".to_string()).await.unwrap(); sub2.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; @@ -313,11 +313,11 @@ mod tests { let mut sub2 = SubSocket::new(Tcp::default()); - pub_socket.bind_socket("0.0.0.0:0").await.unwrap(); + pub_socket.bind("0.0.0.0:0").await.unwrap(); let addr = pub_socket.local_addr().unwrap(); - sub1.connect_socket(addr).await.unwrap(); - sub2.connect_socket(addr).await.unwrap(); + sub1.connect(addr).await.unwrap(); + sub2.connect(addr).await.unwrap(); sub1.subscribe("HELLO".to_string()).await.unwrap(); sub2.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; @@ -349,11 +349,11 @@ mod tests { let mut sub_socket = SubSocket::new(Tcp::default()); // Try to connect and subscribe before the publisher is up - sub_socket.connect_socket("0.0.0.0:6662").await.unwrap(); + sub_socket.connect("0.0.0.0:6662").await.unwrap(); sub_socket.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(500)).await; - pub_socket.bind_socket("0.0.0.0:6662").await.unwrap(); + pub_socket.bind("0.0.0.0:6662").await.unwrap(); tokio::time::sleep(Duration::from_millis(2000)).await; pub_socket @@ -376,11 +376,11 @@ mod tests { let mut sub_socket = SubSocket::new(Quic::default()); // Try to connect and subscribe before the publisher is up - sub_socket.connect_socket("0.0.0.0:6662").await.unwrap(); + sub_socket.connect("0.0.0.0:6662").await.unwrap(); sub_socket.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(1000)).await; - pub_socket.bind_socket("0.0.0.0:6662").await.unwrap(); + pub_socket.bind("0.0.0.0:6662").await.unwrap(); tokio::time::sleep(Duration::from_millis(2000)).await; pub_socket @@ -401,18 +401,18 @@ mod tests { let mut pub_socket = PubSocket::with_options(Tcp::default(), PubOptions::default().max_clients(1)); - pub_socket.bind_socket("0.0.0.0:0").await.unwrap(); + pub_socket.bind("0.0.0.0:0").await.unwrap(); - let mut sub1 = SubSocket::::with_options(Tcp::default(), SubOptions::default()); + let mut sub1 = SubSocket::with_options(Tcp::default(), SubOptions::default()); - let mut sub2 = SubSocket::::with_options(Tcp::default(), SubOptions::default()); + let mut sub2 = SubSocket::with_options(Tcp::default(), SubOptions::default()); let addr = pub_socket.local_addr().unwrap(); - sub1.connect_socket(addr).await.unwrap(); + sub1.connect(addr).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; assert_eq!(pub_socket.stats().active_clients(), 1); - sub2.connect_socket(addr).await.unwrap(); + sub2.connect(addr).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; assert_eq!(pub_socket.stats().active_clients(), 1); } diff --git a/msg-socket/src/pub/socket.rs b/msg-socket/src/pub/socket.rs index 0e331d4..8671c96 100644 --- a/msg-socket/src/pub/socket.rs +++ b/msg-socket/src/pub/socket.rs @@ -10,12 +10,12 @@ use tracing::{debug, trace, warn}; use super::{driver::PubDriver, stats::SocketStats, PubError, PubMessage, PubOptions, SocketState}; use crate::Authenticator; -use msg_transport::Transport; +use msg_transport::{Address, Transport}; use msg_wire::compression::Compressor; /// A publisher socket. This is thread-safe and can be cloned. #[derive(Clone, Default)] -pub struct PubSocket { +pub struct PubSocket, A: Address> { /// The reply socket options, shared with the driver. options: Arc, /// The reply socket state, shared with the driver. @@ -32,39 +32,40 @@ pub struct PubSocket { // complicates the API a lot. We can always change this later for perf reasons. compressor: Option>, /// The local address this socket is bound to. - local_addr: Option, + local_addr: Option, } -impl PubSocket +impl PubSocket where - T: Transport + Send + Unpin + 'static, + T: Transport + Send + Unpin + 'static, { /// Binds the socket to the given socket addres /// /// This method is only available for transports that support [`SocketAddr`] as address type, /// like [`Tcp`](msg_transport::tcp::Tcp) and [`Quic`](msg_transport::quic::Quic). - pub async fn bind_socket(&mut self, addr: impl ToSocketAddrs) -> Result<(), PubError> { + pub async fn bind(&mut self, addr: impl ToSocketAddrs) -> Result<(), PubError> { let addrs = lookup_host(addr).await?; self.try_bind(addrs.collect()).await } } -impl PubSocket +impl PubSocket where - T: Transport + Send + Unpin + 'static, + T: Transport + Send + Unpin + 'static, { /// Binds the socket to the given path. /// /// This method is only available for transports that support [`PathBuf`] as address type, /// like [`Ipc`](msg_transport::ipc::Ipc). - pub async fn bind_path(&mut self, path: impl Into) -> Result<(), PubError> { + pub async fn bind(&mut self, path: impl Into) -> Result<(), PubError> { self.try_bind(vec![path.into()]).await } } -impl PubSocket +impl PubSocket where - T: Transport + Send + Unpin + 'static, + T: Transport + Send + Unpin + 'static, + A: Address, { /// Creates a new reply socket with the default [`PubOptions`]. pub fn new(transport: T) -> Self { @@ -85,7 +86,7 @@ where } /// Sets the connection authenticator for this socket. - pub fn with_auth(mut self, authenticator: A) -> Self { + pub fn with_auth(mut self, authenticator: O) -> Self { self.auth = Some(Arc::new(authenticator)); self } @@ -99,7 +100,7 @@ where /// Binds the socket to the given addresses in order until one succeeds. /// /// This also spawns the socket driver task. - pub async fn try_bind(&mut self, addresses: Vec) -> Result<(), PubError> { + pub async fn try_bind(&mut self, addresses: Vec) -> Result<(), PubError> { let (to_sessions_bcast, from_socket_bcast) = broadcast::channel(self.options.session_buffer_size); @@ -219,7 +220,7 @@ where } /// Returns the local address this socket is bound to. `None` if the socket is not bound. - pub fn local_addr(&self) -> Option<&T::Addr> { + pub fn local_addr(&self) -> Option<&A> { self.local_addr.as_ref() } } diff --git a/msg-socket/src/rep/driver.rs b/msg-socket/src/rep/driver.rs index 9899f16..90a9cce 100644 --- a/msg-socket/src/rep/driver.rs +++ b/msg-socket/src/rep/driver.rs @@ -37,7 +37,7 @@ pub(crate) struct PeerState { } #[allow(clippy::type_complexity)] -pub(crate) struct RepDriver { +pub(crate) struct RepDriver, A: Address> { /// The server transport used to accept incoming connections. pub(crate) transport: T, /// The reply socket state, shared with the socket front-end. @@ -46,9 +46,9 @@ pub(crate) struct RepDriver { /// Options shared with socket. pub(crate) options: Arc, /// [`StreamMap`] of connected peers. The key is the peer's address. - pub(crate) peer_states: StreamMap>>, + pub(crate) peer_states: StreamMap>>, /// Sender to the socket front-end. Used to notify the socket of incoming requests. - pub(crate) to_socket: mpsc::Sender>, + pub(crate) to_socket: mpsc::Sender>, /// Optional connection authenticator. pub(crate) auth: Option>, /// Optional message compressor. This is shared with the socket to keep @@ -57,12 +57,13 @@ pub(crate) struct RepDriver { /// A set of pending incoming connections, represented by [`Transport::Accept`]. pub(super) conn_tasks: FuturesUnordered, /// A joinset of authentication tasks. - pub(crate) auth_tasks: JoinSet, PubError>>, + pub(crate) auth_tasks: JoinSet, PubError>>, } -impl Future for RepDriver +impl Future for RepDriver where - T: Transport + Unpin + 'static, + T: Transport + Unpin + 'static, + A: Address, { type Output = Result<(), PubError>; @@ -176,9 +177,10 @@ where } } -impl RepDriver +impl RepDriver where - T: Transport + Unpin + 'static, + T: Transport + Unpin + 'static, + A: Address, { /// Handles an incoming connection. If this returns an error, the active connections counter /// should be decremented. diff --git a/msg-socket/src/rep/mod.rs b/msg-socket/src/rep/mod.rs index d0b0e4a..16f2161 100644 --- a/msg-socket/src/rep/mod.rs +++ b/msg-socket/src/rep/mod.rs @@ -113,10 +113,10 @@ mod tests { async fn reqrep_simple() { let _ = tracing_subscriber::fmt::try_init(); let mut rep = RepSocket::new(Tcp::default()); - rep.bind_socket(localhost()).await.unwrap(); + rep.bind(localhost()).await.unwrap(); let mut req = ReqSocket::new(Tcp::default()); - req.connect_socket(rep.local_addr().unwrap()).await.unwrap(); + req.connect(rep.local_addr().unwrap()).await.unwrap(); tokio::spawn(async move { loop { @@ -156,7 +156,7 @@ mod tests { // Try to connect even through the server isn't up yet let endpoint = addr.clone(); let connection_attempt = tokio::spawn(async move { - req.connect_socket(endpoint).await.unwrap(); + req.connect(endpoint).await.unwrap(); req }); @@ -164,7 +164,7 @@ mod tests { // Wait a moment to start the server tokio::time::sleep(Duration::from_millis(500)).await; let mut rep = RepSocket::new(Tcp::default()); - rep.bind_socket(addr).await.unwrap(); + rep.bind(addr).await.unwrap(); let req = connection_attempt.await.unwrap(); @@ -193,7 +193,7 @@ mod tests { let _ = tracing_subscriber::fmt::try_init(); let mut rep = RepSocket::new(Tcp::default()).with_auth(Auth); - rep.bind_socket(localhost()).await.unwrap(); + rep.bind(localhost()).await.unwrap(); // Initialize socket with a client ID. This will implicitly enable authentication. let mut req = ReqSocket::with_options( @@ -201,7 +201,7 @@ mod tests { ReqOptions::default().auth_token(Bytes::from("REQ")), ); - req.connect_socket(rep.local_addr().unwrap()).await.unwrap(); + req.connect(rep.local_addr().unwrap()).await.unwrap(); tracing::info!("Connected to rep"); @@ -236,16 +236,16 @@ mod tests { async fn rep_max_connections() { let _ = tracing_subscriber::fmt::try_init(); let mut rep = RepSocket::with_options(Tcp::default(), RepOptions::default().max_clients(1)); - rep.bind_socket("127.0.0.1:0").await.unwrap(); + rep.bind("127.0.0.1:0").await.unwrap(); let addr = rep.local_addr().unwrap(); let mut req1 = ReqSocket::new(Tcp::default()); - req1.connect_socket(addr).await.unwrap(); + req1.connect(addr).await.unwrap(); tokio::time::sleep(Duration::from_secs(1)).await; assert_eq!(rep.stats().active_clients(), 1); let mut req2 = ReqSocket::new(Tcp::default()); - req2.connect_socket(addr).await.unwrap(); + req2.connect(addr).await.unwrap(); tokio::time::sleep(Duration::from_secs(1)).await; assert_eq!(rep.stats().active_clients(), 1); } @@ -256,13 +256,13 @@ mod tests { RepSocket::with_options(Tcp::default(), RepOptions::default().min_compress_size(0)) .with_compressor(SnappyCompressor); - rep.bind_socket("0.0.0.0:4445").await.unwrap(); + rep.bind("0.0.0.0:4445").await.unwrap(); let mut req = ReqSocket::with_options(Tcp::default(), ReqOptions::default().min_compress_size(0)) .with_compressor(GzipCompressor::new(6)); - req.connect_socket("0.0.0.0:4445").await.unwrap(); + req.connect("0.0.0.0:4445").await.unwrap(); tokio::spawn(async move { let req = rep.next().await.unwrap(); diff --git a/msg-socket/src/rep/socket.rs b/msg-socket/src/rep/socket.rs index 6a905a5..fa2bbee 100644 --- a/msg-socket/src/rep/socket.rs +++ b/msg-socket/src/rep/socket.rs @@ -22,60 +22,53 @@ use crate::{ Authenticator, PubError, RepOptions, Request, }; -use msg_transport::Transport; +use msg_transport::{Address, Transport}; use msg_wire::compression::Compressor; /// A reply socket. This socket implements [`Stream`] and yields incoming [`Request`]s. #[derive(Default)] -pub struct RepSocket { +pub struct RepSocket, A: Address> { /// The reply socket options, shared with the driver. options: Arc, /// The reply socket state, shared with the driver. state: Arc, /// Receiver from the socket driver. - from_driver: Option>>, + from_driver: Option>>, /// The transport used by this socket. This value is temporary and will be moved /// to the driver task once the socket is bound. transport: Option, /// Optional connection authenticator. auth: Option>, /// The local address this socket is bound to. - local_addr: Option, + local_addr: Option, /// Optional message compressor. compressor: Option>, } -impl RepSocket +impl RepSocket where - T: Transport + Send + Unpin + 'static, + T: Transport + Send + Unpin + 'static, { - /// Binds the socket to the given socket addres - /// - /// This method is only available for transports that support [`SocketAddr`] as address type, - /// like [`Tcp`](msg_transport::tcp::Tcp) and [`Quic`](msg_transport::quic::Quic). - pub async fn bind_socket(&mut self, addr: impl ToSocketAddrs) -> Result<(), PubError> { + pub async fn bind(&mut self, addr: impl ToSocketAddrs) -> Result<(), PubError> { let addrs = lookup_host(addr).await?; self.try_bind(addrs.collect()).await } } -impl RepSocket +impl RepSocket where - T: Transport + Send + Unpin + 'static, + T: Transport + Send + Unpin + 'static, { - /// Binds the socket to the given path. - /// - /// This method is only available for transports that support [`PathBuf`] as address type, - /// like [`Ipc`](msg_transport::ipc::Ipc). - pub async fn bind_path(&mut self, path: impl Into) -> Result<(), PubError> { + pub async fn bind(&mut self, path: impl Into) -> Result<(), PubError> { let addr = path.into().clone(); self.try_bind(vec![addr]).await } } -impl RepSocket +impl RepSocket where - T: Transport + Send + Unpin + 'static, + T: Transport + Send + Unpin + 'static, + A: Address, { /// Creates a new reply socket with the default [`RepOptions`]. pub fn new(transport: T) -> Self { @@ -96,7 +89,7 @@ where } /// Sets the connection authenticator for this socket. - pub fn with_auth(mut self, authenticator: A) -> Self { + pub fn with_auth(mut self, authenticator: O) -> Self { self.auth = Some(Arc::new(authenticator)); self } @@ -108,7 +101,7 @@ where } /// Binds the socket to the given address. This spawns the socket driver task. - pub async fn try_bind(&mut self, addresses: Vec) -> Result<(), PubError> { + pub async fn try_bind(&mut self, addresses: Vec) -> Result<(), PubError> { let (to_socket, from_backend) = mpsc::channel(DEFAULT_BUFFER_SIZE); let mut transport = self @@ -160,13 +153,13 @@ where } /// Returns the local address this socket is bound to. `None` if the socket is not bound. - pub fn local_addr(&self) -> Option<&T::Addr> { + pub fn local_addr(&self) -> Option<&A> { self.local_addr.as_ref() } } -impl Stream for RepSocket { - type Item = Request; +impl + Unpin, A: Address> Stream for RepSocket { + type Item = Request; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.get_mut() diff --git a/msg-socket/src/req/driver.rs b/msg-socket/src/req/driver.rs index 5fb1a9f..afbae5f 100644 --- a/msg-socket/src/req/driver.rs +++ b/msg-socket/src/req/driver.rs @@ -19,7 +19,7 @@ use tracing::{debug, error, trace}; use super::{Command, ReqError, ReqOptions}; use crate::{req::SocketState, ConnectionState, ExponentialBackoff}; -use msg_transport::Transport; +use msg_transport::{Address, Transport}; use msg_wire::{ auth, compression::{try_decompress_payload, Compressor}, @@ -30,7 +30,7 @@ type ConnectionTask = Pin> + Se /// The request socket driver. Endless future that drives /// the the socket forward. -pub(crate) struct ReqDriver { +pub(crate) struct ReqDriver, A: Address> { /// Options shared with the socket. #[allow(unused)] pub(crate) options: Arc, @@ -43,13 +43,12 @@ pub(crate) struct ReqDriver { /// The transport for this socket. pub(crate) transport: T, /// The address of the server. - pub(crate) addr: T::Addr, + pub(crate) addr: A, /// The connection task which handles the connection to the server. pub(crate) conn_task: Option>, /// The transport controller, wrapped in a [`ConnectionState`] for backoff. /// The [`Framed`] object can send and receive messages from the socket. - pub(crate) conn_state: - ConnectionState, ExponentialBackoff, T::Addr>, + pub(crate) conn_state: ConnectionState, ExponentialBackoff, A>, /// The outgoing message queue. pub(crate) egress_queue: VecDeque, /// The currently pending requests, if any. Uses [`FxHashMap`] for performance. @@ -73,13 +72,14 @@ pub(crate) struct PendingRequest { sender: oneshot::Sender>, } -impl ReqDriver +impl ReqDriver where - T: Transport + Send + Sync + 'static, + T: Transport + Send + Sync + 'static, + A: Address, { /// Start the connection task to the server, handling authentication if necessary. /// The result will be polled by the driver and re-tried according to the backoff policy. - fn try_connect(&mut self, addr: T::Addr) { + fn try_connect(&mut self, addr: A) { trace!("Trying to connect to {:?}", addr); let connect = self.transport.connect(addr.clone()); @@ -248,9 +248,10 @@ where } } -impl Future for ReqDriver +impl Future for ReqDriver where - T: Transport + Unpin + Send + Sync + 'static, + T: Transport + Unpin + Send + Sync + 'static, + A: Address, { type Output = (); diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index 7edb1fa..8d61a0e 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -1,13 +1,14 @@ use bytes::Bytes; -use msg_wire::compression::Compressor; use rustc_hash::FxHashMap; +use std::marker::PhantomData; use std::net::SocketAddr; use std::path::PathBuf; use std::{io, sync::Arc, time::Duration}; use tokio::net::{lookup_host, ToSocketAddrs}; use tokio::sync::{mpsc, oneshot}; -use msg_transport::Transport; +use msg_transport::{Address, Transport}; +use msg_wire::compression::Compressor; use super::{Command, ReqDriver, ReqError, ReqOptions, DEFAULT_BUFFER_SIZE}; use crate::{ @@ -16,7 +17,7 @@ use crate::{ }; /// The request socket. -pub struct ReqSocket { +pub struct ReqSocket, A: Address> { /// Command channel to the backend task. to_driver: Option>, /// The socket transport. @@ -29,14 +30,16 @@ pub struct ReqSocket { // NOTE: for now we're using dynamic dispatch, since using generics here // complicates the API a lot. We can always change this later for perf reasons. compressor: Option>, + /// Marker for the address type. + _marker: PhantomData, } -impl ReqSocket +impl ReqSocket where - T: Transport + Send + Sync + Unpin + 'static, + T: Transport + Send + Sync + Unpin + 'static, { /// Connects to the target address with the default options. - pub async fn connect_socket(&mut self, addr: impl ToSocketAddrs) -> Result<(), ReqError> { + pub async fn connect(&mut self, addr: impl ToSocketAddrs) -> Result<(), ReqError> { let mut addrs = lookup_host(addr).await?; let endpoint = addrs.next().ok_or_else(|| { io::Error::new( @@ -49,19 +52,20 @@ where } } -impl ReqSocket +impl ReqSocket where - T: Transport + Send + Sync + Unpin + 'static, + T: Transport + Send + Sync + Unpin + 'static, { /// Connects to the target path with the default options. - pub async fn connect_path(&mut self, addr: impl Into) -> Result<(), ReqError> { + pub async fn connect(&mut self, addr: impl Into) -> Result<(), ReqError> { self.try_connect(addr.into().clone()).await } } -impl ReqSocket +impl ReqSocket where - T: Transport + Send + Sync + Unpin + 'static, + T: Transport + Send + Sync + Unpin + 'static, + A: Address, { pub fn new(transport: T) -> Self { Self::with_options(transport, ReqOptions::default()) @@ -74,6 +78,7 @@ where options: Arc::new(options), state: Arc::new(SocketState::default()), compressor: None, + _marker: PhantomData, } } @@ -107,7 +112,7 @@ where /// Tries to connect to the target endpoint with the default options. /// A ReqSocket can only be connected to a single address. - pub async fn try_connect(&mut self, endpoint: T::Addr) -> Result<(), ReqError> { + pub async fn try_connect(&mut self, endpoint: A) -> Result<(), ReqError> { // Initialize communication channels let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE); @@ -132,7 +137,7 @@ where let pending_requests = FxHashMap::default(); // Create the socket backend - let driver: ReqDriver = ReqDriver { + let driver: ReqDriver = ReqDriver { addr: endpoint, options: Arc::clone(&self.options), socket_state: Arc::clone(&self.state), diff --git a/msg-socket/src/sub/driver.rs b/msg-socket/src/sub/driver.rs index a157ecb..b450fa1 100644 --- a/msg-socket/src/sub/driver.rs +++ b/msg-socket/src/sub/driver.rs @@ -21,36 +21,36 @@ use super::{ use crate::{ConnectionState, ExponentialBackoff}; use msg_common::{channel, task::JoinMap, Channel}; -use msg_transport::Transport; +use msg_transport::{Address, Transport}; use msg_wire::{auth, compression::try_decompress_payload, pubsub}; /// Publisher channel type, used to send messages to the publisher session /// and receive messages to forward to the socket frontend. type PubChannel = Channel; -pub(crate) struct SubDriver { +pub(crate) struct SubDriver, A: Address> { /// Options shared with the socket. pub(super) options: Arc, /// The transport for this socket. pub(super) transport: T, /// Commands from the socket. - pub(super) from_socket: mpsc::Receiver>, + pub(super) from_socket: mpsc::Receiver>, /// Messages to the socket. - pub(super) to_socket: mpsc::Sender>, + pub(super) to_socket: mpsc::Sender>, /// A joinset of authentication tasks. - pub(super) connection_tasks: JoinMap>, + pub(super) connection_tasks: JoinMap>, /// The set of subscribed topics. pub(super) subscribed_topics: HashSet, /// All publisher sessions for this subscriber socket, keyed by address. - pub(super) publishers: - FxHashMap>, + pub(super) publishers: FxHashMap>, /// Socket state. This is shared with the backend task. - pub(super) state: Arc>, + pub(super) state: Arc>, } -impl Future for SubDriver +impl Future for SubDriver where - T: Transport + Send + Sync + Unpin + 'static, + T: Transport + Send + Sync + Unpin + 'static, + A: Address, { type Output = (); @@ -91,13 +91,14 @@ where } } -impl SubDriver +impl SubDriver where - T: Transport + Send + Sync + 'static, + T: Transport + Send + Sync + 'static, + A: Address, { /// De-activates a publisher by setting it to [`ConnectionState::Inactive`]. /// This will initialize the backoff stream. - fn reset_publisher(&mut self, addr: T::Addr) { + fn reset_publisher(&mut self, addr: A) { debug!("Resetting publisher at {addr:?}"); self.publishers.insert( addr.clone(), @@ -109,7 +110,7 @@ where } /// Returns true if we're already connected to the given publisher address. - fn is_connected(&self, addr: &T::Addr) -> bool { + fn is_connected(&self, addr: &A) -> bool { if self.publishers.get(addr).is_some_and(|s| s.is_active()) { return true; } @@ -117,7 +118,7 @@ where false } - fn is_known(&self, addr: &T::Addr) -> bool { + fn is_known(&self, addr: &A) -> bool { self.publishers.contains_key(addr) } @@ -190,7 +191,7 @@ where } } - fn on_command(&mut self, cmd: Command) { + fn on_command(&mut self, cmd: Command) { debug!("Received command: {:?}", cmd); match cmd { Command::Subscribe { topic } => { @@ -229,7 +230,7 @@ where } } - fn connect(&mut self, addr: T::Addr) { + fn connect(&mut self, addr: A) { let connect = self.transport.connect(addr.clone()); let token = self.options.auth_token.clone(); @@ -294,7 +295,7 @@ where }); } - fn on_connection(&mut self, addr: T::Addr, io: T::Io) { + fn on_connection(&mut self, addr: A, io: T::Io) { if self.is_connected(&addr) { // We're already connected to this publisher warn!(?addr, "Already connected to publisher"); diff --git a/msg-socket/src/sub/socket.rs b/msg-socket/src/sub/socket.rs index 69e9cd8..b767a32 100644 --- a/msg-socket/src/sub/socket.rs +++ b/msg-socket/src/sub/socket.rs @@ -15,35 +15,35 @@ use tokio::{ }; use msg_common::task::JoinMap; -use msg_transport::Transport; +use msg_transport::{Address, Transport}; use super::{ Command, PubMessage, SocketState, SocketStats, SubDriver, SubError, SubOptions, DEFAULT_BUFFER_SIZE, }; -pub struct SubSocket { +pub struct SubSocket, A: Address> { /// Command channel to the socket driver. - to_driver: mpsc::Sender>, + to_driver: mpsc::Sender>, /// Receiver channel from the socket driver. - from_driver: mpsc::Receiver>, + from_driver: mpsc::Receiver>, /// Options for the socket. These are shared with the backend task. #[allow(unused)] options: Arc, /// The pending driver. - driver: Option>, + driver: Option>, /// Socket state. This is shared with the socket frontend. - state: Arc>, + state: Arc>, /// Marker for the transport type. _marker: std::marker::PhantomData, } -impl SubSocket +impl SubSocket where - T: Transport + Send + Sync + Unpin + 'static, + T: Transport + Send + Sync + Unpin + 'static, { /// Connects to the given endpoint asynchronously. - pub async fn connect_socket(&mut self, endpoint: impl ToSocketAddrs) -> Result<(), SubError> { + pub async fn connect(&mut self, endpoint: impl ToSocketAddrs) -> Result<(), SubError> { let mut addrs = lookup_host(endpoint).await?; let mut endpoint = addrs.next().ok_or(SubError::Io(io::Error::new( io::ErrorKind::InvalidInput, @@ -57,11 +57,11 @@ where endpoint.set_ip(IpAddr::V4(Ipv4Addr::LOCALHOST)); } - self.connect(endpoint).await + self.connect_inner(endpoint).await } /// Attempts to connect to the given endpoint immediately. - pub fn try_connect_socket(&mut self, endpoint: impl Into) -> Result<(), SubError> { + pub fn try_connect(&mut self, endpoint: impl Into) -> Result<(), SubError> { let addr = endpoint.into(); let mut endpoint: SocketAddr = addr.parse().map_err(|_| { SubError::Io(io::Error::new( @@ -77,23 +77,22 @@ where endpoint.set_ip(IpAddr::V4(Ipv4Addr::LOCALHOST)); } - self.try_connect(endpoint) + self.try_connect_inner(endpoint) } - pub async fn disconnect_socket( - &mut self, - endpoint: impl ToSocketAddrs, - ) -> Result<(), SubError> { + /// Disconnects from the given endpoint asynchronously. + pub async fn disconnect(&mut self, endpoint: impl ToSocketAddrs) -> Result<(), SubError> { let mut addrs = lookup_host(endpoint).await?; let endpoint = addrs.next().ok_or(SubError::Io(io::Error::new( io::ErrorKind::InvalidInput, "could not find any valid address", )))?; - self.disconnect(endpoint).await + self.disconnect_inner(endpoint).await } - pub fn try_disconnect_socket(&mut self, endpoint: impl Into) -> Result<(), SubError> { + /// Attempts to disconnect from the given endpoint immediately. + pub fn try_disconnect(&mut self, endpoint: impl Into) -> Result<(), SubError> { let endpoint = endpoint.into(); let endpoint: SocketAddr = endpoint.parse().map_err(|_| { SubError::Io(io::Error::new( @@ -102,38 +101,39 @@ where )) })?; - self.try_disconnect(endpoint) + self.try_disconnect_inner(endpoint) } } -impl SubSocket +impl SubSocket where - T: Transport + Send + Sync + Unpin + 'static, + T: Transport + Send + Sync + Unpin + 'static, { /// Connects to the given path asynchronously. pub async fn connect_path(&mut self, path: impl Into) -> Result<(), SubError> { - self.connect(path.into()).await + self.connect_inner(path.into()).await } /// Attempts to connect to the given path immediately. pub fn try_connect_path(&mut self, path: impl Into) -> Result<(), SubError> { - self.try_connect(path.into()) + self.try_connect_inner(path.into()) } /// Disconnects from the given path asynchronously. pub async fn disconnect_path(&mut self, path: impl Into) -> Result<(), SubError> { - self.disconnect(path.into()).await + self.disconnect_inner(path.into()).await } /// Attempts to disconnect from the given path immediately. pub fn try_disconnect_path(&mut self, path: impl Into) -> Result<(), SubError> { - self.try_disconnect(path.into()) + self.try_disconnect_inner(path.into()) } } -impl SubSocket +impl SubSocket where - T: Transport + Send + Sync + Unpin + 'static, + T: Transport + Send + Sync + Unpin + 'static, + A: Address, { #[allow(clippy::new_without_default)] pub fn new(transport: T) -> Self { @@ -173,28 +173,28 @@ where } /// Asynchronously connects to the endpoint. - pub async fn connect(&mut self, endpoint: T::Addr) -> Result<(), SubError> { + pub async fn connect_inner(&mut self, endpoint: A) -> Result<(), SubError> { self.ensure_active_driver(); self.send_command(Command::Connect { endpoint }).await?; Ok(()) } /// Immediately send a connect command to the driver. - pub fn try_connect(&mut self, endpoint: T::Addr) -> Result<(), SubError> { + pub fn try_connect_inner(&mut self, endpoint: A) -> Result<(), SubError> { self.ensure_active_driver(); self.try_send_command(Command::Connect { endpoint })?; Ok(()) } /// Asynchronously disconnects from the endpoint. - pub async fn disconnect(&mut self, endpoint: T::Addr) -> Result<(), SubError> { + pub async fn disconnect_inner(&mut self, endpoint: A) -> Result<(), SubError> { self.ensure_active_driver(); self.send_command(Command::Disconnect { endpoint }).await?; Ok(()) } /// Immediately send a disconnect command to the driver. - pub fn try_disconnect(&mut self, endpoint: T::Addr) -> Result<(), SubError> { + pub fn try_disconnect_inner(&mut self, endpoint: A) -> Result<(), SubError> { self.ensure_active_driver(); self.try_send_command(Command::Disconnect { endpoint })?; Ok(()) @@ -252,7 +252,7 @@ where /// Sends a command to the driver, returning [`SubError::SocketClosed`] if the /// driver has been dropped. - async fn send_command(&self, command: Command) -> Result<(), SubError> { + async fn send_command(&self, command: Command) -> Result<(), SubError> { self.to_driver .send(command) .await @@ -261,7 +261,7 @@ where Ok(()) } - fn try_send_command(&self, command: Command) -> Result<(), SubError> { + fn try_send_command(&self, command: Command) -> Result<(), SubError> { use mpsc::error::TrySendError::*; self.to_driver.try_send(command).map_err(|e| match e { Full(_) => SubError::ChannelFull, @@ -278,20 +278,20 @@ where } } - pub fn stats(&self) -> &SocketStats { + pub fn stats(&self) -> &SocketStats { &self.state.stats } } -impl Drop for SubSocket { +impl, A: Address> Drop for SubSocket { fn drop(&mut self) { // Try to tell the driver to gracefully shut down. let _ = self.to_driver.try_send(Command::Shutdown); } } -impl Stream for SubSocket { - type Item = PubMessage; +impl + Unpin, A: Address> Stream for SubSocket { + type Item = PubMessage; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.from_driver.poll_recv(cx) diff --git a/msg-socket/tests/it/pubsub.rs b/msg-socket/tests/it/pubsub.rs index 8645a54..f89d245 100644 --- a/msg-socket/tests/it/pubsub.rs +++ b/msg-socket/tests/it/pubsub.rs @@ -6,7 +6,7 @@ use tokio::{sync::mpsc, task::JoinSet}; use tokio_stream::StreamExt; use msg_socket::{PubSocket, SubSocket}; -use msg_transport::{quic::Quic, tcp::Tcp, Transport}; +use msg_transport::{quic::Quic, tcp::Tcp, Address, Transport}; const TOPIC: &str = "test"; @@ -46,18 +46,19 @@ async fn pubsub_channel() { simulator.stop(addr); } -async fn pubsub_channel_transport( +async fn pubsub_channel_transport( new_transport: F, - addr: T::Addr, + addr: A, ) -> Result<(), Box> where F: Fn() -> T, - T: Transport + Send + Sync + Unpin + 'static, + T: Transport + Send + Sync + Unpin + 'static, + A: Address, { let mut publisher = PubSocket::new(new_transport()); let mut subscriber = SubSocket::new(new_transport()); - subscriber.connect(addr.clone()).await?; + subscriber.connect_inner(addr.clone()).await?; subscriber.subscribe(TOPIC).await?; inject_delay(400).await; @@ -114,11 +115,12 @@ async fn pubsub_fan_out() { async fn pubsub_fan_out_transport< F: Fn() -> T + Send + 'static + Copy, - T: Transport + Send + Sync + Unpin + 'static, + T: Transport + Send + Sync + Unpin + 'static, + A: Address, >( new_transport: F, subscibers: usize, - addr: T::Addr, + addr: A, ) -> Result<(), Box> { let mut publisher = PubSocket::new(new_transport()); @@ -130,7 +132,7 @@ async fn pubsub_fan_out_transport< let mut subscriber = SubSocket::new(new_transport()); inject_delay((100 * (i + 1)) as u64).await; - subscriber.connect(cloned).await.unwrap(); + subscriber.connect_inner(cloned).await.unwrap(); inject_delay((1000 / (i + 1)) as u64).await; subscriber.subscribe(TOPIC).await.unwrap(); @@ -194,11 +196,12 @@ async fn pubsub_fan_in() { async fn pubsub_fan_in_transport< F: Fn() -> T + Send + 'static + Copy, - T: Transport + Send + Sync + Unpin + 'static, + T: Transport + Send + Sync + Unpin + 'static, + A: Address, >( new_transport: F, publishers: usize, - addr: T::Addr, + addr: A, ) -> Result<(), Box> { let mut sub_tasks = JoinSet::new(); @@ -241,7 +244,7 @@ async fn pubsub_fan_in_transport< for addr in addrs.clone() { inject_delay(500).await; - subscriber.connect(addr.clone()).await.unwrap(); + subscriber.connect_inner(addr.clone()).await.unwrap(); subscriber.subscribe(TOPIC).await.unwrap(); } diff --git a/msg-transport/src/ipc/mod.rs b/msg-transport/src/ipc/mod.rs index cd5490a..ad0ccce 100644 --- a/msg-transport/src/ipc/mod.rs +++ b/msg-transport/src/ipc/mod.rs @@ -97,8 +97,7 @@ impl PeerAddress for IpcStream { } #[async_trait] -impl Transport for Ipc { - type Addr = PathBuf; +impl Transport for Ipc { type Io = IpcStream; type Error = io::Error; @@ -106,11 +105,11 @@ impl Transport for Ipc { type Connect = BoxFuture<'static, Result>; type Accept = BoxFuture<'static, Result>; - fn local_addr(&self) -> Option { + fn local_addr(&self) -> Option { self.path.clone() } - async fn bind(&mut self, addr: Self::Addr) -> Result<(), Self::Error> { + async fn bind(&mut self, addr: PathBuf) -> Result<(), Self::Error> { if addr.exists() { debug!("Socket file already exists. Attempting to remove."); if let Err(e) = std::fs::remove_file(&addr) { @@ -127,7 +126,7 @@ impl Transport for Ipc { Ok(()) } - fn connect(&mut self, addr: Self::Addr) -> Self::Connect { + fn connect(&mut self, addr: PathBuf) -> Self::Connect { Box::pin(async move { IpcStream::connect(addr).await }) } @@ -155,8 +154,8 @@ impl Transport for Ipc { } #[async_trait] -impl TransportExt for Ipc { - fn accept(&mut self) -> Acceptor<'_, Self> +impl TransportExt for Ipc { + fn accept(&mut self) -> Acceptor<'_, Self, PathBuf> where Self: Sized + Unpin, { diff --git a/msg-transport/src/lib.rs b/msg-transport/src/lib.rs index b241da9..9e65a4b 100644 --- a/msg-transport/src/lib.rs +++ b/msg-transport/src/lib.rs @@ -3,11 +3,13 @@ #![cfg_attr(not(test), warn(unused_crate_dependencies))] use futures::{Future, FutureExt}; +use msg_common::io_error; use std::{ fmt::Debug, hash::Hash, io, - net::SocketAddr, + marker::PhantomData, + net::{SocketAddr, ToSocketAddrs}, path::PathBuf, pin::Pin, task::{Context, Poll}, @@ -27,20 +29,78 @@ impl Address for SocketAddr {} /// File system path, used for IPC transport. impl Address for PathBuf {} +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum AddressType { + SocketAddr(SocketAddr), + PathBuf(PathBuf), +} + +impl From for AddressType { + fn from(addr: SocketAddr) -> Self { + AddressType::SocketAddr(addr) + } +} + +impl From for AddressType { + fn from(path: PathBuf) -> Self { + AddressType::PathBuf(path) + } +} + +impl TryFrom<&str> for AddressType { + type Error = io::Error; + + fn try_from(value: &str) -> Result { + let s = value.to_string(); + if s.contains(':') { + // try to parse as socket address + let addr = s + .parse::() + .map_err(|_| io_error("invalid socket address"))?; + + Ok(AddressType::SocketAddr(addr)) + } else { + // try to parse as path + let path = PathBuf::from(s); + Ok(AddressType::PathBuf(path)) + } + } +} + +impl ToSocketAddrs for AddressType { + type Iter = std::vec::IntoIter; + + fn to_socket_addrs(&self) -> io::Result { + match self { + AddressType::SocketAddr(addr) => Ok(vec![*addr].into_iter()), + AddressType::PathBuf(_) => Err(io_error("path is not a valid socket address")), + } + } +} + +impl From for PathBuf { + fn from(val: AddressType) -> Self { + match val { + AddressType::SocketAddr(_) => panic!("socket address is not a valid path"), + AddressType::PathBuf(path) => path, + } + } +} + /// A transport provides connection-oriented communication between two peers through /// ordered and reliable streams of bytes. /// /// It provides an interface to manage both inbound and outbound connections. #[async_trait::async_trait] -pub trait Transport { - /// The generic address type used by this transport - type Addr: Address; +pub trait Transport { + // /// The generic address type used by this transport + // type Addr: Address; /// The result of a successful connection. /// /// The output type is transport-specific, and can be a handle to directly write to the /// connection, or it can be a substream multiplexer in the case of stream protocols. - type Io: AsyncRead + AsyncWrite + PeerAddress + Send + Unpin; + type Io: AsyncRead + AsyncWrite + PeerAddress + Send + Unpin; /// An error that occurred when setting up the connection. type Error: std::error::Error + From + Send + Sync; @@ -54,22 +114,22 @@ pub trait Transport { type Accept: Future> + Send + Unpin; /// Returns the local address this transport is bound to (if it is bound). - fn local_addr(&self) -> Option; + fn local_addr(&self) -> Option; /// Binds to the given address. - async fn bind(&mut self, addr: Self::Addr) -> Result<(), Self::Error>; + async fn bind(&mut self, addr: A) -> Result<(), Self::Error>; /// Connects to the given address, returning a future representing a pending outbound connection. - fn connect(&mut self, addr: Self::Addr) -> Self::Connect; + fn connect(&mut self, addr: A) -> Self::Connect; /// Poll for incoming connections. If an inbound connection is received, a future representing /// a pending inbound connection is returned. The future will resolve to [`Transport::Output`]. fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll; } -pub trait TransportExt: Transport { +pub trait TransportExt: Transport { /// Async-friendly interface for accepting inbound connections. - fn accept(&mut self) -> Acceptor<'_, Self> + fn accept(&mut self) -> Acceptor<'_, Self, A> where Self: Sized + Unpin, { @@ -77,19 +137,24 @@ pub trait TransportExt: Transport { } } -pub struct Acceptor<'a, T> { +pub struct Acceptor<'a, T, A> { inner: &'a mut T, + _marker: PhantomData, } -impl<'a, T> Acceptor<'a, T> { +impl<'a, T, A> Acceptor<'a, T, A> { fn new(inner: &'a mut T) -> Self { - Self { inner } + Self { + inner, + _marker: PhantomData, + } } } -impl<'a, T> Future for Acceptor<'a, T> +impl<'a, T, A> Future for Acceptor<'a, T, A> where - T: Transport + Unpin, + T: Transport + Unpin, + A: Address, { type Output = Result; diff --git a/msg-transport/src/quic/mod.rs b/msg-transport/src/quic/mod.rs index ebff405..120f3ce 100644 --- a/msg-transport/src/quic/mod.rs +++ b/msg-transport/src/quic/mod.rs @@ -84,8 +84,7 @@ impl Quic { } #[async_trait::async_trait] -impl Transport for Quic { - type Addr = SocketAddr; +impl Transport for Quic { type Io = QuicStream; type Error = Error; @@ -218,8 +217,8 @@ impl Transport for Quic { } } -impl TransportExt for Quic { - fn accept(&mut self) -> crate::Acceptor<'_, Self> +impl TransportExt for Quic { + fn accept(&mut self) -> crate::Acceptor<'_, Self, SocketAddr> where Self: Sized + Unpin, { diff --git a/msg-transport/src/tcp/mod.rs b/msg-transport/src/tcp/mod.rs index 0120263..79ee39d 100644 --- a/msg-transport/src/tcp/mod.rs +++ b/msg-transport/src/tcp/mod.rs @@ -36,8 +36,7 @@ impl PeerAddress for TcpStream { } #[async_trait::async_trait] -impl Transport for Tcp { - type Addr = SocketAddr; +impl Transport for Tcp { type Io = TcpStream; type Error = io::Error; @@ -89,8 +88,8 @@ impl Transport for Tcp { } #[async_trait::async_trait] -impl TransportExt for Tcp { - fn accept(&mut self) -> Acceptor<'_, Self> +impl TransportExt for Tcp { + fn accept(&mut self) -> Acceptor<'_, Self, SocketAddr> where Self: Sized + Unpin, { diff --git a/msg/benches/pubsub.rs b/msg/benches/pubsub.rs index d1eaf53..af5ff7e 100644 --- a/msg/benches/pubsub.rs +++ b/msg/benches/pubsub.rs @@ -4,7 +4,7 @@ use criterion::{ Throughput, }; use futures::StreamExt; -use msg::ipc::Ipc; +use msg::{ipc::Ipc, Address}; use pprof::criterion::{Output, PProfProfiler}; use rand::Rng; use std::{ @@ -24,24 +24,24 @@ const MSG_SIZE: usize = 512; #[global_allocator] static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc; -struct PairBenchmark { +struct PairBenchmark, A: Address> { rt: Runtime, - publisher: PubSocket, - subscriber: SubSocket, + publisher: PubSocket, + subscriber: SubSocket, n_reqs: usize, msg_sizes: Vec, } -impl PairBenchmark { +impl + Send + Sync + Unpin + 'static, A: Address> PairBenchmark { /// Sets up the publisher and subscriber sockets. - fn init(&mut self, addr: T::Addr) { + fn init(&mut self, addr: A) { // Set up the socket connections self.rt.block_on(async { self.publisher.try_bind(vec![addr]).await.unwrap(); let addr = self.publisher.local_addr().unwrap(); - self.subscriber.connect(addr.clone()).await.unwrap(); + self.subscriber.connect_inner(addr.clone()).await.unwrap(); self.subscriber .subscribe("HELLO".to_string()) diff --git a/msg/benches/reqrep.rs b/msg/benches/reqrep.rs index bc84e19..981b77a 100644 --- a/msg/benches/reqrep.rs +++ b/msg/benches/reqrep.rs @@ -6,10 +6,10 @@ use criterion::{ Throughput, }; use futures::StreamExt; -use msg::{ipc::Ipc, Transport}; use pprof::criterion::Output; use rand::Rng; +use msg::{ipc::Ipc, Address, Transport}; use msg_socket::{RepSocket, ReqOptions, ReqSocket}; use msg_transport::tcp::Tcp; use tokio::runtime::Runtime; @@ -23,17 +23,17 @@ const MSG_SIZE: usize = 512; #[global_allocator] static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc; -struct PairBenchmark { +struct PairBenchmark, A: Address> { rt: Runtime, - req: ReqSocket, - rep: Option>, + req: ReqSocket, + rep: Option>, n_reqs: usize, msg_sizes: Vec, } -impl PairBenchmark { - fn init(&mut self, addr: T::Addr) { +impl + Send + Sync + Unpin + 'static, A: Address> PairBenchmark { + fn init(&mut self, addr: A) { let mut rep = self.rep.take().unwrap(); // setup the socket connections self.rt.block_on(async { diff --git a/msg/examples/durable.rs b/msg/examples/durable.rs index f6a3662..42cff7e 100644 --- a/msg/examples/durable.rs +++ b/msg/examples/durable.rs @@ -23,7 +23,7 @@ async fn start_rep() { // Initialize the reply socket (server side) with a transport // and an authenticator. let mut rep = RepSocket::new(Tcp::default()).with_auth(Auth); - while rep.bind_socket("0.0.0.0:4444").await.is_err() { + while rep.bind("0.0.0.0:4444").await.is_err() { rep = RepSocket::new(Tcp::default()).with_auth(Auth); tracing::warn!("Failed to bind rep socket, retrying..."); tokio::time::sleep(Duration::from_secs(1)).await; @@ -76,7 +76,7 @@ async fn main() { tokio::spawn( async move { tracing::info!("Trying to connect to rep socket... This will start the connection process in the background, it won't immediately connect."); - req.connect_socket("0.0.0.0:4444").await.unwrap(); + req.connect("0.0.0.0:4444").await.unwrap(); for i in 0..10 { tracing::info!("Sending request {i}..."); diff --git a/msg/examples/ipc.rs b/msg/examples/ipc.rs index 29e32e4..b9578f1 100644 --- a/msg/examples/ipc.rs +++ b/msg/examples/ipc.rs @@ -14,12 +14,12 @@ async fn main() { // use a temporary file as the socket path let path = temp_dir().join("test.sock"); - rep.bind_path(path.clone()).await.unwrap(); + rep.bind(path.clone()).await.unwrap(); println!("Listening on {:?}", rep.local_addr().unwrap()); // Initialize the request socket (client side) with a transport let mut req = ReqSocket::new(Ipc::default()); - req.connect_path(path).await.unwrap(); + req.connect(path).await.unwrap(); tokio::spawn(async move { // Receive the request and respond with "world" diff --git a/msg/examples/pubsub.rs b/msg/examples/pubsub.rs index 8752870..6b8c54c 100644 --- a/msg/examples/pubsub.rs +++ b/msg/examples/pubsub.rs @@ -31,17 +31,17 @@ async fn main() { ); tracing::info!("Setting up the sockets..."); - pub_socket.bind_socket("127.0.0.1:0").await.unwrap(); + pub_socket.bind("127.0.0.1:0").await.unwrap(); let pub_addr = pub_socket.local_addr().unwrap(); tracing::info!("Publisher listening on: {}", pub_addr); - sub1.connect_socket(pub_addr).await.unwrap(); + sub1.connect(pub_addr).await.unwrap(); sub1.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); tracing::info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); - sub2.connect_socket(pub_addr).await.unwrap(); + sub2.connect(pub_addr).await.unwrap(); sub2.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); tracing::info!("Subscriber 2 connected and subscribed to HELLO_TOPIC"); diff --git a/msg/examples/pubsub_auth.rs b/msg/examples/pubsub_auth.rs index 5d4d796..59d7208 100644 --- a/msg/examples/pubsub_auth.rs +++ b/msg/examples/pubsub_auth.rs @@ -45,17 +45,17 @@ async fn main() { ); tracing::info!("Setting up the sockets..."); - pub_socket.bind_socket("127.0.0.1:0").await.unwrap(); + pub_socket.bind("127.0.0.1:0").await.unwrap(); let pub_addr = pub_socket.local_addr().unwrap(); tracing::info!("Publisher listening on: {}", pub_addr); - sub1.connect_socket(pub_addr).await.unwrap(); + sub1.connect(pub_addr).await.unwrap(); sub1.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); tracing::info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); - sub2.connect_socket(pub_addr).await.unwrap(); + sub2.connect(pub_addr).await.unwrap(); sub2.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); tracing::info!("Subscriber 2 connected and subscribed to HELLO_TOPIC"); diff --git a/msg/examples/pubsub_compression.rs b/msg/examples/pubsub_compression.rs index c887805..37cfaa9 100644 --- a/msg/examples/pubsub_compression.rs +++ b/msg/examples/pubsub_compression.rs @@ -21,17 +21,17 @@ async fn main() { let mut sub2 = SubSocket::new(Tcp::default()); tracing::info!("Setting up the sockets..."); - pub_socket.bind_socket("127.0.0.1:0").await.unwrap(); + pub_socket.bind("127.0.0.1:0").await.unwrap(); let pub_addr = pub_socket.local_addr().unwrap(); tracing::info!("Publisher listening on: {}", pub_addr); - sub1.connect_socket(pub_addr).await.unwrap(); + sub1.connect(pub_addr).await.unwrap(); sub1.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); tracing::info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); - sub2.connect_socket(pub_addr).await.unwrap(); + sub2.connect(pub_addr).await.unwrap(); sub2.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); tracing::info!("Subscriber 2 connected and subscribed to HELLO_TOPIC"); diff --git a/msg/examples/quic_vs_tcp.rs b/msg/examples/quic_vs_tcp.rs index c0583cc..f720f70 100644 --- a/msg/examples/quic_vs_tcp.rs +++ b/msg/examples/quic_vs_tcp.rs @@ -3,7 +3,7 @@ use futures::StreamExt; use msg_transport::{quic::Quic, Transport}; use std::time::{Duration, Instant}; -use msg::{tcp::Tcp, PubOptions, PubSocket, SubOptions, SubSocket}; +use msg::{tcp::Tcp, Address, PubOptions, PubSocket, SubOptions, SubSocket}; #[tokio::main] async fn main() { @@ -30,12 +30,12 @@ async fn run_tcp() { ); tracing::info!("Setting up the sockets..."); - pub_socket.bind_socket("127.0.0.1:0").await.unwrap(); + pub_socket.bind("127.0.0.1:0").await.unwrap(); let pub_addr = pub_socket.local_addr().unwrap(); tracing::info!("Publisher listening on: {}", pub_addr); - sub1.connect_socket(pub_addr).await.unwrap(); + sub1.connect(pub_addr).await.unwrap(); sub1.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); tracing::info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); @@ -63,12 +63,12 @@ async fn run_quic() { ); tracing::info!("Setting up the sockets..."); - pub_socket.bind_socket("127.0.0.1:0").await.unwrap(); + pub_socket.bind("127.0.0.1:0").await.unwrap(); let pub_addr = pub_socket.local_addr().unwrap(); tracing::info!("Publisher listening on: {}", pub_addr); - sub1.connect_socket(pub_addr).await.unwrap(); + sub1.connect(pub_addr).await.unwrap(); sub1.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); tracing::info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); @@ -78,10 +78,10 @@ async fn run_quic() { run_transfer("QUIC", &mut pub_socket, &mut sub1).await; } -async fn run_transfer( +async fn run_transfer + Send + Unpin + 'static, A: Address>( transport: &str, - pub_socket: &mut PubSocket, - sub_socket: &mut SubSocket, + pub_socket: &mut PubSocket, + sub_socket: &mut SubSocket, ) { let data = Bytes::from( std::fs::read("./testdata/mainnetCapellaBlock7928030.ssz") diff --git a/msg/examples/reqrep.rs b/msg/examples/reqrep.rs index 4b153a6..286facd 100644 --- a/msg/examples/reqrep.rs +++ b/msg/examples/reqrep.rs @@ -7,11 +7,11 @@ use msg::{tcp::Tcp, RepSocket, ReqSocket}; async fn main() { // Initialize the reply socket (server side) with a transport let mut rep = RepSocket::new(Tcp::default()); - rep.bind_socket("0.0.0.0:4444").await.unwrap(); + rep.bind("0.0.0.0:4444").await.unwrap(); // Initialize the request socket (client side) with a transport let mut req = ReqSocket::new(Tcp::default()); - req.connect_socket("0.0.0.0:4444").await.unwrap(); + req.connect("0.0.0.0:4444").await.unwrap(); tokio::spawn(async move { // Receive the request and respond with "world" diff --git a/msg/examples/reqrep_auth.rs b/msg/examples/reqrep_auth.rs index 8c2e17e..2b4b54e 100644 --- a/msg/examples/reqrep_auth.rs +++ b/msg/examples/reqrep_auth.rs @@ -20,7 +20,7 @@ async fn main() { // Initialize the reply socket (server side) with a transport // and an authenticator. let mut rep = RepSocket::new(Tcp::default()).with_auth(Auth); - rep.bind_socket("0.0.0.0:4444").await.unwrap(); + rep.bind("0.0.0.0:4444").await.unwrap(); // Initialize the request socket (client side) with a transport // and an identifier. This will implicitly turn on client authentication. @@ -29,7 +29,7 @@ async fn main() { ReqOptions::default().auth_token(Bytes::from("REQ")), ); - req.connect_socket("0.0.0.0:4444").await.unwrap(); + req.connect("0.0.0.0:4444").await.unwrap(); tokio::spawn(async move { // Receive the request and respond with "world" diff --git a/msg/examples/reqrep_compression.rs b/msg/examples/reqrep_compression.rs index e5ac2d5..bb85132 100644 --- a/msg/examples/reqrep_compression.rs +++ b/msg/examples/reqrep_compression.rs @@ -13,7 +13,7 @@ async fn main() { RepSocket::with_options(Tcp::default(), RepOptions::default().min_compress_size(0)) // Enable Gzip compression (compression level 6) .with_compressor(GzipCompressor::new(6)); - rep.bind_socket("0.0.0.0:4444").await.unwrap(); + rep.bind("0.0.0.0:4444").await.unwrap(); // Initialize the request socket (client side) with a transport // and a minimum compresion size of 0 bytes to compress all requests @@ -24,7 +24,7 @@ async fn main() { // use the same compression algorithm or level. .with_compressor(GzipCompressor::new(6)); - req.connect_socket("0.0.0.0:4444").await.unwrap(); + req.connect("0.0.0.0:4444").await.unwrap(); tokio::spawn(async move { // Receive the request and respond with "world" From f0cd1e3f44fc94736f5caa03a3df57d4f17235e1 Mon Sep 17 00:00:00 2001 From: nicolas <48695862+merklefruit@users.noreply.github.com> Date: Mon, 26 Aug 2024 18:21:41 +0200 Subject: [PATCH 6/7] chore: small fixes, removed unused code --- msg-common/src/lib.rs | 6 ---- msg-transport/src/lib.rs | 61 +--------------------------------------- 2 files changed, 1 insertion(+), 66 deletions(-) diff --git a/msg-common/src/lib.rs b/msg-common/src/lib.rs index 1225f5e..f7e4b94 100644 --- a/msg-common/src/lib.rs +++ b/msg-common/src/lib.rs @@ -3,7 +3,6 @@ #![cfg_attr(not(test), warn(unused_crate_dependencies))] use std::{ - io, pin::Pin, task::{Context, Poll}, time::SystemTime, @@ -35,11 +34,6 @@ pub fn async_error( Box::pin(async move { Err(e) }) } -/// Creates a new [`io::Error`] with the given message. -pub fn io_error(msg: impl Into) -> io::Error { - io::Error::new(io::ErrorKind::Other, msg.into()) -} - #[allow(non_upper_case_globals)] pub mod constants { pub const KiB: u32 = 1024; diff --git a/msg-transport/src/lib.rs b/msg-transport/src/lib.rs index 9e65a4b..2419f3a 100644 --- a/msg-transport/src/lib.rs +++ b/msg-transport/src/lib.rs @@ -3,13 +3,12 @@ #![cfg_attr(not(test), warn(unused_crate_dependencies))] use futures::{Future, FutureExt}; -use msg_common::io_error; use std::{ fmt::Debug, hash::Hash, io, marker::PhantomData, - net::{SocketAddr, ToSocketAddrs}, + net::SocketAddr, path::PathBuf, pin::Pin, task::{Context, Poll}, @@ -29,64 +28,6 @@ impl Address for SocketAddr {} /// File system path, used for IPC transport. impl Address for PathBuf {} -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum AddressType { - SocketAddr(SocketAddr), - PathBuf(PathBuf), -} - -impl From for AddressType { - fn from(addr: SocketAddr) -> Self { - AddressType::SocketAddr(addr) - } -} - -impl From for AddressType { - fn from(path: PathBuf) -> Self { - AddressType::PathBuf(path) - } -} - -impl TryFrom<&str> for AddressType { - type Error = io::Error; - - fn try_from(value: &str) -> Result { - let s = value.to_string(); - if s.contains(':') { - // try to parse as socket address - let addr = s - .parse::() - .map_err(|_| io_error("invalid socket address"))?; - - Ok(AddressType::SocketAddr(addr)) - } else { - // try to parse as path - let path = PathBuf::from(s); - Ok(AddressType::PathBuf(path)) - } - } -} - -impl ToSocketAddrs for AddressType { - type Iter = std::vec::IntoIter; - - fn to_socket_addrs(&self) -> io::Result { - match self { - AddressType::SocketAddr(addr) => Ok(vec![*addr].into_iter()), - AddressType::PathBuf(_) => Err(io_error("path is not a valid socket address")), - } - } -} - -impl From for PathBuf { - fn from(val: AddressType) -> Self { - match val { - AddressType::SocketAddr(_) => panic!("socket address is not a valid path"), - AddressType::PathBuf(path) => path, - } - } -} - /// A transport provides connection-oriented communication between two peers through /// ordered and reliable streams of bytes. /// From cbab9f372b12803aeb1efa7ef4bc5bfb5054caf6 Mon Sep 17 00:00:00 2001 From: nicolas <48695862+merklefruit@users.noreply.github.com> Date: Mon, 26 Aug 2024 18:22:29 +0200 Subject: [PATCH 7/7] chore: small fixes --- msg-transport/src/lib.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/msg-transport/src/lib.rs b/msg-transport/src/lib.rs index 2419f3a..138c003 100644 --- a/msg-transport/src/lib.rs +++ b/msg-transport/src/lib.rs @@ -2,7 +2,6 @@ #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] #![cfg_attr(not(test), warn(unused_crate_dependencies))] -use futures::{Future, FutureExt}; use std::{ fmt::Debug, hash::Hash, @@ -13,6 +12,9 @@ use std::{ pin::Pin, task::{Context, Poll}, }; + +use async_trait::async_trait; +use futures::{Future, FutureExt}; use tokio::io::{AsyncRead, AsyncWrite}; pub mod ipc; @@ -32,11 +34,8 @@ impl Address for PathBuf {} /// ordered and reliable streams of bytes. /// /// It provides an interface to manage both inbound and outbound connections. -#[async_trait::async_trait] +#[async_trait] pub trait Transport { - // /// The generic address type used by this transport - // type Addr: Address; - /// The result of a successful connection. /// /// The output type is transport-specific, and can be a handle to directly write to the