Skip to content

Commit

Permalink
model server might have already done a serialization. honor that by n…
Browse files Browse the repository at this point in the history
…ot decoding the request again if it is not already bytes or bytestream (#4987)
  • Loading branch information
gwang111 authored Jan 7, 2025
1 parent 2102bb7 commit daa5518
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 18 deletions.
22 changes: 16 additions & 6 deletions src/sagemaker/serve/model_server/multi_model_server/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,28 @@ def input_fn(input_data, content_type, context=None):
if hasattr(schema_builder, "custom_input_translator"):
deserialized_data = schema_builder.custom_input_translator.deserialize(
(
io.BytesIO(input_data)
if type(input_data) == bytes
else io.BytesIO(input_data.encode("utf-8"))
io.BytesIO(input_data.encode("utf-8"))
if not any(
[
isinstance(input_data, bytes),
isinstance(input_data, bytearray),
]
)
else io.BytesIO(input_data)
),
content_type,
)
else:
deserialized_data = schema_builder.input_deserializer.deserialize(
(
io.BytesIO(input_data)
if type(input_data) == bytes
else io.BytesIO(input_data.encode("utf-8"))
io.BytesIO(input_data.encode("utf-8"))
if not any(
[
isinstance(input_data, bytes),
isinstance(input_data, bytearray),
]
)
else io.BytesIO(input_data)
),
content_type[0],
)
Expand Down
22 changes: 16 additions & 6 deletions src/sagemaker/serve/model_server/torchserve/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,28 @@ def input_fn(input_data, content_type):
if hasattr(schema_builder, "custom_input_translator"):
deserialized_data = schema_builder.custom_input_translator.deserialize(
(
io.BytesIO(input_data)
if type(input_data) == bytes
else io.BytesIO(input_data.encode("utf-8"))
io.BytesIO(input_data.encode("utf-8"))
if not any(
[
isinstance(input_data, bytes),
isinstance(input_data, bytearray),
]
)
else io.BytesIO(input_data)
),
content_type,
)
else:
deserialized_data = schema_builder.input_deserializer.deserialize(
(
io.BytesIO(input_data)
if type(input_data) == bytes
else io.BytesIO(input_data.encode("utf-8"))
io.BytesIO(input_data.encode("utf-8"))
if not any(
[
isinstance(input_data, bytes),
isinstance(input_data, bytearray),
]
)
else io.BytesIO(input_data)
),
content_type[0],
)
Expand Down
22 changes: 16 additions & 6 deletions src/sagemaker/serve/model_server/torchserve/xgboost_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,28 @@ def input_fn(input_data, content_type):
if hasattr(schema_builder, "custom_input_translator"):
return schema_builder.custom_input_translator.deserialize(
(
io.BytesIO(input_data)
if type(input_data) == bytes
else io.BytesIO(input_data.encode("utf-8"))
io.BytesIO(input_data.encode("utf-8"))
if not any(
[
isinstance(input_data, bytes),
isinstance(input_data, bytearray),
]
)
else io.BytesIO(input_data)
),
content_type,
)
else:
return schema_builder.input_deserializer.deserialize(
(
io.BytesIO(input_data)
if type(input_data) == bytes
else io.BytesIO(input_data.encode("utf-8"))
io.BytesIO(input_data.encode("utf-8"))
if not any(
[
isinstance(input_data, bytes),
isinstance(input_data, bytearray),
]
)
else io.BytesIO(input_data)
),
content_type[0],
)
Expand Down

0 comments on commit daa5518

Please sign in to comment.