diff --git a/Cargo.toml b/Cargo.toml index 4f25678..03cfc7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,8 @@ members = [ "utils", "cas_object", "cas_types", - "chunk_cache", "xet_threadpool", + "chunk_cache", + "xet_threadpool" ] exclude = ["hf_xet", "chunk_cache_bench"] diff --git a/data/src/bin/xtool.rs b/data/src/bin/xtool.rs index 96ea09b..dc938fe 100644 --- a/data/src/bin/xtool.rs +++ b/data/src/bin/xtool.rs @@ -7,14 +7,8 @@ use anyhow::Result; use cas_client::build_http_client; use cas_object::CompressionScheme; use clap::{Args, Parser, Subcommand}; -use data::data_client::{clean_file, default_config, xorb_compression_for_repo_type}; -use data::errors::DataProcessingError; -use data::{PointerFile, PointerFileTranslator}; -use mdb_shard::file_structs::MDBFileInfo; -use parutils::{tokio_par_for_each, ParallelError}; -use reqwest_middleware::ClientWithMiddleware; -use utils::auth::{TokenInfo, TokenRefresher}; -use utils::errors::AuthError; +use data::migration_tool::hub_client::HubClient; +use data::migration_tool::migrate::migrate_files_impl; use walkdir::WalkDir; use xet_threadpool::ThreadPool; @@ -67,67 +61,6 @@ impl XCommand { } } -#[derive(Debug)] -struct HubClient { - endpoint: String, - token: String, - repo_type: String, - repo_id: String, - client: ClientWithMiddleware, -} - -impl HubClient { - // Get CAS access token from Hub access token. "token_type" is either "read" or "write". - async fn get_jwt_token(&self, token_type: &str) -> Result<(String, String, u64)> { - let endpoint = self.endpoint.as_str(); - let repo_type = self.repo_type.as_str(); - let repo_id = self.repo_id.as_str(); - - let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type}-token/main"); - - let response = self - .client - .get(url) - .bearer_auth(&self.token) - .header("user-agent", "xtool") - .send() - .await?; - - let headers = response.headers(); - let cas_endpoint = headers["X-Xet-Cas-Url"].to_str()?.to_owned(); - let jwt_token: String = headers["X-Xet-Access-Token"].to_str()?.to_owned(); - let jwt_token_expiry: u64 = headers["X-Xet-Token-Expiration"].to_str()?.parse()?; - - Ok((cas_endpoint, jwt_token, jwt_token_expiry)) - } - - async fn refresh_jwt_token(&self, token_type: &str) -> Result<(String, u64)> { - let (_, jwt_token, jwt_token_expiry) = self.get_jwt_token(token_type).await?; - - Ok((jwt_token, jwt_token_expiry)) - } -} - -#[derive(Debug)] -struct HubClientTokenRefresher { - threadpool: Arc, - token_type: String, - client: Arc, -} - -impl TokenRefresher for HubClientTokenRefresher { - fn refresh(&self) -> std::result::Result { - let client = self.client.clone(); - let token_type = self.token_type.clone(); - let ret = self - .threadpool - .external_run_async_task(async move { client.refresh_jwt_token(&token_type).await }) - .map_err(|e| AuthError::TokenRefreshFailure(e.to_string()))? - .map_err(|e| AuthError::TokenRefreshFailure(e.to_string()))?; - Ok(ret) - } -} - #[derive(Subcommand)] enum Command { /// Dry-run of file upload to get file info after dedup. @@ -175,9 +108,11 @@ impl Command { async fn run(self, hub_client: HubClient, threadpool: Arc) -> Result<()> { match self { Command::Dedup(arg) => { - let (all_file_info, clean_ret, total_bytes_trans) = dedup_files( - arg.files, - arg.recursive, + let file_paths = walk_files(arg.files, arg.recursive); + eprintln!("Dedupping {} files...", file_paths.len()); + + let (all_file_info, clean_ret, total_bytes_trans) = migrate_files_impl( + file_paths, arg.sequential, hub_client, threadpool, @@ -206,50 +141,12 @@ impl Command { Ok(()) }, - Command::Query(arg) => query_file(arg.hash, hub_client, threadpool), + Command::Query(_arg) => unimplemented!(), } } } -fn main() -> Result<()> { - let cli = XCommand::parse(); - let threadpool = Arc::new(ThreadPool::new_with_hardware_parallelism_limit()?); - let threadpool_internal = threadpool.clone(); - threadpool.external_run_async_task(async move { cli.run(threadpool_internal).await })??; - - Ok(()) -} - -async fn dedup_files( - files: Vec, - recursive: bool, - sequential: bool, - hub_client: HubClient, - threadpool: Arc, - compression: Option, - dry_run: bool, -) -> Result<(Vec, Vec<(PointerFile, u64)>, u64)> { - let compression = compression.unwrap_or_else(|| xorb_compression_for_repo_type(&hub_client.repo_type)); - eprintln!("Using {compression} compression"); - - let token_type = "write"; - let (endpoint, jwt_token, jwt_token_expiry) = hub_client.get_jwt_token(token_type).await?; - let token_refresher = Arc::new(HubClientTokenRefresher { - threadpool: threadpool.clone(), - token_type: token_type.to_owned(), - client: Arc::new(hub_client), - }) as Arc; - - let (config, _tempdir) = - default_config(endpoint, Some(compression), Some((jwt_token, jwt_token_expiry)), Some(token_refresher))?; - - let num_workers = if sequential { 1 } else { threadpool.num_worker_threads() }; - let processor = if dry_run { - Arc::new(PointerFileTranslator::dry_run(config, threadpool, None, false).await?) - } else { - Arc::new(PointerFileTranslator::new(config, threadpool, None, false).await?) - }; - +fn walk_files(files: Vec, recursive: bool) -> Vec { // Scan all files if under recursive mode let file_paths = if recursive { files @@ -272,60 +169,18 @@ async fn dedup_files( files }; - eprintln!("Dedupping {} files...", file_paths.len()); - - let clean_ret = tokio_par_for_each(file_paths, num_workers, |f, _| async { - let proc = processor.clone(); - clean_file(&proc, f).await - }) - .await - .map_err(|e| match e { - ParallelError::JoinError => DataProcessingError::InternalError("Join error".to_string()), - ParallelError::TaskError(e) => e, - })?; - - let total_bytes_trans = processor.finalize_cleaning().await?; - - if dry_run { - let all_file_info = processor.summarize_file_info_of_session().await?; - Ok((all_file_info, clean_ret, total_bytes_trans)) - } else { - Ok((vec![], clean_ret, total_bytes_trans)) - } + file_paths } fn is_git_special_files(path: &str) -> bool { matches!(path, ".git" | ".gitignore" | ".gitattributes") } -fn query_file(_hash: String, _hub_client: HubClient, _threadpool: Arc) -> Result<()> { - todo!() -} - -#[cfg(test)] -mod tests { - use anyhow::Result; - use cas_client::build_http_client; - - use crate::HubClient; - - #[tokio::test] - #[ignore = "need valid token"] - async fn test_get_jwt_token() -> Result<()> { - let hub_client = HubClient { - endpoint: "https://xethub-poc.us.dev.moon.huggingface.tech".to_owned(), - token: "[MASKED]".to_owned(), - repo_type: "dataset".to_owned(), - repo_id: "test/t2".to_owned(), - client: build_http_client(&None)?, - }; - - let (cas_endpoint, jwt_token, jwt_token_expiry) = hub_client.get_jwt_token("read").await?; - - println!("{cas_endpoint}, {jwt_token}, {jwt_token_expiry}"); - - println!("{:?}", hub_client.refresh_jwt_token("write").await?); +fn main() -> Result<()> { + let cli = XCommand::parse(); + let threadpool = Arc::new(ThreadPool::new_with_hardware_parallelism_limit()?); + let threadpool_internal = threadpool.clone(); + threadpool.external_run_async_task(async move { cli.run(threadpool_internal).await })??; - Ok(()) - } + Ok(()) } diff --git a/data/src/lib.rs b/data/src/lib.rs index ef10fd2..96ad227 100644 --- a/data/src/lib.rs +++ b/data/src/lib.rs @@ -9,6 +9,7 @@ pub mod data_client; mod data_processing; pub mod errors; mod metrics; +pub mod migration_tool; mod parallel_xorb_uploader; mod pointer_file; mod remote_shard_interface; diff --git a/data/src/migration_tool/hub_client.rs b/data/src/migration_tool/hub_client.rs new file mode 100644 index 0000000..1f1a1d3 --- /dev/null +++ b/data/src/migration_tool/hub_client.rs @@ -0,0 +1,96 @@ +use std::sync::Arc; + +use anyhow::Result; +use reqwest_middleware::ClientWithMiddleware; +use utils::auth::{TokenInfo, TokenRefresher}; +use utils::errors::AuthError; +use xet_threadpool::ThreadPool; + +#[derive(Debug)] +pub struct HubClient { + pub endpoint: String, + pub token: String, + pub repo_type: String, + pub repo_id: String, + pub client: ClientWithMiddleware, +} + +impl HubClient { + // Get CAS access token from Hub access token. "token_type" is either "read" or "write". + pub async fn get_jwt_token(&self, token_type: &str) -> Result<(String, String, u64)> { + let endpoint = self.endpoint.as_str(); + let repo_type = self.repo_type.as_str(); + let repo_id = self.repo_id.as_str(); + + let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type}-token/main"); + + let response = self + .client + .get(url) + .bearer_auth(&self.token) + .header("user-agent", "xtool") + .send() + .await?; + + let headers = response.headers(); + let cas_endpoint = headers["X-Xet-Cas-Url"].to_str()?.to_owned(); + let jwt_token: String = headers["X-Xet-Access-Token"].to_str()?.to_owned(); + let jwt_token_expiry: u64 = headers["X-Xet-Token-Expiration"].to_str()?.parse()?; + + Ok((cas_endpoint, jwt_token, jwt_token_expiry)) + } + + async fn refresh_jwt_token(&self, token_type: &str) -> Result<(String, u64)> { + let (_, jwt_token, jwt_token_expiry) = self.get_jwt_token(token_type).await?; + + Ok((jwt_token, jwt_token_expiry)) + } +} + +#[derive(Debug)] +pub struct HubClientTokenRefresher { + pub threadpool: Arc, + pub token_type: String, + pub client: Arc, +} + +impl TokenRefresher for HubClientTokenRefresher { + fn refresh(&self) -> std::result::Result { + let client = self.client.clone(); + let token_type = self.token_type.clone(); + let ret = self + .threadpool + .external_run_async_task(async move { client.refresh_jwt_token(&token_type).await }) + .map_err(|e| AuthError::TokenRefreshFailure(e.to_string()))? + .map_err(|e| AuthError::TokenRefreshFailure(e.to_string()))?; + Ok(ret) + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use cas_client::build_http_client; + + use super::HubClient; + + #[tokio::test] + #[ignore = "need valid token"] + async fn test_get_jwt_token() -> Result<()> { + let hub_client = HubClient { + endpoint: "https://xethub-poc.us.dev.moon.huggingface.tech".to_owned(), + token: "[MASKED]".to_owned(), + repo_type: "dataset".to_owned(), + repo_id: "test/t2".to_owned(), + client: build_http_client(&None)?, + }; + + let (cas_endpoint, jwt_token, jwt_token_expiry) = hub_client.get_jwt_token("read").await?; + + println!("{cas_endpoint}, {jwt_token}, {jwt_token_expiry}"); + + println!("{:?}", hub_client.refresh_jwt_token("write").await?); + + Ok(()) + } +} diff --git a/data/src/migration_tool/migrate.rs b/data/src/migration_tool/migrate.rs new file mode 100644 index 0000000..15817a8 --- /dev/null +++ b/data/src/migration_tool/migrate.rs @@ -0,0 +1,98 @@ +use std::sync::Arc; + +use anyhow::Result; +use cas_client::build_http_client; +use cas_object::CompressionScheme; +use mdb_shard::file_structs::MDBFileInfo; +use parutils::{tokio_par_for_each, ParallelError}; +use utils::auth::TokenRefresher; +use xet_threadpool::ThreadPool; + +use super::hub_client::{HubClient, HubClientTokenRefresher}; +use crate::data_client::{clean_file, default_config, xorb_compression_for_repo_type}; +use crate::errors::DataProcessingError; +use crate::{PointerFile, PointerFileTranslator}; + +/// Migrate files to the Hub with external async runtime. +/// How to use: +/// ```no_run +/// let file_paths = vec!["/path/to/file1".to_string(), "/path/to/file2".to_string()]; +/// let hub_endpoint = "https://huggingface.co"; +/// let hub_token = "your_token"; +/// let repo_type = "model"; +/// let repo_id = "your_repo_id"; +/// let handle = tokio::runtime::Handle::current(); +/// migrate_with_external_runtime(file_paths, hub_endpoint, hub_token, repo_type, repo_id, handle) +/// .await?; +/// ``` +pub async fn migrate_with_external_runtime( + file_paths: Vec, + hub_endpoint: &str, + hub_token: &str, + repo_type: &str, + repo_id: &str, + handle: tokio::runtime::Handle, +) -> Result<()> { + let hub_client = HubClient { + endpoint: hub_endpoint.to_owned(), + token: hub_token.to_owned(), + repo_type: repo_type.to_owned(), + repo_id: repo_id.to_owned(), + client: build_http_client(&None)?, + }; + + let threadpool = Arc::new(ThreadPool::from_external(handle)); + + migrate_files_impl(file_paths, false, hub_client, threadpool, None, false).await?; + + Ok(()) +} + +pub async fn migrate_files_impl( + file_paths: Vec, + sequential: bool, + hub_client: HubClient, + threadpool: Arc, + compression: Option, + dry_run: bool, +) -> Result<(Vec, Vec<(PointerFile, u64)>, u64)> { + let compression = compression.unwrap_or_else(|| xorb_compression_for_repo_type(&hub_client.repo_type)); + eprintln!("Using {compression} compression"); + + let token_type = "write"; + let (endpoint, jwt_token, jwt_token_expiry) = hub_client.get_jwt_token(token_type).await?; + let token_refresher = Arc::new(HubClientTokenRefresher { + threadpool: threadpool.clone(), + token_type: token_type.to_owned(), + client: Arc::new(hub_client), + }) as Arc; + + let (config, _tempdir) = + default_config(endpoint, Some(compression), Some((jwt_token, jwt_token_expiry)), Some(token_refresher))?; + + let num_workers = if sequential { 1 } else { threadpool.num_worker_threads() }; + let processor = if dry_run { + Arc::new(PointerFileTranslator::dry_run(config, threadpool, None, false).await?) + } else { + Arc::new(PointerFileTranslator::new(config, threadpool, None, false).await?) + }; + + let clean_ret = tokio_par_for_each(file_paths, num_workers, |f, _| async { + let proc = processor.clone(); + clean_file(&proc, f).await + }) + .await + .map_err(|e| match e { + ParallelError::JoinError => DataProcessingError::InternalError("Join error".to_string()), + ParallelError::TaskError(e) => e, + })?; + + let total_bytes_trans = processor.finalize_cleaning().await?; + + if dry_run { + let all_file_info = processor.summarize_file_info_of_session().await?; + Ok((all_file_info, clean_ret, total_bytes_trans)) + } else { + Ok((vec![], clean_ret, total_bytes_trans)) + } +} diff --git a/data/src/migration_tool/mod.rs b/data/src/migration_tool/mod.rs new file mode 100644 index 0000000..d7135fa --- /dev/null +++ b/data/src/migration_tool/mod.rs @@ -0,0 +1,2 @@ +pub mod hub_client; +pub mod migrate; diff --git a/xet_threadpool/src/threadpool.rs b/xet_threadpool/src/threadpool.rs index 89dfb09..d4db848 100644 --- a/xet_threadpool/src/threadpool.rs +++ b/xet_threadpool/src/threadpool.rs @@ -98,6 +98,15 @@ impl ThreadPool { }) } + pub fn from_external(handle: tokio::runtime::Handle) -> Self { + Self { + runtime: std::sync::RwLock::new(None), + handle, + external_executor_count: 0.into(), + sigint_shutdown: false.into(), + } + } + pub fn num_worker_threads(&self) -> usize { self.handle.metrics().num_workers() }