diff --git a/app/datasources/db/models.py b/app/datasources/db/models.py index 7570e8e..7b0195b 100644 --- a/app/datasources/db/models.py +++ b/app/datasources/db/models.py @@ -1,7 +1,7 @@ from datetime import datetime, timezone from typing import AsyncIterator, Self, cast -from sqlalchemy import Row +from sqlalchemy import DateTime, Row from sqlmodel import ( JSON, Column, @@ -46,15 +46,17 @@ class TimeStampedSQLModel(SQLModel): """ created: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc).replace(tzinfo=None), + default_factory=lambda: datetime.now(timezone.utc), nullable=False, + sa_type=DateTime(timezone=True), # type: ignore ) modified: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc).replace(tzinfo=None), + default_factory=lambda: datetime.now(timezone.utc), nullable=False, + sa_type=DateTime(timezone=True), # type: ignore sa_column_kwargs={ - "onupdate": lambda: datetime.now(timezone.utc).replace(tzinfo=None), + "onupdate": lambda: datetime.now(timezone.utc), }, ) diff --git a/app/main.py b/app/main.py index 15a5150..9d9c6c4 100644 --- a/app/main.py +++ b/app/main.py @@ -4,7 +4,10 @@ from fastapi import APIRouter, FastAPI +from sqlmodel.ext.asyncio.session import AsyncSession + from . import VERSION +from .datasources.db.database import get_engine from .datasources.queue.exceptions import QueueProviderUnableToConnectException from .datasources.queue.queue_provider import QueueProvider from .routers import about, admin, contracts, default @@ -33,7 +36,8 @@ async def lifespan(app: FastAPI): consume_task = asyncio.create_task( queue_provider.consume(events_service.process_event) ) - await abi_service.load_local_abis_in_database() + async with AsyncSession(get_engine(), expire_on_commit=False) as session: + await abi_service.load_local_abis_in_database(session) yield finally: if consume_task: diff --git a/app/services/abis.py b/app/services/abis.py index 6b90188..fd127ac 100644 --- a/app/services/abis.py +++ b/app/services/abis.py @@ -58,7 +58,6 @@ ) from app.datasources.abis.snapshot import snapshot_delegate_registry_abi from app.datasources.abis.timelock import timelock_abi -from app.datasources.db.database import database_session from app.datasources.db.models import Abi, AbiSource @@ -81,7 +80,6 @@ async def _store_abis_in_database( abi_json=abi_json, source_id=abi_source.id, relevance=relevance ).create(session) - @database_session async def load_local_abis_in_database(self, session: AsyncSession) -> None: abi_source, _ = await AbiSource.get_or_create( session, "localstorage", "decoder-service" diff --git a/app/tests/routers/test_contracts.py b/app/tests/routers/test_contracts.py index fca7195..bfbaa06 100644 --- a/app/tests/routers/test_contracts.py +++ b/app/tests/routers/test_contracts.py @@ -6,6 +6,7 @@ from ...datasources.db.database import database_session from ...datasources.db.models import Abi, AbiSource, Contract from ...main import app +from ...utils import datetime_to_str from ..datasources.db.db_async_conn import DbAsyncConn from ..mocks.abi_mock import mock_abi_json @@ -42,11 +43,11 @@ async def test_view_contracts(self, session: AsyncSession): self.assertEqual(results[0]["address"], address_expected) self.assertEqual(results[0]["abi"]["abi_json"], mock_abi_json) self.assertEqual(results[0]["abi"]["abi_hash"], "0xb4b61541") - self.assertEqual(results[0]["abi"]["modified"], abi.modified.isoformat()) + self.assertEqual(results[0]["abi"]["modified"], datetime_to_str(abi.modified)) self.assertEqual(results[0]["display_name"], None) self.assertEqual(results[0]["chain_id"], 1) self.assertEqual(results[0]["project"], None) - self.assertEqual(results[0]["modified"], contract.modified.isoformat()) + self.assertEqual(results[0]["modified"], datetime_to_str(contract.modified)) # Test filter by chain_id contract = Contract( address=address, name="A Test Contracts", chain_id=5, abi=abi diff --git a/app/tests/services/test_abis.py b/app/tests/services/test_abis.py index 7b987b5..3296204 100644 --- a/app/tests/services/test_abis.py +++ b/app/tests/services/test_abis.py @@ -19,7 +19,7 @@ async def test_load_local_abis_in_database(self, session: AsyncSession): self.assertEqual(await AbiSource.get_all(session), []) self.assertEqual(await Abi.get_all(session), []) - await self.abi_service.load_local_abis_in_database() + await self.abi_service.load_local_abis_in_database(session) self.assertEqual(len(await AbiSource.get_all(session)), 1) abis = await Abi.get_all(session) self.assertEqual(len(abis), 152) @@ -28,7 +28,7 @@ async def test_load_local_abis_in_database(self, session: AsyncSession): self.assertEqual(relevance_counts[90], 5) self.assertEqual(relevance_counts[50], 142) - await self.abi_service.load_local_abis_in_database() + await self.abi_service.load_local_abis_in_database(session) self.assertEqual(len(await Abi.get_all(session)), 152) def test_get_safe_contracts_abis(self): diff --git a/app/utils.py b/app/utils.py new file mode 100644 index 0000000..693c34b --- /dev/null +++ b/app/utils.py @@ -0,0 +1,9 @@ +from datetime import datetime + + +def datetime_to_str(value: datetime) -> str: + """ + :param value: `datetime.datetime` value + :return: ``ISO 8601`` date with ``Z`` format + """ + return value.isoformat().replace("+00:00", "Z") diff --git a/migrations/versions/e5e94343c151_add_timezone_datetime_fields.py b/migrations/versions/e5e94343c151_add_timezone_datetime_fields.py new file mode 100644 index 0000000..c356593 --- /dev/null +++ b/migrations/versions/e5e94343c151_add_timezone_datetime_fields.py @@ -0,0 +1,85 @@ +"""add_timezone_datetime_fields + +Revision ID: e5e94343c151 +Revises: 8d2eacb17707 +Create Date: 2025-01-16 10:14:21.505781 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "e5e94343c151" +down_revision: Union[str, None] = "8d2eacb17707" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "abi", + "created", + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=False, + ) + op.alter_column( + "abi", + "modified", + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=False, + ) + op.alter_column( + "contract", + "created", + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=False, + ) + op.alter_column( + "contract", + "modified", + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "contract", + "modified", + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=False, + ) + op.alter_column( + "contract", + "created", + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=False, + ) + op.alter_column( + "abi", + "modified", + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=False, + ) + op.alter_column( + "abi", + "created", + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=False, + ) + # ### end Alembic commands ###