diff --git a/Cargo.toml b/Cargo.toml index 8f99e3d..e10e27f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "clashrup" description = "Simple CLI to manage your systemd clash.service and config subscriptions on Linux." -version = "0.2.2" +version = "0.2.3" edition = "2021" readme = "README.md" license = "MIT" @@ -18,9 +18,12 @@ clap = { version = "4.0.32", features = ["derive"] } colored = "2.0.0" serde = { version = "1.0.152", features = ["derive"] } toml = "0.5.10" -reqwest = { version = "0.11", features = ["blocking"] } flate2 = "1.0" shellexpand = "3.0.0" openssl = { version = "0.10", features = ["vendored"] } serde_yaml = "0.9.16" local-ip-address = "0.5.0" +reqwest = { version = "0.11", features = ["stream"] } +futures-util = "0.3" +indicatif = "0.17" +tokio = { version = "1.24", features = ["full"] } diff --git a/src/main.rs b/src/main.rs index 49f08ee..0797c07 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,7 @@ use clap::Parser; use clap::Subcommand; use colored::Colorize; use local_ip_address::local_ip; +use reqwest::Client; use shellexpand::tilde; use config::apply_clash_override; @@ -68,7 +69,8 @@ enum ProxyCommands { Unset, } -fn main() { +#[tokio::main] +async fn main() { let args = Args::parse(); let prefix = "clashrup:"; let config_path = tilde(&args.clashrup_config).to_string(); @@ -91,6 +93,9 @@ fn main() { let clash_target_service_path = tilde(&format!("{}/clash.service", config.user_systemd_root)).to_string(); + // Reuse http client for file download + let client = Client::new(); + match &args.command { Some(Commands::Setup) => { // Attempt to download and setup clash binary if needed @@ -109,7 +114,9 @@ fn main() { } // Download clash binary and set permission to executable - download_file(&config.remote_clash_binary_url, clash_gzipped_path); + download_file(&client, &config.remote_clash_binary_url, clash_gzipped_path) + .await + .unwrap(); extract_gzip(clash_gzipped_path, &clash_target_binary_path, prefix); let executable = fs::Permissions::from_mode(0o755); @@ -117,11 +124,19 @@ fn main() { } // Download remote clash config and apply override - download_file(&config.remote_config_url, &clash_target_config_path); + download_file( + &client, + &config.remote_config_url, + &clash_target_config_path, + ) + .await + .unwrap(); apply_clash_override(&clash_target_config_path, &config.clash_config); // Download remote Country.mmdb - download_file(&config.remote_mmdb_url, &clash_target_mmdb_path); + download_file(&client, &config.remote_mmdb_url, &clash_target_mmdb_path) + .await + .unwrap(); // Create clash.service systemd file create_clash_service( @@ -136,12 +151,20 @@ fn main() { } Some(Commands::Update) => { // Download remote clash config and apply override - download_file(&config.remote_config_url, &clash_target_config_path); + download_file( + &client, + &config.remote_config_url, + &clash_target_config_path, + ) + .await + .unwrap(); apply_clash_override(&clash_target_config_path, &config.clash_config); println!("{} Updated and applied config overrides", prefix.yellow()); // Download remote Country.mmdb - download_file(&config.remote_mmdb_url, &clash_target_mmdb_path); + download_file(&client, &config.remote_mmdb_url, &clash_target_mmdb_path) + .await + .unwrap(); // Restart clash systemd service println!("{} Restart clash.service", prefix.green()); diff --git a/src/utils.rs b/src/utils.rs index 02a1c11..7cc0775 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,31 +1,86 @@ +use std::cmp::min; use std::fs; +use std::fs::File; use std::io; +use std::io::Write; use std::path::Path; use colored::Colorize; use flate2::read::GzDecoder; +use futures_util::StreamExt; +use indicatif::ProgressBar; +use indicatif::ProgressStyle; +use reqwest::Client; -pub fn download_file(url: &str, path: &str) { +/// Download file from url to path with a reusable http client. +/// +/// Renders a progress bar if content-length is available from the url headers provided. If not, +/// renders a spinner to indicate that something is downloading. +/// +/// With reference from: +/// * https://github.com/mihaigalos/tutorials/blob/800d5acbc333fd4068622e9b3d870cb5b7d34e12/rust/download_with_progressbar/src/main.rs +/// * https://github.com/console-rs/indicatif/blob/2954b1a24ac5f1900a7861992e4825bff643c9e2/examples/yarnish.rs +/// +/// Note: Allow `clippy::unused_io_amount` because we are writing downloaded chunks on the fly. +#[allow(clippy::unused_io_amount)] +pub async fn download_file(client: &Client, url: &str, path: &str) -> Result<(), String> { // Create parent directory for download destination if not exists let parent_dir = Path::new(path).parent().unwrap(); if !parent_dir.exists() { fs::create_dir_all(parent_dir).unwrap(); } - // Download file - println!( - "{} Downloading from {}", - "download:".blue(), - url.underline().yellow() - ); - let mut resp = reqwest::blocking::get(url).unwrap(); - let mut file = fs::File::create(path).unwrap(); - resp.copy_to(&mut file).unwrap(); - println!( - "{} Downloaded to {}", - "download:".blue(), - path.underline().yellow() - ); + // Create shared http client for multiple downloads when possible + let res = client + .get(url) + .send() + .await + .or(Err(format!("Failed to GET from '{}'", &url)))?; + + // If content length is not available or 0, use a spinner instead of a progress bar + let total_size = res.content_length().unwrap_or(0); + let pb = ProgressBar::new(total_size); + + let bar_style = ProgressStyle::with_template( + "{prefix:.blue}: {msg}\n {elapsed_precise} [{bar:30.white/blue}] \ + {bytes}/{total_bytes} ({bytes_per_sec}, {eta})", + ) + .unwrap() + .progress_chars("- "); + let spinner_style = ProgressStyle::with_template( + "{prefix:.blue}: {wide_msg}\n \ + {spinner} {elapsed_precise} - Download speed {bytes_per_sec}", + ) + .unwrap(); + + if total_size == 0 { + pb.set_style(spinner_style); + } else { + pb.set_style(bar_style); + } + pb.set_prefix("download"); + pb.set_message(format!("Downloading {}", url.underline())); + + // Start file download and update progress bar when new data chunk is received + let mut file = File::create(path).unwrap(); + let mut downloaded: u64 = 0; + let mut stream = res.bytes_stream(); + + while let Some(item) = stream.next().await { + let chunk = item.or(Err("Error while downloading file"))?; + + file.write(&chunk).or(Err("Error while writing to file"))?; + if total_size != 0 { + let new = min(downloaded + (chunk.len() as u64), total_size); + downloaded = new; + pb.set_position(new); + } else { + pb.inc(chunk.len() as u64); + } + } + + pb.finish_with_message(format!("Downloaded to {}", path.underline())); + Ok(()) } pub fn delete_file(path: &str, prefix: &str) {