Skip to content

Commit

Permalink
Merge pull request #86 from tho/fix-issue-85
Browse files Browse the repository at this point in the history
Do not override jwks.ctx and jwks.cancel in Get
  • Loading branch information
MicahParks authored Apr 21, 2023
2 parents b63e165 + e57609a commit 006482b
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 1 deletion.
5 changes: 4 additions & 1 deletion get.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ func Get(jwksURL string, options Options) (jwks *JWKS, err error) {
}

if jwks.refreshInterval != 0 || jwks.refreshUnknownKID {
jwks.ctx, jwks.cancel = context.WithCancel(context.Background())
if jwks.ctx == nil {
jwks.ctx = context.Background()
}
jwks.ctx, jwks.cancel = context.WithCancel(jwks.ctx)
jwks.refreshRequests = make(chan refreshRequest, 1)
go jwks.backgroundRefresh()
}
Expand Down
82 changes: 82 additions & 0 deletions get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,85 @@ func TestJWKS_RefreshUsingBackgroundGoroutine(t *testing.T) {
t.Fatalf("Expected 2 refreshes, got %d.", count)
}
}

func TestJWKS_RefreshCancelCtx(t *testing.T) {
tests := map[string]struct {
provideOptionsCtx bool
cancelOptionsCtx bool
expectedRefreshes int
}{
"cancel Options.Ctx": {
provideOptionsCtx: true,
cancelOptionsCtx: true,
expectedRefreshes: 2,
},
"do not cancel Options.Ctx": {
provideOptionsCtx: true,
cancelOptionsCtx: false,
expectedRefreshes: 3,
},
"do not provide Options.Ctx": {
provideOptionsCtx: false,
cancelOptionsCtx: false,
expectedRefreshes: 3,
},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
var (
ctx context.Context
cancel = func() {}
)
if tc.provideOptionsCtx {
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
}

var counter uint64
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddUint64(&counter, 1)
_, err := w.Write([]byte(jwksJSON))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
defer server.Close()

jwksURL := server.URL
opts := keyfunc.Options{
Ctx: ctx,
RefreshInterval: 1 * time.Second,
}
jwks, err := keyfunc.Get(jwksURL, opts)
if err != nil {
t.Fatalf(logFmt, "Failed to get JWKS from testing URL.", err)
}

// Wait for the first refresh to occur to ensure the
// JWKS gets refreshed at least once.
time.Sleep(1100 * time.Millisecond)

if tc.cancelOptionsCtx {
cancel()
}

// Wait for another refresh cycle to occur to ensure the
// JWKS either did or did not get refreshed depending on
// whether the passed Options.Ctx has been canceled.
time.Sleep(1101 * time.Millisecond)

jwks.EndBackground()

// Wait for another refresh cycle to occur to verify that
// the JWKS did not get refreshed after EndBackground()
// has been called.
time.Sleep(1100 * time.Millisecond)

count := atomic.LoadUint64(&counter)
if count != uint64(tc.expectedRefreshes) {
t.Fatalf("Expected %d refreshes, got %d.", tc.expectedRefreshes, count)
}
})
}
}

0 comments on commit 006482b

Please sign in to comment.