Skip to content

Commit

Permalink
refactor login scopes
Browse files Browse the repository at this point in the history
  • Loading branch information
hummerdmag committed Aug 18, 2024
1 parent 92bfd61 commit 96f59f7
Show file tree
Hide file tree
Showing 17 changed files with 101 additions and 67 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ build:
go build -o ./identifo

lint:
golangci-lint run -D deadcode,errcheck,unused,varcheck,govet
golangci-lint run -D errcheck,unused,govet

build_admin_panel:
rm -rf static/admin_panel
Expand Down
33 changes: 21 additions & 12 deletions jwt/service/jwt_token_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ func (ts *JWTokenService) ValidateTokenString(tstr string, v jwtValidator.Valida
// NewAccessToken creates new access token for user.
func (ts *JWTokenService) NewAccessToken(
user model.User,
scopes []string,
scopes model.AllowedScopesSet,
app model.AppData,
requireTFA bool,
tokenPayload map[string]interface{},
Expand All @@ -235,10 +235,11 @@ func (ts *JWTokenService) NewAccessToken(
payload[PayloadName] = user.Username
}

tokenType := model.TokenTypeAccess
scopesStr := scopes.String()
if requireTFA {
scopes = []string{model.TokenTypeTFAPreauth}
scopesStr = model.TokenTypeTFAPreauth
}

if len(tokenPayload) > 0 {
for k, v := range tokenPayload {
payload[k] = v
Expand All @@ -253,9 +254,9 @@ func (ts *JWTokenService) NewAccessToken(
}

claims := &model.Claims{
Scopes: strings.Join(scopes, " "),
Scopes: scopesStr,
Payload: payload,
Type: tokenType,
Type: model.TokenTypeAccess,
StandardClaims: jwt.StandardClaims{
ExpiresAt: (now + lifespan),
Issuer: ts.issuer,
Expand All @@ -278,22 +279,27 @@ func (ts *JWTokenService) NewAccessToken(
}

// NewRefreshToken creates new refresh token.
func (ts *JWTokenService) NewRefreshToken(u model.User, scopes []string, app model.AppData) (model.Token, error) {
func (ts *JWTokenService) NewRefreshToken(
user model.User,
scopes model.AllowedScopesSet,
app model.AppData,
) (model.Token, error) {
if !app.Active || !app.Offline {
return nil, ErrInvalidApp
}

// no offline request
if !model.SliceContains(scopes, model.OfflineScope) {
if !scopes.Contains(model.OfflineScope) {
return nil, ErrInvalidOfflineScope
}

if !u.Active {
if !user.Active {
return nil, ErrInvalidUser
}

payload := make(map[string]interface{})
if model.SliceContains(app.TokenPayload, PayloadName) {
payload[PayloadName] = u.Username
payload[PayloadName] = user.Username
}
now := ijwt.TimeFunc().Unix()

Expand All @@ -303,13 +309,13 @@ func (ts *JWTokenService) NewRefreshToken(u model.User, scopes []string, app mod
}

claims := &model.Claims{
Scopes: strings.Join(scopes, " "),
Scopes: scopes.String(),
Payload: payload,
Type: model.TokenTypeRefresh,
StandardClaims: jwt.StandardClaims{
ExpiresAt: (now + lifespan),
Issuer: ts.issuer,
Subject: u.ID,
Subject: user.ID,
Audience: app.ID,
IssuedAt: now,
},
Expand Down Expand Up @@ -338,7 +344,10 @@ func (ts *JWTokenService) NewRefreshToken(u model.User, scopes []string, app mod
}

// RefreshAccessToken issues new access token for provided refresh token.
func (ts *JWTokenService) RefreshAccessToken(refreshToken model.Token, tokenPayload map[string]interface{}) (model.Token, error) {
func (ts *JWTokenService) RefreshAccessToken(
refreshToken model.Token,
tokenPayload map[string]interface{},
) (model.Token, error) {
rt, ok := refreshToken.(*model.JWToken)
if !ok || rt == nil {
return nil, model.ErrTokenInvalid
Expand Down
6 changes: 5 additions & 1 deletion jwt/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ func TestNewToken(t *testing.T) {
}, "password", "admin", false)
scopes := []string{"scope1", "scope2"}
tokenPayload := []string{"name"}

app := model.AppData{
ID: "123456",
Secret: "1",
Expand All @@ -138,7 +139,10 @@ func TestNewToken(t *testing.T) {
RolesBlacklist: []string{},
NewUserDefaultRole: "",
}
token, err := tokenService.NewAccessToken(user, scopes, app, false, nil)

allowedScopes := model.AllowedScopes(scopes, scopes, false)

token, err := tokenService.NewAccessToken(user, allowedScopes, app, false, nil)
assert.NoError(t, err)

tokenString, err := tokenService.String(token)
Expand Down
6 changes: 5 additions & 1 deletion model/slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ func SliceIntersect(a, b []string) []string {
}

func SliceContains(s []string, e string) bool {
el := strings.TrimSpace(e)

for _, a := range s {
if strings.TrimSpace(strings.ToLower(a)) == strings.TrimSpace(strings.ToLower(e)) {
if strings.EqualFold(strings.TrimSpace(a), el) {
return true
}

}

return false
}

Expand Down
33 changes: 25 additions & 8 deletions model/token.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package model

import (
"strings"
"time"

jwt "github.com/golang-jwt/jwt/v4"
Expand Down Expand Up @@ -182,15 +183,31 @@ type Claims struct {
// Full example of how to use JWT tokens:
// https://github.com/form3tech-oss/jwt-go/blob/master/cmd/jwt/app.go

func AllowedScopes(requestedScopes, userScopes []string, isOffline bool) []string {
scopes := []string{}
// This type is needed for guard against passing unchecked scopes to the token.
// Do not convert user provided scopes to this type directly.
type AllowedScopesSet struct {
scopes []string
}

func (a AllowedScopesSet) String() string {
return strings.Join(a.scopes, " ")
}

func (a AllowedScopesSet) Scopes() []string {
return a.scopes
}

func (a AllowedScopesSet) Contains(scope string) bool {
return SliceContains(a.scopes, scope)
}

func AllowedScopes(requestedScopes, userScopes []string, isOffline bool) AllowedScopesSet {
// if we requested any scope, let's provide all the scopes user has and requested
if len(requestedScopes) > 0 {
scopes = SliceIntersect(requestedScopes, userScopes)
}
if SliceContains(requestedScopes, "offline") && isOffline {
scopes = append(scopes, "offline")
scopes := SliceIntersect(requestedScopes, userScopes)

if SliceContains(requestedScopes, OfflineScope) && isOffline {
scopes = append(scopes, OfflineScope)
}

return scopes
return AllowedScopesSet{scopes}
}
4 changes: 2 additions & 2 deletions model/token_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ const (

// TokenService is an abstract token manager.
type TokenService interface {
NewAccessToken(u User, scopes []string, app AppData, requireTFA bool, tokenPayload map[string]interface{}) (Token, error)
NewRefreshToken(u User, scopes []string, app AppData) (Token, error)
NewAccessToken(u User, scopes AllowedScopesSet, app AppData, requireTFA bool, tokenPayload map[string]interface{}) (Token, error)
NewRefreshToken(u User, scopes AllowedScopesSet, app AppData) (Token, error)
RefreshAccessToken(token Token, tokenPayload map[string]interface{}) (Token, error)
NewInviteToken(email, role, audience string, data map[string]interface{}) (Token, error)
NewResetToken(userID string) (Token, error)
Expand Down
9 changes: 7 additions & 2 deletions storage/dynamodb/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,17 @@ func (db *DB) IsTableExists(table string) (bool, error) {
}

func (db *DB) DeleteTable(table string) error {
svc := dynamodb.New(session.New())
sess, err := session.NewSession()
if err != nil {
return err
}

svc := dynamodb.New(sess)
input := &dynamodb.DeleteTableInput{
TableName: aws.String(table),
}

_, err := svc.DeleteTable(input)
_, err = svc.DeleteTable(input)
return err
}

Expand Down
7 changes: 3 additions & 4 deletions web/api/2fa.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (ar *Router) EnableTFA() http.HandlerFunc {
return
}

accessToken, _, err := ar.loginUser(user, []string{}, app, false, true, tokenPayload)
accessToken, _, err := ar.loginUser(user, model.AllowedScopesSet{}, app, true, tokenPayload)
if err != nil {
ar.Error(w, locale, http.StatusInternalServerError, l.ErrorTokenUnableToCreateAccessTokenError, err)
return
Expand Down Expand Up @@ -274,8 +274,7 @@ func (ar *Router) FinalizeTFA() http.HandlerFunc {
return
}

createRefreshToken := contains(scopes, model.OfflineScope)
accessToken, refreshToken, err := ar.loginUser(user, scopes, app, createRefreshToken, false, tokenPayload)
accessToken, refreshToken, err := ar.loginUser(user, scopes, app, false, tokenPayload)
if err != nil {
ar.Error(w, locale, http.StatusInternalServerError, l.ErrorTokenUnableToCreateAccessTokenError, err)
return
Expand Down Expand Up @@ -308,7 +307,7 @@ func (ar *Router) FinalizeTFA() http.HandlerFunc {
}

ar.journal(JournalOperationLoginWith2FA,
user.ID, app.ID, r.UserAgent(), scopes)
user.ID, app.ID, r.UserAgent(), scopes.Scopes())

ar.server.Storages().User.UpdateLoginMetadata(user.ID)
ar.ServeJSON(w, locale, http.StatusOK, result)
Expand Down
2 changes: 1 addition & 1 deletion web/api/federated_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ func (ar *Router) FederatedLoginComplete() http.HandlerFunc {
authResult.Scopes = fsess.Scopes

ar.journal(JournalOperationFederatedLogin,
user.ID, app.ID, r.UserAgent(), resultScopes)
user.ID, app.ID, r.UserAgent(), resultScopes.Scopes())

ar.ServeJSON(w, locale, http.StatusOK, authResult)
}
Expand Down
4 changes: 2 additions & 2 deletions web/api/federated_oidc_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,11 @@ func (ar *Router) OIDCLoginComplete(useSession bool) http.HandlerFunc {
authResult.CallbackUrl = fsess.CallbackUrl
}

authResult.Scopes = resultScopes
authResult.Scopes = resultScopes.Scopes()
authResult.ProviderData = *providerData

ar.journal(JournalOperationOIDCLogin,
user.ID, app.ID, r.UserAgent(), resultScopes)
user.ID, app.ID, r.UserAgent(), resultScopes.Scopes())

ar.ServeJSON(w, locale, http.StatusOK, authResult)
}
Expand Down
2 changes: 1 addition & 1 deletion web/api/impersonate_as.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (ar *Router) ImpersonateAs() http.HandlerFunc {
authResult.RefreshToken = ""

ar.journal(JournalOperationImpersonatedAs,
userID, app.ID, r.UserAgent(), resultScopes)
userID, app.ID, r.UserAgent(), resultScopes.Scopes())

ar.ServeJSON(w, locale, http.StatusOK, authResult)
}
Expand Down
30 changes: 18 additions & 12 deletions web/api/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (ar *Router) LoginWithPassword() http.HandlerFunc {
}

ar.journal(JournalOperationLoginWithPassword,
user.ID, app.ID, r.UserAgent(), resultScopes)
user.ID, app.ID, r.UserAgent(), resultScopes.Scopes())

ar.ServeJSON(w, locale, http.StatusOK, authResult)
}
Expand Down Expand Up @@ -311,9 +311,9 @@ func (ar *Router) getTokenPayloadService(app model.AppData) (model.TokenPayloadP
// createRefreshToken boolean param tells if we should issue refresh token as well.
func (ar *Router) loginUser(
user model.User,
scopes []string,
scopes model.AllowedScopesSet,
app model.AppData,
createRefreshToken, require2FA bool,
require2FA bool,
tokenPayload map[string]interface{},
) (string, string, error) {
token, err := ar.server.Services().Token.NewAccessToken(user, scopes, app, require2FA, tokenPayload)
Expand All @@ -325,6 +325,9 @@ func (ar *Router) loginUser(
if err != nil {
return "", "", err
}

createRefreshToken := scopes.Contains(model.OfflineScope)

if !createRefreshToken || require2FA {
return accessTokenString, "", nil
}
Expand All @@ -335,10 +338,12 @@ func (ar *Router) loginUser(
logging.FieldError, err)
return accessTokenString, "", nil
}

refreshTokenString, err := ar.server.Services().Token.String(refresh)
if err != nil {
return "", "", err
}

return accessTokenString, refreshTokenString, nil
}

Expand All @@ -347,11 +352,11 @@ func (ar *Router) loginFlow(
user model.User,
requestedScopes []string,
additionalPayload map[string]any,
) (AuthResponse, []string, error) {
) (AuthResponse, model.AllowedScopesSet, error) {
// check if the user has the scope, that allows to login to the app
// user has to have at least one scope app expecting
if len(app.Scopes) > 0 && len(model.SliceIntersect(app.Scopes, user.Scopes)) == 0 {
return AuthResponse{}, nil, errors.New("user does not have required scope for the app")
return AuthResponse{}, model.AllowedScopesSet{}, errors.New("user does not have required scope for the app")
}

// Do login flow.
Expand All @@ -360,13 +365,12 @@ func (ar *Router) loginFlow(
// Check if we should require user to authenticate with 2FA.
require2FA, enabled2FA, err := ar.check2FA(app.TFAStatus, ar.tfaType, user)
if !require2FA && enabled2FA && err != nil {
return AuthResponse{}, nil, err
return AuthResponse{}, model.AllowedScopesSet{}, err
}

offline := contains(scopes, model.OfflineScope)
tokenPayload, err := ar.getTokenPayloadForApp(app, user.ID)
if err != nil {
return AuthResponse{}, nil, err
return AuthResponse{}, model.AllowedScopesSet{}, err
}

if tokenPayload == nil {
Expand All @@ -377,9 +381,9 @@ func (ar *Router) loginFlow(
}
}

accessToken, refreshToken, err := ar.loginUser(user, scopes, app, offline, require2FA, tokenPayload)
accessToken, refreshToken, err := ar.loginUser(user, scopes, app, require2FA, tokenPayload)
if err != nil {
return AuthResponse{}, nil, err
return AuthResponse{}, model.AllowedScopesSet{}, err
}

result := AuthResponse{
Expand All @@ -391,7 +395,7 @@ func (ar *Router) loginFlow(

if require2FA && enabled2FA {
if err := ar.sendOTPCode(app, user); err != nil {
return AuthResponse{}, nil, err
return AuthResponse{}, model.AllowedScopesSet{}, err
}
} else {
ar.server.Storages().User.UpdateLoginMetadata(user.ID)
Expand Down Expand Up @@ -461,12 +465,14 @@ func (ar *Router) GetImpersonateToken() http.HandlerFunc {
}

// getImpersonateAccessToken creates and returns access token for a user.
func (ar *Router) getImpersonateAccessToken(user model.User, scopes []string, app model.AppData) (string, error) {
func (ar *Router) getImpersonateAccessToken(user model.User, requestedScopes []string, app model.AppData) (string, error) {
tokenPayload, err := ar.getTokenPayloadForApp(app, user.ID)
if err != nil {
return "", err
}

scopes := model.AllowedScopes(requestedScopes, user.Scopes, app.Offline)

token, err := ar.server.Services().Token.NewAccessToken(user, scopes, app, false, tokenPayload)
if err != nil {
return "", err
Expand Down
Loading

0 comments on commit 96f59f7

Please sign in to comment.