From 13875acedd8a8f32de964926b208b6eb644dc770 Mon Sep 17 00:00:00 2001 From: Lucian Petrut Date: Wed, 6 Nov 2024 13:55:12 +0200 Subject: [PATCH] Fix certificate refresh and add e2e tests (#766) --- src/k8s/pkg/k8sd/api/certificates_refresh.go | 54 ++++++++++-- .../k8sd/controllers/csrsigning/reconcile.go | 22 +++-- .../k8sd/controllers/csrsigning/validate.go | 8 +- .../controllers/csrsigning/validate_test.go | 5 +- tests/integration/requirements-test.txt | 1 + .../templates/bootstrap-csr-auto-approve.yaml | 9 ++ tests/integration/tests/test_clustering.py | 88 +++++++++++++++++++ tests/integration/tox.ini | 2 +- 8 files changed, 172 insertions(+), 17 deletions(-) create mode 100644 tests/integration/templates/bootstrap-csr-auto-approve.yaml diff --git a/src/k8s/pkg/k8sd/api/certificates_refresh.go b/src/k8s/pkg/k8sd/api/certificates_refresh.go index beb519752..c94f33433 100644 --- a/src/k8s/pkg/k8sd/api/certificates_refresh.go +++ b/src/k8s/pkg/k8sd/api/certificates_refresh.go @@ -2,10 +2,14 @@ package api import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" "crypto/x509/pkix" + "encoding/base64" "fmt" "math" - "math/rand" + "math/big" "net" "net/http" "path/filepath" @@ -29,7 +33,11 @@ import ( ) func (e *Endpoints) postRefreshCertsPlan(s state.State, r *http.Request) response.Response { - seed := rand.Intn(math.MaxInt) + seedBigInt, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt)) + if err != nil { + return response.InternalError(fmt.Errorf("failed to generate seed: %w", err)) + } + seed := int(seedBigInt.Int64()) snap := e.provider.Snap() isWorker, err := snaputil.IsWorker(snap) @@ -216,6 +224,18 @@ func refreshCertsRunWorker(s state.State, r *http.Request, snap snap.Snap) respo certificates.CACert = clusterConfig.Certificates.GetCACert() certificates.ClientCACert = clusterConfig.Certificates.GetClientCACert() + k8sdPublicKey, err := pkiutil.LoadRSAPublicKey(clusterConfig.Certificates.GetK8sdPublicKey()) + if err != nil { + return response.InternalError(fmt.Errorf("failed to load k8sd public key, error: %w", err)) + } + + hostnames := []string{snap.Hostname()} + ips := []net.IP{net.ParseIP(s.Address().Hostname())} + + extraIPs, extraNames := utils.SplitIPAndDNSSANs(req.ExtraSANs) + hostnames = append(hostnames, extraNames...) + ips = append(ips, extraIPs...) + g, ctx := errgroup.WithContext(r.Context()) for _, csr := range []struct { @@ -234,8 +254,8 @@ func refreshCertsRunWorker(s state.State, r *http.Request, snap snap.Snap) respo commonName: fmt.Sprintf("system:node:%s", snap.Hostname()), organization: []string{"system:nodes"}, usages: []certv1.KeyUsage{certv1.UsageDigitalSignature, certv1.UsageKeyEncipherment, certv1.UsageServerAuth}, - hostnames: []string{snap.Hostname()}, - ips: []net.IP{net.ParseIP(s.Address().Hostname())}, + hostnames: hostnames, + ips: ips, signerName: "k8sd.io/kubelet-serving", certificate: &certificates.KubeletCert, key: &certificates.KubeletKey, @@ -272,14 +292,34 @@ func refreshCertsRunWorker(s state.State, r *http.Request, snap snap.Snap) respo return fmt.Errorf("failed to generate CSR for %s: %w", csr.name, err) } + // Obtain the SHA256 sum of the CSR request. + hash := sha256.New() + _, err = hash.Write([]byte(csrPEM)) + if err != nil { + return fmt.Errorf("failed to checksum CSR %s, err: %w", csr.name, err) + } + + signature, err := rsa.EncryptPKCS1v15(rand.Reader, k8sdPublicKey, hash.Sum(nil)) + if err != nil { + return fmt.Errorf("failed to sign CSR %s, err: %w", csr.name, err) + } + signatureB64 := base64.StdEncoding.EncodeToString(signature) + + expirationSeconds := int32(req.ExpirationSeconds) + if _, err = client.CertificatesV1().CertificateSigningRequests().Create(ctx, &certv1.CertificateSigningRequest{ ObjectMeta: metav1.ObjectMeta{ Name: csr.name, + Annotations: map[string]string{ + "k8sd.io/signature": signatureB64, + "k8sd.io/node": snap.Hostname(), + }, }, Spec: certv1.CertificateSigningRequestSpec{ - Request: []byte(csrPEM), - Usages: csr.usages, - SignerName: csr.signerName, + Request: []byte(csrPEM), + ExpirationSeconds: &expirationSeconds, + Usages: csr.usages, + SignerName: csr.signerName, }, }, metav1.CreateOptions{}); err != nil { return fmt.Errorf("failed to create CSR for %s: %w", csr.name, err) diff --git a/src/k8s/pkg/k8sd/controllers/csrsigning/reconcile.go b/src/k8s/pkg/k8sd/controllers/csrsigning/reconcile.go index 95ccbbb95..779f10777 100644 --- a/src/k8s/pkg/k8sd/controllers/csrsigning/reconcile.go +++ b/src/k8s/pkg/k8sd/controllers/csrsigning/reconcile.go @@ -9,6 +9,7 @@ import ( "fmt" "time" + "github.com/canonical/k8s/pkg/utils" pkiutil "github.com/canonical/k8s/pkg/utils/pki" certv1 "k8s.io/api/certificates/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -96,6 +97,15 @@ func (r *csrSigningReconciler) Reconcile(ctx context.Context, req ctrl.Request) return ctrl.Result{}, err } + notBefore := time.Now() + var notAfter time.Time + + if obj.Spec.ExpirationSeconds != nil { + notAfter = utils.SecondsToExpirationDate(notBefore, int(*obj.Spec.ExpirationSeconds)) + } else { + notAfter = time.Now().AddDate(10, 0, 0) + } + var crtPEM []byte switch obj.Spec.SignerName { case "k8sd.io/kubelet-serving": @@ -114,8 +124,8 @@ func (r *csrSigningReconciler) Reconcile(ctx context.Context, req ctrl.Request) CommonName: obj.Spec.Username, Organization: obj.Spec.Groups, }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), // TODO: expiration date from obj, or config + NotBefore: notBefore, + NotAfter: notAfter, IPAddresses: certRequest.IPAddresses, DNSNames: certRequest.DNSNames, BasicConstraintsValid: true, @@ -149,8 +159,8 @@ func (r *csrSigningReconciler) Reconcile(ctx context.Context, req ctrl.Request) CommonName: obj.Spec.Username, Organization: obj.Spec.Groups, }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), // TODO: expiration date from obj, or config + NotBefore: notBefore, + NotAfter: notAfter, BasicConstraintsValid: true, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, @@ -181,8 +191,8 @@ func (r *csrSigningReconciler) Reconcile(ctx context.Context, req ctrl.Request) Subject: pkix.Name{ CommonName: "system:kube-proxy", }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), // TODO: expiration date from obj, or config + NotBefore: notBefore, + NotAfter: notAfter, BasicConstraintsValid: true, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, diff --git a/src/k8s/pkg/k8sd/controllers/csrsigning/validate.go b/src/k8s/pkg/k8sd/controllers/csrsigning/validate.go index d3d9df0a9..4e0c9b414 100644 --- a/src/k8s/pkg/k8sd/controllers/csrsigning/validate.go +++ b/src/k8s/pkg/k8sd/controllers/csrsigning/validate.go @@ -4,6 +4,7 @@ import ( "crypto/rsa" "crypto/sha256" "crypto/subtle" + "encoding/base64" "fmt" "github.com/canonical/k8s/pkg/utils" @@ -21,7 +22,12 @@ func validateCSR(obj *certv1.CertificateSigningRequest, priv *rsa.PrivateKey) er return fmt.Errorf("failed to parse x509 certificate request: %w", err) } - encryptedSignature := obj.Annotations["k8sd.io/signature"] + encryptedSignatureB64 := obj.Annotations["k8sd.io/signature"] + encryptedSignature, err := base64.StdEncoding.DecodeString(encryptedSignatureB64) + if err != nil { + return fmt.Errorf("failed to decode b64 signature: %w", err) + } + signature, err := rsa.DecryptPKCS1v15(nil, priv, []byte(encryptedSignature)) if err != nil { return fmt.Errorf("failed to decrypt signature: %w", err) diff --git a/src/k8s/pkg/k8sd/controllers/csrsigning/validate_test.go b/src/k8s/pkg/k8sd/controllers/csrsigning/validate_test.go index 7c806a484..7e9a919a8 100644 --- a/src/k8s/pkg/k8sd/controllers/csrsigning/validate_test.go +++ b/src/k8s/pkg/k8sd/controllers/csrsigning/validate_test.go @@ -5,6 +5,7 @@ import ( "crypto/rsa" "crypto/sha256" "crypto/x509/pkix" + "encoding/base64" "testing" pkiutil "github.com/canonical/k8s/pkg/utils/pki" @@ -93,7 +94,7 @@ func TestValidateCSREncryption(t *testing.T) { }, }, expectErr: true, - expectErrMessage: "failed to decrypt signature", + expectErrMessage: "failed to decode b64 signature", }, { name: "Missing Signature", @@ -219,5 +220,5 @@ func mustCreateEncryptedSignature(g Gomega, pub *rsa.PublicKey, csrPEM string) s signature, err := rsa.EncryptPKCS1v15(rand.Reader, pub, hash.Sum(nil)) g.Expect(err).NotTo(HaveOccurred()) - return string(signature) + return base64.StdEncoding.EncodeToString(signature) } diff --git a/tests/integration/requirements-test.txt b/tests/integration/requirements-test.txt index 91282e09c..0fcd9c093 100644 --- a/tests/integration/requirements-test.txt +++ b/tests/integration/requirements-test.txt @@ -3,3 +3,4 @@ pytest==7.3.1 PyYAML==6.0.1 tenacity==8.2.3 pylint==3.2.5 +cryptography==43.0.3 diff --git a/tests/integration/templates/bootstrap-csr-auto-approve.yaml b/tests/integration/templates/bootstrap-csr-auto-approve.yaml new file mode 100644 index 000000000..43fe77c98 --- /dev/null +++ b/tests/integration/templates/bootstrap-csr-auto-approve.yaml @@ -0,0 +1,9 @@ +cluster-config: + network: + enabled: true + dns: + enabled: true + metrics-server: + enabled: true + annotations: + k8sd/v1alpha1/csrsigning/auto-approve: true diff --git a/tests/integration/tests/test_clustering.py b/tests/integration/tests/test_clustering.py index a77e3f9c5..8235e56cf 100644 --- a/tests/integration/tests/test_clustering.py +++ b/tests/integration/tests/test_clustering.py @@ -1,10 +1,16 @@ # # Copyright 2024 Canonical, Ltd. # +import datetime import logging +import os +import subprocess +import tempfile from typing import List import pytest +from cryptography import x509 +from cryptography.hazmat.backends import default_backend from test_util import config, harness, util LOG = logging.getLogger(__name__) @@ -228,3 +234,85 @@ def test_join_with_custom_token_name(instances: List[harness.Instance]): cluster_node.exec(["k8s", "remove-node", joining_cp_with_hostname.id]) nodes = util.ready_nodes(cluster_node) assert len(nodes) == 1, "cp node with hostname should be removed from the cluster" + + +@pytest.mark.node_count(2) +@pytest.mark.bootstrap_config( + (config.MANIFESTS_DIR / "bootstrap-csr-auto-approve.yaml").read_text() +) +def test_cert_refresh(instances: List[harness.Instance]): + cluster_node = instances[0] + joining_worker = instances[1] + + join_token_worker = util.get_join_token(cluster_node, joining_worker, "--worker") + util.join_cluster(joining_worker, join_token_worker) + + util.wait_until_k8s_ready(cluster_node, instances) + nodes = util.ready_nodes(cluster_node) + assert len(nodes) == 2, "nodes should have joined cluster" + + assert "control-plane" in util.get_local_node_status(cluster_node) + assert "worker" in util.get_local_node_status(joining_worker) + + extra_san = "test_san.local" + + def _check_cert(instance, cert_fname): + # Ensure that the certificate was refreshed, having the right expiry date + # and extra SAN. + cert_dir = _get_k8s_cert_dir(instance) + cert_path = os.path.join(cert_dir, cert_fname) + + cert = _get_instance_cert(instance, cert_path) + date = datetime.datetime.now() + assert (cert.not_valid_after - date).days in (364, 365) + + san = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName) + san_dns_names = san.value.get_values_for_type(x509.DNSName) + assert extra_san in san_dns_names + + joining_worker.exec( + ["k8s", "refresh-certs", "--expires-in", "1y", "--extra-sans", extra_san] + ) + + _check_cert(joining_worker, "kubelet.crt") + + cluster_node.exec( + ["k8s", "refresh-certs", "--expires-in", "1y", "--extra-sans", extra_san] + ) + + _check_cert(cluster_node, "kubelet.crt") + _check_cert(cluster_node, "apiserver.crt") + + # Ensure that the services come back online after refreshing the certificates. + util.wait_until_k8s_ready(cluster_node, instances) + + +def _get_k8s_cert_dir(instance: harness.Instance): + tested_paths = [ + "/etc/kubernetes/pki/", + "/var/snap/k8s/common/etc/kubernetes/pki/", + ] + for path in tested_paths: + if _instance_path_exists(instance, path): + return path + + raise Exception("Could not find k8s certificates dir.") + + +def _instance_path_exists(instance: harness.Instance, remote_path: str): + try: + instance.exec(["ls", remote_path]) + return True + except subprocess.CalledProcessError: + return False + + +def _get_instance_cert( + instance: harness.Instance, remote_path: str +) -> x509.Certificate: + with tempfile.NamedTemporaryFile() as fp: + instance.pull_file(remote_path, fp.name) + + pem = fp.read() + cert = x509.load_pem_x509_certificate(pem, default_backend()) + return cert diff --git a/tests/integration/tox.ini b/tests/integration/tox.ini index b59d696d7..1b33bcda9 100644 --- a/tests/integration/tox.ini +++ b/tests/integration/tox.ini @@ -46,6 +46,6 @@ passenv = [flake8] max-line-length = 120 select = E,W,F,C,N -ignore = W503 +ignore = W503,E231,E226 exclude = venv,.git,.tox,.tox_env,.venv,build,dist,*.egg_info show-source = true