Skip to content

Commit

Permalink
fix langgraph node invoke trigger detection from __pregel_push
Browse files Browse the repository at this point in the history
  • Loading branch information
sabrenner committed Feb 6, 2025
1 parent bc8ea4f commit 91fe626
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 21 deletions.
43 changes: 23 additions & 20 deletions ddtrace/llmobs/_integrations/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Any
from typing import Dict
from typing import List
from typing import Set
from typing import Optional
from typing import Set
from typing import Union
from weakref import WeakKeyDictionary

Expand Down Expand Up @@ -229,11 +229,16 @@ def _set_input_links(self, instance: Any, span: Span, parent_span: Union[Span, N
links = []

if not is_step:
self._set_span_links(span, [{
"trace_id": "{:x}".format(span.trace_id),
"span_id": str(invoker_spans[0].span_id),
"attributes": invoker_links_attributes[0],
}])
self._set_span_links(
span,
[
{
"trace_id": "{:x}".format(span.trace_id),
"span_id": str(invoker_spans[0].span_id),
"attributes": invoker_links_attributes[0],
}
],
)

return step_idx

Expand All @@ -252,7 +257,7 @@ def _set_input_links(self, instance: Any, span: Span, parent_span: Union[Span, N
invoker_span = self._spans[id(step)]
invoker_link_attributes = {"from": "output", "to": "input"}
break
if isinstance(step, list): # parallel steps in the list
if isinstance(step, list): # parallel steps in the list
for parallel_step in step:
if id(parallel_step) in self._spans:
if not has_parallel_steps:
Expand All @@ -279,16 +284,16 @@ def _set_input_links(self, instance: Any, span: Span, parent_span: Union[Span, N
self._set_span_links(span, links)

return step_idx

def _set_output_links(self, span: Span, parent_span: Union[Span, None], step_idx: int) -> None:
"""
Sets the output links for the parent span of the given span (to: output)
This is done by removing repeated span links from steps in a chain.
We add output->output span links at every step
We add output->output span links at every step
"""
if parent_span is None:
return

parent_links = parent_span._get_ctx_item(SPAN_LINKS) or []
pop_indecies = self._get_popped_span_link_indecies(parent_span, parent_links, step_idx)
parent_links = [link for i, link in enumerate(parent_links) if i not in pop_indecies]
Expand All @@ -305,7 +310,9 @@ def _set_output_links(self, span: Span, parent_span: Union[Span, None], step_idx
],
)

def _get_popped_span_link_indecies(self, parent_span: Span, parent_links: List[Dict[str, Any]], step_idx: int) -> List[int]:
def _get_popped_span_link_indecies(
self, parent_span: Span, parent_links: List[Dict[str, Any]], step_idx: int
) -> List[int]:
"""
Returns a list of indecies to pop from the parent span links list
This is determined by if the parent span represents a chain, and if there are steps before the step
Expand All @@ -318,11 +325,11 @@ def _get_popped_span_link_indecies(self, parent_span: Span, parent_links: List[D
parent_instance = self._instances.get(parent_span)
if not parent_instance:
return pop_indecies

parent_instance = _extract_bound(parent_instance)
if not hasattr(parent_instance, "steps"): # chain instance
return pop_indecies

steps = getattr(parent_instance, "steps", [])
flatmap_chain_steps = _flattened_chain_steps(steps)
for i in range(step_idx - 1, -1, -1):
Expand All @@ -340,19 +347,15 @@ def _get_popped_span_link_indecies(self, parent_span: Span, parent_links: List[D
if id(parallel_step) in self._spans:
invoker_span_id = self._spans[id(parallel_step)].span_id
link_idx = next(
(
i
for i, link in enumerate(parent_links)
if link["span_id"] == str(invoker_span_id)
),
(i for i, link in enumerate(parent_links) if link["span_id"] == str(invoker_span_id)),
None,
)
if link_idx is not None:
pop_indecies.append(link_idx)
break

return pop_indecies

def _set_span_links(self, span: Span, links: List[Dict[str, Any]]) -> None:
"""Sets the span links on the given span along with the existing links."""
existing_links = span._get_ctx_item(SPAN_LINKS) or []
Expand Down Expand Up @@ -454,7 +457,7 @@ def _llmobs_set_tags_from_chat_model(
content = (
message.get("content", "") if isinstance(message, dict) else getattr(message, "content", "")
)
role = getattr(message, "role", ROLE_MAPPING.get(message.type, ""))
role = getattr(message, "role", ROLE_MAPPING.get(getattr(message, "type", None), ""))
input_messages.append({"content": str(content), "role": str(role)})
span._set_ctx_item(input_tag_key, input_messages)

Expand Down
23 changes: 22 additions & 1 deletion ddtrace/llmobs/_integrations/langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ def _handle_finished_graph(self, graph_span, finished_tasks, is_subgraph_node):
def _link_task_to_parent(self, task_id, task, finished_task_names_to_ids):
"""Create the span links for a queued task from its triggering parent tasks."""
task_config = getattr(task, "config", {})
task_triggers = task_config.get("metadata", {}).get("langgraph_triggers", [])
task_triggers = _normalize_triggers(
triggers=task_config.get("metadata", {}).get("langgraph_triggers", []),
finished_tasks=finished_task_names_to_ids,
next_task=task,
)

trigger_node_names = [_extract_parent(trigger) for trigger in task_triggers]
trigger_node_ids: List[str] = [
Expand All @@ -132,6 +136,23 @@ def _link_task_to_parent(self, task_id, task, finished_task_names_to_ids):
span_links.append(span_link)


def _normalize_triggers(triggers, finished_tasks, next_task) -> List[str]:
"""
Return the default triggers for a LangGraph node.
For nodes queued up with `langgraph.types.Send`, the triggers are an unhelpful ['__pregel_push'].
In this case (and in any case with 1 finished task and 1 trigger), we can infer the trigger from
the one finished task.
"""
if len(finished_tasks) != 1 or len(triggers) != 1:
return []

finished_task_name = list(finished_tasks.keys())[0]
next_task_name = getattr(next_task, "name", "")

return [f"{finished_task_name}:{next_task_name}"]


def _extract_parent(trigger: str) -> str:
"""
Extract the parent node name from a trigger string.
Expand Down

0 comments on commit 91fe626

Please sign in to comment.