Skip to content

Commit

Permalink
Merge pull request #28 from tjni/nc/15nov/command-subgraph
Browse files Browse the repository at this point in the history
Handle interrupt/resume for subgraphs
  • Loading branch information
tjni authored Dec 10, 2024
2 parents 7b59157 + f1b08cb commit 6ddac30
Show file tree
Hide file tree
Showing 4 changed files with 440 additions and 15 deletions.
202 changes: 202 additions & 0 deletions langgraph-tests/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6694,6 +6694,176 @@ def start(state: State) -> list[Union[Send, str]]:
)


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_dynamic_interrupt_subgraph(
request: pytest.FixtureRequest, checkpointer_name: str
) -> None:
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

class SubgraphState(TypedDict):
my_key: str
market: str

tool_two_node_count = 0

def tool_two_node(s: SubgraphState) -> SubgraphState:
nonlocal tool_two_node_count
tool_two_node_count += 1
if s["market"] == "DE":
answer = interrupt("Just because...")
else:
answer = " all good"
return {"my_key": answer}

subgraph = StateGraph(SubgraphState)
subgraph.add_node("do", tool_two_node, retry=RetryPolicy())
subgraph.add_edge(START, "do")

class State(TypedDict):
my_key: Annotated[str, operator.add]
market: str

tool_two_graph = StateGraph(State)
tool_two_graph.add_node("tool_two", subgraph.compile())
tool_two_graph.add_edge(START, "tool_two")
tool_two = tool_two_graph.compile()

tracer = FakeTracer()
assert tool_two.invoke(
{"my_key": "value", "market": "DE"}, {"callbacks": [tracer]}
) == {
"my_key": "value",
"market": "DE",
}
assert tool_two_node_count == 1, "interrupts aren't retried"
assert len(tracer.runs) == 1
run = tracer.runs[0]
assert run.end_time is not None
assert run.error is None
assert run.outputs == {"market": "DE", "my_key": "value"}

assert tool_two.invoke({"my_key": "value", "market": "US"}) == {
"my_key": "value all good",
"market": "US",
}

tool_two = tool_two_graph.compile(checkpointer=checkpointer)

# missing thread_id
with pytest.raises(ValueError, match="thread_id"):
tool_two.invoke({"my_key": "value", "market": "DE"})

# flow: interrupt -> resume with answer
thread2 = {"configurable": {"thread_id": "2"}}
# stop when about to enter node
assert [
c for c in tool_two.stream({"my_key": "value ⛰️", "market": "DE"}, thread2)
] == [
{
"__interrupt__": (
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:"), AnyStr("do:")],
),
)
},
]
# resume with answer
assert [c for c in tool_two.stream(Command(resume=" my answer"), thread2)] == [
{"tool_two": {"my_key": " my answer", "market": "DE"}},
]

# flow: interrupt -> clear tasks
thread1 = {"configurable": {"thread_id": "1"}}
# stop when about to enter node
assert tool_two.invoke({"my_key": "value ⛰️", "market": "DE"}, thread1) == {
"my_key": "value ⛰️",
"market": "DE",
}
assert [
c.metadata
for c in tool_two.checkpointer.list(
{"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
)
] == [
{
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
},
{
"parents": {},
"source": "input",
"step": -1,
"writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}},
"thread_id": "1",
},
]
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️", "market": "DE"},
next=("tool_two",),
tasks=(
PregelTask(
AnyStr(),
"tool_two",
(PULL, "tool_two"),
interrupts=(
Interrupt(
value="Just because...",
resumable=True,
ns=[AnyStr("tool_two:"), AnyStr("do:")],
),
),
state={
"configurable": {
"thread_id": "1",
"checkpoint_ns": AnyStr("tool_two:"),
}
},
),
),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "loop",
"step": 0,
"writes": None,
"thread_id": "1",
},
parent_config=[
*tool_two.checkpointer.list(
{"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2
)
][-1].config,
)
# clear the interrupt and next tasks
tool_two.update_state(thread1, None, as_node=END)
# interrupt and next tasks are cleared
assert tool_two.get_state(thread1) == StateSnapshot(
values={"my_key": "value ⛰️", "market": "DE"},
next=(),
tasks=(),
config=tool_two.checkpointer.get_tuple(thread1).config,
created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"],
metadata={
"parents": {},
"source": "update",
"step": 1,
"writes": {},
"thread_id": "1",
},
parent_config=[
*tool_two.checkpointer.list(
{"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2
)
][-1].config,
)


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_start_branch_then(
snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str
Expand Down Expand Up @@ -11122,3 +11292,35 @@ class CustomParentState(TypedDict):
},
tasks=(),
)


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
def test_interrupt_subgraph(request: pytest.FixtureRequest, checkpointer_name: str):
checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}")

class State(TypedDict):
baz: str

def foo(state):
return {"baz": "foo"}

def bar(state):
value = interrupt("Please provide baz value:")
return {"baz": value}

child_builder = StateGraph(State)
child_builder.add_node(bar)
child_builder.add_edge(START, "bar")

builder = StateGraph(State)
builder.add_node(foo)
builder.add_node("bar", child_builder.compile())
builder.add_edge(START, "foo")
builder.add_edge("foo", "bar")
graph = builder.compile(checkpointer=checkpointer)

thread1 = {"configurable": {"thread_id": "1"}}
# First run, interrupted at bar
assert graph.invoke({"baz": ""}, thread1)
# Resume with answer
assert graph.invoke(Command(resume="bar"), thread1)
Loading

0 comments on commit 6ddac30

Please sign in to comment.