Skip to content

Commit

Permalink
feat(pkce): per client policy
Browse files Browse the repository at this point in the history
This adds a PKCE policy control on a per-client basis and overhauls the testing of this particular handler and the particular error messages it returns.
  • Loading branch information
james-d-elliott committed Mar 11, 2024
1 parent c4913a3 commit c8f4c46
Show file tree
Hide file tree
Showing 8 changed files with 1,372 additions and 384 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ following list of differences:
<sup>[commit](https://github.com/authelia/oauth2-provider/commit/6584d3495422a97ef9aba92e762ffaebce010dd0)</sup>
- [x] Original request id not set early enough
<sup>[commit](https://github.com/authelia/oauth2-provider/commit/6584d3495422a97ef9aba92e762ffaebce010dd0)</sup>
- PKCE Flow
- PKCE Flow:
- [x] Session generated needlessly
<sup>[commit](https://github.com/authelia/oauth2-provider/commit/dbdadf5dee92d13683eeacaa198c28d6704ddb1c)</sup>
- [x] Failure to fetch session causes an error even when not enforced
Expand All @@ -57,6 +57,8 @@ following list of differences:
- [x] Access Token iat and nbf in JWT Profile always original claims
<sup>[commit](https://github.com/authelia/oauth2-provider/commit/a87d91df762a8fe26282145ba9dace0461f31b4d)</sup>
- Features:
- PKCE Flow:
- [x] Per-Client Enforcement Policy
- CoreStrategy:
- [x] Customizable Token Prefix
<sup>[commit](https://github.com/authelia/oauth2-provider/commit/4f55dabdf5d87c34053992c3de3fe7b1bf1046f3)</sup>
Expand Down
1 change: 1 addition & 0 deletions authorize_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,6 @@ func (a *AuthorizeResponse) AddParameter(key, value string) {
if key == consts.FormParameterAuthorizationCode {
a.code = value
}

a.Parameters.Add(key, value)
}
9 changes: 9 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ type RotatedClientSecretsClient interface {
Client
}

// ProofKeyCodeExchangeClient is a Client implementation which provides PKCE client policy values.
type ProofKeyCodeExchangeClient interface {
GetEnforcePKCE() (enforce bool)
GetEnforcePKCEChallengeMethod() (enforce bool)
GetPKCEChallengeMethod() (method string)

Client
}

// ClientAuthenticationPolicyClient is a Client implementation which also provides client authentication policy values.
type ClientAuthenticationPolicyClient interface {
// GetAllowMultipleAuthenticationMethods should return true if the client policy allows multiple authentication
Expand Down
2 changes: 1 addition & 1 deletion compose/compose_pkce.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
func OAuth2PKCEFactory(config oauth2.Configurator, storage any, strategy any) any {
return &pkce.Handler{
AuthorizeCodeStrategy: strategy.(hoauth2.AuthorizeCodeStrategy),
Storage: storage.(pkce.PKCERequestStorage),
Storage: storage.(pkce.Storage),
Config: config,
}
}
7 changes: 7 additions & 0 deletions handler/pkce/const.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package pkce

import "regexp"

var (
verifierWrongFormat = regexp.MustCompile(`[^\w.~-]`)
)
110 changes: 67 additions & 43 deletions handler/pkce/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"regexp"
"fmt"

"github.com/pkg/errors"

Expand All @@ -20,21 +20,15 @@ import (

type Handler struct {
AuthorizeCodeStrategy hoauth2.AuthorizeCodeStrategy
Storage PKCERequestStorage
Storage Storage
Config interface {
oauth2.EnforcePKCEProvider
oauth2.EnforcePKCEForPublicClientsProvider
oauth2.EnablePKCEPlainChallengeMethodProvider
}
}

var (
_ oauth2.TokenEndpointHandler = (*Handler)(nil)
)

var verifierWrongFormat = regexp.MustCompile(`[^\w.~-]`)

func (c *Handler) HandleAuthorizeEndpointRequest(ctx context.Context, requester oauth2.AuthorizeRequester, responder oauth2.AuthorizeResponder) error {
func (c *Handler) HandleAuthorizeEndpointRequest(ctx context.Context, requester oauth2.AuthorizeRequester, responder oauth2.AuthorizeResponder) (err error) {
// This let's us define multiple response types, for example the OpenID Connect 1.0 `id_token`.
if !requester.GetResponseTypes().Has(consts.ResponseTypeAuthorizationCodeFlow) {
return nil
Expand All @@ -44,12 +38,12 @@ func (c *Handler) HandleAuthorizeEndpointRequest(ctx context.Context, requester
method := requester.GetRequestForm().Get(consts.FormParameterCodeChallengeMethod)
client := requester.GetClient()

if err := c.validate(ctx, challenge, method, client); err != nil {
if err = c.validate(ctx, challenge, method, client); err != nil {
return err
}

// We don't need a session if it's not enforced and the PKCE parameters are not provided by the client.
if challenge == "" && method == "" {
if len(challenge) == 0 && len(method) == 0 {
return nil
}

Expand All @@ -61,25 +55,25 @@ func (c *Handler) HandleAuthorizeEndpointRequest(ctx context.Context, requester

signature := c.AuthorizeCodeStrategy.AuthorizeCodeSignature(ctx, code)

if err := c.Storage.CreatePKCERequestSession(ctx, signature, requester.Sanitize([]string{
if err = c.Storage.CreatePKCERequestSession(ctx, signature, requester.Sanitize([]string{
consts.FormParameterCodeChallenge,
consts.FormParameterCodeChallengeMethod,
})); err != nil {
return errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebugError(err))
return errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebugError(fmt.Errorf("Error occurred attempting create PKCE request session: %w.", err)))
}

return nil
}

func (c *Handler) validate(ctx context.Context, challenge, method string, client oauth2.Client) error {
func (c *Handler) validate(ctx context.Context, challenge, method string, client oauth2.Client) (err error) {
if len(challenge) == 0 {
// If the server requires Proof Key for Code Exchange (PKCE) by OAuth
// clients and the client does not send the "code_challenge" in
// the request, the authorization endpoint MUST return the authorization
// error response with the "error" value set to "invalid_request". The
// "error_description" or the response of "error_uri" SHOULD explain the
// nature of error, e.g., code challenge required.
return c.validateNoPKCE(ctx, client)
return c.validateNoPKCE(ctx, consts.FormParameterCodeChallenge, client)
}

// If the server supporting PKCE does not support the requested
Expand All @@ -91,17 +85,30 @@ func (c *Handler) validate(ctx context.Context, challenge, method string, client
switch method {
case consts.PKCEChallengeMethodSHA256:
break
case consts.PKCEChallengeMethodPlain:
fallthrough
case "":
fallthrough
case consts.PKCEChallengeMethodPlain:
if !c.Config.GetEnablePKCEPlainChallengeMethod(ctx) {
return errorsx.WithStack(oauth2.ErrInvalidRequest.
WithHint("Clients must use code_challenge_method=S256, plain is not allowed.").
WithDebug("The server is configured in a way that enforces PKCE S256 as challenge method for clients."))
WithHint("Authorization was requested with 'code_challenge_method' value 'plain', but the authorization server policy does not allow method 'plain' and requires method 'S256'.").
WithDebug("The authorization server is configured in a way that enforces the 'S256' PKCE 'code_challenge_method' for all clients."))
}
default:
return errorsx.WithStack(oauth2.ErrInvalidRequest.
WithHint("The code_challenge_method is not supported, use S256 instead."))
WithHintf("Authorization was requested with 'code_challenge_method' value '%s', but the authorization server doesn't know how to handle this method, try 'S256' instead.", method))
}

if pkce, ok := client.(oauth2.ProofKeyCodeExchangeClient); ok {
if pkce.GetEnforcePKCEChallengeMethod() && method != pkce.GetPKCEChallengeMethod() {
switch cmethod := pkce.GetPKCEChallengeMethod(); {
case method == cmethod, method == "" && cmethod == consts.PKCEChallengeMethodPlain:
break
default:
return errorsx.WithStack(oauth2.ErrInvalidRequest.
WithHintf("Authorization was requested with 'code_challenge_method' value '%s', but the authorization server policy does not allow method '%s' and requires method '%s'.", method, method, pkce.GetPKCEChallengeMethod()).
WithDebugf("The registered client with id '%s' is configured in a way that enforces the use of 'code_challenge_method' with a value of '%s' but the authorization request included method '%s'.", client.GetID(), cmethod, method))
}
}
}

return nil
Expand All @@ -112,7 +119,7 @@ func (c *Handler) validate(ctx context.Context, challenge, method string, client
// TODO: Refactor time permitting.
//
//nolint:gocyclo
func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, requester oauth2.AccessRequester) error {
func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, requester oauth2.AccessRequester) (err error) {
if !c.CanHandleTokenEndpointRequest(ctx, requester) {
return errorsx.WithStack(oauth2.ErrUnknownRequest)
}
Expand All @@ -125,31 +132,37 @@ func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, requester oaut
// endpoint MUST use to verify the "code_verifier".
verifier := requester.GetRequestForm().Get(consts.FormParameterCodeVerifier)

nv := len(verifier)

code := requester.GetRequestForm().Get(consts.FormParameterAuthorizationCode)
signature := c.AuthorizeCodeStrategy.AuthorizeCodeSignature(ctx, code)
requesterPKCE, err := c.Storage.GetPKCERequestSession(ctx, signature, requester.GetSession())

nv := len(verifier)
var requesterPKCE oauth2.Requester

if requesterPKCE, err = c.Storage.GetPKCERequestSession(ctx, signature, requester.GetSession()); err != nil {
if errors.Is(err, oauth2.ErrNotFound) {
if nv == 0 {
return c.validateNoPKCE(ctx, consts.FormParameterCodeVerifier, requester.GetClient())
}

if errors.Is(err, oauth2.ErrNotFound) {
if nv == 0 {
return c.validateNoPKCE(ctx, requester.GetClient())
return errorsx.WithStack(oauth2.ErrInvalidGrant.WithHint("Unable to find initial PKCE data tied to this request.").WithWrap(err).WithDebugError(err))
}

return errorsx.WithStack(oauth2.ErrInvalidGrant.WithHint("Unable to find initial PKCE data tied to this request.").WithWrap(err).WithDebugError(err))
} else if err != nil {
return errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebugError(err))
return errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebugError(fmt.Errorf("Error occurred attempting get PKCE request session: %w.", err)))
}

if err = c.Storage.DeletePKCERequestSession(ctx, signature); err != nil {
return errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebugError(err))
return errorsx.WithStack(oauth2.ErrServerError.WithWrap(err).WithDebugError(fmt.Errorf("Error occurred attempting delete PKCE request session: %w.", err)))
}

challenge := requesterPKCE.GetRequestForm().Get(consts.FormParameterCodeChallenge)
method := requesterPKCE.GetRequestForm().Get(consts.FormParameterCodeChallengeMethod)
client := requesterPKCE.GetClient()

if err = c.validate(ctx, challenge, method, client); err != nil {
if err = c.validate(ctx, challenge, method, requester.GetClient()); err != nil {
return err
}

if err = c.validate(ctx, challenge, method, requesterPKCE.GetClient()); err != nil {
return err
}

Expand All @@ -172,7 +185,7 @@ func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, requester oaut
WithHint("The PKCE code verifier must be at least 43 characters."))
case nv > 128:
return errorsx.WithStack(oauth2.ErrInvalidGrant.
WithHint("The PKCE code verifier can not be longer than 128 characters."))
WithHint("The PKCE code verifier must be no more than 128 characters."))
case nc == 0:
return errorsx.WithStack(oauth2.ErrInvalidGrant.
WithHint("The PKCE code verifier was provided but the code challenge was absent from the authorization request."))
Expand Down Expand Up @@ -226,25 +239,23 @@ func (c *Handler) HandleTokenEndpointRequest(ctx context.Context, requester oaut
return nil
}

func (c *Handler) PopulateTokenEndpointResponse(ctx context.Context, requester oauth2.AccessRequester, responder oauth2.AccessResponder) error {
func (c *Handler) PopulateTokenEndpointResponse(ctx context.Context, requester oauth2.AccessRequester, responder oauth2.AccessResponder) (err error) {
return nil
}

func (c *Handler) CanSkipClientAuth(ctx context.Context, requester oauth2.AccessRequester) bool {
func (c *Handler) CanSkipClientAuth(ctx context.Context, requester oauth2.AccessRequester) (skip bool) {
return false
}

func (c *Handler) CanHandleTokenEndpointRequest(ctx context.Context, requester oauth2.AccessRequester) bool {
// grant_type REQUIRED.
// Value MUST be set to "authorization_code"
func (c *Handler) CanHandleTokenEndpointRequest(ctx context.Context, requester oauth2.AccessRequester) (handle bool) {
return requester.GetGrantTypes().ExactOne(consts.GrantTypeAuthorizationCode)
}

func (c *Handler) validateNoPKCE(ctx context.Context, client oauth2.Client) error {
func (c *Handler) validateNoPKCE(ctx context.Context, parameter string, client oauth2.Client) (err error) {
if c.Config.GetEnforcePKCE(ctx) {
return errorsx.WithStack(oauth2.ErrInvalidRequest.
WithHint("Clients must include a code_challenge when performing the authorize code flow, but it is missing.").
WithDebug("The server is configured in a way that enforces PKCE for clients."))
WithHintf("Clients must include a '%s' when performing the authorize code flow, but it is missing.", parameter).
WithDebug("The authorization server is configured in a way that enforces PKCE for all clients."))
}

if c.Config.GetEnforcePKCEForPublicClients(ctx) {
Expand All @@ -254,10 +265,23 @@ func (c *Handler) validateNoPKCE(ctx context.Context, client oauth2.Client) erro

if client.IsPublic() {
return errorsx.WithStack(oauth2.ErrInvalidRequest.
WithHint("This client must include a code_challenge when performing the authorize code flow, but it is missing.").
WithDebug("The server is configured in a way that enforces PKCE for this client."))
WithHintf("Clients must include a '%s' when performing the authorize code flow, but it is missing.", parameter).
WithDebugf("The authorization server is configured in a way that enforces PKCE for all public client type clients and the '%s' client is using the public client type.", client.GetID()))
}
}

if pkce, ok := client.(oauth2.ProofKeyCodeExchangeClient); ok {
if pkce.GetEnforcePKCE() {
return errorsx.WithStack(oauth2.ErrInvalidRequest.
WithHintf("Clients must include a '%s' when performing the authorize code flow, but it is missing.", parameter).
WithDebugf("The client with id '%s' is registered in a way that enforces PKCE.", client.GetID()))
}
}

return nil
}

var (
_ oauth2.AuthorizeEndpointHandler = (*Handler)(nil)
_ oauth2.TokenEndpointHandler = (*Handler)(nil)
)
Loading

0 comments on commit c8f4c46

Please sign in to comment.