diff --git a/crates/core/src/catcher.rs b/crates/core/src/catcher.rs index 9b21ef4e7..5f85d99c8 100644 --- a/crates/core/src/catcher.rs +++ b/crates/core/src/catcher.rs @@ -86,16 +86,16 @@ impl Catcher { &mut self.hoops } - /// Add a handler as middleware, it will run the handler in current router or it's descendants - /// handle the request. + /// Add a handler as middleware, it will run the handler when error catched. #[inline] pub fn hoop(mut self, hoop: H) -> Self { self.hoops.push(Arc::new(hoop)); self } - /// Add a handler as middleware, it will run the handler in current router or it's descendants - /// handle the request. This middleware only effective when the filter return true. + /// Add a handler as middleware, it will run the handler when error catched. + /// + /// This middleware only effective when the filter return true. #[inline] pub fn hoop_when(mut self, hoop: H, filter: F) -> Self where diff --git a/crates/core/src/http/response.rs b/crates/core/src/http/response.rs index fb04bee70..4d86c5abc 100644 --- a/crates/core/src/http/response.rs +++ b/crates/core/src/http/response.rs @@ -220,11 +220,15 @@ impl Response { } } + let status_code = status_code.unwrap_or_else(|| match &body { + ResBody::None => StatusCode::NOT_FOUND, + ResBody::Error(e) => e.code, + _ => StatusCode::OK, + }); let mut res = hyper::Response::new(body); *res.extensions_mut() = extensions; *res.headers_mut() = headers; - // Default to a 404 if no response code was set - *res.status_mut() = status_code.unwrap_or(StatusCode::NOT_FOUND); + *res.status_mut() = status_code; res } diff --git a/crates/core/src/service.rs b/crates/core/src/service.rs index 62d98600f..ac021032a 100644 --- a/crates/core/src/service.rs +++ b/crates/core/src/service.rs @@ -10,6 +10,7 @@ use hyper::{Method, Request as HyperRequest, Response as HyperResponse}; use crate::catcher::{write_error_default, Catcher}; use crate::conn::SocketAddr; +use crate::handler::{Handler, WhenHoop}; use crate::http::body::{ReqBody, ResBody}; use crate::http::{Mime, Request, Response, StatusCode}; use crate::routing::{FlowCtrl, PathState, Router}; @@ -22,6 +23,8 @@ pub struct Service { pub router: Arc, /// The catcher of this service. pub catcher: Option>, + /// These hoops will alwways be called when request received. + pub hoops: Vec>, /// The allowed media types of this service. pub allowed_media_types: Arc>, } @@ -36,6 +39,7 @@ impl Service { Service { router: router.into(), catcher: None, + hoops: vec![], allowed_media_types: Arc::new(vec![]), } } @@ -74,6 +78,26 @@ impl Service { self } + /// Add a handler as middleware, it will run the handler when request received. + #[inline] + pub fn hoop(mut self, hoop: H) -> Self { + self.hoops.push(Arc::new(hoop)); + self + } + + /// Add a handler as middleware, it will run the handler when request received. + /// + /// This middleware only effective when the filter return true. + #[inline] + pub fn hoop_when(mut self, hoop: H, filter: F) -> Self + where + H: Handler, + F: Fn(&Request, &Depot) -> bool + Send + Sync + 'static, + { + self.hoops.push(Arc::new(WhenHoop { inner: hoop, filter })); + self + } + /// Sets allowed media types list and returns `Self` for write code chained. /// /// # Example @@ -110,6 +134,7 @@ impl Service { http_scheme, router: self.router.clone(), catcher: self.catcher.clone(), + hoops: self.hoops.clone(), allowed_media_types: self.allowed_media_types.clone(), alt_svc_h3, } @@ -143,6 +168,7 @@ pub struct HyperHandler { pub(crate) http_scheme: Scheme, pub(crate) router: Arc, pub(crate) catcher: Option>, + pub(crate) hoops: Vec>, pub(crate) allowed_media_types: Arc>, pub(crate) alt_svc_h3: Option, } @@ -167,14 +193,22 @@ impl HyperHandler { let mut path_state = PathState::new(req.uri().path()); let router = self.router.clone(); + let hoops = self.hoops.clone(); async move { if let Some(dm) = router.detect(&mut req, &mut path_state) { req.params = path_state.params; - let mut ctrl = FlowCtrl::new([&dm.hoops[..], &[dm.goal]].concat()); + let mut ctrl = FlowCtrl::new([&hoops[..], &dm.hoops[..], &[dm.goal]].concat()); ctrl.call_next(&mut req, &mut depot, &mut res).await; if res.status_code.is_none() { res.status_code = Some(StatusCode::OK); } + } else if !hoops.is_empty() { + req.params = path_state.params; + let mut ctrl = FlowCtrl::new(hoops); + ctrl.call_next(&mut req, &mut depot, &mut res).await; + if res.status_code.is_none() { + res.status_code = Some(StatusCode::NOT_FOUND); + } } else { res.status_code(StatusCode::NOT_FOUND); } diff --git a/crates/extra/src/logging.rs b/crates/extra/src/logging.rs index 36f8bb5d4..9535675d6 100644 --- a/crates/extra/src/logging.rs +++ b/crates/extra/src/logging.rs @@ -5,7 +5,7 @@ use std::time::Instant; use tracing::{Instrument, Level}; -use salvo_core::http::{Request, Response, StatusCode}; +use salvo_core::http::{Request, ResBody, Response, StatusCode}; use salvo_core::{async_trait, Depot, FlowCtrl, Handler}; /// A simple logger middleware. @@ -36,7 +36,11 @@ impl Handler for Logger { ctrl.call_next(req, depot, res).await; let duration = now.elapsed(); - let status = res.status_code.unwrap_or(StatusCode::OK); + let status = res.status_code.unwrap_or_else(|| match &res.body { + ResBody::None => StatusCode::NOT_FOUND, + ResBody::Error(e) => e.code, + _ => StatusCode::OK, + }); tracing::info!( %status, ?duration, diff --git a/examples/cors/src/main.rs b/examples/cors/src/main.rs index 5a0b161fb..f00427537 100644 --- a/examples/cors/src/main.rs +++ b/examples/cors/src/main.rs @@ -21,12 +21,10 @@ async fn backend_server() { .allow_headers("authorization") .into_handler(); - let router = Router::with_hoop(cors.clone()) - .push(Router::with_path("hello").post(hello)) - .options(handler::empty()); + let router = Router::with_path("hello").post(hello); + let service = Service::new(router).hoop(cors); let acceptor = TcpListener::new("0.0.0.0:5600").bind().await; - let service = Service::new(router).catcher(Catcher::default().hoop(cors)); Server::new(acceptor).serve(service).await; } diff --git a/examples/logging/src/main.rs b/examples/logging/src/main.rs index 3a82d6928..2e9c7e14f 100644 --- a/examples/logging/src/main.rs +++ b/examples/logging/src/main.rs @@ -10,8 +10,9 @@ async fn hello() -> &'static str { async fn main() { tracing_subscriber::fmt().init(); - let router = Router::new().hoop(Logger::new()).get(hello); + let router = Router::new().get(hello); + let service = Service::new(router).hoop(Logger::new()); let acceptor = TcpListener::new("0.0.0.0:5800").bind().await; - Server::new(acceptor).serve(router).await; + Server::new(acceptor).serve(service).await; }