diff --git a/app/datasources/db/models.py b/app/datasources/db/models.py index 7b0195b..b822add 100644 --- a/app/datasources/db/models.py +++ b/app/datasources/db/models.py @@ -1,7 +1,7 @@ from datetime import datetime, timezone from typing import AsyncIterator, Self, cast -from sqlalchemy import DateTime, Row +from sqlalchemy import DateTime from sqlmodel import ( JSON, Column, @@ -286,7 +286,7 @@ async def get_abi_by_contract_address( @classmethod async def get_contracts_without_abi( cls, session: AsyncSession, max_retries: int = 0 - ) -> AsyncIterator[Row[tuple[Self]]]: + ) -> AsyncIterator[Self]: """ Fetches contracts without an ABI and fewer retries than max_retries, streaming results in batches to reduce memory usage for large datasets. @@ -303,5 +303,18 @@ async def get_contracts_without_abi( .where(cls.fetch_retries <= max_retries) ) result = await session.stream(query) - async for contract in result: + async for (contract,) in result: + yield contract + + @classmethod + async def get_proxy_contracts(cls, session: AsyncSession) -> AsyncIterator[Self]: + """ + Return all the contracts with implementation address, so proxy contracts. + + :param session: + :return: + """ + query = select(cls).where(cls.implementation.isnot(None)) # type: ignore + result = await session.stream(query) + async for (contract,) in result: yield contract diff --git a/app/tests/datasources/db/test_model.py b/app/tests/datasources/db/test_model.py index 4fcb2c8..de66b16 100644 --- a/app/tests/datasources/db/test_model.py +++ b/app/tests/datasources/db/test_model.py @@ -1,10 +1,19 @@ +from typing import cast + from eth_account import Account from hexbytes import HexBytes +from safe_eth.eth.utils import fast_to_checksum_address from sqlmodel.ext.asyncio.session import AsyncSession from app.datasources.db.database import database_session from app.datasources.db.models import Abi, AbiSource, Contract, Project +from app.services.contract_metadata_service import ( + ContractMetadataService, + ContractSource, + EnhancedContractMetadata, +) +from ...mocks.contract_metadata_mocks import etherscan_proxy_metadata_mock from .db_async_conn import DbAsyncConn @@ -150,7 +159,7 @@ async def test_get_contracts_without_abi(self, session: AsyncSession): address=random_address, name="A test contract", chain_id=1 ).create(session) async for contract in Contract.get_contracts_without_abi(session, 0): - self.assertEqual(expected_contract, contract[0]) + self.assertEqual(expected_contract, contract) # Contracts with more retries shouldn't be returned expected_contract.fetch_retries = 1 @@ -163,3 +172,30 @@ async def test_get_contracts_without_abi(self, session: AsyncSession): await expected_contract.update(session) async for contract in Contract.get_contracts_without_abi(session, 10): self.fail("Expected no contracts, but found one.") + + @database_session + async def test_get_proxy_contracts(self, session: AsyncSession): + # Test empty case + async for proxy_contract in Contract.get_proxy_contracts(session): + self.fail("Expected no proxies, but found one.") + + random_address = Account.create().address + await AbiSource(name="Etherscan", url="").create(session) + enhanced_contract_metadata = EnhancedContractMetadata( + address=random_address, + metadata=etherscan_proxy_metadata_mock, + source=ContractSource.ETHERSCAN, + chain_id=1, + ) + result = await ContractMetadataService.process_contract_metadata( + session, enhanced_contract_metadata + ) + self.assertTrue(result) + async for proxy_contract in Contract.get_proxy_contracts(session): + self.assertEqual( + fast_to_checksum_address(proxy_contract.address), random_address + ) + self.assertEqual( + fast_to_checksum_address(cast(bytes, proxy_contract.implementation)), + "0x43506849D7C04F9138D1A2050bbF3A0c054402dd", + ) diff --git a/app/tests/services/test_contract_metadata.py b/app/tests/services/test_contract_metadata.py index a53daaa..c50e764 100644 --- a/app/tests/services/test_contract_metadata.py +++ b/app/tests/services/test_contract_metadata.py @@ -1,3 +1,4 @@ +from copy import copy from unittest import mock from unittest.mock import MagicMock @@ -309,3 +310,47 @@ def test_get_proxy_implementation_address(self): self.assertIsNone( ContractMetadataService.get_proxy_implementation_address(contract_data) ) + + @database_session + async def test_process_metadata_should_update_contracts( + self, session: AsyncSession + ): + contract_address = Account.create().address + chain_id = 1 + await AbiSource(name="Etherscan", url="").create(session) + contract_metadata = EnhancedContractMetadata( + address=contract_address, + metadata=copy(etherscan_proxy_metadata_mock), # Avoid race condition + source=ContractSource.ETHERSCAN, + chain_id=1, + ) + result = await ContractMetadataService.process_contract_metadata( + session, contract_metadata + ) + self.assertTrue(result) + contract = await Contract.get_contract( + session, address=HexBytes(contract_address), chain_id=chain_id + ) + self.assertEqual( + fast_to_checksum_address(contract.address), contract_metadata.address + ) + self.assertEqual( + fast_to_checksum_address(contract.implementation), + contract_metadata.metadata.implementation, # type: ignore + ) + # Process metadata should update the db contract information + implementation_address = fast_to_checksum_address(Account.create().address) + contract_metadata.metadata.implementation = implementation_address # type: ignore + result = await ContractMetadataService.process_contract_metadata( + session, contract_metadata + ) + self.assertTrue(result) + contract = await Contract.get_contract( + session, address=HexBytes(contract_address), chain_id=chain_id + ) + self.assertEqual( + fast_to_checksum_address(contract.address), contract_metadata.address + ) + self.assertEqual( + fast_to_checksum_address(contract.implementation), implementation_address + ) diff --git a/app/tests/workers/test_tasks.py b/app/tests/workers/test_tasks.py index d9acb14..309d53d 100644 --- a/app/tests/workers/test_tasks.py +++ b/app/tests/workers/test_tasks.py @@ -9,6 +9,7 @@ from hexbytes import HexBytes from safe_eth.eth import EthereumNetwork from safe_eth.eth.clients import AsyncEtherscanClientV2 +from safe_eth.eth.utils import fast_to_checksum_address from sqlmodel.ext.asyncio.session import AsyncSession from app.datasources.db.database import database_session @@ -66,18 +67,25 @@ class TestAsyncTasks(DbAsyncConn): async def asyncSetUp(self): await super().asyncSetUp() - self.worker = Worker(redis_broker) + self.worker = Worker(redis_broker, worker_threads=1) self.worker.start() async def asyncTearDown(self): await super().asyncTearDown() self.worker.stop() + redis = get_redis() + redis.flushall() def _wait_tasks_execution(self): + # Ensure that all the messages on redis were consumed redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1) while len(redis_tasks) > 0: redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1) + # Wait for all the messages on the given queue to be processed. + # This method is only meant to be used in tests + self.worker.broker.join("default") + @mock.patch.object(ContractMetadataService, "enabled_clients") @mock.patch.object( AsyncEtherscanClientV2, "async_get_contract_metadata", autospec=True @@ -139,6 +147,7 @@ async def test_get_contract_metadata_task( async def test_get_contract_metadata_task_with_proxy( self, etherscan_get_contract_metadata_mock: MagicMock, session: AsyncSession ): + await AbiSource(name="Etherscan", url="").create(session) etherscan_get_contract_metadata_mock.side_effect = [ etherscan_proxy_metadata_mock, etherscan_metadata_mock, @@ -155,10 +164,14 @@ async def test_get_contract_metadata_task_with_proxy( session, HexBytes(contract_address), chain_id ) self.assertIsNotNone(contract) - + self.assertEqual( + fast_to_checksum_address(contract.implementation), + proxy_implementation_address, + ) proxy_implementation = await Contract.get_contract( session, HexBytes(proxy_implementation_address), chain_id ) self.assertIsNotNone(proxy_implementation) + self.assertEqual(contract.implementation, proxy_implementation.address) self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 2) diff --git a/app/workers/tasks.py b/app/workers/tasks.py index b66c6b2..85ded71 100644 --- a/app/workers/tasks.py +++ b/app/workers/tasks.py @@ -9,7 +9,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from app.config import settings -from app.datasources.db.database import database_session +from app.datasources.db.database import get_engine from app.datasources.db.models import Contract from app.services.contract_metadata_service import get_contract_metadata_service @@ -38,67 +38,78 @@ async def test_task(message: str) -> None: @dramatiq.actor -@database_session async def get_contract_metadata_task( - session: AsyncSession, - address: str, - chain_id: int, - skip_attemp_download: bool = False, -) -> None: - contract_metadata_service = get_contract_metadata_service() - # Just try the first time, following retries should be scheduled - if skip_attemp_download or await contract_metadata_service.should_attempt_download( - session, address, chain_id, 0 - ): - logger.info( - "Downloading contract metadata for contract=%s and chain=%s", - address, - chain_id, - ) - # TODO Check if contract is MultiSend. In that case, get the transaction and decode it - contract_metadata = await contract_metadata_service.get_contract_metadata( - fast_to_checksum_address(address), chain_id - ) - result = await contract_metadata_service.process_contract_metadata( - session, contract_metadata - ) - if result: - logger.info( - "Success download contract metadata for contract=%s and chain=%s", - address, - chain_id, + address: str, chain_id: int, skip_attemp_download: bool = False +): + async with AsyncSession(get_engine(), expire_on_commit=False) as session: + contract_metadata_service = get_contract_metadata_service() + # Just try the first time, following retries should be scheduled + if ( + skip_attemp_download + or await contract_metadata_service.should_attempt_download( + session, address, chain_id, 0 ) - else: + ): logger.info( - "Failed to download contract metadata for contract=%s and chain=%s", + "Downloading contract metadata for contract=%s and chain=%s", address, chain_id, ) + # TODO Check if contract is MultiSend. In that case, get the transaction and decode it + contract_metadata = await contract_metadata_service.get_contract_metadata( + fast_to_checksum_address(address), chain_id + ) + result = await contract_metadata_service.process_contract_metadata( + session, contract_metadata + ) + if result: + logger.info( + "Success download contract metadata for contract=%s and chain=%s", + address, + chain_id, + ) + else: + logger.info( + "Failed to download contract metadata for contract=%s and chain=%s", + address, + chain_id, + ) + + if proxy_implementation_address := contract_metadata_service.get_proxy_implementation_address( + contract_metadata + ): + logger.info( + "Adding task to download proxy implementation metadata from address=%s for contract=%s and chain=%s", + proxy_implementation_address, + address, + chain_id, + ) + get_contract_metadata_task.send( + address=proxy_implementation_address, chain_id=chain_id + ) + else: + logger.debug("Skipping contract=%s and chain=%s", address, chain_id) - if proxy_implementation_address := contract_metadata_service.get_proxy_implementation_address( - contract_metadata + +@dramatiq.actor(periodic=cron("0 0 * * *")) # Every midnight +async def get_missing_contract_metadata_task(): + async with AsyncSession(get_engine(), expire_on_commit=False) as session: + async for contract in Contract.get_contracts_without_abi( + session, settings.CONTRACT_MAX_DOWNLOAD_RETRIES ): - logger.info( - "Adding task to download proxy implementation metadata from address=%s for contract=%s and chain=%s", - proxy_implementation_address, - address, - chain_id, - ) get_contract_metadata_task.send( - address=proxy_implementation_address, chain_id=chain_id + address=HexBytes(contract.address).hex(), + chain_id=contract.chain_id, + skip_attemp_download=True, ) - else: - logger.debug("Skipping contract=%s and chain=%s", address, chain_id) -@dramatiq.actor(periodic=cron("0 0 * * *")) # Every midnight -@database_session -async def get_missing_contract_metadata_task(session: AsyncSession) -> None: - async for contract in Contract.get_contracts_without_abi( - session, settings.CONTRACT_MAX_DOWNLOAD_RETRIES - ): - get_contract_metadata_task.send( - address=HexBytes(contract[0].address).hex(), - chain_id=contract[0].chain_id, - skip_attemp_download=True, - ) +@dramatiq.actor(periodic=cron("0 5 * * *")) # Every day at 5 am +async def update_proxies_task(): + async with AsyncSession(get_engine(), expire_on_commit=False) as session: + async for proxy_contract in Contract.get_proxy_contracts(session): + get_contract_metadata_task.send( + address=HexBytes(proxy_contract.address).hex(), + chain_id=proxy_contract.chain_id, + skip_attemp_download=True, + )