Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
moisses89 committed Dec 19, 2024
1 parent 1d7b090 commit 42e5245
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 4 deletions.
4 changes: 2 additions & 2 deletions app/routers/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ async def list_contracts(
request: Request,
address: str,
chain_ids: Annotated[list[int] | None, Query()] = None,
limit: int = Query(None),
offset: int = Query(None),
limit: int | None = Query(None),
offset: int | None = Query(None),
session: AsyncSession = Depends(get_database_session),
) -> PaginatedResponse[Contract]:
if not fast_is_checksum_address(address):
Expand Down
63 changes: 61 additions & 2 deletions app/tests/routers/test_contracts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from fastapi.testclient import TestClient

from hexbytes import HexBytes
from safe_eth.eth.utils import fast_to_checksum_address
from sqlmodel.ext.asyncio.session import AsyncSession

from ...datasources.db.database import database_session
Expand Down Expand Up @@ -31,7 +30,7 @@ async def test_view_contracts(self, session: AsyncSession):
)
await contract.create(session)
response = self.client.get(
f"/api/v1/contracts/{fast_to_checksum_address(address.hex())}",
f"/api/v1/contracts/{address_expected}",
)
self.assertEqual(response.status_code, 200)
response_json = response.json()
Expand All @@ -46,3 +45,63 @@ async def test_view_contracts(self, session: AsyncSession):
self.assertEqual(results[0]["display_name"], None)
self.assertEqual(results[0]["chain_id"], 1)
self.assertEqual(results[0]["project"], None)
# Test filter by chain_id
contract = Contract(
address=address, name="A Test Contracts", chain_id=5, abi_id=abi.abi_hash
)
await contract.create(session)

response = self.client.get(
f"/api/v1/contracts/{address_expected}?chain_ids=5",
)
self.assertEqual(response.status_code, 200)
response_json = response.json()
results = response_json["results"]
self.assertEqual(response_json["count"], 1)
self.assertEqual(len(results), 1)
self.assertEqual(results[0]["chain_id"], 5)

@database_session
async def test_contracts_pagination(self, session: AsyncSession):
source = AbiSource(name="Etherscan", url="https://api.etherscan.io/api")
await source.create(session)
abi = Abi(abi_json=mock_abi_json, source_id=source.id)
await abi.create(session)
address_expected = "0x6eEF70Da339a98102a642969B3956DEa71A1096e"
address = HexBytes(address_expected)
for chain_id in range(0, 10):
contract = Contract(
address=address,
name="A Test Contracts",
chain_id=chain_id,
abi_id=abi.abi_hash,
)
await contract.create(session)

response = self.client.get(
f"/api/v1/contracts/{address_expected}?limit=5",
)
self.assertEqual(response.status_code, 200)
response_json = response.json()
results = response_json["results"]
self.assertEqual(response_json["count"], 10)
self.assertEqual(
response_json["next"],
f"http://testserver/api/v1/contracts/{address_expected}?limit=5&offset=5",
)
self.assertEqual(response_json["previous"], None)
self.assertEqual(len(results), 5)

response = self.client.get(
f"/api/v1/contracts/{address_expected}?limit=5&offset=5",
)
self.assertEqual(response.status_code, 200)
response_json = response.json()
results = response_json["results"]
self.assertEqual(response_json["count"], 10)
self.assertEqual(response_json["next"], None)
self.assertEqual(
response_json["previous"],
f"http://testserver/api/v1/contracts/{address_expected}?limit=5&offset=0",
)
self.assertEqual(len(results), 5)

0 comments on commit 42e5245

Please sign in to comment.