From 2af1054c7c83f47ff592bb97daa08aab5f6d7846 Mon Sep 17 00:00:00 2001 From: Matt Mastracci Date: Thu, 27 Feb 2025 11:20:57 -0700 Subject: [PATCH] Better surface TLS errors with hints --- gel-stream/Cargo.toml | 2 +- gel-stream/src/common/openssl.rs | 4 +- gel-stream/src/lib.rs | 4 +- gel-stream/src/server/acceptor.rs | 2 +- gel-stream/tests/tls.rs | 4 +- gel-tokio/Cargo.toml | 2 +- gel-tokio/examples/transaction_errors.rs | 12 ++---- gel-tokio/src/builder.rs | 38 ++++++++----------- gel-tokio/src/raw/connection.rs | 48 +++++++++++++----------- 9 files changed, 56 insertions(+), 60 deletions(-) diff --git a/gel-stream/Cargo.toml b/gel-stream/Cargo.toml index 3773fa43..811ef180 100644 --- a/gel-stream/Cargo.toml +++ b/gel-stream/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "gel-stream" license = "MIT/Apache-2.0" -version = "0.1.2" +version = "0.1.3" authors = ["MagicStack Inc. "] edition = "2021" description = "A library for streaming data between clients and servers." diff --git a/gel-stream/src/common/openssl.rs b/gel-stream/src/common/openssl.rs index e88233ff..9f3826c7 100644 --- a/gel-stream/src/common/openssl.rs +++ b/gel-stream/src/common/openssl.rs @@ -171,7 +171,7 @@ impl TlsDriver for OpensslDriver { let webpki_roots = WEBPKI_ROOTS.get_or_init(|| { let webpki_roots = webpki_root_certs::TLS_SERVER_ROOT_CERTS; let mut roots = Vec::new(); - for root in webpki_roots.iter().cloned() { + for root in webpki_roots { // Don't expect the roots to fail to load if let Ok(root) = openssl::x509::X509::from_der(root.as_ref()) { roots.push(root); @@ -352,7 +352,7 @@ impl TlsDriver for OpensslDriver { .peer_certificate() .map(|cert| cert.to_der()) .transpose()?; - let cert = cert.map(|cert| CertificateDer::from(cert)); + let cert = cert.map(CertificateDer::from); Ok(( TlsStream(stream), TlsHandshake { diff --git a/gel-stream/src/lib.rs b/gel-stream/src/lib.rs index 57fc0d61..eac3a1f0 100644 --- a/gel-stream/src/lib.rs +++ b/gel-stream/src/lib.rs @@ -117,7 +117,7 @@ impl SslError { _ => None, }, #[cfg(feature = "openssl")] - SslError::OpenSslErrorStack(e) => match e.errors().get(0).map(|err| err.code()) { + SslError::OpenSslErrorStack(e) => match e.errors().first().map(|err| err.code()) { // SSL_R_WRONG_VERSION_NUMBER Some(0xa00010b) => Some(CommonError::InvalidTlsProtocolData), // SSL_R_PACKET_LENGTH_TOO_LONG @@ -130,7 +130,7 @@ impl SslError { openssl_sys::SSL_ERROR_SSL => { match e .ssl_error() - .and_then(|e| e.errors().get(0)) + .and_then(|e| e.errors().first()) .map(|err| err.code()) { // SSL_R_WRONG_VERSION_NUMBER diff --git a/gel-stream/src/server/acceptor.rs b/gel-stream/src/server/acceptor.rs index 5d75dec7..d49aeadf 100644 --- a/gel-stream/src/server/acceptor.rs +++ b/gel-stream/src/server/acceptor.rs @@ -1,6 +1,6 @@ use crate::{ common::tokio_stream::TokioListenerStream, ConnectionError, LocalAddress, ResolvedTarget, - RewindStream, Ssl, SslError, StreamUpgrade, Target, TlsDriver, TlsServerParameterProvider, + RewindStream, Ssl, SslError, StreamUpgrade, TlsDriver, TlsServerParameterProvider, UpgradableStream, }; use futures::{FutureExt, StreamExt}; diff --git a/gel-stream/tests/tls.rs b/gel-stream/tests/tls.rs index bb0e0170..3d64ec77 100644 --- a/gel-stream/tests/tls.rs +++ b/gel-stream/tests/tls.rs @@ -99,7 +99,7 @@ async fn spawn_tls_server( assert_eq!(handshake.sni.as_deref(), expected_hostname.as_deref()); if validate_cert { assert!(handshake.cert.is_some()); - let cert = parse_cert(&handshake.cert.as_ref().unwrap()); + let cert = parse_cert(handshake.cert.as_ref().unwrap()); let subject = cert.subject().to_string(); assert!( subject.to_ascii_lowercase().contains("ssl_user"), @@ -713,6 +713,6 @@ tls_client_test! { fn parse_cert<'a>( cert: &'a rustls_pki_types::CertificateDer<'a>, ) -> x509_parser::prelude::X509Certificate<'a> { - let (_, cert) = x509_parser::parse_x509_certificate(&cert).unwrap(); + let (_, cert) = x509_parser::parse_x509_certificate(cert).unwrap(); cert } diff --git a/gel-tokio/Cargo.toml b/gel-tokio/Cargo.toml index 704b2f3c..35462129 100644 --- a/gel-tokio/Cargo.toml +++ b/gel-tokio/Cargo.toml @@ -17,7 +17,7 @@ gel-protocol = { path = "../gel-protocol", version = "0.8", features = [ ] } gel-errors = { path = "../gel-errors", version = "0.5" } gel-derive = { path = "../gel-derive", version = "0.7", optional = true } -gel-stream = { path = "../gel-stream", version = "0.1.2", features = ["client", "tokio", "rustls", "hickory", "keepalive"] } +gel-stream = { path = "../gel-stream", version = "0.1.3", features = ["client", "tokio", "rustls", "hickory", "keepalive"] } gel-auth = { path = "../gel-auth", version = "0.1.3" } tokio = { workspace = true, features = ["net", "time", "sync", "macros"] } bytes = "1.5.0" diff --git a/gel-tokio/examples/transaction_errors.rs b/gel-tokio/examples/transaction_errors.rs index 252a11a6..9380940a 100644 --- a/gel-tokio/examples/transaction_errors.rs +++ b/gel-tokio/examples/transaction_errors.rs @@ -8,19 +8,15 @@ use gel_errors::{ErrorKind, UserError}; struct CounterError; fn check_val0(val: i64) -> anyhow::Result<()> { - if val % 3 == 0 { - if rng().random_bool(0.9) { - Err(CounterError)?; - } + if val % 3 == 0 && rng().random_bool(0.9) { + Err(CounterError)?; } Ok(()) } fn check_val1(val: i64) -> Result<(), CounterError> { - if val % 3 == 1 { - if rng().random_bool(0.1) { - Err(CounterError)?; - } + if val % 3 == 1 && rng().random_bool(0.1) { + Err(CounterError)?; } Ok(()) } diff --git a/gel-tokio/src/builder.rs b/gel-tokio/src/builder.rs index c0950b38..e4d76d1d 100644 --- a/gel-tokio/src/builder.rs +++ b/gel-tokio/src/builder.rs @@ -230,7 +230,7 @@ impl CertCheck { pub fn new_fn> + Send + Sync + 'static>(function: impl for <'a> Fn(&'a [u8]) -> F + Send + Sync + 'static) -> Self { let function = Arc::new(move |cert: &'_[u8]| { let fut = function(cert); - Box::pin(async move { fut.await }) as _ + Box::pin(fut) as _ }); Self { function } @@ -1404,21 +1404,19 @@ impl Builder { instance )) })?; - if matches!(instance, InstanceName::Cloud { .. }) { - if cfg.secret_key.is_none() && cfg.cloud_profile.is_none() { - let path = stash_path.join("cloud-profile"); - let profile = fs::read_to_string(&path) - .await - .map_err(|e| { - ClientError::with_source(e).context(format!( - "error reading project settings {:?}: {:?}", - project_dir, path - )) - })? - .trim() - .into(); - cfg.cloud_profile = Some(profile); - } + if matches!(instance, InstanceName::Cloud { .. }) && cfg.secret_key.is_none() && cfg.cloud_profile.is_none() { + let path = stash_path.join("cloud-profile"); + let profile = fs::read_to_string(&path) + .await + .map_err(|e| { + ClientError::with_source(e).context(format!( + "error reading project settings {:?}: {:?}", + project_dir, path + )) + })? + .trim() + .into(); + cfg.cloud_profile = Some(profile); } read_instance(cfg, &instance).await?; let path = stash_path.join("database"); @@ -1941,7 +1939,7 @@ impl Config { tls.root_cert = TlsCert::Webpki; match &self.0.pem_certificates { Some(pem_certificates) => { - tls.root_cert = TlsCert::Custom(read_root_cert_pem(&pem_certificates)?); + tls.root_cert = TlsCert::Custom(read_root_cert_pem(pem_certificates)?); } None => { if let Some(cloud_certs) = self.0.cloud_certs { @@ -1967,11 +1965,7 @@ impl Config { Some(Cow::from(host)) } } else { - if let Some(host) = self.0.address.host() { - Some(Cow::from(host.to_string())) - } else { - None - } + self.0.address.host().map(|host| Cow::from(host.to_string())) } } else { None diff --git a/gel-tokio/src/raw/connection.rs b/gel-tokio/src/raw/connection.rs index 3cde1434..5994b79c 100644 --- a/gel-tokio/src/raw/connection.rs +++ b/gel-tokio/src/raw/connection.rs @@ -287,16 +287,28 @@ async fn connect2( // Allow plaintext reconnection if and only if ClientSecurity is InsecureDevMode and // the server replied with something that looks like TLS handshake failure. - if let Err(ConnectionError::SslError(e)) = &res { - if e.common_error() == Some(CommonError::InvalidTlsProtocolData) { - if cfg.0.client_security == ClientSecurity::InsecureDevMode { - target.try_remove_tls(); - warn!("TLS handshake failed, trying again without TLS"); - *warned = true; - - let mut connector = Connector::new(target.clone()).map_err(ClientConnectionError::with_source)?; - connector.set_keepalive(cfg.0.tcp_keepalive); - res = connector.connect().await; + if let Err(ConnectionError::SslError(e)) = res { + match e.common_error() { + Some(CommonError::InvalidTlsProtocolData) => { + if cfg.0.client_security == ClientSecurity::InsecureDevMode { + target.try_remove_tls(); + warn!("TLS handshake failed, trying again without TLS"); + *warned = true; + let mut connector = Connector::new(target.clone()).map_err(ClientConnectionError::with_source)?; + connector.set_keepalive(cfg.0.tcp_keepalive); + res = connector.connect().await; + } else { + res = Err(ConnectionError::SslError(e)); + } + } + Some(CommonError::InvalidCertificateForName) => { + return Err(ClientConnectionError::with_source(e).context(format!("The server's certificate does not match the requested host name ({:?}). Use `--tls-security no-host-verification` to bypass this check.", target.host().unwrap_or_default()))); + } + Some(e) => { + return Err(ClientConnectionError::with_source(e).context(format!("TLS handshake failed while connecting to ({:?}) ({e:?}). Check client and server TLS options and try again.", target))); + } + None => { + res = Err(ConnectionError::SslError(e)); } } } @@ -310,7 +322,7 @@ async fn connect4(cfg: &Config, mut stream: gel_stream::RawStream) -> Result Result { - resp = client_auth.drive(ClientAuthDrive::ScramResponse(&data)).map_err(AuthenticationError::with_source)?; + resp = client_auth.drive(ClientAuthDrive::ScramResponse(data)).map_err(AuthenticationError::with_source)?; } ServerMessage::Authentication(Authentication::SaslFinal { ref data }) => { - resp = client_auth.drive(ClientAuthDrive::ScramResponse(&data)).map_err(AuthenticationError::with_source)?; + resp = client_auth.drive(ClientAuthDrive::ScramResponse(data)).map_err(AuthenticationError::with_source)?; } ServerMessage::ErrorResponse(err) => { return Err(err.into()); @@ -414,9 +426,7 @@ async fn connect4(cfg: &Config, mut stream: gel_stream::RawStream) -> Result { - return Err(ProtocolError::with_message(format!( - "Unexpected authentication response", - ))); + return Err(ProtocolError::with_message("Unexpected authentication response".to_string())); } ClientAuthResponse::Complete => { break; @@ -708,11 +718,7 @@ fn is_temporary(e: &Error) -> bool { let mut e: &dyn std::error::Error = &e; while let Some(src) = e.source() { if let Some(io_err) = src.downcast_ref::() { - if is_io_error_temporary(io_err) { - return true; - } else { - return false; - } + return is_io_error_temporary(io_err) } e = src; }