Skip to content

Commit

Permalink
EVEREST-1799 get groups claim and validate permissions
Browse files Browse the repository at this point in the history
  • Loading branch information
recharte committed Jan 13, 2025
1 parent e7ff34e commit 39f2145
Show file tree
Hide file tree
Showing 12 changed files with 341 additions and 78 deletions.
16 changes: 8 additions & 8 deletions internal/server/handlers/rbac/backup_storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func TestRBAC_BackupStorage(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -228,7 +228,7 @@ func TestRBAC_BackupStorage(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -395,7 +395,7 @@ func TestRBAC_BackupStorage(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -563,7 +563,7 @@ func TestRBAC_BackupStorage(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -727,7 +727,7 @@ func TestRBAC_BackupStorage(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -768,10 +768,10 @@ func newConfigMapPolicy(policy string) *corev1.ConfigMap {
}
}

func testUserGetter(ctx context.Context) (string, error) {
user, ok := ctx.Value(common.UserCtxKey).(string)
func testUserGetter(ctx context.Context) (rbac.User, error) {
user, ok := ctx.Value(common.UserCtxKey).(rbac.User)
if !ok {
return "", errors.New("user not found in context")
return rbac.User{}, errors.New("user not found in context")
}
return user, nil
}
8 changes: 4 additions & 4 deletions internal/server/handlers/rbac/database_cluster_backup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func TestRBAC_DatabaseClusterBackup(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -216,7 +216,7 @@ func TestRBAC_DatabaseClusterBackup(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -282,7 +282,7 @@ func TestRBAC_DatabaseClusterBackup(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -355,7 +355,7 @@ func TestRBAC_DatabaseClusterBackup(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down
10 changes: 5 additions & 5 deletions internal/server/handlers/rbac/database_cluster_restore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestRBAC_DatabaseClusterRestore(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -156,7 +156,7 @@ func TestRBAC_DatabaseClusterRestore(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -248,7 +248,7 @@ func TestRBAC_DatabaseClusterRestore(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -359,7 +359,7 @@ func TestRBAC_DatabaseClusterRestore(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -430,7 +430,7 @@ func TestRBAC_DatabaseClusterRestore(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down
16 changes: 8 additions & 8 deletions internal/server/handlers/rbac/database_cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func TestRBAC_DatabaseCluster(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "test-user")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "test-user"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -306,7 +306,7 @@ func TestRBAC_DatabaseCluster(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "test-user")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "test-user"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -434,7 +434,7 @@ func TestRBAC_DatabaseCluster(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "test-user")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "test-user"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -812,7 +812,7 @@ func TestRBAC_DatabaseCluster(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "test-user")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "test-user"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -892,7 +892,7 @@ func TestRBAC_DatabaseCluster(t *testing.T) {
)
return h
}
ctx := context.WithValue(context.Background(), common.UserCtxKey, "test-user")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "test-user"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -956,7 +956,7 @@ func TestRBAC_DatabaseCluster(t *testing.T) {
&api.DatabaseClusterCredential{}, nil)
return h
}
ctx := context.WithValue(context.Background(), common.UserCtxKey, "test-user")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "test-user"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -1011,7 +1011,7 @@ func TestRBAC_DatabaseCluster(t *testing.T) {
[]api.DatabaseClusterComponent{}, nil)
return h
}
ctx := context.WithValue(context.Background(), common.UserCtxKey, "test-user")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "test-user"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -1066,7 +1066,7 @@ func TestRBAC_DatabaseCluster(t *testing.T) {
&api.DatabaseClusterPitr{}, nil)
return h
}
ctx := context.WithValue(context.Background(), common.UserCtxKey, "test-user")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "test-user"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down
10 changes: 5 additions & 5 deletions internal/server/handlers/rbac/database_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func TestRBAC_DatabaseEngines(t *testing.T) {
},
},
}
ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -257,7 +257,7 @@ func TestRBAC_DatabaseEngines(t *testing.T) {
return &h
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -313,7 +313,7 @@ func TestRBAC_DatabaseEngines(t *testing.T) {
return &h
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -448,7 +448,7 @@ func TestRBAC_DatabaseEngines(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down Expand Up @@ -567,7 +567,7 @@ func TestRBAC_DatabaseEngines(t *testing.T) {
},
}

ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob")
ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"})
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
Expand Down
26 changes: 16 additions & 10 deletions internal/server/handlers/rbac/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type rbacHandler struct {
enforcer casbin.IEnforcer
log *zap.SugaredLogger
next handlers.Handler
userGetter func(ctx context.Context) (string, error)
userGetter func(ctx context.Context) (rbac.User, error)
}

// New returns a new RBAC handler.
Expand Down Expand Up @@ -55,17 +55,23 @@ func (h *rbacHandler) enforce(
action,
object string,
) error {
subject, err := h.userGetter(ctx)
user, err := h.userGetter(ctx)
if err != nil {
return err
}
ok, err := h.enforcer.Enforce(subject, resource, action, object)
if err != nil {
return fmt.Errorf("enforce error: %w", err)
}
if !ok {
h.log.Warnf("Permission denied: [%s %s %s %s]", subject, resource, action, object)
return ErrInsufficientPermissions

// User is allowed to perform the operation if the user's subject or any
// of its groups have the required permission.
for _, sub := range append([]string{user.Subject}, user.Groups...) {
ok, err := h.enforcer.Enforce(sub, resource, action, object)
if err != nil {
return fmt.Errorf("enforce error: %w", err)
}
if ok {
return nil
}
}
return nil

h.log.Warnf("Permission denied: [%s %s %s %s]", user.Subject, resource, action, object)
return ErrInsufficientPermissions
}
56 changes: 31 additions & 25 deletions internal/server/handlers/rbac/kubernetes.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package rbac

import (
"context"
"errors"
"fmt"

"github.com/AlekSi/pointer"
Expand All @@ -28,15 +27,39 @@ func (h *rbacHandler) GetUserPermissions(ctx context.Context) (*api.UserPermissi
if err != nil {
return nil, err
}
perms, err := h.enforcer.GetImplicitPermissionsForUser(user)
if err != nil {
return nil, fmt.Errorf("failed to GetImplicitPermissionsForUser: %w", err)

// Let's use a map to deduplicate the permissions after resolving all roles
permsMap := make(map[[4]string]struct{})

// Get permissions for the user and the groups it belongs to
for _, sub := range append([]string{user.Subject}, user.Groups...) {
perms, err := h.enforcer.GetImplicitPermissionsForUser(sub)
if err != nil {
return nil, fmt.Errorf("failed to GetImplicitPermissionsForUser: %w", err)
}

// GetImplicitPermissionsForUser returns all policies assigned to the
// user/group directly as well as the policies assigned to the roles
// the user/group has. We need to resolve all roles for the user.
for _, perm := range perms {
if len(perm) != 4 {
// This should never happen, but let's be safe
return nil, fmt.Errorf("invalid permission")
}

// We don't want to expose the groups or roles in the permissions
// so we replace them with the user
permsMap[[4]string{user.Subject, perm[1], perm[2], perm[3]}] = struct{}{}
}
}

if err := h.resolveRoles(user, perms); err != nil {
return nil, err
// Convert the map back to a slice
result := make([][]string, len(permsMap))
i := 0
for k := range permsMap {
result[i] = []string(k[:])
i++
}
result := pointer.To(perms)

nextRes, err := h.next.GetUserPermissions(ctx)
if err != nil {
Expand All @@ -49,25 +72,8 @@ func (h *rbacHandler) GetUserPermissions(ctx context.Context) (*api.UserPermissi
}

res := &api.UserPermissions{
Permissions: result,
Permissions: pointer.To(result),
Enabled: enabled,
}
return res, nil
}

// For a given set of `permissions` for a `user`, this function
// will resolve all roles for the user.
func (h *rbacHandler) resolveRoles(user string, permissions [][]string) error {
userRoles, err := h.enforcer.GetRolesForUser(user)
if err != nil {
return errors.Join(err, errors.New("cannot get user roles"))
}
for _, role := range userRoles {
for i, perm := range permissions {
if perm[0] == role {
permissions[i][0] = user
}
}
}
return nil
}
Loading

0 comments on commit 39f2145

Please sign in to comment.