From aee01f7b995191f6119715cd2bf74e8daf7e1b68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ux=C3=ADo?= Date: Mon, 20 Jan 2025 12:13:10 +0100 Subject: [PATCH] Implement multichain support for decoding (#70) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Closes #38 --------- Co-authored-by: Uxio Fuentefria <6909403+Uxio0@users.noreply.github.com> Co-authored-by: Moisés <7888669+moisses89@users.noreply.github.com> --- app/datasources/db/models.py | 15 ++- app/services/data_decoder.py | 109 +++++++++++++----- app/tests/datasources/db/test_model.py | 21 +++- app/tests/services/test_data_decoder.py | 144 ++++++++++++++++++++++++ 4 files changed, 255 insertions(+), 34 deletions(-) diff --git a/app/datasources/db/models.py b/app/datasources/db/models.py index b822add..49afa5f 100644 --- a/app/datasources/db/models.py +++ b/app/datasources/db/models.py @@ -270,15 +270,24 @@ async def get_or_create( @classmethod async def get_abi_by_contract_address( - cls, session: AsyncSession, address: bytes + cls, session: AsyncSession, address: bytes, chain_id: int | None ) -> ABI | None: - # TODO Add chain_id filter to support multichain - results = await session.exec( + """ + :return: Json ABI given the contract `address` and `chain_id`. If `chain_id` is not given, + sort the ABIs by `chain_id` and return the first one. + """ + query = ( select(Abi.abi_json) .join(cls) .where(cls.address == address) .where(cls.abi_id == Abi.id) ) + if chain_id is not None: + query = query.where(cls.chain_id == chain_id) + else: + query = query.order_by(col(cls.chain_id)) + + results = await session.exec(query) if result := results.first(): return cast(ABI, result) return None diff --git a/app/services/data_decoder.py b/app/services/data_decoder.py index 06bd761..82faa70 100644 --- a/app/services/data_decoder.py +++ b/app/services/data_decoder.py @@ -67,6 +67,10 @@ class DataDecoderService: dummy_w3 = Web3() session: AsyncSession | None + fn_selectors_with_abis: dict[bytes, ABIFunction] + multisend_abis: list[ABI] + multisend_fn_selectors_with_abis: dict[bytes, ABIFunction] + async def init(self, session: AsyncSession) -> None: """ Initialize the data decoder service, loading the ABIs from the database and storing the 4byte selectors @@ -129,27 +133,38 @@ async def get_multisend_abis(self) -> AsyncIterator[ABI]: @alru_cache(maxsize=2048) @database_session async def get_contract_abi( - self, address: Address, session: AsyncSession | None = None + self, + address: Address, + chain_id: int | None, + session: AsyncSession | None = None, ) -> ABI | None: """ Retrieves the ABI for the contract at the given address. :param address: Contract address + :param chain_id: Chain id for the contract :param session: Database session, provided by the decorator :return: List of ABI data if found, `None` otherwise. """ assert session is not None - return await Contract.get_abi_by_contract_address(session, HexBytes(address)) + return await Contract.get_abi_by_contract_address( + session, HexBytes(address), chain_id + ) @alru_cache(maxsize=2048) async def get_contract_fallback_function( - self, address: Address + self, address: Address, chain_id: int | None ) -> ABIFunction | None: """ :param address: Contract address - :return: Fallback ABIFunction if found, `None` otherwise. - """ - abi = await self.get_contract_abi(address) + :param chain_id: Chain for the contract + :return: Fallback `ABIFunction` if found, `None` otherwise. + If contract is not found for the chain, return the first one that matches in other chain. + """ + abi = await self.get_contract_abi(address, chain_id) + if not abi and chain_id is not None: + # Try to find an ABI in other network + abi = await self.get_contract_abi(address, None) if abi: return next( ( @@ -163,25 +178,29 @@ async def get_contract_fallback_function( @alru_cache(maxsize=2048) async def get_contract_abi_selectors_with_functions( - self, address: Address + self, address: Address, chain_id: int | None ) -> dict[bytes, ABIFunction] | None: """ :param address: Contract address - :return: Dictionary of function selects with ABIFunction if found, `None` otherwise - """ - abi = await self.get_contract_abi(address) + :param chain_id: Chain for the contract + :return: Dictionary of function selects with `ABIFunction` if found, `None` otherwise + If contract is not found for the chain, return the first one that matches in other chain. + """ + abi = await self.get_contract_abi(address, chain_id) + if not abi and chain_id is not None: + # Try to find an ABI in other network + abi = await self.get_contract_abi(address, None) if abi: - # TODO We should return that there's a fullMatch for this `data` and `address`, so we are sure - # we are decoding the `data` correctly return self._generate_selectors_with_abis_from_abi(abi) return None async def get_abi_function( - self, data: bytes, address: Address | None = None + self, data: bytes, address: Address | None = None, chain_id: int | None = None ) -> ABIFunction | None: """ :param data: transaction data :param address: contract address in case of ABI colliding + :param chain_id: Chain for the contract :return: Abi function for data if it can be decoded, `None` if not found """ selector = data[:4] @@ -190,7 +209,9 @@ async def get_abi_function( # Try to use specific ABI if address provided if address: contract_selectors_with_abis = ( - await self.get_contract_abi_selectors_with_functions(address) + await self.get_contract_abi_selectors_with_functions( + address, chain_id + ) ) if ( contract_selectors_with_abis @@ -202,7 +223,7 @@ async def get_abi_function( return self.fn_selectors_with_abis[selector] # Check if the contract has a fallback call and return a minimal ABIFunction for fallback call elif address: - return await self.get_contract_fallback_function(address) + return await self.get_contract_fallback_function(address, chain_id) return None def _parse_decoded_arguments(self, value_decoded: Any) -> Any: @@ -220,13 +241,17 @@ def _parse_decoded_arguments(self, value_decoded: Any) -> Any: return value_decoded async def _decode_data( - self, data: bytes | str, address: Address | None = None + self, + data: bytes | str, + address: Address | None = None, + chain_id: int | None = None, ) -> tuple[str, list[tuple[str, str, Any]]]: """ Decode tx data :param data: Tx data as `hex string` or `bytes` :param address: contract address in case of ABI colliding + :param chain_id: Chain for the contract :return: Tuple with the `function name` and a List of sorted tuples with the `name` of the argument, `type` and `value` :raises: CannotDecode if data cannot be decoded. You should catch this exception when using this function @@ -238,7 +263,7 @@ async def _decode_data( data = HexBytes(data) params = data[4:] - fn_abi = await self.get_abi_function(data, address) + fn_abi = await self.get_abi_function(data, address, chain_id) if not fn_abi: raise CannotDecode(data.hex()) try: @@ -254,12 +279,13 @@ async def _decode_data( return fn_abi["name"], list(zip(names, types, values)) async def decode_multisend_data( - self, data: bytes | str + self, data: bytes | str, chain_id: int | None = None ) -> list[MultisendDecoded] | None: """ Decodes Multisend raw data to Multisend dictionary :param data: + :param chain_id: :return: """ try: @@ -271,7 +297,9 @@ async def decode_multisend_data( value=str(multisend_tx.value), data=HexStr(multisend_tx.data.hex()) if multisend_tx.data else None, data_decoded=await self.get_data_decoded( - multisend_tx.data, address=cast(Address, multisend_tx.to) + multisend_tx.data, + address=cast(Address, multisend_tx.to), + chain_id=chain_id, ), ) for multisend_tx in multisend_txs @@ -285,27 +313,34 @@ async def decode_multisend_data( return None async def get_data_decoded( - self, data: bytes | str, address: Address | None = None + self, + data: bytes | str, + address: Address | None = None, + chain_id: int | None = None, ) -> DataDecoded | None: """ Return data prepared for serializing :param data: :param address: contract address in case of ABI colliding + :param chain_id: chain for contract :return: """ if not data: return None try: fn_name, parameters = await self.decode_transaction_with_types( - data, address=address + data, address=address, chain_id=chain_id ) return {"method": fn_name, "parameters": parameters} except DataDecoderException: return None async def decode_parameters_data( - self, data: bytes, parameters: list[ParameterDecoded] + self, + data: bytes, + parameters: list[ParameterDecoded], + chain_id: int | None = None, ) -> list[ParameterDecoded]: """ Decode inner data for function parameters for: @@ -319,7 +354,9 @@ async def decode_parameters_data( fn_selector = data[:4] if fn_selector in self.multisend_fn_selectors_with_abis: # If MultiSend, decode the transactions - parameters[0]["value_decoded"] = await self.decode_multisend_data(data) + parameters[0]["value_decoded"] = await self.decode_multisend_data( + data, chain_id=chain_id + ) elif ( fn_selector == self.EXEC_TRANSACTION_SELECTOR @@ -331,49 +368,63 @@ async def decode_parameters_data( # selector is `0x6a761202` and parameters[2] is data try: parameters[2]["value_decoded"] = await self.get_data_decoded( - data, address=parameters[0]["value"] + data, address=parameters[0]["value"], chain_id=chain_id ) except DataDecoderException: logger.warning("Cannot decode `execTransaction`", exc_info=True) return parameters async def decode_transaction_with_types( - self, data: bytes | str, address: Address | None = None + self, + data: bytes | str, + address: Address | None = None, + chain_id: int | None = None, ) -> tuple[str, list[ParameterDecoded]]: """ Decode tx data and return a list of dictionaries :param data: Tx data as `hex string` or `bytes` :param address: contract address in case of ABI colliding + :param chain_id: chain for the contract :return: Tuple with the `function name` and a list of dictionaries [{'name': str, 'type': str, 'value': `depending on type`}...] :raises: CannotDecode if data cannot be decoded. You should catch this exception when using this function :raises: UnexpectedProblemDecoding if there's an unexpected problem decoding (it shouldn't happen) """ data = HexBytes(data) - fn_name, raw_parameters = await self._decode_data(data, address=address) + fn_name, raw_parameters = await self._decode_data( + data, address=address, chain_id=chain_id + ) # Parameters are returned as tuple, convert it to a dictionary parameters = [ ParameterDecoded(name=name, type=argument_type, value=value) for name, argument_type, value in raw_parameters ] - nested_parameters = await self.decode_parameters_data(data, parameters) + nested_parameters = await self.decode_parameters_data( + data, parameters, chain_id=chain_id + ) return fn_name, nested_parameters async def decode_transaction( - self, data: bytes | str, address: Address | None = None + self, + data: bytes | str, + address: Address | None = None, + chain_id: int | None = None, ) -> tuple[str, dict[str, Any]]: """ Decode tx data and return all the parameters in the same dictionary :param data: Tx data as `hex string` or `bytes` :param address: contract address in case of ABI colliding + :param chain_id: chain for the contract :return: Tuple with the `function name` and a dictionary with the arguments of the function :raises: CannotDecode if data cannot be decoded. You should catch this exception when using this function :raises: UnexpectedProblemDecoding if there's an unexpected problem decoding (it shouldn't happen) """ fn_name, decoded_transactions_with_types = ( - await self.decode_transaction_with_types(data, address=address) + await self.decode_transaction_with_types( + data, address=address, chain_id=chain_id + ) ) decoded_transactions = { d["name"]: d["value"] for d in decoded_transactions_with_types diff --git a/app/tests/datasources/db/test_model.py b/app/tests/datasources/db/test_model.py index de66b16..07ce4ad 100644 --- a/app/tests/datasources/db/test_model.py +++ b/app/tests/datasources/db/test_model.py @@ -38,10 +38,27 @@ async def test_contract_get_abi_by_contract_address(self, session: AsyncSession) await abi.create(session) contract = Contract(address=b"a", name="A test contract", chain_id=1, abi=abi) await contract.create(session) - result = await contract.get_abi_by_contract_address(session, contract.address) + result = await contract.get_abi_by_contract_address( + session, contract.address, 1 + ) self.assertEqual(result, abi_json) - self.assertIsNone(await contract.get_abi_by_contract_address(session, b"b")) + # Check chain_id not matching + result = await contract.get_abi_by_contract_address( + session, contract.address, 2 + ) + self.assertIsNone(result) + + # Ignoring chain_id + result = await contract.get_abi_by_contract_address( + session, contract.address, None + ) + self.assertEqual(result, abi_json) + + # Check address not matching + self.assertIsNone( + await contract.get_abi_by_contract_address(session, b"b", None) + ) @database_session async def test_project(self, session: AsyncSession): diff --git a/app/tests/services/test_data_decoder.py b/app/tests/services/test_data_decoder.py index 47317f3..c5b2edf 100644 --- a/app/tests/services/test_data_decoder.py +++ b/app/tests/services/test_data_decoder.py @@ -518,6 +518,150 @@ async def test_db_tx_decoder(self, session: AsyncSession): # self.assertIn((contract.address,), decoder_service.cache_abis_by_address) # self.assertIn( (contract.address,), decoder_service.cache_contract_abi_selectors_with_functions_by_address, ) + @database_session + async def test_db_tx_decoder_multichain(self, session: AsyncSession): + # Both ABIs generate the same function selector, but with differently ordered parameter names, so + # decoding will be different + example_abi = cast( + ABI, + [ + { + "inputs": [ + { + "internalType": "uint256", + "name": "droidId", + "type": "uint256", + }, + { + "internalType": "uint256", + "name": "numberOfDroids", + "type": "uint256", + }, + ], + "name": "buyDroid", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function", + }, + ], + ) + + # Swap ABI parameters + example_abi_reversed = cast( + ABI, + [ + { + "inputs": [ + { + "internalType": "uint256", + "name": "numberOfDroids", + "type": "uint256", + }, + { + "internalType": "uint256", + "name": "droidId", + "type": "uint256", + }, + ], + "name": "buyDroid", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function", + }, + ], + ) + + example_data = ( + Web3() + .eth.contract(abi=example_abi) + .functions.buyDroid(4, 10) + .build_transaction( + get_empty_tx_params() | {"to": NULL_ADDRESS, "chainId": 1} + )["data"] + ) + + source = AbiSource(name="local", url="") + await source.create(session) + + abi = Abi( + abi_hash=b"ExampleABI", + abi_json=example_abi, + relevance=1, + source_id=source.id, + ) + await abi.create(session) + + abi_reversed = Abi( + abi_hash=b"ExampleABIReversed", + abi_json=example_abi_reversed, + relevance=100, + source_id=source.id, + ) + await abi_reversed.create(session) + + contract = Contract(address=b"a", abi=abi, name="ExampleContract", chain_id=1) + await contract.create(session) + contract_reversed = Contract( + address=b"a", abi=abi_reversed, name="ExampleContractReversed", chain_id=2 + ) + await contract_reversed.create(session) + + decoder_service = DataDecoderService() + await decoder_service.init(session) + + expected_arguments = {"droidId": "4", "numberOfDroids": "10"} + expected_arguments_reversed = {"numberOfDroids": "4", "droidId": "10"} + + contract_address = Address(b"a") + fn_name, arguments = await decoder_service.decode_transaction( + example_data, address=contract_address, chain_id=1 + ) + self.assertEqual(fn_name, "buyDroid") + self.assertEqual(arguments, expected_arguments) + + fn_name, arguments = await decoder_service.decode_transaction( + example_data, address=Address(contract.address), chain_id=2 + ) + self.assertEqual(fn_name, "buyDroid") + self.assertEqual(arguments, expected_arguments_reversed) + + # If chain_id is not matching, lower chain_id contract must be used + fn_name, arguments = await decoder_service.decode_transaction( + example_data, address=contract_address, chain_id=5 + ) + self.assertEqual(fn_name, "buyDroid") + self.assertEqual(arguments, expected_arguments) + + # If chain_id is not provided, lower chain_id contract must be used + fn_name, arguments = await decoder_service.decode_transaction( + example_data, address=contract_address, chain_id=None + ) + self.assertEqual(fn_name, "buyDroid") + self.assertEqual(arguments, expected_arguments) + + # If no contract is matching but abi is on the database, it should be decoded using the more relevant ABI + contract.address = b"b" + await contract.update(session) + contract_reversed.address = b"b" + await contract_reversed.update(session) + + # Check caches are working even if contract was updated on DB + fn_name, arguments = await decoder_service.decode_transaction( + example_data, address=contract_address, chain_id=1 + ) + self.assertEqual(fn_name, "buyDroid") + self.assertEqual(arguments, expected_arguments) + + # Init a new service to remove caches + decoder_service = DataDecoderService() + await decoder_service.init(session) + + fn_name, arguments = await decoder_service.decode_transaction( + example_data, address=contract_address, chain_id=1 + ) + self.assertEqual(fn_name, "buyDroid") + self.assertEqual(arguments, expected_arguments_reversed) + @database_session async def test_decode_fallback_calls_db_tx_decoder(self, session: AsyncSession): example_not_matched_abi = [