Skip to content

Commit

Permalink
fix lots of bugs, remove bytecheck, upgraded examples
Browse files Browse the repository at this point in the history
  • Loading branch information
zyansheep committed Apr 26, 2023
1 parent f30ad68 commit 214768f
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 93 deletions.
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ repository = "https://github.com/zyansheep/rkyv_codec"

[dependencies]
rkyv = { version = "0.7", features = ["validation"] }
bytecheck = "0.6"
futures = "0.3"
pin-project = "1.0"
unsigned-varint = { version = "0.7", features = ["futures"] }
Expand All @@ -20,10 +19,13 @@ bytes-old = { package = "bytes", version = "0.5", optional = true }
thiserror = "1.0"

[dev-dependencies]
# Runtime
async-std = { version = "1.12.0", features = ["attributes"] }
# For tests & benchmark comparisons
serde = "1.0.152"
lazy_static = "1.4.0"
# Examples
# For examples
async-broadcast = "0.5.1"
rustyline-async = "0.3.0"
anyhow = "1.0.68"

Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ To run:
Simple usage example:
```rust
#[derive(Archive, Deserialize, Serialize, Debug, PartialEq, Clone)]
#[archive_attr(derive(CheckBytes, Debug))] // Checkbytes is required
#[archive(check_bytes)] // check_bytes is required
#[archive_attr(derive(Debug))]
struct Test {
int: u8,
string: String,
Expand Down
7 changes: 3 additions & 4 deletions examples/chat_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::{fmt, io::Write};
use async_std::{io, net::TcpStream};
use futures::{FutureExt, SinkExt};

use bytecheck::CheckBytes;
use rkyv::{AlignedVec, Archive, Deserialize, Infallible, Serialize};

use rkyv_codec::{archive_stream, RkyvWriter, VarintLength};
Expand All @@ -12,9 +11,9 @@ use rustyline_async::{Readline, ReadlineError};

#[derive(Archive, Deserialize, Serialize, Debug, PartialEq, Clone)]
// This will generate a PartialEq impl between our unarchived and archived types
#[archive(compare(PartialEq))]
// To use the safe API, you have to derive CheckBytes for the archived type
#[archive_attr(derive(CheckBytes, Debug))]
// To use the safe API, you must use the check_bytes option for the archive
#[archive(compare(PartialEq), check_bytes)]
#[archive_attr(derive(Debug))]
struct ChatMessage {
sender: Option<String>,
message: String,
Expand Down
78 changes: 34 additions & 44 deletions examples/chat_server.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
use std::{net::SocketAddr, sync::Arc};
use std::net::SocketAddr;

use anyhow::Context;
use async_broadcast::{Receiver, Sender};
use async_std::{
channel::{bounded, Receiver, Sender, TrySendError},
io,
net::{TcpListener, TcpStream},
sync::Mutex,
task,
};
use futures::{prelude::*, SinkExt, StreamExt};

use bytecheck::CheckBytes;
use rkyv::{AlignedVec, Archive, Deserialize, Infallible, Serialize};

use anyhow::Context;

use rkyv_codec::{archive_stream, RkyvWriter, VarintLength};

#[derive(Archive, Deserialize, Serialize, Debug, PartialEq, Clone)]
// This will generate a PartialEq impl between our unarchived and archived types
#[archive(compare(PartialEq))]
// To use the safe API, you have to derive CheckBytes for the archived type
#[archive_attr(derive(CheckBytes, Debug))]
// To use the safe API, you have to enable the check_bytes option for the archive
#[archive(compare(PartialEq), check_bytes)]
#[archive_attr(derive(Debug))]
struct ChatMessage {
sender: Option<String>,
message: String,
}

// Process a given TcpStream
async fn process(
stream: TcpStream,
outgoing: Sender<ChatMessage>,
Expand All @@ -34,26 +32,26 @@ async fn process(
) -> anyhow::Result<()> {
println!("[{addr}] Joined Server");
outgoing
.send(ChatMessage {
.broadcast(ChatMessage {
sender: Some("Sever".to_owned()),
message: format!("{} Joined the Chat!", addr),
})
.await?;

let mut reader = stream.clone();

let mut writer = RkyvWriter::<_, VarintLength>::new(stream);

let mut buffer = AlignedVec::new();

loop {
futures::select! {
// Read incoming messages
archive = archive_stream::<_, ChatMessage, VarintLength>(&mut reader, &mut buffer).fuse() => match archive {
Ok(archive) => {
let mut msg: ChatMessage = archive.deserialize(&mut Infallible)?;
msg.sender = Some(format!("{addr}"));
println!("[{addr}] sent {msg:?}");
outgoing.send(msg).await?;
outgoing.broadcast(msg).await?;
}
_ => break,
},
Expand All @@ -63,6 +61,14 @@ async fn process(
}
}

println!("[{addr}] Left Server");
outgoing
.broadcast(ChatMessage {
sender: Some("Sever".to_owned()),
message: format!("{} Left the Chat!", addr),
})
.await?;

Ok(())
}

Expand All @@ -73,42 +79,26 @@ async fn main() -> io::Result<()> {

let mut incoming = listener.incoming();

let (broadcast_sender, message_receiver) = bounded::<ChatMessage>(20);

// Broadcast incoming messages to everyone else connected to the server.
let outgoing_send_list = Arc::new(Mutex::new(Vec::<Sender<ChatMessage>>::new()));
let sender_list = outgoing_send_list.clone();
task::spawn(async move {
while let Ok(msg) = message_receiver.recv().await {
outgoing_send_list.lock().await.retain(|sender| {
if let Err(TrySendError::Closed(_)) = sender.try_send(msg.clone()) {
false
} else {
true
}
})
}
});
// Broadcast channels
let (broadcast_sender, broadcast_receiver) = async_broadcast::broadcast::<ChatMessage>(20);

// Listen for incoming connections
while let Some(stream) = incoming.next().await {
let stream = match stream {
Ok(stream) => stream,
Err(err) => {
println!("error: {err}");
continue;
match stream {
Ok(stream) => {
let outgoing = broadcast_sender.clone();
let incoming = broadcast_receiver.clone();

task::spawn(async move {
let addr = stream.peer_addr().unwrap();
if let Err(err) = process(stream, outgoing, incoming, &addr).await {
println!("[{addr}] error: {err}")
}
});
}
Err(err) => println!("error: {err}"),
};
let outgoing = broadcast_sender.clone();

let (sender, incoming) = bounded(20);

sender_list.lock().await.push(sender);
task::spawn(async move {
let addr = stream.peer_addr().unwrap();
if let Err(err) = process(stream, outgoing, incoming, &addr).await {
println!("[{addr}] error: {err}")
}
});
}

Ok(())
}
3 changes: 1 addition & 2 deletions src/futures_stream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::marker::PhantomData;

use bytecheck::CheckBytes;
use bytes_old::Buf;
use futures_codec::{BytesMut, Decoder, Encoder};
use rkyv::{
Expand Down Expand Up @@ -70,7 +69,7 @@ where
impl<Packet, L: LengthCodec> Decoder for RkyvCodec<Packet, L>
where
Packet: Archive + 'static,
Packet::Archived: for<'b> CheckBytes<rkyv::validation::validators::DefaultValidator<'b>>
Packet::Archived: for<'b> rkyv::CheckBytes<rkyv::validation::validators::DefaultValidator<'b>>
+ Deserialize<Packet, Infallible>,
{
type Item = Packet;
Expand Down
14 changes: 6 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
//! ```rust
//! # use rkyv::{Infallible, Archived, AlignedVec, Archive, Serialize, Deserialize};
//! # use rkyv_codec::{archive_stream, RkyvWriter, VarintLength};
//! # use bytecheck::CheckBytes;
//! # use futures::SinkExt;
//! # async_std::task::block_on(async {
//! #[derive(Archive, Deserialize, Serialize, Debug, PartialEq, Clone)]
//! #[archive_attr(derive(CheckBytes, Debug))] // Checkbytes is required
//! #[archive(check_bytes)] // check_bytes is required
//! #[archive_attr(derive(Debug))]
//! struct Test {
//! int: u8,
//! string: String,
Expand Down Expand Up @@ -129,8 +129,7 @@ mod no_std_feature {
) -> Result<&'b Archived<Packet>, RkyvCodecError>
where
Packet: rkyv::Archive,
Packet::Archived:
bytecheck::CheckBytes<rkyv::validation::validators::DefaultValidator<'b>> + 'b,
Packet::Archived: rkyv::CheckBytes<rkyv::validation::validators::DefaultValidator<'b>> + 'b,
{
// Read length
let mut length_buf = L::Buffer::default();
Expand Down Expand Up @@ -161,7 +160,6 @@ mod tests {
extern crate test;

use async_std::task::block_on;
use bytecheck::CheckBytes;
use bytes::BytesMut;
use futures::{io::Cursor, AsyncRead, AsyncWrite, SinkExt, StreamExt, TryStreamExt};
use futures_codec::{CborCodec, Framed};
Expand All @@ -185,9 +183,9 @@ mod tests {
serde::Deserialize,
)]
// This will generate a PartialEq impl between our unarchived and archived types
#[archive(compare(PartialEq))]
// To use the safe API, you have to derive CheckBytes for the archived type
#[archive_attr(derive(CheckBytes, Debug))]
// To use the safe API, you have to use the check_bytes option for the archive
#[archive(compare(PartialEq), check_bytes)]
#[archive_attr(derive(Debug))]
struct Test {
int: u8,
string: String,
Expand Down
68 changes: 36 additions & 32 deletions src/rkyv_codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use std::{

use pin_project::pin_project;

use bytecheck::CheckBytes;
use rkyv::{
ser::{
serializers::{
Expand Down Expand Up @@ -36,7 +35,13 @@ pub async fn archive_sink<'b, Inner: AsyncWrite + Unpin, L: LengthCodec>(
}
/// Reads a single `&Archived<Object>` from an `AsyncRead` without checking for correct byte formatting
/// # Safety
/// This will cause undefined behavior if the bytestream is not the correct format (i.e. not generated through `archive_sink[_bytes]`, `RkyvWriter`, or `RkyvCodec`) with the correct LengthCodec
/// This may cause undefined behavior if the bytestream is not a valid archive (i.e. not generated through `archive_sink[_bytes]`, or `RkyvWriter`)
///
/// As an optimisation, this function may pass uninitialized bytes to the reader for the reader to read into. Make sure the particular reader in question is implemented correctly and does not read from its passed buffer in the poll_read() function without first writing to it.
/// # Warning
/// Passed buffer is reallocated so it may fit the size of the packet being written. This may allow for DOS attacks if remote sends too large a length encoding
/// # Errors
/// Will return an error if there are not enough bytes to read to read the length of the packet, or read the packet itself. Will also return an error if the length encoding format is invalid.
pub async unsafe fn archive_stream_unsafe<
'b,
Inner: AsyncRead + Unpin,
Expand All @@ -52,57 +57,57 @@ pub async unsafe fn archive_stream_unsafe<
let mut length_buf = L::Buffer::default();
let length_buf = L::as_slice(&mut length_buf);
inner.read_exact(&mut *length_buf).await?;
let (archive_len, remaining) =
let (archive_len, unused) =
L::decode(length_buf).map_err(|_| RkyvCodecError::ReadLengthError)?;

// Reserve buffer
// Reserve enough bytes in buffer to contain
buffer.reserve(archive_len - buffer.len()); // Reserve at least the amount of bytes needed
// Safety: Already reserved the required space
unsafe {
buffer.set_len(archive_len);

// If not enough capacity in buffer to fit `archive_len`, reserve more.
if buffer.capacity() < archive_len {
buffer.reserve(buffer.capacity() - archive_len)
}
// Write any potentially unused bytes from length_buf to buffer
buffer.extend_from_slice(unused);

// Safety: Caller should make sure that reader does not read from this potentially uninitialized buffer passed to poll_read()
unsafe { buffer.set_len(archive_len) }

buffer[0..remaining.len()].copy_from_slice(remaining); // Copy unread length_buf bytes to buffer
// Read into buffer, after any unused length bytes
inner.read_exact(&mut buffer[unused.len()..]).await?;

inner.read_exact(&mut buffer[remaining.len()..]).await?;
// Safety: Caller should make sure that reader does not produce invalid packets.
unsafe { Ok(rkyv::archived_root::<Packet>(buffer)) }
}

/// Reads a single `&Archived<Object>` from an `AsyncRead` using the passed buffer.
///
/// Until streaming iterators (and streaming futures) are implemented in rust, this currently the fastest method I could come up with that requires no recurring heap allocations.
///
/// Requires rkyv validation feature & CheckBytes
/// Requires rkyv "validation" feature
/// # Safety
/// As an optimisation, this function may pass uninitialized bytes to the reader for the reader to read into. Make sure the particular reader in question is implemented correctly and does not read from its passed buffer in the poll_read() function without first writing to it.
/// # Warning
/// Passed buffer is reallocated so it may fit the size of the packet being written. This may allow for DOS attacks if remote sends too large a length encoding
/// # Errors
/// Will return an error if there are not enough bytes to read to read the length of the packet, or read the packet itself. Will also return an error if the length encoding format is invalid or the packet archive itself is invalid.
pub async fn archive_stream<'b, Inner: AsyncRead + Unpin, Packet, L: LengthCodec>(
inner: &mut Inner,
buffer: &'b mut AlignedVec,
) -> Result<&'b Archived<Packet>, RkyvCodecError>
where
Packet: rkyv::Archive,
Packet::Archived: CheckBytes<DefaultValidator<'b>> + 'b,
Packet::Archived: rkyv::CheckBytes<DefaultValidator<'b>>,
{
buffer.clear();

// Read length
let mut length_buf = L::Buffer::default();
let length_buf = L::as_slice(&mut length_buf);
inner.read_exact(length_buf).await?;
let (archive_len, remaining) =
L::decode(length_buf).map_err(|_| RkyvCodecError::ReadLengthError)?;

// Reserve buffer
buffer.reserve(archive_len - buffer.len()); // Reserve at least the amount of bytes needed for packet
// Safety: Already reserved the required space
// Safety: This should not trigger undefined behavior as even if the packet in question is an invalid archive, the archive is not actually read from.
// Safety: Even though this is not an unsafe function, it may under some circumstances produce undefined behavior if the AsyncRead implementation is bad. Make sure the `<Inner as AsyncRead>::poll_read()` implementation does not read from the passed `buf` before writing to it first.
unsafe {
buffer.set_len(archive_len);
let _ = archive_stream_unsafe::<Inner, Packet, L>(inner, buffer).await?;
}

// Read into aligned buffer
buffer[0..remaining.len()].copy_from_slice(remaining); // Copy unread length_buf bytes to buffer
inner.read_exact(&mut buffer[remaining.len()..]).await?; // Read into buffer after the appended bytes

let archive = rkyv::check_archived_root::<'b, Packet>(buffer)
.map_err(|_| RkyvCodecError::CheckArchiveError)?;

Ok(archive)
}

Expand Down Expand Up @@ -225,7 +230,6 @@ mod tests {
extern crate test;

use async_std::task::block_on;
use bytecheck::CheckBytes;
use futures::{io::Cursor, AsyncRead, AsyncWrite, SinkExt, StreamExt, TryStreamExt};
use futures_codec::{CborCodec, Framed};
use rkyv::{to_bytes, AlignedVec, Archive, Archived, Deserialize, Infallible, Serialize};
Expand All @@ -248,9 +252,9 @@ mod tests {
serde::Deserialize,
)]
// This will generate a PartialEq impl between our unarchived and archived types
#[archive(compare(PartialEq))]
// To use the safe API, you have to derive CheckBytes for the archived type
#[archive_attr(derive(CheckBytes, Debug))]
// To use the safe API, you have to use the check_byte option for the archived type
#[archive(compare(PartialEq), check_bytes)]
#[archive_attr(derive(Debug))]
struct Test {
int: u8,
string: String,
Expand Down

0 comments on commit 214768f

Please sign in to comment.