Skip to content

Commit

Permalink
Do not seek back on error
Browse files Browse the repository at this point in the history
  • Loading branch information
antoyo committed Feb 6, 2024
1 parent e3e0431 commit f75bcf8
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 75 deletions.
102 changes: 28 additions & 74 deletions src/stream/read/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,47 +82,15 @@ fn consume<R: Read + ?Sized>(this: &mut R, mut bytes_count: usize) -> io::Result
}
}

/// Like Read::read_exact(), but seek back to the starting position of the reader in case of an
/// error.
#[cfg(feature = "experimental")]
fn read_exact_or_seek_back<R: Read + Seek + ?Sized>(this: &mut R, mut buf: &mut [u8]) -> io::Result<()> {
let mut bytes_read = 0;
while !buf.is_empty() {
match this.read(buf) {
Ok(0) => break,
Ok(n) => {
bytes_read += n as i64;
let tmp = buf;
buf = &mut tmp[n..];
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => {
if let Err(error) = this.seek(SeekFrom::Current(-bytes_read)) {
panic!("Error while seeking back to the start: {}", error);
}
return Err(e)
},
}
}
if !buf.is_empty() {
if let Err(error) = this.seek(SeekFrom::Current(-bytes_read)) {
panic!("Error while seeking back to the start: {}", error);
}
Err(io::Error::new(io::ErrorKind::UnexpectedEof, "failed to fill whole buffer"))
} else {
Ok(())
}
}

#[cfg(feature = "experimental")]
impl<'a, R: Read + Seek> Decoder<'a, BufReader<R>> {
fn read_skippable_frame_size(&mut self) -> io::Result<usize> {
let mut magic_buffer = [0u8; U32_SIZE];
read_exact_or_seek_back(self.reader.reader_mut(), &mut magic_buffer)?;
self.reader.reader_mut().read_exact(&mut magic_buffer)?;

// Read skippable frame size.
let mut buffer = [0u8; U32_SIZE];
read_exact_or_seek_back(self.reader.reader_mut(), &mut buffer)?;
self.reader.reader_mut().read_exact(&mut buffer)?;
let content_size = u32::from_le_bytes(buffer) as usize;

self.seek_back(U32_SIZE * 2);
Expand All @@ -139,47 +107,27 @@ impl<'a, R: Read + Seek> Decoder<'a, BufReader<R>> {
/// Attempt to read a skippable frame and write its content to `dest`.
/// If it cannot read a skippable frame, the reader will be back to its starting position.
pub fn read_skippable_frame(&mut self, dest: &mut [u8]) -> io::Result<(usize, MagicVariant)> {
let mut bytes_to_seek = 0;

let res = (|| {
let mut magic_buffer = [0u8; U32_SIZE];
read_exact_or_seek_back(self.reader.reader_mut(), &mut magic_buffer)?;
let magic_number = u32::from_le_bytes(magic_buffer);

// Read skippable frame size.
let mut buffer = [0u8; U32_SIZE];
read_exact_or_seek_back(self.reader.reader_mut(), &mut buffer)?;
let content_size = u32::from_le_bytes(buffer) as usize;

let op = self.reader.operation();
// FIXME: I feel like we should do that check right after reading the magic number, but
// ZSTD does it after reading the content size.
if !op.is_skippable_frame(&magic_buffer) {
bytes_to_seek = U32_SIZE * 2;
return Err(io::Error::new(io::ErrorKind::Other, "Unsupported frame parameter"));
}
if content_size > dest.len() {
bytes_to_seek = U32_SIZE * 2;
return Err(io::Error::new(io::ErrorKind::Other, "Destination buffer is too small"));
}
let magic_buffer = self.reader.peek_4bytes()?;
let op = self.reader.operation();
if !op.is_skippable_frame(&magic_buffer) {
return Err(io::Error::new(io::ErrorKind::Other, "Unsupported frame parameter"));
}
self.reader.clear_peeked_data();

if content_size > 0 {
read_exact_or_seek_back(self.reader.reader_mut(), &mut dest[..content_size])?;
}
let magic_number = u32::from_le_bytes(magic_buffer);

// Read skippable frame size.
let mut buffer = [0u8; U32_SIZE];
self.reader.reader_mut().read_exact(&mut buffer)?;
let content_size = u32::from_le_bytes(buffer) as usize;

Ok((magic_number, content_size))
})();
if content_size > dest.len() {
return Err(io::Error::new(io::ErrorKind::Other, "Destination buffer is too small"));
}

let (magic_number, content_size) =
match res {
Ok(data) => data,
Err(err) => {
if bytes_to_seek != 0 {
self.seek_back(bytes_to_seek);
}
return Err(err);
},
};
if content_size > 0 {
self.reader.reader_mut().read_exact(&mut dest[..content_size])?;
}

let magic_variant = magic_number - MAGIC_SKIPPABLE_START;

Expand All @@ -202,7 +150,13 @@ impl<'a, R: Read + Seek> Decoder<'a, BufReader<R>> {

// TODO: should we support legacy format?
let mut magic_buffer = [0u8; U32_SIZE];
self.reader.reader_mut().read_exact(&mut magic_buffer)?;
if self.reader.peeking() {
magic_buffer = self.reader.peeked_data();
self.reader.clear_peeked_data();
}
else {
self.reader.reader_mut().read_exact(&mut magic_buffer)?;
}
let magic_number = u32::from_le_bytes(magic_buffer);
self.seek_back(U32_SIZE);
if magic_number & MAGIC_SKIPPABLE_MASK == MAGIC_SKIPPABLE_START {
Expand Down Expand Up @@ -240,7 +194,7 @@ impl<'a, R: Read + Seek> Decoder<'a, BufReader<R>> {
use crate::map_error_code;
const MAX_FRAME_HEADER_SIZE_PREFIX: usize = 5;
let mut buffer = [0u8; MAX_FRAME_HEADER_SIZE_PREFIX];
read_exact_or_seek_back(self.reader.reader_mut(), &mut buffer)?;
self.reader.reader_mut().read_exact(&mut buffer)?;
let size = frame_header_size(&buffer)
.map_err(map_error_code)?;
let byte = buffer[MAX_FRAME_HEADER_SIZE_PREFIX - 1];
Expand Down
44 changes: 43 additions & 1 deletion src/stream/zio/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ pub struct Reader<R, D> {

single_frame: bool,
finished_frame: bool,

peeking: bool,
peeked_data: [u8; 4],
}

enum State {
Expand All @@ -39,6 +42,8 @@ impl<R, D> Reader<R, D> {
state: State::Reading,
single_frame: false,
finished_frame: false,
peeking: false,
peeked_data: [0; 4],
}
}

Expand Down Expand Up @@ -81,7 +86,37 @@ impl<R, D> Reader<R, D> {
{
self.operation.flush(&mut OutBuffer::around(output))
}

/// Read some data, but do not consume it.
pub fn peek_4bytes(&mut self) -> io::Result<[u8; 4]>
where
R: BufRead,
D: Operation,
{
if !self.peeking {
self.reader.read_exact(&mut self.peeked_data)?;
self.peeking = true;
}

Ok(self.peeked_data)
}

/// Clear the peeked data.
pub fn clear_peeked_data(&mut self) {
self.peeking = false;
}

/// Check if there is currently any peeked data.
pub fn peeking(&self) -> bool {
self.peeking
}

/// Get the peeked data.
pub fn peeked_data(&self) -> [u8; 4] {
self.peeked_data
}
}

// Read and retry on Interrupted errors.
fn fill_buf<R>(reader: &mut R) -> io::Result<&[u8]>
where
Expand Down Expand Up @@ -118,12 +153,17 @@ where
loop {
match self.state {
State::Reading => {
let is_peeking = self.peeking;

let (bytes_read, bytes_written) = {
// Start with a fresh pool of un-processed data.
// This is the only line that can return an interruption error.
let input = if first {
// eprintln!("First run, no input coming.");
b""
} else if self.peeking {
self.clear_peeked_data();
&self.peeked_data
} else {
fill_buf(&mut self.reader)?
};
Expand Down Expand Up @@ -170,7 +210,9 @@ where
(src.pos(), dst.pos())
};

self.reader.consume(bytes_read);
if !is_peeking {
self.reader.consume(bytes_read);
}

if bytes_written > 0 {
return Ok(bytes_written);
Expand Down
6 changes: 6 additions & 0 deletions src/stream/zio/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ mod tests {
};

let mut target = vec![];
assert!(decoder.read_skippable_frame(&mut frame).is_err());
assert!(decoder.read_skippable_frame(&mut frame).is_err());
io::copy(&mut decoder, &mut target).unwrap();
assert_eq!("compressed frame 1", String::from_utf8(target).unwrap());

Expand All @@ -371,6 +373,8 @@ mod tests {
let (size, _) = decoder.read_skippable_frame(&mut frame).unwrap();
assert_eq!("SKIP", String::from_utf8_lossy(&frame[..size]));

assert!(decoder.read_skippable_frame(&mut frame).is_err());
assert!(decoder.read_skippable_frame(&mut frame).is_err());
decoder.skip_frame().unwrap();

let inner = decoder.finish();
Expand All @@ -391,6 +395,8 @@ mod tests {
};

let mut target = vec![];
assert!(decoder.read_skippable_frame(&mut frame).is_err());
assert!(decoder.read_skippable_frame(&mut frame).is_err());
io::copy(&mut decoder, &mut target).unwrap();
assert_eq!("compressed frame 3", String::from_utf8(target).unwrap());
}
Expand Down

0 comments on commit f75bcf8

Please sign in to comment.