Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(openai): add token usage stream options to request #11606

Merged
merged 16 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ddtrace/contrib/internal/openai/_endpoint_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,13 @@ def _record_request(self, pin, integration, span, args, kwargs):
span.set_tag_str("openai.request.messages.%d.content" % idx, integration.trunc(str(content)))
span.set_tag_str("openai.request.messages.%d.role" % idx, str(role))
span.set_tag_str("openai.request.messages.%d.name" % idx, str(name))
if parse_version(OPENAI_VERSION) >= (1, 26) and kwargs.get("stream"):
if kwargs.get("stream_options", {}).get("include_usage", None) is not None:
return
span._set_ctx_item("openai_stream_magic", True)
stream_options = kwargs.get("stream_options", {})
stream_options["include_usage"] = True
kwargs["stream_options"] = stream_options

def _record_response(self, pin, integration, span, args, kwargs, resp, error):
resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
Expand Down
53 changes: 50 additions & 3 deletions ddtrace/contrib/internal/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ def __init__(self, wrapped, integration, span, kwargs, is_completion=False):
self._is_completion = is_completion
self._kwargs = kwargs

def _extract_token_chunk(self, chunk):
"""Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
if not self._dd_span._get_ctx_item("openai_stream_magic"):
return
choice = getattr(chunk, "choices", [None])[0]
if not getattr(choice, "finish_reason", None):
return
try:
usage_chunk = next(self.__wrapped__)
self._streamed_chunks[0].insert(0, usage_chunk)
except (StopIteration, GeneratorExit):
pass
Yun-Kim marked this conversation as resolved.
Show resolved Hide resolved


class TracedOpenAIStream(BaseTracedOpenAIStream):
def __enter__(self):
Expand All @@ -48,11 +61,28 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.__wrapped__.__exit__(exc_type, exc_val, exc_tb)

def __iter__(self):
return self
exception_raised = False
try:
for chunk in self.__wrapped__:
self._extract_token_chunk(chunk)
yield chunk
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
except Exception:
self._dd_span.set_exc_info(*sys.exc_info())
exception_raised = True
raise
finally:
if not exception_raised:
_process_finished_stream(
self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion
)
self._dd_span.finish()
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)

def __next__(self):
try:
chunk = self.__wrapped__.__next__()
chunk = next(self.__wrapped__)
self._extract_token_chunk(chunk)
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
return chunk
except StopIteration:
Expand All @@ -78,11 +108,28 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb)

def __aiter__(self):
return self
exception_raised = False
try:
for chunk in self.__wrapped__:
self._extract_token_chunk(chunk)
yield chunk
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
except Exception:
self._dd_span.set_exc_info(*sys.exc_info())
exception_raised = True
raise
finally:
if not exception_raised:
_process_finished_stream(
self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion
)
self._dd_span.finish()
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)

async def __anext__(self):
try:
chunk = await self.__wrapped__.__anext__()
self._extract_token_chunk(chunk)
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
return chunk
except StopAsyncIteration:
Expand Down
10 changes: 7 additions & 3 deletions tests/contrib/openai/test_openai_llmobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,12 +518,16 @@ async def test_chat_completion_azure_async(
)
)

@pytest.mark.skipif(
parse_version(openai_module.version.VERSION) < (1, 26), reason="Stream options only available openai >= 1.26"
)
def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer):
"""Ensure llmobs records are emitted for chat completion endpoints when configured.

Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation.
"""
with get_openai_vcr(subdirectory_name="v1").use_cassette("chat_completion_streamed.yaml"):

with get_openai_vcr(subdirectory_name="v1").use_cassette("chat_completion_streamed_tokens.yaml"):
with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding:
with mock.patch("ddtrace.contrib.internal.openai.utils._est_tokens") as mock_est:
mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8]
Expand All @@ -547,8 +551,8 @@ def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs
model_provider="openai",
input_messages=input_messages,
output_messages=[{"content": expected_completion, "role": "assistant"}],
metadata={"stream": True, "user": "ddtrace-test"},
token_metrics={"input_tokens": 8, "output_tokens": 8, "total_tokens": 16},
metadata={"stream": True, "stream_options": {"include_usage": True}, "user": "ddtrace-test"},
token_metrics={"input_tokens": 17, "output_tokens": 19, "total_tokens": 36},
tags={"ml_app": "<ml-app-name>", "service": "tests.contrib.openai"},
)
)
Expand Down
56 changes: 54 additions & 2 deletions tests/contrib/openai/test_openai_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,59 @@ def test_completion_stream_context_manager(openai, openai_vcr, mock_metrics, moc
assert mock.call.distribution("tokens.total", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls


@pytest.mark.skipif(
parse_version(openai_module.version.VERSION) < (1, 26), reason="Stream options only available openai >= 1.26"
)
def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_tracer):
with openai_vcr.use_cassette("chat_completion_streamed_tokens.yaml"):
with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding:
mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8]
expected_completion = "The Los Angeles Dodgers won the World Series in 2020."
client = openai.OpenAI()
resp = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": "Who won the world series in 2020?"},
],
stream=True,
user="ddtrace-test",
n=None,
)
span = snapshot_tracer.current_span()
chunks = [c for c in resp]
assert len(chunks) == 15
completion = "".join([c.choices[0].delta.content for c in chunks if c.choices[0].delta.content is not None])
assert completion == expected_completion

assert span.get_tag("openai.response.choices.0.message.content") == expected_completion
assert span.get_tag("openai.response.choices.0.message.role") == "assistant"
assert span.get_tag("openai.response.choices.0.finish_reason") == "stop"

expected_tags = [
"version:",
"env:",
"service:tests.contrib.openai",
"openai.request.model:gpt-3.5-turbo",
"model:gpt-3.5-turbo",
"openai.request.endpoint:/v1/chat/completions",
"openai.request.method:POST",
"openai.organization.id:",
"openai.organization.name:datadog-4",
"openai.user.api_key:sk-...key>",
"error:0",
]
assert mock.call.distribution("request.duration", span.duration_ns, tags=expected_tags) in mock_metrics.mock_calls
assert mock.call.gauge("ratelimit.requests", 3000, tags=expected_tags) in mock_metrics.mock_calls
assert mock.call.gauge("ratelimit.remaining.requests", 2999, tags=expected_tags) in mock_metrics.mock_calls
assert mock.call.distribution("tokens.prompt", 17, tags=expected_tags) in mock_metrics.mock_calls
assert mock.call.distribution("tokens.completion", 19, tags=expected_tags) in mock_metrics.mock_calls
assert mock.call.distribution("tokens.total", 36, tags=expected_tags) in mock_metrics.mock_calls


@pytest.mark.skipif(
parse_version(openai_module.version.VERSION) < (1, 26), reason="Stream options only available openai >= 1.26"
)
def test_chat_completion_stream_explicit_no_tokens(openai, openai_vcr, mock_metrics, snapshot_tracer):
with openai_vcr.use_cassette("chat_completion_streamed.yaml"):
with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding:
mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8]
Expand All @@ -1054,10 +1106,10 @@ def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_trace
{"role": "user", "content": "Who won the world series in 2020?"},
],
stream=True,
stream_options={"include_usage": False},
user="ddtrace-test",
n=None,
)
prompt_tokens = 8
span = snapshot_tracer.current_span()
chunks = [c for c in resp]
assert len(chunks) == 15
Expand Down Expand Up @@ -1087,7 +1139,7 @@ def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_trace
expected_tags += ["openai.estimated:true"]
if TIKTOKEN_AVAILABLE:
expected_tags = expected_tags[:-1]
assert mock.call.distribution("tokens.prompt", prompt_tokens, tags=expected_tags) in mock_metrics.mock_calls
assert mock.call.distribution("tokens.prompt", 8, tags=expected_tags) in mock_metrics.mock_calls
assert mock.call.distribution("tokens.completion", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls
assert mock.call.distribution("tokens.total", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls

Expand Down
Loading