Skip to content

Commit

Permalink
wip: upgrade rustls
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioBenitez committed Dec 14, 2023
1 parent a59f3c4 commit 11352ef
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 105 deletions.
6 changes: 3 additions & 3 deletions core/http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ percent-encoding = "2"
http = "0.2"
time = { version = "0.3", features = ["formatting", "macros"] }
indexmap = "2"
rustls = { version = "0.21", optional = true }
tokio-rustls = { version = "0.24", optional = true }
rustls-pemfile = { version = "1.0.2", optional = true }
rustls = { version = "0.22", optional = true }
tokio-rustls = { version = "0.25", optional = true }
rustls-pemfile = { version = "2.0.0", optional = true }
tokio = { version = "1.6.1", features = ["net", "sync", "time"] }
log = "0.4"
ref-cast = "1.0"
Expand Down
33 changes: 21 additions & 12 deletions core/http/src/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,45 @@ use state::InitCell;
pub use tokio::net::TcpListener;

/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
// NOTE: `rustls::Certificate` is exactly isomorphic to `CertificateData`.
#[doc(inline)]
#[cfg(feature = "tls")]
pub use rustls::Certificate as CertificateData;
#[cfg(not(feature = "tls"))]
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct CertificateDer(pub(crate) Vec<u8>);

/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
#[cfg(not(feature = "tls"))]
#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct CertificateData(pub Vec<u8>);
#[cfg(feature = "tls")]
#[derive(Debug, Clone, Eq, PartialEq)]
#[repr(transparent)]
pub struct CertificateDer(pub(crate) rustls::pki_types::CertificateDer<'static>);

/// A collection of raw certificate data.
#[derive(Clone, Default)]
pub struct Certificates(Arc<InitCell<Vec<CertificateData>>>);
pub struct Certificates(Arc<InitCell<Vec<CertificateDer>>>);

impl From<Vec<CertificateData>> for Certificates {
fn from(value: Vec<CertificateData>) -> Self {
impl From<Vec<CertificateDer>> for Certificates {
fn from(value: Vec<CertificateDer>) -> Self {
Certificates(Arc::new(value.into()))
}
}

#[cfg(feature = "tls")]
impl From<Vec<rustls::pki_types::CertificateDer<'static>>> for Certificates {
fn from(value: Vec<rustls::pki_types::CertificateDer<'static>>) -> Self {
let value: Vec<_> = value.into_iter().map(CertificateDer).collect();
Certificates(Arc::new(value.into()))
}
}

#[doc(hidden)]
impl Certificates {
/// Set the the raw certificate chain data. Only the first call actually
/// sets the data; the remaining do nothing.
#[cfg(feature = "tls")]
pub(crate) fn set(&self, data: Vec<CertificateData>) {
pub(crate) fn set(&self, data: Vec<CertificateDer>) {
self.0.set(data);
}

/// Returns the raw certificate chain data, if any is available.
pub fn chain_data(&self) -> Option<&[CertificateData]> {
pub fn chain_data(&self) -> Option<&[CertificateDer]> {
self.0.try_get().map(|v| v.as_slice())
}
}
Expand Down
53 changes: 53 additions & 0 deletions core/http/src/tls/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
pub type Result<T, E = Error> = std::result::Result<T, E>;

#[derive(Debug)]
pub enum KeyError {
// .map_err(|_| err("invalid key file"))
BadFile(std::io::Error),
// ("failed to find key header; supported formats are: RSA, PKCS8, SEC1")
MissingHeader,
NoKeysFound,
// Err(err("no valid keys found; is the file malformed?")),
// Err(err(format!("expected 1 key, found {}", n))),
BadKeyCount(usize),
// .map_err(|_| err("key parsed but is unusable"))
Unsupported,
Io(std::io::Error),
Unusable(rustls::Error),
BadItem(rustls_pemfile::Item),
}

#[derive(Debug)]
pub enum Error {
Io(std::io::Error),
Tls(rustls::Error),
Mtls(rustls::server::VerifierBuilderError),
CertChain(std::io::Error),
MissingKeyHeader,
PrivKey(KeyError),
CertAuth(rustls::Error),
}

impl From<std::io::Error> for Error {
fn from(e: std::io::Error) -> Self {
Error::Io(e)
}
}

impl From<rustls::Error> for Error {
fn from(e: rustls::Error) -> Self {
Error::Tls(e)
}
}

impl From<rustls::server::VerifierBuilderError> for Error {
fn from(value: rustls::server::VerifierBuilderError) -> Self {
Error::Mtls(value)
}
}

impl From<KeyError> for Error {
fn from(value: KeyError) -> Self {
Error::PrivKey(value)
}
}
70 changes: 34 additions & 36 deletions core/http/src/tls/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ use std::net::SocketAddr;
use tokio::net::{TcpListener, TcpStream};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::{Accept, TlsAcceptor, server::TlsStream as BareTlsStream};
use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier};

use crate::tls::util::{load_certs, load_private_key, load_ca_certs};
use crate::listener::{Connection, Listener, Certificates};
use crate::tls::util::{load_cert_chain, load_key, load_ca_certs};
use crate::listener::{Connection, Listener, Certificates, CertificateDer};

/// A TLS listener over TCP.
pub struct TlsListener {
Expand Down Expand Up @@ -40,7 +41,7 @@ pub struct TlsListener {
///
/// To work around this, we "lie" when `peer_certificates()` are requested and
/// always return `Some(Certificates)`. Internally, `Certificates` is an
/// `Arc<InitCell<Vec<CertificateData>>>`, effectively a shared, thread-safe,
/// `Arc<InitCell<Vec<CertificateDer>>>`, effectively a shared, thread-safe,
/// `OnceCell`. The cell is initially empty and is filled as soon as the
/// handshake is complete. If the certificate data were to be requested prior to
/// this point, it would be empty. However, in Rocket, we only request
Expand Down Expand Up @@ -72,49 +73,42 @@ pub struct Config<R> {
}

impl TlsListener {
pub async fn bind<R>(addr: SocketAddr, mut c: Config<R>) -> io::Result<TlsListener>
pub async fn bind<R>(addr: SocketAddr, mut c: Config<R>) -> crate::tls::Result<TlsListener>
where R: io::BufRead
{
use rustls::server::{AllowAnyAuthenticatedClient, AllowAnyAnonymousOrAuthenticatedClient};
use rustls::server::{NoClientAuth, ServerSessionMemoryCache, ServerConfig};

let cert_chain = load_certs(&mut c.cert_chain)
.map_err(|e| io::Error::new(e.kind(), format!("bad TLS cert chain: {}", e)))?;

let key = load_private_key(&mut c.private_key)
.map_err(|e| io::Error::new(e.kind(), format!("bad TLS private key: {}", e)))?;
let provider = rustls::crypto::CryptoProvider {
cipher_suites: c.ciphersuites,
..rustls::crypto::ring::default_provider()
};

let client_auth = match c.ca_certs {
Some(ref mut ca_certs) => match load_ca_certs(ca_certs) {
Ok(ca) if c.mandatory_mtls => AllowAnyAuthenticatedClient::new(ca).boxed(),
Ok(ca) => AllowAnyAnonymousOrAuthenticatedClient::new(ca).boxed(),
Err(e) => return Err(io::Error::new(e.kind(), format!("bad CA cert(s): {}", e))),
let verifier = match c.ca_certs {
Some(ref mut ca_certs) => {
let ca_roots = load_ca_certs(ca_certs)?;
let verifier = WebPkiClientVerifier::builder(Arc::new(ca_roots));
match c.mandatory_mtls {
true => verifier.build()?,
false => verifier.allow_unauthenticated().build()?,
}
},
None => NoClientAuth::boxed(),
None => WebPkiClientVerifier::no_client_auth(),
};

let mut tls_config = ServerConfig::builder()
.with_cipher_suites(&c.ciphersuites)
.with_safe_default_kx_groups()
.with_safe_default_protocol_versions()
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?
.with_client_cert_verifier(client_auth)
.with_single_cert(cert_chain, key)
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?;
let (cert_chain, key) = (load_cert_chain(&mut c.cert_chain)?, load_key(&mut c.private_key)?);
let mut config = ServerConfig::builder_with_provider(Arc::new(provider))
.with_safe_default_protocol_versions()?
.with_client_cert_verifier(verifier)
.with_single_cert(cert_chain, key)?;

tls_config.ignore_client_order = c.prefer_server_order;

tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
config.ignore_client_order = c.prefer_server_order;
config.session_storage = ServerSessionMemoryCache::new(1024);
config.ticketer = rustls::crypto::ring::Ticketer::new()?;
config.alpn_protocols = vec![b"http/1.1".to_vec()];
if cfg!(feature = "http2") {
tls_config.alpn_protocols.insert(0, b"h2".to_vec());
config.alpn_protocols.insert(0, b"h2".to_vec());
}

tls_config.session_storage = ServerSessionMemoryCache::new(1024);
tls_config.ticketer = rustls::Ticketer::new()
.map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS ticketer: {}", e)))?;

let listener = TcpListener::bind(addr).await?;
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
let acceptor = TlsAcceptor::from(Arc::new(config));
Ok(TlsListener { listener, acceptor })
}
}
Expand Down Expand Up @@ -180,7 +174,11 @@ impl TlsStream {
match futures::ready!(Pin::new(accept).poll(cx)) {
Ok(stream) => {
if let Some(cert_chain) = stream.get_ref().1.peer_certificates() {
self.certs.set(cert_chain.to_vec());
let owned_cert_chain = cert_chain.into_iter()
.map(|v| CertificateDer(v.clone().into_owned()))
.collect();

self.certs.set(owned_cert_chain);
}

self.state = TlsState::Streaming(stream);
Expand Down
3 changes: 3 additions & 0 deletions core/http/src/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ pub mod mtls;
pub use rustls;
pub use listener::{TlsListener, Config};
pub mod util;
pub mod error;

pub use error::Result;
6 changes: 3 additions & 3 deletions core/http/src/tls/mtls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use x509_parser::nom;
use x509::{ParsedExtension, X509Name, X509Certificate, TbsCertificate, X509Error, FromDer};
use oid::OID_X509_EXT_SUBJECT_ALT_NAME as SUBJECT_ALT_NAME;

use crate::listener::CertificateData;
use crate::listener::CertificateDer;

/// A type alias for [`Result`](std::result::Result) with the error type set to
/// [`Error`].
Expand Down Expand Up @@ -144,7 +144,7 @@ pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug, PartialEq)]
pub struct Certificate<'a> {
x509: X509Certificate<'a>,
data: &'a CertificateData,
data: &'a CertificateDer,
}

/// An X.509 Distinguished Name (DN) found in a [`Certificate`].
Expand Down Expand Up @@ -224,7 +224,7 @@ impl<'a> Certificate<'a> {

/// PRIVATE: For internal Rocket use only!
#[doc(hidden)]
pub fn parse(chain: &[CertificateData]) -> Result<Certificate<'_>> {
pub fn parse(chain: &[CertificateDer]) -> Result<Certificate<'_>> {
let data = chain.first().ok_or_else(|| Error::Empty)?;
let x509 = Certificate::parse_one(&data.0)?;
Ok(Certificate { x509, data })
Expand Down
86 changes: 39 additions & 47 deletions core/http/src/tls/util.rs
Original file line number Diff line number Diff line change
@@ -1,55 +1,47 @@
use std::io::{self, Cursor, Read};
use std::io;

use rustls::{Certificate, PrivateKey, RootCertStore};
use rustls::RootCertStore;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};

fn err(message: impl Into<std::borrow::Cow<'static, str>>) -> io::Error {
io::Error::new(io::ErrorKind::Other, message.into())
}
use crate::tls::error::{Result, Error, KeyError};

/// Loads certificates from `reader`.
pub fn load_certs(reader: &mut dyn io::BufRead) -> io::Result<Vec<Certificate>> {
let certs = rustls_pemfile::certs(reader).map_err(|_| err("invalid certificate"))?;
Ok(certs.into_iter().map(Certificate).collect())
pub fn load_cert_chain(reader: &mut dyn io::BufRead) -> Result<Vec<CertificateDer<'static>>> {
rustls_pemfile::certs(reader)
.collect::<Result<_, _>>()
.map_err(Error::CertChain)
}

/// Load and decode the private key from `reader`.
pub fn load_private_key(reader: &mut dyn io::BufRead) -> io::Result<PrivateKey> {
// "rsa" (PKCS1) PEM files have a different first-line header than PKCS8
// PEM files, use that to determine the parse function to use.
let mut header = String::new();
let private_keys_fn = loop {
header.clear();
if reader.read_line(&mut header)? == 0 {
return Err(err("failed to find key header; supported formats are: RSA, PKCS8, SEC1"));
}

break match header.trim_end() {
"-----BEGIN RSA PRIVATE KEY-----" => rustls_pemfile::rsa_private_keys,
"-----BEGIN PRIVATE KEY-----" => rustls_pemfile::pkcs8_private_keys,
"-----BEGIN EC PRIVATE KEY-----" => rustls_pemfile::ec_private_keys,
_ => continue,
};
};

let key = private_keys_fn(&mut Cursor::new(header).chain(reader))
.map_err(|_| err("invalid key file"))
.and_then(|mut keys| match keys.len() {
0 => Err(err("no valid keys found; is the file malformed?")),
1 => Ok(PrivateKey(keys.remove(0))),
n => Err(err(format!("expected 1 key, found {}", n))),
})?;
pub fn load_key(reader: &mut dyn io::BufRead) -> Result<PrivateKeyDer<'static>> {
use rustls_pemfile::Item::*;

let mut keys: Vec<PrivateKeyDer<'static>> = rustls_pemfile::read_all(reader)
.map(|result| result.map_err(KeyError::Io)
.and_then(|item| match item {
Pkcs1Key(key) => Ok(key.into()),
Pkcs8Key(key) => Ok(key.into()),
Sec1Key(key) => Ok(key.into()),
_ => Err(KeyError::BadItem(item))
})
)
.collect::<Result<_, _>>()?;

if keys.len() != 1 {
return Err(KeyError::BadKeyCount(keys.len()).into());
}

// Ensure we can use the key.
rustls::sign::any_supported_type(&key)
.map_err(|_| err("key parsed but is unusable"))
.map(|_| key)
let key = keys.remove(0);
rustls::crypto::ring::sign::any_supported_type(&key).map_err(KeyError::Unusable)?;
Ok(key)
}

/// Load and decode CA certificates from `reader`.
pub fn load_ca_certs(reader: &mut dyn io::BufRead) -> io::Result<RootCertStore> {
pub fn load_ca_certs(reader: &mut dyn io::BufRead) -> Result<RootCertStore> {
let mut roots = rustls::RootCertStore::empty();
for cert in load_certs(reader)? {
roots.add(&cert).map_err(|e| err(format!("CA cert error: {}", e)))?;
for cert in load_cert_chain(reader)? {
roots.add(cert).map_err(Error::CertAuth)?;
}

Ok(roots)
Expand All @@ -72,10 +64,10 @@ mod test {
let ecdsa_nistp384_sha384_key = tls_example_key!("ecdsa_nistp384_sha384_key_pkcs8.pem");
let ed2551_key = tls_example_key!("ed25519_key.pem");

load_private_key(&mut Cursor::new(rsa_sha256_key))?;
load_private_key(&mut Cursor::new(ecdsa_nistp256_sha256_key))?;
load_private_key(&mut Cursor::new(ecdsa_nistp384_sha384_key))?;
load_private_key(&mut Cursor::new(ed2551_key))?;
load_key(&mut Cursor::new(rsa_sha256_key))?;
load_key(&mut Cursor::new(ecdsa_nistp256_sha256_key))?;
load_key(&mut Cursor::new(ecdsa_nistp384_sha384_key))?;
load_key(&mut Cursor::new(ed2551_key))?;

Ok(())
}
Expand All @@ -87,10 +79,10 @@ mod test {
let ecdsa_nistp384_sha384_cert = tls_example_key!("ecdsa_nistp384_sha384_cert.pem");
let ed2551_cert = tls_example_key!("ed25519_cert.pem");

load_certs(&mut Cursor::new(rsa_sha256_cert))?;
load_certs(&mut Cursor::new(ecdsa_nistp256_sha256_cert))?;
load_certs(&mut Cursor::new(ecdsa_nistp384_sha384_cert))?;
load_certs(&mut Cursor::new(ed2551_cert))?;
load_cert_chain(&mut Cursor::new(rsa_sha256_cert))?;
load_cert_chain(&mut Cursor::new(ecdsa_nistp256_sha256_cert))?;
load_cert_chain(&mut Cursor::new(ecdsa_nistp384_sha384_cert))?;
load_cert_chain(&mut Cursor::new(ed2551_cert))?;

Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion core/lib/src/config/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ mod with_tls_feature {

use crate::http::tls::Config;
use crate::http::tls::rustls::SupportedCipherSuite as RustlsCipher;
use crate::http::tls::rustls::cipher_suite;
use crate::http::tls::rustls::crypto::ring::cipher_suite;

use yansi::Paint;

Expand Down
Loading

0 comments on commit 11352ef

Please sign in to comment.