Skip to content

Commit

Permalink
Merge branch 'main' into dsuhinin/move-get-model-version-endpoint
Browse files Browse the repository at this point in the history
Signed-off-by: Software Developer <[email protected]>
  • Loading branch information
dsuhinin authored Jan 7, 2025
2 parents 02c83e3 + 22619bd commit e350e44
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 4 deletions.
2 changes: 1 addition & 1 deletion magefiles/generate/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{
"getLatestVersions",
// "createModelVersion",
"updateModelVersion",
// "transitionModelVersionStage",
"transitionModelVersionStage",
"deleteModelVersion",
"getModelVersion",
// "searchModelVersions",
Expand Down
12 changes: 12 additions & 0 deletions mlflow_go/store/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
GetModelVersion,
GetRegisteredModel,
RenameRegisteredModel,
TransitionModelVersionStage,
UpdateModelVersion,
UpdateRegisteredModel,
)
Expand Down Expand Up @@ -100,6 +101,17 @@ def update_model_version(self, name, version, description=None):
request = UpdateModelVersion(name=name, version=str(version), description=description)
self.service.call_endpoint(get_lib().ModelRegistryServiceUpdateModelVersion, request)

def transition_model_version_stage(self, name, version, stage, archive_existing_versions):
request = TransitionModelVersionStage(
name=name,
version=str(version),
stage=stage,
archive_existing_versions=archive_existing_versions,
)
self.service.call_endpoint(
get_lib().ModelRegistryServiceTransitionModelVersionStage, request
)


def ModelRegistryStore(cls):
return type(cls.__name__, (_ModelRegistryStore, cls), {})
Expand Down
1 change: 1 addition & 0 deletions pkg/contract/service/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions pkg/lib/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 36 additions & 0 deletions pkg/model_registry/service/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package service
import (
"context"
"strconv"
"fmt"
"strings"

"github.com/mlflow/mlflow-go/pkg/contract"
"github.com/mlflow/mlflow-go/pkg/model_registry/store/sql/models"
"github.com/mlflow/mlflow-go/pkg/protos"
)

Expand Down Expand Up @@ -122,3 +125,36 @@ func (m *ModelRegistryService) UpdateModelVersion(
ModelVersion: modelVersion.ToProto(),
}, nil
}

func (m *ModelRegistryService) TransitionModelVersionStage(
ctx context.Context, input *protos.TransitionModelVersionStage,
) (*protos.TransitionModelVersionStage_Response, *contract.Error) {
stage, ok := models.CanonicalMapping[strings.ToLower(input.GetStage())]
if !ok {
return nil, contract.NewError(
protos.ErrorCode_INVALID_PARAMETER_VALUE,
fmt.Sprintf(
"Invalid Model Version stage: unknown. Value must be one of %s, %s, %s, %s.",
models.ModelVersionStageNone,
models.ModelVersionStageStaging,
models.ModelVersionStageProduction,
models.ModelVersionStageArchived,
),
)
}

modelVersion, err := m.store.TransitionModelVersionStage(
ctx,
input.GetName(),
input.GetVersion(),
stage,
input.GetArchiveExistingVersions(),
)
if err != nil {
return nil, err
}

return &protos.TransitionModelVersionStage_Response{
ModelVersion: modelVersion.ToProto(),
}, nil
}
82 changes: 81 additions & 1 deletion pkg/model_registry/store/sql/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (m *ModelRegistrySQLStore) GetLatestVersions(
for idx, stage := range stages {
stages[idx] = strings.ToLower(stage)
if canonicalStage, ok := models.CanonicalMapping[stages[idx]]; ok {
stages[idx] = canonicalStage
stages[idx] = canonicalStage.String()

continue
}
Expand Down Expand Up @@ -402,3 +402,83 @@ func (m *ModelRegistrySQLStore) UpdateModelVersion(

return modelVersion, nil
}

//nolint:funlen,cyclop
func (m *ModelRegistrySQLStore) TransitionModelVersionStage(
ctx context.Context, name, version string, stage models.ModelVersionStage, archiveExistingVersions bool,
) (*entities.ModelVersion, *contract.Error) {
isActiveStage := false
if _, ok := models.DefaultStagesForGetLatestVersions[strings.ToLower(stage.String())]; ok {
isActiveStage = true
}

if archiveExistingVersions && !isActiveStage {
return nil, contract.NewError(
protos.ErrorCode_INVALID_PARAMETER_VALUE,
fmt.Sprintf(
`Model version transition cannot archive existing model versions because '%s' is not an Active stage.
Valid stages are %s`,
stage, models.AllModelVersionStages(),
),
)
}

modelVersion, err := m.GetModelVersion(ctx, name, version)

Check failure on line 426 in pkg/model_registry/store/sql/model_versions.go

View workflow job for this annotation

GitHub Actions / test / Test Go (macos-latest)

not enough arguments in call to m.GetModelVersion

Check failure on line 426 in pkg/model_registry/store/sql/model_versions.go

View workflow job for this annotation

GitHub Actions / build / Build Python wheel (darwin, amd64)

not enough arguments in call to m.GetModelVersion

Check failure on line 426 in pkg/model_registry/store/sql/model_versions.go

View workflow job for this annotation

GitHub Actions / test / Test Python (macos-latest, 3.9)

not enough arguments in call to m.GetModelVersion

Check failure on line 426 in pkg/model_registry/store/sql/model_versions.go

View workflow job for this annotation

GitHub Actions / lint / Lint

not enough arguments in call to m.GetModelVersion

Check failure on line 426 in pkg/model_registry/store/sql/model_versions.go

View workflow job for this annotation

GitHub Actions / lint / Lint

not enough arguments in call to m.GetModelVersion

Check failure on line 426 in pkg/model_registry/store/sql/model_versions.go

View workflow job for this annotation

GitHub Actions / test / Test Python (macos-latest, 3.10)

not enough arguments in call to m.GetModelVersion

Check failure on line 426 in pkg/model_registry/store/sql/model_versions.go

View workflow job for this annotation

GitHub Actions / test / Test Go (ubuntu-latest)

not enough arguments in call to m.GetModelVersion

Check failure on line 426 in pkg/model_registry/store/sql/model_versions.go

View workflow job for this annotation

GitHub Actions / test / Test Python (macos-latest, 3.11)

not enough arguments in call to m.GetModelVersion

Check failure on line 426 in pkg/model_registry/store/sql/model_versions.go

View workflow job for this annotation

GitHub Actions / test / Test Python (macos-latest, 3.12)

not enough arguments in call to m.GetModelVersion

Check failure on line 426 in pkg/model_registry/store/sql/model_versions.go

View workflow job for this annotation

GitHub Actions / test / Test Python (ubuntu-latest, 3.9)

not enough arguments in call to m.GetModelVersion

Check failure on line 426 in pkg/model_registry/store/sql/model_versions.go

View workflow job for this annotation

GitHub Actions / test / Test Python (ubuntu-latest, 3.10)

not enough arguments in call to m.GetModelVersion

Check failure on line 426 in pkg/model_registry/store/sql/model_versions.go

View workflow job for this annotation

GitHub Actions / test / Test Python (ubuntu-latest, 3.11)

not enough arguments in call to m.GetModelVersion

Check failure on line 426 in pkg/model_registry/store/sql/model_versions.go

View workflow job for this annotation

GitHub Actions / test / Test Python (ubuntu-latest, 3.12)

not enough arguments in call to m.GetModelVersion

Check failure on line 426 in pkg/model_registry/store/sql/model_versions.go

View workflow job for this annotation

GitHub Actions / test / Test Python (windows-latest, 3.9)

not enough arguments in call to m.GetModelVersion

Check failure on line 426 in pkg/model_registry/store/sql/model_versions.go

View workflow job for this annotation

GitHub Actions / test / Test Python (windows-latest, 3.10)

not enough arguments in call to m.GetModelVersion

Check failure on line 426 in pkg/model_registry/store/sql/model_versions.go

View workflow job for this annotation

GitHub Actions / test / Test Python (windows-latest, 3.11)

not enough arguments in call to m.GetModelVersion
if err != nil {
return nil, err
}

registeredModel, err := m.GetRegisteredModel(ctx, name)
if err != nil {
return nil, err
}

if err := m.db.Transaction(func(transaction *gorm.DB) error {
lastUpdatedTime := time.Now().UnixMilli()
if err := transaction.Model(
&models.RegisteredModel{},
).Where(
"name = ?", registeredModel.Name,
).Updates(&models.RegisteredModel{
LastUpdatedTime: lastUpdatedTime,
}).Error; err != nil {
return err
}

if err := transaction.Model(
&models.ModelVersion{},
).Where(
"name = ?", modelVersion.Name,
).Where(
"version = ?", modelVersion.Version,
).Updates(&models.ModelVersion{
CurrentStage: stage,
LastUpdatedTime: lastUpdatedTime,
}).Error; err != nil {
return err
}

if archiveExistingVersions {
if err := transaction.Where(
"name = ?", name,
).Where(
"version != ?", version,
).Where(
"current_stage = ?", stage,
).Updates(&models.ModelVersion{
CurrentStage: models.ModelVersionStageArchived,
LastUpdatedTime: lastUpdatedTime,
}).Error; err != nil {
return err
}
}

return nil
}); err != nil {
return nil, contract.NewErrorWith(
protos.ErrorCode_INTERNAL_ERROR, "error transitioning model version stage", err,
)
}

return modelVersion, nil
}
9 changes: 7 additions & 2 deletions pkg/model_registry/store/sql/models/model_version_stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,23 @@ const (
ModelVersionStageArchived = "Archived"
)

var CanonicalMapping = map[string]string{
var CanonicalMapping = map[string]ModelVersionStage{
strings.ToLower(ModelVersionStageNone): ModelVersionStageNone,
strings.ToLower(ModelVersionStageStaging): ModelVersionStageStaging,
strings.ToLower(ModelVersionStageProduction): ModelVersionStageProduction,
strings.ToLower(ModelVersionStageArchived): ModelVersionStageArchived,
}

var DefaultStagesForGetLatestVersions = map[string]ModelVersionStage{
strings.ToLower(ModelVersionStageStaging): ModelVersionStageStaging,
strings.ToLower(ModelVersionStageProduction): ModelVersionStageProduction,
}

func AllModelVersionStages() string {
pairs := make([]string, 0, len(CanonicalMapping))

for _, v := range CanonicalMapping {
pairs = append(pairs, v)
pairs = append(pairs, v.String())
}

return strings.Join(pairs, ",")
Expand Down
4 changes: 4 additions & 0 deletions pkg/model_registry/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/mlflow/mlflow-go/pkg/contract"
"github.com/mlflow/mlflow-go/pkg/entities"
"github.com/mlflow/mlflow-go/pkg/model_registry/store/sql/models"
"github.com/mlflow/mlflow-go/pkg/protos"
)

Expand All @@ -18,4 +19,7 @@ type ModelRegistryStore interface {
GetModelVersion(ctx context.Context, name, version string, eager bool) (*entities.ModelVersion, *contract.Error)
DeleteModelVersion(ctx context.Context, name, version string) *contract.Error
UpdateModelVersion(ctx context.Context, name, version, description string) (*entities.ModelVersion, *contract.Error)
TransitionModelVersionStage(
ctx context.Context, name, version string, stage models.ModelVersionStage, archiveExistingVersions bool,
) (*entities.ModelVersion, *contract.Error)
}
11 changes: 11 additions & 0 deletions pkg/server/routes/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit e350e44

Please sign in to comment.