From f1b08cb1f5cfb36531b12638d7359a3008ddb885 Mon Sep 17 00:00:00 2001 From: Theodore Ni <3806110+tjni@users.noreply.github.com> Date: Tue, 10 Dec 2024 02:35:40 -0800 Subject: [PATCH] Handle interrupt/resume for subgraphs --- langgraph-tests/tests/test_pregel.py | 202 +++++++++++++++++++ langgraph-tests/tests/test_pregel_async.py | 219 +++++++++++++++++++++ tests/test_async.py | 19 +- tests/test_sync.py | 15 +- 4 files changed, 440 insertions(+), 15 deletions(-) diff --git a/langgraph-tests/tests/test_pregel.py b/langgraph-tests/tests/test_pregel.py index c883c34..ce0d617 100644 --- a/langgraph-tests/tests/test_pregel.py +++ b/langgraph-tests/tests/test_pregel.py @@ -6694,6 +6694,176 @@ def start(state: State) -> list[Union[Send, str]]: ) +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +def test_dynamic_interrupt_subgraph( + request: pytest.FixtureRequest, checkpointer_name: str +) -> None: + checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") + + class SubgraphState(TypedDict): + my_key: str + market: str + + tool_two_node_count = 0 + + def tool_two_node(s: SubgraphState) -> SubgraphState: + nonlocal tool_two_node_count + tool_two_node_count += 1 + if s["market"] == "DE": + answer = interrupt("Just because...") + else: + answer = " all good" + return {"my_key": answer} + + subgraph = StateGraph(SubgraphState) + subgraph.add_node("do", tool_two_node, retry=RetryPolicy()) + subgraph.add_edge(START, "do") + + class State(TypedDict): + my_key: Annotated[str, operator.add] + market: str + + tool_two_graph = StateGraph(State) + tool_two_graph.add_node("tool_two", subgraph.compile()) + tool_two_graph.add_edge(START, "tool_two") + tool_two = tool_two_graph.compile() + + tracer = FakeTracer() + assert tool_two.invoke( + {"my_key": "value", "market": "DE"}, {"callbacks": [tracer]} + ) == { + "my_key": "value", + "market": "DE", + } + assert tool_two_node_count == 1, "interrupts aren't retried" + assert len(tracer.runs) == 1 + run = tracer.runs[0] + assert run.end_time is not None + assert run.error is None + assert run.outputs == {"market": "DE", "my_key": "value"} + + assert tool_two.invoke({"my_key": "value", "market": "US"}) == { + "my_key": "value all good", + "market": "US", + } + + tool_two = tool_two_graph.compile(checkpointer=checkpointer) + + # missing thread_id + with pytest.raises(ValueError, match="thread_id"): + tool_two.invoke({"my_key": "value", "market": "DE"}) + + # flow: interrupt -> resume with answer + thread2 = {"configurable": {"thread_id": "2"}} + # stop when about to enter node + assert [ + c for c in tool_two.stream({"my_key": "value ⛰️", "market": "DE"}, thread2) + ] == [ + { + "__interrupt__": ( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:"), AnyStr("do:")], + ), + ) + }, + ] + # resume with answer + assert [c for c in tool_two.stream(Command(resume=" my answer"), thread2)] == [ + {"tool_two": {"my_key": " my answer", "market": "DE"}}, + ] + + # flow: interrupt -> clear tasks + thread1 = {"configurable": {"thread_id": "1"}} + # stop when about to enter node + assert tool_two.invoke({"my_key": "value ⛰️", "market": "DE"}, thread1) == { + "my_key": "value ⛰️", + "market": "DE", + } + assert [ + c.metadata + for c in tool_two.checkpointer.list( + {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} + ) + ] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] + assert tool_two.get_state(thread1) == StateSnapshot( + values={"my_key": "value ⛰️", "market": "DE"}, + next=("tool_two",), + tasks=( + PregelTask( + AnyStr(), + "tool_two", + (PULL, "tool_two"), + interrupts=( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:"), AnyStr("do:")], + ), + ), + state={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("tool_two:"), + } + }, + ), + ), + config=tool_two.checkpointer.get_tuple(thread1).config, + created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + metadata={ + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + parent_config=[ + *tool_two.checkpointer.list( + {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2 + ) + ][-1].config, + ) + # clear the interrupt and next tasks + tool_two.update_state(thread1, None, as_node=END) + # interrupt and next tasks are cleared + assert tool_two.get_state(thread1) == StateSnapshot( + values={"my_key": "value ⛰️", "market": "DE"}, + next=(), + tasks=(), + config=tool_two.checkpointer.get_tuple(thread1).config, + created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + metadata={ + "parents": {}, + "source": "update", + "step": 1, + "writes": {}, + "thread_id": "1", + }, + parent_config=[ + *tool_two.checkpointer.list( + {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2 + ) + ][-1].config, + ) + + @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) def test_start_branch_then( snapshot: SnapshotAssertion, request: pytest.FixtureRequest, checkpointer_name: str @@ -11122,3 +11292,35 @@ class CustomParentState(TypedDict): }, tasks=(), ) + + +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +def test_interrupt_subgraph(request: pytest.FixtureRequest, checkpointer_name: str): + checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") + + class State(TypedDict): + baz: str + + def foo(state): + return {"baz": "foo"} + + def bar(state): + value = interrupt("Please provide baz value:") + return {"baz": value} + + child_builder = StateGraph(State) + child_builder.add_node(bar) + child_builder.add_edge(START, "bar") + + builder = StateGraph(State) + builder.add_node(foo) + builder.add_node("bar", child_builder.compile()) + builder.add_edge(START, "foo") + builder.add_edge("foo", "bar") + graph = builder.compile(checkpointer=checkpointer) + + thread1 = {"configurable": {"thread_id": "1"}} + # First run, interrupted at bar + assert graph.invoke({"baz": ""}, thread1) + # Resume with answer + assert graph.invoke(Command(resume="bar"), thread1) diff --git a/langgraph-tests/tests/test_pregel_async.py b/langgraph-tests/tests/test_pregel_async.py index 839e998..f18dc27 100644 --- a/langgraph-tests/tests/test_pregel_async.py +++ b/langgraph-tests/tests/test_pregel_async.py @@ -251,6 +251,189 @@ async def tool_two_node(s: State) -> State: ) +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.11+ is required for async contextvars support", +) +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +async def test_dynamic_interrupt_subgraph(checkpointer_name: str) -> None: + class SubgraphState(TypedDict): + my_key: str + market: str + + tool_two_node_count = 0 + + def tool_two_node(s: SubgraphState) -> SubgraphState: + nonlocal tool_two_node_count + tool_two_node_count += 1 + if s["market"] == "DE": + answer = interrupt("Just because...") + else: + answer = " all good" + return {"my_key": answer} + + subgraph = StateGraph(SubgraphState) + subgraph.add_node("do", tool_two_node, retry=RetryPolicy()) + subgraph.add_edge(START, "do") + + class State(TypedDict): + my_key: Annotated[str, operator.add] + market: str + + tool_two_graph = StateGraph(State) + tool_two_graph.add_node("tool_two", subgraph.compile()) + tool_two_graph.add_edge(START, "tool_two") + tool_two = tool_two_graph.compile() + + tracer = FakeTracer() + assert await tool_two.ainvoke( + {"my_key": "value", "market": "DE"}, {"callbacks": [tracer]} + ) == { + "my_key": "value", + "market": "DE", + } + assert tool_two_node_count == 1, "interrupts aren't retried" + assert len(tracer.runs) == 1 + run = tracer.runs[0] + assert run.end_time is not None + assert run.error is None + assert run.outputs == {"market": "DE", "my_key": "value"} + + assert await tool_two.ainvoke({"my_key": "value", "market": "US"}) == { + "my_key": "value all good", + "market": "US", + } + + async with awith_checkpointer(checkpointer_name) as checkpointer: + tool_two = tool_two_graph.compile(checkpointer=checkpointer) + + # missing thread_id + with pytest.raises(ValueError, match="thread_id"): + await tool_two.ainvoke({"my_key": "value", "market": "DE"}) + + # flow: interrupt -> resume with answer + thread2 = {"configurable": {"thread_id": "2"}} + # stop when about to enter node + assert [ + c + async for c in tool_two.astream( + {"my_key": "value ⛰️", "market": "DE"}, thread2 + ) + ] == [ + { + "__interrupt__": ( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:"), AnyStr("do:")], + ), + ) + }, + ] + # resume with answer + assert [ + c async for c in tool_two.astream(Command(resume=" my answer"), thread2) + ] == [ + {"tool_two": {"my_key": " my answer", "market": "DE"}}, + ] + + # flow: interrupt -> clear + thread1 = {"configurable": {"thread_id": "1"}} + thread1root = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} + # stop when about to enter node + assert [ + c + async for c in tool_two.astream( + {"my_key": "value ⛰️", "market": "DE"}, thread1 + ) + ] == [ + { + "__interrupt__": ( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:"), AnyStr("do:")], + ), + ) + }, + ] + assert [c.metadata async for c in tool_two.checkpointer.alist(thread1root)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] + tup = await tool_two.checkpointer.aget_tuple(thread1) + assert await tool_two.aget_state(thread1) == StateSnapshot( + values={"my_key": "value ⛰️", "market": "DE"}, + next=("tool_two",), + tasks=( + PregelTask( + AnyStr(), + "tool_two", + (PULL, "tool_two"), + interrupts=( + Interrupt( + value="Just because...", + resumable=True, + ns=[AnyStr("tool_two:"), AnyStr("do:")], + ), + ), + state={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("tool_two:"), + } + }, + ), + ), + config=tup.config, + created_at=tup.checkpoint["ts"], + metadata={ + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + parent_config=[ + c async for c in tool_two.checkpointer.alist(thread1root, limit=2) + ][-1].config, + ) + + # clear the interrupt and next tasks + await tool_two.aupdate_state(thread1, None, as_node=END) + # interrupt is cleared, as well as the next tasks + tup = await tool_two.checkpointer.aget_tuple(thread1) + assert await tool_two.aget_state(thread1) == StateSnapshot( + values={"my_key": "value ⛰️", "market": "DE"}, + next=(), + tasks=(), + config=tup.config, + created_at=tup.checkpoint["ts"], + metadata={ + "parents": {}, + "source": "update", + "step": 1, + "writes": {}, + "thread_id": "1", + }, + parent_config=[ + c async for c in tool_two.checkpointer.alist(thread1root, limit=2) + ][-1].config, + ) + + @pytest.mark.skipif(not FF_SEND_V2, reason="send v2 is not enabled") @pytest.mark.skipif( sys.version_info < (3, 11), @@ -10310,3 +10493,39 @@ class CustomParentState(TypedDict): }, tasks=(), ) + + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.11+ is required for async contextvars support", +) +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +async def test_interrupt_subgraph(checkpointer_name: str): + class State(TypedDict): + baz: str + + def foo(state): + return {"baz": "foo"} + + def bar(state): + value = interrupt("Please provide baz value:") + return {"baz": value} + + child_builder = StateGraph(State) + child_builder.add_node(bar) + child_builder.add_edge(START, "bar") + + builder = StateGraph(State) + builder.add_node(foo) + builder.add_node("bar", child_builder.compile()) + builder.add_edge(START, "foo") + builder.add_edge("foo", "bar") + + async with awith_checkpointer(checkpointer_name) as checkpointer: + graph = builder.compile(checkpointer=checkpointer) + + thread1 = {"configurable": {"thread_id": "1"}} + # First run, interrupted at bar + assert await graph.ainvoke({"baz": ""}, thread1) + # Resume with answer + assert await graph.ainvoke(Command(resume="bar"), thread1) diff --git a/tests/test_async.py b/tests/test_async.py index 1b0d944..ec6e1d1 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -1,3 +1,4 @@ +from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any from uuid import uuid4 @@ -19,7 +20,7 @@ @asynccontextmanager -async def _pool_saver(): +async def _pool_saver() -> AsyncIterator[AIOMySQLSaver]: """Fixture for pool mode testing.""" database = f"test_{uuid4().hex[:16]}" # create unique db @@ -49,7 +50,7 @@ async def _pool_saver(): @asynccontextmanager -async def _base_saver(): +async def _base_saver() -> AsyncIterator[AIOMySQLSaver]: """Fixture for regular connection mode testing.""" database = f"test_{uuid4().hex[:16]}" # create unique db @@ -75,7 +76,7 @@ async def _base_saver(): @asynccontextmanager -async def _saver(name: str): +async def _saver(name: str) -> AsyncIterator[AIOMySQLSaver]: if name == "base": async with _base_saver() as saver: yield saver @@ -85,7 +86,7 @@ async def _saver(name: str): @pytest.fixture -def test_data(): +def test_data() -> dict[str, Any]: """Fixture providing test data for checkpoint tests.""" config_1: RunnableConfig = { "configurable": { @@ -136,7 +137,7 @@ def test_data(): @pytest.mark.parametrize("saver_name", ["base", "pool"]) -async def test_asearch(request, saver_name: str, test_data) -> None: +async def test_asearch(saver_name: str, test_data: dict[str, Any]) -> None: async with _saver(saver_name) as saver: configs = test_data["configs"] checkpoints = test_data["checkpoints"] @@ -181,7 +182,7 @@ async def test_asearch(request, saver_name: str, test_data) -> None: @pytest.mark.parametrize("saver_name", ["base", "pool"]) -async def test_null_chars(request, saver_name: str, test_data) -> None: +async def test_null_chars(saver_name: str, test_data: dict[str, Any]) -> None: async with _saver(saver_name) as saver: config = await saver.aput( test_data["configs"][0], @@ -197,7 +198,7 @@ async def test_null_chars(request, saver_name: str, test_data) -> None: @pytest.mark.parametrize("saver_name", ["base", "pool"]) async def test_write_and_read_pending_writes_and_sends( - request, saver_name: str, test_data + saver_name: str, test_data: dict[str, Any] ) -> None: async with _saver(saver_name) as saver: config: RunnableConfig = { @@ -233,7 +234,7 @@ async def test_write_and_read_pending_writes_and_sends( ], ) async def test_write_and_read_channel_values( - request, saver_name: str, channel_values: dict[str, Any] + saver_name: str, channel_values: dict[str, Any] ) -> None: async with _saver(saver_name) as saver: config: RunnableConfig = { @@ -260,7 +261,7 @@ async def test_write_and_read_channel_values( @pytest.mark.parametrize("saver_name", ["base", "pool"]) -async def test_write_and_read_pending_writes(request, saver_name: str) -> None: +async def test_write_and_read_pending_writes(saver_name: str) -> None: async with _saver(saver_name) as saver: config: RunnableConfig = { "configurable": { diff --git a/tests/test_sync.py b/tests/test_sync.py index 6f28613..eecee34 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -1,3 +1,4 @@ +from collections.abc import Iterator from contextlib import contextmanager from typing import Any from uuid import uuid4 @@ -19,7 +20,7 @@ @contextmanager -def _base_saver(): +def _base_saver() -> Iterator[PyMySQLSaver]: """Fixture for regular connection mode testing.""" database = f"test_{uuid4().hex[:16]}" # create unique db @@ -43,14 +44,14 @@ def _base_saver(): @contextmanager -def _saver(name: str): +def _saver(name: str) -> Iterator[PyMySQLSaver]: if name == "base": with _base_saver() as saver: yield saver @pytest.fixture -def test_data(): +def test_data() -> dict[str, Any]: """Fixture providing test data for checkpoint tests.""" config_1: RunnableConfig = { "configurable": { @@ -101,7 +102,7 @@ def test_data(): @pytest.mark.parametrize("saver_name", ["base"]) -def test_search(saver_name: str, test_data) -> None: +def test_search(saver_name: str, test_data: dict[str, Any]) -> None: with _saver(saver_name) as saver: configs = test_data["configs"] checkpoints = test_data["checkpoints"] @@ -144,7 +145,7 @@ def test_search(saver_name: str, test_data) -> None: @pytest.mark.parametrize("saver_name", ["base"]) -def test_null_chars(saver_name: str, test_data) -> None: +def test_null_chars(saver_name: str, test_data: dict[str, Any]) -> None: with _saver(saver_name) as saver: config = saver.put( test_data["configs"][0], @@ -160,7 +161,9 @@ def test_null_chars(saver_name: str, test_data) -> None: @pytest.mark.parametrize("saver_name", ["base"]) -def test_write_and_read_pending_writes_and_sends(saver_name: str, test_data) -> None: +def test_write_and_read_pending_writes_and_sends( + saver_name: str, test_data: dict[str, Any] +) -> None: with _saver(saver_name) as saver: config: RunnableConfig = { "configurable": {