From 8edd6898fa0c85097259f484aa858d8823d2bb0e Mon Sep 17 00:00:00 2001 From: Nimi Wariboko Jr Date: Thu, 11 Jan 2024 14:18:58 -0800 Subject: [PATCH] scylla: Add support for rustls Fixes #293 --- scylla/Cargo.toml | 6 +- scylla/src/lib.rs | 6 ++ scylla/src/transport/connection.rs | 96 +++++++++++++++++++++++-- scylla/src/transport/connection_pool.rs | 2 +- scylla/src/transport/session.rs | 15 +++- scylla/src/transport/session_builder.rs | 34 +++++++++ 6 files changed, 149 insertions(+), 10 deletions(-) diff --git a/scylla/Cargo.toml b/scylla/Cargo.toml index adbb51f04a..0cabdf2860 100644 --- a/scylla/Cargo.toml +++ b/scylla/Cargo.toml @@ -14,9 +14,10 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] [features] -default = [] +default = ["rustls"] ssl = ["dep:tokio-openssl", "dep:openssl"] -cloud = ["ssl", "scylla-cql/serde", "dep:serde_yaml", "dep:serde", "dep:url", "dep:base64"] +rustls = ["dep:tokio-rustls"] +cloud = ["scylla-cql/serde", "dep:serde_yaml", "dep:serde", "dep:url", "dep:base64"] secret = ["scylla-cql/secret"] chrono = ["scylla-cql/chrono"] time = ["scylla-cql/time"] @@ -42,6 +43,7 @@ tracing = "0.1.36" chrono = { version = "0.4.20", default-features = false, features = ["clock"] } openssl = { version = "0.10.32", optional = true } tokio-openssl = { version = "0.6.1", optional = true } +tokio-rustls = { version = "0.25", optional = true } arc-swap = "1.3.0" dashmap = "5.2" strum = "0.23" diff --git a/scylla/src/lib.rs b/scylla/src/lib.rs index 5bf9bc69e8..8671d8c76e 100644 --- a/scylla/src/lib.rs +++ b/scylla/src/lib.rs @@ -142,3 +142,9 @@ pub use transport::retry_policy; pub use transport::speculative_execution; pub use transport::metrics::Metrics; + +#[cfg(all(feature = "ssl", feature = "rustls"))] +compile_error!("both rustls and ssl should not be enabled together."); + +#[cfg(all(feature = "cloud", not(any(feature = "ssl", feature = "rustls"))))] +compile_error!("cloud feature requires either the rustls or ssl feature."); diff --git a/scylla/src/transport/connection.rs b/scylla/src/transport/connection.rs index b6b91b69db..dd55309aa3 100644 --- a/scylla/src/transport/connection.rs +++ b/scylla/src/transport/connection.rs @@ -23,8 +23,10 @@ use std::sync::atomic::AtomicU64; use std::time::Duration; #[cfg(feature = "ssl")] use tokio_openssl::SslStream; +#[cfg(feature = "rustls")] +use tokio_rustls::TlsConnector; -#[cfg(feature = "ssl")] +#[cfg(any(feature = "ssl", feature = "rustls"))] pub(crate) use ssl_config::SslConfig; use crate::authentication::AuthenticatorProvider; @@ -280,12 +282,19 @@ impl NonErrorQueryResponse { }) } } -#[cfg(feature = "ssl")] +#[cfg(any(feature = "ssl", feature = "rustls"))] mod ssl_config { + #[cfg(feature = "ssl")] use openssl::{ error::ErrorStack, ssl::{Ssl, SslContext}, }; + #[cfg(feature = "rustls")] + use std::{net::IpAddr, sync::Arc}; + #[cfg(feature = "rustls")] + use tokio_rustls::rustls::pki_types::ServerName; + #[cfg(feature = "rustls")] + use tokio_rustls::rustls::ClientConfig; #[cfg(feature = "cloud")] use uuid::Uuid; @@ -299,6 +308,7 @@ mod ssl_config { // NodeConnectionPool::new(). Inside that function, the field is mutated to contain SslConfig specific // for the particular node. (The SslConfig must be different, because SNIs differ for different nodes.) // Thenceforth, all connections to that node share the same SslConfig. + #[cfg(feature = "ssl")] #[derive(Clone)] pub struct SslConfig { context: SslContext, @@ -306,6 +316,7 @@ mod ssl_config { sni: Option, } + #[cfg(feature = "ssl")] impl SslConfig { // Used in case when the user provided their own SslContext to be used in all connections. pub fn new_with_global_context(context: SslContext) -> Self { @@ -345,6 +356,58 @@ mod ssl_config { Ok(ssl) } } + + #[cfg(feature = "rustls")] + #[derive(Clone)] + pub struct SslConfig { + config: Arc, + #[cfg(feature = "cloud")] + sni: Option>, + } + + impl SslConfig { + // Used in case when the user provided their own ClientConfig to be used in all connections. + pub fn new_with_global_config(config: &Arc) -> Self { + Self { + config: config.clone(), + #[cfg(feature = "cloud")] + sni: None, + } + } + + // Used in case of Serverless Cloud connections. + #[cfg(feature = "cloud")] + pub(crate) fn new_for_sni( + config: &Arc, + domain_name: &str, + host_id: Option, + ) -> Self { + Self { + config: config.clone(), + #[cfg(feature = "cloud")] + sni: Some(if let Some(host_id) = host_id { + ServerName::try_from(&format!("{}.{}", host_id, domain_name)) + .expect("invalid DNS name") + .to_owned() + } else { + ServerName::try_from(domain_name.into().expect("invalid DNS name")).to_owned() + }), + } + } + + pub(crate) fn server_name(&self, node_addr: IpAddr) -> ServerName<'static> { + #[cfg(feature = "cloud")] + if let Some(sni) = self.sni.as_ref() { + return sni.clone(); + } + ServerName::IpAddress(node_addr.into()) + } + + // A reference to the rustls Client Config to produce a TlsConnection + pub(crate) fn config(&self) -> &Arc { + &self.config + } + } } #[derive(Clone)] @@ -352,7 +415,7 @@ pub struct ConnectionConfig { pub compression: Option, pub tcp_nodelay: bool, pub tcp_keepalive_interval: Option, - #[cfg(feature = "ssl")] + #[cfg(any(feature = "ssl", feature = "rustls"))] pub ssl_config: Option, pub connect_timeout: std::time::Duration, // should be Some only in control connections, @@ -375,7 +438,7 @@ impl Default for ConnectionConfig { tcp_nodelay: true, tcp_keepalive_interval: None, event_sender: None, - #[cfg(feature = "ssl")] + #[cfg(any(feature = "ssl", feature = "rustls"))] ssl_config: None, connect_timeout: std::time::Duration::from_secs(5), default_consistency: Default::default(), @@ -393,7 +456,7 @@ impl Default for ConnectionConfig { } impl ConnectionConfig { - #[cfg(feature = "ssl")] + #[cfg(any(feature = "ssl", feature = "rustls"))] pub fn is_ssl(&self) -> bool { #[cfg(feature = "cloud")] if self.cloud_config.is_some() { @@ -402,7 +465,7 @@ impl ConnectionConfig { self.ssl_config.is_some() } - #[cfg(not(feature = "ssl"))] + #[cfg(not(any(feature = "ssl", feature = "rustls")))] pub fn is_ssl(&self) -> bool { false } @@ -1034,6 +1097,27 @@ impl Connection { return Ok(handle); } + #[cfg(feature = "rustls")] + if let Some(rustls_config) = &config.ssl_config { + let connector = TlsConnector::from(rustls_config.config().clone()); + let stream = connector + .connect(rustls_config.server_name(node_address), stream) + .await?; + + let (task, handle) = Self::router( + config, + stream, + receiver, + error_sender, + orphan_notification_receiver, + router_handle, + node_address, + ) + .remote_handle(); + tokio::task::spawn(task.with_current_subscriber()); + return Ok(handle); + } + let (task, handle) = Self::router( config, stream, diff --git a/scylla/src/transport/connection_pool.rs b/scylla/src/transport/connection_pool.rs index f26ea36ac2..c3599a4b3f 100644 --- a/scylla/src/transport/connection_pool.rs +++ b/scylla/src/transport/connection_pool.rs @@ -1282,7 +1282,7 @@ mod tests { let connection_config = ConnectionConfig { compression: None, tcp_nodelay: true, - #[cfg(feature = "ssl")] + #[cfg(any(feature = "ssl", feature = "rustls"))] ssl_config: None, ..Default::default() }; diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index f4f5ab2365..4085acd284 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -36,7 +36,7 @@ use uuid::Uuid; use super::connection::NonErrorQueryResponse; use super::connection::QueryResponse; -#[cfg(feature = "ssl")] +#[cfg(any(feature = "ssl", feature = "rustls"))] use super::connection::SslConfig; use super::errors::{NewSessionError, QueryError}; use super::execution_profile::{ExecutionProfile, ExecutionProfileHandle, ExecutionProfileInner}; @@ -77,6 +77,8 @@ use crate::authentication::AuthenticatorProvider; #[cfg(feature = "ssl")] use openssl::ssl::SslContext; use scylla_cql::errors::BadQuery; +#[cfg(feature = "rustls")] +use tokio_rustls::rustls::ClientConfig; /// Translates IP addresses received from ScyllaDB nodes into locally reachable addresses. /// @@ -196,6 +198,10 @@ pub struct SessionConfig { #[cfg(feature = "ssl")] pub ssl_context: Option, + /// Provide our Session with TLS + #[cfg(feature = "rustls")] + pub rustls_config: Option>, + pub authenticator: Option>, pub connect_timeout: Duration, @@ -312,6 +318,8 @@ impl SessionConfig { keyspace_case_sensitive: false, #[cfg(feature = "ssl")] ssl_context: None, + #[cfg(feature = "rustls")] + rustls_config: None, authenticator: None, connect_timeout: Duration::from_secs(5), connection_pool_size: Default::default(), @@ -499,6 +507,11 @@ impl Session { tcp_keepalive_interval: config.tcp_keepalive_interval, #[cfg(feature = "ssl")] ssl_config: config.ssl_context.map(SslConfig::new_with_global_context), + #[cfg(feature = "rustls")] + ssl_config: config + .rustls_config + .as_ref() + .map(SslConfig::new_with_global_config), authenticator: config.authenticator.clone(), connect_timeout: config.connect_timeout, event_sender: None, diff --git a/scylla/src/transport/session_builder.rs b/scylla/src/transport/session_builder.rs index 09ee03b961..da779520b3 100644 --- a/scylla/src/transport/session_builder.rs +++ b/scylla/src/transport/session_builder.rs @@ -25,6 +25,8 @@ use std::time::Duration; use crate::authentication::{AuthenticatorProvider, PlainTextAuthenticator}; #[cfg(feature = "ssl")] use openssl::ssl::SslContext; +#[cfg(feature = "rustls")] +use tokio_rustls::rustls::ClientConfig; use tracing::warn; mod sealed { @@ -334,6 +336,38 @@ impl GenericSessionBuilder { self.config.ssl_context = ssl_context; self } + + /// rustls feature + /// Provide SessionBuilder with ClientConfig from rustls crate that will be + /// used to create an ssl connection to the database. + /// If set to None SSL connection won't be used. + /// Default is None. + /// + /// # Example + /// ``` + /// # use std::fs; + /// # use std::path::PathBuf; + /// # use scylla::{Session, SessionBuilder}; + /// # use openssl::ssl::{SslContextBuilder, SslVerifyMode, SslMethod, SslFiletype}; + /// # async fn example() -> Result<(), Box> { + /// let certdir = fs::canonicalize(PathBuf::from("./examples/certs/scylla.crt"))?; + /// let mut context_builder = SslContextBuilder::new(SslMethod::tls())?; + /// context_builder.set_certificate_file(certdir.as_path(), SslFiletype::PEM)?; + /// context_builder.set_verify(SslVerifyMode::NONE); + /// + /// let session: Session = SessionBuilder::new() + /// .known_node("127.0.0.1:9042") + /// .ssl_context(Some(context_builder.build())) + /// .build() + /// .await?; + /// # Ok(()) + /// # } + /// ``` + #[cfg(feature = "rustls")] + pub fn rustls_config(mut self, config: Option>) -> Self { + self.config.rustls_config = config; + self + } } // NOTE: this `impl` block contains configuration options specific for **Cloud** [`Session`].