Skip to content

Commit

Permalink
Add accuracy for decoding endpoint
Browse files Browse the repository at this point in the history
- Closes #72
  • Loading branch information
Uxio0 committed Jan 22, 2025
1 parent c4861a6 commit 93f6985
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 5 deletions.
16 changes: 13 additions & 3 deletions app/routers/data_decoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import APIRouter, HTTPException

from app.routers.models import DataDecodedPublic, DataDecoderInput
from app.services.data_decoder import DataDecoded, get_data_decoder_service
from app.services.data_decoder import get_data_decoder_service

router = APIRouter(
prefix="/data-decoder",
Expand All @@ -12,7 +12,7 @@
@router.post("", response_model=DataDecodedPublic)
async def data_decoder(
input_data: DataDecoderInput,
) -> DataDecoded:
) -> DataDecodedPublic:
data_decoder_service = await get_data_decoder_service()
data_decoded = await data_decoder_service.get_data_decoded(
input_data.data,
Expand All @@ -25,4 +25,14 @@ async def data_decoder(
status_code=404, detail="Cannot find function selector to decode data"
)

return data_decoded
decoding_accuracy = await data_decoder_service.get_decoding_accuracy(
input_data.data,
address=input_data.to,
chain_id=input_data.chain_id,
)

return DataDecodedPublic(
method=data_decoded["method"],
parameters=data_decoded["parameters"],
accuracy=decoding_accuracy,
)
3 changes: 3 additions & 0 deletions app/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
fast_to_checksum_address,
)

from ..services.data_decoder import DecodingAccuracyEnum


class About(BaseModel):
version: str
Expand Down Expand Up @@ -104,6 +106,7 @@ class ParameterDecodedPublic(BaseModel):
class DataDecodedPublic(BaseModel):
method: str
parameters: list[ParameterDecodedPublic]
accuracy: DecodingAccuracyEnum


class MultisendDecodedPublic(BaseModel):
Expand Down
40 changes: 40 additions & 0 deletions app/services/data_decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from enum import Enum
from typing import Any, AsyncIterator, NotRequired, TypedDict, Union, cast

from async_lru import alru_cache
Expand Down Expand Up @@ -45,6 +46,15 @@ class DataDecoded(TypedDict):
parameters: list[ParameterDecoded]


class DecodingAccuracyEnum(Enum):
FULL_MATCH = "FULL_MATCH" # Matched contract and chain id
PARTIAL_MATCH = "PARTIAL_MATCH" # Matched contract
ONLY_FUNCTION_MATCH = (
"ONLY_FUNCTION_MATCH" # Matched function from another contract
)
NO_MATCH = "NO_MATCH" # Selector cannot be decoded


class MultisendDecoded(TypedDict):
operation: int
to: ChecksumAddress
Expand Down Expand Up @@ -431,6 +441,36 @@ async def decode_transaction(
}
return fn_name, decoded_transactions

async def get_decoding_accuracy(
self,
data: bytes | str,
address: Address | None = None,
chain_id: int | None = None,
) -> DecodingAccuracyEnum:
"""
Get decoding accuracy:
- FULL_MATCH: Contract `address` and `chain_id` matching
- PARTIAL_MATCH: Only contract `address` matching
- ONLY_FUNCTION_MATCH: Match with a function of another contract
- NO_MATCH: Cannot decode `data`
:param data:
:param address:
:param chain_id:
:return: DecodingAccuracyEnum
"""
selector = HexBytes(data)[:4]
if selector not in self.fn_selectors_with_abis:
return DecodingAccuracyEnum.NO_MATCH
if address is not None:
if chain_id is not None and await self.get_contract_abi(
address, chain_id=chain_id
):
return DecodingAccuracyEnum.FULL_MATCH
if self.get_contract_abi(address):
return DecodingAccuracyEnum.PARTIAL_MATCH
return DecodingAccuracyEnum.ONLY_FUNCTION_MATCH

def add_abi(self, abi: ABI) -> bool:
"""
Add a new abi without rebuilding the entire decoder
Expand Down
36 changes: 35 additions & 1 deletion app/tests/routers/test_data_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ...datasources.db.models import Abi, AbiSource, Contract
from ...main import app
from ...services.abis import AbiService
from ...services.data_decoder import get_data_decoder_service
from ...services.data_decoder import DecodingAccuracyEnum, get_data_decoder_service
from ..datasources.db.db_async_conn import DbAsyncConn


Expand Down Expand Up @@ -54,6 +54,7 @@ async def test_view_data_decoder(self, session: AsyncSession):
self.assertEqual(
response.json(),
{
"accuracy": DecodingAccuracyEnum.ONLY_FUNCTION_MATCH.name,
"method": "addOwnerWithThreshold",
"parameters": [
{
Expand Down Expand Up @@ -185,6 +186,7 @@ async def test_view_data_decoder_with_chain_id(self, session: AsyncSession):
self.assertEqual(
response.json(),
{
"accuracy": DecodingAccuracyEnum.ONLY_FUNCTION_MATCH.name,
"method": "buyDroid",
"parameters": [
{
Expand Down Expand Up @@ -218,6 +220,7 @@ async def test_view_data_decoder_with_chain_id(self, session: AsyncSession):
self.assertEqual(
response.json(),
{
"accuracy": DecodingAccuracyEnum.FULL_MATCH.name,
"method": "buyDroid",
"parameters": [
{
Expand All @@ -235,3 +238,34 @@ async def test_view_data_decoder_with_chain_id(self, session: AsyncSession):
],
},
)

response = self.client.post(
"/api/v1/data-decoder/",
json={
"data": example_data,
"to": contract_address,
"chainId": 3,
},
)
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.json(),
{
"accuracy": DecodingAccuracyEnum.PARTIAL_MATCH.name,
"method": "buyDroid",
"parameters": [
{
"name": "droidId",
"type": "uint256",
"value": "4",
"value_decoded": None,
},
{
"name": "numberOfDroids",
"type": "uint256",
"value": "10",
"value_decoded": None,
},
],
},
)
30 changes: 29 additions & 1 deletion app/tests/services/test_data_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ...services.data_decoder import (
CannotDecode,
DataDecoderService,
DecodingAccuracyEnum,
UnexpectedProblemDecoding,
get_data_decoder_service,
)
Expand Down Expand Up @@ -615,28 +616,44 @@ async def test_db_tx_decoder_multichain(self, session: AsyncSession):
fn_name, arguments = await decoder_service.decode_transaction(
example_data, address=contract_address, chain_id=1
)
accuracy = await decoder_service.get_decoding_accuracy(
example_data, address=contract_address, chain_id=1
)
self.assertEqual(fn_name, "buyDroid")
self.assertEqual(arguments, expected_arguments)
self.assertEqual(accuracy, DecodingAccuracyEnum.FULL_MATCH)

fn_name, arguments = await decoder_service.decode_transaction(
example_data, address=Address(contract.address), chain_id=2
)
accuracy = await decoder_service.get_decoding_accuracy(
example_data, address=contract_address, chain_id=2
)
self.assertEqual(fn_name, "buyDroid")
self.assertEqual(arguments, expected_arguments_reversed)
self.assertEqual(accuracy, DecodingAccuracyEnum.FULL_MATCH)

# If chain_id is not matching, lower chain_id contract must be used
fn_name, arguments = await decoder_service.decode_transaction(
example_data, address=contract_address, chain_id=5
)
accuracy = await decoder_service.get_decoding_accuracy(
example_data, address=contract_address, chain_id=5
)
self.assertEqual(fn_name, "buyDroid")
self.assertEqual(arguments, expected_arguments)
self.assertEqual(accuracy, DecodingAccuracyEnum.PARTIAL_MATCH)

# If chain_id is not provided, lower chain_id contract must be used
fn_name, arguments = await decoder_service.decode_transaction(
example_data, address=contract_address, chain_id=None
)
accuracy = await decoder_service.get_decoding_accuracy(
example_data, address=contract_address, chain_id=None
)
self.assertEqual(fn_name, "buyDroid")
self.assertEqual(arguments, expected_arguments)
self.assertEqual(accuracy, DecodingAccuracyEnum.PARTIAL_MATCH)

# If no contract is matching but abi is on the database, it should be decoded using the more relevant ABI
contract.address = b"b"
Expand All @@ -648,8 +665,12 @@ async def test_db_tx_decoder_multichain(self, session: AsyncSession):
fn_name, arguments = await decoder_service.decode_transaction(
example_data, address=contract_address, chain_id=1
)
accuracy = await decoder_service.get_decoding_accuracy(
example_data, address=contract_address, chain_id=1
)
self.assertEqual(fn_name, "buyDroid")
self.assertEqual(arguments, expected_arguments)
self.assertEqual(accuracy, DecodingAccuracyEnum.FULL_MATCH)

# Init a new service to remove caches
decoder_service = DataDecoderService()
Expand All @@ -658,8 +679,12 @@ async def test_db_tx_decoder_multichain(self, session: AsyncSession):
fn_name, arguments = await decoder_service.decode_transaction(
example_data, address=contract_address, chain_id=1
)
accuracy = await decoder_service.get_decoding_accuracy(
example_data, address=contract_address, chain_id=1
)
self.assertEqual(fn_name, "buyDroid")
self.assertEqual(arguments, expected_arguments_reversed)
self.assertEqual(accuracy, DecodingAccuracyEnum.PARTIAL_MATCH)

@database_session
async def test_decode_fallback_calls_db_tx_decoder(self, session: AsyncSession):
Expand Down Expand Up @@ -710,4 +735,7 @@ async def test_decode_fallback_calls_db_tx_decoder(self, session: AsyncSession):
)
self.assertEqual(fn_name, "fallback")
self.assertEqual(arguments, {})
# self.assertIn((contract_fallback.address,), decoder_service.cache_abis_by_address)
accuracy = await decoder_service.get_decoding_accuracy(
example_not_matched_data, address=Address(contract_fallback.address)
)
self.assertEqual(accuracy, DecodingAccuracyEnum.NO_MATCH)

0 comments on commit 93f6985

Please sign in to comment.