Skip to content

Commit

Permalink
feature: pass context to handler functions (#109)
Browse files Browse the repository at this point in the history
update unit test

run black locally

fix signature for py2.7

pin flake8

update readme

Co-authored-by: Wei Chu <[email protected]>
  • Loading branch information
waytrue17 and Wei Chu authored Aug 18, 2022
1 parent 52cd814 commit b43439c
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 49 deletions.
23 changes: 19 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ To use the SageMaker Inference Toolkit, you need to do the following:

class DefaultPytorchInferenceHandler(default_inference_handler.DefaultInferenceHandler):

def default_model_fn(self, model_dir):
def default_model_fn(self, model_dir, context=None):
"""Loads a model. For PyTorch, a default function to load a model cannot be provided.
Users should provide customized model_fn() in script.
Args:
model_dir: a directory where model is saved.
context (obj): the request context (default: None).
Returns: A PyTorch model.
"""
Expand All @@ -60,40 +61,54 @@ To use the SageMaker Inference Toolkit, you need to do the following:
See documentation for model_fn at https://github.com/aws/sagemaker-python-sdk
"""))

def default_input_fn(self, input_data, content_type):
def default_input_fn(self, input_data, content_type, context=None):
"""A default input_fn that can handle JSON, CSV and NPZ formats.
Args:
input_data: the request payload serialized in the content_type format
content_type: the request content_type
context (obj): the request context (default: None).
Returns: input_data deserialized into torch.FloatTensor or torch.cuda.FloatTensor depending if cuda is available.
"""
return decoder.decode(input_data, content_type)

def default_predict_fn(self, data, model):
def default_predict_fn(self, data, model, context=None):
"""A default predict_fn for PyTorch. Calls a model on data deserialized in input_fn.
Runs prediction on GPU if cuda is available.
Args:
data: input data (torch.Tensor) for prediction deserialized by input_fn
model: PyTorch model loaded in memory by model_fn
context (obj): the request context (default: None).
Returns: a prediction
"""
return model(input_data)

def default_output_fn(self, prediction, accept):
def default_output_fn(self, prediction, accept, context=None):
"""A default output_fn for PyTorch. Serializes predictions from predict_fn to JSON, CSV or NPY format.
Args:
prediction: a prediction result from predict_fn
accept: type which the output data needs to be serialized
context (obj): the request context (default: None).
Returns: output data serialized
"""
return encoder.encode(prediction, accept)
```
Note, passing context as an argument to the handler functions is optional. Customer can choose to omit context from the function declaration if it's not needed in the runtime. For example, the following handler function declarations will also work:

```
def default_model_fn(self, model_dir)

def default_input_fn(self, input_data, content_type)

def default_predict_fn(self, data, model)

def default_output_fn(self, prediction, accept)
```

2. Implement a handler service that is executed by the model server.
([Here is an example](https://github.com/aws/sagemaker-pytorch-serving-container/blob/master/src/sagemaker_pytorch_serving_container/handler_service.py) of a handler service.)
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker_inference/default_handler_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ def initialize(self, context):
else:
os.environ[PYTHON_PATH_ENV] = code_dir_path

self._service.validate_and_initialize(model_dir=model_dir)
self._service.validate_and_initialize(model_dir=model_dir, context=context)
17 changes: 11 additions & 6 deletions src/sagemaker_inference/default_inference_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
class DefaultInferenceHandler(object):
"""Bare-bones implementation of default inference functions."""

def default_model_fn(self, model_dir):
def default_model_fn(self, model_dir, context=None):
"""Function responsible for loading the model.
Args:
model_dir (str): The directory where model files are stored.
context (obj): the request context (default: None).
Returns:
obj: the loaded model.
Expand All @@ -40,25 +41,28 @@ def default_model_fn(self, model_dir):
)
)

def default_input_fn(self, input_data, content_type): # pylint: disable=no-self-use
def default_input_fn(self, input_data, content_type, context=None):
# pylint: disable=unused-argument, no-self-use
"""Function responsible for deserializing the input data into an object for prediction.
Args:
input_data (obj): the request data.
content_type (str): the request content type.
context (obj): the request context (default: None).
Returns:
obj: data ready for prediction.
"""
return decoder.decode(input_data, content_type)

def default_predict_fn(self, data, model):
def default_predict_fn(self, data, model, context=None):
"""Function responsible for model predictions.
Args:
model (obj): model loaded by the model_fn
data: deserialized data returned by the input_fn
model (obj): model loaded by the model_fn.
data: deserialized data returned by the input_fn.
context (obj): the request context (default: None).
Returns:
obj: prediction result.
Expand All @@ -73,12 +77,13 @@ def default_predict_fn(self, data, model):
)
)

def default_output_fn(self, prediction, accept): # pylint: disable=no-self-use
def default_output_fn(self, prediction, accept, context=None): # pylint: disable=no-self-use
"""Function responsible for serializing the prediction result to the desired accept type.
Args:
prediction (obj): prediction result returned by the predict_fn.
accept (str): accept header expected by the client.
context (obj): the request context (default: None).
Returns:
obj: prediction data.
Expand Down
62 changes: 52 additions & 10 deletions src/sagemaker_inference/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@
import importlib
import traceback

try:
from inspect import signature # pylint: disable=ungrouped-imports
except ImportError:
# for Python2.7
import subprocess
import sys

subprocess.check_call([sys.executable, "-m", "pip", "install", "inspect2"])
from inspect2 import signature

try:
from importlib.util import find_spec # pylint: disable=ungrouped-imports
except ImportError:
Expand Down Expand Up @@ -73,6 +83,7 @@ def __init__(self, default_inference_handler=None):
self._input_fn = None
self._predict_fn = None
self._output_fn = None
self._context = None

@staticmethod
def handle_error(context, inference_exception, trace):
Expand Down Expand Up @@ -109,7 +120,7 @@ def transform(self, data, context):
try:
properties = context.system_properties
model_dir = properties.get("model_dir")
self.validate_and_initialize(model_dir=model_dir)
self.validate_and_initialize(model_dir=model_dir, context=context)

input_data = data[0].get("body")

Expand All @@ -125,7 +136,9 @@ def transform(self, data, context):
if content_type in content_types.UTF8_TYPES:
input_data = input_data.decode("utf-8")

result = self._transform_fn(self._model, input_data, content_type, accept)
result = self._run_handler_function(
self._transform_fn, *(self._model, input_data, content_type, accept)
)

response = result
response_content_type = accept
Expand All @@ -148,20 +161,25 @@ def transform(self, data, context):
trace,
)

def validate_and_initialize(self, model_dir=environment.model_dir): # type: () -> None
def validate_and_initialize(self, model_dir=environment.model_dir, context=None):
"""Validates the user module against the SageMaker inference contract.
Load the model as defined by the ``model_fn`` to prepare handling predictions.
"""
if not self._initialized:
self._context = context
self._environment = environment.Environment()
self._validate_user_module_and_set_functions()

if self._pre_model_fn is not None:
self._pre_model_fn(model_dir)
self._model = self._model_fn(model_dir)
self._run_handler_function(self._pre_model_fn, *(model_dir,))

self._model = self._run_handler_function(self._model_fn, *(model_dir,))

if self._model_warmup_fn is not None:
self._model_warmup_fn(model_dir, self._model)
self._run_handler_function(self._model_warmup_fn, *(model_dir, self._model))

self._initialized = True

def _validate_user_module_and_set_functions(self):
Expand Down Expand Up @@ -214,7 +232,8 @@ def _validate_user_module_and_set_functions(self):

self._transform_fn = self._default_transform_fn

def _default_transform_fn(self, model, input_data, content_type, accept):
def _default_transform_fn(self, model, input_data, content_type, accept, context=None):
# pylint: disable=unused-argument
"""Make predictions against the model and return a serialized response.
This serves as the default implementation of transform_fn, used when the
user has not provided an implementation.
Expand All @@ -224,13 +243,36 @@ def _default_transform_fn(self, model, input_data, content_type, accept):
input_data (obj): the request data.
content_type (str): the request content type.
accept (str): accept header expected by the client.
context (obj): the request context (default: None).
Returns:
obj: the serialized prediction result or a tuple of the form
(response_data, content_type)
"""
data = self._input_fn(input_data, content_type)
prediction = self._predict_fn(data, model)
result = self._output_fn(prediction, accept)
data = self._run_handler_function(self._input_fn, *(input_data, content_type))
prediction = self._run_handler_function(self._predict_fn, *(data, model))
result = self._run_handler_function(self._output_fn, *(prediction, accept))
return result

def _run_handler_function(self, func, *argv):
"""Helper to call the handler function which covers 2 cases:
1. the handle function takes context
2. the handle function does not take context
"""
num_func_input = len(signature(func).parameters)
if num_func_input == len(argv):
# function does not take context
result = func(*argv)
elif num_func_input == len(argv) + 1:
# function takes context
argv_context = argv + (self._context,)
result = func(*argv_context)
else:
raise TypeError(
"{} takes {} arguments but {} were given.".format(
func.__name__, num_func_input, len(argv)
)
)

return result
8 changes: 5 additions & 3 deletions test/unit/test_default_inference_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from mock import patch
from mock import Mock, patch
import pytest

from sagemaker_inference import content_types
Expand All @@ -19,7 +19,8 @@

@patch("sagemaker_inference.decoder.decode")
def test_default_input_fn(loads):
assert DefaultInferenceHandler().default_input_fn(42, content_types.JSON)
context = Mock()
assert DefaultInferenceHandler().default_input_fn(42, content_types.JSON, context)

loads.assert_called_with(42, content_types.JSON)

Expand All @@ -34,7 +35,8 @@ def test_default_input_fn(loads):
)
@patch("sagemaker_inference.encoder.encode", lambda prediction, accept: prediction**2)
def test_default_output_fn(accept, expected_content_type):
result, content_type = DefaultInferenceHandler().default_output_fn(2, accept)
context = Mock()
result, content_type = DefaultInferenceHandler().default_output_fn(2, accept, context)
assert result == 4
assert content_type == expected_content_type

Expand Down
Loading

0 comments on commit b43439c

Please sign in to comment.