diff --git a/src/lib.rs b/src/lib.rs index a0ae373..da46ac3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,14 @@ pub async fn start() -> io::Result<()> { println!("{}[v{}] - {}", &cargo.pkg_name, &cargo.pkg_version, &cargo.description); squire::ascii_art::random(); + // Log a warning message for max payload size beyond 1 GB + if config.max_payload_size > 1024 * 1024 * 1024 { + // Since the default is just 100 MB, the only way to get here is to have an env var + log::warn!("Max payload size is set to '{}' which exceeds the optimal upload size.", + std::env::var("max_payload_size").unwrap()); + log::warn!("Please consider network bandwidth and latency, before using RuStream to upload such high-volume data."); + } + if config.secure_session { log::warn!( "Secure session is turned on! This means that the server can ONLY be hosted via HTTPS or localhost" @@ -61,14 +69,13 @@ pub async fn start() -> io::Result<()> { The closure is defining the configuration for the Actix web server. The purpose of the closure is to configure the server before it starts listening for incoming requests. */ - let max_payload_size = 10 * 1024 * 1024 * 1024; // 10 GB let application = move || { App::new() // Creates a new Actix web application .app_data(web::Data::new(config_clone.clone())) .app_data(web::Data::new(jinja.clone())) .app_data(web::Data::new(fernet.clone())) .app_data(web::Data::new(session.clone())) - .app_data(web::PayloadConfig::default().limit(max_payload_size)) + .app_data(web::PayloadConfig::default().limit(config_clone.max_payload_size)) .wrap(squire::middleware::get_cors(config_clone.websites.clone())) .wrap(middleware::Logger::default()) // Adds a default logger middleware to the application .service(routes::basics::health) // Registers a service for handling requests @@ -84,8 +91,8 @@ pub async fn start() -> io::Result<()> { .service(routes::upload::save_files) }; let server = HttpServer::new(application) - .workers(config.workers as usize) - .max_connections(config.max_connections as usize); + .workers(config.workers) + .max_connections(config.max_connections); // Reference: https://actix.rs/docs/http2/ if config.cert_file.exists() && config.key_file.exists() { log::info!("Binding SSL certificate to serve over HTTPS"); diff --git a/src/routes/auth.rs b/src/routes/auth.rs index 47914b2..a396c95 100644 --- a/src/routes/auth.rs +++ b/src/routes/auth.rs @@ -55,7 +55,7 @@ pub async fn login(request: HttpRequest, let payload = serde_json::to_string(&mapped).unwrap(); let encrypted_payload = fernet.encrypt(payload.as_bytes()); - let cookie_duration = Duration::seconds(config.session_duration as i64); + let cookie_duration = Duration::seconds(config.session_duration); let expiration = OffsetDateTime::now_utc() + cookie_duration; let base_cookie = Cookie::build("session_token", encrypted_payload) .http_only(true) diff --git a/src/squire/authenticator.rs b/src/squire/authenticator.rs index b6bebf5..4563503 100644 --- a/src/squire/authenticator.rs +++ b/src/squire/authenticator.rs @@ -167,12 +167,12 @@ pub fn verify_token( username, }; } - if current_time - timestamp > config.session_duration as i64 { + if current_time - timestamp > config.session_duration { return AuthToken { ok: false, detail: "Session Expired".to_string(), username }; } AuthToken { ok: true, - detail: format!("Session valid for {}s", timestamp + config.session_duration as i64 - current_time), + detail: format!("Session valid for {}s", timestamp + config.session_duration - current_time), username, } } else { diff --git a/src/squire/settings.rs b/src/squire/settings.rs index d00c563..f1551cc 100644 --- a/src/squire/settings.rs +++ b/src/squire/settings.rs @@ -16,16 +16,18 @@ pub struct Config { /// Host IP address for media streaming. pub media_host: String, /// Port number for hosting the application. - pub media_port: i32, + pub media_port: u16, /// Duration of a session in seconds. - pub session_duration: i32, + pub session_duration: i64, /// List of supported file formats. pub file_formats: Vec, /// Number of worker threads to spin up the server. - pub workers: i32, + pub workers: usize, /// Maximum number of concurrent connections. - pub max_connections: i32, + pub max_connections: usize, + /// Max payload allowed by the server in request body. + pub max_payload_size: usize, /// List of websites (supports regex) to add to CORS configuration. pub websites: Vec, @@ -38,13 +40,13 @@ pub struct Config { pub cert_file: path::PathBuf, } -/// Returns the default value for debug flag +/// Returns the default value for debug flag. pub fn default_debug() -> bool { false } -/// Returns the default value for utc_logging +/// Returns the default value for UTC logging. pub fn default_utc_logging() -> bool { true } -/// Returns the default value for ssl files +/// Returns the default value for SSL files. pub fn default_ssl() -> path::PathBuf { path::PathBuf::new() } /// Returns the default media host based on the local machine's IP address. @@ -63,22 +65,22 @@ pub fn default_media_host() -> String { "localhost".to_string() } -/// Returns the default media port (8000). -pub fn default_media_port() -> i32 { 8000 } +/// Returns the default media port (8000) +pub fn default_media_port() -> u16 { 8000 } -/// Returns the default session duration (3600 seconds). -pub fn default_session_duration() -> i32 { 3600 } +/// Returns the default session duration (3600 seconds) +pub fn default_session_duration() -> i64 { 3600 } /// Returns the file formats supported by default. pub fn default_file_formats() -> Vec { vec!["mp4".to_string(), "mov".to_string(), "jpg".to_string(), "jpeg".to_string()] } -/// Returns the default number of worker threads (half of logical cores). -pub fn default_workers() -> i32 { +/// Returns the default number of worker threads (half of logical cores) +pub fn default_workers() -> usize { let logical_cores = thread::available_parallelism(); match logical_cores { - Ok(cores) => cores.get() as i32 / 2, + Ok(cores) => cores.get() / 2, Err(err) => { log::error!("{}", err); 3 @@ -86,10 +88,13 @@ pub fn default_workers() -> i32 { } } -/// Returns the default maximum number of concurrent connections (3). -pub fn default_max_connections() -> i32 { 3 } +/// Returns the default maximum number of concurrent connections (3) +pub fn default_max_connections() -> usize { 3 } -/// Returns an empty list as the default website (CORS configuration). +/// Returns the default max payload size (100 MB) +pub fn default_max_payload_size() -> usize { 100 * 1024 * 1024 } + +/// Returns an empty list as the default website (CORS configuration) pub fn default_websites() -> Vec { Vec::new() } /// Returns the default value for secure_session diff --git a/src/squire/startup.rs b/src/squire/startup.rs index 65828f1..78ccd91 100644 --- a/src/squire/startup.rs +++ b/src/squire/startup.rs @@ -1,5 +1,4 @@ use std; -use std::ffi::OsStr; use std::io::Write; use chrono::{DateTime, Local, Utc}; @@ -110,7 +109,7 @@ fn parse_bool(key: &str) -> Option { } } -/// Extracts the env var by key and parses it as a `i32` +/// Extracts the env var by key and parses it as a `i64` /// /// # Arguments /// @@ -118,17 +117,67 @@ fn parse_bool(key: &str) -> Option { /// /// # Returns /// -/// Returns an `Option` if the value is available. +/// Returns an `Option` if the value is available. /// /// # Panics /// /// If the value is present, but it is an invalid data-type. -fn parse_i32(key: &str) -> Option { +fn parse_i64(key: &str) -> Option { match std::env::var(key) { Ok(val) => match val.parse() { Ok(parsed) => Some(parsed), Err(_) => { - panic!("\n{}\n\texpected i32, received '{}' [value=invalid]\n", key, val); + panic!("\n{}\n\texpected i64, received '{}' [value=invalid]\n", key, val); + } + }, + Err(_) => None, + } +} + +/// Extracts the env var by key and parses it as a `u16` +/// +/// # Arguments +/// +/// * `key` - Key for the environment variable. +/// +/// # Returns +/// +/// Returns an `Option` if the value is available. +/// +/// # Panics +/// +/// If the value is present, but it is an invalid data-type. +fn parse_u16(key: &str) -> Option { + match std::env::var(key) { + Ok(val) => match val.parse() { + Ok(parsed) => Some(parsed), + Err(_) => { + panic!("\n{}\n\texpected u16, received '{}' [value=invalid]\n", key, val); + } + }, + Err(_) => None, + } +} + +/// Extracts the env var by key and parses it as a `usize` +/// +/// # Arguments +/// +/// * `key` - Key for the environment variable. +/// +/// # Returns +/// +/// Returns an `Option` if the value is available. +/// +/// # Panics +/// +/// If the value is present, but it is an invalid data-type. +fn parse_usize(key: &str) -> Option { + match std::env::var(key) { + Ok(val) => match val.parse() { + Ok(parsed) => Some(parsed), + Err(_) => { + panic!("\n{}\n\texpected usize, received '{}' [value=invalid]\n", key, val); } }, Err(_) => None, @@ -180,6 +229,76 @@ fn parse_path(key: &str) -> Option { } } +/// Parses the maximum payload size from human-readable memory format to bytes. +/// +/// - `key` - Key for the environment variable. +/// +/// ## See Also +/// +/// - This function handles internal panic gracefully, in the most detailed way possible. +/// - Panic outputs are suppressed with a custom hook. +/// - Custom hook is set before wrapping the potentially panicking function inside `catch_unwind`. +/// - Custom hook is reset later, so the future panics and go uncaught. +/// - Error message from panic payload is also further processed, to get a detailed reason for panic. +/// +/// # Returns +/// +/// Returns an option of usize if the value is parsable and within the allowed size limit. +fn parse_max_payload(key: &str) -> Option { + match std::env::var(key) { + Ok(value) => { + + let custom_hook = std::panic::take_hook(); + std::panic::set_hook(Box::new(|_panic_info| {})); + let result = std::panic::catch_unwind(|| parse_memory(&value)); + std::panic::set_hook(custom_hook); + + match result { + Ok(output) => { + if let Some(value) = output { + Some(value) + } else { + panic!("\n{}\n\texpected format: '100 MB', received '{}' [value=invalid]\n", + key, value); + } + } + Err(panic_payload) => { + if let Some(&error) = panic_payload.downcast_ref::<&str>() { + panic!("\n{}\n\t{} [value=invalid]\n", key, error); + } else if let Some(error) = panic_payload.downcast_ref::() { + panic!("\n{}\n\t{} [value=invalid]\n", key, error); + } else if let Some(error) = panic_payload.downcast_ref::>() { + panic!("\n{}\n\t{:?} [value=invalid]\n", key, error); + } else { + panic!("\n{}\n\tinvalid memory format! unable to parse panic payload [value=invalid]\n", key); + } + } + } + } + Err(_) => { + None + } + } +} + +fn parse_memory(memory: &str) -> Option { + let value = memory.trim(); + let (size_str, unit) = value.split_at(value.len() - 2); + let size: usize = match size_str.strip_suffix(' ').unwrap_or_default().parse() { + Ok(num) => num, + Err(_) => return None, + }; + + match unit.to_lowercase().as_str() { + "zb" => Some(size * 1024 * 1024 * 1024 * 1024 * 1024), + "tb" => Some(size * 1024 * 1024 * 1024 * 1024), + "gb" => Some(size * 1024 * 1024 * 1024), + "mb" => Some(size * 1024 * 1024), + "kb" => Some(size * 1024), + _ => None, + } +} + /// Handler that's responsible to parse all the env vars. /// /// # Returns @@ -190,15 +309,16 @@ fn load_env_vars() -> settings::Config { let debug = parse_bool("debug").unwrap_or(settings::default_debug()); let utc_logging = parse_bool("utc_logging").unwrap_or(settings::default_utc_logging()); let media_host = std::env::var("media_host").unwrap_or(settings::default_media_host()); - let media_port = parse_i32("media_port").unwrap_or(settings::default_media_port()); - let session_duration = parse_i32("session_duration").unwrap_or(settings::default_session_duration()); + let media_port = parse_u16("media_port").unwrap_or(settings::default_media_port()); + let session_duration = parse_i64("session_duration").unwrap_or(settings::default_session_duration()); let file_formats = parse_vec("file_formats").unwrap_or(settings::default_file_formats()); - let workers = parse_i32("workers").unwrap_or(settings::default_workers()); - let max_connections = parse_i32("max_connections").unwrap_or(settings::default_max_connections()); + let workers = parse_usize("workers").unwrap_or(settings::default_workers()); + let max_connections = parse_usize("max_connections").unwrap_or(settings::default_max_connections()); let websites = parse_vec("websites").unwrap_or(settings::default_websites()); let secure_session = parse_bool("secure_session").unwrap_or(settings::default_secure_session()); let key_file = parse_path("key_file").unwrap_or(settings::default_ssl()); let cert_file = parse_path("cert_file").unwrap_or(settings::default_ssl()); + let max_payload_size = parse_max_payload("max_payload_size").unwrap_or(settings::default_max_payload_size()); settings::Config { authorization, media_source, @@ -210,6 +330,7 @@ fn load_env_vars() -> settings::Config { file_formats, workers, max_connections, + max_payload_size, websites, secure_session, key_file, @@ -253,7 +374,7 @@ fn validate_dir_structure(config: &settings::Config, cargo: &Cargo) { let secure_dir = index_vec.last().unwrap(); // secure_parent_path is the secure index's location let secure_parent_path = &index_vec[0..index_vec.len() - 1] - .join(OsStr::new(std::path::MAIN_SEPARATOR_STR)); + .join(std::ffi::OsStr::new(std::path::MAIN_SEPARATOR_STR)); errors.push_str(&format!( "\n{:?}\n\tSecure index directory [{:?}] should be at the root [{:?}] [depth={}, valid=1]\n\ \t> Hint: Either move {:?} within {:?}, [OR] set the 'media_source' to {:?}\n", @@ -284,7 +405,7 @@ fn validate_dir_structure(config: &settings::Config, cargo: &Cargo) { get_time(config.utc_logging), cargo.crate_name, &secure_path.to_str().unwrap()) } - }, + } Err(err) => panic!("{}", err) } }