From 22619bd329d44acb231c43c27ea567ed96c9ecfe Mon Sep 17 00:00:00 2001 From: Software Developer <7852635+dsuhinin@users.noreply.github.com> Date: Fri, 3 Jan 2025 16:50:57 +0100 Subject: [PATCH] Move POST /mlflow/model-versions/transition-stage endpoint (#101) Signed-off-by: Software Developer Signed-off-by: dsuhinin Co-authored-by: DSuhinin --- magefiles/generate/endpoints.go | 2 +- mlflow_go/store/model_registry.py | 12 +++ pkg/contract/service/model_registry.g.go | 1 + pkg/lib/model_registry.g.go | 8 ++ pkg/model_registry/service/model_versions.go | 36 ++++++++ .../store/sql/model_versions.go | 82 ++++++++++++++++++- .../store/sql/models/model_version_stage.go | 9 +- pkg/model_registry/store/store.go | 4 + pkg/server/routes/model_registry.g.go | 11 +++ 9 files changed, 161 insertions(+), 4 deletions(-) diff --git a/magefiles/generate/endpoints.go b/magefiles/generate/endpoints.go index c0bcb2d5..480f57df 100644 --- a/magefiles/generate/endpoints.go +++ b/magefiles/generate/endpoints.go @@ -57,7 +57,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{ "getLatestVersions", // "createModelVersion", "updateModelVersion", - // "transitionModelVersionStage", + "transitionModelVersionStage", "deleteModelVersion", // "getModelVersion", // "searchModelVersions", diff --git a/mlflow_go/store/model_registry.py b/mlflow_go/store/model_registry.py index 01d76e9f..fca9ccb1 100644 --- a/mlflow_go/store/model_registry.py +++ b/mlflow_go/store/model_registry.py @@ -8,6 +8,7 @@ GetLatestVersions, GetRegisteredModel, RenameRegisteredModel, + TransitionModelVersionStage, UpdateModelVersion, UpdateRegisteredModel, ) @@ -89,6 +90,17 @@ def update_model_version(self, name, version, description=None): request = UpdateModelVersion(name=name, version=str(version), description=description) self.service.call_endpoint(get_lib().ModelRegistryServiceUpdateModelVersion, request) + def transition_model_version_stage(self, name, version, stage, archive_existing_versions): + request = TransitionModelVersionStage( + name=name, + version=str(version), + stage=stage, + archive_existing_versions=archive_existing_versions, + ) + self.service.call_endpoint( + get_lib().ModelRegistryServiceTransitionModelVersionStage, request + ) + def ModelRegistryStore(cls): return type(cls.__name__, (_ModelRegistryStore, cls), {}) diff --git a/pkg/contract/service/model_registry.g.go b/pkg/contract/service/model_registry.g.go index 573383c7..13209c9c 100644 --- a/pkg/contract/service/model_registry.g.go +++ b/pkg/contract/service/model_registry.g.go @@ -16,5 +16,6 @@ type ModelRegistryService interface { GetRegisteredModel(ctx context.Context, input *protos.GetRegisteredModel) (*protos.GetRegisteredModel_Response, *contract.Error) GetLatestVersions(ctx context.Context, input *protos.GetLatestVersions) (*protos.GetLatestVersions_Response, *contract.Error) UpdateModelVersion(ctx context.Context, input *protos.UpdateModelVersion) (*protos.UpdateModelVersion_Response, *contract.Error) + TransitionModelVersionStage(ctx context.Context, input *protos.TransitionModelVersionStage) (*protos.TransitionModelVersionStage_Response, *contract.Error) DeleteModelVersion(ctx context.Context, input *protos.DeleteModelVersion) (*protos.DeleteModelVersion_Response, *contract.Error) } diff --git a/pkg/lib/model_registry.g.go b/pkg/lib/model_registry.g.go index 9a5bdba8..d9dffe7b 100644 --- a/pkg/lib/model_registry.g.go +++ b/pkg/lib/model_registry.g.go @@ -55,6 +55,14 @@ func ModelRegistryServiceUpdateModelVersion(serviceID int64, requestData unsafe. } return invokeServiceMethod(service.UpdateModelVersion, new(protos.UpdateModelVersion), requestData, requestSize, responseSize) } +//export ModelRegistryServiceTransitionModelVersionStage +func ModelRegistryServiceTransitionModelVersionStage(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { + service, err := modelRegistryServices.Get(serviceID) + if err != nil { + return makePointerFromError(err, responseSize) + } + return invokeServiceMethod(service.TransitionModelVersionStage, new(protos.TransitionModelVersionStage), requestData, requestSize, responseSize) +} //export ModelRegistryServiceDeleteModelVersion func ModelRegistryServiceDeleteModelVersion(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { service, err := modelRegistryServices.Get(serviceID) diff --git a/pkg/model_registry/service/model_versions.go b/pkg/model_registry/service/model_versions.go index ac53eca7..49df47fc 100644 --- a/pkg/model_registry/service/model_versions.go +++ b/pkg/model_registry/service/model_versions.go @@ -2,8 +2,11 @@ package service import ( "context" + "fmt" + "strings" "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/model_registry/store/sql/models" "github.com/mlflow/mlflow-go/pkg/protos" ) @@ -99,3 +102,36 @@ func (m *ModelRegistryService) UpdateModelVersion( ModelVersion: modelVersion.ToProto(), }, nil } + +func (m *ModelRegistryService) TransitionModelVersionStage( + ctx context.Context, input *protos.TransitionModelVersionStage, +) (*protos.TransitionModelVersionStage_Response, *contract.Error) { + stage, ok := models.CanonicalMapping[strings.ToLower(input.GetStage())] + if !ok { + return nil, contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf( + "Invalid Model Version stage: unknown. Value must be one of %s, %s, %s, %s.", + models.ModelVersionStageNone, + models.ModelVersionStageStaging, + models.ModelVersionStageProduction, + models.ModelVersionStageArchived, + ), + ) + } + + modelVersion, err := m.store.TransitionModelVersionStage( + ctx, + input.GetName(), + input.GetVersion(), + stage, + input.GetArchiveExistingVersions(), + ) + if err != nil { + return nil, err + } + + return &protos.TransitionModelVersionStage_Response{ + ModelVersion: modelVersion.ToProto(), + }, nil +} diff --git a/pkg/model_registry/store/sql/model_versions.go b/pkg/model_registry/store/sql/model_versions.go index 4b18973e..91b282f1 100644 --- a/pkg/model_registry/store/sql/model_versions.go +++ b/pkg/model_registry/store/sql/model_versions.go @@ -57,7 +57,7 @@ func (m *ModelRegistrySQLStore) GetLatestVersions( for idx, stage := range stages { stages[idx] = strings.ToLower(stage) if canonicalStage, ok := models.CanonicalMapping[stages[idx]]; ok { - stages[idx] = canonicalStage + stages[idx] = canonicalStage.String() continue } @@ -377,3 +377,83 @@ func (m *ModelRegistrySQLStore) UpdateModelVersion( return modelVersion, nil } + +//nolint:funlen,cyclop +func (m *ModelRegistrySQLStore) TransitionModelVersionStage( + ctx context.Context, name, version string, stage models.ModelVersionStage, archiveExistingVersions bool, +) (*entities.ModelVersion, *contract.Error) { + isActiveStage := false + if _, ok := models.DefaultStagesForGetLatestVersions[strings.ToLower(stage.String())]; ok { + isActiveStage = true + } + + if archiveExistingVersions && !isActiveStage { + return nil, contract.NewError( + protos.ErrorCode_INVALID_PARAMETER_VALUE, + fmt.Sprintf( + `Model version transition cannot archive existing model versions because '%s' is not an Active stage. +Valid stages are %s`, + stage, models.AllModelVersionStages(), + ), + ) + } + + modelVersion, err := m.GetModelVersion(ctx, name, version) + if err != nil { + return nil, err + } + + registeredModel, err := m.GetRegisteredModel(ctx, name) + if err != nil { + return nil, err + } + + if err := m.db.Transaction(func(transaction *gorm.DB) error { + lastUpdatedTime := time.Now().UnixMilli() + if err := transaction.Model( + &models.RegisteredModel{}, + ).Where( + "name = ?", registeredModel.Name, + ).Updates(&models.RegisteredModel{ + LastUpdatedTime: lastUpdatedTime, + }).Error; err != nil { + return err + } + + if err := transaction.Model( + &models.ModelVersion{}, + ).Where( + "name = ?", modelVersion.Name, + ).Where( + "version = ?", modelVersion.Version, + ).Updates(&models.ModelVersion{ + CurrentStage: stage, + LastUpdatedTime: lastUpdatedTime, + }).Error; err != nil { + return err + } + + if archiveExistingVersions { + if err := transaction.Where( + "name = ?", name, + ).Where( + "version != ?", version, + ).Where( + "current_stage = ?", stage, + ).Updates(&models.ModelVersion{ + CurrentStage: models.ModelVersionStageArchived, + LastUpdatedTime: lastUpdatedTime, + }).Error; err != nil { + return err + } + } + + return nil + }); err != nil { + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, "error transitioning model version stage", err, + ) + } + + return modelVersion, nil +} diff --git a/pkg/model_registry/store/sql/models/model_version_stage.go b/pkg/model_registry/store/sql/models/model_version_stage.go index 4020f36c..bfbee7c5 100644 --- a/pkg/model_registry/store/sql/models/model_version_stage.go +++ b/pkg/model_registry/store/sql/models/model_version_stage.go @@ -15,18 +15,23 @@ const ( ModelVersionStageArchived = "Archived" ) -var CanonicalMapping = map[string]string{ +var CanonicalMapping = map[string]ModelVersionStage{ strings.ToLower(ModelVersionStageNone): ModelVersionStageNone, strings.ToLower(ModelVersionStageStaging): ModelVersionStageStaging, strings.ToLower(ModelVersionStageProduction): ModelVersionStageProduction, strings.ToLower(ModelVersionStageArchived): ModelVersionStageArchived, } +var DefaultStagesForGetLatestVersions = map[string]ModelVersionStage{ + strings.ToLower(ModelVersionStageStaging): ModelVersionStageStaging, + strings.ToLower(ModelVersionStageProduction): ModelVersionStageProduction, +} + func AllModelVersionStages() string { pairs := make([]string, 0, len(CanonicalMapping)) for _, v := range CanonicalMapping { - pairs = append(pairs, v) + pairs = append(pairs, v.String()) } return strings.Join(pairs, ",") diff --git a/pkg/model_registry/store/store.go b/pkg/model_registry/store/store.go index b1245e30..bf7dcf60 100644 --- a/pkg/model_registry/store/store.go +++ b/pkg/model_registry/store/store.go @@ -5,6 +5,7 @@ import ( "github.com/mlflow/mlflow-go/pkg/contract" "github.com/mlflow/mlflow-go/pkg/entities" + "github.com/mlflow/mlflow-go/pkg/model_registry/store/sql/models" "github.com/mlflow/mlflow-go/pkg/protos" ) @@ -17,4 +18,7 @@ type ModelRegistryStore interface { DeleteRegisteredModel(ctx context.Context, name string) *contract.Error DeleteModelVersion(ctx context.Context, name, version string) *contract.Error UpdateModelVersion(ctx context.Context, name, version, description string) (*entities.ModelVersion, *contract.Error) + TransitionModelVersionStage( + ctx context.Context, name, version string, stage models.ModelVersionStage, archiveExistingVersions bool, + ) (*entities.ModelVersion, *contract.Error) } diff --git a/pkg/server/routes/model_registry.g.go b/pkg/server/routes/model_registry.g.go index ca0f2489..b9903bed 100644 --- a/pkg/server/routes/model_registry.g.go +++ b/pkg/server/routes/model_registry.g.go @@ -88,6 +88,17 @@ func RegisterModelRegistryServiceRoutes(service service.ModelRegistryService, pa } return ctx.JSON(output) }) + app.Post("/mlflow/model-versions/transition-stage", func(ctx *fiber.Ctx) error { + input := &protos.TransitionModelVersionStage{} + if err := parser.ParseBody(ctx, input); err != nil { + return err + } + output, err := service.TransitionModelVersionStage(utils.NewContextWithLoggerFromFiberContext(ctx), input) + if err != nil { + return err + } + return ctx.JSON(output) + }) app.Delete("/mlflow/model-versions/delete", func(ctx *fiber.Ctx) error { input := &protos.DeleteModelVersion{} if err := parser.ParseBody(ctx, input); err != nil {