Skip to content

Commit

Permalink
Websocket support in chains (entrypoint only).
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten committed Feb 22, 2025
1 parent 0c1d882 commit 5eeca91
Show file tree
Hide file tree
Showing 12 changed files with 354 additions and 137 deletions.
39 changes: 38 additions & 1 deletion truss-chains/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
import requests
import websockets
from truss.tests.test_testing_utilities_for_other_tests import (
ensure_kill_all,
get_container_logs_from_prefix,
Expand All @@ -20,6 +21,11 @@
TEST_ROOT = Path(__file__).parent.resolve()


@pytest.fixture
def anyio_backend():
return "asyncio"


@pytest.mark.integration
def test_chain():
with ensure_kill_all():
Expand Down Expand Up @@ -319,7 +325,7 @@ def test_custom_health_checks_chain():
assert "Health check failed." not in container_logs

# Start failing health checks
response = service.run_remote({"fail": True})
_ = service.run_remote({"fail": True})
response = requests.get(health_check_url)
assert response.status_code == 503
container_logs = get_container_logs_from_prefix(entrypoint.name)
Expand All @@ -328,3 +334,34 @@ def test_custom_health_checks_chain():
assert response.status_code == 503
container_logs = get_container_logs_from_prefix(entrypoint.name)
assert container_logs.count("Health check failed.") == 2


@pytest.mark.integration
async def test_websocket_chain(anyio_backend):
with ensure_kill_all():
chain_name = "websocket_chain"
chain_root = TEST_ROOT / chain_name / f"{chain_name}.py"
with framework.ChainletImporter.import_target(chain_root) as entrypoint:
service = deployment_client.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
chain_name=chain_name,
only_generate_trusses=False,
use_local_chains_src=True,
),
)
# Get something like `ws://localhost:38605/v1/websocket`.
url = service.run_remote_url.replace("http", "ws").replace(
"v1/models/model:predict", "v1/websocket"
)
print(url)
logging.warning(url)
time.sleep(1)
async with websockets.connect(url) as websocket:
await websocket.send("Test")
response = await websocket.recv()
assert response == "You said: Test."

await websocket.send("dep")
response = await websocket.recv()
assert response == "Hello from dependency, Head."
49 changes: 48 additions & 1 deletion truss-chains/tests/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ async def run_remote(self) -> AsyncIterator:
yield "123"


def test_raises_is_healthy_not_a_method():
def test_raises_is_healthy_not_a_method() -> None:
match = rf"{TEST_FILE}:\d+ \(IsHealthyNotMethod\) \[kind: TYPE_ERROR\].* `is_healthy` must be a method."

with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():
Expand Down Expand Up @@ -686,3 +686,50 @@ def test_import_model_requires_single_entrypoint():
with pytest.raises(ValueError, match=match), _raise_errors():
with framework.ModelImporter.import_target(model_src):
pass


def test_raises_websocket_with_other_args():
match = (
rf"{TEST_FILE}:\d+ \(WebsocketWithOtherArgs\.run_remote\) \[kind: IO_TYPE_ERROR\].*"
r"When using a websocket as input, no other arguments are allowed"
)

with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():

class WebsocketWithOtherArgs(chains.ChainletBase):
def run_remote(
self, websocket: chains.WebSocketProtocol, name: str
) -> None:
pass


def test_raises_websocket_as_output():
match = (
rf"{TEST_FILE}:\d+ \(WebsocketOutput\.run_remote\) \[kind: IO_TYPE_ERROR\].*"
r"Websockets cannot be used as output type"
)

with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():

class WebsocketOutput(chains.ChainletBase):
def run_remote(self) -> chains.WebSocketProtocol: ... # type: ignore[empty-body]


def test_raises_websocket_as_dependency():
match = (
rf"{TEST_FILE}:\d+ \(WebsocketAsDependency\.__init__\) \[kind: TYPE_ERROR\].*"
r"websockets can only be used in the entrypoint.*"
)

with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():

class Dependency(chains.ChainletBase):
def run_remote(self, websocket: chains.WebSocketProtocol) -> None:
pass

class WebsocketAsDependency(chains.ChainletBase):
def __init__(self, dependency=chains.depends(Dependency)):
self._dependency = dependency

def run_remote(self) -> None:
pass
28 changes: 28 additions & 0 deletions truss-chains/tests/websocket_chain/websocket_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import fastapi

import truss_chains as chains


class Dependency(chains.ChainletBase):
async def run_remote(self, name: str) -> str:
msg = f"Hello from dependency, {name}."
print(msg)
return msg


@chains.mark_entrypoint # ("My Chain Name")
class Head(chains.ChainletBase):
def __init__(self, dependency=chains.depends(Dependency)):
self._dependency = dependency

async def run_remote(self, websocket: chains.WebSocketProtocol) -> None:
try:
while True:
text = await websocket.receive_text()
if text == "dep":
result = await self._dependency.run_remote("Head")
else:
result = f"You said: {text}."
await websocket.send_text(result)
except fastapi.WebSocketDisconnect:
print("Disconnected.")
8 changes: 5 additions & 3 deletions truss-chains/truss_chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
RemoteConfig,
RemoteErrorDetail,
RPCOptions,
WebSocketProtocol,
)
from truss_chains.framework import ChainletBase, ModelBase
from truss_chains.public_api import (
Expand All @@ -42,18 +43,19 @@
"Assets",
"BasetenImage",
"ChainletBase",
"ModelBase",
"ChainletOptions",
"Compute",
"CustomImage",
"DeployedServiceDescriptor",
"DeploymentContext",
"DockerImage",
"RPCOptions",
"GenericRemoteException",
"ModelBase",
"RPCOptions",
"RemoteConfig",
"RemoteErrorDetail",
"DeployedServiceDescriptor",
"StubBase",
"WebSocketProtocol",
"depends",
"depends_context",
"make_abs_path_here",
Expand Down
30 changes: 30 additions & 0 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import traceback
from typing import ( # type: ignore[attr-defined] # Chains uses Python >=3.9.
Any,
AsyncIterator,
Callable,
ClassVar,
Generic,
Expand Down Expand Up @@ -576,6 +577,10 @@ def has_pydantic_args(self):
for arg in args
)

@property
def is_websocket(self) -> bool:
return self.raw == WebSocketProtocol


class StreamingTypeDescriptor(TypeDescriptor):
origin_type: type
Expand Down Expand Up @@ -613,6 +618,10 @@ def streaming_type(self) -> StreamingTypeDescriptor:
raise ValueError(f"{self} is not a streaming endpoint.")
return cast(StreamingTypeDescriptor, self.output_types[0])

@property
def is_websocket(self):
return any(arg.type.is_websocket for arg in self.input_args)


class DependencyDescriptor(SafeModelNonSerializable):
chainlet_cls: Type[ABCChainlet]
Expand Down Expand Up @@ -754,3 +763,24 @@ class PushOptionsLocalDocker(PushOptions):
# in the docker image (which takes precedence over potential pip/site-packages).
# This should be used for integration tests or quick local dev loops.
use_local_chains_src: bool = False


class WebSocketProtocol(Protocol):
"""Describes subset of starlette/fastAPIs websocket interface that we expose."""

headers: Mapping[str, str]

async def accept(self) -> None: ...
async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: ...

async def receive_text(self) -> str: ...
async def receive_bytes(self) -> bytes: ...
async def receive_json(self) -> Any: ...

async def send_text(self, data: str) -> None: ...
async def send_bytes(self, data: bytes) -> None: ...
async def send_json(self, data: Any) -> None: ...

def iter_text(self) -> AsyncIterator[str]: ...
def iter_bytes(self) -> AsyncIterator[bytes]: ...
def iter_json(self) -> AsyncIterator[Any]: ...
42 changes: 32 additions & 10 deletions truss-chains/truss_chains/deployment/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,6 @@ def leave_SimpleStatementLine(


def _gen_load_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _Source:
"""Generates AST for the `load` method of the truss model."""
imports = {"from truss_chains.remote_chainlet import stub", "import logging"}
stub_args = []
for name, dep in chainlet_descriptor.dependencies.items():
Expand Down Expand Up @@ -460,7 +459,6 @@ def _gen_load_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _So
def _gen_health_check_src(
health_check: definitions.HealthCheckAPIDescriptor,
) -> _Source:
"""Generates AST for the `is_healthy` method of the truss model."""
def_str = "async def" if health_check.is_async else "def"
maybe_await = "await " if health_check.is_async else ""
src = (
Expand All @@ -472,7 +470,6 @@ def _gen_health_check_src(


def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _Source:
"""Generates AST for the `predict` method of the truss model."""
imports: set[str] = {
"from truss_chains.remote_chainlet import stub",
"from truss_chains.remote_chainlet import utils",
Expand All @@ -495,7 +492,7 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) ->
f"request: starlette.requests.Request) -> {output_type_name}:"
)
# Add error handling context manager:
parts.append(_indent("with utils.predict_context(request):"))
parts.append(_indent("with utils.predict_context(request.headers):"))
# Invoke Chainlet.
if (
chainlet_descriptor.endpoint.is_async
Expand All @@ -521,6 +518,18 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) ->
return _Source(src="\n".join(parts), imports=imports)


def _gen_websocket_src() -> _Source:
src = """
async def websocket(self, websocket: fastapi.WebSocket) -> None:
with utils.predict_context(websocket.headers):
await self._chainlet.run_remote(websocket)
"""
return _Source(
src=src,
imports={"import fastapi", "from truss_chains.remote_chainlet import utils "},
)


def _gen_truss_chainlet_model(
chainlet_descriptor: definitions.ChainletAPIDescriptor,
) -> _Source:
Expand All @@ -542,12 +551,16 @@ def _gen_truss_chainlet_model(

load_src = _gen_load_src(chainlet_descriptor)
imports.update(load_src.imports)
predict_src = _gen_predict_src(chainlet_descriptor)
imports.update(predict_src.imports)
if chainlet_descriptor.endpoint.is_websocket:
endpoint_src = _gen_websocket_src()
imports.update(endpoint_src.imports)
else:
endpoint_src = _gen_predict_src(chainlet_descriptor)
imports.update(endpoint_src.imports)

new_body: list[Any] = list(class_definition.body.body) + [
libcst.parse_statement(load_src.src),
libcst.parse_statement(predict_src.src),
libcst.parse_statement(endpoint_src.src),
]

if chainlet_descriptor.health_check is not None:
Expand Down Expand Up @@ -581,11 +594,17 @@ def _gen_truss_chainlet_file(
if maybe_stub_src := _gen_stub_src_for_deps(dependencies):
_update_src(maybe_stub_src, src_parts, imports)

input_src = _gen_truss_input_pydantic(chainlet_descriptor)
_update_src(input_src, src_parts, imports)
if not chainlet_descriptor.endpoint.is_streaming:
if not chainlet_descriptor.endpoint.is_websocket:
input_src = _gen_truss_input_pydantic(chainlet_descriptor)
_update_src(input_src, src_parts, imports)

if (
not chainlet_descriptor.endpoint.is_streaming
and not chainlet_descriptor.endpoint.is_websocket
):
output_src = _gen_truss_output_pydantic(chainlet_descriptor)
_update_src(output_src, src_parts, imports)

model_src = _gen_truss_chainlet_model(chainlet_descriptor)
_update_src(model_src, src_parts, imports)

Expand Down Expand Up @@ -670,6 +689,7 @@ def _write_truss_config_yaml(
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
model_name: str,
use_local_chains_src: bool,
is_websocket_endpoint: bool,
):
"""Generate a truss config for a Chainlet."""
config = truss_config.TrussConfig()
Expand All @@ -686,6 +706,7 @@ def _write_truss_config_yaml(
config.resources.accelerator = compute.accelerator
config.resources.use_gpu = bool(compute.accelerator.count)
config.runtime.predict_concurrency = compute.predict_concurrency
config.runtime.is_websocket_endpoint = is_websocket_endpoint
# Image.
_inplace_fill_base_image(chains_config.docker_image, config)
pip_requirements = _make_requirements(chains_config.docker_image)
Expand Down Expand Up @@ -780,6 +801,7 @@ def gen_truss_chainlet(
model_name=model_name or chain_name,
chainlet_to_service=dep_services,
use_local_chains_src=use_local_chains_src,
is_websocket_endpoint=chainlet_descriptor.endpoint.is_websocket,
)
# This assumes all imports are absolute w.r.t chain root (or site-packages).
truss_path.copy_tree_path(
Expand Down
Loading

0 comments on commit 5eeca91

Please sign in to comment.