Skip to content

Commit

Permalink
Fix database issues (#66)
Browse files Browse the repository at this point in the history
* Remove wrong database_session

* Add timezone to datetime fields
  • Loading branch information
falvaradorodriguez authored Jan 16, 2025
1 parent 398a226 commit b20ac18
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 11 deletions.
10 changes: 6 additions & 4 deletions app/datasources/db/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime, timezone
from typing import AsyncIterator, Self, cast

from sqlalchemy import Row
from sqlalchemy import DateTime, Row
from sqlmodel import (
JSON,
Column,
Expand Down Expand Up @@ -46,15 +46,17 @@ class TimeStampedSQLModel(SQLModel):
"""

created: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc).replace(tzinfo=None),
default_factory=lambda: datetime.now(timezone.utc),
nullable=False,
sa_type=DateTime(timezone=True), # type: ignore
)

modified: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc).replace(tzinfo=None),
default_factory=lambda: datetime.now(timezone.utc),
nullable=False,
sa_type=DateTime(timezone=True), # type: ignore
sa_column_kwargs={
"onupdate": lambda: datetime.now(timezone.utc).replace(tzinfo=None),
"onupdate": lambda: datetime.now(timezone.utc),
},
)

Expand Down
6 changes: 5 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from fastapi import APIRouter, FastAPI

from sqlmodel.ext.asyncio.session import AsyncSession

from . import VERSION
from .datasources.db.database import get_engine
from .datasources.queue.exceptions import QueueProviderUnableToConnectException
from .datasources.queue.queue_provider import QueueProvider
from .routers import about, admin, contracts, default
Expand Down Expand Up @@ -33,7 +36,8 @@ async def lifespan(app: FastAPI):
consume_task = asyncio.create_task(
queue_provider.consume(events_service.process_event)
)
await abi_service.load_local_abis_in_database()
async with AsyncSession(get_engine(), expire_on_commit=False) as session:
await abi_service.load_local_abis_in_database(session)
yield
finally:
if consume_task:
Expand Down
2 changes: 0 additions & 2 deletions app/services/abis.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
)
from app.datasources.abis.snapshot import snapshot_delegate_registry_abi
from app.datasources.abis.timelock import timelock_abi
from app.datasources.db.database import database_session
from app.datasources.db.models import Abi, AbiSource


Expand All @@ -81,7 +80,6 @@ async def _store_abis_in_database(
abi_json=abi_json, source_id=abi_source.id, relevance=relevance
).create(session)

@database_session
async def load_local_abis_in_database(self, session: AsyncSession) -> None:
abi_source, _ = await AbiSource.get_or_create(
session, "localstorage", "decoder-service"
Expand Down
5 changes: 3 additions & 2 deletions app/tests/routers/test_contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...datasources.db.database import database_session
from ...datasources.db.models import Abi, AbiSource, Contract
from ...main import app
from ...utils import datetime_to_str
from ..datasources.db.db_async_conn import DbAsyncConn
from ..mocks.abi_mock import mock_abi_json

Expand Down Expand Up @@ -42,11 +43,11 @@ async def test_view_contracts(self, session: AsyncSession):
self.assertEqual(results[0]["address"], address_expected)
self.assertEqual(results[0]["abi"]["abi_json"], mock_abi_json)
self.assertEqual(results[0]["abi"]["abi_hash"], "0xb4b61541")
self.assertEqual(results[0]["abi"]["modified"], abi.modified.isoformat())
self.assertEqual(results[0]["abi"]["modified"], datetime_to_str(abi.modified))
self.assertEqual(results[0]["display_name"], None)
self.assertEqual(results[0]["chain_id"], 1)
self.assertEqual(results[0]["project"], None)
self.assertEqual(results[0]["modified"], contract.modified.isoformat())
self.assertEqual(results[0]["modified"], datetime_to_str(contract.modified))
# Test filter by chain_id
contract = Contract(
address=address, name="A Test Contracts", chain_id=5, abi=abi
Expand Down
4 changes: 2 additions & 2 deletions app/tests/services/test_abis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def test_load_local_abis_in_database(self, session: AsyncSession):
self.assertEqual(await AbiSource.get_all(session), [])
self.assertEqual(await Abi.get_all(session), [])

await self.abi_service.load_local_abis_in_database()
await self.abi_service.load_local_abis_in_database(session)
self.assertEqual(len(await AbiSource.get_all(session)), 1)
abis = await Abi.get_all(session)
self.assertEqual(len(abis), 152)
Expand All @@ -28,7 +28,7 @@ async def test_load_local_abis_in_database(self, session: AsyncSession):
self.assertEqual(relevance_counts[90], 5)
self.assertEqual(relevance_counts[50], 142)

await self.abi_service.load_local_abis_in_database()
await self.abi_service.load_local_abis_in_database(session)
self.assertEqual(len(await Abi.get_all(session)), 152)

def test_get_safe_contracts_abis(self):
Expand Down
9 changes: 9 additions & 0 deletions app/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from datetime import datetime


def datetime_to_str(value: datetime) -> str:
"""
:param value: `datetime.datetime` value
:return: ``ISO 8601`` date with ``Z`` format
"""
return value.isoformat().replace("+00:00", "Z")
85 changes: 85 additions & 0 deletions migrations/versions/e5e94343c151_add_timezone_datetime_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""add_timezone_datetime_fields
Revision ID: e5e94343c151
Revises: 8d2eacb17707
Create Date: 2025-01-16 10:14:21.505781
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision: str = "e5e94343c151"
down_revision: Union[str, None] = "8d2eacb17707"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"abi",
"created",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
op.alter_column(
"abi",
"modified",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
op.alter_column(
"contract",
"created",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
op.alter_column(
"contract",
"modified",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"contract",
"modified",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
op.alter_column(
"contract",
"created",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
op.alter_column(
"abi",
"modified",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
op.alter_column(
"abi",
"created",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
# ### end Alembic commands ###

0 comments on commit b20ac18

Please sign in to comment.