Skip to content

Commit

Permalink
feat: client auth rework
Browse files Browse the repository at this point in the history
This reworks client authentication preventing situations where authentication will pass when it shouldn't.
  • Loading branch information
james-d-elliott committed Feb 18, 2024
1 parent 2c1d3f2 commit ff1cd54
Show file tree
Hide file tree
Showing 12 changed files with 1,064 additions and 546 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-go@v3
with:
go-version: 1.21
go-version: 1.22
- run: make test
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ following list of differences:
- [x] Higher Debug error information visibility (Debug Field includes the
complete RFC6749 error with debug information if available)
- Fixes:
- [x] Basic Scheme Rejects Special Characters
- [x] ~~Basic Scheme Rejects Special Characters~~
<sup>[commit](https://github.com/authelia/oauth2-provider/commit/2314625eb1f21987a9199fb1cdf6da6cee4df965)</sup>
- [x] RFC9068 must condition ignored
<sup>[commit](https://github.com/authelia/oauth2-provider/commit/c6e7a18ee9066b8c17c6f30a180d44507e2e2ff1)</sup>
Expand Down Expand Up @@ -85,6 +85,8 @@ following list of differences:
- [x] Revocation Flow per policy can decide to revoke Refresh Tokens on
request <sup>[commit](e3ffc451f1c7056494f9dc3e51d47e84f12357de)</sup>
- Client Authentication Rework:
- [x] General Refactor
- [x] Prevent Multiple Client Authentication Methods
- [ ] Client Secret Validation Interface
- [ ] JWE support for Client Authentication and Issuance
- [ ] Clock Drift Support
Expand Down
148 changes: 81 additions & 67 deletions access_request_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,55 +22,44 @@ import (
)

func TestNewAccessRequest(t *testing.T) {
ctrl := gomock.NewController(t)
store := internal.NewMockStorage(ctrl)
handler := internal.NewMockTokenEndpointHandler(ctrl)
handler.EXPECT().CanHandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(true).AnyTimes()
handler.EXPECT().CanSkipClientAuth(gomock.Any(), gomock.Any()).Return(false).AnyTimes()
hasher := internal.NewMockHasher(ctrl)
defer ctrl.Finish()

ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil))

client := &DefaultClient{}
config := &Config{ClientSecretsHasher: hasher, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}
provider := &Fosite{Store: store, Config: config}
for k, c := range []struct {
testCases := []struct {
name string
header http.Header
form url.Values
mock func()
mock func(ctx gomock.Matcher, handler *internal.MockTokenEndpointHandler, store *internal.MockStorage, hasher *internal.MockHasher, client *DefaultClient)
method string
expectErr error
expect *AccessRequest
handlers TokenEndpointHandlers
expect func(client *DefaultClient) *AccessRequest
handlers func(handler *internal.MockTokenEndpointHandler) TokenEndpointHandlers
}{
{
name: "ShouldReturnInvalidRequestWhenNoValues",
header: http.Header{},
expectErr: ErrInvalidRequest,
form: url.Values{},
method: "POST",
mock: func() {},
},
{
name: "ShouldReturnInvalidRequestWhenOnlyGrantType",
header: http.Header{},
method: "POST",
form: url.Values{
consts.FormParameterGrantType: {"foo"},
},
mock: func() {},
expectErr: ErrInvalidRequest,
},
{
name: "ShouldReturnInvalidRequestWhenEmptyClientID",
header: http.Header{},
method: "POST",
form: url.Values{
consts.FormParameterGrantType: {"foo"},
consts.FormParameterClientID: {""},
},
expectErr: ErrInvalidRequest,
mock: func() {},
},
{
name: "ShouldReturnInvalidClientWhenGetClientError",
header: http.Header{
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
Expand All @@ -79,12 +68,15 @@ func TestNewAccessRequest(t *testing.T) {
consts.FormParameterGrantType: {"foo"},
},
expectErr: ErrInvalidClient,
mock: func() {
mock: func(ctx gomock.Matcher, handler *internal.MockTokenEndpointHandler, store *internal.MockStorage, hasher *internal.MockHasher, client *DefaultClient) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New(""))
},
handlers: TokenEndpointHandlers{handler},
handlers: func(handler *internal.MockTokenEndpointHandler) TokenEndpointHandlers {
return TokenEndpointHandlers{handler}
},
},
{
name: "ShouldReturnInvalidRequestWhenInvalidMethod",
header: http.Header{
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
Expand All @@ -93,9 +85,9 @@ func TestNewAccessRequest(t *testing.T) {
consts.FormParameterGrantType: {"foo"},
},
expectErr: ErrInvalidRequest,
mock: func() {},
},
{
name: "ShouldReturnInvalidClientWhenBadClientSecret",
header: http.Header{
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
Expand All @@ -104,29 +96,18 @@ func TestNewAccessRequest(t *testing.T) {
consts.FormParameterGrantType: {"foo"},
},
expectErr: ErrInvalidClient,
mock: func() {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New(""))
},
handlers: TokenEndpointHandlers{handler},
},
{
header: http.Header{
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
method: "POST",
form: url.Values{
consts.FormParameterGrantType: {"foo"},
},
expectErr: ErrInvalidClient,
mock: func() {
mock: func(ctx gomock.Matcher, handler *internal.MockTokenEndpointHandler, store *internal.MockStorage, hasher *internal.MockHasher, client *DefaultClient) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New(""))
},
handlers: TokenEndpointHandlers{handler},
handlers: func(handler *internal.MockTokenEndpointHandler) TokenEndpointHandlers {
return TokenEndpointHandlers{handler}
},
},
{
name: "ShouldReturnErrorWhenHandleTokenEndpointError",
header: http.Header{
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
Expand All @@ -135,78 +116,111 @@ func TestNewAccessRequest(t *testing.T) {
consts.FormParameterGrantType: {"foo"},
},
expectErr: ErrServerError,
mock: func() {
mock: func(ctx gomock.Matcher, handler *internal.MockTokenEndpointHandler, store *internal.MockStorage, hasher *internal.MockHasher, client *DefaultClient) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(ErrServerError)
},
handlers: TokenEndpointHandlers{handler},
handlers: func(handler *internal.MockTokenEndpointHandler) TokenEndpointHandlers {
return TokenEndpointHandlers{handler}
},
},
{
name: "ShouldHandleConfidentialClientSuccessfully",
header: http.Header{
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
},
method: "POST",
form: url.Values{
consts.FormParameterGrantType: {"foo"},
},
mock: func() {
mock: func(ctx gomock.Matcher, handler *internal.MockTokenEndpointHandler, store *internal.MockStorage, hasher *internal.MockHasher, client *DefaultClient) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = false
client.Secret = []byte("foo")
hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
},
handlers: TokenEndpointHandlers{handler},
expect: &AccessRequest{
GrantTypes: Arguments{"foo"},
Request: Request{
Client: client,
},
handlers: func(handler *internal.MockTokenEndpointHandler) TokenEndpointHandlers {
return TokenEndpointHandlers{handler}
},
expect: func(client *DefaultClient) *AccessRequest {
return &AccessRequest{
GrantTypes: Arguments{"foo"},
Request: Request{
Client: client,
},
}
},
},
{
name: "ShouldHandlePublicClientTypeSuccessfully",
header: http.Header{
consts.HeaderAuthorization: {basicAuth("foo", "bar")},
consts.HeaderAuthorization: {basicAuth("foo", "")},
},
method: "POST",
form: url.Values{
consts.FormParameterGrantType: {"foo"},
},
mock: func() {
mock: func(ctx gomock.Matcher, handler *internal.MockTokenEndpointHandler, store *internal.MockStorage, hasher *internal.MockHasher, client *DefaultClient) {
store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
client.Public = true
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
},
handlers: TokenEndpointHandlers{handler},
expect: &AccessRequest{
GrantTypes: Arguments{"foo"},
Request: Request{
Client: client,
},
handlers: func(handler *internal.MockTokenEndpointHandler) TokenEndpointHandlers {
return TokenEndpointHandlers{handler}
},
expect: func(client *DefaultClient) *AccessRequest {
return &AccessRequest{
GrantTypes: Arguments{"foo"},
Request: Request{
Client: client,
},
}
},
},
} {
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
store := internal.NewMockStorage(ctrl)
handler := internal.NewMockTokenEndpointHandler(ctrl)
handler.EXPECT().CanHandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(true).AnyTimes()
handler.EXPECT().CanSkipClientAuth(gomock.Any(), gomock.Any()).Return(false).AnyTimes()
hasher := internal.NewMockHasher(ctrl)
defer ctrl.Finish()

ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil))

client := &DefaultClient{}
config := &Config{ClientSecretsHasher: hasher, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}
provider := &Fosite{Store: store, Config: config}

r := &http.Request{
Header: c.header,
PostForm: c.form,
Form: c.form,
Method: c.method,
Header: tc.header,
PostForm: tc.form,
Form: tc.form,
Method: tc.method,
}

c.mock()
config.TokenEndpointHandlers = c.handlers
if tc.mock != nil {
tc.mock(ctx, handler, store, hasher, client)
}

if tc.handlers != nil {
config.TokenEndpointHandlers = tc.handlers(handler)
}

ar, err := provider.NewAccessRequest(context.TODO(), r, new(DefaultSession))

if c.expectErr != nil {
assert.EqualError(t, err, c.expectErr.Error())
if tc.expectErr != nil {
assert.EqualError(t, err, tc.expectErr.Error())
} else {
require.NoError(t, err)
AssertObjectKeysEqual(t, c.expect, ar, "GrantTypes", "Client")
AssertObjectKeysEqual(t, tc.expect(client), ar, "GrantTypes", "Client")
assert.NotNil(t, ar.GetRequestedAt())
}
})
Expand Down
23 changes: 23 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package oauth2

import (
"context"
"fmt"

"github.com/go-jose/go-jose/v3"

Expand Down Expand Up @@ -47,8 +48,21 @@ type ClientWithSecretRotation interface {
GetRotatedHashes() [][]byte
}

// ClientAuthenticationPolicyClient is a Client implementation which also provides client authentication policy values.
type ClientAuthenticationPolicyClient interface {
Client

// GetAllowMultipleAuthenticationMethods should return true if the client policy allows multiple authentication
// methods due to the client implementation breaching RFC6749 Section 2.3.
//
// See: https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.
GetAllowMultipleAuthenticationMethods(ctx context.Context) bool
}

// OpenIDConnectClient represents a client capable of performing OpenID Connect requests.
type OpenIDConnectClient interface {
Client

// GetRequestURIs is an array of request_uri values that are pre-registered by the RP for use at the OP. Servers MAY
// cache the contents of the files referenced by these URIs and not retrieve them at the time they are used in a request.
// OPs can require that request_uri values used be pre-registered with the require_request_uri_registration
Expand All @@ -73,6 +87,11 @@ type OpenIDConnectClient interface {
// JWS [JWS] alg algorithm [JWA] that MUST be used for signing the JWT [JWT] used to authenticate the
// Client at the Token Endpoint for the private_key_jwt and client_secret_jwt authentication methods.
GetTokenEndpointAuthSigningAlgorithm() string

// GetSecretPlainText returns the client secret in plain text.
// This is used to validate the 'token_endpoint_client_auth_method' with a value of 'client_secret_jwt'. If this
// client does NOT have a plain text secret then it MUST return an error.
GetSecretPlainText() (secret []byte, err error)
}

// RefreshFlowScopeClient is a client which can be customized to ignore scopes that were not originally granted.
Expand Down Expand Up @@ -209,6 +228,10 @@ func (c *DefaultOpenIDConnectClient) GetRequestObjectSigningAlgorithm() string {
return c.RequestObjectSigningAlgorithm
}

func (c *DefaultOpenIDConnectClient) GetSecretPlainText() (secret []byte, err error) {
return nil, fmt.Errorf("this registered client does not suport plain text")
}

func (c *DefaultOpenIDConnectClient) GetTokenEndpointAuthMethod() string {
return c.TokenEndpointAuthMethod
}
Expand Down
Loading

0 comments on commit ff1cd54

Please sign in to comment.