Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow async predicate for cors AllowOrigin #478

Merged
merged 13 commits into from
Mar 15, 2024
105 changes: 95 additions & 10 deletions tower-http/src/cors/allow_origin.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
use std::{array, fmt, sync::Arc};

use http::{
header::{self, HeaderName, HeaderValue},
request::Parts as RequestParts,
};
use pin_project_lite::pin_project;
use std::{
array, fmt,
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

use super::{Any, WILDCARD};

Expand Down Expand Up @@ -73,6 +79,21 @@ impl AllowOrigin {
Self(OriginInner::Predicate(Arc::new(f)))
}

/// Set the allowed origins from an async predicate
///
/// See [`CorsLayer::allow_origin`] for more details.
///
/// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
pub fn async_predicate<F, Fut>(f: F) -> Self
where
F: FnOnce(HeaderValue, &RequestParts) -> Fut + Send + Sync + 'static + Clone,
Fut: Future<Output = bool> + Send + Sync + 'static,
{
Self(OriginInner::AsyncPredicate(Arc::new(move |v, p| {
Box::pin((f.clone())(v, p))
})))
}

/// Allow any origin, by mirroring the request origin
///
/// This is equivalent to
Expand All @@ -90,18 +111,70 @@ impl AllowOrigin {
matches!(&self.0, OriginInner::Const(v) if v == WILDCARD)
}

pub(super) fn to_header(
pub(super) fn to_future(
&self,
origin: Option<&HeaderValue>,
parts: &RequestParts,
) -> Option<(HeaderName, HeaderValue)> {
let allow_origin = match &self.0 {
OriginInner::Const(v) => v.clone(),
OriginInner::List(l) => origin.filter(|o| l.contains(o))?.clone(),
OriginInner::Predicate(c) => origin.filter(|origin| c(origin, parts))?.clone(),
};
) -> AllowOriginFuture {
let name = header::ACCESS_CONTROL_ALLOW_ORIGIN;

Some((header::ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin))
match &self.0 {
OriginInner::Const(v) => AllowOriginFuture::ok(Some((name, v.clone()))),
OriginInner::List(l) => {
AllowOriginFuture::ok(origin.filter(|o| l.contains(o)).map(|o| (name, o.clone())))
}
OriginInner::Predicate(c) => AllowOriginFuture::ok(
origin
.filter(|origin| c(origin, parts))
.map(|o| (name, o.clone())),
),
OriginInner::AsyncPredicate(f) => {
if let Some(origin) = origin.cloned() {
let fut = f(origin.clone(), parts);
AllowOriginFuture::fut(async move { fut.await.then_some((name, origin)) })
} else {
AllowOriginFuture::ok(None)
}
}
}
}
}

pin_project! {
#[project = AllowOriginFutureProj]
pub(super) enum AllowOriginFuture {
Ok{
res: Option<(HeaderName, HeaderValue)>
},
Future{
#[pin]
future: Pin<Box<dyn Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>>
},
}
}

impl AllowOriginFuture {
fn ok(res: Option<(HeaderName, HeaderValue)>) -> Self {
Self::Ok { res }
}

fn fut<F: Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>(
future: F,
) -> Self {
Self::Future {
future: Box::pin(future),
}
}
}

impl Future for AllowOriginFuture {
type Output = Option<(HeaderName, HeaderValue)>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project() {
AllowOriginFutureProj::Ok { res } => Poll::Ready(res.take()),
AllowOriginFutureProj::Future { future } => future.poll(cx),
}
}
}

Expand All @@ -111,6 +184,7 @@ impl fmt::Debug for AllowOrigin {
OriginInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(),
OriginInner::List(inner) => f.debug_tuple("List").field(inner).finish(),
OriginInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
OriginInner::AsyncPredicate(_) => f.debug_tuple("AsyncPredicate").finish(),
}
}
}
Expand Down Expand Up @@ -147,6 +221,17 @@ enum OriginInner {
Predicate(
Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
),
AsyncPredicate(
Arc<
dyn for<'a> Fn(
HeaderValue,
&'a RequestParts,
) -> Pin<Box<dyn Future<Output = bool> + Send + 'static>>
+ Send
+ Sync
+ 'static,
>,
),
}

impl Default for OriginInner {
Expand Down
82 changes: 78 additions & 4 deletions tower-http/src/cors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

#![allow(clippy::enum_variant_names)]

use allow_origin::AllowOriginFuture;
use bytes::{BufMut, BytesMut};
use http::{
header::{self, HeaderName},
Expand Down Expand Up @@ -326,6 +327,52 @@ impl CorsLayer {
/// ));
/// ```
///
/// You can also use an async closure:
///
/// ```
/// # #[derive(Clone)]
/// # struct Client;
/// # fn get_api_client() -> Client {
/// # Client
/// # }
/// # impl Client {
/// # async fn fetch_allowed_origins(&self) -> Vec<HeaderValue> {
/// # vec![HeaderValue::from_static("http://example.com")]
/// # }
/// # async fn fetch_allowed_origins_for_path(&self, _path: String) -> Vec<HeaderValue> {
/// # vec![HeaderValue::from_static("http://example.com")]
/// # }
/// # }
/// use tower_http::cors::{CorsLayer, AllowOrigin};
/// use http::{request::Parts as RequestParts, HeaderValue};
///
/// let client = get_api_client();
///
/// let layer = CorsLayer::new().allow_origin(AllowOrigin::async_predicate(
/// |origin: HeaderValue, _request_parts: &RequestParts| async move {
/// // fetch list of origins that are allowed
/// let origins = client.fetch_allowed_origins().await;
/// origins.contains(&origin)
/// },
/// ));
///
/// let client = get_api_client();
///
/// // if using &RequestParts, make sure all the values are owned
/// // before passing into the future
/// let layer = CorsLayer::new().allow_origin(AllowOrigin::async_predicate(
/// |origin: HeaderValue, parts: &RequestParts| {
/// let path = parts.uri.path().to_owned();
///
/// async move {
/// // fetch list of origins that are allowed for this path
/// let origins = client.fetch_allowed_origins_for_path(path).await;
/// origins.contains(&origin)
/// }
/// },
/// ));
/// ```
///
/// Note that multiple calls to this method will override any previous
/// calls.
///
Expand Down Expand Up @@ -621,11 +668,13 @@ where

// These headers are applied to both preflight and subsequent regular CORS requests:
// https://fetch.spec.whatwg.org/#http-responses
headers.extend(self.layer.allow_origin.to_header(origin, &parts));

headers.extend(self.layer.allow_credentials.to_header(origin, &parts));
headers.extend(self.layer.allow_private_network.to_header(origin, &parts));
headers.extend(self.layer.vary.to_header());

let allow_origin_future = self.layer.allow_origin.to_future(origin, &parts);

// Return results immediately upon preflight request
if parts.method == Method::OPTIONS {
// These headers are applied only to preflight requests
Expand All @@ -634,7 +683,10 @@ where
headers.extend(self.layer.max_age.to_header(origin, &parts));

ResponseFuture {
inner: Kind::PreflightCall { headers },
inner: Kind::PreflightCall {
allow_origin_future,
headers,
},
}
} else {
// This header is applied only to non-preflight requests
Expand All @@ -643,6 +695,8 @@ where
let req = Request::from_parts(parts, body);
ResponseFuture {
inner: Kind::CorsCall {
allow_origin_future,
allow_origin_complete: false,
future: self.inner.call(req),
headers,
},
Expand All @@ -663,11 +717,16 @@ pin_project! {
#[project = KindProj]
enum Kind<F> {
CorsCall {
#[pin]
allow_origin_future: AllowOriginFuture,
allow_origin_complete: bool,
#[pin]
future: F,
headers: HeaderMap,
},
PreflightCall {
#[pin]
allow_origin_future: AllowOriginFuture,
headers: HeaderMap,
},
}
Expand All @@ -682,7 +741,17 @@ where

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().inner.project() {
KindProj::CorsCall { future, headers } => {
KindProj::CorsCall {
allow_origin_future,
allow_origin_complete,
future,
headers,
} => {
if !*allow_origin_complete {
headers.extend(ready!(allow_origin_future.poll(cx)));
*allow_origin_complete = true;
}

let mut response: Response<B> = ready!(future.poll(cx))?;

let response_headers = response.headers_mut();
Expand All @@ -697,7 +766,12 @@ where

Poll::Ready(Ok(response))
}
KindProj::PreflightCall { headers } => {
KindProj::PreflightCall {
allow_origin_future,
headers,
} => {
headers.extend(ready!(allow_origin_future.poll(cx)));

let mut response = Response::new(B::default());
mem::swap(response.headers_mut(), headers);

Expand Down
42 changes: 41 additions & 1 deletion tower-http/src/cors/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::test_helpers::Body;
use http::{header, HeaderValue, Request, Response};
use tower::{service_fn, util::ServiceExt, Layer};

use crate::cors::CorsLayer;
use crate::cors::{AllowOrigin, CorsLayer};

#[tokio::test]
#[allow(
Expand All @@ -31,3 +31,43 @@ async fn vary_set_by_inner_service() {
assert_eq!(vary_headers.next(), Some(&PERMISSIVE_CORS_VARY_HEADERS));
assert_eq!(vary_headers.next(), None);
}

#[tokio::test]
async fn test_allow_origin_async_predicate() {
#[derive(Clone)]
struct Client;

impl Client {
async fn fetch_allowed_origins_for_path(&self, _path: String) -> Vec<HeaderValue> {
vec![HeaderValue::from_static("http://example.com")]
}
}

let client = Client;

let allow_origin = AllowOrigin::async_predicate(|origin, parts| {
let path = parts.uri.path().to_owned();

async move {
let origins = client.fetch_allowed_origins_for_path(path).await;

origins.contains(&origin)
}
});

let valid_origin = HeaderValue::from_static("http://example.com");
let parts = http::Request::new("hello world").into_parts().0;

let header = allow_origin
.to_future(Some(&valid_origin), &parts)
.await
.unwrap();
assert_eq!(header.0, header::ACCESS_CONTROL_ALLOW_ORIGIN);
assert_eq!(header.1, valid_origin);

let invalid_origin = HeaderValue::from_static("http://example.org");
let parts = http::Request::new("hello world").into_parts().0;

let res = allow_origin.to_future(Some(&invalid_origin), &parts).await;
assert!(res.is_none());
}
Loading