From 6aa45b9bd4ef8cadca308712d82f136e3747b7d5 Mon Sep 17 00:00:00 2001 From: Emmanuel Bosquet Date: Mon, 12 Feb 2024 17:10:53 +0100 Subject: [PATCH] fix and rewrite CertificateResolver refactor: - simplify overriding of certificate names and expiration - remove trait ResolveCertificate - rename MutexWrappedCertificateResolver to MutexCertificateResolver - rearrange error types, remove useless functions - change signature of get_cn_and_san_attributes fixes: - fix concurrent certificate insert in CertificateResolver - ensure proper removal of certificate in CertificateResolver - fix removal of certificate with identical fingerprint Co-Authored-By: Eloi DEMOLIS --- command/src/certificate.rs | 42 ++- command/src/command.proto | 3 +- command/src/state.rs | 100 ++----- lib/src/http.rs | 2 +- lib/src/https.rs | 91 +++--- lib/src/lib.rs | 13 +- lib/src/server.rs | 2 +- lib/src/tls.rs | 577 ++++++++++++++++++------------------- 8 files changed, 379 insertions(+), 451 deletions(-) diff --git a/command/src/certificate.rs b/command/src/certificate.rs index c7697a5b4..32ade6352 100644 --- a/command/src/certificate.rs +++ b/command/src/certificate.rs @@ -1,4 +1,4 @@ -use std::{collections::HashSet, fmt, str::FromStr}; +use std::{fmt, str::FromStr}; use hex::{FromHex, FromHexError}; use serde::de::{self, Visitor}; @@ -56,13 +56,10 @@ pub fn parse_x509(pem_bytes: &[u8]) -> Result /// Retrieve from the pem (as bytes) the common name (a.k.a `CN`) and the /// subject alternate names (a.k.a `SAN`) -pub fn get_cn_and_san_attributes(pem_bytes: &[u8]) -> Result, CertificateError> { - let x509 = parse_x509(pem_bytes) - .map_err(|err| CertificateError::InvalidCertificate(err.to_string()))?; - - let mut names: HashSet = HashSet::new(); +pub fn get_cn_and_san_attributes(x509: &X509Certificate) -> Vec { + let mut names: Vec = Vec::new(); for name in x509.subject().iter_by_oid(&OID_X509_COMMON_NAME) { - names.insert( + names.push( name.as_str() .map(String::from) .unwrap_or_else(|_| String::from_utf8_lossy(name.as_slice()).to_string()), @@ -74,13 +71,14 @@ pub fn get_cn_and_san_attributes(pem_bytes: &[u8]) -> Result, Ce if let ParsedExtension::SubjectAlternativeName(san) = extension.parsed_extension() { for name in &san.general_names { if let GeneralName::DNSName(name) = name { - names.insert(name.to_string()); + names.push(name.to_string()); } } } } } - Ok(names) + names.dedup(); + names } // ----------------------------------------------------------------------------- @@ -253,3 +251,29 @@ pub fn load_full_certificate( names, }) } + +impl CertificateAndKey { + pub fn fingerprint(&self) -> Result { + let pem = parse_pem(self.certificate.as_bytes())?; + let fingerprint = Fingerprint(Sha256::digest(pem.contents).iter().cloned().collect()); + Ok(fingerprint) + } + + pub fn get_overriding_names(&self) -> Result, CertificateError> { + if self.names.is_empty() { + let pem = parse_pem(self.certificate.as_bytes())?; + let x509 = parse_x509(&pem.contents)?; + + let overriding_names = get_cn_and_san_attributes(&x509); + + Ok(overriding_names.into_iter().collect()) + } else { + Ok(self.names.to_owned()) + } + } + + pub fn apply_overriding_names(&mut self) -> Result<(), CertificateError> { + self.names = self.get_overriding_names()?; + Ok(()) + } +} diff --git a/command/src/command.proto b/command/src/command.proto index f2459dcd9..a3b4c2910 100644 --- a/command/src/command.proto +++ b/command/src/command.proto @@ -295,7 +295,8 @@ message CertificateAndKey { repeated string certificate_chain = 2; required string key = 3; repeated TlsVersion versions = 4; - // hostnames linked to the certificate + // a list of domain names. Override certificate names + // if empty, the names of the certificate will be used repeated string names = 5; } diff --git a/command/src/state.rs b/command/src/state.rs index 6cae2918a..30cc1dddb 100644 --- a/command/src/state.rs +++ b/command/src/state.rs @@ -13,7 +13,7 @@ use std::{ use prost::{DecodeError, Message}; use crate::{ - certificate::{self, calculate_fingerprint, Fingerprint}, + certificate::{calculate_fingerprint, CertificateError, Fingerprint}, proto::{ command::{ request::RequestType, ActivateListener, AddBackend, AddCertificate, CertificateAndKey, @@ -47,7 +47,7 @@ pub enum StateError { #[error("Wrong request: {0}")] WrongRequest(String), #[error("Could not add certificate: {0}")] - AddCertificate(String), + AddCertificate(CertificateError), #[error("Could not remove certificate: {0}")] RemoveCertificate(String), #[error("Could not replace certificate: {0}")] @@ -374,17 +374,21 @@ impl ConfigState { } fn add_certificate(&mut self, add: &AddCertificate) -> Result<(), StateError> { - let fingerprint = Fingerprint( - calculate_fingerprint(add.certificate.certificate.as_bytes()).map_err( - |fingerprint_err| StateError::AddCertificate(fingerprint_err.to_string()), - )?, - ); + let fingerprint = add + .certificate + .fingerprint() + .map_err(StateError::AddCertificate)?; let entry = self .certificates .entry(add.address.clone().into()) .or_insert_with(HashMap::new); + let mut add = add.clone(); + add.certificate + .apply_overriding_names() + .map_err(StateError::AddCertificate)?; + if entry.contains_key(&fingerprint) { info!( "Skip loading of certificate '{}' for domain '{}' on listener '{}', the certificate is already present.", @@ -393,7 +397,7 @@ impl ConfigState { return Ok(()); } - entry.insert(fingerprint, add.certificate.clone()); + entry.insert(fingerprint, add.certificate); Ok(()) } @@ -1248,72 +1252,20 @@ impl ConfigState { &self, filters: QueryCertificatesFilters, ) -> BTreeMap { - if let Some(domain) = filters.domain { - self.certificates - .values() - .flat_map(|hash_map| hash_map.iter()) - .flat_map(|(fingerprint, cert)| { - if cert.names.is_empty() { - let pem = certificate::parse_pem(cert.certificate.as_bytes()).ok()?; - let mut c = cert.to_owned(); - - c.names = certificate::get_cn_and_san_attributes(&pem.contents) - .ok()? - .into_iter() - .collect(); - - return Some((fingerprint, c)); - } - - Some((fingerprint, cert.to_owned())) - }) - .filter(|(_, cert)| cert.names.contains(&domain)) - .map(|(fingerprint, cert)| (fingerprint.to_string(), cert)) - .collect() - } else if let Some(f) = filters.fingerprint { - self.certificates - .values() - .flat_map(|hash_map| hash_map.iter()) - .filter(|(fingerprint, _cert)| fingerprint.to_string() == f) - .flat_map(|(fingerprint, cert)| { - if cert.names.is_empty() { - let pem = certificate::parse_pem(cert.certificate.as_bytes()).ok()?; - let mut c = cert.to_owned(); - - c.names = certificate::get_cn_and_san_attributes(&pem.contents) - .ok()? - .into_iter() - .collect(); - - return Some((fingerprint, c)); - } - - Some((fingerprint, cert.to_owned())) - }) - .map(|(fingerprint, cert)| (fingerprint.to_string(), cert)) - .collect() - } else { - self.certificates - .values() - .flat_map(|hash_map| hash_map.iter()) - .flat_map(|(fingerprint, cert)| { - if cert.names.is_empty() { - let pem = certificate::parse_pem(cert.certificate.as_bytes()).ok()?; - let mut c = cert.to_owned(); - - c.names = certificate::get_cn_and_san_attributes(&pem.contents) - .ok()? - .into_iter() - .collect(); - - return Some((fingerprint, c)); - } - - Some((fingerprint, cert.to_owned())) - }) - .map(|(fingerprint, cert)| (fingerprint.to_string(), cert)) - .collect() - } + self.certificates + .values() + .flat_map(|hash_map| hash_map.iter()) + .filter(|(fingerprint, cert)| { + if let Some(domain) = &filters.domain { + cert.names.contains(domain) + } else if let Some(f) = &filters.fingerprint { + fingerprint.to_string() == *f + } else { + true + } + }) + .map(|(fingerprint, cert)| (fingerprint.to_string(), cert.to_owned())) + .collect() } pub fn list_frontends(&self, filters: FrontendFilters) -> ListedFrontends { diff --git a/lib/src/http.rs b/lib/src/http.rs index 068afeb6b..79c3735ec 100644 --- a/lib/src/http.rs +++ b/lib/src/http.rs @@ -522,7 +522,7 @@ impl HttpProxy { pub fn remove_listener(&mut self, remove: RemoveListener) -> Result<(), ProxyError> { let len = self.listeners.len(); - let remove_address = remove.address.clone().into(); + let remove_address = remove.address.into(); self.listeners .retain(|_, l| l.borrow().address != remove_address); diff --git a/lib/src/https.rs b/lib/src/https.rs index 0f6f44703..8d2fd1da8 100644 --- a/lib/src/https.rs +++ b/lib/src/https.rs @@ -66,7 +66,7 @@ use crate::{ server::{ListenToken, SessionManager}, socket::{server_bind, FrontRustls}, timer::TimeoutContainer, - tls::{CertifiedKeyWrapper, MutexWrappedCertificateResolver, ResolveCertificate}, + tls::MutexCertificateResolver, util::UnwrapLog, AcceptError, CachedTags, FrontendFromRequestError, L7ListenerHandler, L7Proxy, ListenerError, ListenerHandler, Protocol, ProxyConfiguration, ProxyError, ProxySession, SessionIsToBeClosed, @@ -529,7 +529,7 @@ pub struct HttpsListener { config: HttpsListenerConfig, fronts: Router, listener: Option, - resolver: Arc, + resolver: Arc, rustls_details: Arc, tags: BTreeMap, token: Token, @@ -609,51 +609,12 @@ impl L7ListenerHandler for HttpsListener { } } -impl ResolveCertificate for HttpsListener { - type Error = ListenerError; - - fn get_certificate(&self, fingerprint: &Fingerprint) -> Option { - let resolver = self - .resolver - .0 - .lock() - .map_err(|err| ListenerError::Lock(err.to_string())) - .ok()?; - - resolver.get_certificate(fingerprint) - } - - fn add_certificate(&mut self, opts: &AddCertificate) -> Result { - let mut resolver = self - .resolver - .0 - .lock() - .map_err(|err| ListenerError::Lock(err.to_string()))?; - - resolver - .add_certificate(opts) - .map_err(ListenerError::Resolver) - } - - fn remove_certificate(&mut self, fingerprint: &Fingerprint) -> Result<(), Self::Error> { - let mut resolver = self - .resolver - .0 - .lock() - .map_err(|err| ListenerError::Lock(err.to_string()))?; - - resolver - .remove_certificate(fingerprint) - .map_err(ListenerError::Resolver) - } -} - impl HttpsListener { pub fn try_new( config: HttpsListenerConfig, token: Token, ) -> Result { - let resolver = Arc::new(MutexWrappedCertificateResolver::default()); + let resolver = Arc::new(MutexCertificateResolver::default()); let server_config = Arc::new(Self::create_rustls_context(&config, resolver.to_owned())?); @@ -705,7 +666,7 @@ impl HttpsListener { pub fn create_rustls_context( config: &HttpsListenerConfig, - resolver: Arc, + resolver: Arc, ) -> Result { let cipher_names = if config.cipher_list.is_empty() { DEFAULT_CIPHER_SUITES.to_vec() @@ -850,7 +811,7 @@ impl HttpsProxy { ) -> Result, ProxyError> { let len = self.listeners.len(); - let remove_address = remove.address.clone().into(); + let remove_address = remove.address.into(); self.listeners .retain(|_, listener| listener.borrow().address != remove_address); @@ -1131,10 +1092,16 @@ impl HttpsProxy { .listeners .values() .find(|l| l.borrow().address == address) - .ok_or(ProxyError::NoListenerFound(address))?; + .ok_or(ProxyError::NoListenerFound(address))? + .borrow_mut(); - listener - .borrow_mut() + let mut resolver = listener + .resolver + .0 + .lock() + .map_err(|e| ProxyError::Lock(e.to_string()))?; + + resolver .add_certificate(&add_certificate) .map_err(ProxyError::AddCertificate)?; @@ -1150,19 +1117,25 @@ impl HttpsProxy { let fingerprint = Fingerprint( hex::decode(&remove_certificate.fingerprint) - .map_err(|e| ProxyError::WrongCertificateFingerprint(e.to_string()))?, + .map_err(ProxyError::WrongCertificateFingerprint)?, ); let listener = self .listeners .values() .find(|l| l.borrow().address == address) - .ok_or(ProxyError::NoListenerFound(address))?; + .ok_or(ProxyError::NoListenerFound(address))? + .borrow_mut(); - listener - .borrow_mut() + let mut resolver = listener + .resolver + .0 + .lock() + .map_err(|e| ProxyError::Lock(e.to_string()))?; + + resolver .remove_certificate(&fingerprint) - .map_err(ProxyError::AddCertificate)?; + .map_err(ProxyError::RemoveCertificate)?; Ok(None) } @@ -1178,10 +1151,16 @@ impl HttpsProxy { .listeners .values() .find(|l| l.borrow().address == address) - .ok_or(ProxyError::NoListenerFound(address))?; + .ok_or(ProxyError::NoListenerFound(address))? + .borrow_mut(); - listener - .borrow_mut() + let mut resolver = listener + .resolver + .0 + .lock() + .map_err(|e| ProxyError::Lock(e.to_string()))?; + + resolver .replace_certificate(&replace_certificate) .map_err(ProxyError::ReplaceCertificate)?; @@ -1592,7 +1571,7 @@ mod tests { )); let address = SocketAddress::new_v4(127, 0, 0, 1, 1032); - let resolver = Arc::new(MutexWrappedCertificateResolver::default()); + let resolver = Arc::new(MutexCertificateResolver::default()); let server_config = RustlsServerConfig::builder_with_protocol_versions(&[ &rustls::version::TLS12, diff --git a/lib/src/lib.rs b/lib/src/lib.rs index 1436c3f0d..3439ca256 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -351,6 +351,7 @@ use std::{ }; use backends::BackendError; +use hex::FromHexError; use mio::{net::TcpStream, Interest, Token}; use protocol::http::parser::Method; use router::RouterError; @@ -637,8 +638,6 @@ pub enum AcceptError { /// returned by the HTTP, HTTPS and TCP listeners #[derive(thiserror::Error, Debug)] pub enum ListenerError { - #[error("failed to acquire the lock, {0}")] - Lock(String), #[error("failed to handle certificate request, got a resolver error, {0}")] Resolver(CertificateResolverError), #[error("failed to parse pem, {0}")] @@ -689,15 +688,17 @@ pub enum ProxyError { #[error("could not remove frontend: {0}")] RemoveFrontend(ListenerError), #[error("could not add certificate: {0}")] - AddCertificate(ListenerError), + AddCertificate(CertificateResolverError), #[error("could not remove certificate: {0}")] - RemoveCertificate(ListenerError), + RemoveCertificate(CertificateResolverError), #[error("could not replace certificate: {0}")] - ReplaceCertificate(ListenerError), + ReplaceCertificate(CertificateResolverError), #[error("wrong certificate fingerprint: {0}")] - WrongCertificateFingerprint(String), + WrongCertificateFingerprint(FromHexError), #[error("this request is not supported by the proxy")] UnsupportedMessage, + #[error("failed to acquire the lock, {0}")] + Lock(String), } use self::server::ListenToken; diff --git a/lib/src/server.rs b/lib/src/server.rs index 81521ca5d..6b533ddb1 100644 --- a/lib/src/server.rs +++ b/lib/src/server.rs @@ -909,7 +909,7 @@ impl Server { ContentType::Clusters(ClusterInformations { vec: self .config_state - .cluster_state(&cluster_id) + .cluster_state(cluster_id) .map_or(vec![], |ci| vec![ci]), }) .into(), diff --git a/lib/src/tls.rs b/lib/src/tls.rs index e061c3d05..53e994597 100644 --- a/lib/src/tls.rs +++ b/lib/src/tls.rs @@ -25,7 +25,7 @@ use sozu_command::{ certificate::{ get_cn_and_san_attributes, parse_pem, parse_x509, CertificateError, Fingerprint, }, - proto::command::{AddCertificate, CertificateAndKey, ReplaceCertificate}, + proto::command::{AddCertificate, CertificateAndKey, ReplaceCertificate, SocketAddress}, }; use crate::router::trie::{Key, KeyValue, TrieNode}; @@ -34,122 +34,117 @@ use crate::router::trie::{Key, KeyValue, TrieNode}; // Default ParsedCertificateAndKey static DEFAULT_CERTIFICATE: Lazy>> = Lazy::new(|| { - let certificate_and_key = CertificateAndKey { - certificate: include_str!("../assets/certificate.pem").to_string(), - certificate_chain: vec![include_str!("../assets/certificate_chain.pem").to_string()], - key: include_str!("../assets/key.pem").to_string(), - versions: vec![], - names: vec![], + let add = AddCertificate { + certificate: CertificateAndKey { + certificate: include_str!("../assets/certificate.pem").to_string(), + certificate_chain: vec![include_str!("../assets/certificate_chain.pem").to_string()], + key: include_str!("../assets/key.pem").to_string(), + versions: vec![], + names: vec![], + }, + address: SocketAddress::new_v4(0, 0, 0, 0, 8080), // not used anyway + expired_at: None, }; - - CertificateResolver::parse(&certificate_and_key) - .ok() - .map(|c| c.inner) + CertifiedKeyWrapper::try_from(&add).ok().map(|c| c.inner) }); -// ----------------------------------------------------------------------------- -// CertificateResolver trait +#[derive(thiserror::Error, Debug)] +pub enum CertificateResolverError { + #[error("failed to get common name and subject alternate names from pem, {0}")] + InvalidCommonNameAndSubjectAlternateNames(CertificateError), + #[error("invalid private key: {0}")] + InvalidPrivateKey(String), + #[error("empty key")] + EmptyKeys, + #[error("error parsing x509 cert from bytes: {0}")] + ParseX509(CertificateError), + #[error("error parsing pem formated certificate from bytes: {0}")] + ParsePem(CertificateError), + #[error("error parsing overriding names in new certificate: {0}")] + ParseOverridingNames(CertificateError), +} -pub trait ResolveCertificate { - type Error; +/// A wrapper around the Rustls +/// [`CertifiedKey` type](https://docs.rs/rustls/latest/rustls/sign/struct.CertifiedKey.html), +/// stored and returned by the certificate resolver. +#[derive(Clone, Debug)] +pub struct CertifiedKeyWrapper { + inner: Arc, + /// domain names, override what can be found in the cert + names: Vec, + expiration: i64, + fingerprint: Fingerprint, +} - /// return the certificate in both a Rustls-usable form, and the pem format - fn get_certificate(&self, fingerprint: &Fingerprint) -> Option; +/// Convert an AddCertificate request into the Rustls format. +/// Support RSA and ECDSA certificates. +impl TryFrom<&AddCertificate> for CertifiedKeyWrapper { + type Error = CertificateResolverError; - /// persist a certificate, after ensuring validity, and checking if it can replace another certificate - fn add_certificate(&mut self, opts: &AddCertificate) -> Result; + fn try_from(add: &AddCertificate) -> Result { + let cert = add.certificate.clone(); - /// Delete a certificate from the resolver. May fail if there is no alternative for - // a domain name - fn remove_certificate(&mut self, opts: &Fingerprint) -> Result<(), Self::Error>; + let pem = + parse_pem(cert.certificate.as_bytes()).map_err(CertificateResolverError::ParsePem)?; - /// Short-hand for `add_certificate` and then `remove_certificate`. - /// It is possible that the certificate will not be replaced, if the - /// new certificate does not match `add_certificate` rules. - fn replace_certificate( - &mut self, - opts: &ReplaceCertificate, - ) -> Result { - match Fingerprint::from_str(&opts.old_fingerprint) { - Ok(old_fingerprint) => self.remove_certificate(&old_fingerprint)?, - Err(err) => { - error!("failed to parse fingerprint, {}", err); - } - } + let x509 = parse_x509(&pem.contents).map_err(CertificateResolverError::ParseX509)?; - self.add_certificate(&AddCertificate { - address: opts.address.to_owned(), - certificate: opts.new_certificate.to_owned(), - expired_at: opts.new_expired_at.to_owned(), - }) - } -} + let overriding_names = if add.certificate.names.is_empty() { + get_cn_and_san_attributes(&x509) + } else { + add.certificate.names.clone() + }; -// ----------------------------------------------------------------------------- -// CertificateOverride struct + let expiration = add + .expired_at + .unwrap_or(x509.validity().not_after.timestamp()); -/// Enables use of certificates for more domain names -#[derive(Clone, Debug)] -pub struct CertificateOverride { - pub names: Option>, - pub expiration: Option, -} + let fingerprint = Fingerprint(Sha256::digest(&pem.contents).iter().cloned().collect()); -impl From<&AddCertificate> for CertificateOverride { - fn from(opts: &AddCertificate) -> Self { - let mut names = None; - if !opts.certificate.names.is_empty() { - names = Some(opts.certificate.names.iter().cloned().collect()) - } + let mut chain = vec![CertificateDer::from(pem.contents)]; + for cert in &cert.certificate_chain { + let chain_link = parse_pem(cert.as_bytes()) + .map_err(CertificateResolverError::ParsePem)? + .contents; - Self { - names, - expiration: opts.expired_at.to_owned(), + chain.push(CertificateDer::from(chain_link)); } - } -} - -/// A wrapper around the Rustls -/// [`CertifiedKey` type](https://docs.rs/rustls/latest/rustls/sign/struct.CertifiedKey.html), -/// stored and returned by the certificate resolver. -#[derive(Clone)] -pub struct CertifiedKeyWrapper { - inner: Arc, -} -impl CertifiedKeyWrapper { - /// bytes of the pem formatted certificate, first of the chain - fn pem_bytes(&self) -> &[u8] { - self.inner.cert[0].as_ref() - } -} + let mut key_reader = BufReader::new(cert.key.as_bytes()); -// ----------------------------------------------------------------------------- -// CertificateResolverError enum + let item = match rustls_pemfile::read_one(&mut key_reader) + .map_err(|_| CertificateResolverError::EmptyKeys)? + { + Some(item) => item, + None => return Err(CertificateResolverError::EmptyKeys), + }; -#[derive(thiserror::Error, Debug)] -pub enum CertificateResolverError { - #[error("failed to get common name and subject alternate names from pem, {0}")] - InvalidCommonNameAndSubjectAlternateNames(CertificateError), - #[error("invalid private key: {0}")] - InvalidPrivateKey(String), - #[error("empty key")] - EmptyKeys, - #[error("certificate error: {0}")] - CertificateError(CertificateError), -} + let private_key = match item { + rustls_pemfile::Item::Pkcs1Key(rsa_key) => PrivateKeyDer::from(rsa_key), + rustls_pemfile::Item::Pkcs8Key(pkcs8_key) => PrivateKeyDer::from(pkcs8_key), + rustls_pemfile::Item::Sec1Key(ec_key) => PrivateKeyDer::from(ec_key), + _ => return Err(CertificateResolverError::EmptyKeys), + }; -impl From for CertificateResolverError { - fn from(value: CertificateError) -> Self { - Self::CertificateError(value) + match any_supported_type(&private_key) { + Ok(signing_key) => { + let stored_certificate = CertifiedKeyWrapper { + inner: Arc::new(CertifiedKey::new(chain, signing_key)), + names: overriding_names, + expiration, + fingerprint, + }; + Ok(stored_certificate) + } + Err(sign_error) => Err(CertificateResolverError::InvalidPrivateKey( + sign_error.to_string(), + )), + } } } -// ----------------------------------------------------------------------------- -// CertificateResolver struct - /// Parses and stores TLS certificates, makes them available to Rustls for TLS handshakes -#[derive(Default)] +#[derive(Default, Debug)] pub struct CertificateResolver { /// all fingerprints of all pub domains: TrieNode, @@ -157,80 +152,76 @@ pub struct CertificateResolver { certificates: HashMap, /// map of domain_name -> all fingerprints linked to this domain name name_fingerprint_idx: HashMap>, - /// map of fingerprint -> domain names to override - overrides: HashMap, } -impl ResolveCertificate for CertificateResolver { - type Error = CertificateResolverError; - - fn get_certificate(&self, fingerprint: &Fingerprint) -> Option { +impl CertificateResolver { + /// return the certificate in the Rustls-usable form + pub fn get_certificate(&self, fingerprint: &Fingerprint) -> Option { self.certificates.get(fingerprint).map(ToOwned::to_owned) } - fn add_certificate(&mut self, opts: &AddCertificate) -> Result { - // Check if we could parse the certificate, chain and private key, if not just throw an - // error. - let certificate_to_add = Self::parse(&opts.certificate)?; - let fingerprint = fingerprint(certificate_to_add.pem_bytes()); - if !opts.certificate.names.is_empty() || opts.expired_at.is_some() { - self.overrides - .insert(fingerprint.to_owned(), CertificateOverride::from(opts)); - } else { - self.overrides.remove(&fingerprint); - } + /// persist a certificate, after ensuring validity, and checking if it can replace another certificate + pub fn add_certificate( + &mut self, + add: &AddCertificate, + ) -> Result { + let cert_to_add = CertifiedKeyWrapper::try_from(add)?; + + let (should_insert, outdated_certs) = self.should_insert(&cert_to_add)?; - let (should_insert, certificates_to_remove) = - self.should_insert(&fingerprint, &certificate_to_add)?; if !should_insert { // if we do not need to insert the fingerprint just return the fingerprint - return Ok(fingerprint); + return Ok(cert_to_add.fingerprint); } - let new_names = match self.get_names_override(&fingerprint) { - Some(names) => names, - None => self.certificate_names(certificate_to_add.pem_bytes())?, - }; - - self.certificates - .insert(fingerprint.to_owned(), certificate_to_add); + for new_name in &cert_to_add.names { + self.domains.remove(&new_name.to_owned().into_bytes()); - for new_name in new_names { - self.domains - .insert(new_name.to_owned().into_bytes(), fingerprint.to_owned()); + self.domains.insert( + new_name.to_owned().into_bytes(), + cert_to_add.fingerprint.to_owned(), + ); self.name_fingerprint_idx - .entry(new_name) + .entry(new_name.to_owned()) .or_insert_with(HashSet::new) - .insert(fingerprint.to_owned()); + .insert(cert_to_add.fingerprint.to_owned()); } - for (fingerprint, names) in &certificates_to_remove { - for name in names { - if let Some(fingerprints) = self.name_fingerprint_idx.get_mut(name) { - fingerprints.remove(fingerprint); + for name in &cert_to_add.names { + if let Some(fingerprints) = self.name_fingerprint_idx.get_mut(name) { + for outdated in &outdated_certs { + fingerprints.remove(outdated); } } + } - self.certificates.remove(fingerprint); + for outdated in &outdated_certs { + self.certificates.remove(outdated); } - Ok(fingerprint.to_owned()) + self.certificates + .insert(cert_to_add.fingerprint.to_owned(), cert_to_add.clone()); + + Ok(cert_to_add.fingerprint.to_owned()) } - fn remove_certificate(&mut self, fingerprint: &Fingerprint) -> Result<(), Self::Error> { + /// Delete a certificate from the resolver. May fail if there is no alternative for + // a domain name + pub fn remove_certificate( + &mut self, + fingerprint: &Fingerprint, + ) -> Result<(), CertificateResolverError> { if let Some(certificate_to_remove) = self.get_certificate(fingerprint) { - let names = match self.get_names_override(fingerprint) { - Some(names) => names, - None => self.certificate_names(certificate_to_remove.pem_bytes())?, - }; - - for name in &names { - if let Some(fingerprints) = self.name_fingerprint_idx.get_mut(name) { + for name in certificate_to_remove.names { + if let Some(fingerprints) = self.name_fingerprint_idx.get_mut(&name) { fingerprints.remove(fingerprint); - if fingerprints.is_empty() { - self.domains.domain_remove(&name.to_owned().into_bytes()); + self.domains.domain_remove(&name.clone().into_bytes()); + + if let Some(fingerprint) = fingerprints.iter().next() { + self.domains + .insert(name.into_bytes(), fingerprint.to_owned()); } } } @@ -240,15 +231,31 @@ impl ResolveCertificate for CertificateResolver { Ok(()) } -} -/// hashes bytes of the pem-formatted certificate for storage in the hashmap -fn fingerprint(bytes: &[u8]) -> Fingerprint { - Fingerprint(Sha256::digest(bytes).iter().cloned().collect()) -} + /// Short-hand for `add_certificate` and then `remove_certificate`. + /// It is possible that the certificate will not be replaced, if the + /// new certificate does not match `add_certificate` rules. + pub fn replace_certificate( + &mut self, + replace: &ReplaceCertificate, + ) -> Result { + match Fingerprint::from_str(&replace.old_fingerprint) { + Ok(old_fingerprint) => self.remove_certificate(&old_fingerprint)?, + Err(err) => { + error!("failed to parse fingerprint, {}", err); + } + } -impl CertificateResolver { - /// return all fingerprints that are available, provided at least one name is given + self.add_certificate(&AddCertificate { + address: replace.address.to_owned(), + certificate: replace.new_certificate.to_owned(), + expired_at: replace.new_expired_at.to_owned(), + }) + } + + /// return all fingerprints that are available for these domain names, + /// provided at least one name is given + #[cfg(test)] fn find_certificates_by_names( &self, names: &HashSet, @@ -265,148 +272,63 @@ impl CertificateResolver { Ok(fingerprints) } - /// return the hashset of subjects that the certificate is able to handle, by - /// parsing the pem file and scrapping the information + /// return the hashset of subjects that the certificate is able to handle + #[cfg(test)] fn certificate_names( &self, - pem_bytes: &[u8], + fingerprint: &Fingerprint, ) -> Result, CertificateResolverError> { - let fingerprint = fingerprint(pem_bytes); - if let Some(certificate_override) = self.overrides.get(&fingerprint) { - if let Some(names) = &certificate_override.names { - return Ok(names.to_owned()); - } - } - - get_cn_and_san_attributes(pem_bytes) - .map_err(CertificateResolverError::InvalidCommonNameAndSubjectAlternateNames) - } - - /// Parse a raw certificate into the Rustls format. - /// Parses RSA and ECDSA certificates. - fn parse( - certificate_and_key: &CertificateAndKey, - ) -> Result { - let certificate_pem = - sozu_command::certificate::parse_pem(certificate_and_key.certificate.as_bytes())?; - - let mut chain = vec![CertificateDer::from(certificate_pem.contents)]; - for cert in &certificate_and_key.certificate_chain { - let chain_link = parse_pem(cert.as_bytes())?.contents; - - chain.push(CertificateDer::from(chain_link)); - } - - let mut key_reader = BufReader::new(certificate_and_key.key.as_bytes()); - - let item = match rustls_pemfile::read_one(&mut key_reader) - .map_err(|_| CertificateResolverError::EmptyKeys)? - { - Some(item) => item, - None => return Err(CertificateResolverError::EmptyKeys), - }; - - let private_key = match item { - rustls_pemfile::Item::Pkcs1Key(rsa_key) => PrivateKeyDer::from(rsa_key), - rustls_pemfile::Item::Pkcs8Key(pkcs8_key) => PrivateKeyDer::from(pkcs8_key), - rustls_pemfile::Item::Sec1Key(ec_key) => PrivateKeyDer::from(ec_key), - _ => return Err(CertificateResolverError::EmptyKeys), - }; - match any_supported_type(&private_key) { - Ok(signing_key) => { - let stored_certificate = CertifiedKeyWrapper { - inner: Arc::new(CertifiedKey::new(chain, signing_key)), - }; - Ok(stored_certificate) - } - Err(sign_error) => Err(CertificateResolverError::InvalidPrivateKey( - sign_error.to_string(), - )), + if let Some(cert) = self.certificates.get(fingerprint) { + return Ok(cert.names.iter().cloned().collect()); } + Ok(HashSet::new()) } -} -impl CertificateResolver { + /// check the certificate expiration and related certificates, + /// return a list of outdated certificates that should be removed fn should_insert( &self, - fingerprint: &Fingerprint, candidate_cert: &CertifiedKeyWrapper, - ) -> Result<(bool, HashMap>), CertificateResolverError> { - let x509 = parse_x509(candidate_cert.pem_bytes())?; - - // We need to know if the new certificate can replace an already existing one. - let new_names = match self.get_names_override(fingerprint) { - Some(names) => names, - None => self.certificate_names(candidate_cert.pem_bytes())?, - }; + ) -> Result<(bool, Vec), CertificateResolverError> { + let mut should_insert = false; - let expiration = self - .get_expiration_override(fingerprint) - .unwrap_or_else(|| x509.validity().not_after.timestamp()); + let mut related_certificates = HashSet::new(); - let fingerprints = self.find_certificates_by_names(&new_names)?; - let mut certificates = HashMap::new(); - for fingerprint in &fingerprints { - if let Some(cert) = self.get_certificate(fingerprint) { - certificates.insert(fingerprint, cert); + for name in &candidate_cert.names { + match self.name_fingerprint_idx.get(name) { + None => should_insert = true, + Some(fingerprints) if fingerprints.is_empty() => should_insert = true, + Some(fingerprints) => related_certificates.extend(fingerprints), } } - let mut should_insert = false; - let mut certificates_to_remove = HashMap::new(); - let mut certificates_names = HashSet::new(); - for (fingerprint, stored_certificate) in certificates { - let x509 = parse_x509(stored_certificate.pem_bytes())?; - - let certificate_names = match self.get_names_override(fingerprint) { - Some(names) => names, - None => self.certificate_names(stored_certificate.pem_bytes())?, - }; + let mut outdated_certificates = Vec::new(); - let certificate_expiration = self - .get_expiration_override(fingerprint) - .unwrap_or_else(|| x509.validity().not_after.timestamp()); - - let extra_names = certificate_names - .difference(&new_names) - .collect::>(); + for fingerprint in related_certificates { + let related_certificate = match self.certificates.get(fingerprint) { + Some(cert) => cert, + None => { + error!("certificates and fingerprint hashmaps are desynchronized"); + continue; + } + }; - // if the certificate has at least the same name or less and the expiration date - // is closer than the new one. We could remove it and allow the new insertion. - if extra_names.is_empty() && certificate_expiration < expiration { - certificates_to_remove.insert(fingerprint.to_owned(), certificate_names.to_owned()); - should_insert = true; + if related_certificate.expiration > candidate_cert.expiration { + continue; } - // We keep a track of all name of certificates that match our query to - // check, if the new certificate provide an extra domain which is not - // already exposed - for name in certificate_names { - certificates_names.insert(name); + for name in &related_certificate.names { + if !candidate_cert.names.contains(name) { + continue; + } } - } - // In the case where we do not insert the certificate, because there is - // no additional value, we have to check whether it provides an extra domain - // name not registered yet. - let diff: HashSet<&String> = new_names.difference(&certificates_names).collect(); - if !should_insert && diff.is_empty() { - // We already have all domain names registered and there is no update - // for expiration date of certificate. So, skipping the update. - return Ok((false, certificates_to_remove)); - } - - Ok((true, certificates_to_remove)) - } + should_insert = true; - fn get_expiration_override(&self, fingerprint: &Fingerprint) -> Option { - self.overrides.get(fingerprint).and_then(|co| co.expiration) - } + outdated_certificates.push(fingerprint.clone()); + } - fn get_names_override(&self, fingerprint: &Fingerprint) -> Option> { - self.overrides - .get(fingerprint) - .and_then(|co| co.names.to_owned()) + Ok((should_insert, outdated_certificates)) } pub fn domain_lookup( @@ -422,9 +344,9 @@ impl CertificateResolver { // MutexWrappedCertificateResolver struct #[derive(Default)] -pub struct MutexWrappedCertificateResolver(pub Mutex); +pub struct MutexCertificateResolver(pub Mutex); -impl ResolvesServerCert for MutexWrappedCertificateResolver { +impl ResolvesServerCert for MutexCertificateResolver { fn resolve(&self, client_hello: ClientHello) -> Option> { let server_name = client_hello.server_name(); let sigschemes = client_hello.signature_schemes(); @@ -468,7 +390,7 @@ impl ResolvesServerCert for MutexWrappedCertificateResolver { } } -impl Debug for MutexWrappedCertificateResolver { +impl Debug for MutexCertificateResolver { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str("MutexWrappedCertificateResolver") } @@ -485,13 +407,10 @@ mod tests { time::{Duration, SystemTime}, }; - use super::{fingerprint, CertificateResolver, ResolveCertificate}; + use super::CertificateResolver; use rand::{seq::SliceRandom, thread_rng}; - use sozu_command::{ - certificate::parse_pem, - proto::command::{AddCertificate, CertificateAndKey, SocketAddress}, - }; + use sozu_command::proto::command::{AddCertificate, CertificateAndKey, SocketAddress}; #[test] fn lifecycle() -> Result<(), Box> { @@ -503,13 +422,13 @@ mod tests { ..Default::default() }; - let pem = parse_pem(certificate_and_key.certificate.as_bytes())?; - - let fingerprint = resolver.add_certificate(&AddCertificate { - address, - certificate: certificate_and_key, - expired_at: None, - })?; + let fingerprint = resolver + .add_certificate(&AddCertificate { + address, + certificate: certificate_and_key, + expired_at: None, + }) + .expect("could not add certificate"); if resolver.get_certificate(&fingerprint).is_none() { return Err("failed to retrieve certificate".into()); @@ -519,7 +438,7 @@ mod tests { return Err(format!("the certificate must not been removed, {err}").into()); } - let names = resolver.certificate_names(&pem.contents)?; + let names = resolver.certificate_names(&fingerprint)?; if !resolver.find_certificates_by_names(&names)?.is_empty() && resolver.get_certificate(&fingerprint).is_some() { @@ -540,8 +459,6 @@ mod tests { ..Default::default() }; - let pem = parse_pem(certificate_and_key.certificate.as_bytes())?; - let fingerprint = resolver.add_certificate(&AddCertificate { address, certificate: certificate_and_key, @@ -561,14 +478,76 @@ mod tests { } if let Err(err) = resolver.remove_certificate(&fingerprint) { - return Err(format!("the certificate must not been removed, {err}").into()); + return Err(format!("the certificate could not be removed, {err}").into()); } - let names = resolver.certificate_names(&pem.contents)?; + let names = resolver.certificate_names(&fingerprint)?; if !resolver.find_certificates_by_names(&names)?.is_empty() && resolver.get_certificate(&fingerprint).is_some() { - return Err("We have retrieve the certificate that should be deleted".into()); + return Err("We have retrieved the certificate that should be deleted".into()); + } + + Ok(()) + } + + #[test] + fn properly_replace_outdated_cert() -> Result<(), Box> { + let address = SocketAddress::new_v4(127, 0, 0, 1, 8080); + let mut resolver = CertificateResolver::default(); + + let first_certificate = CertificateAndKey { + certificate: String::from(include_str!("../assets/tests/certificate-1y.pem")), + key: String::from(include_str!("../assets/tests/key.pem")), + names: vec!["localhost".into()], + ..Default::default() + }; + let first = resolver.add_certificate(&AddCertificate { + address: address.clone(), + certificate: first_certificate, + expired_at: None, + })?; + if resolver.get_certificate(&first).is_none() { + return Err("failed to retrieve first certificate".into()); + } + match resolver.domain_lookup("localhost".as_bytes(), true) { + Some((_, fingerprint)) if fingerprint == &first => {} + Some((domain, fingerprint)) => { + return Err(format!( + "failed to lookup first inserted certificate. domain: {:?}, fingerprint: {}", + domain, fingerprint + ) + .into()) + } + _ => return Err("failed to lookup first inserted certificate".into()), + } + + let second_certificate = CertificateAndKey { + certificate: String::from(include_str!("../assets/tests/certificate-2y.pem")), + key: String::from(include_str!("../assets/tests/key.pem")), + names: vec!["localhost".into(), "lolcatho.st".into()], + ..Default::default() + }; + let second = resolver.add_certificate(&AddCertificate { + address, + certificate: second_certificate, + expired_at: None, + })?; + + if resolver.get_certificate(&second).is_none() { + return Err("failed to retrieve second certificate".into()); + } + + match resolver.domain_lookup("localhost".as_bytes(), true) { + Some((_, fingerprint)) if fingerprint == &second => {} + Some((domain, fingerprint)) => { + return Err(format!( + "failed to lookup second inserted certificate. domain: {:?}, fingerprint: {}", + domain, fingerprint + ) + .into()) + } + _ => return Err("the former certificate has not been overriden by the new one".into()), } Ok(()) @@ -586,14 +565,13 @@ mod tests { key: String::from(include_str!("../assets/tests/key-1y.pem")), ..Default::default() }; - let pem = parse_pem(certificate_and_key_1y.certificate.as_bytes())?; - let names_1y = resolver.certificate_names(&pem.contents)?; let fingerprint_1y = resolver.add_certificate(&AddCertificate { address: address.clone(), certificate: certificate_and_key_1y, expired_at: None, })?; + let names_1y = resolver.certificate_names(&fingerprint_1y)?; if resolver.get_certificate(&fingerprint_1y).is_none() { return Err("failed to retrieve certificate".into()); @@ -652,9 +630,6 @@ mod tests { ..Default::default() }; - let pem = parse_pem(certificate_and_key_1y.certificate.as_bytes())?; - - let names_1y = resolver.certificate_names(&pem.contents)?; let fingerprint_1y = resolver.add_certificate(&AddCertificate { address: address.clone(), certificate: certificate_and_key_1y, @@ -664,6 +639,7 @@ mod tests { .as_secs() as i64, ), })?; + let names_1y = resolver.certificate_names(&fingerprint_1y)?; if resolver.get_certificate(&fingerprint_1y).is_none() { return Err("failed to retrieve certificate".into()); @@ -747,11 +723,6 @@ mod tests { ]; let mut fingerprints = vec![]; - for certificate in &certificates { - let pem = parse_pem(certificate.certificate.as_bytes())?; - - fingerprints.push(fingerprint(&pem.contents)); - } // randomize entries certificates.shuffle(&mut thread_rng()); @@ -762,11 +733,11 @@ mod tests { let mut resolver = CertificateResolver::default(); for certificate in &certificates { - resolver.add_certificate(&AddCertificate { + fingerprints.push(resolver.add_certificate(&AddCertificate { address: address.clone(), certificate: certificate.to_owned(), expired_at: None, - })?; + })?); } let mut names = HashSet::new();