From 5553ad33439d8e88c51e2a88dd5ea9611833f140 Mon Sep 17 00:00:00 2001 From: hatoo Date: Tue, 4 Feb 2025 19:48:51 +0900 Subject: [PATCH] refactoring --- src/client.rs | 93 +-------------------------------- src/db.rs | 35 +------------ src/main.rs | 51 ++---------------- src/tls_config.rs | 129 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 172 deletions(-) create mode 100644 src/tls_config.rs diff --git a/src/client.rs b/src/client.rs index 8381a5f6..327cd818 100644 --- a/src/client.rs +++ b/src/client.rs @@ -168,51 +168,6 @@ pub enum ClientError { SigV4Error(&'static str), } -#[cfg(feature = "rustls")] -pub struct RuslsConfigs { - no_alpn: Arc, - alpn_h2: Arc, -} - -#[cfg(feature = "rustls")] -impl RuslsConfigs { - pub fn new(config: rustls::ClientConfig) -> Self { - let mut no_alpn = config.clone(); - no_alpn.alpn_protocols = vec![]; - let mut alpn_h2 = config; - alpn_h2.alpn_protocols = vec![b"h2".to_vec()]; - Self { - no_alpn: Arc::new(no_alpn), - alpn_h2: Arc::new(alpn_h2), - } - } - - pub fn config(&self, is_http2: bool) -> &Arc { - if is_http2 { - &self.alpn_h2 - } else { - &self.no_alpn - } - } -} - -#[cfg(all(feature = "native-tls", not(feature = "rustls")))] -pub struct NativeTlsConnectors { - pub no_alpn: tokio_native_tls::TlsConnector, - pub alpn_h2: tokio_native_tls::TlsConnector, -} - -#[cfg(all(feature = "native-tls", not(feature = "rustls")))] -impl NativeTlsConnectors { - pub fn connector(&self, is_http2: bool) -> &tokio_native_tls::TlsConnector { - if is_http2 { - &self.alpn_h2 - } else { - &self.no_alpn - } - } -} - pub struct Client { pub http_version: http::Version, pub proxy_http_version: http::Version, @@ -231,9 +186,9 @@ pub struct Client { #[cfg(feature = "vsock")] pub vsock_addr: Option, #[cfg(feature = "rustls")] - pub rustls_configs: RuslsConfigs, + pub rustls_configs: crate::tls_config::RuslsConfigs, #[cfg(all(feature = "native-tls", not(feature = "rustls")))] - pub native_tls_connectors: NativeTlsConnectors, + pub native_tls_connectors: crate::tls_config::NativeTlsConnectors, } struct ClientStateHttp1 { @@ -874,50 +829,6 @@ impl Client { } } -/// A server certificate verifier that accepts any certificate. -#[cfg(feature = "rustls")] -#[derive(Debug)] -pub struct AcceptAnyServerCert; - -#[cfg(feature = "rustls")] -impl rustls::client::danger::ServerCertVerifier for AcceptAnyServerCert { - fn verify_server_cert( - &self, - _end_entity: &rustls_pki_types::CertificateDer<'_>, - _intermediates: &[rustls_pki_types::CertificateDer<'_>], - _server_name: &rustls_pki_types::ServerName<'_>, - _ocsp_response: &[u8], - _now: rustls_pki_types::UnixTime, - ) -> Result { - Ok(rustls::client::danger::ServerCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - _message: &[u8], - _cert: &rustls_pki_types::CertificateDer<'_>, - _dss: &rustls::DigitallySignedStruct, - ) -> Result { - Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) - } - - fn verify_tls13_signature( - &self, - _message: &[u8], - _cert: &rustls_pki_types::CertificateDer<'_>, - _dss: &rustls::DigitallySignedStruct, - ) -> Result { - Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) - } - - fn supported_verify_schemes(&self) -> Vec { - rustls::crypto::CryptoProvider::get_default() - .unwrap() - .signature_verification_algorithms - .supported_schemes() - } -} - /// Check error and decide whether to cancel the connection fn is_cancel_error(res: &Result) -> bool { matches!(res, Err(ClientError::Deadline)) || is_too_many_open_files(res) diff --git a/src/db.rs b/src/db.rs index 3c0c2984..f5132c8c 100644 --- a/src/db.rs +++ b/src/db.rs @@ -93,40 +93,9 @@ mod test_db { #[cfg(feature = "vsock")] vsock_addr: None, #[cfg(feature = "rustls")] - rustls_configs: { - let mut root_cert_store = rustls::RootCertStore::empty(); - for cert in - rustls_native_certs::load_native_certs().expect("could not load platform certs") - { - root_cert_store.add(cert).unwrap(); - } - let config = rustls::ClientConfig::builder() - .with_root_certificates(root_cert_store.clone()) - .with_no_client_auth(); - crate::client::RuslsConfigs::new(config) - }, + rustls_configs: crate::tls_config::RuslsConfigs::new(false), #[cfg(all(feature = "native-tls", not(feature = "rustls")))] - native_tls_connectors: { - crate::client::NativeTlsConnectors { - no_alpn: { - let connector_builder = native_tls::TlsConnector::builder(); - - connector_builder - .build() - .expect("Failed to build native_tls::TlsConnector") - .into() - }, - alpn_h2: { - let mut connector_builder = native_tls::TlsConnector::builder(); - - connector_builder.request_alpns(&["h2"]); - connector_builder - .build() - .expect("Failed to build native_tls::TlsConnector") - .into() - }, - } - }, + native_tls_connectors: crate::tls_config::NativeTlsConnectors::new(false), }; let result = store(&client, ":memory:", start, &test_vec); assert_eq!(result.unwrap(), 2); diff --git a/src/main.rs b/src/main.rs index 3e6f94d3..62bd7a52 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,6 +33,7 @@ mod pcg64si; mod printer; mod result_data; mod timescale; +mod tls_config; mod url_generator; #[cfg(not(target_env = "msvc"))] @@ -541,55 +542,9 @@ async fn run() -> anyhow::Result<()> { #[cfg(feature = "vsock")] vsock_addr: opts.vsock_addr.map(|v| v.0), #[cfg(feature = "rustls")] - rustls_configs: { - let mut root_cert_store = rustls::RootCertStore::empty(); - for cert in - rustls_native_certs::load_native_certs().expect("could not load platform certs") - { - root_cert_store.add(cert).unwrap(); - } - let mut config = rustls::ClientConfig::builder() - .with_root_certificates(root_cert_store.clone()) - .with_no_client_auth(); - if opts.insecure { - config - .dangerous() - .set_certificate_verifier(Arc::new(client::AcceptAnyServerCert)); - } - client::RuslsConfigs::new(config) - }, + rustls_configs: tls_config::RuslsConfigs::new(opts.insecure), #[cfg(all(feature = "native-tls", not(feature = "rustls")))] - native_tls_connectors: { - client::NativeTlsConnectors { - no_alpn: { - let mut connector_builder = native_tls::TlsConnector::builder(); - if opts.insecure { - connector_builder - .danger_accept_invalid_certs(true) - .danger_accept_invalid_hostnames(true); - } - - connector_builder - .build() - .expect("Failed to build native_tls::TlsConnector") - .into() - }, - alpn_h2: { - let mut connector_builder = native_tls::TlsConnector::builder(); - if opts.insecure { - connector_builder - .danger_accept_invalid_certs(true) - .danger_accept_invalid_hostnames(true); - } - - connector_builder.request_alpns(&["h2"]); - connector_builder - .build() - .expect("Failed to build native_tls::TlsConnector") - .into() - }, - } - }, + native_tls_connectors: tls_config::RuslsConfigs::new(opts.insecure), }); if !opts.no_pre_lookup { diff --git a/src/tls_config.rs b/src/tls_config.rs new file mode 100644 index 00000000..7d9b326c --- /dev/null +++ b/src/tls_config.rs @@ -0,0 +1,129 @@ +use std::sync::Arc; + +#[cfg(feature = "rustls")] +pub struct RuslsConfigs { + no_alpn: Arc, + alpn_h2: Arc, +} + +#[cfg(feature = "rustls")] +impl RuslsConfigs { + pub fn new(insecure: bool) -> Self { + let mut root_cert_store = rustls::RootCertStore::empty(); + for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") + { + root_cert_store.add(cert).unwrap(); + } + let mut config = rustls::ClientConfig::builder() + .with_root_certificates(root_cert_store.clone()) + .with_no_client_auth(); + if insecure { + config + .dangerous() + .set_certificate_verifier(Arc::new(AcceptAnyServerCert)); + } + + let mut no_alpn = config.clone(); + no_alpn.alpn_protocols = vec![]; + let mut alpn_h2 = config; + alpn_h2.alpn_protocols = vec![b"h2".to_vec()]; + Self { + no_alpn: Arc::new(no_alpn), + alpn_h2: Arc::new(alpn_h2), + } + } + + pub fn config(&self, is_http2: bool) -> &Arc { + if is_http2 { + &self.alpn_h2 + } else { + &self.no_alpn + } + } +} + +#[cfg(all(feature = "native-tls", not(feature = "rustls")))] +pub struct NativeTlsConnectors { + pub no_alpn: tokio_native_tls::TlsConnector, + pub alpn_h2: tokio_native_tls::TlsConnector, +} + +#[cfg(all(feature = "native-tls", not(feature = "rustls")))] +impl NativeTlsConnectors { + pub fn new(insecure: bool) -> Self { + let new = |is_http2: bool| { + let mut connector_builder = native_tls::TlsConnector::builder(); + if insecure { + connector_builder + .danger_accept_invalid_certs(true) + .danger_accept_invalid_hostnames(true); + } + + if is_http2 { + connector_builder.request_alpns(&["h2"]); + } + + connector_builder + .build() + .expect("Failed to build native_tls::TlsConnector") + .into() + }; + + Self { + no_alpn: new(false), + alpn_h2: new(true), + } + } + + pub fn connector(&self, is_http2: bool) -> &tokio_native_tls::TlsConnector { + if is_http2 { + &self.alpn_h2 + } else { + &self.no_alpn + } + } +} + +/// A server certificate verifier that accepts any certificate. +#[cfg(feature = "rustls")] +#[derive(Debug)] +pub struct AcceptAnyServerCert; + +#[cfg(feature = "rustls")] +impl rustls::client::danger::ServerCertVerifier for AcceptAnyServerCert { + fn verify_server_cert( + &self, + _end_entity: &rustls_pki_types::CertificateDer<'_>, + _intermediates: &[rustls_pki_types::CertificateDer<'_>], + _server_name: &rustls_pki_types::ServerName<'_>, + _ocsp_response: &[u8], + _now: rustls_pki_types::UnixTime, + ) -> Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls_pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls_pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + rustls::crypto::CryptoProvider::get_default() + .unwrap() + .signature_verification_algorithms + .supported_schemes() + } +}