Skip to content

Commit

Permalink
Add test to ensure contracts are being updated
Browse files Browse the repository at this point in the history
  • Loading branch information
moisses89 committed Jan 15, 2025
1 parent 2e7b883 commit ea3f411
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 4 deletions.
13 changes: 13 additions & 0 deletions app/datasources/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,16 @@ async def get_contracts_without_abi(
result = await session.stream(query)
async for contract in result:
yield contract

@classmethod
async def get_proxy_contracts(cls, session: AsyncSession):
"""
Return all the contracts with implementation address, so proxy contracts.
:param session:
:return:
"""
query = select(cls).where(cls.implementation != None) # noqa: E711
result = await session.stream(query)
async for contract in result:
yield contract
34 changes: 34 additions & 0 deletions app/tests/datasources/db/test_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from eth_account import Account
from hexbytes import HexBytes
from safe_eth.eth.utils import fast_to_checksum_address
from sqlmodel.ext.asyncio.session import AsyncSession

from app.datasources.db.database import database_session
from app.datasources.db.models import Abi, AbiSource, Contract, Project
from app.services.contract_metadata_service import (
ClientSource,
ContractMetadataService,
EnhancedContractMetadata,
)

from ...mocks.contract_metadata_mocks import etherscan_proxy_metadata_mock
from .db_async_conn import DbAsyncConn


Expand Down Expand Up @@ -163,3 +170,30 @@ async def test_get_contracts_without_abi(self, session: AsyncSession):
await expected_contract.update(session)
async for contract in Contract.get_contracts_without_abi(session, 10):
self.fail("Expected no contracts, but found one.")

@database_session
async def test_get_proxy_contracts(self, session: AsyncSession):
# Test empty case
async for proxy_contract in Contract.get_proxy_contracts(session):
self.fail("Expected no proxies, but found one.")

random_address = Account.create().address
await AbiSource(name="Etherscan", url="").create(session)
enhanced_contract_metadata = EnhancedContractMetadata(
address=random_address,
metadata=etherscan_proxy_metadata_mock,
source=ClientSource.ETHERSCAN,
chain_id=1,
)
result = await ContractMetadataService.process_contract_metadata(
session, enhanced_contract_metadata
)
self.assertTrue(result)
async for proxy_contract in Contract.get_proxy_contracts(session):
self.assertEqual(
fast_to_checksum_address(proxy_contract[0].address), random_address
)
self.assertEqual(
fast_to_checksum_address(proxy_contract[0].implementation),
"0x43506849D7C04F9138D1A2050bbF3A0c054402dd",
)
44 changes: 44 additions & 0 deletions app/tests/services/test_contract_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,47 @@ def test_get_proxy_implementation_address(self):
self.assertIsNone(
ContractMetadataService.get_proxy_implementation_address(contract_data)
)

@database_session
async def test_process_metadata_should_update_contracts(
self, session: AsyncSession
):
contract_address = Account.create().address
chain_id = 1
await AbiSource(name="Etherscan", url="").create(session)
contract_metadata = EnhancedContractMetadata(
address=contract_address,
metadata=etherscan_proxy_metadata_mock,
source=ClientSource.ETHERSCAN,
chain_id=1,
)
result = await ContractMetadataService.process_contract_metadata(
session, contract_metadata
)
self.assertTrue(result)
contract = await Contract.get_contract(
session, address=HexBytes(contract_address), chain_id=chain_id
)
self.assertEqual(
fast_to_checksum_address(contract.address), contract_metadata.address
)
self.assertEqual(
fast_to_checksum_address(contract.implementation),
contract_metadata.metadata.implementation, # type: ignore
)
# Process metadata should update the db contract information
implementation_address = fast_to_checksum_address(Account.create().address)
contract_metadata.metadata.implementation = implementation_address # type: ignore
result = await ContractMetadataService.process_contract_metadata(
session, contract_metadata
)
self.assertTrue(result)
contract = await Contract.get_contract(
session, address=HexBytes(contract_address), chain_id=chain_id
)
self.assertEqual(
fast_to_checksum_address(contract.address), contract_metadata.address
)
self.assertEqual(
fast_to_checksum_address(contract.implementation), implementation_address
)
16 changes: 12 additions & 4 deletions app/tests/workers/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import unittest
from typing import Any, Awaitable
Expand Down Expand Up @@ -71,12 +72,16 @@ async def asyncSetUp(self):

async def asyncTearDown(self):
await super().asyncTearDown()
redis = get_redis()
redis.flushall()
self.worker.stop()

def _wait_tasks_execution(self):
async def _wait_tasks_execution(self):
redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1)
while len(redis_tasks) > 0:
redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1)
# TODO should check anyway the task status instead to randomly wait
await asyncio.sleep(2)

@mock.patch.object(ContractMetadataService, "enabled_clients")
@mock.patch.object(
Expand Down Expand Up @@ -139,6 +144,7 @@ async def test_get_contract_metadata_task(
async def test_get_contract_metadata_task_with_proxy(
self, etherscan_get_contract_metadata_mock: MagicMock, session: AsyncSession
):
await AbiSource(name="Etherscan", url="").create(session)
etherscan_get_contract_metadata_mock.side_effect = [
etherscan_proxy_metadata_mock,
etherscan_metadata_mock,
Expand All @@ -149,16 +155,18 @@ async def test_get_contract_metadata_task_with_proxy(

get_contract_metadata_task.fn(address=contract_address, chain_id=chain_id)

self._wait_tasks_execution()
await self._wait_tasks_execution()

self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 2)
contract = await Contract.get_contract(
session, HexBytes(contract_address), chain_id
)
self.assertIsNotNone(contract)
self.assertEqual(
HexBytes(contract.implementation), HexBytes(proxy_implementation_address)
)

proxy_implementation = await Contract.get_contract(
session, HexBytes(proxy_implementation_address), chain_id
)
self.assertIsNotNone(proxy_implementation)

self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 2)
11 changes: 11 additions & 0 deletions app/workers/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,14 @@ async def get_missing_contract_metadata_task(session: AsyncSession) -> None:
chain_id=contract[0].chain_id,
skip_attemp_download=True,
)


@dramatiq.actor(periodic=cron("0 5 * * *")) # Every day at 5 am
@database_session
async def update_proxies_task(session: AsyncSession) -> None:
async for proxy_contract in Contract.get_proxy_contracts(session):
get_contract_metadata_task.send(
address=HexBytes(proxy_contract[0].address).hex(),
chain_id=proxy_contract[0].chain_id,
skip_attemp_download=True,
)

0 comments on commit ea3f411

Please sign in to comment.