From 7ee13b3fa6d1480ae0a0ae4fd13abb9f436decd0 Mon Sep 17 00:00:00 2001 From: faph Date: Fri, 16 Feb 2024 11:17:08 +0000 Subject: [PATCH 01/10] Define test for testing.predict function --- pyproject.toml | 1 + tests/test_inference_server.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 98b4a0b..2727f98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ docs = [ testing = [ "pytest", "pytest-cov", + "sagemaker", # For testing serializers/deserializers ] linting = [ "black", diff --git a/tests/test_inference_server.py b/tests/test_inference_server.py index 86c8267..26368eb 100644 --- a/tests/test_inference_server.py +++ b/tests/test_inference_server.py @@ -9,7 +9,11 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +from typing import Tuple + +import botocore.response import pytest +import sagemaker.serializers import inference_server import inference_server.testing @@ -70,13 +74,38 @@ def test_path_not_found(client): def test_invocations(): - # Test the default plugin which should just pass through any input bytes + """Test the default plugin (which passes through any input bytes) using low-level testing.post_invocations""" data = b"What's the shipping forecast for tomorrow" response = inference_server.testing.post_invocations(data=data, headers={"Accept": "application/octet-stream"}) assert response.data == data assert response.headers["Content-Type"] == "application/octet-stream" +def test_prediction_custom_serializer(): + """Test the default plugin again, now using high-level testing.predict""" + + class Serializer(sagemaker.serializers.BaseSerializer): + @property + def CONTENT_TYPE(self) -> str: + return "application/octet-stream" + + def serialize(self, data: str) -> bytes: + return data.encode() # Simple str to bytes serializer + + class Deserializer(sagemaker.deserializers.BaseDeserializer): + @property + def ACCEPT(self) -> Tuple[str]: + return ("application/json",) + + def deserialize(self, stream: botocore.response.StreamingBody, content_type: str) -> str: + assert content_type in self.ACCEPT + return stream.read().decode() # Simple bytes to str deserializer + + input_data = "What's the shipping forecast for tomorrow" # Simply pass a string + prediction = inference_server.testing.predict(data=input_data, serializer=Serializer(), deserializer=Deserializer()) + assert prediction == input_data # Receive a string + + def test_execution_parameters(client): response = client.get("/execution-parameters") assert response.data == b'{"BatchStrategy":"MultiRecord","MaxConcurrentTransforms":1,"MaxPayloadInMB":6}' From e4f2351ffb177f62280634c2895ccf7ca177ccb0 Mon Sep 17 00:00:00 2001 From: faph Date: Fri, 16 Feb 2024 12:14:32 +0000 Subject: [PATCH 02/10] Define test for testing.predict function without serializer/deserializer args --- tests/test_inference_server.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_inference_server.py b/tests/test_inference_server.py index 26368eb..dfab840 100644 --- a/tests/test_inference_server.py +++ b/tests/test_inference_server.py @@ -106,6 +106,12 @@ def deserialize(self, stream: botocore.response.StreamingBody, content_type: str assert prediction == input_data # Receive a string +def test_prediction_no_serializer(): + input_data = b"What's the shipping forecast for tomorrow" + prediction = inference_server.testing.predict(input_data) # No serializer should be bytes pass through again + assert prediction == input_data + + def test_execution_parameters(client): response = client.get("/execution-parameters") assert response.data == b'{"BatchStrategy":"MultiRecord","MaxConcurrentTransforms":1,"MaxPayloadInMB":6}' From 3d082e607c65ab849730f973ccb98291fd28233f Mon Sep 17 00:00:00 2001 From: faph Date: Fri, 16 Feb 2024 12:30:43 +0000 Subject: [PATCH 03/10] Use SageMaker SDK's own string serializer/deserializer in tests --- tests/test_inference_server.py | 28 ++++++---------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/tests/test_inference_server.py b/tests/test_inference_server.py index dfab840..430c8b1 100644 --- a/tests/test_inference_server.py +++ b/tests/test_inference_server.py @@ -9,10 +9,8 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. -from typing import Tuple - -import botocore.response import pytest +import sagemaker.deserializers import sagemaker.serializers import inference_server @@ -83,26 +81,12 @@ def test_invocations(): def test_prediction_custom_serializer(): """Test the default plugin again, now using high-level testing.predict""" - - class Serializer(sagemaker.serializers.BaseSerializer): - @property - def CONTENT_TYPE(self) -> str: - return "application/octet-stream" - - def serialize(self, data: str) -> bytes: - return data.encode() # Simple str to bytes serializer - - class Deserializer(sagemaker.deserializers.BaseDeserializer): - @property - def ACCEPT(self) -> Tuple[str]: - return ("application/json",) - - def deserialize(self, stream: botocore.response.StreamingBody, content_type: str) -> str: - assert content_type in self.ACCEPT - return stream.read().decode() # Simple bytes to str deserializer - input_data = "What's the shipping forecast for tomorrow" # Simply pass a string - prediction = inference_server.testing.predict(data=input_data, serializer=Serializer(), deserializer=Deserializer()) + prediction = inference_server.testing.predict( + data=input_data, + serializer=sagemaker.serializers.StringSerializer(), + deserializer=sagemaker.deserializers.StringDeserializer(), + ) assert prediction == input_data # Receive a string From 29fb937ed1ce9e18ac5c27f8dc54c0a22e9afa46 Mon Sep 17 00:00:00 2001 From: faph Date: Fri, 16 Feb 2024 14:55:48 +0000 Subject: [PATCH 04/10] Initial implementation of testing.predict --- src/inference_server/testing.py | 50 ++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/src/inference_server/testing.py b/src/inference_server/testing.py index 2e9470d..ddb7ce9 100644 --- a/src/inference_server/testing.py +++ b/src/inference_server/testing.py @@ -13,9 +13,11 @@ Functions for testing **inference-server** plugins """ +import io from types import ModuleType -from typing import Callable, Type, Union +from typing import Any, Callable, Protocol, Tuple, Type, Union +import botocore.response import pluggy import werkzeug.test @@ -23,6 +25,52 @@ import inference_server._plugin +class ImplementsSerialize(Protocol): + """Interface compatible with :class:`sagemaker.serializers.BaseSerializer`""" + + @property + def CONTENT_TYPE(self) -> str: + """The MIME type for the serialized data""" + + def serialize(self, data: Any) -> bytes: + """Return the serialized data""" + + +class ImplementsDeserialize(Protocol): + """Interface compatible with :class:`sagemaker.deserializers.BaseDeserializer`""" + + @property + def ACCEPT(self) -> Tuple[str]: + """The content types that are supported by this deserializer""" + + def deserialize(self, stream: "botocore.response.StreamingBody", content_type: str) -> Any: + """Return the deserialized data""" + + +def predict(data: Any, serializer: ImplementsSerialize, deserializer: ImplementsDeserialize) -> Any: + """ + Invoke the model and return a prediction + + :param data: Model input data + :param serializer: A serializer for sending the data as bytes to the model server. Should be compatible with + :class:`sagemaker.serializers.BaseSerializer`. + :param deserializer: A deserializer for processing the prediction as sent by the model server. Should be compatible + with :class:`sagemaker.deserializers.BaseDeserializer`. + """ + serialized_data = serializer.serialize(data) + http_headers = { + "Content-Type": serializer.CONTENT_TYPE, # The serializer declares the content-type of the input data + "Accept": ", ".join(deserializer.ACCEPT), # The deserializer dictates the content-type of the prediction + } + prediction_response = post_invocations(data=serialized_data, headers=http_headers) + prediction_stream = botocore.response.StreamingBody( + raw_stream=io.BytesIO(prediction_response.data), + content_length=prediction_response.content_length, + ) + prediction_deserialized = deserializer.deserialize(prediction_stream, content_type=prediction_response.content_type) + return prediction_deserialized + + def client() -> werkzeug.test.Client: """ Return an HTTP test client for :mod:`inference_server` From e2d5f2e7437bca2dc7435484a96a6471df955482 Mon Sep 17 00:00:00 2001 From: faph Date: Fri, 16 Feb 2024 15:05:48 +0000 Subject: [PATCH 05/10] Remove support for Python 3.7 (due to Protocol usage) --- pyproject.toml | 2 +- tox.ini | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2727f98..d8048f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ keywords = [ ] # Minimum supported Python version -requires-python = ">=3.7" +requires-python = ">=3.8" # All runtime dependencies that must be packaged, pin major version only. dependencies = [ "codetiming~=1.4", diff --git a/tox.ini b/tox.ini index c9f863c..b838fc7 100644 --- a/tox.ini +++ b/tox.ini @@ -13,7 +13,7 @@ # Tox virtual environment manager for testing and quality assurance [tox] -envlist = py37, py38, py39, py310, py311, py312, linting, docs +envlist = py38, py39, py310, py311, py312, linting, docs isolated_build = True # Developers may not have all Python versions skip_missing_interpreters = true @@ -21,7 +21,6 @@ skip_missing_interpreters = true [gh-actions] # Mapping from GitHub Actions Python versions to Tox environments python = - 3.7: py37 3.8: py38 3.9: py39 3.10: py310 From 5064a6692345c83aa45928480d04f23590f6266d Mon Sep 17 00:00:00 2001 From: faph Date: Fri, 16 Feb 2024 15:09:02 +0000 Subject: [PATCH 06/10] Make serializers and deserializer optional --- src/inference_server/testing.py | 48 ++++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/src/inference_server/testing.py b/src/inference_server/testing.py index ddb7ce9..5ed4303 100644 --- a/src/inference_server/testing.py +++ b/src/inference_server/testing.py @@ -15,7 +15,7 @@ import io from types import ModuleType -from typing import Any, Callable, Protocol, Tuple, Type, Union +from typing import Any, Callable, Optional, Protocol, Tuple, Type, Union import botocore.response import pluggy @@ -47,16 +47,52 @@ def deserialize(self, stream: "botocore.response.StreamingBody", content_type: s """Return the deserialized data""" -def predict(data: Any, serializer: ImplementsSerialize, deserializer: ImplementsDeserialize) -> Any: +class _PassThroughSerializer: + """Serialize bytes as bytes""" + + @property + def CONTENT_TYPE(self) -> str: + """The MIME type for the serialized data""" + return "application/octet-stream" + + def serialize(self, data: bytes) -> bytes: + """Return the serialized data""" + assert isinstance(data, bytes) + return data + + +class _PassThroughDeserializer: + """Deserialize bytes as bytes""" + + @property + def ACCEPT(self) -> Tuple[str]: + """The content types that are supported by this deserializer""" + return ("application/octet-stream",) + + def deserialize(self, stream: "botocore.response.StreamingBody", content_type: str) -> Any: + """Return the deserialized data""" + assert content_type in self.ACCEPT + try: + return stream.read() + finally: + stream.close() + + +def predict( + data: Any, serializer: Optional[ImplementsSerialize] = None, deserializer: Optional[ImplementsDeserialize] = None +) -> Any: """ Invoke the model and return a prediction :param data: Model input data - :param serializer: A serializer for sending the data as bytes to the model server. Should be compatible with - :class:`sagemaker.serializers.BaseSerializer`. - :param deserializer: A deserializer for processing the prediction as sent by the model server. Should be compatible - with :class:`sagemaker.deserializers.BaseDeserializer`. + :param serializer: Optional. A serializer for sending the data as bytes to the model server. Should be compatible + with :class:`sagemaker.serializers.BaseSerializer`. Default: bytes pass-through. + :param deserializer: Optional. A deserializer for processing the prediction as sent by the model server. Should be + compatible with :class:`sagemaker.deserializers.BaseDeserializer`. Default: bytes pass-through. """ + serializer = serializer or _PassThroughSerializer() + deserializer = deserializer or _PassThroughDeserializer() + serialized_data = serializer.serialize(data) http_headers = { "Content-Type": serializer.CONTENT_TYPE, # The serializer declares the content-type of the input data From e3f14a0bf5c59d9a9655f11f7a2367f2b6a32f79 Mon Sep 17 00:00:00 2001 From: faph Date: Fri, 16 Feb 2024 15:43:15 +0000 Subject: [PATCH 07/10] Remove Python 3.7 from github actions --- .github/workflows/test-package.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/test-package.yml b/.github/workflows/test-package.yml index 07ceaac..5a6eddc 100644 --- a/.github/workflows/test-package.yml +++ b/.github/workflows/test-package.yml @@ -27,7 +27,6 @@ jobs: fail-fast: false matrix: python-version: - - '3.7' - '3.8' - '3.9' - '3.10' From f3d5c92a9a8cd9014d59703d98ce1568bbb112f8 Mon Sep 17 00:00:00 2001 From: faph Date: Fri, 16 Feb 2024 21:00:51 +0000 Subject: [PATCH 08/10] Dont use sagemaker module for test serializer as it does not work in Py312 --- pyproject.toml | 2 +- src/inference_server/testing.py | 2 +- tests/test_inference_server.py | 27 +++++++++++++++++++++++---- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d8048f1..40ba7ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ keywords = [ requires-python = ">=3.8" # All runtime dependencies that must be packaged, pin major version only. dependencies = [ + "botocore", "codetiming~=1.4", "importlib-metadata<4; python_version<'3.8'", "orjson~=3.0", @@ -72,7 +73,6 @@ docs = [ testing = [ "pytest", "pytest-cov", - "sagemaker", # For testing serializers/deserializers ] linting = [ "black", diff --git a/src/inference_server/testing.py b/src/inference_server/testing.py index 5ed4303..f27d199 100644 --- a/src/inference_server/testing.py +++ b/src/inference_server/testing.py @@ -43,7 +43,7 @@ class ImplementsDeserialize(Protocol): def ACCEPT(self) -> Tuple[str]: """The content types that are supported by this deserializer""" - def deserialize(self, stream: "botocore.response.StreamingBody", content_type: str) -> Any: + def deserialize(self, stream: botocore.response.StreamingBody, content_type: str) -> Any: """Return the deserialized data""" diff --git a/tests/test_inference_server.py b/tests/test_inference_server.py index 430c8b1..8f0ccc9 100644 --- a/tests/test_inference_server.py +++ b/tests/test_inference_server.py @@ -9,9 +9,10 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the # specific language governing permissions and limitations under the License. +from typing import Tuple + +import botocore.response import pytest -import sagemaker.deserializers -import sagemaker.serializers import inference_server import inference_server.testing @@ -81,11 +82,29 @@ def test_invocations(): def test_prediction_custom_serializer(): """Test the default plugin again, now using high-level testing.predict""" + + class Serializer: + @property + def CONTENT_TYPE(self) -> str: + return "application/octet-stream" + + def serialize(self, data: str) -> bytes: + return data.encode() # Simple str to bytes serializer + + class Deserializer: + @property + def ACCEPT(self) -> Tuple[str]: + return ("application/octet-stream",) + + def deserialize(self, stream: botocore.response.StreamingBody, content_type: str) -> str: + assert content_type in self.ACCEPT + return stream.read().decode() # Simple bytes to str deserializer + input_data = "What's the shipping forecast for tomorrow" # Simply pass a string prediction = inference_server.testing.predict( data=input_data, - serializer=sagemaker.serializers.StringSerializer(), - deserializer=sagemaker.deserializers.StringDeserializer(), + serializer=Serializer(), + deserializer=Deserializer(), ) assert prediction == input_data # Receive a string From 28e9385be7635bf963000eef0d5ea03680291e6e Mon Sep 17 00:00:00 2001 From: faph Date: Fri, 16 Feb 2024 21:05:52 +0000 Subject: [PATCH 09/10] Ignore botocore missing type hints --- src/inference_server/testing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/inference_server/testing.py b/src/inference_server/testing.py index f27d199..3ac8efd 100644 --- a/src/inference_server/testing.py +++ b/src/inference_server/testing.py @@ -17,7 +17,7 @@ from types import ModuleType from typing import Any, Callable, Optional, Protocol, Tuple, Type, Union -import botocore.response +import botocore.response # type: ignore[import-untyped] import pluggy import werkzeug.test From a4987cfa63238d99aef8fd39d5bcf39360cfccca Mon Sep 17 00:00:00 2001 From: faph Date: Wed, 20 Mar 2024 13:44:23 +0000 Subject: [PATCH 10/10] Document testing.predict() --- docs/conf.py | 3 +- docs/testing.rst | 85 ++++++++++++++++++++++++++++++++++-------------- 2 files changed, 62 insertions(+), 26 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index f86de80..0dffcaa 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -46,8 +46,9 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3/", None), - "werkzeug": ("https://werkzeug.palletsprojects.com", None), + "sagemaker": ("https://sagemaker.readthedocs.io/en/stable/", None), "sklearn": ("https://scikit-learn.org/stable/", None), + "werkzeug": ("https://werkzeug.palletsprojects.com", None), } # List of patterns, relative to source directory, that match files and diff --git a/docs/testing.rst b/docs/testing.rst index bf2b082..6ce5967 100644 --- a/docs/testing.rst +++ b/docs/testing.rst @@ -5,12 +5,70 @@ This page explains how to test implemented **inference-server** hooks using the :ref:`hooks:Implementing server hooks`. +Testing model predictions (high-level API) +------------------------------------------ + +To verify whether we have defined and registered all services hooks correctly, we use the +:mod:`inference_server.testing` module. + +A full example looks like this:: + + import sagemaker.deserializers + import sagemaker.serializers + + from inference_server import testing + + def test_prediction_is_ok(): + input_data = {"location": "Fair Isle"} + expected_prediction = { + "wind": "Southwesterly gale force 8 continuing", + "sea_state": "Rough or very rough, occasionally moderate in southeast.", + "weather": "Thundery showers.", + "visibility": "Good, occasionally poor." + } + + prediction = testing.predict( + input_data, + serializer=sagemaker.serializers.JSONSerializer(), + deserializer=sagemaker.deserializers.JSONDeserializer(), + ) + assert prediction == expected_prediction + +Here we can use any serializer compatible with :mod:`sagemaker.serializers` and any deserializer compatible with +:mod:`sagemaker.deserializers` from the AWS SageMaker SDK. + +If no serializer or deserializer is configured, bytes data are passed through as is for both input and output. + + +Testing model predictions (low-level API) +----------------------------------------- + +Instead of using the high-level testing API, we can also use invoke requests similar to the :mod:`requests` library:: + + def test_prediction_request_is_ok(): + input_data = {"location": "Fair Isle"} + expected_prediction = { + "wind": "Southwesterly gale force 8 continuing", + "sea_state": "Rough or very rough, occasionally moderate in southeast.", + "weather": "Thundery showers.", + "visibility": "Good, occasionally poor." + } + + response = testing.post_invocations( + json=input_data, + content_type="application/json", + headers={"Accept": "application/json"}, + ) + assert response.content_type == "application/json" + assert response.json() == expected_prediction + + + Verifying plugin registration ----------------------------- -To verify out model is registered correctly as a plugin, we use this:: +To verify the model is registered correctly as a plugin, we use this:: - from inference_server import testing import shipping_forecast def test_plugin_is_registered(): @@ -39,26 +97,3 @@ To verify our function hooks have been defined correctly, we use this:: def test_predict_fn_hook_is_valid(): assert testing.hookimpl_is_valid(shipping_forecast.predict_fn) - - -Testing model predictions -------------------------- - -To the test a complete model invocation, we use this:: - - def test_prediction_is_ok(): - input_data = {"location": "Fair Isle"} - expected_prediction = { - "wind": "Southwesterly gale force 8 continuing", - "sea_state": "Rough or very rough, occasionally moderate in southeast.", - "weather": "Thundery showers.", - "visibility": "Good, occasionally poor." - } - - response = testing.post_invocations( - json=input_data, - content_type="application/json", - headers={"Accept": "application/json"}, - ) - assert response.content_type == "application/json" - assert response.json() == expected_prediction