Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(openai): add token usage stream options to request #11606

Merged
merged 16 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ddtrace/contrib/internal/openai/_endpoint_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,14 @@ def _record_request(self, pin, integration, span, args, kwargs):
span.set_tag_str("openai.request.messages.%d.content" % idx, integration.trunc(str(content)))
span.set_tag_str("openai.request.messages.%d.role" % idx, str(role))
span.set_tag_str("openai.request.messages.%d.name" % idx, str(name))
if parse_version(OPENAI_VERSION) >= (1, 26) and kwargs.get("stream"):
if kwargs.get("stream_options", {}).get("include_usage", None) is not None:
# Only perform token chunk auto-extraction if this option is not explicitly set
return
span._set_ctx_item("_dd.auto_extract_token_chunk", True)
stream_options = kwargs.get("stream_options", {})
stream_options["include_usage"] = True
kwargs["stream_options"] = stream_options

def _record_response(self, pin, integration, span, args, kwargs, resp, error):
resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
Expand Down
69 changes: 66 additions & 3 deletions ddtrace/contrib/internal/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,28 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.__wrapped__.__exit__(exc_type, exc_val, exc_tb)

def __iter__(self):
return self
exception_raised = False
try:
for chunk in self.__wrapped__:
self._extract_token_chunk(chunk)
yield chunk
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
except Exception:
self._dd_span.set_exc_info(*sys.exc_info())
exception_raised = True
raise
finally:
if not exception_raised:
_process_finished_stream(
self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion
)
self._dd_span.finish()
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)

def __next__(self):
try:
chunk = self.__wrapped__.__next__()
self._extract_token_chunk(chunk)
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
return chunk
except StopIteration:
Expand All @@ -68,6 +85,22 @@ def __next__(self):
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)
raise

def _extract_token_chunk(self, chunk):
"""Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
if not self._dd_span._get_ctx_item("_dd.auto_extract_token_chunk"):
return
choice = getattr(chunk, "choices", [None])[0]
if not getattr(choice, "finish_reason", None):
# Only the second-last chunk in the stream with token usage enabled will have finish_reason set
return
Yun-Kim marked this conversation as resolved.
Show resolved Hide resolved
try:
# User isn't expecting last token chunk to be present since it's not part of the default streamed response,
# so we consume it and extract the token usage metadata before it reaches the user.
usage_chunk = self.__wrapped__.__next__()
Kyle-Verhoog marked this conversation as resolved.
Show resolved Hide resolved
self._streamed_chunks[0].insert(0, usage_chunk)
except (StopIteration, GeneratorExit):
return


class TracedOpenAIAsyncStream(BaseTracedOpenAIStream):
async def __aenter__(self):
Expand All @@ -77,12 +110,29 @@ async def __aenter__(self):
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb)

def __aiter__(self):
return self
async def __aiter__(self):
exception_raised = False
try:
async for chunk in self.__wrapped__:
await self._extract_token_chunk(chunk)
yield chunk
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
except Exception:
self._dd_span.set_exc_info(*sys.exc_info())
exception_raised = True
raise
finally:
if not exception_raised:
_process_finished_stream(
self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion
)
self._dd_span.finish()
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)

async def __anext__(self):
try:
chunk = await self.__wrapped__.__anext__()
await self._extract_token_chunk(chunk)
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
return chunk
except StopAsyncIteration:
Expand All @@ -98,6 +148,19 @@ async def __anext__(self):
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)
raise

async def _extract_token_chunk(self, chunk):
"""Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
if not self._dd_span._get_ctx_item("_dd.auto_extract_token_chunk"):
return
choice = getattr(chunk, "choices", [None])[0]
if not getattr(choice, "finish_reason", None):
return
try:
usage_chunk = await self.__wrapped__.__anext__()
self._streamed_chunks[0].insert(0, usage_chunk)
except (StopAsyncIteration, GeneratorExit):
return


def _compute_token_count(content, model):
# type: (Union[str, List[int]], Optional[str]) -> Tuple[bool, int]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
openai: Introduces automatic extraction of token usage from streamed chat completions.
Unless ``stream_options: {"include_usage": False}`` is explicitly set on your streamed chat completion request,
the OpenAI integration will add ``stream_options: {"include_usage": True}`` to your request and automatically extract the token usage chunk from the streamed response.
27 changes: 15 additions & 12 deletions tests/contrib/openai/test_openai_llmobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,11 +518,17 @@ async def test_chat_completion_azure_async(
)
)

def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer):
@pytest.mark.skipif(
parse_version(openai_module.version.VERSION) < (1, 26), reason="Stream options only available openai >= 1.26"
)
def test_chat_completion_stream_explicit_no_tokens(
self, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer
):
"""Ensure llmobs records are emitted for chat completion endpoints when configured.

Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation.
"""

with get_openai_vcr(subdirectory_name="v1").use_cassette("chat_completion_streamed.yaml"):
with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding:
with mock.patch("ddtrace.contrib.internal.openai.utils._est_tokens") as mock_est:
Expand All @@ -534,7 +540,11 @@ def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs
expected_completion = "The Los Angeles Dodgers won the World Series in 2020."
client = openai.OpenAI()
resp = client.chat.completions.create(
model=model, messages=input_messages, stream=True, user="ddtrace-test"
model=model,
messages=input_messages,
stream=True,
user="ddtrace-test",
stream_options={"include_usage": False},
)
for chunk in resp:
resp_model = chunk.model
Expand All @@ -547,7 +557,7 @@ def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs
model_provider="openai",
input_messages=input_messages,
output_messages=[{"content": expected_completion, "role": "assistant"}],
metadata={"stream": True, "user": "ddtrace-test"},
metadata={"stream": True, "stream_options": {"include_usage": False}, "user": "ddtrace-test"},
token_metrics={"input_tokens": 8, "output_tokens": 8, "total_tokens": 16},
tags={"ml_app": "<ml-app-name>", "service": "tests.contrib.openai"},
)
Expand All @@ -557,20 +567,14 @@ def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs
parse_version(openai_module.version.VERSION) < (1, 26, 0), reason="Streamed tokens available in 1.26.0+"
)
def test_chat_completion_stream_tokens(self, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer):
"""
Ensure llmobs records are emitted for chat completion endpoints when configured
with the `stream_options={"include_usage": True}`.
Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation.
"""
"""Assert that streamed token chunk extraction logic works when options are not explicitly passed from user."""
with get_openai_vcr(subdirectory_name="v1").use_cassette("chat_completion_streamed_tokens.yaml"):
model = "gpt-3.5-turbo"
resp_model = model
input_messages = [{"role": "user", "content": "Who won the world series in 2020?"}]
expected_completion = "The Los Angeles Dodgers won the World Series in 2020."
client = openai.OpenAI()
resp = client.chat.completions.create(
model=model, messages=input_messages, stream=True, stream_options={"include_usage": True}
)
resp = client.chat.completions.create(model=model, messages=input_messages, stream=True)
for chunk in resp:
resp_model = chunk.model
span = mock_tracer.pop_traces()[0][0]
Expand Down Expand Up @@ -671,7 +675,6 @@ def test_chat_completion_tool_call_stream(self, openai, ddtrace_global_config, m
messages=[{"role": "user", "content": chat_completion_input_description}],
user="ddtrace-test",
stream=True,
stream_options={"include_usage": True},
)
for chunk in resp:
resp_model = chunk.model
Expand Down
Loading
Loading