diff --git a/config/experimental/auth/auth.go b/config/experimental/auth/auth.go new file mode 100644 index 00000000..2f560498 --- /dev/null +++ b/config/experimental/auth/auth.go @@ -0,0 +1,212 @@ +// Package auth is an internal package that provides authentication utilities. +// +// IMPORTANT: This package is not meant to be used directly by consumers of the +// SDK and is subject to change without notice. +package auth + +import ( + "sync" + "time" + + "golang.org/x/oauth2" +) + +const ( + // Default duration for the stale period. The number as been set arbitrarily + // and might be changed in the future. + defaultStaleDuration = 3 * time.Minute + + // Disable the asynchronous token refresh by default. This is meant to + // change in the future once the feature is stable. + defaultDisableAsyncRefresh = true +) + +type Option func(*cachedTokenSource) + +// WithCachedToken sets the initial token to be used by a cached token source. +func WithCachedToken(t *oauth2.Token) Option { + return func(cts *cachedTokenSource) { + cts.cachedToken = t + } +} + +// WithAsyncRefresh enables or disables the asynchronous token refresh. +func WithAsyncRefresh(b bool) Option { + return func(cts *cachedTokenSource) { + cts.disableAsync = !b + } +} + +// NewCachedTokenProvider wraps a [oauth2.TokenSource] to cache the tokens +// it returns. By default, the cache will refresh tokens asynchronously a few +// minutes before they expire. +// +// The token cache is safe for concurrent use by multiple goroutines and will +// guarantee that only one token refresh is triggered at a time. +// +// The token cache does not take care of retries in case the token source +// returns and error; it is the responsibility of the provided token source to +// handle retries appropriately. +// +// If the TokenSource is already a cached token source (obtained by calling this +// function), it is returned as is. +func NewCachedTokenSource(ts oauth2.TokenSource, opts ...Option) oauth2.TokenSource { + // This is meant as a niche optimization to avoid double caching of the + // token source in situations where the user calls needs caching guarantees + // but does not know if the token source is already cached. + if cts, ok := ts.(*cachedTokenSource); ok { + return cts + } + + cts := &cachedTokenSource{ + tokenSource: ts, + staleDuration: defaultStaleDuration, + disableAsync: defaultDisableAsyncRefresh, + cachedToken: nil, + timeNow: time.Now, + } + + for _, opt := range opts { + opt(cts) + } + + return cts +} + +type cachedTokenSource struct { + // The token source to obtain tokens from. + tokenSource oauth2.TokenSource + + // If true, only refresh the token with a blocking call when it is expired. + disableAsync bool + + // Duration during which a token is considered stale, see tokenState. + staleDuration time.Duration + + mu sync.Mutex + cachedToken *oauth2.Token + + // Indicates that an async refresh is in progress. This is used to prevent + // multiple async refreshes from being triggered at the same time. + isRefreshing bool + + // Error returned by the last refresh. Async refreshes are disabled if this + // value is not nil so that the cache does not continue sending request to + // a potentially failing server. The next blocking call will re-enable async + // refreshes by setting this value to nil if it succeeds, or return the + // error if it fails. + refreshErr error + + timeNow func() time.Time // for testing +} + +// Token returns a token from the cache or fetches a new one if the current +// token is expired. +func (cts *cachedTokenSource) Token() (*oauth2.Token, error) { + if cts.disableAsync { + return cts.blockingToken() + } + return cts.asyncToken() +} + +// tokenState represents the state of the token. Each token can be in one of +// the following three states: +// - fresh: The token is valid. +// - stale: The token is valid but will expire soon. +// - expired: The token has expired and cannot be used. +// +// Token state through time: +// +// issue time expiry time +// v v +// | fresh | stale | expired -> time +// | valid | +type tokenState int + +const ( + fresh tokenState = iota // The token is valid. + stale // The token is valid but will expire soon. + expired // The token has expired and cannot be used. +) + +// tokenState returns the state of the token. The function is not thread-safe +// and should be called with the lock held. +func (c *cachedTokenSource) tokenState() tokenState { + if c.cachedToken == nil { + return expired + } + switch lifeSpan := c.cachedToken.Expiry.Sub(c.timeNow()); { + case lifeSpan <= 0: + return expired + case lifeSpan <= c.staleDuration: + return stale + default: + return fresh + } +} + +func (cts *cachedTokenSource) asyncToken() (*oauth2.Token, error) { + cts.mu.Lock() + ts := cts.tokenState() + t := cts.cachedToken + cts.mu.Unlock() + + switch ts { + case fresh: + return t, nil + case stale: + cts.triggerAsyncRefresh() + return t, nil + default: // expired + return cts.blockingToken() + } +} + +func (cts *cachedTokenSource) blockingToken() (*oauth2.Token, error) { + cts.mu.Lock() + + // The lock is kept for the entire operation to ensure that only one + // blockingToken operation is running at a time. + defer cts.mu.Unlock() + + // This is important to recover from potential previous failed attempts + // to refresh the token asynchronously, see declaration of refreshErr for + // more information. + cts.isRefreshing = false + cts.refreshErr = nil + + // It's possible that the token got refreshed (either by a blockingToken or + // an asyncRefresh call) while this particular call was waiting to acquire + // the mutex. This check avoids refreshing the token again in such cases. + if ts := cts.tokenState(); ts != expired { // fresh or stale + return cts.cachedToken, nil + } + + t, err := cts.tokenSource.Token() + if err != nil { + return nil, err + } + cts.cachedToken = t + return t, nil +} + +func (cts *cachedTokenSource) triggerAsyncRefresh() { + cts.mu.Lock() + defer cts.mu.Unlock() + if !cts.isRefreshing && cts.refreshErr == nil { + cts.isRefreshing = true + + go func() { + t, err := cts.tokenSource.Token() + + cts.mu.Lock() + defer cts.mu.Unlock() + cts.isRefreshing = false + if err != nil { + cts.refreshErr = err + return + } + cts.cachedToken = t + }() + } +} diff --git a/config/experimental/auth/auth_test.go b/config/experimental/auth/auth_test.go new file mode 100644 index 00000000..035ebe42 --- /dev/null +++ b/config/experimental/auth/auth_test.go @@ -0,0 +1,280 @@ +package auth + +import ( + "fmt" + "reflect" + "sync" + "sync/atomic" + "testing" + "time" + + "golang.org/x/oauth2" +) + +type mockTokenSource func() (*oauth2.Token, error) + +func (m mockTokenSource) Token() (*oauth2.Token, error) { + return m() +} + +func TestNewCachedTokenSource_noCaching(t *testing.T) { + want := &cachedTokenSource{} + got := NewCachedTokenSource(want, nil) + if got != want { + t.Errorf("NewCachedTokenSource() = %v, want %v", got, want) + } +} + +func TestNewCachedTokenSource_default(t *testing.T) { + ts := mockTokenSource(func() (*oauth2.Token, error) { + return nil, nil + }) + + got, ok := NewCachedTokenSource(ts).(*cachedTokenSource) + if !ok { + t.Fatalf("NewCachedTokenSource() = %T, want *cachedTokenSource", got) + } + + if got.staleDuration != defaultStaleDuration { + t.Errorf("NewCachedTokenSource() staleDuration = %v, want %v", got.staleDuration, defaultStaleDuration) + } + if got.disableAsync != defaultDisableAsyncRefresh { + t.Errorf("NewCachedTokenSource() disableAsync = %v, want %v", got.disableAsync, defaultDisableAsyncRefresh) + } + if got.cachedToken != nil { + t.Errorf("NewCachedTokenSource() cachedToken = %v, want nil", got.cachedToken) + } +} + +func TestNewCachedTokenSource_options(t *testing.T) { + ts := mockTokenSource(func() (*oauth2.Token, error) { + return nil, nil + }) + + wantDisableAsync := false + wantCachedToken := &oauth2.Token{Expiry: time.Unix(42, 0)} + + opts := []Option{ + WithAsyncRefresh(!wantDisableAsync), + WithCachedToken(wantCachedToken), + } + + got, ok := NewCachedTokenSource(ts, opts...).(*cachedTokenSource) + if !ok { + t.Fatalf("NewCachedTokenSource() = %T, want *cachedTokenSource", got) + } + + if got.disableAsync != wantDisableAsync { + t.Errorf("NewCachedTokenSource(): disableAsync = %v, want %v", got.disableAsync, wantDisableAsync) + } + if got.cachedToken != wantCachedToken { + t.Errorf("NewCachedTokenSource(): cachedToken = %v, want %v", got.cachedToken, wantCachedToken) + } +} + +func TestCachedTokenSource_tokenState(t *testing.T) { + now := time.Unix(1337, 0) // mock value for time.Now() + + testCases := []struct { + token *oauth2.Token + staleDuration time.Duration + want tokenState + }{ + { + token: nil, + staleDuration: 10 * time.Minute, + want: expired, + }, + { + token: &oauth2.Token{ + Expiry: now.Add(-1 * time.Second), + }, + staleDuration: 10 * time.Minute, + want: expired, + }, + { + token: &oauth2.Token{ + Expiry: now.Add(1 * time.Hour), + }, + staleDuration: 10 * time.Minute, + want: fresh, + }, + { + token: &oauth2.Token{ + Expiry: now.Add(5 * time.Minute), + }, + staleDuration: 10 * time.Minute, + want: stale, + }, + } + + for _, tc := range testCases { + cts := &cachedTokenSource{ + cachedToken: tc.token, + staleDuration: tc.staleDuration, + disableAsync: false, + timeNow: func() time.Time { return now }, + } + + got := cts.tokenState() + + if got != tc.want { + t.Errorf("tokenState() = %v, want %v", got, tc.want) + } + } +} + +func TestCachedTokenSource_Token(t *testing.T) { + now := time.Unix(1337, 0) // mock value for time.Now() + nTokenCalls := 10 // number of goroutines calling Token() + testCases := []struct { + desc string // description of the test case + cachedToken *oauth2.Token // token cached before calling Token() + disableAsync bool // whether are disabled or not + refreshErr error // whether the cache was in error state + + returnedToken *oauth2.Token // token returned by the token source + returnedError error // error returned by the token source + + wantCalls int // expected number of calls to the token source + wantToken *oauth2.Token // expected token in the cache + }{ + { + desc: "[Blocking] no cached token", + disableAsync: true, + returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + wantCalls: 1, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + }, + { + desc: "[Blocking] expired cached token", + disableAsync: true, + cachedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Second)}, + returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + wantCalls: 1, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + }, + { + desc: "[Blocking] fresh cached token", + disableAsync: true, + cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + wantCalls: 0, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + }, + { + desc: "[Blocking] stale cached token", + disableAsync: true, + cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, + wantCalls: 0, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, + }, + { + desc: "[Blocking] refresh error", + disableAsync: true, + returnedError: fmt.Errorf("test error"), + wantCalls: 10, + }, + { + desc: "[Blocking] recover from error", + disableAsync: true, + refreshErr: fmt.Errorf("refresh error"), + cachedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Minute)}, + returnedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Hour)}, + wantCalls: 10, + wantToken: &oauth2.Token{Expiry: now.Add(-1 * time.Hour)}, + }, + { + desc: "[Async] no cached token", + returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + wantCalls: 1, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + }, + { + desc: "[Async] no cached token", + returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + wantCalls: 1, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + }, + { + desc: "[Async] expired cached token", + cachedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Second)}, + returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + wantCalls: 1, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + }, + { + desc: "[Async] fresh cached token", + cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + wantCalls: 0, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + }, + { + desc: "[Async] stale cached token", + cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, + returnedToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + wantCalls: 1, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Hour)}, + }, + { + desc: "[Async] refresh error", + cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, + returnedError: fmt.Errorf("test error"), + wantCalls: 1, + wantToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, + }, + { + desc: "[Async] stale cached token, expired token returned", + cachedToken: &oauth2.Token{Expiry: now.Add(1 * time.Minute)}, + returnedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Second)}, + wantCalls: 10, + wantToken: &oauth2.Token{Expiry: now.Add(-1 * time.Second)}, + }, + { + desc: "[Async] recover from error", + refreshErr: fmt.Errorf("refresh error"), + cachedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Minute)}, + returnedToken: &oauth2.Token{Expiry: now.Add(-1 * time.Hour)}, + wantCalls: 10, + wantToken: &oauth2.Token{Expiry: now.Add(-1 * time.Hour)}, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + gotCalls := int32(0) + cts := &cachedTokenSource{ + disableAsync: tc.disableAsync, + staleDuration: 10 * time.Minute, + cachedToken: tc.cachedToken, + timeNow: func() time.Time { return now }, + tokenSource: mockTokenSource(func() (*oauth2.Token, error) { + atomic.AddInt32(&gotCalls, 1) + return tc.returnedToken, tc.returnedError + }), + } + + wg := sync.WaitGroup{} + for i := 0; i < nTokenCalls; i++ { + wg.Add(1) + go func() { + defer wg.Done() + cts.Token() + }() + } + + wg.Wait() + + // Wait for async refreshes to finish. This part is a little brittle + // but necessary to ensure that the async refresh is done before + // checking the results. + time.Sleep(10 * time.Millisecond) + + if int(gotCalls) != tc.wantCalls { + t.Errorf("want %d calls to cts.tokenSource.Token(), got %d", tc.wantCalls, gotCalls) + } + if !reflect.DeepEqual(tc.wantToken, cts.cachedToken) { + t.Errorf("want cached token %v, got %v", tc.wantToken, cts.cachedToken) + } + }) + } +} diff --git a/config/oauth_visitors.go b/config/oauth_visitors.go index 2b172bf1..e9d3277c 100644 --- a/config/oauth_visitors.go +++ b/config/oauth_visitors.go @@ -5,14 +5,16 @@ import ( "net/http" "time" + "github.com/databricks/databricks-sdk-go/config/experimental/auth" "golang.org/x/oauth2" ) -// serviceToServiceVisitor returns a visitor that sets the Authorization header to the token from the auth token source -// and the provided secondary header to the token from the secondary token source. -func serviceToServiceVisitor(auth, secondary oauth2.TokenSource, secondaryHeader string) func(r *http.Request) error { - refreshableAuth := oauth2.ReuseTokenSource(nil, auth) - refreshableSecondary := oauth2.ReuseTokenSource(nil, secondary) +// serviceToServiceVisitor returns a visitor that sets the Authorization header +// to the token from the auth token sourcevand the provided secondary header to +// the token from the secondary token source. +func serviceToServiceVisitor(primary, secondary oauth2.TokenSource, secondaryHeader string) func(r *http.Request) error { + refreshableAuth := auth.NewCachedTokenSource(primary) + refreshableSecondary := auth.NewCachedTokenSource(secondary) return func(r *http.Request) error { inner, err := refreshableAuth.Token() if err != nil { @@ -31,9 +33,9 @@ func serviceToServiceVisitor(auth, secondary oauth2.TokenSource, secondaryHeader // The same as serviceToServiceVisitor, but without a secondary token source. func refreshableVisitor(inner oauth2.TokenSource) func(r *http.Request) error { - refreshableAuth := oauth2.ReuseTokenSource(nil, inner) + cts := auth.NewCachedTokenSource(inner) return func(r *http.Request) error { - inner, err := refreshableAuth.Token() + inner, err := cts.Token() if err != nil { return fmt.Errorf("inner token: %w", err) } @@ -51,10 +53,32 @@ func azureVisitor(cfg *Config, inner func(*http.Request) error) func(*http.Reque } } -// azureReuseTokenSource calls into oauth2.ReuseTokenSourceWithExpiry with a 40 second expiry window. -// By default, the oauth2 library refreshes a token 10 seconds before it expires. -// Azure Databricks rejects tokens that expire in 30 seconds or less. -// We combine these and refresh the token 40 seconds before it expires. +// azureReuseTokenSource returns a cached token source that refreshes token 40 +// seconds before they expire. The reason for this is that Azure Databricks +// rejects tokens that expire in 30 seconds or less and we want to give a 10 +// second buffer. func azureReuseTokenSource(t *oauth2.Token, ts oauth2.TokenSource) oauth2.TokenSource { - return oauth2.ReuseTokenSourceWithExpiry(t, ts, 40*time.Second) + early := wrap(ts, func(t *oauth2.Token) *oauth2.Token { + t.Expiry = t.Expiry.Add(-40 * time.Second) + return t + }) + + return auth.NewCachedTokenSource(early, auth.WithCachedToken(t)) +} + +func wrap(ts oauth2.TokenSource, fn func(*oauth2.Token) *oauth2.Token) oauth2.TokenSource { + return &tokenSourceWrapper{fn: fn, inner: ts} +} + +type tokenSourceWrapper struct { + fn func(*oauth2.Token) *oauth2.Token + inner oauth2.TokenSource +} + +func (w *tokenSourceWrapper) Token() (*oauth2.Token, error) { + t, err := w.inner.Token() + if err != nil { + return nil, err + } + return w.fn(t), nil }