diff --git a/api/api/validator.go b/api/api/validator.go new file mode 100644 index 000000000..a3dd5522b --- /dev/null +++ b/api/api/validator.go @@ -0,0 +1,156 @@ +package api + +import ( + "context" + "fmt" + + "github.com/caraml-dev/merlin/config" + "github.com/caraml-dev/merlin/models" + "github.com/caraml-dev/merlin/pkg/protocol" + "github.com/caraml-dev/merlin/service" + "github.com/feast-dev/feast/sdk/go/protos/feast/core" +) + +type requestValidator interface { + validate() error +} + +type funcValidate struct { + f func() error +} + +func newFuncValidate(f func() error) *funcValidate { + return &funcValidate{ + f: f, + } +} + +func (fv *funcValidate) validate() error { + return fv.f() +} + +var supportedUPIModelTypes = map[string]bool{ + models.ModelTypePyFunc: true, + models.ModelTypeCustom: true, + models.ModelTypePyFuncV3: true, +} + +func isModelSupportUPI(model *models.Model) bool { + _, isSupported := supportedUPIModelTypes[model.Type] + + return isSupported +} + +func validateRequest(validators ...requestValidator) error { + for _, validator := range validators { + if err := validator.validate(); err != nil { + return err + } + } + return nil +} + +func customModelValidation(model *models.Model, version *models.Version) requestValidator { + return newFuncValidate(func() error { + if model.Type == models.ModelTypeCustom { + if err := validateCustomPredictor(version); err != nil { + return err + } + } + return nil + }) +} + +func upiModelValidation(model *models.Model, endpointProtocol protocol.Protocol) requestValidator { + return newFuncValidate(func() error { + if !isModelSupportUPI(model) && endpointProtocol == protocol.UpiV1 { + return fmt.Errorf("%s model is not supported by UPI", model.Type) + } + return nil + }) +} + +func newVersionEndpointValidation(version *models.Version, envName string) requestValidator { + return newFuncValidate(func() error { + endpoint, ok := version.GetEndpointByEnvironmentName(envName) + if ok && (endpoint.IsRunning() || endpoint.IsServing()) { + return fmt.Errorf("there is `%s` deployment for the model version", endpoint.Status) + } + return nil + }) +} + +func deploymentQuotaValidation(ctx context.Context, model *models.Model, env *models.Environment, endpointSvc service.EndpointsService) requestValidator { + return newFuncValidate(func() error { + deployedModelVersionCount, err := endpointSvc.CountEndpoints(ctx, env, model) + if err != nil { + return fmt.Errorf("unable to count number of endpoints in env %s: %v", env.Name, err) + } + + if deployedModelVersionCount >= config.MaxDeployedVersion { + return fmt.Errorf("max deployed endpoint reached. Max: %d Current: %d, undeploy existing endpoint before continuing", config.MaxDeployedVersion, deployedModelVersionCount) + } + return nil + }) +} + +func transformerValidation( + ctx context.Context, + endpoint *models.VersionEndpoint, + stdTransformerCfg config.StandardTransformerConfig, + feastCore core.CoreServiceClient) requestValidator { + return newFuncValidate(func() error { + if endpoint.Transformer != nil && endpoint.Transformer.Enabled { + err := validateTransformer(ctx, endpoint, stdTransformerCfg, feastCore) + if err != nil { + return fmt.Errorf("Error validating transformer: %v", err) + } + } + return nil + }) +} + +func updateRequestValidation(prev *models.VersionEndpoint, new *models.VersionEndpoint) requestValidator { + return newFuncValidate(func() error { + if prev.EnvironmentName != new.EnvironmentName { + return fmt.Errorf("updating environment is not allowed, previous: %s, new: %s", prev.EnvironmentName, new.EnvironmentName) + } + + if prev.Status == models.EndpointPending { + return fmt.Errorf("updating endpoint status to %s is not allowed when the endpoint is currently in the pending state", new.Status) + } + + if new.Status != prev.Status { + if prev.Status == models.EndpointServing { + return fmt.Errorf("updating endpoint status to %s is not allowed when the endpoint is currently in the serving state", new.Status) + } + + if new.Status != models.EndpointRunning && new.Status != models.EndpointTerminated { + return fmt.Errorf("updating endpoint status to %s is not allowed", new.Status) + } + } + return nil + }) +} + +func deploymentModeValidation(prev *models.VersionEndpoint, new *models.VersionEndpoint) requestValidator { + return newFuncValidate(func() error { + // Should not allow changing the deployment mode of a pending/running/serving model for 2 reasons: + // * For "serving" models it's risky as, we can't guarantee graceful re-deployment + // * Kserve uses slightly different deployment resource naming under the hood and doesn't clean up the older deployment + if (prev.IsRunning() || prev.IsServing()) && new.DeploymentMode != "" && + new.DeploymentMode != prev.DeploymentMode { + return fmt.Errorf("changing deployment type of a %s model is not allowed, please terminate it first", prev.Status) + } + return nil + }) +} + +func modelObservabilityValidation(endpoint *models.VersionEndpoint, model *models.Model) requestValidator { + return newFuncValidate(func() error { + if endpoint.EnableModelObservability && model.Type != models.ModelTypePyFuncV3 { + return fmt.Errorf("model type should be pyfunc_v3 if want to enable model observablity") + } + return nil + }) +} diff --git a/api/api/version_endpoints_api.go b/api/api/version_endpoints_api.go index fd1c629f1..10da316cf 100644 --- a/api/api/version_endpoints_api.go +++ b/api/api/version_endpoints_api.go @@ -33,7 +33,6 @@ import ( "github.com/caraml-dev/merlin/pkg/transformer/feast" "github.com/caraml-dev/merlin/pkg/transformer/pipeline" "github.com/caraml-dev/merlin/pkg/transformer/spec" - "github.com/caraml-dev/merlin/service" ) type EndpointsController struct { @@ -113,130 +112,6 @@ func (c *EndpointsController) GetEndpoint(r *http.Request, vars map[string]strin return Ok(endpoint) } -type endpointValidator interface { - validate(endpoint *models.VersionEndpoint, model *models.Model, version *models.Version) error -} - -type funcValidate struct { - f func(endpoint *models.VersionEndpoint, model *models.Model, version *models.Version) error -} - -func newFuncValidate(f func(*models.VersionEndpoint, *models.Model, *models.Version) error) *funcValidate { - return &funcValidate{ - f: f, - } -} - -func (fv *funcValidate) validate(endpoint *models.VersionEndpoint, model *models.Model, version *models.Version) error { - return fv.f(endpoint, model, version) -} - -func validateEndpointRequest(endpoint *models.VersionEndpoint, model *models.Model, version *models.Version, validators ...endpointValidator) error { - for _, validator := range validators { - if err := validator.validate(endpoint, model, version); err != nil { - return err - } - } - return nil -} - -func customModelValidation(model *models.Model) endpointValidator { - return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error { - if model.Type == models.ModelTypeCustom { - if err := validateCustomPredictor(v); err != nil { - return err - } - } - return nil - }) -} - -func upiModelValidation(model *models.Model) endpointValidator { - return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error { - if !isModelSupportUPI(model) && ve.Protocol == protocol.UpiV1 { - return fmt.Errorf("%s model is not supported by UPI", model.Type) - } - return nil - }) -} - -func newVersionEndpointValidation(version *models.Version, envName string) endpointValidator { - return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error { - endpoint, ok := version.GetEndpointByEnvironmentName(envName) - if ok && (endpoint.IsRunning() || endpoint.IsServing()) { - return fmt.Errorf("there is `%s` deployment for the model version", endpoint.Status) - } - return nil - }) -} - -func deploymentQuotaValidation(ctx context.Context, model *models.Model, env *models.Environment, endpointSvc service.EndpointsService) endpointValidator { - return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error { - deployedModelVersionCount, err := endpointSvc.CountEndpoints(ctx, env, model) - if err != nil { - return fmt.Errorf("unable to count number of endpoints in env %s: %v", env.Name, err) - } - - if deployedModelVersionCount >= config.MaxDeployedVersion { - return fmt.Errorf("max deployed endpoint reached. Max: %d Current: %d, undeploy existing endpoint before continuing", config.MaxDeployedVersion, deployedModelVersionCount) - } - return nil - }) -} - -func transformerValidation( - ctx context.Context, - endpoint *models.VersionEndpoint, - stdTransformerCfg config.StandardTransformerConfig, - feastCore core.CoreServiceClient) endpointValidator { - return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error { - if ve.Transformer != nil && ve.Transformer.Enabled { - err := validateTransformer(ctx, endpoint, stdTransformerCfg, feastCore) - if err != nil { - return fmt.Errorf("Error validating transformer: %v", err) - } - } - return nil - }) -} - -func updateRequestValidation(prev *models.VersionEndpoint, new *models.VersionEndpoint) endpointValidator { - return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error { - if prev.EnvironmentName != new.EnvironmentName { - return fmt.Errorf("updating environment is not allowed, previous: %s, new: %s", prev.EnvironmentName, new.EnvironmentName) - } - - if prev.Status == models.EndpointPending { - return fmt.Errorf("updating endpoint status to %s is not allowed when the endpoint is currently in the pending state", new.Status) - } - - if new.Status != prev.Status { - if prev.Status == models.EndpointServing { - return fmt.Errorf("updating endpoint status to %s is not allowed when the endpoint is currently in the serving state", new.Status) - } - - if new.Status != models.EndpointRunning && new.Status != models.EndpointTerminated { - return fmt.Errorf("updating endpoint status to %s is not allowed", new.Status) - } - } - return nil - }) -} - -func deploymentModeValidation(prev *models.VersionEndpoint, new *models.VersionEndpoint) endpointValidator { - return newFuncValidate(func(ve *models.VersionEndpoint, m *models.Model, v *models.Version) error { - // Should not allow changing the deployment mode of a pending/running/serving model for 2 reasons: - // * For "serving" models it's risky as, we can't guarantee graceful re-deployment - // * Kserve uses slightly different deployment resource naming under the hood and doesn't clean up the older deployment - if (prev.IsRunning() || prev.IsServing()) && new.DeploymentMode != "" && - new.DeploymentMode != prev.DeploymentMode { - return fmt.Errorf("changing deployment type of a %s model is not allowed, please terminate it first", prev.Status) - } - return nil - }) - -} - // CreateEndpoint create new endpoint from a model version and deploy to certain environment as specified by request // If target environment is not set then fallback to default environment func (c *EndpointsController) CreateEndpoint(r *http.Request, vars map[string]string, body interface{}) *Response { @@ -281,15 +156,16 @@ func (c *EndpointsController) CreateEndpoint(r *http.Request, vars map[string]st newEndpoint.EnvironmentName = env.Name } - validationRules := []endpointValidator{ - customModelValidation(model), - upiModelValidation(model), + validationRules := []requestValidator{ + customModelValidation(model, version), + upiModelValidation(model, newEndpoint.Protocol), newVersionEndpointValidation(version, env.Name), deploymentQuotaValidation(ctx, model, env, c.EndpointsService), transformerValidation(ctx, newEndpoint, c.StandardTransformerConfig, c.FeastCoreClient), + modelObservabilityValidation(newEndpoint, model), } - if err := validateEndpointRequest(newEndpoint, model, version, validationRules...); err != nil { + if err := validateRequest(validationRules...); err != nil { return BadRequest(fmt.Sprintf("Request validation failed: %v", err)) } @@ -304,18 +180,6 @@ func (c *EndpointsController) CreateEndpoint(r *http.Request, vars map[string]st return Created(endpoint) } -var supportedUPIModelTypes = map[string]bool{ - models.ModelTypePyFunc: true, - models.ModelTypeCustom: true, - models.ModelTypePyFuncV3: true, -} - -func isModelSupportUPI(model *models.Model) bool { - _, isSupported := supportedUPIModelTypes[model.Type] - - return isSupported -} - // UpdateEndpoint update a an existing endpoint i.e. trigger redeployment func (c *EndpointsController) UpdateEndpoint(r *http.Request, vars map[string]string, body interface{}) *Response { ctx := r.Context() @@ -345,11 +209,6 @@ func (c *EndpointsController) UpdateEndpoint(r *http.Request, vars map[string]st return BadRequest("Unable to parse body as version endpoint resource") } - validationRules := []endpointValidator{ - customModelValidation(model), - updateRequestValidation(endpoint, newEndpoint), - } - env, err := c.AppContext.EnvironmentService.GetEnvironment(newEndpoint.EnvironmentName) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -358,21 +217,20 @@ func (c *EndpointsController) UpdateEndpoint(r *http.Request, vars map[string]st return InternalServerError(fmt.Sprintf("Error getting the specified environment: %v", err)) } + validationRules := []requestValidator{ + customModelValidation(model, version), + updateRequestValidation(endpoint, newEndpoint), + modelObservabilityValidation(newEndpoint, model), + } + if newEndpoint.Status == models.EndpointRunning || newEndpoint.Status == models.EndpointServing { - // validate transformer - // if newEndpoint.Transformer != nil && newEndpoint.Transformer.Enabled { - // err := c.validateTransformer(ctx, newEndpoint.Transformer, newEndpoint.Protocol, newEndpoint.Logger) - // if err != nil { - // return BadRequest(fmt.Sprintf("Error validating the transformer: %v", err)) - // } - // } validationRules = append( validationRules, transformerValidation(ctx, newEndpoint, c.StandardTransformerConfig, c.FeastCoreClient), deploymentModeValidation(endpoint, newEndpoint), ) - if err := validateEndpointRequest(newEndpoint, model, version, validationRules...); err != nil { + if err := validateRequest(validationRules...); err != nil { return BadRequest(fmt.Sprintf("Request validation failed: %v", err)) } @@ -385,9 +243,10 @@ func (c *EndpointsController) UpdateEndpoint(r *http.Request, vars map[string]st return InternalServerError(fmt.Sprintf("Unable to deploy model version: %v", err)) } } else if newEndpoint.Status == models.EndpointTerminated { - if err := validateEndpointRequest(newEndpoint, model, version, validationRules...); err != nil { + if err := validateRequest(validationRules...); err != nil { return BadRequest(fmt.Sprintf("Request validation failed: %v", err)) } + endpoint, err = c.EndpointsService.UndeployEndpoint(ctx, env, model, version, endpoint) if err != nil { return InternalServerError(fmt.Sprintf("Unable to undeploy version endpoint %s: %v", endpointID, err)) diff --git a/api/api/version_endpoints_api_test.go b/api/api/version_endpoints_api_test.go index 93c7c8be0..8d2067c01 100644 --- a/api/api/version_endpoints_api_test.go +++ b/api/api/version_endpoints_api_test.go @@ -1042,6 +1042,288 @@ func TestCreateEndpoint(t *testing.T) { }, }, }, + { + desc: "Should success create endpoint with pyfunc_v3 model observability enabled", + vars: map[string]string{ + "model_id": "1", + "version_id": "1", + }, + requestBody: &models.VersionEndpoint{ + ID: uuid, + VersionID: models.ID(1), + VersionModelID: models.ID(1), + ServiceName: "sample", + Namespace: "sample", + EnvironmentName: "dev", + Message: "", + ResourceRequest: &models.ResourceRequest{ + MinReplica: 1, + MaxReplica: 4, + CPURequest: resource.MustParse("1"), + MemoryRequest: resource.MustParse("1Gi"), + }, + EnvVars: models.EnvVars([]models.EnvVar{ + { + Name: "WORKER", + Value: "1", + }, + }), + EnableModelObservability: true, + }, + modelService: func() *mocks.ModelsService { + svc := &mocks.ModelsService{} + svc.On("FindByID", mock.Anything, models.ID(1)).Return(&models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{}, + ExperimentID: 1, + Type: "pyfunc_v3", + MlflowURL: "", + Endpoints: nil, + }, nil) + return svc + }, + versionService: func() *mocks.VersionsService { + svc := &mocks.VersionsService{} + svc.On("FindByID", mock.Anything, models.ID(1), models.ID(1), mock.Anything).Return(&models.Version{ + ID: models.ID(1), + ModelID: models.ID(1), + Model: &models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{}, + ExperimentID: 1, + Type: "pyfunc", + MlflowURL: "", + Endpoints: nil, + }, + }, nil) + return svc + }, + envService: func() *mocks.EnvironmentService { + svc := &mocks.EnvironmentService{} + svc.On("GetDefaultEnvironment").Return(&models.Environment{ + ID: models.ID(1), + Name: "dev", + Cluster: "dev", + IsDefault: &trueBoolean, + Region: "id", + GcpProject: "dev-proj", + MaxCPU: "1", + MaxMemory: "1Gi", + }, nil) + svc.On("GetEnvironment", "dev").Return(&models.Environment{ + ID: models.ID(1), + Name: "dev", + Cluster: "dev", + IsDefault: &trueBoolean, + Region: "id", + GcpProject: "dev-proj", + MaxCPU: "1", + MaxMemory: "1Gi", + }, nil) + return svc + }, + endpointService: func() *mocks.EndpointsService { + svc := &mocks.EndpointsService{} + svc.On("CountEndpoints", context.Background(), mock.Anything, mock.Anything).Return(0, nil) + svc.On("DeployEndpoint", context.Background(), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&models.VersionEndpoint{ + ID: uuid, + VersionID: models.ID(1), + VersionModelID: models.ID(1), + Status: models.EndpointRunning, + URL: "http://endpoint.svc", + ServiceName: "sample", + InferenceServiceName: "sample", + Namespace: "sample", + Environment: &models.Environment{ + ID: models.ID(1), + Name: "dev", + Cluster: "dev", + IsDefault: &trueBoolean, + Region: "id", + GcpProject: "dev-proj", + MaxCPU: "1", + MaxMemory: "1Gi", + }, + EnvironmentName: "dev", + Message: "", + ResourceRequest: nil, + EnvVars: models.EnvVars([]models.EnvVar{ + { + Name: "WORKER", + Value: "1", + }, + }), + CreatedUpdated: models.CreatedUpdated{}, + }, nil) + return svc + }, + monitoringConfig: config.MonitoringConfig{}, + feastCoreMock: func() *feastmocks.CoreServiceClient { + return &feastmocks.CoreServiceClient{} + }, + expected: &Response{ + code: http.StatusCreated, + data: &models.VersionEndpoint{ + ID: uuid, + VersionID: models.ID(1), + VersionModelID: models.ID(1), + Status: models.EndpointRunning, + URL: "http://endpoint.svc", + ServiceName: "sample", + InferenceServiceName: "sample", + Namespace: "sample", + Environment: &models.Environment{ + ID: models.ID(1), + Name: "dev", + Cluster: "dev", + IsDefault: &trueBoolean, + Region: "id", + GcpProject: "dev-proj", + MaxCPU: "1", + MaxMemory: "1Gi", + }, + EnvironmentName: "dev", + Message: "", + ResourceRequest: nil, + EnvVars: models.EnvVars([]models.EnvVar{ + { + Name: "WORKER", + Value: "1", + }, + }), + CreatedUpdated: models.CreatedUpdated{}, + }, + }, + }, + { + desc: "Should failed create endpoint with non pyfunc_v3 and model observability enabled", + vars: map[string]string{ + "model_id": "1", + "version_id": "1", + }, + requestBody: &models.VersionEndpoint{ + ID: uuid, + VersionID: models.ID(1), + VersionModelID: models.ID(1), + ServiceName: "sample", + Namespace: "sample", + EnvironmentName: "dev", + Message: "", + ResourceRequest: &models.ResourceRequest{ + MinReplica: 1, + MaxReplica: 4, + CPURequest: resource.MustParse("1"), + MemoryRequest: resource.MustParse("1Gi"), + }, + EnvVars: models.EnvVars([]models.EnvVar{ + { + Name: "WORKER", + Value: "1", + }, + }), + EnableModelObservability: true, + }, + modelService: func() *mocks.ModelsService { + svc := &mocks.ModelsService{} + svc.On("FindByID", mock.Anything, models.ID(1)).Return(&models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{}, + ExperimentID: 1, + Type: "pyfunc", + MlflowURL: "", + Endpoints: nil, + }, nil) + return svc + }, + versionService: func() *mocks.VersionsService { + svc := &mocks.VersionsService{} + svc.On("FindByID", mock.Anything, models.ID(1), models.ID(1), mock.Anything).Return(&models.Version{ + ID: models.ID(1), + ModelID: models.ID(1), + Model: &models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{}, + ExperimentID: 1, + Type: "pyfunc", + MlflowURL: "", + Endpoints: nil, + }, + }, nil) + return svc + }, + envService: func() *mocks.EnvironmentService { + svc := &mocks.EnvironmentService{} + svc.On("GetDefaultEnvironment").Return(&models.Environment{ + ID: models.ID(1), + Name: "dev", + Cluster: "dev", + IsDefault: &trueBoolean, + Region: "id", + GcpProject: "dev-proj", + MaxCPU: "1", + MaxMemory: "1Gi", + }, nil) + svc.On("GetEnvironment", "dev").Return(&models.Environment{ + ID: models.ID(1), + Name: "dev", + Cluster: "dev", + IsDefault: &trueBoolean, + Region: "id", + GcpProject: "dev-proj", + MaxCPU: "1", + MaxMemory: "1Gi", + }, nil) + return svc + }, + endpointService: func() *mocks.EndpointsService { + svc := &mocks.EndpointsService{} + svc.On("CountEndpoints", context.Background(), mock.Anything, mock.Anything).Return(0, nil) + svc.On("DeployEndpoint", context.Background(), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&models.VersionEndpoint{ + ID: uuid, + VersionID: models.ID(1), + VersionModelID: models.ID(1), + Status: models.EndpointRunning, + URL: "http://endpoint.svc", + ServiceName: "sample", + InferenceServiceName: "sample", + Namespace: "sample", + Environment: &models.Environment{ + ID: models.ID(1), + Name: "dev", + Cluster: "dev", + IsDefault: &trueBoolean, + Region: "id", + GcpProject: "dev-proj", + MaxCPU: "1", + MaxMemory: "1Gi", + }, + EnvironmentName: "dev", + Message: "", + ResourceRequest: nil, + EnvVars: models.EnvVars([]models.EnvVar{ + { + Name: "WORKER", + Value: "1", + }, + }), + CreatedUpdated: models.CreatedUpdated{}, + }, nil) + return svc + }, + monitoringConfig: config.MonitoringConfig{}, + feastCoreMock: func() *feastmocks.CoreServiceClient { + return &feastmocks.CoreServiceClient{} + }, + expected: BadRequest("Request validation failed: model type should be pyfunc_v3 if want to enable model observablity"), + }, { desc: "Should return 400 if UPI is not supported", vars: map[string]string{ @@ -3183,11 +3465,200 @@ func TestCreateEndpoint(t *testing.T) { }`, }, }, - }, + }, + }, + modelService: func() *mocks.ModelsService { + svc := &mocks.ModelsService{} + svc.On("FindByID", mock.Anything, models.ID(1)).Return(&models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{}, + ExperimentID: 1, + Type: "pyfunc", + MlflowURL: "", + Endpoints: nil, + }, nil) + return svc + }, + versionService: func() *mocks.VersionsService { + svc := &mocks.VersionsService{} + svc.On("FindByID", mock.Anything, models.ID(1), models.ID(1), mock.Anything).Return(&models.Version{ + ID: models.ID(1), + ModelID: models.ID(1), + Model: &models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{}, + ExperimentID: 1, + Type: "pyfunc", + MlflowURL: "", + Endpoints: nil, + }, + }, nil) + return svc + }, + envService: func() *mocks.EnvironmentService { + svc := &mocks.EnvironmentService{} + svc.On("GetDefaultEnvironment").Return(&models.Environment{ + ID: models.ID(1), + Name: "dev", + Cluster: "dev", + IsDefault: &trueBoolean, + Region: "id", + GcpProject: "dev-proj", + MaxCPU: "1", + MaxMemory: "1Gi", + }, nil) + svc.On("GetEnvironment", "dev").Return(&models.Environment{ + ID: models.ID(1), + Name: "dev", + Cluster: "dev", + IsDefault: &trueBoolean, + Region: "id", + GcpProject: "dev-proj", + MaxCPU: "1", + MaxMemory: "1Gi", + }, nil) + return svc + }, + endpointService: func() *mocks.EndpointsService { + svc := &mocks.EndpointsService{} + svc.On("CountEndpoints", context.Background(), mock.Anything, mock.Anything).Return(0, nil) + svc.On("DeployEndpoint", context.Background(), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&models.VersionEndpoint{ + ID: uuid, + VersionID: models.ID(1), + VersionModelID: models.ID(1), + Status: models.EndpointRunning, + URL: "http://endpoint.svc", + ServiceName: "sample", + InferenceServiceName: "sample", + Namespace: "sample", + MonitoringURL: "http://monitoring.com", + Environment: &models.Environment{ + ID: models.ID(1), + Name: "dev", + Cluster: "dev", + IsDefault: &trueBoolean, + Region: "id", + GcpProject: "dev-proj", + MaxCPU: "1", + MaxMemory: "1Gi", + }, + EnvironmentName: "dev", + Message: "", + ResourceRequest: nil, + EnvVars: models.EnvVars([]models.EnvVar{ + { + Name: "WORKER", + Value: "1", + }, + }), + CreatedUpdated: models.CreatedUpdated{}, + }, nil) + return svc + }, + monitoringConfig: config.MonitoringConfig{ + MonitoringEnabled: true, + MonitoringBaseURL: "http://grafana", + }, + standardTransformerConfig: config.StandardTransformerConfig{ + FeastBigtableConfig: &config.FeastBigtableConfig{ + ServingURL: "localhost:6567", + }, + FeastRedisConfig: &config.FeastRedisConfig{ + ServingURL: "localhost:6566", + RedisAddresses: []string{ + "10.1.1.2", "10.1.1.3", + }, + PoolSize: 5, + }, + }, + feastCoreMock: func() *feastmocks.CoreServiceClient { + return &feastmocks.CoreServiceClient{} + }, + expected: &Response{ + code: http.StatusBadRequest, + data: Error{Message: "Request validation failed: Error validating transformer: feast source configuration is not valid, servingURL: localhost:6565 source: UNKNOWN"}, + }, + }, + } + for _, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + modelSvc := tC.modelService() + versionSvc := tC.versionService() + envSvc := tC.envService() + endpointSvc := tC.endpointService() + feastCoreMock := tC.feastCoreMock() + + ctl := &EndpointsController{ + AppContext: &AppContext{ + ModelsService: modelSvc, + VersionsService: versionSvc, + EnvironmentService: envSvc, + EndpointsService: endpointSvc, + FeatureToggleConfig: config.FeatureToggleConfig{ + AlertConfig: config.AlertConfig{ + AlertEnabled: true, + }, + MonitoringConfig: tC.monitoringConfig, + }, + StandardTransformerConfig: tC.standardTransformerConfig, + FeastCoreClient: feastCoreMock, + }, + } + resp := ctl.CreateEndpoint(&http.Request{}, tC.vars, tC.requestBody) + assertEqualResponses(t, tC.expected, resp) + }) + } +} + +func TestUpdateEndpoint(t *testing.T) { + uuid := uuid.New() + trueBoolean := true + testCases := []struct { + desc string + vars map[string]string + requestBody *models.VersionEndpoint + modelService func() *mocks.ModelsService + versionService func() *mocks.VersionsService + endpointService func() *mocks.EndpointsService + envService func() *mocks.EnvironmentService + expected *Response + }{ + { + desc: "Should success update endpoint", + vars: map[string]string{ + "model_id": "1", + "version_id": "1", + "endpoint_id": uuid.String(), + }, + requestBody: &models.VersionEndpoint{ + ID: uuid, + VersionID: models.ID(1), + VersionModelID: models.ID(1), + Status: models.EndpointRunning, + ServiceName: "sample", + Namespace: "sample", + EnvironmentName: "dev", + Message: "", + ResourceRequest: &models.ResourceRequest{ + MinReplica: 1, + MaxReplica: 4, + CPURequest: resource.MustParse("1"), + MemoryRequest: resource.MustParse("1Gi"), + }, + EnvVars: models.EnvVars([]models.EnvVar{ + { + Name: "WORKER", + Value: "1", + }, + }), }, modelService: func() *mocks.ModelsService { svc := &mocks.ModelsService{} - svc.On("FindByID", mock.Anything, models.ID(1)).Return(&models.Model{ + svc.On("FindByID", context.Background(), models.ID(1)).Return(&models.Model{ ID: models.ID(1), Name: "model-1", ProjectID: models.ID(1), @@ -3201,7 +3672,7 @@ func TestCreateEndpoint(t *testing.T) { }, versionService: func() *mocks.VersionsService { svc := &mocks.VersionsService{} - svc.On("FindByID", mock.Anything, models.ID(1), models.ID(1), mock.Anything).Return(&models.Version{ + svc.On("FindByID", context.Background(), models.ID(1), models.ID(1), mock.Anything).Return(&models.Version{ ID: models.ID(1), ModelID: models.ID(1), Model: &models.Model{ @@ -3219,16 +3690,6 @@ func TestCreateEndpoint(t *testing.T) { }, envService: func() *mocks.EnvironmentService { svc := &mocks.EnvironmentService{} - svc.On("GetDefaultEnvironment").Return(&models.Environment{ - ID: models.ID(1), - Name: "dev", - Cluster: "dev", - IsDefault: &trueBoolean, - Region: "id", - GcpProject: "dev-proj", - MaxCPU: "1", - MaxMemory: "1Gi", - }, nil) svc.On("GetEnvironment", "dev").Return(&models.Environment{ ID: models.ID(1), Name: "dev", @@ -3243,7 +3704,35 @@ func TestCreateEndpoint(t *testing.T) { }, endpointService: func() *mocks.EndpointsService { svc := &mocks.EndpointsService{} - svc.On("CountEndpoints", context.Background(), mock.Anything, mock.Anything).Return(0, nil) + svc.On("FindByID", context.Background(), uuid).Return(&models.VersionEndpoint{ + ID: uuid, + VersionID: models.ID(1), + VersionModelID: models.ID(1), + Status: models.EndpointRunning, + ServiceName: "sample", + InferenceServiceName: "sample", + Namespace: "sample", + URL: "http://endpoint.svc", + MonitoringURL: "http://monitoring.com", + Environment: &models.Environment{ + ID: models.ID(1), + Name: "dev", + Cluster: "dev", + IsDefault: &trueBoolean, + Region: "id", + GcpProject: "dev-proj", + MaxCPU: "1", + MaxMemory: "1Gi", + }, EnvironmentName: "dev", + Message: "", + ResourceRequest: nil, + EnvVars: models.EnvVars([]models.EnvVar{ + { + Name: "WORKER", + Value: "1", + }, + }), + }, nil) svc.On("DeployEndpoint", context.Background(), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&models.VersionEndpoint{ ID: uuid, VersionID: models.ID(1), @@ -3277,76 +3766,43 @@ func TestCreateEndpoint(t *testing.T) { }, nil) return svc }, - monitoringConfig: config.MonitoringConfig{ - MonitoringEnabled: true, - MonitoringBaseURL: "http://grafana", - }, - standardTransformerConfig: config.StandardTransformerConfig{ - FeastBigtableConfig: &config.FeastBigtableConfig{ - ServingURL: "localhost:6567", - }, - FeastRedisConfig: &config.FeastRedisConfig{ - ServingURL: "localhost:6566", - RedisAddresses: []string{ - "10.1.1.2", "10.1.1.3", + expected: &Response{ + code: http.StatusOK, + data: &models.VersionEndpoint{ + ID: uuid, + VersionID: models.ID(1), + VersionModelID: models.ID(1), + Status: models.EndpointRunning, + URL: "http://endpoint.svc", + ServiceName: "sample", + InferenceServiceName: "sample", + Namespace: "sample", + MonitoringURL: "http://monitoring.com", + Environment: &models.Environment{ + ID: models.ID(1), + Name: "dev", + Cluster: "dev", + IsDefault: &trueBoolean, + Region: "id", + GcpProject: "dev-proj", + MaxCPU: "1", + MaxMemory: "1Gi", }, - PoolSize: 5, + EnvironmentName: "dev", + Message: "", + ResourceRequest: nil, + EnvVars: models.EnvVars([]models.EnvVar{ + { + Name: "WORKER", + Value: "1", + }, + }), + CreatedUpdated: models.CreatedUpdated{}, }, }, - feastCoreMock: func() *feastmocks.CoreServiceClient { - return &feastmocks.CoreServiceClient{} - }, - expected: &Response{ - code: http.StatusBadRequest, - data: Error{Message: "Request validation failed: Error validating transformer: feast source configuration is not valid, servingURL: localhost:6565 source: UNKNOWN"}, - }, }, - } - for _, tC := range testCases { - t.Run(tC.desc, func(t *testing.T) { - modelSvc := tC.modelService() - versionSvc := tC.versionService() - envSvc := tC.envService() - endpointSvc := tC.endpointService() - feastCoreMock := tC.feastCoreMock() - - ctl := &EndpointsController{ - AppContext: &AppContext{ - ModelsService: modelSvc, - VersionsService: versionSvc, - EnvironmentService: envSvc, - EndpointsService: endpointSvc, - FeatureToggleConfig: config.FeatureToggleConfig{ - AlertConfig: config.AlertConfig{ - AlertEnabled: true, - }, - MonitoringConfig: tC.monitoringConfig, - }, - StandardTransformerConfig: tC.standardTransformerConfig, - FeastCoreClient: feastCoreMock, - }, - } - resp := ctl.CreateEndpoint(&http.Request{}, tC.vars, tC.requestBody) - assertEqualResponses(t, tC.expected, resp) - }) - } -} - -func TestUpdateEndpoint(t *testing.T) { - uuid := uuid.New() - trueBoolean := true - testCases := []struct { - desc string - vars map[string]string - requestBody *models.VersionEndpoint - modelService func() *mocks.ModelsService - versionService func() *mocks.VersionsService - endpointService func() *mocks.EndpointsService - envService func() *mocks.EnvironmentService - expected *Response - }{ { - desc: "Should success update endpoint", + desc: "Should success update endpoint, pyfunc_v3 and model observablity enabled", vars: map[string]string{ "model_id": "1", "version_id": "1", @@ -3373,6 +3829,7 @@ func TestUpdateEndpoint(t *testing.T) { Value: "1", }, }), + EnableModelObservability: true, }, modelService: func() *mocks.ModelsService { svc := &mocks.ModelsService{} @@ -3382,7 +3839,7 @@ func TestUpdateEndpoint(t *testing.T) { ProjectID: models.ID(1), Project: mlp.Project{}, ExperimentID: 1, - Type: "pyfunc", + Type: "pyfunc_v3", MlflowURL: "", Endpoints: nil, }, nil) @@ -3450,6 +3907,7 @@ func TestUpdateEndpoint(t *testing.T) { Value: "1", }, }), + EnableModelObservability: false, }, nil) svc.On("DeployEndpoint", context.Background(), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&models.VersionEndpoint{ ID: uuid, @@ -3480,7 +3938,8 @@ func TestUpdateEndpoint(t *testing.T) { Value: "1", }, }), - CreatedUpdated: models.CreatedUpdated{}, + CreatedUpdated: models.CreatedUpdated{}, + EnableModelObservability: true, }, nil) return svc }, @@ -3515,7 +3974,8 @@ func TestUpdateEndpoint(t *testing.T) { Value: "1", }, }), - CreatedUpdated: models.CreatedUpdated{}, + CreatedUpdated: models.CreatedUpdated{}, + EnableModelObservability: true, }, }, }, @@ -4032,6 +4492,121 @@ func TestUpdateEndpoint(t *testing.T) { data: Error{Message: "Updating endpoint status to pending is not allowed"}, }, }, + { + desc: "Should 400 if new endpoint enable model observability but the model is not pyfunc_v3", + vars: map[string]string{ + "model_id": "1", + "version_id": "1", + "endpoint_id": uuid.String(), + }, + requestBody: &models.VersionEndpoint{ + ID: uuid, + VersionID: models.ID(1), + VersionModelID: models.ID(1), + Status: models.EndpointRunning, + ServiceName: "sample", + Namespace: "sample", + EnvironmentName: "dev", + Message: "", + ResourceRequest: &models.ResourceRequest{ + MinReplica: 1, + MaxReplica: 4, + CPURequest: resource.MustParse("1"), + MemoryRequest: resource.MustParse("1Gi"), + }, + EnvVars: models.EnvVars([]models.EnvVar{ + { + Name: "WORKER", + Value: "1", + }, + }), + EnableModelObservability: true, + }, + modelService: func() *mocks.ModelsService { + svc := &mocks.ModelsService{} + svc.On("FindByID", context.Background(), models.ID(1)).Return(&models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{}, + ExperimentID: 1, + Type: "pyfunc", + MlflowURL: "", + Endpoints: nil, + }, nil) + return svc + }, + versionService: func() *mocks.VersionsService { + svc := &mocks.VersionsService{} + svc.On("FindByID", context.Background(), models.ID(1), models.ID(1), mock.Anything).Return(&models.Version{ + ID: models.ID(1), + ModelID: models.ID(1), + Model: &models.Model{ + ID: models.ID(1), + Name: "model-1", + ProjectID: models.ID(1), + Project: mlp.Project{}, + ExperimentID: 1, + Type: "pyfunc", + MlflowURL: "", + Endpoints: nil, + }, + }, nil) + return svc + }, + envService: func() *mocks.EnvironmentService { + svc := &mocks.EnvironmentService{} + svc.On("GetEnvironment", "dev").Return(&models.Environment{ + ID: models.ID(1), + Name: "dev", + Cluster: "dev", + IsDefault: &trueBoolean, + Region: "id", + GcpProject: "dev-proj", + MaxCPU: "1", + MaxMemory: "1Gi", + }, nil) + return svc + }, + endpointService: func() *mocks.EndpointsService { + svc := &mocks.EndpointsService{} + svc.On("FindByID", context.Background(), uuid).Return(&models.VersionEndpoint{ + ID: uuid, + VersionID: models.ID(1), + VersionModelID: models.ID(1), + Status: models.EndpointRunning, + ServiceName: "sample", + InferenceServiceName: "sample", + Namespace: "sample", + URL: "http://endpoint.svc", + MonitoringURL: "http://monitoring.com", + Environment: &models.Environment{ + ID: models.ID(1), + Name: "dev", + Cluster: "dev", + IsDefault: &trueBoolean, + Region: "id", + GcpProject: "dev-proj", + MaxCPU: "1", + MaxMemory: "1Gi", + }, EnvironmentName: "dev", + Message: "", + ResourceRequest: nil, + EnvVars: models.EnvVars([]models.EnvVar{ + { + Name: "WORKER", + Value: "1", + }, + }), + EnableModelObservability: false, + }, nil) + return svc + }, + expected: &Response{ + code: http.StatusBadRequest, + data: Error{Message: "Request validation failed: model type should be pyfunc_v3 if want to enable model observablity"}, + }, + }, { desc: "Should return 500 if endpoint not found", vars: map[string]string{