From e37d424cbcb99731b5353ce146798f4205e93942 Mon Sep 17 00:00:00 2001 From: Di Xiao Date: Tue, 4 Feb 2025 03:20:51 +0800 Subject: [PATCH] Get repo_type from caller (#160) We'll use repo_type to determine the compression algorithm to use for Xorbs. The corresponding PR in huggingface_hub: https://github.com/huggingface-internal/xetpoc_huggingface_hub/pull/27 --- data/src/data_client.rs | 1 + hf_xet/src/lib.rs | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/data/src/data_client.rs b/data/src/data_client.rs index e3abf3c7..e50f2ac9 100644 --- a/data/src/data_client.rs +++ b/data/src/data_client.rs @@ -108,6 +108,7 @@ pub async fn upload_async( token_info: Option<(String, u64)>, token_refresher: Option>, progress_updater: Option>, + _repo_type: String, ) -> errors::Result> { // chunk files // produce Xorbs + Shards diff --git a/hf_xet/src/lib.rs b/hf_xet/src/lib.rs index bb1fdd0c..0615c190 100644 --- a/hf_xet/src/lib.rs +++ b/hf_xet/src/lib.rs @@ -32,7 +32,7 @@ fn convert_data_processing_error(e: DataProcessingError) -> PyErr { } #[pyfunction] -#[pyo3(signature = (file_paths, endpoint, token_info, token_refresher, progress_updater), text_signature = "(file_paths: List[str], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int, None]]) -> List[PyPointerFile]")] +#[pyo3(signature = (file_paths, endpoint, token_info, token_refresher, progress_updater, repo_type), text_signature = "(file_paths: List[str], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int], None]], repo_type: str) -> List[PyPointerFile]")] pub fn upload_files( py: Python, file_paths: Vec, @@ -40,6 +40,7 @@ pub fn upload_files( token_info: Option<(String, u64)>, token_refresher: Option>, progress_updater: Option>, + repo_type: String, ) -> PyResult> { let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new); let updater = progress_updater @@ -55,6 +56,7 @@ pub fn upload_files( token_info, refresher.map(|v| v as Arc<_>), updater.map(|v| v as Arc<_>), + repo_type, ) .await .map_err(convert_data_processing_error)?