diff --git a/safe_transaction_service/account_abstraction/services/aa_processor_service.py b/safe_transaction_service/account_abstraction/services/aa_processor_service.py index d1ad072d8..030e5e818 100644 --- a/safe_transaction_service/account_abstraction/services/aa_processor_service.py +++ b/safe_transaction_service/account_abstraction/services/aa_processor_service.py @@ -27,6 +27,7 @@ from ..models import SafeOperationConfirmation as SafeOperationConfirmationModel from ..models import UserOperation as UserOperationModel from ..models import UserOperationReceipt as UserOperationReceiptModel +from ..utils import get_bundler_client logger = logging.getLogger(__name__) @@ -42,11 +43,7 @@ class ExecutionFromSafeModuleNotDetected(AaProcessorServiceException): @cache def get_aa_processor_service() -> "AaProcessorService": ethereum_client = EthereumClientProvider() - bundler_client = ( - BundlerClient(settings.ETHEREUM_4337_BUNDLER_URL) - if settings.ETHEREUM_4337_BUNDLER_URL - else None - ) + bundler_client = get_bundler_client() if not bundler_client: logger.warning("Ethereum 4337 bundler client was not configured") supported_entry_points = settings.ETHEREUM_4337_SUPPORTED_ENTRY_POINTS diff --git a/safe_transaction_service/account_abstraction/utils.py b/safe_transaction_service/account_abstraction/utils.py new file mode 100644 index 000000000..ce7d7adcd --- /dev/null +++ b/safe_transaction_service/account_abstraction/utils.py @@ -0,0 +1,20 @@ +import logging +from functools import cache +from typing import Optional + +from django.conf import settings + +from gnosis.eth.account_abstraction import BundlerClient + +logger = logging.getLogger(__name__) + + +@cache +def get_bundler_client() -> Optional[BundlerClient]: + """ + :return: Initialized `ERC4337 RPC Bundler Client` if configured, `None` otherwise + """ + if settings.ETHEREUM_4337_BUNDLER_URL: + return BundlerClient(settings.ETHEREUM_4337_BUNDLER_URL) + logger.warning("ETHEREUM_4337_BUNDLER_URL not set, cannot configure bundler client") + return None diff --git a/safe_transaction_service/history/management/commands/check_chainid_matches.py b/safe_transaction_service/history/management/commands/check_chainid_matches.py index 224bb32cf..0d3106023 100644 --- a/safe_transaction_service/history/management/commands/check_chainid_matches.py +++ b/safe_transaction_service/history/management/commands/check_chainid_matches.py @@ -1,5 +1,6 @@ from django.core.management.base import BaseCommand, CommandError +from safe_transaction_service.account_abstraction.utils import get_bundler_client from safe_transaction_service.utils.ethereum import get_chain_id from ...models import Chain @@ -16,11 +17,23 @@ def handle(self, *args, **options): except Chain.DoesNotExist: chain = Chain.objects.create(chain_id=chain_id) - if chain.chain_id == chain_id: - self.stdout.write( - self.style.SUCCESS(f"EthereumRPC chainId {chain_id} looks good") - ) - else: + if chain_id != chain.chain_id: raise CommandError( f"EthereumRPC chainId {chain_id} does not match previously used chainId {chain.chain_id}" ) + self.stdout.write( + self.style.SUCCESS(f"EthereumRPC chainId {chain_id} looks good") + ) + + if bundler_client := get_bundler_client(): + bundler_chain_id = bundler_client.get_chain_id() + if bundler_chain_id != chain.chain_id: + raise CommandError( + f"ERC4337 BundlerClient chainId {bundler_chain_id} does not match " + f"EthereumClient chainId {chain.chain_id}" + ) + self.stdout.write( + self.style.SUCCESS( + f"ERC4337 BundlerClient chainId {chain_id} looks good" + ) + ) diff --git a/safe_transaction_service/history/tests/test_commands.py b/safe_transaction_service/history/tests/test_commands.py index dbe8b5af1..3931484c5 100644 --- a/safe_transaction_service/history/tests/test_commands.py +++ b/safe_transaction_service/history/tests/test_commands.py @@ -10,6 +10,7 @@ from django_celery_beat.models import PeriodicTask from eth_account import Account +from gnosis.eth.account_abstraction import BundlerClient from gnosis.eth.ethereum_client import EthereumClient, EthereumNetwork from gnosis.safe.tests.safe_test_case import SafeTestCaseMixin @@ -421,10 +422,16 @@ def test_export_multisig_tx_data(self): call_command(command, arguments, stdout=buf) self.assertIn("Start exporting of 1", buf.getvalue()) + @mock.patch( + "safe_transaction_service.history.management.commands.check_chainid_matches.get_bundler_client", + return_value=None, + ) @mock.patch( "safe_transaction_service.history.management.commands.check_chainid_matches.get_chain_id" ) - def test_check_chainid_matches(self, get_chain_id_mock: MagicMock): + def test_check_chainid_matches( + self, get_chain_id_mock: MagicMock, get_bundler_client_mock: MagicMock + ): command = "check_chainid_matches" # Create ChainId model @@ -447,6 +454,36 @@ def test_check_chainid_matches(self, get_chain_id_mock: MagicMock): call_command(command, stdout=buf) self.assertIn("EthereumRPC chainId 1 looks good", buf.getvalue()) + @mock.patch.object(BundlerClient, "get_chain_id", return_value=1234) + @mock.patch( + "safe_transaction_service.history.management.commands.check_chainid_matches.get_bundler_client", + return_value=BundlerClient(""), + ) + @mock.patch( + "safe_transaction_service.history.management.commands.check_chainid_matches.get_chain_id", + return_value=EthereumNetwork.MAINNET.value, + ) + def test_check_chainid_bundler_matches( + self, + get_chain_id_mock: MagicMock, + get_bundler_client_mock: MagicMock, + bundler_get_chain_id_mock: MagicMock, + ): + command = "check_chainid_matches" + with self.assertRaisesMessage( + CommandError, + "ERC4337 BundlerClient chainId 1234 does not match EthereumClient chainId 1", + ): + call_command(command) + + bundler_get_chain_id_mock.return_value = EthereumNetwork.MAINNET.value + buf = StringIO() + call_command(command, stdout=buf) + self.assertEqual( + "EthereumRPC chainId 1 looks good\nERC4337 BundlerClient chainId 1 looks good\n", + buf.getvalue(), + ) + @mock.patch( "safe_transaction_service.history.management.commands.check_index_problems.settings.ETH_L2_NETWORK", return_value=True,