Skip to content

Commit

Permalink
BED-5008 feat: enforce SSO role provision changes on every login (#1150)
Browse files Browse the repository at this point in the history
  • Loading branch information
mistahj67 authored Feb 18, 2025
1 parent d1a1be2 commit f5d9c10
Show file tree
Hide file tree
Showing 14 changed files with 2,243 additions and 348 deletions.
1 change: 1 addition & 0 deletions cmd/api/src/api/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ const (
ErrorResponseUserSelfDisable = "user attempted to disable themselves"
ErrorResponseUserSelfRoleChange = "user attempted to change own role"
ErrorResponseUserSelfSSOProviderChange = "user attempted to change own SSO Provider"
ErrorResponseUserSSOProviderRoleProvisionChange = "user attempted to change a role for a SSO Provider with role provision enabled"
ErrorResponseAGTagWhiteSpace = "asset group tags must not contain whitespace"
ErrorResponseAGNameTagEmpty = "asset group name or tag must not be empty"
ErrorResponseAGDuplicateName = "asset group name must be unique"
Expand Down
34 changes: 28 additions & 6 deletions cmd/api/src/api/v2/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,26 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
} else if roles, err := s.db.GetRoles(request.Context(), updateUserRequest.Roles); err != nil {
api.HandleDatabaseError(request, response, err)
} else {
user.Roles = roles
user.FirstName = null.StringFrom(updateUserRequest.FirstName)
user.LastName = null.StringFrom(updateUserRequest.LastName)
user.EmailAddress = null.StringFrom(updateUserRequest.EmailAddress)
user.PrincipalName = updateUserRequest.Principal
user.IsDisabled = updateUserRequest.IsDisabled
// PATCH requests may not contain every field, only conditionally update if fields exist
if updateUserRequest.FirstName != "" {
user.FirstName = null.StringFrom(updateUserRequest.FirstName)
}

if updateUserRequest.LastName != "" {
user.LastName = null.StringFrom(updateUserRequest.LastName)
}

if updateUserRequest.EmailAddress != "" {
user.EmailAddress = null.StringFrom(updateUserRequest.EmailAddress)
}

if updateUserRequest.Principal != "" {
user.PrincipalName = updateUserRequest.Principal
}

if updateUserRequest.IsDisabled != nil {
user.IsDisabled = *updateUserRequest.IsDisabled
}

loggedInUser, _ := auth.GetUserFromAuthCtx(authCtx.AuthCtx)

Expand Down Expand Up @@ -455,6 +469,14 @@ func (s ManagementResource) UpdateUser(response http.ResponseWriter, request *ht
}
}

// We have to wait until after SSOProvider updates are handled above to validate roles can be safely updated.
if user.SSOProviderHasRoleProvisionEnabled() && !slices.Equal(roles.IDs(), user.Roles.IDs()) {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponseUserSSOProviderRoleProvisionChange, request), response)
return
} else if updateUserRequest.Roles != nil {
user.Roles = roles
}

if err := s.db.UpdateUser(request.Context(), user); err != nil {
if errors.Is(err, database.ErrDuplicateUserPrincipal) {
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusConflict, api.ErrorResponseUserDuplicatePrincipal, request), response)
Expand Down
44 changes: 36 additions & 8 deletions cmd/api/src/api/v2/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {

adminUser = model.User{Unique: model.Unique{ID: must.NewUUIDv4()}}
goodRoles = []int32{0}
badRoles = []int32{1}
goodUserID = must.NewUUIDv4()
badUserID = must.NewUUIDv4()

Expand All @@ -180,6 +181,8 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
Serial: model.Serial{ID: 1234},
SSOProviderID: null.Int32From(ssoProviderID),
},
Config: model.SSOProviderConfig{
AutoProvision: model.SSOProviderAutoProvisionConfig{Enabled: true, RoleProvision: true}},
}
)

Expand Down Expand Up @@ -228,6 +231,24 @@ func TestManagementResource_EnableUserSAML(t *testing.T) {
ResponseStatusCode(http.StatusOK)
})

t.Run("Fails if roles set", func(t *testing.T) {
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Eq(badRoles)).Return(model.Roles{model.Role{Serial: model.Serial{ID: 1}}}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), goodUserID).Return(model.User{SSOProviderID: null.Int32From(ssoProviderID), SSOProvider: &ssoProvider, Roles: model.Roles{model.Role{Serial: model.Serial{ID: 0}}}}, nil)
mockDB.EXPECT().GetSSOProviderById(gomock.Any(), ssoProvider.ID).Return(ssoProvider, nil)

test.Request(t).
WithContext(bhCtx).
WithURLPathVars(map[string]string{"user_id": goodUserID.String()}).
WithBody(v2.UpdateUserRequest{
Principal: "tester",
Roles: badRoles,
SSOProviderID: null.Int32From(ssoProviderID),
}).
OnHandlerFunc(resources.UpdateUser).
Require().
ResponseStatusCode(http.StatusBadRequest)
})

t.Run("Successful user update with sso provider-saml", func(t *testing.T) {
mockDB.EXPECT().GetRoles(gomock.Any(), gomock.Eq(goodRoles)).Return(model.Roles{}, nil)
mockDB.EXPECT().GetUser(gomock.Any(), goodUserID).Return(model.User{}, nil)
Expand Down Expand Up @@ -1533,12 +1554,15 @@ func TestManagementResource_UpdateUser_SelfDisable(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

endpoint := "/api/v2/auth/users"
// logged in user has ID 00000000-0000-0000-0000-000000000000
// leaving ID blank here will make goodUser have the same ID, so this should fail
goodUser := model.User{PrincipalName: "good user"}
var (
endpoint = "/api/v2/auth/users"
// logged in user has ID 00000000-0000-0000-0000-000000000000
// leaving ID blank here will make goodUser have the same ID, so this should fail
goodUser = model.User{PrincipalName: "good user"}
isDisabled = true
resources, mockDB = apitest.NewAuthManagementResource(mockCtrl)
)

resources, mockDB := apitest.NewAuthManagementResource(mockCtrl)
mockDB.EXPECT().GetConfigurationParameter(gomock.Any(), appcfg.PasswordExpirationWindow).Return(appcfg.Parameter{
Key: appcfg.PasswordExpirationWindow,
Value: must.NewJSONBObject(appcfg.PasswordExpiration{
Expand Down Expand Up @@ -1585,7 +1609,7 @@ func TestManagementResource_UpdateUser_SelfDisable(t *testing.T) {
require.Nil(t, err)

payload, err = json.Marshal(v2.UpdateUserRequest{
IsDisabled: true,
IsDisabled: &isDisabled,
})
require.Nil(t, err)

Expand Down Expand Up @@ -1678,6 +1702,8 @@ func TestManagementResource_UpdateUser_LookupActiveSessionsError(t *testing.T) {
},
}

isDisabled := true

resources, mockDB := apitest.NewAuthManagementResource(mockCtrl)
mockDB.EXPECT().GetConfigurationParameter(gomock.Any(), appcfg.PasswordExpirationWindow).Return(appcfg.Parameter{
Key: appcfg.PasswordExpirationWindow,
Expand Down Expand Up @@ -1726,7 +1752,7 @@ func TestManagementResource_UpdateUser_LookupActiveSessionsError(t *testing.T) {
require.Nil(t, err)

payload, err = json.Marshal(v2.UpdateUserRequest{
IsDisabled: true,
IsDisabled: &isDisabled,
})
require.Nil(t, err)

Expand Down Expand Up @@ -1984,6 +2010,8 @@ func TestManagementResource_UpdateUser_Success(t *testing.T) {
},
}

isDisabled := true

resources, mockDB := apitest.NewAuthManagementResource(mockCtrl)
mockDB.EXPECT().GetConfigurationParameter(gomock.Any(), appcfg.PasswordExpirationWindow).Return(appcfg.Parameter{
Key: appcfg.PasswordExpirationWindow,
Expand Down Expand Up @@ -2036,7 +2064,7 @@ func TestManagementResource_UpdateUser_Success(t *testing.T) {
userID, err := uuid.NewV4()
require.Nil(t, err)
updateUserRequest := v2.UpdateUserRequest{
IsDisabled: true,
IsDisabled: &isDisabled,
}

payload, err = json.Marshal(updateUserRequest)
Expand Down
60 changes: 36 additions & 24 deletions cmd/api/src/api/v2/auth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ var (

type oidcClaims struct {
Name string `json:"name"`
FamilyName string `json:"family_name"`
DisplayName string `json:"given_name"`
LastName string `json:"family_name"`
FirstName string `json:"given_name"`
Email string `json:"email"` // Not always present
Verified bool `json:"email_verified"`
PreferredUsername string `json:"preferred_username"` // Present in Entra claims, may be an email
Expand Down Expand Up @@ -240,7 +240,7 @@ func (s ManagementResource) OIDCCallbackHandler(response http.ResponseWriter, re
api.RedirectToLoginURL(response, request, "Claims invalid: no valid email address found")
} else {
if ssoProvider.Config.AutoProvision.Enabled {
if err := jitOIDCUserCreation(request.Context(), ssoProvider, email, claims, s.db); err != nil {
if err := jitOIDCUserUpsert(request.Context(), ssoProvider, email, claims, s.db); err != nil {
// It is safe to let this request drop into the CreateSSOSession function below to ensure proper audit logging
slog.WarnContext(request.Context(), fmt.Sprintf("[OIDC] Error during JIT User Creation: %v", err))
}
Expand Down Expand Up @@ -349,36 +349,48 @@ func getEmailFromOIDCClaims(claims oidcClaims) (string, error) {
return "", ErrEmailMissing
}

func jitOIDCUserCreation(ctx context.Context, ssoProvider model.SSOProvider, email string, claims oidcClaims, u jitUserCreator) error {
func jitOIDCUserUpsert(ctx context.Context, ssoProvider model.SSOProvider, email string, claims oidcClaims, u jitUserUpserter) error {
if roles, err := SanitizeAndGetRoles(ctx, ssoProvider.Config.AutoProvision, claims.Roles, u); err != nil {
return fmt.Errorf("sanitize roles: %v", err)
} else if len(roles) != 1 {
return fmt.Errorf("invalid roles")
} else if _, err := u.LookupUser(ctx, email); err != nil && !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("lookup user: %v", err)
} else if errors.Is(err, database.ErrNotFound) {
var user = model.User{
EmailAddress: null.StringFrom(email),
PrincipalName: email,
Roles: roles,
SSOProviderID: null.Int32From(ssoProvider.ID),
EULAAccepted: true, // EULA Acceptance does not pertain to Bloodhound Community Edition; this flag is used for Bloodhound Enterprise users
FirstName: null.StringFrom(email),
LastName: null.StringFrom("Last name not found"),
} else if user, err := u.LookupUser(ctx, email); err != nil {
if errors.Is(err, database.ErrNotFound) {
return jitOIDCUserCreate(ctx, ssoProvider, email, claims, u, roles)
}

if claims.DisplayName != "" {
user.FirstName = null.StringFrom(claims.DisplayName)
return fmt.Errorf("user lookup: %v", err)
} else if ssoProvider.Config.AutoProvision.RoleProvision && !user.Roles.Has(roles[0]) {
// roles should only ever have 1 role
user.Roles = roles
if err := u.UpdateUser(ctx, user); err != nil {
return fmt.Errorf("update user: %v", err)
}
}

if claims.FamilyName != "" {
user.LastName = null.StringFrom(claims.FamilyName)
}
return nil
}

if _, err := u.CreateUser(ctx, user); err != nil {
return fmt.Errorf("create user: %v", err)
}
func jitOIDCUserCreate(ctx context.Context, ssoProvider model.SSOProvider, email string, claims oidcClaims, u jitUserUpserter, roles model.Roles) error {
user := model.User{
EmailAddress: null.StringFrom(email),
PrincipalName: email,
Roles: roles,
SSOProviderID: null.Int32From(ssoProvider.ID),
EULAAccepted: true, // EULA Acceptance does not pertain to Bloodhound Community Edition; this flag is used for Bloodhound Enterprise users
FirstName: null.StringFrom(email),
LastName: null.StringFrom("Last name not found"),
}

if claims.FirstName != "" {
user.FirstName = null.StringFrom(claims.FirstName)
}

if claims.LastName != "" {
user.LastName = null.StringFrom(claims.LastName)
}

if _, err := u.CreateUser(ctx, user); err != nil {
return fmt.Errorf("create user: %v", err)
}
return nil
}
54 changes: 33 additions & 21 deletions cmd/api/src/api/v2/auth/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ func (s ManagementResource) SAMLCallbackHandler(response http.ResponseWriter, re
api.RedirectToLoginURL(response, request, "Invalid assertion: no valid email address found")
} else {
if ssoProvider.Config.AutoProvision.Enabled {
if err := jitSAMLUserCreation(request.Context(), ssoProvider, principalName, assertion, s.db); err != nil {
if err := jitSAMLUserUpsert(request.Context(), ssoProvider, principalName, assertion, s.db); err != nil {
// It is safe to let this request drop into the CreateSSOSession function below to ensure proper audit logging
slog.WarnContext(request.Context(), fmt.Sprintf("[SAML] Error during JIT User Creation: %v", err))
}
Expand All @@ -470,36 +470,48 @@ func (s ManagementResource) SAMLCallbackHandler(response http.ResponseWriter, re
}
}

func jitSAMLUserCreation(ctx context.Context, ssoProvider model.SSOProvider, principalName string, assertion *saml.Assertion, u jitUserCreator) error {
func jitSAMLUserUpsert(ctx context.Context, ssoProvider model.SSOProvider, principalName string, assertion *saml.Assertion, u jitUserUpserter) error {
if roles, err := SanitizeAndGetRoles(ctx, ssoProvider.Config.AutoProvision, ssoProvider.SAMLProvider.GetSAMLUserRolesFromAssertion(assertion), u); err != nil {
return fmt.Errorf("sanitize roles: %v", err)
} else if len(roles) != 1 {
return fmt.Errorf("invalid roles detected")
} else if _, err := u.LookupUser(ctx, principalName); err != nil && !errors.Is(err, database.ErrNotFound) {
} else if user, err := u.LookupUser(ctx, principalName); err != nil {
if errors.Is(err, database.ErrNotFound) {
return jitSAMLUserCreate(ctx, ssoProvider, principalName, assertion, u, roles)
}
return fmt.Errorf("lookup user: %v", err)
} else if errors.Is(err, database.ErrNotFound) {
user := model.User{
EmailAddress: null.StringFrom(principalName),
PrincipalName: principalName,
Roles: roles,
SSOProviderID: null.Int32From(ssoProvider.ID),
EULAAccepted: true, // EULA Acceptance does not pertain to Bloodhound Community Edition; this flag is used for Bloodhound Enterprise users
FirstName: null.StringFrom(principalName),
LastName: null.StringFrom("Last name not found"),
} else if ssoProvider.Config.AutoProvision.RoleProvision && !user.Roles.Has(roles[0]) {
// roles should only ever have 1 role
user.Roles = roles
if err := u.UpdateUser(ctx, user); err != nil {
return fmt.Errorf("update user: %v", err)
}
}

if givenName, err := ssoProvider.SAMLProvider.GetSAMLUserGivenNameFromAssertion(assertion); err == nil {
user.FirstName = null.StringFrom(givenName)
}
return nil
}

if surname, err := ssoProvider.SAMLProvider.GetSAMLUserSurnameFromAssertion(assertion); err == nil {
user.LastName = null.StringFrom(surname)
}
func jitSAMLUserCreate(ctx context.Context, ssoProvider model.SSOProvider, principalName string, assertion *saml.Assertion, u jitUserUpserter, roles model.Roles) error {
user := model.User{
EmailAddress: null.StringFrom(principalName),
PrincipalName: principalName,
Roles: roles,
SSOProviderID: null.Int32From(ssoProvider.ID),
EULAAccepted: true, // EULA Acceptance does not pertain to Bloodhound Community Edition; this flag is used for Bloodhound Enterprise users
FirstName: null.StringFrom(principalName),
LastName: null.StringFrom("Last name not found"),
}

if _, err := u.CreateUser(ctx, user); err != nil {
return fmt.Errorf("create user: %v", err)
}
if givenName, err := ssoProvider.SAMLProvider.GetSAMLUserGivenNameFromAssertion(assertion); err == nil {
user.FirstName = null.StringFrom(givenName)
}

if surname, err := ssoProvider.SAMLProvider.GetSAMLUserSurnameFromAssertion(assertion); err == nil {
user.LastName = null.StringFrom(surname)
}

if _, err := u.CreateUser(ctx, user); err != nil {
return fmt.Errorf("create user: %v", err)
}
return nil
}
3 changes: 2 additions & 1 deletion cmd/api/src/api/v2/auth/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,12 @@ type getAllRoler interface {
GetAllRoles(ctx context.Context, order string, filter model.SQLFilter) (model.Roles, error)
}

type jitUserCreator interface {
type jitUserUpserter interface {
getAllRoler

LookupUser(ctx context.Context, principalNameOrEmail string) (model.User, error)
CreateUser(ctx context.Context, user model.User) (model.User, error)
UpdateUser(ctx context.Context, user model.User) error
}

// ListAuthProviders lists all available SSO providers (SAML and OIDC) with sorting and filtering
Expand Down
3 changes: 1 addition & 2 deletions cmd/api/src/api/v2/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package v2
import (
"github.com/gorilla/schema"
"github.com/specterops/bloodhound/cache"
_ "github.com/specterops/bloodhound/dawgs/drivers/neo4j"
"github.com/specterops/bloodhound/dawgs/graph"
"github.com/specterops/bloodhound/src/api"
"github.com/specterops/bloodhound/src/auth"
Expand Down Expand Up @@ -68,7 +67,7 @@ type UpdateUserRequest struct {
Roles []int32 `json:"roles"`
SAMLProviderID string `json:"saml_provider_id"`
SSOProviderID null.Int32 `json:"sso_provider_id"`
IsDisabled bool `json:"is_disabled"`
IsDisabled *bool `json:"is_disabled,omitempty"`
}

type CreateUserRequest struct {
Expand Down
4 changes: 4 additions & 0 deletions cmd/api/src/model/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,10 @@ func (s *User) RemoveRole(role Role) {
s.Roles = s.Roles.RemoveByName(role.Name)
}

func (s *User) SSOProviderHasRoleProvisionEnabled() bool {
return s.SSOProvider != nil && s.SSOProvider.Config.AutoProvision.Enabled && s.SSOProvider.Config.AutoProvision.RoleProvision
}

type Users []User

func (s Users) IsSortable(column string) bool {
Expand Down
Loading

0 comments on commit f5d9c10

Please sign in to comment.