Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: openai errors are correctly proxied #28

Merged
merged 5 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 23 additions & 34 deletions aidial_adapter_dial/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from urllib.parse import urlparse

from aidial_sdk.exceptions import InvalidRequestError
from aidial_sdk.telemetry.init import init_telemetry
from aidial_sdk.telemetry.types import TelemetryConfig
from fastapi import FastAPI, Request
Expand All @@ -14,20 +15,13 @@
from aidial_adapter_dial.transformer import AttachmentTransformer
from aidial_adapter_dial.utils.dict import censor_ci_dict
from aidial_adapter_dial.utils.env import get_env
from aidial_adapter_dial.utils.exceptions import (
HTTPException,
dial_exception_decorator,
)
from aidial_adapter_dial.utils.exceptions import to_dial_exception
from aidial_adapter_dial.utils.http_client import get_http_client
from aidial_adapter_dial.utils.log_config import configure_loggers
from aidial_adapter_dial.utils.reflection import call_with_extra_body
from aidial_adapter_dial.utils.sse_stream import to_openai_sse_stream
from aidial_adapter_dial.utils.storage import FileStorage
from aidial_adapter_dial.utils.streaming import (
amap_stream,
generate_stream,
map_stream,
)
from aidial_adapter_dial.utils.streaming import amap_stream, map_stream

app = FastAPI()

Expand Down Expand Up @@ -73,46 +67,37 @@ async def parse(cls, request: Request, endpoint_name: str) -> "AzureClient":

local_dial_api_key = headers.get("api-key", None)
if not local_dial_api_key:
raise HTTPException(
status_code=400,
message="The 'api-key' request header is missing",
)
raise InvalidRequestError("The 'api-key' request header is missing")

upstream_endpoint = headers.get(UPSTREAM_ENDPOINT_HEADER, None)
if not upstream_endpoint:
raise HTTPException(
status_code=400,
message=f"The {UPSTREAM_ENDPOINT_HEADER!r} request header is missing",
raise InvalidRequestError(
f"The {UPSTREAM_ENDPOINT_HEADER!r} request header is missing"
)

remote_dial_url = get_hostname(upstream_endpoint)
remote_dial_api_key = headers.get(UPSTREAM_KEY_HEADER, None)

if not remote_dial_api_key:
if remote_dial_url != LOCAL_DIAL_URL:
raise HTTPException(
status_code=400,
message=(
f"Given that {UPSTREAM_KEY_HEADER!r} header is missing, "
f"it's expected that hostname of upstream endpoint ({upstream_endpoint!r}) is "
f"the same as the local DIAL URL ({LOCAL_DIAL_URL!r}) "
),
raise InvalidRequestError(
f"Given that {UPSTREAM_KEY_HEADER!r} header is missing, "
f"it's expected that hostname of upstream endpoint ({upstream_endpoint!r}) is "
f"the same as the local DIAL URL ({LOCAL_DIAL_URL!r}) "
)

local_dial_api_key = request.headers.get("api-key")
if not local_dial_api_key:
raise HTTPException(
status_code=400,
message="The 'api-key' request header is missing",
raise InvalidRequestError(
"The 'api-key' request header is missing"
)

remote_dial_api_key = local_dial_api_key

endpoint_suffix = f"/{endpoint_name}"
if not upstream_endpoint.endswith(endpoint_suffix):
raise HTTPException(
status_code=400,
message=f"The {UPSTREAM_ENDPOINT_HEADER!r} request header must end with {endpoint_suffix!r}",
raise InvalidRequestError(
f"The {UPSTREAM_ENDPOINT_HEADER!r} request header must end with {endpoint_suffix!r}"
)
upstream_endpoint = upstream_endpoint.removesuffix(endpoint_suffix)

Expand Down Expand Up @@ -149,7 +134,6 @@ async def parse(cls, request: Request, endpoint_name: str) -> "AzureClient":

@app.post("/embeddings")
@app.post("/openai/deployments/{deployment_id:path}/embeddings")
@dial_exception_decorator
async def embeddings_proxy(request: Request):
body = await request.json()
az_client = await AzureClient.parse(request, "embeddings")
Expand All @@ -163,7 +147,6 @@ async def embeddings_proxy(request: Request):

@app.post("/chat/completions")
@app.post("/openai/deployments/{deployment_id:path}/chat/completions")
@dial_exception_decorator
async def chat_completions_proxy(request: Request):

az_client = await AzureClient.parse(request, "chat/completions")
Expand Down Expand Up @@ -192,9 +175,7 @@ async def modify_chunk(chunk: dict) -> dict:

chunk_stream = map_stream(lambda obj: obj.to_dict(), response)
return StreamingResponse(
to_openai_sse_stream(
amap_stream(modify_chunk, generate_stream(chunk_stream))
),
to_openai_sse_stream(amap_stream(modify_chunk, chunk_stream)),
media_type="text/event-stream",
)
else:
Expand All @@ -205,6 +186,14 @@ async def modify_chunk(chunk: dict) -> dict:
return resp


@app.exception_handler(Exception)
def exception_handler(request: Request, e: Exception):
log.exception(f"caught exception: {type(e).__module__}.{type(e).__name__}")
dial_exception = to_dial_exception(e)
fastapi_response = dial_exception.to_fastapi_response()
return fastapi_response


@app.get("/health")
def health():
return {"status": "ok"}
141 changes: 39 additions & 102 deletions aidial_adapter_dial/utils/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,131 +1,68 @@
import dataclasses
import logging
from functools import wraps
from typing import Optional
from typing import Any

from fastapi import HTTPException as FastAPIException
from aidial_sdk.exceptions import HTTPException as DialException
from fastapi.responses import JSONResponse as FastAPIResponse
from httpx import Headers
from openai import APIConnectionError, APIStatusError, APITimeoutError

log = logging.getLogger(__name__)


class HTTPException(Exception):
def __init__(
self,
message: str,
status_code: int = 500,
type: str = "runtime_error",
param: Optional[str] = None,
code: Optional[str] = None,
display_message: Optional[str] = None,
) -> None:
self.message = message
self.status_code = status_code
self.type = type
self.param = param
self.code = code
self.display_message = display_message

def __repr__(self):
return (
"%s(message=%r, status_code=%r, type=%r, param=%r, code=%r, display_message=%r)"
% (
self.__class__.__name__,
self.message,
self.status_code,
self.type,
self.param,
self.code,
self.display_message,
)
)

@dataclasses.dataclass
class ResponseWrapper:
status_code: int
headers: Headers | None
content: Any

def remove_nones(d: dict) -> dict:
return {k: v for k, v in d.items() if v is not None}


def create_error(
message: str,
type: Optional[str] = None,
param: Optional[str] = None,
code: Optional[str] = None,
display_message: Optional[str] = None,
):
return {
"error": remove_nones(
{
"message": message,
"type": type,
"param": param,
"code": code,
"display_message": display_message,
}
def to_fastapi_response(self) -> FastAPIResponse:
return FastAPIResponse(
content=self.content,
status_code=self.status_code,
headers=self.headers,
)
}


def to_dial_exception(e: Exception) -> HTTPException | FastAPIException:
if isinstance(e, APIStatusError):
r = e.response
def to_dial_exception(exc: Exception) -> DialException | ResponseWrapper:
if isinstance(exc, APIStatusError):
r = exc.response
headers = r.headers

if "Content-Length" in headers:
del headers["Content-Length"]

return FastAPIException(
detail=r.text,
try:
content = r.json()
except Exception:
content = r.text

return ResponseWrapper(
status_code=r.status_code,
headers=dict(headers),
headers=headers,
content=content,
)

if isinstance(e, APITimeoutError):
return HTTPException("Request timed out", 504, "timeout")
if isinstance(exc, APITimeoutError):
return DialException("Request timed out", 504, "timeout")

if isinstance(e, APIConnectionError):
return HTTPException(
if isinstance(exc, APIConnectionError):
return DialException(
"Error communicating with OpenAI", 502, "connection"
)

if isinstance(e, HTTPException):
return e
if isinstance(exc, DialException):
return exc

return HTTPException(
return DialException(
status_code=500,
type="internal_server_error",
message=str(e),
code=None,
param=None,
message=str(exc),
)


def to_starlette_exception(
e: HTTPException | FastAPIException,
) -> FastAPIException:
if isinstance(e, FastAPIException):
return e

return FastAPIException(
status_code=e.status_code,
detail=create_error(
message=e.message,
type=e.type,
param=e.param,
code=e.code,
display_message=e.display_message,
),
)


def dial_exception_decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as e:
log.exception(
f"caught exception: {type(e).__module__}.{type(e).__name__}"
)
dial_exception = to_dial_exception(e)
raise to_starlette_exception(dial_exception) from e

return wrapper
def to_json_content(exc: DialException | ResponseWrapper) -> Any:
if isinstance(exc, DialException):
return exc.json_error()
else:
return exc.content
8 changes: 3 additions & 5 deletions aidial_adapter_dial/utils/reflection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from typing import Any, Callable, Coroutine, TypeVar

from aidial_adapter_dial.utils.exceptions import HTTPException
from aidial_sdk.exceptions import InvalidRequestError

T = TypeVar("T")

Expand All @@ -18,10 +18,8 @@ async def call_with_extra_body(
extra_args = actual_args - expected_args

if extra_args and "extra_body" not in expected_args:
raise HTTPException(
f"Unrecognized request argument supplied: {extra_args}.",
400,
"invalid_request_error",
raise InvalidRequestError(
f"Unrecognized request argument supplied: {extra_args}."
)

arg["extra_body"] = arg.get("extra_body") or {}
Expand Down
47 changes: 5 additions & 42 deletions aidial_adapter_dial/utils/sse_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from typing import Any, AsyncIterator, Mapping

from aidial_adapter_dial.utils.exceptions import (
create_error,
to_dial_exception,
to_starlette_exception,
to_json_content,
)

DATA_PREFIX = "data: "
Expand All @@ -22,43 +21,6 @@ def format_chunk(data: str | Mapping[str, Any]) -> str:
END_CHUNK = format_chunk(OPENAI_END_MARKER)


async def parse_openai_sse_stream(
stream: AsyncIterator[bytes],
) -> AsyncIterator[dict]:
async for line in stream:
try:
payload = line.decode("utf-8-sig").lstrip() # type: ignore
except Exception:
yield create_error(
message="Can't decode chunk to a string", type="runtime_error"
)
return

if payload.strip() == "":
continue

if not payload.startswith(DATA_PREFIX):
yield create_error(
message="Invalid chunk format", type="runtime_error"
)
return

payload = payload[len(DATA_PREFIX) :]

if payload.strip() == OPENAI_END_MARKER:
break

try:
chunk = json.loads(payload)
except json.JSONDecodeError:
yield create_error(
message="Can't parse chunk to JSON", type="runtime_error"
)
return

yield chunk


log = logging.getLogger(__name__)


Expand All @@ -68,13 +30,14 @@ async def to_openai_sse_stream(
try:
async for chunk in stream:
yield format_chunk(chunk)
yield END_CHUNK
except Exception as e:
log.exception(
f"caught exception while streaming: {type(e).__module__}.{type(e).__name__}"
)

dial_exception = to_dial_exception(e)
starlette_exception = to_starlette_exception(dial_exception)
error_chunk = to_json_content(dial_exception)

yield format_chunk(error_chunk)

yield format_chunk(starlette_exception.detail)
yield END_CHUNK
adubovik marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading