diff --git a/app/datasources/db/models.py b/app/datasources/db/models.py index 410369a..a0db6e7 100644 --- a/app/datasources/db/models.py +++ b/app/datasources/db/models.py @@ -1,4 +1,22 @@ -from sqlmodel import JSON, Column, Field, SQLModel, UniqueConstraint +from sqlmodel import JSON, Column, Field, SQLModel, UniqueConstraint, select + + +class SqlQueryBase: + @classmethod + async def get_all(cls, session): + result = await session.exec(select(cls)) + return result.all() + + async def _save(self, session): + session.add(self) + await session.commit() + return self + + async def update(self, session): + return await self._save(session) + + async def create(self, session): + return await self._save(session) class AbiSource(SQLModel, table=True): @@ -7,24 +25,20 @@ class AbiSource(SQLModel, table=True): url: str = Field(nullable=False) -class Abi(SQLModel, table=True): +class Abi(SqlQueryBase, SQLModel, table=True): abi_hash: bytes = Field(nullable=False, primary_key=True) relevance: int = Field(nullable=False, default=0) abi_json: dict = Field(default_factory=dict, sa_column=Column(JSON)) source_id: int = Field(default=None, foreign_key="abisource.id") -class Chain(SQLModel, table=True): - id: int = Field(primary_key=True) # Chain ID - name: str = Field(nullable=False) - - -class Contract(SQLModel, table=True): +class Contract(SqlQueryBase, SQLModel, table=True): __table_args__ = ( UniqueConstraint("address", "chain_id", name="address_chain_unique"), ) - address: bytes = Field(nullable=False, primary_key=True) + id: int | None = Field(default=None, primary_key=True) + address: bytes = Field(nullable=False) name: str = Field(nullable=False) display_name: str | None = None description: str | None = None @@ -34,4 +48,4 @@ class Contract(SQLModel, table=True): abi_id: bytes | None = Field( nullable=True, default=None, foreign_key="abi.abi_hash" ) - chain_id: int = Field(default=None, foreign_key="chain.id") + chain_id: int = Field(default=None) diff --git a/app/services/chain.py b/app/services/chain.py deleted file mode 100644 index 38b485e..0000000 --- a/app/services/chain.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Sequence - -from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession - -from app.datasources.db.database import get_database_session -from app.datasources.db.models import Chain - - -class ChainService: - - @staticmethod - @get_database_session - async def get_all(session: AsyncSession) -> Sequence[Chain]: - """ - Get all chains - - :param session: passed by the decorator - :return: - """ - result = await session.exec(select(Chain)) - return result.all() - - @staticmethod - @get_database_session - async def create(chain: Chain, session: AsyncSession) -> Chain: - """ - Create a new chain - - :param chain: - :param session: - :return: - """ - session.add(chain) - await session.commit() - return chain diff --git a/app/services/contract.py b/app/services/contract.py index 41cd588..16a5c59 100644 --- a/app/services/contract.py +++ b/app/services/contract.py @@ -1,6 +1,5 @@ from typing import Sequence -from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from app.datasources.db.models import Contract @@ -16,18 +15,4 @@ async def get_all(session: AsyncSession) -> Sequence[Contract]: :param session: passed by the decorator :return: """ - result = await session.exec(select(Contract)) - return result.all() - - @staticmethod - async def create(contract: Contract, session: AsyncSession) -> Contract: - """ - Create a new contract - - :param contract: - :param session: - :return: - """ - session.add(contract) - await session.commit() - return contract + return await Contract.get_all(session) diff --git a/app/tests/db/test_model.py b/app/tests/db/test_model.py index 3a2f653..1206f0e 100644 --- a/app/tests/db/test_model.py +++ b/app/tests/db/test_model.py @@ -2,18 +2,15 @@ from sqlmodel.ext.asyncio.session import AsyncSession from app.datasources.db.database import database_session -from app.datasources.db.models import Chain, Contract +from app.datasources.db.models import Contract from app.tests.db.db_async_conn import DbAsyncConn class TestModel(DbAsyncConn): @database_session async def test_contract(self, session: AsyncSession): - chain = Chain(id=1, name="mainnet") - session.add(chain) - contract = Contract(address=b"a", name="A Test Contracts", chain_id=chain.id) - session.add(contract) - await session.commit() + contract = Contract(address=b"a", name="A Test Contracts", chain_id=1) + await contract.create(session) statement = select(Contract).where(Contract.address == b"a") result = await session.exec(statement) self.assertEqual(result.one(), contract) diff --git a/app/tests/routers/test_contracts.py b/app/tests/routers/test_contracts.py index 5bb2961..fa5a911 100644 --- a/app/tests/routers/test_contracts.py +++ b/app/tests/routers/test_contracts.py @@ -1,11 +1,10 @@ from fastapi.testclient import TestClient from sqlmodel.ext.asyncio.session import AsyncSession + from ...datasources.db.database import database_session -from ...datasources.db.models import Chain, Contract +from ...datasources.db.models import Contract from ...main import app -from ...services.chain import ChainService -from ...services.contract import ContractService from ..db.db_async_conn import DbAsyncConn @@ -16,18 +15,14 @@ class TestRouterContract(DbAsyncConn): def setUpClass(cls): cls.client = TestClient(app) - @database_session async def test_view_contracts(self, session: AsyncSession): - chain = Chain(id=1, name="mainnet") - await ChainService.create(chain) contract = Contract(address=b"a", name="A Test Contracts", chain_id=1) expected_response = { "name": "A Test Contracts", "description": None, "address": "a", } - await ContractService.create(contract=contract, session=session) + await contract.create(session) response = self.client.get("/api/v1/contracts") self.assertEqual(response.status_code, 200) - self.assertDictEqual(response.json()[0], expected_response) diff --git a/migrations/versions/c52699d26409_init.py b/migrations/versions/c52699d26409_init.py deleted file mode 100644 index 1f956f3..0000000 --- a/migrations/versions/c52699d26409_init.py +++ /dev/null @@ -1,80 +0,0 @@ -"""init - -Revision ID: c52699d26409 -Revises: -Create Date: 2024-12-11 10:58:44.921603 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -import sqlmodel -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "c52699d26409" -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "abisource", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column("url", sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.PrimaryKeyConstraint("id"), - ) - op.create_table( - "chain", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.PrimaryKeyConstraint("id"), - ) - op.create_table( - "abi", - sa.Column("abi_hash", sa.LargeBinary(), nullable=False), - sa.Column("relevance", sa.Integer(), nullable=False), - sa.Column("abi_json", sa.JSON(), nullable=True), - sa.Column("source_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["source_id"], - ["abisource.id"], - ), - sa.PrimaryKeyConstraint("abi_hash"), - ) - op.create_table( - "contract", - sa.Column("address", sa.LargeBinary(), nullable=False), - sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column("display_name", sqlmodel.sql.sqltypes.AutoString(), nullable=True), - sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True), - sa.Column("trusted_for_delegate", sa.Boolean(), nullable=False), - sa.Column("proxy", sa.Boolean(), nullable=False), - sa.Column("fetch_retries", sa.Integer(), nullable=False), - sa.Column("abi_id", sa.LargeBinary(), nullable=True), - sa.Column("chain_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["abi_id"], - ["abi.abi_hash"], - ), - sa.ForeignKeyConstraint( - ["chain_id"], - ["chain.id"], - ), - sa.PrimaryKeyConstraint("address"), - sa.UniqueConstraint("address", "chain_id", name="address_chain_unique"), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("contract") - op.drop_table("abi") - op.drop_table("chain") - op.drop_table("abisource") - # ### end Alembic commands ###