diff --git a/truss-chains/tests/test_e2e.py b/truss-chains/tests/test_e2e.py index 174a359af..cefd307a0 100644 --- a/truss-chains/tests/test_e2e.py +++ b/truss-chains/tests/test_e2e.py @@ -5,6 +5,7 @@ import pytest import requests +import websockets from truss.tests.test_testing_utilities_for_other_tests import ( ensure_kill_all, get_container_logs_from_prefix, @@ -20,6 +21,11 @@ TEST_ROOT = Path(__file__).parent.resolve() +@pytest.fixture +def anyio_backend(): + return "asyncio" + + @pytest.mark.integration def test_chain(): with ensure_kill_all(): @@ -319,7 +325,7 @@ def test_custom_health_checks_chain(): assert "Health check failed." not in container_logs # Start failing health checks - response = service.run_remote({"fail": True}) + _ = service.run_remote({"fail": True}) response = requests.get(health_check_url) assert response.status_code == 503 container_logs = get_container_logs_from_prefix(entrypoint.name) @@ -328,3 +334,34 @@ def test_custom_health_checks_chain(): assert response.status_code == 503 container_logs = get_container_logs_from_prefix(entrypoint.name) assert container_logs.count("Health check failed.") == 2 + + +@pytest.mark.integration +async def test_websocket_chain(anyio_backend): + with ensure_kill_all(): + chain_name = "websocket_chain" + chain_root = TEST_ROOT / chain_name / f"{chain_name}.py" + with framework.ChainletImporter.import_target(chain_root) as entrypoint: + service = deployment_client.push( + entrypoint, + options=definitions.PushOptionsLocalDocker( + chain_name=chain_name, + only_generate_trusses=False, + use_local_chains_src=True, + ), + ) + # Get something like `ws://localhost:38605/v1/websocket`. + url = service.run_remote_url.replace("http", "ws").replace( + "v1/models/model:predict", "v1/websocket" + ) + print(url) + logging.warning(url) + time.sleep(1) + async with websockets.connect(url) as websocket: + await websocket.send("Test") + response = await websocket.recv() + assert response == "You said: Test." + + await websocket.send("dep") + response = await websocket.recv() + assert response == "Hello from dependency, Head." diff --git a/truss-chains/tests/test_framework.py b/truss-chains/tests/test_framework.py index adafa70d6..a338f10ee 100644 --- a/truss-chains/tests/test_framework.py +++ b/truss-chains/tests/test_framework.py @@ -583,7 +583,7 @@ async def run_remote(self) -> AsyncIterator: yield "123" -def test_raises_is_healthy_not_a_method(): +def test_raises_is_healthy_not_a_method() -> None: match = rf"{TEST_FILE}:\d+ \(IsHealthyNotMethod\) \[kind: TYPE_ERROR\].* `is_healthy` must be a method." with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): @@ -686,3 +686,50 @@ def test_import_model_requires_single_entrypoint(): with pytest.raises(ValueError, match=match), _raise_errors(): with framework.ModelImporter.import_target(model_src): pass + + +def test_raises_websocket_with_other_args(): + match = ( + rf"{TEST_FILE}:\d+ \(WebsocketWithOtherArgs\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"When using a websocket as input, no other arguments are allowed" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class WebsocketWithOtherArgs(chains.ChainletBase): + def run_remote( + self, websocket: chains.WebSocketProtocol, name: str + ) -> None: + pass + + +def test_raises_websocket_as_output(): + match = ( + rf"{TEST_FILE}:\d+ \(WebsocketOutput\.run_remote\) \[kind: IO_TYPE_ERROR\].*" + r"Websockets cannot be used as output type" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class WebsocketOutput(chains.ChainletBase): + def run_remote(self) -> chains.WebSocketProtocol: ... # type: ignore[empty-body] + + +def test_raises_websocket_as_dependency(): + match = ( + rf"{TEST_FILE}:\d+ \(WebsocketAsDependency\.__init__\) \[kind: TYPE_ERROR\].*" + r"websockets can only be used in the entrypoint.*" + ) + + with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors(): + + class Dependency(chains.ChainletBase): + def run_remote(self, websocket: chains.WebSocketProtocol) -> None: + pass + + class WebsocketAsDependency(chains.ChainletBase): + def __init__(self, dependency=chains.depends(Dependency)): + self._dependency = dependency + + def run_remote(self) -> None: + pass diff --git a/truss-chains/tests/websocket_chain/websocket_chain.py b/truss-chains/tests/websocket_chain/websocket_chain.py new file mode 100644 index 000000000..d1b5409c6 --- /dev/null +++ b/truss-chains/tests/websocket_chain/websocket_chain.py @@ -0,0 +1,28 @@ +import fastapi + +import truss_chains as chains + + +class Dependency(chains.ChainletBase): + async def run_remote(self, name: str) -> str: + msg = f"Hello from dependency, {name}." + print(msg) + return msg + + +@chains.mark_entrypoint # ("My Chain Name") +class Head(chains.ChainletBase): + def __init__(self, dependency=chains.depends(Dependency)): + self._dependency = dependency + + async def run_remote(self, websocket: chains.WebSocketProtocol) -> None: + try: + while True: + text = await websocket.receive_text() + if text == "dep": + result = await self._dependency.run_remote("Head") + else: + result = f"You said: {text}." + await websocket.send_text(result) + except fastapi.WebSocketDisconnect: + print("Disconnected.") diff --git a/truss-chains/truss_chains/__init__.py b/truss-chains/truss_chains/__init__.py index 94b62b1e4..bbcf4646b 100644 --- a/truss-chains/truss_chains/__init__.py +++ b/truss-chains/truss_chains/__init__.py @@ -24,6 +24,7 @@ RemoteConfig, RemoteErrorDetail, RPCOptions, + WebSocketProtocol, ) from truss_chains.framework import ChainletBase, ModelBase from truss_chains.public_api import ( @@ -42,18 +43,19 @@ "Assets", "BasetenImage", "ChainletBase", - "ModelBase", "ChainletOptions", "Compute", "CustomImage", + "DeployedServiceDescriptor", "DeploymentContext", "DockerImage", - "RPCOptions", "GenericRemoteException", + "ModelBase", + "RPCOptions", "RemoteConfig", "RemoteErrorDetail", - "DeployedServiceDescriptor", "StubBase", + "WebSocketProtocol", "depends", "depends_context", "make_abs_path_here", diff --git a/truss-chains/truss_chains/definitions.py b/truss-chains/truss_chains/definitions.py index 7b669d19a..a03318b69 100644 --- a/truss-chains/truss_chains/definitions.py +++ b/truss-chains/truss_chains/definitions.py @@ -6,6 +6,7 @@ import traceback from typing import ( # type: ignore[attr-defined] # Chains uses Python >=3.9. Any, + AsyncIterator, Callable, ClassVar, Generic, @@ -576,6 +577,10 @@ def has_pydantic_args(self): for arg in args ) + @property + def is_websocket(self) -> bool: + return self.raw == WebSocketProtocol + class StreamingTypeDescriptor(TypeDescriptor): origin_type: type @@ -613,6 +618,10 @@ def streaming_type(self) -> StreamingTypeDescriptor: raise ValueError(f"{self} is not a streaming endpoint.") return cast(StreamingTypeDescriptor, self.output_types[0]) + @property + def is_websocket(self): + return any(arg.type.is_websocket for arg in self.input_args) + class DependencyDescriptor(SafeModelNonSerializable): chainlet_cls: Type[ABCChainlet] @@ -754,3 +763,24 @@ class PushOptionsLocalDocker(PushOptions): # in the docker image (which takes precedence over potential pip/site-packages). # This should be used for integration tests or quick local dev loops. use_local_chains_src: bool = False + + +class WebSocketProtocol(Protocol): + """Describes subset of starlette/fastAPIs websocket interface that we expose.""" + + headers: Mapping[str, str] + + async def accept(self) -> None: ... + async def close(self, code: int = 1000, reason: Optional[str] = None) -> None: ... + + async def receive_text(self) -> str: ... + async def receive_bytes(self) -> bytes: ... + async def receive_json(self) -> Any: ... + + async def send_text(self, data: str) -> None: ... + async def send_bytes(self, data: bytes) -> None: ... + async def send_json(self, data: Any) -> None: ... + + def iter_text(self) -> AsyncIterator[str]: ... + def iter_bytes(self) -> AsyncIterator[bytes]: ... + def iter_json(self) -> AsyncIterator[Any]: ... diff --git a/truss-chains/truss_chains/deployment/code_gen.py b/truss-chains/truss_chains/deployment/code_gen.py index 6f0eebfb5..4d85c1bfa 100644 --- a/truss-chains/truss_chains/deployment/code_gen.py +++ b/truss-chains/truss_chains/deployment/code_gen.py @@ -430,7 +430,6 @@ def leave_SimpleStatementLine( def _gen_load_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _Source: - """Generates AST for the `load` method of the truss model.""" imports = {"from truss_chains.remote_chainlet import stub", "import logging"} stub_args = [] for name, dep in chainlet_descriptor.dependencies.items(): @@ -460,7 +459,6 @@ def _gen_load_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _So def _gen_health_check_src( health_check: definitions.HealthCheckAPIDescriptor, ) -> _Source: - """Generates AST for the `is_healthy` method of the truss model.""" def_str = "async def" if health_check.is_async else "def" maybe_await = "await " if health_check.is_async else "" src = ( @@ -472,7 +470,6 @@ def _gen_health_check_src( def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _Source: - """Generates AST for the `predict` method of the truss model.""" imports: set[str] = { "from truss_chains.remote_chainlet import stub", "from truss_chains.remote_chainlet import utils", @@ -495,7 +492,7 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> f"request: starlette.requests.Request) -> {output_type_name}:" ) # Add error handling context manager: - parts.append(_indent("with utils.predict_context(request):")) + parts.append(_indent("with utils.predict_context(request.headers):")) # Invoke Chainlet. if ( chainlet_descriptor.endpoint.is_async @@ -521,6 +518,18 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> return _Source(src="\n".join(parts), imports=imports) +def _gen_websocket_src() -> _Source: + src = """ +async def websocket(self, websocket: fastapi.WebSocket) -> None: + with utils.predict_context(websocket.headers): + await self._chainlet.run_remote(websocket) +""" + return _Source( + src=src, + imports={"import fastapi", "from truss_chains.remote_chainlet import utils "}, + ) + + def _gen_truss_chainlet_model( chainlet_descriptor: definitions.ChainletAPIDescriptor, ) -> _Source: @@ -542,12 +551,16 @@ def _gen_truss_chainlet_model( load_src = _gen_load_src(chainlet_descriptor) imports.update(load_src.imports) - predict_src = _gen_predict_src(chainlet_descriptor) - imports.update(predict_src.imports) + if chainlet_descriptor.endpoint.is_websocket: + endpoint_src = _gen_websocket_src() + imports.update(endpoint_src.imports) + else: + endpoint_src = _gen_predict_src(chainlet_descriptor) + imports.update(endpoint_src.imports) new_body: list[Any] = list(class_definition.body.body) + [ libcst.parse_statement(load_src.src), - libcst.parse_statement(predict_src.src), + libcst.parse_statement(endpoint_src.src), ] if chainlet_descriptor.health_check is not None: @@ -581,11 +594,17 @@ def _gen_truss_chainlet_file( if maybe_stub_src := _gen_stub_src_for_deps(dependencies): _update_src(maybe_stub_src, src_parts, imports) - input_src = _gen_truss_input_pydantic(chainlet_descriptor) - _update_src(input_src, src_parts, imports) - if not chainlet_descriptor.endpoint.is_streaming: + if not chainlet_descriptor.endpoint.is_websocket: + input_src = _gen_truss_input_pydantic(chainlet_descriptor) + _update_src(input_src, src_parts, imports) + + if ( + not chainlet_descriptor.endpoint.is_streaming + and not chainlet_descriptor.endpoint.is_websocket + ): output_src = _gen_truss_output_pydantic(chainlet_descriptor) _update_src(output_src, src_parts, imports) + model_src = _gen_truss_chainlet_model(chainlet_descriptor) _update_src(model_src, src_parts, imports) @@ -670,6 +689,7 @@ def _write_truss_config_yaml( chainlet_to_service: Mapping[str, definitions.ServiceDescriptor], model_name: str, use_local_chains_src: bool, + is_websocket_endpoint: bool, ): """Generate a truss config for a Chainlet.""" config = truss_config.TrussConfig() @@ -686,6 +706,7 @@ def _write_truss_config_yaml( config.resources.accelerator = compute.accelerator config.resources.use_gpu = bool(compute.accelerator.count) config.runtime.predict_concurrency = compute.predict_concurrency + config.runtime.is_websocket_endpoint = is_websocket_endpoint # Image. _inplace_fill_base_image(chains_config.docker_image, config) pip_requirements = _make_requirements(chains_config.docker_image) @@ -780,6 +801,7 @@ def gen_truss_chainlet( model_name=model_name or chain_name, chainlet_to_service=dep_services, use_local_chains_src=use_local_chains_src, + is_websocket_endpoint=chainlet_descriptor.endpoint.is_websocket, ) # This assumes all imports are absolute w.r.t chain root (or site-packages). truss_path.copy_tree_path( diff --git a/truss-chains/truss_chains/framework.py b/truss-chains/truss_chains/framework.py index 3aab15d9b..f15ccfdb5 100644 --- a/truss-chains/truss_chains/framework.py +++ b/truss-chains/truss_chains/framework.py @@ -292,7 +292,7 @@ def _validate_io_type( location, ) return - if annotation in _SIMPLE_TYPES: + if annotation in _SIMPLE_TYPES or annotation == definitions.WebSocketProtocol: return error_msg = ( @@ -445,6 +445,27 @@ def _validate_endpoint_output_types( return output_types +def _validate_websocket_endpoint( + descriptor: definitions.EndpointAPIDescriptor, location: _ErrorLocation +): + if any(arg.is_websocket for arg in descriptor.output_types): + _collect_error( + "Websockets cannot be used as output type.", + _ErrorKind.IO_TYPE_ERROR, + location, + ) + if not any(arg.type.is_websocket for arg in descriptor.input_args): + return + + if len(descriptor.input_args) > 1: + _collect_error( + "When using a websocket as input, no other arguments are allowed.", + _ErrorKind.IO_TYPE_ERROR, + location, + ) + # TODO: add more validations here.. + + def _validate_and_describe_endpoint( cls: Type[definitions.ABCChainlet], location: _ErrorLocation ) -> definitions.EndpointAPIDescriptor: @@ -529,14 +550,15 @@ def _validate_and_describe_endpoint( DeprecationWarning, stacklevel=1, ) - - return definitions.EndpointAPIDescriptor( + descriptor = definitions.EndpointAPIDescriptor( name=cls.endpoint_method_name, input_args=input_args, output_types=output_types, is_async=is_async, is_streaming=is_streaming, ) + _validate_websocket_endpoint(descriptor, location) + return descriptor def _get_generic_class_type(var): @@ -678,7 +700,16 @@ def _validate_dependencies(self, params: list[inspect.Parameter]): _collect_error( f"The same Chainlet class cannot be used multiple times for " f"different arguments. Got previously used " - f"`{marker.chainlet_cls}` for `{param.name}`.", + f"`{marker.chainlet_cls.__name__}` for `{param.name}`.", + _ErrorKind.TYPE_ERROR, + self._location, + ) + + if get_descriptor(marker.chainlet_cls).endpoint.is_websocket: + _collect_error( + f"The dependency chainlet `{marker.chainlet_cls.__name__}` for " + f"`{param.name}` uses a websocket. But websockets can only be used " + "in the entrypoint, not in 'inner' chainlets.", _ErrorKind.TYPE_ERROR, self._location, ) diff --git a/truss-chains/truss_chains/remote_chainlet/model_skeleton.py b/truss-chains/truss_chains/remote_chainlet/model_skeleton.py index 92cbeed0b..a6076afbe 100644 --- a/truss-chains/truss_chains/remote_chainlet/model_skeleton.py +++ b/truss-chains/truss_chains/remote_chainlet/model_skeleton.py @@ -55,6 +55,6 @@ def __init__( # def predict( # self, inputs: TextToNumInput, request: starlette.requests.Request # ) -> TextToNumOutput: - # with utils.predict_context(request): + # with utils.predict_context(request.headers): # result = self._chainlet.run_remote(**utils.pydantic_set_field_dict(inputs)) # return TextToNumOutput(result) diff --git a/truss-chains/truss_chains/remote_chainlet/utils.py b/truss-chains/truss_chains/remote_chainlet/utils.py index f123b059c..b6239888b 100644 --- a/truss-chains/truss_chains/remote_chainlet/utils.py +++ b/truss-chains/truss_chains/remote_chainlet/utils.py @@ -14,7 +14,6 @@ import fastapi import httpx import pydantic -import starlette.requests from truss.templates.shared import dynamic_config_resolver from truss_chains import definitions @@ -119,9 +118,9 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: @contextlib.contextmanager -def _trace_parent(request: starlette.requests.Request) -> Iterator[None]: +def _trace_parent(headers: Mapping[str, str]) -> Iterator[None]: token = _trace_parent_context.set( - request.headers.get(definitions.OTEL_TRACE_PARENT_HEADER_KEY, "") + headers.get(definitions.OTEL_TRACE_PARENT_HEADER_KEY, "") ) try: yield @@ -327,6 +326,6 @@ async def async_response_raise_errors( @contextlib.contextmanager -def predict_context(request: starlette.requests.Request) -> Iterator[None]: - with _trace_parent(request), _exception_to_http_error(): +def predict_context(headers: Mapping[str, str]) -> Iterator[None]: + with _trace_parent(headers), _exception_to_http_error(): yield diff --git a/truss/templates/server/common/tracing.py b/truss/templates/server/common/tracing.py index cf6273877..a9bf035a1 100644 --- a/truss/templates/server/common/tracing.py +++ b/truss/templates/server/common/tracing.py @@ -118,8 +118,15 @@ def get_truss_tracer(secrets: secrets_resolver.Secrets, config) -> trace.Tracer: @contextlib.contextmanager -def detach_context() -> Iterator[trace.Context]: - """Breaks opentelemetry's context propagation. +def section_as_event( + span: sdk_trace.Span, section_name: str, detach: bool = False +) -> Iterator[Optional[trace.Context]]: + """Helper to record the start and end of a sections as events and the duration. + + Note that events are much cheaper to create than dedicated spans. + + Optionally detaches the OpenTelemetry context to isolate tracing. + This intentionally breaks opentelemetry's context propagation. The goal is to separate truss-internal tracing instrumentation completely from potential user-defined tracing. Opentelemetry has a global state @@ -128,30 +135,23 @@ def detach_context() -> Iterator[trace.Context]: internal contexts. Therefore, all user code (predict and pre/post-processing) should be wrapped in this context for isolation. """ - current_context = context.get_current() - # Create an invalid tracing context. This forces that tracing code inside this - # context manager creates a new root tracing context. - transient_token = context.attach(trace.set_span_in_context(trace.INVALID_SPAN)) - try: - yield current_context - finally: - # Reattach original context. - context.detach(transient_token) - context.attach(current_context) - - -@contextlib.contextmanager -def section_as_event(span: sdk_trace.Span, section_name: str) -> Iterator[None]: - """Helper to record the start and end of a sections as events and the duration. - - Note that events are much cheaper to create than dedicated spans. - """ t0 = time.time() span.add_event(f"start: {section_name}") + detached_ctx = None + transient_token = None + + if detach: + detached_ctx = context.get_current() + transient_token = context.attach(trace.set_span_in_context(trace.INVALID_SPAN)) + try: - yield + yield detached_ctx finally: t1 = time.time() span.add_event( f"done: {section_name}", attributes={ATTR_NAME_DURATION: t1 - t0} ) + if detach: + assert detached_ctx is not None + context.detach(transient_token) + context.attach(detached_ctx) diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index 3a3351c11..4fdc0ab21 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -220,7 +220,7 @@ def _is_generator(cls, method: Any): @dataclasses.dataclass class ModelDescriptor: preprocess: Optional[MethodDescriptor] - predict: MethodDescriptor + predict: Optional[MethodDescriptor] # Websocket may replace predict. postprocess: Optional[MethodDescriptor] truss_schema: Optional[TrussSchema] setup_environment: Optional[MethodDescriptor] @@ -231,8 +231,13 @@ class ModelDescriptor: @cached_property def skip_input_parsing(self) -> bool: - return self.predict.arg_config == ArgConfig.REQUEST_ONLY and ( - not self.preprocess or self.preprocess.arg_config == ArgConfig.REQUEST_ONLY + return bool( + self.predict + and self.predict.arg_config == ArgConfig.REQUEST_ONLY + and ( + not self.preprocess + or self.preprocess.arg_config == ArgConfig.REQUEST_ONLY + ) ) @classmethod @@ -266,25 +271,6 @@ def _safe_extract_descriptor( @classmethod def from_model(cls, model_cls) -> "ModelDescriptor": - preprocess = cls._safe_extract_descriptor(model_cls, MethodName.PREPROCESS) - predict = cls._safe_extract_descriptor(model_cls, MethodName.PREDICT) - if not predict: - raise errors.ModelDefinitionError( - f"Truss model must have a `{MethodName.PREDICT}` method." - ) - elif preprocess and predict.arg_config == ArgConfig.REQUEST_ONLY: - raise errors.ModelDefinitionError( - f"When using `{MethodName.PREPROCESS}`, the {MethodName.PREDICT} method " - f"cannot only have the request argument (because the result of `{MethodName.PREPROCESS}` " - "would be discarded)." - ) - - postprocess = cls._safe_extract_descriptor(model_cls, MethodName.POSTPROCESS) - if postprocess and postprocess.arg_config == ArgConfig.REQUEST_ONLY: - raise errors.ModelDefinitionError( - f"The `{MethodName.POSTPROCESS}` method cannot only have the request " - f"argument (because the result of `{MethodName.PREDICT}` would be discarded)." - ) setup = cls._safe_extract_descriptor(model_cls, MethodName.SETUP_ENVIRONMENT) completions = cls._safe_extract_descriptor(model_cls, MethodName.COMPLETIONS) chats = cls._safe_extract_descriptor(model_cls, MethodName.CHAT_COMPLETIONS) @@ -293,16 +279,53 @@ def from_model(cls, model_cls) -> "ModelDescriptor": raise errors.ModelDefinitionError( f"`{MethodName.IS_HEALTHY}` must have only one argument: `self`." ) - websocket = cls._safe_extract_descriptor(model_cls, MethodName.WEBSOCKET) - if websocket and websocket.arg_config != ArgConfig.INPUTS_ONLY: + predict = cls._safe_extract_descriptor(model_cls, MethodName.PREDICT) + truss_schema, preprocess, postprocess = None, None, None + + if websocket and predict: raise errors.ModelDefinitionError( - f"`{MethodName.WEBSOCKET}` must have only one argument: `websocket`." + f"Truss model cannot have both `{MethodName.PREDICT}` and " + f"`{MethodName.WEBSOCKET}` method." ) - truss_schema = cls._gen_truss_schema( - predict=predict, preprocess=preprocess, postprocess=postprocess - ) + if websocket: + assert predict is None + if websocket.arg_config != ArgConfig.INPUTS_ONLY: + raise errors.ModelDefinitionError( + f"`{MethodName.WEBSOCKET}` must have only one argument: `websocket`." + ) + elif predict: + assert websocket is None + preprocess = cls._safe_extract_descriptor(model_cls, MethodName.PREPROCESS) + if not predict: + raise errors.ModelDefinitionError( + f"Truss model must have a `{MethodName.PREDICT}` method." + ) + if preprocess and predict.arg_config == ArgConfig.REQUEST_ONLY: + raise errors.ModelDefinitionError( + f"When using `{MethodName.PREPROCESS}`, the {MethodName.PREDICT} method " + f"cannot only have the request argument (because the result of " + f"`{MethodName.PREPROCESS}` would be discarded)." + ) + + postprocess = cls._safe_extract_descriptor( + model_cls, MethodName.POSTPROCESS + ) + if postprocess and postprocess.arg_config == ArgConfig.REQUEST_ONLY: + raise errors.ModelDefinitionError( + f"The `{MethodName.POSTPROCESS}` method cannot only have the request " + f"argument (because the result of `{MethodName.PREDICT}` would be discarded)." + ) + + truss_schema = cls._gen_truss_schema( + predict=predict, preprocess=preprocess, postprocess=postprocess + ) + + else: + raise errors.ModelDefinitionError( + f"Truss model must have a `{MethodName.PREDICT}` or `{MethodName.WEBSOCKET}` method." + ) return cls( preprocess=preprocess, @@ -617,6 +640,9 @@ async def _predict( # or, if `postprocessing` is used, anything. In the last case postprocessing # must convert the result to something serializable. descriptor = self.model_descriptor.predict + assert descriptor, ( + f"`{MethodName.PREDICT}` must only be called if model has it." + ) return await self._execute_user_model_fn(inputs, request, descriptor) async def postprocess( @@ -723,10 +749,9 @@ async def _execute_model_endpoint( Wraps the execution of any model code other than `predict`. """ fn_span = self._tracer.start_span(f"call-{descriptor.method_name}") - # TODO(nikhil): Make it easier to start a section with detached context. with tracing.section_as_event( - fn_span, descriptor.method_name - ), tracing.detach_context() as detached_ctx: + fn_span, descriptor.method_name, detach=True + ) as detached_ctx: result = await self._execute_user_model_fn(inputs, request, descriptor) if inspect.isgenerator(result) or inspect.isasyncgen(result): @@ -810,10 +835,7 @@ async def predict( """ if self.model_descriptor.preprocess: with self._tracer.start_as_current_span("call-pre") as span_pre: - # TODO(nikhil): Make it easier to start a section with detached context. - with tracing.section_as_event( - span_pre, "preprocess" - ), tracing.detach_context(): + with tracing.section_as_event(span_pre, "preprocess", detach=True): preprocess_result = await self.preprocess(inputs, request) else: preprocess_result = inputs @@ -822,10 +844,9 @@ async def predict( async with deferred_semaphore_and_span( self._predict_semaphore, span_predict ) as get_defer_fn: - # TODO(nikhil): Make it easier to start a section with detached context. with tracing.section_as_event( - span_predict, "predict" - ), tracing.detach_context() as detached_ctx: + span_predict, "predict", detach=True + ) as detached_ctx: # To prevent span pollution, we need to make sure spans created by user # code don't inherit context from our spans (which happens even if # different tracer instances are used). @@ -883,10 +904,7 @@ async def predict( if self.model_descriptor.postprocess: with self._tracer.start_as_current_span("call-post") as span_post: - # TODO(nikhil): Make it easier to start a section with detached context. - with tracing.section_as_event( - span_post, "postprocess" - ), tracing.detach_context(): + with tracing.section_as_event(span_post, "postprocess", detach=True): postprocess_result = await self.postprocess(predict_result, request) return postprocess_result else: diff --git a/truss/tests/test_model_inference.py b/truss/tests/test_model_inference.py index ac18ceea6..dfd2d1f23 100644 --- a/truss/tests/test_model_inference.py +++ b/truss/tests/test_model_inference.py @@ -1715,21 +1715,16 @@ def test_custom_openai_endpoints(): Test a Truss that exposes an OpenAI compatible endpoint. """ model = """ - from typing import Dict - class Model: - def __init__(self): - pass - def load(self): self._predict_count = 0 self._completions_count = 0 - async def predict(self, inputs: Dict) -> int: + async def predict(self, inputs) -> int: self._predict_count += inputs["increment"] return self._predict_count - async def completions(self, inputs: Dict) -> int: + async def completions(self, inputs) -> int: self._completions_count += inputs["increment"] return self._completions_count """ @@ -1754,16 +1749,10 @@ def test_postprocess_async_generator_streaming(): Test a Truss that exposes an OpenAI compatible endpoint. """ model = """ - from typing import Dict, List, Generator + from typing import List, Generator class Model: - def __init__(self): - pass - - def load(self): - pass - - async def predict(self, inputs: Dict) -> List[str]: + async def predict(self, inputs) -> List[str]: nums: List[int] = inputs["nums"] return nums @@ -1787,16 +1776,10 @@ def test_preprocess_async_generator(): Test a Truss that exposes an OpenAI compatible endpoint. """ model = """ - from typing import Dict, List, AsyncGenerator + from typing import List, AsyncGenerator class Model: - def __init__(self): - pass - - def load(self): - pass - - async def preprocess(self, inputs: Dict) -> AsyncGenerator[str, None]: + async def preprocess(self, inputs) -> AsyncGenerator[str, None]: for num in inputs["nums"]: yield num @@ -1817,20 +1800,14 @@ def test_openai_client_streaming(): Test a Truss that exposes an OpenAI compatible endpoint. """ model = """ - from typing import Dict, AsyncGenerator + from typing import AsyncGenerator class Model: - def __init__(self): - pass - - def load(self): - pass - - async def chat_completions(self, inputs: Dict) -> AsyncGenerator[str, None]: + async def chat_completions(self, inputs) -> AsyncGenerator[str, None]: for num in inputs["nums"]: yield num - async def predict(self, inputs: Dict): + async def predict(self, inputs): pass """ with ensure_kill_all(), _temp_truss(model) as tr: @@ -1854,28 +1831,60 @@ async def predict(self, inputs: Dict): @pytest.mark.asyncio @pytest.mark.integration -async def test_websocket_endpoint(): +async def test_raise_predict_and_websocket_endpoint(): model = """ - import fastapi - from typing import Dict, AsyncGenerator - class Model: - def __init__(self): + async def websocket(self, websocket): pass - def load(self): + async def predict(self, inputs): pass + """ + with ensure_kill_all(), _temp_truss(model, "") as tr: + container = tr.docker_run( + local_port=8090, detach=True, wait_for_server_ready=False + ) + time.sleep(1) + _assert_logs_contain_error( + container.logs(), + message="Exception while loading model", + error="cannot have both `predict` and `websocket` method", + ) + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_raise_no_endpoint(): + model = """ + class Model: + pass + """ + with ensure_kill_all(), _temp_truss(model, "") as tr: + container = tr.docker_run( + local_port=8090, detach=True, wait_for_server_ready=False + ) + time.sleep(1) + _assert_logs_contain_error( + container.logs(), + message="Exception while loading model", + error="must have a `predict` or `websocket` method", + ) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_websocket_endpoint(): + model = """ + import fastapi + + class Model: async def websocket(self, websocket: fastapi.WebSocket): try: while True: text = await websocket.receive_text() await websocket.send_text(text + " pong") - except WebSocketDisconnect: + except fastapi.WebSocketDisconnect: pass - - async def predict(self, inputs: Dict): - pass """ with ensure_kill_all(), _temp_truss(model) as tr: tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True) @@ -1895,15 +1904,9 @@ async def predict(self, inputs: Dict): @pytest.mark.integration async def test_nonexistent_websocket_endpoint(): model = """ - from typing import Dict class Model: - def __init__(self): - pass - - def load(self): - pass - async def predict(self, inputs: Dict): + async def predict(self, inputs): pass """ with ensure_kill_all(), _temp_truss(model) as tr: