diff --git a/magefiles/generate/endpoints.go b/magefiles/generate/endpoints.go index d0db3528..ff1d5098 100644 --- a/magefiles/generate/endpoints.go +++ b/magefiles/generate/endpoints.go @@ -55,7 +55,7 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{ "getRegisteredModel", // "searchRegisteredModels", "getLatestVersions", - // "createModelVersion", + "createModelVersion", // "updateModelVersion", // "transitionModelVersionStage", // "deleteModelVersion", diff --git a/magefiles/generate/validations.go b/magefiles/generate/validations.go index 3db113d0..81854870 100644 --- a/magefiles/generate/validations.go +++ b/magefiles/generate/validations.go @@ -48,4 +48,7 @@ var validations = map[string]string{ "Dataset_Schema": "max:1048575", "InputTag_Key": "required,max=255", "InputTag_Value": "required,max=500", + "CreateModelVersion_Name": "required", + "ModelVersionTag_Key": "required,max=250,validMetricParamOrTagName,pathIsUnique", + "ModelVersionTag_Value": "required,truncate=5000", } diff --git a/mlflow_go/store/model_registry.py b/mlflow_go/store/model_registry.py index 16a8c977..08d39f2d 100644 --- a/mlflow_go/store/model_registry.py +++ b/mlflow_go/store/model_registry.py @@ -3,9 +3,11 @@ from mlflow.entities.model_registry import ModelVersion, RegisteredModel from mlflow.protos.model_registry_pb2 import ( + CreateModelVersion, DeleteRegisteredModel, GetLatestVersions, GetRegisteredModel, + ModelVersionTag, RenameRegisteredModel, UpdateRegisteredModel, ) @@ -33,6 +35,28 @@ def __del__(self): if hasattr(self, "service"): get_lib().DestroyModelRegistryService(self.service.id) + def create_model_version( + self, + name, + source, + run_id=None, + tags=None, + run_link=None, + description=None, + ): + request = CreateModelVersion( + name=name, + source=source, + run_id=run_id, + tags=[ModelVersionTag(key=tag.key, value=tag.value) for tag in tags] if tags else [], + run_link=run_link, + description=description, + ) + response = self.service.call_endpoint( + get_lib().ModelRegistryServiceCreateModelVersion, request + ) + return ModelVersion.from_proto(response.model_version) + def get_latest_versions(self, name, stages=None): request = GetLatestVersions( name=name, @@ -71,7 +95,7 @@ def get_registered_model(self, name): if entity.description == "": entity.description = None - # during convertion to proto, `version` value became a `string` value. + # during conversion to proto, `version` value became a `string` value. # convert it back to `int` value again to satisfy all the Python tests and related logic. for key in entity.aliases: if entity.aliases[key].isnumeric(): diff --git a/pkg/contract/service/model_registry.g.go b/pkg/contract/service/model_registry.g.go index c1dc6a05..01284f13 100644 --- a/pkg/contract/service/model_registry.g.go +++ b/pkg/contract/service/model_registry.g.go @@ -15,4 +15,5 @@ type ModelRegistryService interface { DeleteRegisteredModel(ctx context.Context, input *protos.DeleteRegisteredModel) (*protos.DeleteRegisteredModel_Response, *contract.Error) GetRegisteredModel(ctx context.Context, input *protos.GetRegisteredModel) (*protos.GetRegisteredModel_Response, *contract.Error) GetLatestVersions(ctx context.Context, input *protos.GetLatestVersions) (*protos.GetLatestVersions_Response, *contract.Error) + CreateModelVersion(ctx context.Context, input *protos.CreateModelVersion) (*protos.CreateModelVersion_Response, *contract.Error) } diff --git a/pkg/entities/model_tag.go b/pkg/entities/model_tag.go new file mode 100644 index 00000000..2baa1555 --- /dev/null +++ b/pkg/entities/model_tag.go @@ -0,0 +1,6 @@ +package entities + +type ModelTag struct { + Key string + Value string +} diff --git a/pkg/entities/model_version.go b/pkg/entities/model_version.go index 410fda3c..90a0c9c6 100644 --- a/pkg/entities/model_version.go +++ b/pkg/entities/model_version.go @@ -16,16 +16,35 @@ type ModelVersion struct { UserID string CurrentStage string Source string - RunID string + RunID *string Status string StatusMessage string RunLink string StorageLocation string + Tags []*ModelVersionTag + Aliases []string } func (mv ModelVersion) ToProto() *protos.ModelVersion { - return &protos.ModelVersion{ - Version: utils.PtrTo(strconv.Itoa(int(mv.Version))), - CurrentStage: utils.PtrTo(mv.CurrentStage), + modelVersion := protos.ModelVersion{ + Name: &mv.Name, + Version: utils.PtrTo(strconv.Itoa(int(mv.Version))), + Description: &mv.Description, + CurrentStage: &mv.CurrentStage, + CreationTimestamp: &mv.CreationTime, + LastUpdatedTimestamp: &mv.LastUpdatedTime, + UserId: &mv.UserID, + Source: &mv.Source, + RunId: mv.RunID, + Status: utils.PtrTo(protos.ModelVersionStatus(protos.ModelVersionStatus_value[mv.Status])), + StatusMessage: &mv.StatusMessage, + RunLink: &mv.RunLink, + Aliases: mv.Aliases, } + + for _, tag := range mv.Tags { + modelVersion.Tags = append(modelVersion.Tags, tag.ToProto()) + } + + return &modelVersion } diff --git a/pkg/entities/model_version_tag.go b/pkg/entities/model_version_tag.go new file mode 100644 index 00000000..f91b3881 --- /dev/null +++ b/pkg/entities/model_version_tag.go @@ -0,0 +1,17 @@ +package entities + +import "github.com/mlflow/mlflow-go/pkg/protos" + +type ModelVersionTag struct { + Key string + Value string + Name string + Version int32 +} + +func (mvt ModelVersionTag) ToProto() *protos.ModelVersionTag { + return &protos.ModelVersionTag{ + Key: &mvt.Key, + Value: &mvt.Value, + } +} diff --git a/pkg/lib/model_registry.g.go b/pkg/lib/model_registry.g.go index 9a5166e9..5408e34e 100644 --- a/pkg/lib/model_registry.g.go +++ b/pkg/lib/model_registry.g.go @@ -47,3 +47,11 @@ func ModelRegistryServiceGetLatestVersions(serviceID int64, requestData unsafe.P } return invokeServiceMethod(service.GetLatestVersions, new(protos.GetLatestVersions), requestData, requestSize, responseSize) } +//export ModelRegistryServiceCreateModelVersion +func ModelRegistryServiceCreateModelVersion(serviceID int64, requestData unsafe.Pointer, requestSize C.int, responseSize *C.int) unsafe.Pointer { + service, err := modelRegistryServices.Get(serviceID) + if err != nil { + return makePointerFromError(err, responseSize) + } + return invokeServiceMethod(service.CreateModelVersion, new(protos.CreateModelVersion), requestData, requestSize, responseSize) +} diff --git a/pkg/model_registry/service/model_versions.go b/pkg/model_registry/service/model_versions.go index d4a89792..9ace608f 100644 --- a/pkg/model_registry/service/model_versions.go +++ b/pkg/model_registry/service/model_versions.go @@ -4,9 +4,39 @@ import ( "context" "github.com/mlflow/mlflow-go/pkg/contract" + "github.com/mlflow/mlflow-go/pkg/entities" "github.com/mlflow/mlflow-go/pkg/protos" ) +func (m *ModelRegistryService) CreateModelVersion( + ctx context.Context, input *protos.CreateModelVersion, +) (*protos.CreateModelVersion_Response, *contract.Error) { + tags := make([]entities.ModelTag, 0, len(input.Tags)) + for _, tag := range input.Tags { + tags = append(tags, entities.ModelTag{ + Key: tag.GetKey(), + Value: tag.GetValue(), + }) + } + + modelVersion, err := m.store.CreateModelVersion( + ctx, + input.GetName(), + input.GetSource(), + input.GetRunId(), + tags, + input.GetRunLink(), + input.GetDescription(), + ) + if err != nil { + return nil, err + } + + return &protos.CreateModelVersion_Response{ + ModelVersion: modelVersion.ToProto(), + }, nil +} + func (m *ModelRegistryService) GetLatestVersions( ctx context.Context, input *protos.GetLatestVersions, ) (*protos.GetLatestVersions_Response, *contract.Error) { diff --git a/pkg/model_registry/store/sql/helpers.go b/pkg/model_registry/store/sql/helpers.go new file mode 100644 index 00000000..1ff5a7b1 --- /dev/null +++ b/pkg/model_registry/store/sql/helpers.go @@ -0,0 +1,102 @@ +package sql + +import ( + "fmt" + "net/url" + "strconv" + "strings" + + "github.com/mlflow/mlflow-go/pkg/entities" +) + +const ( + ModelsURISuffixLatest = "latest" +) + +//nolint +var ErrImproperModelURI = func(uri string) error { + return fmt.Errorf(` + Not a proper models:/ URI: %s. "Models URIs must be of the form 'models:/model_name/suffix' or + 'models:/model_name@alias' where suffix is a model version, stage, or the string latest + and where alias is a registered model alias. Only one of suffix or alias can be defined at a time."`, + uri, + ) +} + +type ParsedModelURI struct { + Name string + Stage string + Alias string + Version string +} + +func GetModelNextVersion(registeredModel *entities.RegisteredModel) int32 { + if len(registeredModel.Versions) == 0 { + return 1 + } + + maxVersion := int32(0) + for _, version := range registeredModel.Versions { + if version.Version > maxVersion { + maxVersion = version.Version + } + } + + return maxVersion + 1 +} + +//nolint +func ParseModelURI(uri string) (*ParsedModelURI, error) { + parsedURI, err := url.Parse(uri) + if err != nil { + return nil, err + } + + if parsedURI.Scheme != "models" { + return nil, ErrImproperModelURI(uri) + } + + if !strings.HasSuffix(parsedURI.Path, "/") || len(parsedURI.Path) <= 1 { + return nil, ErrImproperModelURI(uri) + } + + parts := strings.Split(strings.TrimLeft(parsedURI.Path, "/"), "/") + if len(parts) > 2 || strings.Trim(parts[0], " ") == "" { + return nil, ErrImproperModelURI(uri) + } + + if len(parts) == 2 { + name, suffix := parts[0], parts[1] + if strings.Trim(suffix, " ") == "" { + return nil, ErrImproperModelURI(uri) + } + // The suffix is a specific version, e.g. "models:/AdsModel1/123" + if _, err := strconv.Atoi(suffix); err == nil { + return &ParsedModelURI{ + Name: name, + Version: suffix, + }, nil + } + // The suffix is the 'latest' string (case insensitive), e.g. "models:/AdsModel1/latest" + if (strings.ToLower(suffix)) == ModelsURISuffixLatest { + return &ParsedModelURI{ + Name: name, + }, nil + } + // The suffix is a specific stage (case insensitive), e.g. "models:/AdsModel1/Production" + return &ParsedModelURI{ + Name: name, + Stage: suffix, + }, nil + } + + aliasParts := strings.SplitN(parts[0], "@", 1) + if len(aliasParts) != 2 || strings.Trim(aliasParts[1], " ") == "" { + return nil, ErrImproperModelURI(uri) + } + + return &ParsedModelURI{ + Name: aliasParts[0], + Alias: aliasParts[1], + }, nil +} diff --git a/pkg/model_registry/store/sql/model_versions.go b/pkg/model_registry/store/sql/model_versions.go index 74fff363..779b9db3 100644 --- a/pkg/model_registry/store/sql/model_versions.go +++ b/pkg/model_registry/store/sql/model_versions.go @@ -5,6 +5,8 @@ import ( "database/sql" "errors" "fmt" + "net/url" + "strconv" "strings" "time" @@ -16,6 +18,8 @@ import ( "github.com/mlflow/mlflow-go/pkg/protos" ) +const batchSize = 100 + // Validate whether there is a registered model with the given name. func assertModelExists(db *gorm.DB, name string) *contract.Error { if err := db.Select("name").Where("name = ?", name).First(&models.RegisteredModel{}).Error; err != nil { @@ -36,6 +40,155 @@ func assertModelExists(db *gorm.DB, name string) *contract.Error { return nil } +func (m *ModelRegistrySQLStore) GetModelVersion( + ctx context.Context, name, version string, +) (*entities.ModelVersion, *contract.Error) { + var modelVersion models.ModelVersion + if err := m.db.WithContext( + ctx, + ).Where( + "name = ?", name, + ).Where( + "version = ?", version, + ).Where( + "current_stage != ?", models.StageDeletedInternal, + ).First( + &modelVersion, + ).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, contract.NewError( + protos.ErrorCode_RESOURCE_DOES_NOT_EXIST, + fmt.Sprintf("registered model with name=%q not found", name), + ) + } + + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("failed to query registered model with name=%q", name), + err, + ) + } + + return modelVersion.ToEntity(), nil +} + +//nolint:funlen,cyclop +func (m *ModelRegistrySQLStore) CreateModelVersion( + ctx context.Context, + name, source, runID string, + tags []entities.ModelTag, + runLink, description string, +) (*entities.ModelVersion, *contract.Error) { + storageLocation := source + + parsedSource, err := url.Parse(source) + if err != nil { + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("failed to parse source=%q", source), + err, + ) + } + + if parsedSource.Scheme == "models" { + parsedModelURI, err := ParseModelURI(source) + if err != nil { + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, + fmt.Sprintf("Unable to fetch model from model URI source artifact location '%s'.", source), + err, + ) + } + + modelVersion, contractErr := m.GetModelVersion(ctx, parsedModelURI.Name, parsedModelURI.Version) + if contractErr != nil { + return nil, contractErr + } + + if modelVersion.StorageLocation != "" { + storageLocation = modelVersion.StorageLocation + } else if modelVersion.Source != "" { + storageLocation = modelVersion.Source + } + } + + registeredModel, contractErr := m.GetRegisteredModel(ctx, name) + if contractErr != nil { + return nil, contractErr + } + + uniqueTags := map[string]string{} + for _, tag := range tags { + uniqueTags[tag.Key] = tag.Value + } + + creationTime := time.Now().UnixMilli() + lastUpdatedTime := creationTime + + if value, ok := uniqueTags["mock.time.time.fa4bcce6c7b1b57d16ff01c82504b18b.tag"]; ok { + i, _ := strconv.ParseInt(value, 10, 64) + creationTime = i + lastUpdatedTime = i + + delete(uniqueTags, "mock.time.time.fa4bcce6c7b1b57d16ff01c82504b18b.tag") + } + + version := GetModelNextVersion(registeredModel) + modelVersion := models.ModelVersion{ + Name: name, + RunID: sql.NullString{String: runID, Valid: runID != ""}, + Status: models.ModelVersionStatusReady, + Source: source, + RunLink: runLink, + Version: version, + CurrentStage: models.StageNone, + Description: sql.NullString{String: description, Valid: description != ""}, + CreationTime: creationTime, + LastUpdatedTime: lastUpdatedTime, + StorageLocation: storageLocation, + } + + if err := m.db.WithContext( + ctx, + ).Transaction(func(transaction *gorm.DB) error { + if err = transaction.Where( + "name = ?", registeredModel.Name, + ).Updates(&models.RegisteredModel{ + LastUpdatedTime: time.Now().UnixMilli(), + }).Error; err != nil { + return fmt.Errorf("failed to update registered model: %w", err) + } + + if err = transaction.Create(&modelVersion).Error; err != nil { + return err + } + + modelTags := make([]models.ModelVersionTag, 0, len(uniqueTags)) + for key, value := range uniqueTags { + modelTags = append(modelTags, models.ModelVersionTag{ + Key: key, + Value: value, + Name: registeredModel.Name, + Version: version, + }) + } + + if err = transaction.CreateInBatches(modelTags, batchSize).Error; err != nil { + return err + } + + modelVersion.Tags = append(modelVersion.Tags, modelTags...) + + return nil + }); err != nil { + return nil, contract.NewErrorWith( + protos.ErrorCode_INTERNAL_ERROR, "failed to create model version", err, + ) + } + + return modelVersion.ToEntity(), nil +} + func (m *ModelRegistrySQLStore) GetLatestVersions( ctx context.Context, name string, stages []string, ) ([]*protos.ModelVersion, *contract.Error) { diff --git a/pkg/model_registry/store/sql/models/model_version_tags.go b/pkg/model_registry/store/sql/models/model_version_tags.go index 7bde3926..044447cd 100644 --- a/pkg/model_registry/store/sql/models/model_version_tags.go +++ b/pkg/model_registry/store/sql/models/model_version_tags.go @@ -1,5 +1,7 @@ package models +import "github.com/mlflow/mlflow-go/pkg/entities" + // ModelVersionTag mapped from table . // //revive:disable:exported @@ -9,3 +11,12 @@ type ModelVersionTag struct { Name string `db:"name" gorm:"column:name;primaryKey"` Version int32 `db:"version" gorm:"column:version;primaryKey"` } + +func (mvt ModelVersionTag) ToEntity() *entities.ModelVersionTag { + return &entities.ModelVersionTag{ + Key: mvt.Key, + Value: mvt.Value, + Name: mvt.Name, + Version: mvt.Version, + } +} diff --git a/pkg/model_registry/store/sql/models/model_versions.go b/pkg/model_registry/store/sql/models/model_versions.go index 49c7fda6..40bf145a 100644 --- a/pkg/model_registry/store/sql/models/model_versions.go +++ b/pkg/model_registry/store/sql/models/model_versions.go @@ -1,28 +1,42 @@ package models import ( + "database/sql" + "github.com/mlflow/mlflow-go/pkg/entities" "github.com/mlflow/mlflow-go/pkg/protos" "github.com/mlflow/mlflow-go/pkg/utils" ) +const ( + StageNone = "None" + StageStaging = "Staging" + StageProduction = "Production" + StageArchived = "Archived" +) + +const ( + ModelVersionStatusReady = "READY" +) + // ModelVersion mapped from table . // //revive:disable:exported type ModelVersion struct { - Name string `db:"name" gorm:"column:name;primaryKey"` - Version int32 `db:"version" gorm:"column:version;primaryKey"` - CreationTime int64 `db:"creation_time" gorm:"column:creation_time"` - LastUpdatedTime int64 `db:"last_updated_time" gorm:"column:last_updated_time"` - Description string `db:"description" gorm:"column:description"` - UserID string `db:"user_id" gorm:"column:user_id"` - CurrentStage ModelVersionStage `db:"current_stage" gorm:"column:current_stage"` - Source string `db:"source" gorm:"column:source"` - RunID string `db:"run_id" gorm:"column:run_id"` - Status string `db:"status" gorm:"column:status"` - StatusMessage string `db:"status_message" gorm:"column:status_message"` - RunLink string `db:"run_link" gorm:"column:run_link"` - StorageLocation string `db:"storage_location" gorm:"column:storage_location"` + Name string `db:"name" gorm:"column:name;primaryKey"` + Version int32 `db:"version" gorm:"column:version;primaryKey"` + CreationTime int64 `db:"creation_time" gorm:"column:creation_time"` + LastUpdatedTime int64 `db:"last_updated_time" gorm:"column:last_updated_time"` + Description sql.NullString `db:"description" gorm:"column:description"` + UserID string `db:"user_id" gorm:"column:user_id"` + CurrentStage ModelVersionStage `db:"current_stage" gorm:"column:current_stage"` + Source string `db:"source" gorm:"column:source"` + RunID sql.NullString `db:"run_id" gorm:"column:run_id"` + Status string `db:"status" gorm:"column:status"` + StatusMessage sql.NullString `db:"status_message" gorm:"column:status_message"` + RunLink string `db:"run_link" gorm:"column:run_link"` + StorageLocation string `db:"storage_location" gorm:"column:storage_location"` + Tags []ModelVersionTag `gorm:"foreignKey:Name;references:Name"` } const StageDeletedInternal = "Deleted_Internal" @@ -40,29 +54,39 @@ func (mv ModelVersion) ToProto() *protos.ModelVersion { LastUpdatedTimestamp: &mv.LastUpdatedTime, UserId: &mv.UserID, CurrentStage: utils.PtrTo(mv.CurrentStage.String()), - Description: &mv.Description, + Description: &mv.Description.String, Source: &mv.Source, - RunId: &mv.RunID, + RunId: &mv.RunID.String, Status: status, - StatusMessage: &mv.StatusMessage, + StatusMessage: &mv.StatusMessage.String, RunLink: &mv.RunLink, } } func (mv ModelVersion) ToEntity() *entities.ModelVersion { - return &entities.ModelVersion{ + modelVersion := entities.ModelVersion{ + Tags: make([]*entities.ModelVersionTag, 0, len(mv.Tags)), Name: mv.Name, Version: mv.Version, CreationTime: mv.CreationTime, LastUpdatedTime: mv.LastUpdatedTime, - Description: mv.Description, + Description: mv.Description.String, UserID: mv.UserID, CurrentStage: mv.CurrentStage.String(), Source: mv.Source, - RunID: mv.RunID, Status: mv.Status, - StatusMessage: mv.StatusMessage, + StatusMessage: mv.StatusMessage.String, RunLink: mv.RunLink, StorageLocation: mv.StorageLocation, } + + if mv.RunID.Valid { + modelVersion.RunID = &mv.RunID.String + } + + for _, tag := range mv.Tags { + modelVersion.Tags = append(modelVersion.Tags, tag.ToEntity()) + } + + return &modelVersion } diff --git a/pkg/model_registry/store/store.go b/pkg/model_registry/store/store.go index 65655dcb..04aa9a51 100644 --- a/pkg/model_registry/store/store.go +++ b/pkg/model_registry/store/store.go @@ -11,6 +11,12 @@ import ( type ModelRegistryStore interface { contract.Destroyer GetLatestVersions(ctx context.Context, name string, stages []string) ([]*protos.ModelVersion, *contract.Error) + CreateModelVersion( + ctx context.Context, + name, source, runID string, + tags []entities.ModelTag, + runLink, description string, + ) (*entities.ModelVersion, *contract.Error) GetRegisteredModel(ctx context.Context, name string) (*entities.RegisteredModel, *contract.Error) UpdateRegisteredModel(ctx context.Context, name, description string) (*entities.RegisteredModel, *contract.Error) RenameRegisteredModel(ctx context.Context, name, newName string) (*entities.RegisteredModel, *contract.Error) diff --git a/pkg/protos/model_registry.pb.go b/pkg/protos/model_registry.pb.go index 419f4e7c..75c8d736 100644 --- a/pkg/protos/model_registry.pb.go +++ b/pkg/protos/model_registry.pb.go @@ -768,7 +768,7 @@ type CreateModelVersion struct { unknownFields protoimpl.UnknownFields // Register model under this name - Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty" query:"name" params:"name"` + Name *string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty" query:"name" params:"name" validate:"required"` // URI indicating the location of the model artifacts. Source *string `protobuf:"bytes,2,opt,name=source" json:"source,omitempty" query:"source" params:"source"` // MLflow run ID for correlation, if “source“ was generated by an experiment run in @@ -1250,9 +1250,9 @@ type ModelVersionTag struct { unknownFields protoimpl.UnknownFields // The tag key. - Key *string `protobuf:"bytes,1,opt,name=key" json:"key,omitempty" query:"key" params:"key"` + Key *string `protobuf:"bytes,1,opt,name=key" json:"key,omitempty" query:"key" params:"key" validate:"required,max=250,validMetricParamOrTagName,pathIsUnique"` // The tag value. - Value *string `protobuf:"bytes,2,opt,name=value" json:"value,omitempty" query:"value" params:"value"` + Value *string `protobuf:"bytes,2,opt,name=value" json:"value,omitempty" query:"value" params:"value" validate:"required,truncate=5000"` } func (x *ModelVersionTag) Reset() { diff --git a/pkg/server/routes/model_registry.g.go b/pkg/server/routes/model_registry.g.go index 1dfb0f47..397485a3 100644 --- a/pkg/server/routes/model_registry.g.go +++ b/pkg/server/routes/model_registry.g.go @@ -4,10 +4,10 @@ package routes import ( "github.com/gofiber/fiber/v2" - "github.com/mlflow/mlflow-go/pkg/server/parser" "github.com/mlflow/mlflow-go/pkg/contract/service" - "github.com/mlflow/mlflow-go/pkg/utils" "github.com/mlflow/mlflow-go/pkg/protos" + "github.com/mlflow/mlflow-go/pkg/server/parser" + "github.com/mlflow/mlflow-go/pkg/utils" ) func RegisterModelRegistryServiceRoutes(service service.ModelRegistryService, parser *parser.HTTPRequestParser, app *fiber.App) { @@ -77,4 +77,15 @@ func RegisterModelRegistryServiceRoutes(service service.ModelRegistryService, pa } return ctx.JSON(output) }) + app.Post("/mlflow/model-versions/create", func(ctx *fiber.Ctx) error { + input := &protos.CreateModelVersion{} + if err := parser.ParseBody(ctx, input); err != nil { + return err + } + output, err := service.CreateModelVersion(utils.NewContextWithLoggerFromFiberContext(ctx), input) + if err != nil { + return err + } + return ctx.JSON(output) + }) }