Skip to content

Commit

Permalink
add check for SA if not found in cache but expected to be there
Browse files Browse the repository at this point in the history
Signed-off-by: Joshua Silverio <[email protected]>
  • Loading branch information
jsilverio22 committed Mar 6, 2023
1 parent c33767b commit dbe77b6
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 29 deletions.
24 changes: 14 additions & 10 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type CacheResponse struct {

type ServiceAccountCache interface {
Start(stop chan struct{})
Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64)
Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, err error)
// ToJSON returns cache contents as JSON string
ToJSON() string
}
Expand Down Expand Up @@ -76,22 +76,26 @@ func init() {
// Get will return the cached configuration of the given ServiceAccount.
// It will first look at the set of ServiceAccounts configured using annotations. If none are found, it will look for any
// ServiceAccount configured through the pod-identity-webhook ConfigMap.
func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) {
func (c *serviceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, err error) {
klog.V(5).Infof("Fetching sa %s/%s from cache", namespace, name)
var respSA *CacheResponse
{
resp := c.getSA(name, namespace)
if resp != nil && resp.RoleARN != "" {
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration
respSA = c.getSA(name, namespace)
if respSA != nil && respSA.RoleARN != "" {
return respSA.RoleARN, respSA.Audience, respSA.UseRegionalSTS, respSA.TokenExpiration, nil
}
}
{
resp := c.getCM(name, namespace)
if resp != nil {
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration
respCM := c.getCM(name, namespace)
if respCM != nil {
return respCM.RoleARN, respCM.Audience, respCM.UseRegionalSTS, respCM.TokenExpiration, nil
}
}
klog.V(5).Infof("Service account %s/%s not found in cache", namespace, name)
return "", "", false, pkg.DefaultTokenExpiration
if respSA == nil {
return "", "", false, pkg.DefaultTokenExpiration, fmt.Errorf("service account %s/%s not found in cache and one is expected", namespace, name)
}

return "", "", false, pkg.DefaultTokenExpiration, nil
}

func (c *serviceAccountCache) getSA(name, namespace string) *CacheResponse {
Expand Down
56 changes: 43 additions & 13 deletions pkg/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@ func TestSaCache(t *testing.T) {
webhookUsage: prometheus.NewGauge(prometheus.GaugeOpts{}),
}

role, aud, useRegionalSTS, tokenExpiration := cache.Get("default", "default")

role, aud, useRegionalSTS, tokenExpiration, err := cache.Get("default", "default")
if err == nil {
t.Fatal("Expected err to not be empty")
}
if role != "" || aud != "" {
t.Errorf("Expected role and aud to be empty, got %s, %s, %t, %d", role, aud, useRegionalSTS, tokenExpiration)
}

cache.addSA(testSA)

role, aud, useRegionalSTS, tokenExpiration = cache.Get("default", "default")
role, aud, useRegionalSTS, tokenExpiration, err = cache.Get("default", "default")

assert.Equal(t, roleArn, role, "Expected role to be %s, got %s", roleArn, role)
assert.Equal(t, "sts.amazonaws.com", aud, "Expected aud to be sts.amzonaws.com, got %s", aud)
Expand Down Expand Up @@ -154,7 +156,10 @@ func TestNonRegionalSTS(t *testing.T) {
t.Fatalf("cache never called addSA: %v", err)
}

gotRoleArn, gotAudience, useRegionalSTS, gotTokenExpiration := cache.Get("default", "default")
gotRoleArn, gotAudience, useRegionalSTS, gotTokenExpiration, err := cache.Get("default", "default")
if err != nil {
t.Fatal(err)
}
if gotRoleArn != roleArn {
t.Errorf("got roleArn %v, expected %v", gotRoleArn, roleArn)
}
Expand Down Expand Up @@ -199,7 +204,10 @@ func TestPopulateCacheFromCM(t *testing.T) {
t.Errorf("failed to build cache: %v", err)
}

role, _, _, _ := c.Get("mysa2", "myns2")
role, _, _, _, err := c.Get("mysa2", "myns2")
if err != nil {
t.Fatal(err)
}
if role == "" {
t.Errorf("cloud not find entry that should have been added")
}
Expand All @@ -211,7 +219,10 @@ func TestPopulateCacheFromCM(t *testing.T) {
t.Errorf("failed to build cache: %v", err)
}

role, _, _, _ := c.Get("mysa2", "myns2")
role, _, _, _, err := c.Get("mysa2", "myns2")
if err != nil {
t.Fatal(err)
}
if role == "" {
t.Errorf("cloud not find entry that should have been added")
}
Expand All @@ -223,7 +234,8 @@ func TestPopulateCacheFromCM(t *testing.T) {
t.Errorf("failed to build cache: %v", err)
}

role, _, _, _ := c.Get("mysa2", "myns2")
role, _, _, _, _ := c.Get("mysa2", "myns2")

if role != "" {
t.Errorf("found entry that should have been removed")
}
Expand Down Expand Up @@ -253,7 +265,10 @@ func TestSAAnnotationRemoval(t *testing.T) {
c.addSA(oldSA)

{
gotRoleArn, _, _, _ := c.Get("default", "default")
gotRoleArn, _, _, _, err := c.Get("default", "default")
if err != nil {
t.Fatal(err)
}
if gotRoleArn != roleArn {
t.Errorf("got roleArn %q, expected %q", gotRoleArn, roleArn)
}
Expand All @@ -265,7 +280,10 @@ func TestSAAnnotationRemoval(t *testing.T) {
c.addSA(newSA)

{
gotRoleArn, _, _, _ := c.Get("default", "default")
gotRoleArn, _, _, _, err := c.Get("default", "default")
if err != nil {
t.Fatal(err)
}
if gotRoleArn != "" {
t.Errorf("got roleArn %v, expected %q", gotRoleArn, "")
}
Expand Down Expand Up @@ -320,7 +338,10 @@ func TestCachePrecedence(t *testing.T) {
t.Errorf("failed to build cache: %v", err)
}

role, _, _, exp := c.Get("mysa2", "myns2")
role, _, _, exp, err := c.Get("mysa2", "myns2")
if err != nil {
t.Fatal(err)
}
if role == "" {
t.Errorf("could not find entry that should have been added")
}
Expand All @@ -337,7 +358,10 @@ func TestCachePrecedence(t *testing.T) {
}

// Removing sa2 from CM, but SA still exists
role, _, _, exp := c.Get("mysa2", "myns2")
role, _, _, exp, err := c.Get("mysa2", "myns2")
if err != nil {
t.Fatal(err)
}
if role == "" {
t.Errorf("could not find entry that should still exist")
}
Expand All @@ -353,7 +377,10 @@ func TestCachePrecedence(t *testing.T) {
c.addSA(sa2)

// Neither cache should return any hits now
role, _, _, _ := c.Get("myns2", "mysa2")
role, _, _, _, err := c.Get("myns2", "mysa2")
if err == nil {
t.Errorf("found entry that should not exist")
}
if role != "" {
t.Errorf("found entry that should not exist")
}
Expand All @@ -367,7 +394,10 @@ func TestCachePrecedence(t *testing.T) {
t.Errorf("failed to build cache: %v", err)
}

role, _, _, exp := c.Get("mysa2", "myns2")
role, _, _, exp, err := c.Get("mysa2", "myns2")
if err != nil {
t.Fatal(err)
}
if role == "" {
t.Errorf("cloud not find entry that should have been added")
}
Expand Down
9 changes: 5 additions & 4 deletions pkg/cache/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package cache

import (
"encoding/json"
"k8s.io/api/core/v1"
"strconv"
"sync"

"k8s.io/api/core/v1"

"github.com/aws/amazon-eks-pod-identity-webhook/pkg"
)

Expand Down Expand Up @@ -44,14 +45,14 @@ var _ ServiceAccountCache = &FakeServiceAccountCache{}
func (f *FakeServiceAccountCache) Start(chan struct{}) {}

// Get gets a service account from the cache
func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64) {
func (f *FakeServiceAccountCache) Get(name, namespace string) (role, aud string, useRegionalSTS bool, tokenExpiration int64, err error) {
f.mu.RLock()
defer f.mu.RUnlock()
resp, ok := f.cache[namespace+"/"+name]
if !ok {
return "", "", false, pkg.DefaultTokenExpiration
return "", "", false, pkg.DefaultTokenExpiration, nil
}
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration
return resp.RoleARN, resp.Audience, resp.UseRegionalSTS, resp.TokenExpiration, nil
}

// Add adds a cache entry
Expand Down
17 changes: 15 additions & 2 deletions pkg/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,22 @@ func (m *Modifier) MutatePod(ar *v1beta1.AdmissionReview) *v1beta1.AdmissionResp
// audience: serviceaccount annotation > flag
// regionalSTS: serviceaccount annotation > flag
// tokenExpiration: pod annotation > serviceaccount annotation > flag
podRole, audience, regionalSTS, tokenExpiration := m.Cache.Get(pod.Spec.ServiceAccountName, pod.Namespace)

podRole, audience, regionalSTS, tokenExpiration, err := m.Cache.Get(pod.Spec.ServiceAccountName, pod.Namespace)
// determine whether to perform mutation
if err != nil {
klog.Errorf("Pod was not mutated. Reason: "+
"Service account was not found in cache and was expected. %s",
logContext(pod.Name,
pod.GenerateName,
pod.Spec.ServiceAccountName,
pod.Namespace))
return &v1beta1.AdmissionResponse{
Allowed: false,
Result: &metav1.Status{
Message: err.Error(),
},
}
}
if podRole == "" {
klog.V(4).Infof("Pod was not mutated. Reason: "+
"Service account did not have the right annotations or was not found in the cache. %s",
Expand Down

0 comments on commit dbe77b6

Please sign in to comment.