-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
203 additions
and
70 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,80 @@ | ||
from typing import Optional | ||
from sqlmodel import ( | ||
JSON, | ||
Column, | ||
Field, | ||
Relationship, | ||
SQLModel, | ||
UniqueConstraint, | ||
select, | ||
) | ||
|
||
from sqlmodel import Field, SQLModel | ||
|
||
class SqlQueryBase: | ||
|
||
class Contract(SQLModel, table=True): | ||
address: bytes = Field(nullable=False, primary_key=True) | ||
@classmethod | ||
async def get_all(cls, session): | ||
result = await session.exec(select(cls)) | ||
return result.all() | ||
|
||
async def _save(self, session): | ||
session.add(self) | ||
await session.commit() | ||
return self | ||
|
||
async def update(self, session): | ||
return await self._save(session) | ||
|
||
async def create(self, session): | ||
return await self._save(session) | ||
|
||
|
||
class AbiSource(SqlQueryBase, SQLModel, table=True): | ||
id: int | None = Field(default=None, primary_key=True) | ||
name: str = Field(nullable=False) | ||
url: str = Field(nullable=False) | ||
|
||
abis: list["Abi"] = Relationship(back_populates="source") | ||
|
||
|
||
class Abi(SqlQueryBase, SQLModel, table=True): | ||
id: int | None = Field(default=None, primary_key=True) | ||
abi_hash: bytes = Field(nullable=False, index=True, unique=True) | ||
relevance: int | None = Field(nullable=False, default=0) | ||
abi_json: dict = Field(default_factory=dict, sa_column=Column(JSON)) | ||
source_id: int | None = Field( | ||
nullable=True, default=None, foreign_key="abisource.id" | ||
) | ||
|
||
source: AbiSource | None = Relationship(back_populates="abis") | ||
contracts: list["Contract"] = Relationship(back_populates="abi") | ||
|
||
|
||
class Project(SqlQueryBase, SQLModel, table=True): | ||
id: int | None = Field(default=None, primary_key=True) | ||
description: str = Field(nullable=False) | ||
logo_file: str = Field(nullable=False) | ||
contracts: list["Contract"] = Relationship(back_populates="project") | ||
|
||
|
||
class Contract(SqlQueryBase, SQLModel, table=True): | ||
__table_args__ = ( | ||
UniqueConstraint("address", "chain_id", name="address_chain_unique"), | ||
) | ||
|
||
id: int | None = Field(default=None, primary_key=True) | ||
address: bytes = Field(nullable=False, index=True) | ||
name: str = Field(nullable=False) | ||
description: Optional[str] = None | ||
display_name: str | None = None | ||
description: str | None = None | ||
trusted_for_delegate: bool = Field(nullable=False, default=False) | ||
proxy: bool = Field(nullable=False, default=False) | ||
fetch_retries: int = Field(nullable=False, default=0) | ||
abi_id: bytes | None = Field( | ||
nullable=True, default=None, foreign_key="abi.abi_hash" | ||
) | ||
abi: Abi | None = Relationship(back_populates="contracts") | ||
project_id: int | None = Field( | ||
nullable=True, default=None, foreign_key="project.id" | ||
) | ||
project: Project | None = Relationship(back_populates="contracts") | ||
chain_id: int = Field(default=None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,45 @@ | ||
from sqlmodel import select | ||
from sqlmodel.ext.asyncio.session import AsyncSession | ||
|
||
from app.datasources.db.database import database_session | ||
from app.datasources.db.models import Contract | ||
from app.datasources.db.models import Abi, AbiSource, Contract, Project | ||
from app.tests.db.db_async_conn import DbAsyncConn | ||
|
||
|
||
class TestModel(DbAsyncConn): | ||
@database_session | ||
async def test_contract(self, session: AsyncSession): | ||
contract = Contract(address=b"a", name="A Test Contracts") | ||
session.add(contract) | ||
await session.commit() | ||
statement = select(Contract).where(Contract.address == b"a") | ||
result = await session.exec(statement) | ||
self.assertEqual(result.one(), contract) | ||
contract = Contract(address=b"a", name="A test contract", chain_id=1) | ||
await contract.create(session) | ||
await contract.create(session) | ||
result = await contract.get_all(session) | ||
self.assertEqual(result[0], contract) | ||
|
||
@database_session | ||
async def test_project(self, session: AsyncSession): | ||
project = Project(description="A Test Project", logo_file="logo.jpg") | ||
await project.create(session) | ||
result = await project.get_all(session) | ||
self.assertEqual(result[0], project) | ||
|
||
@database_session | ||
async def test_abi(self, session: AsyncSession): | ||
abi = Abi(abi_hash=b"A Test Abi", abi_json={"name": "A Test Project"}) | ||
await abi.create(session) | ||
result = await abi.get_all(session) | ||
self.assertEqual(result[0], abi) | ||
|
||
@database_session | ||
async def test_abi_source(self, session: AsyncSession): | ||
abi_source = AbiSource(name="A Test Source", url="https://test.com") | ||
await abi_source.create(session) | ||
result = await abi_source.get_all(session) | ||
self.assertEqual(result[0], abi_source) | ||
abi = Abi( | ||
abi_hash=b"A Test Abi", | ||
abi_json={"name": "A Test Project"}, | ||
source_id=abi_source.id, | ||
) | ||
await abi.create(session) | ||
result = await abi.get_all(session) | ||
self.assertEqual(result[0], abi) | ||
self.assertEqual(result[0].source, abi_source) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""init | ||
Revision ID: 9912fd3fc9ce | ||
Revises: | ||
Create Date: 2024-12-13 11:23:10.023773 | ||
""" | ||
|
||
from typing import Sequence, Union | ||
|
||
import sqlalchemy as sa | ||
import sqlmodel | ||
from alembic import op | ||
|
||
# revision identifiers, used by Alembic. | ||
revision: str = "9912fd3fc9ce" | ||
down_revision: Union[str, None] = None | ||
branch_labels: Union[str, Sequence[str], None] = None | ||
depends_on: Union[str, Sequence[str], None] = None | ||
|
||
|
||
def upgrade() -> None: | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.create_table( | ||
"abisource", | ||
sa.Column("id", sa.Integer(), nullable=False), | ||
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), | ||
sa.Column("url", sqlmodel.sql.sqltypes.AutoString(), nullable=False), | ||
sa.PrimaryKeyConstraint("id"), | ||
) | ||
op.create_table( | ||
"project", | ||
sa.Column("id", sa.Integer(), nullable=False), | ||
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=False), | ||
sa.Column("logo_file", sqlmodel.sql.sqltypes.AutoString(), nullable=False), | ||
sa.PrimaryKeyConstraint("id"), | ||
) | ||
op.create_table( | ||
"abi", | ||
sa.Column("id", sa.Integer(), nullable=False), | ||
sa.Column("abi_hash", sa.LargeBinary(), nullable=False), | ||
sa.Column("relevance", sa.Integer(), nullable=False), | ||
sa.Column("abi_json", sa.JSON(), nullable=True), | ||
sa.Column("source_id", sa.Integer(), nullable=False), | ||
sa.ForeignKeyConstraint( | ||
["source_id"], | ||
["abisource.id"], | ||
), | ||
sa.PrimaryKeyConstraint("id"), | ||
) | ||
op.create_index(op.f("ix_abi_abi_hash"), "abi", ["abi_hash"], unique=True) | ||
op.create_table( | ||
"contract", | ||
sa.Column("id", sa.Integer(), nullable=False), | ||
sa.Column("address", sa.LargeBinary(), nullable=False), | ||
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), | ||
sa.Column("display_name", sqlmodel.sql.sqltypes.AutoString(), nullable=True), | ||
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True), | ||
sa.Column("trusted_for_delegate", sa.Boolean(), nullable=False), | ||
sa.Column("proxy", sa.Boolean(), nullable=False), | ||
sa.Column("fetch_retries", sa.Integer(), nullable=False), | ||
sa.Column("abi_id", sa.LargeBinary(), nullable=True), | ||
sa.Column("project_id", sa.Integer(), nullable=True), | ||
sa.Column("chain_id", sa.Integer(), nullable=False), | ||
sa.ForeignKeyConstraint( | ||
["abi_id"], | ||
["abi.abi_hash"], | ||
), | ||
sa.ForeignKeyConstraint( | ||
["project_id"], | ||
["project.id"], | ||
), | ||
sa.PrimaryKeyConstraint("id"), | ||
sa.UniqueConstraint("address", "chain_id", name="address_chain_unique"), | ||
) | ||
op.create_index(op.f("ix_contract_address"), "contract", ["address"], unique=False) | ||
# ### end Alembic commands ### | ||
|
||
|
||
def downgrade() -> None: | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.drop_index(op.f("ix_contract_address"), table_name="contract") | ||
op.drop_table("contract") | ||
op.drop_index(op.f("ix_abi_abi_hash"), table_name="abi") | ||
op.drop_table("abi") | ||
op.drop_table("project") | ||
op.drop_table("abisource") | ||
# ### end Alembic commands ### |
This file was deleted.
Oops, something went wrong.