Skip to content

Commit

Permalink
Add model and tests (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
moisses89 authored Dec 17, 2024
1 parent 9d4aa1f commit bfa0850
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 70 deletions.
81 changes: 76 additions & 5 deletions app/datasources/db/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,80 @@
from typing import Optional
from sqlmodel import (
JSON,
Column,
Field,
Relationship,
SQLModel,
UniqueConstraint,
select,
)

from sqlmodel import Field, SQLModel

class SqlQueryBase:

class Contract(SQLModel, table=True):
address: bytes = Field(nullable=False, primary_key=True)
@classmethod
async def get_all(cls, session):
result = await session.exec(select(cls))
return result.all()

async def _save(self, session):
session.add(self)
await session.commit()
return self

async def update(self, session):
return await self._save(session)

async def create(self, session):
return await self._save(session)


class AbiSource(SqlQueryBase, SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
name: str = Field(nullable=False)
url: str = Field(nullable=False)

abis: list["Abi"] = Relationship(back_populates="source")


class Abi(SqlQueryBase, SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
abi_hash: bytes = Field(nullable=False, index=True, unique=True)
relevance: int | None = Field(nullable=False, default=0)
abi_json: dict = Field(default_factory=dict, sa_column=Column(JSON))
source_id: int | None = Field(
nullable=True, default=None, foreign_key="abisource.id"
)

source: AbiSource | None = Relationship(back_populates="abis")
contracts: list["Contract"] = Relationship(back_populates="abi")


class Project(SqlQueryBase, SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
description: str = Field(nullable=False)
logo_file: str = Field(nullable=False)
contracts: list["Contract"] = Relationship(back_populates="project")


class Contract(SqlQueryBase, SQLModel, table=True):
__table_args__ = (
UniqueConstraint("address", "chain_id", name="address_chain_unique"),
)

id: int | None = Field(default=None, primary_key=True)
address: bytes = Field(nullable=False, index=True)
name: str = Field(nullable=False)
description: Optional[str] = None
display_name: str | None = None
description: str | None = None
trusted_for_delegate: bool = Field(nullable=False, default=False)
proxy: bool = Field(nullable=False, default=False)
fetch_retries: int = Field(nullable=False, default=0)
abi_id: bytes | None = Field(
nullable=True, default=None, foreign_key="abi.abi_hash"
)
abi: Abi | None = Relationship(back_populates="contracts")
project_id: int | None = Field(
nullable=True, default=None, foreign_key="project.id"
)
project: Project | None = Relationship(back_populates="contracts")
chain_id: int = Field(default=None)
17 changes: 1 addition & 16 deletions app/services/contract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Sequence

from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession

from app.datasources.db.models import Contract
Expand All @@ -16,18 +15,4 @@ async def get_all(session: AsyncSession) -> Sequence[Contract]:
:param session: passed by the decorator
:return:
"""
result = await session.exec(select(Contract))
return result.all()

@staticmethod
async def create(contract: Contract, session: AsyncSession) -> Contract:
"""
Create a new contract
:param contract:
:param session:
:return:
"""
session.add(contract)
await session.commit()
return contract
return await Contract.get_all(session)
44 changes: 36 additions & 8 deletions app/tests/db/test_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,45 @@
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession

from app.datasources.db.database import database_session
from app.datasources.db.models import Contract
from app.datasources.db.models import Abi, AbiSource, Contract, Project
from app.tests.db.db_async_conn import DbAsyncConn


class TestModel(DbAsyncConn):
@database_session
async def test_contract(self, session: AsyncSession):
contract = Contract(address=b"a", name="A Test Contracts")
session.add(contract)
await session.commit()
statement = select(Contract).where(Contract.address == b"a")
result = await session.exec(statement)
self.assertEqual(result.one(), contract)
contract = Contract(address=b"a", name="A test contract", chain_id=1)
await contract.create(session)
await contract.create(session)
result = await contract.get_all(session)
self.assertEqual(result[0], contract)

@database_session
async def test_project(self, session: AsyncSession):
project = Project(description="A Test Project", logo_file="logo.jpg")
await project.create(session)
result = await project.get_all(session)
self.assertEqual(result[0], project)

@database_session
async def test_abi(self, session: AsyncSession):
abi = Abi(abi_hash=b"A Test Abi", abi_json={"name": "A Test Project"})
await abi.create(session)
result = await abi.get_all(session)
self.assertEqual(result[0], abi)

@database_session
async def test_abi_source(self, session: AsyncSession):
abi_source = AbiSource(name="A Test Source", url="https://test.com")
await abi_source.create(session)
result = await abi_source.get_all(session)
self.assertEqual(result[0], abi_source)
abi = Abi(
abi_hash=b"A Test Abi",
abi_json={"name": "A Test Project"},
source_id=abi_source.id,
)
await abi.create(session)
result = await abi.get_all(session)
self.assertEqual(result[0], abi)
self.assertEqual(result[0].source, abi_source)
6 changes: 2 additions & 4 deletions app/tests/routers/test_contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ...datasources.db.database import database_session
from ...datasources.db.models import Contract
from ...main import app
from ...services.contract import ContractService
from ..db.db_async_conn import DbAsyncConn


Expand All @@ -18,13 +17,12 @@ def setUpClass(cls):

@database_session
async def test_view_contracts(self, session: AsyncSession):
contract = Contract(address=b"a", name="A Test Contracts")
contract = Contract(address=b"a", name="A Test Contracts", chain_id=1)
expected_response = {
"name": "A Test Contracts",
"description": None,
"address": "a",
}
await ContractService.create(contract=contract, session=session)
await contract.create(session)
response = self.client.get("/api/v1/contracts")
self.assertEqual(response.status_code, 200)
self.assertDictEqual(response.json()[0], expected_response)
88 changes: 88 additions & 0 deletions migrations/versions/9912fd3fc9ce_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""init
Revision ID: 9912fd3fc9ce
Revises:
Create Date: 2024-12-13 11:23:10.023773
"""

from typing import Sequence, Union

import sqlalchemy as sa
import sqlmodel
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "9912fd3fc9ce"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"abisource",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("url", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"project",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("logo_file", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"abi",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("abi_hash", sa.LargeBinary(), nullable=False),
sa.Column("relevance", sa.Integer(), nullable=False),
sa.Column("abi_json", sa.JSON(), nullable=True),
sa.Column("source_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["source_id"],
["abisource.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_abi_abi_hash"), "abi", ["abi_hash"], unique=True)
op.create_table(
"contract",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("address", sa.LargeBinary(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("display_name", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("trusted_for_delegate", sa.Boolean(), nullable=False),
sa.Column("proxy", sa.Boolean(), nullable=False),
sa.Column("fetch_retries", sa.Integer(), nullable=False),
sa.Column("abi_id", sa.LargeBinary(), nullable=True),
sa.Column("project_id", sa.Integer(), nullable=True),
sa.Column("chain_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["abi_id"],
["abi.abi_hash"],
),
sa.ForeignKeyConstraint(
["project_id"],
["project.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("address", "chain_id", name="address_chain_unique"),
)
op.create_index(op.f("ix_contract_address"), "contract", ["address"], unique=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_contract_address"), table_name="contract")
op.drop_table("contract")
op.drop_index(op.f("ix_abi_abi_hash"), table_name="abi")
op.drop_table("abi")
op.drop_table("project")
op.drop_table("abisource")
# ### end Alembic commands ###
37 changes: 0 additions & 37 deletions migrations/versions/d0c5d72aa50b_init.py

This file was deleted.

0 comments on commit bfa0850

Please sign in to comment.