Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: The https tls handshake is not 100% successful at high concurrency. #646

Merged
merged 3 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions crates/core/src/conn/acme/listener.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::io::{Error as IoError, ErrorKind, Result as IoResult};
use std::io::Result as IoResult;
use std::path::PathBuf;
use std::sync::{Arc, Weak};
use std::time::Duration;
Expand All @@ -11,7 +11,7 @@ use tokio_rustls::rustls::sign::CertifiedKey;
use tokio_rustls::server::TlsStream;
use tokio_rustls::TlsAcceptor;

use crate::conn::{Accepted, Acceptor, Holding, Listener};
use crate::conn::{Accepted, Acceptor, HandshakeStream, Holding, Listener};

use crate::http::uri::Scheme;
use crate::http::Version;
Expand Down Expand Up @@ -426,7 +426,7 @@ where
T: Acceptor + Send + 'static,
<T as Acceptor>::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
type Conn = TlsStream<T::Conn>;
type Conn = HandshakeStream<TlsStream<T::Conn>>;

#[inline]
fn holdings(&self) -> &[Holding] {
Expand All @@ -442,13 +442,8 @@ where
http_version,
http_scheme,
} = self.inner.accept().await?;
let conn = self
.tls_acceptor
.accept(conn)
.await
.map_err(|e| IoError::new(ErrorKind::Other, e.to_string()))?;
Ok(Accepted {
conn,
conn: HandshakeStream::new(self.tls_acceptor.accept(conn)),
local_addr,
remote_addr,
http_version,
Expand Down
142 changes: 142 additions & 0 deletions crates/core/src/conn/handshake_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
use std::future::Future;
use std::io::{Error as IoError, ErrorKind, Result as IoResult};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use futures_util::{future::BoxFuture, FutureExt};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, Result};

use crate::conn::HttpBuilder;
use crate::http::HttpConnection;
use crate::service::HyperHandler;

enum State<S> {
Handshaking(BoxFuture<'static, Result<S>>),
Ready(S),
Error,
}

/// Tls stream.
pub struct HandshakeStream<S> {
state: State<S>,
}

impl<S> HandshakeStream<S> {
pub(crate) fn new<F>(handshake: F) -> Self
where
F: Future<Output = Result<S>> + Send + 'static,
{
Self {
state: State::Handshaking(handshake.boxed()),
}
}
}

impl<S> AsyncRead for HandshakeStream<S>
where
S: AsyncRead + Unpin + Send + 'static,
{
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
let this = &mut *self;

loop {
match &mut this.state {
State::Handshaking(fut) => match fut.poll_unpin(cx) {
Poll::Ready(Ok(s)) => this.state = State::Ready(s),
Poll::Ready(Err(err)) => {
this.state = State::Error;
return Poll::Ready(Err(err));
}
Poll::Pending => return Poll::Pending,
},
State::Ready(stream) => return Pin::new(stream).poll_read(cx, buf),
State::Error => return Poll::Ready(Err(invalid_data_error("poll read invalid data"))),
}
}
}
}

impl<S> AsyncWrite for HandshakeStream<S>
where
S: AsyncWrite + Unpin + Send + 'static,
{
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
let this = &mut *self;

loop {
match &mut this.state {
State::Handshaking(fut) => match fut.poll_unpin(cx) {
Poll::Ready(Ok(s)) => this.state = State::Ready(s),
Poll::Ready(Err(err)) => {
this.state = State::Error;
return Poll::Ready(Err(err));
}
Poll::Pending => return Poll::Pending,
},
State::Ready(stream) => return Pin::new(stream).poll_write(cx, buf),
State::Error => return Poll::Ready(Err(invalid_data_error("poll write invalid data"))),
}
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
let this = &mut *self;

loop {
match &mut this.state {
State::Handshaking(fut) => match fut.poll_unpin(cx) {
Poll::Ready(Ok(s)) => this.state = State::Ready(s),
Poll::Ready(Err(err)) => {
this.state = State::Error;
return Poll::Ready(Err(err));
}
Poll::Pending => return Poll::Pending,
},
State::Ready(stream) => return Pin::new(stream).poll_flush(cx),
State::Error => return Poll::Ready(Err(invalid_data_error("poll flush invalid data"))),
}
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
let this = &mut *self;

loop {
match &mut this.state {
State::Handshaking(fut) => match fut.poll_unpin(cx) {
Poll::Ready(Ok(s)) => this.state = State::Ready(s),
Poll::Ready(Err(err)) => {
this.state = State::Error;
return Poll::Ready(Err(err));
}
Poll::Pending => return Poll::Pending,
},
State::Ready(stream) => return Pin::new(stream).poll_shutdown(cx),
State::Error => return Poll::Ready(Err(invalid_data_error("poll shutdown invalid data"))),
}
}
}
}

impl<S> HttpConnection for HandshakeStream<S>
where
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
async fn serve(
self,
handler: HyperHandler,
builder: Arc<HttpBuilder>,
idle_timeout: Option<Duration>,
) -> IoResult<()> {
builder
.serve_connection(self, handler, idle_timeout)
.await
.map_err(|e| IoError::new(ErrorKind::Other, e.to_string()))
}
}

fn invalid_data_error(msg: &'static str) -> IoError {
IoError::new(ErrorKind::InvalidData, msg)
}
1 change: 0 additions & 1 deletion crates/core/src/conn/joined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ impl<A, B> JoinedAcceptor<A, B> {
}
}

#[async_trait]
impl<A, B> HttpConnection for JoinedStream<A, B>
where
A: HttpConnection + Send,
Expand Down
7 changes: 5 additions & 2 deletions crates/core/src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ use crate::http::{HttpConnection, Version};
mod proto;
pub use proto::HttpBuilder;

cfg_feature! {
#![any(feature = "native-tls", feature = "rustls", feature = "openssl-tls", feature = "acme")]
mod handshake_stream;
pub use handshake_stream::HandshakeStream;
}
cfg_feature! {
#![feature = "acme"]
pub mod acme;
Expand Down Expand Up @@ -77,13 +82,11 @@ cfg_feature! {
use tokio_rustls::server::TlsStream;
use tokio::io::{AsyncRead, AsyncWrite};

use crate::async_trait;
use crate::service::HyperHandler;
use crate::http::{HttpConnection};
use crate::conn::HttpBuilder;

#[cfg(any(feature = "rustls", feature = "acme"))]
#[async_trait]
impl<S> HttpConnection for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
Expand Down
21 changes: 12 additions & 9 deletions crates/core/src/conn/native_tls/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio_native_tls::TlsStream;

use crate::async_trait;
use crate::conn::{Accepted, Acceptor, Holding, HttpBuilder, IntoConfigStream, Listener};
use crate::conn::{Accepted, Acceptor, HandshakeStream, Holding, HttpBuilder, IntoConfigStream, Listener};
use crate::http::{HttpConnection, Version};
use crate::service::HyperHandler;

Expand Down Expand Up @@ -62,7 +62,6 @@ where
}
}

#[async_trait]
impl<S> HttpConnection for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
Expand Down Expand Up @@ -134,7 +133,7 @@ where
<T as Acceptor>::Conn: AsyncRead + AsyncWrite + Unpin + Send,
E: StdError + Send,
{
type Conn = TlsStream<T::Conn>;
type Conn = HandshakeStream<TlsStream<T::Conn>>;

#[inline]
fn holdings(&self) -> &[Holding] {
Expand All @@ -154,7 +153,9 @@ where
config
};
if let Some(config) = config {
let identity = config.try_into().map_err(|e| IoError::new(ErrorKind::Other, e.to_string()))?;
let identity = config
.try_into()
.map_err(|e| IoError::new(ErrorKind::Other, e.to_string()))?;
let tls_acceptor = tokio_native_tls::native_tls::TlsAcceptor::new(identity);
match tls_acceptor {
Ok(tls_acceptor) => {
Expand All @@ -180,12 +181,14 @@ where
http_version,
http_scheme,
} = self.inner.accept().await?;
let conn = tls_acceptor
.accept(conn)
.await
.map_err(|e| IoError::new(ErrorKind::Other, e.to_string()))?;
let conn = async move {
tls_acceptor
.accept(conn)
.await
.map_err(|e| IoError::new(ErrorKind::Other, e.to_string()))
};
Ok(Accepted {
conn,
conn: HandshakeStream::new(conn),
local_addr,
remote_addr,
http_version,
Expand Down
31 changes: 17 additions & 14 deletions crates/core/src/conn/openssl/listener.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
//! openssl module
use std::error::Error as StdError;
use std::io::{Error as IoError, Result as IoResult};
use std::marker::PhantomData;
use std::error::Error as StdError;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use futures_util::stream::{BoxStream, Stream, StreamExt};
use futures_util::task::noop_waker_ref;
use futures_util::stream::{Stream,BoxStream, StreamExt};
use http::uri::Scheme;
use openssl::ssl::{Ssl, SslAcceptor};
use tokio::io::ErrorKind;
Expand All @@ -17,7 +17,7 @@ use tokio_openssl::SslStream;
use super::SslAcceptorBuilder;

use crate::async_trait;
use crate::conn::{Accepted, Acceptor,HttpBuilder, Holding, IntoConfigStream, Listener};
use crate::conn::{Accepted, Acceptor, HandshakeStream, Holding, HttpBuilder, IntoConfigStream, Listener};
use crate::http::{HttpConnection, Version};
use crate::service::HyperHandler;

Expand Down Expand Up @@ -112,7 +112,6 @@ where
}
}

#[async_trait]
impl<S> HttpConnection for SslStream<S>
where
S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
Expand All @@ -138,7 +137,7 @@ where
T: Acceptor + Send + 'static,
E: StdError + Send,
{
type Conn = SslStream<T::Conn>;
type Conn = HandshakeStream<SslStream<T::Conn>>;

/// Get the local address bound to this listener.
fn holdings(&self) -> &[Holding] {
Expand Down Expand Up @@ -182,16 +181,20 @@ where
http_version,
http_scheme,
} = self.inner.accept().await?;
let ssl = Ssl::new(tls_acceptor.context()).map_err(|err| IoError::new(ErrorKind::Other, err.to_string()))?;
let mut tls_stream =
SslStream::new(ssl, conn).map_err(|err| IoError::new(ErrorKind::Other, err.to_string()))?;
use std::pin::Pin;
Pin::new(&mut tls_stream)
.accept()
.await
.map_err(|err| IoError::new(ErrorKind::Other, err.to_string()))?;
let conn = async move {
let ssl =
Ssl::new(tls_acceptor.context()).map_err(|err| IoError::new(ErrorKind::Other, err.to_string()))?;
let mut tls_stream =
SslStream::new(ssl, conn).map_err(|err| IoError::new(ErrorKind::Other, err.to_string()))?;
std::pin::Pin::new(&mut tls_stream)
.accept()
.await
.map_err(|err| IoError::new(ErrorKind::Other, err.to_string()))?;
Ok(tls_stream)
};

Ok(Accepted {
conn: tls_stream,
conn: HandshakeStream::new(conn),
local_addr,
remote_addr,
http_version,
Expand Down
2 changes: 0 additions & 2 deletions crates/core/src/conn/quinn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use salvo_http3::http3_quinn;
pub use salvo_http3::http3_quinn::ServerConfig;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use crate::async_trait;
use crate::conn::rustls::RustlsConfig;
use crate::conn::{HttpBuilder, IntoConfigStream};
use crate::http::HttpConnection;
Expand Down Expand Up @@ -71,7 +70,6 @@ impl AsyncWrite for H3Connection {
}
}

#[async_trait]
impl HttpConnection for H3Connection {
async fn serve(
self,
Expand Down
11 changes: 3 additions & 8 deletions crates/core/src/conn/rustls/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@ use tokio_rustls::server::TlsStream;

use crate::async_trait;
use crate::conn::Holding;
use crate::conn::{Accepted, Acceptor, IntoConfigStream, Listener};
use crate::conn::{Accepted, HandshakeStream, Acceptor, IntoConfigStream, Listener};
use crate::http::uri::Scheme;
use crate::http::Version;

use super::ServerConfig;


/// A wrapper of `Listener` with rustls.
pub struct RustlsListener<S, C, T, E> {
config_stream: S,
Expand Down Expand Up @@ -120,7 +119,7 @@ where
<T as Acceptor>::Conn: AsyncRead + AsyncWrite + Send + Unpin + 'static,
E: StdError + Send,
{
type Conn = TlsStream<T::Conn>;
type Conn = HandshakeStream<TlsStream<T::Conn>>;

fn holdings(&self) -> &[Holding] {
&self.holdings
Expand Down Expand Up @@ -161,12 +160,8 @@ where
http_version,
http_scheme,
} = self.inner.accept().await?;
let conn = tls_acceptor
.accept(conn)
.await
.map_err(|e| IoError::new(ErrorKind::Other, e.to_string()))?;
Ok(Accepted {
conn,
conn: HandshakeStream::new(tls_acceptor.accept(conn)),
local_addr,
remote_addr,
http_version,
Expand Down
Loading
Loading