Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose a convenient function to migrate files in repo scanner #167

Merged
merged 3 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ members = [
"utils",
"cas_object",
"cas_types",
"chunk_cache", "xet_threadpool",
"chunk_cache",
"xet_threadpool"
]

exclude = ["hf_xet", "chunk_cache_bench"]
Expand Down
177 changes: 16 additions & 161 deletions data/src/bin/xtool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,8 @@ use anyhow::Result;
use cas_client::build_http_client;
use cas_object::CompressionScheme;
use clap::{Args, Parser, Subcommand};
use data::data_client::{clean_file, default_config, xorb_compression_for_repo_type};
use data::errors::DataProcessingError;
use data::{PointerFile, PointerFileTranslator};
use mdb_shard::file_structs::MDBFileInfo;
use parutils::{tokio_par_for_each, ParallelError};
use reqwest_middleware::ClientWithMiddleware;
use utils::auth::{TokenInfo, TokenRefresher};
use utils::errors::AuthError;
use data::migration_tool::hub_client::HubClient;
use data::migration_tool::migrate::migrate_files_impl;
use walkdir::WalkDir;
use xet_threadpool::ThreadPool;

Expand Down Expand Up @@ -67,67 +61,6 @@ impl XCommand {
}
}

#[derive(Debug)]
struct HubClient {
endpoint: String,
token: String,
repo_type: String,
repo_id: String,
client: ClientWithMiddleware,
}

impl HubClient {
// Get CAS access token from Hub access token. "token_type" is either "read" or "write".
async fn get_jwt_token(&self, token_type: &str) -> Result<(String, String, u64)> {
let endpoint = self.endpoint.as_str();
let repo_type = self.repo_type.as_str();
let repo_id = self.repo_id.as_str();

let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type}-token/main");

let response = self
.client
.get(url)
.bearer_auth(&self.token)
.header("user-agent", "xtool")
.send()
.await?;

let headers = response.headers();
let cas_endpoint = headers["X-Xet-Cas-Url"].to_str()?.to_owned();
let jwt_token: String = headers["X-Xet-Access-Token"].to_str()?.to_owned();
let jwt_token_expiry: u64 = headers["X-Xet-Token-Expiration"].to_str()?.parse()?;

Ok((cas_endpoint, jwt_token, jwt_token_expiry))
}

async fn refresh_jwt_token(&self, token_type: &str) -> Result<(String, u64)> {
let (_, jwt_token, jwt_token_expiry) = self.get_jwt_token(token_type).await?;

Ok((jwt_token, jwt_token_expiry))
}
}

#[derive(Debug)]
struct HubClientTokenRefresher {
threadpool: Arc<ThreadPool>,
token_type: String,
client: Arc<HubClient>,
}

impl TokenRefresher for HubClientTokenRefresher {
fn refresh(&self) -> std::result::Result<TokenInfo, AuthError> {
let client = self.client.clone();
let token_type = self.token_type.clone();
let ret = self
.threadpool
.external_run_async_task(async move { client.refresh_jwt_token(&token_type).await })
.map_err(|e| AuthError::TokenRefreshFailure(e.to_string()))?
.map_err(|e| AuthError::TokenRefreshFailure(e.to_string()))?;
Ok(ret)
}
}

#[derive(Subcommand)]
enum Command {
/// Dry-run of file upload to get file info after dedup.
Expand Down Expand Up @@ -175,9 +108,11 @@ impl Command {
async fn run(self, hub_client: HubClient, threadpool: Arc<ThreadPool>) -> Result<()> {
match self {
Command::Dedup(arg) => {
let (all_file_info, clean_ret, total_bytes_trans) = dedup_files(
arg.files,
arg.recursive,
let file_paths = walk_files(arg.files, arg.recursive);
eprintln!("Dedupping {} files...", file_paths.len());

let (all_file_info, clean_ret, total_bytes_trans) = migrate_files_impl(
file_paths,
arg.sequential,
hub_client,
threadpool,
Expand Down Expand Up @@ -206,50 +141,12 @@ impl Command {

Ok(())
},
Command::Query(arg) => query_file(arg.hash, hub_client, threadpool),
Command::Query(_arg) => unimplemented!(),
}
}
}

fn main() -> Result<()> {
let cli = XCommand::parse();
let threadpool = Arc::new(ThreadPool::new_with_hardware_parallelism_limit()?);
let threadpool_internal = threadpool.clone();
threadpool.external_run_async_task(async move { cli.run(threadpool_internal).await })??;

Ok(())
}

async fn dedup_files(
files: Vec<String>,
recursive: bool,
sequential: bool,
hub_client: HubClient,
threadpool: Arc<ThreadPool>,
compression: Option<CompressionScheme>,
dry_run: bool,
) -> Result<(Vec<MDBFileInfo>, Vec<(PointerFile, u64)>, u64)> {
let compression = compression.unwrap_or_else(|| xorb_compression_for_repo_type(&hub_client.repo_type));
eprintln!("Using {compression} compression");

let token_type = "write";
let (endpoint, jwt_token, jwt_token_expiry) = hub_client.get_jwt_token(token_type).await?;
let token_refresher = Arc::new(HubClientTokenRefresher {
threadpool: threadpool.clone(),
token_type: token_type.to_owned(),
client: Arc::new(hub_client),
}) as Arc<dyn TokenRefresher>;

let (config, _tempdir) =
default_config(endpoint, Some(compression), Some((jwt_token, jwt_token_expiry)), Some(token_refresher))?;

let num_workers = if sequential { 1 } else { threadpool.num_worker_threads() };
let processor = if dry_run {
Arc::new(PointerFileTranslator::dry_run(config, threadpool, None, false).await?)
} else {
Arc::new(PointerFileTranslator::new(config, threadpool, None, false).await?)
};

fn walk_files(files: Vec<String>, recursive: bool) -> Vec<String> {
// Scan all files if under recursive mode
let file_paths = if recursive {
files
Expand All @@ -272,60 +169,18 @@ async fn dedup_files(
files
};

eprintln!("Dedupping {} files...", file_paths.len());

let clean_ret = tokio_par_for_each(file_paths, num_workers, |f, _| async {
let proc = processor.clone();
clean_file(&proc, f).await
})
.await
.map_err(|e| match e {
ParallelError::JoinError => DataProcessingError::InternalError("Join error".to_string()),
ParallelError::TaskError(e) => e,
})?;

let total_bytes_trans = processor.finalize_cleaning().await?;

if dry_run {
let all_file_info = processor.summarize_file_info_of_session().await?;
Ok((all_file_info, clean_ret, total_bytes_trans))
} else {
Ok((vec![], clean_ret, total_bytes_trans))
}
file_paths
}

fn is_git_special_files(path: &str) -> bool {
matches!(path, ".git" | ".gitignore" | ".gitattributes")
}

fn query_file(_hash: String, _hub_client: HubClient, _threadpool: Arc<ThreadPool>) -> Result<()> {
todo!()
}

#[cfg(test)]
mod tests {
use anyhow::Result;
use cas_client::build_http_client;

use crate::HubClient;

#[tokio::test]
#[ignore = "need valid token"]
async fn test_get_jwt_token() -> Result<()> {
let hub_client = HubClient {
endpoint: "https://xethub-poc.us.dev.moon.huggingface.tech".to_owned(),
token: "[MASKED]".to_owned(),
repo_type: "dataset".to_owned(),
repo_id: "test/t2".to_owned(),
client: build_http_client(&None)?,
};

let (cas_endpoint, jwt_token, jwt_token_expiry) = hub_client.get_jwt_token("read").await?;

println!("{cas_endpoint}, {jwt_token}, {jwt_token_expiry}");

println!("{:?}", hub_client.refresh_jwt_token("write").await?);
fn main() -> Result<()> {
let cli = XCommand::parse();
let threadpool = Arc::new(ThreadPool::new_with_hardware_parallelism_limit()?);
let threadpool_internal = threadpool.clone();
threadpool.external_run_async_task(async move { cli.run(threadpool_internal).await })??;

Ok(())
}
Ok(())
}
1 change: 1 addition & 0 deletions data/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub mod data_client;
mod data_processing;
pub mod errors;
mod metrics;
pub mod migration_tool;
mod parallel_xorb_uploader;
mod pointer_file;
mod remote_shard_interface;
Expand Down
96 changes: 96 additions & 0 deletions data/src/migration_tool/hub_client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use std::sync::Arc;

use anyhow::Result;
use reqwest_middleware::ClientWithMiddleware;
use utils::auth::{TokenInfo, TokenRefresher};
use utils::errors::AuthError;
use xet_threadpool::ThreadPool;

#[derive(Debug)]
pub struct HubClient {
pub endpoint: String,
pub token: String,
pub repo_type: String,
pub repo_id: String,
pub client: ClientWithMiddleware,
}

impl HubClient {
// Get CAS access token from Hub access token. "token_type" is either "read" or "write".
pub async fn get_jwt_token(&self, token_type: &str) -> Result<(String, String, u64)> {
let endpoint = self.endpoint.as_str();
let repo_type = self.repo_type.as_str();
let repo_id = self.repo_id.as_str();

let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type}-token/main");

let response = self
.client
.get(url)
.bearer_auth(&self.token)
.header("user-agent", "xtool")
.send()
.await?;

let headers = response.headers();
let cas_endpoint = headers["X-Xet-Cas-Url"].to_str()?.to_owned();
let jwt_token: String = headers["X-Xet-Access-Token"].to_str()?.to_owned();
let jwt_token_expiry: u64 = headers["X-Xet-Token-Expiration"].to_str()?.parse()?;

Ok((cas_endpoint, jwt_token, jwt_token_expiry))
}

async fn refresh_jwt_token(&self, token_type: &str) -> Result<(String, u64)> {
let (_, jwt_token, jwt_token_expiry) = self.get_jwt_token(token_type).await?;

Ok((jwt_token, jwt_token_expiry))
}
}

#[derive(Debug)]
pub struct HubClientTokenRefresher {
pub threadpool: Arc<ThreadPool>,
pub token_type: String,
pub client: Arc<HubClient>,
}

impl TokenRefresher for HubClientTokenRefresher {
fn refresh(&self) -> std::result::Result<TokenInfo, AuthError> {
let client = self.client.clone();
let token_type = self.token_type.clone();
let ret = self
.threadpool
.external_run_async_task(async move { client.refresh_jwt_token(&token_type).await })
.map_err(|e| AuthError::TokenRefreshFailure(e.to_string()))?
.map_err(|e| AuthError::TokenRefreshFailure(e.to_string()))?;
Ok(ret)
}
}

#[cfg(test)]
mod tests {
use anyhow::Result;
use cas_client::build_http_client;

use super::HubClient;

#[tokio::test]
#[ignore = "need valid token"]
async fn test_get_jwt_token() -> Result<()> {
let hub_client = HubClient {
endpoint: "https://xethub-poc.us.dev.moon.huggingface.tech".to_owned(),
token: "[MASKED]".to_owned(),
repo_type: "dataset".to_owned(),
repo_id: "test/t2".to_owned(),
client: build_http_client(&None)?,
};

let (cas_endpoint, jwt_token, jwt_token_expiry) = hub_client.get_jwt_token("read").await?;

println!("{cas_endpoint}, {jwt_token}, {jwt_token_expiry}");

println!("{:?}", hub_client.refresh_jwt_token("write").await?);

Ok(())
}
}
Loading