Skip to content

Commit

Permalink
rustls upgrade: done
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioBenitez committed Dec 15, 2023
1 parent 11352ef commit eb02c0d
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 33 deletions.
64 changes: 53 additions & 11 deletions core/http/src/tls/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,9 @@ pub type Result<T, E = Error> = std::result::Result<T, E>;

#[derive(Debug)]
pub enum KeyError {
// .map_err(|_| err("invalid key file"))
BadFile(std::io::Error),
// ("failed to find key header; supported formats are: RSA, PKCS8, SEC1")
MissingHeader,
NoKeysFound,
// Err(err("no valid keys found; is the file malformed?")),
// Err(err(format!("expected 1 key, found {}", n))),
BadKeyCount(usize),
// .map_err(|_| err("key parsed but is unusable"))
Unsupported,
Io(std::io::Error),
Unusable(rustls::Error),
Unsupported(rustls::Error),
BadItem(rustls_pemfile::Item),
}

Expand All @@ -23,11 +14,62 @@ pub enum Error {
Tls(rustls::Error),
Mtls(rustls::server::VerifierBuilderError),
CertChain(std::io::Error),
MissingKeyHeader,
PrivKey(KeyError),
CertAuth(rustls::Error),
}

impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use Error::*;

match self {
Io(e) => write!(f, "i/o error during tls binding: {e}"),
Tls(e) => write!(f, "tls configuration error: {e}"),
Mtls(e) => write!(f, "mtls verifier error: {e}"),
CertChain(e) => write!(f, "failed to process certificate chain: {e}"),
PrivKey(e) => write!(f, "failed to process private key: {e}"),
CertAuth(e) => write!(f, "failed to process certificate authority: {e}"),
}
}
}

impl std::fmt::Display for KeyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use KeyError::*;

match self {
Io(e) => write!(f, "error reading key file: {e}"),
BadKeyCount(0) => write!(f, "no valid keys found. is the file malformed?"),
BadKeyCount(n) => write!(f, "expected exactly 1 key, found {n}"),
Unsupported(e) => write!(f, "key is valid but is unsupported: {e}"),
BadItem(i) => write!(f, "found unexpected item in key file: {i:#?}"),
}
}
}

impl std::error::Error for KeyError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
KeyError::Io(e) => Some(e),
KeyError::Unsupported(e) => Some(e),
_ => None,
}
}
}

impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Error::Io(e) => Some(e),
Error::Tls(e) => Some(e),
Error::Mtls(e) => Some(e),
Error::CertChain(e) => Some(e),
Error::PrivKey(e) => Some(e),
Error::CertAuth(e) => Some(e),
}
}
}

impl From<std::io::Error> for Error {
fn from(e: std::io::Error) -> Self {
Error::Io(e)
Expand Down
15 changes: 7 additions & 8 deletions core/http/src/tls/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ impl TlsListener {

let verifier = match c.ca_certs {
Some(ref mut ca_certs) => {
let ca_roots = load_ca_certs(ca_certs)?;
let verifier = WebPkiClientVerifier::builder(Arc::new(ca_roots));
let ca_roots = Arc::new(load_ca_certs(ca_certs)?);
let verifier = WebPkiClientVerifier::builder(ca_roots);
match c.mandatory_mtls {
true => verifier.build()?,
false => verifier.allow_unauthenticated().build()?,
Expand All @@ -93,7 +93,8 @@ impl TlsListener {
None => WebPkiClientVerifier::no_client_auth(),
};

let (cert_chain, key) = (load_cert_chain(&mut c.cert_chain)?, load_key(&mut c.private_key)?);
let key = load_key(&mut c.private_key)?;
let cert_chain = load_cert_chain(&mut c.cert_chain)?;
let mut config = ServerConfig::builder_with_provider(Arc::new(provider))
.with_safe_default_protocol_versions()?
.with_client_cert_verifier(verifier)
Expand Down Expand Up @@ -173,12 +174,10 @@ impl TlsStream {
TlsState::Handshaking(ref mut accept) => {
match futures::ready!(Pin::new(accept).poll(cx)) {
Ok(stream) => {
if let Some(cert_chain) = stream.get_ref().1.peer_certificates() {
let owned_cert_chain = cert_chain.into_iter()
if let Some(peer_certs) = stream.get_ref().1.peer_certificates() {
self.certs.set(peer_certs.into_iter()
.map(|v| CertificateDer(v.clone().into_owned()))
.collect();

self.certs.set(owned_cert_chain);
.collect());
}

self.state = TlsState::Streaming(stream);
Expand Down
2 changes: 1 addition & 1 deletion core/http/src/tls/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub fn load_key(reader: &mut dyn io::BufRead) -> Result<PrivateKeyDer<'static>>

// Ensure we can use the key.
let key = keys.remove(0);
rustls::crypto::ring::sign::any_supported_type(&key).map_err(KeyError::Unusable)?;
rustls::crypto::ring::sign::any_supported_type(&key).map_err(KeyError::Unsupported)?;
Ok(key)
}

Expand Down
14 changes: 9 additions & 5 deletions core/lib/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,11 @@ impl Error {

"aborting due to failed shutdown"
}
ErrorKind::TlsBind(_) => todo!(),
ErrorKind::TlsBind(e) => {
error!("Rocket failed to bind via TLS to network socket.");
info_!("{}", e);
"aborting due to TLS bind error"
}
}
}
}
Expand All @@ -247,16 +251,16 @@ impl fmt::Display for ErrorKind {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ErrorKind::Bind(e) => write!(f, "binding failed: {}", e),
ErrorKind::Io(e) => write!(f, "I/O error: {}", e),
ErrorKind::Bind(e) => write!(f, "binding failed: {e}"),
ErrorKind::Io(e) => write!(f, "I/O error: {e}"),
ErrorKind::Collisions(_) => "collisions detected".fmt(f),
ErrorKind::FailedFairings(_) => "launch fairing(s) failed".fmt(f),
ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f),
ErrorKind::Config(_) => "failed to extract configuration".fmt(f),
ErrorKind::SentinelAborts(_) => "sentinel(s) aborted".fmt(f),
ErrorKind::Shutdown(_, Some(e)) => write!(f, "shutdown failed: {}", e),
ErrorKind::Shutdown(_, Some(e)) => write!(f, "shutdown failed: {e}"),
ErrorKind::Shutdown(_, None) => "shutdown failed".fmt(f),
ErrorKind::TlsBind(_) => todo!(),
ErrorKind::TlsBind(e) => write!(f, "TLS bind failed: {e}"),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions examples/tls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ publish = false

[dependencies]
rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets"] }
yansi = "1.0.0-rc.1"
2 changes: 1 addition & 1 deletion examples/tls/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ fn rocket() -> _ {
// Run `./private/gen_certs.sh` to generate a CA and key pairs.
rocket::build()
.mount("/", routes![hello, mutual])
.attach(redirector::Redirector { port: 3000 })
.attach(crate::redirector::Redirector::on(3000))
}
39 changes: 32 additions & 7 deletions examples/tls/src/redirector.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
//! Redirect all HTTP requests to HTTPs.
use std::sync::OnceLock;

use rocket::http::Status;
use rocket::log::LogLevel;
use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite, Config};
use rocket::fairing::{Fairing, Info, Kind};
use rocket::response::Redirect;

#[derive(Debug, Copy, Clone)]
#[derive(Debug, Clone)]
pub struct Redirector {
pub port: u16
pub listen_port: u16,
pub tls_port: OnceLock<u16>,
}

impl Redirector {
// Route function that gets call on every single request.
pub fn on(port: u16) -> Self {
Redirector { listen_port: port, tls_port: OnceLock::new() }
}

// Route function that gets called on every single request.
fn redirect<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> {
// FIXME: Check the host against a whitelist!
let redirector = req.rocket().state::<Self>().expect("managed Self");
if let Some(host) = req.host() {
let https_uri = format!("https://{}{}", host, req.uri());
let domain = host.domain();
let https_uri = match redirector.tls_port.get() {
Some(443) | None => format!("https://{domain}{}", req.uri()),
Some(port) => format!("https://{domain}:{port}{}", req.uri()),
};

route::Outcome::from(req, Redirect::permanent(https_uri)).pin()
} else {
route::Outcome::from(req, Status::BadRequest).pin()
Expand All @@ -25,20 +38,29 @@ impl Redirector {

// Launch an instance of Rocket than handles redirection on `self.port`.
pub async fn try_launch(self, mut config: Config) -> Result<Rocket<Ignite>, Error> {
use yansi::Paint;
use rocket::http::Method::*;

// Determine the port TLS is being served on.
let tls_port = self.tls_port.get_or_init(|| config.port);

// Adjust config for redirector: disable TLS, set port, disable logging.
config.tls = None;
config.port = self.port;
config.port = self.listen_port;
config.log_level = LogLevel::Critical;

info!("{}{}", "🔒 ".mask(), "HTTP -> HTTPS Redirector:".magenta());
info_!("redirecting on insecure port {} to TLS port {}",
self.listen_port.yellow(), tls_port.green());

// Build a vector of routes to `redirect` on `<path..>` for each method.
let redirects = [Get, Put, Post, Delete, Options, Head, Trace, Connect, Patch]
.into_iter()
.map(|m| Route::new(m, "/<path..>", Self::redirect))
.collect::<Vec<_>>();

rocket::custom(config)
.manage(self)
.mount("/", redirects)
.launch()
.await
Expand All @@ -48,11 +70,14 @@ impl Redirector {
#[rocket::async_trait]
impl Fairing for Redirector {
fn info(&self) -> Info {
Info { name: "HTTP -> HTTPS Redirector", kind: Kind::Liftoff }
Info {
name: "HTTP -> HTTPS Redirector",
kind: Kind::Liftoff | Kind::Singleton
}
}

async fn on_liftoff(&self, rkt: &Rocket<Orbit>) {
let (this, shutdown, config) = (*self, rkt.shutdown(), rkt.config().clone());
let (this, shutdown, config) = (self.clone(), rkt.shutdown(), rkt.config().clone());
let _ = rocket::tokio::spawn(async move {
if let Err(e) = this.try_launch(config).await {
error!("Failed to start HTTP -> HTTPS redirector.");
Expand Down

0 comments on commit eb02c0d

Please sign in to comment.