Skip to content

Commit

Permalink
Move aiomysql references out of _ainternal.py.
Browse files Browse the repository at this point in the history
This lays the groundwork to apply the new shallow checkpointer logic
using a single shallow.py file like how upstream is doing today.
  • Loading branch information
tjni committed Jan 2, 2025
1 parent 1cf8e50 commit 61a512f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 12 deletions.
52 changes: 41 additions & 11 deletions langgraph/checkpoint/mysql/_ainternal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,55 @@

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Union
from typing import AsyncContextManager, Generic, Protocol, TypeVar, Union, cast

import aiomysql # type: ignore
import pymysql.connections

Conn = Union[aiomysql.Connection, aiomysql.Pool]
class AIOMySQLConnection(AsyncContextManager, Protocol):
"""From aiomysql package."""

async def begin(self) -> None:
"""Begin transaction."""
...

async def commit(self) -> None:
"""Commit changes to stable storage."""
...

async def rollback(self) -> None:
"""Roll back the current transaction."""
...

async def set_charset(self, charset: str) -> None:
"""Sets the character set for the current connection"""
...


C = TypeVar("C", bound=AIOMySQLConnection) # connection type
COut = TypeVar("COut", bound=AIOMySQLConnection, covariant=True) # connection type


class AIOMySQLPool(Protocol, Generic[COut]):
"""From aiomysql package."""

def acquire(self) -> COut:
"""Gets a connection from the connection pool."""
...


Conn = Union[C, AIOMySQLPool[C]]


@asynccontextmanager
async def get_connection(
conn: Conn,
) -> AsyncIterator[aiomysql.Connection]:
if isinstance(conn, aiomysql.Connection):
yield conn
elif isinstance(conn, aiomysql.Pool):
async with conn.acquire() as _conn:
conn: Conn[C],
) -> AsyncIterator[C]:
if hasattr(conn, "cursor"):
yield cast(C, conn)
elif hasattr(conn, "acquire"):
async with cast(AIOMySQLPool[C], conn).acquire() as _conn:
# This seems necessary until https://github.com/PyMySQL/PyMySQL/pull/1119
# is merged into aiomysql.
await _conn.set_charset(pymysql.connections.DEFAULT_CHARSET)
await _conn.set_charset("utf8mb4")
yield _conn
else:
raise TypeError(f"Invalid connection type: {type(conn)}")
2 changes: 1 addition & 1 deletion langgraph/checkpoint/mysql/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from langgraph.checkpoint.serde.base import SerializerProtocol

Conn = _ainternal.Conn # For backward compatibility
Conn = _ainternal.Conn[aiomysql.Connection] # For backward compatibility


class AIOMySQLSaver(BaseMySQLSaver):
Expand Down

0 comments on commit 61a512f

Please sign in to comment.