Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(llmobs): automatically set span links with decorators #12255

Merged
merged 8 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion ddtrace/llmobs/_llmobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from ddtrace.llmobs._constants import TAGS
from ddtrace.llmobs._evaluators.runner import EvaluatorRunner
from ddtrace.llmobs._utils import AnnotationContext
from ddtrace.llmobs._utils import LinkTracker
from ddtrace.llmobs._utils import _get_llmobs_parent_id
from ddtrace.llmobs._utils import _get_ml_app
from ddtrace.llmobs._utils import _get_session_id
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(self, tracer=None):

forksafe.register(self._child_after_fork)

self._link_tracker = LinkTracker()
self._annotations = []
self._annotation_context_lock = forksafe.RLock()

Expand Down Expand Up @@ -204,7 +206,7 @@ def _llmobs_span_event(cls, span: Span) -> Dict[str, Any]:
llmobs_span_event["tags"] = cls._llmobs_tags(span, ml_app, session_id)

span_links = span._get_ctx_item(SPAN_LINKS)
if isinstance(span_links, list):
if isinstance(span_links, list) and span_links:
llmobs_span_event["span_links"] = span_links

return llmobs_span_event
Expand Down Expand Up @@ -391,6 +393,55 @@ def disable(cls) -> None:

log.debug("%s disabled", cls.__name__)

def _record_object(self, span, obj, input_or_output):
if obj is None:
return
span_links = []
for span_link in self._link_tracker.get_span_links_from_object(obj):
try:
if span_link["attributes"]["from"] == "input" and input_or_output == "output":
continue
except KeyError:
log.debug("failed to read span link: ", span_link)
continue
span_links.append(
{
"trace_id": span_link["trace_id"],
"span_id": span_link["span_id"],
"attributes": {
"from": span_link["attributes"]["from"],
"to": input_or_output,
},
}
)
self._tag_span_links(span, span_links)
self._link_tracker.add_span_links_to_object(
obj,
[
{
"trace_id": self.export_span(span)["trace_id"],
"span_id": self.export_span(span)["span_id"],
"attributes": {
"from": input_or_output,
},
}
],
)

def _tag_span_links(self, span, span_links):
if not span_links:
return
span_links = [
lievan marked this conversation as resolved.
Show resolved Hide resolved
span_link
for span_link in span_links
if span_link["span_id"] != LLMObs.export_span(span)["span_id"]
and span_link["trace_id"] == LLMObs.export_span(span)["trace_id"]
]
current_span_links = span._get_ctx_item(SPAN_LINKS)
if current_span_links:
span_links = current_span_links + span_links
span._set_ctx_item(SPAN_LINKS, span_links)

@classmethod
def annotation_context(
cls, tags: Optional[Dict[str, Any]] = None, prompt: Optional[dict] = None, name: Optional[str] = None
Expand Down
17 changes: 17 additions & 0 deletions ddtrace/llmobs/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,23 @@ def validate_prompt(prompt: dict) -> Dict[str, Union[str, dict, List[str]]]:
return validated_prompt


class LinkTracker:
def __init__(self, object_span_links=None):
self._object_span_links = object_span_links or {}

def get_object_id(self, obj):
return f"{type(obj).__name__}_{id(obj)}"

def add_span_links_to_object(self, obj, span_links):
obj_id = self.get_object_id(obj)
if obj_id not in self._object_span_links:
self._object_span_links[obj_id] = []
self._object_span_links[obj_id] += span_links

def get_span_links_from_object(self, obj):
return self._object_span_links.get(self.get_object_id(obj), [])


class AnnotationContext:
def __init__(self, _register_annotator, _deregister_annotator):
self._register_annotator = _register_annotator
Expand Down
20 changes: 18 additions & 2 deletions ddtrace/llmobs/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable
from typing import Optional

from ddtrace import config
from ddtrace.internal.compat import iscoroutinefunction
from ddtrace.internal.compat import isgeneratorfunction
from ddtrace.internal.logger import get_logger
Expand Down Expand Up @@ -138,8 +139,16 @@ def wrapper(*args, **kwargs):
name=span_name,
session_id=session_id,
ml_app=ml_app,
):
return func(*args, **kwargs)
) as span:
if config._llmobs_auto_span_linking_enabled:
for arg in args:
LLMObs._instance._record_object(span, arg, "input")
for arg in kwargs.values():
LLMObs._instance._record_object(span, arg, "input")
ret = func(*args, **kwargs)
if config._llmobs_auto_span_linking_enabled:
LLMObs._instance._record_object(span, ret, "output")
return ret

return generator_wrapper if (isgeneratorfunction(func) or isasyncgenfunction(func)) else wrapper

Expand Down Expand Up @@ -231,6 +240,11 @@ def wrapper(*args, **kwargs):
_, span_name = _get_llmobs_span_options(name, None, func)
traced_operation = getattr(LLMObs, operation_kind, LLMObs.workflow)
with traced_operation(name=span_name, session_id=session_id, ml_app=ml_app) as span:
if config._llmobs_auto_span_linking_enabled:
for arg in args:
LLMObs._instance._record_object(span, arg, "input")
for arg in kwargs.values():
LLMObs._instance._record_object(span, arg, "input")
func_signature = signature(func)
bound_args = func_signature.bind_partial(*args, **kwargs)
if _automatic_io_annotation and bound_args.arguments:
Expand All @@ -243,6 +257,8 @@ def wrapper(*args, **kwargs):
and span._get_ctx_item(OUTPUT_VALUE) is None
):
LLMObs.annotate(span=span, output_data=resp)
if config._llmobs_auto_span_linking_enabled:
LLMObs._instance._record_object(span, resp, "output")
return resp

return generator_wrapper if (isgeneratorfunction(func) or isasyncgenfunction(func)) else wrapper
Expand Down
1 change: 1 addition & 0 deletions ddtrace/settings/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def __init__(self):
self._llmobs_sample_rate = _get_config("DD_LLMOBS_SAMPLE_RATE", 1.0, float)
self._llmobs_ml_app = _get_config("DD_LLMOBS_ML_APP")
self._llmobs_agentless_enabled = _get_config("DD_LLMOBS_AGENTLESS_ENABLED", False, asbool)
self._llmobs_auto_span_linking_enabled = _get_config("_DD_LLMOBS_AUTO_SPAN_LINKING_ENABLED", False, asbool)

self._inject_force = _get_config("DD_INJECT_FORCE", False, asbool)
self._lib_was_injected = False
Expand Down
16 changes: 15 additions & 1 deletion tests/llmobs/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,13 @@ def _expected_llmobs_non_llm_span_event(


def _llmobs_base_span_event(
span, span_kind, tags=None, session_id=None, error=None, error_message=None, error_stack=None
span,
span_kind,
tags=None,
session_id=None,
error=None,
error_message=None,
error_stack=None,
):
span_event = {
"trace_id": "{:x}".format(span.trace_id),
Expand Down Expand Up @@ -763,3 +769,11 @@ def _expected_ragas_answer_relevancy_spans(ragas_inputs=None):
"tags": expected_ragas_trace_tags(),
},
]


def _expected_span_link(span_event, link_from, link_to):
return {
"trace_id": span_event["trace_id"],
"span_id": span_event["span_id"],
"attributes": {"from": link_from, "to": link_to},
}
83 changes: 83 additions & 0 deletions tests/llmobs/test_llmobs_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
from ddtrace.llmobs.decorators import workflow
from tests.llmobs._utils import _expected_llmobs_llm_span_event
from tests.llmobs._utils import _expected_llmobs_non_llm_span_event
from tests.llmobs._utils import _expected_span_link
from tests.utils import override_global_config


@pytest.fixture
def auto_linking_enabled():
with override_global_config(dict(_llmobs_auto_span_linking_enabled=True)):
yield


@pytest.fixture
Expand Down Expand Up @@ -828,3 +836,78 @@ def get_next_element(alist):
error_message=span.get_tag("error.message"),
error_stack=span.get_tag("error.stack"),
)


def test_decorator_records_span_links(llmobs, llmobs_events, auto_linking_enabled):
@workflow
def one(inp):
return 1

@task
def two(inp):
return inp

with llmobs.agent("dummy_trace"):
two(one("test_input"))

one_span = llmobs_events[0]
two_span = llmobs_events[1]

assert "span_links" not in one_span
assert len(two_span["span_links"]) == 2
assert two_span["span_links"][0] == _expected_span_link(one_span, "output", "input")
assert two_span["span_links"][1] == _expected_span_link(one_span, "output", "output")


def test_decorator_records_span_links_for_multi_input_functions(llmobs, llmobs_events, auto_linking_enabled):
@agent
def some_agent(a, b):
pass

@workflow
def one():
return 1

@task
def two():
return 2

with llmobs.agent("dummy_trace"):
some_agent(one(), two())

one_span = llmobs_events[0]
two_span = llmobs_events[1]
three_span = llmobs_events[2]

assert "span_links" not in one_span
assert "span_links" not in two_span
assert len(three_span["span_links"]) == 2
assert three_span["span_links"][0] == _expected_span_link(one_span, "output", "input")
assert three_span["span_links"][1] == _expected_span_link(two_span, "output", "input")


def test_decorator_records_span_links_via_kwargs(llmobs, llmobs_events, auto_linking_enabled):
@agent
def some_agent(a=None, b=None):
pass

@workflow
def one():
return 1

@task
def two():
return 2

with llmobs.agent("dummy_trace"):
some_agent(one(), two())

one_span = llmobs_events[0]
two_span = llmobs_events[1]
three_span = llmobs_events[2]

assert "span_links" not in one_span
assert "span_links" not in two_span
assert len(three_span["span_links"]) == 2
assert three_span["span_links"][0] == _expected_span_link(one_span, "output", "input")
assert three_span["span_links"][1] == _expected_span_link(two_span, "output", "input")
1 change: 1 addition & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def override_global_config(values):
"_llmobs_sample_rate",
"_llmobs_ml_app",
"_llmobs_agentless_enabled",
"_llmobs_auto_span_linking_enabled",
"_data_streams_enabled",
]

Expand Down
Loading