Skip to content

Commit

Permalink
chore: fix update vertex validation (#1327)
Browse files Browse the repository at this point in the history
Signed-off-by: Sidhant Kohli <[email protected]>
  • Loading branch information
kohlisid authored Nov 2, 2023
1 parent d5c3b07 commit 4715867
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 7 deletions.
18 changes: 18 additions & 0 deletions server/apis/v1/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,8 @@ func (h *handler) UpdateVertex(c *gin.Context) {
inputVertexName = c.Param("vertex")
pipeline = c.Param("pipeline")
ns = c.Param("namespace")
// dryRun is used to check if the operation is just a validation or an actual create
dryRun = strings.EqualFold("true", c.DefaultQuery("dry-run", "false"))
)

pl, err := h.numaflowClient.Pipelines(ns).Get(context.Background(), pipeline, metav1.GetOptions{})
Expand Down Expand Up @@ -665,6 +667,22 @@ func (h *handler) UpdateVertex(c *gin.Context) {
break
}
}
err = validateNamespace(h, pl, ns)
if err != nil {
h.respondWithError(c, err.Error())
return
}
pl.Namespace = ns
err = validatePipelineSpec(h, nil, pl, ValidTypeCreate)
if err != nil {
h.respondWithError(c, fmt.Sprintf("Failed to validate pipeline spec, %s", err.Error()))
return
}
// if Validation flag "dryRun" is set to true, return without creating the pipeline
if dryRun {
c.JSON(http.StatusOK, NewNumaflowAPIResponse(nil, nil))
return
}

if _, err := h.numaflowClient.Pipelines(ns).Update(context.Background(), pl, metav1.UpdateOptions{}); err != nil {
h.respondWithError(c, fmt.Sprintf("Failed to update the vertex: namespace %q pipeline %q vertex %q: %s",
Expand Down
32 changes: 25 additions & 7 deletions server/authz/rbac.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"path"
"strings"
"sync"

"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/model"
Expand Down Expand Up @@ -50,6 +51,7 @@ type CasbinObject struct {
currentScopes []string
policyDefault string
configReader *viper.Viper
permCountLock *sync.RWMutex
}

func NewCasbinObject() (*CasbinObject, error) {
Expand All @@ -75,6 +77,7 @@ func NewCasbinObject() (*CasbinObject, error) {
currentScopes: currentScopes,
policyDefault: policyDefault,
configReader: configReader,
permCountLock: &sync.RWMutex{},
}

// Watch for changes in the config file.
Expand All @@ -101,7 +104,7 @@ func (cas *CasbinObject) Authorize(c *gin.Context, userInfo *authn.UserInfo) boo
// Check for the given scoped list if the user is authorized using any of the subjects in the list.
for _, scopedSubject := range scopedList {
// Check if the user has permissions in the policy for the given scoped subject.
userHasPolicies = userHasPolicies || hasPermissionsDefined(cas.enforcer, scopedSubject, cas.userPermCount)
userHasPolicies = userHasPolicies || cas.hasPermissionsDefined(scopedSubject)
if ok := enforceCheck(cas.enforcer, scopedSubject, resource, object, action); ok {
return ok
}
Expand Down Expand Up @@ -295,26 +298,41 @@ func getDefaultPolicy(config *viper.Viper) string {
// We have a cache userPermCount to store the count of permissions for a user. If the user has permissions in the
// policy, we store the count in the cache and return based on the value.
// If the user does not have permissions in the policy, we add it to the cache before returning
func hasPermissionsDefined(enforcer *casbin.Enforcer, user string, userPermCount map[string]int) bool {
func (cas *CasbinObject) hasPermissionsDefined(user string) bool {
// check if user exists in userPermCount
if userPermCount == nil {
userPermCount = make(map[string]int)
if cas.userPermCount == nil {
cas.userPermCount = make(map[string]int)
}
val, ok := userPermCount[user]
val, ok := cas.getPermCount(user)
// If the key exists
if ok {
// Return true if the user has permissions in the policy
// and false if the user does not have permissions in the policy.
return val > 0
}
// get the permissions for the user
cnt, err := enforcer.GetImplicitPermissionsForUser(user)
cnt, err := cas.enforcer.GetImplicitPermissionsForUser(user)
if err != nil {
logger.Errorw("Failed to get permissions for user", "user", user, "error", err)
return false
}
count := len(cnt)
// store the count in userPermCount
userPermCount[user] = count
cas.updatePermCount(user, count)
return count > 0
}

// updatePermCount updates the permission count for the user in the cache.
func (cas *CasbinObject) updatePermCount(user string, count int) {
cas.permCountLock.Lock()
defer cas.permCountLock.Unlock()
cas.userPermCount[user] = count
}

// getPermCount returns the permission count for the user from the cache.
func (cas *CasbinObject) getPermCount(user string) (int, bool) {
cas.permCountLock.RLock()
defer cas.permCountLock.RUnlock()
val, ok := cas.userPermCount[user]
return val, ok
}

0 comments on commit 4715867

Please sign in to comment.