Skip to content

Commit

Permalink
Add ...WithContext() versions of GetSecret calls (#7)
Browse files Browse the repository at this point in the history
Previously, this library was using context.Background() for all network
requests. This adds APIs that take a context.
  • Loading branch information
ojrac authored Feb 13, 2023
1 parent 304869b commit f88736e
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 21 deletions.
27 changes: 23 additions & 4 deletions secretcache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
package secretcache

import (
"context"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/secretsmanager"
"github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface"
Expand Down Expand Up @@ -83,15 +86,23 @@ func (c *Cache) getCachedSecret(secretId string) *secretCacheItem {
// GetSecretString gets the secret string value from the cache for given secret id and a default version stage.
// Returns the secret string and an error if operation failed.
func (c *Cache) GetSecretString(secretId string) (string, error) {
return c.GetSecretStringWithStage(secretId, DefaultVersionStage)
return c.GetSecretStringWithContext(aws.BackgroundContext(), secretId)
}

func (c *Cache) GetSecretStringWithContext(ctx context.Context, secretId string) (string, error) {
return c.GetSecretStringWithStageWithContext(ctx, secretId, DefaultVersionStage)
}

// GetSecretStringWithStage gets the secret string value from the cache for given secret id and version stage.
// Returns the secret string and an error if operation failed.
func (c *Cache) GetSecretStringWithStage(secretId string, versionStage string) (string, error) {
return c.GetSecretStringWithStageWithContext(aws.BackgroundContext(), secretId, versionStage)
}

func (c *Cache) GetSecretStringWithStageWithContext(ctx context.Context, secretId string, versionStage string) (string, error) {
secretCacheItem := c.getCachedSecret(secretId)

getSecretValueOutput, err := secretCacheItem.getSecretValue(versionStage)
getSecretValueOutput, err := secretCacheItem.getSecretValue(ctx, versionStage)

if err != nil {
return "", err
Expand All @@ -111,15 +122,23 @@ func (c *Cache) GetSecretStringWithStage(secretId string, versionStage string) (
// GetSecretBinary gets the secret binary value from the cache for given secret id and a default version stage.
// Returns the secret binary and an error if operation failed.
func (c *Cache) GetSecretBinary(secretId string) ([]byte, error) {
return c.GetSecretBinaryWithStage(secretId, DefaultVersionStage)
return c.GetSecretBinaryWithContext(aws.BackgroundContext(), secretId)
}

func (c *Cache) GetSecretBinaryWithContext(ctx context.Context, secretId string) ([]byte, error) {
return c.GetSecretBinaryWithStageWithContext(ctx, secretId, DefaultVersionStage)
}

// GetSecretBinaryWithStage gets the secret binary value from the cache for given secret id and version stage.
// Returns the secret binary and an error if operation failed.
func (c *Cache) GetSecretBinaryWithStage(secretId string, versionStage string) ([]byte, error) {
return c.GetSecretBinaryWithStageWithContext(aws.BackgroundContext(), secretId, versionStage)
}

func (c *Cache) GetSecretBinaryWithStageWithContext(ctx context.Context, secretId string, versionStage string) ([]byte, error) {
secretCacheItem := c.getCachedSecret(secretId)

getSecretValueOutput, err := secretCacheItem.getSecretValue(versionStage)
getSecretValueOutput, err := secretCacheItem.getSecretValue(ctx, versionStage)

if err != nil {
return nil, err
Expand Down
16 changes: 8 additions & 8 deletions secretcache/cacheItem.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
package secretcache

import (
"context"
"fmt"
"math"
"math/rand"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/secretsmanager"
"github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface"
Expand Down Expand Up @@ -80,12 +80,12 @@ func (ci *secretCacheItem) getVersionId(versionStage string) (string, bool) {

// executeRefresh performs the actual refresh of the cached secret information.
// Returns the DescribeSecret API result and an error if call failed.
func (ci *secretCacheItem) executeRefresh() (*secretsmanager.DescribeSecretOutput, error) {
func (ci *secretCacheItem) executeRefresh(ctx context.Context) (*secretsmanager.DescribeSecretOutput, error) {
input := &secretsmanager.DescribeSecretInput{
SecretId: &ci.secretId,
}

result, err := ci.client.DescribeSecretWithContext(aws.BackgroundContext(), input, request.WithAppendUserAgent(userAgent()))
result, err := ci.client.DescribeSecretWithContext(ctx, input, request.WithAppendUserAgent(userAgent()))

var maxTTL int64
if ci.config.CacheItemTTL == 0 {
Expand Down Expand Up @@ -132,14 +132,14 @@ func (ci *secretCacheItem) getVersion(versionStage string) (*cacheVersion, bool)
}

// refresh the cached object when needed.
func (ci *secretCacheItem) refresh() {
func (ci *secretCacheItem) refresh(ctx context.Context) {
if !ci.isRefreshNeeded() {
return
}

ci.refreshNeeded = false

result, err := ci.executeRefresh()
result, err := ci.executeRefresh(ctx)

if err != nil {
ci.errorCount++
Expand All @@ -158,7 +158,7 @@ func (ci *secretCacheItem) refresh() {

// getSecretValue gets the cached secret value for the given version stage.
// Returns the GetSecretValue API result and an error if operation fails.
func (ci *secretCacheItem) getSecretValue(versionStage string) (*secretsmanager.GetSecretValueOutput, error) {
func (ci *secretCacheItem) getSecretValue(ctx context.Context, versionStage string) (*secretsmanager.GetSecretValueOutput, error) {
if versionStage == "" && ci.config.VersionStage == "" {
versionStage = DefaultVersionStage
} else if versionStage == "" && ci.config.VersionStage != "" {
Expand All @@ -168,7 +168,7 @@ func (ci *secretCacheItem) getSecretValue(versionStage string) (*secretsmanager.
ci.mux.Lock()
defer ci.mux.Unlock()

ci.refresh()
ci.refresh(ctx)
version, ok := ci.getVersion(versionStage)

if !ok {
Expand All @@ -183,7 +183,7 @@ func (ci *secretCacheItem) getSecretValue(versionStage string) (*secretsmanager.
}

}
return version.getSecretValue()
return version.getSecretValue(ctx)
}

// setWithHook sets the cache item's data using the CacheHook, if one is configured.
Expand Down
4 changes: 2 additions & 2 deletions secretcache/cacheObjects_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestMaxCacheTTL(t *testing.T) {
config := CacheConfig{CacheItemTTL: -1}
cacheItem.config = config

_, err := cacheItem.executeRefresh()
_, err := cacheItem.executeRefresh(aws.BackgroundContext())

if err == nil {
t.Fatalf("Expected error due to negative cache ttl")
Expand All @@ -83,7 +83,7 @@ func TestMaxCacheTTL(t *testing.T) {
config = CacheConfig{CacheItemTTL: 0}
cacheItem.config = config

_, err = cacheItem.executeRefresh()
_, err = cacheItem.executeRefresh(aws.BackgroundContext())

if err != nil {
t.Fatalf("Unexpected error on zero cache ttl")
Expand Down
14 changes: 7 additions & 7 deletions secretcache/cacheVersion.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
package secretcache

import (
"context"
"math"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/secretsmanager"
"github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface"
Expand All @@ -45,14 +45,14 @@ func (cv *cacheVersion) isRefreshNeeded() bool {
}

// refresh the cached object when needed.
func (cv *cacheVersion) refresh() {
func (cv *cacheVersion) refresh(ctx context.Context) {
if !cv.isRefreshNeeded() {
return
}

cv.refreshNeeded = false

result, err := cv.executeRefresh()
result, err := cv.executeRefresh(ctx)

if err != nil {
cv.errorCount++
Expand All @@ -72,21 +72,21 @@ func (cv *cacheVersion) refresh() {

// executeRefresh performs the actual refresh of the cached secret information.
// Returns the GetSecretValue API result and an error if operation fails.
func (cv *cacheVersion) executeRefresh() (*secretsmanager.GetSecretValueOutput, error) {
func (cv *cacheVersion) executeRefresh(ctx context.Context) (*secretsmanager.GetSecretValueOutput, error) {
input := &secretsmanager.GetSecretValueInput{
SecretId: &cv.secretId,
VersionId: &cv.versionId,
}
return cv.client.GetSecretValueWithContext(aws.BackgroundContext(), input, request.WithAppendUserAgent(userAgent()))
return cv.client.GetSecretValueWithContext(ctx, input, request.WithAppendUserAgent(userAgent()))
}

// getSecretValue gets the cached secret version value.
// Returns the GetSecretValue API cached result and an error if operation fails.
func (cv *cacheVersion) getSecretValue() (*secretsmanager.GetSecretValueOutput, error) {
func (cv *cacheVersion) getSecretValue(ctx context.Context) (*secretsmanager.GetSecretValueOutput, error) {
cv.mux.Lock()
defer cv.mux.Unlock()

cv.refresh()
cv.refresh(ctx)

return cv.getWithHook(), cv.err
}
Expand Down

0 comments on commit f88736e

Please sign in to comment.