From 72be3355b8041ef2358fe7c36b6b65c7f508b5f6 Mon Sep 17 00:00:00 2001 From: daalfox Date: Sun, 2 Mar 2025 02:29:52 +0330 Subject: [PATCH] normalize_path: Add `Append` mode --- tower-http/src/builder.rs | 18 +++ tower-http/src/normalize_path.rs | 187 +++++++++++++++++++++++++++---- 2 files changed, 182 insertions(+), 23 deletions(-) diff --git a/tower-http/src/builder.rs b/tower-http/src/builder.rs index 58b789f2..cce1311a 100644 --- a/tower-http/src/builder.rs +++ b/tower-http/src/builder.rs @@ -364,6 +364,17 @@ pub trait ServiceBuilderExt: crate::sealed::Sealed + Sized { fn trim_trailing_slash( self, ) -> ServiceBuilder>; + + /// Normalize paths based on the specified `mode`. + /// + /// See [`tower_http::normalize_path`] for more details. + /// + /// [`tower_http::normalize_path`]: crate::normalize_path + #[cfg(feature = "normalize-path")] + fn normalize_path( + self, + mode: crate::normalize_path::NormalizeMode, + ) -> ServiceBuilder>; } impl crate::sealed::Sealed for ServiceBuilder {} @@ -594,4 +605,11 @@ impl ServiceBuilderExt for ServiceBuilder { ) -> ServiceBuilder> { self.layer(crate::normalize_path::NormalizePathLayer::trim_trailing_slash()) } + #[cfg(feature = "normalize-path")] + fn normalize_path( + self, + mode: crate::normalize_path::NormalizeMode, + ) -> ServiceBuilder> { + self.layer(crate::normalize_path::NormalizePathLayer::new(mode)) + } } diff --git a/tower-http/src/normalize_path.rs b/tower-http/src/normalize_path.rs index efc7be52..92e3fc70 100644 --- a/tower-http/src/normalize_path.rs +++ b/tower-http/src/normalize_path.rs @@ -1,12 +1,11 @@ //! Middleware that normalizes paths. //! -//! Any trailing slashes from request paths will be removed. For example, a request with `/foo/` -//! will be changed to `/foo` before reaching the inner service. +//! Normalizes the request paths based on the provided `NormalizeMode` //! //! # Example //! //! ``` -//! use tower_http::normalize_path::NormalizePathLayer; +//! use tower_http::normalize_path::{NormalizePathLayer, NormalizeMode}; //! use http::{Request, Response, StatusCode}; //! use http_body_util::Full; //! use bytes::Bytes; @@ -22,7 +21,7 @@ //! //! let mut service = ServiceBuilder::new() //! // trim trailing slashes from paths -//! .layer(NormalizePathLayer::trim_trailing_slash()) +//! .layer(NormalizePathLayer::new(NormalizeMode::Trim)) //! .service_fn(handle); //! //! // call the service @@ -45,11 +44,22 @@ use std::{ use tower_layer::Layer; use tower_service::Service; +/// Different modes of normalizing paths +#[derive(Debug, Copy, Clone)] +pub enum NormalizeMode { + /// Normalizes paths by trimming the trailing slashes, e.g. /foo/ -> /foo + Trim, + /// Normalizes paths by appending trailing slash, e.g. /foo -> /foo/ + Append, +} + /// Layer that applies [`NormalizePath`] which normalizes paths. /// /// See the [module docs](self) for more details. #[derive(Debug, Copy, Clone)] -pub struct NormalizePathLayer {} +pub struct NormalizePathLayer { + mode: NormalizeMode, +} impl NormalizePathLayer { /// Create a new [`NormalizePathLayer`]. @@ -57,7 +67,16 @@ impl NormalizePathLayer { /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/` /// will be changed to `/foo` before reaching the inner service. pub fn trim_trailing_slash() -> Self { - NormalizePathLayer {} + NormalizePathLayer { + mode: NormalizeMode::Trim, + } + } + + /// Create a new [`NormalizePathLayer`]. + /// + /// Creates a new `NormalizePathLayer` with the specified mode. + pub fn new(mode: NormalizeMode) -> Self { + NormalizePathLayer { mode } } } @@ -65,7 +84,7 @@ impl Layer for NormalizePathLayer { type Service = NormalizePath; fn layer(&self, inner: S) -> Self::Service { - NormalizePath::trim_trailing_slash(inner) + NormalizePath::new(inner, self.mode) } } @@ -74,16 +93,16 @@ impl Layer for NormalizePathLayer { /// See the [module docs](self) for more details. #[derive(Debug, Copy, Clone)] pub struct NormalizePath { + mode: NormalizeMode, inner: S, } impl NormalizePath { /// Create a new [`NormalizePath`]. /// - /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/` - /// will be changed to `/foo` before reaching the inner service. - pub fn trim_trailing_slash(inner: S) -> Self { - Self { inner } + /// Normalize path based on the specified `mode` + pub fn new(inner: S, mode: NormalizeMode) -> Self { + Self { mode, inner } } define_inner_service_accessors!(); @@ -103,12 +122,15 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - normalize_trailing_slash(req.uri_mut()); + match self.mode { + NormalizeMode::Trim => trim_trailing_slash(req.uri_mut()), + NormalizeMode::Append => append_trailing_slash(req.uri_mut()), + } self.inner.call(req) } } -fn normalize_trailing_slash(uri: &mut Uri) { +fn trim_trailing_slash(uri: &mut Uri) { if !uri.path().ends_with('/') && !uri.path().starts_with("//") { return; } @@ -137,6 +159,40 @@ fn normalize_trailing_slash(uri: &mut Uri) { } } +fn append_trailing_slash(uri: &mut Uri) { + if uri.path().ends_with("/") && !uri.path().ends_with("//") { + return; + } + + let trimmed = uri.path().trim_matches('/'); + let new_path = if trimmed.is_empty() { + "/".to_string() + } else { + format!("/{}/", trimmed) + }; + + let mut parts = uri.clone().into_parts(); + + let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query { + let new_path_and_query = if let Some(query) = path_and_query.query() { + Cow::Owned(format!("{}?{}", new_path, query)) + } else { + new_path.into() + } + .parse() + .unwrap(); + + Some(new_path_and_query) + } else { + Some(new_path.parse().unwrap()) + }; + + parts.path_and_query = new_path_and_query; + if let Ok(new_uri) = Uri::from_parts(parts) { + *uri = new_uri; + } +} + #[cfg(test)] mod tests { use super::*; @@ -144,7 +200,7 @@ mod tests { use tower::{ServiceBuilder, ServiceExt}; #[tokio::test] - async fn works() { + async fn trim_works() { async fn handle(request: Request<()>) -> Result, Infallible> { Ok(Response::new(request.uri().to_string())) } @@ -168,63 +224,148 @@ mod tests { #[test] fn is_noop_if_no_trailing_slash() { let mut uri = "/foo".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/foo"); } #[test] fn maintains_query() { let mut uri = "/foo/?a=a".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/foo?a=a"); } #[test] fn removes_multiple_trailing_slashes() { let mut uri = "/foo////".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/foo"); } #[test] fn removes_multiple_trailing_slashes_even_with_query() { let mut uri = "/foo////?a=a".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/foo?a=a"); } #[test] fn is_noop_on_index() { let mut uri = "/".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/"); } #[test] fn removes_multiple_trailing_slashes_on_index() { let mut uri = "////".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/"); } #[test] fn removes_multiple_trailing_slashes_on_index_even_with_query() { let mut uri = "////?a=a".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/?a=a"); } #[test] fn removes_multiple_preceding_slashes_even_with_query() { let mut uri = "///foo//?a=a".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/foo?a=a"); } #[test] fn removes_multiple_preceding_slashes() { let mut uri = "///foo".parse::().unwrap(); - normalize_trailing_slash(&mut uri); + trim_trailing_slash(&mut uri); assert_eq!(uri, "/foo"); } + + #[tokio::test] + async fn append_works() { + async fn handle(request: Request<()>) -> Result, Infallible> { + Ok(Response::new(request.uri().to_string())) + } + + let mut svc = ServiceBuilder::new() + .layer(NormalizePathLayer::new(NormalizeMode::Trim)) + .service_fn(handle); + + let body = svc + .ready() + .await + .unwrap() + .call(Request::builder().uri("/foo").body(()).unwrap()) + .await + .unwrap() + .into_body(); + + assert_eq!(body, "/foo/"); + } + + #[test] + fn is_noop_if_trailing_slash() { + let mut uri = "/foo/".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/"); + } + + #[test] + fn append_maintains_query() { + let mut uri = "/foo?a=a".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/?a=a"); + } + + #[test] + fn append_only_keeps_one_slash() { + let mut uri = "/foo////".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/"); + } + + #[test] + fn append_only_keeps_one_slash_even_with_query() { + let mut uri = "/foo////?a=a".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/?a=a"); + } + + #[test] + fn append_is_noop_on_index() { + let mut uri = "/".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/"); + } + + #[test] + fn append_removes_multiple_trailing_slashes_on_index() { + let mut uri = "////".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/"); + } + + #[test] + fn append_removes_multiple_trailing_slashes_on_index_even_with_query() { + let mut uri = "////?a=a".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/?a=a"); + } + + #[test] + fn append_removes_multiple_preceding_slashes_even_with_query() { + let mut uri = "///foo//?a=a".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/?a=a"); + } + + #[test] + fn append_removes_multiple_preceding_slashes() { + let mut uri = "///foo".parse::().unwrap(); + append_trailing_slash(&mut uri); + assert_eq!(uri, "/foo/"); + } }