Skip to content

Commit

Permalink
Add back hash of checkpoint_ns to primary keys.
Browse files Browse the repository at this point in the history
This is used during in the upsert queries. I've also added a test to
verify this behavior.
  • Loading branch information
tjni committed Dec 22, 2024
1 parent 65161d9 commit 9fa1fb1
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 3 deletions.
30 changes: 27 additions & 3 deletions langgraph/checkpoint/mysql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,13 @@
"""
CREATE INDEX checkpoints_checkpoint_id_idx ON checkpoints (checkpoint_id);
""",
# The following three migrations were contributed to buy more room for
# nested subgraphs, since that contributes to checkpoint_ns length.
"ALTER TABLE checkpoints MODIFY COLUMN `checkpoint_ns` VARCHAR(255) NOT NULL DEFAULT '';",
"ALTER TABLE checkpoint_blobs MODIFY COLUMN `checkpoint_ns` VARCHAR(255) NOT NULL DEFAULT '';",
"ALTER TABLE checkpoint_writes MODIFY COLUMN `checkpoint_ns` VARCHAR(255) NOT NULL DEFAULT '';",
# The following three migrations drastically increase the size of the
# checkpoint_ns field to support deeply nested subgraphs.
"""
ALTER TABLE checkpoints
DROP PRIMARY KEY,
Expand All @@ -90,6 +94,26 @@
ADD PRIMARY KEY (thread_id, checkpoint_id, task_id, idx),
MODIFY COLUMN `checkpoint_ns` VARCHAR(2000) NOT NULL DEFAULT '';
""",
# The following three migrations restore checkpoint_ns as part of the
# primary key, but hashed to fit into the primary key size limit.
"""
ALTER TABLE checkpoints
ADD COLUMN checkpoint_ns_hash BINARY(16) AS (UNHEX(MD5(checkpoint_ns))) STORED,
DROP PRIMARY KEY,
ADD PRIMARY KEY (thread_id, checkpoint_ns_hash, checkpoint_id);
""",
"""
ALTER TABLE checkpoint_blobs
ADD COLUMN checkpoint_ns_hash BINARY(16) AS (UNHEX(MD5(checkpoint_ns))) STORED,
DROP PRIMARY KEY,
ADD PRIMARY KEY (thread_id, checkpoint_ns_hash, channel, version);
""",
"""
ALTER TABLE checkpoint_writes
ADD COLUMN checkpoint_ns_hash BINARY(16) AS (UNHEX(MD5(checkpoint_ns))) STORED,
DROP PRIMARY KEY,
ADD PRIMARY KEY (thread_id, checkpoint_ns_hash, checkpoint_id, task_id, idx);
""",
]

SELECT_SQL = f"""
Expand All @@ -114,7 +138,7 @@
) as channel_versions
inner join checkpoint_blobs bl
on bl.thread_id = checkpoints.thread_id
and bl.checkpoint_ns = checkpoints.checkpoint_ns
and bl.checkpoint_ns_hash = checkpoints.checkpoint_ns_hash
and bl.channel = channel_versions.channel
and bl.version = channel_versions.version
) as channel_values,
Expand All @@ -123,14 +147,14 @@
json_arrayagg(json_array(cw.task_id, cw.channel, cw.type, cw.blob, cw.idx))
from checkpoint_writes cw
where cw.thread_id = checkpoints.thread_id
and cw.checkpoint_ns = checkpoints.checkpoint_ns
and cw.checkpoint_ns_hash = checkpoints.checkpoint_ns_hash
and cw.checkpoint_id = checkpoints.checkpoint_id
) as pending_writes,
(
select json_arrayagg(json_array(cw.task_id, cw.type, cw.blob, cw.idx))
from checkpoint_writes cw
where cw.thread_id = checkpoints.thread_id
and cw.checkpoint_ns = checkpoints.checkpoint_ns
and cw.checkpoint_ns_hash = checkpoints.checkpoint_ns_hash
and cw.checkpoint_id = checkpoints.parent_checkpoint_id
and cw.channel = '{TASKS}'
) as pending_sends
Expand Down
26 changes: 26 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from copy import deepcopy
from typing import Any
from uuid import uuid4

Expand Down Expand Up @@ -289,3 +290,28 @@ async def test_write_and_read_pending_writes(saver_name: str) -> None:
(task_id, "channel2", [1, 2, 3]),
(task_id, "channel3", None),
]


@pytest.mark.parametrize("saver_name", ["base", "pool"])
async def test_write_with_different_checkpoint_ns_does_an_update(
saver_name: str,
) -> None:
async with _saver(saver_name) as saver:
config1: RunnableConfig = {
"configurable": {
"thread_id": "thread-6",
"checkpoint_id": "6",
"checkpoint_ns": "first",
}
}
config2 = deepcopy(config1)
config2["configurable"]["checkpoint_ns"] = "second"

chkpnt = empty_checkpoint()

await saver.aput(config1, chkpnt, {}, {})
await saver.aput(config2, chkpnt, {}, {})

results = [c async for c in saver.alist({})]

assert len(results) == 2
24 changes: 24 additions & 0 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Iterator
from contextlib import closing, contextmanager
from copy import deepcopy
from typing import Any, cast
from uuid import uuid4

Expand Down Expand Up @@ -315,3 +316,26 @@ def test_write_and_read_pending_writes(saver_name: str) -> None:
(task_id, "channel2", [1, 2, 3]),
(task_id, "channel3", None),
]


@pytest.mark.parametrize("saver_name", ["base", "sqlalchemy_pool", "callable"])
def test_write_with_different_checkpoint_ns_does_an_update(saver_name: str) -> None:
with _saver(saver_name) as saver:
config1: RunnableConfig = {
"configurable": {
"thread_id": "thread-6",
"checkpoint_id": "6",
"checkpoint_ns": "first",
}
}
config2 = deepcopy(config1)
config2["configurable"]["checkpoint_ns"] = "second"

chkpnt = empty_checkpoint()

saver.put(config1, chkpnt, {}, {})
saver.put(config2, chkpnt, {}, {})

results = list(saver.list({}))

assert len(results) == 2

0 comments on commit 9fa1fb1

Please sign in to comment.