Skip to content

Commit

Permalink
Move database queries to db datasource
Browse files Browse the repository at this point in the history
  • Loading branch information
moisses89 committed Dec 11, 2024
1 parent b825684 commit ea2bc18
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 156 deletions.
34 changes: 24 additions & 10 deletions app/datasources/db/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,22 @@
from sqlmodel import JSON, Column, Field, SQLModel, UniqueConstraint
from sqlmodel import JSON, Column, Field, SQLModel, UniqueConstraint, select


class SqlQueryBase:
@classmethod
async def get_all(cls, session):
result = await session.exec(select(cls))
return result.all()

async def _save(self, session):
session.add(self)
await session.commit()
return self

async def update(self, session):
return await self._save(session)

async def create(self, session):
return await self._save(session)


class AbiSource(SQLModel, table=True):
Expand All @@ -7,24 +25,20 @@ class AbiSource(SQLModel, table=True):
url: str = Field(nullable=False)


class Abi(SQLModel, table=True):
class Abi(SqlQueryBase, SQLModel, table=True):
abi_hash: bytes = Field(nullable=False, primary_key=True)
relevance: int = Field(nullable=False, default=0)
abi_json: dict = Field(default_factory=dict, sa_column=Column(JSON))
source_id: int = Field(default=None, foreign_key="abisource.id")


class Chain(SQLModel, table=True):
id: int = Field(primary_key=True) # Chain ID
name: str = Field(nullable=False)


class Contract(SQLModel, table=True):
class Contract(SqlQueryBase, SQLModel, table=True):
__table_args__ = (
UniqueConstraint("address", "chain_id", name="address_chain_unique"),
)

address: bytes = Field(nullable=False, primary_key=True)
id: int | None = Field(default=None, primary_key=True)
address: bytes = Field(nullable=False)
name: str = Field(nullable=False)
display_name: str | None = None
description: str | None = None
Expand All @@ -34,4 +48,4 @@ class Contract(SQLModel, table=True):
abi_id: bytes | None = Field(
nullable=True, default=None, foreign_key="abi.abi_hash"
)
chain_id: int = Field(default=None, foreign_key="chain.id")
chain_id: int = Field(default=None)
36 changes: 0 additions & 36 deletions app/services/chain.py

This file was deleted.

17 changes: 1 addition & 16 deletions app/services/contract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Sequence

from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession

from app.datasources.db.models import Contract
Expand All @@ -16,18 +15,4 @@ async def get_all(session: AsyncSession) -> Sequence[Contract]:
:param session: passed by the decorator
:return:
"""
result = await session.exec(select(Contract))
return result.all()

@staticmethod
async def create(contract: Contract, session: AsyncSession) -> Contract:
"""
Create a new contract
:param contract:
:param session:
:return:
"""
session.add(contract)
await session.commit()
return contract
return await Contract.get_all(session)
9 changes: 3 additions & 6 deletions app/tests/db/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,15 @@
from sqlmodel.ext.asyncio.session import AsyncSession

from app.datasources.db.database import database_session
from app.datasources.db.models import Chain, Contract
from app.datasources.db.models import Contract
from app.tests.db.db_async_conn import DbAsyncConn


class TestModel(DbAsyncConn):
@database_session
async def test_contract(self, session: AsyncSession):
chain = Chain(id=1, name="mainnet")
session.add(chain)
contract = Contract(address=b"a", name="A Test Contracts", chain_id=chain.id)
session.add(contract)
await session.commit()
contract = Contract(address=b"a", name="A Test Contracts", chain_id=1)
await contract.create(session)
statement = select(Contract).where(Contract.address == b"a")
result = await session.exec(statement)
self.assertEqual(result.one(), contract)
11 changes: 3 additions & 8 deletions app/tests/routers/test_contracts.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from fastapi.testclient import TestClient

from sqlmodel.ext.asyncio.session import AsyncSession

from ...datasources.db.database import database_session
from ...datasources.db.models import Chain, Contract
from ...datasources.db.models import Contract
from ...main import app
from ...services.chain import ChainService
from ...services.contract import ContractService
from ..db.db_async_conn import DbAsyncConn


Expand All @@ -16,18 +15,14 @@ class TestRouterContract(DbAsyncConn):
def setUpClass(cls):
cls.client = TestClient(app)


@database_session
async def test_view_contracts(self, session: AsyncSession):
chain = Chain(id=1, name="mainnet")
await ChainService.create(chain)
contract = Contract(address=b"a", name="A Test Contracts", chain_id=1)
expected_response = {
"name": "A Test Contracts",
"description": None,
"address": "a",
}
await ContractService.create(contract=contract, session=session)
await contract.create(session)
response = self.client.get("/api/v1/contracts")
self.assertEqual(response.status_code, 200)
self.assertDictEqual(response.json()[0], expected_response)
80 changes: 0 additions & 80 deletions migrations/versions/c52699d26409_init.py

This file was deleted.

0 comments on commit ea2bc18

Please sign in to comment.