Skip to content

Commit

Permalink
move DELETE /mlflow/model-versions/delete endpoint. (#96)
Browse files Browse the repository at this point in the history
Signed-off-by: Software Developer <[email protected]>
  • Loading branch information
dsuhinin authored Dec 11, 2024
1 parent 21dfd2e commit ab5854f
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 12 deletions.
2 changes: 1 addition & 1 deletion magefiles/generate/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{
// "createModelVersion",
// "updateModelVersion",
// "transitionModelVersionStage",
// "deleteModelVersion",
"deleteModelVersion",
// "getModelVersion",
// "searchModelVersions",
// "getModelVersionDownloadUri",
Expand Down
7 changes: 6 additions & 1 deletion mlflow_go/store/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from mlflow.entities.model_registry import ModelVersion, RegisteredModel
from mlflow.protos.model_registry_pb2 import (
DeleteModelVersion,
DeleteRegisteredModel,
GetLatestVersions,
GetRegisteredModel,
Expand Down Expand Up @@ -71,14 +72,18 @@ def get_registered_model(self, name):
if entity.description == "":
entity.description = None

# during convertion to proto, `version` value became a `string` value.
# during conversion to proto, `version` value became a `string` value.
# convert it back to `int` value again to satisfy all the Python tests and related logic.
for key in entity.aliases:
if entity.aliases[key].isnumeric():
entity.aliases[key] = int(entity.aliases[key])

return entity

def delete_model_version(self, name, version):
request = DeleteModelVersion(name=name, version=str(version))
self.service.call_endpoint(get_lib().ModelRegistryServiceDeleteModelVersion, request)


def ModelRegistryStore(cls):
return type(cls.__name__, (_ModelRegistryStore, cls), {})
Expand Down
1 change: 1 addition & 0 deletions pkg/contract/service/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions pkg/lib/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions pkg/model_registry/service/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,13 @@ func (m *ModelRegistryService) GetRegisteredModel(
RegisteredModel: registeredModel.ToProto(),
}, nil
}

func (m *ModelRegistryService) DeleteModelVersion(
ctx context.Context, input *protos.DeleteModelVersion,
) (*protos.DeleteModelVersion_Response, *contract.Error) {
if err := m.store.DeleteModelVersion(ctx, input.GetName(), input.GetVersion()); err != nil {
return nil, err
}

return &protos.DeleteModelVersion_Response{}, nil
}
93 changes: 92 additions & 1 deletion pkg/model_registry/store/sql/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (m *ModelRegistrySQLStore) GetRegisteredModel(
//nolint:perfsprint
return nil, contract.NewErrorWith(
protos.ErrorCode_INTERNAL_ERROR,
fmt.Sprintf("failed to get experiment by name %s", name),
fmt.Sprintf("failed to get Registered Model by name %s", name),
err,
)
}
Expand Down Expand Up @@ -262,3 +262,94 @@ func (m *ModelRegistrySQLStore) DeleteRegisteredModel(ctx context.Context, name

return nil
}

func (m *ModelRegistrySQLStore) GetModelVersion(
ctx context.Context, name, version string,
) (*entities.ModelVersion, *contract.Error) {
var modelVersion models.ModelVersion
if err := m.db.WithContext(
ctx,
).Where(
"name = ?", name,
).Where(
"version = ?", version,
).Where(
"current_stage != ?", models.StageDeletedInternal,
).First(
&modelVersion,
).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, contract.NewError(
protos.ErrorCode_RESOURCE_DOES_NOT_EXIST,
fmt.Sprintf("Model Version (name=%s, version=%s) not found", name, version),
)
}

return nil, contract.NewErrorWith(
protos.ErrorCode_INTERNAL_ERROR,
fmt.Sprintf("failed to get Model Version by name %s and version %s", name, version),
err,
)
}

return modelVersion.ToEntity(), nil
}

func (m *ModelRegistrySQLStore) DeleteModelVersion(ctx context.Context, name, version string) *contract.Error {
registeredModel, err := m.GetRegisteredModel(ctx, name)
if err != nil {
return err
}

modelVersion, err := m.GetModelVersion(ctx, name, version)
if err != nil {
return err
}

if err := m.db.WithContext(ctx).Transaction(func(transaction *gorm.DB) error {
if err := transaction.Model(
&models.RegisteredModel{},
).Where(
"name = ?", registeredModel.Name,
).Updates(&models.RegisteredModel{
LastUpdatedTime: time.Now().UnixMilli(),
}).Error; err != nil {
return err
}

if err := transaction.Where(
"name = ?", registeredModel.Name,
).Where(
"version = ?", version,
).Delete(
&models.RegisteredModelAlias{},
).Error; err != nil {
return err
}

if err := transaction.Model(
&models.ModelVersion{},
).Where(
"name = ?", modelVersion.Name,
).Where(
"version = ?", modelVersion.Version,
).Updates(&models.ModelVersion{
RunID: "REDACTED-RUN-ID",
UserID: sql.NullString{Valid: true},
Source: "REDACTED-SOURCE-PATH",
RunLink: "REDACTED-RUN-LINK",
CurrentStage: models.StageDeletedInternal,
Description: sql.NullString{Valid: true},
StatusMessage: sql.NullString{Valid: true},
LastUpdatedTime: time.Now().UnixMilli(),
}).Error; err != nil {
return err
}

return nil
}); err != nil {
return contract.NewErrorWith(protos.ErrorCode_INTERNAL_ERROR, "error deleting model version", err)
}

return nil
}
20 changes: 11 additions & 9 deletions pkg/model_registry/store/sql/models/model_versions.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package models

import (
"database/sql"

"github.com/mlflow/mlflow-go/pkg/entities"
"github.com/mlflow/mlflow-go/pkg/protos"
"github.com/mlflow/mlflow-go/pkg/utils"
Expand All @@ -14,13 +16,13 @@ type ModelVersion struct {
Version int32 `db:"version" gorm:"column:version;primaryKey"`
CreationTime int64 `db:"creation_time" gorm:"column:creation_time"`
LastUpdatedTime int64 `db:"last_updated_time" gorm:"column:last_updated_time"`
Description string `db:"description" gorm:"column:description"`
UserID string `db:"user_id" gorm:"column:user_id"`
Description sql.NullString `db:"description" gorm:"column:description"`
UserID sql.NullString `db:"user_id" gorm:"column:user_id"`
CurrentStage ModelVersionStage `db:"current_stage" gorm:"column:current_stage"`
Source string `db:"source" gorm:"column:source"`
RunID string `db:"run_id" gorm:"column:run_id"`
Status string `db:"status" gorm:"column:status"`
StatusMessage string `db:"status_message" gorm:"column:status_message"`
StatusMessage sql.NullString `db:"status_message" gorm:"column:status_message"`
RunLink string `db:"run_link" gorm:"column:run_link"`
StorageLocation string `db:"storage_location" gorm:"column:storage_location"`
}
Expand All @@ -38,13 +40,13 @@ func (mv ModelVersion) ToProto() *protos.ModelVersion {
Version: utils.ConvertInt32PointerToStringPointer(&mv.Version),
CreationTimestamp: &mv.CreationTime,
LastUpdatedTimestamp: &mv.LastUpdatedTime,
UserId: &mv.UserID,
UserId: &mv.UserID.String,
CurrentStage: utils.PtrTo(mv.CurrentStage.String()),
Description: &mv.Description,
Description: &mv.Description.String,
Source: &mv.Source,
RunId: &mv.RunID,
Status: status,
StatusMessage: &mv.StatusMessage,
StatusMessage: &mv.StatusMessage.String,
RunLink: &mv.RunLink,
}
}
Expand All @@ -55,13 +57,13 @@ func (mv ModelVersion) ToEntity() *entities.ModelVersion {
Version: mv.Version,
CreationTime: mv.CreationTime,
LastUpdatedTime: mv.LastUpdatedTime,
Description: mv.Description,
UserID: mv.UserID,
Description: mv.Description.String,
UserID: mv.UserID.String,
CurrentStage: mv.CurrentStage.String(),
Source: mv.Source,
RunID: mv.RunID,
Status: mv.Status,
StatusMessage: mv.StatusMessage,
StatusMessage: mv.StatusMessage.String,
RunLink: mv.RunLink,
StorageLocation: mv.StorageLocation,
}
Expand Down
1 change: 1 addition & 0 deletions pkg/model_registry/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ type ModelRegistryStore interface {
UpdateRegisteredModel(ctx context.Context, name, description string) (*entities.RegisteredModel, *contract.Error)
RenameRegisteredModel(ctx context.Context, name, newName string) (*entities.RegisteredModel, *contract.Error)
DeleteRegisteredModel(ctx context.Context, name string) *contract.Error
DeleteModelVersion(ctx context.Context, name, version string) *contract.Error
}
11 changes: 11 additions & 0 deletions pkg/server/routes/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ab5854f

Please sign in to comment.