Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SetTag endpoint #10

Merged
merged 27 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
41c7754
Add SetTag endpoint
jescalada Oct 4, 2024
97927f6
Fix line endings to LF
jescalada Oct 4, 2024
2cc3ee1
Update mock store
jescalada Oct 8, 2024
59f30ac
Clean up unused code
jescalada Oct 8, 2024
a3235d3
Refactor tests.go
jescalada Oct 8, 2024
b810993
Remove unused validation code
jescalada Oct 8, 2024
084c471
Add SetTag struct validation
jescalada Oct 9, 2024
9039d3f
Replace runId requirement with manual check
jescalada Oct 9, 2024
fd44cbc
Clean up SetTag store
jescalada Oct 9, 2024
4344210
Minor adjustments
jescalada Oct 9, 2024
02ed6d9
Update validations
jescalada Oct 9, 2024
bc25c96
Add preliminary SetTag missing logic
jescalada Oct 11, 2024
7086ea1
Add missing generated file
jescalada Oct 11, 2024
5bc3cd0
Merge remote-tracking branch 'origin/main' into implement-set-tag
jescalada Oct 11, 2024
8eeaacb
Fix protoc version issue
jescalada Oct 11, 2024
91660ed
Fix test_update_run_name test
jescalada Oct 14, 2024
f91e43b
Add pythonSpecific to docs, normalize MLflow spelling
jescalada Oct 14, 2024
a3bd537
Override test_set_tag execution
jescalada Oct 16, 2024
b39b3f0
Merge remote-tracking branch 'origin/main' into implement-set-tag
jescalada Oct 16, 2024
19404be
Extract helper functions from SetTag
jescalada Oct 16, 2024
87c2efe
Revert unnecessary changes
jescalada Oct 16, 2024
a7b7390
Fix postCreate.sh
jescalada Oct 16, 2024
b27e4f6
Fix linter error
jescalada Oct 16, 2024
fe585cc
Fix format issue
jescalada Oct 16, 2024
2f753ad
Simplify handleRunNameUpdate
jescalada Oct 19, 2024
d6081fd
Merge branch 'main' into implement-set-tag
jescalada Oct 19, 2024
069212c
Refactor PythonSpecific
nojaf Oct 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@
// "forwardPorts": [5432],

// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand": ".devcontainer/postCreate.sh"
"postCreateCommand": ".devcontainer/postCreate.sh",
jescalada marked this conversation as resolved.
Show resolved Hide resolved

// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
"remoteUser": "root"
jescalada marked this conversation as resolved.
Show resolved Hide resolved
}
Empty file modified .devcontainer/postCreate.sh
100755 → 100644
jescalada marked this conversation as resolved.
Show resolved Hide resolved
Empty file.
4 changes: 2 additions & 2 deletions magefiles/generate/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ var ServiceInfoMap = map[string]ServiceGenerationInfo{
"logMetric",
// "logParam",
// "setExperimentTag",
// "setTag",
"setTag",
// "setTraceTag",
// "deleteTraceTag",
// "deleteTag",
"deleteTag",
"searchRuns",
// "listArtifacts",
// "getMetricHistory",
Expand Down
4 changes: 4 additions & 0 deletions magefiles/generate/validations.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,8 @@ var validations = map[string]string{
"LogMetric_Key": "required",
"LogMetric_Value": "required",
"LogMetric_Timestamp": "required",
"SetTag_RunId": "required",
"SetTag_Key": "required",
"DeleteTag_RunId": "required",
"DeleteTag_Key": "required",
}
27 changes: 27 additions & 0 deletions magefiles/tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,33 @@ func (Test) Python() error {
return nil
}

// Run specific Python test against the Go backend.
func (Test) PythonSpecific(testName string) error {
jescalada marked this conversation as resolved.
Show resolved Hide resolved
jescalada marked this conversation as resolved.
Show resolved Hide resolved
nojaf marked this conversation as resolved.
Show resolved Hide resolved
libpath, err := os.MkdirTemp("", "")
if err != nil {
return err
}

defer os.RemoveAll(libpath)
defer cleanUpMemoryFile()

if err := sh.RunV("python", "-m", "mlflow_go.lib", ".", libpath); err != nil {
return nil
}

if err := sh.RunWithV(map[string]string{
"MLFLOW_GO_LIBRARY_PATH": libpath,
}, "pytest",
"--confcutdir=.",
".mlflow.repo/tests/tracking/test_rest_tracking.py",
"-k", testName,
); err != nil {
return err
}

return nil
}

// Run the Go unit tests.
func (Test) Unit() error {
return sh.RunV("go", "test", "./pkg/...")
Expand Down
9 changes: 9 additions & 0 deletions mlflow_go/store/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
CreateRun,
DeleteExperiment,
DeleteRun,
DeleteTag,
GetExperiment,
GetExperimentByName,
GetRun,
Expand All @@ -22,6 +23,7 @@
RestoreExperiment,
RestoreRun,
SearchRuns,
SetTag,
UpdateExperiment,
UpdateRun,
)
Expand Down Expand Up @@ -165,6 +167,13 @@ def log_metric(self, run_id, metric):
)
self.service.call_endpoint(get_lib().TrackingServiceLogMetric, request)

def set_tag(self, run_id, tag):
request = SetTag(run_id=run_id, key=tag.key, value=tag.value)
self.service.call_endpoint(get_lib().TrackingServiceSetTag, request)

def delete_tag(self, run_id, key):
request = DeleteTag(run_id=run_id, key=key)
self.service.call_endpoint(get_lib().TrackingServiceDeleteTag, request)

def TrackingStore(cls):
return type(cls.__name__, (_TrackingStore, cls), {})
Expand Down
2 changes: 2 additions & 0 deletions pkg/contract/service/tracking.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions pkg/lib/tracking.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/protos/artifacts/mlflow_artifacts.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/protos/databricks.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/protos/databricks_artifacts.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/protos/internal.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/protos/model_registry.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/protos/scalapb/scalapb.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions pkg/protos/service.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions pkg/server/routes/tracking.g.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 27 additions & 0 deletions pkg/tracking/service/tags.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package service

import (
"context"
"fmt"

"github.com/mlflow/mlflow-go/pkg/contract"
"github.com/mlflow/mlflow-go/pkg/protos"
)

func (ts TrackingService) SetTag(ctx context.Context, input *protos.SetTag) (*protos.SetTag_Response, *contract.Error) {
// Print input
jescalada marked this conversation as resolved.
Show resolved Hide resolved
fmt.Println(input)
if err := ts.Store.SetTag(ctx, input.GetRunId(), input.GetKey(), input.GetValue()); err != nil {
return nil, err
}

return &protos.SetTag_Response{}, nil
}

func (ts TrackingService) DeleteTag(ctx context.Context, input *protos.DeleteTag) (*protos.DeleteTag_Response, *contract.Error) {
if err := ts.Store.DeleteTag(ctx, input.GetRunId(), input.GetKey()); err != nil {
return nil, err
}

return &protos.DeleteTag_Response{}, nil
}
Loading
Loading