Skip to content

Commit

Permalink
refactor: improve naming and key handling in API client (#588)
Browse files Browse the repository at this point in the history
This commit uses a more descriptive name for the code that refreshes
connection info. Instead of "refresher," this commit renames the object
to "adminAPIClient" to match what we're doing in other connectors.
Likewise "performRefresh" has been renamed "connectionInfo," a getter for
connection info.

In addition, the RSA key is no longer stored on the cache but instead
passed to the only code that needs the key: the API client.
  • Loading branch information
enocom authored Jun 6, 2024
1 parent b640ffb commit 519a9ac
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 41 deletions.
10 changes: 4 additions & 6 deletions internal/alloydb/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ const (

var (
// Instance URI is in the format:
// '/projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>'
// 'projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>'
// Additionally, we have to support legacy "domain-scoped" projects
// (e.g. "google.com:PROJECT")
instURIRegex = regexp.MustCompile("projects/([^:]+(:[^:]+)?)/locations/([^:]+)/clusters/([^:]+)/instances/([^:]+)")
Expand Down Expand Up @@ -138,13 +138,12 @@ func (r *refreshOperation) isValid() bool {
type RefreshAheadCache struct {
instanceURI InstanceURI
logger debug.ContextLogger
key *rsa.PrivateKey
// refreshTimeout sets the maximum duration a refresh cycle can run
// for.
refreshTimeout time.Duration
// l controls the rate at which refresh cycles are run.
l *rate.Limiter
r refresher
r adminAPIClient

resultGuard sync.RWMutex
// cur represents the current refreshOperation that will be used to
Expand Down Expand Up @@ -175,9 +174,8 @@ func NewRefreshAheadCache(
i := &RefreshAheadCache{
instanceURI: instance,
logger: l,
key: key,
l: rate.NewLimiter(rate.Every(refreshInterval), refreshBurst),
r: newRefresher(client, dialerID),
r: newAdminAPIClient(client, key, dialerID),
refreshTimeout: refreshTimeout,
ctx: ctx,
cancel: cancel,
Expand Down Expand Up @@ -296,7 +294,7 @@ func (i *RefreshAheadCache) scheduleRefresh(d time.Duration) *refreshOperation {
r.err,
)
} else {
r.result, r.err = i.r.performRefresh(i.ctx, i.instanceURI, i.key)
r.result, r.err = i.r.connectionInfo(i.ctx, i.instanceURI)
i.logger.Debugf(
ctx,
"[%v] Connection info refresh operation complete",
Expand Down
16 changes: 8 additions & 8 deletions internal/alloydb/instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ func genRSAKey() *rsa.PrivateKey {
return key
}

// RSAKey is used for test only.
var RSAKey = genRSAKey()
// rsaKey is used for test only.
var rsaKey = genRSAKey()

func TestParseInstURI(t *testing.T) {
tcs := []struct {
Expand All @@ -53,7 +53,7 @@ func TestParseInstURI(t *testing.T) {
}{
{
desc: "vanilla instance URI",
in: "/projects/proj/locations/reg/clusters/clust/instances/name",
in: "projects/proj/locations/reg/clusters/clust/instances/name",
want: InstanceURI{
project: "proj",
region: "reg",
Expand All @@ -63,7 +63,7 @@ func TestParseInstURI(t *testing.T) {
},
{
desc: "with legacy domain-scoped project",
in: "/projects/google.com:proj/locations/reg/clusters/clust/instances/name",
in: "projects/google.com:proj/locations/reg/clusters/clust/instances/name",
want: InstanceURI{
project: "google.com:proj",
region: "reg",
Expand Down Expand Up @@ -157,7 +157,7 @@ func TestConnectionInfo(t *testing.T) {
i := NewRefreshAheadCache(
testInstanceURI(),
nullLogger{},
c, RSAKey, 30*time.Second, "dialer-id",
c, rsaKey, 30*time.Second, "dialer-id",
)
if err != nil {
t.Fatalf("failed to create mock instance: %v", err)
Expand Down Expand Up @@ -192,7 +192,7 @@ func TestConnectionInfo(t *testing.T) {
}

func testInstanceURI() InstanceURI {
i, _ := ParseInstURI("/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance")
i, _ := ParseInstURI("projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance")
return i
}

Expand All @@ -209,7 +209,7 @@ func TestConnectInfoErrors(t *testing.T) {
i := NewRefreshAheadCache(
testInstanceURI(),
nullLogger{},
c, RSAKey, 0, "dialer-id",
c, rsaKey, 0, "dialer-id",
)
if err != nil {
t.Fatalf("failed to initialize Instance: %v", err)
Expand Down Expand Up @@ -242,7 +242,7 @@ func TestClose(t *testing.T) {
i := NewRefreshAheadCache(
testInstanceURI(),
nullLogger{},
c, RSAKey, 30, "dialer-ider",
c, rsaKey, 30, "dialer-ider",
)
if err != nil {
t.Fatalf("failed to initialize Instance: %v", err)
Expand Down
9 changes: 4 additions & 5 deletions internal/alloydb/lazy.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ import (
type LazyRefreshCache struct {
uri InstanceURI
logger debug.ContextLogger
key *rsa.PrivateKey
r refresher
r adminAPIClient
mu sync.Mutex
needsRefresh bool
cached ConnectionInfo
Expand All @@ -48,9 +47,9 @@ func NewLazyRefreshCache(
return &LazyRefreshCache{
uri: uri,
logger: l,
key: key,
r: newRefresher(
r: newAdminAPIClient(
client,
key,
dialerID,
),
}
Expand Down Expand Up @@ -84,7 +83,7 @@ func (c *LazyRefreshCache) ConnectionInfo(
"[%v] Connection info refresh operation started",
c.uri.String(),
)
ci, err := c.r.performRefresh(ctx, c.uri, c.key)
ci, err := c.r.connectionInfo(ctx, c.uri)
if err != nil {
c.logger.Debugf(
ctx,
Expand Down
4 changes: 2 additions & 2 deletions internal/alloydb/lazy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestLazyRefreshCacheConnectionInfo(t *testing.T) {
}
cache := NewLazyRefreshCache(
testInstanceURI(), nullLogger{}, c,
RSAKey, 30*time.Second, "",
rsaKey, 30*time.Second, "",
)

ci, err := cache.ConnectionInfo(context.Background())
Expand Down Expand Up @@ -90,7 +90,7 @@ func TestLazyRefreshCacheForceRefresh(t *testing.T) {
}
cache := NewLazyRefreshCache(
testInstanceURI(), nullLogger{}, c,
RSAKey, 30*time.Second, "",
rsaKey, 30*time.Second, "",
)

_, err = cache.ConnectionInfo(context.Background())
Expand Down
29 changes: 16 additions & 13 deletions internal/alloydb/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,23 +221,25 @@ func newClientCertificate(
}, nil
}

// newRefresher creates a Refresher.
func newRefresher(
func newAdminAPIClient(
client *alloydbadmin.AlloyDBAdminClient,
key *rsa.PrivateKey,
dialerID string,
) refresher {
return refresher{
) adminAPIClient {
return adminAPIClient{
client: client,
key: key,
dialerID: dialerID,
}
}

// refresher manages the AlloyDB Admin API access to instance metadata and to
// ephemeral certificates.
type refresher struct {
// adminAPIClient manages the AlloyDB Admin API access to instance metadata and
// to ephemeral certificates.
type adminAPIClient struct {
// client provides access to the AlloyDB Admin API
client *alloydbadmin.AlloyDBAdminClient

// key is used to request client certificates
key *rsa.PrivateKey
// dialerID is the unique ID of the associated dialer.
dialerID string
}
Expand All @@ -251,16 +253,17 @@ type ConnectionInfo struct {
Expiration time.Time
}

func (r refresher) performRefresh(
ctx context.Context, i InstanceURI, k *rsa.PrivateKey,
func (c adminAPIClient) connectionInfo(
ctx context.Context, i InstanceURI,
) (res ConnectionInfo, err error) {

var refreshEnd trace.EndSpanFunc
ctx, refreshEnd = trace.StartSpan(ctx, "cloud.google.com/go/alloydbconn/internal.RefreshConnection",
trace.AddInstanceName(i.String()),
)
defer func() {
go trace.RecordRefreshResult(
context.Background(), i.String(), r.dialerID, err,
context.Background(), i.String(), c.dialerID, err,
)
refreshEnd(err)
}()
Expand All @@ -272,7 +275,7 @@ func (r refresher) performRefresh(
mdCh := make(chan mdRes, 1)
go func() {
defer close(mdCh)
c, err := fetchInstanceInfo(ctx, r.client, i)
c, err := fetchInstanceInfo(ctx, c.client, i)
mdCh <- mdRes{info: c, err: err}
}()

Expand All @@ -283,7 +286,7 @@ func (r refresher) performRefresh(
certCh := make(chan certRes, 1)
go func() {
defer close(certCh)
cc, err := fetchClientCertificate(ctx, r.client, i, k)
cc, err := fetchClientCertificate(ctx, c.client, i, c.key)
certCh <- certRes{cc: cc, err: err}
}()

Expand Down
14 changes: 7 additions & 7 deletions internal/alloydb/refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestRefresh(t *testing.T) {
wantPublicIP := "127.0.0.1"
wantPSC := "x.y.alloydb.goog"
wantExpiry := time.Now().Add(time.Hour).UTC().Round(time.Second)
wantInstURI := "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance"
wantInstURI := "projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance"
cn, err := ParseInstURI(wantInstURI)
if err != nil {
t.Fatalf("parseConnName(%s)failed : %v", cn, err)
Expand Down Expand Up @@ -62,8 +62,8 @@ func TestRefresh(t *testing.T) {
if err != nil {
t.Fatalf("admin API client error: %v", err)
}
r := newRefresher(cl, testDialerID)
res, err := r.performRefresh(context.Background(), cn, RSAKey)
r := newAdminAPIClient(cl, rsaKey, testDialerID)
res, err := r.connectionInfo(context.Background(), cn)
if err != nil {
t.Fatalf("performRefresh unexpectedly failed with error: %v", err)
}
Expand Down Expand Up @@ -98,7 +98,7 @@ func TestRefresh(t *testing.T) {
}

func TestRefreshFailsFast(t *testing.T) {
wantInstURI := "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance"
wantInstURI := "projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance"
cn, err := ParseInstURI(wantInstURI)
if err != nil {
t.Fatalf("parseConnName(%s)failed : %v", cn, err)
Expand All @@ -124,17 +124,17 @@ func TestRefreshFailsFast(t *testing.T) {
if err != nil {
t.Fatalf("admin API client error: %v", err)
}
r := newRefresher(cl, testDialerID)
r := newAdminAPIClient(cl, rsaKey, testDialerID)

_, err = r.performRefresh(context.Background(), cn, RSAKey)
_, err = r.connectionInfo(context.Background(), cn)
if err != nil {
t.Fatalf("expected no error, got = %v", err)
}

ctx, cancel := context.WithCancel(context.Background())
cancel()
// context is canceled
_, err = r.performRefresh(ctx, cn, RSAKey)
_, err = r.connectionInfo(ctx, cn)
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled error, got = %v", err)
}
Expand Down

0 comments on commit 519a9ac

Please sign in to comment.