Skip to content

Commit

Permalink
Add validation
Browse files Browse the repository at this point in the history
  • Loading branch information
tiopramayudi committed Nov 28, 2023
1 parent fa31b27 commit c31b2aa
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 118 deletions.
267 changes: 170 additions & 97 deletions api/api/version_endpoints_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

merror "github.com/caraml-dev/merlin/pkg/errors"
"github.com/caraml-dev/merlin/pkg/protocol"
"github.com/feast-dev/feast/sdk/go/protos/feast/core"
"github.com/google/uuid"
"google.golang.org/protobuf/encoding/protojson"
"gorm.io/gorm"
Expand All @@ -32,6 +33,7 @@ import (
"github.com/caraml-dev/merlin/pkg/transformer/feast"
"github.com/caraml-dev/merlin/pkg/transformer/pipeline"
"github.com/caraml-dev/merlin/pkg/transformer/spec"
"github.com/caraml-dev/merlin/service"
)

type EndpointsController struct {
Expand Down Expand Up @@ -111,6 +113,130 @@ func (c *EndpointsController) GetEndpoint(r *http.Request, vars map[string]strin
return Ok(endpoint)
}

type endpointValidator interface {
validate(endpoint *models.VersionEndpoint, model *models.Model, version *models.Version) error
}

type funcValidate struct {
f func(endpoint *models.VersionEndpoint, model *models.Model, version *models.Version) error
}

func newFuncValidate(f func(*models.VersionEndpoint, *models.Model, *models.Version) error) *funcValidate {
return &funcValidate{
f: f,
}
}

func (fv *funcValidate) validate(endpoint *models.VersionEndpoint, model *models.Model, version *models.Version) error {
return fv.f(endpoint, model, version)
}

func validateEndpointRequest(endpoint *models.VersionEndpoint, model *models.Model, version *models.Version, validators ...endpointValidator) error {
for _, validator := range validators {
if err := validator.validate(endpoint, model, version); err != nil {
return err
}
}
return nil
}

func customModelValidation(model *models.Model) endpointValidator {
return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error {
if model.Type == models.ModelTypeCustom {
if err := validateCustomPredictor(v); err != nil {
return err
}
}
return nil
})
}

func upiModelValidation(model *models.Model) endpointValidator {
return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error {
if !isModelSupportUPI(model) && ve.Protocol == protocol.UpiV1 {
return fmt.Errorf("%s model is not supported by UPI", model.Type)
}
return nil
})
}

func newVersionEndpointValidation(version *models.Version, envName string) endpointValidator {
return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error {
endpoint, ok := version.GetEndpointByEnvironmentName(envName)
if ok && (endpoint.IsRunning() || endpoint.IsServing()) {
return fmt.Errorf("there is `%s` deployment for the model version", endpoint.Status)
}
return nil
})
}

func deploymentQuotaValidation(ctx context.Context, model *models.Model, env *models.Environment, endpointSvc service.EndpointsService) endpointValidator {
return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error {
deployedModelVersionCount, err := endpointSvc.CountEndpoints(ctx, env, model)
if err != nil {
return fmt.Errorf("unable to count number of endpoints in env %s: %v", env.Name, err)
}

if deployedModelVersionCount >= config.MaxDeployedVersion {
return fmt.Errorf("max deployed endpoint reached. Max: %d Current: %d, undeploy existing endpoint before continuing", config.MaxDeployedVersion, deployedModelVersionCount)
}
return nil
})
}

func transformerValidation(
ctx context.Context,
endpoint *models.VersionEndpoint,
stdTransformerCfg config.StandardTransformerConfig,
feastCore core.CoreServiceClient) endpointValidator {
return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error {
if ve.Transformer != nil && ve.Transformer.Enabled {
err := validateTransformer(ctx, endpoint, stdTransformerCfg, feastCore)
if err != nil {
return fmt.Errorf("Error validating transformer: %v", err)
}
}
return nil
})
}

func updateRequestValidation(prev *models.VersionEndpoint, new *models.VersionEndpoint) endpointValidator {
return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error {
if prev.EnvironmentName != new.EnvironmentName {
return fmt.Errorf("updating environment is not allowed, previous: %s, new: %s", prev.EnvironmentName, new.EnvironmentName)
}

if prev.Status == models.EndpointPending {
return fmt.Errorf("updating endpoint status to %s is not allowed when the endpoint is currently in the pending state", new.Status)
}

if new.Status != prev.Status {
if prev.Status == models.EndpointServing {
return fmt.Errorf("updating endpoint status to %s is not allowed when the endpoint is currently in the serving state", new.Status)
}

if new.Status != models.EndpointRunning && new.Status != models.EndpointTerminated {
return fmt.Errorf("updating endpoint status to %s is not allowed", new.Status)
}
}
return nil
})
}

func deploymentModeValidation(prev *models.VersionEndpoint, new *models.VersionEndpoint) endpointValidator {
return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error {
// Should not allow changing the deployment mode of a pending/running/serving model for 2 reasons:
// * For "serving" models it's risky as, we can't guarantee graceful re-deployment
// * Kserve uses slightly different deployment resource naming under the hood and doesn't clean up the older deployment
if (prev.IsRunning() || prev.IsServing()) && new.DeploymentMode != "" &&
new.DeploymentMode != prev.DeploymentMode {
return fmt.Errorf("changing deployment type of a %s model is not allowed, please terminate it first", prev.Status)
}
return nil
})

}

// CreateEndpoint create new endpoint from a model version and deploy to certain environment as specified by request
// If target environment is not set then fallback to default environment
func (c *EndpointsController) CreateEndpoint(r *http.Request, vars map[string]string, body interface{}) *Response {
Expand All @@ -126,15 +252,6 @@ func (c *EndpointsController) CreateEndpoint(r *http.Request, vars map[string]st
}
return InternalServerError(fmt.Sprintf("Error getting model / version: %v", err))
}

// validate custom predictor
if model.Type == models.ModelTypeCustom {
err := c.validateCustomPredictor(ctx, version)
if err != nil {
return BadRequest(fmt.Sprintf("Error validating custom predictor: %v", err))
}
}

env, err := c.AppContext.EnvironmentService.GetDefaultEnvironment()
if err != nil {
return InternalServerError(fmt.Sprintf("Unable to find default environment, specify environment target for deployment: %v", err))
Expand Down Expand Up @@ -164,38 +281,19 @@ func (c *EndpointsController) CreateEndpoint(r *http.Request, vars map[string]st
newEndpoint.EnvironmentName = env.Name
}

// check that UPI is supported
if !isModelSupportUPI(model) && newEndpoint.Protocol == protocol.UpiV1 {
return BadRequest(
fmt.Sprintf("%s model is not supported by UPI", model.Type))
}

// check that the endpoint is not deployed nor deploying
endpoint, ok := version.GetEndpointByEnvironmentName(env.Name)
if ok && (endpoint.IsRunning() || endpoint.IsServing()) {
return BadRequest(
fmt.Sprintf("There is `%s` deployment for the model version", endpoint.Status))
validationRules := []endpointValidator{
customModelValidation(model),
upiModelValidation(model),
newVersionEndpointValidation(version, env.Name),
deploymentQuotaValidation(ctx, model, env, c.EndpointsService),
transformerValidation(ctx, newEndpoint, c.StandardTransformerConfig, c.FeastCoreClient),
}

// check that the model version quota
deployedModelVersionCount, err := c.EndpointsService.CountEndpoints(ctx, env, model)
if err != nil {
return InternalServerError(fmt.Sprintf("Unable to count number of endpoints in env %s: %v", env.Name, err))
}

if deployedModelVersionCount >= config.MaxDeployedVersion {
return BadRequest(fmt.Sprintf("Max deployed endpoint reached. Max: %d Current: %d, undeploy existing endpoint before continuing", config.MaxDeployedVersion, deployedModelVersionCount))
}

// validate transformer
if newEndpoint.Transformer != nil && newEndpoint.Transformer.Enabled {
err := c.validateTransformer(ctx, newEndpoint.Transformer, newEndpoint.Protocol, newEndpoint.Logger)
if err != nil {
return BadRequest(fmt.Sprintf("Error validating transformer: %v", err))
}
if err := validateEndpointRequest(newEndpoint, model, version, validationRules...); err != nil {
return BadRequest(fmt.Sprintf("Request validation failed: %v", err))
}

endpoint, err = c.EndpointsService.DeployEndpoint(ctx, env, model, version, newEndpoint)
endpoint, err := c.EndpointsService.DeployEndpoint(ctx, env, model, version, newEndpoint)
if err != nil {
if errors.Is(err, merror.InvalidInputError) {
return BadRequest(fmt.Sprintf("Unable to process model version input: %v", err))
Expand Down Expand Up @@ -234,14 +332,6 @@ func (c *EndpointsController) UpdateEndpoint(r *http.Request, vars map[string]st
return InternalServerError(fmt.Sprintf("Error getting model / version: %v", err))
}

// validate custom predictor
if model.Type == models.ModelTypeCustom {
err := c.validateCustomPredictor(ctx, version)
if err != nil {
return BadRequest(fmt.Sprintf("Error validating custom predictor: %v", err))
}
}

endpoint, err := c.EndpointsService.FindByID(ctx, endpointID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
Expand All @@ -255,9 +345,9 @@ func (c *EndpointsController) UpdateEndpoint(r *http.Request, vars map[string]st
return BadRequest("Unable to parse body as version endpoint resource")
}

err = validateUpdateRequest(endpoint, newEndpoint)
if err != nil {
return BadRequest(fmt.Sprintf("Error validating request: %v", err))
validationRules := []endpointValidator{
customModelValidation(model),
updateRequestValidation(endpoint, newEndpoint),
}

env, err := c.AppContext.EnvironmentService.GetEnvironment(newEndpoint.EnvironmentName)
Expand All @@ -270,20 +360,20 @@ func (c *EndpointsController) UpdateEndpoint(r *http.Request, vars map[string]st

if newEndpoint.Status == models.EndpointRunning || newEndpoint.Status == models.EndpointServing {
// validate transformer
if newEndpoint.Transformer != nil && newEndpoint.Transformer.Enabled {
err := c.validateTransformer(ctx, newEndpoint.Transformer, newEndpoint.Protocol, newEndpoint.Logger)
if err != nil {
return BadRequest(fmt.Sprintf("Error validating the transformer: %v", err))
}
}

// Should not allow changing the deployment mode of a pending/running/serving model for 2 reasons:
// * For "serving" models it's risky as, we can't guarantee graceful re-deployment
// * Kserve uses slightly different deployment resource naming under the hood and doesn't clean up the older deployment
if (endpoint.IsRunning() || endpoint.IsServing()) && newEndpoint.DeploymentMode != "" &&
newEndpoint.DeploymentMode != endpoint.DeploymentMode {
return BadRequest(fmt.Sprintf("Changing deployment type of a %s model is not allowed, please terminate it first.",
endpoint.Status))
// if newEndpoint.Transformer != nil && newEndpoint.Transformer.Enabled {
// err := c.validateTransformer(ctx, newEndpoint.Transformer, newEndpoint.Protocol, newEndpoint.Logger)
// if err != nil {
// return BadRequest(fmt.Sprintf("Error validating the transformer: %v", err))
// }
// }
validationRules = append(
validationRules,
transformerValidation(ctx, newEndpoint, c.StandardTransformerConfig, c.FeastCoreClient),
deploymentModeValidation(endpoint, newEndpoint),
)

if err := validateEndpointRequest(newEndpoint, model, version, validationRules...); err != nil {
return BadRequest(fmt.Sprintf("Request validation failed: %v", err))
}

endpoint, err = c.EndpointsService.DeployEndpoint(ctx, env, model, version, newEndpoint)
Expand All @@ -295,12 +385,15 @@ func (c *EndpointsController) UpdateEndpoint(r *http.Request, vars map[string]st
return InternalServerError(fmt.Sprintf("Unable to deploy model version: %v", err))
}
} else if newEndpoint.Status == models.EndpointTerminated {
if err := validateEndpointRequest(newEndpoint, model, version, validationRules...); err != nil {
return BadRequest(fmt.Sprintf("Request validation failed: %v", err))
}
endpoint, err = c.EndpointsService.UndeployEndpoint(ctx, env, model, version, endpoint)
if err != nil {
return InternalServerError(fmt.Sprintf("Unable to undeploy version endpoint %s: %v", endpointID, err))
}
} else {
return InternalServerError(fmt.Sprintf("Updating endpoint status to %s is not allowed", newEndpoint.Status))
return BadRequest(fmt.Sprintf("Updating endpoint status to %s is not allowed", newEndpoint.Status))
}

return Ok(endpoint)
Expand Down Expand Up @@ -393,74 +486,54 @@ func (c *EndpointsController) ListContainers(r *http.Request, vars map[string]st
return Ok(containers)
}

func validateUpdateRequest(prev *models.VersionEndpoint, new *models.VersionEndpoint) error {
if prev.EnvironmentName != new.EnvironmentName {
return fmt.Errorf("Updating environment is not allowed, previous: %s, new: %s", prev.EnvironmentName, new.EnvironmentName)
}

if prev.Status == models.EndpointPending {
return fmt.Errorf("Updating endpoint status to %s is not allowed when the endpoint is currently in the pending state", new.Status)
}

if new.Status != prev.Status {
if prev.Status == models.EndpointServing {
return fmt.Errorf("Updating endpoint status to %s is not allowed when the endpoint is currently in the serving state", new.Status)
}

if new.Status != models.EndpointRunning && new.Status != models.EndpointTerminated {
return fmt.Errorf("Updating endpoint status to %s is not allowed", new.Status)
}
}

return nil
}

func (c *EndpointsController) validateTransformer(ctx context.Context, trans *models.Transformer, protocol protocol.Protocol, logger *models.Logger) error {
func validateTransformer(ctx context.Context, endpoint *models.VersionEndpoint, stdTransformerConfig config.StandardTransformerConfig, feastCore core.CoreServiceClient) error {
trans := endpoint.Transformer
protocol := endpoint.Protocol
logger := endpoint.Logger
switch trans.TransformerType {
case models.CustomTransformerType, models.DefaultTransformerType:
if trans.Image == "" {
return errors.New("Transformer image name is not specified")
return errors.New("transformer image name is not specified")
}
case models.StandardTransformerType:
envVars := trans.EnvVars.ToMap()
cfg, ok := envVars[transformer.StandardTransformerConfigEnvName]
if !ok {
return errors.New("Standard transformer config is not specified")
return errors.New("standard transformer config is not specified")
}

var predictionLogCfg *spec.PredictionLogConfig
if logger != nil && logger.Prediction != nil {
predictionLogCfg = logger.Prediction.ToPredictionLogConfig()
}

return c.validateStandardTransformerConfig(ctx, cfg, protocol, predictionLogCfg)
feastOptions := &feast.Options{
StorageConfigs: stdTransformerConfig.ToFeastStorageConfigs(),
}
return validateStandardTransformerConfig(ctx, cfg, protocol, predictionLogCfg, feastOptions, feastCore)
default:
return fmt.Errorf("Unknown transformer type: %s", trans.TransformerType)
return fmt.Errorf("unknown transformer type: %s", trans.TransformerType)
}

return nil
}

func (c *EndpointsController) validateCustomPredictor(ctx context.Context, version *models.Version) error {
func validateCustomPredictor(version *models.Version) error {
customPredictor := version.CustomPredictor
if customPredictor == nil {
return errors.New("custom predictor must be specified")
}
return customPredictor.IsValid()
}

func (c *EndpointsController) validateStandardTransformerConfig(ctx context.Context, cfg string, protocol protocol.Protocol, predictionLogConfig *spec.PredictionLogConfig) error {
func validateStandardTransformerConfig(ctx context.Context, cfg string, protocol protocol.Protocol, predictionLogConfig *spec.PredictionLogConfig, feastOpts *feast.Options, feastCore core.CoreServiceClient) error {
stdTransformerConfig := &spec.StandardTransformerConfig{}
err := protojson.Unmarshal([]byte(cfg), stdTransformerConfig)
if err != nil {
return err
}

feastOptions := &feast.Options{
StorageConfigs: c.StandardTransformerConfig.ToFeastStorageConfigs(),
}

stdTransformerConfig.PredictionLogConfig = predictionLogConfig

return pipeline.ValidateTransformerConfig(ctx, c.FeastCoreClient, stdTransformerConfig, feastOptions, protocol)
return pipeline.ValidateTransformerConfig(ctx, feastCore, stdTransformerConfig, feastOpts, protocol)
}
Loading

0 comments on commit c31b2aa

Please sign in to comment.