Skip to content

Commit

Permalink
remove trait ResolveCertificate
Browse files Browse the repository at this point in the history
rename MutexWrappedCertificateResolver to MutexCertificateResolver
  • Loading branch information
Keksoj committed Feb 13, 2024
1 parent f10f83f commit 76d4b7f
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 117 deletions.
89 changes: 34 additions & 55 deletions lib/src/https.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ use crate::{
server::{ListenToken, SessionManager},
socket::{server_bind, FrontRustls},
timer::TimeoutContainer,
tls::{CertifiedKeyWrapper, MutexWrappedCertificateResolver, ResolveCertificate},
tls::MutexCertificateResolver,
util::UnwrapLog,
AcceptError, CachedTags, FrontendFromRequestError, L7ListenerHandler, L7Proxy, ListenerError,
ListenerHandler, Protocol, ProxyConfiguration, ProxyError, ProxySession, SessionIsToBeClosed,
Expand Down Expand Up @@ -529,7 +529,7 @@ pub struct HttpsListener {
config: HttpsListenerConfig,
fronts: Router,
listener: Option<MioTcpListener>,
resolver: Arc<MutexWrappedCertificateResolver>,
resolver: Arc<MutexCertificateResolver>,
rustls_details: Arc<RustlsServerConfig>,
tags: BTreeMap<String, CachedTags>,
token: Token,
Expand Down Expand Up @@ -609,51 +609,12 @@ impl L7ListenerHandler for HttpsListener {
}
}

impl ResolveCertificate for HttpsListener {
type Error = ListenerError;

fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<CertifiedKeyWrapper> {
let resolver = self
.resolver
.0
.lock()
.map_err(|err| ListenerError::Lock(err.to_string()))
.ok()?;

resolver.get_certificate(fingerprint)
}

fn add_certificate(&mut self, opts: &AddCertificate) -> Result<Fingerprint, Self::Error> {
let mut resolver = self
.resolver
.0
.lock()
.map_err(|err| ListenerError::Lock(err.to_string()))?;

resolver
.add_certificate(opts)
.map_err(ListenerError::Resolver)
}

fn remove_certificate(&mut self, fingerprint: &Fingerprint) -> Result<(), Self::Error> {
let mut resolver = self
.resolver
.0
.lock()
.map_err(|err| ListenerError::Lock(err.to_string()))?;

resolver
.remove_certificate(fingerprint)
.map_err(ListenerError::Resolver)
}
}

impl HttpsListener {
pub fn try_new(
config: HttpsListenerConfig,
token: Token,
) -> Result<HttpsListener, ListenerError> {
let resolver = Arc::new(MutexWrappedCertificateResolver::default());
let resolver = Arc::new(MutexCertificateResolver::default());

let server_config = Arc::new(Self::create_rustls_context(&config, resolver.to_owned())?);

Expand Down Expand Up @@ -705,7 +666,7 @@ impl HttpsListener {

pub fn create_rustls_context(
config: &HttpsListenerConfig,
resolver: Arc<MutexWrappedCertificateResolver>,
resolver: Arc<MutexCertificateResolver>,
) -> Result<RustlsServerConfig, ListenerError> {
let cipher_names = if config.cipher_list.is_empty() {
DEFAULT_CIPHER_SUITES.to_vec()
Expand Down Expand Up @@ -1131,10 +1092,16 @@ impl HttpsProxy {
.listeners
.values()
.find(|l| l.borrow().address == address)
.ok_or(ProxyError::NoListenerFound(address))?;
.ok_or(ProxyError::NoListenerFound(address))?
.borrow_mut();

listener
.borrow_mut()
let mut resolver = listener
.resolver
.0
.lock()
.map_err(|e| ProxyError::Lock(e.to_string()))?;

resolver
.add_certificate(&add_certificate)
.map_err(ProxyError::AddCertificate)?;

Expand All @@ -1150,19 +1117,25 @@ impl HttpsProxy {

let fingerprint = Fingerprint(
hex::decode(&remove_certificate.fingerprint)
.map_err(|e| ProxyError::WrongCertificateFingerprint(e.to_string()))?,
.map_err(ProxyError::WrongCertificateFingerprint)?,
);

let listener = self
.listeners
.values()
.find(|l| l.borrow().address == address)
.ok_or(ProxyError::NoListenerFound(address))?;
.ok_or(ProxyError::NoListenerFound(address))?
.borrow_mut();

listener
.borrow_mut()
let mut resolver = listener
.resolver
.0
.lock()
.map_err(|e| ProxyError::Lock(e.to_string()))?;

resolver
.remove_certificate(&fingerprint)
.map_err(ProxyError::AddCertificate)?;
.map_err(ProxyError::RemoveCertificate)?;

Ok(None)
}
Expand All @@ -1178,10 +1151,16 @@ impl HttpsProxy {
.listeners
.values()
.find(|l| l.borrow().address == address)
.ok_or(ProxyError::NoListenerFound(address))?;
.ok_or(ProxyError::NoListenerFound(address))?
.borrow_mut();

listener
.borrow_mut()
let mut resolver = listener
.resolver
.0
.lock()
.map_err(|e| ProxyError::Lock(e.to_string()))?;

resolver
.replace_certificate(&replace_certificate)
.map_err(ProxyError::ReplaceCertificate)?;

Expand Down Expand Up @@ -1592,7 +1571,7 @@ mod tests {
));

let address = SocketAddress::new_v4(127, 0, 0, 1, 1032);
let resolver = Arc::new(MutexWrappedCertificateResolver::default());
let resolver = Arc::new(MutexCertificateResolver::default());

let server_config = RustlsServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS12,
Expand Down
13 changes: 7 additions & 6 deletions lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ use std::{
};

use backends::BackendError;
use hex::FromHexError;
use mio::{net::TcpStream, Interest, Token};
use protocol::http::parser::Method;
use router::RouterError;
Expand Down Expand Up @@ -637,8 +638,6 @@ pub enum AcceptError {
/// returned by the HTTP, HTTPS and TCP listeners
#[derive(thiserror::Error, Debug)]
pub enum ListenerError {
#[error("failed to acquire the lock, {0}")]
Lock(String),
#[error("failed to handle certificate request, got a resolver error, {0}")]
Resolver(CertificateResolverError),
#[error("failed to parse pem, {0}")]
Expand Down Expand Up @@ -689,15 +688,17 @@ pub enum ProxyError {
#[error("could not remove frontend: {0}")]
RemoveFrontend(ListenerError),
#[error("could not add certificate: {0}")]
AddCertificate(ListenerError),
AddCertificate(CertificateResolverError),
#[error("could not remove certificate: {0}")]
RemoveCertificate(ListenerError),
RemoveCertificate(CertificateResolverError),
#[error("could not replace certificate: {0}")]
ReplaceCertificate(ListenerError),
ReplaceCertificate(CertificateResolverError),
#[error("wrong certificate fingerprint: {0}")]
WrongCertificateFingerprint(String),
WrongCertificateFingerprint(FromHexError),
#[error("this request is not supported by the proxy")]
UnsupportedMessage,
#[error("failed to acquire the lock, {0}")]
Lock(String),
}

use self::server::ListenToken;
Expand Down
101 changes: 45 additions & 56 deletions lib/src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,44 +56,6 @@ static DEFAULT_CERTIFICATE: Lazy<Option<Arc<CertifiedKey>>> = Lazy::new(|| {
// .map(|c| c.inner)
});

// -----------------------------------------------------------------------------
// CertificateResolver trait

pub trait ResolveCertificate {
type Error;

/// return the certificate in both a Rustls-usable form, and the pem format
fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<CertifiedKeyWrapper>;

/// persist a certificate, after ensuring validity, and checking if it can replace another certificate
fn add_certificate(&mut self, opts: &AddCertificate) -> Result<Fingerprint, Self::Error>;

/// Delete a certificate from the resolver. May fail if there is no alternative for
// a domain name
fn remove_certificate(&mut self, opts: &Fingerprint) -> Result<(), Self::Error>;

/// Short-hand for `add_certificate` and then `remove_certificate`.
/// It is possible that the certificate will not be replaced, if the
/// new certificate does not match `add_certificate` rules.
fn replace_certificate(
&mut self,
opts: &ReplaceCertificate,
) -> Result<Fingerprint, Self::Error> {
match Fingerprint::from_str(&opts.old_fingerprint) {
Ok(old_fingerprint) => self.remove_certificate(&old_fingerprint)?,
Err(err) => {
error!("failed to parse fingerprint, {}", err);
}
}

self.add_certificate(&AddCertificate {
address: opts.address.to_owned(),
certificate: opts.new_certificate.to_owned(),
expired_at: opts.new_expired_at.to_owned(),
})
}
}

// -----------------------------------------------------------------------------
// CertificateOverride struct

Expand Down Expand Up @@ -140,8 +102,8 @@ impl CertifiedKeyWrapper {
}
}

/// Parse a raw certificate into the Rustls format.
/// Parses RSA and ECDSA certificates.
/// Convert an AddCertificate request into the Rustls format.
/// Support RSA and ECDSA certificates.
impl TryFrom<&AddCertificate> for CertifiedKeyWrapper {
type Error = CertificateResolverError;

Expand Down Expand Up @@ -232,14 +194,17 @@ pub struct CertificateResolver {
// overrides: HashMap<Fingerprint, CertificateOverride>,
}

impl ResolveCertificate for CertificateResolver {
type Error = CertificateResolverError;

fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<CertifiedKeyWrapper> {
impl CertificateResolver {
/// return the certificate in the Rustls-usable form
pub fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<CertifiedKeyWrapper> {
self.certificates.get(fingerprint).map(ToOwned::to_owned)
}

fn add_certificate(&mut self, add: &AddCertificate) -> Result<Fingerprint, Self::Error> {
/// persist a certificate, after ensuring validity, and checking if it can replace another certificate
pub fn add_certificate(
&mut self,
add: &AddCertificate,
) -> Result<Fingerprint, CertificateResolverError> {
// Check if we could parse the certificate, chain and private key, if not just throw an
// error.
let certificate_to_add = CertifiedKeyWrapper::try_from(add)?;
Expand Down Expand Up @@ -293,7 +258,12 @@ impl ResolveCertificate for CertificateResolver {
Ok(fingerprint.to_owned())
}

fn remove_certificate(&mut self, fingerprint: &Fingerprint) -> Result<(), Self::Error> {
/// Delete a certificate from the resolver. May fail if there is no alternative for
// a domain name
pub fn remove_certificate(
&mut self,
fingerprint: &Fingerprint,
) -> Result<(), CertificateResolverError> {
if let Some(certificate_to_remove) = self.get_certificate(fingerprint) {
// let names = match self.get_names_override(fingerprint) {
// Some(names) => names,
Expand All @@ -317,14 +287,28 @@ impl ResolveCertificate for CertificateResolver {

Ok(())
}
}

/// hashes bytes of the pem-formatted certificate for storage in the hashmap
fn fingerprint(bytes: &[u8]) -> Fingerprint {
Fingerprint(Sha256::digest(bytes).iter().cloned().collect())
}
/// Short-hand for `add_certificate` and then `remove_certificate`.
/// It is possible that the certificate will not be replaced, if the
/// new certificate does not match `add_certificate` rules.
pub fn replace_certificate(
&mut self,
replace: &ReplaceCertificate,
) -> Result<Fingerprint, CertificateResolverError> {
match Fingerprint::from_str(&replace.old_fingerprint) {
Ok(old_fingerprint) => self.remove_certificate(&old_fingerprint)?,
Err(err) => {
error!("failed to parse fingerprint, {}", err);
}
}

self.add_certificate(&AddCertificate {
address: replace.address.to_owned(),
certificate: replace.new_certificate.to_owned(),
expired_at: replace.new_expired_at.to_owned(),
})
}

impl CertificateResolver {
/// return all fingerprints that are available for these domain names,
/// provided at least one name is given
fn find_certificates_by_names(
Expand Down Expand Up @@ -522,13 +506,18 @@ impl CertificateResolver {
}
}

/// hashes bytes of the pem-formatted certificate for storage in the hashmap
fn fingerprint(bytes: &[u8]) -> Fingerprint {
Fingerprint(Sha256::digest(bytes).iter().cloned().collect())
}

// -----------------------------------------------------------------------------
// MutexWrappedCertificateResolver struct

#[derive(Default)]
pub struct MutexWrappedCertificateResolver(pub Mutex<CertificateResolver>);
pub struct MutexCertificateResolver(pub Mutex<CertificateResolver>);

impl ResolvesServerCert for MutexWrappedCertificateResolver {
impl ResolvesServerCert for MutexCertificateResolver {
fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
let server_name = client_hello.server_name();
let sigschemes = client_hello.signature_schemes();
Expand Down Expand Up @@ -572,7 +561,7 @@ impl ResolvesServerCert for MutexWrappedCertificateResolver {
}
}

impl Debug for MutexWrappedCertificateResolver {
impl Debug for MutexCertificateResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("MutexWrappedCertificateResolver")
}
Expand All @@ -589,7 +578,7 @@ mod tests {
time::{Duration, SystemTime},
};

use super::{fingerprint, CertificateResolver, ResolveCertificate};
use super::{fingerprint, CertificateResolver};

use rand::{seq::SliceRandom, thread_rng};
use sozu_command::{
Expand Down

0 comments on commit 76d4b7f

Please sign in to comment.