diff --git a/langgraph/checkpoint/mysql/base.py b/langgraph/checkpoint/mysql/base.py index 727ee52..78ba717 100644 --- a/langgraph/checkpoint/mysql/base.py +++ b/langgraph/checkpoint/mysql/base.py @@ -69,9 +69,13 @@ """ CREATE INDEX checkpoints_checkpoint_id_idx ON checkpoints (checkpoint_id); """, + # The following three migrations were contributed to buy more room for + # nested subgraphs, since that contributes to checkpoint_ns length. "ALTER TABLE checkpoints MODIFY COLUMN `checkpoint_ns` VARCHAR(255) NOT NULL DEFAULT '';", "ALTER TABLE checkpoint_blobs MODIFY COLUMN `checkpoint_ns` VARCHAR(255) NOT NULL DEFAULT '';", "ALTER TABLE checkpoint_writes MODIFY COLUMN `checkpoint_ns` VARCHAR(255) NOT NULL DEFAULT '';", + # The following three migrations drastically increase the size of the + # checkpoint_ns field to support deeply nested subgraphs. """ ALTER TABLE checkpoints DROP PRIMARY KEY, @@ -90,6 +94,26 @@ ADD PRIMARY KEY (thread_id, checkpoint_id, task_id, idx), MODIFY COLUMN `checkpoint_ns` VARCHAR(2000) NOT NULL DEFAULT ''; """, + # The following three migrations restore checkpoint_ns as part of the + # primary key, but hashed to fit into the primary key size limit. + """ + ALTER TABLE checkpoints + ADD COLUMN checkpoint_ns_hash BINARY(16) AS (UNHEX(MD5(checkpoint_ns))) STORED, + DROP PRIMARY KEY, + ADD PRIMARY KEY (thread_id, checkpoint_ns_hash, checkpoint_id); + """, + """ + ALTER TABLE checkpoint_blobs + ADD COLUMN checkpoint_ns_hash BINARY(16) AS (UNHEX(MD5(checkpoint_ns))) STORED, + DROP PRIMARY KEY, + ADD PRIMARY KEY (thread_id, checkpoint_ns_hash, channel, version); + """, + """ + ALTER TABLE checkpoint_writes + ADD COLUMN checkpoint_ns_hash BINARY(16) AS (UNHEX(MD5(checkpoint_ns))) STORED, + DROP PRIMARY KEY, + ADD PRIMARY KEY (thread_id, checkpoint_ns_hash, checkpoint_id, task_id, idx); + """, ] SELECT_SQL = f""" @@ -114,7 +138,7 @@ ) as channel_versions inner join checkpoint_blobs bl on bl.thread_id = checkpoints.thread_id - and bl.checkpoint_ns = checkpoints.checkpoint_ns + and bl.checkpoint_ns_hash = checkpoints.checkpoint_ns_hash and bl.channel = channel_versions.channel and bl.version = channel_versions.version ) as channel_values, @@ -123,14 +147,14 @@ json_arrayagg(json_array(cw.task_id, cw.channel, cw.type, cw.blob, cw.idx)) from checkpoint_writes cw where cw.thread_id = checkpoints.thread_id - and cw.checkpoint_ns = checkpoints.checkpoint_ns + and cw.checkpoint_ns_hash = checkpoints.checkpoint_ns_hash and cw.checkpoint_id = checkpoints.checkpoint_id ) as pending_writes, ( select json_arrayagg(json_array(cw.task_id, cw.type, cw.blob, cw.idx)) from checkpoint_writes cw where cw.thread_id = checkpoints.thread_id - and cw.checkpoint_ns = checkpoints.checkpoint_ns + and cw.checkpoint_ns_hash = checkpoints.checkpoint_ns_hash and cw.checkpoint_id = checkpoints.parent_checkpoint_id and cw.channel = '{TASKS}' ) as pending_sends diff --git a/tests/test_async.py b/tests/test_async.py index ec6e1d1..5a17e85 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -1,5 +1,6 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager +from copy import deepcopy from typing import Any from uuid import uuid4 @@ -289,3 +290,28 @@ async def test_write_and_read_pending_writes(saver_name: str) -> None: (task_id, "channel2", [1, 2, 3]), (task_id, "channel3", None), ] + + +@pytest.mark.parametrize("saver_name", ["base", "pool"]) +async def test_write_with_different_checkpoint_ns_does_an_update( + saver_name: str, +) -> None: + async with _saver(saver_name) as saver: + config1: RunnableConfig = { + "configurable": { + "thread_id": "thread-6", + "checkpoint_id": "6", + "checkpoint_ns": "first", + } + } + config2 = deepcopy(config1) + config2["configurable"]["checkpoint_ns"] = "second" + + chkpnt = empty_checkpoint() + + await saver.aput(config1, chkpnt, {}, {}) + await saver.aput(config2, chkpnt, {}, {}) + + results = [c async for c in saver.alist({})] + + assert len(results) == 2 diff --git a/tests/test_sync.py b/tests/test_sync.py index debc1bf..9760ea6 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -1,5 +1,6 @@ from collections.abc import Iterator from contextlib import closing, contextmanager +from copy import deepcopy from typing import Any, cast from uuid import uuid4 @@ -315,3 +316,26 @@ def test_write_and_read_pending_writes(saver_name: str) -> None: (task_id, "channel2", [1, 2, 3]), (task_id, "channel3", None), ] + + +@pytest.mark.parametrize("saver_name", ["base", "sqlalchemy_pool", "callable"]) +def test_write_with_different_checkpoint_ns_does_an_update(saver_name: str) -> None: + with _saver(saver_name) as saver: + config1: RunnableConfig = { + "configurable": { + "thread_id": "thread-6", + "checkpoint_id": "6", + "checkpoint_ns": "first", + } + } + config2 = deepcopy(config1) + config2["configurable"]["checkpoint_ns"] = "second" + + chkpnt = empty_checkpoint() + + saver.put(config1, chkpnt, {}, {}) + saver.put(config2, chkpnt, {}, {}) + + results = list(saver.list({})) + + assert len(results) == 2