From 9832ab412587c5b6a1cef275b4781b654f5801e7 Mon Sep 17 00:00:00 2001 From: MiaAltieri Date: Fri, 20 Dec 2024 20:52:23 +0000 Subject: [PATCH 1/6] update libs --- lib/charms/grafana_agent/v0/cos_agent.py | 17 +- lib/charms/operator_libs_linux/v2/snap.py | 45 +- .../v3/tls_certificates.py | 960 ++++++++---------- 3 files changed, 473 insertions(+), 549 deletions(-) diff --git a/lib/charms/grafana_agent/v0/cos_agent.py b/lib/charms/grafana_agent/v0/cos_agent.py index c57e3f059..1ea79a625 100644 --- a/lib/charms/grafana_agent/v0/cos_agent.py +++ b/lib/charms/grafana_agent/v0/cos_agent.py @@ -22,7 +22,6 @@ Using the `COSAgentProvider` object only requires instantiating it, typically in the `__init__` method of your charm (the one which sends telemetry). -The constructor of `COSAgentProvider` has only one required and ten optional parameters: ```python def __init__( @@ -235,10 +234,10 @@ def __init__(self, *args): import pydantic from cosl import GrafanaDashboard, JujuTopology from cosl.rules import AlertRules -from ops import CharmBase from ops.charm import RelationChangedEvent from ops.framework import EventBase, EventSource, Object, ObjectEvents from ops.model import ModelError, Relation +from ops.testing import CharmType if TYPE_CHECKING: try: @@ -253,7 +252,7 @@ class _MetricsEndpointDict(TypedDict): LIBID = "dc15fa84cef84ce58155fb84f6c6213a" LIBAPI = 0 -LIBPATCH = 11 +LIBPATCH = 12 PYDEPS = ["cosl", "pydantic"] @@ -468,7 +467,7 @@ def dump(self, databag: Optional[MutableMapping] = None, clear: bool = True): return databag -class CosAgentProviderUnitData(DatabagModel): # pyright: ignore [reportGeneralTypeIssues] +class CosAgentProviderUnitData(DatabagModel): """Unit databag model for `cos-agent` relation.""" # The following entries are the same for all units of the same principal. @@ -495,7 +494,7 @@ class CosAgentProviderUnitData(DatabagModel): # pyright: ignore [reportGeneralT KEY: ClassVar[str] = "config" -class CosAgentPeersUnitData(DatabagModel): # pyright: ignore [reportGeneralTypeIssues] +class CosAgentPeersUnitData(DatabagModel): """Unit databag model for `peers` cos-agent machine charm peer relation.""" # We need the principal unit name and relation metadata to be able to render identifiers @@ -594,9 +593,7 @@ class Receiver(pydantic.BaseModel): ) -class CosAgentRequirerUnitData( - DatabagModel -): # pyright: ignore [reportGeneralTypeIssues] # noqa: D101 +class CosAgentRequirerUnitData(DatabagModel): # noqa: D101 """Application databag model for the COS-agent requirer.""" receivers: List[Receiver] = pydantic.Field( @@ -610,7 +607,7 @@ class COSAgentProvider(Object): def __init__( self, - charm: CharmBase, + charm: CharmType, relation_name: str = DEFAULT_RELATION_NAME, metrics_endpoints: Optional[List["_MetricsEndpointDict"]] = None, metrics_rules_dir: str = "./src/prometheus_alert_rules", @@ -879,7 +876,7 @@ class COSAgentRequirer(Object): def __init__( self, - charm: CharmBase, + charm: CharmType, *, relation_name: str = DEFAULT_RELATION_NAME, peer_relation_name: str = DEFAULT_PEER_RELATION_NAME, diff --git a/lib/charms/operator_libs_linux/v2/snap.py b/lib/charms/operator_libs_linux/v2/snap.py index 9d09a78d3..d14f864fd 100644 --- a/lib/charms/operator_libs_linux/v2/snap.py +++ b/lib/charms/operator_libs_linux/v2/snap.py @@ -64,6 +64,7 @@ import socket import subprocess import sys +import time import urllib.error import urllib.parse import urllib.request @@ -83,7 +84,7 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 7 +LIBPATCH = 9 # Regex to locate 7-bit C1 ANSI sequences @@ -332,7 +333,7 @@ def get(self, key: Optional[str], *, typed: bool = False) -> Any: return self._snap("get", [key]).strip() - def set(self, config: Dict[str, Any], *, typed: bool = False) -> str: + def set(self, config: Dict[str, Any], *, typed: bool = False) -> None: """Set a snap configuration value. Args: @@ -340,11 +341,9 @@ def set(self, config: Dict[str, Any], *, typed: bool = False) -> str: typed: set to True to convert all values in the config into typed values while configuring the snap (set with typed=True). Default is not to convert. """ - if typed: - kv = [f"{key}={json.dumps(val)}" for key, val in config.items()] - return self._snap("set", ["-t"] + kv) - - return self._snap("set", [f"{key}={val}" for key, val in config.items()]) + if not typed: + config = {k: str(v) for k, v in config.items()} + self._snap_client._put_snap_conf(self._name, config) def unset(self, key) -> str: """Unset a snap configuration value. @@ -770,7 +769,33 @@ def _request( headers["Content-Type"] = "application/json" response = self._request_raw(method, path, query, headers, data) - return json.loads(response.read().decode())["result"] + response = json.loads(response.read().decode()) + if response["type"] == "async": + return self._wait(response["change"]) + return response["result"] + + def _wait(self, change_id: str, timeout=300) -> JSONType: + """Wait for an async change to complete. + + The poll time is 100 milliseconds, the same as in snap clients. + """ + deadline = time.time() + timeout + while True: + if time.time() > deadline: + raise TimeoutError(f"timeout waiting for snap change {change_id}") + response = self._request("GET", f"changes/{change_id}") + status = response["status"] + if status == "Done": + return response.get("data") + if status == "Doing" or status == "Do": + time.sleep(0.1) + continue + if status == "Wait": + logger.warning("snap change %s succeeded with status 'Wait'", change_id) + return response.get("data") + raise SnapError( + f"snap change {response.get('kind')!r} id {change_id} failed with status {status}" + ) def _request_raw( self, @@ -818,6 +843,10 @@ def get_installed_snap_apps(self, name: str) -> List: """Query the snap server for apps belonging to a named, currently installed snap.""" return self._request("GET", "apps", {"names": name, "select": "service"}) + def _put_snap_conf(self, name: str, conf: Dict[str, Any]): + """Set the configuration details for an installed snap.""" + return self._request("PUT", f"snaps/{name}/conf", body=conf) + class SnapCache(Mapping): """An abstraction to represent installed/available packages. diff --git a/lib/charms/tls_certificates_interface/v3/tls_certificates.py b/lib/charms/tls_certificates_interface/v3/tls_certificates.py index 141412b00..c232362fe 100644 --- a/lib/charms/tls_certificates_interface/v3/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v3/tls_certificates.py @@ -1,4 +1,4 @@ -# Copyright 2024 Canonical Ltd. +# Copyright 2021 Canonical Ltd. # See LICENSE file for licensing details. @@ -7,19 +7,16 @@ This library contains the Requires and Provides classes for handling the tls-certificates interface. -Pre-requisites: - - Juju >= 3.0 - ## Getting Started From a charm directory, fetch the library using `charmcraft`: ```shell -charmcraft fetch-lib charms.tls_certificates_interface.v3.tls_certificates +charmcraft fetch-lib charms.tls_certificates_interface.v2.tls_certificates ``` Add the following libraries to the charm's `requirements.txt` file: - jsonschema -- cryptography >= 42.0.0 +- cryptography Add the following section to the charm's `charmcraft.yaml` file: ```yaml @@ -39,10 +36,10 @@ Example: ```python -from charms.tls_certificates_interface.v3.tls_certificates import ( +from charms.tls_certificates_interface.v2.tls_certificates import ( CertificateCreationRequestEvent, CertificateRevocationRequestEvent, - TLSCertificatesProvidesV3, + TLSCertificatesProvidesV2, generate_private_key, ) from ops.charm import CharmBase, InstallEvent @@ -62,7 +59,7 @@ class ExampleProviderCharm(CharmBase): def __init__(self, *args): super().__init__(*args) - self.certificates = TLSCertificatesProvidesV3(self, "certificates") + self.certificates = TLSCertificatesProvidesV2(self, "certificates") self.framework.observe( self.certificates.on.certificate_request, self._on_certificate_request @@ -111,7 +108,6 @@ def _on_certificate_request(self, event: CertificateCreationRequestEvent) -> Non ca=ca_certificate, chain=[ca_certificate, certificate], relation_id=event.relation_id, - recommended_expiry_notification_time=720, ) def _on_certificate_revocation_request(self, event: CertificateRevocationRequestEvent) -> None: @@ -130,15 +126,15 @@ def _on_certificate_revocation_request(self, event: CertificateRevocationRequest Example: ```python -from charms.tls_certificates_interface.v3.tls_certificates import ( +from charms.tls_certificates_interface.v2.tls_certificates import ( CertificateAvailableEvent, CertificateExpiringEvent, CertificateRevokedEvent, - TLSCertificatesRequiresV3, + TLSCertificatesRequiresV2, generate_csr, generate_private_key, ) -from ops.charm import CharmBase, RelationCreatedEvent +from ops.charm import CharmBase, RelationJoinedEvent from ops.main import main from ops.model import ActiveStatus, WaitingStatus from typing import Union @@ -149,10 +145,10 @@ class ExampleRequirerCharm(CharmBase): def __init__(self, *args): super().__init__(*args) self.cert_subject = "whatever" - self.certificates = TLSCertificatesRequiresV3(self, "certificates") + self.certificates = TLSCertificatesRequiresV2(self, "certificates") self.framework.observe(self.on.install, self._on_install) self.framework.observe( - self.on.certificates_relation_created, self._on_certificates_relation_created + self.on.certificates_relation_joined, self._on_certificates_relation_joined ) self.framework.observe( self.certificates.on.certificate_available, self._on_certificate_available @@ -180,7 +176,7 @@ def _on_install(self, event) -> None: {"private_key_password": "banana", "private_key": private_key.decode()} ) - def _on_certificates_relation_created(self, event: RelationCreatedEvent) -> None: + def _on_certificates_relation_joined(self, event: RelationJoinedEvent) -> None: replicas_relation = self.model.get_relation("replicas") if not replicas_relation: self.unit.status = WaitingStatus("Waiting for peer relation to be created") @@ -277,19 +273,19 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven """ # noqa: D405, D410, D411, D214, D416 import copy -import ipaddress import json import logging import uuid from contextlib import suppress -from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import List, Literal, Optional, Union +from ipaddress import IPv4Address +from typing import Any, Dict, List, Literal, Optional, Union from cryptography import x509 from cryptography.hazmat._oid import ExtensionOID from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.serialization import pkcs12 from jsonschema import exceptions, validate from ops.charm import ( CharmBase, @@ -297,28 +293,21 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven RelationBrokenEvent, RelationChangedEvent, SecretExpiredEvent, + UpdateStatusEvent, ) from ops.framework import EventBase, EventSource, Handle, Object from ops.jujuversion import JujuVersion -from ops.model import ( - Application, - ModelError, - Relation, - RelationDataContent, - Secret, - SecretNotFoundError, - Unit, -) +from ops.model import ModelError, Relation, RelationDataContent, SecretNotFoundError # The unique Charmhub library identifier, never change it LIBID = "afd8c2bccf834997afce12c2706d2ede" # Increment this major API version when introducing breaking changes -LIBAPI = 3 +LIBAPI = 2 # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 23 +LIBPATCH = 29 PYDEPS = ["cryptography", "jsonschema"] @@ -433,58 +422,6 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven logger = logging.getLogger(__name__) -@dataclass -class RequirerCSR: - """This class represents a certificate signing request from an interface Requirer.""" - - relation_id: int - application_name: str - unit_name: str - csr: str - is_ca: bool - - -@dataclass -class ProviderCertificate: - """This class represents a certificate from an interface Provider.""" - - relation_id: int - application_name: str - csr: str - certificate: str - ca: str - chain: List[str] - revoked: bool - expiry_time: datetime - expiry_notification_time: Optional[datetime] = None - - def chain_as_pem(self) -> str: - """Return full certificate chain as a PEM string.""" - return "\n\n".join(reversed(self.chain)) - - def to_json(self) -> str: - """Return the object as a JSON string. - - Returns: - str: JSON representation of the object - """ - return json.dumps( - { - "relation_id": self.relation_id, - "application_name": self.application_name, - "csr": self.csr, - "certificate": self.certificate, - "ca": self.ca, - "chain": self.chain, - "revoked": self.revoked, - "expiry_time": self.expiry_time.isoformat(), - "expiry_notification_time": self.expiry_notification_time.isoformat() - if self.expiry_notification_time - else None, - } - ) - - class CertificateAvailableEvent(EventBase): """Charm Event triggered when a TLS certificate is available.""" @@ -518,15 +455,11 @@ def restore(self, snapshot: dict): self.ca = snapshot["ca"] self.chain = snapshot["chain"] - def chain_as_pem(self) -> str: - """Return full certificate chain as a PEM string.""" - return "\n\n".join(reversed(self.chain)) - class CertificateExpiringEvent(EventBase): """Charm Event triggered when a TLS certificate is almost expired.""" - def __init__(self, handle, certificate: str, expiry: str): + def __init__(self, handle: Handle, certificate: str, expiry: str): """CertificateExpiringEvent. Args: @@ -708,49 +641,21 @@ def _get_closest_future_time( ) -def calculate_expiry_notification_time( - validity_start_time: datetime, - expiry_time: datetime, - provider_recommended_notification_time: Optional[int], - requirer_recommended_notification_time: Optional[int], -) -> datetime: - """Calculate a reasonable time to notify the user about the certificate expiry. - - It takes into account the time recommended by the provider and by the requirer. - Time recommended by the provider is preferred, - then time recommended by the requirer, - then dynamically calculated time. +def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]: + """Extract expiry time from a certificate string. Args: - validity_start_time: Certificate validity time - expiry_time: Certificate expiry time - provider_recommended_notification_time: - Time in hours prior to expiry to notify the user. - Recommended by the provider. - requirer_recommended_notification_time: - Time in hours prior to expiry to notify the user. - Recommended by the requirer. + certificate (str): x509 certificate as a string Returns: - datetime: Time to notify the user about the certificate expiry. + Optional[datetime]: Expiry datetime or None """ - if provider_recommended_notification_time is not None: - provider_recommended_notification_time = abs(provider_recommended_notification_time) - provider_recommendation_time_delta = expiry_time - timedelta( - hours=provider_recommended_notification_time - ) - if validity_start_time < provider_recommendation_time_delta: - return provider_recommendation_time_delta - - if requirer_recommended_notification_time is not None: - requirer_recommended_notification_time = abs(requirer_recommended_notification_time) - requirer_recommendation_time_delta = expiry_time - timedelta( - hours=requirer_recommended_notification_time - ) - if validity_start_time < requirer_recommendation_time_delta: - return requirer_recommendation_time_delta - calculated_hours = (expiry_time - validity_start_time).total_seconds() / (3600 * 3) - return expiry_time - timedelta(hours=calculated_hours) + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + return certificate_object.not_valid_after_utc + except ValueError: + logger.warning("Could not load certificate.") + return None def generate_ca( @@ -981,6 +886,38 @@ def generate_certificate( return cert.public_bytes(serialization.Encoding.PEM) +def generate_pfx_package( + certificate: bytes, + private_key: bytes, + package_password: str, + private_key_password: Optional[bytes] = None, +) -> bytes: + """Generate a PFX package to contain the TLS certificate and private key. + + Args: + certificate (bytes): TLS certificate + private_key (bytes): Private key + package_password (str): Password to open the PFX package + private_key_password (bytes): Private key password + + Returns: + bytes: + """ + private_key_object = serialization.load_pem_private_key( + private_key, password=private_key_password + ) + certificate_object = x509.load_pem_x509_certificate(certificate) + name = certificate_object.subject.rfc4514_string() + pfx_bytes = pkcs12.serialize_key_and_certificates( + name=name.encode(), + cert=certificate_object, + key=private_key_object, # type: ignore[arg-type] + cas=None, + encryption_algorithm=serialization.BestAvailableEncryption(package_password.encode()), + ) + return pfx_bytes + + def generate_private_key( password: Optional[bytes] = None, key_size: int = 2048, @@ -1019,8 +956,6 @@ def generate_csr( # noqa: C901 organization: Optional[str] = None, email_address: Optional[str] = None, country_name: Optional[str] = None, - state_or_province_name: Optional[str] = None, - locality_name: Optional[str] = None, private_key_password: Optional[bytes] = None, sans: Optional[List[str]] = None, sans_oid: Optional[List[str]] = None, @@ -1039,8 +974,6 @@ def generate_csr( # noqa: C901 organization (str): Name of organization. email_address (str): Email address. country_name (str): Country Name. - state_or_province_name (str): State or Province Name. - locality_name (str): Locality Name. private_key_password (bytes): Private key password sans (list): Use sans_dns - this will be deprecated in a future release List of DNS subject alternative names (keeping it for now for backward compatibility) @@ -1066,19 +999,13 @@ def generate_csr( # noqa: C901 subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address)) if country_name: subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name)) - if state_or_province_name: - subject_name.append( - x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name) - ) - if locality_name: - subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name)) csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name)) _sans: List[x509.GeneralName] = [] if sans_oid: _sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) if sans_ip: - _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip]) + _sans.extend([x509.IPAddress(IPv4Address(san)) for san in sans_ip]) if sans: _sans.extend([x509.DNSName(san) for san in sans]) if sans_dns: @@ -1094,13 +1021,6 @@ def generate_csr( # noqa: C901 return signed_certificate.public_bytes(serialization.Encoding.PEM) -def get_sha256_hex(data: str) -> str: - """Calculate the hash of the provided data and return the hexadecimal representation.""" - digest = hashes.Hash(hashes.SHA256()) - digest.update(data.encode()) - return digest.finalize().hex() - - def csr_matches_certificate(csr: str, cert: str) -> bool: """Check if a CSR matches a certificate. @@ -1110,39 +1030,27 @@ def csr_matches_certificate(csr: str, cert: str) -> bool: Returns: bool: True/False depending on whether the CSR matches the certificate. """ - csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) - cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) - - if csr_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) != cert_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ): - return False - return True - - -def _relation_data_is_valid( - relation: Relation, app_or_unit: Union[Application, Unit], json_schema: dict -) -> bool: - """Check whether relation data is valid based on json schema. - - Args: - relation (Relation): Relation object - app_or_unit (Union[Application, Unit]): Application or unit object - json_schema (dict): Json schema - - Returns: - bool: Whether relation data is valid. - """ - relation_data = _load_relation_data(relation.data[app_or_unit]) try: - validate(instance=relation_data, schema=json_schema) - return True - except exceptions.ValidationError: + csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) + cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) + + if csr_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) != cert_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ): + return False + if ( + csr_object.public_key().public_numbers().n # type: ignore[union-attr] + != cert_object.public_key().public_numbers().n # type: ignore[union-attr] + ): + return False + except ValueError: + logger.warning("Could not load certificate or CSR.") return False + return True class CertificatesProviderCharmEvents(CharmEvents): @@ -1161,7 +1069,7 @@ class CertificatesRequirerCharmEvents(CharmEvents): all_certificates_invalidated = EventSource(AllCertificatesInvalidatedEvent) -class TLSCertificatesProvidesV3(Object): +class TLSCertificatesProvidesV2(Object): """TLS certificates provider class to be instantiated by TLS certificates providers.""" on = CertificatesProviderCharmEvents() # type: ignore[reportAssignmentType] @@ -1197,7 +1105,6 @@ def _add_certificate( certificate_signing_request: str, ca: str, chain: List[str], - recommended_expiry_notification_time: Optional[int] = None, ) -> None: """Add certificate to relation data. @@ -1207,8 +1114,6 @@ def _add_certificate( certificate_signing_request (str): Certificate Signing Request ca (str): CA Certificate chain (list): CA Chain - recommended_expiry_notification_time (int): - Time in hours before the certificate expires to notify the user. Returns: None @@ -1226,7 +1131,6 @@ def _add_certificate( "certificate_signing_request": certificate_signing_request, "ca": ca, "chain": chain, - "recommended_expiry_notification_time": recommended_expiry_notification_time, } provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) @@ -1274,6 +1178,22 @@ def _remove_certificate( certificates.remove(certificate_dict) relation.data[self.model.app]["certificates"] = json.dumps(certificates) + @staticmethod + def _relation_data_is_valid(certificates_data: dict) -> bool: + """Use JSON schema validator to validate relation data content. + + Args: + certificates_data (dict): Certificate data dictionary as retrieved from relation data. + + Returns: + bool: True/False depending on whether the relation data follows the json schema. + """ + try: + validate(instance=certificates_data, schema=REQUIRER_JSON_SCHEMA) + return True + except exceptions.ValidationError: + return False + def revoke_all_certificates(self) -> None: """Revoke all certificates of this provider. @@ -1293,7 +1213,6 @@ def set_relation_certificate( ca: str, chain: List[str], relation_id: int, - recommended_expiry_notification_time: Optional[int] = None, ) -> None: """Add certificates to relation data. @@ -1303,8 +1222,6 @@ def set_relation_certificate( ca (str): CA Certificate chain (list): CA Chain relation_id (int): Juju relation ID - recommended_expiry_notification_time (int): - Recommended time in hours before the certificate expires to notify the user. Returns: None @@ -1326,7 +1243,6 @@ def set_relation_certificate( certificate_signing_request=certificate_signing_request.strip(), ca=ca.strip(), chain=[cert.strip() for cert in chain], - recommended_expiry_notification_time=recommended_expiry_notification_time, ) def remove_certificate(self, certificate: str) -> None: @@ -1346,24 +1262,16 @@ def remove_certificate(self, certificate: str) -> None: def get_issued_certificates( self, relation_id: Optional[int] = None - ) -> List[ProviderCertificate]: - """Return a List of issued (non revoked) certificates. + ) -> Dict[str, List[Dict[str, str]]]: + """Return a dictionary of issued certificates. - Returns: - List: List of ProviderCertificate objects - """ - provider_certificates = self.get_provider_certificates(relation_id=relation_id) - return [certificate for certificate in provider_certificates if not certificate.revoked] - - def get_provider_certificates( - self, relation_id: Optional[int] = None - ) -> List[ProviderCertificate]: - """Return a List of issued certificates. + It returns certificates from all relations if relation_id is not specified. + Certificates are returned per application name and CSR. Returns: - List: List of ProviderCertificate objects + dict: Certificates per application name. """ - certificates: List[ProviderCertificate] = [] + certificates: Dict[str, List[Dict[str, str]]] = {} relations = ( [ relation @@ -1374,33 +1282,19 @@ def get_provider_certificates( else self.model.relations.get(self.relationship_name, []) ) for relation in relations: - if not relation.app: - logger.warning("Relation %s does not have an application", relation.id) - continue provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) + + certificates[relation.app.name] = [] # type: ignore[union-attr] for certificate in provider_certificates: - try: - certificate_object = x509.load_pem_x509_certificate( - data=certificate["certificate"].encode() + if not certificate.get("revoked", False): + certificates[relation.app.name].append( # type: ignore[union-attr] + { + "csr": certificate["certificate_signing_request"], + "certificate": certificate["certificate"], + } ) - except ValueError as e: - logger.error("Could not load certificate - Skipping: %s", e) - continue - provider_certificate = ProviderCertificate( - relation_id=relation.id, - application_name=relation.app.name, - csr=certificate["certificate_signing_request"], - certificate=certificate["certificate"], - ca=certificate["ca"], - chain=certificate["chain"], - revoked=certificate.get("revoked", False), - expiry_time=certificate_object.not_valid_after_utc, - expiry_notification_time=certificate.get( - "recommended_expiry_notification_time" - ), - ) - certificates.append(provider_certificate) + return certificates def _on_relation_changed(self, event: RelationChangedEvent) -> None: @@ -1423,90 +1317,124 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: return if not self.model.unit.is_leader(): return - if not _relation_data_is_valid(event.relation, event.unit, REQUIRER_JSON_SCHEMA): + requirer_relation_data = _load_relation_data(event.relation.data[event.unit]) + provider_relation_data = self._load_app_relation_data(event.relation) + if not self._relation_data_is_valid(requirer_relation_data): logger.debug("Relation data did not pass JSON Schema validation") return - provider_certificates = self.get_provider_certificates(relation_id=event.relation.id) - requirer_csrs = self.get_requirer_csrs(relation_id=event.relation.id) + provider_certificates = provider_relation_data.get("certificates", []) + requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) provider_csrs = [ - certificate_creation_request.csr + certificate_creation_request["certificate_signing_request"] for certificate_creation_request in provider_certificates ] - for certificate_request in requirer_csrs: - if certificate_request.csr not in provider_csrs: + requirer_unit_certificate_requests = [ + { + "csr": certificate_creation_request["certificate_signing_request"], + "is_ca": certificate_creation_request.get("ca", False), + } + for certificate_creation_request in requirer_csrs + ] + for certificate_request in requirer_unit_certificate_requests: + if certificate_request["csr"] not in provider_csrs: self.on.certificate_creation_request.emit( - certificate_signing_request=certificate_request.csr, - relation_id=certificate_request.relation_id, - is_ca=certificate_request.is_ca, + certificate_signing_request=certificate_request["csr"], + relation_id=event.relation.id, + is_ca=certificate_request["is_ca"], ) self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id) def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None: """Revoke certificates for which no unit has a CSR. - Goes through all generated certificates and compare against the list of CSRs for all units. + Goes through all generated certificates and compare against the list of CSRs for all units + of a given relationship. + + Args: + relation_id (int): Relation id Returns: None """ - provider_certificates = self.get_unsolicited_certificates(relation_id=relation_id) - for provider_certificate in provider_certificates: - self.on.certificate_revocation_request.emit( - certificate=provider_certificate.certificate, - certificate_signing_request=provider_certificate.csr, - ca=provider_certificate.ca, - chain=provider_certificate.chain, - ) - self.remove_certificate(certificate=provider_certificate.certificate) - - def get_unsolicited_certificates( - self, relation_id: Optional[int] = None - ) -> List[ProviderCertificate]: - """Return provider certificates for which no certificate requests exists. - - Those certificates should be revoked. - """ - unsolicited_certificates: List[ProviderCertificate] = [] - provider_certificates = self.get_provider_certificates(relation_id=relation_id) - requirer_csrs = self.get_requirer_csrs(relation_id=relation_id) - list_of_csrs = [csr.csr for csr in requirer_csrs] + certificates_relation = self.model.get_relation( + relation_name=self.relationship_name, relation_id=relation_id + ) + if not certificates_relation: + raise RuntimeError(f"Relation {self.relationship_name} does not exist") + provider_relation_data = self._load_app_relation_data(certificates_relation) + list_of_csrs: List[str] = [] + for unit in certificates_relation.units: + requirer_relation_data = _load_relation_data(certificates_relation.data[unit]) + requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) + list_of_csrs.extend(csr["certificate_signing_request"] for csr in requirer_csrs) + provider_certificates = provider_relation_data.get("certificates", []) for certificate in provider_certificates: - if certificate.csr not in list_of_csrs: - unsolicited_certificates.append(certificate) - return unsolicited_certificates + if certificate["certificate_signing_request"] not in list_of_csrs: + self.on.certificate_revocation_request.emit( + certificate=certificate["certificate"], + certificate_signing_request=certificate["certificate_signing_request"], + ca=certificate["ca"], + chain=certificate["chain"], + ) + self.remove_certificate(certificate=certificate["certificate"]) def get_outstanding_certificate_requests( self, relation_id: Optional[int] = None - ) -> List[RequirerCSR]: + ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: """Return CSR's for which no certificate has been issued. + Example return: [ + { + "relation_id": 0, + "application_name": "tls-certificates-requirer", + "unit_name": "tls-certificates-requirer/0", + "unit_csrs": [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + "is_ca": false + } + ] + } + ] + Args: relation_id (int): Relation id Returns: - list: List of RequirerCSR objects. + list: List of dictionaries that contain the unit's csrs + that don't have a certificate issued. """ - requirer_csrs = self.get_requirer_csrs(relation_id=relation_id) - outstanding_csrs: List[RequirerCSR] = [] - for relation_csr in requirer_csrs: - if not self.certificate_issued_for_csr( - app_name=relation_csr.application_name, - csr=relation_csr.csr, - relation_id=relation_id, - ): - outstanding_csrs.append(relation_csr) - return outstanding_csrs - - def get_requirer_csrs(self, relation_id: Optional[int] = None) -> List[RequirerCSR]: - """Return a list of requirers' CSRs. + all_unit_csr_mappings = copy.deepcopy(self.get_requirer_csrs(relation_id=relation_id)) + filtered_all_unit_csr_mappings: List[Dict[str, Union[int, str, List[Dict[str, str]]]]] = [] + for unit_csr_mapping in all_unit_csr_mappings: + csrs_without_certs = [] + for csr in unit_csr_mapping["unit_csrs"]: # type: ignore[union-attr] + if not self.certificate_issued_for_csr( + app_name=unit_csr_mapping["application_name"], # type: ignore[arg-type] + csr=csr["certificate_signing_request"], # type: ignore[index] + relation_id=relation_id, + ): + csrs_without_certs.append(csr) + if csrs_without_certs: + unit_csr_mapping["unit_csrs"] = csrs_without_certs # type: ignore[assignment] + filtered_all_unit_csr_mappings.append(unit_csr_mapping) + return filtered_all_unit_csr_mappings + + def get_requirer_csrs( + self, relation_id: Optional[int] = None + ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: + """Return a list of requirers' CSRs grouped by unit. It returns CSRs from all relations if relation_id is not specified. CSRs are returned per relation id, application name and unit name. Returns: - list: List[RequirerCSR] + list: List of dictionaries that contain the unit's csrs + with the following information + relation_id, application_name and unit_name. """ - relation_csrs: List[RequirerCSR] = [] + unit_csr_mappings: List[Dict[str, Union[int, str, List[Dict[str, str]]]]] = [] + relations = ( [ relation @@ -1521,24 +1449,15 @@ def get_requirer_csrs(self, relation_id: Optional[int] = None) -> List[RequirerC for unit in relation.units: requirer_relation_data = _load_relation_data(relation.data[unit]) unit_csrs_list = requirer_relation_data.get("certificate_signing_requests", []) - for unit_csr in unit_csrs_list: - csr = unit_csr.get("certificate_signing_request") - if not csr: - logger.warning("No CSR found in relation data - Skipping") - continue - ca = unit_csr.get("ca", False) - if not relation.app: - logger.warning("No remote app in relation - Skipping") - continue - relation_csr = RequirerCSR( - relation_id=relation.id, - application_name=relation.app.name, - unit_name=unit.name, - csr=csr, - is_ca=ca, - ) - relation_csrs.append(relation_csr) - return relation_csrs + unit_csr_mappings.append( + { + "relation_id": relation.id, + "application_name": relation.app.name, # type: ignore[union-attr] + "unit_name": unit.name, + "unit_csrs": unit_csrs_list, + } + ) + return unit_csr_mappings def certificate_issued_for_csr( self, app_name: str, csr: str, relation_id: Optional[int] @@ -1549,18 +1468,19 @@ def certificate_issued_for_csr( app_name (str): Application name that the CSR belongs to. csr (str): Certificate Signing Request. relation_id (Optional[int]): Relation ID - Returns: bool: True/False depending on whether a certificate has been issued for the given CSR. """ - issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id) - for issued_certificate in issued_certificates_per_csr: - if issued_certificate.csr == csr and issued_certificate.application_name == app_name: - return csr_matches_certificate(csr, issued_certificate.certificate) + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id)[ + app_name + ] + for issued_pair in issued_certificates_per_csr: + if "csr" in issued_pair and issued_pair["csr"] == csr: + return csr_matches_certificate(csr, issued_pair["certificate"]) return False -class TLSCertificatesRequiresV3(Object): +class TLSCertificatesRequiresV2(Object): """TLS certificates requirer class to be instantiated by TLS certificates requirers.""" on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] @@ -1569,21 +1489,17 @@ def __init__( self, charm: CharmBase, relationship_name: str, - expiry_notification_time: Optional[int] = None, + expiry_notification_time: int = 168, ): """Generate/use private key and observes relation changed event. Args: charm: Charm object relationship_name: Juju relation name - expiry_notification_time (int): Number of hours prior to certificate expiry. - Used to trigger the CertificateExpiring event. - This value is used as a recommendation only, - The actual value is calculated taking into account the provider's recommendation. + expiry_notification_time (int): Time difference between now and expiry (in hours). + Used to trigger the CertificateExpiring event. Default: 7 days. """ super().__init__(charm, relationship_name) - if not JujuVersion.from_environ().has_secrets: - logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)") self.relationship_name = relationship_name self.charm = charm self.expiry_notification_time = expiry_notification_time @@ -1593,39 +1509,32 @@ def __init__( self.framework.observe( charm.on[relationship_name].relation_broken, self._on_relation_broken ) - self.framework.observe(charm.on.secret_expired, self._on_secret_expired) + if JujuVersion.from_environ().has_secrets: + self.framework.observe(charm.on.secret_expired, self._on_secret_expired) + else: + self.framework.observe(charm.on.update_status, self._on_update_status) - def get_requirer_csrs(self) -> List[RequirerCSR]: + @property + def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]: """Return list of requirer's CSRs from relation unit data. - Returns: - list: List of RequirerCSR objects. + Example: + [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + "ca": false + } + ] """ relation = self.model.get_relation(self.relationship_name) if not relation: - return [] - requirer_csrs = [] + raise RuntimeError(f"Relation {self.relationship_name} does not exist") requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) - requirer_csrs_dict = requirer_relation_data.get("certificate_signing_requests", []) - for requirer_csr_dict in requirer_csrs_dict: - csr = requirer_csr_dict.get("certificate_signing_request") - if not csr: - logger.warning("No CSR found in relation data - Skipping") - continue - ca = requirer_csr_dict.get("ca", False) - relation_csr = RequirerCSR( - relation_id=relation.id, - application_name=self.model.app.name, - unit_name=self.model.unit.name, - csr=csr, - is_ca=ca, - ) - requirer_csrs.append(relation_csr) - return requirer_csrs + return requirer_relation_data.get("certificate_signing_requests", []) - def get_provider_certificates(self) -> List[ProviderCertificate]: + @property + def _provider_certificates(self) -> List[Dict[str, str]]: """Return list of certificates from the provider's relation data.""" - provider_certificates: List[ProviderCertificate] = [] relation = self.model.get_relation(self.relationship_name) if not relation: logger.debug("No relation: %s", self.relationship_name) @@ -1634,50 +1543,12 @@ def get_provider_certificates(self) -> List[ProviderCertificate]: logger.debug("No remote app in relation: %s", self.relationship_name) return [] provider_relation_data = _load_relation_data(relation.data[relation.app]) - provider_certificate_dicts = provider_relation_data.get("certificates", []) - for provider_certificate_dict in provider_certificate_dicts: - certificate = provider_certificate_dict.get("certificate") - if not certificate: - logger.warning("No certificate found in relation data - Skipping") - continue - try: - certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) - except ValueError as e: - logger.error("Could not load certificate - Skipping: %s", e) - continue - ca = provider_certificate_dict.get("ca") - chain = provider_certificate_dict.get("chain", []) - csr = provider_certificate_dict.get("certificate_signing_request") - recommended_expiry_notification_time = provider_certificate_dict.get( - "recommended_expiry_notification_time" - ) - expiry_time = certificate_object.not_valid_after_utc - validity_start_time = certificate_object.not_valid_before_utc - expiry_notification_time = calculate_expiry_notification_time( - validity_start_time=validity_start_time, - expiry_time=expiry_time, - provider_recommended_notification_time=recommended_expiry_notification_time, - requirer_recommended_notification_time=self.expiry_notification_time, - ) - if not csr: - logger.warning("No CSR found in relation data - Skipping") - continue - revoked = provider_certificate_dict.get("revoked", False) - provider_certificate = ProviderCertificate( - relation_id=relation.id, - application_name=relation.app.name, - csr=csr, - certificate=certificate, - ca=ca, - chain=chain, - revoked=revoked, - expiry_time=expiry_time, - expiry_notification_time=expiry_notification_time, - ) - provider_certificates.append(provider_certificate) - return provider_certificates + if not self._relation_data_is_valid(provider_relation_data): + logger.warning("Provider relation data did not pass JSON Schema validation") + return [] + return provider_relation_data.get("certificates", []) - def _add_requirer_csr_to_relation_data(self, csr: str, is_ca: bool) -> None: + def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: """Add CSR to relation data. Args: @@ -1693,23 +1564,18 @@ def _add_requirer_csr_to_relation_data(self, csr: str, is_ca: bool) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - for requirer_csr in self.get_requirer_csrs(): - if requirer_csr.csr == csr and requirer_csr.is_ca == is_ca: - logger.info("CSR already in relation data - Doing nothing") - return - new_csr_dict = { + new_csr_dict: Dict[str, Union[bool, str]] = { "certificate_signing_request": csr, "ca": is_ca, } - requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) - existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) - new_relation_data = copy.deepcopy(existing_relation_data) - new_relation_data.append(new_csr_dict) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( - new_relation_data - ) + if new_csr_dict in self._requirer_csrs: + logger.info("CSR already in relation data - Doing nothing") + return + requirer_csrs = copy.deepcopy(self._requirer_csrs) + requirer_csrs.append(new_csr_dict) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) - def _remove_requirer_csr_from_relation_data(self, csr: str) -> None: + def _remove_requirer_csr(self, csr: str) -> None: """Remove CSR from relation data. Args: @@ -1724,18 +1590,14 @@ def _remove_requirer_csr_from_relation_data(self, csr: str) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - if not self.get_requirer_csrs(): + requirer_csrs = copy.deepcopy(self._requirer_csrs) + if not requirer_csrs: logger.info("No CSRs in relation data - Doing nothing") return - requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) - existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) - new_relation_data = copy.deepcopy(existing_relation_data) - for requirer_csr in new_relation_data: + for requirer_csr in requirer_csrs: if requirer_csr["certificate_signing_request"] == csr: - new_relation_data.remove(requirer_csr) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( - new_relation_data - ) + requirer_csrs.remove(requirer_csr) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) def request_certificate_creation( self, certificate_signing_request: bytes, is_ca: bool = False @@ -1755,9 +1617,7 @@ def request_certificate_creation( f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - self._add_requirer_csr_to_relation_data( - certificate_signing_request.decode().strip(), is_ca=is_ca - ) + self._add_requirer_csr(certificate_signing_request.decode().strip(), is_ca=is_ca) logger.info("Certificate request sent to provider") def request_certificate_revocation(self, certificate_signing_request: bytes) -> None: @@ -1773,7 +1633,7 @@ def request_certificate_revocation(self, certificate_signing_request: bytes) -> Returns: None """ - self._remove_requirer_csr_from_relation_data(certificate_signing_request.decode().strip()) + self._remove_requirer_csr(certificate_signing_request.decode().strip()) logger.info("Certificate revocation sent to provider") def request_certificate_renewal( @@ -1801,58 +1661,107 @@ def request_certificate_renewal( ) logger.info("Certificate renewal request completed.") - def get_assigned_certificates(self) -> List[ProviderCertificate]: + def get_assigned_certificates(self) -> List[Dict[str, str]]: """Get a list of certificates that were assigned to this unit. Returns: - List: List[ProviderCertificate] + List of certificates. For example: + [ + { + "ca": "-----BEGIN CERTIFICATE-----...", + "chain": [ + "-----BEGIN CERTIFICATE-----..." + ], + "certificate": "-----BEGIN CERTIFICATE-----...", + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + } + ] """ - assigned_certificates = [] - for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): - if cert := self._find_certificate_in_relation_data(requirer_csr.csr): - assigned_certificates.append(cert) - return assigned_certificates - - def get_expiring_certificates(self) -> List[ProviderCertificate]: + final_list = [] + for csr in self.get_certificate_signing_requests(fulfilled_only=True): + assert isinstance(csr["certificate_signing_request"], str) + if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): + final_list.append(cert) + return final_list + + def get_expiring_certificates(self) -> List[Dict[str, str]]: """Get a list of certificates that were assigned to this unit that are expiring or expired. Returns: - List: List[ProviderCertificate] + List of certificates. For example: + [ + { + "ca": "-----BEGIN CERTIFICATE-----...", + "chain": [ + "-----BEGIN CERTIFICATE-----..." + ], + "certificate": "-----BEGIN CERTIFICATE-----...", + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + } + ] """ - expiring_certificates: List[ProviderCertificate] = [] - for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): - if cert := self._find_certificate_in_relation_data(requirer_csr.csr): - if not cert.expiry_time or not cert.expiry_notification_time: + final_list = [] + for csr in self.get_certificate_signing_requests(fulfilled_only=True): + assert isinstance(csr["certificate_signing_request"], str) + if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): + expiry_time = _get_certificate_expiry_time(cert["certificate"]) + if not expiry_time: continue - if datetime.now(timezone.utc) > cert.expiry_notification_time: - expiring_certificates.append(cert) - return expiring_certificates + expiry_notification_time = expiry_time - timedelta( + hours=self.expiry_notification_time + ) + if datetime.now(timezone.utc) > expiry_notification_time: + final_list.append(cert) + return final_list def get_certificate_signing_requests( self, fulfilled_only: bool = False, unfulfilled_only: bool = False, - ) -> List[RequirerCSR]: + ) -> List[Dict[str, Union[bool, str]]]: """Get the list of CSR's that were sent to the provider. You can choose to get only the CSR's that have a certificate assigned or only the CSR's - that don't. + that don't. Args: fulfilled_only (bool): This option will discard CSRs that don't have certificates yet. unfulfilled_only (bool): This option will discard CSRs that have certificates signed. Returns: - List of RequirerCSR objects. + List of CSR dictionaries. For example: + [ + { + "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", + "ca": false + } + ] """ - csrs = [] - for requirer_csr in self.get_requirer_csrs(): - cert = self._find_certificate_in_relation_data(requirer_csr.csr) + final_list = [] + for csr in self._requirer_csrs: + assert isinstance(csr["certificate_signing_request"], str) + cert = self._find_certificate_in_relation_data(csr["certificate_signing_request"]) if (unfulfilled_only and cert) or (fulfilled_only and not cert): continue - csrs.append(requirer_csr) + final_list.append(csr) + + return final_list + + @staticmethod + def _relation_data_is_valid(certificates_data: dict) -> bool: + """Check whether relation data is valid based on json schema. - return csrs + Args: + certificates_data: Certificate data in dict format. + + Returns: + bool: Whether relation data is valid. + """ + try: + validate(instance=certificates_data, schema=PROVIDER_JSON_SCHEMA) + return True + except exceptions.ValidationError: + return False def _on_relation_changed(self, event: RelationChangedEvent) -> None: """Handle relation changed event. @@ -1862,8 +1771,9 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: If the provider certificate is revoked, emit a CertificateInvalidateEvent, otherwise emit a CertificateAvailableEvent. - Remove the secret for revoked certificate, or add a secret with the correct expiry - time for new certificates. + When Juju secrets are available, remove the secret for revoked certificate, + or add a secret with the correct expiry time for new certificates. + Args: event: Juju event @@ -1871,74 +1781,54 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: Returns: None """ - if not event.app: - logger.warning("No remote app in relation - Skipping") - return - if not _relation_data_is_valid(event.relation, event.app, PROVIDER_JSON_SCHEMA): - logger.debug("Relation data did not pass JSON Schema validation") - return - provider_certificates = self.get_provider_certificates() requirer_csrs = [ - certificate_creation_request.csr - for certificate_creation_request in self.get_requirer_csrs() + certificate_creation_request["certificate_signing_request"] + for certificate_creation_request in self._requirer_csrs ] - for certificate in provider_certificates: - if certificate.csr in requirer_csrs: - csr_in_sha256_hex = get_sha256_hex(certificate.csr) - if certificate.revoked: - with suppress(SecretNotFoundError): - logger.debug( - "Removing secret with label %s", - f"{LIBID}-{csr_in_sha256_hex}", - ) - secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}") - secret.remove_all_revisions() + for certificate in self._provider_certificates: + if certificate["certificate_signing_request"] in requirer_csrs: + if certificate.get("revoked", False): + if JujuVersion.from_environ().has_secrets: + with suppress(SecretNotFoundError): + secret = self.model.get_secret( + label=f"{LIBID}-{certificate['certificate_signing_request']}" + ) + secret.remove_all_revisions() self.on.certificate_invalidated.emit( reason="revoked", - certificate=certificate.certificate, - certificate_signing_request=certificate.csr, - ca=certificate.ca, - chain=certificate.chain, + certificate=certificate["certificate"], + certificate_signing_request=certificate["certificate_signing_request"], + ca=certificate["ca"], + chain=certificate["chain"], ) else: - try: - secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}") - logger.debug( - "Setting secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" - ) - # Juju < 3.6 will create a new revision even if the content is the same - if ( - secret.get_content(refresh=True).get("certificate", "") - == certificate.certificate - ): - logger.debug( - "Secret %s with correct certificate already exists", - f"{LIBID}-{csr_in_sha256_hex}", + if JujuVersion.from_environ().has_secrets: + try: + secret = self.model.get_secret( + label=f"{LIBID}-{certificate['certificate_signing_request']}" + ) + secret.set_content({"certificate": certificate["certificate"]}) + secret.set_info( + expire=self._get_next_secret_expiry_time( + certificate["certificate"] + ), + ) + except SecretNotFoundError: + secret = self.charm.unit.add_secret( + {"certificate": certificate["certificate"]}, + label=f"{LIBID}-{certificate['certificate_signing_request']}", + expire=self._get_next_secret_expiry_time( + certificate["certificate"] + ), ) - continue - secret.set_content( - {"certificate": certificate.certificate, "csr": certificate.csr} - ) - secret.set_info( - expire=self._get_next_secret_expiry_time(certificate), - ) - except SecretNotFoundError: - logger.debug( - "Creating new secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" - ) - secret = self.charm.unit.add_secret( - {"certificate": certificate.certificate, "csr": certificate.csr}, - label=f"{LIBID}-{csr_in_sha256_hex}", - expire=self._get_next_secret_expiry_time(certificate), - ) self.on.certificate_available.emit( - certificate_signing_request=certificate.csr, - certificate=certificate.certificate, - ca=certificate.ca, - chain=certificate.chain, + certificate_signing_request=certificate["certificate_signing_request"], + certificate=certificate["certificate"], + ca=certificate["ca"], + chain=certificate["chain"], ) - def _get_next_secret_expiry_time(self, certificate: ProviderCertificate) -> Optional[datetime]: + def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: """Return the expiry time or expiry notification time. Extracts the expiry time from the provided certificate, calculates the @@ -1946,21 +1836,20 @@ def _get_next_secret_expiry_time(self, certificate: ProviderCertificate) -> Opti the future. Args: - certificate: ProviderCertificate object + certificate: x509 certificate Returns: Optional[datetime]: None if the certificate expiry time cannot be read, next expiry time otherwise. """ - if not certificate.expiry_time or not certificate.expiry_notification_time: + expiry_time = _get_certificate_expiry_time(certificate) + if not expiry_time: return None - return _get_closest_future_time( - certificate.expiry_notification_time, - certificate.expiry_time, - ) + expiry_notification_time = expiry_time - timedelta(hours=self.expiry_notification_time) + return _get_closest_future_time(expiry_notification_time, expiry_time) def _on_relation_broken(self, event: RelationBrokenEvent) -> None: - """Handle Relation Broken Event. + """Handle relation broken event. Emitting `all_certificates_invalidated` from `relation-broken` rather than `relation-departed` since certs are stored in app data. @@ -1974,7 +1863,7 @@ def _on_relation_broken(self, event: RelationBrokenEvent) -> None: self.on.all_certificates_invalidated.emit() def _on_secret_expired(self, event: SecretExpiredEvent) -> None: - """Handle Secret Expired Event. + """Handle secret expired event. Loads the certificate from the secret, and will emit 1 of 2 events. @@ -1989,73 +1878,82 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: Args: event (SecretExpiredEvent): Juju event """ - csr = self._get_csr_from_secret(event.secret) - if not csr: - logger.error("Failed to get CSR from secret %s", event.secret.label) + if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"): return - provider_certificate = self._find_certificate_in_relation_data(csr) - if not provider_certificate: + csr = event.secret.label[len(f"{LIBID}-") :] + certificate_dict = self._find_certificate_in_relation_data(csr) + if not certificate_dict: # A secret expired but we did not find matching certificate. Cleaning up - logger.warning( - "Failed to find matching certificate for csr, cleaning up secret %s", - event.secret.label, - ) event.secret.remove_all_revisions() return - if not provider_certificate.expiry_time: + expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) + if not expiry_time: # A secret expired but matching certificate is invalid. Cleaning up - logger.warning( - "Certificate matching csr is invalid, cleaning up secret %s", - event.secret.label, - ) event.secret.remove_all_revisions() return - if datetime.now(timezone.utc) < provider_certificate.expiry_time: + if datetime.now(timezone.utc) < expiry_time: logger.warning("Certificate almost expired") self.on.certificate_expiring.emit( - certificate=provider_certificate.certificate, - expiry=provider_certificate.expiry_time.isoformat(), + certificate=certificate_dict["certificate"], + expiry=expiry_time.isoformat(), ) event.secret.set_info( - expire=provider_certificate.expiry_time, + expire=_get_certificate_expiry_time(certificate_dict["certificate"]), ) else: logger.warning("Certificate is expired") self.on.certificate_invalidated.emit( reason="expired", - certificate=provider_certificate.certificate, - certificate_signing_request=provider_certificate.csr, - ca=provider_certificate.ca, - chain=provider_certificate.chain, + certificate=certificate_dict["certificate"], + certificate_signing_request=certificate_dict["certificate_signing_request"], + ca=certificate_dict["ca"], + chain=certificate_dict["chain"], ) - self.request_certificate_revocation(provider_certificate.certificate.encode()) + self.request_certificate_revocation(certificate_dict["certificate"].encode()) event.secret.remove_all_revisions() - def _find_certificate_in_relation_data(self, csr: str) -> Optional[ProviderCertificate]: + def _find_certificate_in_relation_data(self, csr: str) -> Optional[Dict[str, Any]]: """Return the certificate that match the given CSR.""" - for provider_certificate in self.get_provider_certificates(): - if provider_certificate.csr != csr: + for certificate_dict in self._provider_certificates: + if certificate_dict["certificate_signing_request"] != csr: continue - return provider_certificate + return certificate_dict return None - def _get_csr_from_secret(self, secret: Secret) -> Union[str, None]: - """Extract the CSR from the secret label or content. + def _on_update_status(self, event: UpdateStatusEvent) -> None: + """Handle update status event. - This function is a workaround to maintain backwards compatibility - and fix the issue reported in - https://github.com/canonical/tls-certificates-interface/issues/228 + Goes through each certificate in the "certificates" relation and checks their expiry date. + If they are close to expire (<7 days), emits a CertificateExpiringEvent event and if + they are expired, emits a CertificateExpiredEvent. + + Args: + event (UpdateStatusEvent): Juju event + + Returns: + None """ - try: - content = secret.get_content(refresh=True) - except SecretNotFoundError: - return None - if not (csr := content.get("csr", None)): - # In versions <14 of the Lib we were storing the CSR in the label of the secret - # The CSR now is stored int the content of the secret, which was a breaking change - # Here we get the CSR if the secret was created by an app using libpatch 14 or lower - if secret.label and secret.label.startswith(f"{LIBID}-"): - csr = secret.label[len(f"{LIBID}-") :] - return csr + for certificate_dict in self._provider_certificates: + expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) + if not expiry_time: + continue + time_difference = expiry_time - datetime.now(timezone.utc) + if time_difference.total_seconds() < 0: + logger.warning("Certificate is expired") + self.on.certificate_invalidated.emit( + reason="expired", + certificate=certificate_dict["certificate"], + certificate_signing_request=certificate_dict["certificate_signing_request"], + ca=certificate_dict["ca"], + chain=certificate_dict["chain"], + ) + self.request_certificate_revocation(certificate_dict["certificate"].encode()) + continue + if time_difference.total_seconds() < (self.expiry_notification_time * 60 * 60): + logger.warning("Certificate almost expired") + self.on.certificate_expiring.emit( + certificate=certificate_dict["certificate"], + expiry=expiry_time.isoformat(), + ) From 3a62a08b30360dc861693ef312b4a734eaaed00e Mon Sep 17 00:00:00 2001 From: MiaAltieri Date: Fri, 20 Dec 2024 21:02:57 +0000 Subject: [PATCH 2/6] fix ci --- .github/workflows/ci.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8b3d5ba77..403fbeea9 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -49,11 +49,11 @@ jobs: terraform fmt terraform validate - name: lint test charm module - working-directory: ./terraform/tests - run: | - terraform init - terraform fmt - terraform validate + working-directory: ./terraform/tests + run: | + terraform init + terraform fmt + terraform validate - name: run checks - prepare run: | sudo snap install juju --channel=3.6/beta --classic From e17b425ec53f7429cc9e28e144d9338df1fb16d2 Mon Sep 17 00:00:00 2001 From: MiaAltieri Date: Fri, 20 Dec 2024 21:07:44 +0000 Subject: [PATCH 3/6] update tls libs --- .../v3/tls_certificates.py | 958 ++++++++++-------- 1 file changed, 530 insertions(+), 428 deletions(-) diff --git a/lib/charms/tls_certificates_interface/v3/tls_certificates.py b/lib/charms/tls_certificates_interface/v3/tls_certificates.py index c232362fe..8cab71336 100644 --- a/lib/charms/tls_certificates_interface/v3/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v3/tls_certificates.py @@ -1,4 +1,4 @@ -# Copyright 2021 Canonical Ltd. +# Copyright 2024 Canonical Ltd. # See LICENSE file for licensing details. @@ -7,16 +7,19 @@ This library contains the Requires and Provides classes for handling the tls-certificates interface. +Pre-requisites: + - Juju >= 3.0 + ## Getting Started From a charm directory, fetch the library using `charmcraft`: ```shell -charmcraft fetch-lib charms.tls_certificates_interface.v2.tls_certificates +charmcraft fetch-lib charms.tls_certificates_interface.v3.tls_certificates ``` Add the following libraries to the charm's `requirements.txt` file: - jsonschema -- cryptography +- cryptography >= 42.0.0 Add the following section to the charm's `charmcraft.yaml` file: ```yaml @@ -36,10 +39,10 @@ Example: ```python -from charms.tls_certificates_interface.v2.tls_certificates import ( +from charms.tls_certificates_interface.v3.tls_certificates import ( CertificateCreationRequestEvent, CertificateRevocationRequestEvent, - TLSCertificatesProvidesV2, + TLSCertificatesProvidesV3, generate_private_key, ) from ops.charm import CharmBase, InstallEvent @@ -59,7 +62,7 @@ class ExampleProviderCharm(CharmBase): def __init__(self, *args): super().__init__(*args) - self.certificates = TLSCertificatesProvidesV2(self, "certificates") + self.certificates = TLSCertificatesProvidesV3(self, "certificates") self.framework.observe( self.certificates.on.certificate_request, self._on_certificate_request @@ -108,6 +111,7 @@ def _on_certificate_request(self, event: CertificateCreationRequestEvent) -> Non ca=ca_certificate, chain=[ca_certificate, certificate], relation_id=event.relation_id, + recommended_expiry_notification_time=720, ) def _on_certificate_revocation_request(self, event: CertificateRevocationRequestEvent) -> None: @@ -126,15 +130,15 @@ def _on_certificate_revocation_request(self, event: CertificateRevocationRequest Example: ```python -from charms.tls_certificates_interface.v2.tls_certificates import ( +from charms.tls_certificates_interface.v3.tls_certificates import ( CertificateAvailableEvent, CertificateExpiringEvent, CertificateRevokedEvent, - TLSCertificatesRequiresV2, + TLSCertificatesRequiresV3, generate_csr, generate_private_key, ) -from ops.charm import CharmBase, RelationJoinedEvent +from ops.charm import CharmBase, RelationCreatedEvent from ops.main import main from ops.model import ActiveStatus, WaitingStatus from typing import Union @@ -145,10 +149,10 @@ class ExampleRequirerCharm(CharmBase): def __init__(self, *args): super().__init__(*args) self.cert_subject = "whatever" - self.certificates = TLSCertificatesRequiresV2(self, "certificates") + self.certificates = TLSCertificatesRequiresV3(self, "certificates") self.framework.observe(self.on.install, self._on_install) self.framework.observe( - self.on.certificates_relation_joined, self._on_certificates_relation_joined + self.on.certificates_relation_created, self._on_certificates_relation_created ) self.framework.observe( self.certificates.on.certificate_available, self._on_certificate_available @@ -176,7 +180,7 @@ def _on_install(self, event) -> None: {"private_key_password": "banana", "private_key": private_key.decode()} ) - def _on_certificates_relation_joined(self, event: RelationJoinedEvent) -> None: + def _on_certificates_relation_created(self, event: RelationCreatedEvent) -> None: replicas_relation = self.model.get_relation("replicas") if not replicas_relation: self.unit.status = WaitingStatus("Waiting for peer relation to be created") @@ -273,19 +277,19 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven """ # noqa: D405, D410, D411, D214, D416 import copy +import ipaddress import json import logging import uuid from contextlib import suppress +from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from ipaddress import IPv4Address -from typing import Any, Dict, List, Literal, Optional, Union +from typing import List, Literal, Optional, Union from cryptography import x509 from cryptography.hazmat._oid import ExtensionOID from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.serialization import pkcs12 from jsonschema import exceptions, validate from ops.charm import ( CharmBase, @@ -293,21 +297,28 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven RelationBrokenEvent, RelationChangedEvent, SecretExpiredEvent, - UpdateStatusEvent, ) from ops.framework import EventBase, EventSource, Handle, Object from ops.jujuversion import JujuVersion -from ops.model import ModelError, Relation, RelationDataContent, SecretNotFoundError +from ops.model import ( + Application, + ModelError, + Relation, + RelationDataContent, + Secret, + SecretNotFoundError, + Unit, +) # The unique Charmhub library identifier, never change it LIBID = "afd8c2bccf834997afce12c2706d2ede" # Increment this major API version when introducing breaking changes -LIBAPI = 2 +LIBAPI = 3 # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 29 +LIBPATCH = 24 PYDEPS = ["cryptography", "jsonschema"] @@ -422,6 +433,58 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven logger = logging.getLogger(__name__) +@dataclass +class RequirerCSR: + """This class represents a certificate signing request from an interface Requirer.""" + + relation_id: int + application_name: str + unit_name: str + csr: str + is_ca: bool + + +@dataclass +class ProviderCertificate: + """This class represents a certificate from an interface Provider.""" + + relation_id: int + application_name: str + csr: str + certificate: str + ca: str + chain: List[str] + revoked: bool + expiry_time: datetime + expiry_notification_time: Optional[datetime] = None + + def chain_as_pem(self) -> str: + """Return full certificate chain as a PEM string.""" + return "\n\n".join(reversed(self.chain)) + + def to_json(self) -> str: + """Return the object as a JSON string. + + Returns: + str: JSON representation of the object + """ + return json.dumps( + { + "relation_id": self.relation_id, + "application_name": self.application_name, + "csr": self.csr, + "certificate": self.certificate, + "ca": self.ca, + "chain": self.chain, + "revoked": self.revoked, + "expiry_time": self.expiry_time.isoformat(), + "expiry_notification_time": self.expiry_notification_time.isoformat() + if self.expiry_notification_time + else None, + } + ) + + class CertificateAvailableEvent(EventBase): """Charm Event triggered when a TLS certificate is available.""" @@ -455,6 +518,10 @@ def restore(self, snapshot: dict): self.ca = snapshot["ca"] self.chain = snapshot["chain"] + def chain_as_pem(self) -> str: + """Return full certificate chain as a PEM string.""" + return "\n\n".join(reversed(self.chain)) + class CertificateExpiringEvent(EventBase): """Charm Event triggered when a TLS certificate is almost expired.""" @@ -641,21 +708,49 @@ def _get_closest_future_time( ) -def _get_certificate_expiry_time(certificate: str) -> Optional[datetime]: - """Extract expiry time from a certificate string. +def calculate_expiry_notification_time( + validity_start_time: datetime, + expiry_time: datetime, + provider_recommended_notification_time: Optional[int], + requirer_recommended_notification_time: Optional[int], +) -> datetime: + """Calculate a reasonable time to notify the user about the certificate expiry. + + It takes into account the time recommended by the provider and by the requirer. + Time recommended by the provider is preferred, + then time recommended by the requirer, + then dynamically calculated time. Args: - certificate (str): x509 certificate as a string + validity_start_time: Certificate validity time + expiry_time: Certificate expiry time + provider_recommended_notification_time: + Time in hours prior to expiry to notify the user. + Recommended by the provider. + requirer_recommended_notification_time: + Time in hours prior to expiry to notify the user. + Recommended by the requirer. Returns: - Optional[datetime]: Expiry datetime or None + datetime: Time to notify the user about the certificate expiry. """ - try: - certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) - return certificate_object.not_valid_after_utc - except ValueError: - logger.warning("Could not load certificate.") - return None + if provider_recommended_notification_time is not None: + provider_recommended_notification_time = abs(provider_recommended_notification_time) + provider_recommendation_time_delta = expiry_time - timedelta( + hours=provider_recommended_notification_time + ) + if validity_start_time < provider_recommendation_time_delta: + return provider_recommendation_time_delta + + if requirer_recommended_notification_time is not None: + requirer_recommended_notification_time = abs(requirer_recommended_notification_time) + requirer_recommendation_time_delta = expiry_time - timedelta( + hours=requirer_recommended_notification_time + ) + if validity_start_time < requirer_recommendation_time_delta: + return requirer_recommendation_time_delta + calculated_hours = (expiry_time - validity_start_time).total_seconds() / (3600 * 3) + return expiry_time - timedelta(hours=calculated_hours) def generate_ca( @@ -886,38 +981,6 @@ def generate_certificate( return cert.public_bytes(serialization.Encoding.PEM) -def generate_pfx_package( - certificate: bytes, - private_key: bytes, - package_password: str, - private_key_password: Optional[bytes] = None, -) -> bytes: - """Generate a PFX package to contain the TLS certificate and private key. - - Args: - certificate (bytes): TLS certificate - private_key (bytes): Private key - package_password (str): Password to open the PFX package - private_key_password (bytes): Private key password - - Returns: - bytes: - """ - private_key_object = serialization.load_pem_private_key( - private_key, password=private_key_password - ) - certificate_object = x509.load_pem_x509_certificate(certificate) - name = certificate_object.subject.rfc4514_string() - pfx_bytes = pkcs12.serialize_key_and_certificates( - name=name.encode(), - cert=certificate_object, - key=private_key_object, # type: ignore[arg-type] - cas=None, - encryption_algorithm=serialization.BestAvailableEncryption(package_password.encode()), - ) - return pfx_bytes - - def generate_private_key( password: Optional[bytes] = None, key_size: int = 2048, @@ -956,6 +1019,8 @@ def generate_csr( # noqa: C901 organization: Optional[str] = None, email_address: Optional[str] = None, country_name: Optional[str] = None, + state_or_province_name: Optional[str] = None, + locality_name: Optional[str] = None, private_key_password: Optional[bytes] = None, sans: Optional[List[str]] = None, sans_oid: Optional[List[str]] = None, @@ -974,6 +1039,8 @@ def generate_csr( # noqa: C901 organization (str): Name of organization. email_address (str): Email address. country_name (str): Country Name. + state_or_province_name (str): State or Province Name. + locality_name (str): Locality Name. private_key_password (bytes): Private key password sans (list): Use sans_dns - this will be deprecated in a future release List of DNS subject alternative names (keeping it for now for backward compatibility) @@ -999,13 +1066,19 @@ def generate_csr( # noqa: C901 subject_name.append(x509.NameAttribute(x509.NameOID.EMAIL_ADDRESS, email_address)) if country_name: subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name)) + if state_or_province_name: + subject_name.append( + x509.NameAttribute(x509.NameOID.STATE_OR_PROVINCE_NAME, state_or_province_name) + ) + if locality_name: + subject_name.append(x509.NameAttribute(x509.NameOID.LOCALITY_NAME, locality_name)) csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name)) _sans: List[x509.GeneralName] = [] if sans_oid: _sans.extend([x509.RegisteredID(x509.ObjectIdentifier(san)) for san in sans_oid]) if sans_ip: - _sans.extend([x509.IPAddress(IPv4Address(san)) for san in sans_ip]) + _sans.extend([x509.IPAddress(ipaddress.ip_address(san)) for san in sans_ip]) if sans: _sans.extend([x509.DNSName(san) for san in sans]) if sans_dns: @@ -1021,6 +1094,13 @@ def generate_csr( # noqa: C901 return signed_certificate.public_bytes(serialization.Encoding.PEM) +def get_sha256_hex(data: str) -> str: + """Calculate the hash of the provided data and return the hexadecimal representation.""" + digest = hashes.Hash(hashes.SHA256()) + digest.update(data.encode()) + return digest.finalize().hex() + + def csr_matches_certificate(csr: str, cert: str) -> bool: """Check if a CSR matches a certificate. @@ -1030,29 +1110,41 @@ def csr_matches_certificate(csr: str, cert: str) -> bool: Returns: bool: True/False depending on whether the CSR matches the certificate. """ - try: - csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) - cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) - - if csr_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) != cert_object.public_key().public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ): - return False - if ( - csr_object.public_key().public_numbers().n # type: ignore[union-attr] - != cert_object.public_key().public_numbers().n # type: ignore[union-attr] - ): - return False - except ValueError: - logger.warning("Could not load certificate or CSR.") + csr_object = x509.load_pem_x509_csr(csr.encode("utf-8")) + cert_object = x509.load_pem_x509_certificate(cert.encode("utf-8")) + + if csr_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) != cert_object.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ): return False return True +def _relation_data_is_valid( + relation: Relation, app_or_unit: Union[Application, Unit], json_schema: dict +) -> bool: + """Check whether relation data is valid based on json schema. + + Args: + relation (Relation): Relation object + app_or_unit (Union[Application, Unit]): Application or unit object + json_schema (dict): Json schema + + Returns: + bool: Whether relation data is valid. + """ + relation_data = _load_relation_data(relation.data[app_or_unit]) + try: + validate(instance=relation_data, schema=json_schema) + return True + except exceptions.ValidationError: + return False + + class CertificatesProviderCharmEvents(CharmEvents): """List of events that the TLS Certificates provider charm can leverage.""" @@ -1069,7 +1161,7 @@ class CertificatesRequirerCharmEvents(CharmEvents): all_certificates_invalidated = EventSource(AllCertificatesInvalidatedEvent) -class TLSCertificatesProvidesV2(Object): +class TLSCertificatesProvidesV3(Object): """TLS certificates provider class to be instantiated by TLS certificates providers.""" on = CertificatesProviderCharmEvents() # type: ignore[reportAssignmentType] @@ -1105,6 +1197,7 @@ def _add_certificate( certificate_signing_request: str, ca: str, chain: List[str], + recommended_expiry_notification_time: Optional[int] = None, ) -> None: """Add certificate to relation data. @@ -1114,6 +1207,8 @@ def _add_certificate( certificate_signing_request (str): Certificate Signing Request ca (str): CA Certificate chain (list): CA Chain + recommended_expiry_notification_time (int): + Time in hours before the certificate expires to notify the user. Returns: None @@ -1131,6 +1226,7 @@ def _add_certificate( "certificate_signing_request": certificate_signing_request, "ca": ca, "chain": chain, + "recommended_expiry_notification_time": recommended_expiry_notification_time, } provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) @@ -1178,22 +1274,6 @@ def _remove_certificate( certificates.remove(certificate_dict) relation.data[self.model.app]["certificates"] = json.dumps(certificates) - @staticmethod - def _relation_data_is_valid(certificates_data: dict) -> bool: - """Use JSON schema validator to validate relation data content. - - Args: - certificates_data (dict): Certificate data dictionary as retrieved from relation data. - - Returns: - bool: True/False depending on whether the relation data follows the json schema. - """ - try: - validate(instance=certificates_data, schema=REQUIRER_JSON_SCHEMA) - return True - except exceptions.ValidationError: - return False - def revoke_all_certificates(self) -> None: """Revoke all certificates of this provider. @@ -1213,6 +1293,7 @@ def set_relation_certificate( ca: str, chain: List[str], relation_id: int, + recommended_expiry_notification_time: Optional[int] = None, ) -> None: """Add certificates to relation data. @@ -1222,6 +1303,8 @@ def set_relation_certificate( ca (str): CA Certificate chain (list): CA Chain relation_id (int): Juju relation ID + recommended_expiry_notification_time (int): + Recommended time in hours before the certificate expires to notify the user. Returns: None @@ -1243,6 +1326,7 @@ def set_relation_certificate( certificate_signing_request=certificate_signing_request.strip(), ca=ca.strip(), chain=[cert.strip() for cert in chain], + recommended_expiry_notification_time=recommended_expiry_notification_time, ) def remove_certificate(self, certificate: str) -> None: @@ -1262,16 +1346,24 @@ def remove_certificate(self, certificate: str) -> None: def get_issued_certificates( self, relation_id: Optional[int] = None - ) -> Dict[str, List[Dict[str, str]]]: - """Return a dictionary of issued certificates. + ) -> List[ProviderCertificate]: + """Return a List of issued (non revoked) certificates. - It returns certificates from all relations if relation_id is not specified. - Certificates are returned per application name and CSR. + Returns: + List: List of ProviderCertificate objects + """ + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + return [certificate for certificate in provider_certificates if not certificate.revoked] + + def get_provider_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return a List of issued certificates. Returns: - dict: Certificates per application name. + List: List of ProviderCertificate objects """ - certificates: Dict[str, List[Dict[str, str]]] = {} + certificates: List[ProviderCertificate] = [] relations = ( [ relation @@ -1282,19 +1374,33 @@ def get_issued_certificates( else self.model.relations.get(self.relationship_name, []) ) for relation in relations: + if not relation.app: + logger.warning("Relation %s does not have an application", relation.id) + continue provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) - - certificates[relation.app.name] = [] # type: ignore[union-attr] for certificate in provider_certificates: - if not certificate.get("revoked", False): - certificates[relation.app.name].append( # type: ignore[union-attr] - { - "csr": certificate["certificate_signing_request"], - "certificate": certificate["certificate"], - } + try: + certificate_object = x509.load_pem_x509_certificate( + data=certificate["certificate"].encode() ) - + except ValueError as e: + logger.error("Could not load certificate - Skipping: %s", e) + continue + provider_certificate = ProviderCertificate( + relation_id=relation.id, + application_name=relation.app.name, + csr=certificate["certificate_signing_request"], + certificate=certificate["certificate"], + ca=certificate["ca"], + chain=certificate["chain"], + revoked=certificate.get("revoked", False), + expiry_time=certificate_object.not_valid_after_utc, + expiry_notification_time=certificate.get( + "recommended_expiry_notification_time" + ), + ) + certificates.append(provider_certificate) return certificates def _on_relation_changed(self, event: RelationChangedEvent) -> None: @@ -1317,124 +1423,90 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: return if not self.model.unit.is_leader(): return - requirer_relation_data = _load_relation_data(event.relation.data[event.unit]) - provider_relation_data = self._load_app_relation_data(event.relation) - if not self._relation_data_is_valid(requirer_relation_data): + if not _relation_data_is_valid(event.relation, event.unit, REQUIRER_JSON_SCHEMA): logger.debug("Relation data did not pass JSON Schema validation") return - provider_certificates = provider_relation_data.get("certificates", []) - requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) + provider_certificates = self.get_provider_certificates(relation_id=event.relation.id) + requirer_csrs = self.get_requirer_csrs(relation_id=event.relation.id) provider_csrs = [ - certificate_creation_request["certificate_signing_request"] + certificate_creation_request.csr for certificate_creation_request in provider_certificates ] - requirer_unit_certificate_requests = [ - { - "csr": certificate_creation_request["certificate_signing_request"], - "is_ca": certificate_creation_request.get("ca", False), - } - for certificate_creation_request in requirer_csrs - ] - for certificate_request in requirer_unit_certificate_requests: - if certificate_request["csr"] not in provider_csrs: + for certificate_request in requirer_csrs: + if certificate_request.csr not in provider_csrs: self.on.certificate_creation_request.emit( - certificate_signing_request=certificate_request["csr"], - relation_id=event.relation.id, - is_ca=certificate_request["is_ca"], + certificate_signing_request=certificate_request.csr, + relation_id=certificate_request.relation_id, + is_ca=certificate_request.is_ca, ) self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id) def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None: """Revoke certificates for which no unit has a CSR. - Goes through all generated certificates and compare against the list of CSRs for all units - of a given relationship. - - Args: - relation_id (int): Relation id + Goes through all generated certificates and compare against the list of CSRs for all units. Returns: None """ - certificates_relation = self.model.get_relation( - relation_name=self.relationship_name, relation_id=relation_id - ) - if not certificates_relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") - provider_relation_data = self._load_app_relation_data(certificates_relation) - list_of_csrs: List[str] = [] - for unit in certificates_relation.units: - requirer_relation_data = _load_relation_data(certificates_relation.data[unit]) - requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) - list_of_csrs.extend(csr["certificate_signing_request"] for csr in requirer_csrs) - provider_certificates = provider_relation_data.get("certificates", []) + provider_certificates = self.get_unsolicited_certificates(relation_id=relation_id) + for provider_certificate in provider_certificates: + self.on.certificate_revocation_request.emit( + certificate=provider_certificate.certificate, + certificate_signing_request=provider_certificate.csr, + ca=provider_certificate.ca, + chain=provider_certificate.chain, + ) + self.remove_certificate(certificate=provider_certificate.certificate) + + def get_unsolicited_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return provider certificates for which no certificate requests exists. + + Those certificates should be revoked. + """ + unsolicited_certificates: List[ProviderCertificate] = [] + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + requirer_csrs = self.get_requirer_csrs(relation_id=relation_id) + list_of_csrs = [csr.csr for csr in requirer_csrs] for certificate in provider_certificates: - if certificate["certificate_signing_request"] not in list_of_csrs: - self.on.certificate_revocation_request.emit( - certificate=certificate["certificate"], - certificate_signing_request=certificate["certificate_signing_request"], - ca=certificate["ca"], - chain=certificate["chain"], - ) - self.remove_certificate(certificate=certificate["certificate"]) + if certificate.csr not in list_of_csrs: + unsolicited_certificates.append(certificate) + return unsolicited_certificates def get_outstanding_certificate_requests( self, relation_id: Optional[int] = None - ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: + ) -> List[RequirerCSR]: """Return CSR's for which no certificate has been issued. - Example return: [ - { - "relation_id": 0, - "application_name": "tls-certificates-requirer", - "unit_name": "tls-certificates-requirer/0", - "unit_csrs": [ - { - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - "is_ca": false - } - ] - } - ] - Args: relation_id (int): Relation id Returns: - list: List of dictionaries that contain the unit's csrs - that don't have a certificate issued. + list: List of RequirerCSR objects. """ - all_unit_csr_mappings = copy.deepcopy(self.get_requirer_csrs(relation_id=relation_id)) - filtered_all_unit_csr_mappings: List[Dict[str, Union[int, str, List[Dict[str, str]]]]] = [] - for unit_csr_mapping in all_unit_csr_mappings: - csrs_without_certs = [] - for csr in unit_csr_mapping["unit_csrs"]: # type: ignore[union-attr] - if not self.certificate_issued_for_csr( - app_name=unit_csr_mapping["application_name"], # type: ignore[arg-type] - csr=csr["certificate_signing_request"], # type: ignore[index] - relation_id=relation_id, - ): - csrs_without_certs.append(csr) - if csrs_without_certs: - unit_csr_mapping["unit_csrs"] = csrs_without_certs # type: ignore[assignment] - filtered_all_unit_csr_mappings.append(unit_csr_mapping) - return filtered_all_unit_csr_mappings - - def get_requirer_csrs( - self, relation_id: Optional[int] = None - ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: - """Return a list of requirers' CSRs grouped by unit. + requirer_csrs = self.get_requirer_csrs(relation_id=relation_id) + outstanding_csrs: List[RequirerCSR] = [] + for relation_csr in requirer_csrs: + if not self.certificate_issued_for_csr( + app_name=relation_csr.application_name, + csr=relation_csr.csr, + relation_id=relation_id, + ): + outstanding_csrs.append(relation_csr) + return outstanding_csrs + + def get_requirer_csrs(self, relation_id: Optional[int] = None) -> List[RequirerCSR]: + """Return a list of requirers' CSRs. It returns CSRs from all relations if relation_id is not specified. CSRs are returned per relation id, application name and unit name. Returns: - list: List of dictionaries that contain the unit's csrs - with the following information - relation_id, application_name and unit_name. + list: List[RequirerCSR] """ - unit_csr_mappings: List[Dict[str, Union[int, str, List[Dict[str, str]]]]] = [] - + relation_csrs: List[RequirerCSR] = [] relations = ( [ relation @@ -1449,15 +1521,24 @@ def get_requirer_csrs( for unit in relation.units: requirer_relation_data = _load_relation_data(relation.data[unit]) unit_csrs_list = requirer_relation_data.get("certificate_signing_requests", []) - unit_csr_mappings.append( - { - "relation_id": relation.id, - "application_name": relation.app.name, # type: ignore[union-attr] - "unit_name": unit.name, - "unit_csrs": unit_csrs_list, - } - ) - return unit_csr_mappings + for unit_csr in unit_csrs_list: + csr = unit_csr.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + ca = unit_csr.get("ca", False) + if not relation.app: + logger.warning("No remote app in relation - Skipping") + continue + relation_csr = RequirerCSR( + relation_id=relation.id, + application_name=relation.app.name, + unit_name=unit.name, + csr=csr, + is_ca=ca, + ) + relation_csrs.append(relation_csr) + return relation_csrs def certificate_issued_for_csr( self, app_name: str, csr: str, relation_id: Optional[int] @@ -1468,19 +1549,18 @@ def certificate_issued_for_csr( app_name (str): Application name that the CSR belongs to. csr (str): Certificate Signing Request. relation_id (Optional[int]): Relation ID + Returns: bool: True/False depending on whether a certificate has been issued for the given CSR. """ - issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id)[ - app_name - ] - for issued_pair in issued_certificates_per_csr: - if "csr" in issued_pair and issued_pair["csr"] == csr: - return csr_matches_certificate(csr, issued_pair["certificate"]) + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id) + for issued_certificate in issued_certificates_per_csr: + if issued_certificate.csr == csr and issued_certificate.application_name == app_name: + return csr_matches_certificate(csr, issued_certificate.certificate) return False -class TLSCertificatesRequiresV2(Object): +class TLSCertificatesRequiresV3(Object): """TLS certificates requirer class to be instantiated by TLS certificates requirers.""" on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] @@ -1489,17 +1569,21 @@ def __init__( self, charm: CharmBase, relationship_name: str, - expiry_notification_time: int = 168, + expiry_notification_time: Optional[int] = None, ): """Generate/use private key and observes relation changed event. Args: charm: Charm object relationship_name: Juju relation name - expiry_notification_time (int): Time difference between now and expiry (in hours). - Used to trigger the CertificateExpiring event. Default: 7 days. + expiry_notification_time (int): Number of hours prior to certificate expiry. + Used to trigger the CertificateExpiring event. + This value is used as a recommendation only, + The actual value is calculated taking into account the provider's recommendation. """ super().__init__(charm, relationship_name) + if not JujuVersion.from_environ().has_secrets: + logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)") self.relationship_name = relationship_name self.charm = charm self.expiry_notification_time = expiry_notification_time @@ -1509,32 +1593,39 @@ def __init__( self.framework.observe( charm.on[relationship_name].relation_broken, self._on_relation_broken ) - if JujuVersion.from_environ().has_secrets: - self.framework.observe(charm.on.secret_expired, self._on_secret_expired) - else: - self.framework.observe(charm.on.update_status, self._on_update_status) + self.framework.observe(charm.on.secret_expired, self._on_secret_expired) - @property - def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]: + def get_requirer_csrs(self) -> List[RequirerCSR]: """Return list of requirer's CSRs from relation unit data. - Example: - [ - { - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - "ca": false - } - ] + Returns: + list: List of RequirerCSR objects. """ relation = self.model.get_relation(self.relationship_name) if not relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") + return [] + requirer_csrs = [] requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) - return requirer_relation_data.get("certificate_signing_requests", []) + requirer_csrs_dict = requirer_relation_data.get("certificate_signing_requests", []) + for requirer_csr_dict in requirer_csrs_dict: + csr = requirer_csr_dict.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + ca = requirer_csr_dict.get("ca", False) + relation_csr = RequirerCSR( + relation_id=relation.id, + application_name=self.model.app.name, + unit_name=self.model.unit.name, + csr=csr, + is_ca=ca, + ) + requirer_csrs.append(relation_csr) + return requirer_csrs - @property - def _provider_certificates(self) -> List[Dict[str, str]]: + def get_provider_certificates(self) -> List[ProviderCertificate]: """Return list of certificates from the provider's relation data.""" + provider_certificates: List[ProviderCertificate] = [] relation = self.model.get_relation(self.relationship_name) if not relation: logger.debug("No relation: %s", self.relationship_name) @@ -1543,12 +1634,50 @@ def _provider_certificates(self) -> List[Dict[str, str]]: logger.debug("No remote app in relation: %s", self.relationship_name) return [] provider_relation_data = _load_relation_data(relation.data[relation.app]) - if not self._relation_data_is_valid(provider_relation_data): - logger.warning("Provider relation data did not pass JSON Schema validation") - return [] - return provider_relation_data.get("certificates", []) + provider_certificate_dicts = provider_relation_data.get("certificates", []) + for provider_certificate_dict in provider_certificate_dicts: + certificate = provider_certificate_dict.get("certificate") + if not certificate: + logger.warning("No certificate found in relation data - Skipping") + continue + try: + certificate_object = x509.load_pem_x509_certificate(data=certificate.encode()) + except ValueError as e: + logger.error("Could not load certificate - Skipping: %s", e) + continue + ca = provider_certificate_dict.get("ca") + chain = provider_certificate_dict.get("chain", []) + csr = provider_certificate_dict.get("certificate_signing_request") + recommended_expiry_notification_time = provider_certificate_dict.get( + "recommended_expiry_notification_time" + ) + expiry_time = certificate_object.not_valid_after_utc + validity_start_time = certificate_object.not_valid_before_utc + expiry_notification_time = calculate_expiry_notification_time( + validity_start_time=validity_start_time, + expiry_time=expiry_time, + provider_recommended_notification_time=recommended_expiry_notification_time, + requirer_recommended_notification_time=self.expiry_notification_time, + ) + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + revoked = provider_certificate_dict.get("revoked", False) + provider_certificate = ProviderCertificate( + relation_id=relation.id, + application_name=relation.app.name, + csr=csr, + certificate=certificate, + ca=ca, + chain=chain, + revoked=revoked, + expiry_time=expiry_time, + expiry_notification_time=expiry_notification_time, + ) + provider_certificates.append(provider_certificate) + return provider_certificates - def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: + def _add_requirer_csr_to_relation_data(self, csr: str, is_ca: bool) -> None: """Add CSR to relation data. Args: @@ -1564,18 +1693,23 @@ def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - new_csr_dict: Dict[str, Union[bool, str]] = { + for requirer_csr in self.get_requirer_csrs(): + if requirer_csr.csr == csr and requirer_csr.is_ca == is_ca: + logger.info("CSR already in relation data - Doing nothing") + return + new_csr_dict = { "certificate_signing_request": csr, "ca": is_ca, } - if new_csr_dict in self._requirer_csrs: - logger.info("CSR already in relation data - Doing nothing") - return - requirer_csrs = copy.deepcopy(self._requirer_csrs) - requirer_csrs.append(new_csr_dict) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) + requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) + existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) + new_relation_data = copy.deepcopy(existing_relation_data) + new_relation_data.append(new_csr_dict) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( + new_relation_data + ) - def _remove_requirer_csr(self, csr: str) -> None: + def _remove_requirer_csr_from_relation_data(self, csr: str) -> None: """Remove CSR from relation data. Args: @@ -1590,14 +1724,18 @@ def _remove_requirer_csr(self, csr: str) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - requirer_csrs = copy.deepcopy(self._requirer_csrs) - if not requirer_csrs: + if not self.get_requirer_csrs(): logger.info("No CSRs in relation data - Doing nothing") return - for requirer_csr in requirer_csrs: + requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) + existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) + new_relation_data = copy.deepcopy(existing_relation_data) + for requirer_csr in new_relation_data: if requirer_csr["certificate_signing_request"] == csr: - requirer_csrs.remove(requirer_csr) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) + new_relation_data.remove(requirer_csr) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( + new_relation_data + ) def request_certificate_creation( self, certificate_signing_request: bytes, is_ca: bool = False @@ -1617,7 +1755,9 @@ def request_certificate_creation( f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - self._add_requirer_csr(certificate_signing_request.decode().strip(), is_ca=is_ca) + self._add_requirer_csr_to_relation_data( + certificate_signing_request.decode().strip(), is_ca=is_ca + ) logger.info("Certificate request sent to provider") def request_certificate_revocation(self, certificate_signing_request: bytes) -> None: @@ -1633,7 +1773,7 @@ def request_certificate_revocation(self, certificate_signing_request: bytes) -> Returns: None """ - self._remove_requirer_csr(certificate_signing_request.decode().strip()) + self._remove_requirer_csr_from_relation_data(certificate_signing_request.decode().strip()) logger.info("Certificate revocation sent to provider") def request_certificate_renewal( @@ -1661,107 +1801,58 @@ def request_certificate_renewal( ) logger.info("Certificate renewal request completed.") - def get_assigned_certificates(self) -> List[Dict[str, str]]: + def get_assigned_certificates(self) -> List[ProviderCertificate]: """Get a list of certificates that were assigned to this unit. Returns: - List of certificates. For example: - [ - { - "ca": "-----BEGIN CERTIFICATE-----...", - "chain": [ - "-----BEGIN CERTIFICATE-----..." - ], - "certificate": "-----BEGIN CERTIFICATE-----...", - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - } - ] + List: List[ProviderCertificate] """ - final_list = [] - for csr in self.get_certificate_signing_requests(fulfilled_only=True): - assert isinstance(csr["certificate_signing_request"], str) - if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): - final_list.append(cert) - return final_list - - def get_expiring_certificates(self) -> List[Dict[str, str]]: + assigned_certificates = [] + for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): + if cert := self._find_certificate_in_relation_data(requirer_csr.csr): + assigned_certificates.append(cert) + return assigned_certificates + + def get_expiring_certificates(self) -> List[ProviderCertificate]: """Get a list of certificates that were assigned to this unit that are expiring or expired. Returns: - List of certificates. For example: - [ - { - "ca": "-----BEGIN CERTIFICATE-----...", - "chain": [ - "-----BEGIN CERTIFICATE-----..." - ], - "certificate": "-----BEGIN CERTIFICATE-----...", - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - } - ] + List: List[ProviderCertificate] """ - final_list = [] - for csr in self.get_certificate_signing_requests(fulfilled_only=True): - assert isinstance(csr["certificate_signing_request"], str) - if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): - expiry_time = _get_certificate_expiry_time(cert["certificate"]) - if not expiry_time: + expiring_certificates: List[ProviderCertificate] = [] + for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): + if cert := self._find_certificate_in_relation_data(requirer_csr.csr): + if not cert.expiry_time or not cert.expiry_notification_time: continue - expiry_notification_time = expiry_time - timedelta( - hours=self.expiry_notification_time - ) - if datetime.now(timezone.utc) > expiry_notification_time: - final_list.append(cert) - return final_list + if datetime.now(timezone.utc) > cert.expiry_notification_time: + expiring_certificates.append(cert) + return expiring_certificates def get_certificate_signing_requests( self, fulfilled_only: bool = False, unfulfilled_only: bool = False, - ) -> List[Dict[str, Union[bool, str]]]: + ) -> List[RequirerCSR]: """Get the list of CSR's that were sent to the provider. You can choose to get only the CSR's that have a certificate assigned or only the CSR's - that don't. + that don't. Args: fulfilled_only (bool): This option will discard CSRs that don't have certificates yet. unfulfilled_only (bool): This option will discard CSRs that have certificates signed. Returns: - List of CSR dictionaries. For example: - [ - { - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - "ca": false - } - ] + List of RequirerCSR objects. """ - final_list = [] - for csr in self._requirer_csrs: - assert isinstance(csr["certificate_signing_request"], str) - cert = self._find_certificate_in_relation_data(csr["certificate_signing_request"]) + csrs = [] + for requirer_csr in self.get_requirer_csrs(): + cert = self._find_certificate_in_relation_data(requirer_csr.csr) if (unfulfilled_only and cert) or (fulfilled_only and not cert): continue - final_list.append(csr) - - return final_list - - @staticmethod - def _relation_data_is_valid(certificates_data: dict) -> bool: - """Check whether relation data is valid based on json schema. + csrs.append(requirer_csr) - Args: - certificates_data: Certificate data in dict format. - - Returns: - bool: Whether relation data is valid. - """ - try: - validate(instance=certificates_data, schema=PROVIDER_JSON_SCHEMA) - return True - except exceptions.ValidationError: - return False + return csrs def _on_relation_changed(self, event: RelationChangedEvent) -> None: """Handle relation changed event. @@ -1771,9 +1862,8 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: If the provider certificate is revoked, emit a CertificateInvalidateEvent, otherwise emit a CertificateAvailableEvent. - When Juju secrets are available, remove the secret for revoked certificate, - or add a secret with the correct expiry time for new certificates. - + Remove the secret for revoked certificate, or add a secret with the correct expiry + time for new certificates. Args: event: Juju event @@ -1781,54 +1871,74 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: Returns: None """ + if not event.app: + logger.warning("No remote app in relation - Skipping") + return + if not _relation_data_is_valid(event.relation, event.app, PROVIDER_JSON_SCHEMA): + logger.debug("Relation data did not pass JSON Schema validation") + return + provider_certificates = self.get_provider_certificates() requirer_csrs = [ - certificate_creation_request["certificate_signing_request"] - for certificate_creation_request in self._requirer_csrs + certificate_creation_request.csr + for certificate_creation_request in self.get_requirer_csrs() ] - for certificate in self._provider_certificates: - if certificate["certificate_signing_request"] in requirer_csrs: - if certificate.get("revoked", False): - if JujuVersion.from_environ().has_secrets: - with suppress(SecretNotFoundError): - secret = self.model.get_secret( - label=f"{LIBID}-{certificate['certificate_signing_request']}" - ) - secret.remove_all_revisions() + for certificate in provider_certificates: + if certificate.csr in requirer_csrs: + csr_in_sha256_hex = get_sha256_hex(certificate.csr) + if certificate.revoked: + with suppress(SecretNotFoundError): + logger.debug( + "Removing secret with label %s", + f"{LIBID}-{csr_in_sha256_hex}", + ) + secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}") + secret.remove_all_revisions() self.on.certificate_invalidated.emit( reason="revoked", - certificate=certificate["certificate"], - certificate_signing_request=certificate["certificate_signing_request"], - ca=certificate["ca"], - chain=certificate["chain"], + certificate=certificate.certificate, + certificate_signing_request=certificate.csr, + ca=certificate.ca, + chain=certificate.chain, ) else: - if JujuVersion.from_environ().has_secrets: - try: - secret = self.model.get_secret( - label=f"{LIBID}-{certificate['certificate_signing_request']}" - ) - secret.set_content({"certificate": certificate["certificate"]}) - secret.set_info( - expire=self._get_next_secret_expiry_time( - certificate["certificate"] - ), - ) - except SecretNotFoundError: - secret = self.charm.unit.add_secret( - {"certificate": certificate["certificate"]}, - label=f"{LIBID}-{certificate['certificate_signing_request']}", - expire=self._get_next_secret_expiry_time( - certificate["certificate"] - ), + try: + secret = self.model.get_secret(label=f"{LIBID}-{csr_in_sha256_hex}") + logger.debug( + "Setting secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" + ) + # Juju < 3.6 will create a new revision even if the content is the same + if ( + secret.get_content(refresh=True).get("certificate", "") + == certificate.certificate + ): + logger.debug( + "Secret %s with correct certificate already exists", + f"{LIBID}-{csr_in_sha256_hex}", ) + continue + secret.set_content( + {"certificate": certificate.certificate, "csr": certificate.csr} + ) + secret.set_info( + expire=self._get_next_secret_expiry_time(certificate), + ) + except SecretNotFoundError: + logger.debug( + "Creating new secret with label %s", f"{LIBID}-{csr_in_sha256_hex}" + ) + secret = self.charm.unit.add_secret( + {"certificate": certificate.certificate, "csr": certificate.csr}, + label=f"{LIBID}-{csr_in_sha256_hex}", + expire=self._get_next_secret_expiry_time(certificate), + ) self.on.certificate_available.emit( - certificate_signing_request=certificate["certificate_signing_request"], - certificate=certificate["certificate"], - ca=certificate["ca"], - chain=certificate["chain"], + certificate_signing_request=certificate.csr, + certificate=certificate.certificate, + ca=certificate.ca, + chain=certificate.chain, ) - def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: + def _get_next_secret_expiry_time(self, certificate: ProviderCertificate) -> Optional[datetime]: """Return the expiry time or expiry notification time. Extracts the expiry time from the provided certificate, calculates the @@ -1836,20 +1946,21 @@ def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: the future. Args: - certificate: x509 certificate + certificate: ProviderCertificate object Returns: Optional[datetime]: None if the certificate expiry time cannot be read, next expiry time otherwise. """ - expiry_time = _get_certificate_expiry_time(certificate) - if not expiry_time: + if not certificate.expiry_time or not certificate.expiry_notification_time: return None - expiry_notification_time = expiry_time - timedelta(hours=self.expiry_notification_time) - return _get_closest_future_time(expiry_notification_time, expiry_time) + return _get_closest_future_time( + certificate.expiry_notification_time, + certificate.expiry_time, + ) def _on_relation_broken(self, event: RelationBrokenEvent) -> None: - """Handle relation broken event. + """Handle Relation Broken Event. Emitting `all_certificates_invalidated` from `relation-broken` rather than `relation-departed` since certs are stored in app data. @@ -1863,7 +1974,7 @@ def _on_relation_broken(self, event: RelationBrokenEvent) -> None: self.on.all_certificates_invalidated.emit() def _on_secret_expired(self, event: SecretExpiredEvent) -> None: - """Handle secret expired event. + """Handle Secret Expired Event. Loads the certificate from the secret, and will emit 1 of 2 events. @@ -1878,82 +1989,73 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: Args: event (SecretExpiredEvent): Juju event """ - if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"): + csr = self._get_csr_from_secret(event.secret) + if not csr: + logger.error("Failed to get CSR from secret %s", event.secret.label) return - csr = event.secret.label[len(f"{LIBID}-") :] - certificate_dict = self._find_certificate_in_relation_data(csr) - if not certificate_dict: + provider_certificate = self._find_certificate_in_relation_data(csr) + if not provider_certificate: # A secret expired but we did not find matching certificate. Cleaning up + logger.warning( + "Failed to find matching certificate for csr, cleaning up secret %s", + event.secret.label, + ) event.secret.remove_all_revisions() return - expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) - if not expiry_time: + if not provider_certificate.expiry_time: # A secret expired but matching certificate is invalid. Cleaning up + logger.warning( + "Certificate matching csr is invalid, cleaning up secret %s", + event.secret.label, + ) event.secret.remove_all_revisions() return - if datetime.now(timezone.utc) < expiry_time: + if datetime.now(timezone.utc) < provider_certificate.expiry_time: logger.warning("Certificate almost expired") self.on.certificate_expiring.emit( - certificate=certificate_dict["certificate"], - expiry=expiry_time.isoformat(), + certificate=provider_certificate.certificate, + expiry=provider_certificate.expiry_time.isoformat(), ) event.secret.set_info( - expire=_get_certificate_expiry_time(certificate_dict["certificate"]), + expire=provider_certificate.expiry_time, ) else: logger.warning("Certificate is expired") self.on.certificate_invalidated.emit( reason="expired", - certificate=certificate_dict["certificate"], - certificate_signing_request=certificate_dict["certificate_signing_request"], - ca=certificate_dict["ca"], - chain=certificate_dict["chain"], + certificate=provider_certificate.certificate, + certificate_signing_request=provider_certificate.csr, + ca=provider_certificate.ca, + chain=provider_certificate.chain, ) - self.request_certificate_revocation(certificate_dict["certificate"].encode()) + self.request_certificate_revocation(provider_certificate.certificate.encode()) event.secret.remove_all_revisions() - def _find_certificate_in_relation_data(self, csr: str) -> Optional[Dict[str, Any]]: + def _find_certificate_in_relation_data(self, csr: str) -> Optional[ProviderCertificate]: """Return the certificate that match the given CSR.""" - for certificate_dict in self._provider_certificates: - if certificate_dict["certificate_signing_request"] != csr: + for provider_certificate in self.get_provider_certificates(): + if provider_certificate.csr != csr: continue - return certificate_dict + return provider_certificate return None - def _on_update_status(self, event: UpdateStatusEvent) -> None: - """Handle update status event. + def _get_csr_from_secret(self, secret: Secret) -> Union[str, None]: + """Extract the CSR from the secret label or content. - Goes through each certificate in the "certificates" relation and checks their expiry date. - If they are close to expire (<7 days), emits a CertificateExpiringEvent event and if - they are expired, emits a CertificateExpiredEvent. - - Args: - event (UpdateStatusEvent): Juju event - - Returns: - None + This function is a workaround to maintain backwards compatibility + and fix the issue reported in + https://github.com/canonical/tls-certificates-interface/issues/228 """ - for certificate_dict in self._provider_certificates: - expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) - if not expiry_time: - continue - time_difference = expiry_time - datetime.now(timezone.utc) - if time_difference.total_seconds() < 0: - logger.warning("Certificate is expired") - self.on.certificate_invalidated.emit( - reason="expired", - certificate=certificate_dict["certificate"], - certificate_signing_request=certificate_dict["certificate_signing_request"], - ca=certificate_dict["ca"], - chain=certificate_dict["chain"], - ) - self.request_certificate_revocation(certificate_dict["certificate"].encode()) - continue - if time_difference.total_seconds() < (self.expiry_notification_time * 60 * 60): - logger.warning("Certificate almost expired") - self.on.certificate_expiring.emit( - certificate=certificate_dict["certificate"], - expiry=expiry_time.isoformat(), - ) + try: + content = secret.get_content(refresh=True) + except SecretNotFoundError: + return None + if not (csr := content.get("csr", None)): + # In versions <14 of the Lib we were storing the CSR in the label of the secret + # The CSR now is stored int the content of the secret, which was a breaking change + # Here we get the CSR if the secret was created by an app using libpatch 14 or lower + if secret.label and secret.label.startswith(f"{LIBID}-"): + csr = secret.label[len(f"{LIBID}-") :] + return csr From fa08e00c319070a27332483c2503f05e90c11dc6 Mon Sep 17 00:00:00 2001 From: MiaAltieri Date: Fri, 20 Dec 2024 21:35:16 +0000 Subject: [PATCH 4/6] fix CI workflow --- .github/workflows/ci.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 403fbeea9..37148de06 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -51,9 +51,9 @@ jobs: - name: lint test charm module working-directory: ./terraform/tests run: | - terraform init - terraform fmt - terraform validate + terraform init + terraform fmt + terraform validate - name: run checks - prepare run: | sudo snap install juju --channel=3.6/beta --classic From 0ad048d79591e8118dd8589898348f849cc39012 Mon Sep 17 00:00:00 2001 From: MiaAltieri Date: Fri, 20 Dec 2024 21:57:39 +0000 Subject: [PATCH 5/6] fix TF --- .github/workflows/ci.yaml | 13 +++++++------ terraform/main.tf | 2 +- terraform/tests/simple_deployment.tf | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 37148de06..58cba7f50 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -43,17 +43,17 @@ jobs: fetch-depth: 0 - name: lint charm module - working-directory: ./terraform run: | + pushd ./terraform terraform init terraform fmt terraform validate - - name: lint test charm module - working-directory: ./terraform/tests - run: | + pushd ./tests terraform init terraform fmt terraform validate + popd + popd - name: run checks - prepare run: | sudo snap install juju --channel=3.6/beta --classic @@ -75,9 +75,10 @@ jobs: juju model-defaults logging-config='=INFO; unit=DEBUG' juju add-model test - name: Terraform deploy - working-directory: ./terraform/tests/ run: | - terraform apply -var "model=test" -target null_resource.simple_deployment_juju_wait_deployment -auto-approve + pushd ./terraform/tests/ + terraform apply -var "model_name=test" -target null_resource.simple_deployment_juju_wait_deployment -auto-approve + popd lib-check: name: Check libraries diff --git a/terraform/main.tf b/terraform/main.tf index bf330f69c..c8c6b5fdf 100644 --- a/terraform/main.tf +++ b/terraform/main.tf @@ -7,7 +7,7 @@ resource "juju_application" "mongodb" { name = "mongodb" channel = var.channel revision = var.revision - base = "ubuntu@22.04" + base = "ubuntu@22.04" } config = var.config model = var.model diff --git a/terraform/tests/simple_deployment.tf b/terraform/tests/simple_deployment.tf index 74e3ba041..a3e30d584 100644 --- a/terraform/tests/simple_deployment.tf +++ b/terraform/tests/simple_deployment.tf @@ -3,7 +3,7 @@ module "mongodb" { app_name = var.app_name model = var.model_name units = var.simple_mongodb_units - channel = "6/edge" + channel = "6/edge" } resource "juju_integration" "simple_deployment_tls-operator_mongodb-integration" { From ea3ccd1c95b623a2555667eb8e286ba8dc711a81 Mon Sep 17 00:00:00 2001 From: MiaAltieri Date: Fri, 20 Dec 2024 22:49:26 +0000 Subject: [PATCH 6/6] clean up CI workflows to nehas suggestions --- .github/workflows/ci.yaml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 58cba7f50..9073ce1c3 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -43,17 +43,17 @@ jobs: fetch-depth: 0 - name: lint charm module + working-directory: ./terraform run: | - pushd ./terraform terraform init terraform fmt terraform validate - pushd ./tests + - name: lint test charm module + working-directory: ./terraform/tests + run: | terraform init terraform fmt terraform validate - popd - popd - name: run checks - prepare run: | sudo snap install juju --channel=3.6/beta --classic @@ -75,10 +75,9 @@ jobs: juju model-defaults logging-config='=INFO; unit=DEBUG' juju add-model test - name: Terraform deploy + working-directory: ./terraform/tests/ run: | - pushd ./terraform/tests/ terraform apply -var "model_name=test" -target null_resource.simple_deployment_juju_wait_deployment -auto-approve - popd lib-check: name: Check libraries