Skip to content

Commit

Permalink
Implement multichain support for decoding (#70)
Browse files Browse the repository at this point in the history
- Closes #38



---------

Co-authored-by: Uxio Fuentefria <[email protected]>
Co-authored-by: Moisés <[email protected]>
  • Loading branch information
3 people authored Jan 20, 2025
1 parent 01e9832 commit aee01f7
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 34 deletions.
15 changes: 12 additions & 3 deletions app/datasources/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,24 @@ async def get_or_create(

@classmethod
async def get_abi_by_contract_address(
cls, session: AsyncSession, address: bytes
cls, session: AsyncSession, address: bytes, chain_id: int | None
) -> ABI | None:
# TODO Add chain_id filter to support multichain
results = await session.exec(
"""
:return: Json ABI given the contract `address` and `chain_id`. If `chain_id` is not given,
sort the ABIs by `chain_id` and return the first one.
"""
query = (
select(Abi.abi_json)
.join(cls)
.where(cls.address == address)
.where(cls.abi_id == Abi.id)
)
if chain_id is not None:
query = query.where(cls.chain_id == chain_id)
else:
query = query.order_by(col(cls.chain_id))

results = await session.exec(query)
if result := results.first():
return cast(ABI, result)
return None
Expand Down
109 changes: 80 additions & 29 deletions app/services/data_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ class DataDecoderService:
dummy_w3 = Web3()
session: AsyncSession | None

fn_selectors_with_abis: dict[bytes, ABIFunction]
multisend_abis: list[ABI]
multisend_fn_selectors_with_abis: dict[bytes, ABIFunction]

async def init(self, session: AsyncSession) -> None:
"""
Initialize the data decoder service, loading the ABIs from the database and storing the 4byte selectors
Expand Down Expand Up @@ -129,27 +133,38 @@ async def get_multisend_abis(self) -> AsyncIterator[ABI]:
@alru_cache(maxsize=2048)
@database_session
async def get_contract_abi(
self, address: Address, session: AsyncSession | None = None
self,
address: Address,
chain_id: int | None,
session: AsyncSession | None = None,
) -> ABI | None:
"""
Retrieves the ABI for the contract at the given address.
:param address: Contract address
:param chain_id: Chain id for the contract
:param session: Database session, provided by the decorator
:return: List of ABI data if found, `None` otherwise.
"""
assert session is not None
return await Contract.get_abi_by_contract_address(session, HexBytes(address))
return await Contract.get_abi_by_contract_address(
session, HexBytes(address), chain_id
)

@alru_cache(maxsize=2048)
async def get_contract_fallback_function(
self, address: Address
self, address: Address, chain_id: int | None
) -> ABIFunction | None:
"""
:param address: Contract address
:return: Fallback ABIFunction if found, `None` otherwise.
"""
abi = await self.get_contract_abi(address)
:param chain_id: Chain for the contract
:return: Fallback `ABIFunction` if found, `None` otherwise.
If contract is not found for the chain, return the first one that matches in other chain.
"""
abi = await self.get_contract_abi(address, chain_id)
if not abi and chain_id is not None:
# Try to find an ABI in other network
abi = await self.get_contract_abi(address, None)
if abi:
return next(
(
Expand All @@ -163,25 +178,29 @@ async def get_contract_fallback_function(

@alru_cache(maxsize=2048)
async def get_contract_abi_selectors_with_functions(
self, address: Address
self, address: Address, chain_id: int | None
) -> dict[bytes, ABIFunction] | None:
"""
:param address: Contract address
:return: Dictionary of function selects with ABIFunction if found, `None` otherwise
"""
abi = await self.get_contract_abi(address)
:param chain_id: Chain for the contract
:return: Dictionary of function selects with `ABIFunction` if found, `None` otherwise
If contract is not found for the chain, return the first one that matches in other chain.
"""
abi = await self.get_contract_abi(address, chain_id)
if not abi and chain_id is not None:
# Try to find an ABI in other network
abi = await self.get_contract_abi(address, None)
if abi:
# TODO We should return that there's a fullMatch for this `data` and `address`, so we are sure
# we are decoding the `data` correctly
return self._generate_selectors_with_abis_from_abi(abi)
return None

async def get_abi_function(
self, data: bytes, address: Address | None = None
self, data: bytes, address: Address | None = None, chain_id: int | None = None
) -> ABIFunction | None:
"""
:param data: transaction data
:param address: contract address in case of ABI colliding
:param chain_id: Chain for the contract
:return: Abi function for data if it can be decoded, `None` if not found
"""
selector = data[:4]
Expand All @@ -190,7 +209,9 @@ async def get_abi_function(
# Try to use specific ABI if address provided
if address:
contract_selectors_with_abis = (
await self.get_contract_abi_selectors_with_functions(address)
await self.get_contract_abi_selectors_with_functions(
address, chain_id
)
)
if (
contract_selectors_with_abis
Expand All @@ -202,7 +223,7 @@ async def get_abi_function(
return self.fn_selectors_with_abis[selector]
# Check if the contract has a fallback call and return a minimal ABIFunction for fallback call
elif address:
return await self.get_contract_fallback_function(address)
return await self.get_contract_fallback_function(address, chain_id)
return None

def _parse_decoded_arguments(self, value_decoded: Any) -> Any:
Expand All @@ -220,13 +241,17 @@ def _parse_decoded_arguments(self, value_decoded: Any) -> Any:
return value_decoded

async def _decode_data(
self, data: bytes | str, address: Address | None = None
self,
data: bytes | str,
address: Address | None = None,
chain_id: int | None = None,
) -> tuple[str, list[tuple[str, str, Any]]]:
"""
Decode tx data
:param data: Tx data as `hex string` or `bytes`
:param address: contract address in case of ABI colliding
:param chain_id: Chain for the contract
:return: Tuple with the `function name` and a List of sorted tuples with
the `name` of the argument, `type` and `value`
:raises: CannotDecode if data cannot be decoded. You should catch this exception when using this function
Expand All @@ -238,7 +263,7 @@ async def _decode_data(

data = HexBytes(data)
params = data[4:]
fn_abi = await self.get_abi_function(data, address)
fn_abi = await self.get_abi_function(data, address, chain_id)
if not fn_abi:
raise CannotDecode(data.hex())
try:
Expand All @@ -254,12 +279,13 @@ async def _decode_data(
return fn_abi["name"], list(zip(names, types, values))

async def decode_multisend_data(
self, data: bytes | str
self, data: bytes | str, chain_id: int | None = None
) -> list[MultisendDecoded] | None:
"""
Decodes Multisend raw data to Multisend dictionary
:param data:
:param chain_id:
:return:
"""
try:
Expand All @@ -271,7 +297,9 @@ async def decode_multisend_data(
value=str(multisend_tx.value),
data=HexStr(multisend_tx.data.hex()) if multisend_tx.data else None,
data_decoded=await self.get_data_decoded(
multisend_tx.data, address=cast(Address, multisend_tx.to)
multisend_tx.data,
address=cast(Address, multisend_tx.to),
chain_id=chain_id,
),
)
for multisend_tx in multisend_txs
Expand All @@ -285,27 +313,34 @@ async def decode_multisend_data(
return None

async def get_data_decoded(
self, data: bytes | str, address: Address | None = None
self,
data: bytes | str,
address: Address | None = None,
chain_id: int | None = None,
) -> DataDecoded | None:
"""
Return data prepared for serializing
:param data:
:param address: contract address in case of ABI colliding
:param chain_id: chain for contract
:return:
"""
if not data:
return None
try:
fn_name, parameters = await self.decode_transaction_with_types(
data, address=address
data, address=address, chain_id=chain_id
)
return {"method": fn_name, "parameters": parameters}
except DataDecoderException:
return None

async def decode_parameters_data(
self, data: bytes, parameters: list[ParameterDecoded]
self,
data: bytes,
parameters: list[ParameterDecoded],
chain_id: int | None = None,
) -> list[ParameterDecoded]:
"""
Decode inner data for function parameters for:
Expand All @@ -319,7 +354,9 @@ async def decode_parameters_data(
fn_selector = data[:4]
if fn_selector in self.multisend_fn_selectors_with_abis:
# If MultiSend, decode the transactions
parameters[0]["value_decoded"] = await self.decode_multisend_data(data)
parameters[0]["value_decoded"] = await self.decode_multisend_data(
data, chain_id=chain_id
)

elif (
fn_selector == self.EXEC_TRANSACTION_SELECTOR
Expand All @@ -331,49 +368,63 @@ async def decode_parameters_data(
# selector is `0x6a761202` and parameters[2] is data
try:
parameters[2]["value_decoded"] = await self.get_data_decoded(
data, address=parameters[0]["value"]
data, address=parameters[0]["value"], chain_id=chain_id
)
except DataDecoderException:
logger.warning("Cannot decode `execTransaction`", exc_info=True)
return parameters

async def decode_transaction_with_types(
self, data: bytes | str, address: Address | None = None
self,
data: bytes | str,
address: Address | None = None,
chain_id: int | None = None,
) -> tuple[str, list[ParameterDecoded]]:
"""
Decode tx data and return a list of dictionaries
:param data: Tx data as `hex string` or `bytes`
:param address: contract address in case of ABI colliding
:param chain_id: chain for the contract
:return: Tuple with the `function name` and a list of dictionaries
[{'name': str, 'type': str, 'value': `depending on type`}...]
:raises: CannotDecode if data cannot be decoded. You should catch this exception when using this function
:raises: UnexpectedProblemDecoding if there's an unexpected problem decoding (it shouldn't happen)
"""
data = HexBytes(data)
fn_name, raw_parameters = await self._decode_data(data, address=address)
fn_name, raw_parameters = await self._decode_data(
data, address=address, chain_id=chain_id
)
# Parameters are returned as tuple, convert it to a dictionary
parameters = [
ParameterDecoded(name=name, type=argument_type, value=value)
for name, argument_type, value in raw_parameters
]
nested_parameters = await self.decode_parameters_data(data, parameters)
nested_parameters = await self.decode_parameters_data(
data, parameters, chain_id=chain_id
)
return fn_name, nested_parameters

async def decode_transaction(
self, data: bytes | str, address: Address | None = None
self,
data: bytes | str,
address: Address | None = None,
chain_id: int | None = None,
) -> tuple[str, dict[str, Any]]:
"""
Decode tx data and return all the parameters in the same dictionary
:param data: Tx data as `hex string` or `bytes`
:param address: contract address in case of ABI colliding
:param chain_id: chain for the contract
:return: Tuple with the `function name` and a dictionary with the arguments of the function
:raises: CannotDecode if data cannot be decoded. You should catch this exception when using this function
:raises: UnexpectedProblemDecoding if there's an unexpected problem decoding (it shouldn't happen)
"""
fn_name, decoded_transactions_with_types = (
await self.decode_transaction_with_types(data, address=address)
await self.decode_transaction_with_types(
data, address=address, chain_id=chain_id
)
)
decoded_transactions = {
d["name"]: d["value"] for d in decoded_transactions_with_types
Expand Down
21 changes: 19 additions & 2 deletions app/tests/datasources/db/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,27 @@ async def test_contract_get_abi_by_contract_address(self, session: AsyncSession)
await abi.create(session)
contract = Contract(address=b"a", name="A test contract", chain_id=1, abi=abi)
await contract.create(session)
result = await contract.get_abi_by_contract_address(session, contract.address)
result = await contract.get_abi_by_contract_address(
session, contract.address, 1
)
self.assertEqual(result, abi_json)

self.assertIsNone(await contract.get_abi_by_contract_address(session, b"b"))
# Check chain_id not matching
result = await contract.get_abi_by_contract_address(
session, contract.address, 2
)
self.assertIsNone(result)

# Ignoring chain_id
result = await contract.get_abi_by_contract_address(
session, contract.address, None
)
self.assertEqual(result, abi_json)

# Check address not matching
self.assertIsNone(
await contract.get_abi_by_contract_address(session, b"b", None)
)

@database_session
async def test_project(self, session: AsyncSession):
Expand Down
Loading

0 comments on commit aee01f7

Please sign in to comment.