Skip to content

Commit

Permalink
Expose a convenient function to migrate files in repo scanner (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanses authored Feb 6, 2025
1 parent 169c368 commit fc6c3c9
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 162 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ members = [
"utils",
"cas_object",
"cas_types",
"chunk_cache", "xet_threadpool",
"chunk_cache",
"xet_threadpool"
]

exclude = ["hf_xet", "chunk_cache_bench"]
Expand Down
177 changes: 16 additions & 161 deletions data/src/bin/xtool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,8 @@ use anyhow::Result;
use cas_client::build_http_client;
use cas_object::CompressionScheme;
use clap::{Args, Parser, Subcommand};
use data::data_client::{clean_file, default_config, xorb_compression_for_repo_type};
use data::errors::DataProcessingError;
use data::{PointerFile, PointerFileTranslator};
use mdb_shard::file_structs::MDBFileInfo;
use parutils::{tokio_par_for_each, ParallelError};
use reqwest_middleware::ClientWithMiddleware;
use utils::auth::{TokenInfo, TokenRefresher};
use utils::errors::AuthError;
use data::migration_tool::hub_client::HubClient;
use data::migration_tool::migrate::migrate_files_impl;
use walkdir::WalkDir;
use xet_threadpool::ThreadPool;

Expand Down Expand Up @@ -67,67 +61,6 @@ impl XCommand {
}
}

#[derive(Debug)]
struct HubClient {
endpoint: String,
token: String,
repo_type: String,
repo_id: String,
client: ClientWithMiddleware,
}

impl HubClient {
// Get CAS access token from Hub access token. "token_type" is either "read" or "write".
async fn get_jwt_token(&self, token_type: &str) -> Result<(String, String, u64)> {
let endpoint = self.endpoint.as_str();
let repo_type = self.repo_type.as_str();
let repo_id = self.repo_id.as_str();

let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type}-token/main");

let response = self
.client
.get(url)
.bearer_auth(&self.token)
.header("user-agent", "xtool")
.send()
.await?;

let headers = response.headers();
let cas_endpoint = headers["X-Xet-Cas-Url"].to_str()?.to_owned();
let jwt_token: String = headers["X-Xet-Access-Token"].to_str()?.to_owned();
let jwt_token_expiry: u64 = headers["X-Xet-Token-Expiration"].to_str()?.parse()?;

Ok((cas_endpoint, jwt_token, jwt_token_expiry))
}

async fn refresh_jwt_token(&self, token_type: &str) -> Result<(String, u64)> {
let (_, jwt_token, jwt_token_expiry) = self.get_jwt_token(token_type).await?;

Ok((jwt_token, jwt_token_expiry))
}
}

#[derive(Debug)]
struct HubClientTokenRefresher {
threadpool: Arc<ThreadPool>,
token_type: String,
client: Arc<HubClient>,
}

impl TokenRefresher for HubClientTokenRefresher {
fn refresh(&self) -> std::result::Result<TokenInfo, AuthError> {
let client = self.client.clone();
let token_type = self.token_type.clone();
let ret = self
.threadpool
.external_run_async_task(async move { client.refresh_jwt_token(&token_type).await })
.map_err(|e| AuthError::TokenRefreshFailure(e.to_string()))?
.map_err(|e| AuthError::TokenRefreshFailure(e.to_string()))?;
Ok(ret)
}
}

#[derive(Subcommand)]
enum Command {
/// Dry-run of file upload to get file info after dedup.
Expand Down Expand Up @@ -175,9 +108,11 @@ impl Command {
async fn run(self, hub_client: HubClient, threadpool: Arc<ThreadPool>) -> Result<()> {
match self {
Command::Dedup(arg) => {
let (all_file_info, clean_ret, total_bytes_trans) = dedup_files(
arg.files,
arg.recursive,
let file_paths = walk_files(arg.files, arg.recursive);
eprintln!("Dedupping {} files...", file_paths.len());

let (all_file_info, clean_ret, total_bytes_trans) = migrate_files_impl(
file_paths,
arg.sequential,
hub_client,
threadpool,
Expand Down Expand Up @@ -206,50 +141,12 @@ impl Command {

Ok(())
},
Command::Query(arg) => query_file(arg.hash, hub_client, threadpool),
Command::Query(_arg) => unimplemented!(),
}
}
}

fn main() -> Result<()> {
let cli = XCommand::parse();
let threadpool = Arc::new(ThreadPool::new_with_hardware_parallelism_limit()?);
let threadpool_internal = threadpool.clone();
threadpool.external_run_async_task(async move { cli.run(threadpool_internal).await })??;

Ok(())
}

async fn dedup_files(
files: Vec<String>,
recursive: bool,
sequential: bool,
hub_client: HubClient,
threadpool: Arc<ThreadPool>,
compression: Option<CompressionScheme>,
dry_run: bool,
) -> Result<(Vec<MDBFileInfo>, Vec<(PointerFile, u64)>, u64)> {
let compression = compression.unwrap_or_else(|| xorb_compression_for_repo_type(&hub_client.repo_type));
eprintln!("Using {compression} compression");

let token_type = "write";
let (endpoint, jwt_token, jwt_token_expiry) = hub_client.get_jwt_token(token_type).await?;
let token_refresher = Arc::new(HubClientTokenRefresher {
threadpool: threadpool.clone(),
token_type: token_type.to_owned(),
client: Arc::new(hub_client),
}) as Arc<dyn TokenRefresher>;

let (config, _tempdir) =
default_config(endpoint, Some(compression), Some((jwt_token, jwt_token_expiry)), Some(token_refresher))?;

let num_workers = if sequential { 1 } else { threadpool.num_worker_threads() };
let processor = if dry_run {
Arc::new(PointerFileTranslator::dry_run(config, threadpool, None, false).await?)
} else {
Arc::new(PointerFileTranslator::new(config, threadpool, None, false).await?)
};

fn walk_files(files: Vec<String>, recursive: bool) -> Vec<String> {
// Scan all files if under recursive mode
let file_paths = if recursive {
files
Expand All @@ -272,60 +169,18 @@ async fn dedup_files(
files
};

eprintln!("Dedupping {} files...", file_paths.len());

let clean_ret = tokio_par_for_each(file_paths, num_workers, |f, _| async {
let proc = processor.clone();
clean_file(&proc, f).await
})
.await
.map_err(|e| match e {
ParallelError::JoinError => DataProcessingError::InternalError("Join error".to_string()),
ParallelError::TaskError(e) => e,
})?;

let total_bytes_trans = processor.finalize_cleaning().await?;

if dry_run {
let all_file_info = processor.summarize_file_info_of_session().await?;
Ok((all_file_info, clean_ret, total_bytes_trans))
} else {
Ok((vec![], clean_ret, total_bytes_trans))
}
file_paths
}

fn is_git_special_files(path: &str) -> bool {
matches!(path, ".git" | ".gitignore" | ".gitattributes")
}

fn query_file(_hash: String, _hub_client: HubClient, _threadpool: Arc<ThreadPool>) -> Result<()> {
todo!()
}

#[cfg(test)]
mod tests {
use anyhow::Result;
use cas_client::build_http_client;

use crate::HubClient;

#[tokio::test]
#[ignore = "need valid token"]
async fn test_get_jwt_token() -> Result<()> {
let hub_client = HubClient {
endpoint: "https://xethub-poc.us.dev.moon.huggingface.tech".to_owned(),
token: "[MASKED]".to_owned(),
repo_type: "dataset".to_owned(),
repo_id: "test/t2".to_owned(),
client: build_http_client(&None)?,
};

let (cas_endpoint, jwt_token, jwt_token_expiry) = hub_client.get_jwt_token("read").await?;

println!("{cas_endpoint}, {jwt_token}, {jwt_token_expiry}");

println!("{:?}", hub_client.refresh_jwt_token("write").await?);
fn main() -> Result<()> {
let cli = XCommand::parse();
let threadpool = Arc::new(ThreadPool::new_with_hardware_parallelism_limit()?);
let threadpool_internal = threadpool.clone();
threadpool.external_run_async_task(async move { cli.run(threadpool_internal).await })??;

Ok(())
}
Ok(())
}
1 change: 1 addition & 0 deletions data/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub mod data_client;
mod data_processing;
pub mod errors;
mod metrics;
pub mod migration_tool;
mod parallel_xorb_uploader;
mod pointer_file;
mod remote_shard_interface;
Expand Down
96 changes: 96 additions & 0 deletions data/src/migration_tool/hub_client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use std::sync::Arc;

use anyhow::Result;
use reqwest_middleware::ClientWithMiddleware;
use utils::auth::{TokenInfo, TokenRefresher};
use utils::errors::AuthError;
use xet_threadpool::ThreadPool;

#[derive(Debug)]
pub struct HubClient {
pub endpoint: String,
pub token: String,
pub repo_type: String,
pub repo_id: String,
pub client: ClientWithMiddleware,
}

impl HubClient {
// Get CAS access token from Hub access token. "token_type" is either "read" or "write".
pub async fn get_jwt_token(&self, token_type: &str) -> Result<(String, String, u64)> {
let endpoint = self.endpoint.as_str();
let repo_type = self.repo_type.as_str();
let repo_id = self.repo_id.as_str();

let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type}-token/main");

let response = self
.client
.get(url)
.bearer_auth(&self.token)
.header("user-agent", "xtool")
.send()
.await?;

let headers = response.headers();
let cas_endpoint = headers["X-Xet-Cas-Url"].to_str()?.to_owned();
let jwt_token: String = headers["X-Xet-Access-Token"].to_str()?.to_owned();
let jwt_token_expiry: u64 = headers["X-Xet-Token-Expiration"].to_str()?.parse()?;

Ok((cas_endpoint, jwt_token, jwt_token_expiry))
}

async fn refresh_jwt_token(&self, token_type: &str) -> Result<(String, u64)> {
let (_, jwt_token, jwt_token_expiry) = self.get_jwt_token(token_type).await?;

Ok((jwt_token, jwt_token_expiry))
}
}

#[derive(Debug)]
pub struct HubClientTokenRefresher {
pub threadpool: Arc<ThreadPool>,
pub token_type: String,
pub client: Arc<HubClient>,
}

impl TokenRefresher for HubClientTokenRefresher {
fn refresh(&self) -> std::result::Result<TokenInfo, AuthError> {
let client = self.client.clone();
let token_type = self.token_type.clone();
let ret = self
.threadpool
.external_run_async_task(async move { client.refresh_jwt_token(&token_type).await })
.map_err(|e| AuthError::TokenRefreshFailure(e.to_string()))?
.map_err(|e| AuthError::TokenRefreshFailure(e.to_string()))?;
Ok(ret)
}
}

#[cfg(test)]
mod tests {
use anyhow::Result;
use cas_client::build_http_client;

use super::HubClient;

#[tokio::test]
#[ignore = "need valid token"]
async fn test_get_jwt_token() -> Result<()> {
let hub_client = HubClient {
endpoint: "https://xethub-poc.us.dev.moon.huggingface.tech".to_owned(),
token: "[MASKED]".to_owned(),
repo_type: "dataset".to_owned(),
repo_id: "test/t2".to_owned(),
client: build_http_client(&None)?,
};

let (cas_endpoint, jwt_token, jwt_token_expiry) = hub_client.get_jwt_token("read").await?;

println!("{cas_endpoint}, {jwt_token}, {jwt_token_expiry}");

println!("{:?}", hub_client.refresh_jwt_token("write").await?);

Ok(())
}
}
Loading

0 comments on commit fc6c3c9

Please sign in to comment.