Skip to content

Commit

Permalink
normalize_path: Add Append mode
Browse files Browse the repository at this point in the history
  • Loading branch information
daalfox committed Mar 1, 2025
1 parent d0c522b commit 72be335
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 23 deletions.
18 changes: 18 additions & 0 deletions tower-http/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,17 @@ pub trait ServiceBuilderExt<L>: crate::sealed::Sealed<L> + Sized {
fn trim_trailing_slash(
self,
) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>>;

/// Normalize paths based on the specified `mode`.
///
/// See [`tower_http::normalize_path`] for more details.
///
/// [`tower_http::normalize_path`]: crate::normalize_path
#[cfg(feature = "normalize-path")]
fn normalize_path(
self,
mode: crate::normalize_path::NormalizeMode,
) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>>;
}

impl<L> crate::sealed::Sealed<L> for ServiceBuilder<L> {}
Expand Down Expand Up @@ -594,4 +605,11 @@ impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>> {
self.layer(crate::normalize_path::NormalizePathLayer::trim_trailing_slash())
}
#[cfg(feature = "normalize-path")]
fn normalize_path(
self,
mode: crate::normalize_path::NormalizeMode,
) -> ServiceBuilder<Stack<crate::normalize_path::NormalizePathLayer, L>> {
self.layer(crate::normalize_path::NormalizePathLayer::new(mode))
}
}
187 changes: 164 additions & 23 deletions tower-http/src/normalize_path.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
//! Middleware that normalizes paths.
//!
//! Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
//! will be changed to `/foo` before reaching the inner service.
//! Normalizes the request paths based on the provided `NormalizeMode`
//!
//! # Example
//!
//! ```
//! use tower_http::normalize_path::NormalizePathLayer;
//! use tower_http::normalize_path::{NormalizePathLayer, NormalizeMode};
//! use http::{Request, Response, StatusCode};
//! use http_body_util::Full;
//! use bytes::Bytes;
Expand All @@ -22,7 +21,7 @@
//!
//! let mut service = ServiceBuilder::new()
//! // trim trailing slashes from paths
//! .layer(NormalizePathLayer::trim_trailing_slash())
//! .layer(NormalizePathLayer::new(NormalizeMode::Trim))
//! .service_fn(handle);
//!
//! // call the service
Expand All @@ -45,27 +44,47 @@ use std::{
use tower_layer::Layer;
use tower_service::Service;

/// Different modes of normalizing paths
#[derive(Debug, Copy, Clone)]
pub enum NormalizeMode {
/// Normalizes paths by trimming the trailing slashes, e.g. /foo/ -> /foo
Trim,
/// Normalizes paths by appending trailing slash, e.g. /foo -> /foo/
Append,
}

/// Layer that applies [`NormalizePath`] which normalizes paths.
///
/// See the [module docs](self) for more details.
#[derive(Debug, Copy, Clone)]
pub struct NormalizePathLayer {}
pub struct NormalizePathLayer {
mode: NormalizeMode,
}

impl NormalizePathLayer {
/// Create a new [`NormalizePathLayer`].
///
/// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
/// will be changed to `/foo` before reaching the inner service.
pub fn trim_trailing_slash() -> Self {
NormalizePathLayer {}
NormalizePathLayer {
mode: NormalizeMode::Trim,
}
}

/// Create a new [`NormalizePathLayer`].
///
/// Creates a new `NormalizePathLayer` with the specified mode.
pub fn new(mode: NormalizeMode) -> Self {
NormalizePathLayer { mode }
}
}

impl<S> Layer<S> for NormalizePathLayer {
type Service = NormalizePath<S>;

fn layer(&self, inner: S) -> Self::Service {
NormalizePath::trim_trailing_slash(inner)
NormalizePath::new(inner, self.mode)
}
}

Expand All @@ -74,16 +93,16 @@ impl<S> Layer<S> for NormalizePathLayer {
/// See the [module docs](self) for more details.
#[derive(Debug, Copy, Clone)]
pub struct NormalizePath<S> {
mode: NormalizeMode,
inner: S,
}

impl<S> NormalizePath<S> {
/// Create a new [`NormalizePath`].
///
/// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
/// will be changed to `/foo` before reaching the inner service.
pub fn trim_trailing_slash(inner: S) -> Self {
Self { inner }
/// Normalize path based on the specified `mode`
pub fn new(inner: S, mode: NormalizeMode) -> Self {
Self { mode, inner }
}

define_inner_service_accessors!();
Expand All @@ -103,12 +122,15 @@ where
}

fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
normalize_trailing_slash(req.uri_mut());
match self.mode {
NormalizeMode::Trim => trim_trailing_slash(req.uri_mut()),
NormalizeMode::Append => append_trailing_slash(req.uri_mut()),
}
self.inner.call(req)
}
}

fn normalize_trailing_slash(uri: &mut Uri) {
fn trim_trailing_slash(uri: &mut Uri) {
if !uri.path().ends_with('/') && !uri.path().starts_with("//") {
return;
}
Expand Down Expand Up @@ -137,14 +159,48 @@ fn normalize_trailing_slash(uri: &mut Uri) {
}
}

fn append_trailing_slash(uri: &mut Uri) {
if uri.path().ends_with("/") && !uri.path().ends_with("//") {
return;
}

let trimmed = uri.path().trim_matches('/');
let new_path = if trimmed.is_empty() {
"/".to_string()
} else {
format!("/{}/", trimmed)
};

let mut parts = uri.clone().into_parts();

let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
let new_path_and_query = if let Some(query) = path_and_query.query() {
Cow::Owned(format!("{}?{}", new_path, query))
} else {
new_path.into()
}
.parse()
.unwrap();

Some(new_path_and_query)
} else {
Some(new_path.parse().unwrap())
};

parts.path_and_query = new_path_and_query;
if let Ok(new_uri) = Uri::from_parts(parts) {
*uri = new_uri;
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::convert::Infallible;
use tower::{ServiceBuilder, ServiceExt};

#[tokio::test]
async fn works() {
async fn trim_works() {
async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
Ok(Response::new(request.uri().to_string()))
}
Expand All @@ -168,63 +224,148 @@ mod tests {
#[test]
fn is_noop_if_no_trailing_slash() {
let mut uri = "/foo".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo");
}

#[test]
fn maintains_query() {
let mut uri = "/foo/?a=a".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo?a=a");
}

#[test]
fn removes_multiple_trailing_slashes() {
let mut uri = "/foo////".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo");
}

#[test]
fn removes_multiple_trailing_slashes_even_with_query() {
let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo?a=a");
}

#[test]
fn is_noop_on_index() {
let mut uri = "/".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/");
}

#[test]
fn removes_multiple_trailing_slashes_on_index() {
let mut uri = "////".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/");
}

#[test]
fn removes_multiple_trailing_slashes_on_index_even_with_query() {
let mut uri = "////?a=a".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/?a=a");
}

#[test]
fn removes_multiple_preceding_slashes_even_with_query() {
let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo?a=a");
}

#[test]
fn removes_multiple_preceding_slashes() {
let mut uri = "///foo".parse::<Uri>().unwrap();
normalize_trailing_slash(&mut uri);
trim_trailing_slash(&mut uri);
assert_eq!(uri, "/foo");
}

#[tokio::test]
async fn append_works() {
async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
Ok(Response::new(request.uri().to_string()))
}

let mut svc = ServiceBuilder::new()
.layer(NormalizePathLayer::new(NormalizeMode::Trim))
.service_fn(handle);

let body = svc
.ready()
.await
.unwrap()
.call(Request::builder().uri("/foo").body(()).unwrap())
.await
.unwrap()
.into_body();

assert_eq!(body, "/foo/");
}

#[test]
fn is_noop_if_trailing_slash() {
let mut uri = "/foo/".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/");
}

#[test]
fn append_maintains_query() {
let mut uri = "/foo?a=a".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/?a=a");
}

#[test]
fn append_only_keeps_one_slash() {
let mut uri = "/foo////".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/");
}

#[test]
fn append_only_keeps_one_slash_even_with_query() {
let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/?a=a");
}

#[test]
fn append_is_noop_on_index() {
let mut uri = "/".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/");
}

#[test]
fn append_removes_multiple_trailing_slashes_on_index() {
let mut uri = "////".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/");
}

#[test]
fn append_removes_multiple_trailing_slashes_on_index_even_with_query() {
let mut uri = "////?a=a".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/?a=a");
}

#[test]
fn append_removes_multiple_preceding_slashes_even_with_query() {
let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/?a=a");
}

#[test]
fn append_removes_multiple_preceding_slashes() {
let mut uri = "///foo".parse::<Uri>().unwrap();
append_trailing_slash(&mut uri);
assert_eq!(uri, "/foo/");
}
}

0 comments on commit 72be335

Please sign in to comment.