Skip to content

Commit

Permalink
Move GET /mlflow/model-versions/get endpoint. (#99)
Browse files Browse the repository at this point in the history
Signed-off-by: Software Developer <[email protected]>
Signed-off-by: dsuhinin <[email protected]>
Signed-off-by: Software Developer <[email protected]>
Co-authored-by: DSuhinin <[email protected]>
  • Loading branch information
dsuhinin and DSuhinin authored Jan 7, 2025
1 parent 22619bd commit 5a9605f
Show file tree
Hide file tree
Showing 28 changed files with 885 additions and 570 deletions.
2 changes: 1 addition & 1 deletion magefiles/generate/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{
"updateModelVersion",
"transitionModelVersionStage",
"deleteModelVersion",
// "getModelVersion",
"getModelVersion",
// "searchModelVersions",
// "getModelVersionDownloadUri",
// "setRegisteredModelTag",
Expand Down
Empty file added mlflow.db
Empty file.
11 changes: 11 additions & 0 deletions mlflow_go/store/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DeleteModelVersion,
DeleteRegisteredModel,
GetLatestVersions,
GetModelVersion,
GetRegisteredModel,
RenameRegisteredModel,
TransitionModelVersionStage,
Expand Down Expand Up @@ -86,6 +87,16 @@ def delete_model_version(self, name, version):
request = DeleteModelVersion(name=name, version=str(version))
self.service.call_endpoint(get_lib().ModelRegistryServiceDeleteModelVersion, request)

def get_model_version(self, name, version):
request = GetModelVersion(name=name, version=str(version))
response = self.service.call_endpoint(
get_lib().ModelRegistryServiceGetModelVersion, request
)
entity = ModelVersion.from_proto(response.model_version)
if entity.description == "":
entity.description = None
return entity

def update_model_version(self, name, version, description=None):
request = UpdateModelVersion(name=name, version=str(version), description=description)
self.service.call_endpoint(get_lib().ModelRegistryServiceUpdateModelVersion, request)
Expand Down
1 change: 1 addition & 0 deletions pkg/contract/service/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 32 additions & 3 deletions pkg/entities/model_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,40 @@ type ModelVersion struct {
StatusMessage string
RunLink string
StorageLocation string
Tags []*ModelVersionTag
Aliases []*RegisteredModelAlias
}

func (mv ModelVersion) ToProto() *protos.ModelVersion {
return &protos.ModelVersion{
Version: utils.PtrTo(strconv.Itoa(int(mv.Version))),
CurrentStage: utils.PtrTo(mv.CurrentStage),
modelVersion := protos.ModelVersion{
Name: utils.PtrTo(mv.Name),
Version: utils.PtrTo(strconv.Itoa(int(mv.Version))),
CurrentStage: utils.PtrTo(mv.CurrentStage),
CreationTimestamp: utils.PtrTo(mv.CreationTime),
LastUpdatedTimestamp: utils.PtrTo(mv.LastUpdatedTime),
Description: utils.PtrTo(mv.Description),
UserId: utils.PtrTo(mv.UserID),
Source: utils.PtrTo(mv.Source),
Status: utils.PtrTo(protos.ModelVersionStatus(protos.ModelVersionStatus_value[mv.Status])),
Tags: make([]*protos.ModelVersionTag, 0, len(mv.Tags)),
RunLink: utils.PtrTo(mv.RunLink),
}

if mv.RunID != "" {
modelVersion.RunId = utils.PtrTo(mv.RunID)
}

if mv.StatusMessage != "" {
modelVersion.StatusMessage = utils.PtrTo(mv.StatusMessage)
}

for _, tag := range mv.Tags {
modelVersion.Tags = append(modelVersion.Tags, tag.ToProto())
}

for _, alias := range mv.Aliases {
modelVersion.Aliases = append(modelVersion.Aliases, alias.Alias)
}

return &modelVersion
}
18 changes: 18 additions & 0 deletions pkg/entities/model_version_tag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package entities

import (
"github.com/mlflow/mlflow-go/pkg/protos"
"github.com/mlflow/mlflow-go/pkg/utils"
)

type ModelVersionTag struct {
Key string
Value string
}

func (mvt ModelVersionTag) ToProto() *protos.ModelVersionTag {
return &protos.ModelVersionTag{
Key: utils.PtrTo(mvt.Key),
Value: utils.PtrTo(mvt.Value),
}
}
8 changes: 8 additions & 0 deletions pkg/lib/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions pkg/model_registry/service/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package service
import (
"context"
"fmt"
"strconv"
"strings"

"github.com/mlflow/mlflow-go/pkg/contract"
Expand Down Expand Up @@ -90,6 +91,28 @@ func (m *ModelRegistryService) DeleteModelVersion(
return &protos.DeleteModelVersion_Response{}, nil
}

func (m *ModelRegistryService) GetModelVersion(
ctx context.Context, input *protos.GetModelVersion,
) (*protos.GetModelVersion_Response, *contract.Error) {
// by some strange reason GetModelVersion.Version has a string type so we can't apply our validation,
// that's why such a custom validation exists to satisfy Python tests.
version := input.GetVersion()
if _, err := strconv.Atoi(version); err != nil {
return nil, contract.NewErrorWith(
protos.ErrorCode_INVALID_PARAMETER_VALUE, "Model version must be an integer", err,
)
}

modelVersion, err := m.store.GetModelVersion(ctx, input.GetName(), version, true)
if err != nil {
return nil, err
}

return &protos.GetModelVersion_Response{
ModelVersion: modelVersion.ToProto(),
}, nil
}

func (m *ModelRegistryService) UpdateModelVersion(
ctx context.Context, input *protos.UpdateModelVersion,
) (*protos.UpdateModelVersion_Response, *contract.Error) {
Expand Down
37 changes: 31 additions & 6 deletions pkg/model_registry/store/sql/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,18 +264,26 @@ func (m *ModelRegistrySQLStore) DeleteRegisteredModel(ctx context.Context, name
}

func (m *ModelRegistrySQLStore) GetModelVersion(
ctx context.Context, name, version string,
ctx context.Context, name, version string, eager bool,
) (*entities.ModelVersion, *contract.Error) {
var modelVersion models.ModelVersion
if err := m.db.WithContext(

query := m.db.WithContext(
ctx,
).Where(
"name = ?", name,
).Where(
"version = ?", version,
).Where(
"current_stage != ?", models.StageDeletedInternal,
).First(
)

// preload Tags only by demand.
if eager {
query = query.Preload("Tags")
}

if err := query.First(
&modelVersion,
).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
Expand All @@ -292,6 +300,23 @@ func (m *ModelRegistrySQLStore) GetModelVersion(
)
}

var registeredModelAliases []models.RegisteredModelAlias
if err := m.db.WithContext(ctx).Where(
"name = ?", modelVersion.Name,
).Where(
"version = ?", modelVersion.Version,
).Find(
&registeredModelAliases,
).Error; err != nil {
return nil, contract.NewErrorWith(
protos.ErrorCode_INTERNAL_ERROR,
fmt.Sprintf("failed to get Registered Model Aliases by name %s and version %s", name, version),
err,
)
}

modelVersion.Aliases = append(modelVersion.Aliases, registeredModelAliases...)

return modelVersion.ToEntity(), nil
}

Expand All @@ -301,7 +326,7 @@ func (m *ModelRegistrySQLStore) DeleteModelVersion(ctx context.Context, name, ve
return err
}

modelVersion, err := m.GetModelVersion(ctx, name, version)
modelVersion, err := m.GetModelVersion(ctx, name, version, false)
if err != nil {
return err
}
Expand Down Expand Up @@ -357,7 +382,7 @@ func (m *ModelRegistrySQLStore) DeleteModelVersion(ctx context.Context, name, ve
func (m *ModelRegistrySQLStore) UpdateModelVersion(
ctx context.Context, name, version, description string,
) (*entities.ModelVersion, *contract.Error) {
modelVersion, err := m.GetModelVersion(ctx, name, version)
modelVersion, err := m.GetModelVersion(ctx, name, version, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -398,7 +423,7 @@ Valid stages are %s`,
)
}

modelVersion, err := m.GetModelVersion(ctx, name, version)
modelVersion, err := m.GetModelVersion(ctx, name, version, false)
if err != nil {
return nil, err
}
Expand Down
17 changes: 13 additions & 4 deletions pkg/model_registry/store/sql/models/model_version_tags.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
package models

import "github.com/mlflow/mlflow-go/pkg/entities"

// ModelVersionTag mapped from table <model_version_tags>.
//
//revive:disable:exported
type ModelVersionTag struct {
Key string `db:"key" gorm:"column:key;primaryKey"`
Value string `db:"value" gorm:"column:value"`
Name string `db:"name" gorm:"column:name;primaryKey"`
Version int32 `db:"version" gorm:"column:version;primaryKey"`
Key string `gorm:"column:key;primaryKey"`
Value string `gorm:"column:value"`
Name string `gorm:"column:name;primaryKey"`
Version int32 `gorm:"column:version;primaryKey"`
}

func (mvt ModelVersionTag) ToEntity() *entities.ModelVersionTag {
return &entities.ModelVersionTag{
Key: mvt.Key,
Value: mvt.Value,
}
}
46 changes: 30 additions & 16 deletions pkg/model_registry/store/sql/models/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,29 @@ import (
"github.com/mlflow/mlflow-go/pkg/utils"
)

const StageDeletedInternal = "Deleted_Internal"

// ModelVersion mapped from table <model_versions>.
//
//revive:disable:exported
type ModelVersion struct {
Name string `db:"name" gorm:"column:name;primaryKey"`
Version int32 `db:"version" gorm:"column:version;primaryKey"`
CreationTime int64 `db:"creation_time" gorm:"column:creation_time"`
LastUpdatedTime int64 `db:"last_updated_time" gorm:"column:last_updated_time"`
Description sql.NullString `db:"description" gorm:"column:description"`
UserID sql.NullString `db:"user_id" gorm:"column:user_id"`
CurrentStage ModelVersionStage `db:"current_stage" gorm:"column:current_stage"`
Source string `db:"source" gorm:"column:source"`
RunID string `db:"run_id" gorm:"column:run_id"`
Status string `db:"status" gorm:"column:status"`
StatusMessage sql.NullString `db:"status_message" gorm:"column:status_message"`
RunLink string `db:"run_link" gorm:"column:run_link"`
StorageLocation string `db:"storage_location" gorm:"column:storage_location"`
Name string `gorm:"column:name;primaryKey"`
Version int32 `gorm:"column:version;primaryKey"`
CreationTime int64 `gorm:"column:creation_time"`
LastUpdatedTime int64 `gorm:"column:last_updated_time"`
Description sql.NullString `gorm:"column:description"`
UserID sql.NullString `gorm:"column:user_id"`
CurrentStage ModelVersionStage `gorm:"column:current_stage"`
Source string `gorm:"column:source"`
RunID string `gorm:"column:run_id"`
Status string `gorm:"column:status"`
StatusMessage sql.NullString `gorm:"column:status_message"`
RunLink string `gorm:"column:run_link"`
StorageLocation string `gorm:"column:storage_location"`
Tags []ModelVersionTag `gorm:"foreignKey:Name,Version"`
Aliases []RegisteredModelAlias `gorm:"-"`
}

const StageDeletedInternal = "Deleted_Internal"

func (mv ModelVersion) ToProto() *protos.ModelVersion {
var status *protos.ModelVersionStatus
if s, ok := protos.ModelVersionStatus_value[mv.Status]; ok {
Expand All @@ -52,7 +54,7 @@ func (mv ModelVersion) ToProto() *protos.ModelVersion {
}

func (mv ModelVersion) ToEntity() *entities.ModelVersion {
return &entities.ModelVersion{
modelVersion := entities.ModelVersion{
Name: mv.Name,
Version: mv.Version,
CreationTime: mv.CreationTime,
Expand All @@ -66,5 +68,17 @@ func (mv ModelVersion) ToEntity() *entities.ModelVersion {
StatusMessage: mv.StatusMessage.String,
RunLink: mv.RunLink,
StorageLocation: mv.StorageLocation,
Tags: make([]*entities.ModelVersionTag, 0, len(mv.Tags)),
Aliases: make([]*entities.RegisteredModelAlias, 0, len(mv.Aliases)),
}

for _, tag := range mv.Tags {
modelVersion.Tags = append(modelVersion.Tags, tag.ToEntity())
}

for _, alias := range mv.Aliases {
modelVersion.Aliases = append(modelVersion.Aliases, alias.ToEntity())
}

return &modelVersion
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import (

// RegisteredModelAlias mapped from table <registered_model_aliases>.
type RegisteredModelAlias struct {
Name string `db:"name" gorm:"column:name;primaryKey"`
Alias string `db:"alias" gorm:"column:alias;primaryKey"`
Version int32 `db:"version" gorm:"column:version;not null"`
Name string `gorm:"column:name;primaryKey"`
Alias string `gorm:"column:alias;primaryKey"`
Version int32 `gorm:"column:version;not null"`
}

func (a RegisteredModelAlias) ToEntity() *entities.RegisteredModelAlias {
Expand Down
1 change: 1 addition & 0 deletions pkg/model_registry/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type ModelRegistryStore interface {
UpdateRegisteredModel(ctx context.Context, name, description string) (*entities.RegisteredModel, *contract.Error)
RenameRegisteredModel(ctx context.Context, name, newName string) (*entities.RegisteredModel, *contract.Error)
DeleteRegisteredModel(ctx context.Context, name string) *contract.Error
GetModelVersion(ctx context.Context, name, version string, eager bool) (*entities.ModelVersion, *contract.Error)
DeleteModelVersion(ctx context.Context, name, version string) *contract.Error
UpdateModelVersion(ctx context.Context, name, version, description string) (*entities.ModelVersion, *contract.Error)
TransitionModelVersionStage(
Expand Down
11 changes: 11 additions & 0 deletions pkg/server/routes/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/tracking/store/sql/models/alembic_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package models

// AlembicVersion mapped from table <alembic_version>.
type AlembicVersion struct {
VersionNum *string `db:"version_num" gorm:"column:version_num;primaryKey"`
VersionNum *string `gorm:"column:version_num;primaryKey"`
}

// TableName AlembicVersion's table name.
Expand Down
16 changes: 8 additions & 8 deletions pkg/tracking/store/sql/models/datasets.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import (

// Dataset mapped from table <datasets>.
type Dataset struct {
ID string `db:"dataset_uuid" gorm:"column:dataset_uuid;not null"`
ExperimentID int32 `db:"experiment_id" gorm:"column:experiment_id;primaryKey"`
Name string `db:"name" gorm:"column:name;primaryKey"`
Digest string `db:"digest" gorm:"column:digest;primaryKey"`
SourceType string `db:"dataset_source_type" gorm:"column:dataset_source_type;not null"`
Source string `db:"dataset_source" gorm:"column:dataset_source;not null"`
Schema string `db:"dataset_schema" gorm:"column:dataset_schema"`
Profile string `db:"dataset_profile" gorm:"column:dataset_profile"`
ID string `gorm:"column:dataset_uuid;not null"`
ExperimentID int32 `gorm:"column:experiment_id;primaryKey"`
Name string `gorm:"column:name;primaryKey"`
Digest string `gorm:"column:digest;primaryKey"`
SourceType string `gorm:"column:dataset_source_type;not null"`
Source string `gorm:"column:dataset_source;not null"`
Schema string `gorm:"column:dataset_schema"`
Profile string `gorm:"column:dataset_profile"`
}

func (d *Dataset) ToEntity() *entities.Dataset {
Expand Down
Loading

0 comments on commit 5a9605f

Please sign in to comment.