Skip to content

Commit

Permalink
Merge pull request #27 from tjni/wfh/provisioning_cursor_reuse
Browse files Browse the repository at this point in the history
Fixup initial provisioning of aio postgres db
  • Loading branch information
tjni authored Dec 10, 2024
2 parents 2cc0c7a + 3b2692e commit 7b59157
Show file tree
Hide file tree
Showing 6 changed files with 502 additions and 391 deletions.
2 changes: 1 addition & 1 deletion langgraph-tests/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def checkpointer_pymysql():
# drop unique db
with pymysql.connect(**PyMySQLSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"drop databASE {database}")
cursor.execute(f"DROP DATABASE {database}")


@asynccontextmanager
Expand Down
42 changes: 13 additions & 29 deletions langgraph/checkpoint/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Protocol,
TypeVar,
Union,
cast,
)

from langchain_core.runnables import RunnableConfig
Expand All @@ -34,7 +33,7 @@
from langgraph.checkpoint.serde.base import SerializerProtocol


class DictCursor(Protocol):
class DictCursor(ContextManager, Protocol):
"""
Protocol that a cursor should implement.
Expand All @@ -54,7 +53,7 @@ def fetchone(self) -> Optional[dict[str, Any]]: ...
def fetchall(self) -> Sequence[dict[str, Any]]: ...


R = TypeVar("R", bound=ContextManager, covariant=True) # cursor type
R = TypeVar("R", bound=DictCursor) # cursor type


Conn = _internal.Conn # For backward compatibility
Expand All @@ -73,10 +72,6 @@ def __init__(
self.conn = conn
self.lock = threading.Lock()

@staticmethod
def _is_no_such_table_error(e: Exception) -> bool:
raise NotImplementedError

@staticmethod
def _get_cursor_from_connection(conn: _internal.C) -> R:
raise NotImplementedError
Expand Down Expand Up @@ -110,21 +105,14 @@ def setup(self) -> None:
already exist and runs database migrations. It MUST be called directly by the user
the first time checkpointer is used.
"""
with self._cursor() as cur_:
cur = cast(DictCursor, cur_)
try:
cur.execute(
"SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1"
)
row = cur.fetchone()
if row is None:
version = -1
else:
version = row["v"]
except Exception as e:
if not self._is_no_such_table_error(e):
raise
with self._cursor() as cur:
cur.execute(self.MIGRATIONS[0])
cur.execute("SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1")
row = cur.fetchone()
if row is None:
version = -1
else:
version = row["v"]
for v, migration in zip(
range(version + 1, len(self.MIGRATIONS)),
self.MIGRATIONS[version + 1 :],
Expand Down Expand Up @@ -177,8 +165,7 @@ def list(
if limit:
query += f" LIMIT {limit}"
# if we change this to use .stream() we need to make sure to close the cursor
with self._cursor() as cur_:
cur = cast(DictCursor, cur_)
with self._cursor() as cur:
cur.execute(query, args)
values = cur.fetchall()
for value in values:
Expand Down Expand Up @@ -257,8 +244,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
args = (thread_id, checkpoint_ns)
where = "WHERE thread_id = %s AND checkpoint_ns = %s ORDER BY checkpoint_id DESC LIMIT 1"

with self._cursor() as cur_:
cur = cast(DictCursor, cur_)
with self._cursor() as cur:
cur.execute(
self.SELECT_SQL + where,
args,
Expand Down Expand Up @@ -342,8 +328,7 @@ def put(
}
}

with self._cursor(pipeline=True) as cur_:
cur = cast(DictCursor, cur_)
with self._cursor(pipeline=True) as cur:
cur.executemany(
self.UPSERT_CHECKPOINT_BLOBS_SQL,
self._dump_blobs(
Expand Down Expand Up @@ -386,8 +371,7 @@ def put_writes(
if all(w[0] in WRITES_IDX_MAP for w in writes)
else self.INSERT_CHECKPOINT_WRITES_SQL
)
with self._cursor(pipeline=True) as cur_:
cur = cast(DictCursor, cur_)
with self._cursor(pipeline=True) as cur:
cur.executemany(
query,
self._dump_writes(
Expand Down
21 changes: 8 additions & 13 deletions langgraph/checkpoint/mysql/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import aiomysql # type: ignore
import pymysql
import pymysql.connections
import pymysql.constants.ER
from langchain_core.runnables import RunnableConfig

from langgraph.checkpoint.base import (
Expand Down Expand Up @@ -100,19 +99,15 @@ async def setup(self) -> None:
the first time checkpointer is used.
"""
async with self._cursor() as cur:
try:
await cur.execute(
"SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1"
)
row = await cur.fetchone()
if row is None:
version = -1
else:
version = row["v"]
except pymysql.ProgrammingError as e:
if e.args[0] != pymysql.constants.ER.NO_SUCH_TABLE:
raise
await cur.execute(self.MIGRATIONS[0])
await cur.execute(
"SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1"
)
row = await cur.fetchone()
if row is None:
version = -1
else:
version = row["v"]
for v, migration in zip(
range(version + 1, len(self.MIGRATIONS)),
self.MIGRATIONS[version + 1 :],
Expand Down
8 changes: 0 additions & 8 deletions langgraph/checkpoint/mysql/pymysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,6 @@ def from_conn_string(
) as conn:
yield cls(conn)

@override
@staticmethod
def _is_no_such_table_error(e: Exception) -> bool:
return (
isinstance(e, pymysql.ProgrammingError)
and e.args[0] == pymysql.constants.ER.NO_SUCH_TABLE
)

@override
@staticmethod
def _get_cursor_from_connection(conn: pymysql.Connection) -> DictCursor:
Expand Down
Loading

0 comments on commit 7b59157

Please sign in to comment.