Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse the session between requests #21

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.test
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
TEST=True
REDIS_URL=redis://localhost:6379/0
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/postgres
DATABASE_POOL_CLASS=NullPool
2 changes: 2 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class Settings(BaseSettings):
REDIS_URL: str = "redis://"
DATABASE_URL: str = "psql://postgres:"
DATABASE_POOL_CLASS: str = "AsyncAdaptedQueuePool"
DATABASE_POOL_SIZE: int = 10
TEST: bool = False


settings = Settings()
45 changes: 31 additions & 14 deletions app/datasources/db/database.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from functools import wraps
from collections.abc import AsyncGenerator
from functools import cache, wraps

from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool
from sqlmodel.ext.asyncio.session import AsyncSession

Expand All @@ -12,15 +13,34 @@
AsyncAdaptedQueuePool.__name__: AsyncAdaptedQueuePool,
}

engine = create_async_engine(
settings.DATABASE_URL,
echo=True,
future=True,
poolclass=pool_classes.get(settings.DATABASE_POOL_CLASS),
)


def get_database_session(func):
@cache
def get_engine() -> AsyncEngine:
"""
Establish connection to database
:return:
"""
if settings.TEST:
return create_async_engine(
settings.DATABASE_URL,
future=True,
poolclass=NullPool,
)
else:
return create_async_engine(
settings.DATABASE_URL,
future=True,
poolclass=pool_classes.get(settings.DATABASE_POOL_CLASS),
pool_size=settings.DATABASE_POOL_SIZE,
)


async def get_database_session() -> AsyncGenerator:
async with AsyncSession(get_engine(), expire_on_commit=False) as session:
yield session


def database_session(func):
"""
Decorator that creates a new database session for the given function

Expand All @@ -30,16 +50,13 @@ def get_database_session(func):

@wraps(func)
async def wrapper(*args, **kwargs):
async with AsyncSession(engine) as session:
async with AsyncSession(get_engine(), expire_on_commit=False) as session:
try:
return await func(*args, **kwargs, session=session)
except Exception as e:
# Rollback errors
await session.rollback()
logging.error(f"Error occurred: {e}")
raise
finally:
# Ensure that session is closed
await session.close()

return wrapper
12 changes: 8 additions & 4 deletions app/routers/contracts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Sequence

from fastapi import APIRouter
from fastapi import APIRouter, Depends

from sqlmodel.ext.asyncio.session import AsyncSession

from ..datasources.db.database import get_database_session
from ..datasources.db.models import Contract
from ..services.contract import ContractService

Expand All @@ -12,6 +15,7 @@


@router.get("", response_model=Sequence[Contract])
async def list_contracts() -> Sequence[Contract]:
contract_service = ContractService()
return await contract_service.get_all()
async def list_contracts(
session: AsyncSession = Depends(get_database_session),
) -> Sequence[Contract]:
return await ContractService.get_all(session)
3 changes: 0 additions & 3 deletions app/services/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession

from app.datasources.db.database import get_database_session
from app.datasources.db.models import Contract


class ContractService:

@staticmethod
@get_database_session
async def get_all(session: AsyncSession) -> Sequence[Contract]:
"""
Get all contracts
Expand All @@ -22,7 +20,6 @@ async def get_all(session: AsyncSession) -> Sequence[Contract]:
return result.all()

@staticmethod
@get_database_session
async def create(contract: Contract, session: AsyncSession) -> Contract:
"""
Create a new contract
Expand Down
4 changes: 2 additions & 2 deletions app/tests/db/db_async_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from sqlmodel import SQLModel

from app.datasources.db.database import engine
from app.datasources.db.database import get_engine


class DbAsyncConn(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.engine = engine
self.engine = get_engine()
# Create the database tables
async with self.engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
Expand Down
4 changes: 2 additions & 2 deletions app/tests/db/test_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession

from app.datasources.db.database import get_database_session
from app.datasources.db.database import database_session
from app.datasources.db.models import Contract
from app.tests.db.db_async_conn import DbAsyncConn


class TestModel(DbAsyncConn):
@get_database_session
@database_session
async def test_contract(self, session: AsyncSession):
contract = Contract(address=b"a", name="A Test Contracts")
session.add(contract)
Expand Down
8 changes: 6 additions & 2 deletions app/tests/routers/test_contracts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from fastapi.testclient import TestClient

from sqlmodel.ext.asyncio.session import AsyncSession

from ...datasources.db.database import database_session
from ...datasources.db.models import Contract
from ...main import app
from ...services.contract import ContractService
Expand All @@ -13,14 +16,15 @@ class TestRouterContract(DbAsyncConn):
def setUpClass(cls):
cls.client = TestClient(app)

async def test_view_contracts(self):
@database_session
async def test_view_contracts(self, session: AsyncSession):
contract = Contract(address=b"a", name="A Test Contracts")
expected_response = {
"name": "A Test Contracts",
"description": None,
"address": "a",
}
await ContractService.create(contract=contract)
await ContractService.create(contract=contract, session=session)
response = self.client.get("/api/v1/contracts")
self.assertEqual(response.status_code, 200)
self.assertDictEqual(response.json()[0], expected_response)
Loading