diff --git a/auth/service.go b/auth/service.go index b5c23672e0..8d6b3ec6a3 100644 --- a/auth/service.go +++ b/auth/service.go @@ -497,15 +497,17 @@ func (svc service) RetrieveDomain(ctx context.Context, token, id string) (Domain if err != nil { return Domain{}, errors.Wrap(svcerr.ErrViewEntity, err) } - if err = svc.Authorize(ctx, policies.Policy{ - Subject: EncodeDomainUserID(id, res.User), - SubjectType: policies.UserType, - SubjectKind: policies.UsersKind, - Object: id, - ObjectType: policies.DomainType, - Permission: policies.MembershipPermission, - }); err != nil { - return Domain{ID: domain.ID, Name: domain.Name, Alias: domain.Alias}, nil + if err := svc.checkSuperAdmin(ctx, res.User); err != nil { + if err = svc.Authorize(ctx, policies.Policy{ + Subject: EncodeDomainUserID(id, res.User), + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Object: id, + ObjectType: policies.DomainType, + Permission: policies.MembershipPermission, + }); err != nil { + return Domain{ID: domain.ID, Name: domain.Name, Alias: domain.Alias}, nil + } } return domain, nil } @@ -515,21 +517,25 @@ func (svc service) RetrieveDomainPermissions(ctx context.Context, token, id stri if err != nil { return []string{}, err } - domainUserSubject := EncodeDomainUserID(id, res.User) - if err := svc.Authorize(ctx, policies.Policy{ - Subject: domainUserSubject, - SubjectType: policies.UserType, - SubjectKind: policies.UsersKind, - Object: id, - ObjectType: policies.DomainType, - Permission: policies.MembershipPermission, - }); err != nil { - return []string{}, err + subject := res.User + if err := svc.checkSuperAdmin(ctx, res.User); err != nil { + domainUserSubject := EncodeDomainUserID(id, res.User) + if err := svc.Authorize(ctx, policies.Policy{ + Subject: domainUserSubject, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Object: id, + ObjectType: policies.DomainType, + Permission: policies.MembershipPermission, + }); err != nil { + return []string{}, err + } + subject = domainUserSubject } lp, err := svc.policysvc.ListPermissions(ctx, policies.Policy{ SubjectType: policies.UserType, - Subject: domainUserSubject, + Subject: subject, Object: id, ObjectType: policies.DomainType, }, []string{policies.AdminPermission, policies.EditPermission, policies.ViewPermission, policies.MembershipPermission, policies.CreatePermission}) @@ -544,15 +550,17 @@ func (svc service) UpdateDomain(ctx context.Context, token, id string, d DomainR if err != nil { return Domain{}, err } - if err := svc.Authorize(ctx, policies.Policy{ - Subject: EncodeDomainUserID(id, key.User), - SubjectType: policies.UserType, - SubjectKind: policies.UsersKind, - Object: id, - ObjectType: policies.DomainType, - Permission: policies.EditPermission, - }); err != nil { - return Domain{}, err + if err := svc.checkSuperAdmin(ctx, key.User); err != nil { + if err := svc.Authorize(ctx, policies.Policy{ + Subject: EncodeDomainUserID(id, key.User), + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Object: id, + ObjectType: policies.DomainType, + Permission: policies.EditPermission, + }); err != nil { + return Domain{}, err + } } dom, err := svc.domains.Update(ctx, id, key.User, d) @@ -567,15 +575,17 @@ func (svc service) ChangeDomainStatus(ctx context.Context, token, id string, d D if err != nil { return Domain{}, errors.Wrap(svcerr.ErrAuthentication, err) } - if err := svc.Authorize(ctx, policies.Policy{ - Subject: EncodeDomainUserID(id, key.User), - SubjectType: policies.UserType, - SubjectKind: policies.UsersKind, - Object: id, - ObjectType: policies.DomainType, - Permission: policies.AdminPermission, - }); err != nil { - return Domain{}, err + if err := svc.checkSuperAdmin(ctx, key.User); err != nil { + if err := svc.Authorize(ctx, policies.Policy{ + Subject: EncodeDomainUserID(id, key.User), + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Object: id, + ObjectType: policies.DomainType, + Permission: policies.AdminPermission, + }); err != nil { + return Domain{}, err + } } dom, err := svc.domains.Update(ctx, id, key.User, d) @@ -591,13 +601,7 @@ func (svc service) ListDomains(ctx context.Context, token string, p Page) (Domai return DomainsPage{}, errors.Wrap(svcerr.ErrAuthentication, err) } p.SubjectID = key.User - if err := svc.Authorize(ctx, policies.Policy{ - Subject: key.User, - SubjectType: policies.UserType, - Permission: policies.AdminPermission, - ObjectType: policies.PlatformType, - Object: policies.MagistralaObject, - }); err == nil { + if err := svc.checkSuperAdmin(ctx, key.User); err == nil { p.SubjectID = "" } dp, err := svc.domains.ListDomains(ctx, p) @@ -618,27 +622,29 @@ func (svc service) AssignUsers(ctx context.Context, token, id string, userIds [] return errors.Wrap(svcerr.ErrAuthentication, err) } - domainUserID := EncodeDomainUserID(id, res.User) - if err := svc.Authorize(ctx, policies.Policy{ - Subject: domainUserID, - SubjectType: policies.UserType, - SubjectKind: policies.UsersKind, - Object: id, - ObjectType: policies.DomainType, - Permission: policies.SharePermission, - }); err != nil { - return err - } + if err := svc.checkSuperAdmin(ctx, res.User); err != nil { + domainUserID := EncodeDomainUserID(id, res.User) + if err := svc.Authorize(ctx, policies.Policy{ + Subject: domainUserID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Object: id, + ObjectType: policies.DomainType, + Permission: policies.SharePermission, + }); err != nil { + return err + } - if err := svc.Authorize(ctx, policies.Policy{ - Subject: domainUserID, - SubjectType: policies.UserType, - SubjectKind: policies.UsersKind, - Object: id, - ObjectType: policies.DomainType, - Permission: SwitchToPermission(relation), - }); err != nil { - return err + if err := svc.Authorize(ctx, policies.Policy{ + Subject: domainUserID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Object: id, + ObjectType: policies.DomainType, + Permission: SwitchToPermission(relation), + }); err != nil { + return err + } } for _, userID := range userIds { @@ -662,27 +668,29 @@ func (svc service) UnassignUser(ctx context.Context, token, id, userID string) e return errors.Wrap(svcerr.ErrAuthentication, err) } - domainUserID := EncodeDomainUserID(id, res.User) - pr := policies.Policy{ - Subject: domainUserID, - SubjectType: policies.UserType, - SubjectKind: policies.UsersKind, - Object: id, - ObjectType: policies.DomainType, - Permission: policies.SharePermission, - } - if err := svc.Authorize(ctx, pr); err != nil { - return err - } + if err := svc.checkSuperAdmin(ctx, res.User); err != nil { + domainUserID := EncodeDomainUserID(id, res.User) + pr := policies.Policy{ + Subject: domainUserID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Object: id, + ObjectType: policies.DomainType, + Permission: policies.SharePermission, + } + if err := svc.Authorize(ctx, pr); err != nil { + return err + } - pr.Permission = policies.AdminPermission - if err := svc.Authorize(ctx, pr); err != nil { - pr.SubjectKind = policies.UsersKind - // User is not admin. - pr.Subject = userID - if err := svc.Authorize(ctx, pr); err == nil { - // Non admin attempts to remove admin. - return errors.Wrap(svcerr.ErrAuthorization, err) + pr.Permission = policies.AdminPermission + if err := svc.Authorize(ctx, pr); err != nil { + pr.SubjectKind = policies.UsersKind + // User is not admin. + pr.Subject = userID + if err := svc.Authorize(ctx, pr); err == nil { + // Non admin attempts to remove admin. + return errors.Wrap(svcerr.ErrAuthorization, err) + } } } @@ -713,13 +721,7 @@ func (svc service) ListUserDomains(ctx context.Context, token, userID string, p if err != nil { return DomainsPage{}, errors.Wrap(svcerr.ErrAuthentication, err) } - if err := svc.Authorize(ctx, policies.Policy{ - Subject: res.User, - SubjectType: policies.UserType, - Permission: policies.AdminPermission, - Object: policies.MagistralaObject, - ObjectType: policies.PlatformType, - }); err != nil { + if err := svc.checkSuperAdmin(ctx, res.User); err != nil { return DomainsPage{}, errors.Wrap(svcerr.ErrAuthorization, err) } if userID != "" && res.User != userID { @@ -906,3 +908,17 @@ func (svc service) DeleteUserFromDomains(ctx context.Context, id string) (err er return nil } + +func (svc service) checkSuperAdmin(ctx context.Context, userID string) error { + if err := svc.evaluator.CheckPolicy(ctx, policies.Policy{ + Subject: userID, + SubjectType: policies.UserType, + Permission: policies.AdminPermission, + Object: policies.MagistralaObject, + ObjectType: policies.PlatformType, + }); err != nil { + return svcerr.ErrAuthorization + } + + return nil +} diff --git a/auth/service_test.go b/auth/service_test.go index f17c114322..b338a44eb9 100644 --- a/auth/service_test.go +++ b/auth/service_test.go @@ -1396,15 +1396,23 @@ func TestRetrieveDomain(t *testing.T) { domainID string domainRepoErr error domainRepoErr1 error + checkAdminErr error checkPolicyErr error err error }{ { - desc: "retrieve domain successfully", + desc: "retrieve domain successfully as super admin", token: accessToken, domainID: validID, err: nil, }, + { + desc: "retrieve domain successfully as domain admin", + token: accessToken, + domainID: validID, + checkAdminErr: svcerr.ErrAuthorization, + err: nil, + }, { desc: "retrieve domain with invalid token", token: inValidToken, @@ -1438,13 +1446,44 @@ func TestRetrieveDomain(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := drepo.On("RetrieveByID", mock.Anything, groupName).Return(auth.Domain{}, tc.domainRepoErr) - repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr) + policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: userID, + SubjectType: policies.UserType, + Permission: policies.AdminPermission, + Object: policies.MagistralaObject, + ObjectType: policies.PlatformType, + }).Return(tc.checkAdminErr) + policyCall1 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: auth.EncodeDomainUserID(tc.domainID, userID), + SubjectType: policies.UserType, + Permission: policies.MembershipPermission, + Object: tc.domainID, + ObjectType: policies.DomainType, + }).Return(tc.checkPolicyErr) + policyCall2 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: userID, + SubjectType: policies.UserType, + Permission: policies.AdminPermission, + Object: tc.domainID, + ObjectType: policies.DomainType, + }).Return(tc.checkPolicyErr) + policyCall3 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: auth.EncodeDomainUserID(tc.domainID, userID), + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Permission: policies.MembershipPermission, + Object: tc.domainID, + ObjectType: policies.DomainType, + }).Return(tc.checkPolicyErr) repoCall2 := drepo.On("RetrieveByID", mock.Anything, tc.domainID).Return(auth.Domain{}, tc.domainRepoErr1) _, err := svc.RetrieveDomain(context.Background(), tc.token, tc.domainID) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) repoCall.Unset() - repoCall1.Unset() repoCall2.Unset() + policyCall.Unset() + policyCall1.Unset() + policyCall2.Unset() + policyCall3.Unset() }) } } @@ -1458,15 +1497,23 @@ func TestRetrieveDomainPermissions(t *testing.T) { domainID string retreivePermissionsErr error retreiveByIDErr error + checkAdminErr error checkPolicyErr error err error }{ { - desc: "retrieve domain permissions successfully", + desc: "retrieve domain permissions successfully as platform admin", token: accessToken, domainID: validID, err: nil, }, + { + desc: "retrieve domain permissions successfully as domain admin", + token: accessToken, + domainID: validID, + checkAdminErr: svcerr.ErrAuthorization, + err: nil, + }, { desc: "retrieve domain permissions with invalid token", token: inValidToken, @@ -1474,11 +1521,11 @@ func TestRetrieveDomainPermissions(t *testing.T) { err: svcerr.ErrAuthentication, }, { - desc: "retrieve domain permissions with empty domainID", - token: accessToken, - domainID: "", - checkPolicyErr: svcerr.ErrAuthorization, - err: svcerr.ErrDomainAuthorization, + desc: "retrieve domain permissions with empty domainID", + token: accessToken, + domainID: "", + retreivePermissionsErr: svcerr.ErrAuthorization, + err: svcerr.ErrAuthorization, }, { desc: "retrieve domain permissions with failed to retrieve permissions", @@ -1491,6 +1538,7 @@ func TestRetrieveDomainPermissions(t *testing.T) { desc: "retrieve domain permissions with failed to retrieve by id", token: accessToken, domainID: validID, + checkAdminErr: svcerr.ErrAuthorization, retreiveByIDErr: repoerr.ErrNotFound, err: svcerr.ErrNotFound, }, @@ -1500,12 +1548,43 @@ func TestRetrieveDomainPermissions(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { repoCall := pService.On("ListPermissions", mock.Anything, mock.Anything, mock.Anything).Return(policies.Permissions{}, tc.retreivePermissionsErr) repoCall1 := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.retreiveByIDErr) - repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr) + policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: userID, + SubjectType: policies.UserType, + Permission: policies.AdminPermission, + Object: policies.MagistralaObject, + ObjectType: policies.PlatformType, + }).Return(tc.checkAdminErr) + policyCall1 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: auth.EncodeDomainUserID(tc.domainID, userID), + SubjectType: policies.UserType, + Permission: policies.MembershipPermission, + Object: tc.domainID, + ObjectType: policies.DomainType, + }).Return(tc.checkPolicyErr) + policyCall2 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: userID, + SubjectType: policies.UserType, + Permission: policies.AdminPermission, + Object: tc.domainID, + ObjectType: policies.DomainType, + }).Return(tc.checkPolicyErr) + policyCall3 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: auth.EncodeDomainUserID(tc.domainID, userID), + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Permission: policies.MembershipPermission, + Object: tc.domainID, + ObjectType: policies.DomainType, + }).Return(tc.checkPolicyErr) _, err := svc.RetrieveDomainPermissions(context.Background(), tc.token, tc.domainID) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) repoCall.Unset() repoCall1.Unset() - repoCall2.Unset() + policyCall.Unset() + policyCall1.Unset() + policyCall2.Unset() + policyCall3.Unset() }) } } @@ -1521,10 +1600,11 @@ func TestUpdateDomain(t *testing.T) { checkPolicyErr error retrieveByIDErr error updateErr error + checkAdminErr error err error }{ { - desc: "update domain successfully", + desc: "update domain successfully as platform admin", token: accessToken, domainID: validID, domReq: auth.DomainReq{ @@ -1533,6 +1613,17 @@ func TestUpdateDomain(t *testing.T) { }, err: nil, }, + { + desc: "update domain successfully as domain admin", + token: accessToken, + domainID: validID, + domReq: auth.DomainReq{ + Name: &valid, + Alias: &valid, + }, + checkAdminErr: svcerr.ErrAuthorization, + err: nil, + }, { desc: "update domain with invalid token", token: inValidToken, @@ -1552,6 +1643,7 @@ func TestUpdateDomain(t *testing.T) { Alias: &valid, }, checkPolicyErr: svcerr.ErrAuthorization, + checkAdminErr: svcerr.ErrAuthorization, err: svcerr.ErrDomainAuthorization, }, { @@ -1563,6 +1655,7 @@ func TestUpdateDomain(t *testing.T) { Alias: &valid, }, retrieveByIDErr: repoerr.ErrNotFound, + checkAdminErr: svcerr.ErrAuthorization, err: svcerr.ErrNotFound, }, { @@ -1580,14 +1673,45 @@ func TestUpdateDomain(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - repoCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr) + policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: userID, + SubjectType: policies.UserType, + Permission: policies.AdminPermission, + Object: policies.MagistralaObject, + ObjectType: policies.PlatformType, + }).Return(tc.checkAdminErr) + policyCall1 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: auth.EncodeDomainUserID(tc.domainID, userID), + SubjectType: policies.UserType, + Permission: policies.MembershipPermission, + Object: tc.domainID, + ObjectType: policies.DomainType, + }).Return(tc.checkPolicyErr) + policyCall2 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: userID, + SubjectType: policies.UserType, + Permission: policies.AdminPermission, + Object: tc.domainID, + ObjectType: policies.DomainType, + }).Return(tc.checkPolicyErr) + policyCall3 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: auth.EncodeDomainUserID(tc.domainID, userID), + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Permission: policies.EditPermission, + Object: tc.domainID, + ObjectType: policies.DomainType, + }).Return(tc.checkPolicyErr) repoCall1 := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.retrieveByIDErr) repoCall2 := drepo.On("Update", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(auth.Domain{}, tc.updateErr) _, err := svc.UpdateDomain(context.Background(), tc.token, tc.domainID, tc.domReq) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) - repoCall.Unset() repoCall1.Unset() repoCall2.Unset() + policyCall.Unset() + policyCall1.Unset() + policyCall2.Unset() + policyCall3.Unset() }) } } @@ -1604,11 +1728,12 @@ func TestChangeDomainStatus(t *testing.T) { domainReq auth.DomainReq retreieveByIDErr error checkPolicyErr error + checkAdminErr error updateErr error err error }{ { - desc: "change domain status successfully", + desc: "change domain status successfully as platform admin", token: accessToken, domainID: validID, domainReq: auth.DomainReq{ @@ -1616,6 +1741,16 @@ func TestChangeDomainStatus(t *testing.T) { }, err: nil, }, + { + desc: "change domain status successfully as platform admin", + token: accessToken, + domainID: validID, + domainReq: auth.DomainReq{ + Status: &disabledStatus, + }, + checkAdminErr: svcerr.ErrAuthorization, + err: nil, + }, { desc: "change domain status with invalid token", token: inValidToken, @@ -1632,6 +1767,7 @@ func TestChangeDomainStatus(t *testing.T) { domainReq: auth.DomainReq{ Status: &disabledStatus, }, + checkAdminErr: svcerr.ErrAuthorization, retreieveByIDErr: repoerr.ErrNotFound, err: svcerr.ErrNotFound, }, @@ -1642,6 +1778,7 @@ func TestChangeDomainStatus(t *testing.T) { domainReq: auth.DomainReq{ Status: &disabledStatus, }, + checkAdminErr: svcerr.ErrAuthorization, checkPolicyErr: svcerr.ErrAuthorization, err: svcerr.ErrDomainAuthorization, }, @@ -1660,13 +1797,44 @@ func TestChangeDomainStatus(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.retreieveByIDErr) - repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(tc.checkPolicyErr) + policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: userID, + SubjectType: policies.UserType, + Permission: policies.AdminPermission, + Object: policies.MagistralaObject, + ObjectType: policies.PlatformType, + }).Return(tc.checkAdminErr) + policyCall1 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: auth.EncodeDomainUserID(tc.domainID, userID), + SubjectType: policies.UserType, + Permission: policies.MembershipPermission, + Object: tc.domainID, + ObjectType: policies.DomainType, + }).Return(tc.checkPolicyErr) + policyCall2 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: userID, + SubjectType: policies.UserType, + Permission: policies.AdminPermission, + Object: tc.domainID, + ObjectType: policies.DomainType, + }).Return(tc.checkPolicyErr) + policyCall3 := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: auth.EncodeDomainUserID(tc.domainID, userID), + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Permission: policies.AdminPermission, + Object: tc.domainID, + ObjectType: policies.DomainType, + }).Return(tc.checkPolicyErr) repoCall2 := drepo.On("Update", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(auth.Domain{}, tc.updateErr) _, err := svc.ChangeDomainStatus(context.Background(), tc.token, tc.domainID, tc.domainReq) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) repoCall.Unset() - repoCall1.Unset() repoCall2.Unset() + policyCall.Unset() + policyCall1.Unset() + policyCall2.Unset() + policyCall3.Unset() }) } } @@ -1743,25 +1911,26 @@ func TestAssignUsers(t *testing.T) { svc, accessToken := newService() cases := []struct { - desc string - token string - domainID string - userIDs []string - relation string - checkPolicyReq policies.Policy - checkAdminPolicyReq policies.Policy - checkDomainPolicyReq policies.Policy - checkPolicyReq1 policies.Policy - checkpolicyErr error - checkPolicyErr1 error - checkPolicyErr2 error - addPoliciesErr error - savePoliciesErr error - deletePoliciesErr error - err error + desc string + token string + domainID string + userIDs []string + relation string + checkPolicyReq policies.Policy + checkAdminPolicyReq policies.Policy + checkDomainPolicyReq policies.Policy + checkPolicyReq1 policies.Policy + checkpolicyErr error + checkPolicyErr1 error + checkPolicyErr2 error + addPoliciesErr error + savePoliciesErr error + deletePoliciesErr error + checkPlatformAdminErr error + err error }{ { - desc: "assign users successfully", + desc: "assign users successfully as platform admin", token: accessToken, domainID: validID, userIDs: []string{validID}, @@ -1798,6 +1967,45 @@ func TestAssignUsers(t *testing.T) { }, err: nil, }, + { + desc: "assign users successfully", + token: accessToken, + domainID: validID, + userIDs: []string{validID}, + relation: policies.ContributorRelation, + checkPolicyReq: policies.Policy{ + Subject: auth.EncodeDomainUserID(validID, userID), + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Object: validID, + ObjectType: policies.DomainType, + Permission: policies.SharePermission, + }, + checkAdminPolicyReq: policies.Policy{ + Subject: auth.EncodeDomainUserID(validID, userID), + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Object: validID, + ObjectType: policies.DomainType, + Permission: policies.ViewPermission, + }, + checkDomainPolicyReq: policies.Policy{ + Subject: validID, + SubjectType: policies.UserType, + Object: policies.MagistralaObject, + ObjectType: policies.PlatformType, + Permission: policies.MembershipPermission, + }, + checkPolicyReq1: policies.Policy{ + Subject: auth.EncodeDomainUserID(validID, userID), + SubjectType: policies.UserType, + Object: validID, + ObjectType: policies.DomainType, + Permission: policies.MembershipPermission, + }, + checkPlatformAdminErr: svcerr.ErrAuthorization, + err: nil, + }, { desc: "assign users with invalid token", token: inValidToken, @@ -1828,7 +2036,8 @@ func TestAssignUsers(t *testing.T) { ObjectType: policies.PlatformType, Permission: policies.MembershipPermission, }, - err: svcerr.ErrAuthentication, + checkPlatformAdminErr: svcerr.ErrAuthorization, + err: svcerr.ErrAuthentication, }, { desc: "assign users with invalid domainID", @@ -1858,8 +2067,9 @@ func TestAssignUsers(t *testing.T) { ObjectType: policies.DomainType, Permission: policies.MembershipPermission, }, - checkPolicyErr1: svcerr.ErrAuthorization, - err: svcerr.ErrAuthorization, + checkPolicyErr1: svcerr.ErrAuthorization, + checkPlatformAdminErr: svcerr.ErrAuthorization, + err: svcerr.ErrAuthorization, }, { desc: "assign users with invalid userIDs", @@ -1897,8 +2107,9 @@ func TestAssignUsers(t *testing.T) { ObjectType: policies.DomainType, Permission: policies.MembershipPermission, }, - checkPolicyErr2: svcerr.ErrMalformedEntity, - err: svcerr.ErrDomainAuthorization, + checkPolicyErr2: svcerr.ErrMalformedEntity, + checkPlatformAdminErr: svcerr.ErrAuthorization, + err: svcerr.ErrDomainAuthorization, }, { desc: "assign users with failed to add policies to agent", @@ -1936,8 +2147,9 @@ func TestAssignUsers(t *testing.T) { ObjectType: policies.DomainType, Permission: policies.MembershipPermission, }, - addPoliciesErr: svcerr.ErrAuthorization, - err: errAddPolicies, + addPoliciesErr: svcerr.ErrAuthorization, + checkPlatformAdminErr: svcerr.ErrAuthorization, + err: errAddPolicies, }, { desc: "assign users with failed to save policies to domain", @@ -1975,8 +2187,9 @@ func TestAssignUsers(t *testing.T) { ObjectType: policies.DomainType, Permission: policies.MembershipPermission, }, - savePoliciesErr: repoerr.ErrCreateEntity, - err: errAddPolicies, + checkPlatformAdminErr: svcerr.ErrAuthorization, + savePoliciesErr: repoerr.ErrCreateEntity, + err: errAddPolicies, }, { desc: "assign users with failed to save policies to domain and failed to delete", @@ -2014,15 +2227,23 @@ func TestAssignUsers(t *testing.T) { ObjectType: policies.DomainType, Permission: policies.MembershipPermission, }, - savePoliciesErr: repoerr.ErrCreateEntity, - deletePoliciesErr: svcerr.ErrDomainAuthorization, - err: errAddPolicies, + savePoliciesErr: repoerr.ErrCreateEntity, + deletePoliciesErr: svcerr.ErrDomainAuthorization, + checkPlatformAdminErr: svcerr.ErrAuthorization, + err: errAddPolicies, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, nil) + policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: userID, + SubjectType: policies.UserType, + Permission: policies.AdminPermission, + Object: policies.MagistralaObject, + ObjectType: policies.PlatformType, + }).Return(tc.checkPlatformAdminErr) repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPolicyReq).Return(tc.checkpolicyErr) repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkAdminPolicyReq).Return(tc.checkPolicyErr1) repoCall3 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkDomainPolicyReq).Return(tc.checkPolicyErr2) @@ -2040,6 +2261,7 @@ func TestAssignUsers(t *testing.T) { repoCall5.Unset() repoCall6.Unset() repoCall7.Unset() + policyCall.Unset() }) } } @@ -2059,10 +2281,11 @@ func TestUnassignUser(t *testing.T) { checkPolicyErr1 error deletePolicyFilterErr error deletePoliciesErr error + checkPlatformAdminErr error err error }{ { - desc: "unassign user successfully", + desc: "unassign user successfully as platform admin", token: accessToken, domainID: validID, userID: validID, @@ -2091,6 +2314,37 @@ func TestUnassignUser(t *testing.T) { }, err: nil, }, + { + desc: "unassign user successfully as domain admin", + token: accessToken, + domainID: validID, + userID: validID, + checkPolicyReq: policies.Policy{ + Subject: auth.EncodeDomainUserID(validID, userID), + SubjectType: policies.UserType, + Object: validID, + ObjectType: policies.DomainType, + Permission: policies.MembershipPermission, + }, + checkAdminPolicyReq: policies.Policy{ + Subject: auth.EncodeDomainUserID(validID, userID), + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Object: validID, + ObjectType: policies.DomainType, + Permission: policies.AdminPermission, + }, + checkDomainPolicyReq: policies.Policy{ + Subject: auth.EncodeDomainUserID(validID, userID), + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Object: validID, + ObjectType: policies.DomainType, + Permission: policies.SharePermission, + }, + checkPlatformAdminErr: svcerr.ErrAuthorization, + err: nil, + }, { desc: "unassign users with invalid token", token: inValidToken, @@ -2112,7 +2366,8 @@ func TestUnassignUser(t *testing.T) { ObjectType: policies.DomainType, Permission: policies.AdminPermission, }, - err: svcerr.ErrAuthentication, + checkPlatformAdminErr: svcerr.ErrAuthorization, + err: svcerr.ErrAuthentication, }, { desc: "unassign users with invalid domainID", @@ -2142,8 +2397,9 @@ func TestUnassignUser(t *testing.T) { ObjectType: policies.DomainType, Permission: policies.MembershipPermission, }, - checkPolicyErr1: svcerr.ErrAuthorization, - err: svcerr.ErrDomainAuthorization, + checkPolicyErr1: svcerr.ErrAuthorization, + checkPlatformAdminErr: svcerr.ErrAuthorization, + err: svcerr.ErrDomainAuthorization, }, { desc: "unassign users with failed to delete policies from agent", @@ -2174,6 +2430,7 @@ func TestUnassignUser(t *testing.T) { Permission: policies.MembershipPermission, }, deletePolicyFilterErr: errors.ErrMalformedEntity, + checkPlatformAdminErr: svcerr.ErrAuthorization, err: errors.ErrMalformedEntity, }, { @@ -2206,6 +2463,7 @@ func TestUnassignUser(t *testing.T) { }, deletePoliciesErr: errors.ErrMalformedEntity, deletePolicyFilterErr: errors.ErrMalformedEntity, + checkPlatformAdminErr: svcerr.ErrAuthorization, err: errors.ErrMalformedEntity, }, { @@ -2236,26 +2494,35 @@ func TestUnassignUser(t *testing.T) { ObjectType: policies.DomainType, Permission: policies.SharePermission, }, - deletePoliciesErr: errors.ErrMalformedEntity, - err: errors.ErrMalformedEntity, + deletePoliciesErr: errors.ErrMalformedEntity, + checkPlatformAdminErr: svcerr.ErrAuthorization, + err: errors.ErrMalformedEntity, }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { repoCall := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, nil) - repoCall1 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPolicyReq).Return(tc.checkPolicyErr) - repoCall2 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkAdminPolicyReq).Return(tc.checkPolicyErr1) - repoCall3 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkDomainPolicyReq).Return(tc.checkPolicyErr1) + policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ + Subject: userID, + SubjectType: policies.UserType, + Permission: policies.AdminPermission, + Object: policies.MagistralaObject, + ObjectType: policies.PlatformType, + }).Return(tc.checkPlatformAdminErr) + policyCall1 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkPolicyReq).Return(tc.checkPolicyErr) + policyCall2 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkAdminPolicyReq).Return(tc.checkPolicyErr1) + policyCall3 := pEvaluator.On("CheckPolicy", mock.Anything, tc.checkDomainPolicyReq).Return(tc.checkPolicyErr1) repoCall4 := pService.On("DeletePolicyFilter", mock.Anything, mock.Anything).Return(tc.deletePolicyFilterErr) repoCall5 := drepo.On("DeletePolicies", mock.Anything, mock.Anything, mock.Anything).Return(tc.deletePoliciesErr) err := svc.UnassignUser(context.Background(), tc.token, tc.domainID, tc.userID) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) repoCall.Unset() - repoCall1.Unset() - repoCall2.Unset() - repoCall3.Unset() + policyCall.Unset() + policyCall1.Unset() + policyCall2.Unset() repoCall4.Unset() + policyCall3.Unset() repoCall5.Unset() }) } diff --git a/invitations/middleware/authorization.go b/invitations/middleware/authorization.go index 1f89b1fef2..84d0a32934 100644 --- a/invitations/middleware/authorization.go +++ b/invitations/middleware/authorization.go @@ -30,9 +30,6 @@ func AuthorizationMiddleware(authz authz.Authorization, svc invitations.Service) } func (am *authorizationMiddleware) SendInvitation(ctx context.Context, session authn.Session, invitation invitations.Invitation) (err error) { - if err := am.checkAdmin(ctx, session.UserID, session.DomainID); err != nil { - return err - } session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID) domainUserId := auth.EncodeDomainUserID(invitation.DomainID, invitation.UserID) if err := am.authorize(ctx, domainUserId, policies.MembershipPermission, policies.DomainType, invitation.DomainID); err == nil { @@ -40,7 +37,7 @@ func (am *authorizationMiddleware) SendInvitation(ctx context.Context, session a return errors.Wrap(svcerr.ErrConflict, ErrMemberExist) } - if err := am.checkAdmin(ctx, session.DomainUserID, invitation.DomainID); err != nil { + if err := am.checkAdmin(ctx, session); err != nil { return err } @@ -50,7 +47,7 @@ func (am *authorizationMiddleware) SendInvitation(ctx context.Context, session a func (am *authorizationMiddleware) ViewInvitation(ctx context.Context, session authn.Session, userID, domain string) (invitation invitations.Invitation, err error) { session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID) if session.UserID != userID { - if err := am.checkAdmin(ctx, session.DomainUserID, domain); err != nil { + if err := am.checkAdmin(ctx, session); err != nil { return invitations.Invitation{}, err } } @@ -60,7 +57,7 @@ func (am *authorizationMiddleware) ViewInvitation(ctx context.Context, session a func (am *authorizationMiddleware) ListInvitations(ctx context.Context, session authn.Session, page invitations.Page) (invs invitations.InvitationPage, err error) { session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID) - if err := am.authorize(ctx, session.DomainUserID, policies.AdminPermission, policies.PlatformType, policies.MagistralaObject); err == nil { + if err := am.authorize(ctx, session.UserID, policies.AdminPermission, policies.PlatformType, policies.MagistralaObject); err == nil { session.SuperAdmin = true } @@ -88,7 +85,7 @@ func (am *authorizationMiddleware) RejectInvitation(ctx context.Context, session func (am *authorizationMiddleware) DeleteInvitation(ctx context.Context, session authn.Session, userID, domainID string) (err error) { session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID) - if err := am.checkAdmin(ctx, session.DomainUserID, domainID); err != nil { + if err := am.checkAdmin(ctx, session); err != nil { return err } @@ -96,12 +93,12 @@ func (am *authorizationMiddleware) DeleteInvitation(ctx context.Context, session } // checkAdmin checks if the given user is a domain or platform administrator. -func (am *authorizationMiddleware) checkAdmin(ctx context.Context, userID, domainID string) error { - if err := am.authorize(ctx, userID, policies.AdminPermission, policies.DomainType, domainID); err == nil { +func (am *authorizationMiddleware) checkAdmin(ctx context.Context, session authn.Session) error { + if err := am.authorize(ctx, session.DomainUserID, policies.AdminPermission, policies.DomainType, session.DomainID); err == nil { return nil } - if err := am.authorize(ctx, userID, policies.AdminPermission, policies.PlatformType, policies.MagistralaObject); err == nil { + if err := am.authorize(ctx, session.UserID, policies.AdminPermission, policies.PlatformType, policies.MagistralaObject); err == nil { return nil } diff --git a/invitations/service.go b/invitations/service.go index 5b81d7ea68..6bc636d57c 100644 --- a/invitations/service.go +++ b/invitations/service.go @@ -49,7 +49,10 @@ func (svc *service) SendInvitation(ctx context.Context, session authn.Session, i invitation.CreatedAt = time.Now() - return svc.repo.Create(ctx, invitation) + if err := svc.repo.Create(ctx, invitation); err != nil { + return err + } + return nil } func (svc *service) ViewInvitation(ctx context.Context, session authn.Session, userID, domainID string) (invitation Invitation, err error) { diff --git a/users/middleware/authorization.go b/users/middleware/authorization.go index 53c552ff4d..24f796e68a 100644 --- a/users/middleware/authorization.go +++ b/users/middleware/authorization.go @@ -67,21 +67,23 @@ func (am *authorizationMiddleware) ListMembers(ctx context.Context, session auth if session.DomainUserID == "" { return users.MembersPage{}, svcerr.ErrDomainAuthorization } - switch objectKind { - case policies.GroupsKind: - if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, mgauth.SwitchToPermission(pm.Permission), policies.GroupType, objectID); err != nil { - return users.MembersPage{}, err + if err := am.checkSuperAdmin(ctx, session.UserID); err != nil { + switch objectKind { + case policies.GroupsKind: + if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, mgauth.SwitchToPermission(pm.Permission), policies.GroupType, objectID); err != nil { + return users.MembersPage{}, err + } + case policies.DomainsKind: + if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, mgauth.SwitchToPermission(pm.Permission), policies.DomainType, objectID); err != nil { + return users.MembersPage{}, err + } + case policies.ThingsKind: + if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, mgauth.SwitchToPermission(pm.Permission), policies.ThingType, objectID); err != nil { + return users.MembersPage{}, err + } + default: + return users.MembersPage{}, svcerr.ErrAuthorization } - case policies.DomainsKind: - if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, mgauth.SwitchToPermission(pm.Permission), policies.DomainType, objectID); err != nil { - return users.MembersPage{}, err - } - case policies.ThingsKind: - if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, mgauth.SwitchToPermission(pm.Permission), policies.ThingType, objectID); err != nil { - return users.MembersPage{}, err - } - default: - return users.MembersPage{}, svcerr.ErrAuthorization } return am.svc.ListMembers(ctx, session, objectKind, objectID, pm)