Skip to content

Commit

Permalink
refactor: shrink connectionInfoCache interface (#553)
Browse files Browse the repository at this point in the history
This commit moves connection tracking to the dialer and reduces the size
of the connectionInfoCache interface. This change makes it easier to
implement alternate caches (e.g., a lazy refresh cache). In addition,
this commit renames some variables to match the new names.
enocom authored Apr 18, 2024
1 parent 2dd2371 commit 2d4b79a
Showing 3 changed files with 47 additions and 42 deletions.
69 changes: 39 additions & 30 deletions dialer.go
Original file line number Diff line number Diff line change
@@ -78,19 +78,24 @@ func getDefaultKeys() (*rsa.PrivateKey, error) {
}

type connectionInfoCache interface {
OpenConns() *uint64
ConnectionInfo(context.Context) (alloydb.ConnectionInfo, error)
ForceRefresh()
io.Closer
}

// monitoredCache is a wrapper around a connectionInfoCache that tracks the
// number of connections to the associated instance.
type monitoredCache struct {
openConns uint64
connectionInfoCache
}

// A Dialer is used to create connections to AlloyDB instance.
//
// Use NewDialer to initialize a Dialer.
type Dialer struct {
lock sync.RWMutex
// instances map instance URIs to *alloydb.Instance types
instances map[alloydb.InstanceURI]connectionInfoCache
lock sync.RWMutex
cache map[alloydb.InstanceURI]monitoredCache
key *rsa.PrivateKey
refreshTimeout time.Duration
// closed reports if the dialer has been closed.
@@ -180,7 +185,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
}
d := &Dialer{
closed: make(chan struct{}),
instances: make(map[alloydb.InstanceURI]connectionInfoCache),
cache: make(map[alloydb.InstanceURI]monitoredCache),
key: cfg.rsaKey,
refreshTimeout: cfg.refreshTimeout,
client: client,
@@ -226,12 +231,12 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)

var endInfo trace.EndSpanFunc
ctx, endInfo = trace.StartSpan(ctx, "cloud.google.com/go/alloydbconn/internal.InstanceInfo")
i, err := d.instance(inst)
cache, err := d.connectionInfoCache(inst)
if err != nil {
endInfo(err)
return nil, err
}
ci, err := i.ConnectionInfo(ctx)
ci, err := cache.ConnectionInfo(ctx)
if err != nil {
d.lock.Lock()
defer d.lock.Unlock()
@@ -241,8 +246,8 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
err,
)
// Stop all background refreshes
i.Close()
delete(d.instances, inst)
cache.Close()
delete(d.cache, inst)
endInfo(err)
return nil, err
}
@@ -255,9 +260,9 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
// So check that the certificate is valid before proceeding.
if invalidClientCert(inst, d.logger, ci.Expiration) {
d.logger.Debugf("[%v] Refreshing certificate now", inst.String())
i.ForceRefresh()
cache.ForceRefresh()
// Block on refreshed connection info
ci, err = i.ConnectionInfo(ctx)
ci, err = cache.ConnectionInfo(ctx)
if err != nil {
d.lock.Lock()
defer d.lock.Unlock()
@@ -267,8 +272,8 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
err,
)
// Stop all background refreshes
i.Close()
delete(d.instances, inst)
cache.Close()
delete(d.cache, inst)
return nil, err
}
}
@@ -294,7 +299,7 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
if err != nil {
d.logger.Debugf("[%v] Dialing %v failed: %v", inst.String(), hostPort, err)
// refresh the instance info in case it caused the connection failure
i.ForceRefresh()
cache.ForceRefresh()
return nil, errtype.NewDialError("failed to dial", inst.String(), err)
}
if c, ok := conn.(*net.TCPConn); ok {
@@ -331,7 +336,7 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
if err := tlsConn.HandshakeContext(ctx); err != nil {
d.logger.Debugf("[%v] TLS handshake failed: %v", inst.String(), err)
// refresh the instance info in case it caused the handshake failure
i.ForceRefresh()
cache.ForceRefresh()
_ = tlsConn.Close() // best effort close attempt
return nil, errtype.NewDialError("handshake failed", inst.String(), err)
}
@@ -346,13 +351,13 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)

latency := time.Since(startTime).Milliseconds()
go func() {
n := atomic.AddUint64(i.OpenConns(), 1)
n := atomic.AddUint64(&cache.openConns, 1)
trace.RecordOpenConnections(ctx, int64(n), d.dialerID, inst.String())
trace.RecordDialLatency(ctx, instance, d.dialerID, latency)
}()

return newInstrumentedConn(tlsConn, func() {
n := atomic.AddUint64(i.OpenConns(), ^uint64(0))
n := atomic.AddUint64(&cache.openConns, ^uint64(0))
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, inst.String())
}), nil
}
@@ -529,34 +534,38 @@ func (d *Dialer) Close() error {

d.lock.Lock()
defer d.lock.Unlock()
for _, i := range d.instances {
for _, i := range d.cache {
i.Close()
}
return nil
}

func (d *Dialer) instance(instance alloydb.InstanceURI) (connectionInfoCache, error) {
func (d *Dialer) connectionInfoCache(
uri alloydb.InstanceURI,
) (monitoredCache, error) {
d.lock.RLock()
i, ok := d.instances[instance]
c, ok := d.cache[uri]
d.lock.RUnlock()
if !ok {
d.lock.Lock()
defer d.lock.Unlock()
// Recheck to ensure instance wasn't created between locks
i, ok = d.instances[instance]
c, ok = d.cache[uri]
if !ok {
i = alloydb.NewRefreshAheadCache(
instance,
d.logger,
d.client, d.key,
d.refreshTimeout, d.dialerID,
)
c = monitoredCache{
connectionInfoCache: alloydb.NewRefreshAheadCache(
uri,
d.logger,
d.client, d.key,
d.refreshTimeout, d.dialerID,
),
}
d.logger.Debugf(
"[%v] Connection info added to cache",
instance.String(),
uri.String(),
)
d.instances[instance] = i
d.cache[uri] = c
}
}
return i, nil
return c, nil
}
12 changes: 8 additions & 4 deletions dialer_test.go
Original file line number Diff line number Diff line change
@@ -247,7 +247,9 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) {
err: errors.New("connect info failed"),
}},
}
d.instances[badInst] = spy
d.cache[badInst] = monitoredCache{
connectionInfoCache: spy,
}

_, err = d.Dial(context.Background(), badInstanceName)
if err == nil {
@@ -261,7 +263,7 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) {

// Now verify that bad connection name has been deleted from map.
d.lock.RLock()
_, ok := d.instances[badInst]
_, ok := d.cache[badInst]
d.lock.RUnlock()
if ok {
t.Fatal("bad instance was not removed from the cache")
@@ -297,7 +299,9 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) {
},
},
}
d.instances[cn] = spy
d.cache[cn] = monitoredCache{
connectionInfoCache: spy,
}

_, err = d.Dial(context.Background(), inst)
if !errors.Is(err, sentinel) {
@@ -317,7 +321,7 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) {

// Now verify that bad connection name has been deleted from map.
d.lock.RLock()
_, ok := d.instances[cn]
_, ok := d.cache[cn]
d.lock.RUnlock()
if ok {
t.Fatal("bad instance was not removed from the cache")
8 changes: 0 additions & 8 deletions internal/alloydb/instance.go
Original file line number Diff line number Diff line change
@@ -127,9 +127,6 @@ func (r *refreshOperation) isValid() bool {
// required information approximately 4 minutes before the previous certificate
// expires (every ~56 minutes).
type RefreshAheadCache struct {
// OpenConns is the number of open connections to the instance.
openConns uint64

instanceURI InstanceURI
logger debug.Logger
key *rsa.PrivateKey
@@ -185,11 +182,6 @@ func NewRefreshAheadCache(
return i
}

// OpenConns reports the number of open connections.
func (i *RefreshAheadCache) OpenConns() *uint64 {
return &i.openConns
}

// Close closes the instance; it stops the refresh cycle and prevents it from
// making additional calls to the AlloyDB Admin API.
func (i *RefreshAheadCache) Close() error {

0 comments on commit 2d4b79a

Please sign in to comment.