Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmastrac committed Feb 21, 2025
1 parent 1101bdb commit 17f3ada
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 25 deletions.
154 changes: 130 additions & 24 deletions gel-stream/src/common/rustls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,24 @@ impl TlsDriver for RustlsDriver {
}
}

fn make_roots(
root_certs: &[CertificateDer<'static>],
webpki: bool,
) -> Result<RootCertStore, crate::SslError> {
let mut roots = RootCertStore::empty();
if webpki {
let webpki_roots = webpki_roots::TLS_SERVER_ROOTS;
roots.extend(webpki_roots.iter().cloned());
}
let (loaded, ignored) = roots.add_parsable_certificates(root_certs.iter().cloned());
if !root_certs.is_empty() && (loaded == 0 || ignored > 0) {
return Err(
std::io::Error::new(std::io::ErrorKind::InvalidInput, "Invalid certificate").into(),
);
}
Ok(roots)
}

fn make_verifier(
server_cert_verify: &TlsServerCertVerify,
root_cert: &TlsCert,
Expand All @@ -242,22 +260,12 @@ fn make_verifier(
root_cert,
TlsCert::Webpki | TlsCert::WebpkiPlus(_) | TlsCert::Custom(_)
) {
let mut roots = RootCertStore::empty();
if matches!(root_cert, TlsCert::Webpki | TlsCert::WebpkiPlus(_)) {
let webpki_roots = webpki_roots::TLS_SERVER_ROOTS;
roots.extend(webpki_roots.iter().cloned());
}

if let TlsCert::Custom(root) = root_cert {
let (loaded, ignored) = roots.add_parsable_certificates(root.iter().cloned());
if loaded == 0 || ignored > 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Invalid certificate",
)
.into());
}
}
let roots = match root_cert {
TlsCert::Webpki => make_roots(&[], true),
TlsCert::Custom(roots) => make_roots(roots, false),
TlsCert::WebpkiPlus(roots) => make_roots(roots, true),
_ => unreachable!(),
}?;

let verifier = WebPkiServerVerifier::builder(Arc::new(roots))
.with_crls(crls)
Expand All @@ -268,17 +276,23 @@ fn make_verifier(
return Ok(verifier);
}

let verifier = if let TlsCert::SystemPlus(roots) = root_cert {
Verifier::new_with_extra_roots(roots.iter().cloned())?
let verifier: Arc<dyn ServerCertVerifier> = if let TlsCert::SystemPlus(roots) = root_cert {
let roots = make_roots(roots, false)?;
let v1 = WebPkiServerVerifier::builder(Arc::new(roots))
.with_crls(crls)
.build()?;
let v2 = Arc::new(Verifier::new());
Arc::new(ChainingVerifier::new(v1, v2))
} else {
Verifier::new()
Arc::new(Verifier::new())
};

let verifier = if *server_cert_verify == TlsServerCertVerify::IgnoreHostname {
Arc::new(IgnoreHostnameVerifier::new(Arc::new(verifier))) as Arc<dyn ServerCertVerifier>
} else {
Arc::new(verifier)
};
let verifier: Arc<dyn ServerCertVerifier> =
if *server_cert_verify == TlsServerCertVerify::IgnoreHostname {
Arc::new(IgnoreHostnameVerifier::new(verifier))
} else {
verifier
};

Ok(verifier)
}
Expand Down Expand Up @@ -335,6 +349,98 @@ impl ServerCertVerifier for IgnoreHostnameVerifier {
}
}

#[derive(Debug)]
struct ChainingVerifier {
verifier1: Arc<dyn ServerCertVerifier>,
verifier2: Arc<dyn ServerCertVerifier>,
}

impl ChainingVerifier {
fn new(verifier1: Arc<dyn ServerCertVerifier>, verifier2: Arc<dyn ServerCertVerifier>) -> Self {
Self {
verifier1,
verifier2,
}
}
}

impl ServerCertVerifier for ChainingVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
server_name: &ServerName,
ocsp_response: &[u8],
now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
let res = self.verifier1.verify_server_cert(
end_entity,
intermediates,
server_name,
ocsp_response,
now,
);
if let Ok(res) = res {
return Ok(res);
}

let res2 = self.verifier2.verify_server_cert(
end_entity,
intermediates,
server_name,
ocsp_response,
now,
);
if let Ok(res) = res2 {
return Ok(res);
}

res
}

fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
let res = self.verifier1.verify_tls12_signature(message, cert, dss);
if let Ok(res) = res {
return Ok(res);
}

let res2 = self.verifier2.verify_tls12_signature(message, cert, dss);
if let Ok(res) = res2 {
return Ok(res);
}

res
}

fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
let res = self.verifier1.verify_tls13_signature(message, cert, dss);
if let Ok(res) = res {
return Ok(res);
}

let res2 = self.verifier2.verify_tls13_signature(message, cert, dss);
if let Ok(res) = res2 {
return Ok(res);
}

res
}

fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.verifier1.supported_verify_schemes()
}
}

#[derive(Debug)]
struct NullVerifier;

Expand Down
2 changes: 1 addition & 1 deletion gel-stream/tests/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ tls_test! {

Ok(())
}

/// The certificate is not valid for 127.0.0.1, so the connection should fail.
#[tokio::test]
#[ntest::timeout(30_000)]
Expand Down

0 comments on commit 17f3ada

Please sign in to comment.