Skip to content

Commit

Permalink
Merge branch 'main' into compression-ranges
Browse files Browse the repository at this point in the history
  • Loading branch information
jplatte authored Jan 26, 2024
2 parents 6b16b28 + cca6c68 commit db8ecc7
Show file tree
Hide file tree
Showing 16 changed files with 153 additions and 65 deletions.
9 changes: 4 additions & 5 deletions examples/axum-key-value-store/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ publish = false
license = "MIT"

[dependencies]
hyper = { version = "0.14.15", features = ["full"] }
tokio = { version = "1.2.0", features = ["full"] }
tokio = { version = "1.32.0", features = ["full"] }
tower = { version = "0.4.13", features = ["full"] }
tower-http = { path = "../../tower-http", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3.11", features = ["env-filter"] }
axum = "0.6"
clap = { version = "4.3.16", features = ["derive"] }
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
axum = "0.7"
clap = { version = "4.4.4", features = ["derive"] }
19 changes: 7 additions & 12 deletions examples/axum-key-value-store/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
fn main() {
eprintln!("this example has not yet been updated to hyper 1.0");
}

/*
use axum::{
body::Bytes,
extract::{Path, State},
Expand All @@ -18,6 +13,7 @@ use std::{
sync::{Arc, RwLock},
time::Duration,
};
use tokio::net::TcpListener;
use tower::ServiceBuilder;
use tower_http::{
timeout::TimeoutLayer,
Expand Down Expand Up @@ -49,10 +45,12 @@ async fn main() {
// Run our service
let addr = SocketAddr::from((Ipv4Addr::UNSPECIFIED, config.port));
tracing::info!("Listening on {}", addr);
axum::Server::bind(&addr)
.serve(app().into_make_service())
.await
.expect("server error");
axum::serve(
TcpListener::bind(addr).await.expect("bind error"),
app().into_make_service(),
)
.await
.expect("server error");
}

fn app() -> Router {
Expand All @@ -79,8 +77,6 @@ fn app() -> Router {
.sensitive_response_headers(sensitive_headers)
// Set a timeout
.layer(TimeoutLayer::new(Duration::from_secs(10)))
// Box the response body so it implements `Default` which is required by axum
.map_response_body(axum::body::boxed)
// Compress responses
.compression()
// Set a `Content-Type` if there isn't one already.
Expand Down Expand Up @@ -113,4 +109,3 @@ async fn set_key(Path(path): Path<String>, state: State<AppState>, value: Bytes)

// See https://github.com/tokio-rs/axum/blob/main/examples/testing/src/main.rs for an example of
// how to test axum apps
*/
Binary file added test-files/precompressed.txt.zst
Binary file not shown.
27 changes: 19 additions & 8 deletions tower-http/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,34 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

# Unreleased

## Added
## Added:

- None.
- **compression:** Will now send a `vary: accept-encoding` header on compressed responses ([#399])

## Changed
## Fixed

- None.
- **compression:** Skip compression for range requests ([#446])

## Removed
[#399]: https://github.com/tower-rs/tower-http/pull/399
[#446]: https://github.com/tower-rs/tower-http/pull/446

- None.
# 0.5.1 (January 14, 2024)

## Added

- **fs:** Support files precompressed with `zstd` in `ServeFile`
- **trace:** Add default generic parameters for `ResponseBody` and `ResponseFuture` ([#455])
- **trace:** Add type aliases `HttpMakeClassifier` and `GrpcMakeClassifier` ([#455])

## Fixed

- **compression:** Skip compression for range requests ([#446])
- **cors:** Keep Vary headers set by the inner service when setting response headers ([#398])
- **fs:** `ServeDir` now no longer redirects from `/directory` to `/directory/`
if `append_index_html_on_directories` is disabled ([#421])

[#446]: https://github.com/tower-rs/tower-http/pull/446
[#398]: https://github.com/tower-rs/tower-http/pull/398
[#421]: https://github.com/tower-rs/tower-http/pull/421
[#455]: https://github.com/tower-rs/tower-http/pull/455

# 0.5.0 (November 21, 2023)

Expand Down
2 changes: 1 addition & 1 deletion tower-http/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "tower-http"
description = "Tower middleware and utilities for HTTP clients and servers"
version = "0.5.0"
version = "0.5.1"
authors = ["Tower Maintainers <team@tower-rs.com>"]
edition = "2018"
license = "MIT"
Expand Down
13 changes: 4 additions & 9 deletions tower-http/src/builder.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use tower::ServiceBuilder;

#[cfg(feature = "trace")]
use crate::classify::{GrpcErrorsAsFailures, ServerErrorsAsFailures, SharedClassifier};

#[allow(unused_imports)]
use http::header::HeaderName;
#[allow(unused_imports)]
Expand Down Expand Up @@ -126,7 +123,7 @@ pub trait ServiceBuilderExt<L>: crate::sealed::Sealed<L> + Sized {
#[cfg(feature = "trace")]
fn trace_for_http(
self,
) -> ServiceBuilder<Stack<crate::trace::TraceLayer<SharedClassifier<ServerErrorsAsFailures>>, L>>;
) -> ServiceBuilder<Stack<crate::trace::TraceLayer<crate::trace::HttpMakeClassifier>, L>>;

/// High level tracing that classifies responses using gRPC headers.
///
Expand All @@ -140,7 +137,7 @@ pub trait ServiceBuilderExt<L>: crate::sealed::Sealed<L> + Sized {
#[cfg(feature = "trace")]
fn trace_for_grpc(
self,
) -> ServiceBuilder<Stack<crate::trace::TraceLayer<SharedClassifier<GrpcErrorsAsFailures>>, L>>;
) -> ServiceBuilder<Stack<crate::trace::TraceLayer<crate::trace::GrpcMakeClassifier>, L>>;

/// Follow redirect resposes using the [`Standard`] policy.
///
Expand Down Expand Up @@ -427,16 +424,14 @@ impl<L> ServiceBuilderExt<L> for ServiceBuilder<L> {
#[cfg(feature = "trace")]
fn trace_for_http(
self,
) -> ServiceBuilder<Stack<crate::trace::TraceLayer<SharedClassifier<ServerErrorsAsFailures>>, L>>
{
) -> ServiceBuilder<Stack<crate::trace::TraceLayer<crate::trace::HttpMakeClassifier>, L>> {
self.layer(crate::trace::TraceLayer::new_for_http())
}

#[cfg(feature = "trace")]
fn trace_for_grpc(
self,
) -> ServiceBuilder<Stack<crate::trace::TraceLayer<SharedClassifier<GrpcErrorsAsFailures>>, L>>
{
) -> ServiceBuilder<Stack<crate::trace::TraceLayer<crate::trace::GrpcMakeClassifier>, L>> {
self.layer(crate::trace::TraceLayer::new_for_grpc())
}

Expand Down
6 changes: 6 additions & 0 deletions tower-http/src/compression/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ where

let (mut parts, body) = res.into_parts();

if should_compress {
parts
.headers
.append(header::VARY, header::ACCEPT_ENCODING.into());
}

let body = match (should_compress, self.encoding) {
// if compression is _not_ support or the client doesn't accept it
(false, _) | (_, Encoding::Identity) => {
Expand Down
23 changes: 8 additions & 15 deletions tower-http/src/cors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -618,24 +618,10 @@ where

// These headers are applied to both preflight and subsequent regular CORS requests:
// https://fetch.spec.whatwg.org/#http-responses

headers.extend(self.layer.allow_origin.to_header(origin, &parts));
headers.extend(self.layer.allow_credentials.to_header(origin, &parts));
headers.extend(self.layer.allow_private_network.to_header(origin, &parts));

let mut vary_headers = self.layer.vary.values();
if let Some(first) = vary_headers.next() {
let mut header = match headers.entry(header::VARY) {
header::Entry::Occupied(_) => {
unreachable!("no vary header inserted up to this point")
}
header::Entry::Vacant(v) => v.insert_entry(first),
};

for val in vary_headers {
header.append(val);
}
}
headers.extend(self.layer.vary.to_header());

// Return results immediately upon preflight request
if parts.method == Method::OPTIONS {
Expand Down Expand Up @@ -695,6 +681,13 @@ where
match self.project().inner.project() {
KindProj::CorsCall { future, headers } => {
let mut response: Response<B> = ready!(future.poll(cx))?;

// vary header can have multiple values, don't overwrite
// previously-set value(s).
if let Some(vary) = headers.remove(header::VARY) {
headers.append(header::VARY, vary);
}
// extend will overwrite previous headers of remaining names
response.headers_mut().extend(headers.drain());

Poll::Ready(Ok(response))
Expand Down
33 changes: 33 additions & 0 deletions tower-http/src/cors/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use std::convert::Infallible;

use http::{header, HeaderValue, Request, Response};
use hyper::Body;
use tower::{service_fn, util::ServiceExt, Layer};

use crate::cors::CorsLayer;

#[tokio::test]
#[allow(
clippy::declare_interior_mutable_const,
clippy::borrow_interior_mutable_const
)]
async fn vary_set_by_inner_service() {
const CUSTOM_VARY_HEADERS: HeaderValue = HeaderValue::from_static("accept, accept-encoding");
const PERMISSIVE_CORS_VARY_HEADERS: HeaderValue = HeaderValue::from_static(
"origin, access-control-request-method, access-control-request-headers",
);

async fn inner_svc(_: Request<Body>) -> Result<Response<Body>, Infallible> {
Ok(Response::builder()
.header(header::VARY, CUSTOM_VARY_HEADERS)
.body(Body::empty())
.unwrap())
}

let svc = CorsLayer::permissive().layer(service_fn(inner_svc));
let res = svc.oneshot(Request::new(Body::empty())).await.unwrap();
let mut vary_headers = res.headers().get_all(header::VARY).into_iter();
assert_eq!(vary_headers.next(), Some(&CUSTOM_VARY_HEADERS));
assert_eq!(vary_headers.next(), Some(&PERMISSIVE_CORS_VARY_HEADERS));
assert_eq!(vary_headers.next(), None);
}
15 changes: 12 additions & 3 deletions tower-http/src/cors/vary.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::array;

use http::{header::HeaderName, HeaderValue};
use http::header::{self, HeaderName, HeaderValue};

use super::preflight_request_headers;

Expand All @@ -26,8 +26,17 @@ impl Vary {
Self(headers.into_iter().map(Into::into).collect())
}

pub(super) fn values(&self) -> impl Iterator<Item = HeaderValue> + '_ {
self.0.iter().cloned()
pub(super) fn to_header(&self) -> Option<(HeaderName, HeaderValue)> {
let values = &self.0;
let mut res = values.get(0)?.as_bytes().to_owned();
for val in &values[1..] {
res.extend_from_slice(b", ");
res.extend_from_slice(val.as_bytes());
}

let header_val = HeaderValue::from_bytes(&res)
.expect("comma-separated list of HeaderValues is always a valid HeaderValue");
Some((header::VARY, header_val))
}
}

Expand Down
35 changes: 35 additions & 0 deletions tower-http/src/services/fs/serve_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ impl ServeFile {
Self(self.0.precompressed_deflate())
}

/// Informs the service that it should also look for a precompressed zstd
/// version of the file.
///
/// If the client has an `Accept-Encoding` header that allows the zstd encoding,
/// the file `foo.txt.zst` will be served instead of `foo.txt`.
/// If the precompressed file is not available, or the client doesn't support it,
/// the uncompressed version will be served instead.
/// Both the precompressed version and the uncompressed version are expected
/// to be present in the same directory. Different precompressed
/// variants can be combined.
pub fn precompressed_zstd(self) -> Self {
Self(self.0.precompressed_zstd())
}

/// Set a specific read buffer chunk size.
///
/// The default capacity is 64kb.
Expand Down Expand Up @@ -129,6 +143,7 @@ where
mod tests {
use crate::services::ServeFile;
use crate::test_helpers::Body;
use async_compression::tokio::bufread::ZstdDecoder;
use brotli::BrotliDecompress;
use flate2::bufread::DeflateDecoder;
use flate2::bufread::GzDecoder;
Expand All @@ -139,6 +154,7 @@ mod tests {
use mime::Mime;
use std::io::Read;
use std::str::FromStr;
use tokio::io::AsyncReadExt;
use tower::ServiceExt;

#[tokio::test]
Expand Down Expand Up @@ -342,6 +358,25 @@ mod tests {
assert!(decompressed.starts_with("\"This is a test file!\""));
}

#[tokio::test]
async fn precompressed_zstd() {
let svc = ServeFile::new("../test-files/precompressed.txt").precompressed_zstd();
let request = Request::builder()
.header("Accept-Encoding", "zstd,br")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(request).await.unwrap();

assert_eq!(res.headers()["content-type"], "text/plain");
assert_eq!(res.headers()["content-encoding"], "zstd");

let body = res.into_body().collect().await.unwrap().to_bytes();
let mut decoder = ZstdDecoder::new(&body[..]);
let mut decompressed = String::new();
decoder.read_to_string(&mut decompressed).await.unwrap();
assert!(decompressed.starts_with("\"This is a test file!\""));
}

#[tokio::test]
async fn multi_precompressed() {
let svc = ServeFile::new("../test-files/precompressed.txt")
Expand Down
4 changes: 2 additions & 2 deletions tower-http/src/trace/body.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{OnBodyChunk, OnEos, OnFailure};
use super::{DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, OnBodyChunk, OnEos, OnFailure};
use crate::classify::ClassifyEos;
use http_body::{Body, Frame};
use pin_project_lite::pin_project;
Expand All @@ -14,7 +14,7 @@ pin_project! {
/// Response body for [`Trace`].
///
/// [`Trace`]: super::Trace
pub struct ResponseBody<B, C, OnBodyChunk, OnEos, OnFailure> {
pub struct ResponseBody<B, C, OnBodyChunk = DefaultOnBodyChunk, OnEos = DefaultOnEos, OnFailure = DefaultOnFailure> {
#[pin]
pub(crate) inner: B,
pub(crate) classify_eos: Option<C>,
Expand Down
7 changes: 5 additions & 2 deletions tower-http/src/trace/future.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use super::{OnBodyChunk, OnEos, OnFailure, OnResponse, ResponseBody};
use super::{
DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnResponse, OnBodyChunk, OnEos,
OnFailure, OnResponse, ResponseBody,
};
use crate::classify::{ClassifiedResponse, ClassifyResponse};
use http::Response;
use http_body::Body;
Expand All @@ -15,7 +18,7 @@ pin_project! {
/// Response future for [`Trace`].
///
/// [`Trace`]: super::Trace
pub struct ResponseFuture<F, C, OnResponse, OnBodyChunk, OnEos, OnFailure> {
pub struct ResponseFuture<F, C, OnResponse = DefaultOnResponse, OnBodyChunk = DefaultOnBodyChunk, OnEos = DefaultOnEos, OnFailure = DefaultOnFailure> {
#[pin]
pub(crate) inner: F,
pub(crate) span: Span,
Expand Down
6 changes: 3 additions & 3 deletions tower-http/src/trace/layer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
DefaultMakeSpan, DefaultOnBodyChunk, DefaultOnEos, DefaultOnFailure, DefaultOnRequest,
DefaultOnResponse, Trace,
DefaultOnResponse, GrpcMakeClassifier, HttpMakeClassifier, Trace,
};
use crate::classify::{
GrpcErrorsAsFailures, MakeClassifier, ServerErrorsAsFailures, SharedClassifier,
Expand Down Expand Up @@ -176,7 +176,7 @@ impl<M, MakeSpan, OnRequest, OnResponse, OnBodyChunk, OnEos, OnFailure>
}
}

impl TraceLayer<SharedClassifier<ServerErrorsAsFailures>> {
impl TraceLayer<HttpMakeClassifier> {
/// Create a new [`TraceLayer`] using [`ServerErrorsAsFailures`] which supports classifying
/// regular HTTP responses based on the status code.
pub fn new_for_http() -> Self {
Expand All @@ -192,7 +192,7 @@ impl TraceLayer<SharedClassifier<ServerErrorsAsFailures>> {
}
}

impl TraceLayer<SharedClassifier<GrpcErrorsAsFailures>> {
impl TraceLayer<GrpcMakeClassifier> {
/// Create a new [`TraceLayer`] using [`GrpcErrorsAsFailures`] which supports classifying
/// gRPC responses and streams based on the `grpc-status` header.
pub fn new_for_grpc() -> Self {
Expand Down
Loading

0 comments on commit db8ecc7

Please sign in to comment.