Skip to content

Commit

Permalink
chore(langchain): auto-instrument with langgraph (#12208)
Browse files Browse the repository at this point in the history
Adds gated auto-instrumentation for LangChain with LangGraph, along with
a couple small fixes for LangGraph.

### LangGraph bug fixes
1. We were not properly recording errors from LangGraph spans, and not
setting LLMObs tags in the case of failures. We add our
`set_llmobs_tags` in the appropriate spots to resolve this.
2. LangGraph will sometimes have triggers for a step be something like
`["__pregel_push"]` in the case when a `Send` is enqueued as a trigger,
and will not give hints to the actual invoker. To resolve this, we check
that if for any given queued task, there is only one finished task. If
so, set that task as the trigger.

### Complications with LangChain linking
The primary obstacle was that, unlike with LangGraph, where we had an
intermediary function to patch between tasks executed by the graph, we
don't have that for LangChain LCEL chains. Their steps are executed in a
loop inside their `invoke` method, which blocks us from jumping in
between the steps to make the links.

Additionally, LangChain elements that we trace are sometimes embedded in
other `Runnable` types:
- `RunnableBinding`, which binds an instance inside of it
- `RunnableAssign`, which embeds an instance of a `RunnableParallel`
- `RunnableParallel`, which can run multiple `Runnables` (some of which
we might trace) in parallel, which have the parallel items in a
`steps__` attribute

Something I tried to overcome this was to flatten/flatmap the items of
the list of steps to extract these.

To overcome linking between steps, I recorded the instance of each
traced Runnable, mapping its ID to its span, and vice versa, to be able
to grab instances if needed. Additionally, for each Runnable item in the
chain, I marked them as a chain step by adding them to a set of steps to
later check against.

### Setting links
Link setting is split into setting the input links (`"to": "input"`) and
output links (`"to": "output"`)

#### Input Links

Input links are set by:
1. Identifying if the span represents a step in a chain. If **not**, set
its input link as `input --> input` from the parent span.
3. If it does represent a step in the chain, find the index of the
previous traced step in the chain (the chain instance is grabbed from
the span to instance mapping referenced above), and setting it as the
`output --> input` link. If the step contains multiple spans (ie from a
`RunnableParallel`), add all of those spans as links with the same
`output --> input` attribute

The index of the previously traced step in the chain (`-1` if not found
or is not a chain step) is returned for use in output linkage.

#### Output Links

Output links are set by:
1. If the span does not represent a step in a chain, or the parent span
isn't a chain (ie has a `steps` attribute), set the `output --> output`
link from the current span onto the parent span.
2. If the span does represent a step a chain, remove all previous span
links on the span from the previous traced step, and set the new span
link from the current span. We do this overwriting every time because we
don't know ahead of time which step in the chain will be the last one we
trace, so we have to remove previous span links if we find we need to
add a new one.

## Checklist
- [x] PR author has checked that all the criteria below are met
- The PR description includes an overview of the change
- The PR description articulates the motivation for the change
- The change includes tests OR the PR description describes a testing
strategy
- The PR description notes risks associated with the change, if any
- Newly-added code is easy to change
- The change follows the [library release note
guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html)
- The change includes or references documentation updates if necessary
- Backport labels are set (if
[applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting))

## Reviewer Checklist
- [x] Reviewer has checked that all the criteria below are met 
- Title is accurate
- All changes are related to the pull request's stated goal
- Avoids breaking
[API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces)
changes
- Testing strategy adequately addresses listed risks
- Newly-added code is easy to change
- Release note makes sense to a user of the library
- If necessary, author has acknowledged and discussed the performance
implications of this PR as reported in the benchmarks PR comment
- Backport labels are set in a manner that is consistent with the
[release branch maintenance
policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)
  • Loading branch information
sabrenner authored Feb 12, 2025
1 parent 0fcafa2 commit 8002eac
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 2 deletions.
31 changes: 31 additions & 0 deletions ddtrace/contrib/internal/langchain/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def traced_llm_generate(langchain, pin, func, instance, args, kwargs):
api_key=_extract_api_key(instance),
)
completions = None

integration.record_instance(instance, span)

try:
if integration.is_pc_sampled_span(span):
for idx, prompt in enumerate(prompts):
Expand Down Expand Up @@ -231,6 +234,9 @@ async def traced_llm_agenerate(langchain, pin, func, instance, args, kwargs):
model=model,
api_key=_extract_api_key(instance),
)

integration.record_instance(instance, span)

completions = None
try:
if integration.is_pc_sampled_span(span):
Expand Down Expand Up @@ -282,6 +288,9 @@ def traced_chat_model_generate(langchain, pin, func, instance, args, kwargs):
model=_extract_model_name(instance),
api_key=_extract_api_key(instance),
)

integration.record_instance(instance, span)

chat_completions = None
try:
for message_set_idx, message_set in enumerate(chat_messages):
Expand Down Expand Up @@ -372,6 +381,9 @@ async def traced_chat_model_agenerate(langchain, pin, func, instance, args, kwar
model=_extract_model_name(instance),
api_key=_extract_api_key(instance),
)

integration.record_instance(instance, span)

chat_completions = None
try:
for message_set_idx, message_set in enumerate(chat_messages):
Expand Down Expand Up @@ -469,6 +481,9 @@ def traced_embedding(langchain, pin, func, instance, args, kwargs):
model=_extract_model_name(instance),
api_key=_extract_api_key(instance),
)

integration.record_instance(instance, span)

embeddings = None
try:
if isinstance(input_texts, str):
Expand Down Expand Up @@ -520,6 +535,9 @@ def traced_lcel_runnable_sequence(langchain, pin, func, instance, args, kwargs):
)
inputs = None
final_output = None

integration.record_instance(instance, span)

try:
try:
inputs = get_argument_value(args, kwargs, 0, "input")
Expand Down Expand Up @@ -564,6 +582,9 @@ async def traced_lcel_runnable_sequence_async(langchain, pin, func, instance, ar
)
inputs = None
final_output = None

integration.record_instance(instance, span)

try:
try:
inputs = get_argument_value(args, kwargs, 0, "input")
Expand Down Expand Up @@ -608,6 +629,9 @@ def traced_similarity_search(langchain, pin, func, instance, args, kwargs):
provider=provider,
api_key=_extract_api_key(instance),
)

integration.record_instance(instance, span)

documents = []
try:
if integration.is_pc_sampled_span(span):
Expand Down Expand Up @@ -655,6 +679,7 @@ def traced_chain_stream(langchain, pin, func, instance, args, kwargs):
integration: LangChainIntegration = langchain._datadog_integration

def _on_span_started(span: Span):
integration.record_instance(instance, span)
inputs = get_argument_value(args, kwargs, 0, "input")
if not integration.is_pc_sampled_span(span):
return
Expand Down Expand Up @@ -712,6 +737,7 @@ def traced_chat_stream(langchain, pin, func, instance, args, kwargs):
model = _extract_model_name(instance)

def _on_span_started(span: Span):
integration.record_instance(instance, span)
if not integration.is_pc_sampled_span(span):
return
chat_messages = get_argument_value(args, kwargs, 0, "input")
Expand Down Expand Up @@ -771,6 +797,7 @@ def traced_llm_stream(langchain, pin, func, instance, args, kwargs):
model = _extract_model_name(instance)

def _on_span_start(span: Span):
integration.record_instance(instance, span)
if not integration.is_pc_sampled_span(span):
return
inp = get_argument_value(args, kwargs, 0, "input")
Expand Down Expand Up @@ -818,6 +845,8 @@ def traced_base_tool_invoke(langchain, pin, func, instance, args, kwargs):
submit_to_llmobs=True,
)

integration.record_instance(instance, span)

tool_output = None
tool_info = {}
try:
Expand Down Expand Up @@ -869,6 +898,8 @@ async def traced_base_tool_ainvoke(langchain, pin, func, instance, args, kwargs)
submit_to_llmobs=True,
)

integration.record_instance(instance, span)

tool_output = None
tool_info = {}
try:
Expand Down
8 changes: 8 additions & 0 deletions ddtrace/contrib/internal/langgraph/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def traced_pregel_stream(langgraph, pin, func, instance, args, kwargs):
result = func(*args, **kwargs)
except Exception:
span.set_exc_info(*sys.exc_info())
integration.llmobs_set_tags(span, args=args, kwargs={**kwargs, "name": name}, response=None, operation="graph")
span.finish()
raise

Expand All @@ -139,6 +140,9 @@ def _stream():
break
except Exception:
span.set_exc_info(*sys.exc_info())
integration.llmobs_set_tags(
span, args=args, kwargs={**kwargs, "name": name}, response=None, operation="graph"
)
span.finish()
raise

Expand All @@ -160,6 +164,7 @@ def traced_pregel_astream(langgraph, pin, func, instance, args, kwargs):
result = func(*args, **kwargs)
except Exception:
span.set_exc_info(*sys.exc_info())
integration.llmobs_set_tags(span, args=args, kwargs={**kwargs, "name": name}, response=None, operation="graph")
span.finish()
raise

Expand All @@ -178,6 +183,9 @@ async def _astream():
break
except Exception:
span.set_exc_info(*sys.exc_info())
integration.llmobs_set_tags(
span, args=args, kwargs={**kwargs, "name": name}, response=None, operation="graph"
)
span.finish()
raise

Expand Down
6 changes: 6 additions & 0 deletions ddtrace/llmobs/_integrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def __init__(self, integration_config: IntegrationConfig) -> None:
self.start_log_writer()
self._llmobs_pc_sampler = RateSampler(sample_rate=config._llmobs_sample_rate)

@property
def span_linking_enabled(self) -> bool:
return asbool(os.getenv("_DD_LLMOBS_AUTO_SPAN_LINKING_ENABLED", "false")) or asbool(
os.getenv("_DD_TRACE_LANGGRAPH_ENABLED", "false")
)

@property
def metrics_enabled(self) -> bool:
"""
Expand Down
Loading

0 comments on commit 8002eac

Please sign in to comment.