diff --git a/command/src/certificate.rs b/command/src/certificate.rs index c7697a5b4..c81d99553 100644 --- a/command/src/certificate.rs +++ b/command/src/certificate.rs @@ -253,3 +253,27 @@ pub fn load_full_certificate( names, }) } + +impl CertificateAndKey { + pub fn fingerprint(&self) -> Result { + let pem = parse_pem(self.certificate.as_bytes())?; + let fingerprint = Fingerprint(Sha256::digest(pem.contents).iter().cloned().collect()); + Ok(fingerprint) + } + + pub fn get_overriding_names(&self) -> Result, CertificateError> { + if self.names.is_empty() { + let pem = parse_pem(self.certificate.as_bytes())?; + let overriding_names = get_cn_and_san_attributes(&pem.contents)?; + + Ok(overriding_names.into_iter().collect()) + } else { + Ok(self.names.to_owned()) + } + } + + pub fn apply_overriding_names(&mut self) -> Result<(), CertificateError> { + self.names = self.get_overriding_names()?; + Ok(()) + } +} diff --git a/command/src/command.proto b/command/src/command.proto index f2459dcd9..58a4aba08 100644 --- a/command/src/command.proto +++ b/command/src/command.proto @@ -295,7 +295,10 @@ message CertificateAndKey { repeated string certificate_chain = 2; required string key = 3; repeated TlsVersion versions = 4; - // hostnames linked to the certificate + // this field overrides the certificate names + // TODO: find a proper way to document this, for instance: + // "if empty, there is no override" + // comment should be consistent with CertificateResolver::add_certificate repeated string names = 5; } diff --git a/command/src/state.rs b/command/src/state.rs index 6cae2918a..c96748dd2 100644 --- a/command/src/state.rs +++ b/command/src/state.rs @@ -13,7 +13,7 @@ use std::{ use prost::{DecodeError, Message}; use crate::{ - certificate::{self, calculate_fingerprint, Fingerprint}, + certificate::{self, calculate_fingerprint, CertificateError, Fingerprint}, proto::{ command::{ request::RequestType, ActivateListener, AddBackend, AddCertificate, CertificateAndKey, @@ -47,7 +47,7 @@ pub enum StateError { #[error("Wrong request: {0}")] WrongRequest(String), #[error("Could not add certificate: {0}")] - AddCertificate(String), + AddCertificate(CertificateError), #[error("Could not remove certificate: {0}")] RemoveCertificate(String), #[error("Could not replace certificate: {0}")] @@ -374,11 +374,15 @@ impl ConfigState { } fn add_certificate(&mut self, add: &AddCertificate) -> Result<(), StateError> { - let fingerprint = Fingerprint( - calculate_fingerprint(add.certificate.certificate.as_bytes()).map_err( - |fingerprint_err| StateError::AddCertificate(fingerprint_err.to_string()), - )?, - ); + let overriding_names = add + .certificate + .get_overriding_names() + .map_err(StateError::AddCertificate)?; + + let fingerprint = add + .certificate + .fingerprint() + .map_err(StateError::AddCertificate)?; let entry = self .certificates @@ -388,7 +392,7 @@ impl ConfigState { if entry.contains_key(&fingerprint) { info!( "Skip loading of certificate '{}' for domain '{}' on listener '{}', the certificate is already present.", - fingerprint, add.certificate.names.join(", "), add.address + fingerprint, overriding_names.join(", "), add.address ); return Ok(()); } @@ -1248,72 +1252,20 @@ impl ConfigState { &self, filters: QueryCertificatesFilters, ) -> BTreeMap { - if let Some(domain) = filters.domain { - self.certificates - .values() - .flat_map(|hash_map| hash_map.iter()) - .flat_map(|(fingerprint, cert)| { - if cert.names.is_empty() { - let pem = certificate::parse_pem(cert.certificate.as_bytes()).ok()?; - let mut c = cert.to_owned(); - - c.names = certificate::get_cn_and_san_attributes(&pem.contents) - .ok()? - .into_iter() - .collect(); - - return Some((fingerprint, c)); - } - - Some((fingerprint, cert.to_owned())) - }) - .filter(|(_, cert)| cert.names.contains(&domain)) - .map(|(fingerprint, cert)| (fingerprint.to_string(), cert)) - .collect() - } else if let Some(f) = filters.fingerprint { - self.certificates - .values() - .flat_map(|hash_map| hash_map.iter()) - .filter(|(fingerprint, _cert)| fingerprint.to_string() == f) - .flat_map(|(fingerprint, cert)| { - if cert.names.is_empty() { - let pem = certificate::parse_pem(cert.certificate.as_bytes()).ok()?; - let mut c = cert.to_owned(); - - c.names = certificate::get_cn_and_san_attributes(&pem.contents) - .ok()? - .into_iter() - .collect(); - - return Some((fingerprint, c)); - } - - Some((fingerprint, cert.to_owned())) - }) - .map(|(fingerprint, cert)| (fingerprint.to_string(), cert)) - .collect() - } else { - self.certificates - .values() - .flat_map(|hash_map| hash_map.iter()) - .flat_map(|(fingerprint, cert)| { - if cert.names.is_empty() { - let pem = certificate::parse_pem(cert.certificate.as_bytes()).ok()?; - let mut c = cert.to_owned(); - - c.names = certificate::get_cn_and_san_attributes(&pem.contents) - .ok()? - .into_iter() - .collect(); - - return Some((fingerprint, c)); - } - - Some((fingerprint, cert.to_owned())) - }) - .map(|(fingerprint, cert)| (fingerprint.to_string(), cert)) - .collect() - } + self.certificates + .values() + .flat_map(|hash_map| hash_map.iter()) + .filter(|(fingerprint, cert)| { + if let Some(domain) = &filters.domain { + cert.names.contains(&domain) + } else if let Some(f) = &filters.fingerprint { + fingerprint.to_string() == *f + } else { + true + } + }) + .map(|(fingerprint, cert)| (fingerprint.to_string(), cert.to_owned())) + .collect() } pub fn list_frontends(&self, filters: FrontendFilters) -> ListedFrontends { diff --git a/lib/src/tls.rs b/lib/src/tls.rs index e33a78947..ab015f053 100644 --- a/lib/src/tls.rs +++ b/lib/src/tls.rs @@ -86,9 +86,10 @@ impl From<&AddCertificate> for CertificateOverride { #[derive(Clone, Debug)] pub struct CertifiedKeyWrapper { inner: Arc, - /// domain names (can be overriden) + /// domain names, override what can be found in the cert names: Vec, expiration: i64, + // TODO: add field fingerprint } impl CertifiedKeyWrapper { @@ -100,6 +101,10 @@ impl CertifiedKeyWrapper { fn names(&self) -> &[String] { &self.names } + + fn fingerprint(&self) -> Fingerprint { + Fingerprint(Sha256::digest(self.pem_bytes()).iter().cloned().collect()) + } } /// Convert an AddCertificate request into the Rustls format. @@ -139,11 +144,13 @@ impl TryFrom<&AddCertificate> for CertifiedKeyWrapper { _ => return Err(CertificateResolverError::EmptyKeys), }; + let overriding_names = cert.get_overriding_names()?; + match any_supported_type(&private_key) { Ok(signing_key) => { let stored_certificate = CertifiedKeyWrapper { inner: Arc::new(CertifiedKey::new(chain, signing_key)), - names: cert.names.clone(), + names: overriding_names, expiration, }; Ok(stored_certificate) @@ -214,7 +221,7 @@ impl CertificateResolver { certificate_to_add ); - let fingerprint = fingerprint(certificate_to_add.pem_bytes()); + let fingerprint = certificate_to_add.fingerprint(); let (should_insert, outdated_certs) = self.should_insert(&certificate_to_add)?; @@ -346,13 +353,14 @@ impl CertificateResolver { .map_err(CertificateResolverError::InvalidCommonNameAndSubjectAlternateNames) } */ + // this is better in my opinion fn certificate_names( &self, fingerprint: &Fingerprint, ) -> Result, CertificateResolverError> { if let Some(cert) = self.certificates.get(fingerprint) { - return Ok(cert.names().iter().map(|s| s.to_owned()).collect()); + return Ok(cert.names().iter().cloned().collect()); } Ok(HashSet::new()) } @@ -444,6 +452,8 @@ impl CertificateResolver { let mut related_certificates = HashSet::new(); + // TODO: make sure the tests give names or should_insert will fail + // or recalculate names to add for name in candidate_cert.names() { match self.name_fingerprint_idx.get(name) { None => should_insert = true, @@ -452,10 +462,6 @@ impl CertificateResolver { } } - if related_certificates.is_empty() { - return Ok((true, Vec::new())); - } - let mut outdated_certificates = Vec::new(); for fingerprint in related_certificates { @@ -684,9 +690,7 @@ mod tests { key: String::from(include_str!("../assets/tests/key-1y.pem")), ..Default::default() }; - let pem = parse_pem(certificate_and_key_1y.certificate.as_bytes())?; - // let names_1y = resolver.certificate_names(&pem.contents)?; let fingerprint_1y = resolver.add_certificate(&AddCertificate { address: address.clone(), certificate: certificate_and_key_1y, @@ -748,12 +752,10 @@ mod tests { let certificate_and_key_1y = CertificateAndKey { certificate: String::from(include_str!("../assets/tests/certificate-1y.pem")), key: String::from(include_str!("../assets/tests/key-1y.pem")), + names: vec!["bla.com", "www.bla.com"], ..Default::default() }; - let pem = parse_pem(certificate_and_key_1y.certificate.as_bytes())?; - - // let names_1y = resolver.certificate_names(&pem.contents)?; let fingerprint_1y = resolver.add_certificate(&AddCertificate { address: address.clone(), certificate: certificate_and_key_1y, @@ -848,9 +850,7 @@ mod tests { let mut fingerprints = vec![]; for certificate in &certificates { - let pem = parse_pem(certificate.certificate.as_bytes())?; - - fingerprints.push(fingerprint(&pem.contents)); + fingerprints.push(certificate.fingerprint()?); } // randomize entries