diff --git a/src/stream/read/mod.rs b/src/stream/read/mod.rs index 5aea774f..e42bfc49 100644 --- a/src/stream/read/mod.rs +++ b/src/stream/read/mod.rs @@ -82,47 +82,15 @@ fn consume(this: &mut R, mut bytes_count: usize) -> io::Result } } -/// Like Read::read_exact(), but seek back to the starting position of the reader in case of an -/// error. -#[cfg(feature = "experimental")] -fn read_exact_or_seek_back(this: &mut R, mut buf: &mut [u8]) -> io::Result<()> { - let mut bytes_read = 0; - while !buf.is_empty() { - match this.read(buf) { - Ok(0) => break, - Ok(n) => { - bytes_read += n as i64; - let tmp = buf; - buf = &mut tmp[n..]; - } - Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} - Err(e) => { - if let Err(error) = this.seek(SeekFrom::Current(-bytes_read)) { - panic!("Error while seeking back to the start: {}", error); - } - return Err(e) - }, - } - } - if !buf.is_empty() { - if let Err(error) = this.seek(SeekFrom::Current(-bytes_read)) { - panic!("Error while seeking back to the start: {}", error); - } - Err(io::Error::new(io::ErrorKind::UnexpectedEof, "failed to fill whole buffer")) - } else { - Ok(()) - } -} - #[cfg(feature = "experimental")] impl<'a, R: Read + Seek> Decoder<'a, BufReader> { fn read_skippable_frame_size(&mut self) -> io::Result { let mut magic_buffer = [0u8; U32_SIZE]; - read_exact_or_seek_back(self.reader.reader_mut(), &mut magic_buffer)?; + self.reader.reader_mut().read_exact(&mut magic_buffer)?; // Read skippable frame size. let mut buffer = [0u8; U32_SIZE]; - read_exact_or_seek_back(self.reader.reader_mut(), &mut buffer)?; + self.reader.reader_mut().read_exact(&mut buffer)?; let content_size = u32::from_le_bytes(buffer) as usize; self.seek_back(U32_SIZE * 2); @@ -139,47 +107,27 @@ impl<'a, R: Read + Seek> Decoder<'a, BufReader> { /// Attempt to read a skippable frame and write its content to `dest`. /// If it cannot read a skippable frame, the reader will be back to its starting position. pub fn read_skippable_frame(&mut self, dest: &mut [u8]) -> io::Result<(usize, MagicVariant)> { - let mut bytes_to_seek = 0; - - let res = (|| { - let mut magic_buffer = [0u8; U32_SIZE]; - read_exact_or_seek_back(self.reader.reader_mut(), &mut magic_buffer)?; - let magic_number = u32::from_le_bytes(magic_buffer); - - // Read skippable frame size. - let mut buffer = [0u8; U32_SIZE]; - read_exact_or_seek_back(self.reader.reader_mut(), &mut buffer)?; - let content_size = u32::from_le_bytes(buffer) as usize; - - let op = self.reader.operation(); - // FIXME: I feel like we should do that check right after reading the magic number, but - // ZSTD does it after reading the content size. - if !op.is_skippable_frame(&magic_buffer) { - bytes_to_seek = U32_SIZE * 2; - return Err(io::Error::new(io::ErrorKind::Other, "Unsupported frame parameter")); - } - if content_size > dest.len() { - bytes_to_seek = U32_SIZE * 2; - return Err(io::Error::new(io::ErrorKind::Other, "Destination buffer is too small")); - } + let magic_buffer = self.reader.peek_4bytes()?; + let op = self.reader.operation(); + if !op.is_skippable_frame(&magic_buffer) { + return Err(io::Error::new(io::ErrorKind::Other, "Unsupported frame parameter")); + } + self.reader.clear_peeked_data(); - if content_size > 0 { - read_exact_or_seek_back(self.reader.reader_mut(), &mut dest[..content_size])?; - } + let magic_number = u32::from_le_bytes(magic_buffer); + + // Read skippable frame size. + let mut buffer = [0u8; U32_SIZE]; + self.reader.reader_mut().read_exact(&mut buffer)?; + let content_size = u32::from_le_bytes(buffer) as usize; - Ok((magic_number, content_size)) - })(); + if content_size > dest.len() { + return Err(io::Error::new(io::ErrorKind::Other, "Destination buffer is too small")); + } - let (magic_number, content_size) = - match res { - Ok(data) => data, - Err(err) => { - if bytes_to_seek != 0 { - self.seek_back(bytes_to_seek); - } - return Err(err); - }, - }; + if content_size > 0 { + self.reader.reader_mut().read_exact(&mut dest[..content_size])?; + } let magic_variant = magic_number - MAGIC_SKIPPABLE_START; @@ -202,7 +150,13 @@ impl<'a, R: Read + Seek> Decoder<'a, BufReader> { // TODO: should we support legacy format? let mut magic_buffer = [0u8; U32_SIZE]; - self.reader.reader_mut().read_exact(&mut magic_buffer)?; + if self.reader.peeking() { + magic_buffer = self.reader.peeked_data(); + self.reader.clear_peeked_data(); + } + else { + self.reader.reader_mut().read_exact(&mut magic_buffer)?; + } let magic_number = u32::from_le_bytes(magic_buffer); self.seek_back(U32_SIZE); if magic_number & MAGIC_SKIPPABLE_MASK == MAGIC_SKIPPABLE_START { @@ -240,7 +194,7 @@ impl<'a, R: Read + Seek> Decoder<'a, BufReader> { use crate::map_error_code; const MAX_FRAME_HEADER_SIZE_PREFIX: usize = 5; let mut buffer = [0u8; MAX_FRAME_HEADER_SIZE_PREFIX]; - read_exact_or_seek_back(self.reader.reader_mut(), &mut buffer)?; + self.reader.reader_mut().read_exact(&mut buffer)?; let size = frame_header_size(&buffer) .map_err(map_error_code)?; let byte = buffer[MAX_FRAME_HEADER_SIZE_PREFIX - 1]; diff --git a/src/stream/zio/reader.rs b/src/stream/zio/reader.rs index 457f01bc..de80172f 100644 --- a/src/stream/zio/reader.rs +++ b/src/stream/zio/reader.rs @@ -17,6 +17,9 @@ pub struct Reader { single_frame: bool, finished_frame: bool, + + peeking: bool, + peeked_data: [u8; 4], } enum State { @@ -39,6 +42,8 @@ impl Reader { state: State::Reading, single_frame: false, finished_frame: false, + peeking: false, + peeked_data: [0; 4], } } @@ -81,7 +86,37 @@ impl Reader { { self.operation.flush(&mut OutBuffer::around(output)) } + + /// Read some data, but do not consume it. + pub fn peek_4bytes(&mut self) -> io::Result<[u8; 4]> + where + R: BufRead, + D: Operation, + { + if !self.peeking { + self.reader.read_exact(&mut self.peeked_data)?; + self.peeking = true; + } + + Ok(self.peeked_data) + } + + /// Clear the peeked data. + pub fn clear_peeked_data(&mut self) { + self.peeking = false; + } + + /// Check if there is currently any peeked data. + pub fn peeking(&self) -> bool { + self.peeking + } + + /// Get the peeked data. + pub fn peeked_data(&self) -> [u8; 4] { + self.peeked_data + } } + // Read and retry on Interrupted errors. fn fill_buf(reader: &mut R) -> io::Result<&[u8]> where @@ -118,12 +153,17 @@ where loop { match self.state { State::Reading => { + let is_peeking = self.peeking; + let (bytes_read, bytes_written) = { // Start with a fresh pool of un-processed data. // This is the only line that can return an interruption error. let input = if first { // eprintln!("First run, no input coming."); b"" + } else if self.peeking { + self.clear_peeked_data(); + &self.peeked_data } else { fill_buf(&mut self.reader)? }; @@ -170,7 +210,9 @@ where (src.pos(), dst.pos()) }; - self.reader.consume(bytes_read); + if !is_peeking { + self.reader.consume(bytes_read); + } if bytes_written > 0 { return Ok(bytes_written); diff --git a/src/stream/zio/writer.rs b/src/stream/zio/writer.rs index 6ff1f125..98758775 100644 --- a/src/stream/zio/writer.rs +++ b/src/stream/zio/writer.rs @@ -358,6 +358,8 @@ mod tests { }; let mut target = vec![]; + assert!(decoder.read_skippable_frame(&mut frame).is_err()); + assert!(decoder.read_skippable_frame(&mut frame).is_err()); io::copy(&mut decoder, &mut target).unwrap(); assert_eq!("compressed frame 1", String::from_utf8(target).unwrap()); @@ -371,6 +373,8 @@ mod tests { let (size, _) = decoder.read_skippable_frame(&mut frame).unwrap(); assert_eq!("SKIP", String::from_utf8_lossy(&frame[..size])); + assert!(decoder.read_skippable_frame(&mut frame).is_err()); + assert!(decoder.read_skippable_frame(&mut frame).is_err()); decoder.skip_frame().unwrap(); let inner = decoder.finish(); @@ -391,6 +395,8 @@ mod tests { }; let mut target = vec![]; + assert!(decoder.read_skippable_frame(&mut frame).is_err()); + assert!(decoder.read_skippable_frame(&mut frame).is_err()); io::copy(&mut decoder, &mut target).unwrap(); assert_eq!("compressed frame 3", String::from_utf8(target).unwrap()); }