Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move POST /mlflow/model-versions/create endpoint. #93

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion magefiles/generate/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{
"getRegisteredModel",
// "searchRegisteredModels",
"getLatestVersions",
// "createModelVersion",
"createModelVersion",
// "updateModelVersion",
// "transitionModelVersionStage",
// "deleteModelVersion",
Expand Down
3 changes: 3 additions & 0 deletions magefiles/generate/validations.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,7 @@ var validations = map[string]string{
"Dataset_Schema": "max:1048575",
"InputTag_Key": "required,max=255",
"InputTag_Value": "required,max=500",
"CreateModelVersion_Name": "required",
"ModelVersionTag_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique",
"ModelVersionTag_Value": "required,truncate=5000",
}
26 changes: 25 additions & 1 deletion mlflow_go/store/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

from mlflow.entities.model_registry import ModelVersion, RegisteredModel
from mlflow.protos.model_registry_pb2 import (
CreateModelVersion,
DeleteRegisteredModel,
GetLatestVersions,
GetRegisteredModel,
ModelVersionTag,
RenameRegisteredModel,
UpdateRegisteredModel,
)
Expand Down Expand Up @@ -33,6 +35,28 @@ def __del__(self):
if hasattr(self, "service"):
get_lib().DestroyModelRegistryService(self.service.id)

def create_model_version(
self,
name,
source,
run_id=None,
tags=None,
run_link=None,
description=None,
):
request = CreateModelVersion(
name=name,
source=source,
run_id=run_id,
tags=[ModelVersionTag(key=tag.key, value=tag.value) for tag in tags] if tags else [],
run_link=run_link,
description=description,
)
response = self.service.call_endpoint(
get_lib().ModelRegistryServiceCreateModelVersion, request
)
return ModelVersion.from_proto(response.model_version)

def get_latest_versions(self, name, stages=None):
request = GetLatestVersions(
name=name,
Expand Down Expand Up @@ -71,7 +95,7 @@ def get_registered_model(self, name):
if entity.description == "":
entity.description = None

# during convertion to proto, `version` value became a `string` value.
# during conversion to proto, `version` value became a `string` value.
# convert it back to `int` value again to satisfy all the Python tests and related logic.
for key in entity.aliases:
if entity.aliases[key].isnumeric():
Expand Down
1 change: 1 addition & 0 deletions pkg/contract/service/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions pkg/entities/model_tag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package entities

type ModelTag struct {
Key string
Value string
}
27 changes: 23 additions & 4 deletions pkg/entities/model_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,35 @@ type ModelVersion struct {
UserID string
CurrentStage string
Source string
RunID string
RunID *string
Status string
StatusMessage string
RunLink string
StorageLocation string
Tags []*ModelVersionTag
Aliases []string
}

func (mv ModelVersion) ToProto() *protos.ModelVersion {
return &protos.ModelVersion{
Version: utils.PtrTo(strconv.Itoa(int(mv.Version))),
CurrentStage: utils.PtrTo(mv.CurrentStage),
modelVersion := protos.ModelVersion{
Name: &mv.Name,
Version: utils.PtrTo(strconv.Itoa(int(mv.Version))),
Description: &mv.Description,
CurrentStage: &mv.CurrentStage,
CreationTimestamp: &mv.CreationTime,
LastUpdatedTimestamp: &mv.LastUpdatedTime,
UserId: &mv.UserID,
Source: &mv.Source,
RunId: mv.RunID,
Status: utils.PtrTo(protos.ModelVersionStatus(protos.ModelVersionStatus_value[mv.Status])),
StatusMessage: &mv.StatusMessage,
RunLink: &mv.RunLink,
Aliases: mv.Aliases,
}

for _, tag := range mv.Tags {
modelVersion.Tags = append(modelVersion.Tags, tag.ToProto())
}

return &modelVersion
}
17 changes: 17 additions & 0 deletions pkg/entities/model_version_tag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package entities

import "github.com/mlflow/mlflow-go/pkg/protos"

type ModelVersionTag struct {
Key string
Value string
Name string
Version int32
}

func (mvt ModelVersionTag) ToProto() *protos.ModelVersionTag {
return &protos.ModelVersionTag{
Key: &mvt.Key,
Value: &mvt.Value,
}
}
8 changes: 8 additions & 0 deletions pkg/lib/model_registry.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions pkg/model_registry/service/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,39 @@ import (
"context"

"github.com/mlflow/mlflow-go/pkg/contract"
"github.com/mlflow/mlflow-go/pkg/entities"
"github.com/mlflow/mlflow-go/pkg/protos"
)

func (m *ModelRegistryService) CreateModelVersion(
ctx context.Context, input *protos.CreateModelVersion,
) (*protos.CreateModelVersion_Response, *contract.Error) {
tags := make([]entities.ModelTag, 0, len(input.Tags))
for _, tag := range input.Tags {
tags = append(tags, entities.ModelTag{
Key: tag.GetKey(),
Value: tag.GetValue(),
})
}

modelVersion, err := m.store.CreateModelVersion(
ctx,
input.GetName(),
input.GetSource(),
input.GetRunId(),
tags,
input.GetRunLink(),
input.GetDescription(),
)
if err != nil {
return nil, err
}

return &protos.CreateModelVersion_Response{
ModelVersion: modelVersion.ToProto(),
}, nil
}

func (m *ModelRegistryService) GetLatestVersions(
ctx context.Context, input *protos.GetLatestVersions,
) (*protos.GetLatestVersions_Response, *contract.Error) {
Expand Down
102 changes: 102 additions & 0 deletions pkg/model_registry/store/sql/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package sql

import (
"fmt"
"net/url"
"strconv"
"strings"

"github.com/mlflow/mlflow-go/pkg/entities"
)

const (
ModelsURISuffixLatest = "latest"
)

//nolint
var ErrImproperModelURI = func(uri string) error {
return fmt.Errorf(`
Not a proper models:/ URI: %s. "Models URIs must be of the form 'models:/model_name/suffix' or
'models:/model_name@alias' where suffix is a model version, stage, or the string latest
and where alias is a registered model alias. Only one of suffix or alias can be defined at a time."`,
uri,
)
}

type ParsedModelURI struct {
Name string
Stage string
Alias string
Version string
}

func GetModelNextVersion(registeredModel *entities.RegisteredModel) int32 {
if len(registeredModel.Versions) == 0 {
return 1
}

maxVersion := int32(0)
for _, version := range registeredModel.Versions {
if version.Version > maxVersion {
maxVersion = version.Version
}
}

return maxVersion + 1
}

//nolint
func ParseModelURI(uri string) (*ParsedModelURI, error) {
parsedURI, err := url.Parse(uri)
if err != nil {
return nil, err
}

if parsedURI.Scheme != "models" {
return nil, ErrImproperModelURI(uri)
}

if !strings.HasSuffix(parsedURI.Path, "/") || len(parsedURI.Path) <= 1 {
return nil, ErrImproperModelURI(uri)
}

parts := strings.Split(strings.TrimLeft(parsedURI.Path, "/"), "/")
if len(parts) > 2 || strings.Trim(parts[0], " ") == "" {
return nil, ErrImproperModelURI(uri)
}

if len(parts) == 2 {
name, suffix := parts[0], parts[1]
if strings.Trim(suffix, " ") == "" {
return nil, ErrImproperModelURI(uri)
}
// The suffix is a specific version, e.g. "models:/AdsModel1/123"
if _, err := strconv.Atoi(suffix); err == nil {
return &ParsedModelURI{
Name: name,
Version: suffix,
}, nil
}
// The suffix is the 'latest' string (case insensitive), e.g. "models:/AdsModel1/latest"
if (strings.ToLower(suffix)) == ModelsURISuffixLatest {
return &ParsedModelURI{
Name: name,
}, nil
}
// The suffix is a specific stage (case insensitive), e.g. "models:/AdsModel1/Production"
return &ParsedModelURI{
Name: name,
Stage: suffix,
}, nil
}

aliasParts := strings.SplitN(parts[0], "@", 1)
if len(aliasParts) != 2 || strings.Trim(aliasParts[1], " ") == "" {
return nil, ErrImproperModelURI(uri)
}

return &ParsedModelURI{
Name: aliasParts[0],
Alias: aliasParts[1],
}, nil
}
Loading
Loading