Skip to content

Commit

Permalink
Get repo_type from caller (#160)
Browse files Browse the repository at this point in the history
We'll use repo_type to determine the compression algorithm to use for
Xorbs.
The corresponding PR in huggingface_hub:
huggingface-internal/xetpoc_huggingface_hub#27
  • Loading branch information
seanses authored Feb 3, 2025
1 parent d3144cd commit e37d424
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
1 change: 1 addition & 0 deletions data/src/data_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ pub async fn upload_async(
token_info: Option<(String, u64)>,
token_refresher: Option<Arc<dyn TokenRefresher>>,
progress_updater: Option<Arc<dyn ProgressUpdater>>,
_repo_type: String,
) -> errors::Result<Vec<PointerFile>> {
// chunk files
// produce Xorbs + Shards
Expand Down
4 changes: 3 additions & 1 deletion hf_xet/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ fn convert_data_processing_error(e: DataProcessingError) -> PyErr {
}

#[pyfunction]
#[pyo3(signature = (file_paths, endpoint, token_info, token_refresher, progress_updater), text_signature = "(file_paths: List[str], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int, None]]) -> List[PyPointerFile]")]
#[pyo3(signature = (file_paths, endpoint, token_info, token_refresher, progress_updater, repo_type), text_signature = "(file_paths: List[str], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int], None]], repo_type: str) -> List[PyPointerFile]")]
pub fn upload_files(
py: Python,
file_paths: Vec<String>,
endpoint: Option<String>,
token_info: Option<(String, u64)>,
token_refresher: Option<Py<PyAny>>,
progress_updater: Option<Py<PyAny>>,
repo_type: String,
) -> PyResult<Vec<PyPointerFile>> {
let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new);
let updater = progress_updater
Expand All @@ -55,6 +56,7 @@ pub fn upload_files(
token_info,
refresher.map(|v| v as Arc<_>),
updater.map(|v| v as Arc<_>),
repo_type,
)
.await
.map_err(convert_data_processing_error)?
Expand Down

0 comments on commit e37d424

Please sign in to comment.