diff --git a/app/datasources/db/models.py b/app/datasources/db/models.py index 8027024..0b9ef56 100644 --- a/app/datasources/db/models.py +++ b/app/datasources/db/models.py @@ -283,7 +283,7 @@ async def get_abi_by_contract_address( @classmethod async def get_contracts_without_abi( cls, session: AsyncSession, max_retries: int = 0 - ): + ) -> AsyncGenerator["Contract", None]: """ Fetches contracts without an ABI and fewer retries than max_retries, streaming results in batches to reduce memory usage for large datasets. More information about streaming results can be found here: https://docs.sqlalchemy.org/en/20/core/connections.html#streaming-with-a-dynamically-growing-buffer-using-stream-results @@ -298,13 +298,13 @@ async def get_contracts_without_abi( .where(cls.fetch_retries <= max_retries) ) result = await session.stream(query) - async for contract in result: + async for (contract,) in result: yield contract @classmethod async def get_proxy_contracts( cls, session: AsyncSession - ) -> AsyncGenerator["Contract"]: + ) -> AsyncGenerator["Contract", None]: """ Return all the contracts with implementation address, so proxy contracts. diff --git a/app/tests/datasources/db/test_model.py b/app/tests/datasources/db/test_model.py index 0295993..415bdc4 100644 --- a/app/tests/datasources/db/test_model.py +++ b/app/tests/datasources/db/test_model.py @@ -1,3 +1,5 @@ +from typing import cast + from eth_account import Account from hexbytes import HexBytes from safe_eth.eth.utils import fast_to_checksum_address @@ -157,7 +159,7 @@ async def test_get_contracts_without_abi(self, session: AsyncSession): address=random_address, name="A test contract", chain_id=1 ).create(session) async for contract in Contract.get_contracts_without_abi(session, 0): - self.assertEqual(expected_contract, contract[0]) + self.assertEqual(expected_contract, contract) # Contracts with more retries shouldn't be returned expected_contract.fetch_retries = 1 @@ -194,6 +196,6 @@ async def test_get_proxy_contracts(self, session: AsyncSession): fast_to_checksum_address(proxy_contract.address), random_address ) self.assertEqual( - fast_to_checksum_address(proxy_contract.implementation), + fast_to_checksum_address(cast(bytes, proxy_contract.implementation)), "0x43506849D7C04F9138D1A2050bbF3A0c054402dd", ) diff --git a/app/workers/tasks.py b/app/workers/tasks.py index b8407f5..acc26b4 100644 --- a/app/workers/tasks.py +++ b/app/workers/tasks.py @@ -97,8 +97,8 @@ async def get_missing_contract_metadata_task(): session, settings.CONTRACT_MAX_DOWNLOAD_RETRIES ): get_contract_metadata_task.send( - address=HexBytes(contract[0].address).hex(), - chain_id=contract[0].chain_id, + address=HexBytes(contract.address).hex(), + chain_id=contract.chain_id, skip_attemp_download=True, )