Skip to content

Commit

Permalink
feat: rfc9207
Browse files Browse the repository at this point in the history
This implements RFC9207 OAuth 2.0 Authorization Server Issuer Identification. See Also: https://datatracker.ietf.org/doc/html/rfc9207.
  • Loading branch information
james-d-elliott committed Dec 22, 2023
1 parent 0b6232d commit f30e737
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 21 deletions.
67 changes: 47 additions & 20 deletions authorize_error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ func TestWriteAuthorizeError(t *testing.T) {
err *RFC6749Error
debug bool
doNotUseLegacyFormat bool
mock func(*MockResponseWriter, *MockAuthorizeRequester, http.Header)
setup func(t *testing.T, provider *Fosite, rw *MockResponseWriter, requester *MockAuthorizeRequester, header http.Header)
checkHeader func(*testing.T, http.Header)
}{
{
name: "ShouldHandleInvalidGrantResponseModeDefault",
err: ErrInvalidGrant,
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(false)
req.EXPECT().GetResponseMode().Return(ResponseModeDefault)
rw.EXPECT().Header().Times(3).Return(header)
Expand All @@ -70,7 +70,7 @@ func TestWriteAuthorizeError(t *testing.T) {
name: "ShouldHandleInvalidRequestResponseModeQueryWithDebug",
debug: true,
err: ErrInvalidRequest.WithDebug("with-debug"),
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0]))
req.EXPECT().GetState().Return("foostate")
Expand All @@ -92,7 +92,7 @@ func TestWriteAuthorizeError(t *testing.T) {
debug: true,
doNotUseLegacyFormat: true,
err: ErrInvalidRequest.WithDebug("with-debug"),
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0]))
req.EXPECT().GetState().Return("foostate")
Expand All @@ -113,7 +113,7 @@ func TestWriteAuthorizeError(t *testing.T) {
name: "ShouldHandleInvalidRequestResponseModeQueryWithNonLegacy",
doNotUseLegacyFormat: true,
err: ErrInvalidRequest.WithDebug("with-debug"),
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0]))
req.EXPECT().GetState().Return("foostate")
Expand All @@ -133,7 +133,7 @@ func TestWriteAuthorizeError(t *testing.T) {
{
name: "ShouldHandleInvalidRequestResponseModeDefault",
err: ErrInvalidRequest.WithDebug("with-debug"),
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0]))
req.EXPECT().GetState().Return("foostate")
Expand All @@ -153,7 +153,7 @@ func TestWriteAuthorizeError(t *testing.T) {
{
name: "ShouldHandleInvalidRequestResponseModeQuery",
err: ErrInvalidRequest,
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
req.EXPECT().GetState().Return("foostate")
Expand All @@ -173,7 +173,7 @@ func TestWriteAuthorizeError(t *testing.T) {
{
name: "ShouldHandleUnsupportedGrantTypeResponseModeFragment",
err: ErrUnsupportedGrantType,
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
req.EXPECT().GetState().Return("foostate")
Expand All @@ -193,7 +193,7 @@ func TestWriteAuthorizeError(t *testing.T) {
{
name: "ShouldHandleInvalidRequestResponseModeFragment",
err: ErrInvalidRequest,
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0]))
req.EXPECT().GetState().Return("foostate")
Expand All @@ -213,7 +213,7 @@ func TestWriteAuthorizeError(t *testing.T) {
{
name: "ShouldHandleInvalidRequestResponseModeFragmentAltURL",
err: ErrInvalidRequest,
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
req.EXPECT().GetState().Return("foostate")
Expand All @@ -233,7 +233,7 @@ func TestWriteAuthorizeError(t *testing.T) {
{
name: "ShouldHandleInvalidRequestResponseModeFragmentWithDebugOmitted",
err: ErrInvalidRequest.WithDebug("with-debug"),
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0]))
req.EXPECT().GetState().Return("foostate")
Expand All @@ -254,7 +254,7 @@ func TestWriteAuthorizeError(t *testing.T) {
name: "ShouldHandleInvalidRequestResponseModeFragmentWithDebug",
err: ErrInvalidRequest.WithDebug("with-debug"),
debug: true,
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0]))
req.EXPECT().GetState().Return("foostate")
Expand All @@ -276,7 +276,7 @@ func TestWriteAuthorizeError(t *testing.T) {
err: ErrInvalidRequest.WithDebug("with-debug"),
debug: true,
doNotUseLegacyFormat: true,
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0]))
req.EXPECT().GetState().Return("foostate")
Expand All @@ -299,7 +299,7 @@ func TestWriteAuthorizeError(t *testing.T) {
name: "ShouldHandleInvalidRequestResponseModeFragmentWithoutLegacy",
err: ErrInvalidRequest.WithDebug("with-debug"),
doNotUseLegacyFormat: true,
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0]))
req.EXPECT().GetState().Return("foostate")
Expand All @@ -322,7 +322,7 @@ func TestWriteAuthorizeError(t *testing.T) {
{
name: "ShouldHandleInvalidRequestResponseModeFragmentWithDebugOmittedAltURL",
err: ErrInvalidRequest.WithDebug("with-debug"),
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
req.EXPECT().GetState().Return("foostate")
Expand All @@ -343,7 +343,7 @@ func TestWriteAuthorizeError(t *testing.T) {
name: "ShouldHandleInvalidRequestResponseModeFragmentWithDebugAltURL",
debug: true,
err: ErrInvalidRequest.WithDebug("with-debug"),
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
req.EXPECT().GetState().Return("foostate")
Expand All @@ -364,7 +364,7 @@ func TestWriteAuthorizeError(t *testing.T) {
name: "ShouldHandleInvalidRequestResponseModeFragmentWithDebugAltURLImplicitIDToken",
debug: true,
err: ErrInvalidRequest.WithDebug("with-debug"),
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
req.EXPECT().GetState().Return("foostate")
Expand All @@ -385,7 +385,7 @@ func TestWriteAuthorizeError(t *testing.T) {
name: "ShouldHandleInvalidRequestResponseModeFragmentWithDebugAltURLImplicitToken",
debug: true,
err: ErrInvalidRequest.WithDebug("with-debug"),
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
req.EXPECT().GetState().Return("foostate")
Expand All @@ -402,11 +402,38 @@ func TestWriteAuthorizeError(t *testing.T) {
assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma))
},
},
{
name: "ShouldHandleInvalidRequestResponseModeFragmentWithDebugAltURLImplicitTokenWithIdentifier",
debug: true,
err: ErrInvalidRequest.WithDebug("with-debug"),
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
provider.Config = &Config{
SendDebugMessagesToClients: false,
UseLegacyErrorFormat: false,
AuthorizationServerIdentificationIssuer: "https://example.com",
}

req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
req.EXPECT().GetState().Return("foostate")
req.EXPECT().GetResponseTypes().AnyTimes().Return(Arguments([]string{consts.ResponseTypeImplicitFlowToken}))
req.EXPECT().GetResponseMode().Return(ResponseModeFragment).AnyTimes()
rw.EXPECT().Header().Times(3).Return(header)
rw.EXPECT().WriteHeader(http.StatusSeeOther)
},
checkHeader: func(t *testing.T, header http.Header) {
a, _ := url.Parse("https://foobar.com/?foo=bar#error=invalid_request&error_description=The+request+is+missing+a+required+parameter%2C+includes+an+invalid+parameter+value%2C+includes+a+parameter+more+than+once%2C+or+is+otherwise+malformed.+Make+sure+that+the+various+parameters+are+correct%2C+be+aware+of+case+sensitivity+and+trim+your+parameters.+Make+sure+that+the+client+you+are+using+has+exactly+whitelisted+the+redirect_uri+you+specified.&iss=https%3A%2F%2Fexample.com&state=foostate")
b, _ := url.Parse(header.Get(consts.HeaderLocation))
assert.Equal(t, a, b, "\n\t%s\n\t%s", header.Get(consts.HeaderLocation), a.String())
assert.Equal(t, consts.CacheControlNoStore, header.Get(consts.HeaderCacheControl))
assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma))
},
},
{
name: "ShouldHandleInvalidRequestResponseModePostWithDebugAltURLImplicitToken",
debug: true,
err: ErrInvalidRequest.WithDebug("with-debug"),
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
setup: func(t *testing.T, provider *Fosite, rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) {
req.EXPECT().IsRedirectURIValid().Return(true)
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
req.EXPECT().GetState().Return("foostate")
Expand Down Expand Up @@ -439,7 +466,7 @@ func TestWriteAuthorizeError(t *testing.T) {

header := http.Header{}

tc.mock(rw, req, header)
tc.setup(t, provider, rw, req, header)
provider.WriteAuthorizeError(context.Background(), rw, req, tc.err)
tc.checkHeader(t, header)
})
Expand Down
5 changes: 5 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ type IDTokenIssuerProvider interface {
GetIDTokenIssuer(ctx context.Context) string
}

// AuthorizationServerIdentificationIssuerProvider provides OAuth 2.0 Authorization Server Issuer Identification related methods.
type AuthorizationServerIdentificationIssuerProvider interface {
GetAuthorizationServerIdentificationIssuer(ctx context.Context) (issuer string)
}

// JWTScopeFieldProvider returns the provider for configuring the JWT scope field.
type JWTScopeFieldProvider interface {
// GetJWTScopeField returns the JWT scope field.
Expand Down
8 changes: 8 additions & 0 deletions config_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ var (
_ GetJWTMaxDurationProvider = (*Config)(nil)
_ IDTokenLifespanProvider = (*Config)(nil)
_ IDTokenIssuerProvider = (*Config)(nil)
_ AuthorizationServerIdentificationIssuerProvider = (*Config)(nil)
_ JWKSFetcherStrategyProvider = (*Config)(nil)
_ ClientAuthenticationStrategyProvider = (*Config)(nil)
_ SendDebugMessagesToClientsProvider = (*Config)(nil)
Expand Down Expand Up @@ -87,6 +88,9 @@ type Config struct {
// IDTokenIssuer sets the default issuer of the ID Token.
IDTokenIssuer string

// AuthorizationServerIdentificationIssuer string sets the issuer identifier for authorization responses.
AuthorizationServerIdentificationIssuer string

// HashCost sets the cost of the password hashing cost. Defaults to 12.
HashCost int

Expand Down Expand Up @@ -321,6 +325,10 @@ func (c *Config) GetIDTokenIssuer(ctx context.Context) string {
return c.IDTokenIssuer
}

func (c *Config) GetAuthorizationServerIdentificationIssuer(ctx context.Context) (issuer string) {
return c.AuthorizationServerIdentificationIssuer
}

// GetGrantTypeJWTBearerIssuedDateOptional returns the GrantTypeJWTBearerIssuedDateOptional field.
func (c *Config) GetGrantTypeJWTBearerIssuedDateOptional(ctx context.Context) bool {
return c.GrantTypeJWTBearerIssuedDateOptional
Expand Down
1 change: 1 addition & 0 deletions fosite.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ var _ Provider = (*Fosite)(nil)
type Configurator interface {
IDTokenIssuerProvider
IDTokenLifespanProvider
AuthorizationServerIdentificationIssuerProvider
AllowedPromptsProvider
EnforcePKCEProvider
EnforcePKCEForPublicClientsProvider
Expand Down
10 changes: 9 additions & 1 deletion response_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,19 @@ func (h *DefaultResponseModeHandler) handleWriteAuthorizeResponse(ctx context.Co

rm := requester.GetResponseMode()

if rm == ResponseModeJWT {
switch rm {
case ResponseModeJWT:
if requester.GetResponseTypes().ExactOne(consts.ResponseTypeAuthorizationCodeFlow) {
rm = ResponseModeQueryJWT
} else {
rm = ResponseModeFragmentJWT
}
case ResponseModeFormPost, ResponseModeQuery, ResponseModeFragment, ResponseModeDefault:
// RFC9207 OAuth 2.0 Authorization Server Issuer Identification.
// See Also: https://datatracker.ietf.org/doc/html/rfc9207.
if issuer := h.Config.GetAuthorizationServerIdentificationIssuer(ctx); len(issuer) != 0 {
parameters.Set(consts.FormParameterIssuer, issuer)
}
}

switch rm {
Expand Down Expand Up @@ -220,5 +227,6 @@ type ResponseModeHandlerConfigurator interface {
JWTSecuredAuthorizeResponseModeLifespanProvider
MessageCatalogProvider
SendDebugMessagesToClientsProvider
AuthorizationServerIdentificationIssuerProvider
UseLegacyErrorFormatProvider
}

0 comments on commit f30e737

Please sign in to comment.