Skip to content

Commit

Permalink
Add option to filter certificates by tag before adding it to LB (#658)
Browse files Browse the repository at this point in the history
* Add option to filter certificates by tag before adding it to LB

Signed-off-by: Lucas Thiesen <[email protected]>
  • Loading branch information
lucastt authored Nov 6, 2023
1 parent 04144f2 commit f2f28dc
Show file tree
Hide file tree
Showing 8 changed files with 310 additions and 30 deletions.
39 changes: 34 additions & 5 deletions aws/acm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package aws

import (
"crypto/x509"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/acm"
Expand All @@ -10,16 +11,17 @@ import (
)

type acmCertificateProvider struct {
api acmiface.ACMAPI
api acmiface.ACMAPI
filterTag string
}

func newACMCertProvider(api acmiface.ACMAPI) certs.CertificatesProvider {
return &acmCertificateProvider{api: api}
func newACMCertProvider(api acmiface.ACMAPI, certFilterTag string) certs.CertificatesProvider {
return &acmCertificateProvider{api: api, filterTag: certFilterTag}
}

// GetCertificates returns a list of AWS ACM certificates
func (p *acmCertificateProvider) GetCertificates() ([]*certs.CertificateSummary, error) {
acmSummaries, err := getACMCertificateSummaries(p.api)
acmSummaries, err := getACMCertificateSummaries(p.api, p.filterTag)
if err != nil {
return nil, err
}
Expand All @@ -34,20 +36,47 @@ func (p *acmCertificateProvider) GetCertificates() ([]*certs.CertificateSummary,
return result, nil
}

func getACMCertificateSummaries(api acmiface.ACMAPI) ([]*acm.CertificateSummary, error) {
func getACMCertificateSummaries(api acmiface.ACMAPI, filterTag string) ([]*acm.CertificateSummary, error) {
params := &acm.ListCertificatesInput{
CertificateStatuses: []*string{
aws.String(acm.CertificateStatusIssued),
},
}
acmSummaries := make([]*acm.CertificateSummary, 0)

err := api.ListCertificatesPages(params, func(page *acm.ListCertificatesOutput, lastPage bool) bool {
acmSummaries = append(acmSummaries, page.CertificateSummaryList...)
return true
})

if tag := strings.Split(filterTag, "="); filterTag != "=" && len(tag) == 2 {
return filterCertificatesByTag(api, acmSummaries, tag[0], tag[1])
}

return acmSummaries, err
}

func filterCertificatesByTag(api acmiface.ACMAPI, allSummaries []*acm.CertificateSummary, key, value string) ([]*acm.CertificateSummary, error) {
prodSummaries := make([]*acm.CertificateSummary, 0)
for _, summary := range allSummaries {
in := &acm.ListTagsForCertificateInput{
CertificateArn: summary.CertificateArn,
}
out, err := api.ListTagsForCertificate(in)
if err != nil {
return nil, err
}

for _, tag := range out.Tags {
if *tag.Key == key && *tag.Value == value {
prodSummaries = append(prodSummaries, summary)
}
}
}

return prodSummaries, nil
}

func getCertificateSummaryFromACM(api acmiface.ACMAPI, arn *string) (*certs.CertificateSummary, error) {
params := &acm.GetCertificateInput{CertificateArn: arn}
resp, err := api.GetCertificate(params)
Expand Down
106 changes: 93 additions & 13 deletions aws/acm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@ type acmExpect struct {
DomainNames []string
Chain int
Error error
EmptyList bool
}

func TestACM(t *testing.T) {
cert := mustRead("acm.txt")
chain := mustRead("chain.txt")

for _, ti := range []struct {
msg string
api acmiface.ACMAPI
expect acmExpect
msg string
api acmiface.ACMAPI
filterTag string
expect acmExpect
}{
{
msg: "Found ACM Cert foobar and a chain",
Expand All @@ -37,9 +39,11 @@ func TestACM(t *testing.T) {
},
},
},
acm.GetCertificateOutput{
Certificate: aws.String(cert),
CertificateChain: aws.String(chain),
map[string]*acm.GetCertificateOutput{
"foobar": {
Certificate: aws.String(cert),
CertificateChain: aws.String(chain),
},
},
),
expect: acmExpect{
Expand All @@ -59,19 +63,90 @@ func TestACM(t *testing.T) {
},
},
},
acm.GetCertificateOutput{
Certificate: aws.String(cert),
map[string]*acm.GetCertificateOutput{
"foobar": {
Certificate: aws.String(cert),
},
},
),
expect: acmExpect{
ARN: "foobar",
DomainNames: []string{"foobar.de"},
Error: nil,
},
},
{
msg: "Found one ACM Cert with correct filter tag",
api: fake.NewACMClientWithTags(
acm.ListCertificatesOutput{
CertificateSummaryList: []*acm.CertificateSummary{
{
CertificateArn: aws.String("foobar"),
DomainName: aws.String("foobar.de"),
},
{
CertificateArn: aws.String("foobaz"),
DomainName: aws.String("foobar.de"),
},
},
},
map[string]*acm.GetCertificateOutput{
"foobar": {
Certificate: aws.String(cert),
},
"foobaz": {
Certificate: aws.String(cert),
},
},
map[string]*acm.ListTagsForCertificateOutput{
"foobar": {
Tags: []*acm.Tag{{Key: aws.String("production"), Value: aws.String("true")}},
},
"foobaz": {
Tags: []*acm.Tag{{Key: aws.String("production"), Value: aws.String("false")}},
},
},
),
filterTag: "production=true",
expect: acmExpect{
ARN: "foobar",
DomainNames: []string{"foobar.de"},
Error: nil,
},
},
{
msg: "ACM Cert with incorrect filter tag should not be found",
api: fake.NewACMClientWithTags(
acm.ListCertificatesOutput{
CertificateSummaryList: []*acm.CertificateSummary{
{
CertificateArn: aws.String("foobar"),
DomainName: aws.String("foobar.de"),
},
},
},
map[string]*acm.GetCertificateOutput{
"foobar": {
Certificate: aws.String(cert),
},
},
map[string]*acm.ListTagsForCertificateOutput{
"foobar": {
Tags: []*acm.Tag{{Key: aws.String("production"), Value: aws.String("false")}},
},
},
),
filterTag: "production=true",
expect: acmExpect{
EmptyList: true,
ARN: "foobar",
DomainNames: []string{"foobar.de"},
Error: nil,
},
},
} {
t.Run(ti.msg, func(t *testing.T) {
provider := newACMCertProvider(ti.api)
provider := newACMCertProvider(ti.api, ti.filterTag)
list, err := provider.GetCertificates()

if ti.expect.Error != nil {
Expand All @@ -80,11 +155,16 @@ func TestACM(t *testing.T) {
require.NoError(t, err)
}

require.Equal(t, 1, len(list))
if ti.expect.EmptyList {
require.Equal(t, 0, len(list))

cert := list[0]
require.Equal(t, ti.expect.ARN, cert.ID())
require.Equal(t, ti.expect.DomainNames, cert.DomainNames())
} else {
require.Equal(t, 1, len(list))

cert := list[0]
require.Equal(t, ti.expect.ARN, cert.ID())
require.Equal(t, ti.expect.DomainNames, cert.DomainNames())
}
})
}
}
8 changes: 4 additions & 4 deletions aws/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,12 @@ func (a *Adapter) UpdateManifest(clusterID, vpcID string) (*Adapter, error) {
return a, err
}

func (a *Adapter) NewACMCertificateProvider() certs.CertificatesProvider {
return newACMCertProvider(a.acm)
func (a *Adapter) NewACMCertificateProvider(certFilterTag string) certs.CertificatesProvider {
return newACMCertProvider(a.acm, certFilterTag)
}

func (a *Adapter) NewIAMCertificateProvider() certs.CertificatesProvider {
return newIAMCertProvider(a.iam)
func (a *Adapter) NewIAMCertificateProvider(certFilterTag string) certs.CertificatesProvider {
return newIAMCertProvider(a.iam, certFilterTag)
}

// WithHealthCheckPath returns the receiver adapter after changing the health check path that will be used by
Expand Down
29 changes: 26 additions & 3 deletions aws/fake/acm.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
package fake

import (
"fmt"

"github.com/aws/aws-sdk-go/service/acm"
"github.com/aws/aws-sdk-go/service/acm/acmiface"
)

type ACMClient struct {
acmiface.ACMAPI
output acm.ListCertificatesOutput
cert acm.GetCertificateOutput
cert map[string]*acm.GetCertificateOutput
tags map[string]*acm.ListTagsForCertificateOutput
}

func (m ACMClient) ListCertificates(in *acm.ListCertificatesInput) (*acm.ListCertificatesOutput, error) {
Expand All @@ -21,12 +24,32 @@ func (m ACMClient) ListCertificatesPages(input *acm.ListCertificatesInput, fn fu
}

func (m ACMClient) GetCertificate(input *acm.GetCertificateInput) (*acm.GetCertificateOutput, error) {
return &m.cert, nil
return m.cert[*input.CertificateArn], nil
}

func (m ACMClient) ListTagsForCertificate(in *acm.ListTagsForCertificateInput) (*acm.ListTagsForCertificateOutput, error) {
if in.CertificateArn == nil {
return nil, fmt.Errorf("expected a valid CertificateArn, got: nil")
}
arn := *in.CertificateArn
return m.tags[arn], nil
}

func NewACMClient(output acm.ListCertificatesOutput, cert map[string]*acm.GetCertificateOutput) ACMClient {
return ACMClient{
output: output,
cert: cert,
}
}

func NewACMClient(output acm.ListCertificatesOutput, cert acm.GetCertificateOutput) ACMClient {
func NewACMClientWithTags(
output acm.ListCertificatesOutput,
cert map[string]*acm.GetCertificateOutput,
tags map[string]*acm.ListTagsForCertificateOutput,
) ACMClient {
return ACMClient{
output: output,
cert: cert,
tags: tags,
}
}
26 changes: 26 additions & 0 deletions aws/fake/iam.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package fake

import (
"fmt"

"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
)
Expand All @@ -9,6 +11,7 @@ type IAMClient struct {
iamiface.IAMAPI
list iam.ListServerCertificatesOutput
cert iam.GetServerCertificateOutput
tags map[string]*iam.ListServerCertificateTagsOutput
}

func (m IAMClient) ListServerCertificates(*iam.ListServerCertificatesInput) (*iam.ListServerCertificatesOutput, error) {
Expand All @@ -20,6 +23,17 @@ func (m IAMClient) ListServerCertificatesPages(input *iam.ListServerCertificates
return nil
}

func (m IAMClient) ListServerCertificateTags(
in *iam.ListServerCertificateTagsInput,
) (*iam.ListServerCertificateTagsOutput, error) {

if in.ServerCertificateName == nil {
return nil, fmt.Errorf("expected a valid CertificateArn, got: nil")
}
name := *in.ServerCertificateName
return m.tags[name], nil
}

func (m IAMClient) GetServerCertificate(*iam.GetServerCertificateInput) (*iam.GetServerCertificateOutput, error) {
return &m.cert, nil
}
Expand All @@ -30,3 +44,15 @@ func NewIAMClient(list iam.ListServerCertificatesOutput, cert iam.GetServerCerti
cert: cert,
}
}

func NewIAMClientWithTag(
list iam.ListServerCertificatesOutput,
cert iam.GetServerCertificateOutput,
tags map[string]*iam.ListServerCertificateTagsOutput,
) IAMClient {
return IAMClient{
list: list,
cert: cert,
tags: tags,
}
}
Loading

0 comments on commit f2f28dc

Please sign in to comment.