-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Expose a convenient function to migrate files in repo scanner (#167)
- Loading branch information
Showing
7 changed files
with
224 additions
and
162 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
use std::sync::Arc; | ||
|
||
use anyhow::Result; | ||
use reqwest_middleware::ClientWithMiddleware; | ||
use utils::auth::{TokenInfo, TokenRefresher}; | ||
use utils::errors::AuthError; | ||
use xet_threadpool::ThreadPool; | ||
|
||
#[derive(Debug)] | ||
pub struct HubClient { | ||
pub endpoint: String, | ||
pub token: String, | ||
pub repo_type: String, | ||
pub repo_id: String, | ||
pub client: ClientWithMiddleware, | ||
} | ||
|
||
impl HubClient { | ||
// Get CAS access token from Hub access token. "token_type" is either "read" or "write". | ||
pub async fn get_jwt_token(&self, token_type: &str) -> Result<(String, String, u64)> { | ||
let endpoint = self.endpoint.as_str(); | ||
let repo_type = self.repo_type.as_str(); | ||
let repo_id = self.repo_id.as_str(); | ||
|
||
let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type}-token/main"); | ||
|
||
let response = self | ||
.client | ||
.get(url) | ||
.bearer_auth(&self.token) | ||
.header("user-agent", "xtool") | ||
.send() | ||
.await?; | ||
|
||
let headers = response.headers(); | ||
let cas_endpoint = headers["X-Xet-Cas-Url"].to_str()?.to_owned(); | ||
let jwt_token: String = headers["X-Xet-Access-Token"].to_str()?.to_owned(); | ||
let jwt_token_expiry: u64 = headers["X-Xet-Token-Expiration"].to_str()?.parse()?; | ||
|
||
Ok((cas_endpoint, jwt_token, jwt_token_expiry)) | ||
} | ||
|
||
async fn refresh_jwt_token(&self, token_type: &str) -> Result<(String, u64)> { | ||
let (_, jwt_token, jwt_token_expiry) = self.get_jwt_token(token_type).await?; | ||
|
||
Ok((jwt_token, jwt_token_expiry)) | ||
} | ||
} | ||
|
||
#[derive(Debug)] | ||
pub struct HubClientTokenRefresher { | ||
pub threadpool: Arc<ThreadPool>, | ||
pub token_type: String, | ||
pub client: Arc<HubClient>, | ||
} | ||
|
||
impl TokenRefresher for HubClientTokenRefresher { | ||
fn refresh(&self) -> std::result::Result<TokenInfo, AuthError> { | ||
let client = self.client.clone(); | ||
let token_type = self.token_type.clone(); | ||
let ret = self | ||
.threadpool | ||
.external_run_async_task(async move { client.refresh_jwt_token(&token_type).await }) | ||
.map_err(|e| AuthError::TokenRefreshFailure(e.to_string()))? | ||
.map_err(|e| AuthError::TokenRefreshFailure(e.to_string()))?; | ||
Ok(ret) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use anyhow::Result; | ||
use cas_client::build_http_client; | ||
|
||
use super::HubClient; | ||
|
||
#[tokio::test] | ||
#[ignore = "need valid token"] | ||
async fn test_get_jwt_token() -> Result<()> { | ||
let hub_client = HubClient { | ||
endpoint: "https://xethub-poc.us.dev.moon.huggingface.tech".to_owned(), | ||
token: "[MASKED]".to_owned(), | ||
repo_type: "dataset".to_owned(), | ||
repo_id: "test/t2".to_owned(), | ||
client: build_http_client(&None)?, | ||
}; | ||
|
||
let (cas_endpoint, jwt_token, jwt_token_expiry) = hub_client.get_jwt_token("read").await?; | ||
|
||
println!("{cas_endpoint}, {jwt_token}, {jwt_token_expiry}"); | ||
|
||
println!("{:?}", hub_client.refresh_jwt_token("write").await?); | ||
|
||
Ok(()) | ||
} | ||
} |
Oops, something went wrong.