diff --git a/Makefile b/Makefile index c0b0ee9..307bb07 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ ###################### start-mysql: - MYSQL_VERSION=${MYSQL_VERSION} docker compose -f tests/compose-mysql.yml up -V --force-recreate --wait || ( \ + MYSQL_VERSION=$(MYSQL_VERSION) docker compose -f tests/compose-mysql.yml up -V --force-recreate --wait || ( \ echo "Failed to start MySQL, printing logs..."; \ docker compose -f tests/compose-mysql.yml logs; \ exit 1 \ @@ -38,7 +38,7 @@ test: TEST ?= . test_watch: - MYSQL_VERSION=${MYSQL_VERSION:-8} make start-mysql; \ + MYSQL_VERSION=$(MYSQL_VERSION) make start-mysql; \ poetry run ptw $(TEST); \ EXIT_CODE=$$?; \ make stop-mysql; \ diff --git a/langgraph-tests/tests/__snapshots__/test_large_cases.ambr b/langgraph-tests/tests/__snapshots__/test_large_cases.ambr index b64b761..48c3be7 100644 --- a/langgraph-tests/tests/__snapshots__/test_large_cases.ambr +++ b/langgraph-tests/tests/__snapshots__/test_large_cases.ambr @@ -101,6 +101,40 @@ ''' # --- +# name: test_branch_then[pymysql_shallow] + ''' + graph TD; + __start__ --> prepare; + finish --> __end__; + prepare -.-> tool_two_slow; + tool_two_slow --> finish; + prepare -.-> tool_two_fast; + tool_two_fast --> finish; + + ''' +# --- +# name: test_branch_then[pymysql_shallow].1 + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([
__start__
]):::first + prepare(prepare) + tool_two_slow(tool_two_slow) + tool_two_fast(tool_two_fast) + finish(finish) + __end__([__end__
]):::last + __start__ --> prepare; + finish --> __end__; + prepare -.-> tool_two_slow; + tool_two_slow --> finish; + prepare -.-> tool_two_fast; + tool_two_fast --> finish; + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- # name: test_conditional_graph[pymysql] ''' { @@ -926,6 +960,281 @@ ''' # --- +# name: test_conditional_graph[pymysql_shallow] + ''' + { + "nodes": [ + { + "id": "__start__", + "type": "schema", + "data": "__start__" + }, + { + "id": "agent", + "type": "runnable", + "data": { + "id": [ + "langchain", + "schema", + "runnable", + "RunnableAssign" + ], + "name": "agent" + } + }, + { + "id": "tools", + "type": "runnable", + "data": { + "id": [ + "langgraph", + "utils", + "runnable", + "RunnableCallable" + ], + "name": "tools" + }, + "metadata": { + "parents": {}, + "version": 2, + "variant": "b" + } + }, + { + "id": "__end__", + "type": "schema", + "data": "__end__" + } + ], + "edges": [ + { + "source": "__start__", + "target": "agent" + }, + { + "source": "tools", + "target": "agent" + }, + { + "source": "agent", + "target": "tools", + "data": "continue", + "conditional": true + }, + { + "source": "agent", + "target": "__end__", + "data": "exit", + "conditional": true + } + ] + } + ''' +# --- +# name: test_conditional_graph[pymysql_shallow].1 + ''' + graph TD; + __start__ --> agent; + tools --> agent; + agent -. continue .-> tools; + agent -. exit .-> __end__; + + ''' +# --- +# name: test_conditional_graph[pymysql_shallow].2 + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([__start__
]):::first + agent(agent) + tools(tools__end__
]):::last + __start__ --> agent; + tools --> agent; + agent -. continue .-> tools; + agent -. exit .-> __end__; + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- +# name: test_conditional_graph[pymysql_shallow].3 + ''' + { + "nodes": [ + { + "id": "__start__", + "type": "schema", + "data": "__start__" + }, + { + "id": "agent", + "type": "runnable", + "data": { + "id": [ + "langchain", + "schema", + "runnable", + "RunnableAssign" + ], + "name": "agent" + } + }, + { + "id": "tools", + "type": "runnable", + "data": { + "id": [ + "langgraph", + "utils", + "runnable", + "RunnableCallable" + ], + "name": "tools" + }, + "metadata": { + "parents": {}, + "version": 2, + "variant": "b" + } + }, + { + "id": "__end__", + "type": "schema", + "data": "__end__" + } + ], + "edges": [ + { + "source": "__start__", + "target": "agent" + }, + { + "source": "tools", + "target": "agent" + }, + { + "source": "agent", + "target": "tools", + "data": "continue", + "conditional": true + }, + { + "source": "agent", + "target": "__end__", + "data": "exit", + "conditional": true + } + ] + } + ''' +# --- +# name: test_conditional_graph[pymysql_shallow].4 + ''' + graph TD; + __start__ --> agent; + tools --> agent; + agent -. continue .-> tools; + agent -. exit .-> __end__; + + ''' +# --- +# name: test_conditional_graph[pymysql_shallow].5 + dict({ + 'edges': list([ + dict({ + 'source': '__start__', + 'target': 'agent', + }), + dict({ + 'source': 'tools', + 'target': 'agent', + }), + dict({ + 'conditional': True, + 'data': 'continue', + 'source': 'agent', + 'target': 'tools', + }), + dict({ + 'conditional': True, + 'data': 'exit', + 'source': 'agent', + 'target': '__end__', + }), + ]), + 'nodes': list([ + dict({ + 'data': '__start__', + 'id': '__start__', + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'schema', + 'runnable', + 'RunnableAssign', + ]), + 'name': 'agent', + }), + 'id': 'agent', + 'metadata': dict({ + '__interrupt': 'after', + }), + 'type': 'runnable', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langgraph', + 'utils', + 'runnable', + 'RunnableCallable', + ]), + 'name': 'tools', + }), + 'id': 'tools', + 'metadata': dict({ + 'parents': dict({ + }), + 'variant': 'b', + 'version': 2, + }), + 'type': 'runnable', + }), + dict({ + 'data': '__end__', + 'id': '__end__', + 'type': 'schema', + }), + ]), + }) +# --- +# name: test_conditional_graph[pymysql_shallow].6 + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([__start__
]):::first + agent(agent__end__
]):::last + __start__ --> agent; + tools --> agent; + agent -. continue .-> tools; + agent -. exit .-> __end__; + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- # name: test_conditional_state_graph[pymysql] '{"$defs": {"AgentAction": {"description": "Represents a request to execute an action by an agent.\\n\\nThe action consists of the name of the tool to execute and the input to pass\\nto the tool. The log is used to pass along extra information about the action.", "properties": {"tool": {"title": "Tool", "type": "string"}, "tool_input": {"anyOf": [{"type": "string"}, {"type": "object"}], "title": "Tool Input"}, "log": {"title": "Log", "type": "string"}, "type": {"const": "AgentAction", "default": "AgentAction", "enum": ["AgentAction"], "title": "Type", "type": "string"}}, "required": ["tool", "tool_input", "log"], "title": "AgentAction", "type": "object"}, "AgentFinish": {"description": "Final return value of an ActionAgent.\\n\\nAgents return an AgentFinish when they have reached a stopping condition.", "properties": {"return_values": {"title": "Return Values", "type": "object"}, "log": {"title": "Log", "type": "string"}, "type": {"const": "AgentFinish", "default": "AgentFinish", "enum": ["AgentFinish"], "title": "Type", "type": "string"}}, "required": ["return_values", "log"], "title": "AgentFinish", "type": "object"}}, "properties": {"input": {"default": null, "title": "Input", "type": "string"}, "agent_outcome": {"anyOf": [{"$ref": "#/$defs/AgentAction"}, {"$ref": "#/$defs/AgentFinish"}, {"type": "null"}], "default": null, "title": "Agent Outcome"}, "intermediate_steps": {"default": null, "items": {"maxItems": 2, "minItems": 2, "prefixItems": [{"$ref": "#/$defs/AgentAction"}, {"type": "string"}], "type": "array"}, "title": "Intermediate Steps", "type": "array"}}, "title": "LangGraphInput", "type": "object"}' # --- @@ -1172,6 +1481,88 @@ ''' # --- +# name: test_conditional_state_graph[pymysql_shallow] + '{"$defs": {"AgentAction": {"description": "Represents a request to execute an action by an agent.\\n\\nThe action consists of the name of the tool to execute and the input to pass\\nto the tool. The log is used to pass along extra information about the action.", "properties": {"tool": {"title": "Tool", "type": "string"}, "tool_input": {"anyOf": [{"type": "string"}, {"type": "object"}], "title": "Tool Input"}, "log": {"title": "Log", "type": "string"}, "type": {"const": "AgentAction", "default": "AgentAction", "enum": ["AgentAction"], "title": "Type", "type": "string"}}, "required": ["tool", "tool_input", "log"], "title": "AgentAction", "type": "object"}, "AgentFinish": {"description": "Final return value of an ActionAgent.\\n\\nAgents return an AgentFinish when they have reached a stopping condition.", "properties": {"return_values": {"title": "Return Values", "type": "object"}, "log": {"title": "Log", "type": "string"}, "type": {"const": "AgentFinish", "default": "AgentFinish", "enum": ["AgentFinish"], "title": "Type", "type": "string"}}, "required": ["return_values", "log"], "title": "AgentFinish", "type": "object"}}, "properties": {"input": {"default": null, "title": "Input", "type": "string"}, "agent_outcome": {"anyOf": [{"$ref": "#/$defs/AgentAction"}, {"$ref": "#/$defs/AgentFinish"}, {"type": "null"}], "default": null, "title": "Agent Outcome"}, "intermediate_steps": {"default": null, "items": {"maxItems": 2, "minItems": 2, "prefixItems": [{"$ref": "#/$defs/AgentAction"}, {"type": "string"}], "type": "array"}, "title": "Intermediate Steps", "type": "array"}}, "title": "LangGraphInput", "type": "object"}' +# --- +# name: test_conditional_state_graph[pymysql_shallow].1 + '{"$defs": {"AgentAction": {"description": "Represents a request to execute an action by an agent.\\n\\nThe action consists of the name of the tool to execute and the input to pass\\nto the tool. The log is used to pass along extra information about the action.", "properties": {"tool": {"title": "Tool", "type": "string"}, "tool_input": {"anyOf": [{"type": "string"}, {"type": "object"}], "title": "Tool Input"}, "log": {"title": "Log", "type": "string"}, "type": {"const": "AgentAction", "default": "AgentAction", "enum": ["AgentAction"], "title": "Type", "type": "string"}}, "required": ["tool", "tool_input", "log"], "title": "AgentAction", "type": "object"}, "AgentFinish": {"description": "Final return value of an ActionAgent.\\n\\nAgents return an AgentFinish when they have reached a stopping condition.", "properties": {"return_values": {"title": "Return Values", "type": "object"}, "log": {"title": "Log", "type": "string"}, "type": {"const": "AgentFinish", "default": "AgentFinish", "enum": ["AgentFinish"], "title": "Type", "type": "string"}}, "required": ["return_values", "log"], "title": "AgentFinish", "type": "object"}}, "properties": {"input": {"default": null, "title": "Input", "type": "string"}, "agent_outcome": {"anyOf": [{"$ref": "#/$defs/AgentAction"}, {"$ref": "#/$defs/AgentFinish"}, {"type": "null"}], "default": null, "title": "Agent Outcome"}, "intermediate_steps": {"default": null, "items": {"maxItems": 2, "minItems": 2, "prefixItems": [{"$ref": "#/$defs/AgentAction"}, {"type": "string"}], "type": "array"}, "title": "Intermediate Steps", "type": "array"}}, "title": "LangGraphOutput", "type": "object"}' +# --- +# name: test_conditional_state_graph[pymysql_shallow].2 + ''' + { + "nodes": [ + { + "id": "__start__", + "type": "schema", + "data": "__start__" + }, + { + "id": "agent", + "type": "runnable", + "data": { + "id": [ + "langchain", + "schema", + "runnable", + "RunnableSequence" + ], + "name": "agent" + } + }, + { + "id": "tools", + "type": "runnable", + "data": { + "id": [ + "langgraph", + "utils", + "runnable", + "RunnableCallable" + ], + "name": "tools" + } + }, + { + "id": "__end__", + "type": "schema", + "data": "__end__" + } + ], + "edges": [ + { + "source": "__start__", + "target": "agent" + }, + { + "source": "tools", + "target": "agent" + }, + { + "source": "agent", + "target": "tools", + "data": "continue", + "conditional": true + }, + { + "source": "agent", + "target": "__end__", + "data": "exit", + "conditional": true + } + ] + } + ''' +# --- +# name: test_conditional_state_graph[pymysql_shallow].3 + ''' + graph TD; + __start__ --> agent; + tools --> agent; + agent -. continue .-> tools; + agent -. exit .-> __end__; + + ''' +# --- # name: test_message_graph[pymysql] '{"$defs": {"AIMessage": {"additionalProperties": true, "description": "Message from an AI.\\n\\nAIMessage is returned from a chat model as a response to a prompt.\\n\\nThis message represents the output of the model and consists of both\\nthe raw output as returned by the model together standardized fields\\n(e.g., tool calls, usage metadata) added by the LangChain framework.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ai", "default": "ai", "enum": ["ai"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}, "tool_calls": {"default": [], "items": {"$ref": "#/$defs/ToolCall"}, "title": "Tool Calls", "type": "array"}, "invalid_tool_calls": {"default": [], "items": {"$ref": "#/$defs/InvalidToolCall"}, "title": "Invalid Tool Calls", "type": "array"}, "usage_metadata": {"anyOf": [{"$ref": "#/$defs/UsageMetadata"}, {"type": "null"}], "default": null}}, "required": ["content"], "title": "AIMessage", "type": "object"}, "AIMessageChunk": {"additionalProperties": true, "description": "Message chunk from an AI.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "AIMessageChunk", "default": "AIMessageChunk", "enum": ["AIMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}, "tool_calls": {"default": [], "items": {"$ref": "#/$defs/ToolCall"}, "title": "Tool Calls", "type": "array"}, "invalid_tool_calls": {"default": [], "items": {"$ref": "#/$defs/InvalidToolCall"}, "title": "Invalid Tool Calls", "type": "array"}, "usage_metadata": {"anyOf": [{"$ref": "#/$defs/UsageMetadata"}, {"type": "null"}], "default": null}, "tool_call_chunks": {"default": [], "items": {"$ref": "#/$defs/ToolCallChunk"}, "title": "Tool Call Chunks", "type": "array"}}, "required": ["content"], "title": "AIMessageChunk", "type": "object"}, "ChatMessage": {"additionalProperties": true, "description": "Message that can be assigned an arbitrary speaker (i.e. role).", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "chat", "default": "chat", "enum": ["chat"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "role": {"title": "Role", "type": "string"}}, "required": ["content", "role"], "title": "ChatMessage", "type": "object"}, "ChatMessageChunk": {"additionalProperties": true, "description": "Chat Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ChatMessageChunk", "default": "ChatMessageChunk", "enum": ["ChatMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "role": {"title": "Role", "type": "string"}}, "required": ["content", "role"], "title": "ChatMessageChunk", "type": "object"}, "FunctionMessage": {"additionalProperties": true, "description": "Message for passing the result of executing a tool back to a model.\\n\\nFunctionMessage are an older version of the ToolMessage schema, and\\ndo not contain the tool_call_id field.\\n\\nThe tool_call_id field is used to associate the tool call request with the\\ntool call response. This is useful in situations where a chat model is able\\nto request multiple tool calls in parallel.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "function", "default": "function", "enum": ["function"], "title": "Type", "type": "string"}, "name": {"title": "Name", "type": "string"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "name"], "title": "FunctionMessage", "type": "object"}, "FunctionMessageChunk": {"additionalProperties": true, "description": "Function Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "FunctionMessageChunk", "default": "FunctionMessageChunk", "enum": ["FunctionMessageChunk"], "title": "Type", "type": "string"}, "name": {"title": "Name", "type": "string"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "name"], "title": "FunctionMessageChunk", "type": "object"}, "HumanMessage": {"additionalProperties": true, "description": "Message from a human.\\n\\nHumanMessages are messages that are passed in from a human to the model.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import HumanMessage, SystemMessage\\n\\n messages = [\\n SystemMessage(\\n content=\\"You are a helpful assistant! Your name is Bob.\\"\\n ),\\n HumanMessage(\\n content=\\"What is your name?\\"\\n )\\n ]\\n\\n # Instantiate a chat model and invoke it with the messages\\n model = ...\\n print(model.invoke(messages))", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "human", "default": "human", "enum": ["human"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}}, "required": ["content"], "title": "HumanMessage", "type": "object"}, "HumanMessageChunk": {"additionalProperties": true, "description": "Human Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "HumanMessageChunk", "default": "HumanMessageChunk", "enum": ["HumanMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}}, "required": ["content"], "title": "HumanMessageChunk", "type": "object"}, "InputTokenDetails": {"description": "Breakdown of input token counts.\\n\\nDoes *not* need to sum to full input token count. Does *not* need to have all keys.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"audio\\": 10,\\n \\"cache_creation\\": 200,\\n \\"cache_read\\": 100,\\n }\\n\\n.. versionadded:: 0.3.9", "properties": {"audio": {"title": "Audio", "type": "integer"}, "cache_creation": {"title": "Cache Creation", "type": "integer"}, "cache_read": {"title": "Cache Read", "type": "integer"}}, "title": "InputTokenDetails", "type": "object"}, "InvalidToolCall": {"description": "Allowance for errors made by LLM.\\n\\nHere we add an `error` key to surface errors made during generation\\n(e.g., invalid JSON arguments.)", "properties": {"name": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Name"}, "args": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Args"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "error": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Error"}, "type": {"const": "invalid_tool_call", "enum": ["invalid_tool_call"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id", "error"], "title": "InvalidToolCall", "type": "object"}, "OutputTokenDetails": {"description": "Breakdown of output token counts.\\n\\nDoes *not* need to sum to full output token count. Does *not* need to have all keys.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"audio\\": 10,\\n \\"reasoning\\": 200,\\n }\\n\\n.. versionadded:: 0.3.9", "properties": {"audio": {"title": "Audio", "type": "integer"}, "reasoning": {"title": "Reasoning", "type": "integer"}}, "title": "OutputTokenDetails", "type": "object"}, "SystemMessage": {"additionalProperties": true, "description": "Message for priming AI behavior.\\n\\nThe system message is usually passed in as the first of a sequence\\nof input messages.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import HumanMessage, SystemMessage\\n\\n messages = [\\n SystemMessage(\\n content=\\"You are a helpful assistant! Your name is Bob.\\"\\n ),\\n HumanMessage(\\n content=\\"What is your name?\\"\\n )\\n ]\\n\\n # Define a chat model and invoke it with the messages\\n print(model.invoke(messages))", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "system", "default": "system", "enum": ["system"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content"], "title": "SystemMessage", "type": "object"}, "SystemMessageChunk": {"additionalProperties": true, "description": "System Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "SystemMessageChunk", "default": "SystemMessageChunk", "enum": ["SystemMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content"], "title": "SystemMessageChunk", "type": "object"}, "ToolCall": {"description": "Represents a request to call a tool.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"name\\": \\"foo\\",\\n \\"args\\": {\\"a\\": 1},\\n \\"id\\": \\"123\\"\\n }\\n\\n This represents a request to call the tool named \\"foo\\" with arguments {\\"a\\": 1}\\n and an identifier of \\"123\\".", "properties": {"name": {"title": "Name", "type": "string"}, "args": {"title": "Args", "type": "object"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "type": {"const": "tool_call", "enum": ["tool_call"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id"], "title": "ToolCall", "type": "object"}, "ToolCallChunk": {"description": "A chunk of a tool call (e.g., as part of a stream).\\n\\nWhen merging ToolCallChunks (e.g., via AIMessageChunk.__add__),\\nall string attributes are concatenated. Chunks are only merged if their\\nvalues of `index` are equal and not None.\\n\\nExample:\\n\\n.. code-block:: python\\n\\n left_chunks = [ToolCallChunk(name=\\"foo\\", args=\'{\\"a\\":\', index=0)]\\n right_chunks = [ToolCallChunk(name=None, args=\'1}\', index=0)]\\n\\n (\\n AIMessageChunk(content=\\"\\", tool_call_chunks=left_chunks)\\n + AIMessageChunk(content=\\"\\", tool_call_chunks=right_chunks)\\n ).tool_call_chunks == [ToolCallChunk(name=\'foo\', args=\'{\\"a\\":1}\', index=0)]", "properties": {"name": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Name"}, "args": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Args"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "index": {"anyOf": [{"type": "integer"}, {"type": "null"}], "title": "Index"}, "type": {"const": "tool_call_chunk", "enum": ["tool_call_chunk"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id", "index"], "title": "ToolCallChunk", "type": "object"}, "ToolMessage": {"additionalProperties": true, "description": "Message for passing the result of executing a tool back to a model.\\n\\nToolMessages contain the result of a tool invocation. Typically, the result\\nis encoded inside the `content` field.\\n\\nExample: A ToolMessage representing a result of 42 from a tool call with id\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import ToolMessage\\n\\n ToolMessage(content=\'42\', tool_call_id=\'call_Jja7J89XsjrOLA5r!MEOW!SL\')\\n\\n\\nExample: A ToolMessage where only part of the tool output is sent to the model\\n and the full output is passed in to artifact.\\n\\n .. versionadded:: 0.2.17\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import ToolMessage\\n\\n tool_output = {\\n \\"stdout\\": \\"From the graph we can see that the correlation between x and y is ...\\",\\n \\"stderr\\": None,\\n \\"artifacts\\": {\\"type\\": \\"image\\", \\"base64_data\\": \\"/9j/4gIcSU...\\"},\\n }\\n\\n ToolMessage(\\n content=tool_output[\\"stdout\\"],\\n artifact=tool_output,\\n tool_call_id=\'call_Jja7J89XsjrOLA5r!MEOW!SL\',\\n )\\n\\nThe tool_call_id field is used to associate the tool call request with the\\ntool call response. This is useful in situations where a chat model is able\\nto request multiple tool calls in parallel.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "tool", "default": "tool", "enum": ["tool"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "tool_call_id": {"title": "Tool Call Id", "type": "string"}, "artifact": {"default": null, "title": "Artifact"}, "status": {"default": "success", "enum": ["success", "error"], "title": "Status", "type": "string"}}, "required": ["content", "tool_call_id"], "title": "ToolMessage", "type": "object"}, "ToolMessageChunk": {"additionalProperties": true, "description": "Tool Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ToolMessageChunk", "default": "ToolMessageChunk", "enum": ["ToolMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "tool_call_id": {"title": "Tool Call Id", "type": "string"}, "artifact": {"default": null, "title": "Artifact"}, "status": {"default": "success", "enum": ["success", "error"], "title": "Status", "type": "string"}}, "required": ["content", "tool_call_id"], "title": "ToolMessageChunk", "type": "object"}, "UsageMetadata": {"description": "Usage metadata for a message, such as token counts.\\n\\nThis is a standard representation of token usage that is consistent across models.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"input_tokens\\": 350,\\n \\"output_tokens\\": 240,\\n \\"total_tokens\\": 590,\\n \\"input_token_details\\": {\\n \\"audio\\": 10,\\n \\"cache_creation\\": 200,\\n \\"cache_read\\": 100,\\n },\\n \\"output_token_details\\": {\\n \\"audio\\": 10,\\n \\"reasoning\\": 200,\\n }\\n }\\n\\n.. versionchanged:: 0.3.9\\n\\n Added ``input_token_details`` and ``output_token_details``.", "properties": {"input_tokens": {"title": "Input Tokens", "type": "integer"}, "output_tokens": {"title": "Output Tokens", "type": "integer"}, "total_tokens": {"title": "Total Tokens", "type": "integer"}, "input_token_details": {"$ref": "#/$defs/InputTokenDetails"}, "output_token_details": {"$ref": "#/$defs/OutputTokenDetails"}}, "required": ["input_tokens", "output_tokens", "total_tokens"], "title": "UsageMetadata", "type": "object"}}, "default": null, "items": {"oneOf": [{"$ref": "#/$defs/AIMessage"}, {"$ref": "#/$defs/HumanMessage"}, {"$ref": "#/$defs/ChatMessage"}, {"$ref": "#/$defs/SystemMessage"}, {"$ref": "#/$defs/FunctionMessage"}, {"$ref": "#/$defs/ToolMessage"}, {"$ref": "#/$defs/AIMessageChunk"}, {"$ref": "#/$defs/HumanMessageChunk"}, {"$ref": "#/$defs/ChatMessageChunk"}, {"$ref": "#/$defs/SystemMessageChunk"}, {"$ref": "#/$defs/FunctionMessageChunk"}, {"$ref": "#/$defs/ToolMessageChunk"}]}, "title": "LangGraphInput", "type": "array"}' # --- @@ -1415,6 +1806,87 @@ ''' # --- +# name: test_message_graph[pymysql_shallow] + '{"$defs": {"AIMessage": {"additionalProperties": true, "description": "Message from an AI.\\n\\nAIMessage is returned from a chat model as a response to a prompt.\\n\\nThis message represents the output of the model and consists of both\\nthe raw output as returned by the model together standardized fields\\n(e.g., tool calls, usage metadata) added by the LangChain framework.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ai", "default": "ai", "enum": ["ai"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}, "tool_calls": {"default": [], "items": {"$ref": "#/$defs/ToolCall"}, "title": "Tool Calls", "type": "array"}, "invalid_tool_calls": {"default": [], "items": {"$ref": "#/$defs/InvalidToolCall"}, "title": "Invalid Tool Calls", "type": "array"}, "usage_metadata": {"anyOf": [{"$ref": "#/$defs/UsageMetadata"}, {"type": "null"}], "default": null}}, "required": ["content"], "title": "AIMessage", "type": "object"}, "AIMessageChunk": {"additionalProperties": true, "description": "Message chunk from an AI.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "AIMessageChunk", "default": "AIMessageChunk", "enum": ["AIMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}, "tool_calls": {"default": [], "items": {"$ref": "#/$defs/ToolCall"}, "title": "Tool Calls", "type": "array"}, "invalid_tool_calls": {"default": [], "items": {"$ref": "#/$defs/InvalidToolCall"}, "title": "Invalid Tool Calls", "type": "array"}, "usage_metadata": {"anyOf": [{"$ref": "#/$defs/UsageMetadata"}, {"type": "null"}], "default": null}, "tool_call_chunks": {"default": [], "items": {"$ref": "#/$defs/ToolCallChunk"}, "title": "Tool Call Chunks", "type": "array"}}, "required": ["content"], "title": "AIMessageChunk", "type": "object"}, "ChatMessage": {"additionalProperties": true, "description": "Message that can be assigned an arbitrary speaker (i.e. role).", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "chat", "default": "chat", "enum": ["chat"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "role": {"title": "Role", "type": "string"}}, "required": ["content", "role"], "title": "ChatMessage", "type": "object"}, "ChatMessageChunk": {"additionalProperties": true, "description": "Chat Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ChatMessageChunk", "default": "ChatMessageChunk", "enum": ["ChatMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "role": {"title": "Role", "type": "string"}}, "required": ["content", "role"], "title": "ChatMessageChunk", "type": "object"}, "FunctionMessage": {"additionalProperties": true, "description": "Message for passing the result of executing a tool back to a model.\\n\\nFunctionMessage are an older version of the ToolMessage schema, and\\ndo not contain the tool_call_id field.\\n\\nThe tool_call_id field is used to associate the tool call request with the\\ntool call response. This is useful in situations where a chat model is able\\nto request multiple tool calls in parallel.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "function", "default": "function", "enum": ["function"], "title": "Type", "type": "string"}, "name": {"title": "Name", "type": "string"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "name"], "title": "FunctionMessage", "type": "object"}, "FunctionMessageChunk": {"additionalProperties": true, "description": "Function Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "FunctionMessageChunk", "default": "FunctionMessageChunk", "enum": ["FunctionMessageChunk"], "title": "Type", "type": "string"}, "name": {"title": "Name", "type": "string"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "name"], "title": "FunctionMessageChunk", "type": "object"}, "HumanMessage": {"additionalProperties": true, "description": "Message from a human.\\n\\nHumanMessages are messages that are passed in from a human to the model.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import HumanMessage, SystemMessage\\n\\n messages = [\\n SystemMessage(\\n content=\\"You are a helpful assistant! Your name is Bob.\\"\\n ),\\n HumanMessage(\\n content=\\"What is your name?\\"\\n )\\n ]\\n\\n # Instantiate a chat model and invoke it with the messages\\n model = ...\\n print(model.invoke(messages))", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "human", "default": "human", "enum": ["human"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}}, "required": ["content"], "title": "HumanMessage", "type": "object"}, "HumanMessageChunk": {"additionalProperties": true, "description": "Human Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "HumanMessageChunk", "default": "HumanMessageChunk", "enum": ["HumanMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}}, "required": ["content"], "title": "HumanMessageChunk", "type": "object"}, "InputTokenDetails": {"description": "Breakdown of input token counts.\\n\\nDoes *not* need to sum to full input token count. Does *not* need to have all keys.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"audio\\": 10,\\n \\"cache_creation\\": 200,\\n \\"cache_read\\": 100,\\n }\\n\\n.. versionadded:: 0.3.9", "properties": {"audio": {"title": "Audio", "type": "integer"}, "cache_creation": {"title": "Cache Creation", "type": "integer"}, "cache_read": {"title": "Cache Read", "type": "integer"}}, "title": "InputTokenDetails", "type": "object"}, "InvalidToolCall": {"description": "Allowance for errors made by LLM.\\n\\nHere we add an `error` key to surface errors made during generation\\n(e.g., invalid JSON arguments.)", "properties": {"name": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Name"}, "args": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Args"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "error": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Error"}, "type": {"const": "invalid_tool_call", "enum": ["invalid_tool_call"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id", "error"], "title": "InvalidToolCall", "type": "object"}, "OutputTokenDetails": {"description": "Breakdown of output token counts.\\n\\nDoes *not* need to sum to full output token count. Does *not* need to have all keys.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"audio\\": 10,\\n \\"reasoning\\": 200,\\n }\\n\\n.. versionadded:: 0.3.9", "properties": {"audio": {"title": "Audio", "type": "integer"}, "reasoning": {"title": "Reasoning", "type": "integer"}}, "title": "OutputTokenDetails", "type": "object"}, "SystemMessage": {"additionalProperties": true, "description": "Message for priming AI behavior.\\n\\nThe system message is usually passed in as the first of a sequence\\nof input messages.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import HumanMessage, SystemMessage\\n\\n messages = [\\n SystemMessage(\\n content=\\"You are a helpful assistant! Your name is Bob.\\"\\n ),\\n HumanMessage(\\n content=\\"What is your name?\\"\\n )\\n ]\\n\\n # Define a chat model and invoke it with the messages\\n print(model.invoke(messages))", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "system", "default": "system", "enum": ["system"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content"], "title": "SystemMessage", "type": "object"}, "SystemMessageChunk": {"additionalProperties": true, "description": "System Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "SystemMessageChunk", "default": "SystemMessageChunk", "enum": ["SystemMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content"], "title": "SystemMessageChunk", "type": "object"}, "ToolCall": {"description": "Represents a request to call a tool.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"name\\": \\"foo\\",\\n \\"args\\": {\\"a\\": 1},\\n \\"id\\": \\"123\\"\\n }\\n\\n This represents a request to call the tool named \\"foo\\" with arguments {\\"a\\": 1}\\n and an identifier of \\"123\\".", "properties": {"name": {"title": "Name", "type": "string"}, "args": {"title": "Args", "type": "object"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "type": {"const": "tool_call", "enum": ["tool_call"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id"], "title": "ToolCall", "type": "object"}, "ToolCallChunk": {"description": "A chunk of a tool call (e.g., as part of a stream).\\n\\nWhen merging ToolCallChunks (e.g., via AIMessageChunk.__add__),\\nall string attributes are concatenated. Chunks are only merged if their\\nvalues of `index` are equal and not None.\\n\\nExample:\\n\\n.. code-block:: python\\n\\n left_chunks = [ToolCallChunk(name=\\"foo\\", args=\'{\\"a\\":\', index=0)]\\n right_chunks = [ToolCallChunk(name=None, args=\'1}\', index=0)]\\n\\n (\\n AIMessageChunk(content=\\"\\", tool_call_chunks=left_chunks)\\n + AIMessageChunk(content=\\"\\", tool_call_chunks=right_chunks)\\n ).tool_call_chunks == [ToolCallChunk(name=\'foo\', args=\'{\\"a\\":1}\', index=0)]", "properties": {"name": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Name"}, "args": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Args"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "index": {"anyOf": [{"type": "integer"}, {"type": "null"}], "title": "Index"}, "type": {"const": "tool_call_chunk", "enum": ["tool_call_chunk"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id", "index"], "title": "ToolCallChunk", "type": "object"}, "ToolMessage": {"additionalProperties": true, "description": "Message for passing the result of executing a tool back to a model.\\n\\nToolMessages contain the result of a tool invocation. Typically, the result\\nis encoded inside the `content` field.\\n\\nExample: A ToolMessage representing a result of 42 from a tool call with id\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import ToolMessage\\n\\n ToolMessage(content=\'42\', tool_call_id=\'call_Jja7J89XsjrOLA5r!MEOW!SL\')\\n\\n\\nExample: A ToolMessage where only part of the tool output is sent to the model\\n and the full output is passed in to artifact.\\n\\n .. versionadded:: 0.2.17\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import ToolMessage\\n\\n tool_output = {\\n \\"stdout\\": \\"From the graph we can see that the correlation between x and y is ...\\",\\n \\"stderr\\": None,\\n \\"artifacts\\": {\\"type\\": \\"image\\", \\"base64_data\\": \\"/9j/4gIcSU...\\"},\\n }\\n\\n ToolMessage(\\n content=tool_output[\\"stdout\\"],\\n artifact=tool_output,\\n tool_call_id=\'call_Jja7J89XsjrOLA5r!MEOW!SL\',\\n )\\n\\nThe tool_call_id field is used to associate the tool call request with the\\ntool call response. This is useful in situations where a chat model is able\\nto request multiple tool calls in parallel.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "tool", "default": "tool", "enum": ["tool"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "tool_call_id": {"title": "Tool Call Id", "type": "string"}, "artifact": {"default": null, "title": "Artifact"}, "status": {"default": "success", "enum": ["success", "error"], "title": "Status", "type": "string"}}, "required": ["content", "tool_call_id"], "title": "ToolMessage", "type": "object"}, "ToolMessageChunk": {"additionalProperties": true, "description": "Tool Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ToolMessageChunk", "default": "ToolMessageChunk", "enum": ["ToolMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "tool_call_id": {"title": "Tool Call Id", "type": "string"}, "artifact": {"default": null, "title": "Artifact"}, "status": {"default": "success", "enum": ["success", "error"], "title": "Status", "type": "string"}}, "required": ["content", "tool_call_id"], "title": "ToolMessageChunk", "type": "object"}, "UsageMetadata": {"description": "Usage metadata for a message, such as token counts.\\n\\nThis is a standard representation of token usage that is consistent across models.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"input_tokens\\": 350,\\n \\"output_tokens\\": 240,\\n \\"total_tokens\\": 590,\\n \\"input_token_details\\": {\\n \\"audio\\": 10,\\n \\"cache_creation\\": 200,\\n \\"cache_read\\": 100,\\n },\\n \\"output_token_details\\": {\\n \\"audio\\": 10,\\n \\"reasoning\\": 200,\\n }\\n }\\n\\n.. versionchanged:: 0.3.9\\n\\n Added ``input_token_details`` and ``output_token_details``.", "properties": {"input_tokens": {"title": "Input Tokens", "type": "integer"}, "output_tokens": {"title": "Output Tokens", "type": "integer"}, "total_tokens": {"title": "Total Tokens", "type": "integer"}, "input_token_details": {"$ref": "#/$defs/InputTokenDetails"}, "output_token_details": {"$ref": "#/$defs/OutputTokenDetails"}}, "required": ["input_tokens", "output_tokens", "total_tokens"], "title": "UsageMetadata", "type": "object"}}, "default": null, "items": {"oneOf": [{"$ref": "#/$defs/AIMessage"}, {"$ref": "#/$defs/HumanMessage"}, {"$ref": "#/$defs/ChatMessage"}, {"$ref": "#/$defs/SystemMessage"}, {"$ref": "#/$defs/FunctionMessage"}, {"$ref": "#/$defs/ToolMessage"}, {"$ref": "#/$defs/AIMessageChunk"}, {"$ref": "#/$defs/HumanMessageChunk"}, {"$ref": "#/$defs/ChatMessageChunk"}, {"$ref": "#/$defs/SystemMessageChunk"}, {"$ref": "#/$defs/FunctionMessageChunk"}, {"$ref": "#/$defs/ToolMessageChunk"}]}, "title": "LangGraphInput", "type": "array"}' +# --- +# name: test_message_graph[pymysql_shallow].1 + '{"$defs": {"AIMessage": {"additionalProperties": true, "description": "Message from an AI.\\n\\nAIMessage is returned from a chat model as a response to a prompt.\\n\\nThis message represents the output of the model and consists of both\\nthe raw output as returned by the model together standardized fields\\n(e.g., tool calls, usage metadata) added by the LangChain framework.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ai", "default": "ai", "enum": ["ai"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}, "tool_calls": {"default": [], "items": {"$ref": "#/$defs/ToolCall"}, "title": "Tool Calls", "type": "array"}, "invalid_tool_calls": {"default": [], "items": {"$ref": "#/$defs/InvalidToolCall"}, "title": "Invalid Tool Calls", "type": "array"}, "usage_metadata": {"anyOf": [{"$ref": "#/$defs/UsageMetadata"}, {"type": "null"}], "default": null}}, "required": ["content"], "title": "AIMessage", "type": "object"}, "AIMessageChunk": {"additionalProperties": true, "description": "Message chunk from an AI.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "AIMessageChunk", "default": "AIMessageChunk", "enum": ["AIMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}, "tool_calls": {"default": [], "items": {"$ref": "#/$defs/ToolCall"}, "title": "Tool Calls", "type": "array"}, "invalid_tool_calls": {"default": [], "items": {"$ref": "#/$defs/InvalidToolCall"}, "title": "Invalid Tool Calls", "type": "array"}, "usage_metadata": {"anyOf": [{"$ref": "#/$defs/UsageMetadata"}, {"type": "null"}], "default": null}, "tool_call_chunks": {"default": [], "items": {"$ref": "#/$defs/ToolCallChunk"}, "title": "Tool Call Chunks", "type": "array"}}, "required": ["content"], "title": "AIMessageChunk", "type": "object"}, "ChatMessage": {"additionalProperties": true, "description": "Message that can be assigned an arbitrary speaker (i.e. role).", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "chat", "default": "chat", "enum": ["chat"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "role": {"title": "Role", "type": "string"}}, "required": ["content", "role"], "title": "ChatMessage", "type": "object"}, "ChatMessageChunk": {"additionalProperties": true, "description": "Chat Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ChatMessageChunk", "default": "ChatMessageChunk", "enum": ["ChatMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "role": {"title": "Role", "type": "string"}}, "required": ["content", "role"], "title": "ChatMessageChunk", "type": "object"}, "FunctionMessage": {"additionalProperties": true, "description": "Message for passing the result of executing a tool back to a model.\\n\\nFunctionMessage are an older version of the ToolMessage schema, and\\ndo not contain the tool_call_id field.\\n\\nThe tool_call_id field is used to associate the tool call request with the\\ntool call response. This is useful in situations where a chat model is able\\nto request multiple tool calls in parallel.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "function", "default": "function", "enum": ["function"], "title": "Type", "type": "string"}, "name": {"title": "Name", "type": "string"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "name"], "title": "FunctionMessage", "type": "object"}, "FunctionMessageChunk": {"additionalProperties": true, "description": "Function Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "FunctionMessageChunk", "default": "FunctionMessageChunk", "enum": ["FunctionMessageChunk"], "title": "Type", "type": "string"}, "name": {"title": "Name", "type": "string"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "name"], "title": "FunctionMessageChunk", "type": "object"}, "HumanMessage": {"additionalProperties": true, "description": "Message from a human.\\n\\nHumanMessages are messages that are passed in from a human to the model.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import HumanMessage, SystemMessage\\n\\n messages = [\\n SystemMessage(\\n content=\\"You are a helpful assistant! Your name is Bob.\\"\\n ),\\n HumanMessage(\\n content=\\"What is your name?\\"\\n )\\n ]\\n\\n # Instantiate a chat model and invoke it with the messages\\n model = ...\\n print(model.invoke(messages))", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "human", "default": "human", "enum": ["human"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}}, "required": ["content"], "title": "HumanMessage", "type": "object"}, "HumanMessageChunk": {"additionalProperties": true, "description": "Human Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "HumanMessageChunk", "default": "HumanMessageChunk", "enum": ["HumanMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "example": {"default": false, "title": "Example", "type": "boolean"}}, "required": ["content"], "title": "HumanMessageChunk", "type": "object"}, "InputTokenDetails": {"description": "Breakdown of input token counts.\\n\\nDoes *not* need to sum to full input token count. Does *not* need to have all keys.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"audio\\": 10,\\n \\"cache_creation\\": 200,\\n \\"cache_read\\": 100,\\n }\\n\\n.. versionadded:: 0.3.9", "properties": {"audio": {"title": "Audio", "type": "integer"}, "cache_creation": {"title": "Cache Creation", "type": "integer"}, "cache_read": {"title": "Cache Read", "type": "integer"}}, "title": "InputTokenDetails", "type": "object"}, "InvalidToolCall": {"description": "Allowance for errors made by LLM.\\n\\nHere we add an `error` key to surface errors made during generation\\n(e.g., invalid JSON arguments.)", "properties": {"name": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Name"}, "args": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Args"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "error": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Error"}, "type": {"const": "invalid_tool_call", "enum": ["invalid_tool_call"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id", "error"], "title": "InvalidToolCall", "type": "object"}, "OutputTokenDetails": {"description": "Breakdown of output token counts.\\n\\nDoes *not* need to sum to full output token count. Does *not* need to have all keys.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"audio\\": 10,\\n \\"reasoning\\": 200,\\n }\\n\\n.. versionadded:: 0.3.9", "properties": {"audio": {"title": "Audio", "type": "integer"}, "reasoning": {"title": "Reasoning", "type": "integer"}}, "title": "OutputTokenDetails", "type": "object"}, "SystemMessage": {"additionalProperties": true, "description": "Message for priming AI behavior.\\n\\nThe system message is usually passed in as the first of a sequence\\nof input messages.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import HumanMessage, SystemMessage\\n\\n messages = [\\n SystemMessage(\\n content=\\"You are a helpful assistant! Your name is Bob.\\"\\n ),\\n HumanMessage(\\n content=\\"What is your name?\\"\\n )\\n ]\\n\\n # Define a chat model and invoke it with the messages\\n print(model.invoke(messages))", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "system", "default": "system", "enum": ["system"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content"], "title": "SystemMessage", "type": "object"}, "SystemMessageChunk": {"additionalProperties": true, "description": "System Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "SystemMessageChunk", "default": "SystemMessageChunk", "enum": ["SystemMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content"], "title": "SystemMessageChunk", "type": "object"}, "ToolCall": {"description": "Represents a request to call a tool.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"name\\": \\"foo\\",\\n \\"args\\": {\\"a\\": 1},\\n \\"id\\": \\"123\\"\\n }\\n\\n This represents a request to call the tool named \\"foo\\" with arguments {\\"a\\": 1}\\n and an identifier of \\"123\\".", "properties": {"name": {"title": "Name", "type": "string"}, "args": {"title": "Args", "type": "object"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "type": {"const": "tool_call", "enum": ["tool_call"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id"], "title": "ToolCall", "type": "object"}, "ToolCallChunk": {"description": "A chunk of a tool call (e.g., as part of a stream).\\n\\nWhen merging ToolCallChunks (e.g., via AIMessageChunk.__add__),\\nall string attributes are concatenated. Chunks are only merged if their\\nvalues of `index` are equal and not None.\\n\\nExample:\\n\\n.. code-block:: python\\n\\n left_chunks = [ToolCallChunk(name=\\"foo\\", args=\'{\\"a\\":\', index=0)]\\n right_chunks = [ToolCallChunk(name=None, args=\'1}\', index=0)]\\n\\n (\\n AIMessageChunk(content=\\"\\", tool_call_chunks=left_chunks)\\n + AIMessageChunk(content=\\"\\", tool_call_chunks=right_chunks)\\n ).tool_call_chunks == [ToolCallChunk(name=\'foo\', args=\'{\\"a\\":1}\', index=0)]", "properties": {"name": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Name"}, "args": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Args"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "title": "Id"}, "index": {"anyOf": [{"type": "integer"}, {"type": "null"}], "title": "Index"}, "type": {"const": "tool_call_chunk", "enum": ["tool_call_chunk"], "title": "Type", "type": "string"}}, "required": ["name", "args", "id", "index"], "title": "ToolCallChunk", "type": "object"}, "ToolMessage": {"additionalProperties": true, "description": "Message for passing the result of executing a tool back to a model.\\n\\nToolMessages contain the result of a tool invocation. Typically, the result\\nis encoded inside the `content` field.\\n\\nExample: A ToolMessage representing a result of 42 from a tool call with id\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import ToolMessage\\n\\n ToolMessage(content=\'42\', tool_call_id=\'call_Jja7J89XsjrOLA5r!MEOW!SL\')\\n\\n\\nExample: A ToolMessage where only part of the tool output is sent to the model\\n and the full output is passed in to artifact.\\n\\n .. versionadded:: 0.2.17\\n\\n .. code-block:: python\\n\\n from langchain_core.messages import ToolMessage\\n\\n tool_output = {\\n \\"stdout\\": \\"From the graph we can see that the correlation between x and y is ...\\",\\n \\"stderr\\": None,\\n \\"artifacts\\": {\\"type\\": \\"image\\", \\"base64_data\\": \\"/9j/4gIcSU...\\"},\\n }\\n\\n ToolMessage(\\n content=tool_output[\\"stdout\\"],\\n artifact=tool_output,\\n tool_call_id=\'call_Jja7J89XsjrOLA5r!MEOW!SL\',\\n )\\n\\nThe tool_call_id field is used to associate the tool call request with the\\ntool call response. This is useful in situations where a chat model is able\\nto request multiple tool calls in parallel.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "tool", "default": "tool", "enum": ["tool"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "tool_call_id": {"title": "Tool Call Id", "type": "string"}, "artifact": {"default": null, "title": "Artifact"}, "status": {"default": "success", "enum": ["success", "error"], "title": "Status", "type": "string"}}, "required": ["content", "tool_call_id"], "title": "ToolMessage", "type": "object"}, "ToolMessageChunk": {"additionalProperties": true, "description": "Tool Message chunk.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"const": "ToolMessageChunk", "default": "ToolMessageChunk", "enum": ["ToolMessageChunk"], "title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}, "tool_call_id": {"title": "Tool Call Id", "type": "string"}, "artifact": {"default": null, "title": "Artifact"}, "status": {"default": "success", "enum": ["success", "error"], "title": "Status", "type": "string"}}, "required": ["content", "tool_call_id"], "title": "ToolMessageChunk", "type": "object"}, "UsageMetadata": {"description": "Usage metadata for a message, such as token counts.\\n\\nThis is a standard representation of token usage that is consistent across models.\\n\\nExample:\\n\\n .. code-block:: python\\n\\n {\\n \\"input_tokens\\": 350,\\n \\"output_tokens\\": 240,\\n \\"total_tokens\\": 590,\\n \\"input_token_details\\": {\\n \\"audio\\": 10,\\n \\"cache_creation\\": 200,\\n \\"cache_read\\": 100,\\n },\\n \\"output_token_details\\": {\\n \\"audio\\": 10,\\n \\"reasoning\\": 200,\\n }\\n }\\n\\n.. versionchanged:: 0.3.9\\n\\n Added ``input_token_details`` and ``output_token_details``.", "properties": {"input_tokens": {"title": "Input Tokens", "type": "integer"}, "output_tokens": {"title": "Output Tokens", "type": "integer"}, "total_tokens": {"title": "Total Tokens", "type": "integer"}, "input_token_details": {"$ref": "#/$defs/InputTokenDetails"}, "output_token_details": {"$ref": "#/$defs/OutputTokenDetails"}}, "required": ["input_tokens", "output_tokens", "total_tokens"], "title": "UsageMetadata", "type": "object"}}, "default": null, "items": {"oneOf": [{"$ref": "#/$defs/AIMessage"}, {"$ref": "#/$defs/HumanMessage"}, {"$ref": "#/$defs/ChatMessage"}, {"$ref": "#/$defs/SystemMessage"}, {"$ref": "#/$defs/FunctionMessage"}, {"$ref": "#/$defs/ToolMessage"}, {"$ref": "#/$defs/AIMessageChunk"}, {"$ref": "#/$defs/HumanMessageChunk"}, {"$ref": "#/$defs/ChatMessageChunk"}, {"$ref": "#/$defs/SystemMessageChunk"}, {"$ref": "#/$defs/FunctionMessageChunk"}, {"$ref": "#/$defs/ToolMessageChunk"}]}, "title": "LangGraphOutput", "type": "array"}' +# --- +# name: test_message_graph[pymysql_shallow].2 + ''' + { + "nodes": [ + { + "id": "__start__", + "type": "schema", + "data": "__start__" + }, + { + "id": "agent", + "type": "runnable", + "data": { + "id": [ + "tests", + "test_large_cases", + "FakeFuntionChatModel" + ], + "name": "agent" + } + }, + { + "id": "tools", + "type": "runnable", + "data": { + "id": [ + "langgraph", + "prebuilt", + "tool_node", + "ToolNode" + ], + "name": "tools" + } + }, + { + "id": "__end__", + "type": "schema", + "data": "__end__" + } + ], + "edges": [ + { + "source": "__start__", + "target": "agent" + }, + { + "source": "tools", + "target": "agent" + }, + { + "source": "agent", + "target": "tools", + "data": "continue", + "conditional": true + }, + { + "source": "agent", + "target": "__end__", + "data": "end", + "conditional": true + } + ] + } + ''' +# --- +# name: test_message_graph[pymysql_shallow].3 + ''' + graph TD; + __start__ --> agent; + tools --> agent; + agent -. continue .-> tools; + agent -. end .-> __end__; + + ''' +# --- # name: test_send_react_interrupt_control[pymysql] ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% @@ -1460,6 +1932,21 @@ ''' # --- +# name: test_send_react_interrupt_control[pymysql_shallow] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([__start__
]):::first + agent(agent) + foo([foo]):::last + __start__ --> agent; + agent -.-> foo; + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- # name: test_start_branch_then[pymysql] ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% @@ -1514,6 +2001,24 @@ ''' # --- +# name: test_start_branch_then[pymysql_shallow] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([__start__
]):::first + tool_two_slow(tool_two_slow) + tool_two_fast(tool_two_fast) + __end__([__end__
]):::last + __start__ -.-> tool_two_slow; + tool_two_slow --> __end__; + __start__ -.-> tool_two_fast; + tool_two_fast --> __end__; + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + + ''' +# --- # name: test_weather_subgraph[pymysql] ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% @@ -1587,5 +2092,30 @@ classDef first fill-opacity:0 classDef last fill:#bfb6fc + ''' +# --- +# name: test_weather_subgraph[pymysql_shallow] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([__start__
]):::first + router_node(router_node) + normal_llm_node(normal_llm_node) + weather_graph_model_node(model_node) + weather_graph_weather_node(weather_node__end__
]):::last + __start__ --> router_node; + normal_llm_node --> __end__; + weather_graph_weather_node --> __end__; + router_node -.-> normal_llm_node; + router_node -.-> weather_graph_model_node; + router_node -.-> __end__; + subgraph weather_graph + weather_graph_model_node --> weather_graph_weather_node; + end + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + ''' # --- \ No newline at end of file diff --git a/langgraph-tests/tests/__snapshots__/test_large_cases_async.ambr b/langgraph-tests/tests/__snapshots__/test_large_cases_async.ambr index a83a355..2bfde6b 100644 --- a/langgraph-tests/tests/__snapshots__/test_large_cases_async.ambr +++ b/langgraph-tests/tests/__snapshots__/test_large_cases_async.ambr @@ -47,5 +47,30 @@ classDef first fill-opacity:0 classDef last fill:#bfb6fc + ''' +# --- +# name: test_weather_subgraph[aiomysql_shallow] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([__start__
]):::first + router_node(router_node) + normal_llm_node(normal_llm_node) + weather_graph_model_node(model_node) + weather_graph_weather_node(weather_node__end__
]):::last + __start__ --> router_node; + normal_llm_node --> __end__; + weather_graph_weather_node --> __end__; + router_node -.-> normal_llm_node; + router_node -.-> weather_graph_model_node; + router_node -.-> __end__; + subgraph weather_graph + weather_graph_model_node --> weather_graph_weather_node; + end + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + ''' # --- \ No newline at end of file diff --git a/langgraph-tests/tests/__snapshots__/test_pregel.ambr b/langgraph-tests/tests/__snapshots__/test_pregel.ambr index de76b42..1b5e4ec 100644 --- a/langgraph-tests/tests/__snapshots__/test_pregel.ambr +++ b/langgraph-tests/tests/__snapshots__/test_pregel.ambr @@ -25,6 +25,19 @@ ''' # --- +# name: test_in_one_fan_out_state_graph_waiting_edge[pymysql_shallow] + ''' + graph TD; + __start__ --> rewrite_query; + analyzer_one --> retriever_one; + qa --> __end__; + retriever_one --> qa; + retriever_two --> qa; + rewrite_query --> analyzer_one; + rewrite_query --> retriever_two; + + ''' +# --- # name: test_in_one_fan_out_state_graph_waiting_edge[pymysql_callable] ''' graph TD; @@ -248,6 +261,76 @@ 'type': 'object', }) # --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_shallow] + ''' + graph TD; + __start__ --> rewrite_query; + analyzer_one --> retriever_one; + qa --> __end__; + retriever_one --> qa; + retriever_two --> qa; + rewrite_query --> analyzer_one; + rewrite_query -.-> retriever_two; + + ''' +# --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_shallow].1 + dict({ + 'definitions': dict({ + 'InnerObject': dict({ + 'properties': dict({ + 'yo': dict({ + 'title': 'Yo', + 'type': 'integer', + }), + }), + 'required': list([ + 'yo', + ]), + 'title': 'InnerObject', + 'type': 'object', + }), + }), + 'properties': dict({ + 'inner': dict({ + '$ref': '#/definitions/InnerObject', + }), + 'query': dict({ + 'title': 'Query', + 'type': 'string', + }), + }), + 'required': list([ + 'query', + 'inner', + ]), + 'title': 'Input', + 'type': 'object', + }) +# --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_shallow].2 + dict({ + 'properties': dict({ + 'answer': dict({ + 'title': 'Answer', + 'type': 'string', + }), + 'docs': dict({ + 'items': dict({ + 'type': 'string', + }), + 'title': 'Docs', + 'type': 'array', + }), + }), + 'required': list([ + 'answer', + 'docs', + ]), + 'title': 'Output', + 'type': 'object', + }) +# --- # name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql] ''' graph TD; @@ -458,6 +541,76 @@ 'type': 'object', }) # --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_shallow] + ''' + graph TD; + __start__ --> rewrite_query; + analyzer_one --> retriever_one; + qa --> __end__; + retriever_one --> qa; + retriever_two --> qa; + rewrite_query --> analyzer_one; + rewrite_query -.-> retriever_two; + + ''' +# --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_shallow].1 + dict({ + '$defs': dict({ + 'InnerObject': dict({ + 'properties': dict({ + 'yo': dict({ + 'title': 'Yo', + 'type': 'integer', + }), + }), + 'required': list([ + 'yo', + ]), + 'title': 'InnerObject', + 'type': 'object', + }), + }), + 'properties': dict({ + 'inner': dict({ + '$ref': '#/$defs/InnerObject', + }), + 'query': dict({ + 'title': 'Query', + 'type': 'string', + }), + }), + 'required': list([ + 'query', + 'inner', + ]), + 'title': 'Input', + 'type': 'object', + }) +# --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_shallow].2 + dict({ + 'properties': dict({ + 'answer': dict({ + 'title': 'Answer', + 'type': 'string', + }), + 'docs': dict({ + 'items': dict({ + 'type': 'string', + }), + 'title': 'Docs', + 'type': 'array', + }), + }), + 'required': list([ + 'answer', + 'docs', + ]), + 'title': 'Output', + 'type': 'object', + }) +# --- # name: test_in_one_fan_out_state_graph_waiting_edge_via_branch[pymysql] ''' graph TD; @@ -497,3 +650,16 @@ ''' # --- +# name: test_in_one_fan_out_state_graph_waiting_edge_via_branch[pymysql_shallow] + ''' + graph TD; + __start__ --> rewrite_query; + analyzer_one --> retriever_one; + qa --> __end__; + retriever_one --> qa; + retriever_two --> qa; + rewrite_query --> analyzer_one; + rewrite_query -.-> retriever_two; + + ''' +# --- diff --git a/langgraph-tests/tests/__snapshots__/test_pregel_async.ambr b/langgraph-tests/tests/__snapshots__/test_pregel_async.ambr index d0d5470..f76d66e 100644 --- a/langgraph-tests/tests/__snapshots__/test_pregel_async.ambr +++ b/langgraph-tests/tests/__snapshots__/test_pregel_async.ambr @@ -241,6 +241,127 @@ 'type': 'object', }) # --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[aiomysql_shallow] + ''' + graph TD; + __start__ --> rewrite_query; + analyzer_one --> retriever_one; + qa --> __end__; + retriever_one --> qa; + retriever_two --> qa; + rewrite_query --> analyzer_one; + rewrite_query -.-> retriever_two; + + ''' +# --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[aiomysql_shallow].1 + dict({ + '$defs': dict({ + 'InnerObject': dict({ + 'properties': dict({ + 'yo': dict({ + 'title': 'Yo', + 'type': 'integer', + }), + }), + 'required': list([ + 'yo', + ]), + 'title': 'InnerObject', + 'type': 'object', + }), + }), + 'properties': dict({ + 'answer': dict({ + 'anyOf': list([ + dict({ + 'type': 'string', + }), + dict({ + 'type': 'null', + }), + ]), + 'default': None, + 'title': 'Answer', + }), + 'docs': dict({ + 'items': dict({ + 'type': 'string', + }), + 'title': 'Docs', + 'type': 'array', + }), + 'inner': dict({ + '$ref': '#/$defs/InnerObject', + }), + 'query': dict({ + 'title': 'Query', + 'type': 'string', + }), + }), + 'required': list([ + 'query', + 'inner', + 'docs', + ]), + 'title': 'State', + 'type': 'object', + }) +# --- +# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[aiomysql_shallow].2 + dict({ + '$defs': dict({ + 'InnerObject': dict({ + 'properties': dict({ + 'yo': dict({ + 'title': 'Yo', + 'type': 'integer', + }), + }), + 'required': list([ + 'yo', + ]), + 'title': 'InnerObject', + 'type': 'object', + }), + }), + 'properties': dict({ + 'answer': dict({ + 'anyOf': list([ + dict({ + 'type': 'string', + }), + dict({ + 'type': 'null', + }), + ]), + 'default': None, + 'title': 'Answer', + }), + 'docs': dict({ + 'items': dict({ + 'type': 'string', + }), + 'title': 'Docs', + 'type': 'array', + }), + 'inner': dict({ + '$ref': '#/$defs/InnerObject', + }), + 'query': dict({ + 'title': 'Query', + 'type': 'string', + }), + }), + 'required': list([ + 'query', + 'inner', + 'docs', + ]), + 'title': 'State', + 'type': 'object', + }) +# --- # name: test_send_react_interrupt_control[aiomysql] ''' %%{init: {'flowchart': {'curve': 'linear'}}}%% @@ -269,5 +390,20 @@ classDef first fill-opacity:0 classDef last fill:#bfb6fc + ''' +# --- +# name: test_send_react_interrupt_control[aiomysql_shallow] + ''' + %%{init: {'flowchart': {'curve': 'linear'}}}%% + graph TD; + __start__([__start__
]):::first + agent(agent) + foo([foo]):::last + __start__ --> agent; + agent -.-> foo; + classDef default fill:#f2f0ff,line-height:1.2 + classDef first fill-opacity:0 + classDef last fill:#bfb6fc + ''' # --- \ No newline at end of file diff --git a/langgraph-tests/tests/conftest.py b/langgraph-tests/tests/conftest.py index e2b0e56..5e995fc 100644 --- a/langgraph-tests/tests/conftest.py +++ b/langgraph-tests/tests/conftest.py @@ -12,8 +12,8 @@ from sqlalchemy import Pool, create_pool_from_url from langgraph.checkpoint.base import BaseCheckpointSaver -from langgraph.checkpoint.mysql.aio import AIOMySQLSaver -from langgraph.checkpoint.mysql.pymysql import PyMySQLSaver +from langgraph.checkpoint.mysql.aio import AIOMySQLSaver, ShallowAIOMySQLSaver +from langgraph.checkpoint.mysql.pymysql import PyMySQLSaver, ShallowPyMySQLSaver from langgraph.store.base import BaseStore from langgraph.store.mysql.aio import AIOMySQLStore from langgraph.store.mysql.pymysql import PyMySQLStore @@ -68,6 +68,28 @@ def checkpointer_pymysql(): cursor.execute(f"DROP DATABASE {database}") +@pytest.fixture(scope="function") +def checkpointer_pymysql_shallow(): + database = f"test_{uuid4().hex[:16]}" + + # create unique db + with pymysql.connect(**PyMySQLSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn: + with conn.cursor() as cursor: + cursor.execute(f"CREATE DATABASE {database}") + try: + # yield checkpointer + with ShallowPyMySQLSaver.from_conn_string( + DEFAULT_MYSQL_URI + database + ) as checkpointer: + checkpointer.setup() + yield checkpointer + finally: + # drop unique db + with pymysql.connect(**PyMySQLSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn: + with conn.cursor() as cursor: + cursor.execute(f"DROP DATABASE {database}") + + @pytest.fixture(scope="function") def checkpointer_pymysql_sqlalchemy_pool(): database = f"test_{uuid4().hex[:16]}" @@ -138,6 +160,33 @@ async def _checkpointer_aiomysql(): await cursor.execute(f"DROP DATABASE {database}") +@asynccontextmanager +async def _checkpointer_aiomysql_shallow(): + database = f"test_{uuid4().hex[:16]}" + # create unique db + async with await aiomysql.connect( + **AIOMySQLSaver.parse_conn_string(DEFAULT_MYSQL_URI), + autocommit=True, + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"CREATE DATABASE {database}") + try: + # yield checkpointer + async with ShallowAIOMySQLSaver.from_conn_string( + DEFAULT_MYSQL_URI + database + ) as checkpointer: + await checkpointer.setup() + yield checkpointer + finally: + # drop unique db + async with await aiomysql.connect( + **AIOMySQLSaver.parse_conn_string(DEFAULT_MYSQL_URI), + autocommit=True + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"DROP DATABASE {database}") + + @asynccontextmanager async def _checkpointer_aiomysql_pool(): database = f"test_{uuid4().hex[:16]}" @@ -175,6 +224,9 @@ async def awith_checkpointer( if checkpointer_name == "aiomysql": async with _checkpointer_aiomysql() as checkpointer: yield checkpointer + elif checkpointer_name == "aiomysql_shallow": + async with _checkpointer_aiomysql_shallow() as checkpointer: + yield checkpointer elif checkpointer_name == "aiomysql_pool": async with _checkpointer_aiomysql_pool() as checkpointer: yield checkpointer @@ -268,6 +320,30 @@ async def _store_aiomysql(): await cursor.execute(f"DROP DATABASE {database}") +@asynccontextmanager +async def _store_aiomysql_shallow(): + database = f"test_{uuid4().hex[:16]}" + async with await aiomysql.connect( + **ShallowAIOMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI), + autocommit=True, + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"CREATE DATABASE {database}") + try: + async with ShallowAIOMySQLStore.from_conn_string( + DEFAULT_MYSQL_URI + database + ) as store: + await store.setup() + yield store + finally: + async with await aiomysql.connect( + **ShallowAIOMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI), + autocommit=True + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"DROP DATABASE {database}") + + @asynccontextmanager async def _store_aiomysql_pool(): database = f"test_{uuid4().hex[:16]}" @@ -307,7 +383,21 @@ async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]: raise NotImplementedError(f"Unknown store {store_name}") -ALL_CHECKPOINTERS_SYNC = ["pymysql", "pymysql_sqlalchemy_pool", "pymysql_callable"] -ALL_CHECKPOINTERS_ASYNC = ["aiomysql", "aiomysql_pool"] +SHALLOW_CHECKPOINTERS_SYNC = ["pymysql_shallow"] +REGULAR_CHECKPOINTERS_SYNC = [ + "pymysql", + "pymysql_sqlalchemy_pool", + "pymysql_callable" +] +ALL_CHECKPOINTERS_SYNC = [ + *REGULAR_CHECKPOINTERS_SYNC, + *SHALLOW_CHECKPOINTERS_SYNC, +] +SHALLOW_CHECKPOINTERS_ASYNC = ["aiomysql_shallow"] +REGULAR_CHECKPOINTERS_ASYNC = ["aiomysql", "aiomysql_pool"] +ALL_CHECKPOINTERS_ASYNC = [ + *REGULAR_CHECKPOINTERS_ASYNC, + *SHALLOW_CHECKPOINTERS_ASYNC, +] ALL_STORES_SYNC = ["pymysql", "pymysql_sqlalchemy_pool", "pymysql_callable"] ALL_STORES_ASYNC = ["aiomysql", "aiomysql_pool"] diff --git a/langgraph-tests/tests/test_large_cases.py b/langgraph-tests/tests/test_large_cases.py index 8c9a6de..39cdc91 100644 --- a/langgraph-tests/tests/test_large_cases.py +++ b/langgraph-tests/tests/test_large_cases.py @@ -37,7 +37,11 @@ ) from tests.agents import AgentAction, AgentFinish from tests.any_str import AnyDict, AnyStr, UnsortedSequence -from tests.conftest import ALL_CHECKPOINTERS_SYNC, SHOULD_CHECK_SNAPSHOTS +from tests.conftest import ( + ALL_CHECKPOINTERS_SYNC, + REGULAR_CHECKPOINTERS_SYNC, + SHOULD_CHECK_SNAPSHOTS, +) from tests.fake_tracer import FakeTracer from tests.messages import ( _AnyIdAIMessage, @@ -109,6 +113,9 @@ def test_invoke_two_processes_in_out_interrupt( snapshot = app.get_state(thread2) assert snapshot.next == () + if "shallow" in checkpointer_name: + return + # list history history = [c for c in app.get_state_history(thread1)] assert history == [ @@ -293,7 +300,7 @@ def test_invoke_two_processes_in_out_interrupt( ] -@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +@pytest.mark.parametrize("checkpointer_name", REGULAR_CHECKPOINTERS_SYNC) def test_fork_always_re_runs_nodes( request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture ) -> None: @@ -739,8 +746,14 @@ def should_continue(data: dict) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], - config=app_w_interrupt.checkpointer.get_tuple(config).config, + created_at=AnyStr(), + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, metadata={ "parents": {}, "source": "loop", @@ -759,7 +772,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert ( app_w_interrupt.checkpointer.get_tuple(config).config["configurable"][ @@ -793,8 +810,14 @@ def should_continue(data: dict) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -811,7 +834,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -905,8 +932,14 @@ def should_continue(data: dict) -> str: }, tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -932,7 +965,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # test state get/update methods with interrupt_before @@ -968,8 +1005,14 @@ def should_continue(data: dict) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -988,7 +1031,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -1016,8 +1063,14 @@ def should_continue(data: dict) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1034,7 +1087,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -1128,8 +1185,14 @@ def should_continue(data: dict) -> str: }, tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1155,7 +1218,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # test re-invoke to continue with interrupt_before @@ -1191,8 +1258,14 @@ def should_continue(data: dict) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1211,7 +1284,11 @@ def should_continue(data: dict) -> str: }, "thread_id": "3", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -1580,8 +1657,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1597,7 +1680,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) with assert_ctx_once(): @@ -1623,8 +1710,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1640,7 +1733,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) with assert_ctx_once(): @@ -1701,8 +1798,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1717,7 +1820,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # test state get/update methods with interrupt_before @@ -1752,8 +1859,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1769,7 +1882,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -1794,8 +1911,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1811,7 +1934,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -1870,8 +1997,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1886,7 +2019,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # test w interrupt before all @@ -1910,8 +2047,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1919,7 +2062,11 @@ def should_continue(data: AgentState) -> str: "writes": None, "thread_id": "3", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -1942,8 +2089,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1959,7 +2112,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "3", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -1998,8 +2155,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -2020,7 +2183,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "3", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -2066,8 +2233,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "4", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -2083,7 +2256,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "4", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -2122,8 +2299,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "4", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -2144,7 +2327,11 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "4", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -2478,8 +2665,14 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: PregelTask(AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2)), ), next=("tools",), - config=(app_w_interrupt.checkpointer.get_tuple(config)).config, - created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -2487,7 +2680,11 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: "writes": None, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config + ), ) # modify ai message @@ -2517,8 +2714,14 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, tasks=(PregelTask(AnyStr(), "tools", (PUSH, (), 0)),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -2541,7 +2744,11 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -2648,8 +2855,14 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: PregelTask(AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 3)), ), next=("tools", "tools"), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -2665,7 +2878,11 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config + ), ) app_w_interrupt.update_state( @@ -2702,8 +2919,14 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -2716,7 +2939,11 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config + ), ) # interrupt before tools @@ -2795,8 +3022,14 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: PregelTask(AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2)), ), next=("tools",), - config=(app_w_interrupt.checkpointer.get_tuple(config)).config, - created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -2804,7 +3037,11 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: "writes": None, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config + ), ) # modify ai message @@ -2835,7 +3072,7 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: tasks=(PregelTask(AnyStr(), "tools", (PUSH, (), 0)),), next=("tools",), config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -2858,7 +3095,11 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -2965,8 +3206,14 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: PregelTask(AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 3)), ), next=("tools", "tools"), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -2982,7 +3229,11 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config + ), ) app_w_interrupt.update_state( @@ -3019,8 +3270,14 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -3033,7 +3290,11 @@ def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config + ), ) @@ -3298,11 +3559,17 @@ def should_continue(messages): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], - metadata={ - "parents": {}, - "source": "loop", + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), + metadata={ + "parents": {}, + "source": "loop", "step": 1, "writes": { "agent": AIMessage( @@ -3319,7 +3586,11 @@ def should_continue(messages): }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # modify ai message @@ -3366,7 +3637,11 @@ def should_continue(messages): }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -3428,8 +3703,14 @@ def should_continue(messages): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -3449,7 +3730,11 @@ def should_continue(messages): }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -3481,8 +3766,14 @@ def should_continue(messages): ], tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -3490,7 +3781,11 @@ def should_continue(messages): "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt = workflow.compile( @@ -3534,8 +3829,14 @@ def should_continue(messages): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -3555,7 +3856,11 @@ def should_continue(messages): }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # modify ai message @@ -3581,8 +3886,14 @@ def should_continue(messages): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -3602,7 +3913,11 @@ def should_continue(messages): }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -3664,8 +3979,14 @@ def should_continue(messages): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -3685,7 +4006,11 @@ def should_continue(messages): }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -3718,8 +4043,14 @@ def should_continue(messages): ], tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -3727,7 +4058,11 @@ def should_continue(messages): "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # add an extra message as if it came from "tools" node @@ -3760,8 +4095,14 @@ def should_continue(messages): ], tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -3769,7 +4110,11 @@ def should_continue(messages): "writes": {"tools": UnsortedSequence("ai", "an extra message")}, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) @@ -4037,8 +4382,14 @@ class State(TypedDict): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4058,7 +4409,11 @@ class State(TypedDict): }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # modify ai message @@ -4105,7 +4460,11 @@ class State(TypedDict): }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -4168,8 +4527,14 @@ class State(TypedDict): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4189,7 +4554,11 @@ class State(TypedDict): }, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -4222,8 +4591,14 @@ class State(TypedDict): ], tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -4231,7 +4606,11 @@ class State(TypedDict): "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "1", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt = workflow.compile( @@ -4275,8 +4654,14 @@ class State(TypedDict): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4296,7 +4681,11 @@ class State(TypedDict): }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # modify ai message @@ -4322,8 +4711,14 @@ class State(TypedDict): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -4343,7 +4738,11 @@ class State(TypedDict): }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) assert [c for c in app_w_interrupt.stream(None, config)] == [ @@ -4406,8 +4805,14 @@ class State(TypedDict): ], tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4427,7 +4832,11 @@ class State(TypedDict): }, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) app_w_interrupt.update_state( @@ -4459,8 +4868,14 @@ class State(TypedDict): ], tasks=(), next=(), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -4468,7 +4883,11 @@ class State(TypedDict): "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # add an extra message as if it came from "tools" node @@ -4501,8 +4920,14 @@ class State(TypedDict): ], tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -4510,7 +4935,11 @@ class State(TypedDict): "writes": {"tools": UnsortedSequence("ai", "an extra message")}, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ), ) # create new graph with one more state key, reuse previous thread history @@ -4574,8 +5003,14 @@ class MoreState(TypedDict): }, tasks=(PregelTask(AnyStr(), "agent", (PULL, "agent")),), next=("agent",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -4583,7 +5018,11 @@ class MoreState(TypedDict): "writes": {"tools": UnsortedSequence("ai", "an extra message")}, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(new_app.checkpointer.list(config, limit=2))[-1].config + ), ) # new input is merged to old state @@ -4705,22 +5144,25 @@ def tool_two_node(s: State) -> State: "my_key": "value ⛰️", "market": "DE", } - assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ - { - "parents": {}, - "source": "loop", - "step": 0, - "writes": None, - "thread_id": "1", - }, - { - "parents": {}, - "source": "input", - "step": -1, - "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, - "thread_id": "1", - }, - ] + + if "shallow" not in checkpointer_name: + assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] + assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, next=("tool_two",), @@ -4738,8 +5180,14 @@ def tool_two_node(s: State) -> State: ), ), ), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4747,7 +5195,11 @@ def tool_two_node(s: State) -> State: "writes": None, "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) # clear the interrupt and next tasks tool_two.update_state(thread1, None, as_node=END) @@ -4756,8 +5208,14 @@ def tool_two_node(s: State) -> State: values={"my_key": "value ⛰️", "market": "DE"}, next=(), tasks=(), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -4765,7 +5223,11 @@ def tool_two_node(s: State) -> State: "writes": {}, "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) @@ -4859,22 +5321,24 @@ def start(state: State) -> list[Union[Send, str]]: "my_key": "value ⛰️ one", "market": "DE", } - assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ - { - "parents": {}, - "source": "loop", - "step": 0, - "writes": {"tool_one": {"my_key": " one"}}, - "thread_id": "1", - }, - { - "parents": {}, - "source": "input", - "step": -1, - "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, - "thread_id": "1", - }, - ] + if "shallow" not in checkpointer_name: + assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": {"tool_one": {"my_key": " one"}}, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] + assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️ one", "market": "DE"}, next=("tool_two",), @@ -4892,8 +5356,14 @@ def start(state: State) -> list[Union[Send, str]]: ), ), ), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4901,7 +5371,11 @@ def start(state: State) -> list[Union[Send, str]]: "writes": {"tool_one": {"my_key": " one"}}, "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*tool_two.checkpointer.list(thread1, limit=2)][-1].config + ), ) # clear the interrupt and next tasks tool_two.update_state(thread1, None) @@ -4917,8 +5391,14 @@ def start(state: State) -> list[Union[Send, str]]: interrupts=(), ), ), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -4926,7 +5406,11 @@ def start(state: State) -> list[Union[Send, str]]: "writes": {}, "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [*tool_two.checkpointer.list(thread1, limit=2)][-1].config + ), ) @@ -5017,27 +5501,30 @@ class State(TypedDict): "my_key": "value ⛰️", "market": "DE", } - assert [ - c.metadata - for c in tool_two.checkpointer.list( - {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} - ) - ] == [ - { - "parents": {}, - "source": "loop", - "step": 0, - "writes": None, - "thread_id": "1", - }, - { - "parents": {}, - "source": "input", - "step": -1, - "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, - "thread_id": "1", - }, - ] + + if "shallow" not in checkpointer_name: + assert [ + c.metadata + for c in tool_two.checkpointer.list( + {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} + ) + ] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] + assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, next=("tool_two",), @@ -5061,8 +5548,14 @@ class State(TypedDict): }, ), ), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5070,11 +5563,15 @@ class State(TypedDict): "writes": None, "thread_id": "1", }, - parent_config=[ - *tool_two.checkpointer.list( - {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2 - ) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list( + tool_two.checkpointer.list( + {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2 + ) + )[-1].config + ), ) # clear the interrupt and next tasks tool_two.update_state(thread1, None, as_node=END) @@ -5083,8 +5580,14 @@ class State(TypedDict): values={"my_key": "value ⛰️", "market": "DE"}, next=(), tasks=(), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -5092,11 +5595,15 @@ class State(TypedDict): "writes": {}, "thread_id": "1", }, - parent_config=[ - *tool_two.checkpointer.list( - {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2 - ) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list( + tool_two.checkpointer.list( + {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}, limit=2 + ) + )[-1].config + ), ) @@ -5166,30 +5673,39 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "my_key": "value ⛰️", "market": "DE", } - assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ - { - "parents": {}, - "source": "loop", - "step": 0, - "writes": None, - "assistant_id": "a", - "thread_id": "1", - }, - { - "parents": {}, - "source": "input", - "step": -1, - "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, - "assistant_id": "a", - "thread_id": "1", - }, - ] + + if "shallow" not in checkpointer_name: + assert [c.metadata for c in tool_two.checkpointer.list(thread1)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "assistant_id": "a", + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "assistant_id": "a", + "thread_id": "1", + }, + ] + assert tool_two.get_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5198,7 +5714,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread1, debug=1) == { @@ -5209,8 +5729,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "value ⛰️ slow", "market": "DE"}, tasks=(), next=(), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5219,7 +5745,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) thread2 = {"configurable": {"thread_id": "2", "assistant_id": "a"}} @@ -5232,8 +5762,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "value", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=tool_two.checkpointer.get_tuple(thread2).config, - created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5242,7 +5778,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "2", }, - parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread2, debug=1) == { @@ -5253,8 +5793,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "value fast", "market": "US"}, tasks=(), next=(), - config=tool_two.checkpointer.get_tuple(thread2).config, - created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5263,7 +5809,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "2", }, - parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) thread3 = {"configurable": {"thread_id": "3", "assistant_id": "b"}} @@ -5276,8 +5826,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "value", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=tool_two.checkpointer.get_tuple(thread3).config, - created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5286,7 +5842,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "b", "thread_id": "3", }, - parent_config=[*tool_two.checkpointer.list(thread3, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread3, limit=2))[-1].config + ), ) # update state tool_two.update_state(thread3, {"my_key": "key"}) # appends to my_key @@ -5294,8 +5854,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "valuekey", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=tool_two.checkpointer.get_tuple(thread3).config, - created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -5304,7 +5870,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "b", "thread_id": "3", }, - parent_config=[*tool_two.checkpointer.list(thread3, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread3, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread3, debug=1) == { @@ -5315,8 +5885,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "valuekey fast", "market": "US"}, tasks=(), next=(), - config=tool_two.checkpointer.get_tuple(thread3).config, - created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5325,7 +5901,11 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "b", "thread_id": "3", }, - parent_config=[*tool_two.checkpointer.list(thread3, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread3, limit=2))[-1].config + ), ) @@ -5679,8 +6259,14 @@ class State(TypedDict): values={"my_key": "value prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5688,7 +6274,11 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread1, debug=1) == { @@ -5699,8 +6289,14 @@ class State(TypedDict): values={"my_key": "value prepared slow finished", "market": "DE"}, tasks=(), next=(), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5708,7 +6304,11 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "1", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) thread2 = {"configurable": {"thread_id": "2"}} @@ -5721,8 +6321,14 @@ class State(TypedDict): values={"my_key": "value prepared", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=tool_two.checkpointer.get_tuple(thread2).config, - created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5730,7 +6336,11 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "2", }, - parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread2, debug=1) == { @@ -5741,8 +6351,14 @@ class State(TypedDict): values={"my_key": "value prepared fast finished", "market": "US"}, tasks=(), next=(), - config=tool_two.checkpointer.get_tuple(thread2).config, - created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5750,7 +6366,11 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "2", }, - parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) tool_two = tool_two_graph.compile( @@ -5771,8 +6391,14 @@ class State(TypedDict): }, tasks=(PregelTask(AnyStr(), "finish", (PULL, "finish")),), next=("finish",), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "11", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5780,7 +6406,11 @@ class State(TypedDict): "writes": {"tool_two_slow": {"my_key": " slow"}}, "thread_id": "11", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) # update state @@ -5792,8 +6422,14 @@ class State(TypedDict): }, tasks=(PregelTask(AnyStr(), "finish", (PULL, "finish")),), next=("finish",), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "11", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -5801,7 +6437,11 @@ class State(TypedDict): "writes": {"tool_two_slow": {"my_key": "er"}}, "thread_id": "11", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) tool_two = tool_two_graph.compile( @@ -5822,8 +6462,14 @@ class State(TypedDict): values={"my_key": "value prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "21", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5831,7 +6477,11 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "21", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread1, debug=1) == { @@ -5842,8 +6492,14 @@ class State(TypedDict): values={"my_key": "value prepared slow finished", "market": "DE"}, tasks=(), next=(), - config=tool_two.checkpointer.get_tuple(thread1).config, - created_at=tool_two.checkpointer.get_tuple(thread1).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "21", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5851,7 +6507,11 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "21", }, - parent_config=[*tool_two.checkpointer.list(thread1, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread1, limit=2))[-1].config + ), ) thread2 = {"configurable": {"thread_id": "22"}} @@ -5864,8 +6524,14 @@ class State(TypedDict): values={"my_key": "value prepared", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=tool_two.checkpointer.get_tuple(thread2).config, - created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "22", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5873,7 +6539,11 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "22", }, - parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread2, debug=1) == { @@ -5884,8 +6554,14 @@ class State(TypedDict): values={"my_key": "value prepared fast finished", "market": "US"}, tasks=(), next=(), - config=tool_two.checkpointer.get_tuple(thread2).config, - created_at=tool_two.checkpointer.get_tuple(thread2).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "22", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5893,18 +6569,28 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "22", }, - parent_config=[*tool_two.checkpointer.list(thread2, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread2, limit=2))[-1].config + ), ) thread3 = {"configurable": {"thread_id": "23"}} # update an empty thread before first run - uconfig = tool_two.update_state(thread3, {"my_key": "key", "market": "DE"}) + tool_two.update_state(thread3, {"my_key": "key", "market": "DE"}) # check current state assert tool_two.get_state(thread3) == StateSnapshot( values={"my_key": "key", "market": "DE"}, tasks=(PregelTask(AnyStr(), "prepare", (PULL, "prepare")),), next=("prepare",), - config=uconfig, + config={ + "configurable": { + "thread_id": "23", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, created_at=AnyStr(), metadata={ "parents": {}, @@ -5925,8 +6611,14 @@ class State(TypedDict): values={"my_key": "key prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), - config=tool_two.checkpointer.get_tuple(thread3).config, - created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "23", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5934,7 +6626,11 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "23", }, - parent_config=uconfig, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread3, limit=2))[-1].config + ), ) # resume, for same result as above assert tool_two.invoke(None, thread3, debug=1) == { @@ -5945,8 +6641,14 @@ class State(TypedDict): values={"my_key": "key prepared slow finished", "market": "DE"}, tasks=(), next=(), - config=tool_two.checkpointer.get_tuple(thread3).config, - created_at=tool_two.checkpointer.get_tuple(thread3).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "23", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -5954,7 +6656,11 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "23", }, - parent_config=[*tool_two.checkpointer.list(thread3, limit=2)][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else list(tool_two.checkpointer.list(thread3, limit=2))[-1].config + ), ) @@ -6028,8 +6734,10 @@ def route_to_three(state) -> Literal["3"]: state = graph.get_state(thread1) assert state.next == ("flaky",) # check history - history = [c for c in graph.get_state_history(thread1)] - assert len(history) == 2 + if "shallow" not in checkpointer_name: + history = [c for c in graph.get_state_history(thread1)] + assert len(history) == 2 + # resume execution assert graph.invoke(None, thread1, debug=1) == [ "0", @@ -6050,7 +6758,7 @@ def route_to_three(state) -> Literal["3"]: assert state.next == () # check history history = [c for c in graph.get_state_history(thread1)] - assert history == [ + expected_history = [ StateSnapshot( values=[ "0", @@ -6078,13 +6786,17 @@ def route_to_three(state) -> Literal["3"]: "parents": {}, }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=(), ), StateSnapshot( @@ -6277,6 +6989,10 @@ def route_to_three(state) -> Literal["3"]: ), ), ] + if "shallow" in checkpointer_name: + expected_history = expected_history[:1] + + assert history == expected_history @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) @@ -6362,13 +7078,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # now, get_state with subgraphs state assert app.get_state(config, subgraphs=True) == StateSnapshot( @@ -6422,16 +7142,20 @@ def outer_2(state: State): "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("inner:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - {"": AnyStr(), AnyStr("child:"): AnyStr()} - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("inner:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), + } } - }, + ), ), ), ), @@ -6451,17 +7175,21 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # get_state_history returns outer graph checkpoints history = list(app.get_state_history(config)) - assert history == [ + expected_history = [ StateSnapshot( values={"my_key": "hi my value"}, tasks=( @@ -6493,13 +7221,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ), StateSnapshot( values={"my_key": "my value"}, @@ -6564,9 +7296,15 @@ def outer_2(state: State): parent_config=None, ), ] + + if "shallow" in checkpointer_name: + expected_history = expected_history[:1] + + assert history == expected_history + # get_state_history for a subgraph returns its checkpoints child_history = [*app.get_state_history(history[0].tasks[0].state)] - assert child_history == [ + expected_child_history = [ StateSnapshot( values={"my_key": "hi my value here", "my_other_key": "hi my value"}, next=("inner_2",), @@ -6599,16 +7337,20 @@ def outer_2(state: State): "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("inner:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - {"": AnyStr(), AnyStr("child:"): AnyStr()} - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("inner:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), + } } - }, + ), tasks=(PregelTask(AnyStr(), "inner_2", (PULL, "inner_2")),), ), StateSnapshot( @@ -6699,6 +7441,11 @@ def outer_2(state: State): ), ] + if "shallow" in checkpointer_name: + expected_child_history = expected_child_history[:1] + + assert child_history == expected_child_history + # resume app.invoke(None, config, debug=True) # test state w/ nested subgraph state (after resuming from interrupt) @@ -6723,13 +7470,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # test full history at the end actual_history = list(app.get_state_history(config)) @@ -6755,13 +7506,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ), StateSnapshot( values={"my_key": "hi my value here and there"}, @@ -6897,6 +7652,9 @@ def outer_2(state: State): parent_config=None, ), ] + if "shallow" in checkpointer_name: + expected_history = expected_history[:1] + assert actual_history == expected_history # test looking up parent state by checkpoint ID for actual_snapshot, expected_snapshot in zip(actual_history, expected_history): @@ -7001,13 +7759,17 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) child_state = app.get_state(outer_state.tasks[0].state) assert ( @@ -7043,13 +7805,17 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("child:"), - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("child:"), + "checkpoint_id": AnyStr(), + } } - }, + ), ).tasks[0] ) grandchild_state = app.get_state(child_state.tasks[0].state) @@ -7096,20 +7862,24 @@ def parent_2(state: State): "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("child:"): AnyStr(), - AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), + } + ), + } } - }, + ), ) # get state with subgraphs assert app.get_state(config, subgraphs=True) == StateSnapshot( @@ -7176,22 +7946,26 @@ def parent_2(state: State): "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("child:"): AnyStr(), - AnyStr( - re.compile(r"child:.+|child1:") - ): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr( + re.compile(r"child:.+|child1:") + ): AnyStr(), + } + ), + } } - }, + ), ), ), ), @@ -7220,16 +7994,20 @@ def parent_2(state: State): "langgraph_checkpoint_ns": AnyStr("child:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("child:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - {"": AnyStr(), AnyStr("child:"): AnyStr()} - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("child:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), + } } - }, + ), ), ), ), @@ -7249,13 +8027,17 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # # resume assert [c for c in app.stream(None, config, subgraphs=True)] == [ @@ -7292,15 +8074,23 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) ) + + if "shallow" in checkpointer_name: + return + # get outer graph history outer_history = list(app.get_state_history(config)) assert outer_history == [ @@ -8000,19 +8790,23 @@ def edit(state: JokeState): "langgraph_triggers": [PUSH], }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("generate_joke:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("generate_joke:"): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("generate_joke:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("generate_joke:"): AnyStr(), + } + ), + } } - }, + ), tasks=(PregelTask(id=AnyStr(""), name="generate", path=(PULL, "generate")),), ) assert graph.get_state(outer_state.tasks[2].state) == StateSnapshot( @@ -8045,19 +8839,23 @@ def edit(state: JokeState): "langgraph_triggers": [PUSH], }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("generate_joke:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("generate_joke:"): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("generate_joke:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("generate_joke:"): AnyStr(), + } + ), + } } - }, + ), tasks=(PregelTask(id=AnyStr(""), name="generate", path=(PULL, "generate")),), ) # update state of dogs joke graph @@ -8101,16 +8899,23 @@ def edit(state: JokeState): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) assert actual_snapshot == expected_snapshot + if "shallow" in checkpointer_name: + return + # test full history actual_history = list(graph.get_state_history(config)) @@ -8372,13 +9177,17 @@ def foo(call: ToolCall): "thread_id": "2", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -8455,13 +9264,17 @@ def foo(call: ToolCall): "thread_id": "2", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=(), ) @@ -8531,13 +9344,17 @@ def foo(call: ToolCall): "thread_id": "3", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "3", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -8640,13 +9457,17 @@ def foo(call: ToolCall): "thread_id": "3", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "3", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -8841,13 +9662,17 @@ def foo(call: ToolCall): "thread_id": "2", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -8924,13 +9749,17 @@ def foo(call: ToolCall): "thread_id": "2", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=(), ) @@ -9104,13 +9933,17 @@ def weather_graph(state: RouterState): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -9192,13 +10025,17 @@ def weather_graph(state: RouterState): "thread_id": "14", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -9241,19 +10078,23 @@ def weather_graph(state: RouterState): "langgraph_checkpoint_ns": AnyStr("weather_graph:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": AnyStr("weather_graph:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("weather_graph:"): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": AnyStr("weather_graph:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("weather_graph:"): AnyStr(), + } + ), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -9292,13 +10133,17 @@ def weather_graph(state: RouterState): "thread_id": "14", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -9347,19 +10192,23 @@ def weather_graph(state: RouterState): "langgraph_checkpoint_ns": AnyStr("weather_graph:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": AnyStr("weather_graph:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("weather_graph:"): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": AnyStr("weather_graph:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("weather_graph:"): AnyStr(), + } + ), + } } - }, + ), tasks=(), ), ), diff --git a/langgraph-tests/tests/test_large_cases_async.py b/langgraph-tests/tests/test_large_cases_async.py index 6e80c9e..4e0fd76 100644 --- a/langgraph-tests/tests/test_large_cases_async.py +++ b/langgraph-tests/tests/test_large_cases_async.py @@ -35,7 +35,11 @@ from langgraph.store.memory import InMemoryStore from langgraph.types import PregelTask, Send, StateSnapshot, StreamWriter from tests.any_str import AnyDict, AnyStr -from tests.conftest import ALL_CHECKPOINTERS_ASYNC, awith_checkpointer +from tests.conftest import ( + ALL_CHECKPOINTERS_ASYNC, + REGULAR_CHECKPOINTERS_ASYNC, + awith_checkpointer, +) from tests.fake_tracer import FakeTracer from tests.messages import ( _AnyIdAIMessage, @@ -108,6 +112,9 @@ async def test_invoke_two_processes_in_out_interrupt( snapshot = await app.aget_state(thread2) assert snapshot.next == () + if "shallow" in checkpointer_name: + return + # list history history = [c async for c in app.aget_state_history(thread1)] assert history == [ @@ -308,7 +315,7 @@ async def test_invoke_two_processes_in_out_interrupt( ] -@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +@pytest.mark.parametrize("checkpointer_name", REGULAR_CHECKPOINTERS_ASYNC) async def test_fork_always_re_runs_nodes( checkpointer_name: str, mocker: MockerFixture ) -> None: @@ -841,9 +848,13 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) await app_w_interrupt.aupdate_state( @@ -871,10 +882,14 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -891,9 +906,13 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -987,10 +1006,14 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, tasks=(), next=(), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1016,9 +1039,13 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) # test state get/update methods with interrupt_before @@ -1061,10 +1088,14 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1083,9 +1114,13 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) await app_w_interrupt.aupdate_state( @@ -1113,10 +1148,14 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1133,9 +1172,13 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -1229,10 +1272,14 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, tasks=(), next=(), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1258,9 +1305,13 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) # test re-invoke to continue with interrupt_before @@ -1303,10 +1354,14 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1325,9 +1380,13 @@ async def should_continue(data: dict, config: RunnableConfig) -> str: }, "thread_id": "3", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -1727,10 +1786,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1746,9 +1809,13 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) async with assert_ctx_once(): @@ -1774,10 +1841,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1793,9 +1864,13 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) async with assert_ctx_once(): @@ -1856,10 +1931,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(), next=(), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1874,9 +1953,13 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) # test state get/update methods with interrupt_before @@ -1915,10 +1998,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -1934,9 +2021,13 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) await app_w_interrupt.aupdate_state( @@ -1961,10 +2052,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(PregelTask(AnyStr(), "tools", (PULL, "tools")),), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -1980,9 +2075,13 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -2041,10 +2140,14 @@ def should_continue(data: AgentState) -> str: }, tasks=(), next=(), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -2059,9 +2162,13 @@ def should_continue(data: AgentState) -> str: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) @@ -2365,10 +2472,14 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: PregelTask(AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2)), ), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -2376,9 +2487,13 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: "writes": None, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) # modify ai message @@ -2430,9 +2545,13 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -2557,9 +2676,13 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) await app_w_interrupt.aupdate_state( @@ -2607,9 +2730,13 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) # interrupt before tools @@ -2644,7 +2771,7 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, {"__interrupt__": ()}, ] - + tup = await app_w_interrupt.checkpointer.aget_tuple(config) assert await app_w_interrupt.aget_state(config) == StateSnapshot( values={ "messages": [ @@ -2690,10 +2817,8 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: PregelTask(AnyStr(), "tools", (PUSH, ("__pregel_pull", "agent"), 2)), ), next=("tools",), - config=(await app_w_interrupt.checkpointer.aget_tuple(config)).config, - created_at=( - await app_w_interrupt.checkpointer.aget_tuple(config) - ).checkpoint["ts"], + config=tup.config, + created_at=tup.checkpoint["ts"], metadata={ "parents": {}, "source": "loop", @@ -2701,9 +2826,13 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: "writes": None, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) # modify ai message @@ -2755,9 +2884,13 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -2884,9 +3017,13 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) await app_w_interrupt.aupdate_state( @@ -2934,9 +3071,13 @@ async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: }, "thread_id": "2", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) @@ -3195,9 +3336,13 @@ def should_continue(messages): }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) # modify ai message @@ -3245,9 +3390,13 @@ def should_continue(messages): }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) assert [c async for c in app_w_interrupt.astream(None, config)] == [ @@ -3331,9 +3480,13 @@ def should_continue(messages): }, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) await app_w_interrupt.aupdate_state( @@ -3375,9 +3528,13 @@ def should_continue(messages): "writes": {"agent": AIMessage(content="answer", id="ai2")}, "thread_id": "1", }, - parent_config=[ - c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in app_w_interrupt.checkpointer.alist(config, limit=2) + ][-1].config + ), ) @@ -3444,32 +3601,38 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "my_key": "value", "market": "DE", } - assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ - { - "parents": {}, - "source": "loop", - "step": 0, - "writes": None, - "assistant_id": "a", - "thread_id": "1", - }, - { - "parents": {}, - "source": "input", - "step": -1, - "writes": {"__start__": {"my_key": "value", "market": "DE"}}, - "assistant_id": "a", - "thread_id": "1", - }, - ] + if "shallow" not in checkpointer_name: + assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "assistant_id": "a", + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value", "market": "DE"}}, + "assistant_id": "a", + "thread_id": "1", + }, + ] + assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), - config=(await tool_two.checkpointer.aget_tuple(thread1)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -3478,9 +3641,13 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "1", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) # resume, for same result as above assert await tool_two.ainvoke(None, thread1, debug=1) == { @@ -3491,10 +3658,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "value slow", "market": "DE"}, tasks=(), next=(), - config=(await tool_two.checkpointer.aget_tuple(thread1)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -3503,9 +3674,13 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "1", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) thread2 = {"configurable": {"thread_id": "2", "assistant_id": "a"}} @@ -3518,10 +3693,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "value", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=(await tool_two.checkpointer.aget_tuple(thread2)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -3530,9 +3709,13 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "2", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread2, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread2, limit=2)][ + -1 + ].config + ), ) # resume, for same result as above assert await tool_two.ainvoke(None, thread2, debug=1) == { @@ -3543,10 +3726,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "value fast", "market": "US"}, tasks=(), next=(), - config=(await tool_two.checkpointer.aget_tuple(thread2)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -3555,9 +3742,13 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "a", "thread_id": "2", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread2, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread2, limit=2)][ + -1 + ].config + ), ) thread3 = {"configurable": {"thread_id": "3", "assistant_id": "b"}} @@ -3570,10 +3761,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "value", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=(await tool_two.checkpointer.aget_tuple(thread3)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -3582,9 +3777,13 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "b", "thread_id": "3", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread3, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread3, limit=2)][ + -1 + ].config + ), ) # update state await tool_two.aupdate_state(thread3, {"my_key": "key"}) # appends to my_key @@ -3592,10 +3791,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "valuekey", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=(await tool_two.checkpointer.aget_tuple(thread3)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -3604,9 +3807,13 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "b", "thread_id": "3", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread3, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread3, limit=2)][ + -1 + ].config + ), ) # resume, for same result as above assert await tool_two.ainvoke(None, thread3, debug=1) == { @@ -3617,10 +3824,14 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: values={"my_key": "valuekey fast", "market": "US"}, tasks=(), next=(), - config=(await tool_two.checkpointer.aget_tuple(thread3)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "3", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -3629,9 +3840,13 @@ def tool_two_fast(data: State, config: RunnableConfig) -> State: "assistant_id": "b", "thread_id": "3", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread3, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread3, limit=2)][ + -1 + ].config + ), ) @@ -4149,10 +4364,14 @@ class State(TypedDict): values={"my_key": "value prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), - config=(await tool_two.checkpointer.aget_tuple(thread1)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "11", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4160,9 +4379,13 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "11", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) # resume, for same result as above assert await tool_two.ainvoke(None, thread1, debug=1) == { @@ -4173,10 +4396,14 @@ class State(TypedDict): values={"my_key": "value prepared slow finished", "market": "DE"}, tasks=(), next=(), - config=(await tool_two.checkpointer.aget_tuple(thread1)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "11", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4184,9 +4411,13 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "11", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) thread2 = {"configurable": {"thread_id": "12"}} @@ -4199,10 +4430,14 @@ class State(TypedDict): values={"my_key": "value prepared", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=(await tool_two.checkpointer.aget_tuple(thread2)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "12", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4210,9 +4445,13 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "12", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread2, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread2, limit=2)][ + -1 + ].config + ), ) # resume, for same result as above assert await tool_two.ainvoke(None, thread2, debug=1) == { @@ -4223,10 +4462,14 @@ class State(TypedDict): values={"my_key": "value prepared fast finished", "market": "US"}, tasks=(), next=(), - config=(await tool_two.checkpointer.aget_tuple(thread2)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "12", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4234,9 +4477,13 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "12", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread2, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread2, limit=2)][ + -1 + ].config + ), ) tool_two = tool_two_graph.compile( @@ -4257,10 +4504,14 @@ class State(TypedDict): values={"my_key": "value prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), - config=(await tool_two.checkpointer.aget_tuple(thread1)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "21", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4268,9 +4519,13 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "21", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) # resume, for same result as above assert await tool_two.ainvoke(None, thread1, debug=1) == { @@ -4281,10 +4536,14 @@ class State(TypedDict): values={"my_key": "value prepared slow finished", "market": "DE"}, tasks=(), next=(), - config=(await tool_two.checkpointer.aget_tuple(thread1)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread1)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "21", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4292,9 +4551,13 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "21", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) thread2 = {"configurable": {"thread_id": "22"}} @@ -4307,10 +4570,14 @@ class State(TypedDict): values={"my_key": "value prepared", "market": "US"}, tasks=(PregelTask(AnyStr(), "tool_two_fast", (PULL, "tool_two_fast")),), next=("tool_two_fast",), - config=(await tool_two.checkpointer.aget_tuple(thread2)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "22", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4318,9 +4585,13 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "22", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread2, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread2, limit=2)][ + -1 + ].config + ), ) # resume, for same result as above assert await tool_two.ainvoke(None, thread2, debug=1) == { @@ -4331,10 +4602,14 @@ class State(TypedDict): values={"my_key": "value prepared fast finished", "market": "US"}, tasks=(), next=(), - config=(await tool_two.checkpointer.aget_tuple(thread2)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread2)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "22", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4342,9 +4617,13 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "22", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread2, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread2, limit=2)][ + -1 + ].config + ), ) thread3 = {"configurable": {"thread_id": "23"}} @@ -4378,10 +4657,14 @@ class State(TypedDict): values={"my_key": "key prepared", "market": "DE"}, tasks=(PregelTask(AnyStr(), "tool_two_slow", (PULL, "tool_two_slow")),), next=("tool_two_slow",), - config=(await tool_two.checkpointer.aget_tuple(thread3)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "23", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4389,7 +4672,7 @@ class State(TypedDict): "writes": {"prepare": {"my_key": " prepared"}}, "thread_id": "23", }, - parent_config=uconfig, + parent_config=(None if "shallow" in checkpointer_name else uconfig), ) # resume, for same result as above assert await tool_two.ainvoke(None, thread3, debug=1) == { @@ -4400,10 +4683,14 @@ class State(TypedDict): values={"my_key": "key prepared slow finished", "market": "DE"}, tasks=(), next=(), - config=(await tool_two.checkpointer.aget_tuple(thread3)).config, - created_at=(await tool_two.checkpointer.aget_tuple(thread3)).checkpoint[ - "ts" - ], + config={ + "configurable": { + "thread_id": "23", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "loop", @@ -4411,9 +4698,13 @@ class State(TypedDict): "writes": {"finish": {"my_key": " finished"}}, "thread_id": "23", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread3, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread3, limit=2)][ + -1 + ].config + ), ) @@ -4499,13 +4790,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # now, get_state with subgraphs state assert await app.aget_state(config, subgraphs=True) == StateSnapshot( @@ -4559,16 +4854,20 @@ def outer_2(state: State): "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("inner:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - {"": AnyStr(), AnyStr("child:"): AnyStr()} - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("inner:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), + } } - }, + ), ), ), ), @@ -4588,17 +4887,21 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # get_state_history returns outer graph checkpoints history = [c async for c in app.aget_state_history(config)] - assert history == [ + expected_history = [ StateSnapshot( values={"my_key": "hi my value"}, tasks=( @@ -4630,13 +4933,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ), StateSnapshot( values={"my_key": "my value"}, @@ -4701,11 +5008,17 @@ def outer_2(state: State): parent_config=None, ), ] + + if "shallow" in checkpointer_name: + expected_history = expected_history[:1] + + assert history == expected_history + # get_state_history for a subgraph returns its checkpoints child_history = [ c async for c in app.aget_state_history(history[0].tasks[0].state) ] - assert child_history == [ + expected_child_history = [ StateSnapshot( values={"my_key": "hi my value here", "my_other_key": "hi my value"}, next=("inner_2",), @@ -4738,16 +5051,20 @@ def outer_2(state: State): "langgraph_checkpoint_ns": AnyStr("inner:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("inner:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - {"": AnyStr(), AnyStr("child:"): AnyStr()} - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("inner:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), + } } - }, + ), tasks=(PregelTask(AnyStr(), "inner_2", (PULL, "inner_2")),), ), StateSnapshot( @@ -4838,6 +5155,11 @@ def outer_2(state: State): ), ] + if "shallow" in checkpointer_name: + expected_child_history = expected_child_history[:1] + + assert child_history == expected_child_history + # resume await app.ainvoke(None, config, debug=True) # test state w/ nested subgraph state (after resuming from interrupt) @@ -4862,13 +5184,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # test full history at the end actual_history = [c async for c in app.aget_state_history(config)] @@ -4896,13 +5222,17 @@ def outer_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ), StateSnapshot( values={"my_key": "hi my value here and there"}, @@ -5041,6 +5371,9 @@ def outer_2(state: State): parent_config=None, ), ] + if "shallow" in checkpointer_name: + expected_history = expected_history[:1] + assert actual_history == expected_history # test looking up parent state by checkpoint ID for actual_snapshot, expected_snapshot in zip(actual_history, expected_history): @@ -5144,13 +5477,17 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) child_state = await app.aget_state(outer_state.tasks[0].state) assert ( @@ -5186,13 +5523,17 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("child:"), - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("child:"), + "checkpoint_id": AnyStr(), + } } - }, + ), ).tasks[0] ) grandchild_state = await app.aget_state(child_state.tasks[0].state) @@ -5239,20 +5580,24 @@ def parent_2(state: State): "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("child:"): AnyStr(), - AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr(re.compile(r"child:.+|child1:")): AnyStr(), + } + ), + } } - }, + ), ) # get state with subgraphs assert await app.aget_state(config, subgraphs=True) == StateSnapshot( @@ -5321,22 +5666,28 @@ def parent_2(state: State): "langgraph_triggers": [AnyStr("start:child_1")], }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr(), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("child:"): AnyStr(), - AnyStr( - re.compile(r"child:.+|child1:") - ): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr(), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("child:"): AnyStr(), + AnyStr( + re.compile( + r"child:.+|child1:" + ) + ): AnyStr(), + } + ), + } } - }, + ), ), ), ), @@ -5365,16 +5716,20 @@ def parent_2(state: State): "langgraph_checkpoint_ns": AnyStr("child:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": AnyStr("child:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - {"": AnyStr(), AnyStr("child:"): AnyStr()} - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": AnyStr("child:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + {"": AnyStr(), AnyStr("child:"): AnyStr()} + ), + } } - }, + ), ), ), ), @@ -5394,13 +5749,17 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) # resume assert [c async for c in app.astream(None, config, subgraphs=True)] == [ @@ -5442,15 +5801,23 @@ def parent_2(state: State): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) ) + + if "shallow" in checkpointer_name: + return + # get outer graph history outer_history = [c async for c in app.aget_state_history(config)] assert ( @@ -6167,16 +6534,23 @@ async def edit(state: JokeState): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) assert actual_snapshot == expected_snapshot + if "shallow" in checkpointer_name: + return + # test full history actual_history = [c async for c in graph.aget_state_history(config)] expected_history = [ @@ -6449,13 +6823,17 @@ def get_first_in_list(): "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -6541,13 +6919,17 @@ def get_first_in_list(): "thread_id": "14", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -6590,19 +6972,23 @@ def get_first_in_list(): "langgraph_checkpoint_ns": AnyStr("weather_graph:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": AnyStr("weather_graph:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("weather_graph:"): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": AnyStr("weather_graph:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("weather_graph:"): AnyStr(), + } + ), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -6641,13 +7027,17 @@ def get_first_in_list(): "thread_id": "14", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -6698,19 +7088,23 @@ def get_first_in_list(): "langgraph_checkpoint_ns": AnyStr("weather_graph:"), }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "14", - "checkpoint_ns": AnyStr("weather_graph:"), - "checkpoint_id": AnyStr(), - "checkpoint_map": AnyDict( - { - "": AnyStr(), - AnyStr("weather_graph:"): AnyStr(), - } - ), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "14", + "checkpoint_ns": AnyStr("weather_graph:"), + "checkpoint_id": AnyStr(), + "checkpoint_map": AnyDict( + { + "": AnyStr(), + AnyStr("weather_graph:"): AnyStr(), + } + ), + } } - }, + ), tasks=(), ), ), diff --git a/langgraph-tests/tests/test_pregel.py b/langgraph-tests/tests/test_pregel.py index c5f95f9..27e855c 100644 --- a/langgraph-tests/tests/test_pregel.py +++ b/langgraph-tests/tests/test_pregel.py @@ -56,6 +56,7 @@ from tests.conftest import ( ALL_CHECKPOINTERS_SYNC, ALL_STORES_SYNC, + REGULAR_CHECKPOINTERS_SYNC, SHOULD_CHECK_SNAPSHOTS, ) from tests.messages import ( @@ -63,7 +64,7 @@ ) -@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +@pytest.mark.parametrize("checkpointer_name", REGULAR_CHECKPOINTERS_SYNC) def test_run_from_checkpoint_id_retains_previous_writes( request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture ) -> None: @@ -319,6 +320,10 @@ def reset(self): # both the pending write and the new write were applied, 1 + 2 + 3 = 6 assert graph.invoke(None, thread1) == {"value": 6} + if "shallow" in checkpointer_name: + assert len(list(checkpointer.list(thread1))) == 1 + return + # check all final checkpoints checkpoints = [c for c in checkpointer.list(thread1)] # we should have 3 @@ -630,6 +635,9 @@ def raise_if_above_10(input: int) -> int: assert state.values.get("total") == 5 assert state.next == () + if "shallow" in checkpointer_name: + return + assert len(list(app.get_state_history(thread_1, limit=1))) == 1 # list all checkpoints for thread 1 thread_1_history = [c for c in app.get_state_history(thread_1)] @@ -861,6 +869,11 @@ def qa(data: State) -> State: ] app_w_interrupt.update_state(config, {"docs": ["doc5"]}) + expected_parent_config = ( + None + if "shallow" in checkpointer_name + else list(app_w_interrupt.checkpointer.list(config, limit=2))[-1].config + ) assert app_w_interrupt.get_state(config) == StateSnapshot( values={ "query": "analyzed: query: what is weather in sf", @@ -868,8 +881,14 @@ def qa(data: State) -> State: }, tasks=(PregelTask(AnyStr(), "qa", (PULL, "qa")),), next=("qa",), - config=app_w_interrupt.checkpointer.get_tuple(config).config, - created_at=app_w_interrupt.checkpointer.get_tuple(config).checkpoint["ts"], + config={ + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } + }, + created_at=AnyStr(), metadata={ "parents": {}, "source": "update", @@ -877,7 +896,7 @@ def qa(data: State) -> State: "writes": {"retriever_one": {"docs": ["doc5"]}}, "thread_id": "2", }, - parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, + parent_config=expected_parent_config, ) assert [c for c in app_w_interrupt.stream(None, config, debug=1)] == [ @@ -1946,13 +1965,17 @@ class CustomParentState(TypedDict): "parents": {}, }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=(), ) @@ -2260,3 +2283,57 @@ def node2(state: State): graph.invoke({"foo": "abc"}, config) result = graph.invoke(Command(goto=["node2"]), config) assert result == {"foo": "abc|node-1|node-2|node-2"} + + +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) +def test_checkpoint_recovery(request: pytest.FixtureRequest, checkpointer_name: str): + """Test recovery from checkpoints after failures.""" + checkpointer = request.getfixturevalue(f"checkpointer_{checkpointer_name}") + + class State(TypedDict): + steps: Annotated[list[str], operator.add] + attempt: int # Track number of attempts + + def failing_node(state: State): + # Fail on first attempt, succeed on retry + if state["attempt"] == 1: + raise RuntimeError("Simulated failure") + return {"steps": ["node1"]} + + def second_node(state: State): + return {"steps": ["node2"]} + + builder = StateGraph(State) + builder.add_node("node1", failing_node) + builder.add_node("node2", second_node) + builder.add_edge(START, "node1") + builder.add_edge("node1", "node2") + + graph = builder.compile(checkpointer=checkpointer) + config = {"configurable": {"thread_id": "1"}} + + # First attempt should fail + with pytest.raises(RuntimeError): + graph.invoke({"steps": ["start"], "attempt": 1}, config) + + # Verify checkpoint state + state = graph.get_state(config) + assert state is not None + assert state.values == {"steps": ["start"], "attempt": 1} # input state saved + assert state.next == ("node1",) # Should retry failed node + assert "RuntimeError('Simulated failure')" in state.tasks[0].error + + # Retry with updated attempt count + result = graph.invoke({"steps": [], "attempt": 2}, config) + assert result == {"steps": ["start", "node1", "node2"], "attempt": 2} + + if "shallow" in checkpointer_name: + return + + # Verify checkpoint history shows both attempts + history = list(graph.get_state_history(config)) + assert len(history) == 6 # Initial + failed attempt + successful attempt + + # Verify the error was recorded in checkpoint + failed_checkpoint = next(c for c in history if c.tasks and c.tasks[0].error) + assert "RuntimeError('Simulated failure')" in failed_checkpoint.tasks[0].error diff --git a/langgraph-tests/tests/test_pregel_async.py b/langgraph-tests/tests/test_pregel_async.py index a5286c1..c4c36e8 100644 --- a/langgraph-tests/tests/test_pregel_async.py +++ b/langgraph-tests/tests/test_pregel_async.py @@ -61,6 +61,7 @@ from tests.conftest import ( ALL_CHECKPOINTERS_ASYNC, ALL_STORES_ASYNC, + REGULAR_CHECKPOINTERS_ASYNC, SHOULD_CHECK_SNAPSHOTS, awith_checkpointer, awith_store, @@ -172,22 +173,23 @@ async def tool_two_node(s: State) -> State: ) }, ] - assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ - { - "parents": {}, - "source": "loop", - "step": 0, - "writes": None, - "thread_id": "1", - }, - { - "parents": {}, - "source": "input", - "step": -1, - "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, - "thread_id": "1", - }, - ] + if "shallow" not in checkpointer_name: + assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] tup = await tool_two.checkpointer.aget_tuple(thread1) assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, @@ -215,9 +217,13 @@ async def tool_two_node(s: State) -> State: "writes": None, "thread_id": "1", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) # clear the interrupt and next tasks @@ -237,9 +243,13 @@ async def tool_two_node(s: State) -> State: "writes": {}, "thread_id": "1", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) @@ -349,22 +359,25 @@ class State(TypedDict): ) }, ] - assert [c.metadata async for c in tool_two.checkpointer.alist(thread1root)] == [ - { - "parents": {}, - "source": "loop", - "step": 0, - "writes": None, - "thread_id": "1", - }, - { - "parents": {}, - "source": "input", - "step": -1, - "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, - "thread_id": "1", - }, - ] + if "shallow" not in checkpointer_name: + assert [ + c.metadata async for c in tool_two.checkpointer.alist(thread1root) + ] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": None, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] tup = await tool_two.checkpointer.aget_tuple(thread1) assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️", "market": "DE"}, @@ -398,9 +411,13 @@ class State(TypedDict): "writes": None, "thread_id": "1", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1root, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in tool_two.checkpointer.alist(thread1root, limit=2) + ][-1].config + ), ) # clear the interrupt and next tasks @@ -420,9 +437,13 @@ class State(TypedDict): "writes": {}, "thread_id": "1", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1root, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [ + c async for c in tool_two.checkpointer.alist(thread1root, limit=2) + ][-1].config + ), ) @@ -513,22 +534,24 @@ def start(state: State) -> list[Union[Send, str]]: "my_key": "value ⛰️ one", "market": "DE", } - assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ - { - "parents": {}, - "source": "loop", - "step": 0, - "writes": {"tool_one": {"my_key": " one"}}, - "thread_id": "1", - }, - { - "parents": {}, - "source": "input", - "step": -1, - "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, - "thread_id": "1", - }, - ] + + if "shallow" not in checkpointer_name: + assert [c.metadata async for c in tool_two.checkpointer.alist(thread1)] == [ + { + "parents": {}, + "source": "loop", + "step": 0, + "writes": {"tool_one": {"my_key": " one"}}, + "thread_id": "1", + }, + { + "parents": {}, + "source": "input", + "step": -1, + "writes": {"__start__": {"my_key": "value ⛰️", "market": "DE"}}, + "thread_id": "1", + }, + ] tup = await tool_two.checkpointer.aget_tuple(thread1) assert await tool_two.aget_state(thread1) == StateSnapshot( values={"my_key": "value ⛰️ one", "market": "DE"}, @@ -556,9 +579,13 @@ def start(state: State) -> list[Union[Send, str]]: "writes": {"tool_one": {"my_key": " one"}}, "thread_id": "1", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) # clear the interrupt and next tasks await tool_two.aupdate_state(thread1, None) @@ -584,9 +611,13 @@ def start(state: State) -> list[Union[Send, str]]: "writes": {}, "thread_id": "1", }, - parent_config=[ - c async for c in tool_two.checkpointer.alist(thread1, limit=2) - ][-1].config, + parent_config=( + None + if "shallow" in checkpointer_name + else [c async for c in tool_two.checkpointer.alist(thread1, limit=2)][ + -1 + ].config + ), ) @@ -970,6 +1001,10 @@ def reset(self): # both the pending write and the new write were applied, 1 + 2 + 3 = 6 assert await graph.ainvoke(None, thread1) == {"value": 6} + if "shallow" in checkpointer_name: + assert len([c async for c in checkpointer.alist(thread1)]) == 1 + return + # check all final checkpoints checkpoints = [c async for c in checkpointer.alist(thread1)] # we should have 3 @@ -1127,7 +1162,7 @@ def reset(self): ) -@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +@pytest.mark.parametrize("checkpointer_name", REGULAR_CHECKPOINTERS_ASYNC) async def test_run_from_checkpoint_id_retains_previous_writes( request: pytest.FixtureRequest, checkpointer_name: str, mocker: MockerFixture ) -> None: @@ -1467,7 +1502,7 @@ async def graph(state: dict) -> dict: ] -@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +@pytest.mark.parametrize("checkpointer_name", REGULAR_CHECKPOINTERS_ASYNC) async def test_send_dedupe_on_resume(checkpointer_name: str) -> None: if not FF_SEND_V2: pytest.skip("Send deduplication is only available in Send V2") @@ -1920,13 +1955,17 @@ async def foo(call: ToolCall): "thread_id": "2", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -2003,13 +2042,17 @@ async def foo(call: ToolCall): "thread_id": "2", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=(), ) @@ -2079,13 +2122,17 @@ async def foo(call: ToolCall): "thread_id": "3", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "3", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -2188,13 +2235,17 @@ async def foo(call: ToolCall): "thread_id": "3", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "3", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "3", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -2388,13 +2439,17 @@ async def foo(call: ToolCall): "thread_id": "2", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=( PregelTask( id=AnyStr(), @@ -2471,13 +2526,17 @@ async def foo(call: ToolCall): "thread_id": "2", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "2", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "2", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=(), ) @@ -2700,6 +2759,9 @@ def raise_if_above_10(input: int) -> int: assert state.values.get("total") == 5 assert state.next == () + if "shallow" in checkpointer_name: + return + assert len([c async for c in app.aget_state_history(thread_1, limit=1)]) == 1 # list all checkpoints for thread 1 thread_1_history = [c async for c in app.aget_state_history(thread_1)] @@ -3172,13 +3234,17 @@ async def decider(data: State) -> str: "thread_id": "1", }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), ) async with assert_ctx_once(): @@ -3932,13 +3998,17 @@ class CustomParentState(TypedDict): "parents": {}, }, created_at=AnyStr(), - parent_config={ - "configurable": { - "thread_id": "1", - "checkpoint_ns": "", - "checkpoint_id": AnyStr(), + parent_config=( + None + if "shallow" in checkpointer_name + else { + "configurable": { + "thread_id": "1", + "checkpoint_ns": "", + "checkpoint_id": AnyStr(), + } } - }, + ), tasks=(), ) @@ -4268,3 +4338,49 @@ def node2(state: State): await graph.ainvoke({"foo": "abc"}, config) result = await graph.ainvoke(Command(goto=["node2"]), config) assert result == {"foo": "abc|node-1|node-2|node-2"} + + +@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) +async def test_checkpoint_recovery_async(checkpointer_name: str): + """Test recovery from checkpoints after failures with async nodes.""" + class State(TypedDict): + steps: Annotated[list[str], operator.add] + attempt: int # Track number of attempts + async def failing_node(state: State): + # Fail on first attempt, succeed on retry + if state["attempt"] == 1: + raise RuntimeError("Simulated failure") + await asyncio.sleep(0.1) # Simulate async work + return {"steps": ["node1"]} + async def second_node(state: State): + await asyncio.sleep(0.1) # Simulate async work + return {"steps": ["node2"]} + builder = StateGraph(State) + builder.add_node("node1", failing_node) + builder.add_node("node2", second_node) + builder.add_edge(START, "node1") + builder.add_edge("node1", "node2") + async with awith_checkpointer(checkpointer_name) as checkpointer: + graph = builder.compile(checkpointer=checkpointer) + config = {"configurable": {"thread_id": "1"}} + # First attempt should fail + with pytest.raises(RuntimeError): + await graph.ainvoke({"steps": ["start"], "attempt": 1}, config) + # Verify checkpoint state + state = await graph.aget_state(config) + assert state is not None + assert state.values == {"steps": ["start"], "attempt": 1} # input state saved + assert state.next == ("node1",) # Should retry failed node + # Retry with updated attempt count + result = await graph.ainvoke({"steps": [], "attempt": 2}, config) + assert result == {"steps": ["start", "node1", "node2"], "attempt": 2} + + if "shallow" in checkpointer_name: + return + + # Verify checkpoint history shows both attempts + history = [c async for c in graph.aget_state_history(config)] + assert len(history) == 6 # Initial + failed attempt + successful attempt + # Verify the error was recorded in checkpoint + failed_checkpoint = next(c for c in history if c.tasks and c.tasks[0].error) + assert "RuntimeError('Simulated failure')" in failed_checkpoint.tasks[0].error diff --git a/langgraph/checkpoint/mysql/__init__.py b/langgraph/checkpoint/mysql/__init__.py index 5930f4c..d32228b 100644 --- a/langgraph/checkpoint/mysql/__init__.py +++ b/langgraph/checkpoint/mysql/__init__.py @@ -2,16 +2,7 @@ import threading from collections.abc import Iterator, Sequence from contextlib import contextmanager -from typing import ( - Any, - ContextManager, - Generic, - Mapping, - Optional, - Protocol, - TypeVar, - Union, -) +from typing import Any, Generic, Optional from langchain_core.runnables import RunnableConfig @@ -32,34 +23,10 @@ ) from langgraph.checkpoint.serde.base import SerializerProtocol - -class DictCursor(ContextManager, Protocol): - """ - Protocol that a cursor should implement. - - Modeled after DBAPICursor from Typeshed. - """ - - def execute( - self, - operation: str, - parameters: Union[Sequence[Any], Mapping[str, Any]] = ..., - /, - ) -> object: ... - def executemany( - self, operation: str, seq_of_parameters: Sequence[Sequence[Any]], / - ) -> object: ... - def fetchone(self) -> Optional[dict[str, Any]]: ... - def fetchall(self) -> Sequence[dict[str, Any]]: ... - - -R = TypeVar("R", bound=DictCursor) # cursor type - - Conn = _internal.Conn # For backward compatibility -class BaseSyncMySQLSaver(BaseMySQLSaver, Generic[_internal.C, R]): +class BaseSyncMySQLSaver(BaseMySQLSaver, Generic[_internal.C, _internal.R]): lock: threading.Lock def __init__( @@ -73,11 +40,11 @@ def __init__( self.lock = threading.Lock() @staticmethod - def _get_cursor_from_connection(conn: _internal.C) -> R: + def _get_cursor_from_connection(conn: _internal.C) -> _internal.R: raise NotImplementedError @contextmanager - def _cursor(self, *, pipeline: bool = False) -> Iterator[R]: + def _cursor(self, *, pipeline: bool = False) -> Iterator[_internal.R]: """Create a database cursor as a context manager. Args: diff --git a/langgraph/checkpoint/mysql/_ainternal.py b/langgraph/checkpoint/mysql/_ainternal.py index 1e8f487..0e8b011 100644 --- a/langgraph/checkpoint/mysql/_ainternal.py +++ b/langgraph/checkpoint/mysql/_ainternal.py @@ -2,7 +2,43 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import AsyncContextManager, Generic, Protocol, TypeVar, Union, cast +from typing import ( + Any, + AsyncContextManager, + Generic, + Mapping, + Optional, + Protocol, + Sequence, + TypeVar, + Union, + cast, +) + + +class AsyncDictCursor(AsyncContextManager, Protocol): + """ + Protocol that a cursor should implement. + + Modeled after DBAPICursor from Typeshed. + """ + + async def execute( + self, + operation: str, + parameters: Union[Sequence[Any], Mapping[str, Any]] = ..., + /, + ) -> object: ... + async def executemany( + self, operation: str, seq_of_parameters: Sequence[Sequence[Any]], / + ) -> object: ... + async def fetchone(self) -> Optional[dict[str, Any]]: ... + async def fetchall(self) -> Sequence[dict[str, Any]]: ... + + def __aiter__(self) -> AsyncIterator[dict[str, Any]]: ... + + +R = TypeVar("R", bound=AsyncDictCursor) # cursor type class AIOMySQLConnection(AsyncContextManager, Protocol): diff --git a/langgraph/checkpoint/mysql/_internal.py b/langgraph/checkpoint/mysql/_internal.py index ac88ab2..e502bcf 100644 --- a/langgraph/checkpoint/mysql/_internal.py +++ b/langgraph/checkpoint/mysql/_internal.py @@ -2,7 +2,41 @@ from collections.abc import Callable, Iterator from contextlib import closing, contextmanager -from typing import ContextManager, Generic, Protocol, TypeVar, Union, cast +from typing import ( + Any, + ContextManager, + Generic, + Mapping, + Optional, + Protocol, + Sequence, + TypeVar, + Union, + cast, +) + + +class DictCursor(ContextManager, Protocol): + """ + Protocol that a cursor should implement. + + Modeled after DBAPICursor from Typeshed. + """ + + def execute( + self, + operation: str, + parameters: Union[Sequence[Any], Mapping[str, Any]] = ..., + /, + ) -> object: ... + def executemany( + self, operation: str, seq_of_parameters: Sequence[Sequence[Any]], / + ) -> object: ... + def fetchone(self) -> Optional[dict[str, Any]]: ... + def fetchall(self) -> Sequence[dict[str, Any]]: ... + + +R = TypeVar("R", bound=DictCursor) # cursor type class Connection(ContextManager, Protocol): diff --git a/langgraph/checkpoint/mysql/aio.py b/langgraph/checkpoint/mysql/aio.py index 2769057..c78b3b3 100644 --- a/langgraph/checkpoint/mysql/aio.py +++ b/langgraph/checkpoint/mysql/aio.py @@ -3,12 +3,13 @@ import urllib.parse from collections.abc import AsyncIterator, Iterator, Sequence from contextlib import asynccontextmanager -from typing import Any, Optional +from typing import Any, Optional, cast import aiomysql # type: ignore import pymysql import pymysql.connections from langchain_core.runnables import RunnableConfig +from typing_extensions import Self, override from langgraph.checkpoint.base import ( WRITES_IDX_MAP, @@ -20,6 +21,7 @@ ) from langgraph.checkpoint.mysql import _ainternal from langgraph.checkpoint.mysql.base import BaseMySQLSaver +from langgraph.checkpoint.mysql.shallow import BaseShallowAsyncMySQLSaver from langgraph.checkpoint.mysql.utils import ( deserialize_channel_values, deserialize_pending_sends, @@ -69,7 +71,7 @@ async def from_conn_string( conn_string: str, *, serde: Optional[SerializerProtocol] = None, - ) -> AsyncIterator["AIOMySQLSaver"]: + ) -> AsyncIterator[Self]: """Create a new AIOMySQLSaver instance from a connection string. Args: @@ -466,4 +468,42 @@ def put_writes( ).result() -__all__ = ["AIOMySQLSaver", "Conn"] +class ShallowAIOMySQLSaver( + BaseShallowAsyncMySQLSaver[aiomysql.Connection, aiomysql.DictCursor] +): + @classmethod + @asynccontextmanager + async def from_conn_string( + cls, + conn_string: str, + *, + serde: Optional[SerializerProtocol] = None, + ) -> AsyncIterator[Self]: + """Create a new ShallowAIOMySQLSaver instance from a connection string. + + Args: + conn_string (str): The MySQL connection info string. + + Returns: + ShallowAIOMySQLSaver: A new ShallowAIOMySQLSaver instance. + + Example: + conn_string=mysql+aiomysql://user:password@localhost/db?unix_socket=/path/to/socket + """ + async with aiomysql.connect( + **AIOMySQLSaver.parse_conn_string(conn_string), + autocommit=True, + ) as conn: + # This seems necessary until https://github.com/PyMySQL/PyMySQL/pull/1119 + # is merged into aiomysql. + await conn.set_charset(pymysql.connections.DEFAULT_CHARSET) + + yield cls(conn=conn, serde=serde) + + @override + @staticmethod + def _get_cursor_from_connection(conn: aiomysql.Connection) -> aiomysql.DictCursor: + return cast(aiomysql.DictCursor, conn.cursor(aiomysql.DictCursor)) + + +__all__ = ["AIOMySQLSaver", "ShallowAIOMySQLSaver", "Conn"] diff --git a/langgraph/checkpoint/mysql/pymysql.py b/langgraph/checkpoint/mysql/pymysql.py index 04a8026..17b76c9 100644 --- a/langgraph/checkpoint/mysql/pymysql.py +++ b/langgraph/checkpoint/mysql/pymysql.py @@ -10,6 +10,7 @@ from langgraph.checkpoint.mysql import BaseSyncMySQLSaver from langgraph.checkpoint.mysql import Conn as BaseConn +from langgraph.checkpoint.mysql.shallow import BaseShallowSyncMySQLSaver Conn = BaseConn[pymysql.Connection] # type: ignore @@ -48,7 +49,7 @@ def from_conn_string( PyMySQLSaver: A new PyMySQLSaver instance. Example: - conn_string=mysql+aiomysql://user:password@localhost/db?unix_socket=/path/to/socket + conn_string=mysql://user:password@localhost/db?unix_socket=/path/to/socket """ with pymysql.connect( **cls.parse_conn_string(conn_string), @@ -62,4 +63,34 @@ def _get_cursor_from_connection(conn: pymysql.Connection) -> DictCursor: return conn.cursor(DictCursor) -__all__ = ["PyMySQLSaver", "Conn"] +class ShallowPyMySQLSaver(BaseShallowSyncMySQLSaver): + @classmethod + @contextmanager + def from_conn_string( + cls, + conn_string: str, + ) -> Iterator[Self]: + """Create a new ShallowPyMySQLSaver instance from a connection string. + + Args: + conn_string (str): The MySQL connection info string. + + Returns: + ShallowPyMySQLSaver: A new ShallowPyMySQLSaver instance. + + Example: + conn_string=mysql://user:password@localhost/db?unix_socket=/path/to/socket + """ + with pymysql.connect( + **PyMySQLSaver.parse_conn_string(conn_string), + autocommit=True, + ) as conn: + yield cls(conn) + + @override + @staticmethod + def _get_cursor_from_connection(conn: pymysql.Connection) -> DictCursor: + return conn.cursor(DictCursor) + + +__all__ = ["PyMySQLSaver", "ShallowPyMySQLSaver", "Conn"] diff --git a/langgraph/checkpoint/mysql/shallow.py b/langgraph/checkpoint/mysql/shallow.py new file mode 100644 index 0000000..efaf7b6 --- /dev/null +++ b/langgraph/checkpoint/mysql/shallow.py @@ -0,0 +1,796 @@ +import asyncio +import json +import threading +from collections.abc import AsyncIterator, Iterator, Sequence +from contextlib import asynccontextmanager, contextmanager +from typing import Any, Generic, Optional + +from langchain_core.runnables import RunnableConfig + +from langgraph.checkpoint.base import ( + WRITES_IDX_MAP, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, +) +from langgraph.checkpoint.mysql import _ainternal, _internal +from langgraph.checkpoint.mysql.base import BaseMySQLSaver +from langgraph.checkpoint.mysql.utils import ( + deserialize_channel_values, + deserialize_pending_sends, + deserialize_pending_writes, +) +from langgraph.checkpoint.serde.base import SerializerProtocol +from langgraph.checkpoint.serde.types import TASKS + +""" +To add a new migration, add a new string to the MIGRATIONS list. +The position of the migration in the list is the version number. +""" +MIGRATIONS = [ + """CREATE TABLE IF NOT EXISTS checkpoint_migrations ( + v INTEGER PRIMARY KEY +);""", + """CREATE TABLE IF NOT EXISTS checkpoints ( + thread_id VARCHAR(150) NOT NULL, + checkpoint_ns VARCHAR(2000) NOT NULL DEFAULT '', + checkpoint_ns_hash BINARY(16) AS (UNHEX(MD5(checkpoint_ns))) STORED, + type VARCHAR(150), + checkpoint JSON NOT NULL, + metadata JSON NOT NULL DEFAULT ('{}'), + PRIMARY KEY (thread_id, checkpoint_ns_hash) +);""", + """CREATE TABLE IF NOT EXISTS checkpoint_blobs ( + thread_id VARCHAR(150) NOT NULL, + checkpoint_ns VARCHAR(2000) NOT NULL DEFAULT '', + checkpoint_ns_hash BINARY(16) AS (UNHEX(MD5(checkpoint_ns))) STORED, + channel VARCHAR(150) NOT NULL, + type VARCHAR(150) NOT NULL, + `blob` LONGBLOB, + PRIMARY KEY (thread_id, checkpoint_ns_hash, channel) +);""", + """CREATE TABLE IF NOT EXISTS checkpoint_writes ( + thread_id VARCHAR(150) NOT NULL, + checkpoint_ns VARCHAR(2000) NOT NULL DEFAULT '', + checkpoint_ns_hash BINARY(16) AS (UNHEX(MD5(checkpoint_ns))) STORED, + checkpoint_id VARCHAR(150) NOT NULL, + task_id VARCHAR(150) NOT NULL, + idx INTEGER NOT NULL, + channel VARCHAR(150) NOT NULL, + type VARCHAR(150), + `blob` LONGBLOB NOT NULL, + PRIMARY KEY (thread_id, checkpoint_ns_hash, checkpoint_id, task_id, idx) +);""", + """ + CREATE INDEX checkpoints_thread_id_idx ON checkpoints (thread_id); + """, + """ + CREATE INDEX checkpoint_blobs_thread_id_idx ON checkpoint_blobs (thread_id); + """, + """ + CREATE INDEX checkpoint_writes_thread_id_idx ON checkpoint_writes (thread_id); + """, +] + +SELECT_SQL = f""" +select + thread_id, + checkpoint, + checkpoint_ns, + metadata, + ( + select json_arrayagg(json_array(bl.channel, bl.type, bl.blob)) + from json_table( + json_keys(checkpoint, '$.channel_versions'), + '$[*]' columns (channel VARCHAR(150) PATH '$') + ) as channels + inner join checkpoint_blobs bl + on bl.channel = channels.channel + where bl.thread_id = checkpoints.thread_id + and bl.checkpoint_ns_hash = checkpoints.checkpoint_ns_hash + ) as channel_values, + ( + select + json_arrayagg(json_array(cw.task_id, cw.channel, cw.type, cw.blob, cw.idx)) + from checkpoint_writes cw + where cw.thread_id = checkpoints.thread_id + and cw.checkpoint_ns_hash = checkpoints.checkpoint_ns_hash + and cw.checkpoint_id = checkpoint->>'$.id' + ) as pending_writes, + ( + select json_arrayagg(json_array(cw.task_id, cw.type, cw.blob, cw.idx)) + from checkpoint_writes cw + where cw.thread_id = checkpoints.thread_id + and cw.checkpoint_ns_hash = checkpoints.checkpoint_ns_hash + and cw.channel = '{TASKS}' + ) as pending_sends +from checkpoints """ + +UPSERT_CHECKPOINT_BLOBS_SQL = """ + INSERT INTO checkpoint_blobs (thread_id, checkpoint_ns, channel, type, `blob`) + VALUES (%s, %s, %s, %s, %s) AS new + ON DUPLICATE KEY UPDATE + type = new.type, + `blob` = new.blob; +""" + +UPSERT_CHECKPOINTS_SQL = """ + INSERT INTO checkpoints (thread_id, checkpoint_ns, checkpoint, metadata) + VALUES (%s, %s, %s, %s) AS new + ON DUPLICATE KEY UPDATE + checkpoint = new.checkpoint, + metadata = new.metadata; +""" + +UPSERT_CHECKPOINT_WRITES_SQL = """ + INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, `blob`) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) AS new + ON DUPLICATE KEY UPDATE + channel = new.channel, + type = new.type, + `blob` = new.blob; +""" + +INSERT_CHECKPOINT_WRITES_SQL = """ + INSERT IGNORE INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, `blob`) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) +""" + + +def _dump_blobs( + serde: SerializerProtocol, + thread_id: str, + checkpoint_ns: str, + values: dict[str, Any], + versions: ChannelVersions, +) -> list[tuple[str, str, str, str, Optional[bytes]]]: + if not versions: + return [] + + return [ + ( + thread_id, + checkpoint_ns, + k, + *(serde.dumps_typed(values[k]) if k in values else ("empty", None)), + ) + for k in versions + ] + + +class BaseShallowSyncMySQLSaver(BaseMySQLSaver, Generic[_internal.C, _internal.R]): + """A checkpoint saver that uses MySQL to store checkpoints. + This checkpointer ONLY stores the most recent checkpoint and does NOT retain any history. + It is meant to be a light-weight drop-in replacement for the PostgresSaver that + supports most of the LangGraph persistence functionality with the exception of time travel. + """ + + SELECT_SQL = SELECT_SQL + MIGRATIONS = MIGRATIONS + UPSERT_CHECKPOINT_BLOBS_SQL = UPSERT_CHECKPOINT_BLOBS_SQL + UPSERT_CHECKPOINTS_SQL = UPSERT_CHECKPOINTS_SQL + UPSERT_CHECKPOINT_WRITES_SQL = UPSERT_CHECKPOINT_WRITES_SQL + INSERT_CHECKPOINT_WRITES_SQL = INSERT_CHECKPOINT_WRITES_SQL + + lock: threading.Lock + + def __init__( + self, + conn: _internal.Conn, + serde: Optional[SerializerProtocol] = None, + ) -> None: + super().__init__(serde=serde) + + self.conn = conn + self.lock = threading.Lock() + + @staticmethod + def _get_cursor_from_connection(conn: _internal.C) -> _internal.R: + raise NotImplementedError + + @contextmanager + def _cursor(self, *, pipeline: bool = False) -> Iterator[_internal.R]: + """Create a database cursor as a context manager. + + Args: + pipeline (bool): whether to use transaction context manager and handle concurrency + """ + with _internal.get_connection(self.conn) as conn: + if pipeline: + with self.lock: + conn.begin() + try: + with self._get_cursor_from_connection(conn) as cur: + yield cur + conn.commit() + except: + conn.rollback() + raise + else: + with self.lock, self._get_cursor_from_connection(conn) as cur: + yield cur + + def setup(self) -> None: + """Set up the checkpoint database asynchronously. + + This method creates the necessary tables in the MySQL database if they don't + already exist and runs database migrations. It MUST be called directly by the user + the first time checkpointer is used. + """ + with self._cursor() as cur: + cur.execute(self.MIGRATIONS[0]) + cur.execute("SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1") + row = cur.fetchone() + if row is None: + version = -1 + else: + version = row["v"] + for v, migration in zip( + range(version + 1, len(self.MIGRATIONS)), + self.MIGRATIONS[version + 1 :], + ): + cur.execute(migration) + cur.execute(f"INSERT INTO checkpoint_migrations (v) VALUES ({v})") + cur.execute("COMMIT") + + def list( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> Iterator[CheckpointTuple]: + """List checkpoints from the database. + + This method retrieves a list of checkpoint tuples from the MySQL database based + on the provided config. For shallow savers, this method returns a list with + ONLY the most recent checkpoint. + """ + where, args = self._search_where(config, filter, before) + query = self.SELECT_SQL + where + if limit: + query += f" LIMIT {limit}" + with self._cursor() as cur: + cur.execute(self.SELECT_SQL + where, args) + values = cur.fetchall() + for value in values: + checkpoint = self._load_checkpoint( + json.loads(value["checkpoint"]), + deserialize_channel_values(value["channel_values"]), + deserialize_pending_sends(value["pending_sends"]), + ) + yield CheckpointTuple( + config={ + "configurable": { + "thread_id": value["thread_id"], + "checkpoint_ns": value["checkpoint_ns"], + "checkpoint_id": checkpoint["id"], + } + }, + checkpoint=checkpoint, + metadata=self._load_metadata(value["metadata"]), + pending_writes=self._load_writes( + deserialize_pending_writes(value["pending_writes"]) + ), + ) + + def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Get a checkpoint tuple from the database. + + This method retrieves a checkpoint tuple from the MySQL database based on the + provided config (matching the thread ID in the config). + + Args: + config (RunnableConfig): The config to use for retrieving the checkpoint. + + Returns: + Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. + + Examples: + + Basic: + >>> config = {"configurable": {"thread_id": "1"}} + >>> checkpoint_tuple = memory.get_tuple(config) + >>> print(checkpoint_tuple) + CheckpointTuple(...) + + With timestamp: + + >>> config = { + ... "configurable": { + ... "thread_id": "1", + ... "checkpoint_ns": "", + ... "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875", + ... } + ... } + >>> checkpoint_tuple = memory.get_tuple(config) + >>> print(checkpoint_tuple) + CheckpointTuple(...) + """ # noqa + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + args = (thread_id, checkpoint_ns) + where = "WHERE thread_id = %s AND checkpoint_ns_hash = UNHEX(MD5(%s))" + + with self._cursor() as cur: + cur.execute( + self.SELECT_SQL + where, + args, + ) + values = cur.fetchall() + for value in values: + checkpoint = self._load_checkpoint( + json.loads(value["checkpoint"]), + deserialize_channel_values(value["channel_values"]), + deserialize_pending_sends(value["pending_sends"]), + ) + return CheckpointTuple( + config={ + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + }, + checkpoint=checkpoint, + metadata=self._load_metadata(value["metadata"]), + pending_writes=self._load_writes( + deserialize_pending_writes(value["pending_writes"]) + ), + ) + + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Save a checkpoint to the database. + + This method saves a checkpoint to the MySQL database. The checkpoint is associated + with the provided config. For shallow savers, this method saves ONLY the most recent + checkpoint and overwrites a previous checkpoint, if it exists. + + Args: + config (RunnableConfig): The config to associate with the checkpoint. + checkpoint (Checkpoint): The checkpoint to save. + metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + + Examples: + + >>> from langgraph.checkpoint.mysql import PyMySQLSaver + >>> DB_URI = "mysql://mysql:mysql@localhost:5432/mysql" + >>> with ShallowPyMySQLSaver.from_conn_string(DB_URI) as memory: + >>> config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} + >>> checkpoint = {"ts": "2024-05-04T06:32:42.235444+00:00", "id": "1ef4f797-8335-6428-8001-8a1503f9b875", "channel_values": {"key": "value"}} + >>> saved_config = memory.put(config, checkpoint, {"source": "input", "step": 1, "writes": {"key": "value"}}, {}) + >>> print(saved_config) + {'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '1ef4f797-8335-6428-8001-8a1503f9b875'}} + """ + configurable = config["configurable"].copy() + thread_id = configurable.pop("thread_id") + checkpoint_ns = configurable.pop("checkpoint_ns") + copy = checkpoint.copy() + next_config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + } + + with self._cursor(pipeline=True) as cur: + cur.execute( + """DELETE FROM checkpoint_writes + WHERE thread_id = %s AND checkpoint_ns_hash = UNHEX(MD5(%s)) AND checkpoint_id NOT IN (%s, %s)""", + ( + thread_id, + checkpoint_ns, + checkpoint["id"], + configurable.get("checkpoint_id", ""), + ), + ) + cur.executemany( + self.UPSERT_CHECKPOINT_BLOBS_SQL, + _dump_blobs( + self.serde, + thread_id, + checkpoint_ns, + copy.pop("channel_values"), # type: ignore[misc] + new_versions, + ), + ) + cur.execute( + self.UPSERT_CHECKPOINTS_SQL, + ( + thread_id, + checkpoint_ns, + json.dumps(self._dump_checkpoint(copy)), + self._dump_metadata(metadata), + ), + ) + return next_config + + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str, + ) -> None: + """Store intermediate writes linked to a checkpoint. + + This method saves intermediate writes associated with a checkpoint to the MySQL database. + + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (List[Tuple[str, Any]]): List of writes to store. + task_id (str): Identifier for the task creating the writes. + """ + query = ( + self.UPSERT_CHECKPOINT_WRITES_SQL + if all(w[0] in WRITES_IDX_MAP for w in writes) + else self.INSERT_CHECKPOINT_WRITES_SQL + ) + with self._cursor(pipeline=True) as cur: + cur.executemany( + query, + self._dump_writes( + config["configurable"]["thread_id"], + config["configurable"]["checkpoint_ns"], + config["configurable"]["checkpoint_id"], + task_id, + writes, + ), + ) + + +class BaseShallowAsyncMySQLSaver(BaseMySQLSaver, Generic[_ainternal.C, _ainternal.R]): + """A checkpoint saver that uses MySQL to store checkpoints asynchronously. + This checkpointer ONLY stores the most recent checkpoint and does NOT retain any history. + It is meant to be a light-weight drop-in replacement for the async MySQL saver that + supports most of the LangGraph persistence functionality with the exception of time travel. + """ + + SELECT_SQL = SELECT_SQL + MIGRATIONS = MIGRATIONS + UPSERT_CHECKPOINT_BLOBS_SQL = UPSERT_CHECKPOINT_BLOBS_SQL + UPSERT_CHECKPOINTS_SQL = UPSERT_CHECKPOINTS_SQL + UPSERT_CHECKPOINT_WRITES_SQL = UPSERT_CHECKPOINT_WRITES_SQL + INSERT_CHECKPOINT_WRITES_SQL = INSERT_CHECKPOINT_WRITES_SQL + lock: asyncio.Lock + + def __init__( + self, + conn: _ainternal.Conn, + serde: Optional[SerializerProtocol] = None, + ) -> None: + super().__init__(serde=serde) + + self.conn = conn + self.lock = asyncio.Lock() + self.loop = asyncio.get_running_loop() + + @staticmethod + def _get_cursor_from_connection(conn: _ainternal.C) -> _ainternal.R: + raise NotImplementedError + + @asynccontextmanager + async def _cursor(self, *, pipeline: bool = False) -> AsyncIterator[_ainternal.R]: + """Create a database cursor as a context manager. + + Args: + pipeline (bool): whether to use transaction context manager and handle concurrency + """ + async with _ainternal.get_connection(self.conn) as conn: + if pipeline: + async with self.lock: + await conn.begin() + try: + async with self._get_cursor_from_connection(conn) as cur: + yield cur + await conn.commit() + except: + await conn.rollback() + raise + else: + async with ( + self.lock, + self._get_cursor_from_connection(conn) as cur, + ): + yield cur + + async def setup(self) -> None: + """Set up the checkpoint database asynchronously. + This method creates the necessary tables in the MySQL database if they don't + already exist and runs database migrations. It MUST be called directly by the user + the first time checkpointer is used. + """ + async with self._cursor() as cur: + await cur.execute(self.MIGRATIONS[0]) + await cur.execute( + "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1" + ) + row = await cur.fetchone() + if row is None: + version = -1 + else: + version = row["v"] + for v, migration in zip( + range(version + 1, len(self.MIGRATIONS)), + self.MIGRATIONS[version + 1 :], + ): + await cur.execute(migration) + await cur.execute(f"INSERT INTO checkpoint_migrations (v) VALUES ({v})") + + async def alist( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> AsyncIterator[CheckpointTuple]: + """List checkpoints from the database asynchronously. + This method retrieves a list of checkpoint tuples from the MySQL database based + on the provided config. For shallow savers, this method returns a list with + ONLY the most recent checkpoint. + """ + where, args = self._search_where(config, filter, before) + query = self.SELECT_SQL + where + if limit: + query += f" LIMIT {limit}" + async with self._cursor() as cur: + await cur.execute(self.SELECT_SQL + where, args) + async for value in cur: + checkpoint = await asyncio.to_thread( + self._load_checkpoint, + json.loads(value["checkpoint"]), + deserialize_channel_values(value["channel_values"]), + deserialize_pending_sends(value["pending_sends"]), + ) + yield CheckpointTuple( + config={ + "configurable": { + "thread_id": value["thread_id"], + "checkpoint_ns": value["checkpoint_ns"], + "checkpoint_id": checkpoint["id"], + } + }, + checkpoint=checkpoint, + metadata=self._load_metadata(value["metadata"]), + pending_writes=await asyncio.to_thread( + self._load_writes, + deserialize_pending_writes(value["pending_writes"]), + ), + ) + + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Get a checkpoint tuple from the database asynchronously. + This method retrieves a checkpoint tuple from the MySQL database based on the + provided config (matching the thread ID in the config). + Args: + config (RunnableConfig): The config to use for retrieving the checkpoint. + Returns: + Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. + """ + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + args = (thread_id, checkpoint_ns) + where = "WHERE thread_id = %s AND checkpoint_ns_hash = UNHEX(MD5(%s))" + + async with self._cursor() as cur: + await cur.execute( + self.SELECT_SQL + where, + args, + ) + + async for value in cur: + checkpoint = await asyncio.to_thread( + self._load_checkpoint, + json.loads(value["checkpoint"]), + deserialize_channel_values(value["channel_values"]), + deserialize_pending_sends(value["pending_sends"]), + ) + return CheckpointTuple( + config={ + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + }, + checkpoint=checkpoint, + metadata=self._load_metadata(value["metadata"]), + pending_writes=await asyncio.to_thread( + self._load_writes, + deserialize_pending_writes(value["pending_writes"]), + ), + ) + + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Save a checkpoint to the database asynchronously. + This method saves a checkpoint to the MySQL database. The checkpoint is associated + with the provided config. For shallow savers, this method saves ONLY the most recent + checkpoint and overwrites a previous checkpoint, if it exists. + Args: + config (RunnableConfig): The config to associate with the checkpoint. + checkpoint (Checkpoint): The checkpoint to save. + metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + """ + configurable = config["configurable"].copy() + thread_id = configurable.pop("thread_id") + checkpoint_ns = configurable.pop("checkpoint_ns") + + copy = checkpoint.copy() + next_config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + } + + async with self._cursor(pipeline=True) as cur: + await cur.execute( + """DELETE FROM checkpoint_writes + WHERE thread_id = %s AND checkpoint_ns_hash = UNHEX(MD5(%s)) AND checkpoint_id NOT IN (%s, %s)""", + ( + thread_id, + checkpoint_ns, + checkpoint["id"], + configurable.get("checkpoint_id", ""), + ), + ) + await cur.executemany( + self.UPSERT_CHECKPOINT_BLOBS_SQL, + _dump_blobs( + self.serde, + thread_id, + checkpoint_ns, + copy.pop("channel_values"), # type: ignore[misc] + new_versions, + ), + ) + await cur.execute( + self.UPSERT_CHECKPOINTS_SQL, + ( + thread_id, + checkpoint_ns, + json.dumps(self._dump_checkpoint(copy)), + self._dump_metadata(metadata), + ), + ) + return next_config + + async def aput_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str, + ) -> None: + """Store intermediate writes linked to a checkpoint asynchronously. + This method saves intermediate writes associated with a checkpoint to the database. + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair. + task_id (str): Identifier for the task creating the writes. + """ + query = ( + self.UPSERT_CHECKPOINT_WRITES_SQL + if all(w[0] in WRITES_IDX_MAP for w in writes) + else self.INSERT_CHECKPOINT_WRITES_SQL + ) + params = await asyncio.to_thread( + self._dump_writes, + config["configurable"]["thread_id"], + config["configurable"]["checkpoint_ns"], + config["configurable"]["checkpoint_id"], + task_id, + writes, + ) + async with self._cursor(pipeline=True) as cur: + await cur.executemany(query, params) + + def list( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> Iterator[CheckpointTuple]: + """List checkpoints from the database. + This method retrieves a list of checkpoint tuples from the MySQL database based + on the provided config. For shallow savers, this method returns a list with + ONLY the most recent checkpoint. + """ + aiter_ = self.alist(config, filter=filter, before=before, limit=limit) + while True: + try: + yield asyncio.run_coroutine_threadsafe( + anext(aiter_), # noqa: F821 + self.loop, + ).result() + except StopAsyncIteration: + break + + def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Get a checkpoint tuple from the database. + This method retrieves a checkpoint tuple from the MySQL database based on the + provided config (matching the thread ID in the config). + Args: + config (RunnableConfig): The config to use for retrieving the checkpoint. + Returns: + Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. + """ + try: + # check if we are in the main thread, only bg threads can block + # we don't check in other methods to avoid the overhead + if asyncio.get_running_loop() is self.loop: + raise asyncio.InvalidStateError( + "Synchronous calls to asynchronous shallow savers are only allowed from a " + "different thread. From the main thread, use the async interface." + "For example, use `await checkpointer.aget_tuple(...)` or `await " + "graph.ainvoke(...)`." + ) + except RuntimeError: + pass + return asyncio.run_coroutine_threadsafe( + self.aget_tuple(config), self.loop + ).result() + + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Save a checkpoint to the database. + This method saves a checkpoint to the MySQL database. The checkpoint is associated + with the provided config. For shallow savers, this method saves ONLY the most recent + checkpoint and overwrites a previous checkpoint, if it exists. + Args: + config (RunnableConfig): The config to associate with the checkpoint. + checkpoint (Checkpoint): The checkpoint to save. + metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + """ + return asyncio.run_coroutine_threadsafe( + self.aput(config, checkpoint, metadata, new_versions), self.loop + ).result() + + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str, + ) -> None: + """Store intermediate writes linked to a checkpoint. + This method saves intermediate writes associated with a checkpoint to the database. + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair. + task_id (str): Identifier for the task creating the writes. + """ + return asyncio.run_coroutine_threadsafe( + self.aput_writes(config, writes, task_id), self.loop + ).result() diff --git a/tests/test_async.py b/tests/test_async.py index c5ce1c5..7818c52 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -1,7 +1,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from copy import deepcopy -from typing import Any +from typing import Any, Union from uuid import uuid4 import aiomysql # type: ignore @@ -15,7 +15,7 @@ create_checkpoint, empty_checkpoint, ) -from langgraph.checkpoint.mysql.aio import AIOMySQLSaver +from langgraph.checkpoint.mysql.aio import AIOMySQLSaver, ShallowAIOMySQLSaver from langgraph.checkpoint.serde.types import TASKS from tests.conftest import DEFAULT_BASE_URI @@ -77,10 +77,41 @@ async def _base_saver() -> AsyncIterator[AIOMySQLSaver]: @asynccontextmanager -async def _saver(name: str) -> AsyncIterator[AIOMySQLSaver]: +async def _shallow_saver() -> AsyncIterator[ShallowAIOMySQLSaver]: + """Fixture for shallow connection mode testing.""" + database = f"test_{uuid4().hex[:16]}" + # create unique db + async with await aiomysql.connect( + **AIOMySQLSaver.parse_conn_string(DEFAULT_BASE_URI), + autocommit=True, + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"CREATE DATABASE {database}") + try: + async with ShallowAIOMySQLSaver.from_conn_string( + DEFAULT_BASE_URI + database + ) as checkpointer: + await checkpointer.setup() + yield checkpointer + finally: + # drop unique db + async with await aiomysql.connect( + **AIOMySQLSaver.parse_conn_string(DEFAULT_BASE_URI), autocommit=True + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"DROP DATABASE {database}") + + +@asynccontextmanager +async def _saver( + name: str, +) -> AsyncIterator[Union[AIOMySQLSaver, ShallowAIOMySQLSaver]]: if name == "base": async with _base_saver() as saver: yield saver + elif name == "shallow": + async with _shallow_saver() as saver: + yield saver elif name == "pool": async with _pool_saver() as saver: yield saver @@ -137,7 +168,7 @@ def test_data() -> dict[str, Any]: } -@pytest.mark.parametrize("saver_name", ["base", "pool"]) +@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"]) async def test_asearch(saver_name: str, test_data: dict[str, Any]) -> None: async with _saver(saver_name) as saver: configs = test_data["configs"] @@ -182,7 +213,7 @@ async def test_asearch(saver_name: str, test_data: dict[str, Any]) -> None: } == {"", "inner"} -@pytest.mark.parametrize("saver_name", ["base", "pool"]) +@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"]) async def test_null_chars(saver_name: str, test_data: dict[str, Any]) -> None: async with _saver(saver_name) as saver: config = await saver.aput( @@ -197,7 +228,7 @@ async def test_null_chars(saver_name: str, test_data: dict[str, Any]) -> None: ].metadata["my_key"] == "abc" -@pytest.mark.parametrize("saver_name", ["base", "pool"]) +@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"]) async def test_write_and_read_pending_writes_and_sends( saver_name: str, test_data: dict[str, Any] ) -> None: @@ -226,7 +257,7 @@ async def test_write_and_read_pending_writes_and_sends( assert result.checkpoint["pending_sends"] == ["w3v"] -@pytest.mark.parametrize("saver_name", ["base", "pool"]) +@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"]) @pytest.mark.parametrize( "channel_values", [ @@ -261,7 +292,7 @@ async def test_write_and_read_channel_values( assert result.checkpoint["channel_values"] == channel_values -@pytest.mark.parametrize("saver_name", ["base", "pool"]) +@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"]) async def test_write_and_read_pending_writes(saver_name: str) -> None: async with _saver(saver_name) as saver: config: RunnableConfig = { @@ -292,7 +323,7 @@ async def test_write_and_read_pending_writes(saver_name: str) -> None: ] -@pytest.mark.parametrize("saver_name", ["base", "pool"]) +@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"]) async def test_write_with_different_checkpoint_ns_inserts( saver_name: str, ) -> None: @@ -317,7 +348,7 @@ async def test_write_with_different_checkpoint_ns_inserts( assert len(results) == 2 -@pytest.mark.parametrize("saver_name", ["base", "pool"]) +@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"]) async def test_write_with_same_checkpoint_ns_updates( saver_name: str, ) -> None: diff --git a/tests/test_sync.py b/tests/test_sync.py index 1c2075e..e514360 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -1,7 +1,7 @@ from collections.abc import Iterator from contextlib import closing, contextmanager from copy import deepcopy -from typing import Any, cast +from typing import Any, Union, cast from uuid import uuid4 import pymysql @@ -15,7 +15,7 @@ create_checkpoint, empty_checkpoint, ) -from langgraph.checkpoint.mysql.pymysql import PyMySQLSaver +from langgraph.checkpoint.mysql.pymysql import PyMySQLSaver, ShallowPyMySQLSaver from langgraph.checkpoint.serde.types import TASKS from tests.conftest import DEFAULT_BASE_URI, get_pymysql_sqlalchemy_pool @@ -100,10 +100,39 @@ def _base_saver() -> Iterator[PyMySQLSaver]: @contextmanager -def _saver(name: str) -> Iterator[PyMySQLSaver]: +def _shallow_saver() -> Iterator[ShallowPyMySQLSaver]: + """Fixture for regular connection mode testing with a shallow checkpointer.""" + database = f"test_{uuid4().hex[:16]}" + # create unique db + with pymysql.connect( + **PyMySQLSaver.parse_conn_string(DEFAULT_BASE_URI), autocommit=True + ) as conn: + with conn.cursor() as cursor: + cursor.execute(f"CREATE DATABASE {database}") + try: + # yield checkpointer + with ShallowPyMySQLSaver.from_conn_string( + DEFAULT_BASE_URI + database + ) as checkpointer: + checkpointer.setup() + yield checkpointer + finally: + # drop unique db + with pymysql.connect( + **PyMySQLSaver.parse_conn_string(DEFAULT_BASE_URI), autocommit=True + ) as conn: + with conn.cursor() as cursor: + cursor.execute(f"DROP DATABASE {database}") + + +@contextmanager +def _saver(name: str) -> Iterator[Union[PyMySQLSaver, ShallowPyMySQLSaver]]: if name == "base": with _base_saver() as saver: yield saver + if name == "shallow": + with _shallow_saver() as saver: + yield saver elif name == "sqlalchemy_pool": with _sqlalchemy_pool_saver() as saver: yield saver @@ -163,7 +192,9 @@ def test_data() -> dict[str, Any]: } -@pytest.mark.parametrize("saver_name", ["base", "sqlalchemy_pool", "callable"]) +@pytest.mark.parametrize( + "saver_name", ["base", "sqlalchemy_pool", "callable", "shallow"] +) def test_search(saver_name: str, test_data: dict[str, Any]) -> None: with _saver(saver_name) as saver: configs = test_data["configs"] @@ -206,7 +237,9 @@ def test_search(saver_name: str, test_data: dict[str, Any]) -> None: } == {"", "inner"} -@pytest.mark.parametrize("saver_name", ["base", "sqlalchemy_pool", "callable"]) +@pytest.mark.parametrize( + "saver_name", ["base", "sqlalchemy_pool", "callable", "shallow"] +) def test_null_chars(saver_name: str, test_data: dict[str, Any]) -> None: with _saver(saver_name) as saver: config = saver.put( @@ -222,7 +255,9 @@ def test_null_chars(saver_name: str, test_data: dict[str, Any]) -> None: ) -@pytest.mark.parametrize("saver_name", ["base", "sqlalchemy_pool", "callable"]) +@pytest.mark.parametrize( + "saver_name", ["base", "sqlalchemy_pool", "callable", "shallow"] +) def test_write_and_read_pending_writes_and_sends( saver_name: str, test_data: dict[str, Any] ) -> None: @@ -252,7 +287,9 @@ def test_write_and_read_pending_writes_and_sends( assert result.checkpoint["pending_sends"] == ["w3v"] -@pytest.mark.parametrize("saver_name", ["base", "sqlalchemy_pool", "callable"]) +@pytest.mark.parametrize( + "saver_name", ["base", "sqlalchemy_pool", "callable", "shallow"] +) @pytest.mark.parametrize( "channel_values", [ @@ -287,7 +324,9 @@ def test_write_and_read_channel_values( assert result.checkpoint["channel_values"] == channel_values -@pytest.mark.parametrize("saver_name", ["base", "sqlalchemy_pool", "callable"]) +@pytest.mark.parametrize( + "saver_name", ["base", "sqlalchemy_pool", "callable", "shallow"] +) def test_write_and_read_pending_writes(saver_name: str) -> None: with _saver(saver_name) as saver: config: RunnableConfig = { @@ -318,7 +357,9 @@ def test_write_and_read_pending_writes(saver_name: str) -> None: ] -@pytest.mark.parametrize("saver_name", ["base", "sqlalchemy_pool", "callable"]) +@pytest.mark.parametrize( + "saver_name", ["base", "sqlalchemy_pool", "callable", "shallow"] +) def test_write_with_different_checkpoint_ns_inserts(saver_name: str) -> None: with _saver(saver_name) as saver: config1: RunnableConfig = { @@ -341,7 +382,9 @@ def test_write_with_different_checkpoint_ns_inserts(saver_name: str) -> None: assert len(results) == 2 -@pytest.mark.parametrize("saver_name", ["base", "sqlalchemy_pool", "callable"]) +@pytest.mark.parametrize( + "saver_name", ["base", "sqlalchemy_pool", "callable", "shallow"] +) def test_write_with_same_checkpoint_ns_updates(saver_name: str) -> None: with _saver(saver_name) as saver: config: RunnableConfig = {