Skip to content

Commit

Permalink
Add validation for deployment using model observability
Browse files Browse the repository at this point in the history
  • Loading branch information
tiopramayudi committed Nov 29, 2023
1 parent c31b2aa commit ae0ab16
Show file tree
Hide file tree
Showing 3 changed files with 827 additions and 237 deletions.
156 changes: 156 additions & 0 deletions api/api/validator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package api

import (
"context"
"fmt"

"github.com/caraml-dev/merlin/config"
"github.com/caraml-dev/merlin/models"
"github.com/caraml-dev/merlin/pkg/protocol"
"github.com/caraml-dev/merlin/service"
"github.com/feast-dev/feast/sdk/go/protos/feast/core"
)

type requestValidator interface {
validate() error
}

type funcValidate struct {
f func() error
}

func newFuncValidate(f func() error) *funcValidate {
return &funcValidate{
f: f,
}
}

func (fv *funcValidate) validate() error {
return fv.f()
}

var supportedUPIModelTypes = map[string]bool{
models.ModelTypePyFunc: true,
models.ModelTypeCustom: true,
models.ModelTypePyFuncV3: true,
}

func isModelSupportUPI(model *models.Model) bool {
_, isSupported := supportedUPIModelTypes[model.Type]

return isSupported
}

func validateRequest(validators ...requestValidator) error {
for _, validator := range validators {
if err := validator.validate(); err != nil {
return err
}
}
return nil
}

func customModelValidation(model *models.Model, version *models.Version) requestValidator {
return newFuncValidate(func() error {
if model.Type == models.ModelTypeCustom {
if err := validateCustomPredictor(version); err != nil {
return err
}
}
return nil
})
}

func upiModelValidation(model *models.Model, endpointProtocol protocol.Protocol) requestValidator {
return newFuncValidate(func() error {
if !isModelSupportUPI(model) && endpointProtocol == protocol.UpiV1 {
return fmt.Errorf("%s model is not supported by UPI", model.Type)
}
return nil
})
}

func newVersionEndpointValidation(version *models.Version, envName string) requestValidator {
return newFuncValidate(func() error {
endpoint, ok := version.GetEndpointByEnvironmentName(envName)
if ok && (endpoint.IsRunning() || endpoint.IsServing()) {
return fmt.Errorf("there is `%s` deployment for the model version", endpoint.Status)
}
return nil
})
}

func deploymentQuotaValidation(ctx context.Context, model *models.Model, env *models.Environment, endpointSvc service.EndpointsService) requestValidator {
return newFuncValidate(func() error {
deployedModelVersionCount, err := endpointSvc.CountEndpoints(ctx, env, model)
if err != nil {
return fmt.Errorf("unable to count number of endpoints in env %s: %v", env.Name, err)
}

if deployedModelVersionCount >= config.MaxDeployedVersion {
return fmt.Errorf("max deployed endpoint reached. Max: %d Current: %d, undeploy existing endpoint before continuing", config.MaxDeployedVersion, deployedModelVersionCount)
}
return nil
})
}

func transformerValidation(
ctx context.Context,
endpoint *models.VersionEndpoint,
stdTransformerCfg config.StandardTransformerConfig,
feastCore core.CoreServiceClient) requestValidator {
return newFuncValidate(func() error {
if endpoint.Transformer != nil && endpoint.Transformer.Enabled {
err := validateTransformer(ctx, endpoint, stdTransformerCfg, feastCore)
if err != nil {
return fmt.Errorf("Error validating transformer: %v", err)
}
}
return nil
})
}

func updateRequestValidation(prev *models.VersionEndpoint, new *models.VersionEndpoint) requestValidator {
return newFuncValidate(func() error {
if prev.EnvironmentName != new.EnvironmentName {
return fmt.Errorf("updating environment is not allowed, previous: %s, new: %s", prev.EnvironmentName, new.EnvironmentName)
}

if prev.Status == models.EndpointPending {
return fmt.Errorf("updating endpoint status to %s is not allowed when the endpoint is currently in the pending state", new.Status)
}

if new.Status != prev.Status {
if prev.Status == models.EndpointServing {
return fmt.Errorf("updating endpoint status to %s is not allowed when the endpoint is currently in the serving state", new.Status)
}

if new.Status != models.EndpointRunning && new.Status != models.EndpointTerminated {
return fmt.Errorf("updating endpoint status to %s is not allowed", new.Status)
}
}
return nil
})
}

func deploymentModeValidation(prev *models.VersionEndpoint, new *models.VersionEndpoint) requestValidator {
return newFuncValidate(func() error {
// Should not allow changing the deployment mode of a pending/running/serving model for 2 reasons:
// * For "serving" models it's risky as, we can't guarantee graceful re-deployment
// * Kserve uses slightly different deployment resource naming under the hood and doesn't clean up the older deployment
if (prev.IsRunning() || prev.IsServing()) && new.DeploymentMode != "" &&
new.DeploymentMode != prev.DeploymentMode {
return fmt.Errorf("changing deployment type of a %s model is not allowed, please terminate it first", prev.Status)
}
return nil
})
}

func modelObservabilityValidation(endpoint *models.VersionEndpoint, model *models.Model) requestValidator {
return newFuncValidate(func() error {
if endpoint.EnableModelObservability && model.Type != models.ModelTypePyFuncV3 {
return fmt.Errorf("model type should be pyfunc_v3 if want to enable model observablity")
}
return nil
})
}
169 changes: 14 additions & 155 deletions api/api/version_endpoints_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import (
"github.com/caraml-dev/merlin/pkg/transformer/feast"
"github.com/caraml-dev/merlin/pkg/transformer/pipeline"
"github.com/caraml-dev/merlin/pkg/transformer/spec"
"github.com/caraml-dev/merlin/service"
)

type EndpointsController struct {
Expand Down Expand Up @@ -113,130 +112,6 @@ func (c *EndpointsController) GetEndpoint(r *http.Request, vars map[string]strin
return Ok(endpoint)
}

type endpointValidator interface {
validate(endpoint *models.VersionEndpoint, model *models.Model, version *models.Version) error
}

type funcValidate struct {
f func(endpoint *models.VersionEndpoint, model *models.Model, version *models.Version) error
}

func newFuncValidate(f func(*models.VersionEndpoint, *models.Model, *models.Version) error) *funcValidate {
return &funcValidate{
f: f,
}
}

func (fv *funcValidate) validate(endpoint *models.VersionEndpoint, model *models.Model, version *models.Version) error {
return fv.f(endpoint, model, version)
}

func validateEndpointRequest(endpoint *models.VersionEndpoint, model *models.Model, version *models.Version, validators ...endpointValidator) error {
for _, validator := range validators {
if err := validator.validate(endpoint, model, version); err != nil {
return err
}
}
return nil
}

func customModelValidation(model *models.Model) endpointValidator {
return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error {
if model.Type == models.ModelTypeCustom {
if err := validateCustomPredictor(v); err != nil {
return err
}
}
return nil
})
}

func upiModelValidation(model *models.Model) endpointValidator {
return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error {
if !isModelSupportUPI(model) && ve.Protocol == protocol.UpiV1 {
return fmt.Errorf("%s model is not supported by UPI", model.Type)
}
return nil
})
}

func newVersionEndpointValidation(version *models.Version, envName string) endpointValidator {
return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error {
endpoint, ok := version.GetEndpointByEnvironmentName(envName)
if ok && (endpoint.IsRunning() || endpoint.IsServing()) {
return fmt.Errorf("there is `%s` deployment for the model version", endpoint.Status)
}
return nil
})
}

func deploymentQuotaValidation(ctx context.Context, model *models.Model, env *models.Environment, endpointSvc service.EndpointsService) endpointValidator {
return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error {
deployedModelVersionCount, err := endpointSvc.CountEndpoints(ctx, env, model)
if err != nil {
return fmt.Errorf("unable to count number of endpoints in env %s: %v", env.Name, err)
}

if deployedModelVersionCount >= config.MaxDeployedVersion {
return fmt.Errorf("max deployed endpoint reached. Max: %d Current: %d, undeploy existing endpoint before continuing", config.MaxDeployedVersion, deployedModelVersionCount)
}
return nil
})
}

func transformerValidation(
ctx context.Context,
endpoint *models.VersionEndpoint,
stdTransformerCfg config.StandardTransformerConfig,
feastCore core.CoreServiceClient) endpointValidator {
return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error {
if ve.Transformer != nil && ve.Transformer.Enabled {
err := validateTransformer(ctx, endpoint, stdTransformerCfg, feastCore)
if err != nil {
return fmt.Errorf("Error validating transformer: %v", err)
}
}
return nil
})
}

func updateRequestValidation(prev *models.VersionEndpoint, new *models.VersionEndpoint) endpointValidator {
return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error {
if prev.EnvironmentName != new.EnvironmentName {
return fmt.Errorf("updating environment is not allowed, previous: %s, new: %s", prev.EnvironmentName, new.EnvironmentName)
}

if prev.Status == models.EndpointPending {
return fmt.Errorf("updating endpoint status to %s is not allowed when the endpoint is currently in the pending state", new.Status)
}

if new.Status != prev.Status {
if prev.Status == models.EndpointServing {
return fmt.Errorf("updating endpoint status to %s is not allowed when the endpoint is currently in the serving state", new.Status)
}

if new.Status != models.EndpointRunning && new.Status != models.EndpointTerminated {
return fmt.Errorf("updating endpoint status to %s is not allowed", new.Status)
}
}
return nil
})
}

func deploymentModeValidation(prev *models.VersionEndpoint, new *models.VersionEndpoint) endpointValidator {
return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error {
// Should not allow changing the deployment mode of a pending/running/serving model for 2 reasons:
// * For "serving" models it's risky as, we can't guarantee graceful re-deployment
// * Kserve uses slightly different deployment resource naming under the hood and doesn't clean up the older deployment
if (prev.IsRunning() || prev.IsServing()) && new.DeploymentMode != "" &&
new.DeploymentMode != prev.DeploymentMode {
return fmt.Errorf("changing deployment type of a %s model is not allowed, please terminate it first", prev.Status)
}
return nil
})

}

// CreateEndpoint create new endpoint from a model version and deploy to certain environment as specified by request
// If target environment is not set then fallback to default environment
func (c *EndpointsController) CreateEndpoint(r *http.Request, vars map[string]string, body interface{}) *Response {
Expand Down Expand Up @@ -281,15 +156,16 @@ func (c *EndpointsController) CreateEndpoint(r *http.Request, vars map[string]st
newEndpoint.EnvironmentName = env.Name
}

validationRules := []endpointValidator{
customModelValidation(model),
upiModelValidation(model),
validationRules := []requestValidator{
customModelValidation(model, version),
upiModelValidation(model, newEndpoint.Protocol),
newVersionEndpointValidation(version, env.Name),
deploymentQuotaValidation(ctx, model, env, c.EndpointsService),
transformerValidation(ctx, newEndpoint, c.StandardTransformerConfig, c.FeastCoreClient),
modelObservabilityValidation(newEndpoint, model),
}

if err := validateEndpointRequest(newEndpoint, model, version, validationRules...); err != nil {
if err := validateRequest(validationRules...); err != nil {
return BadRequest(fmt.Sprintf("Request validation failed: %v", err))
}

Expand All @@ -304,18 +180,6 @@ func (c *EndpointsController) CreateEndpoint(r *http.Request, vars map[string]st
return Created(endpoint)
}

var supportedUPIModelTypes = map[string]bool{
models.ModelTypePyFunc: true,
models.ModelTypeCustom: true,
models.ModelTypePyFuncV3: true,
}

func isModelSupportUPI(model *models.Model) bool {
_, isSupported := supportedUPIModelTypes[model.Type]

return isSupported
}

// UpdateEndpoint update a an existing endpoint i.e. trigger redeployment
func (c *EndpointsController) UpdateEndpoint(r *http.Request, vars map[string]string, body interface{}) *Response {
ctx := r.Context()
Expand Down Expand Up @@ -345,11 +209,6 @@ func (c *EndpointsController) UpdateEndpoint(r *http.Request, vars map[string]st
return BadRequest("Unable to parse body as version endpoint resource")
}

validationRules := []endpointValidator{
customModelValidation(model),
updateRequestValidation(endpoint, newEndpoint),
}

env, err := c.AppContext.EnvironmentService.GetEnvironment(newEndpoint.EnvironmentName)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
Expand All @@ -358,21 +217,20 @@ func (c *EndpointsController) UpdateEndpoint(r *http.Request, vars map[string]st
return InternalServerError(fmt.Sprintf("Error getting the specified environment: %v", err))
}

validationRules := []requestValidator{
customModelValidation(model, version),
updateRequestValidation(endpoint, newEndpoint),
modelObservabilityValidation(newEndpoint, model),
}

if newEndpoint.Status == models.EndpointRunning || newEndpoint.Status == models.EndpointServing {
// validate transformer
// if newEndpoint.Transformer != nil && newEndpoint.Transformer.Enabled {
// err := c.validateTransformer(ctx, newEndpoint.Transformer, newEndpoint.Protocol, newEndpoint.Logger)
// if err != nil {
// return BadRequest(fmt.Sprintf("Error validating the transformer: %v", err))
// }
// }
validationRules = append(
validationRules,
transformerValidation(ctx, newEndpoint, c.StandardTransformerConfig, c.FeastCoreClient),
deploymentModeValidation(endpoint, newEndpoint),
)

if err := validateEndpointRequest(newEndpoint, model, version, validationRules...); err != nil {
if err := validateRequest(validationRules...); err != nil {
return BadRequest(fmt.Sprintf("Request validation failed: %v", err))
}

Expand All @@ -385,9 +243,10 @@ func (c *EndpointsController) UpdateEndpoint(r *http.Request, vars map[string]st
return InternalServerError(fmt.Sprintf("Unable to deploy model version: %v", err))
}
} else if newEndpoint.Status == models.EndpointTerminated {
if err := validateEndpointRequest(newEndpoint, model, version, validationRules...); err != nil {
if err := validateRequest(validationRules...); err != nil {
return BadRequest(fmt.Sprintf("Request validation failed: %v", err))
}

endpoint, err = c.EndpointsService.UndeployEndpoint(ctx, env, model, version, endpoint)
if err != nil {
return InternalServerError(fmt.Sprintf("Unable to undeploy version endpoint %s: %v", endpointID, err))
Expand Down
Loading

0 comments on commit ae0ab16

Please sign in to comment.