-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add validation for deployment using model observability
- Loading branch information
1 parent
c31b2aa
commit ae0ab16
Showing
3 changed files
with
827 additions
and
237 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
package api | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
|
||
"github.com/caraml-dev/merlin/config" | ||
"github.com/caraml-dev/merlin/models" | ||
"github.com/caraml-dev/merlin/pkg/protocol" | ||
"github.com/caraml-dev/merlin/service" | ||
"github.com/feast-dev/feast/sdk/go/protos/feast/core" | ||
) | ||
|
||
type requestValidator interface { | ||
validate() error | ||
} | ||
|
||
type funcValidate struct { | ||
f func() error | ||
} | ||
|
||
func newFuncValidate(f func() error) *funcValidate { | ||
return &funcValidate{ | ||
f: f, | ||
} | ||
} | ||
|
||
func (fv *funcValidate) validate() error { | ||
return fv.f() | ||
} | ||
|
||
var supportedUPIModelTypes = map[string]bool{ | ||
models.ModelTypePyFunc: true, | ||
models.ModelTypeCustom: true, | ||
models.ModelTypePyFuncV3: true, | ||
} | ||
|
||
func isModelSupportUPI(model *models.Model) bool { | ||
_, isSupported := supportedUPIModelTypes[model.Type] | ||
|
||
return isSupported | ||
} | ||
|
||
func validateRequest(validators ...requestValidator) error { | ||
for _, validator := range validators { | ||
if err := validator.validate(); err != nil { | ||
return err | ||
} | ||
} | ||
return nil | ||
} | ||
|
||
func customModelValidation(model *models.Model, version *models.Version) requestValidator { | ||
return newFuncValidate(func() error { | ||
if model.Type == models.ModelTypeCustom { | ||
if err := validateCustomPredictor(version); err != nil { | ||
return err | ||
} | ||
} | ||
return nil | ||
}) | ||
} | ||
|
||
func upiModelValidation(model *models.Model, endpointProtocol protocol.Protocol) requestValidator { | ||
return newFuncValidate(func() error { | ||
if !isModelSupportUPI(model) && endpointProtocol == protocol.UpiV1 { | ||
return fmt.Errorf("%s model is not supported by UPI", model.Type) | ||
} | ||
return nil | ||
}) | ||
} | ||
|
||
func newVersionEndpointValidation(version *models.Version, envName string) requestValidator { | ||
return newFuncValidate(func() error { | ||
endpoint, ok := version.GetEndpointByEnvironmentName(envName) | ||
if ok && (endpoint.IsRunning() || endpoint.IsServing()) { | ||
return fmt.Errorf("there is `%s` deployment for the model version", endpoint.Status) | ||
} | ||
return nil | ||
}) | ||
} | ||
|
||
func deploymentQuotaValidation(ctx context.Context, model *models.Model, env *models.Environment, endpointSvc service.EndpointsService) requestValidator { | ||
return newFuncValidate(func() error { | ||
deployedModelVersionCount, err := endpointSvc.CountEndpoints(ctx, env, model) | ||
if err != nil { | ||
return fmt.Errorf("unable to count number of endpoints in env %s: %v", env.Name, err) | ||
} | ||
|
||
if deployedModelVersionCount >= config.MaxDeployedVersion { | ||
return fmt.Errorf("max deployed endpoint reached. Max: %d Current: %d, undeploy existing endpoint before continuing", config.MaxDeployedVersion, deployedModelVersionCount) | ||
} | ||
return nil | ||
}) | ||
} | ||
|
||
func transformerValidation( | ||
ctx context.Context, | ||
endpoint *models.VersionEndpoint, | ||
stdTransformerCfg config.StandardTransformerConfig, | ||
feastCore core.CoreServiceClient) requestValidator { | ||
return newFuncValidate(func() error { | ||
if endpoint.Transformer != nil && endpoint.Transformer.Enabled { | ||
err := validateTransformer(ctx, endpoint, stdTransformerCfg, feastCore) | ||
if err != nil { | ||
return fmt.Errorf("Error validating transformer: %v", err) | ||
} | ||
} | ||
return nil | ||
}) | ||
} | ||
|
||
func updateRequestValidation(prev *models.VersionEndpoint, new *models.VersionEndpoint) requestValidator { | ||
return newFuncValidate(func() error { | ||
if prev.EnvironmentName != new.EnvironmentName { | ||
return fmt.Errorf("updating environment is not allowed, previous: %s, new: %s", prev.EnvironmentName, new.EnvironmentName) | ||
} | ||
|
||
if prev.Status == models.EndpointPending { | ||
return fmt.Errorf("updating endpoint status to %s is not allowed when the endpoint is currently in the pending state", new.Status) | ||
} | ||
|
||
if new.Status != prev.Status { | ||
if prev.Status == models.EndpointServing { | ||
return fmt.Errorf("updating endpoint status to %s is not allowed when the endpoint is currently in the serving state", new.Status) | ||
} | ||
|
||
if new.Status != models.EndpointRunning && new.Status != models.EndpointTerminated { | ||
return fmt.Errorf("updating endpoint status to %s is not allowed", new.Status) | ||
} | ||
} | ||
return nil | ||
}) | ||
} | ||
|
||
func deploymentModeValidation(prev *models.VersionEndpoint, new *models.VersionEndpoint) requestValidator { | ||
return newFuncValidate(func() error { | ||
// Should not allow changing the deployment mode of a pending/running/serving model for 2 reasons: | ||
// * For "serving" models it's risky as, we can't guarantee graceful re-deployment | ||
// * Kserve uses slightly different deployment resource naming under the hood and doesn't clean up the older deployment | ||
if (prev.IsRunning() || prev.IsServing()) && new.DeploymentMode != "" && | ||
new.DeploymentMode != prev.DeploymentMode { | ||
return fmt.Errorf("changing deployment type of a %s model is not allowed, please terminate it first", prev.Status) | ||
} | ||
return nil | ||
}) | ||
} | ||
|
||
func modelObservabilityValidation(endpoint *models.VersionEndpoint, model *models.Model) requestValidator { | ||
return newFuncValidate(func() error { | ||
if endpoint.EnableModelObservability && model.Type != models.ModelTypePyFuncV3 { | ||
return fmt.Errorf("model type should be pyfunc_v3 if want to enable model observablity") | ||
} | ||
return nil | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.