Skip to content

Commit

Permalink
Fix helper chunk extraction method to be async compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
Yun-Kim committed Dec 31, 2024
1 parent 23e706b commit 480f3d0
Showing 1 changed file with 30 additions and 18 deletions.
48 changes: 30 additions & 18 deletions ddtrace/contrib/internal/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,6 @@ def __init__(self, wrapped, integration, span, kwargs, is_completion=False):
self._is_completion = is_completion
self._kwargs = kwargs

def _extract_token_chunk(self, chunk):
"""Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
if not self._dd_span._get_ctx_item("openai_stream_magic"):
return
choice = getattr(chunk, "choices", [None])[0]
if not getattr(choice, "finish_reason", None):
return
try:
usage_chunk = next(self.__wrapped__)
self._streamed_chunks[0].insert(0, usage_chunk)
except (StopIteration, GeneratorExit):
pass


class TracedOpenAIStream(BaseTracedOpenAIStream):
def __enter__(self):
Expand Down Expand Up @@ -98,6 +85,18 @@ def __next__(self):
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)
raise

def _extract_token_chunk(self, chunk):
"""Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
if not self._dd_span._get_ctx_item("openai_stream_magic"):
return
choice = getattr(chunk, "choices", [None])[0]
if not getattr(choice, "finish_reason", None):
return
try:
usage_chunk = next(self.__wrapped__)
self._streamed_chunks[0].insert(0, usage_chunk)
except (StopIteration, GeneratorExit):
return

class TracedOpenAIAsyncStream(BaseTracedOpenAIStream):
async def __aenter__(self):
Expand All @@ -107,11 +106,11 @@ async def __aenter__(self):
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb)

def __aiter__(self):
async def __aiter__(self):
exception_raised = False
try:
for chunk in self.__wrapped__:
self._extract_token_chunk(chunk)
async for chunk in self.__wrapped__:
await self._extract_token_chunk(chunk)
yield chunk
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
except Exception:
Expand All @@ -128,8 +127,8 @@ def __aiter__(self):

async def __anext__(self):
try:
chunk = await self.__wrapped__.__anext__()
self._extract_token_chunk(chunk)
chunk = await anext(self.__wrapped__)
await self._extract_token_chunk(chunk)
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
return chunk
except StopAsyncIteration:
Expand All @@ -145,6 +144,19 @@ async def __anext__(self):
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)
raise

async def _extract_token_chunk(self, chunk):
"""Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
if not self._dd_span._get_ctx_item("openai_stream_magic"):
return
choice = getattr(chunk, "choices", [None])[0]
if not getattr(choice, "finish_reason", None):
return
try:
usage_chunk = await anext(self.__wrapped__)
self._streamed_chunks[0].insert(0, usage_chunk)
except (StopAsyncIteration, GeneratorExit):
return


def _compute_token_count(content, model):
# type: (Union[str, List[int]], Optional[str]) -> Tuple[bool, int]
Expand Down

0 comments on commit 480f3d0

Please sign in to comment.