Skip to content

Commit

Permalink
Better surface TLS errors with hints
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Feb 27, 2025
1 parent 111d535 commit 2af1054
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 60 deletions.
2 changes: 1 addition & 1 deletion gel-stream/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "gel-stream"
license = "MIT/Apache-2.0"
version = "0.1.2"
version = "0.1.3"
authors = ["MagicStack Inc. <hello@magic.io>"]
edition = "2021"
description = "A library for streaming data between clients and servers."
Expand Down
4 changes: 2 additions & 2 deletions gel-stream/src/common/openssl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ impl TlsDriver for OpensslDriver {
let webpki_roots = WEBPKI_ROOTS.get_or_init(|| {
let webpki_roots = webpki_root_certs::TLS_SERVER_ROOT_CERTS;
let mut roots = Vec::new();
for root in webpki_roots.iter().cloned() {
for root in webpki_roots {
// Don't expect the roots to fail to load
if let Ok(root) = openssl::x509::X509::from_der(root.as_ref()) {
roots.push(root);
Expand Down Expand Up @@ -352,7 +352,7 @@ impl TlsDriver for OpensslDriver {
.peer_certificate()
.map(|cert| cert.to_der())
.transpose()?;
let cert = cert.map(|cert| CertificateDer::from(cert));
let cert = cert.map(CertificateDer::from);
Ok((
TlsStream(stream),
TlsHandshake {
Expand Down
4 changes: 2 additions & 2 deletions gel-stream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ impl SslError {
_ => None,
},
#[cfg(feature = "openssl")]
SslError::OpenSslErrorStack(e) => match e.errors().get(0).map(|err| err.code()) {
SslError::OpenSslErrorStack(e) => match e.errors().first().map(|err| err.code()) {
// SSL_R_WRONG_VERSION_NUMBER
Some(0xa00010b) => Some(CommonError::InvalidTlsProtocolData),
// SSL_R_PACKET_LENGTH_TOO_LONG
Expand All @@ -130,7 +130,7 @@ impl SslError {
openssl_sys::SSL_ERROR_SSL => {
match e
.ssl_error()
.and_then(|e| e.errors().get(0))
.and_then(|e| e.errors().first())
.map(|err| err.code())
{
// SSL_R_WRONG_VERSION_NUMBER
Expand Down
2 changes: 1 addition & 1 deletion gel-stream/src/server/acceptor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
common::tokio_stream::TokioListenerStream, ConnectionError, LocalAddress, ResolvedTarget,
RewindStream, Ssl, SslError, StreamUpgrade, Target, TlsDriver, TlsServerParameterProvider,
RewindStream, Ssl, SslError, StreamUpgrade, TlsDriver, TlsServerParameterProvider,
UpgradableStream,
};
use futures::{FutureExt, StreamExt};
Expand Down
4 changes: 2 additions & 2 deletions gel-stream/tests/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async fn spawn_tls_server<S: TlsDriver>(
assert_eq!(handshake.sni.as_deref(), expected_hostname.as_deref());
if validate_cert {
assert!(handshake.cert.is_some());
let cert = parse_cert(&handshake.cert.as_ref().unwrap());
let cert = parse_cert(handshake.cert.as_ref().unwrap());
let subject = cert.subject().to_string();
assert!(
subject.to_ascii_lowercase().contains("ssl_user"),
Expand Down Expand Up @@ -713,6 +713,6 @@ tls_client_test! {
fn parse_cert<'a>(
cert: &'a rustls_pki_types::CertificateDer<'a>,
) -> x509_parser::prelude::X509Certificate<'a> {
let (_, cert) = x509_parser::parse_x509_certificate(&cert).unwrap();
let (_, cert) = x509_parser::parse_x509_certificate(cert).unwrap();
cert
}
2 changes: 1 addition & 1 deletion gel-tokio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ gel-protocol = { path = "../gel-protocol", version = "0.8", features = [
] }
gel-errors = { path = "../gel-errors", version = "0.5" }
gel-derive = { path = "../gel-derive", version = "0.7", optional = true }
gel-stream = { path = "../gel-stream", version = "0.1.2", features = ["client", "tokio", "rustls", "hickory", "keepalive"] }
gel-stream = { path = "../gel-stream", version = "0.1.3", features = ["client", "tokio", "rustls", "hickory", "keepalive"] }
gel-auth = { path = "../gel-auth", version = "0.1.3" }
tokio = { workspace = true, features = ["net", "time", "sync", "macros"] }
bytes = "1.5.0"
Expand Down
12 changes: 4 additions & 8 deletions gel-tokio/examples/transaction_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,15 @@ use gel_errors::{ErrorKind, UserError};
struct CounterError;

fn check_val0(val: i64) -> anyhow::Result<()> {
if val % 3 == 0 {
if rng().random_bool(0.9) {
Err(CounterError)?;
}
if val % 3 == 0 && rng().random_bool(0.9) {
Err(CounterError)?;
}
Ok(())
}

fn check_val1(val: i64) -> Result<(), CounterError> {
if val % 3 == 1 {
if rng().random_bool(0.1) {
Err(CounterError)?;
}
if val % 3 == 1 && rng().random_bool(0.1) {
Err(CounterError)?;
}
Ok(())
}
Expand Down
38 changes: 16 additions & 22 deletions gel-tokio/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ impl CertCheck {
pub fn new_fn<F: Future<Output = Result<(), gel_errors::Error>> + Send + Sync + 'static>(function: impl for <'a> Fn(&'a [u8]) -> F + Send + Sync + 'static) -> Self {
let function = Arc::new(move |cert: &'_[u8]| {
let fut = function(cert);
Box::pin(async move { fut.await }) as _
Box::pin(fut) as _
});

Self { function }
Expand Down Expand Up @@ -1404,21 +1404,19 @@ impl Builder {
instance
))
})?;
if matches!(instance, InstanceName::Cloud { .. }) {
if cfg.secret_key.is_none() && cfg.cloud_profile.is_none() {
let path = stash_path.join("cloud-profile");
let profile = fs::read_to_string(&path)
.await
.map_err(|e| {
ClientError::with_source(e).context(format!(
"error reading project settings {:?}: {:?}",
project_dir, path
))
})?
.trim()
.into();
cfg.cloud_profile = Some(profile);
}
if matches!(instance, InstanceName::Cloud { .. }) && cfg.secret_key.is_none() && cfg.cloud_profile.is_none() {
let path = stash_path.join("cloud-profile");
let profile = fs::read_to_string(&path)
.await
.map_err(|e| {
ClientError::with_source(e).context(format!(
"error reading project settings {:?}: {:?}",
project_dir, path
))
})?
.trim()
.into();
cfg.cloud_profile = Some(profile);
}
read_instance(cfg, &instance).await?;
let path = stash_path.join("database");
Expand Down Expand Up @@ -1941,7 +1939,7 @@ impl Config {
tls.root_cert = TlsCert::Webpki;
match &self.0.pem_certificates {
Some(pem_certificates) => {
tls.root_cert = TlsCert::Custom(read_root_cert_pem(&pem_certificates)?);
tls.root_cert = TlsCert::Custom(read_root_cert_pem(pem_certificates)?);
}
None => {
if let Some(cloud_certs) = self.0.cloud_certs {
Expand All @@ -1967,11 +1965,7 @@ impl Config {
Some(Cow::from(host))
}
} else {
if let Some(host) = self.0.address.host() {
Some(Cow::from(host.to_string()))
} else {
None
}
self.0.address.host().map(|host| Cow::from(host.to_string()))
}
} else {
None
Expand Down
48 changes: 27 additions & 21 deletions gel-tokio/src/raw/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,16 +287,28 @@ async fn connect2(

// Allow plaintext reconnection if and only if ClientSecurity is InsecureDevMode and
// the server replied with something that looks like TLS handshake failure.
if let Err(ConnectionError::SslError(e)) = &res {
if e.common_error() == Some(CommonError::InvalidTlsProtocolData) {
if cfg.0.client_security == ClientSecurity::InsecureDevMode {
target.try_remove_tls();
warn!("TLS handshake failed, trying again without TLS");
*warned = true;

let mut connector = Connector::new(target.clone()).map_err(ClientConnectionError::with_source)?;
connector.set_keepalive(cfg.0.tcp_keepalive);
res = connector.connect().await;
if let Err(ConnectionError::SslError(e)) = res {
match e.common_error() {
Some(CommonError::InvalidTlsProtocolData) => {
if cfg.0.client_security == ClientSecurity::InsecureDevMode {
target.try_remove_tls();
warn!("TLS handshake failed, trying again without TLS");
*warned = true;
let mut connector = Connector::new(target.clone()).map_err(ClientConnectionError::with_source)?;
connector.set_keepalive(cfg.0.tcp_keepalive);
res = connector.connect().await;
} else {
res = Err(ConnectionError::SslError(e));
}
}
Some(CommonError::InvalidCertificateForName) => {
return Err(ClientConnectionError::with_source(e).context(format!("The server's certificate does not match the requested host name ({:?}). Use `--tls-security no-host-verification` to bypass this check.", target.host().unwrap_or_default())));
}
Some(e) => {
return Err(ClientConnectionError::with_source(e).context(format!("TLS handshake failed while connecting to ({:?}) ({e:?}). Check client and server TLS options and try again.", target)));
}
None => {
res = Err(ConnectionError::SslError(e));
}
}
}
Expand All @@ -310,7 +322,7 @@ async fn connect4(cfg: &Config, mut stream: gel_stream::RawStream) -> Result<Con
if let Some(cert_check) = &cfg.0.cert_check {
if let Some(handshake) = stream.handshake() {
if let Some(cert) = &handshake.cert {
cert_check.call(&cert).await?;
cert_check.call(cert).await?;
}
}
}
Expand Down Expand Up @@ -382,10 +394,10 @@ async fn connect4(cfg: &Config, mut stream: gel_stream::RawStream) -> Result<Con
}
}
ServerMessage::Authentication(Authentication::SaslContinue { ref data }) => {
resp = client_auth.drive(ClientAuthDrive::ScramResponse(&data)).map_err(AuthenticationError::with_source)?;
resp = client_auth.drive(ClientAuthDrive::ScramResponse(data)).map_err(AuthenticationError::with_source)?;
}
ServerMessage::Authentication(Authentication::SaslFinal { ref data }) => {
resp = client_auth.drive(ClientAuthDrive::ScramResponse(&data)).map_err(AuthenticationError::with_source)?;
resp = client_auth.drive(ClientAuthDrive::ScramResponse(data)).map_err(AuthenticationError::with_source)?;
}
ServerMessage::ErrorResponse(err) => {
return Err(err.into());
Expand Down Expand Up @@ -414,9 +426,7 @@ async fn connect4(cfg: &Config, mut stream: gel_stream::RawStream) -> Result<Con
.await?;
},
ClientAuthResponse::Initial(..) => {
return Err(ProtocolError::with_message(format!(
"Unexpected authentication response",
)));
return Err(ProtocolError::with_message("Unexpected authentication response".to_string()));
}
ClientAuthResponse::Complete => {
break;
Expand Down Expand Up @@ -708,11 +718,7 @@ fn is_temporary(e: &Error) -> bool {
let mut e: &dyn std::error::Error = &e;
while let Some(src) = e.source() {
if let Some(io_err) = src.downcast_ref::<io::Error>() {
if is_io_error_temporary(io_err) {
return true;
} else {
return false;
}
return is_io_error_temporary(io_err)
}
e = src;
}
Expand Down

0 comments on commit 2af1054

Please sign in to comment.