Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add update proxies task #64

Merged
merged 9 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions app/datasources/db/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime, timezone
from typing import AsyncIterator, Self, cast

from sqlalchemy import DateTime, Row
from sqlalchemy import DateTime
from sqlmodel import (
JSON,
Column,
Expand Down Expand Up @@ -286,7 +286,7 @@ async def get_abi_by_contract_address(
@classmethod
async def get_contracts_without_abi(
cls, session: AsyncSession, max_retries: int = 0
) -> AsyncIterator[Row[tuple[Self]]]:
) -> AsyncIterator[Self]:
"""
Fetches contracts without an ABI and fewer retries than max_retries,
streaming results in batches to reduce memory usage for large datasets.
Expand All @@ -303,5 +303,18 @@ async def get_contracts_without_abi(
.where(cls.fetch_retries <= max_retries)
)
result = await session.stream(query)
async for contract in result:
async for (contract,) in result:
yield contract

@classmethod
async def get_proxy_contracts(cls, session: AsyncSession) -> AsyncIterator[Self]:
"""
Return all the contracts with implementation address, so proxy contracts.

:param session:
:return:
"""
query = select(cls).where(cls.implementation.isnot(None)) # type: ignore
result = await session.stream(query)
async for (contract,) in result:
yield contract
38 changes: 37 additions & 1 deletion app/tests/datasources/db/test_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from typing import cast

from eth_account import Account
from hexbytes import HexBytes
from safe_eth.eth.utils import fast_to_checksum_address
from sqlmodel.ext.asyncio.session import AsyncSession

from app.datasources.db.database import database_session
from app.datasources.db.models import Abi, AbiSource, Contract, Project
from app.services.contract_metadata_service import (
ContractMetadataService,
ContractSource,
EnhancedContractMetadata,
)

from ...mocks.contract_metadata_mocks import etherscan_proxy_metadata_mock
from .db_async_conn import DbAsyncConn


Expand Down Expand Up @@ -150,7 +159,7 @@ async def test_get_contracts_without_abi(self, session: AsyncSession):
address=random_address, name="A test contract", chain_id=1
).create(session)
async for contract in Contract.get_contracts_without_abi(session, 0):
self.assertEqual(expected_contract, contract[0])
self.assertEqual(expected_contract, contract)

# Contracts with more retries shouldn't be returned
expected_contract.fetch_retries = 1
Expand All @@ -163,3 +172,30 @@ async def test_get_contracts_without_abi(self, session: AsyncSession):
await expected_contract.update(session)
async for contract in Contract.get_contracts_without_abi(session, 10):
self.fail("Expected no contracts, but found one.")

@database_session
async def test_get_proxy_contracts(self, session: AsyncSession):
# Test empty case
async for proxy_contract in Contract.get_proxy_contracts(session):
self.fail("Expected no proxies, but found one.")

random_address = Account.create().address
await AbiSource(name="Etherscan", url="").create(session)
enhanced_contract_metadata = EnhancedContractMetadata(
address=random_address,
metadata=etherscan_proxy_metadata_mock,
source=ContractSource.ETHERSCAN,
chain_id=1,
)
result = await ContractMetadataService.process_contract_metadata(
session, enhanced_contract_metadata
)
self.assertTrue(result)
async for proxy_contract in Contract.get_proxy_contracts(session):
self.assertEqual(
fast_to_checksum_address(proxy_contract.address), random_address
)
self.assertEqual(
fast_to_checksum_address(cast(bytes, proxy_contract.implementation)),
"0x43506849D7C04F9138D1A2050bbF3A0c054402dd",
)
45 changes: 45 additions & 0 deletions app/tests/services/test_contract_metadata.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import copy
from unittest import mock
from unittest.mock import MagicMock

Expand Down Expand Up @@ -309,3 +310,47 @@ def test_get_proxy_implementation_address(self):
self.assertIsNone(
ContractMetadataService.get_proxy_implementation_address(contract_data)
)

@database_session
async def test_process_metadata_should_update_contracts(
self, session: AsyncSession
):
contract_address = Account.create().address
chain_id = 1
await AbiSource(name="Etherscan", url="").create(session)
contract_metadata = EnhancedContractMetadata(
address=contract_address,
metadata=copy(etherscan_proxy_metadata_mock), # Avoid race condition
source=ContractSource.ETHERSCAN,
chain_id=1,
)
result = await ContractMetadataService.process_contract_metadata(
session, contract_metadata
)
self.assertTrue(result)
contract = await Contract.get_contract(
session, address=HexBytes(contract_address), chain_id=chain_id
)
self.assertEqual(
fast_to_checksum_address(contract.address), contract_metadata.address
)
self.assertEqual(
fast_to_checksum_address(contract.implementation),
contract_metadata.metadata.implementation, # type: ignore
)
# Process metadata should update the db contract information
implementation_address = fast_to_checksum_address(Account.create().address)
contract_metadata.metadata.implementation = implementation_address # type: ignore
result = await ContractMetadataService.process_contract_metadata(
session, contract_metadata
)
self.assertTrue(result)
contract = await Contract.get_contract(
session, address=HexBytes(contract_address), chain_id=chain_id
)
self.assertEqual(
fast_to_checksum_address(contract.address), contract_metadata.address
)
self.assertEqual(
fast_to_checksum_address(contract.implementation), implementation_address
)
17 changes: 15 additions & 2 deletions app/tests/workers/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from hexbytes import HexBytes
from safe_eth.eth import EthereumNetwork
from safe_eth.eth.clients import AsyncEtherscanClientV2
from safe_eth.eth.utils import fast_to_checksum_address
from sqlmodel.ext.asyncio.session import AsyncSession

from app.datasources.db.database import database_session
Expand Down Expand Up @@ -66,18 +67,25 @@ class TestAsyncTasks(DbAsyncConn):

async def asyncSetUp(self):
await super().asyncSetUp()
self.worker = Worker(redis_broker)
self.worker = Worker(redis_broker, worker_threads=1)
self.worker.start()

async def asyncTearDown(self):
await super().asyncTearDown()
self.worker.stop()
redis = get_redis()
redis.flushall()

def _wait_tasks_execution(self):
# Ensure that all the messages on redis were consumed
redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1)
while len(redis_tasks) > 0:
redis_tasks = self.worker.broker.client.lrange("dramatiq:default", 0, -1)

# Wait for all the messages on the given queue to be processed.
# This method is only meant to be used in tests
self.worker.broker.join("default")

@mock.patch.object(ContractMetadataService, "enabled_clients")
@mock.patch.object(
AsyncEtherscanClientV2, "async_get_contract_metadata", autospec=True
Expand Down Expand Up @@ -139,6 +147,7 @@ async def test_get_contract_metadata_task(
async def test_get_contract_metadata_task_with_proxy(
self, etherscan_get_contract_metadata_mock: MagicMock, session: AsyncSession
):
await AbiSource(name="Etherscan", url="").create(session)
etherscan_get_contract_metadata_mock.side_effect = [
etherscan_proxy_metadata_mock,
etherscan_metadata_mock,
Expand All @@ -155,10 +164,14 @@ async def test_get_contract_metadata_task_with_proxy(
session, HexBytes(contract_address), chain_id
)
self.assertIsNotNone(contract)

self.assertEqual(
fast_to_checksum_address(contract.implementation),
proxy_implementation_address,
)
proxy_implementation = await Contract.get_contract(
session, HexBytes(proxy_implementation_address), chain_id
)
self.assertIsNotNone(proxy_implementation)
self.assertEqual(contract.implementation, proxy_implementation.address)

self.assertEqual(etherscan_get_contract_metadata_mock.call_count, 2)
117 changes: 64 additions & 53 deletions app/workers/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sqlmodel.ext.asyncio.session import AsyncSession

from app.config import settings
from app.datasources.db.database import database_session
from app.datasources.db.database import get_engine
from app.datasources.db.models import Contract
from app.services.contract_metadata_service import get_contract_metadata_service

Expand Down Expand Up @@ -38,67 +38,78 @@ async def test_task(message: str) -> None:


@dramatiq.actor
@database_session
async def get_contract_metadata_task(
session: AsyncSession,
address: str,
chain_id: int,
skip_attemp_download: bool = False,
) -> None:
contract_metadata_service = get_contract_metadata_service()
# Just try the first time, following retries should be scheduled
if skip_attemp_download or await contract_metadata_service.should_attempt_download(
session, address, chain_id, 0
):
logger.info(
"Downloading contract metadata for contract=%s and chain=%s",
address,
chain_id,
)
# TODO Check if contract is MultiSend. In that case, get the transaction and decode it
contract_metadata = await contract_metadata_service.get_contract_metadata(
fast_to_checksum_address(address), chain_id
)
result = await contract_metadata_service.process_contract_metadata(
session, contract_metadata
)
if result:
logger.info(
"Success download contract metadata for contract=%s and chain=%s",
address,
chain_id,
address: str, chain_id: int, skip_attemp_download: bool = False
):
async with AsyncSession(get_engine(), expire_on_commit=False) as session:
contract_metadata_service = get_contract_metadata_service()
# Just try the first time, following retries should be scheduled
if (
skip_attemp_download
or await contract_metadata_service.should_attempt_download(
session, address, chain_id, 0
)
else:
):
logger.info(
"Failed to download contract metadata for contract=%s and chain=%s",
"Downloading contract metadata for contract=%s and chain=%s",
address,
chain_id,
)
# TODO Check if contract is MultiSend. In that case, get the transaction and decode it
contract_metadata = await contract_metadata_service.get_contract_metadata(
fast_to_checksum_address(address), chain_id
)
result = await contract_metadata_service.process_contract_metadata(
session, contract_metadata
)
if result:
logger.info(
"Success download contract metadata for contract=%s and chain=%s",
address,
chain_id,
)
else:
logger.info(
"Failed to download contract metadata for contract=%s and chain=%s",
address,
chain_id,
)

if proxy_implementation_address := contract_metadata_service.get_proxy_implementation_address(
contract_metadata
):
logger.info(
"Adding task to download proxy implementation metadata from address=%s for contract=%s and chain=%s",
proxy_implementation_address,
address,
chain_id,
)
get_contract_metadata_task.send(
address=proxy_implementation_address, chain_id=chain_id
)
else:
logger.debug("Skipping contract=%s and chain=%s", address, chain_id)

if proxy_implementation_address := contract_metadata_service.get_proxy_implementation_address(
contract_metadata

@dramatiq.actor(periodic=cron("0 0 * * *")) # Every midnight
async def get_missing_contract_metadata_task():
async with AsyncSession(get_engine(), expire_on_commit=False) as session:
async for contract in Contract.get_contracts_without_abi(
moisses89 marked this conversation as resolved.
Show resolved Hide resolved
session, settings.CONTRACT_MAX_DOWNLOAD_RETRIES
):
logger.info(
"Adding task to download proxy implementation metadata from address=%s for contract=%s and chain=%s",
proxy_implementation_address,
address,
chain_id,
)
get_contract_metadata_task.send(
address=proxy_implementation_address, chain_id=chain_id
address=HexBytes(contract.address).hex(),
chain_id=contract.chain_id,
skip_attemp_download=True,
)
else:
logger.debug("Skipping contract=%s and chain=%s", address, chain_id)


@dramatiq.actor(periodic=cron("0 0 * * *")) # Every midnight
@database_session
async def get_missing_contract_metadata_task(session: AsyncSession) -> None:
async for contract in Contract.get_contracts_without_abi(
session, settings.CONTRACT_MAX_DOWNLOAD_RETRIES
):
get_contract_metadata_task.send(
address=HexBytes(contract[0].address).hex(),
chain_id=contract[0].chain_id,
skip_attemp_download=True,
)
@dramatiq.actor(periodic=cron("0 5 * * *")) # Every day at 5 am
async def update_proxies_task():
async with AsyncSession(get_engine(), expire_on_commit=False) as session:
async for proxy_contract in Contract.get_proxy_contracts(session):
get_contract_metadata_task.send(
address=HexBytes(proxy_contract.address).hex(),
chain_id=proxy_contract.chain_id,
skip_attemp_download=True,
)
Loading