Skip to content

Commit

Permalink
Fix streaming for OpenAI clients (#1371)
Browse files Browse the repository at this point in the history
  • Loading branch information
nnarayen authored Feb 7, 2025
1 parent 911dddf commit d5db001
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 22 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.60rc005"
version = "0.9.60rc006"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
56 changes: 35 additions & 21 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,19 +716,37 @@ async def _process_model_fn(
result = await self._execute_async_model_fn(inputs, request, descriptor)

if inspect.isgenerator(result) or inspect.isasyncgen(result):
if request.headers.get("accept") == "application/json":
return await _gather_generator(result)
else:
return await self._stream_with_background_task(
result,
fn_span,
detached_ctx,
# No semaphores needed for non-predict model functions.
release_and_end=lambda: None,
)
return await self._handle_generator_response(
request, result, fn_span, detached_ctx, release_and_end=lambda: None
)

return result

def _should_gather_generator(self, request: starlette.requests.Request) -> bool:
# The OpenAI SDK sends an accept header for JSON even in a streaming context,
# but we need to stream results back for client compatibility. Luckily,
# we can differentiate by looking at the user agent (e.g. OpenAI/Python 1.61.0)
user_agent = request.headers.get("user-agent", "")
if "openai" in user_agent.lower():
return False
# TODO(nikhil): determine if we can safely deprecate this behavior.
return request.headers.get("accept") == "application/json"

async def _handle_generator_response(
self,
request: starlette.requests.Request,
generator: Union[Generator[bytes, None, None], AsyncGenerator[bytes, None]],
span: trace.Span,
trace_ctx: trace.Context,
release_and_end: Callable[[], None],
):
if self._should_gather_generator(request):
return await _gather_generator(generator)
else:
return await self._stream_with_background_task(
generator, span, trace_ctx, release_and_end
)

async def completions(
self, inputs: InputType, request: starlette.requests.Request
) -> OutputType:
Expand Down Expand Up @@ -801,17 +819,13 @@ async def __call__(
"the predict method."
)

if request.headers.get("accept") == "application/json":
# In the case of a streaming response, consume stream
# if the http accept header is set, and json is requested.
return await _gather_generator(predict_result)
else:
return await self._stream_with_background_task(
predict_result,
span_predict,
detached_ctx,
release_and_end=get_defer_fn(),
)
return await self._handle_generator_response(
request,
predict_result,
span_predict,
detached_ctx,
release_and_end=get_defer_fn(),
)

if isinstance(predict_result, starlette.responses.Response):
if self.model_descriptor.postprocess:
Expand Down
41 changes: 41 additions & 0 deletions truss/tests/test_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1818,3 +1818,44 @@ async def predict(self, nums: AsyncGenerator[str, None]) -> List[str]:
response = requests.post(PREDICT_URL, json={"nums": ["1", "2"]})
assert response.status_code == 200
assert response.json() == ["1", "2"]


@pytest.mark.integration
def test_openai_client_streaming():
"""
Test a Truss that exposes an OpenAI compatible endpoint.
"""
model = """
from typing import Dict, AsyncGenerator
class Model:
def __init__(self):
pass
def load(self):
pass
async def chat_completions(self, inputs: Dict) -> AsyncGenerator[str, None]:
for num in inputs["nums"]:
yield num
async def predict(self, inputs: Dict):
pass
"""
with ensure_kill_all(), _temp_truss(model) as tr:
tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)

response = requests.post(
CHAT_COMPLETIONS_URL,
json={"nums": ["1", "2"]},
stream=True,
# Despite requesting json, we should still stream results back.
headers={
"accept": "application/json",
"user-agent": "OpenAI/Python 1.61.0",
},
)
assert response.headers.get("transfer-encoding") == "chunked"
assert [
byte_string.decode() for byte_string in list(response.iter_content())
] == ["1", "2"]

0 comments on commit d5db001

Please sign in to comment.