-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
alexioannides
committed
Jul 20, 2021
1 parent
46b5999
commit a6192e9
Showing
9 changed files
with
103 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
jupyterlab==3.0.16 | ||
seaborn==0.11.1 | ||
numpy==1.21.0 | ||
pandas==1.2.5 | ||
pandas==1.3.0 | ||
scikit-learn==0.24.2 | ||
boto3==1.17.101 | ||
joblib==1.0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,7 @@ | ||
numpy>=1.21.0 | ||
pandas>=1.2.0 | ||
scikit-learn>=0.24.0 | ||
boto3>=1.17.0 | ||
joblib>=1.0.0 | ||
numpy==1.21.0 | ||
pandas==1.2.5 | ||
scikit-learn==0.24.2 | ||
boto3==1.17.101 | ||
fastapi==0.65.2 | ||
uvicorn==0.14.0 | ||
git+https://github.com/bodywork-ml/[email protected].1 | ||
git+https://github.com/bodywork-ml/[email protected].4 |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,66 @@ | ||
""" | ||
Tests for web API. | ||
""" | ||
import pickle | ||
from subprocess import run | ||
from unittest.mock import patch | ||
|
||
from bodywork_pipeline_utils.aws import Model | ||
from fastapi.testclient import TestClient | ||
from numpy import array | ||
|
||
from pipeline.serve_model import app | ||
|
||
test_client = TestClient(app) | ||
|
||
|
||
def wrapped_model() -> Model: | ||
with open("tests/resources/model.pkl", "r+b") as file: | ||
wrapped_model = pickle.load(file) | ||
return wrapped_model | ||
|
||
|
||
@patch("pipeline.serve_model.wrapped_model", new=wrapped_model(), create=True) | ||
def test_web_api_returns_valid_response_given_valid_data(): | ||
prediction_request = {"product_code": "SKU001", "orders_placed": 100} | ||
prediction_response = test_client.post( | ||
"/api/v0.1/time_to_dispatch", json=prediction_request | ||
) | ||
model_obj = wrapped_model() | ||
expected_prediction = model_obj.model.predict(array([[100, 0]])).tolist()[0] | ||
assert prediction_response.status_code == 200 | ||
assert "est_hours_to_dispatch" in prediction_response.json().keys() | ||
assert "model_version" in prediction_response.json().keys() | ||
assert prediction_response.json()["est_hours_to_dispatch"] == expected_prediction | ||
assert prediction_response.json()["model_version"] == str(model_obj) | ||
|
||
|
||
@patch("pipeline.serve_model.wrapped_model", new=wrapped_model(), create=True) | ||
def test_web_api_returns_error_code_given_invalid_data(): | ||
prediction_request = {"product_code": "SKU001", "foo": 100} | ||
prediction_response = test_client.post( | ||
"/api/v0.1/time_to_dispatch", json=prediction_request | ||
) | ||
assert prediction_response.status_code == 422 | ||
assert "value_error.missing" in prediction_response.text | ||
|
||
prediction_request = {"product_code": "SKU000", "orders_placed": 100} | ||
prediction_response = test_client.post( | ||
"/api/v0.1/time_to_dispatch", json=prediction_request | ||
) | ||
assert prediction_response.status_code == 422 | ||
assert "not a valid enumeration member" in prediction_response.text | ||
|
||
prediction_request = {"product_code": "SKU001", "orders_placed": -100} | ||
prediction_response = test_client.post( | ||
"/api/v0.1/time_to_dispatch", json=prediction_request | ||
) | ||
assert prediction_response.status_code == 422 | ||
assert "ensure this value is greater than or equal to 0" in prediction_response.text | ||
|
||
|
||
def test_web_server_raises_exception_if_passed_invalid_args(): | ||
process = run( | ||
["python", "-m", "pipeline.serve_model"], capture_output=True, encoding="utf-8" | ||
) | ||
assert process.returncode != 0 | ||
assert "ERROR" in process.stdout | ||
assert "Invalid arguments passed to serve_model.py" in process.stdout |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters