diff --git a/app/datasources/db/models.py b/app/datasources/db/models.py index 85d76b2..73f4625 100644 --- a/app/datasources/db/models.py +++ b/app/datasources/db/models.py @@ -300,3 +300,16 @@ async def get_contracts_without_abi( result = await session.stream(query) async for contract in result: yield contract + + @classmethod + async def get_proxy_contracts(cls, session: AsyncSession): + """ + Return all the contracts with implementation address, so proxy contracts. + + :param session: + :return: + """ + query = select(cls).where(cls.implementation != None) # noqa: E711 + result = await session.stream(query) + async for contract in result: + yield contract diff --git a/app/tests/datasources/db/test_model.py b/app/tests/datasources/db/test_model.py index 4fcb2c8..24f9e0e 100644 --- a/app/tests/datasources/db/test_model.py +++ b/app/tests/datasources/db/test_model.py @@ -1,10 +1,17 @@ from eth_account import Account from hexbytes import HexBytes +from safe_eth.eth.utils import fast_to_checksum_address from sqlmodel.ext.asyncio.session import AsyncSession from app.datasources.db.database import database_session from app.datasources.db.models import Abi, AbiSource, Contract, Project +from app.services.contract_metadata_service import ( + ClientSource, + ContractMetadataService, + EnhancedContractMetadata, +) +from ...mocks.contract_metadata_mocks import etherscan_proxy_metadata_mock from .db_async_conn import DbAsyncConn @@ -163,3 +170,30 @@ async def test_get_contracts_without_abi(self, session: AsyncSession): await expected_contract.update(session) async for contract in Contract.get_contracts_without_abi(session, 10): self.fail("Expected no contracts, but found one.") + + @database_session + async def test_get_proxy_contracts(self, session: AsyncSession): + # Test empty case + async for proxy_contract in Contract.get_proxy_contracts(session): + self.fail("Expected no proxies, but found one.") + + random_address = Account.create().address + await AbiSource(name="Etherscan", url="").create(session) + enhanced_contract_metadata = EnhancedContractMetadata( + address=random_address, + metadata=etherscan_proxy_metadata_mock, + source=ClientSource.ETHERSCAN, + chain_id=1, + ) + result = await ContractMetadataService.process_contract_metadata( + session, enhanced_contract_metadata + ) + self.assertTrue(result) + async for proxy_contract in Contract.get_proxy_contracts(session): + self.assertEqual( + fast_to_checksum_address(proxy_contract[0].address), random_address + ) + self.assertEqual( + fast_to_checksum_address(proxy_contract[0].implementation), + "0x43506849D7C04F9138D1A2050bbF3A0c054402dd", + ) diff --git a/app/tests/services/test_contract_metadata.py b/app/tests/services/test_contract_metadata.py index dd15905..036087d 100644 --- a/app/tests/services/test_contract_metadata.py +++ b/app/tests/services/test_contract_metadata.py @@ -309,3 +309,47 @@ def test_get_proxy_implementation_address(self): self.assertIsNone( ContractMetadataService.get_proxy_implementation_address(contract_data) ) + + @database_session + async def test_process_metadata_should_update_contracts( + self, session: AsyncSession + ): + contract_address = Account.create().address + chain_id = 1 + await AbiSource(name="Etherscan", url="").create(session) + contract_metadata = EnhancedContractMetadata( + address=contract_address, + metadata=etherscan_proxy_metadata_mock, + source=ClientSource.ETHERSCAN, + chain_id=1, + ) + result = await ContractMetadataService.process_contract_metadata( + session, contract_metadata + ) + self.assertTrue(result) + contract = await Contract.get_contract( + session, address=HexBytes(contract_address), chain_id=chain_id + ) + self.assertEqual( + fast_to_checksum_address(contract.address), contract_metadata.address + ) + self.assertEqual( + fast_to_checksum_address(contract.implementation), + contract_metadata.metadata.implementation, # type: ignore + ) + # Process metadata should update the db contract information + implementation_address = fast_to_checksum_address(Account.create().address) + contract_metadata.metadata.implementation = implementation_address # type: ignore + result = await ContractMetadataService.process_contract_metadata( + session, contract_metadata + ) + self.assertTrue(result) + contract = await Contract.get_contract( + session, address=HexBytes(contract_address), chain_id=chain_id + ) + self.assertEqual( + fast_to_checksum_address(contract.address), contract_metadata.address + ) + self.assertEqual( + fast_to_checksum_address(contract.implementation), implementation_address + ) diff --git a/app/tests/workers/test_tasks.py b/app/tests/workers/test_tasks.py index d9acb14..f69e249 100644 --- a/app/tests/workers/test_tasks.py +++ b/app/tests/workers/test_tasks.py @@ -1,3 +1,4 @@ +import asyncio import json import unittest from typing import Any, Awaitable @@ -71,12 +72,16 @@ async def asyncSetUp(self): async def asyncTearDown(self): await super().asyncTearDown() + redis = get_redis() + redis.flushall() self.worker.stop() - def _wait_tasks_execution(self): + async def _wait_tasks_execution(self): redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1) while len(redis_tasks) > 0: redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1) + # TODO should check anyway the task status instead to randomly wait + await asyncio.sleep(2) @mock.patch.object(ContractMetadataService, "enabled_clients") @mock.patch.object( @@ -139,6 +144,7 @@ async def test_get_contract_metadata_task( async def test_get_contract_metadata_task_with_proxy( self, etherscan_get_contract_metadata_mock: MagicMock, session: AsyncSession ): + await AbiSource(name="Etherscan", url="").create(session) etherscan_get_contract_metadata_mock.side_effect = [ etherscan_proxy_metadata_mock, etherscan_metadata_mock, @@ -149,16 +155,18 @@ async def test_get_contract_metadata_task_with_proxy( get_contract_metadata_task.fn(address=contract_address, chain_id=chain_id) - self._wait_tasks_execution() + await self._wait_tasks_execution() + self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 2) contract = await Contract.get_contract( session, HexBytes(contract_address), chain_id ) self.assertIsNotNone(contract) + self.assertEqual( + HexBytes(contract.implementation), HexBytes(proxy_implementation_address) + ) proxy_implementation = await Contract.get_contract( session, HexBytes(proxy_implementation_address), chain_id ) self.assertIsNotNone(proxy_implementation) - - self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 2) diff --git a/app/workers/tasks.py b/app/workers/tasks.py index 44874cb..946fff8 100644 --- a/app/workers/tasks.py +++ b/app/workers/tasks.py @@ -101,3 +101,14 @@ async def get_missing_contract_metadata_task(session: AsyncSession) -> None: chain_id=contract[0].chain_id, skip_attemp_download=True, ) + + +@dramatiq.actor(periodic=cron("0 5 * * *")) # Every day at 5 am +@database_session +async def update_proxies_task(session: AsyncSession) -> None: + async for proxy_contract in Contract.get_proxy_contracts(session): + get_contract_metadata_task.send( + address=HexBytes(proxy_contract[0].address).hex(), + chain_id=proxy_contract[0].chain_id, + skip_attemp_download=True, + )