diff --git a/aidial_adapter_dial/app.py b/aidial_adapter_dial/app.py index be715ab..e2f23b0 100644 --- a/aidial_adapter_dial/app.py +++ b/aidial_adapter_dial/app.py @@ -2,6 +2,7 @@ import logging from urllib.parse import urlparse +from aidial_sdk.exceptions import InvalidRequestError from aidial_sdk.telemetry.init import init_telemetry from aidial_sdk.telemetry.types import TelemetryConfig from fastapi import FastAPI, Request @@ -14,20 +15,13 @@ from aidial_adapter_dial.transformer import AttachmentTransformer from aidial_adapter_dial.utils.dict import censor_ci_dict from aidial_adapter_dial.utils.env import get_env -from aidial_adapter_dial.utils.exceptions import ( - HTTPException, - dial_exception_decorator, -) +from aidial_adapter_dial.utils.exceptions import to_dial_exception from aidial_adapter_dial.utils.http_client import get_http_client from aidial_adapter_dial.utils.log_config import configure_loggers from aidial_adapter_dial.utils.reflection import call_with_extra_body from aidial_adapter_dial.utils.sse_stream import to_openai_sse_stream from aidial_adapter_dial.utils.storage import FileStorage -from aidial_adapter_dial.utils.streaming import ( - amap_stream, - generate_stream, - map_stream, -) +from aidial_adapter_dial.utils.streaming import amap_stream, map_stream app = FastAPI() @@ -73,16 +67,12 @@ async def parse(cls, request: Request, endpoint_name: str) -> "AzureClient": local_dial_api_key = headers.get("api-key", None) if not local_dial_api_key: - raise HTTPException( - status_code=400, - message="The 'api-key' request header is missing", - ) + raise InvalidRequestError("The 'api-key' request header is missing") upstream_endpoint = headers.get(UPSTREAM_ENDPOINT_HEADER, None) if not upstream_endpoint: - raise HTTPException( - status_code=400, - message=f"The {UPSTREAM_ENDPOINT_HEADER!r} request header is missing", + raise InvalidRequestError( + f"The {UPSTREAM_ENDPOINT_HEADER!r} request header is missing" ) remote_dial_url = get_hostname(upstream_endpoint) @@ -90,29 +80,24 @@ async def parse(cls, request: Request, endpoint_name: str) -> "AzureClient": if not remote_dial_api_key: if remote_dial_url != LOCAL_DIAL_URL: - raise HTTPException( - status_code=400, - message=( - f"Given that {UPSTREAM_KEY_HEADER!r} header is missing, " - f"it's expected that hostname of upstream endpoint ({upstream_endpoint!r}) is " - f"the same as the local DIAL URL ({LOCAL_DIAL_URL!r}) " - ), + raise InvalidRequestError( + f"Given that {UPSTREAM_KEY_HEADER!r} header is missing, " + f"it's expected that hostname of upstream endpoint ({upstream_endpoint!r}) is " + f"the same as the local DIAL URL ({LOCAL_DIAL_URL!r}) " ) local_dial_api_key = request.headers.get("api-key") if not local_dial_api_key: - raise HTTPException( - status_code=400, - message="The 'api-key' request header is missing", + raise InvalidRequestError( + "The 'api-key' request header is missing" ) remote_dial_api_key = local_dial_api_key endpoint_suffix = f"/{endpoint_name}" if not upstream_endpoint.endswith(endpoint_suffix): - raise HTTPException( - status_code=400, - message=f"The {UPSTREAM_ENDPOINT_HEADER!r} request header must end with {endpoint_suffix!r}", + raise InvalidRequestError( + f"The {UPSTREAM_ENDPOINT_HEADER!r} request header must end with {endpoint_suffix!r}" ) upstream_endpoint = upstream_endpoint.removesuffix(endpoint_suffix) @@ -149,7 +134,6 @@ async def parse(cls, request: Request, endpoint_name: str) -> "AzureClient": @app.post("/embeddings") @app.post("/openai/deployments/{deployment_id:path}/embeddings") -@dial_exception_decorator async def embeddings_proxy(request: Request): body = await request.json() az_client = await AzureClient.parse(request, "embeddings") @@ -163,7 +147,6 @@ async def embeddings_proxy(request: Request): @app.post("/chat/completions") @app.post("/openai/deployments/{deployment_id:path}/chat/completions") -@dial_exception_decorator async def chat_completions_proxy(request: Request): az_client = await AzureClient.parse(request, "chat/completions") @@ -192,9 +175,7 @@ async def modify_chunk(chunk: dict) -> dict: chunk_stream = map_stream(lambda obj: obj.to_dict(), response) return StreamingResponse( - to_openai_sse_stream( - amap_stream(modify_chunk, generate_stream(chunk_stream)) - ), + to_openai_sse_stream(amap_stream(modify_chunk, chunk_stream)), media_type="text/event-stream", ) else: @@ -205,6 +186,14 @@ async def modify_chunk(chunk: dict) -> dict: return resp +@app.exception_handler(Exception) +def exception_handler(request: Request, e: Exception): + log.exception(f"caught exception: {type(e).__module__}.{type(e).__name__}") + dial_exception = to_dial_exception(e) + fastapi_response = dial_exception.to_fastapi_response() + return fastapi_response + + @app.get("/health") def health(): return {"status": "ok"} diff --git a/aidial_adapter_dial/utils/exceptions.py b/aidial_adapter_dial/utils/exceptions.py index 34476ab..d1f5e2c 100644 --- a/aidial_adapter_dial/utils/exceptions.py +++ b/aidial_adapter_dial/utils/exceptions.py @@ -1,131 +1,68 @@ +import dataclasses import logging -from functools import wraps -from typing import Optional +from typing import Any -from fastapi import HTTPException as FastAPIException +from aidial_sdk.exceptions import HTTPException as DialException +from fastapi.responses import JSONResponse as FastAPIResponse +from httpx import Headers from openai import APIConnectionError, APIStatusError, APITimeoutError log = logging.getLogger(__name__) -class HTTPException(Exception): - def __init__( - self, - message: str, - status_code: int = 500, - type: str = "runtime_error", - param: Optional[str] = None, - code: Optional[str] = None, - display_message: Optional[str] = None, - ) -> None: - self.message = message - self.status_code = status_code - self.type = type - self.param = param - self.code = code - self.display_message = display_message - - def __repr__(self): - return ( - "%s(message=%r, status_code=%r, type=%r, param=%r, code=%r, display_message=%r)" - % ( - self.__class__.__name__, - self.message, - self.status_code, - self.type, - self.param, - self.code, - self.display_message, - ) - ) - +@dataclasses.dataclass +class ResponseWrapper: + status_code: int + headers: Headers | None + content: Any -def remove_nones(d: dict) -> dict: - return {k: v for k, v in d.items() if v is not None} - - -def create_error( - message: str, - type: Optional[str] = None, - param: Optional[str] = None, - code: Optional[str] = None, - display_message: Optional[str] = None, -): - return { - "error": remove_nones( - { - "message": message, - "type": type, - "param": param, - "code": code, - "display_message": display_message, - } + def to_fastapi_response(self) -> FastAPIResponse: + return FastAPIResponse( + content=self.content, + status_code=self.status_code, + headers=self.headers, ) - } -def to_dial_exception(e: Exception) -> HTTPException | FastAPIException: - if isinstance(e, APIStatusError): - r = e.response +def to_dial_exception(exc: Exception) -> DialException | ResponseWrapper: + if isinstance(exc, APIStatusError): + r = exc.response headers = r.headers if "Content-Length" in headers: del headers["Content-Length"] - return FastAPIException( - detail=r.text, + try: + content = r.json() + except Exception: + content = r.text + + return ResponseWrapper( status_code=r.status_code, - headers=dict(headers), + headers=headers, + content=content, ) - if isinstance(e, APITimeoutError): - return HTTPException("Request timed out", 504, "timeout") + if isinstance(exc, APITimeoutError): + return DialException("Request timed out", 504, "timeout") - if isinstance(e, APIConnectionError): - return HTTPException( + if isinstance(exc, APIConnectionError): + return DialException( "Error communicating with OpenAI", 502, "connection" ) - if isinstance(e, HTTPException): - return e + if isinstance(exc, DialException): + return exc - return HTTPException( + return DialException( status_code=500, type="internal_server_error", - message=str(e), - code=None, - param=None, + message=str(exc), ) -def to_starlette_exception( - e: HTTPException | FastAPIException, -) -> FastAPIException: - if isinstance(e, FastAPIException): - return e - - return FastAPIException( - status_code=e.status_code, - detail=create_error( - message=e.message, - type=e.type, - param=e.param, - code=e.code, - display_message=e.display_message, - ), - ) - - -def dial_exception_decorator(func): - @wraps(func) - async def wrapper(*args, **kwargs): - try: - return await func(*args, **kwargs) - except Exception as e: - log.exception( - f"caught exception: {type(e).__module__}.{type(e).__name__}" - ) - dial_exception = to_dial_exception(e) - raise to_starlette_exception(dial_exception) from e - - return wrapper +def to_json_content(exc: DialException | ResponseWrapper) -> Any: + if isinstance(exc, DialException): + return exc.json_error() + else: + return exc.content diff --git a/aidial_adapter_dial/utils/reflection.py b/aidial_adapter_dial/utils/reflection.py index a300604..23b94ea 100644 --- a/aidial_adapter_dial/utils/reflection.py +++ b/aidial_adapter_dial/utils/reflection.py @@ -1,7 +1,7 @@ import inspect from typing import Any, Callable, Coroutine, TypeVar -from aidial_adapter_dial.utils.exceptions import HTTPException +from aidial_sdk.exceptions import InvalidRequestError T = TypeVar("T") @@ -18,10 +18,8 @@ async def call_with_extra_body( extra_args = actual_args - expected_args if extra_args and "extra_body" not in expected_args: - raise HTTPException( - f"Unrecognized request argument supplied: {extra_args}.", - 400, - "invalid_request_error", + raise InvalidRequestError( + f"Unrecognized request argument supplied: {extra_args}." ) arg["extra_body"] = arg.get("extra_body") or {} diff --git a/aidial_adapter_dial/utils/sse_stream.py b/aidial_adapter_dial/utils/sse_stream.py index c9e28e6..80b3af9 100644 --- a/aidial_adapter_dial/utils/sse_stream.py +++ b/aidial_adapter_dial/utils/sse_stream.py @@ -3,9 +3,8 @@ from typing import Any, AsyncIterator, Mapping from aidial_adapter_dial.utils.exceptions import ( - create_error, to_dial_exception, - to_starlette_exception, + to_json_content, ) DATA_PREFIX = "data: " @@ -22,43 +21,6 @@ def format_chunk(data: str | Mapping[str, Any]) -> str: END_CHUNK = format_chunk(OPENAI_END_MARKER) -async def parse_openai_sse_stream( - stream: AsyncIterator[bytes], -) -> AsyncIterator[dict]: - async for line in stream: - try: - payload = line.decode("utf-8-sig").lstrip() # type: ignore - except Exception: - yield create_error( - message="Can't decode chunk to a string", type="runtime_error" - ) - return - - if payload.strip() == "": - continue - - if not payload.startswith(DATA_PREFIX): - yield create_error( - message="Invalid chunk format", type="runtime_error" - ) - return - - payload = payload[len(DATA_PREFIX) :] - - if payload.strip() == OPENAI_END_MARKER: - break - - try: - chunk = json.loads(payload) - except json.JSONDecodeError: - yield create_error( - message="Can't parse chunk to JSON", type="runtime_error" - ) - return - - yield chunk - - log = logging.getLogger(__name__) @@ -68,13 +30,14 @@ async def to_openai_sse_stream( try: async for chunk in stream: yield format_chunk(chunk) - yield END_CHUNK except Exception as e: log.exception( f"caught exception while streaming: {type(e).__module__}.{type(e).__name__}" ) dial_exception = to_dial_exception(e) - starlette_exception = to_starlette_exception(dial_exception) + error_chunk = to_json_content(dial_exception) + + yield format_chunk(error_chunk) - yield format_chunk(starlette_exception.detail) + yield END_CHUNK diff --git a/aidial_adapter_dial/utils/streaming.py b/aidial_adapter_dial/utils/streaming.py index c845fec..6b371bd 100644 --- a/aidial_adapter_dial/utils/streaming.py +++ b/aidial_adapter_dial/utils/streaming.py @@ -1,41 +1,12 @@ -import logging from typing import AsyncIterator, Awaitable, Callable, Optional, TypeVar -from openai import APIError - -from aidial_adapter_dial.utils.exceptions import create_error - -log = logging.getLogger(__name__) - - -async def generate_stream(stream: AsyncIterator[dict]) -> AsyncIterator[dict]: - try: - async for chunk in stream: - yield chunk - except APIError as e: - log.error(f"error during steaming: {e.body}") - - display_message = None - if e.body is not None and isinstance(e.body, dict): - display_message = e.body.get("display_message", None) - - yield create_error( - message=e.message, - type=e.type, - param=e.param, - code=e.code, - display_message=display_message, - ) - return - - -T = TypeVar("T") -V = TypeVar("V") +_T = TypeVar("_T") +_V = TypeVar("_V") async def map_stream( - func: Callable[[T], Optional[V]], iterator: AsyncIterator[T] -) -> AsyncIterator[V]: + func: Callable[[_T], Optional[_V]], iterator: AsyncIterator[_T] +) -> AsyncIterator[_V]: async for item in iterator: new_item = func(item) if new_item is not None: @@ -43,8 +14,8 @@ async def map_stream( async def amap_stream( - func: Callable[[T], Awaitable[Optional[V]]], iterator: AsyncIterator[T] -) -> AsyncIterator[V]: + func: Callable[[_T], Awaitable[Optional[_V]]], iterator: AsyncIterator[_T] +) -> AsyncIterator[_V]: async for item in iterator: new_item = await func(item) if new_item is not None: diff --git a/poetry.lock b/poetry.lock index ec2542b..b56a884 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,19 +2,19 @@ [[package]] name = "aidial-sdk" -version = "0.9.0" +version = "0.15.0" description = "Framework to create applications and model adapters for AI DIAL" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "aidial_sdk-0.9.0-py3-none-any.whl", hash = "sha256:4353737f4b3ab2b41461303fca26a718564921d5c25fa8829e09dacab357a976"}, - {file = "aidial_sdk-0.9.0.tar.gz", hash = "sha256:ab611ed46c18ab07fb8584549f9c698599ad8e88ed22a67260c42c0b44eca0c0"}, + {file = "aidial_sdk-0.15.0-py3-none-any.whl", hash = "sha256:7b9b3e5ec9688be2919dcd7dd0312aac807dc7917393ee5f846332713ad2e26a"}, + {file = "aidial_sdk-0.15.0.tar.gz", hash = "sha256:6b47bb36e8c795300e0d4b61308c6a2f86b59abb97905390a02789b343460720"}, ] [package.dependencies] aiohttp = ">=3.8.3,<4.0.0" fastapi = ">=0.51,<1.0" -httpx = ">=0.25.0,<0.26.0" +httpx = ">=0.25.0,<1.0" opentelemetry-api = {version = "1.20.0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-distro = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-exporter-otlp-proto-grpc = {version = "1.20.0", optional = true, markers = "extra == \"telemetry\""} @@ -1897,4 +1897,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "f1d62d5edcd9629b045cc744c403459cdba7129c1048ef5b2e5593870acacd3b" +content-hash = "8d8c524c2e71ec384086fc1b0e3006f22a47473174c761983d52814b2ca8c1c5" diff --git a/pyproject.toml b/pyproject.toml index 3794612..2fd3b1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ uvicorn = "0.23" aiohttp = "3.10.2" openai = "1.32.0" # NOTE: used solely for chat completion response types pydantic = "^1.10.12" -aidial-sdk = {version = "^0.9.0", extras = ["telemetry"]} +aidial-sdk = {version = "^0.15.0", extras = ["telemetry"]} respx = "^0.21.1" [tool.poetry.group.test.dependencies]