Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add option for connection check opt-out #620

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 35 additions & 21 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ type Dialer struct {
// ahead cache assumes a background goroutine may run consistently.
lazyRefresh bool

// disableMetadataExchange is a temporary addition to help clients who
// cannot use the metadata exchange yet. In future versions, this field
// should be removed.
disableMetadataExchange bool

staticConnInfo io.Reader

client *alloydbadmin.AlloyDBAdminClient
Expand Down Expand Up @@ -183,6 +188,10 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
return nil, cfg.err
}
}
if cfg.disableMetadataExchange && cfg.useIAMAuthN {
return nil, errors.New("incompatible options: WithOptOutOfAdvancedConnection " +
"check cannot be used with WithIAMAuthN")
}
userAgent := strings.Join(cfg.userAgents, " ")
// Add this to the end to make sure it's not overridden
cfg.adminOpts = append(cfg.adminOpts, option.WithUserAgent(userAgent))
Expand Down Expand Up @@ -221,21 +230,22 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
return nil, err
}
d := &Dialer{
closed: make(chan struct{}),
cache: make(map[alloydb.InstanceURI]monitoredCache),
lazyRefresh: cfg.lazyRefresh,
staticConnInfo: cfg.staticConnInfo,
keyGenerator: g,
refreshTimeout: cfg.refreshTimeout,
client: client,
logger: cfg.logger,
defaultDialCfg: dialCfg,
dialerID: uuid.New().String(),
dialFunc: cfg.dialFunc,
useIAMAuthN: cfg.useIAMAuthN,
iamTokenSource: ts,
userAgent: userAgent,
buffer: newBuffer(),
closed: make(chan struct{}),
cache: make(map[alloydb.InstanceURI]monitoredCache),
lazyRefresh: cfg.lazyRefresh,
disableMetadataExchange: cfg.disableMetadataExchange,
staticConnInfo: cfg.staticConnInfo,
keyGenerator: g,
refreshTimeout: cfg.refreshTimeout,
client: client,
logger: cfg.logger,
defaultDialCfg: dialCfg,
dialerID: uuid.New().String(),
dialFunc: cfg.dialFunc,
useIAMAuthN: cfg.useIAMAuthN,
iamTokenSource: ts,
userAgent: userAgent,
buffer: newBuffer(),
}
return d, nil
}
Expand Down Expand Up @@ -351,12 +361,14 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
return nil, errtype.NewDialError("handshake failed", inst.String(), err)
}

// The metadata exchange must occur after the TLS connection is established
// to avoid leaking sensitive information.
err = d.metadataExchange(tlsConn)
if err != nil {
_ = tlsConn.Close() // best effort close attempt
return nil, err
if !d.disableMetadataExchange {
// The metadata exchange must occur after the TLS connection is established
// to avoid leaking sensitive information.
err = d.metadataExchange(tlsConn)
if err != nil {
_ = tlsConn.Close() // best effort close attempt
return nil, err
}
}

latency := time.Since(startTime).Milliseconds()
Expand Down Expand Up @@ -598,6 +610,7 @@ func (d *Dialer) connectionInfoCache(
d.logger,
d.client, k,
d.refreshTimeout, d.dialerID,
d.disableMetadataExchange,
)
case d.staticConnInfo != nil:
var err error
Expand All @@ -615,6 +628,7 @@ func (d *Dialer) connectionInfoCache(
d.logger,
d.client, k,
d.refreshTimeout, d.dialerID,
d.disableMetadataExchange,
)
}
var open uint64
Expand Down
21 changes: 21 additions & 0 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,27 @@ func (stubTokenSource) Token() (*oauth2.Token, error) {
return &oauth2.Token{}, nil
}

func TestDialerIncompatibleOptions(t *testing.T) {
tcs := []struct {
desc string
opts []Option
}{
{
desc: "opt out connection check doesn't work with IAM authn",
opts: []Option{WithOptOutOfAdvancedConnectionCheck(), WithIAMAuthN()},
},
}

for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, err := NewDialer(context.Background(), tc.opts...)
if err == nil {
t.Fatalf("got = %v, want no error", err)
}
})
}
}

func TestDialerCanConnectToInstance(t *testing.T) {
ctx := context.Background()
inst := mock.NewFakeInstance(
Expand Down
10 changes: 10 additions & 0 deletions e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ func TestPgxConnect(t *testing.T) {
)
},
},
{
desc: "metadata exchange disabled",
f: func(ctx context.Context) (*pgxpool.Pool, func() error, error) {
return connectPgx(
ctx, alloydbInstanceName,
alloydbUser, alloydbPass, alloydbDB,
alloydbconn.WithOptOutOfAdvancedConnectionCheck(),
)
},
},
}

for _, tc := range tcs {
Expand Down
3 changes: 2 additions & 1 deletion internal/alloydb/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,14 @@ func NewRefreshAheadCache(
key *rsa.PrivateKey,
refreshTimeout time.Duration,
dialerID string,
disableMetadataExchange bool,
) *RefreshAheadCache {
ctx, cancel := context.WithCancel(context.Background())
i := &RefreshAheadCache{
instanceURI: instance,
logger: l,
l: rate.NewLimiter(rate.Every(refreshInterval), refreshBurst),
r: newAdminAPIClient(client, key, dialerID),
r: newAdminAPIClient(client, key, dialerID, disableMetadataExchange),
refreshTimeout: refreshTimeout,
ctx: ctx,
cancel: cancel,
Expand Down
3 changes: 3 additions & 0 deletions internal/alloydb/instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ func TestConnectionInfo(t *testing.T) {
testInstanceURI(),
nullLogger{},
c, rsaKey, 30*time.Second, "dialer-id",
false,
)
if err != nil {
t.Fatalf("failed to create mock instance: %v", err)
Expand Down Expand Up @@ -210,6 +211,7 @@ func TestConnectInfoErrors(t *testing.T) {
testInstanceURI(),
nullLogger{},
c, rsaKey, 0, "dialer-id",
false,
)
if err != nil {
t.Fatalf("failed to initialize Instance: %v", err)
Expand Down Expand Up @@ -243,6 +245,7 @@ func TestClose(t *testing.T) {
testInstanceURI(),
nullLogger{},
c, rsaKey, 30, "dialer-ider",
false,
)
if err != nil {
t.Fatalf("failed to initialize Instance: %v", err)
Expand Down
7 changes: 2 additions & 5 deletions internal/alloydb/lazy.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,12 @@ func NewLazyRefreshCache(
key *rsa.PrivateKey,
_ time.Duration,
dialerID string,
disableMetadataExchange bool,
) *LazyRefreshCache {
return &LazyRefreshCache{
uri: uri,
logger: l,
r: newAdminAPIClient(
client,
key,
dialerID,
),
r: newAdminAPIClient(client, key, dialerID, disableMetadataExchange),
}
}

Expand Down
2 changes: 2 additions & 0 deletions internal/alloydb/lazy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func TestLazyRefreshCacheConnectionInfo(t *testing.T) {
cache := NewLazyRefreshCache(
testInstanceURI(), nullLogger{}, c,
rsaKey, 30*time.Second, "",
false,
)

ci, err := cache.ConnectionInfo(context.Background())
Expand Down Expand Up @@ -91,6 +92,7 @@ func TestLazyRefreshCacheForceRefresh(t *testing.T) {
cache := NewLazyRefreshCache(
testInstanceURI(), nullLogger{}, c,
rsaKey, 30*time.Second, "",
false,
)

_, err = cache.ConnectionInfo(context.Background())
Expand Down
16 changes: 11 additions & 5 deletions internal/alloydb/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ func fetchClientCertificate(
cl *alloydbadmin.AlloyDBAdminClient,
inst InstanceURI,
key *rsa.PrivateKey,
disableMetadataExchange bool,
) (cc *clientCertificate, err error) {
var end trace.EndSpanFunc
ctx, end = trace.StartSpan(ctx, "cloud.google.com/go/alloydbconn/internal.FetchEphemeralCert")
Expand All @@ -138,7 +139,7 @@ func fetchClientCertificate(
),
PublicKey: buf.String(),
CertDuration: durationpb.New(time.Second * 3600),
UseMetadataExchange: true,
UseMetadataExchange: !disableMetadataExchange,
}
resp, err := cl.GenerateClientCertificate(ctx, req)
if err != nil {
Expand Down Expand Up @@ -225,11 +226,13 @@ func newAdminAPIClient(
client *alloydbadmin.AlloyDBAdminClient,
key *rsa.PrivateKey,
dialerID string,
disableMetadataExchange bool,
) adminAPIClient {
return adminAPIClient{
client: client,
key: key,
dialerID: dialerID,
client: client,
key: key,
dialerID: dialerID,
disableMetadataExchange: disableMetadataExchange,
}
}

Expand All @@ -242,6 +245,9 @@ type adminAPIClient struct {
key *rsa.PrivateKey
// dialerID is the unique ID of the associated dialer.
dialerID string
// disableMetadataExchange is a temporary addition to ease the migration to
// when the metadata exchange is required.
disableMetadataExchange bool
}

// ConnectionInfo holds all the data necessary to connect to an instance.
Expand Down Expand Up @@ -286,7 +292,7 @@ func (c adminAPIClient) connectionInfo(
certCh := make(chan certRes, 1)
go func() {
defer close(certCh)
cc, err := fetchClientCertificate(ctx, c.client, i, c.key)
cc, err := fetchClientCertificate(ctx, c.client, i, c.key, c.disableMetadataExchange)
certCh <- certRes{cc: cc, err: err}
}()

Expand Down
4 changes: 2 additions & 2 deletions internal/alloydb/refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestRefresh(t *testing.T) {
if err != nil {
t.Fatalf("admin API client error: %v", err)
}
r := newAdminAPIClient(cl, rsaKey, testDialerID)
r := newAdminAPIClient(cl, rsaKey, testDialerID, false)
res, err := r.connectionInfo(context.Background(), cn)
if err != nil {
t.Fatalf("performRefresh unexpectedly failed with error: %v", err)
Expand Down Expand Up @@ -124,7 +124,7 @@ func TestRefreshFailsFast(t *testing.T) {
if err != nil {
t.Fatalf("admin API client error: %v", err)
}
r := newAdminAPIClient(cl, rsaKey, testDialerID)
r := newAdminAPIClient(cl, rsaKey, testDialerID, false)

_, err = r.connectionInfo(context.Background(), cn)
if err != nil {
Expand Down
22 changes: 22 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ type dialerConfig struct {
logger debug.ContextLogger
lazyRefresh bool

// disableMetadataExchange is a temporary addition and will be removed in
// future versions.
disableMetadataExchange bool

staticConnInfo io.Reader
// err tracks any dialer options that may have failed.
err error
Expand Down Expand Up @@ -241,6 +245,24 @@ func WithStaticConnectionInfo(r io.Reader) Option {
}
}

// WithOptOutOfAdvancedConnectionCheck disables the dataplane permission check.
// It is intended only for clients who are running in an environment where the
// workload's IP address is otherwise unknown and cannot be allow-listed in a
// VPC Service Control security perimeter. This option is incompatible with IAM
// Authentication.
//
// NOTE: This option is for internal usage only and is meant to ease the
// migration when the advanced check will be required on the server. In future
// versions this will revert to a no-op and should not be used. If you think
// you need this option, open an issue on
// https://github.com/GoogleCloudPlatform/alloydb-go-connector for design
// advice.
func WithOptOutOfAdvancedConnectionCheck() Option {
return func(d *dialerConfig) {
d.disableMetadataExchange = true
}
}

// A DialOption is an option for configuring how a Dialer's Dial call is
// executed.
type DialOption func(d *dialCfg)
Expand Down
3 changes: 2 additions & 1 deletion pgxpool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ import (
// that should be called when you're done with the database connection.
func connectPgx(
ctx context.Context, instURI, user, pass, dbname string,
opts ...alloydbconn.Option,
) (*pgxpool.Pool, func() error, error) {
// First initialize the dialer. alloydbconn.NewDialer accepts additional
// options to configure credentials, timeouts, etc.
//
// For details, see:
// https://pkg.go.dev/cloud.google.com/go/alloydbconn#Option
d, err := alloydbconn.NewDialer(ctx)
d, err := alloydbconn.NewDialer(ctx, opts...)
if err != nil {
noop := func() error { return nil }
return nil, noop, fmt.Errorf("failed to init Dialer: %v", err)
Expand Down
Loading