diff --git a/authorize_error.go b/authorize_error.go index c7e97669..1a2bac0f 100644 --- a/authorize_error.go +++ b/authorize_error.go @@ -12,60 +12,40 @@ import ( "authelia.com/provider/oauth2/internal/consts" ) -func (f *Fosite) WriteAuthorizeError(ctx context.Context, rw http.ResponseWriter, ar AuthorizeRequester, err error) { +func (f *Fosite) WriteAuthorizeError(ctx context.Context, rw http.ResponseWriter, requester AuthorizeRequester, err error) { rw.Header().Set(consts.HeaderCacheControl, consts.CacheControlNoStore) rw.Header().Set(consts.HeaderPragma, consts.PragmaNoCache) - if f.ResponseModeHandler(ctx).ResponseModes().Has(ar.GetResponseMode()) { - f.ResponseModeHandler(ctx).WriteAuthorizeError(ctx, rw, ar, err) - return - } - - rfc := ErrorToRFC6749Error(err).WithLegacyFormat(f.Config.GetUseLegacyErrorFormat(ctx)).WithExposeDebug(f.Config.GetSendDebugMessagesToClients(ctx)).WithLocalizer(f.Config.GetMessageCatalog(ctx), getLangFromRequester(ar)) - if !ar.IsRedirectURIValid() { - rw.Header().Set(consts.HeaderContentType, consts.ContentTypeApplicationJSON) + for _, handler := range f.ResponseModeHandlers(ctx) { + if handler.ResponseModes().Has(requester.GetResponseMode()) { + handler.WriteAuthorizeError(ctx, rw, requester, err) - js, err := json.Marshal(rfc) - if err != nil { - if f.Config.GetSendDebugMessagesToClients(ctx) { - errorMessage := EscapeJSONString(err.Error()) - http.Error(rw, fmt.Sprintf(`{"error":"server_error","error_description":"%s"}`, errorMessage), http.StatusInternalServerError) - } else { - http.Error(rw, `{"error":"server_error"}`, http.StatusInternalServerError) - } return } - - rw.WriteHeader(rfc.CodeField) - _, _ = rw.Write(js) - return } - redirectURI := ar.GetRedirectURI() + f.handleWriteAuthorizeErrorJSON(ctx, rw, ErrServerError.WithHint("The Authorization Server was unable to process the requested Response Mode.")) +} - // The endpoint URI MUST NOT include a fragment component. - redirectURI.Fragment = "" +func (f *Fosite) handleWriteAuthorizeErrorJSON(ctx context.Context, rw http.ResponseWriter, rfc *RFC6749Error) { + rw.Header().Set(consts.HeaderContentType, consts.ContentTypeApplicationJSON) - errors := rfc.ToValues() - errors.Set(consts.FormParameterState, ar.GetState()) + var ( + data []byte + err error + ) - var redirectURIString string - if ar.GetResponseMode() == ResponseModeFormPost { - rw.Header().Set(consts.HeaderContentType, consts.ContentTypeTextHTML) - WriteAuthorizeFormPostResponse(redirectURI.String(), errors, GetPostFormHTMLTemplate(ctx, f), rw) - return - } else if ar.GetResponseMode() == ResponseModeFragment { - redirectURIString = redirectURI.String() + "#" + errors.Encode() - } else { - for key, values := range redirectURI.Query() { - for _, value := range values { - errors.Add(key, value) - } + if data, err = json.Marshal(rfc); err != nil { + if f.Config.GetSendDebugMessagesToClients(ctx) { + errorMessage := EscapeJSONString(err.Error()) + http.Error(rw, fmt.Sprintf(`{"error":"server_error","error_description":"%s"}`, errorMessage), http.StatusInternalServerError) + } else { + http.Error(rw, `{"error":"server_error"}`, http.StatusInternalServerError) } - redirectURI.RawQuery = errors.Encode() - redirectURIString = redirectURI.String() + + return } - rw.Header().Set(consts.HeaderLocation, redirectURIString) - rw.WriteHeader(http.StatusSeeOther) + rw.WriteHeader(rfc.CodeField) + _, _ = rw.Write(data) } diff --git a/authorize_error_test.go b/authorize_error_test.go index 20c20e0e..56939c44 100644 --- a/authorize_error_test.go +++ b/authorize_error_test.go @@ -5,7 +5,6 @@ package oauth2_test import ( "context" - "fmt" "net/http" "net/url" "testing" @@ -43,35 +42,35 @@ func TestWriteAuthorizeError(t *testing.T) { purls = append(purls, purl) } - header := http.Header{} - for k, c := range []struct { + testCases := []struct { + name string err *RFC6749Error debug bool doNotUseLegacyFormat bool - mock func(*MockResponseWriter, *MockAuthorizeRequester) - checkHeader func(*testing.T, int) + mock func(*MockResponseWriter, *MockAuthorizeRequester, http.Header) + checkHeader func(*testing.T, http.Header) }{ - // 0 { - err: ErrInvalidGrant, - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + name: "ShouldHandleInvalidGrantResponseModeDefault", + err: ErrInvalidGrant, + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(false) req.EXPECT().GetResponseMode().Return(ResponseModeDefault) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusBadRequest) rw.EXPECT().Write(gomock.Any()) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { assert.Equal(t, consts.ContentTypeApplicationJSON, header.Get(consts.HeaderContentType)) assert.Equal(t, consts.CacheControlNoStore, header.Get(consts.HeaderCacheControl)) assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 1 { + name: "ShouldHandleInvalidRequestResponseModeQueryWithDebug", debug: true, err: ErrInvalidRequest.WithDebug("with-debug"), - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") @@ -80,7 +79,7 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { a, _ := url.Parse("https://foobar.com/?error=invalid_request&error_debug=with-debug&error_description=The+request+is+missing+a+required+parameter%2C+includes+an+invalid+parameter+value%2C+includes+a+parameter+more+than+once%2C+or+is+otherwise+malformed.&error_hint=Make+sure+that+the+various+parameters+are+correct%2C+be+aware+of+case+sensitivity+and+trim+your+parameters.+Make+sure+that+the+client+you+are+using+has+exactly+whitelisted+the+redirect_uri+you+specified.&state=foostate") b, _ := url.Parse(header.Get(consts.HeaderLocation)) assert.Equal(t, a, b) @@ -88,12 +87,12 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 2 { + name: "ShouldHandleInvalidRequestResponseModeQueryWithDebugNonLegacy", debug: true, doNotUseLegacyFormat: true, err: ErrInvalidRequest.WithDebug("with-debug"), - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") @@ -102,7 +101,7 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { a, _ := url.Parse("https://foobar.com/?error=invalid_request&error_description=The+request+is+missing+a+required+parameter%2C+includes+an+invalid+parameter+value%2C+includes+a+parameter+more+than+once%2C+or+is+otherwise+malformed.+Make+sure+that+the+various+parameters+are+correct%2C+be+aware+of+case+sensitivity+and+trim+your+parameters.+Make+sure+that+the+client+you+are+using+has+exactly+whitelisted+the+redirect_uri+you+specified.+with-debug&state=foostate") b, _ := url.Parse(header.Get(consts.HeaderLocation)) assert.Equal(t, a, b) @@ -110,11 +109,11 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 3 { + name: "ShouldHandleInvalidRequestResponseModeQueryWithNonLegacy", doNotUseLegacyFormat: true, err: ErrInvalidRequest.WithDebug("with-debug"), - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") @@ -123,7 +122,7 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { a, _ := url.Parse("https://foobar.com/?error=invalid_request&error_description=The+request+is+missing+a+required+parameter%2C+includes+an+invalid+parameter+value%2C+includes+a+parameter+more+than+once%2C+or+is+otherwise+malformed.+Make+sure+that+the+various+parameters+are+correct%2C+be+aware+of+case+sensitivity+and+trim+your+parameters.+Make+sure+that+the+client+you+are+using+has+exactly+whitelisted+the+redirect_uri+you+specified.&state=foostate") b, _ := url.Parse(header.Get(consts.HeaderLocation)) assert.Equal(t, a, b) @@ -131,10 +130,10 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 4 { - err: ErrInvalidRequest.WithDebug("with-debug"), - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + name: "ShouldHandleInvalidRequestResponseModeDefault", + err: ErrInvalidRequest.WithDebug("with-debug"), + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") @@ -143,7 +142,7 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { a, _ := url.Parse("https://foobar.com/?error=invalid_request&error_description=The+request+is+missing+a+required+parameter%2C+includes+an+invalid+parameter+value%2C+includes+a+parameter+more+than+once%2C+or+is+otherwise+malformed.&error_hint=Make+sure+that+the+various+parameters+are+correct%2C+be+aware+of+case+sensitivity+and+trim+your+parameters.+Make+sure+that+the+client+you+are+using+has+exactly+whitelisted+the+redirect_uri+you+specified.&state=foostate") b, _ := url.Parse(header.Get(consts.HeaderLocation)) assert.Equal(t, a, b) @@ -151,10 +150,10 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 5 { - err: ErrInvalidRequest, - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + name: "ShouldHandleInvalidRequestResponseModeQuery", + err: ErrInvalidRequest, + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") @@ -163,7 +162,7 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { a, _ := url.Parse("https://foobar.com/?error=invalid_request&error_description=The+request+is+missing+a+required+parameter%2C+includes+an+invalid+parameter+value%2C+includes+a+parameter+more+than+once%2C+or+is+otherwise+malformed.&error_hint=Make+sure+that+the+various+parameters+are+correct%2C+be+aware+of+case+sensitivity+and+trim+your+parameters.+Make+sure+that+the+client+you+are+using+has+exactly+whitelisted+the+redirect_uri+you+specified.&foo=bar&state=foostate") b, _ := url.Parse(header.Get(consts.HeaderLocation)) assert.Equal(t, a, b) @@ -171,10 +170,10 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 6 { - err: ErrUnsupportedGrantType, - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + name: "ShouldHandleUnsupportedGrantTypeResponseModeFragment", + err: ErrUnsupportedGrantType, + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") @@ -183,7 +182,7 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { a, _ := url.Parse("https://foobar.com/?foo=bar#error=unsupported_grant_type&error_description=The+authorization+grant+type+is+not+supported+by+the+authorization+server.&state=foostate") b, _ := url.Parse(header.Get(consts.HeaderLocation)) assert.Equal(t, a, b) @@ -191,10 +190,10 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 7 { - err: ErrInvalidRequest, - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + name: "ShouldHandleInvalidRequestResponseModeFragment", + err: ErrInvalidRequest, + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") @@ -203,7 +202,7 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { a, _ := url.Parse("https://foobar.com/#error=invalid_request&error_description=The+request+is+missing+a+required+parameter%2C+includes+an+invalid+parameter+value%2C+includes+a+parameter+more+than+once%2C+or+is+otherwise+malformed.&error_hint=Make+sure+that+the+various+parameters+are+correct%2C+be+aware+of+case+sensitivity+and+trim+your+parameters.+Make+sure+that+the+client+you+are+using+has+exactly+whitelisted+the+redirect_uri+you+specified.&state=foostate") b, _ := url.Parse(header.Get(consts.HeaderLocation)) assert.Equal(t, a, b, "\n\t%s\n\t%s", header.Get(consts.HeaderLocation), a.String()) @@ -211,10 +210,10 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 8 { - err: ErrInvalidRequest, - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + name: "ShouldHandleInvalidRequestResponseModeFragmentAltURL", + err: ErrInvalidRequest, + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") @@ -223,7 +222,7 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { a, _ := url.Parse("https://foobar.com/?foo=bar#error=invalid_request&error_description=The+request+is+missing+a+required+parameter%2C+includes+an+invalid+parameter+value%2C+includes+a+parameter+more+than+once%2C+or+is+otherwise+malformed.&error_hint=Make+sure+that+the+various+parameters+are+correct%2C+be+aware+of+case+sensitivity+and+trim+your+parameters.+Make+sure+that+the+client+you+are+using+has+exactly+whitelisted+the+redirect_uri+you+specified.&state=foostate") b, _ := url.Parse(header.Get(consts.HeaderLocation)) assert.Equal(t, a, b, "\n\t%s\n\t%s", header.Get(consts.HeaderLocation), a.String()) @@ -231,10 +230,10 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 9 { - err: ErrInvalidRequest.WithDebug("with-debug"), - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + name: "ShouldHandleInvalidRequestResponseModeFragmentWithDebugOmitted", + err: ErrInvalidRequest.WithDebug("with-debug"), + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") @@ -243,7 +242,7 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { a, _ := url.Parse("https://foobar.com/#error=invalid_request&error_description=The+request+is+missing+a+required+parameter%2C+includes+an+invalid+parameter+value%2C+includes+a+parameter+more+than+once%2C+or+is+otherwise+malformed.&error_hint=Make+sure+that+the+various+parameters+are+correct%2C+be+aware+of+case+sensitivity+and+trim+your+parameters.+Make+sure+that+the+client+you+are+using+has+exactly+whitelisted+the+redirect_uri+you+specified.&state=foostate") b, _ := url.Parse(header.Get(consts.HeaderLocation)) assert.Equal(t, a, b, "\n\t%s\n\t%s", header.Get(consts.HeaderLocation), a.String()) @@ -251,11 +250,11 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 10 { + name: "ShouldHandleInvalidRequestResponseModeFragmentWithDebug", err: ErrInvalidRequest.WithDebug("with-debug"), debug: true, - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") @@ -264,7 +263,7 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { a, _ := url.Parse("https://foobar.com/#error=invalid_request&error_debug=with-debug&error_description=The+request+is+missing+a+required+parameter%2C+includes+an+invalid+parameter+value%2C+includes+a+parameter+more+than+once%2C+or+is+otherwise+malformed.&error_hint=Make+sure+that+the+various+parameters+are+correct%2C+be+aware+of+case+sensitivity+and+trim+your+parameters.+Make+sure+that+the+client+you+are+using+has+exactly+whitelisted+the+redirect_uri+you+specified.&state=foostate") b, _ := url.Parse(header.Get(consts.HeaderLocation)) assert.Equal(t, a, b, "\n\t%s\n\t%s", header.Get(consts.HeaderLocation), a.String()) @@ -272,12 +271,12 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 11 { + name: "ShouldHandleInvalidRequestResponseModeFragmentWithDebugWithNonLegacy", err: ErrInvalidRequest.WithDebug("with-debug"), debug: true, doNotUseLegacyFormat: true, - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") @@ -286,7 +285,7 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { a, _ := url.Parse("https://foobar.com/#error=invalid_request&error_description=The+request+is+missing+a+required+parameter%2C+includes+an+invalid+parameter+value%2C+includes+a+parameter+more+than+once%2C+or+is+otherwise+malformed.+Make+sure+that+the+various+parameters+are+correct%2C+be+aware+of+case+sensitivity+and+trim+your+parameters.+Make+sure+that+the+client+you+are+using+has+exactly+whitelisted+the+redirect_uri+you+specified.+with-debug&state=foostate") b, _ := url.Parse(header.Get(consts.HeaderLocation)) assert.NotContains(t, header.Get(consts.HeaderLocation), "error_hint") @@ -296,11 +295,11 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 12 { + name: "ShouldHandleInvalidRequestResponseModeFragmentWithoutLegacy", err: ErrInvalidRequest.WithDebug("with-debug"), doNotUseLegacyFormat: true, - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") @@ -309,7 +308,7 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { a, _ := url.Parse("https://foobar.com/#error=invalid_request&error_description=The+request+is+missing+a+required+parameter%2C+includes+an+invalid+parameter+value%2C+includes+a+parameter+more+than+once%2C+or+is+otherwise+malformed.+Make+sure+that+the+various+parameters+are+correct%2C+be+aware+of+case+sensitivity+and+trim+your+parameters.+Make+sure+that+the+client+you+are+using+has+exactly+whitelisted+the+redirect_uri+you+specified.&state=foostate") b, _ := url.Parse(header.Get(consts.HeaderLocation)) assert.NotContains(t, header.Get(consts.HeaderLocation), "error_hint") @@ -320,10 +319,10 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 13 { - err: ErrInvalidRequest.WithDebug("with-debug"), - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + name: "ShouldHandleInvalidRequestResponseModeFragmentWithDebugOmittedAltURL", + err: ErrInvalidRequest.WithDebug("with-debug"), + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") @@ -332,7 +331,7 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { a, _ := url.Parse("https://foobar.com/?foo=bar#error=invalid_request&error_description=The+request+is+missing+a+required+parameter%2C+includes+an+invalid+parameter+value%2C+includes+a+parameter+more+than+once%2C+or+is+otherwise+malformed.&error_hint=Make+sure+that+the+various+parameters+are+correct%2C+be+aware+of+case+sensitivity+and+trim+your+parameters.+Make+sure+that+the+client+you+are+using+has+exactly+whitelisted+the+redirect_uri+you+specified.&state=foostate") b, _ := url.Parse(header.Get(consts.HeaderLocation)) assert.Equal(t, a, b, "\n\t%s\n\t%s", header.Get(consts.HeaderLocation), a.String()) @@ -340,11 +339,11 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 14 { + name: "ShouldHandleInvalidRequestResponseModeFragmentWithDebugAltURL", debug: true, err: ErrInvalidRequest.WithDebug("with-debug"), - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") @@ -353,7 +352,7 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { a, _ := url.Parse("https://foobar.com/?foo=bar#error=invalid_request&error_debug=with-debug&error_description=The+request+is+missing+a+required+parameter%2C+includes+an+invalid+parameter+value%2C+includes+a+parameter+more+than+once%2C+or+is+otherwise+malformed.&error_hint=Make+sure+that+the+various+parameters+are+correct%2C+be+aware+of+case+sensitivity+and+trim+your+parameters.+Make+sure+that+the+client+you+are+using+has+exactly+whitelisted+the+redirect_uri+you+specified.&state=foostate") b, _ := url.Parse(header.Get(consts.HeaderLocation)) assert.Equal(t, a, b, "\n\t%s\n\t%s", header.Get(consts.HeaderLocation), a.String()) @@ -361,11 +360,11 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 15 { + name: "ShouldHandleInvalidRequestResponseModeFragmentWithDebugAltURLImplicitIDToken", debug: true, err: ErrInvalidRequest.WithDebug("with-debug"), - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") @@ -374,7 +373,7 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { a, _ := url.Parse("https://foobar.com/?foo=bar#error=invalid_request&error_debug=with-debug&error_description=The+request+is+missing+a+required+parameter%2C+includes+an+invalid+parameter+value%2C+includes+a+parameter+more+than+once%2C+or+is+otherwise+malformed.&error_hint=Make+sure+that+the+various+parameters+are+correct%2C+be+aware+of+case+sensitivity+and+trim+your+parameters.+Make+sure+that+the+client+you+are+using+has+exactly+whitelisted+the+redirect_uri+you+specified.&state=foostate") b, _ := url.Parse(header.Get(consts.HeaderLocation)) assert.Equal(t, a, b, "\n\t%s\n\t%s", header.Get(consts.HeaderLocation), a.String()) @@ -382,11 +381,11 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 16 { + name: "ShouldHandleInvalidRequestResponseModeFragmentWithDebugAltURLImplicitToken", debug: true, err: ErrInvalidRequest.WithDebug("with-debug"), - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") @@ -395,7 +394,7 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { a, _ := url.Parse("https://foobar.com/?foo=bar#error=invalid_request&error_debug=with-debug&error_description=The+request+is+missing+a+required+parameter%2C+includes+an+invalid+parameter+value%2C+includes+a+parameter+more+than+once%2C+or+is+otherwise+malformed.&error_hint=Make+sure+that+the+various+parameters+are+correct%2C+be+aware+of+case+sensitivity+and+trim+your+parameters.+Make+sure+that+the+client+you+are+using+has+exactly+whitelisted+the+redirect_uri+you+specified.&state=foostate") b, _ := url.Parse(header.Get(consts.HeaderLocation)) assert.Equal(t, a, b, "\n\t%s\n\t%s", header.Get(consts.HeaderLocation), a.String()) @@ -403,11 +402,11 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) }, }, - // 17 { + name: "ShouldHandleInvalidRequestResponseModePostWithDebugAltURLImplicitToken", debug: true, err: ErrInvalidRequest.WithDebug("with-debug"), - mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester, header http.Header) { req.EXPECT().IsRedirectURIValid().Return(true) req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") @@ -416,18 +415,20 @@ func TestWriteAuthorizeError(t *testing.T) { rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().Write(gomock.Any()).AnyTimes() }, - checkHeader: func(t *testing.T, k int) { + checkHeader: func(t *testing.T, header http.Header) { assert.Equal(t, consts.CacheControlNoStore, header.Get(consts.HeaderCacheControl)) assert.Equal(t, consts.PragmaNoCache, header.Get(consts.HeaderPragma)) assert.Equal(t, consts.ContentTypeTextHTML, header.Get(consts.HeaderContentType)) }, }, - } { - t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { provider := &Fosite{ Config: &Config{ - SendDebugMessagesToClients: c.debug, - UseLegacyErrorFormat: !c.doNotUseLegacyFormat, + SendDebugMessagesToClients: tc.debug, + UseLegacyErrorFormat: !tc.doNotUseLegacyFormat, }, } @@ -436,10 +437,11 @@ func TestWriteAuthorizeError(t *testing.T) { rw := NewMockResponseWriter(ctrl) req := NewMockAuthorizeRequester(ctrl) - c.mock(rw, req) - provider.WriteAuthorizeError(context.Background(), rw, req, c.err) - c.checkHeader(t, k) - header = http.Header{} + header := http.Header{} + + tc.mock(rw, req, header) + provider.WriteAuthorizeError(context.Background(), rw, req, tc.err) + tc.checkHeader(t, header) }) } } diff --git a/authorize_helper.go b/authorize_helper.go index 0f6c2f66..bffe10c4 100644 --- a/authorize_helper.go +++ b/authorize_helper.go @@ -198,9 +198,10 @@ func WriteAuthorizeFormPostResponse(redirectURL string, parameters url.Values, t }) } -func GetPostFormHTMLTemplate(ctx context.Context, f *Fosite) *template.Template { - if t := f.Config.GetFormPostHTMLTemplate(ctx); t != nil { +func GetPostFormHTMLTemplate(ctx context.Context, c FormPostHTMLTemplateProvider) *template.Template { + if t := c.GetFormPostHTMLTemplate(ctx); t != nil { return t } + return DefaultFormPostTemplate } diff --git a/authorize_request.go b/authorize_request.go index 4dff3aaf..b5783348 100644 --- a/authorize_request.go +++ b/authorize_request.go @@ -12,10 +12,14 @@ import ( type ResponseModeType string const ( - ResponseModeDefault = ResponseModeType("") - ResponseModeFormPost = ResponseModeType(consts.ResponseModeFormPost) - ResponseModeQuery = ResponseModeType(consts.ResponseModeQuery) - ResponseModeFragment = ResponseModeType(consts.ResponseModeFragment) + ResponseModeDefault = ResponseModeType("") + ResponseModeFormPost = ResponseModeType(consts.ResponseModeFormPost) + ResponseModeQuery = ResponseModeType(consts.ResponseModeQuery) + ResponseModeFragment = ResponseModeType(consts.ResponseModeFragment) + ResponseModeFormPostJWT = ResponseModeType(consts.ResponseModeFormPostJWT) + ResponseModeQueryJWT = ResponseModeType(consts.ResponseModeQueryJWT) + ResponseModeFragmentJWT = ResponseModeType(consts.ResponseModeFragmentJWT) + ResponseModeJWT = ResponseModeType(consts.ResponseModeJWT) ) // AuthorizeRequest is an implementation of AuthorizeRequester diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 60be4d9e..453d691d 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -228,25 +228,19 @@ func (f *Fosite) validateResponseTypes(r *http.Request, request *AuthorizeReques } func (f *Fosite) ParseResponseMode(ctx context.Context, r *http.Request, request *AuthorizeRequest) error { - switch responseMode := r.Form.Get(consts.FormParameterResponseMode); responseMode { - case string(ResponseModeDefault): - request.ResponseMode = ResponseModeDefault - case string(ResponseModeFragment): - request.ResponseMode = ResponseModeFragment - case string(ResponseModeQuery): - request.ResponseMode = ResponseModeQuery - case string(ResponseModeFormPost): - request.ResponseMode = ResponseModeFormPost - default: - rm := ResponseModeType(responseMode) - if f.ResponseModeHandler(ctx).ResponseModes().Has(rm) { - request.ResponseMode = rm - break + m := r.Form.Get(consts.FormParameterResponseMode) + + for _, handler := range f.ResponseModeHandlers(ctx) { + mode := ResponseModeType(m) + + if handler.ResponseModes().Has(mode) { + request.ResponseMode = mode + + return nil } - return errorsx.WithStack(ErrUnsupportedResponseMode.WithHintf("Request with unsupported response_mode \"%s\".", responseMode)) } - return nil + return errorsx.WithStack(ErrUnsupportedResponseMode.WithHintf("Request with unsupported response_mode \"%s\".", m)) } func (f *Fosite) validateResponseMode(r *http.Request, request *AuthorizeRequest) error { @@ -334,6 +328,7 @@ func (f *Fosite) newAuthorizeRequest(ctx context.Context, r *http.Request, isPAR if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart { return request, errorsx.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithWrap(err).WithDebug(err.Error())) } + request.Form = r.Form // Save state to the request to be returned in error conditions (https://github.com/ory/hydra/issues/1642) @@ -355,6 +350,7 @@ func (f *Fosite) newAuthorizeRequest(ctx context.Context, r *http.Request, isPAR if err != nil { return request, errorsx.WithStack(ErrInvalidClient.WithHint("The requested OAuth 2.0 Client does not exist.").WithWrap(err).WithDebug(err.Error())) } + request.Client = client // Now that the base fields (state and client) are populated, we extract all the information diff --git a/authorize_write.go b/authorize_write.go index a43ab782..36ec6359 100644 --- a/authorize_write.go +++ b/authorize_write.go @@ -10,52 +10,16 @@ import ( "authelia.com/provider/oauth2/internal/consts" ) -func (f *Fosite) WriteAuthorizeResponse(ctx context.Context, rw http.ResponseWriter, ar AuthorizeRequester, resp AuthorizeResponder) { - // Set custom headers, e.g. "X-MySuperCoolCustomHeader" or "X-DONT-CACHE-ME"... - wh := rw.Header() - rh := resp.GetHeader() - for k := range rh { - wh.Set(k, rh.Get(k)) - } - - wh.Set(consts.HeaderCacheControl, consts.CacheControlNoStore) - wh.Set(consts.HeaderPragma, consts.PragmaNoCache) - - redir := ar.GetRedirectURI() - switch rm := ar.GetResponseMode(); rm { - case ResponseModeFormPost: - //form_post - rw.Header().Add(consts.HeaderContentType, consts.ContentTypeTextHTML) - WriteAuthorizeFormPostResponse(redir.String(), resp.GetParameters(), GetPostFormHTMLTemplate(ctx, f), rw) - return - case ResponseModeQuery, ResponseModeDefault: - // Explicit grants - q := redir.Query() - rq := resp.GetParameters() - for k := range rq { - q.Set(k, rq.Get(k)) - } - redir.RawQuery = q.Encode() - sendRedirect(redir.String(), rw) - return - case ResponseModeFragment: - // Implicit grants - // The endpoint URI MUST NOT include a fragment component. - redir.Fragment = "" +func (f *Fosite) WriteAuthorizeResponse(ctx context.Context, rw http.ResponseWriter, requester AuthorizeRequester, responder AuthorizeResponder) { + for _, handler := range f.ResponseModeHandlers(ctx) { + if handler.ResponseModes().Has(requester.GetResponseMode()) { + handler.WriteAuthorizeResponse(ctx, rw, requester, responder) - u := redir.String() - fr := resp.GetParameters() - if len(fr) > 0 { - u = u + "#" + fr.Encode() - } - sendRedirect(u, rw) - return - default: - if f.ResponseModeHandler(ctx).ResponseModes().Has(rm) { - f.ResponseModeHandler(ctx).WriteAuthorizeResponse(ctx, rw, ar, resp) return } } + + f.handleWriteAuthorizeErrorJSON(ctx, rw, ErrServerError.WithHint("The Authorization Server was unable to process the requested Response Mode.")) } // https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 diff --git a/authorize_write_test.go b/authorize_write_test.go index 22d5960c..18703210 100644 --- a/authorize_write_test.go +++ b/authorize_write_test.go @@ -18,30 +18,24 @@ import ( ) func TestWriteAuthorizeResponse(t *testing.T) { - provider := &Fosite{Config: new(Config)} - header := http.Header{} - ctrl := gomock.NewController(t) - rw := NewMockResponseWriter(ctrl) - ar := NewMockAuthorizeRequester(ctrl) - resp := NewMockAuthorizeResponder(ctrl) - defer ctrl.Finish() - - for k, c := range []struct { - setup func() - expect func() + testCases := []struct { + name string + setup func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) + expect func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) }{ { - setup: func() { + name: "ShouldWriteResponseModeDefault", + setup: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { redir, _ := url.Parse("https://foobar.com/?foo=bar") - ar.EXPECT().GetRedirectURI().Return(redir) - ar.EXPECT().GetResponseMode().Return(ResponseModeDefault) - resp.EXPECT().GetParameters().Return(url.Values{}) - resp.EXPECT().GetHeader().Return(http.Header{}) + requester.EXPECT().GetRedirectURI().Return(redir) + requester.EXPECT().GetResponseMode().Return(ResponseModeDefault).Times(2) + responder.EXPECT().GetParameters().Return(url.Values{}) + responder.EXPECT().GetHeader().Return(http.Header{}) rw.EXPECT().Header().Return(header).Times(2) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - expect: func() { + expect: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { assert.Equal(t, http.Header{ consts.HeaderLocation: []string{"https://foobar.com/?foo=bar"}, consts.HeaderCacheControl: []string{consts.CacheControlNoStore}, @@ -50,17 +44,18 @@ func TestWriteAuthorizeResponse(t *testing.T) { }, }, { - setup: func() { + name: "ShouldWriteResponseModeFragment", + setup: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { redir, _ := url.Parse("https://foobar.com/?foo=bar") - ar.EXPECT().GetRedirectURI().Return(redir) - ar.EXPECT().GetResponseMode().Return(ResponseModeFragment) - resp.EXPECT().GetParameters().Return(url.Values{"bar": {"baz"}}) - resp.EXPECT().GetHeader().Return(http.Header{}) + requester.EXPECT().GetRedirectURI().Return(redir) + requester.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2) + responder.EXPECT().GetParameters().Return(url.Values{"bar": {"baz"}}) + responder.EXPECT().GetHeader().Return(http.Header{}) rw.EXPECT().Header().Return(header).Times(2) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - expect: func() { + expect: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { assert.Equal(t, http.Header{ consts.HeaderLocation: []string{"https://foobar.com/?foo=bar#bar=baz"}, consts.HeaderCacheControl: []string{consts.CacheControlNoStore}, @@ -69,17 +64,18 @@ func TestWriteAuthorizeResponse(t *testing.T) { }, }, { - setup: func() { + name: "ShouldWriteResponseModeQuery", + setup: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { redir, _ := url.Parse("https://foobar.com/?foo=bar") - ar.EXPECT().GetRedirectURI().Return(redir) - ar.EXPECT().GetResponseMode().Return(ResponseModeQuery) - resp.EXPECT().GetParameters().Return(url.Values{"bar": {"baz"}}) - resp.EXPECT().GetHeader().Return(http.Header{}) + requester.EXPECT().GetRedirectURI().Return(redir) + requester.EXPECT().GetResponseMode().Return(ResponseModeQuery).Times(2) + responder.EXPECT().GetParameters().Return(url.Values{"bar": {"baz"}}) + responder.EXPECT().GetHeader().Return(http.Header{}) rw.EXPECT().Header().Return(header).Times(2) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - expect: func() { + expect: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { expectedUrl, _ := url.Parse("https://foobar.com/?foo=bar&bar=baz") actualUrl, err := url.Parse(header.Get(consts.HeaderLocation)) assert.Nil(t, err) @@ -89,17 +85,18 @@ func TestWriteAuthorizeResponse(t *testing.T) { }, }, { - setup: func() { + name: "ShouldWriteResponseModeFragmentWithCustomHeaders", + setup: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { redir, _ := url.Parse("https://foobar.com/?foo=bar") - ar.EXPECT().GetRedirectURI().Return(redir) - ar.EXPECT().GetResponseMode().Return(ResponseModeFragment) - resp.EXPECT().GetParameters().Return(url.Values{"bar": {"b+az ab"}}) - resp.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) + requester.EXPECT().GetRedirectURI().Return(redir) + requester.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2) + responder.EXPECT().GetParameters().Return(url.Values{"bar": {"b+az ab"}}) + responder.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) rw.EXPECT().Header().Return(header).Times(2) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - expect: func() { + expect: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { assert.Equal(t, http.Header{ "X-Bar": {"baz"}, consts.HeaderLocation: {"https://foobar.com/?foo=bar#bar=b%2Baz+ab"}, @@ -109,17 +106,18 @@ func TestWriteAuthorizeResponse(t *testing.T) { }, }, { - setup: func() { + name: "ShouldWriteResponseModeQueryWithCustomHeaders", + setup: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { redir, _ := url.Parse("https://foobar.com/?foo=bar") - ar.EXPECT().GetRedirectURI().Return(redir) - ar.EXPECT().GetResponseMode().Return(ResponseModeQuery) - resp.EXPECT().GetParameters().Return(url.Values{"bar": {"b+az"}, consts.FormParameterScope: {"a b"}}) - resp.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) + requester.EXPECT().GetRedirectURI().Return(redir) + requester.EXPECT().GetResponseMode().Return(ResponseModeQuery).Times(2) + responder.EXPECT().GetParameters().Return(url.Values{"bar": {"b+az"}, consts.FormParameterScope: {"a b"}}) + responder.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) rw.EXPECT().Header().Return(header).Times(2) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - expect: func() { + expect: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { expectedUrl, err := url.Parse("https://foobar.com/?foo=bar&bar=b%2Baz&scope=a+b") assert.Nil(t, err) actualUrl, err := url.Parse(header.Get(consts.HeaderLocation)) @@ -131,17 +129,18 @@ func TestWriteAuthorizeResponse(t *testing.T) { }, }, { - setup: func() { + name: "ShouldWriteResponseModeFragmentWithCustomHeadersAndSpecialChars", + setup: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { redir, _ := url.Parse("https://foobar.com/?foo=bar") - ar.EXPECT().GetRedirectURI().Return(redir) - ar.EXPECT().GetResponseMode().Return(ResponseModeFragment) - resp.EXPECT().GetParameters().Return(url.Values{consts.FormParameterScope: {"api:*"}}) - resp.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) + requester.EXPECT().GetRedirectURI().Return(redir) + requester.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2) + responder.EXPECT().GetParameters().Return(url.Values{consts.FormParameterScope: {"api:*"}}) + responder.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) rw.EXPECT().Header().Return(header).Times(2) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - expect: func() { + expect: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { assert.Equal(t, http.Header{ "X-Bar": {"baz"}, consts.HeaderLocation: {"https://foobar.com/?foo=bar#scope=api%3A%2A"}, @@ -151,17 +150,18 @@ func TestWriteAuthorizeResponse(t *testing.T) { }, }, { - setup: func() { + name: "ShouldWriteResponseModeFragmentWithCustomParameters", + setup: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { redir, _ := url.Parse("https://foobar.com/?foo=bar#bar=baz") - ar.EXPECT().GetRedirectURI().Return(redir) - ar.EXPECT().GetResponseMode().Return(ResponseModeFragment) - resp.EXPECT().GetParameters().Return(url.Values{"qux": {"quux"}}) - resp.EXPECT().GetHeader().Return(http.Header{}) + requester.EXPECT().GetRedirectURI().Return(redir) + requester.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2) + responder.EXPECT().GetParameters().Return(url.Values{"qux": {"quux"}}) + responder.EXPECT().GetHeader().Return(http.Header{}) rw.EXPECT().Header().Return(header).Times(2) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - expect: func() { + expect: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { assert.Equal(t, http.Header{ consts.HeaderLocation: {"https://foobar.com/?foo=bar#qux=quux"}, consts.HeaderCacheControl: []string{consts.CacheControlNoStore}, @@ -170,17 +170,18 @@ func TestWriteAuthorizeResponse(t *testing.T) { }, }, { - setup: func() { + name: "ShouldWriteResponseModeFragmentWithEncodedState", + setup: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { redir, _ := url.Parse("https://foobar.com/?foo=bar") - ar.EXPECT().GetRedirectURI().Return(redir) - ar.EXPECT().GetResponseMode().Return(ResponseModeFragment) - resp.EXPECT().GetParameters().Return(url.Values{consts.FormParameterState: {"{\"a\":\"b=c&d=e\"}"}}) - resp.EXPECT().GetHeader().Return(http.Header{}) + requester.EXPECT().GetRedirectURI().Return(redir) + requester.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2) + responder.EXPECT().GetParameters().Return(url.Values{consts.FormParameterState: {"{\"a\":\"b=c&d=e\"}"}}) + responder.EXPECT().GetHeader().Return(http.Header{}) rw.EXPECT().Header().Return(header).Times(2) rw.EXPECT().WriteHeader(http.StatusSeeOther) }, - expect: func() { + expect: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { assert.Equal(t, http.Header{ consts.HeaderLocation: {"https://foobar.com/?foo=bar#state=%7B%22a%22%3A%22b%3Dc%26d%3De%22%7D"}, consts.HeaderCacheControl: []string{consts.CacheControlNoStore}, @@ -189,26 +190,39 @@ func TestWriteAuthorizeResponse(t *testing.T) { }, }, { - setup: func() { + name: "ShouldWriteResponseModeFormPostWithValues", + setup: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { redir, _ := url.Parse("https://foobar.com/?foo=bar") - ar.EXPECT().GetRedirectURI().Return(redir) - ar.EXPECT().GetResponseMode().Return(ResponseModeFormPost) - resp.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) - resp.EXPECT().GetParameters().Return(url.Values{consts.FormParameterAuthorizationCode: {"poz65kqoneu"}, consts.FormParameterState: {"qm6dnsrn"}}) + requester.EXPECT().GetRedirectURI().Return(redir) + requester.EXPECT().GetResponseMode().Return(ResponseModeFormPost).Times(2) + responder.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) + responder.EXPECT().GetParameters().Return(url.Values{consts.FormParameterAuthorizationCode: {"poz65kqoneu"}, consts.FormParameterState: {"qm6dnsrn"}}) rw.EXPECT().Header().Return(header).AnyTimes() rw.EXPECT().Write(gomock.Any()).AnyTimes() }, - expect: func() { + expect: func(t *testing.T, rw *MockResponseWriter, requester *MockAuthorizeRequester, responder *MockAuthorizeResponder, header http.Header) { assert.Equal(t, consts.ContentTypeTextHTML, header.Get(consts.HeaderContentType)) }, }, - } { - t.Logf("Starting test case %d", k) - c.setup() - provider.WriteAuthorizeResponse(context.Background(), rw, ar, resp) - c.expect() - header = http.Header{} - t.Logf("Passed test case %d", k) + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + provider := &Fosite{Config: new(Config)} + ctrl := gomock.NewController(t) + + rw := NewMockResponseWriter(ctrl) + requester := NewMockAuthorizeRequester(ctrl) + responder := NewMockAuthorizeResponder(ctrl) + + defer ctrl.Finish() + + header := http.Header{} + + tc.setup(t, rw, requester, responder, header) + provider.WriteAuthorizeResponse(context.TODO(), rw, requester, responder) + tc.expect(t, rw, requester, responder, header) + }) } } diff --git a/client.go b/client.go index b110ce7c..933a3762 100644 --- a/client.go +++ b/client.go @@ -6,7 +6,7 @@ package oauth2 import ( "context" - jose "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3" "authelia.com/provider/oauth2/internal/consts" ) @@ -91,6 +91,16 @@ type RevokeFlowRevokeRefreshTokensExplicitClient interface { GetRevokeRefreshTokensExplicitly(ctx context.Context) bool } +// JARMClient is a client which supports JARM. +type JARMClient interface { + Client + + GetAuthorizationSignedResponseKeyID() (kid string) + GetAuthorizationSignedResponseAlg() (alg string) + GetAuthorizationEncryptedResponseAlg() (alg string) + GetAuthorizationEncryptedResponseEncryptionAlg() (alg string) +} + // ResponseModeClient represents a client capable of handling response_mode type ResponseModeClient interface { // GetResponseModes returns the response modes that client is allowed to send diff --git a/config.go b/config.go index 473a221f..acc49ac7 100644 --- a/config.go +++ b/config.go @@ -106,6 +106,23 @@ type JWTScopeFieldProvider interface { GetJWTScopeField(ctx context.Context) jwt.JWTScopeFieldEnum } +// JWTSecuredAuthorizeResponseModeIssuerProvider returns the provider for configuring the JARM issuer. +type JWTSecuredAuthorizeResponseModeIssuerProvider interface { + // GetJWTSecuredAuthorizeResponseModeIssuer returns the JARM issuer. + GetJWTSecuredAuthorizeResponseModeIssuer(ctx context.Context) string +} + +// JWTSecuredAuthorizeResponseModeSignerProvider returns the provider for configuring the JARM signer. +type JWTSecuredAuthorizeResponseModeSignerProvider interface { + // GetJWTSecuredAuthorizeResponseModeSigner returns the JARM signer. + GetJWTSecuredAuthorizeResponseModeSigner(ctx context.Context) jwt.Signer +} + +// JWTSecuredAuthorizeResponseModeLifespanProvider returns the provider for configuring the JWT Secured Authorize Response Mode token lifespan. +type JWTSecuredAuthorizeResponseModeLifespanProvider interface { + GetJWTSecuredAuthorizeResponseModeLifespan(ctx context.Context) time.Duration +} + // AllowedPromptsProvider returns the provider for configuring the allowed prompts. type AllowedPromptsProvider interface { // GetAllowedPrompts returns the allowed prompts. @@ -238,10 +255,10 @@ type ClientAuthenticationStrategyProvider interface { GetClientAuthenticationStrategy(ctx context.Context) ClientAuthenticationStrategy } -// ResponseModeHandlerExtensionProvider returns the provider for configuring the response mode handler extension. -type ResponseModeHandlerExtensionProvider interface { - // GetResponseModeHandlerExtension returns the response mode handler extension. - GetResponseModeHandlerExtension(ctx context.Context) ResponseModeHandler +// ResponseModeHandlerProvider returns the provider for configuring the response mode handlers. +type ResponseModeHandlerProvider interface { + // GetResponseModeHandlers returns the response mode handlers in order of execution. + GetResponseModeHandlers(ctx context.Context) []ResponseModeHandler } // MessageCatalogProvider returns the provider for configuring the message catalog. diff --git a/config_default.go b/config_default.go index fc555d8b..06193924 100644 --- a/config_default.go +++ b/config_default.go @@ -10,7 +10,7 @@ import ( "net/url" "time" - retryablehttp "github.com/hashicorp/go-retryablehttp" + "github.com/hashicorp/go-retryablehttp" "authelia.com/provider/oauth2/i18n" "authelia.com/provider/oauth2/internal/consts" @@ -23,45 +23,48 @@ const ( ) var ( - _ AuthorizeCodeLifespanProvider = (*Config)(nil) - _ RefreshTokenLifespanProvider = (*Config)(nil) - _ AccessTokenLifespanProvider = (*Config)(nil) - _ ScopeStrategyProvider = (*Config)(nil) - _ AudienceStrategyProvider = (*Config)(nil) - _ RedirectSecureCheckerProvider = (*Config)(nil) - _ RefreshTokenScopesProvider = (*Config)(nil) - _ DisableRefreshTokenValidationProvider = (*Config)(nil) - _ AccessTokenIssuerProvider = (*Config)(nil) - _ JWTScopeFieldProvider = (*Config)(nil) - _ AllowedPromptsProvider = (*Config)(nil) - _ OmitRedirectScopeParamProvider = (*Config)(nil) - _ MinParameterEntropyProvider = (*Config)(nil) - _ SanitationAllowedProvider = (*Config)(nil) - _ EnforcePKCEForPublicClientsProvider = (*Config)(nil) - _ EnablePKCEPlainChallengeMethodProvider = (*Config)(nil) - _ EnforcePKCEProvider = (*Config)(nil) - _ GrantTypeJWTBearerCanSkipClientAuthProvider = (*Config)(nil) - _ GrantTypeJWTBearerIDOptionalProvider = (*Config)(nil) - _ GrantTypeJWTBearerIssuedDateOptionalProvider = (*Config)(nil) - _ GetJWTMaxDurationProvider = (*Config)(nil) - _ IDTokenLifespanProvider = (*Config)(nil) - _ IDTokenIssuerProvider = (*Config)(nil) - _ JWKSFetcherStrategyProvider = (*Config)(nil) - _ ClientAuthenticationStrategyProvider = (*Config)(nil) - _ SendDebugMessagesToClientsProvider = (*Config)(nil) - _ ResponseModeHandlerExtensionProvider = (*Config)(nil) - _ MessageCatalogProvider = (*Config)(nil) - _ FormPostHTMLTemplateProvider = (*Config)(nil) - _ TokenURLProvider = (*Config)(nil) - _ GetSecretsHashingProvider = (*Config)(nil) - _ HTTPClientProvider = (*Config)(nil) - _ HMACHashingProvider = (*Config)(nil) - _ AuthorizeEndpointHandlersProvider = (*Config)(nil) - _ TokenEndpointHandlersProvider = (*Config)(nil) - _ TokenIntrospectionHandlersProvider = (*Config)(nil) - _ RevocationHandlersProvider = (*Config)(nil) - _ PushedAuthorizeRequestHandlersProvider = (*Config)(nil) - _ PushedAuthorizeRequestConfigProvider = (*Config)(nil) + _ AuthorizeCodeLifespanProvider = (*Config)(nil) + _ RefreshTokenLifespanProvider = (*Config)(nil) + _ AccessTokenLifespanProvider = (*Config)(nil) + _ ScopeStrategyProvider = (*Config)(nil) + _ AudienceStrategyProvider = (*Config)(nil) + _ RedirectSecureCheckerProvider = (*Config)(nil) + _ RefreshTokenScopesProvider = (*Config)(nil) + _ DisableRefreshTokenValidationProvider = (*Config)(nil) + _ AccessTokenIssuerProvider = (*Config)(nil) + _ JWTScopeFieldProvider = (*Config)(nil) + _ JWTSecuredAuthorizeResponseModeIssuerProvider = (*Config)(nil) + _ JWTSecuredAuthorizeResponseModeSignerProvider = (*Config)(nil) + _ JWTSecuredAuthorizeResponseModeLifespanProvider = (*Config)(nil) + _ AllowedPromptsProvider = (*Config)(nil) + _ OmitRedirectScopeParamProvider = (*Config)(nil) + _ MinParameterEntropyProvider = (*Config)(nil) + _ SanitationAllowedProvider = (*Config)(nil) + _ EnforcePKCEForPublicClientsProvider = (*Config)(nil) + _ EnablePKCEPlainChallengeMethodProvider = (*Config)(nil) + _ EnforcePKCEProvider = (*Config)(nil) + _ GrantTypeJWTBearerCanSkipClientAuthProvider = (*Config)(nil) + _ GrantTypeJWTBearerIDOptionalProvider = (*Config)(nil) + _ GrantTypeJWTBearerIssuedDateOptionalProvider = (*Config)(nil) + _ GetJWTMaxDurationProvider = (*Config)(nil) + _ IDTokenLifespanProvider = (*Config)(nil) + _ IDTokenIssuerProvider = (*Config)(nil) + _ JWKSFetcherStrategyProvider = (*Config)(nil) + _ ClientAuthenticationStrategyProvider = (*Config)(nil) + _ SendDebugMessagesToClientsProvider = (*Config)(nil) + _ ResponseModeHandlerProvider = (*Config)(nil) + _ MessageCatalogProvider = (*Config)(nil) + _ FormPostHTMLTemplateProvider = (*Config)(nil) + _ TokenURLProvider = (*Config)(nil) + _ GetSecretsHashingProvider = (*Config)(nil) + _ HTTPClientProvider = (*Config)(nil) + _ HMACHashingProvider = (*Config)(nil) + _ AuthorizeEndpointHandlersProvider = (*Config)(nil) + _ TokenEndpointHandlersProvider = (*Config)(nil) + _ TokenIntrospectionHandlersProvider = (*Config)(nil) + _ RevocationHandlersProvider = (*Config)(nil) + _ PushedAuthorizeRequestHandlersProvider = (*Config)(nil) + _ PushedAuthorizeRequestConfigProvider = (*Config)(nil) ) type Config struct { @@ -161,8 +164,8 @@ type Config struct { // ClientAuthenticationStrategy indicates the Strategy to authenticate client requests ClientAuthenticationStrategy ClientAuthenticationStrategy - // ResponseModeHandlerExtension provides a handler for custom response modes - ResponseModeHandlerExtension ResponseModeHandler + // ResponseModeHandlers provides the handlers for performing response mode formatting. + ResponseModeHandlers []ResponseModeHandler // MessageCatalog is the message bundle used for i18n MessageCatalog i18n.MessageCatalog @@ -180,6 +183,16 @@ type Config struct { // JWTScopeClaimKey defines the claim key to be used to set the scope in. Valid fields are "scope" or "scp" or both. JWTScopeClaimKey jwt.JWTScopeFieldEnum + // JWTSecuredAuthorizeResponseModeIssuer sets the default issuer for the JWT Secured Authorization Response Mode. + JWTSecuredAuthorizeResponseModeIssuer string + + // JWTSecuredAuthorizeResponseModeLifespan sets the default lifetime for the tokens issued in the + // JWT Secured Authorization Response Mode. Defaults to 10 minutes. + JWTSecuredAuthorizeResponseModeLifespan time.Duration + + // JWTSecuredAuthorizeResponseModeSigner is the signer for JWT Secured Authorization Response Mode. Has no default. + JWTSecuredAuthorizeResponseModeSigner jwt.Signer + // AccessTokenIssuer is the issuer to be used when generating access tokens. AccessTokenIssuer string @@ -260,6 +273,7 @@ func (c *Config) GetHTTPClient(ctx context.Context) *retryablehttp.Client { if c.HTTPClient == nil { return retryablehttp.NewClient() } + return c.HTTPClient } @@ -267,6 +281,7 @@ func (c *Config) GetSecretsHasher(ctx context.Context) Hasher { if c.ClientSecretsHasher == nil { c.ClientSecretsHasher = &BCrypt{Config: c} } + return c.ClientSecretsHasher } @@ -282,8 +297,12 @@ func (c *Config) GetMessageCatalog(ctx context.Context) i18n.MessageCatalog { return c.MessageCatalog } -func (c *Config) GetResponseModeHandlerExtension(ctx context.Context) ResponseModeHandler { - return c.ResponseModeHandlerExtension +func (c *Config) GetResponseModeHandlers(ctx context.Context) []ResponseModeHandler { + if len(c.ResponseModeHandlers) == 0 { + c.ResponseModeHandlers = []ResponseModeHandler{&DefaultResponseModeHandler{Config: c}} + } + + return c.ResponseModeHandlers } func (c *Config) GetSendDebugMessagesToClients(ctx context.Context) bool { @@ -350,6 +369,14 @@ func (c *Config) GetJWTScopeField(ctx context.Context) jwt.JWTScopeFieldEnum { return c.JWTScopeClaimKey } +func (c *Config) GetJWTSecuredAuthorizeResponseModeIssuer(ctx context.Context) string { + return c.IDTokenIssuer +} + +func (c *Config) GetJWTSecuredAuthorizeResponseModeSigner(ctx context.Context) jwt.Signer { + return c.JWTSecuredAuthorizeResponseModeSigner +} + func (c *Config) GetAllowedPrompts(_ context.Context) []string { return c.AllowedPromptValues } @@ -359,6 +386,7 @@ func (c *Config) GetScopeStrategy(_ context.Context) ScopeStrategy { if c.ScopeStrategy == nil { c.ScopeStrategy = WildcardScopeStrategy } + return c.ScopeStrategy } @@ -367,6 +395,7 @@ func (c *Config) GetAudienceStrategy(_ context.Context) AudienceMatchingStrategy if c.AudienceMatchingStrategy == nil { c.AudienceMatchingStrategy = DefaultAudienceMatchingStrategy } + return c.AudienceMatchingStrategy } @@ -375,6 +404,7 @@ func (c *Config) GetAuthorizeCodeLifespan(_ context.Context) time.Duration { if c.AuthorizeCodeLifespan == 0 { return time.Minute * 15 } + return c.AuthorizeCodeLifespan } @@ -383,6 +413,7 @@ func (c *Config) GetIDTokenLifespan(_ context.Context) time.Duration { if c.IDTokenLifespan == 0 { return time.Hour } + return c.IDTokenLifespan } @@ -391,6 +422,7 @@ func (c *Config) GetAccessTokenLifespan(_ context.Context) time.Duration { if c.AccessTokenLifespan == 0 { return time.Hour } + return c.AccessTokenLifespan } @@ -399,6 +431,7 @@ func (c *Config) GetVerifiableCredentialsNonceLifespan(_ context.Context) time.D if c.VerifiableCredentialsNonceLifespan == 0 { return time.Hour } + return c.VerifiableCredentialsNonceLifespan } @@ -408,14 +441,25 @@ func (c *Config) GetRefreshTokenLifespan(_ context.Context) time.Duration { if c.RefreshTokenLifespan == 0 { return time.Hour * 24 * 30 } + return c.RefreshTokenLifespan } +// GetJWTSecuredAuthorizeResponseModeLifespan returns how long a JWT issued by the JWT Secured Authorize Response Mode should be valid. Defaults to 10 minutes. +func (c *Config) GetJWTSecuredAuthorizeResponseModeLifespan(_ context.Context) time.Duration { + if c.JWTSecuredAuthorizeResponseModeLifespan == 0 { + return time.Minute * 10 + } + + return c.JWTSecuredAuthorizeResponseModeLifespan +} + // GetBCryptCost returns the bcrypt cost factor. Defaults to 12. func (c *Config) GetBCryptCost(_ context.Context) int { if c.HashCost == 0 { return DefaultBCryptWorkFactor } + return c.HashCost } @@ -424,6 +468,7 @@ func (c *Config) GetJWKSFetcherStrategy(_ context.Context) JWKSFetcherStrategy { if c.JWKSFetcherStrategy == nil { c.JWKSFetcherStrategy = NewDefaultJWKSFetcherStrategy() } + return c.JWKSFetcherStrategy } @@ -432,6 +477,7 @@ func (c *Config) GetTokenEntropy(_ context.Context) int { if c.TokenEntropy == 0 { return 32 } + return c.TokenEntropy } @@ -440,6 +486,7 @@ func (c *Config) GetRedirectSecureChecker(_ context.Context) func(context.Contex if c.RedirectSecureChecker == nil { return IsRedirectURISecure } + return c.RedirectSecureChecker } @@ -448,6 +495,7 @@ func (c *Config) GetRefreshTokenScopes(_ context.Context) []string { if c.RefreshTokenScopes == nil { return []string{consts.ScopeOffline, consts.ScopeOfflineAccess} } + return c.RefreshTokenScopes } @@ -455,9 +503,9 @@ func (c *Config) GetRefreshTokenScopes(_ context.Context) []string { func (c *Config) GetMinParameterEntropy(_ context.Context) int { if c.MinParameterEntropy == 0 { return MinParameterEntropy - } else { - return c.MinParameterEntropy } + + return c.MinParameterEntropy } // GetJWTMaxDuration specified the maximum amount of allowed `exp` time for a JWT. It compares @@ -468,6 +516,7 @@ func (c *Config) GetJWTMaxDuration(_ context.Context) time.Duration { if c.GrantTypeJWTBearerMaxDuration == 0 { return time.Hour * 24 } + return c.GrantTypeJWTBearerMaxDuration } diff --git a/errors.go b/errors.go index e57975e2..96052469 100644 --- a/errors.go +++ b/errors.go @@ -534,6 +534,15 @@ func (e *RFC6749Error) computeHintField() { e.HintField = i18n.GetMessageOrDefault(e.catalog, e.hintIDField, e.lang, e.HintField, e.hintArgs...) } +func ErrorToRFC6749ErrorFallback(err error, fallback *RFC6749Error) *RFC6749Error { + var e *RFC6749Error + if errors.As(err, &e) { + return e + } + + return fallback.WithWrap(err).WithDebug(err.Error()) +} + // ErrorToDebugRFC6749Error converts the provided error to a *DebugRFC6749Error provided it is not nil and can be // cast as a *RFC6749Error. func ErrorToDebugRFC6749Error(err error) (rfc error) { diff --git a/fosite.go b/fosite.go index 8fd06dd6..2d0bff50 100644 --- a/fosite.go +++ b/fosite.go @@ -10,8 +10,6 @@ import ( const MinParameterEntropy = 8 -var defaultResponseModeHandler = &DefaultResponseModeHandler{} - // AuthorizeEndpointHandlers is a list of AuthorizeEndpointHandler type AuthorizeEndpointHandlers []AuthorizeEndpointHandler @@ -101,6 +99,9 @@ type Configurator interface { OmitRedirectScopeParamProvider SanitationAllowedProvider JWTScopeFieldProvider + JWTSecuredAuthorizeResponseModeIssuerProvider + JWTSecuredAuthorizeResponseModeSignerProvider + JWTSecuredAuthorizeResponseModeLifespanProvider AccessTokenIssuerProvider DisableRefreshTokenValidationProvider RefreshTokenScopesProvider @@ -118,12 +119,11 @@ type Configurator interface { MinParameterEntropyProvider HMACHashingProvider ClientAuthenticationStrategyProvider - ResponseModeHandlerExtensionProvider + ResponseModeHandlerProvider SendDebugMessagesToClientsProvider RevokeRefreshTokensExplicitlyProvider JWKSFetcherStrategyProvider ClientAuthenticationStrategyProvider - ResponseModeHandlerExtensionProvider MessageCatalogProvider FormPostHTMLTemplateProvider TokenURLProvider @@ -155,9 +155,7 @@ func (f *Fosite) GetMinParameterEntropy(ctx context.Context) int { return MinParameterEntropy } -func (f *Fosite) ResponseModeHandler(ctx context.Context) ResponseModeHandler { - if ext := f.Config.GetResponseModeHandlerExtension(ctx); ext != nil { - return ext - } - return defaultResponseModeHandler +// ResponseModeHandlers returns the configured ResponseModeHandler implementations for this instance. +func (f *Fosite) ResponseModeHandlers(ctx context.Context) []ResponseModeHandler { + return f.Config.GetResponseModeHandlers(ctx) } diff --git a/handler/rfc7523/handler_test.go b/handler/rfc7523/handler_test.go index 691a5ffb..ecf5f120 100644 --- a/handler/rfc7523/handler_test.go +++ b/handler/rfc7523/handler_test.go @@ -15,7 +15,7 @@ import ( "testing" "time" - jose "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3" "github.com/go-jose/go-jose/v3/jwt" "github.com/stretchr/testify/suite" "go.uber.org/mock/gomock" diff --git a/integration/authorize_form_post_test.go b/integration/authorize_form_post_test.go index 31db5d72..416c3a9c 100644 --- a/integration/authorize_form_post_test.go +++ b/integration/authorize_form_post_test.go @@ -5,13 +5,13 @@ package integration_test import ( "context" - "fmt" + "errors" "net/http" + "net/http/httptest" "net/url" "strings" "testing" - "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" xoauth2 "golang.org/x/oauth2" @@ -26,49 +26,29 @@ import ( ) type formPostTestCase struct { - description string - setup func() + name string + setup func(t *testing.T, server *httptest.Server) (state string, client *xoauth2.Config) check checkFunc responseType string } -type checkFunc func(t *testing.T, stateFromServer string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) +type checkFunc func(t *testing.T, expectedState, actualState string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) func TestAuthorizeFormPostResponseMode(t *testing.T) { - session := &defaultSession{ - DefaultSession: &openid.DefaultSession{ - Claims: &jwt.IDTokenClaims{ - Subject: "peter", - }, - Headers: &jwt.Headers{}, - }, - } - config := &oauth2.Config{ResponseModeHandlerExtension: &decoratedFormPostResponse{}, GlobalSecret: []byte("some-secret-thats-random-some-secret-thats-random-")} - f := compose.ComposeAllEnabled(config, store, gen.MustRSAKey()) - ts := mockServer(t, f, session) - defer ts.Close() - - oauthClient := newOAuth2Client(ts) - defaultClient := store.Clients["my-client"].(*oauth2.DefaultClient) - defaultClient.RedirectURIs[0] = ts.URL + "/callback" - responseModeClient := &oauth2.DefaultResponseModeClient{ - DefaultClient: defaultClient, - ResponseModes: []oauth2.ResponseModeType{oauth2.ResponseModeFormPost, oauth2.ResponseModeFormPost, "decorated_form_post"}, - } - store.Clients["response-mode-client"] = responseModeClient - oauthClient.ClientID = "response-mode-client" - - var state string - for k, c := range []formPostTestCase{ + testCases := []struct { + name string + setup func(t *testing.T, client *xoauth2.Config, server *httptest.Server) + check checkFunc + responseType string + state string + }{ { - description: "implicit grant #1 test with form_post", + name: "ShouldHandleImplicitFlowBoth", responseType: "id_token%20token", - setup: func() { - state = "12345678901234567890" - oauthClient.Scopes = []string{consts.ScopeOpenID} - }, - check: func(t *testing.T, stateFromServer string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) { - assert.EqualValues(t, state, stateFromServer) + state: "12345678901234567890", + setup: func(t *testing.T, client *xoauth2.Config, server *httptest.Server) {}, + check: func(t *testing.T, expectedState, actualState string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) { + assert.Equal(t, expectedState, actualState) assert.NotEmpty(t, token.TokenType) assert.NotEmpty(t, token.AccessToken) assert.NotEmpty(t, token.Expiry) @@ -76,37 +56,32 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) { }, }, { - description: "implicit grant #2 test with form_post", + name: "ShouldHandleImplicitFlowIDToken", responseType: consts.ResponseTypeImplicitFlowIDToken, - setup: func() { - state = "12345678901234567890" - oauthClient.Scopes = []string{consts.ScopeOpenID} - }, - check: func(t *testing.T, stateFromServer string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) { - assert.EqualValues(t, state, stateFromServer) + state: "12345678901234567890", + setup: func(t *testing.T, client *xoauth2.Config, server *httptest.Server) {}, + check: func(t *testing.T, expectedState, actualState string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) { + assert.Equal(t, expectedState, actualState) assert.NotEmpty(t, iDToken) }, }, { - description: "Authorization code grant test with form_post", + name: "ShouldHandleAuthorizationCodeFlow", responseType: consts.ResponseTypeAuthorizationCodeFlow, - setup: func() { - state = "12345678901234567890" - }, - check: func(t *testing.T, stateFromServer string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) { - assert.EqualValues(t, state, stateFromServer) + state: "12345678901234567890", + setup: func(t *testing.T, client *xoauth2.Config, server *httptest.Server) {}, + check: func(t *testing.T, expectedState, actualState string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) { + assert.Equal(t, expectedState, actualState) assert.NotEmpty(t, code) }, }, { - description: "Hybrid #1 grant test with form_post", + name: "ShouldHandleHybridFlowToken", responseType: "token%20code", - setup: func() { - state = "12345678901234567890" - oauthClient.Scopes = []string{consts.ScopeOpenID} - }, - check: func(t *testing.T, stateFromServer string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) { - assert.EqualValues(t, state, stateFromServer) + state: "12345678901234567890", + setup: func(t *testing.T, client *xoauth2.Config, server *httptest.Server) {}, + check: func(t *testing.T, expectedState, actualState string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) { + assert.Equal(t, expectedState, actualState) assert.NotEmpty(t, code) assert.NotEmpty(t, token.TokenType) assert.NotEmpty(t, token.AccessToken) @@ -114,14 +89,12 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) { }, }, { - description: "Hybrid #2 grant test with form_post", + name: "ShouldHandleHybridFlowBoth", responseType: "token%20id_token%20code", - setup: func() { - state = "12345678901234567890" - oauthClient.Scopes = []string{consts.ScopeOpenID} - }, - check: func(t *testing.T, stateFromServer string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) { - assert.EqualValues(t, state, stateFromServer) + state: "12345678901234567890", + setup: func(t *testing.T, client *xoauth2.Config, server *httptest.Server) {}, + check: func(t *testing.T, expectedState, actualState string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) { + assert.Equal(t, expectedState, actualState) assert.NotEmpty(t, code) assert.NotEmpty(t, iDToken) assert.NotEmpty(t, token.TokenType) @@ -130,66 +103,141 @@ func TestAuthorizeFormPostResponseMode(t *testing.T) { }, }, { - description: "Hybrid #3 grant test with form_post", + name: "ShouldHandleHybridFlowIDToken", responseType: "id_token%20code", - setup: func() { - state = "12345678901234567890" - oauthClient.Scopes = []string{consts.ScopeOpenID} - }, - check: func(t *testing.T, stateFromServer string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) { - assert.EqualValues(t, state, stateFromServer) + state: "12345678901234567890", + setup: func(t *testing.T, client *xoauth2.Config, server *httptest.Server) {}, + check: func(t *testing.T, expectedState, actualState string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) { + assert.Equal(t, expectedState, actualState) assert.NotEmpty(t, code) assert.NotEmpty(t, iDToken) }, }, { - description: "error message test for form_post response", + name: "ShouldHandleFailure", responseType: "foo", - setup: func() { - state = "12345678901234567890" - }, - check: func(t *testing.T, stateFromServer string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) { - assert.EqualValues(t, state, stateFromServer) + state: "12345678901234567890", + setup: func(t *testing.T, client *xoauth2.Config, server *httptest.Server) {}, + check: func(t *testing.T, expectedState, actualState string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) { + assert.Equal(t, expectedState, actualState) assert.NotEmpty(t, err["ErrorField"]) assert.NotEmpty(t, err["DescriptionField"]) }, }, - } { - // Test canonical form_post - t.Run(fmt.Sprintf("case=%d/description=%s", k, c.description), testFormPost(&state, false, c, oauthClient, consts.ResponseModeFormPost)) - - // Test decorated form_post response - c.check = decorateCheck(c.check) - t.Run(fmt.Sprintf("case=%d/description=decorated_%s", k, c.description), testFormPost(&state, true, c, oauthClient, "decorated_form_post")) } -} -func testFormPost(state *string, customResponse bool, c formPostTestCase, oauthClient *xoauth2.Config, responseMode string) func(t *testing.T) { - return func(t *testing.T) { - c.setup() - authURL := strings.Replace(oauthClient.AuthCodeURL(*state, xoauth2.SetAuthURLParam("response_mode", responseMode), xoauth2.SetAuthURLParam("nonce", "111111111")), "response_type=code", "response_type="+c.responseType, -1) - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return errors.New("Dont follow redirects") - }, - } - resp, err := client.Get(authURL) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - code, state, token, iDToken, cparam, errResp, err := internal.ParseFormPostResponse(store.Clients["response-mode-client"].GetRedirectURIs()[0], resp.Body) - require.NoError(t, err) - c.check(t, state, code, iDToken, token, cparam, errResp) - } -} + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Run("Canonical", func(t *testing.T) { + session := &defaultSession{ + DefaultSession: &openid.DefaultSession{ + Claims: &jwt.IDTokenClaims{ + Subject: "peter", + }, + Headers: &jwt.Headers{}, + }, + } + + config := &oauth2.Config{GlobalSecret: []byte("some-secret-thats-random-some-secret-thats-random-")} + config.ResponseModeHandlers = []oauth2.ResponseModeHandler{&oauth2.DefaultResponseModeHandler{Config: config}, &DecoratedFormPostResponse{}} + + f := compose.ComposeAllEnabled(config, store, gen.MustRSAKey()) + server := mockServer(t, f, session) -func decorateCheck(cf checkFunc) checkFunc { - return func(t *testing.T, stateFromServer string, code string, token xoauth2.Token, iDToken string, cparam url.Values, err map[string]string) { - cf(t, stateFromServer, code, token, iDToken, cparam, err) - if len(err) > 0 { - assert.Contains(t, cparam, "custom_err_param") - return - } - assert.Contains(t, cparam, "custom_param") + defer server.Close() + + store.Clients["response-mode-client"] = &oauth2.DefaultResponseModeClient{ + DefaultClient: &oauth2.DefaultClient{ + ID: "response-mode-client", + Secret: []byte(`$2a$10$IxMdI6d.LIRZPpSfEwNoeu4rY3FhDREsxFJXikcgdRRAStxUlsuEO`), // = "foobar" + RedirectURIs: []string{server.URL + "/callback"}, + ResponseTypes: []string{consts.ResponseTypeImplicitFlowIDToken, consts.ResponseTypeAuthorizationCodeFlow, consts.ResponseTypeImplicitFlowToken, consts.ResponseTypeImplicitFlowBoth, consts.ResponseTypeHybridFlowIDToken, consts.ResponseTypeHybridFlowToken, consts.ResponseTypeHybridFlowBoth}, + GrantTypes: []string{consts.GrantTypeImplicit, consts.GrantTypeRefreshToken, consts.GrantTypeAuthorizationCode, consts.GrantTypeResourceOwnerPasswordCredentials, consts.GrantTypeClientCredentials}, + Scopes: []string{"oauth2", consts.ScopeOffline, consts.ScopeOpenID}, + Audience: []string{tokenURL}, + }, + ResponseModes: []oauth2.ResponseModeType{oauth2.ResponseModeFormPost, oauth2.ResponseModeFormPost, "decorated_form_post"}, + } + + client := newOAuth2Client(server) + client.ClientID = "response-mode-client" + client.Scopes = []string{consts.ScopeOpenID} + + authURL := strings.Replace(client.AuthCodeURL(tc.state, xoauth2.SetAuthURLParam(consts.FormParameterResponseMode, consts.ResponseModeFormPost), xoauth2.SetAuthURLParam("nonce", "111111111")), "response_type=code", "response_type="+tc.responseType, -1) + + c := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return errors.New("Dont follow redirects") + }, + } + + resp, err := c.Get(authURL) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + code, actualState, token, iDToken, cparam, errResp, err := internal.ParseFormPostResponse(store.Clients["response-mode-client"].GetRedirectURIs()[0], resp.Body) + require.NoError(t, err) + + tc.check(t, tc.state, actualState, code, iDToken, token, cparam, errResp) + }) + + t.Run("Decorated", func(t *testing.T) { + session := &defaultSession{ + DefaultSession: &openid.DefaultSession{ + Claims: &jwt.IDTokenClaims{ + Subject: "peter", + }, + Headers: &jwt.Headers{}, + }, + } + + config := &oauth2.Config{GlobalSecret: []byte("some-secret-thats-random-some-secret-thats-random-")} + config.ResponseModeHandlers = []oauth2.ResponseModeHandler{&oauth2.DefaultResponseModeHandler{Config: config}, &DecoratedFormPostResponse{}} + + f := compose.ComposeAllEnabled(config, store, gen.MustRSAKey()) + server := mockServer(t, f, session) + + defer server.Close() + + store.Clients["response-mode-client"] = &oauth2.DefaultResponseModeClient{ + DefaultClient: &oauth2.DefaultClient{ + ID: "response-mode-client", + Secret: []byte(`$2a$10$IxMdI6d.LIRZPpSfEwNoeu4rY3FhDREsxFJXikcgdRRAStxUlsuEO`), // = "foobar" + RedirectURIs: []string{server.URL + "/callback"}, + ResponseTypes: []string{consts.ResponseTypeImplicitFlowIDToken, consts.ResponseTypeAuthorizationCodeFlow, consts.ResponseTypeImplicitFlowToken, consts.ResponseTypeImplicitFlowBoth, consts.ResponseTypeHybridFlowIDToken, consts.ResponseTypeHybridFlowToken, consts.ResponseTypeHybridFlowBoth}, + GrantTypes: []string{consts.GrantTypeImplicit, consts.GrantTypeRefreshToken, consts.GrantTypeAuthorizationCode, consts.GrantTypeResourceOwnerPasswordCredentials, consts.GrantTypeClientCredentials}, + Scopes: []string{"oauth2", consts.ScopeOffline, consts.ScopeOpenID}, + Audience: []string{tokenURL}, + }, + ResponseModes: []oauth2.ResponseModeType{oauth2.ResponseModeFormPost, oauth2.ResponseModeFormPost, "decorated_form_post"}, + } + + client := newOAuth2Client(server) + client.ClientID = "response-mode-client" + client.Scopes = []string{consts.ScopeOpenID} + + authURL := strings.Replace(client.AuthCodeURL(tc.state, xoauth2.SetAuthURLParam(consts.FormParameterResponseMode, "decorated_form_post"), xoauth2.SetAuthURLParam("nonce", "111111111")), "response_type=code", "response_type="+tc.responseType, -1) + + c := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return errors.New("Dont follow redirects") + }, + } + + resp, err := c.Get(authURL) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + code, actualState, token, iDToken, cparam, errResp, err := internal.ParseFormPostResponse(store.Clients["response-mode-client"].GetRedirectURIs()[0], resp.Body) + require.NoError(t, err) + + tc.check(t, tc.state, actualState, code, iDToken, token, cparam, errResp) + + if len(errResp) > 0 { + assert.Contains(t, cparam, "custom_err_param") + return + } + assert.Contains(t, cparam, "custom_param") + }) + }) } } @@ -197,25 +245,24 @@ func decorateCheck(cf checkFunc) checkFunc { // of a custom response mode handler. // In this case it decorates the `form_post` response mode // with some additional custom parameters -type decoratedFormPostResponse struct { -} +type DecoratedFormPostResponse struct{} -func (m *decoratedFormPostResponse) ResponseModes() oauth2.ResponseModeTypes { +func (m *DecoratedFormPostResponse) ResponseModes() oauth2.ResponseModeTypes { return oauth2.ResponseModeTypes{"decorated_form_post"} } -func (m *decoratedFormPostResponse) WriteAuthorizeResponse(ctx context.Context, rw http.ResponseWriter, ar oauth2.AuthorizeRequester, resp oauth2.AuthorizeResponder) { +func (m *DecoratedFormPostResponse) WriteAuthorizeResponse(ctx context.Context, rw http.ResponseWriter, ar oauth2.AuthorizeRequester, resp oauth2.AuthorizeResponder) { rw.Header().Add(consts.HeaderContentType, consts.ContentTypeTextHTML) resp.AddParameter("custom_param", "foo") oauth2.WriteAuthorizeFormPostResponse(ar.GetRedirectURI().String(), resp.GetParameters(), oauth2.GetPostFormHTMLTemplate(ctx, - oauth2.New(nil, new(oauth2.Config))), rw) + new(oauth2.Config)), rw) } -func (m *decoratedFormPostResponse) WriteAuthorizeError(ctx context.Context, rw http.ResponseWriter, ar oauth2.AuthorizeRequester, err error) { +func (m *DecoratedFormPostResponse) WriteAuthorizeError(ctx context.Context, rw http.ResponseWriter, ar oauth2.AuthorizeRequester, err error) { rfcerr := oauth2.ErrorToRFC6749Error(err) errors := rfcerr.ToValues() - errors.Set("state", ar.GetState()) + errors.Set(consts.FormParameterState, ar.GetState()) errors.Add("custom_err_param", "bar") oauth2.WriteAuthorizeFormPostResponse(ar.GetRedirectURI().String(), errors, oauth2.GetPostFormHTMLTemplate(ctx, - oauth2.New(nil, new(oauth2.Config))), rw) + new(oauth2.Config)), rw) } diff --git a/internal/reflection/field.go b/internal/reflection/field.go index c23c63c2..d87a66df 100644 --- a/internal/reflection/field.go +++ b/internal/reflection/field.go @@ -6,7 +6,7 @@ import ( "reflect" ) -func GetField(obj interface{}, name string) (interface{}, error) { +func GetField(obj any, name string) (any, error) { if !hasValidType(obj, []reflect.Kind{reflect.Struct, reflect.Ptr}) { return nil, errors.New("Cannot use GetField on a non-struct interface") } @@ -20,7 +20,7 @@ func GetField(obj interface{}, name string) (interface{}, error) { return field.Interface(), nil } -func hasValidType(obj interface{}, types []reflect.Kind) bool { +func hasValidType(obj any, types []reflect.Kind) bool { for _, t := range types { if reflect.TypeOf(obj).Kind() == t { return true @@ -30,7 +30,7 @@ func hasValidType(obj interface{}, types []reflect.Kind) bool { return false } -func reflectValue(obj interface{}) reflect.Value { +func reflectValue(obj any) reflect.Value { var val reflect.Value if reflect.TypeOf(obj).Kind() == reflect.Ptr { diff --git a/response_handler.go b/response_handler.go index 5733a800..3a076b1c 100644 --- a/response_handler.go +++ b/response_handler.go @@ -5,10 +5,179 @@ package oauth2 import ( "context" + "encoding/json" + "fmt" "net/http" + "net/url" + + "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/internal/errorsx" + "authelia.com/provider/oauth2/token/jarm" +) + +type DefaultResponseModeHandler struct { + Config ResponseModeHandlerConfigurator +} + +var ( + _ ResponseModeHandler = (*DefaultResponseModeHandler)(nil) ) -// ResponseModeHandler provides a contract for handling custom response modes +// ResponseModes returns the response modes this fosite.ResponseModeHandler is responsible for. +func (h *DefaultResponseModeHandler) ResponseModes() ResponseModeTypes { + return ResponseModeTypes{ + ResponseModeDefault, + ResponseModeQuery, + ResponseModeFragment, + ResponseModeFormPost, + ResponseModeJWT, + ResponseModeQueryJWT, + ResponseModeFragmentJWT, + ResponseModeFormPostJWT, + } +} + +// WriteAuthorizeResponse writes authorization responses. +func (h *DefaultResponseModeHandler) WriteAuthorizeResponse(ctx context.Context, rw http.ResponseWriter, requester AuthorizeRequester, responder AuthorizeResponder) { + header := rw.Header() + + header.Set(consts.HeaderCacheControl, consts.CacheControlNoStore) + header.Set(consts.HeaderPragma, consts.PragmaNoCache) + + rheader := responder.GetHeader() + + for k := range rheader { + header.Set(k, rheader.Get(k)) + } + + h.handleWriteAuthorizeResponse(ctx, rw, requester, responder.GetParameters()) +} + +// WriteAuthorizeError writes authorization errors. +func (h *DefaultResponseModeHandler) WriteAuthorizeError(ctx context.Context, rw http.ResponseWriter, requester AuthorizeRequester, e error) { + rfc := ErrorToRFC6749Error(e). + WithLegacyFormat(h.Config.GetUseLegacyErrorFormat(ctx)). + WithExposeDebug(h.Config.GetSendDebugMessagesToClients(ctx)). + WithLocalizer(h.Config.GetMessageCatalog(ctx), getLangFromRequester(requester)) + + if !requester.IsRedirectURIValid() { + h.handleWriteAuthorizeErrorJSON(ctx, rw, rfc) + + return + } + + parameters := rfc.ToValues() + + if state := requester.GetState(); len(state) != 0 { + parameters.Set(consts.FormParameterState, state) + } + + h.handleWriteAuthorizeResponse(ctx, rw, requester, parameters) +} + +func (h *DefaultResponseModeHandler) handleWriteAuthorizeResponse(ctx context.Context, rw http.ResponseWriter, requester AuthorizeRequester, parameters url.Values) { + redirectURI := requester.GetRedirectURI() + redirectURI.Fragment = "" + + var ( + form url.Values + err error + location string + ) + + rm := requester.GetResponseMode() + + if rm == ResponseModeJWT { + if requester.GetResponseTypes().ExactOne(consts.ResponseTypeAuthorizationCodeFlow) { + rm = ResponseModeQueryJWT + } else { + rm = ResponseModeFragmentJWT + } + } + + switch rm { + case ResponseModeFormPost, ResponseModeFormPostJWT: + if form, err = h.EncodeResponseForm(ctx, rm, requester, parameters); err != nil { + h.handleWriteAuthorizeErrorJSON(ctx, rw, ErrServerError.WithWrap(err).WithDebug(err.Error())) + + return + } + + rw.Header().Set(consts.HeaderContentType, consts.ContentTypeTextHTML) + WriteAuthorizeFormPostResponse(redirectURI.String(), form, GetPostFormHTMLTemplate(ctx, h.Config), rw) + + return + case ResponseModeQuery, ResponseModeDefault, ResponseModeQueryJWT, ResponseModeJWT: + for key, values := range redirectURI.Query() { + for _, value := range values { + parameters.Add(key, value) + } + } + + if form, err = h.EncodeResponseForm(ctx, rm, requester, parameters); err != nil { + h.handleWriteAuthorizeErrorJSON(ctx, rw, ErrServerError.WithWrap(err).WithDebug(err.Error())) + + return + } + + redirectURI.RawQuery = form.Encode() + + location = redirectURI.String() + case ResponseModeFragment, ResponseModeFragmentJWT: + if form, err = h.EncodeResponseForm(ctx, rm, requester, parameters); err != nil { + h.handleWriteAuthorizeErrorJSON(ctx, rw, ErrServerError.WithWrap(err).WithDebug(err.Error())) + + return + } + + location = redirectURI.String() + "#" + form.Encode() + } + + rw.Header().Set(consts.HeaderLocation, location) + rw.WriteHeader(http.StatusSeeOther) +} + +// EncodeResponseForm encodes the response form if necessary. +func (h *DefaultResponseModeHandler) EncodeResponseForm(ctx context.Context, rm ResponseModeType, requester AuthorizeRequester, parameters url.Values) (form url.Values, err error) { + switch rm { + case ResponseModeFormPostJWT, ResponseModeQueryJWT, ResponseModeFragmentJWT: + client := requester.GetClient() + + jclient, ok := client.(JARMClient) + if !ok { + return nil, errorsx.WithStack(ErrServerError.WithDebug("The client is not capable of handling the JWT-Secured Authorization Response Mode.")) + } + + return jarm.EncodeParameters(jarm.Generate(ctx, h.Config, jclient, requester.GetSession(), parameters)) + default: + return parameters, nil + } +} + +func (h *DefaultResponseModeHandler) handleWriteAuthorizeErrorJSON(ctx context.Context, rw http.ResponseWriter, rfc *RFC6749Error) { + rw.Header().Set(consts.HeaderContentType, consts.ContentTypeApplicationJSON) + + var ( + data []byte + err error + ) + + if data, err = json.Marshal(rfc); err != nil { + if h.Config.GetSendDebugMessagesToClients(ctx) { + errorMessage := EscapeJSONString(err.Error()) + http.Error(rw, fmt.Sprintf(`{"error":"server_error","error_description":"%s"}`, errorMessage), http.StatusInternalServerError) + } else { + http.Error(rw, `{"error":"server_error"}`, http.StatusInternalServerError) + } + + return + } + + rw.WriteHeader(rfc.CodeField) + _, _ = rw.Write(data) +} + +// ResponseModeHandler provides a contract for handling response modes. type ResponseModeHandler interface { // ResponseModes returns a set of supported response modes handled // by the interface implementation. @@ -20,17 +189,17 @@ type ResponseModeHandler interface { // WriteAuthorizeResponse writes successful responses // - // Following headers are expected to be set by default: + // The following headers are expected to be set by implementations of this interface: // header.Set(consts.HeaderCacheControl, consts.CacheControlNoStore) // header.Set(consts.HeaderPragma, consts.PragmaNoCache) - WriteAuthorizeResponse(ctx context.Context, rw http.ResponseWriter, ar AuthorizeRequester, resp AuthorizeResponder) + WriteAuthorizeResponse(ctx context.Context, rw http.ResponseWriter, requester AuthorizeRequester, responder AuthorizeResponder) // WriteAuthorizeError writes error responses // - // Following headers are expected to be set by default: + // The following headers are expected to be set by implementations of this interface: // header.Set(consts.HeaderCacheControl, consts.CacheControlNoStore) // header.Set(consts.HeaderPragma, consts.PragmaNoCache) - WriteAuthorizeError(ctx context.Context, rw http.ResponseWriter, ar AuthorizeRequester, err error) + WriteAuthorizeError(ctx context.Context, rw http.ResponseWriter, requester AuthorizeRequester, err error) } type ResponseModeTypes []ResponseModeType @@ -44,16 +213,12 @@ func (rs ResponseModeTypes) Has(item ResponseModeType) bool { return false } -func NewDefaultResponseModeHandler() *DefaultResponseModeHandler { - return new(DefaultResponseModeHandler) -} - -type DefaultResponseModeHandler struct{} - -func (d *DefaultResponseModeHandler) ResponseModes() ResponseModeTypes { return nil } - -func (d *DefaultResponseModeHandler) WriteAuthorizeResponse(ctx context.Context, rw http.ResponseWriter, ar AuthorizeRequester, resp AuthorizeResponder) { -} - -func (d *DefaultResponseModeHandler) WriteAuthorizeError(ctx context.Context, rw http.ResponseWriter, ar AuthorizeRequester, err error) { +type ResponseModeHandlerConfigurator interface { + FormPostHTMLTemplateProvider + JWTSecuredAuthorizeResponseModeIssuerProvider + JWTSecuredAuthorizeResponseModeSignerProvider + JWTSecuredAuthorizeResponseModeLifespanProvider + MessageCatalogProvider + SendDebugMessagesToClientsProvider + UseLegacyErrorFormatProvider } diff --git a/storage/memory.go b/storage/memory.go index 91b272e1..4fde92e6 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -9,7 +9,7 @@ import ( "sync" "time" - jose "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3" "authelia.com/provider/oauth2" "authelia.com/provider/oauth2/internal" diff --git a/token/jarm/generate.go b/token/jarm/generate.go new file mode 100644 index 00000000..bb4e18bd --- /dev/null +++ b/token/jarm/generate.go @@ -0,0 +1,83 @@ +package jarm + +import ( + "context" + "errors" + "net/url" + "time" + + "github.com/google/uuid" + + "authelia.com/provider/oauth2/internal/consts" + "authelia.com/provider/oauth2/token/jwt" +) + +// EncodeParameters takes the result from jarm.Generate and turns it into parameters in the form of url.Values. +func EncodeParameters(token, _ string, tErr error) (parameters url.Values, err error) { + if tErr != nil { + return nil, tErr + } + + return url.Values{consts.FormParameterResponse: []string{token}}, nil +} + +// Generate generates the token and signature for a JARM response. +func Generate(ctx context.Context, config Configurator, client Client, session any, in url.Values) (token, signature string, err error) { + headers := map[string]any{} + + if alg := client.GetAuthorizationSignedResponseAlg(); len(alg) > 0 { + headers[consts.JSONWebTokenHeaderAlgorithm] = alg + } + + if kid := client.GetAuthorizationSignedResponseKeyID(); len(kid) > 0 { + headers[consts.JSONWebTokenHeaderKeyIdentifier] = kid + } + + var issuer string + + issuer = config.GetJWTSecuredAuthorizeResponseModeIssuer(ctx) + + if len(issuer) == 0 { + var ( + src jwt.MapClaims + value any + ok bool + ) + + switch s := session.(type) { + case nil: + return "", "", errors.New("The JARM response modes require the Authorize Requester session to be set but it wasn't.") + case OpenIDSession: + src = s.IDTokenClaims().ToMapClaims() + case JWTSessionContainer: + src = s.GetJWTClaims().ToMapClaims() + default: + return "", "", errors.New("The JARM response modes require the Authorize Requester session to implement either the openid.Session or oauth2.JWTSessionContainer interfaces but it doesn't.") + } + + if value, ok = src[consts.ClaimIssuer]; ok { + issuer, _ = value.(string) + } + } + + claims := &jwt.JARMClaims{ + JTI: uuid.New().String(), + Issuer: issuer, + IssuedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(config.GetJWTSecuredAuthorizeResponseModeLifespan(ctx)), + Audience: []string{client.GetID()}, + Extra: map[string]any{}, + } + + for param := range in { + claims.Extra[param] = in.Get(param) + } + + var signer jwt.Signer + + if signer = config.GetJWTSecuredAuthorizeResponseModeSigner(ctx); signer == nil { + return "", "", errors.New("The JARM response modes require the JWTSecuredAuthorizeResponseModeSignerProvider to return a jwt.Signer but it didn't.") + } + + return signer.Generate(ctx, claims.ToMapClaims(), &jwt.Headers{Extra: headers}) +} diff --git a/token/jarm/types.go b/token/jarm/types.go new file mode 100644 index 00000000..c36ca4d2 --- /dev/null +++ b/token/jarm/types.go @@ -0,0 +1,32 @@ +package jarm + +import ( + "context" + "time" + + "authelia.com/provider/oauth2/token/jwt" +) + +type Configurator interface { + GetJWTSecuredAuthorizeResponseModeIssuer(ctx context.Context) string + GetJWTSecuredAuthorizeResponseModeSigner(ctx context.Context) jwt.Signer + GetJWTSecuredAuthorizeResponseModeLifespan(ctx context.Context) time.Duration +} + +type Client interface { + GetID() string + GetAuthorizationSignedResponseKeyID() (kid string) + GetAuthorizationSignedResponseAlg() (alg string) + GetAuthorizationEncryptedResponseAlg() (alg string) + GetAuthorizationEncryptedResponseEncryptionAlg() (alg string) +} + +type OpenIDSession interface { + IDTokenHeaders() *jwt.Headers + IDTokenClaims() *jwt.IDTokenClaims +} + +type JWTSessionContainer interface { + GetJWTHeader() *jwt.Headers + GetJWTClaims() jwt.JWTClaimsContainer +} diff --git a/token/jwt/claims.go b/token/jwt/claims.go index f6279985..501e1ad3 100644 --- a/token/jwt/claims.go +++ b/token/jwt/claims.go @@ -77,3 +77,28 @@ func Copy(elements map[string]any) (result map[string]any) { return result } + +// StringSliceFromMap asserts a map any value to a []string provided it has a good type. +func StringSliceFromMap(value any) (values []string, ok bool) { + switch v := value.(type) { + case nil: + return nil, true + case []string: + return v, true + case string: + return []string{v}, true + case []any: + for _, item := range v { + switch iv := item.(type) { + case string: + values = append(values, iv) + default: + return nil, false + } + } + + return values, true + default: + return nil, false + } +} diff --git a/token/jwt/claims_jarm.go b/token/jwt/claims_jarm.go new file mode 100644 index 00000000..d43c5a98 --- /dev/null +++ b/token/jwt/claims_jarm.go @@ -0,0 +1,107 @@ +package jwt + +import ( + "time" + + "github.com/google/uuid" + + "authelia.com/provider/oauth2/internal/consts" +) + +// JARMClaims represent a token's claims. +type JARMClaims struct { + Issuer string + Audience []string + JTI string + IssuedAt time.Time + ExpiresAt time.Time + Extra map[string]any +} + +// ToMap will transform the headers to a map structure +func (c *JARMClaims) ToMap() map[string]any { + var ret = Copy(c.Extra) + + if c.Issuer != "" { + ret[consts.ClaimIssuer] = c.Issuer + } else { + delete(ret, consts.ClaimIssuer) + } + + if c.JTI != "" { + ret[consts.ClaimJWTID] = c.JTI + } else { + ret[consts.ClaimJWTID] = uuid.New().String() + } + + if len(c.Audience) > 0 { + ret[consts.ClaimAudience] = c.Audience + } else { + ret[consts.ClaimAudience] = []string{} + } + + if !c.IssuedAt.IsZero() { + ret[consts.ClaimIssuedAt] = c.IssuedAt.Unix() + } else { + delete(ret, consts.ClaimIssuedAt) + } + + if !c.ExpiresAt.IsZero() { + ret[consts.ClaimExpirationTime] = c.ExpiresAt.Unix() + } else { + delete(ret, consts.ClaimExpirationTime) + } + + return ret +} + +// FromMap will set the claims based on a mapping +func (c *JARMClaims) FromMap(m map[string]any) { + c.Extra = make(map[string]any) + for k, v := range m { + switch k { + case consts.ClaimIssuer: + if s, ok := v.(string); ok { + c.Issuer = s + } + case consts.ClaimJWTID: + if s, ok := v.(string); ok { + c.JTI = s + } + case consts.ClaimAudience: + if aud, ok := StringSliceFromMap(v); ok { + c.Audience = aud + } + case consts.ClaimIssuedAt: + c.IssuedAt = toTime(v, c.IssuedAt) + case consts.ClaimExpirationTime: + c.ExpiresAt = toTime(v, c.ExpiresAt) + default: + c.Extra[k] = v + } + } +} + +// Add will add a key-value pair to the extra field +func (c *JARMClaims) Add(key string, value any) { + if c.Extra == nil { + c.Extra = make(map[string]any) + } + + c.Extra[key] = value +} + +// Get will get a value from the extra field based on a given key +func (c JARMClaims) Get(key string) any { + return c.ToMap()[key] +} + +// ToMapClaims will return a jwt-go MapClaims representation +func (c JARMClaims) ToMapClaims() MapClaims { + return c.ToMap() +} + +// FromMapClaims will populate claims from a jwt-go MapClaims representation +func (c *JARMClaims) FromMapClaims(mc MapClaims) { + c.FromMap(mc) +} diff --git a/token/jwt/claims_jarm_test.go b/token/jwt/claims_jarm_test.go new file mode 100644 index 00000000..941236d7 --- /dev/null +++ b/token/jwt/claims_jarm_test.go @@ -0,0 +1,63 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package jwt_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "authelia.com/provider/oauth2/internal/consts" + . "authelia.com/provider/oauth2/token/jwt" +) + +var jarmClaims = &JARMClaims{ + Issuer: "authelia", + Audience: []string{"tests"}, + JTI: "abcdef", + IssuedAt: time.Now().UTC().Round(time.Second), + ExpiresAt: time.Now().UTC().Add(time.Hour).Round(time.Second), + Extra: map[string]any{ + "foo": "bar", + "baz": "bar", + }, +} + +var jarmClaimsMap = map[string]any{ + consts.ClaimIssuer: jwtClaims.Issuer, + consts.ClaimAudience: jwtClaims.Audience, + consts.ClaimJWTID: jwtClaims.JTI, + consts.ClaimIssuedAt: jwtClaims.IssuedAt.Unix(), + consts.ClaimExpirationTime: jwtClaims.ExpiresAt.Unix(), + "foo": jwtClaims.Extra["foo"], + "baz": jwtClaims.Extra["baz"], +} + +func TestJARMClaimAddGetString(t *testing.T) { + jarmClaims.Add("foo", "bar") + assert.Equal(t, "bar", jarmClaims.Get("foo")) +} + +func TestJARMClaimsToMapSetsID(t *testing.T) { + assert.NotEmpty(t, (&JARMClaims{}).ToMap()[consts.ClaimJWTID]) +} + +func TestJARMAssert(t *testing.T) { + assert.Nil(t, (&JARMClaims{ExpiresAt: time.Now().UTC().Add(time.Hour)}). + ToMapClaims().Valid()) + assert.NotNil(t, (&JARMClaims{ExpiresAt: time.Now().UTC().Add(-2 * time.Hour)}). + ToMapClaims().Valid()) +} + +func TestJARMtClaimsToMap(t *testing.T) { + assert.Equal(t, jarmClaimsMap, jarmClaims.ToMap()) +} + +func TestJARMClaimsFromMap(t *testing.T) { + var claims JARMClaims + + claims.FromMap(jarmClaimsMap) + assert.Equal(t, jarmClaims, &claims) +} diff --git a/token/jwt/claims_jwt.go b/token/jwt/claims_jwt.go index 11acab81..f84c241a 100644 --- a/token/jwt/claims_jwt.go +++ b/token/jwt/claims_jwt.go @@ -175,10 +175,8 @@ func (c *JWTClaims) FromMap(m map[string]any) { c.Issuer = s } case consts.ClaimAudience: - if s, ok := v.(string); ok { - c.Audience = []string{s} - } else if s, ok := v.([]string); ok { - c.Audience = s + if aud, ok := StringSliceFromMap(v); ok { + c.Audience = aud } case consts.ClaimIssuedAt: c.IssuedAt = toTime(v, c.IssuedAt) diff --git a/token/jwt/claims_jwt_test.go b/token/jwt/claims_jwt_test.go index 3bd1f393..d746e4a7 100644 --- a/token/jwt/claims_jwt_test.go +++ b/token/jwt/claims_jwt_test.go @@ -48,7 +48,7 @@ func TestClaimAddGetString(t *testing.T) { } func TestClaimsToMapSetsID(t *testing.T) { - assert.NotEmpty(t, (&JWTClaims{}).ToMap()["jti"]) + assert.NotEmpty(t, (&JWTClaims{}).ToMap()[consts.ClaimJWTID]) } func TestAssert(t *testing.T) { diff --git a/token/jwt/map_claims.go b/token/jwt/claims_map.go similarity index 93% rename from token/jwt/map_claims.go rename to token/jwt/claims_map.go index e748d052..d0543cbb 100644 --- a/token/jwt/map_claims.go +++ b/token/jwt/claims_map.go @@ -28,24 +28,16 @@ type MapClaims map[string]any // VerifyAudience compares the aud claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyAudience(cmp string, req bool) bool { - var aud []string - switch v := m[consts.ClaimAudience].(type) { - case []string: - aud = v - case []any: - for _, a := range v { - vs, ok := a.(string) - if !ok { - return false - } - aud = append(aud, vs) - } - case string: - aud = append(aud, v) - default: - return false + var ( + aud []string + ok bool + ) + + if aud, ok = StringSliceFromMap(m[consts.ClaimAudience]); ok { + return verifyAud(aud, cmp, req) } - return verifyAud(aud, cmp, req) + + return false } // VerifyExpiresAt compares the exp claim against cmp. @@ -94,6 +86,7 @@ func (m MapClaims) toInt64(claim string) (int64, bool) { if err == nil { return v, true } + vf, err := t.Float64() if err != nil { return 0, false @@ -101,6 +94,7 @@ func (m MapClaims) toInt64(claim string) (int64, bool) { return int64(vf), true } + return 0, false } @@ -160,6 +154,7 @@ func verifyAud(aud []string, cmp string, required bool) bool { return true } } + return false } @@ -167,6 +162,7 @@ func verifyExp(exp int64, now int64, required bool) bool { if exp == 0 { return !required } + return now <= exp } @@ -174,6 +170,7 @@ func verifyIat(iat int64, now int64, required bool) bool { if iat == 0 { return !required } + return now >= iat } @@ -181,6 +178,7 @@ func verifyIss(iss string, cmp string, required bool) bool { if iss == "" { return !required } + if subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) != 0 { return true } else { @@ -192,5 +190,6 @@ func verifyNbf(nbf int64, now int64, required bool) bool { if nbf == 0 { return !required } + return now >= nbf } diff --git a/token/jwt/map_claims_test.go b/token/jwt/claims_map_test.go similarity index 99% rename from token/jwt/map_claims_test.go rename to token/jwt/claims_map_test.go index 625d36ad..9613e461 100644 --- a/token/jwt/map_claims_test.go +++ b/token/jwt/claims_map_test.go @@ -96,7 +96,7 @@ func Test_mapClaims_string_aud_no_claim(t *testing.T) { func Test_mapClaims_string_aud_no_claim_not_required(t *testing.T) { mapClaims := MapClaims{} - want := false + want := true got := mapClaims.VerifyAudience("foo", false) if want != got {