Skip to content

Commit

Permalink
Refactor database session (#79)
Browse files Browse the repository at this point in the history
  • Loading branch information
moisses89 authored Jan 24, 2025
1 parent be533a6 commit f9da5b3
Show file tree
Hide file tree
Showing 21 changed files with 447 additions and 462 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,13 @@ To open an interactive Python shell within a Docker container and query the data
```
Example usage:
```
In [11]: contracts = await Contract.get_all(session)
In [11]: contracts = await Contract.get_all()
In [12]: contracts[0].address
Out[12]: b'J\xdb\xaa\xc7\xbc#\x9e%\x19\xcb\xfd#\x97\xe0\xf7Z\x1d\xe3U\xc8'
```
Call `await restore_session()` to reopen a new session.

## Contributors
[See contributors](https://github.com/safe-global/safe-decoder-service/graphs/contributors)
76 changes: 59 additions & 17 deletions app/datasources/db/database.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
import logging
from collections.abc import AsyncGenerator
import uuid
from contextlib import contextmanager
from contextvars import ContextVar
from functools import cache, wraps
from typing import Generator

from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
async_scoped_session,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.pool import AsyncAdaptedQueuePool, NullPool
from sqlmodel.ext.asyncio.session import AsyncSession

from ...config import settings

logger = logging.getLogger(__name__)

pool_classes = {
NullPool.__name__: NullPool,
AsyncAdaptedQueuePool.__name__: AsyncAdaptedQueuePool,
}

_db_session_context: ContextVar[str] = ContextVar("db_session_context")


@cache
def get_engine() -> AsyncEngine:
Expand All @@ -35,28 +46,59 @@ def get_engine() -> AsyncEngine:
)


async def get_database_session() -> AsyncGenerator:
async with AsyncSession(get_engine(), expire_on_commit=False) as session:
yield session
@contextmanager
def set_database_session_context(
session_id: str | None = None,
) -> Generator[None, None, None]:
"""
Set session ContextVar, at the end it will be removed.
This context is designed to be used with `async_scoped_session` to define a context scope.
:param session_id:
:return:
"""
_session_id: str = session_id or str(uuid.uuid4())
logger.debug(f"Storing db_session context: {_session_id}")
token = _db_session_context.set(_session_id)
try:
yield
finally:
logger.debug(f"Removing db_session context: {_session_id}")
_db_session_context.reset(token)


def database_session(func):
def _get_database_session_context() -> str:
"""
Decorator that creates a new database session for the given function
Get the database session id from the ContextVar.
Used as a scope function on `async_scoped_session`.
:param func:
:return:
:return: session_id for the current context
"""
return _db_session_context.get()


def db_session_context(func):
"""
Wrap the decorated function in the `set_database_session_context` context.
Decorated function will share the same database session.
Remove the session at the end of the context.
"""

@wraps(func)
async def wrapper(*args, **kwargs):
async with AsyncSession(get_engine(), expire_on_commit=False) as session:
with set_database_session_context():
try:
return await func(*args, **kwargs, session=session)
except Exception as e:
# Rollback errors
await session.rollback()
logging.error(f"Error occurred: {e}")
raise
return await func(*args, **kwargs)
finally:
logger.debug(
f"Removing session context: {_get_database_session_context()}"
)
await db_session.remove()

return wrapper


async_session_factory = async_sessionmaker(get_engine(), expire_on_commit=False)
db_session = async_scoped_session(
session_factory=async_session_factory, scopefunc=_get_database_session_context
)
93 changes: 41 additions & 52 deletions app/datasources/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,30 @@
col,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql._expression_select_cls import SelectBase
from web3.types import ABI

from .database import db_session
from .utils import get_md5_abi_hash


class SqlQueryBase:

@classmethod
async def get_all(cls, session: AsyncSession):
result = await session.exec(select(cls))
return result.all()
async def get_all(cls):
result = await db_session.execute(select(cls))
return result.scalars().all()

async def _save(self, session: AsyncSession):
session.add(self)
await session.commit()
async def _save(self):
db_session.add(self)
await db_session.commit()
return self

async def update(self, session: AsyncSession):
return await self._save(session)
async def update(self):
return await self._save()

async def create(self, session: AsyncSession):
return await self._save(session)
async def create(self):
return await self._save()


class TimeStampedSQLModel(SQLModel):
Expand Down Expand Up @@ -69,33 +69,30 @@ class AbiSource(SqlQueryBase, SQLModel, table=True):
abis: list["Abi"] = Relationship(back_populates="source")

@classmethod
async def get_or_create(
cls, session: AsyncSession, name: str, url: str
) -> tuple["AbiSource", bool]:
async def get_or_create(cls, name: str, url: str) -> tuple["AbiSource", bool]:
"""
Checks if an AbiSource with the given 'name' and 'url' exists.
If found, returns it with False. If not, creates and returns it with True.
:param session: The database session.
:param name: The name to check or create.
:param url: The URL to check or create.
:return: A tuple containing the AbiSource object and a boolean indicating
whether it was created `True` or already exists `False`.
"""
query = select(cls).where(cls.name == name, cls.url == url)
results = await session.exec(query)
if result := results.first():
results = await db_session.execute(query)
if result := results.scalars().first():
return result, False
else:
new_item = cls(name=name, url=url)
await new_item.create(session)
await new_item.create()
return new_item, True

@classmethod
async def get_abi_source(cls, session: AsyncSession, name: str):
async def get_abi_source(cls, name: str):
query = select(cls).where(cls.name == name)
results = await session.exec(query)
if result := results.first():
results = await db_session.execute(query)
if result := results.scalars().first():
return result
return None

Expand All @@ -113,48 +110,45 @@ class Abi(SqlQueryBase, TimeStampedSQLModel, table=True):
contracts: list["Contract"] = Relationship(back_populates="abi")

@classmethod
async def get_abis_sorted_by_relevance(
cls, session: AsyncSession
) -> AsyncIterator[ABI]:
async def get_abis_sorted_by_relevance(cls) -> AsyncIterator[ABI]:
"""
:return: Abi JSON, with the ones with less relevance first
"""
results = await session.exec(select(cls.abi_json).order_by(col(cls.relevance)))
for result in results:
results = await db_session.execute(
select(cls.abi_json).order_by(col(cls.relevance))
)
for result in results.scalars().all():
yield cast(ABI, result)

async def create(self, session):
async def create(self):
self.abi_hash = get_md5_abi_hash(self.abi_json)
return await self._save(session)
return await self._save()

@classmethod
async def get_abi(
cls,
session: AsyncSession,
abi_json: list[dict] | dict,
):
"""
Checks if an Abi exists based on the 'abi_json' by its calculated 'abi_hash'.
If it exists, returns the existing Abi. If not,
returns None.
:param session: The database session.
:param abi_json: The ABI JSON to check.
:return: The Abi object if it exists, or None if it doesn't.
"""
abi_hash = get_md5_abi_hash(abi_json)
query = select(cls).where(cls.abi_hash == abi_hash)
result = await session.exec(query)
result = await db_session.execute(query)

if existing_abi := result.first():
if existing_abi := result.scalars().first():
return existing_abi

return None

@classmethod
async def get_or_create_abi(
cls,
session: AsyncSession,
abi_json: list[dict] | dict,
source_id: int | None,
relevance: int | None = 0,
Expand All @@ -163,18 +157,17 @@ async def get_or_create_abi(
Checks if an Abi with the given 'abi_json' exists.
If found, returns it with False. If not, creates and returns it with True.
:param session: The database session.
:param abi_json: The ABI JSON to check.
:param relevance:
:param source_id:
:return: A tuple containing the Abi object and a boolean indicating
whether it was created `True` or already exists `False`.
"""
if abi := await cls.get_abi(session, abi_json):
if abi := await cls.get_abi(abi_json):
return abi, False
else:
new_item = cls(abi_json=abi_json, relevance=relevance, source_id=source_id)
await new_item.create(session)
await new_item.create()
return new_item, True


Expand Down Expand Up @@ -230,47 +223,45 @@ def get_contracts_with_abi_query(
return query

@classmethod
async def get_contract(cls, session: AsyncSession, address: bytes, chain_id: int):
async def get_contract(cls, address: bytes, chain_id: int):
query = (
select(cls).where(cls.address == address).where(cls.chain_id == chain_id)
)
results = await session.exec(query)
if result := results.first():
results = await db_session.execute(query)
if result := results.scalars().first():
return result
return None

@classmethod
async def get_or_create(
cls,
session: AsyncSession,
address: bytes,
chain_id: int,
**kwargs,
) -> tuple["Contract", bool]:
"""
Update or create the given params.
:param session: The database session.
:param address:
:param chain_id:
:param kwargs:
:return: A tuple containing the Contract object and a boolean indicating
whether it was created `True` or already exists `False`.
"""
if contract := await cls.get_contract(session, address, chain_id):
if contract := await cls.get_contract(address, chain_id):
return contract, False
else:
contract = cls(address=address, chain_id=chain_id)
# Add optional fields
for key, value in kwargs.items():
setattr(contract, key, value)

await contract.create(session)
await contract.create()
return contract, True

@classmethod
async def get_abi_by_contract_address(
cls, session: AsyncSession, address: bytes, chain_id: int | None
cls, address: bytes, chain_id: int | None
) -> ABI | None:
"""
:return: Json ABI given the contract `address` and `chain_id`. If `chain_id` is not given,
Expand All @@ -287,22 +278,21 @@ async def get_abi_by_contract_address(
else:
query = query.order_by(col(cls.chain_id))

results = await session.exec(query)
if result := results.first():
results = await db_session.execute(query)
if result := results.scalars().first():
return cast(ABI, result)
return None

@classmethod
async def get_contracts_without_abi(
cls, session: AsyncSession, max_retries: int = 0
cls, max_retries: int = 0
) -> AsyncIterator[Self]:
"""
Fetches contracts without an ABI and fewer retries than max_retries,
streaming results in batches to reduce memory usage for large datasets.
More information about streaming results can be found here:
https://docs.sqlalchemy.org/en/20/core/connections.html#streaming-with-a-dynamically-growing-buffer-using-stream-results
:param session:
:param max_retries:
:return:
"""
Expand All @@ -311,19 +301,18 @@ async def get_contracts_without_abi(
.where(cls.abi_id == None) # noqa: E711
.where(cls.fetch_retries <= max_retries)
)
result = await session.stream(query)
result = await db_session.stream(query)
async for (contract,) in result:
yield contract

@classmethod
async def get_proxy_contracts(cls, session: AsyncSession) -> AsyncIterator[Self]:
async def get_proxy_contracts(cls) -> AsyncIterator[Self]:
"""
Return all the contracts with implementation address, so proxy contracts.
:param session:
:return:
"""
query = select(cls).where(cls.implementation.isnot(None)) # type: ignore
result = await session.stream(query)
result = await db_session.stream(query)
async for (contract,) in result:
yield contract
Loading

0 comments on commit f9da5b3

Please sign in to comment.