Skip to content

Commit

Permalink
[MINOR] testing.predict() (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
faph authored Mar 20, 2024
2 parents 33bac58 + a4987cf commit 4c58a1c
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 32 deletions.
1 change: 0 additions & 1 deletion .github/workflows/test-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ jobs:
fail-fast: false
matrix:
python-version:
- '3.7'
- '3.8'
- '3.9'
- '3.10'
Expand Down
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@

intersphinx_mapping = {
"python": ("https://docs.python.org/3/", None),
"werkzeug": ("https://werkzeug.palletsprojects.com", None),
"sagemaker": ("https://sagemaker.readthedocs.io/en/stable/", None),
"sklearn": ("https://scikit-learn.org/stable/", None),
"werkzeug": ("https://werkzeug.palletsprojects.com", None),
}

# List of patterns, relative to source directory, that match files and
Expand Down
85 changes: 60 additions & 25 deletions docs/testing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,70 @@ This page explains how to test implemented **inference-server** hooks using the
:ref:`hooks:Implementing server hooks`.


Testing model predictions (high-level API)
------------------------------------------

To verify whether we have defined and registered all services hooks correctly, we use the
:mod:`inference_server.testing` module.

A full example looks like this::

import sagemaker.deserializers
import sagemaker.serializers

from inference_server import testing

def test_prediction_is_ok():
input_data = {"location": "Fair Isle"}
expected_prediction = {
"wind": "Southwesterly gale force 8 continuing",
"sea_state": "Rough or very rough, occasionally moderate in southeast.",
"weather": "Thundery showers.",
"visibility": "Good, occasionally poor."
}

prediction = testing.predict(
input_data,
serializer=sagemaker.serializers.JSONSerializer(),
deserializer=sagemaker.deserializers.JSONDeserializer(),
)
assert prediction == expected_prediction

Here we can use any serializer compatible with :mod:`sagemaker.serializers` and any deserializer compatible with
:mod:`sagemaker.deserializers` from the AWS SageMaker SDK.

If no serializer or deserializer is configured, bytes data are passed through as is for both input and output.


Testing model predictions (low-level API)
-----------------------------------------

Instead of using the high-level testing API, we can also use invoke requests similar to the :mod:`requests` library::

def test_prediction_request_is_ok():
input_data = {"location": "Fair Isle"}
expected_prediction = {
"wind": "Southwesterly gale force 8 continuing",
"sea_state": "Rough or very rough, occasionally moderate in southeast.",
"weather": "Thundery showers.",
"visibility": "Good, occasionally poor."
}

response = testing.post_invocations(
json=input_data,
content_type="application/json",
headers={"Accept": "application/json"},
)
assert response.content_type == "application/json"
assert response.json() == expected_prediction



Verifying plugin registration
-----------------------------

To verify out model is registered correctly as a plugin, we use this::
To verify the model is registered correctly as a plugin, we use this::

from inference_server import testing
import shipping_forecast

def test_plugin_is_registered():
Expand Down Expand Up @@ -39,26 +97,3 @@ To verify our function hooks have been defined correctly, we use this::

def test_predict_fn_hook_is_valid():
assert testing.hookimpl_is_valid(shipping_forecast.predict_fn)


Testing model predictions
-------------------------

To the test a complete model invocation, we use this::

def test_prediction_is_ok():
input_data = {"location": "Fair Isle"}
expected_prediction = {
"wind": "Southwesterly gale force 8 continuing",
"sea_state": "Rough or very rough, occasionally moderate in southeast.",
"weather": "Thundery showers.",
"visibility": "Good, occasionally poor."
}

response = testing.post_invocations(
json=input_data,
content_type="application/json",
headers={"Accept": "application/json"},
)
assert response.content_type == "application/json"
assert response.json() == expected_prediction
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ keywords = [
]

# Minimum supported Python version
requires-python = ">=3.7"
requires-python = ">=3.8"
# All runtime dependencies that must be packaged, pin major version only.
dependencies = [
"botocore",
"codetiming~=1.4",
"importlib-metadata<4; python_version<'3.8'",
"orjson~=3.0",
Expand Down
86 changes: 85 additions & 1 deletion src/inference_server/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,100 @@
Functions for testing **inference-server** plugins
"""

import io
from types import ModuleType
from typing import Callable, Type, Union
from typing import Any, Callable, Optional, Protocol, Tuple, Type, Union

import botocore.response # type: ignore[import-untyped]
import pluggy
import werkzeug.test

import inference_server
import inference_server._plugin


class ImplementsSerialize(Protocol):
"""Interface compatible with :class:`sagemaker.serializers.BaseSerializer`"""

@property
def CONTENT_TYPE(self) -> str:
"""The MIME type for the serialized data"""

def serialize(self, data: Any) -> bytes:
"""Return the serialized data"""


class ImplementsDeserialize(Protocol):
"""Interface compatible with :class:`sagemaker.deserializers.BaseDeserializer`"""

@property
def ACCEPT(self) -> Tuple[str]:
"""The content types that are supported by this deserializer"""

def deserialize(self, stream: botocore.response.StreamingBody, content_type: str) -> Any:
"""Return the deserialized data"""


class _PassThroughSerializer:
"""Serialize bytes as bytes"""

@property
def CONTENT_TYPE(self) -> str:
"""The MIME type for the serialized data"""
return "application/octet-stream"

def serialize(self, data: bytes) -> bytes:
"""Return the serialized data"""
assert isinstance(data, bytes)
return data


class _PassThroughDeserializer:
"""Deserialize bytes as bytes"""

@property
def ACCEPT(self) -> Tuple[str]:
"""The content types that are supported by this deserializer"""
return ("application/octet-stream",)

def deserialize(self, stream: "botocore.response.StreamingBody", content_type: str) -> Any:
"""Return the deserialized data"""
assert content_type in self.ACCEPT
try:
return stream.read()
finally:
stream.close()


def predict(
data: Any, serializer: Optional[ImplementsSerialize] = None, deserializer: Optional[ImplementsDeserialize] = None
) -> Any:
"""
Invoke the model and return a prediction
:param data: Model input data
:param serializer: Optional. A serializer for sending the data as bytes to the model server. Should be compatible
with :class:`sagemaker.serializers.BaseSerializer`. Default: bytes pass-through.
:param deserializer: Optional. A deserializer for processing the prediction as sent by the model server. Should be
compatible with :class:`sagemaker.deserializers.BaseDeserializer`. Default: bytes pass-through.
"""
serializer = serializer or _PassThroughSerializer()
deserializer = deserializer or _PassThroughDeserializer()

serialized_data = serializer.serialize(data)
http_headers = {
"Content-Type": serializer.CONTENT_TYPE, # The serializer declares the content-type of the input data
"Accept": ", ".join(deserializer.ACCEPT), # The deserializer dictates the content-type of the prediction
}
prediction_response = post_invocations(data=serialized_data, headers=http_headers)
prediction_stream = botocore.response.StreamingBody(
raw_stream=io.BytesIO(prediction_response.data),
content_length=prediction_response.content_length,
)
prediction_deserialized = deserializer.deserialize(prediction_stream, content_type=prediction_response.content_type)
return prediction_deserialized


def client() -> werkzeug.test.Client:
"""
Return an HTTP test client for :mod:`inference_server`
Expand Down
40 changes: 39 additions & 1 deletion tests/test_inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

from typing import Tuple

import botocore.response
import pytest

import inference_server
Expand Down Expand Up @@ -70,13 +73,48 @@ def test_path_not_found(client):


def test_invocations():
# Test the default plugin which should just pass through any input bytes
"""Test the default plugin (which passes through any input bytes) using low-level testing.post_invocations"""
data = b"What's the shipping forecast for tomorrow"
response = inference_server.testing.post_invocations(data=data, headers={"Accept": "application/octet-stream"})
assert response.data == data
assert response.headers["Content-Type"] == "application/octet-stream"


def test_prediction_custom_serializer():
"""Test the default plugin again, now using high-level testing.predict"""

class Serializer:
@property
def CONTENT_TYPE(self) -> str:
return "application/octet-stream"

def serialize(self, data: str) -> bytes:
return data.encode() # Simple str to bytes serializer

class Deserializer:
@property
def ACCEPT(self) -> Tuple[str]:
return ("application/octet-stream",)

def deserialize(self, stream: botocore.response.StreamingBody, content_type: str) -> str:
assert content_type in self.ACCEPT
return stream.read().decode() # Simple bytes to str deserializer

input_data = "What's the shipping forecast for tomorrow" # Simply pass a string
prediction = inference_server.testing.predict(
data=input_data,
serializer=Serializer(),
deserializer=Deserializer(),
)
assert prediction == input_data # Receive a string


def test_prediction_no_serializer():
input_data = b"What's the shipping forecast for tomorrow"
prediction = inference_server.testing.predict(input_data) # No serializer should be bytes pass through again
assert prediction == input_data


def test_execution_parameters(client):
response = client.get("/execution-parameters")
assert response.data == b'{"BatchStrategy":"MultiRecord","MaxConcurrentTransforms":1,"MaxPayloadInMB":6}'
Expand Down
3 changes: 1 addition & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
# Tox virtual environment manager for testing and quality assurance

[tox]
envlist = py37, py38, py39, py310, py311, py312, linting, docs
envlist = py38, py39, py310, py311, py312, linting, docs
isolated_build = True
# Developers may not have all Python versions
skip_missing_interpreters = true

[gh-actions]
# Mapping from GitHub Actions Python versions to Tox environments
python =
3.7: py37
3.8: py38
3.9: py39
3.10: py310
Expand Down

0 comments on commit 4c58a1c

Please sign in to comment.