Skip to content

Commit

Permalink
Add small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Uxio0 committed Jan 16, 2025
1 parent a0aa0fc commit 398a226
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 52 deletions.
19 changes: 11 additions & 8 deletions app/datasources/db/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime, timezone
from typing import AsyncIterator, cast
from typing import AsyncIterator, Self, cast

from sqlalchemy import Row
from sqlmodel import (
JSON,
Column,
Expand Down Expand Up @@ -77,7 +78,7 @@ async def get_or_create(
:param name: The name to check or create.
:param url: The URL to check or create.
:return: A tuple containing the AbiSource object and a boolean indicating
whether it was created (True) or already exists (False).
whether it was created `True` or already exists `False`.
"""
query = select(cls).where(cls.name == name, cls.url == url)
results = await session.exec(query)
Expand Down Expand Up @@ -165,7 +166,7 @@ async def get_or_create_abi(
:param relevance:
:param source_id:
:return: A tuple containing the Abi object and a boolean indicating
whether it was created (True) or already exists (False).
whether it was created `True` or already exists `False`.
"""
if abi := await cls.get_abi(session, abi_json):
return abi, False
Expand Down Expand Up @@ -215,7 +216,7 @@ def get_contracts_query(
Return a statement to get contracts for the provided address and chain_id
:param address:
:param chain_ids: list of chain_ids, None for all chains
:param chain_ids: list of chain_ids, `None` for all chains
:return:
"""
query = select(cls).where(cls.address == address)
Expand Down Expand Up @@ -252,7 +253,7 @@ async def get_or_create(
:param chain_id:
:param kwargs:
:return: A tuple containing the Contract object and a boolean indicating
whether it was created (True) or already exists (False).
whether it was created `True` or already exists `False`.
"""
if contract := await cls.get_contract(session, address, chain_id):
return contract, False
Expand Down Expand Up @@ -283,10 +284,12 @@ async def get_abi_by_contract_address(
@classmethod
async def get_contracts_without_abi(
cls, session: AsyncSession, max_retries: int = 0
):
) -> AsyncIterator[Row[tuple[Self]]]:
"""
Fetches contracts without an ABI and fewer retries than max_retries, streaming results in batches to reduce memory usage for large datasets.
More information about streaming results can be found here: https://docs.sqlalchemy.org/en/20/core/connections.html#streaming-with-a-dynamically-growing-buffer-using-stream-results
Fetches contracts without an ABI and fewer retries than max_retries,
streaming results in batches to reduce memory usage for large datasets.
More information about streaming results can be found here:
https://docs.sqlalchemy.org/en/20/core/connections.html#streaming-with-a-dynamically-growing-buffer-using-stream-results
:param session:
:param max_retries:
Expand Down
58 changes: 23 additions & 35 deletions app/services/contract_metadata_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
logger = logging.getLogger(__name__)


class ClientSource(enum.Enum):
class ContractSource(enum.Enum):
ETHERSCAN = "Etherscan"
SOURCIFY = "Sourcify"
BLOCKSCOUT = "Blockscout"
Expand All @@ -37,7 +37,7 @@ class ClientSource(enum.Enum):
class EnhancedContractMetadata:
address: ChecksumAddress
metadata: ContractMetadata | None
source: ClientSource | None
source: ContractSource | None
chain_id: int


Expand Down Expand Up @@ -93,34 +93,28 @@ def enabled_clients(
self, chain_id: int
) -> list[AsyncEtherscanClientV2 | AsyncBlockscoutClient | AsyncSourcifyClient]:
"""
Return a list of available chains for the provided chain_id.
First etherscan, second sourcify, third blockscout.
:param chain_id:
:return:
:return: List of available clients for the provided `chain_id`.
First Etherscan, second Sourcify, third Blockscout.
"""
enabled_clients: list[
AsyncEtherscanClientV2 | AsyncBlockscoutClient | AsyncSourcifyClient
] = []
if etherscan_client := self._get_etherscan_client(chain_id):
enabled_clients.append(etherscan_client)
if sourcify_client := self._get_sourcify_client(chain_id):
enabled_clients.append(sourcify_client)
if blockscout_client := self._get_blockscout_client(chain_id):
enabled_clients.append(blockscout_client)
return enabled_clients
clients = (
self._get_etherscan_client(chain_id),
self._get_sourcify_client(chain_id),
self._get_blockscout_client(chain_id),
)
return [client for client in clients if client]

@staticmethod
@cache
def get_client_enum(
client: AsyncEtherscanClientV2 | AsyncSourcifyClient | AsyncBlockscoutClient,
) -> ClientSource:
) -> ContractSource:
if isinstance(client, AsyncEtherscanClientV2):
return ClientSource.ETHERSCAN
elif isinstance(client, AsyncSourcifyClient):
return ClientSource.SOURCIFY
elif isinstance(client, AsyncBlockscoutClient):
return ClientSource.BLOCKSCOUT
return ContractSource.ETHERSCAN
if isinstance(client, AsyncSourcifyClient):
return ContractSource.SOURCIFY
if isinstance(client, AsyncBlockscoutClient):
return ContractSource.BLOCKSCOUT

async def get_contract_metadata(
self, contract_address: ChecksumAddress, chain_id: int
Expand All @@ -138,7 +132,6 @@ async def get_contract_metadata(
contract_address
)
if contract_metadata:

return EnhancedContractMetadata(
address=contract_address,
metadata=contract_metadata,
Expand Down Expand Up @@ -174,7 +167,7 @@ async def process_contract_metadata(
address=HexBytes(contract_metadata.address),
chain_id=contract_metadata.chain_id,
)
with_metadata: bool

if contract_metadata.metadata:
if contract_metadata.source:
source = await AbiSource.get_abi_source(
Expand All @@ -196,13 +189,10 @@ async def process_contract_metadata(
contract.implementation = HexBytes(
contract_metadata.metadata.implementation
)
with_metadata = True
else:
with_metadata = False

contract.fetch_retries += 1
await contract.update(session=session)
return with_metadata
return bool(contract_metadata.metadata)

@staticmethod
def get_proxy_implementation_address(
Expand All @@ -220,23 +210,21 @@ async def should_attempt_download(
max_retries: int,
) -> bool:
"""
Return True if fetch retries is less than the number of retries and there is not ABI, False otherwise.
False is being cached to avoid query the database in the future for the same number of retries.
:param session:
:param contract_address:
:param chain_id:
:param max_retries:
:return:
:return: `True` if `fetch retries` are less than the number of `max_retries` and there is not ABI, `False` otherwise.
`False` is being cached to avoid query the database in the future for the same number of retries.
"""
redis = get_redis()
cache_key = (
f"should_attempt_download:{contract_address}:{chain_id}:{max_retries}"
)
# Try from cache first
cached_retries = cast(str, redis.get(cache_key))
if cached_retries:
return bool(int(cached_retries))
cached_retries: bytes = cast(bytes, redis.get(cache_key))
if cached_retries is not None:
return bool(int(cached_retries.decode()))
else:
contract = await Contract.get_contract(
session, address=HexBytes(contract_address), chain_id=chain_id
Expand Down
5 changes: 2 additions & 3 deletions app/services/events.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import logging
from typing import Dict

from safe_eth.eth.utils import fast_is_checksum_address

Expand All @@ -27,13 +26,13 @@ def process_event(self, message: str) -> None:
logging.error(f"Unsupported message. Cannot parse as JSON: {message}")

@staticmethod
def _is_processable_event(tx_service_event: Dict) -> bool:
def _is_processable_event(tx_service_event: dict) -> bool:
"""
Validates if the event has the required fields 'chainId', 'type', and 'to' as strings,
and if the event type and address meet the expected criteria.
:param tx_service_event: The event object to validate.
:return: True if the event is valid, False otherwise.
:return: `True` if the event is valid, `False` otherwise.
"""
chain_id = tx_service_event.get("chainId")
event_type = tx_service_event.get("type")
Expand Down
12 changes: 6 additions & 6 deletions app/tests/services/test_contract_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from app.datasources.db.database import database_session
from app.datasources.db.models import Abi, AbiSource, Contract
from app.services.contract_metadata_service import (
ClientSource,
ContractMetadataService,
ContractSource,
EnhancedContractMetadata,
)

Expand Down Expand Up @@ -165,7 +165,7 @@ async def test_process_contract_metadata(self, session: AsyncSession):
contract_data = EnhancedContractMetadata(
address=random_address,
metadata=etherscan_metadata_mock,
source=ClientSource.ETHERSCAN,
source=ContractSource.ETHERSCAN,
chain_id=1,
)
await ContractMetadataService.process_contract_metadata(session, contract_data)
Expand All @@ -184,7 +184,7 @@ async def test_process_contract_metadata(self, session: AsyncSession):
proxy_contract_data = EnhancedContractMetadata(
address=random_address,
metadata=etherscan_proxy_metadata_mock,
source=ClientSource.ETHERSCAN,
source=ContractSource.ETHERSCAN,
chain_id=1,
)
await ContractMetadataService.process_contract_metadata(
Expand Down Expand Up @@ -223,7 +223,7 @@ async def test_process_contract_metadata(self, session: AsyncSession):

await AbiSource.get_or_create(session, "Blockscout", "")
contract_data.metadata = blockscout_metadata_mock
contract_data.source = ClientSource.BLOCKSCOUT
contract_data.source = ContractSource.BLOCKSCOUT
await ContractMetadataService.process_contract_metadata(session, contract_data)
new_contract = await Contract.get_contract(
session, address=HexBytes(contract_data.address), chain_id=1
Expand Down Expand Up @@ -288,7 +288,7 @@ def test_get_proxy_implementation_address(self):
proxy_contract_data = EnhancedContractMetadata(
address=random_address,
metadata=etherscan_proxy_metadata_mock,
source=ClientSource.ETHERSCAN,
source=ContractSource.ETHERSCAN,
chain_id=1,
)
proxy_implementation_address = (
Expand All @@ -303,7 +303,7 @@ def test_get_proxy_implementation_address(self):
contract_data = EnhancedContractMetadata(
address=random_address,
metadata=etherscan_metadata_mock,
source=ClientSource.ETHERSCAN,
source=ContractSource.ETHERSCAN,
chain_id=1,
)
self.assertIsNone(
Expand Down
1 change: 1 addition & 0 deletions app/workers/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ async def get_contract_metadata_task(
address,
chain_id,
)
# TODO Check if contract is MultiSend. In that case, get the transaction and decode it
contract_metadata = await contract_metadata_service.get_contract_metadata(
fast_to_checksum_address(address), chain_id
)
Expand Down

0 comments on commit 398a226

Please sign in to comment.