Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fallback to fully reading the package stream when downloading before attempting decompression #797

Merged
merged 19 commits into from
Aug 1, 2024
Merged
2 changes: 2 additions & 0 deletions crates/rattler_package_streaming/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ tempfile = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["fs"] }
tokio-util = { workspace = true, features = ["io-util"] }
tracing = { workspace = true }
url = { workspace = true }
zip = { workspace = true, features = ["deflate", "time"] }
zstd = { workspace = true, features = ["zstdmt"] }
Expand All @@ -44,3 +45,4 @@ tools = { path = "../tools" }
walkdir = { workspace = true }
rstest = { workspace = true }
rstest_reuse = { workspace = true }
insta = { workspace = true, features = ["yaml"] }
2 changes: 1 addition & 1 deletion crates/rattler_package_streaming/src/fs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub fn extract_tar_bz2(archive: &Path, destination: &Path) -> Result<ExtractResu
/// ```
pub fn extract_conda(archive: &Path, destination: &Path) -> Result<ExtractResult, ExtractError> {
let file = File::open(archive)?;
crate::read::extract_conda(file, destination)
crate::read::extract_conda_via_streaming(file, destination)
}

/// Extracts the contents a package archive at the specified path to a directory. The type of
Expand Down
90 changes: 69 additions & 21 deletions crates/rattler_package_streaming/src/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
//! [`std::io::Read`] trait.

use super::{ExtractError, ExtractResult};
use rattler_digest::HashingReader;
use std::io::{copy, Seek, SeekFrom};
use std::mem::ManuallyDrop;
use std::{ffi::OsStr, io::Read, path::Path};
use zip::read::read_zipfile_from_stream;
use tempfile::SpooledTempFile;
use zip::read::{read_zipfile_from_stream, ZipArchive, ZipFile};

/// Returns the `.tar.bz2` as a decompressed `tar::Archive`. The `tar::Archive` can be used to
/// extract the files from it, or perform introspection.
Expand Down Expand Up @@ -44,7 +47,10 @@ pub fn extract_tar_bz2(
}

/// Extracts the contents of a `.conda` package archive.
pub fn extract_conda(reader: impl Read, destination: &Path) -> Result<ExtractResult, ExtractError> {
pub fn extract_conda_via_streaming(
reader: impl Read,
destination: &Path,
) -> Result<ExtractResult, ExtractError> {
// Construct the destination path if it doesnt exist yet
std::fs::create_dir_all(destination).map_err(ExtractError::CouldNotCreateDestination)?;

Expand All @@ -56,27 +62,69 @@ pub fn extract_conda(reader: impl Read, destination: &Path) -> Result<ExtractRes

// Iterate over all entries in the zip-file and extract them one-by-one
while let Some(file) = read_zipfile_from_stream(&mut md5_reader)? {
// If an error occurs while we are reading the contents of the zip we don't want to
// seek to the end of the file. Using [`ManuallyDrop`] we prevent `drop` to be called on
// the `file` in case the stack unwinds.
let mut file = ManuallyDrop::new(file);

if file
.mangled_name()
.file_name()
.map(OsStr::to_string_lossy)
.map_or(false, |file_name| file_name.ends_with(".tar.zst"))
{
stream_tar_zst(&mut *file)?.unpack(destination)?;
} else {
// Manually read to the end of the stream if that didn't happen.
std::io::copy(&mut *file, &mut std::io::sink())?;
}

// Take the file out of the [`ManuallyDrop`] to properly drop it.
let _ = ManuallyDrop::into_inner(file);
extract_zipfile(file, destination)?;
}
compute_hashes(md5_reader)
}

/// Extracts the contents of a .conda package archive by fully reading the stream and then decompressing
pub fn extract_conda_via_buffering(
reader: impl Read,
destination: &Path,
) -> Result<ExtractResult, ExtractError> {
// delete destination first, as this method is usually used as a fallback from a failed streaming decompression
if destination.exists() {
std::fs::remove_dir_all(destination).map_err(ExtractError::CouldNotCreateDestination)?;
}
std::fs::create_dir_all(destination).map_err(ExtractError::CouldNotCreateDestination)?;

// Create a SpooledTempFile with a 5MB limit
let mut temp_file = SpooledTempFile::new(5 * 1024 * 1024);
let sha256_reader = rattler_digest::HashingReader::<_, rattler_digest::Sha256>::new(reader);
let mut md5_reader =
rattler_digest::HashingReader::<_, rattler_digest::Md5>::new(sha256_reader);

copy(&mut md5_reader, &mut temp_file)?;
temp_file.seek(SeekFrom::Start(0))?;
let mut archive = ZipArchive::new(temp_file)?;

for i in 0..archive.len() {
let file = archive.by_index(i)?;
extract_zipfile(file, destination)?;
}
// Read the file to the end to make sure the hash is properly computed.
std::io::copy(&mut md5_reader, &mut std::io::sink())?;

compute_hashes(md5_reader)
}

fn extract_zipfile(zip_file: ZipFile<'_>, destination: &Path) -> Result<(), ExtractError> {
// If an error occurs while we are reading the contents of the zip we don't want to
// seek to the end of the file. Using [`ManuallyDrop`] we prevent `drop` to be called on
// the `file` in case the stack unwinds.
let mut file = ManuallyDrop::new(zip_file);

if file
.mangled_name()
.file_name()
.map(OsStr::to_string_lossy)
.map_or(false, |file_name| file_name.ends_with(".tar.zst"))
{
stream_tar_zst(&mut *file)?.unpack(destination)?;
} else {
// Manually read to the end of the stream if that didn't happen.
std::io::copy(&mut *file, &mut std::io::sink())?;
}

// Take the file out of the [`ManuallyDrop`] to properly drop it.
let _ = ManuallyDrop::into_inner(file);

Ok(())
}

fn compute_hashes<R: Read>(
mut md5_reader: HashingReader<HashingReader<R, rattler_digest::Sha256>, rattler_digest::Md5>,
) -> Result<ExtractResult, ExtractError> {
// Read the file to the end to make sure the hash is properly computed.
std::io::copy(&mut md5_reader, &mut std::io::sink())?;

Expand Down
38 changes: 33 additions & 5 deletions crates/rattler_package_streaming/src/reqwest/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@ use std::sync::Arc;
use tokio::io::BufReader;
use tokio_util::either::Either;
use tokio_util::io::StreamReader;
use tracing;
use url::Url;
use zip::result::ZipError;

/// zipfiles may use data descriptors to signal that the decompressor needs to seek ahead in the buffer
/// to find the compressed data length.
/// Since we stream the package over a non seekable HTTP connection, this condition will cause an error during
/// decompression. In this case, we fallback to reading the whole data to a buffer before attempting decompression.
/// Read more in https://github.com/conda-incubator/rattler/issues/794
const DATA_DESCRIPTOR_ERROR_MESSAGE: &str = "The file length is not available in the local header";

fn error_for_status(response: reqwest::Response) -> reqwest_middleware::Result<Response> {
response
Expand Down Expand Up @@ -131,12 +140,31 @@ pub async fn extract_conda(
reporter: Option<Arc<dyn DownloadReporter>>,
) -> Result<ExtractResult, ExtractError> {
// The `response` is used to stream in the package data
let reader = get_reader(url.clone(), client, expected_sha256, reporter.clone()).await?;
let result = crate::tokio::async_read::extract_conda(reader, destination).await?;
if let Some(reporter) = &reporter {
reporter.on_download_complete();
let reader = get_reader(
url.clone(),
client.clone(),
expected_sha256,
reporter.clone(),
)
.await?;
match crate::tokio::async_read::extract_conda(reader, destination).await {
Ok(result) => {
if let Some(reporter) = &reporter {
reporter.on_download_complete();
}
Ok(result)
}
// https://github.com/conda-incubator/rattler/issues/794
Err(ExtractError::ZipError(ZipError::UnsupportedArchive(zip_error)))
if (zip_error.contains(DATA_DESCRIPTOR_ERROR_MESSAGE)) =>
{
tracing::warn!("Failed to stream decompress conda package from '{}' due to the presence of zip data descriptors. Falling back to non streaming decompression", url);
let new_reader =
get_reader(url.clone(), client, expected_sha256, reporter.clone()).await?;
crate::tokio::async_read::extract_conda_via_buffering(new_reader, destination).await
}
Err(e) => Err(e),
}
Ok(result)
}

/// Extracts the contents a package archive from the specified remote location. The type of package
Expand Down
36 changes: 34 additions & 2 deletions crates/rattler_package_streaming/src/tokio/async_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//! [`tokio::io::AsyncRead`] trait.

use crate::{ExtractError, ExtractResult};
use std::io::Read;
use std::path::Path;
use tokio::io::AsyncRead;
use tokio_util::io::SyncIoBridge;
Expand Down Expand Up @@ -30,17 +31,48 @@ pub async fn extract_tar_bz2(
}

/// Extracts the contents of a `.conda` package archive.
/// This will perform on-the-fly decompression by streaming the reader.
pub async fn extract_conda(
reader: impl AsyncRead + Send + 'static,
destination: &Path,
) -> Result<ExtractResult, ExtractError> {
extract_conda_internal(
reader,
destination,
crate::read::extract_conda_via_streaming,
)
.await
}

/// Extracts the contents of a .conda package archive by fully reading the stream and then decompressing
pub async fn extract_conda_via_buffering(
reader: impl AsyncRead + Send + 'static,
destination: &Path,
) -> Result<ExtractResult, ExtractError> {
extract_conda_internal(
reader,
destination,
crate::read::extract_conda_via_buffering,
)
.await
}

/// Extracts the contents of a `.conda` package archive using the provided extraction function
async fn extract_conda_internal(
reader: impl AsyncRead + Send + 'static,
destination: &Path,
extract_fn: fn(Box<dyn Read>, &Path) -> Result<ExtractResult, ExtractError>,
) -> Result<ExtractResult, ExtractError> {
// Create a async -> sync bridge
let reader = SyncIoBridge::new(Box::pin(reader));

// Spawn a block task to perform the extraction
let destination = destination.to_owned();
match tokio::task::spawn_blocking(move || crate::read::extract_conda(reader, &destination))
.await
match tokio::task::spawn_blocking(move || {
let reader: Box<dyn Read> = Box::new(reader);
extract_fn(reader, &destination)
})
.await
{
Ok(result) => result,
Err(err) => {
Expand Down
46 changes: 35 additions & 11 deletions crates/rattler_package_streaming/tests/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{

use rattler_conda_types::package::IndexJson;
use rattler_package_streaming::{
read::{extract_conda, extract_tar_bz2},
read::{extract_conda_via_buffering, extract_conda_via_streaming, extract_tar_bz2},
ExtractError,
};
use rstest::rstest;
Expand Down Expand Up @@ -111,7 +111,7 @@ fn test_extract_conda(#[case] input: Url, #[case] sha256: &str, #[case] md5: &st
println!("Target dir: {}", temp_dir.display());
let file_path = tools::download_and_cache_file(input, sha256).unwrap();
let target_dir = temp_dir.join(file_path.file_stem().unwrap());
let result = extract_conda(
let result = extract_conda_via_streaming(
File::open(test_data_dir().join(file_path)).unwrap(),
&target_dir,
)
Expand Down Expand Up @@ -211,14 +211,15 @@ async fn test_extract_conda_async(#[case] input: Url, #[case] sha256: &str, #[ca
.unwrap();

let target_dir = temp_dir.join(file_path.file_stem().unwrap());
let result = rattler_package_streaming::tokio::async_read::extract_conda(
tokio::fs::File::open(&test_data_dir().join(file_path))
.await
.unwrap(),
&target_dir,
)
.await
.unwrap();
let result: rattler_package_streaming::ExtractResult =
rattler_package_streaming::tokio::async_read::extract_conda(
tokio::fs::File::open(&test_data_dir().join(file_path))
.await
.unwrap(),
&target_dir,
)
.await
.unwrap();

assert_eq!(&format!("{:x}", result.sha256), sha256);
assert_eq!(&format!("{:x}", result.md5), md5);
Expand Down Expand Up @@ -266,7 +267,7 @@ fn test_extract_flaky_conda(#[values(0, 1, 13, 50, 74, 150, 8096, 16384, 20000)]
let temp_dir = Path::new(env!("CARGO_TARGET_TMPDIR"));
println!("Target dir: {}", temp_dir.display());
let target_dir = temp_dir.join(package_path.file_stem().unwrap());
let result = extract_conda(
let result = extract_conda_via_streaming(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the same here ^

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same response as above ^

FlakyReader {
reader: File::open(package_path).unwrap(),
total_read: 0,
Expand All @@ -279,6 +280,29 @@ fn test_extract_flaky_conda(#[values(0, 1, 13, 50, 74, 150, 8096, 16384, 20000)]
assert_matches::assert_matches!(result, ExtractError::IoError(_));
}

#[rstest]
fn test_extract_data_descriptor_package_fails_streaming_and_uses_buffering() {
let package_path = "tests/resources/ca-certificates-2024.7.4-hbcca054_0.conda";

let temp_dir = Path::new(env!("CARGO_TARGET_TMPDIR"));
let target_dir = temp_dir.join("package_using_data_descriptors");
let result = extract_conda_via_streaming(File::open(package_path).unwrap(), &target_dir)
.expect_err("this should error out and not panic");

assert_matches::assert_matches!(
result,
ExtractError::ZipError(zip::result::ZipError::UnsupportedArchive(
"The file length is not available in the local header"
))
);

let new_result =
extract_conda_via_buffering(File::open(package_path).unwrap(), &target_dir).unwrap();

insta::assert_snapshot!("new_result_sha256", &format!("{:x}", new_result.sha256));
insta::assert_snapshot!("new_result_md5", &format!("{:x}", new_result.md5));
}

struct FlakyReader<R: Read> {
reader: R,
cutoff: usize,
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: crates/rattler_package_streaming/tests/extract.rs
expression: "&format!(\"{:x}\", new_result.md5)"
---
a1d1adb5a5dc516dfb3dccc7b9b574a9
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: crates/rattler_package_streaming/tests/extract.rs
expression: "&format!(\"{:x}\", new_result.sha256)"
---
6a5d6d8a1a7552dbf8c617312ef951a77d2dac09f2aeaba661deebce603a7a97
4 changes: 2 additions & 2 deletions crates/rattler_package_streaming/tests/write.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use rattler_conda_types::package::ArchiveType;
use rattler_package_streaming::read::{extract_conda, extract_tar_bz2};
use rattler_package_streaming::read::{extract_conda_via_streaming, extract_tar_bz2};
use rattler_package_streaming::write::{
write_conda_package, write_tar_bz2_package, CompressionLevel,
};
Expand Down Expand Up @@ -209,7 +209,7 @@ fn test_rewrite_conda() {

let name = file_path.file_stem().unwrap().to_string_lossy();
let target_dir = temp_dir.join(file_path.file_stem().unwrap());
extract_conda(File::open(&file_path).unwrap(), &target_dir).unwrap();
extract_conda_via_streaming(File::open(&file_path).unwrap(), &target_dir).unwrap();

let new_archive = temp_dir.join(format!(
"{}-new.conda",
Expand Down
Loading