Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix database issues #66

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions app/datasources/db/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime, timezone
from typing import AsyncIterator, Self, cast

from sqlalchemy import Row
from sqlalchemy import DateTime, Row
from sqlmodel import (
JSON,
Column,
Expand Down Expand Up @@ -46,15 +46,17 @@ class TimeStampedSQLModel(SQLModel):
"""

created: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc).replace(tzinfo=None),
default_factory=lambda: datetime.now(timezone.utc),
nullable=False,
sa_type=DateTime(timezone=True), # type: ignore
)

modified: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc).replace(tzinfo=None),
default_factory=lambda: datetime.now(timezone.utc),
nullable=False,
sa_type=DateTime(timezone=True), # type: ignore
sa_column_kwargs={
"onupdate": lambda: datetime.now(timezone.utc).replace(tzinfo=None),
"onupdate": lambda: datetime.now(timezone.utc),
},
)

Expand Down
6 changes: 5 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

from fastapi import APIRouter, FastAPI

from sqlmodel.ext.asyncio.session import AsyncSession

from . import VERSION
from .datasources.db.database import get_engine
from .datasources.queue.exceptions import QueueProviderUnableToConnectException
from .datasources.queue.queue_provider import QueueProvider
from .routers import about, admin, contracts, default
Expand Down Expand Up @@ -33,7 +36,8 @@ async def lifespan(app: FastAPI):
consume_task = asyncio.create_task(
queue_provider.consume(events_service.process_event)
)
await abi_service.load_local_abis_in_database()
async with AsyncSession(get_engine(), expire_on_commit=False) as session:
await abi_service.load_local_abis_in_database(session)
yield
finally:
if consume_task:
Expand Down
2 changes: 0 additions & 2 deletions app/services/abis.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
)
from app.datasources.abis.snapshot import snapshot_delegate_registry_abi
from app.datasources.abis.timelock import timelock_abi
from app.datasources.db.database import database_session
from app.datasources.db.models import Abi, AbiSource


Expand All @@ -81,7 +80,6 @@ async def _store_abis_in_database(
abi_json=abi_json, source_id=abi_source.id, relevance=relevance
).create(session)

@database_session
async def load_local_abis_in_database(self, session: AsyncSession) -> None:
abi_source, _ = await AbiSource.get_or_create(
session, "localstorage", "decoder-service"
Expand Down
5 changes: 3 additions & 2 deletions app/tests/routers/test_contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...datasources.db.database import database_session
from ...datasources.db.models import Abi, AbiSource, Contract
from ...main import app
from ...utils import datetime_to_str
from ..datasources.db.db_async_conn import DbAsyncConn
from ..mocks.abi_mock import mock_abi_json

Expand Down Expand Up @@ -42,11 +43,11 @@ async def test_view_contracts(self, session: AsyncSession):
self.assertEqual(results[0]["address"], address_expected)
self.assertEqual(results[0]["abi"]["abi_json"], mock_abi_json)
self.assertEqual(results[0]["abi"]["abi_hash"], "0xb4b61541")
self.assertEqual(results[0]["abi"]["modified"], abi.modified.isoformat())
self.assertEqual(results[0]["abi"]["modified"], datetime_to_str(abi.modified))
self.assertEqual(results[0]["display_name"], None)
self.assertEqual(results[0]["chain_id"], 1)
self.assertEqual(results[0]["project"], None)
self.assertEqual(results[0]["modified"], contract.modified.isoformat())
self.assertEqual(results[0]["modified"], datetime_to_str(contract.modified))
# Test filter by chain_id
contract = Contract(
address=address, name="A Test Contracts", chain_id=5, abi=abi
Expand Down
4 changes: 2 additions & 2 deletions app/tests/services/test_abis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def test_load_local_abis_in_database(self, session: AsyncSession):
self.assertEqual(await AbiSource.get_all(session), [])
self.assertEqual(await Abi.get_all(session), [])

await self.abi_service.load_local_abis_in_database()
await self.abi_service.load_local_abis_in_database(session)
self.assertEqual(len(await AbiSource.get_all(session)), 1)
abis = await Abi.get_all(session)
self.assertEqual(len(abis), 152)
Expand All @@ -28,7 +28,7 @@ async def test_load_local_abis_in_database(self, session: AsyncSession):
self.assertEqual(relevance_counts[90], 5)
self.assertEqual(relevance_counts[50], 142)

await self.abi_service.load_local_abis_in_database()
await self.abi_service.load_local_abis_in_database(session)
self.assertEqual(len(await Abi.get_all(session)), 152)

def test_get_safe_contracts_abis(self):
Expand Down
9 changes: 9 additions & 0 deletions app/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from datetime import datetime


def datetime_to_str(value: datetime) -> str:
"""
:param value: `datetime.datetime` value
:return: ``ISO 8601`` date with ``Z`` format
"""
return value.isoformat().replace("+00:00", "Z")
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""add_timezone_datetime_fields
Revision ID: e5e94343c151
Revises: 8d2eacb17707
Create Date: 2025-01-16 10:14:21.505781
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision: str = "e5e94343c151"
down_revision: Union[str, None] = "8d2eacb17707"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"abi",
"created",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
op.alter_column(
"abi",
"modified",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
op.alter_column(
"contract",
"created",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
op.alter_column(
"contract",
"modified",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"contract",
"modified",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
op.alter_column(
"contract",
"created",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
op.alter_column(
"abi",
"modified",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
op.alter_column(
"abi",
"created",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
# ### end Alembic commands ###
Loading