-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add run_pyfunc_model() to build and run pyfunc model locally (#514)
<!-- Thanks for sending a pull request! Here are some tips for you: 1. Run unit tests and ensure that they are passing 2. If your change introduces any API changes, make sure to update the e2e tests 3. Make sure documentation is updated for your PR! --> # Description <!-- Briefly describe the motivation for the change. Please include illustrations where appropriate. --> To improve the user experience in developing the PyFunc model, I'm adding a new `merlin.run_pyfunc_model()` helper function to build and run the PyFunc model's Docker image locally. The new function has the similar arguments as `merlin.log_pyfunc_model()`. Currently, merlin-sdk has (outdated) [`ModelVersion.start_server()`](https://github.com/caraml-dev/merlin/blob/main/python/sdk/merlin/model.py#L1387) that also can be used to run standard and PyFunc models locally. However, there are some drawbacks to the current implementation: 1. To use `start_server()` function, users need to connect to the Merlin server, create a project (if not created yet), create a new model version, and log the model artifacts to the remote MLflow tracking server. This could lead to some unfinished model versions getting uploaded onto the Merlin and MLflow servers. 2. The implementation of `start_server()` will download the artifacts from the MLflow tracking server first to be used to build the Docker image locally. ``` # Before: from merlin from merlin.model import ModelType, PyFuncModel # Implement PyFuncModel class MyModel(PyFuncModel): .... # Connecting to Merlin server merlin.set_url("...") merlin.set_project("...") merlin.set_model("my-model", ModelType.PYFUNC) # Create new model version, log it, and run server: with merlin.new_model_version() as v: v.log_pyfunc_model( model_instance=MyModel(), ... ) # run pyfunc server v.start_server(...) # Or, if users already have logged existing model version on Merlin, # they can get the latest model version and run it locally: versions = merlin.active_model().list_version() versions.sort(key=lambda v: v.id, reverse=True) last_version = versions[0] last_version.start_server() ``` The new `merlin.run_pyfunc_model()` is more straightforward as it can build and run the pyfunc model without creating a new model version and uploading the model artifact: ``` # After: from merlin from merlin.model import ModelType, PyFuncModel # Implement PyFuncModel class MyModel(PyFuncModel): .... # Run pyfunc server merlin.run_pyfunc_model( model_instance=MyModel(), ... ) ``` # Modifications <!-- Summarize the key code changes. --> 1. Introduce `run_pyfunc_model()` in `pyfunc` package 2. Import `pyfunc.run_pyfunc_model()` into merlin's `__all__` so it can be called via `merlin.run_pyfunc_model()` 3. Refactor `ModelVersion.start_server()` to use `pyfunc.run_pyfunc_model()` for PyFunc model 4. Add an example of how to run the PyFunc model locally on PyFunc notebook example 5. Add some simple PyFunc model examples in `pyfunc/examples` folder # Tests <!-- Besides the existing / updated automated tests, what specific scenarios should be tested? Consider the backward compatibility of the changes, whether corner cases are covered, etc. Please describe the tests and check the ones that have been completed. Eg: - [x] Deploying new and existing standard models - [ ] Deploying PyFunc models --> 1. Add test_examples.py in pyfunc-server package # Checklist - [x] Added PR label - [x] Added unit test, integration, and/or e2e tests - [x] Tested locally - [x] Updated documentation - [ ] Update Swagger spec if the PR introduce API changes - [ ] Regenerated Golang and Python client if the PR introduces API changes # Release Notes <!-- Does this PR introduce a user-facing change? If no, just write "NONE" in the release-note block below. If yes, a release note is required. Enter your extended release note in the block below. If the PR requires additional action from users switching to the new release, include the string "action required". For more information about release notes, see kubernetes' guide here: http://git.k8s.io/community/contributors/guide/release-notes.md --> ```release-note Add merlin.run_pyfunc_model() function to build and run the PyFunc model locally. ```
- Loading branch information
1 parent
367a4ec
commit 904a894
Showing
34 changed files
with
1,120 additions
and
402 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Binary file not shown.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import logging | ||
|
||
import merlin | ||
from merlin.model import PyFuncModel | ||
|
||
|
||
class EchoModel(PyFuncModel): | ||
def initialize(self, artifacts): | ||
pass | ||
|
||
def infer(self, request): | ||
logging.info("request: %s", request) | ||
return request | ||
|
||
|
||
if __name__ == "__main__": | ||
# Run pyfunc model locally without uploading to Merlin server | ||
merlin.run_pyfunc_model( | ||
model_instance=EchoModel(), | ||
conda_env="env.yaml", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
dependencies: | ||
- python=3.10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Echo UPI Model Examples | ||
|
||
Run the server locally: | ||
|
||
``` | ||
python upi_server.py | ||
``` | ||
|
||
In different terminal session, run the client: | ||
|
||
``` | ||
python upi_client.py | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
dependencies: | ||
- python=3.10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import grpc | ||
import pandas as pd | ||
from caraml.upi.utils import df_to_table | ||
from caraml.upi.v1 import upi_pb2, upi_pb2_grpc | ||
|
||
|
||
def create_upi_request() -> upi_pb2.PredictValuesRequest: | ||
target_name = "echo" | ||
df = pd.DataFrame( | ||
[[4, 1, "hi"]] * 3, | ||
columns=["int_value", "int_value_2", "string_value"], | ||
index=["0000", "1111", "2222"], | ||
) | ||
prediction_id = "12345" | ||
|
||
return upi_pb2.PredictValuesRequest( | ||
target_name=target_name, | ||
prediction_table=df_to_table(df, "predict"), | ||
metadata=upi_pb2.RequestMetadata(prediction_id=prediction_id), | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
channel = grpc.insecure_channel(f"localhost:8080") | ||
stub = upi_pb2_grpc.UniversalPredictionServiceStub(channel) | ||
|
||
request = create_upi_request() | ||
response = stub.PredictValues(request=request) | ||
print(response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
dependencies: | ||
- python=3.10.* | ||
- pip: | ||
- joblib>=0.13.0,<1.2.0 # >=1.2.0 upon upgrade of kserve's version | ||
- numpy<=1.23.5 # Temporary pin numpy due to https://numpy.org/doc/stable/release/1.20.0-notes.html#numpy-1-20-0-release-notes | ||
- scikit-learn>=1.1.2 | ||
- xgboost==1.6.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import os | ||
|
||
import joblib | ||
import merlin | ||
import numpy as np | ||
import xgboost as xgb | ||
from joblib import dump | ||
from merlin.model import PyFuncModel | ||
from sklearn import svm | ||
from sklearn.datasets import load_iris | ||
|
||
XGB_PATH = os.path.join("models/", "model_1.bst") | ||
SKLEARN_PATH = os.path.join("models/", "model_2.joblib") | ||
|
||
|
||
class IrisModel(PyFuncModel): | ||
def initialize(self, artifacts): | ||
self._model_1 = xgb.Booster(model_file=artifacts["xgb_model"]) | ||
self._model_2 = joblib.load(artifacts["sklearn_model"]) | ||
|
||
def infer(self, model_input): | ||
inputs = np.array(model_input["instances"]) | ||
dmatrix = xgb.DMatrix(model_input["instances"]) | ||
result_1 = self._model_1.predict(dmatrix) | ||
result_2 = self._model_2.predict_proba(inputs) | ||
return {"predictions": ((result_1 + result_2) / 2).tolist()} | ||
|
||
|
||
def train_models(): | ||
iris = load_iris() | ||
y = iris["target"] | ||
X = iris["data"] | ||
|
||
# train xgboost model | ||
dtrain = xgb.DMatrix(X, label=y) | ||
param = { | ||
"max_depth": 6, | ||
"eta": 0.1, | ||
"silent": 1, | ||
"nthread": 4, | ||
"num_class": 3, | ||
"objective": "multi:softprob", | ||
} | ||
xgb_model = xgb.train(params=param, dtrain=dtrain) | ||
xgb_model.save_model(XGB_PATH) | ||
|
||
# train sklearn model | ||
clf = svm.SVC(gamma="scale", probability=True) | ||
clf.fit(X, y) | ||
dump(clf, SKLEARN_PATH) | ||
|
||
|
||
if __name__ == "__main__": | ||
train_models() | ||
|
||
# test pyfunc model locally | ||
iris_model = IrisModel() | ||
iris_model.initialize( | ||
artifacts={ | ||
"xgb_model": XGB_PATH, | ||
"sklearn_model": SKLEARN_PATH, | ||
} | ||
) | ||
pred = iris_model.infer({"instances": [[2.8, 1.0, 6.8, 0.4], [3.1, 1.4, 4.5, 1.6]]}) | ||
print(pred) | ||
|
||
# run pyfunc model locally | ||
merlin.run_pyfunc_model( | ||
model_instance=IrisModel(), | ||
conda_env="env.yaml", | ||
artifacts={ | ||
"xgb_model": XGB_PATH, | ||
"sklearn_model": SKLEARN_PATH, | ||
}, | ||
) |
Binary file not shown.
Binary file not shown.
Oops, something went wrong.