Skip to content

Commit

Permalink
Merge pull request #50 from tjni/more-pool-support
Browse files Browse the repository at this point in the history
Simpler connection pool support.
  • Loading branch information
tjni authored Jan 21, 2025
2 parents a44e78e + a320f7f commit c8de897
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 970 deletions.
570 changes: 20 additions & 550 deletions langgraph-tests/tests/__snapshots__/test_large_cases.ambr

Large diffs are not rendered by default.

182 changes: 8 additions & 174 deletions langgraph-tests/tests/__snapshots__/test_pregel.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge[pymysql_sqlalchemy_pool]
# name: test_in_one_fan_out_state_graph_waiting_edge[pymysql_pool]
'''
graph TD;
__start__ --> rewrite_query;
Expand All @@ -38,19 +38,6 @@

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge[pymysql_callable]
'''
graph TD;
__start__ --> rewrite_query;
analyzer_one --> retriever_one;
qa --> __end__;
retriever_one --> qa;
retriever_two --> qa;
rewrite_query --> analyzer_one;
rewrite_query --> retriever_two;

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql]
'''
graph TD;
Expand Down Expand Up @@ -121,77 +108,7 @@
'type': 'object',
})
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_sqlalchemy_pool]
'''
graph TD;
__start__ --> rewrite_query;
analyzer_one --> retriever_one;
qa --> __end__;
retriever_one --> qa;
retriever_two --> qa;
rewrite_query --> analyzer_one;
rewrite_query -.-> retriever_two;

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_sqlalchemy_pool].1
dict({
'definitions': dict({
'InnerObject': dict({
'properties': dict({
'yo': dict({
'title': 'Yo',
'type': 'integer',
}),
}),
'required': list([
'yo',
]),
'title': 'InnerObject',
'type': 'object',
}),
}),
'properties': dict({
'inner': dict({
'$ref': '#/definitions/InnerObject',
}),
'query': dict({
'title': 'Query',
'type': 'string',
}),
}),
'required': list([
'query',
'inner',
]),
'title': 'Input',
'type': 'object',
})
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_sqlalchemy_pool].2
dict({
'properties': dict({
'answer': dict({
'title': 'Answer',
'type': 'string',
}),
'docs': dict({
'items': dict({
'type': 'string',
}),
'title': 'Docs',
'type': 'array',
}),
}),
'required': list([
'answer',
'docs',
]),
'title': 'Output',
'type': 'object',
})
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_callable]
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_pool]
'''
graph TD;
__start__ --> rewrite_query;
Expand All @@ -204,7 +121,7 @@

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_callable].1
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_pool].1
dict({
'definitions': dict({
'InnerObject': dict({
Expand Down Expand Up @@ -238,7 +155,7 @@
'type': 'object',
})
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_callable].2
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic1[pymysql_pool].2
dict({
'properties': dict({
'answer': dict({
Expand Down Expand Up @@ -401,7 +318,7 @@
'type': 'object',
})
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_sqlalchemy_pool]
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_pool]
'''
graph TD;
__start__ --> rewrite_query;
Expand All @@ -414,7 +331,7 @@

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_sqlalchemy_pool].1
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_pool].1
dict({
'$defs': dict({
'InnerObject': dict({
Expand Down Expand Up @@ -448,77 +365,7 @@
'type': 'object',
})
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_sqlalchemy_pool].2
dict({
'properties': dict({
'answer': dict({
'title': 'Answer',
'type': 'string',
}),
'docs': dict({
'items': dict({
'type': 'string',
}),
'title': 'Docs',
'type': 'array',
}),
}),
'required': list([
'answer',
'docs',
]),
'title': 'Output',
'type': 'object',
})
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_callable]
'''
graph TD;
__start__ --> rewrite_query;
analyzer_one --> retriever_one;
qa --> __end__;
retriever_one --> qa;
retriever_two --> qa;
rewrite_query --> analyzer_one;
rewrite_query -.-> retriever_two;

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_callable].1
dict({
'$defs': dict({
'InnerObject': dict({
'properties': dict({
'yo': dict({
'title': 'Yo',
'type': 'integer',
}),
}),
'required': list([
'yo',
]),
'title': 'InnerObject',
'type': 'object',
}),
}),
'properties': dict({
'inner': dict({
'$ref': '#/$defs/InnerObject',
}),
'query': dict({
'title': 'Query',
'type': 'string',
}),
}),
'required': list([
'query',
'inner',
]),
'title': 'Input',
'type': 'object',
})
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_callable].2
# name: test_in_one_fan_out_state_graph_waiting_edge_custom_state_class_pydantic2[pymysql_pool].2
dict({
'properties': dict({
'answer': dict({
Expand Down Expand Up @@ -624,20 +471,7 @@

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_via_branch[pymysql_sqlalchemy_pool]
'''
graph TD;
__start__ --> rewrite_query;
analyzer_one --> retriever_one;
qa --> __end__;
retriever_one --> qa;
retriever_two --> qa;
rewrite_query --> analyzer_one;
rewrite_query -.-> retriever_two;

'''
# ---
# name: test_in_one_fan_out_state_graph_waiting_edge_via_branch[pymysql_callable]
# name: test_in_one_fan_out_state_graph_waiting_edge_via_branch[pymysql_pool]
'''
graph TD;
__start__ --> rewrite_query;
Expand Down
71 changes: 13 additions & 58 deletions langgraph-tests/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import asynccontextmanager, closing
from typing import AsyncIterator, Optional, cast
from contextlib import asynccontextmanager
from typing import AsyncIterator, Optional
from uuid import UUID, uuid4

import aiomysql # type: ignore
Expand All @@ -9,7 +9,7 @@
from langchain_core import __version__ as core_version
from packaging import version
from pytest_mock import MockerFixture
from sqlalchemy import Pool, create_pool_from_url
from sqlalchemy import Engine, create_engine

from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.mysql.aio import AIOMySQLSaver, ShallowAIOMySQLSaver
Expand All @@ -28,9 +28,9 @@
SHOULD_CHECK_SNAPSHOTS = IS_LANGCHAIN_CORE_030_OR_GREATER


def get_pymysql_sqlalchemy_pool(uri: str) -> Pool:
def get_pymysql_sqlalchemy_engine(uri: str) -> Engine:
updated_uri = uri.replace("mysql://", "mysql+pymysql://")
return create_pool_from_url(updated_uri)
return create_engine(updated_uri)


@pytest.fixture
Expand Down Expand Up @@ -91,39 +91,16 @@ def checkpointer_pymysql_shallow():


@pytest.fixture(scope="function")
def checkpointer_pymysql_sqlalchemy_pool():
def checkpointer_pymysql_pool():
database = f"test_{uuid4().hex[:16]}"

# create unique db
with pymysql.connect(**PyMySQLSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE {database}")
try:
checkpointer = PyMySQLSaver(get_pymysql_sqlalchemy_pool(DEFAULT_MYSQL_URI + database))
checkpointer.setup()
yield checkpointer
finally:
# drop unique db
with pymysql.connect(**PyMySQLSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE {database}")


@pytest.fixture(scope="function")
def checkpointer_pymysql_callable():
database = f"test_{uuid4().hex[:16]}"

# create unique db
with pymysql.connect(**PyMySQLSaver.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE {database}")
try:
pool = get_pymysql_sqlalchemy_pool(DEFAULT_MYSQL_URI + database)

def callable() -> pymysql.Connection:
return cast(pymysql.Connection, closing(pool.connect()))

checkpointer = PyMySQLSaver(callable)
pool = get_pymysql_sqlalchemy_engine(DEFAULT_MYSQL_URI + database)
checkpointer = PyMySQLSaver(pool.raw_connection)
checkpointer.setup()
yield checkpointer
finally:
Expand Down Expand Up @@ -255,27 +232,7 @@ def store_pymysql():


@pytest.fixture(scope="function")
def store_pymysql_sqlalchemy_pool():
database = f"test_{uuid4().hex[:16]}"

# create unique db
with pymysql.connect(**PyMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE {database}")
try:
# yield store
store = PyMySQLStore(get_pymysql_sqlalchemy_pool(DEFAULT_MYSQL_URI + database))
store.setup()
yield store
finally:
# drop unique db
with pymysql.connect(**PyMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True) as conn:
with conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE {database}")


@pytest.fixture(scope="function")
def store_pymysql_callable():
def store_pymysql_pool():
database = f"test_{uuid4().hex[:16]}"

# create unique db
Expand All @@ -284,9 +241,8 @@ def store_pymysql_callable():
cursor.execute(f"CREATE DATABASE {database}")
try:
# yield store
pool = get_pymysql_sqlalchemy_pool(DEFAULT_MYSQL_URI + database)
callable = lambda: cast(pymysql.Connection, closing(pool.connect()))
store = PyMySQLStore(callable)
engine = get_pymysql_sqlalchemy_engine(DEFAULT_MYSQL_URI + database)
store = PyMySQLStore(engine.raw_connection)
store.setup()
yield store
finally:
Expand Down Expand Up @@ -386,8 +342,7 @@ async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]:
SHALLOW_CHECKPOINTERS_SYNC = ["pymysql_shallow"]
REGULAR_CHECKPOINTERS_SYNC = [
"pymysql",
"pymysql_sqlalchemy_pool",
"pymysql_callable"
"pymysql_pool",
]
ALL_CHECKPOINTERS_SYNC = [
*REGULAR_CHECKPOINTERS_SYNC,
Expand All @@ -399,5 +354,5 @@ async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]:
*REGULAR_CHECKPOINTERS_ASYNC,
*SHALLOW_CHECKPOINTERS_ASYNC,
]
ALL_STORES_SYNC = ["pymysql", "pymysql_sqlalchemy_pool", "pymysql_callable"]
ALL_STORES_SYNC = ["pymysql", "pymysql_pool"]
ALL_STORES_ASYNC = ["aiomysql", "aiomysql_pool"]
Loading

0 comments on commit c8de897

Please sign in to comment.