Skip to content

Commit

Permalink
Merge pull request #1074 from sozu-proxy/fix-certificate-insert
Browse files Browse the repository at this point in the history
Fix certificate insertion
  • Loading branch information
Keksoj authored Feb 19, 2024
2 parents e451177 + 6aa45b9 commit 05f07ef
Show file tree
Hide file tree
Showing 8 changed files with 379 additions and 451 deletions.
42 changes: 33 additions & 9 deletions command/src/certificate.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashSet, fmt, str::FromStr};
use std::{fmt, str::FromStr};

use hex::{FromHex, FromHexError};
use serde::de::{self, Visitor};
Expand Down Expand Up @@ -56,13 +56,10 @@ pub fn parse_x509(pem_bytes: &[u8]) -> Result<X509Certificate, CertificateError>

/// Retrieve from the pem (as bytes) the common name (a.k.a `CN`) and the
/// subject alternate names (a.k.a `SAN`)
pub fn get_cn_and_san_attributes(pem_bytes: &[u8]) -> Result<HashSet<String>, CertificateError> {
let x509 = parse_x509(pem_bytes)
.map_err(|err| CertificateError::InvalidCertificate(err.to_string()))?;

let mut names: HashSet<String> = HashSet::new();
pub fn get_cn_and_san_attributes(x509: &X509Certificate) -> Vec<String> {
let mut names: Vec<String> = Vec::new();
for name in x509.subject().iter_by_oid(&OID_X509_COMMON_NAME) {
names.insert(
names.push(
name.as_str()
.map(String::from)
.unwrap_or_else(|_| String::from_utf8_lossy(name.as_slice()).to_string()),
Expand All @@ -74,13 +71,14 @@ pub fn get_cn_and_san_attributes(pem_bytes: &[u8]) -> Result<HashSet<String>, Ce
if let ParsedExtension::SubjectAlternativeName(san) = extension.parsed_extension() {
for name in &san.general_names {
if let GeneralName::DNSName(name) = name {
names.insert(name.to_string());
names.push(name.to_string());
}
}
}
}
}
Ok(names)
names.dedup();
names
}

// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -253,3 +251,29 @@ pub fn load_full_certificate(
names,
})
}

impl CertificateAndKey {
pub fn fingerprint(&self) -> Result<Fingerprint, CertificateError> {
let pem = parse_pem(self.certificate.as_bytes())?;
let fingerprint = Fingerprint(Sha256::digest(pem.contents).iter().cloned().collect());
Ok(fingerprint)
}

pub fn get_overriding_names(&self) -> Result<Vec<String>, CertificateError> {
if self.names.is_empty() {
let pem = parse_pem(self.certificate.as_bytes())?;
let x509 = parse_x509(&pem.contents)?;

let overriding_names = get_cn_and_san_attributes(&x509);

Ok(overriding_names.into_iter().collect())
} else {
Ok(self.names.to_owned())
}
}

pub fn apply_overriding_names(&mut self) -> Result<(), CertificateError> {
self.names = self.get_overriding_names()?;
Ok(())
}
}
3 changes: 2 additions & 1 deletion command/src/command.proto
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ message CertificateAndKey {
repeated string certificate_chain = 2;
required string key = 3;
repeated TlsVersion versions = 4;
// hostnames linked to the certificate
// a list of domain names. Override certificate names
// if empty, the names of the certificate will be used
repeated string names = 5;
}

Expand Down
100 changes: 26 additions & 74 deletions command/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::{
use prost::{DecodeError, Message};

use crate::{
certificate::{self, calculate_fingerprint, Fingerprint},
certificate::{calculate_fingerprint, CertificateError, Fingerprint},
proto::{
command::{
request::RequestType, ActivateListener, AddBackend, AddCertificate, CertificateAndKey,
Expand Down Expand Up @@ -47,7 +47,7 @@ pub enum StateError {
#[error("Wrong request: {0}")]
WrongRequest(String),
#[error("Could not add certificate: {0}")]
AddCertificate(String),
AddCertificate(CertificateError),
#[error("Could not remove certificate: {0}")]
RemoveCertificate(String),
#[error("Could not replace certificate: {0}")]
Expand Down Expand Up @@ -374,17 +374,21 @@ impl ConfigState {
}

fn add_certificate(&mut self, add: &AddCertificate) -> Result<(), StateError> {
let fingerprint = Fingerprint(
calculate_fingerprint(add.certificate.certificate.as_bytes()).map_err(
|fingerprint_err| StateError::AddCertificate(fingerprint_err.to_string()),
)?,
);
let fingerprint = add
.certificate
.fingerprint()
.map_err(StateError::AddCertificate)?;

let entry = self
.certificates
.entry(add.address.clone().into())
.or_insert_with(HashMap::new);

let mut add = add.clone();
add.certificate
.apply_overriding_names()
.map_err(StateError::AddCertificate)?;

if entry.contains_key(&fingerprint) {
info!(
"Skip loading of certificate '{}' for domain '{}' on listener '{}', the certificate is already present.",
Expand All @@ -393,7 +397,7 @@ impl ConfigState {
return Ok(());
}

entry.insert(fingerprint, add.certificate.clone());
entry.insert(fingerprint, add.certificate);
Ok(())
}

Expand Down Expand Up @@ -1248,72 +1252,20 @@ impl ConfigState {
&self,
filters: QueryCertificatesFilters,
) -> BTreeMap<String, CertificateAndKey> {
if let Some(domain) = filters.domain {
self.certificates
.values()
.flat_map(|hash_map| hash_map.iter())
.flat_map(|(fingerprint, cert)| {
if cert.names.is_empty() {
let pem = certificate::parse_pem(cert.certificate.as_bytes()).ok()?;
let mut c = cert.to_owned();

c.names = certificate::get_cn_and_san_attributes(&pem.contents)
.ok()?
.into_iter()
.collect();

return Some((fingerprint, c));
}

Some((fingerprint, cert.to_owned()))
})
.filter(|(_, cert)| cert.names.contains(&domain))
.map(|(fingerprint, cert)| (fingerprint.to_string(), cert))
.collect()
} else if let Some(f) = filters.fingerprint {
self.certificates
.values()
.flat_map(|hash_map| hash_map.iter())
.filter(|(fingerprint, _cert)| fingerprint.to_string() == f)
.flat_map(|(fingerprint, cert)| {
if cert.names.is_empty() {
let pem = certificate::parse_pem(cert.certificate.as_bytes()).ok()?;
let mut c = cert.to_owned();

c.names = certificate::get_cn_and_san_attributes(&pem.contents)
.ok()?
.into_iter()
.collect();

return Some((fingerprint, c));
}

Some((fingerprint, cert.to_owned()))
})
.map(|(fingerprint, cert)| (fingerprint.to_string(), cert))
.collect()
} else {
self.certificates
.values()
.flat_map(|hash_map| hash_map.iter())
.flat_map(|(fingerprint, cert)| {
if cert.names.is_empty() {
let pem = certificate::parse_pem(cert.certificate.as_bytes()).ok()?;
let mut c = cert.to_owned();

c.names = certificate::get_cn_and_san_attributes(&pem.contents)
.ok()?
.into_iter()
.collect();

return Some((fingerprint, c));
}

Some((fingerprint, cert.to_owned()))
})
.map(|(fingerprint, cert)| (fingerprint.to_string(), cert))
.collect()
}
self.certificates
.values()
.flat_map(|hash_map| hash_map.iter())
.filter(|(fingerprint, cert)| {
if let Some(domain) = &filters.domain {
cert.names.contains(domain)
} else if let Some(f) = &filters.fingerprint {
fingerprint.to_string() == *f
} else {
true
}
})
.map(|(fingerprint, cert)| (fingerprint.to_string(), cert.to_owned()))
.collect()
}

pub fn list_frontends(&self, filters: FrontendFilters) -> ListedFrontends {
Expand Down
2 changes: 1 addition & 1 deletion lib/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ impl HttpProxy {

pub fn remove_listener(&mut self, remove: RemoveListener) -> Result<(), ProxyError> {
let len = self.listeners.len();
let remove_address = remove.address.clone().into();
let remove_address = remove.address.into();
self.listeners
.retain(|_, l| l.borrow().address != remove_address);

Expand Down
91 changes: 35 additions & 56 deletions lib/src/https.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ use crate::{
server::{ListenToken, SessionManager},
socket::{server_bind, FrontRustls},
timer::TimeoutContainer,
tls::{CertifiedKeyWrapper, MutexWrappedCertificateResolver, ResolveCertificate},
tls::MutexCertificateResolver,
util::UnwrapLog,
AcceptError, CachedTags, FrontendFromRequestError, L7ListenerHandler, L7Proxy, ListenerError,
ListenerHandler, Protocol, ProxyConfiguration, ProxyError, ProxySession, SessionIsToBeClosed,
Expand Down Expand Up @@ -529,7 +529,7 @@ pub struct HttpsListener {
config: HttpsListenerConfig,
fronts: Router,
listener: Option<MioTcpListener>,
resolver: Arc<MutexWrappedCertificateResolver>,
resolver: Arc<MutexCertificateResolver>,
rustls_details: Arc<RustlsServerConfig>,
tags: BTreeMap<String, CachedTags>,
token: Token,
Expand Down Expand Up @@ -609,51 +609,12 @@ impl L7ListenerHandler for HttpsListener {
}
}

impl ResolveCertificate for HttpsListener {
type Error = ListenerError;

fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<CertifiedKeyWrapper> {
let resolver = self
.resolver
.0
.lock()
.map_err(|err| ListenerError::Lock(err.to_string()))
.ok()?;

resolver.get_certificate(fingerprint)
}

fn add_certificate(&mut self, opts: &AddCertificate) -> Result<Fingerprint, Self::Error> {
let mut resolver = self
.resolver
.0
.lock()
.map_err(|err| ListenerError::Lock(err.to_string()))?;

resolver
.add_certificate(opts)
.map_err(ListenerError::Resolver)
}

fn remove_certificate(&mut self, fingerprint: &Fingerprint) -> Result<(), Self::Error> {
let mut resolver = self
.resolver
.0
.lock()
.map_err(|err| ListenerError::Lock(err.to_string()))?;

resolver
.remove_certificate(fingerprint)
.map_err(ListenerError::Resolver)
}
}

impl HttpsListener {
pub fn try_new(
config: HttpsListenerConfig,
token: Token,
) -> Result<HttpsListener, ListenerError> {
let resolver = Arc::new(MutexWrappedCertificateResolver::default());
let resolver = Arc::new(MutexCertificateResolver::default());

let server_config = Arc::new(Self::create_rustls_context(&config, resolver.to_owned())?);

Expand Down Expand Up @@ -705,7 +666,7 @@ impl HttpsListener {

pub fn create_rustls_context(
config: &HttpsListenerConfig,
resolver: Arc<MutexWrappedCertificateResolver>,
resolver: Arc<MutexCertificateResolver>,
) -> Result<RustlsServerConfig, ListenerError> {
let cipher_names = if config.cipher_list.is_empty() {
DEFAULT_CIPHER_SUITES.to_vec()
Expand Down Expand Up @@ -850,7 +811,7 @@ impl HttpsProxy {
) -> Result<Option<ResponseContent>, ProxyError> {
let len = self.listeners.len();

let remove_address = remove.address.clone().into();
let remove_address = remove.address.into();
self.listeners
.retain(|_, listener| listener.borrow().address != remove_address);

Expand Down Expand Up @@ -1131,10 +1092,16 @@ impl HttpsProxy {
.listeners
.values()
.find(|l| l.borrow().address == address)
.ok_or(ProxyError::NoListenerFound(address))?;
.ok_or(ProxyError::NoListenerFound(address))?
.borrow_mut();

listener
.borrow_mut()
let mut resolver = listener
.resolver
.0
.lock()
.map_err(|e| ProxyError::Lock(e.to_string()))?;

resolver
.add_certificate(&add_certificate)
.map_err(ProxyError::AddCertificate)?;

Expand All @@ -1150,19 +1117,25 @@ impl HttpsProxy {

let fingerprint = Fingerprint(
hex::decode(&remove_certificate.fingerprint)
.map_err(|e| ProxyError::WrongCertificateFingerprint(e.to_string()))?,
.map_err(ProxyError::WrongCertificateFingerprint)?,
);

let listener = self
.listeners
.values()
.find(|l| l.borrow().address == address)
.ok_or(ProxyError::NoListenerFound(address))?;
.ok_or(ProxyError::NoListenerFound(address))?
.borrow_mut();

listener
.borrow_mut()
let mut resolver = listener
.resolver
.0
.lock()
.map_err(|e| ProxyError::Lock(e.to_string()))?;

resolver
.remove_certificate(&fingerprint)
.map_err(ProxyError::AddCertificate)?;
.map_err(ProxyError::RemoveCertificate)?;

Ok(None)
}
Expand All @@ -1178,10 +1151,16 @@ impl HttpsProxy {
.listeners
.values()
.find(|l| l.borrow().address == address)
.ok_or(ProxyError::NoListenerFound(address))?;
.ok_or(ProxyError::NoListenerFound(address))?
.borrow_mut();

listener
.borrow_mut()
let mut resolver = listener
.resolver
.0
.lock()
.map_err(|e| ProxyError::Lock(e.to_string()))?;

resolver
.replace_certificate(&replace_certificate)
.map_err(ProxyError::ReplaceCertificate)?;

Expand Down Expand Up @@ -1592,7 +1571,7 @@ mod tests {
));

let address = SocketAddress::new_v4(127, 0, 0, 1, 1032);
let resolver = Arc::new(MutexWrappedCertificateResolver::default());
let resolver = Arc::new(MutexCertificateResolver::default());

let server_config = RustlsServerConfig::builder_with_protocol_versions(&[
&rustls::version::TLS12,
Expand Down
Loading

0 comments on commit 05f07ef

Please sign in to comment.