Skip to content

Commit

Permalink
Abstract PyMySQL specifics into a separate file.
Browse files Browse the repository at this point in the history
This lays some groundwork for perhaps supporting multiple drivers in the
future.
  • Loading branch information
tjni committed Sep 25, 2024
1 parent f51a990 commit 3dc9d06
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 59 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ There is currently no support for other drivers.
> When manually creating MySQL connections and passing them to `PyMySQLSaver` or `AIOMySQLSaver`, make sure to include `autocommit=True`.
```python
from langgraph.checkpoint.mysql import PyMySQLSaver
from langgraph.checkpoint.mysql.pymysql import PyMySQLSaver

write_config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
read_config = {"configurable": {"thread_id": "1"}}
Expand Down
124 changes: 69 additions & 55 deletions langgraph/checkpoint/mysql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import json
import threading
import urllib.parse
from contextlib import contextmanager
from typing import Any, Iterator, Optional, Protocol, Sequence, Union
from typing import (
Any,
ContextManager,
Generic,
Iterator,
Mapping,
Optional,
Protocol,
Sequence,
TypeVar,
Union,
cast,
)

import pymysql
import pymysql.constants.ER
import pymysql.cursors
from langchain_core.runnables import RunnableConfig

from langgraph.checkpoint.base import (
Expand All @@ -28,66 +36,72 @@
from langgraph.checkpoint.serde.base import SerializerProtocol


class ConnectionPool(Protocol):
class DictCursor(Protocol):
"""
Protocol that a cursor should implement.
Modeled after DBAPICursor from Typeshed.
"""

def execute(
self,
operation: str,
parameters: Union[Sequence[Any], Mapping[str, Any]] = ...,
/,
) -> object: ...
def executemany(
self, operation: str, seq_of_parameters: Sequence[Sequence[Any]], /
) -> object: ...
def fetchone(self) -> Optional[dict[str, Any]]: ...
def fetchall(self) -> Sequence[dict[str, Any]]: ...


C = TypeVar("C", bound=ContextManager, covariant=True) # connecion type
R = TypeVar("R", bound=ContextManager, covariant=True) # cursor type


class ConnectionPool(Protocol, Generic[C]):
"""Protocol that a MySQL connection pool should implement."""

def get_connection(self) -> pymysql.Connection:
def get_connection(self) -> C:
"""Gets a connection from the connection pool."""
...


Conn = Union[pymysql.Connection, ConnectionPool]
Conn = Union[C, ConnectionPool[C]]


@contextmanager
def _get_connection(conn: Conn) -> Iterator[pymysql.Connection]:
if isinstance(conn, pymysql.Connection):
yield conn
def _get_connection(conn: Conn[C]) -> Iterator[C]:
if hasattr(conn, "cursor"):
yield cast(C, conn)
elif hasattr(conn, "get_connection"):
with conn.get_connection() as conn:
yield conn
with cast(ConnectionPool[C], conn).get_connection() as _conn:
yield _conn
else:
raise TypeError(f"Invalid connection type: {type(conn)}")


class PyMySQLSaver(BaseMySQLSaver):
class BaseSyncMySQLSaver(BaseMySQLSaver, Generic[C, R]):
lock: threading.Lock

def __init__(
self,
conn: Conn,
conn: Conn[C],
serde: Optional[SerializerProtocol] = None,
) -> None:
super().__init__(serde=serde)

self.conn = conn
self.lock = threading.Lock()

@classmethod
@contextmanager
def from_conn_string(
cls,
conn_string: str,
) -> Iterator["PyMySQLSaver"]:
"""Create a new PyMySQLSaver instance from a connection string.
@staticmethod
def _is_no_such_table_error(e: Exception) -> bool:
raise NotImplementedError

Args:
conn_string (str): The MySQL connection info string.
Returns:
PyMySQLSaver: A new PyMySQLSaver instance.
"""
parsed = urllib.parse.urlparse(conn_string)

with pymysql.connect(
host=parsed.hostname,
user=parsed.username,
password=parsed.password or "",
database=parsed.path[1:],
port=parsed.port or 3306,
autocommit=True,
) as conn:
yield PyMySQLSaver(conn)
@contextmanager
def _cursor(self) -> Iterator[R]:
raise NotImplementedError

def setup(self) -> None:
"""Set up the checkpoint database asynchronously.
Expand All @@ -96,7 +110,8 @@ def setup(self) -> None:
already exist and runs database migrations. It MUST be called directly by the user
the first time checkpointer is used.
"""
with self._cursor() as cur:
with self._cursor() as cur_:
cur = cast(DictCursor, cur_)
try:
cur.execute(
"SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1"
Expand All @@ -106,8 +121,8 @@ def setup(self) -> None:
version = -1
else:
version = row["v"]
except pymysql.ProgrammingError as e:
if e.args[0] != pymysql.constants.ER.NO_SUCH_TABLE:
except Exception as e:
if not self._is_no_such_table_error(e):
raise
version = -1
for v, migration in zip(
Expand Down Expand Up @@ -162,9 +177,11 @@ def list(
if limit:
query += f" LIMIT {limit}"
# if we change this to use .stream() we need to make sure to close the cursor
with self._cursor() as cur:
with self._cursor() as cur_:
cur = cast(DictCursor, cur_)
cur.execute(query, args)
for value in cur:
values = cur.fetchall()
for value in values:
yield CheckpointTuple(
{
"configurable": {
Expand Down Expand Up @@ -238,13 +255,14 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
args = (thread_id, checkpoint_ns)
where = "WHERE thread_id = %s AND checkpoint_ns = %s ORDER BY checkpoint_id DESC LIMIT 1"

with self._cursor() as cur:
with self._cursor() as cur_:
cur = cast(DictCursor, cur_)
cur.execute(
self.SELECT_SQL + where,
args,
)

for value in cur:
values = cur.fetchall()
for value in values:
return CheckpointTuple(
{
"configurable": {
Expand Down Expand Up @@ -320,7 +338,8 @@ def put(
}
}

with self._cursor() as cur:
with self._cursor() as cur_:
cur = cast(DictCursor, cur_)
cur.executemany(
self.UPSERT_CHECKPOINT_BLOBS_SQL,
self._dump_blobs(
Expand Down Expand Up @@ -363,7 +382,8 @@ def put_writes(
if all(w[0] in WRITES_IDX_MAP for w in writes)
else self.INSERT_CHECKPOINT_WRITES_SQL
)
with self._cursor() as cur:
with self._cursor() as cur_:
cur = cast(DictCursor, cur_)
cur.executemany(
query,
self._dump_writes(
Expand All @@ -374,9 +394,3 @@ def put_writes(
writes,
),
)

@contextmanager
def _cursor(self) -> Iterator[pymysql.cursors.DictCursor]:
with _get_connection(self.conn) as conn:
with self.lock, conn.cursor(pymysql.cursors.DictCursor) as cur:
yield cur
53 changes: 53 additions & 0 deletions langgraph/checkpoint/mysql/pymysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import urllib.parse
from collections.abc import Iterator
from contextlib import contextmanager

import pymysql
import pymysql.constants.ER
from pymysql.cursors import DictCursor
from typing_extensions import Self, override

from langgraph.checkpoint.mysql import BaseSyncMySQLSaver, _get_connection


class PyMySQLSaver(BaseSyncMySQLSaver[pymysql.Connection, DictCursor]):
@classmethod
@contextmanager
def from_conn_string(
cls,
conn_string: str,
) -> Iterator[Self]:
"""Create a new PyMySQLSaver instance from a connection string.
Args:
conn_string (str): The MySQL connection info string.
Returns:
PyMySQLSaver: A new PyMySQLSaver instance.
"""
parsed = urllib.parse.urlparse(conn_string)

with pymysql.connect(
host=parsed.hostname,
user=parsed.username,
password=parsed.password or "",
database=parsed.path[1:],
port=parsed.port or 3306,
autocommit=True,
) as conn:
yield cls(conn)

@override
@staticmethod
def _is_no_such_table_error(e: Exception) -> bool:
return (
isinstance(e, pymysql.ProgrammingError)
and e.args[0] == pymysql.constants.ER.NO_SUCH_TABLE
)

@override
@contextmanager
def _cursor(self) -> Iterator[DictCursor]:
with _get_connection(self.conn) as conn:
with self.lock, conn.cursor(DictCursor) as cur:
yield cur
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langgraph-checkpoint-mysql"
version = "1.0.1"
version = "1.0.2"
description = "Library with a MySQL implementation of LangGraph checkpoint saver."
authors = ["Theodore Ni <[email protected]>"]
license = "MIT"
Expand All @@ -13,6 +13,7 @@ python = "^3.9.0,<4.0"
langgraph-checkpoint = "^1.0.11"
pymysql = { version = "^1.1.1", optional = true }
aiomysql = { version = "^0.2.0", optional = true }
typing-extensions = "^4.12.2"

[tool.poetry.extras]
pymysql = ["pymysql"]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
create_checkpoint,
empty_checkpoint,
)
from langgraph.checkpoint.mysql import PyMySQLSaver
from langgraph.checkpoint.mysql.pymysql import PyMySQLSaver
from langgraph.checkpoint.serde.types import TASKS


Expand Down

0 comments on commit 3dc9d06

Please sign in to comment.