Skip to content

Commit

Permalink
Merge pull request #45 from chainbound/feat/header-compression-type
Browse files Browse the repository at this point in the history
feat: Pub/Sub auto-decompression
  • Loading branch information
merklefruit authored Dec 20, 2023
2 parents 27acfde + d3a3cd0 commit b282179
Show file tree
Hide file tree
Showing 15 changed files with 291 additions and 80 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ rand = "0.8"
rustc-hash = "1"
flate2 = "1"
zstd = "0.13"
snap = "1"

[profile.dev]
opt-level = 1
Expand Down
46 changes: 37 additions & 9 deletions msg-socket/src/pub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use std::io;
use thiserror::Error;

mod driver;
use msg_wire::{compression::Compressor, pubsub};
use msg_wire::{
compression::{CompressionType, Compressor},
pubsub,
};
mod session;
mod socket;
mod stats;
Expand Down Expand Up @@ -43,6 +46,9 @@ pub struct PubOptions {
/// The maximum number of bytes that can be buffered in the session before being flushed.
/// This internally sets [`Framed::set_backpressure_boundary`](tokio_util::codec::Framed).
backpressure_boundary: usize,
/// Minimum payload size in bytes for compression to be used. If the payload is smaller than
/// this threshold, it will not be compressed.
min_compress_size: usize,
}

impl Default for PubOptions {
Expand All @@ -52,6 +58,7 @@ impl Default for PubOptions {
session_buffer_size: 1024,
flush_interval: Some(std::time::Duration::from_micros(50)),
backpressure_boundary: 8192,
min_compress_size: 8192,
}
}
}
Expand Down Expand Up @@ -83,12 +90,21 @@ impl PubOptions {
self.flush_interval = Some(flush_interval);
self
}

/// Sets the minimum payload size in bytes for compression to be used. If the payload is smaller than
/// this threshold, it will not be compressed.
pub fn min_compress_size(mut self, min_compress_size: usize) -> Self {
self.min_compress_size = min_compress_size;
self
}
}

/// A message received from a publisher.
/// Includes the source, topic, and payload.
#[derive(Debug, Clone)]
pub struct PubMessage {
/// The compression type used for the message payload.
compression_type: CompressionType,
/// The topic of the message.
topic: String,
/// The message payload.
Expand All @@ -98,7 +114,13 @@ pub struct PubMessage {
#[allow(unused)]
impl PubMessage {
pub fn new(topic: String, payload: Bytes) -> Self {
Self { topic, payload }
Self {
// Initialize the compression type to None.
// The actual compression type will be set in the `compress` method.
compression_type: CompressionType::None,
topic,
payload,
}
}

#[inline]
Expand All @@ -118,12 +140,18 @@ impl PubMessage {

#[inline]
pub fn into_wire(self, seq: u32) -> pubsub::Message {
pubsub::Message::new(seq, Bytes::from(self.topic), self.payload)
pubsub::Message::new(
seq,
Bytes::from(self.topic),
self.payload,
self.compression_type as u8,
)
}

#[inline]
pub fn compress(&mut self, compressor: &dyn Compressor) -> Result<(), io::Error> {
self.payload = compressor.compress(&self.payload)?;
self.compression_type = compressor.compression_type();

Ok(())
}
Expand All @@ -141,7 +169,7 @@ mod tests {

use futures::StreamExt;
use msg_transport::{Tcp, TcpOptions};
use msg_wire::compression::{GzipCompressor, GzipDecompressor};
use msg_wire::compression::GzipCompressor;

use crate::SubSocket;

Expand Down Expand Up @@ -216,16 +244,16 @@ mod tests {
async fn pubsub_many_compressed() {
let _ = tracing_subscriber::fmt::try_init();

let mut pub_socket = PubSocket::new(Tcp::new()).with_compressor(GzipCompressor::new(6));
let mut pub_socket =
PubSocket::with_options(Tcp::new(), PubOptions::default().min_compress_size(0))
.with_compressor(GzipCompressor::new(6));
let mut sub1 = SubSocket::new(Tcp::new_with_options(
TcpOptions::default().with_blocking_connect(),
))
.with_decompressor(GzipDecompressor::new());
));

let mut sub2 = SubSocket::new(Tcp::new_with_options(
TcpOptions::default().with_blocking_connect(),
))
.with_decompressor(GzipDecompressor::new());
));

pub_socket.bind("0.0.0.0:0").await.unwrap();
let addr = pub_socket.local_addr().unwrap();
Expand Down
24 changes: 13 additions & 11 deletions msg-socket/src/pub/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,19 @@ impl<T: ServerTransport> PubSocket<T> {
let mut msg = PubMessage::new(topic, message);

// We compress here since that way we only have to do it once.
if let Some(ref compressor) = self.compressor {
let len_before = msg.payload().len();

// For relatively small messages, this takes <100us
msg.compress(compressor.as_ref())?;

debug!(
"Compressed message from {} to {} bytes",
len_before,
msg.payload().len(),
);
// Compression is only done if the message is larger than the
// configured minimum payload size.
let len_before = msg.payload().len();
if len_before > self.options.min_compress_size {
if let Some(ref compressor) = self.compressor {
msg.compress(compressor.as_ref())?;

debug!(
"Compressed message from {} to {} bytes",
len_before,
msg.payload().len(),
);
}
}

// Broadcast the message directly to all active sessions.
Expand Down
22 changes: 7 additions & 15 deletions msg-socket/src/sub/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use super::{
};
use msg_common::unix_micros;
use msg_transport::ClientTransport;
use msg_wire::compression::Decompressor;
use msg_wire::pubsub;

type ConnectionResult<Io, E> = Result<(SocketAddr, Io), E>;
Expand All @@ -35,8 +34,6 @@ pub(crate) struct SubDriver<T: ClientTransport> {
pub(super) to_socket: mpsc::Sender<PubMessage>,
/// A joinset of authentication tasks.
pub(super) connection_tasks: JoinSet<ConnectionResult<T::Io, T::Error>>,
/// Optional payload decompressor.
pub(super) decompressor: Option<Arc<dyn Decompressor>>,
/// The set of subscribed topics.
pub(super) subscribed_topics: HashSet<String>,
/// All active publisher sessions for this subscriber socket.
Expand All @@ -59,18 +56,19 @@ where
if let Poll::Ready(Some((addr, result))) = this.publishers.poll_next_unpin(cx) {
match result {
Ok(mut msg) => {
if let Some(ref compressor) = this.decompressor {
let Ok(decompressed) = compressor.decompress(&msg.payload) else {
match msg.try_decompress() {
None => { /* No decompression necessary */ }
Some(Ok(decompressed)) => msg.payload = decompressed,
Some(Err(e)) => {
error!(
topic = msg.topic.as_str(),
"Failed to decompress message payload"
"Failed to decompress message payload: {:?}", e
);

continue;
};

msg.payload = decompressed;
}
}

this.on_message(PubMessage::new(addr, msg.topic, msg.payload));
}
Err(e) => {
Expand Down Expand Up @@ -109,12 +107,6 @@ impl<T> SubDriver<T>
where
T: ClientTransport + Send + Sync + 'static,
{
/// Sets the payload decompressor for the socket. This decompressor will be used to decompress all incoming
/// messages from the publishers.
pub fn set_decompressor<C: Decompressor>(&mut self, decompressor: C) {
self.decompressor = Some(Arc::new(decompressor));
}

fn on_command(&mut self, cmd: Command) {
debug!("Received command: {:?}", cmd);
match cmd {
Expand Down
13 changes: 0 additions & 13 deletions msg-socket/src/sub/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use tokio::{sync::mpsc, task::JoinSet};
use tokio_stream::StreamMap;

use msg_transport::ClientTransport;
use msg_wire::compression::Decompressor;

use super::{
Command, PubMessage, SocketState, SocketStats, SubDriver, SubError, SubOptions,
Expand Down Expand Up @@ -53,7 +52,6 @@ where
transport: Arc::new(transport),
from_socket,
to_socket,
decompressor: None,
connection_tasks: JoinSet::new(),
publishers: StreamMap::with_capacity(24),
subscribed_topics: HashSet::with_capacity(32),
Expand All @@ -70,17 +68,6 @@ where
}
}

/// Sets the payload decompressor for the socket. This decompressor will be used to decompress
/// all incoming messages from the publishers.
pub fn with_decompressor<C: Decompressor>(mut self, decompressor: C) -> Self {
self.driver
.as_mut()
.expect("Driver has been spawned already, cannot set compressor")
.set_decompressor(decompressor);

self
}

/// Asynchronously connects to the endpoint.
pub async fn connect(&mut self, endpoint: &str) -> Result<(), SubError> {
self.ensure_active_driver();
Expand Down
39 changes: 36 additions & 3 deletions msg-socket/src/sub/stream.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
use bytes::Bytes;
use futures::{SinkExt, Stream, StreamExt};
use std::{
io,
pin::Pin,
task::{ready, Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::Framed;
use tracing::debug;
use tracing::{debug, trace};

use super::SubError;
use msg_wire::pubsub;
use msg_wire::{
compression::{
CompressionType, Decompressor, GzipDecompressor, SnappyDecompressor, ZstdDecompressor,
},
pubsub,
};

/// Wraps a framed connection to a publisher and exposes all the PUBSUB specific methods.
pub(super) struct PublisherStream<Io> {
Expand Down Expand Up @@ -49,30 +55,57 @@ impl<Io: AsyncRead + AsyncWrite + Unpin> PublisherStream<Io> {

pub(super) struct TopicMessage {
pub timestamp: u64,
pub compression_type: u8,
pub topic: String,
pub payload: Bytes,
}

impl TopicMessage {
/// Tries to decompress the message payload if necessary.
///
/// - Returns `Some(Ok(Bytes))` if the payload is compressed and decompression succeeded.
/// - Returns `Some(Err(..))` if the payload is compressed but could not be decompressed.
/// - Returns `None` if the payload is not compressed.
pub fn try_decompress(&self) -> Option<Result<Bytes, io::Error>> {
match CompressionType::try_from(self.compression_type) {
Ok(supported_compression_type) => match supported_compression_type {
CompressionType::None => None,
// NOTE: Decompressors are unit structs, so there is no allocation here
CompressionType::Gzip => Some(GzipDecompressor.decompress(&self.payload)),
CompressionType::Zstd => Some(ZstdDecompressor.decompress(&self.payload)),
CompressionType::Snappy => Some(SnappyDecompressor.decompress(&self.payload)),
},
Err(unsupported_compression_type) => Some(Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unsupported compression type: {unsupported_compression_type}"),
))),
}
}
}

impl<Io: AsyncRead + AsyncWrite + Unpin> Stream for PublisherStream<Io> {
type Item = Result<TopicMessage, pubsub::Error>;

#[inline]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();

// We set flush to false only when flush returns ready (i.e. the buffer is fully flushed)
if this.flush && this.conn.poll_flush_unpin(cx).is_ready() {
tracing::trace!("Flushed connection");
trace!("Flushed connection");
this.flush = false
}

if let Some(result) = ready!(this.conn.poll_next_unpin(cx)) {
return Poll::Ready(Some(result.map(|msg| {
let timestamp = msg.timestamp();
let compression_type = msg.compression_type();
let (topic, payload) = msg.into_parts();

// TODO: this will allocate. Can we just return the `Cow`?
let topic = String::from_utf8_lossy(&topic).to_string();
TopicMessage {
compression_type,
timestamp,
topic,
payload,
Expand Down
1 change: 1 addition & 0 deletions msg-wire/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ tokio-util.workspace = true
tracing.workspace = true
flate2.workspace = true
zstd.workspace = true
snap.workspace = true
12 changes: 5 additions & 7 deletions msg-wire/src/compression/gzip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use bytes::Bytes;
use flate2::{read::GzDecoder, write::GzEncoder, Compression};
use std::io::{self, Read, Write};

use super::{Compressor, Decompressor};
use super::{CompressionType, Compressor, Decompressor};

/// A compressor that uses the gzip algorithm.
pub struct GzipCompressor {
Expand All @@ -17,6 +17,10 @@ impl GzipCompressor {
}

impl Compressor for GzipCompressor {
fn compression_type(&self) -> CompressionType {
CompressionType::Gzip
}

fn compress(&self, data: &[u8]) -> Result<Bytes, io::Error> {
// Optimistically allocate the compressed buffer to 1/4 of the original size.
let mut encoder = GzEncoder::new(
Expand All @@ -35,12 +39,6 @@ impl Compressor for GzipCompressor {
#[derive(Debug, Default)]
pub struct GzipDecompressor;

impl GzipDecompressor {
pub fn new() -> Self {
Self
}
}

impl Decompressor for GzipDecompressor {
fn decompress(&self, data: &[u8]) -> Result<Bytes, io::Error> {
let mut decoder = GzDecoder::new(data);
Expand Down
Loading

0 comments on commit b282179

Please sign in to comment.