From 6af07ea3f0868de63e9e2c55c4db7e1c9e7b42fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mois=C3=A9s?= <7888669+moisses89@users.noreply.github.com> Date: Wed, 15 Jan 2025 13:34:03 +0100 Subject: [PATCH 1/9] Add test to ensure contracts are being updated --- app/datasources/db/models.py | 13 ++++++ app/tests/datasources/db/test_model.py | 34 +++++++++++++++ app/tests/services/test_contract_metadata.py | 44 ++++++++++++++++++++ app/tests/workers/test_tasks.py | 16 +++++-- app/workers/tasks.py | 11 +++++ 5 files changed, 114 insertions(+), 4 deletions(-) diff --git a/app/datasources/db/models.py b/app/datasources/db/models.py index 7b0195b..561ca7b 100644 --- a/app/datasources/db/models.py +++ b/app/datasources/db/models.py @@ -305,3 +305,16 @@ async def get_contracts_without_abi( result = await session.stream(query) async for contract in result: yield contract + + @classmethod + async def get_proxy_contracts(cls, session: AsyncSession): + """ + Return all the contracts with implementation address, so proxy contracts. + + :param session: + :return: + """ + query = select(cls).where(cls.implementation != None) # noqa: E711 + result = await session.stream(query) + async for contract in result: + yield contract diff --git a/app/tests/datasources/db/test_model.py b/app/tests/datasources/db/test_model.py index 4fcb2c8..24f9e0e 100644 --- a/app/tests/datasources/db/test_model.py +++ b/app/tests/datasources/db/test_model.py @@ -1,10 +1,17 @@ from eth_account import Account from hexbytes import HexBytes +from safe_eth.eth.utils import fast_to_checksum_address from sqlmodel.ext.asyncio.session import AsyncSession from app.datasources.db.database import database_session from app.datasources.db.models import Abi, AbiSource, Contract, Project +from app.services.contract_metadata_service import ( + ClientSource, + ContractMetadataService, + EnhancedContractMetadata, +) +from ...mocks.contract_metadata_mocks import etherscan_proxy_metadata_mock from .db_async_conn import DbAsyncConn @@ -163,3 +170,30 @@ async def test_get_contracts_without_abi(self, session: AsyncSession): await expected_contract.update(session) async for contract in Contract.get_contracts_without_abi(session, 10): self.fail("Expected no contracts, but found one.") + + @database_session + async def test_get_proxy_contracts(self, session: AsyncSession): + # Test empty case + async for proxy_contract in Contract.get_proxy_contracts(session): + self.fail("Expected no proxies, but found one.") + + random_address = Account.create().address + await AbiSource(name="Etherscan", url="").create(session) + enhanced_contract_metadata = EnhancedContractMetadata( + address=random_address, + metadata=etherscan_proxy_metadata_mock, + source=ClientSource.ETHERSCAN, + chain_id=1, + ) + result = await ContractMetadataService.process_contract_metadata( + session, enhanced_contract_metadata + ) + self.assertTrue(result) + async for proxy_contract in Contract.get_proxy_contracts(session): + self.assertEqual( + fast_to_checksum_address(proxy_contract[0].address), random_address + ) + self.assertEqual( + fast_to_checksum_address(proxy_contract[0].implementation), + "0x43506849D7C04F9138D1A2050bbF3A0c054402dd", + ) diff --git a/app/tests/services/test_contract_metadata.py b/app/tests/services/test_contract_metadata.py index a53daaa..1f9b4e4 100644 --- a/app/tests/services/test_contract_metadata.py +++ b/app/tests/services/test_contract_metadata.py @@ -309,3 +309,47 @@ def test_get_proxy_implementation_address(self): self.assertIsNone( ContractMetadataService.get_proxy_implementation_address(contract_data) ) + + @database_session + async def test_process_metadata_should_update_contracts( + self, session: AsyncSession + ): + contract_address = Account.create().address + chain_id = 1 + await AbiSource(name="Etherscan", url="").create(session) + contract_metadata = EnhancedContractMetadata( + address=contract_address, + metadata=etherscan_proxy_metadata_mock, + source=ClientSource.ETHERSCAN, + chain_id=1, + ) + result = await ContractMetadataService.process_contract_metadata( + session, contract_metadata + ) + self.assertTrue(result) + contract = await Contract.get_contract( + session, address=HexBytes(contract_address), chain_id=chain_id + ) + self.assertEqual( + fast_to_checksum_address(contract.address), contract_metadata.address + ) + self.assertEqual( + fast_to_checksum_address(contract.implementation), + contract_metadata.metadata.implementation, # type: ignore + ) + # Process metadata should update the db contract information + implementation_address = fast_to_checksum_address(Account.create().address) + contract_metadata.metadata.implementation = implementation_address # type: ignore + result = await ContractMetadataService.process_contract_metadata( + session, contract_metadata + ) + self.assertTrue(result) + contract = await Contract.get_contract( + session, address=HexBytes(contract_address), chain_id=chain_id + ) + self.assertEqual( + fast_to_checksum_address(contract.address), contract_metadata.address + ) + self.assertEqual( + fast_to_checksum_address(contract.implementation), implementation_address + ) diff --git a/app/tests/workers/test_tasks.py b/app/tests/workers/test_tasks.py index d9acb14..f69e249 100644 --- a/app/tests/workers/test_tasks.py +++ b/app/tests/workers/test_tasks.py @@ -1,3 +1,4 @@ +import asyncio import json import unittest from typing import Any, Awaitable @@ -71,12 +72,16 @@ async def asyncSetUp(self): async def asyncTearDown(self): await super().asyncTearDown() + redis = get_redis() + redis.flushall() self.worker.stop() - def _wait_tasks_execution(self): + async def _wait_tasks_execution(self): redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1) while len(redis_tasks) > 0: redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1) + # TODO should check anyway the task status instead to randomly wait + await asyncio.sleep(2) @mock.patch.object(ContractMetadataService, "enabled_clients") @mock.patch.object( @@ -139,6 +144,7 @@ async def test_get_contract_metadata_task( async def test_get_contract_metadata_task_with_proxy( self, etherscan_get_contract_metadata_mock: MagicMock, session: AsyncSession ): + await AbiSource(name="Etherscan", url="").create(session) etherscan_get_contract_metadata_mock.side_effect = [ etherscan_proxy_metadata_mock, etherscan_metadata_mock, @@ -149,16 +155,18 @@ async def test_get_contract_metadata_task_with_proxy( get_contract_metadata_task.fn(address=contract_address, chain_id=chain_id) - self._wait_tasks_execution() + await self._wait_tasks_execution() + self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 2) contract = await Contract.get_contract( session, HexBytes(contract_address), chain_id ) self.assertIsNotNone(contract) + self.assertEqual( + HexBytes(contract.implementation), HexBytes(proxy_implementation_address) + ) proxy_implementation = await Contract.get_contract( session, HexBytes(proxy_implementation_address), chain_id ) self.assertIsNotNone(proxy_implementation) - - self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 2) diff --git a/app/workers/tasks.py b/app/workers/tasks.py index b66c6b2..e231f84 100644 --- a/app/workers/tasks.py +++ b/app/workers/tasks.py @@ -102,3 +102,14 @@ async def get_missing_contract_metadata_task(session: AsyncSession) -> None: chain_id=contract[0].chain_id, skip_attemp_download=True, ) + + +@dramatiq.actor(periodic=cron("0 5 * * *")) # Every day at 5 am +@database_session +async def update_proxies_task(session: AsyncSession) -> None: + async for proxy_contract in Contract.get_proxy_contracts(session): + get_contract_metadata_task.send( + address=HexBytes(proxy_contract[0].address).hex(), + chain_id=proxy_contract[0].chain_id, + skip_attemp_download=True, + ) From a4ce45df1333059ea77b2d198535d23185f26053 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mois=C3=A9s?= <7888669+moisses89@users.noreply.github.com> Date: Wed, 15 Jan 2025 13:52:00 +0100 Subject: [PATCH 2/9] Restore weak test --- app/tests/workers/test_tasks.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/app/tests/workers/test_tasks.py b/app/tests/workers/test_tasks.py index f69e249..d9acb14 100644 --- a/app/tests/workers/test_tasks.py +++ b/app/tests/workers/test_tasks.py @@ -1,4 +1,3 @@ -import asyncio import json import unittest from typing import Any, Awaitable @@ -72,16 +71,12 @@ async def asyncSetUp(self): async def asyncTearDown(self): await super().asyncTearDown() - redis = get_redis() - redis.flushall() self.worker.stop() - async def _wait_tasks_execution(self): + def _wait_tasks_execution(self): redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1) while len(redis_tasks) > 0: redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1) - # TODO should check anyway the task status instead to randomly wait - await asyncio.sleep(2) @mock.patch.object(ContractMetadataService, "enabled_clients") @mock.patch.object( @@ -144,7 +139,6 @@ async def test_get_contract_metadata_task( async def test_get_contract_metadata_task_with_proxy( self, etherscan_get_contract_metadata_mock: MagicMock, session: AsyncSession ): - await AbiSource(name="Etherscan", url="").create(session) etherscan_get_contract_metadata_mock.side_effect = [ etherscan_proxy_metadata_mock, etherscan_metadata_mock, @@ -155,18 +149,16 @@ async def test_get_contract_metadata_task_with_proxy( get_contract_metadata_task.fn(address=contract_address, chain_id=chain_id) - await self._wait_tasks_execution() + self._wait_tasks_execution() - self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 2) contract = await Contract.get_contract( session, HexBytes(contract_address), chain_id ) self.assertIsNotNone(contract) - self.assertEqual( - HexBytes(contract.implementation), HexBytes(proxy_implementation_address) - ) proxy_implementation = await Contract.get_contract( session, HexBytes(proxy_implementation_address), chain_id ) self.assertIsNotNone(proxy_implementation) + + self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 2) From 4054ebb73ab8159f2a9257342d3f7b5fbbf4f229 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mois=C3=A9s?= <7888669+moisses89@users.noreply.github.com> Date: Wed, 15 Jan 2025 16:41:13 +0100 Subject: [PATCH 3/9] Flush redis --- app/tests/workers/test_tasks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/app/tests/workers/test_tasks.py b/app/tests/workers/test_tasks.py index d9acb14..627c960 100644 --- a/app/tests/workers/test_tasks.py +++ b/app/tests/workers/test_tasks.py @@ -72,6 +72,8 @@ async def asyncSetUp(self): async def asyncTearDown(self): await super().asyncTearDown() self.worker.stop() + redis = get_redis() + redis.flushall() def _wait_tasks_execution(self): redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1) From 13f4f347aad6a2621161d01862eb92437abc0410 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mois=C3=A9s?= <7888669+moisses89@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:15:34 +0100 Subject: [PATCH 4/9] Fix test --- app/tests/workers/test_tasks.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/app/tests/workers/test_tasks.py b/app/tests/workers/test_tasks.py index 627c960..309d53d 100644 --- a/app/tests/workers/test_tasks.py +++ b/app/tests/workers/test_tasks.py @@ -9,6 +9,7 @@ from hexbytes import HexBytes from safe_eth.eth import EthereumNetwork from safe_eth.eth.clients import AsyncEtherscanClientV2 +from safe_eth.eth.utils import fast_to_checksum_address from sqlmodel.ext.asyncio.session import AsyncSession from app.datasources.db.database import database_session @@ -66,7 +67,7 @@ class TestAsyncTasks(DbAsyncConn): async def asyncSetUp(self): await super().asyncSetUp() - self.worker = Worker(redis_broker) + self.worker = Worker(redis_broker, worker_threads=1) self.worker.start() async def asyncTearDown(self): @@ -76,10 +77,15 @@ async def asyncTearDown(self): redis.flushall() def _wait_tasks_execution(self): + # Ensure that all the messages on redis were consumed redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1) while len(redis_tasks) > 0: redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1) + # Wait for all the messages on the given queue to be processed. + # This method is only meant to be used in tests + self.worker.broker.join("default") + @mock.patch.object(ContractMetadataService, "enabled_clients") @mock.patch.object( AsyncEtherscanClientV2, "async_get_contract_metadata", autospec=True @@ -141,6 +147,7 @@ async def test_get_contract_metadata_task( async def test_get_contract_metadata_task_with_proxy( self, etherscan_get_contract_metadata_mock: MagicMock, session: AsyncSession ): + await AbiSource(name="Etherscan", url="").create(session) etherscan_get_contract_metadata_mock.side_effect = [ etherscan_proxy_metadata_mock, etherscan_metadata_mock, @@ -157,10 +164,14 @@ async def test_get_contract_metadata_task_with_proxy( session, HexBytes(contract_address), chain_id ) self.assertIsNotNone(contract) - + self.assertEqual( + fast_to_checksum_address(contract.implementation), + proxy_implementation_address, + ) proxy_implementation = await Contract.get_contract( session, HexBytes(proxy_implementation_address), chain_id ) self.assertIsNotNone(proxy_implementation) + self.assertEqual(contract.implementation, proxy_implementation.address) self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 2) From 7b471bace18c47d994643343df33328e74b656f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mois=C3=A9s?= <7888669+moisses89@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:29:29 +0100 Subject: [PATCH 5/9] Remove database_session from tasks --- app/datasources/db/models.py | 2 +- app/tests/datasources/db/test_model.py | 4 +- app/workers/tasks.py | 123 ++++++++++++------------- 3 files changed, 64 insertions(+), 65 deletions(-) diff --git a/app/datasources/db/models.py b/app/datasources/db/models.py index 561ca7b..66e8ad1 100644 --- a/app/datasources/db/models.py +++ b/app/datasources/db/models.py @@ -316,5 +316,5 @@ async def get_proxy_contracts(cls, session: AsyncSession): """ query = select(cls).where(cls.implementation != None) # noqa: E711 result = await session.stream(query) - async for contract in result: + async for (contract,) in result: yield contract diff --git a/app/tests/datasources/db/test_model.py b/app/tests/datasources/db/test_model.py index 24f9e0e..0295993 100644 --- a/app/tests/datasources/db/test_model.py +++ b/app/tests/datasources/db/test_model.py @@ -191,9 +191,9 @@ async def test_get_proxy_contracts(self, session: AsyncSession): self.assertTrue(result) async for proxy_contract in Contract.get_proxy_contracts(session): self.assertEqual( - fast_to_checksum_address(proxy_contract[0].address), random_address + fast_to_checksum_address(proxy_contract.address), random_address ) self.assertEqual( - fast_to_checksum_address(proxy_contract[0].implementation), + fast_to_checksum_address(proxy_contract.implementation), "0x43506849D7C04F9138D1A2050bbF3A0c054402dd", ) diff --git a/app/workers/tasks.py b/app/workers/tasks.py index e231f84..b8407f5 100644 --- a/app/workers/tasks.py +++ b/app/workers/tasks.py @@ -9,7 +9,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from app.config import settings -from app.datasources.db.database import database_session +from app.datasources.db.database import get_engine from app.datasources.db.models import Contract from app.services.contract_metadata_service import get_contract_metadata_service @@ -38,78 +38,77 @@ async def test_task(message: str) -> None: @dramatiq.actor -@database_session async def get_contract_metadata_task( - session: AsyncSession, - address: str, - chain_id: int, - skip_attemp_download: bool = False, -) -> None: - contract_metadata_service = get_contract_metadata_service() - # Just try the first time, following retries should be scheduled - if skip_attemp_download or await contract_metadata_service.should_attempt_download( - session, address, chain_id, 0 - ): - logger.info( - "Downloading contract metadata for contract=%s and chain=%s", - address, - chain_id, - ) - # TODO Check if contract is MultiSend. In that case, get the transaction and decode it - contract_metadata = await contract_metadata_service.get_contract_metadata( - fast_to_checksum_address(address), chain_id - ) - result = await contract_metadata_service.process_contract_metadata( - session, contract_metadata - ) - if result: - logger.info( - "Success download contract metadata for contract=%s and chain=%s", - address, - chain_id, + address: str, chain_id: int, skip_attemp_download: bool = False +): + async with AsyncSession(get_engine(), expire_on_commit=False) as session: + contract_metadata_service = get_contract_metadata_service() + # Just try the first time, following retries should be scheduled + if ( + skip_attemp_download + or await contract_metadata_service.should_attempt_download( + session, address, chain_id, 0 ) - else: - logger.info( - "Failed to download contract metadata for contract=%s and chain=%s", - address, - chain_id, - ) - - if proxy_implementation_address := contract_metadata_service.get_proxy_implementation_address( - contract_metadata ): logger.info( - "Adding task to download proxy implementation metadata from address=%s for contract=%s and chain=%s", - proxy_implementation_address, + "Downloading contract metadata for contract=%s and chain=%s", address, chain_id, ) - get_contract_metadata_task.send( - address=proxy_implementation_address, chain_id=chain_id + contract_metadata = await contract_metadata_service.get_contract_metadata( + fast_to_checksum_address(address), chain_id + ) + result = await contract_metadata_service.process_contract_metadata( + session, contract_metadata ) - else: - logger.debug("Skipping contract=%s and chain=%s", address, chain_id) + if result: + logger.info( + "Success download contract metadata for contract=%s and chain=%s", + address, + chain_id, + ) + else: + logger.info( + "Failed to download contract metadata for contract=%s and chain=%s", + address, + chain_id, + ) + + if proxy_implementation_address := contract_metadata_service.get_proxy_implementation_address( + contract_metadata + ): + logger.info( + "Adding task to download proxy implementation metadata from address=%s for contract=%s and chain=%s", + proxy_implementation_address, + address, + chain_id, + ) + get_contract_metadata_task.send( + address=proxy_implementation_address, chain_id=chain_id + ) + else: + logger.debug("Skipping contract=%s and chain=%s", address, chain_id) @dramatiq.actor(periodic=cron("0 0 * * *")) # Every midnight -@database_session -async def get_missing_contract_metadata_task(session: AsyncSession) -> None: - async for contract in Contract.get_contracts_without_abi( - session, settings.CONTRACT_MAX_DOWNLOAD_RETRIES - ): - get_contract_metadata_task.send( - address=HexBytes(contract[0].address).hex(), - chain_id=contract[0].chain_id, - skip_attemp_download=True, - ) +async def get_missing_contract_metadata_task(): + async with AsyncSession(get_engine(), expire_on_commit=False) as session: + async for contract in Contract.get_contracts_without_abi( + session, settings.CONTRACT_MAX_DOWNLOAD_RETRIES + ): + get_contract_metadata_task.send( + address=HexBytes(contract[0].address).hex(), + chain_id=contract[0].chain_id, + skip_attemp_download=True, + ) @dramatiq.actor(periodic=cron("0 5 * * *")) # Every day at 5 am -@database_session -async def update_proxies_task(session: AsyncSession) -> None: - async for proxy_contract in Contract.get_proxy_contracts(session): - get_contract_metadata_task.send( - address=HexBytes(proxy_contract[0].address).hex(), - chain_id=proxy_contract[0].chain_id, - skip_attemp_download=True, - ) +async def update_proxies_task(): + async with AsyncSession(get_engine(), expire_on_commit=False) as session: + async for proxy_contract in Contract.get_proxy_contracts(session): + get_contract_metadata_task.send( + address=HexBytes(proxy_contract.address).hex(), + chain_id=proxy_contract.chain_id, + skip_attemp_download=True, + ) From 195ab7c900aeafb2c07b9822be42015cac20d02a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mois=C3=A9s?= <7888669+moisses89@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:39:26 +0100 Subject: [PATCH 6/9] Fix race condition --- app/tests/services/test_contract_metadata.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/tests/services/test_contract_metadata.py b/app/tests/services/test_contract_metadata.py index 1f9b4e4..22cca30 100644 --- a/app/tests/services/test_contract_metadata.py +++ b/app/tests/services/test_contract_metadata.py @@ -1,3 +1,4 @@ +from copy import copy from unittest import mock from unittest.mock import MagicMock @@ -319,7 +320,7 @@ async def test_process_metadata_should_update_contracts( await AbiSource(name="Etherscan", url="").create(session) contract_metadata = EnhancedContractMetadata( address=contract_address, - metadata=etherscan_proxy_metadata_mock, + metadata=copy(etherscan_proxy_metadata_mock), # Avoid race condition source=ClientSource.ETHERSCAN, chain_id=1, ) From 2ec0cbbaf46eeaa36bb9ac8a9a3e1e876edda34b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mois=C3=A9s?= <7888669+moisses89@users.noreply.github.com> Date: Thu, 16 Jan 2025 09:42:02 +0100 Subject: [PATCH 7/9] Add return type --- app/datasources/db/models.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/app/datasources/db/models.py b/app/datasources/db/models.py index 66e8ad1..9d3a450 100644 --- a/app/datasources/db/models.py +++ b/app/datasources/db/models.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import AsyncIterator, Self, cast +from typing import AsyncGenerator, AsyncIterator, cast from sqlalchemy import DateTime, Row from sqlmodel import ( @@ -80,7 +80,7 @@ async def get_or_create( :param name: The name to check or create. :param url: The URL to check or create. :return: A tuple containing the AbiSource object and a boolean indicating - whether it was created `True` or already exists `False`. + whether it was created (True) or already exists (False). """ query = select(cls).where(cls.name == name, cls.url == url) results = await session.exec(query) @@ -168,7 +168,7 @@ async def get_or_create_abi( :param relevance: :param source_id: :return: A tuple containing the Abi object and a boolean indicating - whether it was created `True` or already exists `False`. + whether it was created (True) or already exists (False). """ if abi := await cls.get_abi(session, abi_json): return abi, False @@ -218,7 +218,7 @@ def get_contracts_query( Return a statement to get contracts for the provided address and chain_id :param address: - :param chain_ids: list of chain_ids, `None` for all chains + :param chain_ids: list of chain_ids, None for all chains :return: """ query = select(cls).where(cls.address == address) @@ -307,7 +307,9 @@ async def get_contracts_without_abi( yield contract @classmethod - async def get_proxy_contracts(cls, session: AsyncSession): + async def get_proxy_contracts( + cls, session: AsyncSession + ) -> AsyncGenerator["Contract"]: """ Return all the contracts with implementation address, so proxy contracts. From da5bd1e3e8f8d0a49fab78599324ed4e2020862a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mois=C3=A9s?= <7888669+moisses89@users.noreply.github.com> Date: Thu, 16 Jan 2025 09:55:37 +0100 Subject: [PATCH 8/9] Fix return types --- app/datasources/db/models.py | 6 +++--- app/tests/datasources/db/test_model.py | 6 ++++-- app/workers/tasks.py | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/app/datasources/db/models.py b/app/datasources/db/models.py index 9d3a450..2910b03 100644 --- a/app/datasources/db/models.py +++ b/app/datasources/db/models.py @@ -286,7 +286,7 @@ async def get_abi_by_contract_address( @classmethod async def get_contracts_without_abi( cls, session: AsyncSession, max_retries: int = 0 - ) -> AsyncIterator[Row[tuple[Self]]]: + ) -> AsyncGenerator["Contract", None]: """ Fetches contracts without an ABI and fewer retries than max_retries, streaming results in batches to reduce memory usage for large datasets. @@ -303,13 +303,13 @@ async def get_contracts_without_abi( .where(cls.fetch_retries <= max_retries) ) result = await session.stream(query) - async for contract in result: + async for (contract,) in result: yield contract @classmethod async def get_proxy_contracts( cls, session: AsyncSession - ) -> AsyncGenerator["Contract"]: + ) -> AsyncGenerator["Contract", None]: """ Return all the contracts with implementation address, so proxy contracts. diff --git a/app/tests/datasources/db/test_model.py b/app/tests/datasources/db/test_model.py index 0295993..415bdc4 100644 --- a/app/tests/datasources/db/test_model.py +++ b/app/tests/datasources/db/test_model.py @@ -1,3 +1,5 @@ +from typing import cast + from eth_account import Account from hexbytes import HexBytes from safe_eth.eth.utils import fast_to_checksum_address @@ -157,7 +159,7 @@ async def test_get_contracts_without_abi(self, session: AsyncSession): address=random_address, name="A test contract", chain_id=1 ).create(session) async for contract in Contract.get_contracts_without_abi(session, 0): - self.assertEqual(expected_contract, contract[0]) + self.assertEqual(expected_contract, contract) # Contracts with more retries shouldn't be returned expected_contract.fetch_retries = 1 @@ -194,6 +196,6 @@ async def test_get_proxy_contracts(self, session: AsyncSession): fast_to_checksum_address(proxy_contract.address), random_address ) self.assertEqual( - fast_to_checksum_address(proxy_contract.implementation), + fast_to_checksum_address(cast(bytes, proxy_contract.implementation)), "0x43506849D7C04F9138D1A2050bbF3A0c054402dd", ) diff --git a/app/workers/tasks.py b/app/workers/tasks.py index b8407f5..acc26b4 100644 --- a/app/workers/tasks.py +++ b/app/workers/tasks.py @@ -97,8 +97,8 @@ async def get_missing_contract_metadata_task(): session, settings.CONTRACT_MAX_DOWNLOAD_RETRIES ): get_contract_metadata_task.send( - address=HexBytes(contract[0].address).hex(), - chain_id=contract[0].chain_id, + address=HexBytes(contract.address).hex(), + chain_id=contract.chain_id, skip_attemp_download=True, ) From 8a64c8d288cdd8b862e9280493b0cd31342c44d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mois=C3=A9s?= <7888669+moisses89@users.noreply.github.com> Date: Thu, 16 Jan 2025 13:03:40 +0100 Subject: [PATCH 9/9] Add commit suggestions --- app/datasources/db/models.py | 18 ++++++++---------- app/tests/datasources/db/test_model.py | 4 ++-- app/tests/services/test_contract_metadata.py | 2 +- app/workers/tasks.py | 1 + 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/app/datasources/db/models.py b/app/datasources/db/models.py index 2910b03..b822add 100644 --- a/app/datasources/db/models.py +++ b/app/datasources/db/models.py @@ -1,7 +1,7 @@ from datetime import datetime, timezone -from typing import AsyncGenerator, AsyncIterator, cast +from typing import AsyncIterator, Self, cast -from sqlalchemy import DateTime, Row +from sqlalchemy import DateTime from sqlmodel import ( JSON, Column, @@ -80,7 +80,7 @@ async def get_or_create( :param name: The name to check or create. :param url: The URL to check or create. :return: A tuple containing the AbiSource object and a boolean indicating - whether it was created (True) or already exists (False). + whether it was created `True` or already exists `False`. """ query = select(cls).where(cls.name == name, cls.url == url) results = await session.exec(query) @@ -168,7 +168,7 @@ async def get_or_create_abi( :param relevance: :param source_id: :return: A tuple containing the Abi object and a boolean indicating - whether it was created (True) or already exists (False). + whether it was created `True` or already exists `False`. """ if abi := await cls.get_abi(session, abi_json): return abi, False @@ -218,7 +218,7 @@ def get_contracts_query( Return a statement to get contracts for the provided address and chain_id :param address: - :param chain_ids: list of chain_ids, None for all chains + :param chain_ids: list of chain_ids, `None` for all chains :return: """ query = select(cls).where(cls.address == address) @@ -286,7 +286,7 @@ async def get_abi_by_contract_address( @classmethod async def get_contracts_without_abi( cls, session: AsyncSession, max_retries: int = 0 - ) -> AsyncGenerator["Contract", None]: + ) -> AsyncIterator[Self]: """ Fetches contracts without an ABI and fewer retries than max_retries, streaming results in batches to reduce memory usage for large datasets. @@ -307,16 +307,14 @@ async def get_contracts_without_abi( yield contract @classmethod - async def get_proxy_contracts( - cls, session: AsyncSession - ) -> AsyncGenerator["Contract", None]: + async def get_proxy_contracts(cls, session: AsyncSession) -> AsyncIterator[Self]: """ Return all the contracts with implementation address, so proxy contracts. :param session: :return: """ - query = select(cls).where(cls.implementation != None) # noqa: E711 + query = select(cls).where(cls.implementation.isnot(None)) # type: ignore result = await session.stream(query) async for (contract,) in result: yield contract diff --git a/app/tests/datasources/db/test_model.py b/app/tests/datasources/db/test_model.py index 415bdc4..de66b16 100644 --- a/app/tests/datasources/db/test_model.py +++ b/app/tests/datasources/db/test_model.py @@ -8,8 +8,8 @@ from app.datasources.db.database import database_session from app.datasources.db.models import Abi, AbiSource, Contract, Project from app.services.contract_metadata_service import ( - ClientSource, ContractMetadataService, + ContractSource, EnhancedContractMetadata, ) @@ -184,7 +184,7 @@ async def test_get_proxy_contracts(self, session: AsyncSession): enhanced_contract_metadata = EnhancedContractMetadata( address=random_address, metadata=etherscan_proxy_metadata_mock, - source=ClientSource.ETHERSCAN, + source=ContractSource.ETHERSCAN, chain_id=1, ) result = await ContractMetadataService.process_contract_metadata( diff --git a/app/tests/services/test_contract_metadata.py b/app/tests/services/test_contract_metadata.py index 22cca30..c50e764 100644 --- a/app/tests/services/test_contract_metadata.py +++ b/app/tests/services/test_contract_metadata.py @@ -321,7 +321,7 @@ async def test_process_metadata_should_update_contracts( contract_metadata = EnhancedContractMetadata( address=contract_address, metadata=copy(etherscan_proxy_metadata_mock), # Avoid race condition - source=ClientSource.ETHERSCAN, + source=ContractSource.ETHERSCAN, chain_id=1, ) result = await ContractMetadataService.process_contract_metadata( diff --git a/app/workers/tasks.py b/app/workers/tasks.py index acc26b4..85ded71 100644 --- a/app/workers/tasks.py +++ b/app/workers/tasks.py @@ -55,6 +55,7 @@ async def get_contract_metadata_task( address, chain_id, ) + # TODO Check if contract is MultiSend. In that case, get the transaction and decode it contract_metadata = await contract_metadata_service.get_contract_metadata( fast_to_checksum_address(address), chain_id )