Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
hatoo committed Feb 4, 2025
1 parent 4dccdef commit 5553ad3
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 172 deletions.
93 changes: 2 additions & 91 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,51 +168,6 @@ pub enum ClientError {
SigV4Error(&'static str),
}

#[cfg(feature = "rustls")]
pub struct RuslsConfigs {
no_alpn: Arc<rustls::ClientConfig>,
alpn_h2: Arc<rustls::ClientConfig>,
}

#[cfg(feature = "rustls")]
impl RuslsConfigs {
pub fn new(config: rustls::ClientConfig) -> Self {
let mut no_alpn = config.clone();
no_alpn.alpn_protocols = vec![];
let mut alpn_h2 = config;
alpn_h2.alpn_protocols = vec![b"h2".to_vec()];
Self {
no_alpn: Arc::new(no_alpn),
alpn_h2: Arc::new(alpn_h2),
}
}

pub fn config(&self, is_http2: bool) -> &Arc<rustls::ClientConfig> {
if is_http2 {
&self.alpn_h2
} else {
&self.no_alpn
}
}
}

#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
pub struct NativeTlsConnectors {
pub no_alpn: tokio_native_tls::TlsConnector,
pub alpn_h2: tokio_native_tls::TlsConnector,
}

#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
impl NativeTlsConnectors {
pub fn connector(&self, is_http2: bool) -> &tokio_native_tls::TlsConnector {
if is_http2 {
&self.alpn_h2
} else {
&self.no_alpn
}
}
}

pub struct Client {
pub http_version: http::Version,
pub proxy_http_version: http::Version,
Expand All @@ -231,9 +186,9 @@ pub struct Client {
#[cfg(feature = "vsock")]
pub vsock_addr: Option<tokio_vsock::VsockAddr>,
#[cfg(feature = "rustls")]
pub rustls_configs: RuslsConfigs,
pub rustls_configs: crate::tls_config::RuslsConfigs,
#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
pub native_tls_connectors: NativeTlsConnectors,
pub native_tls_connectors: crate::tls_config::NativeTlsConnectors,
}

struct ClientStateHttp1 {
Expand Down Expand Up @@ -874,50 +829,6 @@ impl Client {
}
}

/// A server certificate verifier that accepts any certificate.
#[cfg(feature = "rustls")]
#[derive(Debug)]
pub struct AcceptAnyServerCert;

#[cfg(feature = "rustls")]
impl rustls::client::danger::ServerCertVerifier for AcceptAnyServerCert {
fn verify_server_cert(
&self,
_end_entity: &rustls_pki_types::CertificateDer<'_>,
_intermediates: &[rustls_pki_types::CertificateDer<'_>],
_server_name: &rustls_pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls_pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}

fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}

fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}

fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::CryptoProvider::get_default()
.unwrap()
.signature_verification_algorithms
.supported_schemes()
}
}

/// Check error and decide whether to cancel the connection
fn is_cancel_error(res: &Result<RequestResult, ClientError>) -> bool {
matches!(res, Err(ClientError::Deadline)) || is_too_many_open_files(res)
Expand Down
35 changes: 2 additions & 33 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,40 +93,9 @@ mod test_db {
#[cfg(feature = "vsock")]
vsock_addr: None,
#[cfg(feature = "rustls")]
rustls_configs: {
let mut root_cert_store = rustls::RootCertStore::empty();
for cert in
rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
root_cert_store.add(cert).unwrap();
}
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_cert_store.clone())
.with_no_client_auth();
crate::client::RuslsConfigs::new(config)
},
rustls_configs: crate::tls_config::RuslsConfigs::new(false),
#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
native_tls_connectors: {
crate::client::NativeTlsConnectors {
no_alpn: {
let connector_builder = native_tls::TlsConnector::builder();

connector_builder
.build()
.expect("Failed to build native_tls::TlsConnector")
.into()
},
alpn_h2: {
let mut connector_builder = native_tls::TlsConnector::builder();

connector_builder.request_alpns(&["h2"]);
connector_builder
.build()
.expect("Failed to build native_tls::TlsConnector")
.into()
},
}
},
native_tls_connectors: crate::tls_config::NativeTlsConnectors::new(false),
};
let result = store(&client, ":memory:", start, &test_vec);
assert_eq!(result.unwrap(), 2);
Expand Down
51 changes: 3 additions & 48 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ mod pcg64si;
mod printer;
mod result_data;
mod timescale;
mod tls_config;
mod url_generator;

#[cfg(not(target_env = "msvc"))]
Expand Down Expand Up @@ -541,55 +542,9 @@ async fn run() -> anyhow::Result<()> {
#[cfg(feature = "vsock")]
vsock_addr: opts.vsock_addr.map(|v| v.0),
#[cfg(feature = "rustls")]
rustls_configs: {
let mut root_cert_store = rustls::RootCertStore::empty();
for cert in
rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
root_cert_store.add(cert).unwrap();
}
let mut config = rustls::ClientConfig::builder()
.with_root_certificates(root_cert_store.clone())
.with_no_client_auth();
if opts.insecure {
config
.dangerous()
.set_certificate_verifier(Arc::new(client::AcceptAnyServerCert));
}
client::RuslsConfigs::new(config)
},
rustls_configs: tls_config::RuslsConfigs::new(opts.insecure),
#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
native_tls_connectors: {
client::NativeTlsConnectors {
no_alpn: {
let mut connector_builder = native_tls::TlsConnector::builder();
if opts.insecure {
connector_builder
.danger_accept_invalid_certs(true)
.danger_accept_invalid_hostnames(true);
}

connector_builder
.build()
.expect("Failed to build native_tls::TlsConnector")
.into()
},
alpn_h2: {
let mut connector_builder = native_tls::TlsConnector::builder();
if opts.insecure {
connector_builder
.danger_accept_invalid_certs(true)
.danger_accept_invalid_hostnames(true);
}

connector_builder.request_alpns(&["h2"]);
connector_builder
.build()
.expect("Failed to build native_tls::TlsConnector")
.into()
},
}
},
native_tls_connectors: tls_config::RuslsConfigs::new(opts.insecure),
});

if !opts.no_pre_lookup {
Expand Down
129 changes: 129 additions & 0 deletions src/tls_config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
use std::sync::Arc;

#[cfg(feature = "rustls")]
pub struct RuslsConfigs {
no_alpn: Arc<rustls::ClientConfig>,
alpn_h2: Arc<rustls::ClientConfig>,
}

#[cfg(feature = "rustls")]
impl RuslsConfigs {
pub fn new(insecure: bool) -> Self {
let mut root_cert_store = rustls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
root_cert_store.add(cert).unwrap();
}
let mut config = rustls::ClientConfig::builder()
.with_root_certificates(root_cert_store.clone())
.with_no_client_auth();
if insecure {
config
.dangerous()
.set_certificate_verifier(Arc::new(AcceptAnyServerCert));
}

let mut no_alpn = config.clone();
no_alpn.alpn_protocols = vec![];
let mut alpn_h2 = config;
alpn_h2.alpn_protocols = vec![b"h2".to_vec()];
Self {
no_alpn: Arc::new(no_alpn),
alpn_h2: Arc::new(alpn_h2),
}
}

pub fn config(&self, is_http2: bool) -> &Arc<rustls::ClientConfig> {
if is_http2 {
&self.alpn_h2
} else {
&self.no_alpn
}
}
}

#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
pub struct NativeTlsConnectors {
pub no_alpn: tokio_native_tls::TlsConnector,
pub alpn_h2: tokio_native_tls::TlsConnector,
}

#[cfg(all(feature = "native-tls", not(feature = "rustls")))]
impl NativeTlsConnectors {
pub fn new(insecure: bool) -> Self {
let new = |is_http2: bool| {
let mut connector_builder = native_tls::TlsConnector::builder();
if insecure {
connector_builder
.danger_accept_invalid_certs(true)
.danger_accept_invalid_hostnames(true);
}

if is_http2 {
connector_builder.request_alpns(&["h2"]);
}

connector_builder
.build()
.expect("Failed to build native_tls::TlsConnector")
.into()
};

Self {
no_alpn: new(false),
alpn_h2: new(true),
}
}

pub fn connector(&self, is_http2: bool) -> &tokio_native_tls::TlsConnector {
if is_http2 {
&self.alpn_h2
} else {
&self.no_alpn
}
}
}

/// A server certificate verifier that accepts any certificate.
#[cfg(feature = "rustls")]
#[derive(Debug)]
pub struct AcceptAnyServerCert;

#[cfg(feature = "rustls")]
impl rustls::client::danger::ServerCertVerifier for AcceptAnyServerCert {
fn verify_server_cert(
&self,
_end_entity: &rustls_pki_types::CertificateDer<'_>,
_intermediates: &[rustls_pki_types::CertificateDer<'_>],
_server_name: &rustls_pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls_pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}

fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}

fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls_pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}

fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::CryptoProvider::get_default()
.unwrap()
.signature_verification_algorithms
.supported_schemes()
}
}

0 comments on commit 5553ad3

Please sign in to comment.