From 3f2591292d148b6399c61a9105ff2f6bb80974ba Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Mon, 18 Nov 2024 11:57:58 +0000 Subject: [PATCH 1/5] chore: bump aidial-sdk from 0.9.0 to 0.15.0 --- poetry.lock | 10 +++++----- pyproject.toml | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/poetry.lock b/poetry.lock index ec2542b..b56a884 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,19 +2,19 @@ [[package]] name = "aidial-sdk" -version = "0.9.0" +version = "0.15.0" description = "Framework to create applications and model adapters for AI DIAL" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "aidial_sdk-0.9.0-py3-none-any.whl", hash = "sha256:4353737f4b3ab2b41461303fca26a718564921d5c25fa8829e09dacab357a976"}, - {file = "aidial_sdk-0.9.0.tar.gz", hash = "sha256:ab611ed46c18ab07fb8584549f9c698599ad8e88ed22a67260c42c0b44eca0c0"}, + {file = "aidial_sdk-0.15.0-py3-none-any.whl", hash = "sha256:7b9b3e5ec9688be2919dcd7dd0312aac807dc7917393ee5f846332713ad2e26a"}, + {file = "aidial_sdk-0.15.0.tar.gz", hash = "sha256:6b47bb36e8c795300e0d4b61308c6a2f86b59abb97905390a02789b343460720"}, ] [package.dependencies] aiohttp = ">=3.8.3,<4.0.0" fastapi = ">=0.51,<1.0" -httpx = ">=0.25.0,<0.26.0" +httpx = ">=0.25.0,<1.0" opentelemetry-api = {version = "1.20.0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-distro = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-exporter-otlp-proto-grpc = {version = "1.20.0", optional = true, markers = "extra == \"telemetry\""} @@ -1897,4 +1897,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "f1d62d5edcd9629b045cc744c403459cdba7129c1048ef5b2e5593870acacd3b" +content-hash = "8d8c524c2e71ec384086fc1b0e3006f22a47473174c761983d52814b2ca8c1c5" diff --git a/pyproject.toml b/pyproject.toml index 3794612..2fd3b1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ uvicorn = "0.23" aiohttp = "3.10.2" openai = "1.32.0" # NOTE: used solely for chat completion response types pydantic = "^1.10.12" -aidial-sdk = {version = "^0.9.0", extras = ["telemetry"]} +aidial-sdk = {version = "^0.15.0", extras = ["telemetry"]} respx = "^0.21.1" [tool.poetry.group.test.dependencies] From 832face4987ebf78e6b5a055f5a598545c498a62 Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Mon, 18 Nov 2024 12:16:31 +0000 Subject: [PATCH 2/5] feat: migrate to HTTPException class from DIAL SDK --- aidial_adapter_dial/app.py | 47 +++++-------- aidial_adapter_dial/utils/exceptions.py | 90 ++++--------------------- aidial_adapter_dial/utils/reflection.py | 8 +-- aidial_adapter_dial/utils/sse_stream.py | 47 ++----------- aidial_adapter_dial/utils/streaming.py | 41 ++--------- 5 files changed, 41 insertions(+), 192 deletions(-) diff --git a/aidial_adapter_dial/app.py b/aidial_adapter_dial/app.py index be715ab..fa6b685 100644 --- a/aidial_adapter_dial/app.py +++ b/aidial_adapter_dial/app.py @@ -2,6 +2,7 @@ import logging from urllib.parse import urlparse +from aidial_sdk.exceptions import InvalidRequestError from aidial_sdk.telemetry.init import init_telemetry from aidial_sdk.telemetry.types import TelemetryConfig from fastapi import FastAPI, Request @@ -14,20 +15,13 @@ from aidial_adapter_dial.transformer import AttachmentTransformer from aidial_adapter_dial.utils.dict import censor_ci_dict from aidial_adapter_dial.utils.env import get_env -from aidial_adapter_dial.utils.exceptions import ( - HTTPException, - dial_exception_decorator, -) +from aidial_adapter_dial.utils.exceptions import dial_exception_decorator from aidial_adapter_dial.utils.http_client import get_http_client from aidial_adapter_dial.utils.log_config import configure_loggers from aidial_adapter_dial.utils.reflection import call_with_extra_body from aidial_adapter_dial.utils.sse_stream import to_openai_sse_stream from aidial_adapter_dial.utils.storage import FileStorage -from aidial_adapter_dial.utils.streaming import ( - amap_stream, - generate_stream, - map_stream, -) +from aidial_adapter_dial.utils.streaming import amap_stream, map_stream app = FastAPI() @@ -73,16 +67,12 @@ async def parse(cls, request: Request, endpoint_name: str) -> "AzureClient": local_dial_api_key = headers.get("api-key", None) if not local_dial_api_key: - raise HTTPException( - status_code=400, - message="The 'api-key' request header is missing", - ) + raise InvalidRequestError("The 'api-key' request header is missing") upstream_endpoint = headers.get(UPSTREAM_ENDPOINT_HEADER, None) if not upstream_endpoint: - raise HTTPException( - status_code=400, - message=f"The {UPSTREAM_ENDPOINT_HEADER!r} request header is missing", + raise InvalidRequestError( + f"The {UPSTREAM_ENDPOINT_HEADER!r} request header is missing" ) remote_dial_url = get_hostname(upstream_endpoint) @@ -90,29 +80,24 @@ async def parse(cls, request: Request, endpoint_name: str) -> "AzureClient": if not remote_dial_api_key: if remote_dial_url != LOCAL_DIAL_URL: - raise HTTPException( - status_code=400, - message=( - f"Given that {UPSTREAM_KEY_HEADER!r} header is missing, " - f"it's expected that hostname of upstream endpoint ({upstream_endpoint!r}) is " - f"the same as the local DIAL URL ({LOCAL_DIAL_URL!r}) " - ), + raise InvalidRequestError( + f"Given that {UPSTREAM_KEY_HEADER!r} header is missing, " + f"it's expected that hostname of upstream endpoint ({upstream_endpoint!r}) is " + f"the same as the local DIAL URL ({LOCAL_DIAL_URL!r}) " ) local_dial_api_key = request.headers.get("api-key") if not local_dial_api_key: - raise HTTPException( - status_code=400, - message="The 'api-key' request header is missing", + raise InvalidRequestError( + "The 'api-key' request header is missing" ) remote_dial_api_key = local_dial_api_key endpoint_suffix = f"/{endpoint_name}" if not upstream_endpoint.endswith(endpoint_suffix): - raise HTTPException( - status_code=400, - message=f"The {UPSTREAM_ENDPOINT_HEADER!r} request header must end with {endpoint_suffix!r}", + raise InvalidRequestError( + f"The {UPSTREAM_ENDPOINT_HEADER!r} request header must end with {endpoint_suffix!r}" ) upstream_endpoint = upstream_endpoint.removesuffix(endpoint_suffix) @@ -192,9 +177,7 @@ async def modify_chunk(chunk: dict) -> dict: chunk_stream = map_stream(lambda obj: obj.to_dict(), response) return StreamingResponse( - to_openai_sse_stream( - amap_stream(modify_chunk, generate_stream(chunk_stream)) - ), + to_openai_sse_stream(amap_stream(modify_chunk, chunk_stream)), media_type="text/event-stream", ) else: diff --git a/aidial_adapter_dial/utils/exceptions.py b/aidial_adapter_dial/utils/exceptions.py index 34476ab..d35755b 100644 --- a/aidial_adapter_dial/utils/exceptions.py +++ b/aidial_adapter_dial/utils/exceptions.py @@ -1,70 +1,14 @@ import logging from functools import wraps -from typing import Optional +from aidial_sdk.exceptions import HTTPException as DialException from fastapi import HTTPException as FastAPIException from openai import APIConnectionError, APIStatusError, APITimeoutError log = logging.getLogger(__name__) -class HTTPException(Exception): - def __init__( - self, - message: str, - status_code: int = 500, - type: str = "runtime_error", - param: Optional[str] = None, - code: Optional[str] = None, - display_message: Optional[str] = None, - ) -> None: - self.message = message - self.status_code = status_code - self.type = type - self.param = param - self.code = code - self.display_message = display_message - - def __repr__(self): - return ( - "%s(message=%r, status_code=%r, type=%r, param=%r, code=%r, display_message=%r)" - % ( - self.__class__.__name__, - self.message, - self.status_code, - self.type, - self.param, - self.code, - self.display_message, - ) - ) - - -def remove_nones(d: dict) -> dict: - return {k: v for k, v in d.items() if v is not None} - - -def create_error( - message: str, - type: Optional[str] = None, - param: Optional[str] = None, - code: Optional[str] = None, - display_message: Optional[str] = None, -): - return { - "error": remove_nones( - { - "message": message, - "type": type, - "param": param, - "code": code, - "display_message": display_message, - } - ) - } - - -def to_dial_exception(e: Exception) -> HTTPException | FastAPIException: +def to_dial_exception(e: Exception) -> DialException | FastAPIException: if isinstance(e, APIStatusError): r = e.response headers = r.headers @@ -79,41 +23,30 @@ def to_dial_exception(e: Exception) -> HTTPException | FastAPIException: ) if isinstance(e, APITimeoutError): - return HTTPException("Request timed out", 504, "timeout") + return DialException("Request timed out", 504, "timeout") if isinstance(e, APIConnectionError): - return HTTPException( + return DialException( "Error communicating with OpenAI", 502, "connection" ) - if isinstance(e, HTTPException): + if isinstance(e, DialException): return e - return HTTPException( + return DialException( status_code=500, type="internal_server_error", message=str(e), - code=None, - param=None, ) -def to_starlette_exception( - e: HTTPException | FastAPIException, +def to_fastapi_exception( + e: DialException | FastAPIException, ) -> FastAPIException: if isinstance(e, FastAPIException): return e - - return FastAPIException( - status_code=e.status_code, - detail=create_error( - message=e.message, - type=e.type, - param=e.param, - code=e.code, - display_message=e.display_message, - ), - ) + else: + return e.to_fastapi_exception() def dial_exception_decorator(func): @@ -126,6 +59,7 @@ async def wrapper(*args, **kwargs): f"caught exception: {type(e).__module__}.{type(e).__name__}" ) dial_exception = to_dial_exception(e) - raise to_starlette_exception(dial_exception) from e + fastapi_exception = to_fastapi_exception(dial_exception) + raise fastapi_exception from e return wrapper diff --git a/aidial_adapter_dial/utils/reflection.py b/aidial_adapter_dial/utils/reflection.py index a300604..23b94ea 100644 --- a/aidial_adapter_dial/utils/reflection.py +++ b/aidial_adapter_dial/utils/reflection.py @@ -1,7 +1,7 @@ import inspect from typing import Any, Callable, Coroutine, TypeVar -from aidial_adapter_dial.utils.exceptions import HTTPException +from aidial_sdk.exceptions import InvalidRequestError T = TypeVar("T") @@ -18,10 +18,8 @@ async def call_with_extra_body( extra_args = actual_args - expected_args if extra_args and "extra_body" not in expected_args: - raise HTTPException( - f"Unrecognized request argument supplied: {extra_args}.", - 400, - "invalid_request_error", + raise InvalidRequestError( + f"Unrecognized request argument supplied: {extra_args}." ) arg["extra_body"] = arg.get("extra_body") or {} diff --git a/aidial_adapter_dial/utils/sse_stream.py b/aidial_adapter_dial/utils/sse_stream.py index c9e28e6..612c1c6 100644 --- a/aidial_adapter_dial/utils/sse_stream.py +++ b/aidial_adapter_dial/utils/sse_stream.py @@ -3,9 +3,8 @@ from typing import Any, AsyncIterator, Mapping from aidial_adapter_dial.utils.exceptions import ( - create_error, to_dial_exception, - to_starlette_exception, + to_fastapi_exception, ) DATA_PREFIX = "data: " @@ -22,43 +21,6 @@ def format_chunk(data: str | Mapping[str, Any]) -> str: END_CHUNK = format_chunk(OPENAI_END_MARKER) -async def parse_openai_sse_stream( - stream: AsyncIterator[bytes], -) -> AsyncIterator[dict]: - async for line in stream: - try: - payload = line.decode("utf-8-sig").lstrip() # type: ignore - except Exception: - yield create_error( - message="Can't decode chunk to a string", type="runtime_error" - ) - return - - if payload.strip() == "": - continue - - if not payload.startswith(DATA_PREFIX): - yield create_error( - message="Invalid chunk format", type="runtime_error" - ) - return - - payload = payload[len(DATA_PREFIX) :] - - if payload.strip() == OPENAI_END_MARKER: - break - - try: - chunk = json.loads(payload) - except json.JSONDecodeError: - yield create_error( - message="Can't parse chunk to JSON", type="runtime_error" - ) - return - - yield chunk - - log = logging.getLogger(__name__) @@ -68,13 +30,14 @@ async def to_openai_sse_stream( try: async for chunk in stream: yield format_chunk(chunk) - yield END_CHUNK except Exception as e: log.exception( f"caught exception while streaming: {type(e).__module__}.{type(e).__name__}" ) dial_exception = to_dial_exception(e) - starlette_exception = to_starlette_exception(dial_exception) + fastapi_exception = to_fastapi_exception(dial_exception) + + yield format_chunk(fastapi_exception.detail) - yield format_chunk(starlette_exception.detail) + yield END_CHUNK diff --git a/aidial_adapter_dial/utils/streaming.py b/aidial_adapter_dial/utils/streaming.py index c845fec..6b371bd 100644 --- a/aidial_adapter_dial/utils/streaming.py +++ b/aidial_adapter_dial/utils/streaming.py @@ -1,41 +1,12 @@ -import logging from typing import AsyncIterator, Awaitable, Callable, Optional, TypeVar -from openai import APIError - -from aidial_adapter_dial.utils.exceptions import create_error - -log = logging.getLogger(__name__) - - -async def generate_stream(stream: AsyncIterator[dict]) -> AsyncIterator[dict]: - try: - async for chunk in stream: - yield chunk - except APIError as e: - log.error(f"error during steaming: {e.body}") - - display_message = None - if e.body is not None and isinstance(e.body, dict): - display_message = e.body.get("display_message", None) - - yield create_error( - message=e.message, - type=e.type, - param=e.param, - code=e.code, - display_message=display_message, - ) - return - - -T = TypeVar("T") -V = TypeVar("V") +_T = TypeVar("_T") +_V = TypeVar("_V") async def map_stream( - func: Callable[[T], Optional[V]], iterator: AsyncIterator[T] -) -> AsyncIterator[V]: + func: Callable[[_T], Optional[_V]], iterator: AsyncIterator[_T] +) -> AsyncIterator[_V]: async for item in iterator: new_item = func(item) if new_item is not None: @@ -43,8 +14,8 @@ async def map_stream( async def amap_stream( - func: Callable[[T], Awaitable[Optional[V]]], iterator: AsyncIterator[T] -) -> AsyncIterator[V]: + func: Callable[[_T], Awaitable[Optional[_V]]], iterator: AsyncIterator[_T] +) -> AsyncIterator[_V]: async for item in iterator: new_item = await func(item) if new_item is not None: From cc6f5e03a724ebee654d6341b31bcb3e995d396c Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Mon, 18 Nov 2024 12:48:58 +0000 Subject: [PATCH 3/5] feat: simplified exceptions module --- aidial_adapter_dial/app.py | 12 +++- aidial_adapter_dial/utils/exceptions.py | 73 +++++++++++++------------ aidial_adapter_dial/utils/sse_stream.py | 6 +- 3 files changed, 50 insertions(+), 41 deletions(-) diff --git a/aidial_adapter_dial/app.py b/aidial_adapter_dial/app.py index fa6b685..e2f23b0 100644 --- a/aidial_adapter_dial/app.py +++ b/aidial_adapter_dial/app.py @@ -15,7 +15,7 @@ from aidial_adapter_dial.transformer import AttachmentTransformer from aidial_adapter_dial.utils.dict import censor_ci_dict from aidial_adapter_dial.utils.env import get_env -from aidial_adapter_dial.utils.exceptions import dial_exception_decorator +from aidial_adapter_dial.utils.exceptions import to_dial_exception from aidial_adapter_dial.utils.http_client import get_http_client from aidial_adapter_dial.utils.log_config import configure_loggers from aidial_adapter_dial.utils.reflection import call_with_extra_body @@ -134,7 +134,6 @@ async def parse(cls, request: Request, endpoint_name: str) -> "AzureClient": @app.post("/embeddings") @app.post("/openai/deployments/{deployment_id:path}/embeddings") -@dial_exception_decorator async def embeddings_proxy(request: Request): body = await request.json() az_client = await AzureClient.parse(request, "embeddings") @@ -148,7 +147,6 @@ async def embeddings_proxy(request: Request): @app.post("/chat/completions") @app.post("/openai/deployments/{deployment_id:path}/chat/completions") -@dial_exception_decorator async def chat_completions_proxy(request: Request): az_client = await AzureClient.parse(request, "chat/completions") @@ -188,6 +186,14 @@ async def modify_chunk(chunk: dict) -> dict: return resp +@app.exception_handler(Exception) +def exception_handler(request: Request, e: Exception): + log.exception(f"caught exception: {type(e).__module__}.{type(e).__name__}") + dial_exception = to_dial_exception(e) + fastapi_response = dial_exception.to_fastapi_response() + return fastapi_response + + @app.get("/health") def health(): return {"status": "ok"} diff --git a/aidial_adapter_dial/utils/exceptions.py b/aidial_adapter_dial/utils/exceptions.py index d35755b..d1f5e2c 100644 --- a/aidial_adapter_dial/utils/exceptions.py +++ b/aidial_adapter_dial/utils/exceptions.py @@ -1,65 +1,68 @@ +import dataclasses import logging -from functools import wraps +from typing import Any from aidial_sdk.exceptions import HTTPException as DialException -from fastapi import HTTPException as FastAPIException +from fastapi.responses import JSONResponse as FastAPIResponse +from httpx import Headers from openai import APIConnectionError, APIStatusError, APITimeoutError log = logging.getLogger(__name__) -def to_dial_exception(e: Exception) -> DialException | FastAPIException: - if isinstance(e, APIStatusError): - r = e.response +@dataclasses.dataclass +class ResponseWrapper: + status_code: int + headers: Headers | None + content: Any + + def to_fastapi_response(self) -> FastAPIResponse: + return FastAPIResponse( + content=self.content, + status_code=self.status_code, + headers=self.headers, + ) + + +def to_dial_exception(exc: Exception) -> DialException | ResponseWrapper: + if isinstance(exc, APIStatusError): + r = exc.response headers = r.headers if "Content-Length" in headers: del headers["Content-Length"] - return FastAPIException( - detail=r.text, + try: + content = r.json() + except Exception: + content = r.text + + return ResponseWrapper( status_code=r.status_code, - headers=dict(headers), + headers=headers, + content=content, ) - if isinstance(e, APITimeoutError): + if isinstance(exc, APITimeoutError): return DialException("Request timed out", 504, "timeout") - if isinstance(e, APIConnectionError): + if isinstance(exc, APIConnectionError): return DialException( "Error communicating with OpenAI", 502, "connection" ) - if isinstance(e, DialException): - return e + if isinstance(exc, DialException): + return exc return DialException( status_code=500, type="internal_server_error", - message=str(e), + message=str(exc), ) -def to_fastapi_exception( - e: DialException | FastAPIException, -) -> FastAPIException: - if isinstance(e, FastAPIException): - return e +def to_json_content(exc: DialException | ResponseWrapper) -> Any: + if isinstance(exc, DialException): + return exc.json_error() else: - return e.to_fastapi_exception() - - -def dial_exception_decorator(func): - @wraps(func) - async def wrapper(*args, **kwargs): - try: - return await func(*args, **kwargs) - except Exception as e: - log.exception( - f"caught exception: {type(e).__module__}.{type(e).__name__}" - ) - dial_exception = to_dial_exception(e) - fastapi_exception = to_fastapi_exception(dial_exception) - raise fastapi_exception from e - - return wrapper + return exc.content diff --git a/aidial_adapter_dial/utils/sse_stream.py b/aidial_adapter_dial/utils/sse_stream.py index 612c1c6..80b3af9 100644 --- a/aidial_adapter_dial/utils/sse_stream.py +++ b/aidial_adapter_dial/utils/sse_stream.py @@ -4,7 +4,7 @@ from aidial_adapter_dial.utils.exceptions import ( to_dial_exception, - to_fastapi_exception, + to_json_content, ) DATA_PREFIX = "data: " @@ -36,8 +36,8 @@ async def to_openai_sse_stream( ) dial_exception = to_dial_exception(e) - fastapi_exception = to_fastapi_exception(dial_exception) + error_chunk = to_json_content(dial_exception) - yield format_chunk(fastapi_exception.detail) + yield format_chunk(error_chunk) yield END_CHUNK From 74b744d8c0eff91516ad5ce564f6b09f050a5e2f Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Mon, 18 Nov 2024 15:52:54 +0000 Subject: [PATCH 4/5] fix: fixes issues with an incorrect handling of Content-Encoding header --- aidial_adapter_dial/utils/exceptions.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/aidial_adapter_dial/utils/exceptions.py b/aidial_adapter_dial/utils/exceptions.py index d1f5e2c..4947d4c 100644 --- a/aidial_adapter_dial/utils/exceptions.py +++ b/aidial_adapter_dial/utils/exceptions.py @@ -29,9 +29,20 @@ def to_dial_exception(exc: Exception) -> DialException | ResponseWrapper: r = exc.response headers = r.headers + # The original content length may have changed + # due to the response modification in the adapter. if "Content-Length" in headers: del headers["Content-Length"] + # httpx library (used by openai) automatically sets + # "Accept-Encoding:gzip,deflate" header in requests to the upstream. + # Therefore, we may receive from the upstream gzip-encoded + # response along with "Content-Encoding:gzip" header. + # We either need to encode the response, or + # remove the "Content-Encoding" header. + if "Content-Encoding" in headers: + del headers["Content-Encoding"] + try: content = r.json() except Exception: From 169ab3bb20e87205ebf9b719aca27d22a566f1fe Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Mon, 18 Nov 2024 16:48:51 +0000 Subject: [PATCH 5/5] fix: reverted Content-Encoding handling --- aidial_adapter_dial/utils/exceptions.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/aidial_adapter_dial/utils/exceptions.py b/aidial_adapter_dial/utils/exceptions.py index 4947d4c..d1f5e2c 100644 --- a/aidial_adapter_dial/utils/exceptions.py +++ b/aidial_adapter_dial/utils/exceptions.py @@ -29,20 +29,9 @@ def to_dial_exception(exc: Exception) -> DialException | ResponseWrapper: r = exc.response headers = r.headers - # The original content length may have changed - # due to the response modification in the adapter. if "Content-Length" in headers: del headers["Content-Length"] - # httpx library (used by openai) automatically sets - # "Accept-Encoding:gzip,deflate" header in requests to the upstream. - # Therefore, we may receive from the upstream gzip-encoded - # response along with "Content-Encoding:gzip" header. - # We either need to encode the response, or - # remove the "Content-Encoding" header. - if "Content-Encoding" in headers: - del headers["Content-Encoding"] - try: content = r.json() except Exception: