From 39f2145c3c07df1dcbe631fff9f16aa6dbef8509 Mon Sep 17 00:00:00 2001 From: Diogo Recharte Date: Fri, 3 Jan 2025 12:24:29 +0000 Subject: [PATCH] EVEREST-1799 get groups claim and validate permissions --- .../handlers/rbac/backup_storage_test.go | 16 +- .../rbac/database_cluster_backup_test.go | 8 +- .../rbac/database_cluster_restore_test.go | 10 +- .../handlers/rbac/database_cluster_test.go | 16 +- .../handlers/rbac/database_engine_test.go | 10 +- internal/server/handlers/rbac/handler.go | 26 +-- internal/server/handlers/rbac/kubernetes.go | 56 ++++--- .../server/handlers/rbac/kubernetes_test.go | 150 ++++++++++++++++++ .../handlers/rbac/monitoring_instance_test.go | 10 +- pkg/rbac/rbac.go | 47 +++++- pkg/rbac/rbac_test.go | 68 ++++++++ ui/apps/everest/src/App.tsx | 2 +- 12 files changed, 341 insertions(+), 78 deletions(-) create mode 100644 internal/server/handlers/rbac/kubernetes_test.go create mode 100644 pkg/rbac/rbac_test.go diff --git a/internal/server/handlers/rbac/backup_storage_test.go b/internal/server/handlers/rbac/backup_storage_test.go index 2f7fe9bed..fc477ca99 100644 --- a/internal/server/handlers/rbac/backup_storage_test.go +++ b/internal/server/handlers/rbac/backup_storage_test.go @@ -142,7 +142,7 @@ func TestRBAC_BackupStorage(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -228,7 +228,7 @@ func TestRBAC_BackupStorage(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -395,7 +395,7 @@ func TestRBAC_BackupStorage(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -563,7 +563,7 @@ func TestRBAC_BackupStorage(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -727,7 +727,7 @@ func TestRBAC_BackupStorage(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -768,10 +768,10 @@ func newConfigMapPolicy(policy string) *corev1.ConfigMap { } } -func testUserGetter(ctx context.Context) (string, error) { - user, ok := ctx.Value(common.UserCtxKey).(string) +func testUserGetter(ctx context.Context) (rbac.User, error) { + user, ok := ctx.Value(common.UserCtxKey).(rbac.User) if !ok { - return "", errors.New("user not found in context") + return rbac.User{}, errors.New("user not found in context") } return user, nil } diff --git a/internal/server/handlers/rbac/database_cluster_backup_test.go b/internal/server/handlers/rbac/database_cluster_backup_test.go index 51b0f149b..2470c4d8d 100644 --- a/internal/server/handlers/rbac/database_cluster_backup_test.go +++ b/internal/server/handlers/rbac/database_cluster_backup_test.go @@ -138,7 +138,7 @@ func TestRBAC_DatabaseClusterBackup(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -216,7 +216,7 @@ func TestRBAC_DatabaseClusterBackup(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -282,7 +282,7 @@ func TestRBAC_DatabaseClusterBackup(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -355,7 +355,7 @@ func TestRBAC_DatabaseClusterBackup(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() diff --git a/internal/server/handlers/rbac/database_cluster_restore_test.go b/internal/server/handlers/rbac/database_cluster_restore_test.go index 7f77e9608..55c4d87ba 100644 --- a/internal/server/handlers/rbac/database_cluster_restore_test.go +++ b/internal/server/handlers/rbac/database_cluster_restore_test.go @@ -91,7 +91,7 @@ func TestRBAC_DatabaseClusterRestore(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -156,7 +156,7 @@ func TestRBAC_DatabaseClusterRestore(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -248,7 +248,7 @@ func TestRBAC_DatabaseClusterRestore(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -359,7 +359,7 @@ func TestRBAC_DatabaseClusterRestore(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -430,7 +430,7 @@ func TestRBAC_DatabaseClusterRestore(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() diff --git a/internal/server/handlers/rbac/database_cluster_test.go b/internal/server/handlers/rbac/database_cluster_test.go index 824e0ab04..d4e403132 100644 --- a/internal/server/handlers/rbac/database_cluster_test.go +++ b/internal/server/handlers/rbac/database_cluster_test.go @@ -197,7 +197,7 @@ func TestRBAC_DatabaseCluster(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "test-user") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "test-user"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -306,7 +306,7 @@ func TestRBAC_DatabaseCluster(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "test-user") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "test-user"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -434,7 +434,7 @@ func TestRBAC_DatabaseCluster(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "test-user") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "test-user"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -812,7 +812,7 @@ func TestRBAC_DatabaseCluster(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "test-user") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "test-user"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -892,7 +892,7 @@ func TestRBAC_DatabaseCluster(t *testing.T) { ) return h } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "test-user") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "test-user"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -956,7 +956,7 @@ func TestRBAC_DatabaseCluster(t *testing.T) { &api.DatabaseClusterCredential{}, nil) return h } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "test-user") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "test-user"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -1011,7 +1011,7 @@ func TestRBAC_DatabaseCluster(t *testing.T) { []api.DatabaseClusterComponent{}, nil) return h } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "test-user") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "test-user"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -1066,7 +1066,7 @@ func TestRBAC_DatabaseCluster(t *testing.T) { &api.DatabaseClusterPitr{}, nil) return h } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "test-user") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "test-user"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() diff --git a/internal/server/handlers/rbac/database_engine_test.go b/internal/server/handlers/rbac/database_engine_test.go index 084ecb615..bc4f5eba9 100644 --- a/internal/server/handlers/rbac/database_engine_test.go +++ b/internal/server/handlers/rbac/database_engine_test.go @@ -198,7 +198,7 @@ func TestRBAC_DatabaseEngines(t *testing.T) { }, }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -257,7 +257,7 @@ func TestRBAC_DatabaseEngines(t *testing.T) { return &h } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -313,7 +313,7 @@ func TestRBAC_DatabaseEngines(t *testing.T) { return &h } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -448,7 +448,7 @@ func TestRBAC_DatabaseEngines(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -567,7 +567,7 @@ func TestRBAC_DatabaseEngines(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() diff --git a/internal/server/handlers/rbac/handler.go b/internal/server/handlers/rbac/handler.go index a8020f9c2..37e0172eb 100644 --- a/internal/server/handlers/rbac/handler.go +++ b/internal/server/handlers/rbac/handler.go @@ -21,7 +21,7 @@ type rbacHandler struct { enforcer casbin.IEnforcer log *zap.SugaredLogger next handlers.Handler - userGetter func(ctx context.Context) (string, error) + userGetter func(ctx context.Context) (rbac.User, error) } // New returns a new RBAC handler. @@ -55,17 +55,23 @@ func (h *rbacHandler) enforce( action, object string, ) error { - subject, err := h.userGetter(ctx) + user, err := h.userGetter(ctx) if err != nil { return err } - ok, err := h.enforcer.Enforce(subject, resource, action, object) - if err != nil { - return fmt.Errorf("enforce error: %w", err) - } - if !ok { - h.log.Warnf("Permission denied: [%s %s %s %s]", subject, resource, action, object) - return ErrInsufficientPermissions + + // User is allowed to perform the operation if the user's subject or any + // of its groups have the required permission. + for _, sub := range append([]string{user.Subject}, user.Groups...) { + ok, err := h.enforcer.Enforce(sub, resource, action, object) + if err != nil { + return fmt.Errorf("enforce error: %w", err) + } + if ok { + return nil + } } - return nil + + h.log.Warnf("Permission denied: [%s %s %s %s]", user.Subject, resource, action, object) + return ErrInsufficientPermissions } diff --git a/internal/server/handlers/rbac/kubernetes.go b/internal/server/handlers/rbac/kubernetes.go index 83385fd3c..5e8353150 100644 --- a/internal/server/handlers/rbac/kubernetes.go +++ b/internal/server/handlers/rbac/kubernetes.go @@ -3,7 +3,6 @@ package rbac import ( "context" - "errors" "fmt" "github.com/AlekSi/pointer" @@ -28,15 +27,39 @@ func (h *rbacHandler) GetUserPermissions(ctx context.Context) (*api.UserPermissi if err != nil { return nil, err } - perms, err := h.enforcer.GetImplicitPermissionsForUser(user) - if err != nil { - return nil, fmt.Errorf("failed to GetImplicitPermissionsForUser: %w", err) + + // Let's use a map to deduplicate the permissions after resolving all roles + permsMap := make(map[[4]string]struct{}) + + // Get permissions for the user and the groups it belongs to + for _, sub := range append([]string{user.Subject}, user.Groups...) { + perms, err := h.enforcer.GetImplicitPermissionsForUser(sub) + if err != nil { + return nil, fmt.Errorf("failed to GetImplicitPermissionsForUser: %w", err) + } + + // GetImplicitPermissionsForUser returns all policies assigned to the + // user/group directly as well as the policies assigned to the roles + // the user/group has. We need to resolve all roles for the user. + for _, perm := range perms { + if len(perm) != 4 { + // This should never happen, but let's be safe + return nil, fmt.Errorf("invalid permission") + } + + // We don't want to expose the groups or roles in the permissions + // so we replace them with the user + permsMap[[4]string{user.Subject, perm[1], perm[2], perm[3]}] = struct{}{} + } } - if err := h.resolveRoles(user, perms); err != nil { - return nil, err + // Convert the map back to a slice + result := make([][]string, len(permsMap)) + i := 0 + for k := range permsMap { + result[i] = []string(k[:]) + i++ } - result := pointer.To(perms) nextRes, err := h.next.GetUserPermissions(ctx) if err != nil { @@ -49,25 +72,8 @@ func (h *rbacHandler) GetUserPermissions(ctx context.Context) (*api.UserPermissi } res := &api.UserPermissions{ - Permissions: result, + Permissions: pointer.To(result), Enabled: enabled, } return res, nil } - -// For a given set of `permissions` for a `user`, this function -// will resolve all roles for the user. -func (h *rbacHandler) resolveRoles(user string, permissions [][]string) error { - userRoles, err := h.enforcer.GetRolesForUser(user) - if err != nil { - return errors.Join(err, errors.New("cannot get user roles")) - } - for _, role := range userRoles { - for i, perm := range permissions { - if perm[0] == role { - permissions[i][0] = user - } - } - } - return nil -} diff --git a/internal/server/handlers/rbac/kubernetes_test.go b/internal/server/handlers/rbac/kubernetes_test.go new file mode 100644 index 000000000..621459431 --- /dev/null +++ b/internal/server/handlers/rbac/kubernetes_test.go @@ -0,0 +1,150 @@ +package rbac + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/percona/everest/api" + "github.com/percona/everest/internal/server/handlers" + "github.com/percona/everest/pkg/common" + "github.com/percona/everest/pkg/rbac" +) + +func TestRBAC_Kubernetes(t *testing.T) { + t.Parallel() + + data := func() *handlers.MockHandler { + next := handlers.MockHandler{} + next.On("GetUserPermissions", + mock.Anything, + ).Return( + &api.UserPermissions{ + Enabled: true, + }, + nil, + ) + return &next + } + + t.Run("GetUserPermissions", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + desc string + user rbac.User + policy string + outPerms [][]string + }{ + { + desc: "default admin permissions", + user: rbac.User{ + Subject: "test-user", + }, + policy: newPolicy( + "g, test-user, role:admin", + ), + outPerms: [][]string{ + {"test-user", "monitoring-instances", "*", "*/*"}, + {"test-user", "database-cluster-backups", "*", "*/*"}, + {"test-user", "database-cluster-restores", "*", "*/*"}, + {"test-user", "database-clusters", "*", "*/*"}, + {"test-user", "database-cluster-credentials", "*", "*/*"}, + {"test-user", "database-engines", "*", "*/*"}, + {"test-user", "namespaces", "*", "*"}, + {"test-user", "backup-storages", "*", "*/*"}, + }, + }, + { + desc: "permissions from different roles are merged", + user: rbac.User{ + Subject: "test-user", + }, + policy: newPolicy( + "p, test-user, database-clusters, *, */*", + "p, role:creater, database-clusters, create, */*", + "p, role:reader, database-clusters, read, */*", + "p, role:updater, database-clusters, update, */*", + "p, role:deleter, database-clusters, delete, */*", + "g, test-user, role:creater", + "g, test-user, role:reader", + "g, test-user, role:updater", + "g, another-user, role:deleter", + ), + outPerms: [][]string{ + {"test-user", "database-clusters", "*", "*/*"}, + {"test-user", "database-clusters", "create", "*/*"}, + {"test-user", "database-clusters", "read", "*/*"}, + {"test-user", "database-clusters", "update", "*/*"}, + }, + }, + { + desc: "permissions from different groups are merged", + user: rbac.User{ + Subject: "test-user", + Groups: []string{"test-group-1", "test-group-2"}, + }, + policy: newPolicy( + "p, test-user, database-clusters, read, */*", + "p, test-group-1, database-clusters, create, */*", + "p, test-group-2, database-clusters, update, */*", + "p, test-group-3, database-clusters, delete, */*", + ), + outPerms: [][]string{ + {"test-user", "database-clusters", "read", "*/*"}, + {"test-user", "database-clusters", "create", "*/*"}, + {"test-user", "database-clusters", "update", "*/*"}, + }, + }, + { + desc: "duplicate permissions are removed", + user: rbac.User{ + Subject: "test-user", + }, + policy: newPolicy( + "p, test-user, database-clusters, *, */*", + "p, role:test, database-clusters, *, */*", + "g, test-user, role:test", + ), + outPerms: [][]string{ + {"test-user", "database-clusters", "*", "*/*"}, + }, + }, + { + desc: "no policy", + user: rbac.User{ + Subject: "test-user", + }, + policy: newPolicy(), + outPerms: [][]string{}, + }, + } + + for _, tc := range testCases { + ctx := context.WithValue(context.Background(), common.UserCtxKey, tc.user) + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + k8sMock := newConfigMapMock(tc.policy) + enf, err := rbac.NewEnforcer(ctx, k8sMock, zap.NewNop().Sugar()) + require.NoError(t, err) + next := data() + + h := &rbacHandler{ + next: next, + log: zap.NewNop().Sugar(), + enforcer: enf, + userGetter: testUserGetter, + } + + perms, err := h.GetUserPermissions(ctx) + require.NoError(t, err) + assert.True(t, perms.Enabled) + assert.ElementsMatch(t, tc.outPerms, *perms.Permissions) + }) + } + }) +} diff --git a/internal/server/handlers/rbac/monitoring_instance_test.go b/internal/server/handlers/rbac/monitoring_instance_test.go index a39b2fcaa..fbfac2c26 100644 --- a/internal/server/handlers/rbac/monitoring_instance_test.go +++ b/internal/server/handlers/rbac/monitoring_instance_test.go @@ -159,7 +159,7 @@ func TestRBAC_MonitoringInstance(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -245,7 +245,7 @@ func TestRBAC_MonitoringInstance(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -412,7 +412,7 @@ func TestRBAC_MonitoringInstance(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -580,7 +580,7 @@ func TestRBAC_MonitoringInstance(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() @@ -744,7 +744,7 @@ func TestRBAC_MonitoringInstance(t *testing.T) { }, } - ctx := context.WithValue(context.Background(), common.UserCtxKey, "bob") + ctx := context.WithValue(context.Background(), common.UserCtxKey, rbac.User{Subject: "bob"}) for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { t.Parallel() diff --git a/pkg/rbac/rbac.go b/pkg/rbac/rbac.go index 2205bbf0e..f1e06f058 100644 --- a/pkg/rbac/rbac.go +++ b/pkg/rbac/rbac.go @@ -68,6 +68,11 @@ const ( rbacEnabledValueTrue = "true" ) +type User struct { + Subject string + Groups []string +} + // Setup a new informer that watches our RBAC ConfigMap. // This informer reloads the policy whenever the ConfigMap is updated. func refreshEnforcerInBackground( @@ -178,31 +183,59 @@ func NewEnforcer(ctx context.Context, kubeClient kubernetes.KubernetesConnector, } // GetUser extracts the user from the JWT token in the context. -func GetUser(ctx context.Context) (string, error) { +func GetUser(ctx context.Context) (User, error) { token, ok := ctx.Value(common.UserCtxKey).(*jwt.Token) if !ok { - return "", errors.New("failed to get token from context") + return User{}, errors.New("failed to get token from context") } claims, ok := token.Claims.(jwt.MapClaims) // by default claims is of type `jwt.MapClaims` if !ok { - return "", errors.New("failed to get claims from token") + return User{}, errors.New("failed to get claims from token") } subject, err := claims.GetSubject() if err != nil { - return "", errors.Join(err, errors.New("failed to get subject from claims")) + return User{}, errors.Join(err, errors.New("failed to get subject from claims")) } issuer, err := claims.GetIssuer() if err != nil { - return "", errors.Join(err, errors.New("failed to get issuer from claims")) + return User{}, errors.Join(err, errors.New("failed to get issuer from claims")) } if issuer == session.SessionManagerClaimsIssuer { - return strings.Split(subject, ":")[0], nil + subject = strings.Split(subject, ":")[0] + } + + groups := getScopeValues(claims, []string{"groups"}) + return User{Subject: subject, Groups: groups}, nil +} + +func getScopeValues(claims jwt.MapClaims, scopes []string) []string { + groups := []string{} + for i := range scopes { + scopeIf, ok := claims[scopes[i]] + if !ok { + continue + } + + switch val := scopeIf.(type) { + case []interface{}: + for _, groupIf := range val { + group, ok := groupIf.(string) + if ok { + groups = append(groups, group) + } + } + case []string: + groups = append(groups, val...) + case string: + groups = append(groups, val) + } } - return subject, nil + + return groups } func loadAdminPolicy(enf casbin.IEnforcer) error { diff --git a/pkg/rbac/rbac_test.go b/pkg/rbac/rbac_test.go new file mode 100644 index 000000000..de271fae6 --- /dev/null +++ b/pkg/rbac/rbac_test.go @@ -0,0 +1,68 @@ +package rbac + +import ( + "testing" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" +) + +func TestGetScopeValues(t *testing.T) { + t.Parallel() + testcases := []struct { + desc string + claims jwt.MapClaims + scopes []string + out []string + }{ + { + desc: "empty claims", + claims: jwt.MapClaims{}, + scopes: []string{"groups"}, + out: []string{}, + }, + { + desc: "empty scopes", + claims: jwt.MapClaims{"groups": []string{"my-org:my-team"}}, + scopes: nil, + out: []string{}, + }, + { + desc: "empty groups", + claims: jwt.MapClaims{"groups": []string{}}, + scopes: []string{"groups"}, + out: []string{}, + }, + { + desc: "single group", + claims: jwt.MapClaims{"groups": []string{"my-org:my-team"}}, + scopes: []string{"groups"}, + out: []string{"my-org:my-team"}, + }, + { + desc: "multiple groups", + claims: jwt.MapClaims{"groups": []string{"my-org:my-team1", "my-org:my-team2"}}, + scopes: []string{"groups"}, + out: []string{"my-org:my-team1", "my-org:my-team2"}, + }, + { + desc: "multiple groups and other", + claims: jwt.MapClaims{"groups": []string{"my-org:my-team1", "my-org:my-team2"}, "other": []string{"other1", "other2"}}, + scopes: []string{"groups"}, + out: []string{"my-org:my-team1", "my-org:my-team2"}, + }, + { + desc: "multiple groups and other with all scopes", + claims: jwt.MapClaims{"groups": []string{"my-org:my-team1", "my-org:my-team2"}, "other": []string{"other1", "other2"}}, + scopes: []string{"groups", "other"}, + out: []string{"my-org:my-team1", "my-org:my-team2", "other1", "other2"}, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.out, getScopeValues(tc.claims, tc.scopes)) + }) + } +} diff --git a/ui/apps/everest/src/App.tsx b/ui/apps/everest/src/App.tsx index 2e79fab83..5d85ed797 100644 --- a/ui/apps/everest/src/App.tsx +++ b/ui/apps/everest/src/App.tsx @@ -72,7 +72,7 @@ const App = () => { oidcConfig={{ ...configs?.oidc, redirectUri: `${window.location.protocol}//${window.location.host}/login-callback`, - scope: 'openid profile email', + scope: 'openid profile email groups', responseType: 'code', autoSignIn: false, automaticSilentRenew: false,