From 51614cb3719e652215a0d4089b42899ba630a0f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mois=C3=A9s?= <7888669+moisses89@users.noreply.github.com> Date: Thu, 19 Dec 2024 10:22:29 +0100 Subject: [PATCH] Add dependencies --- app/routers/contracts.py | 9 +++++++-- app/services/contract.py | 5 ++--- requirements/prod.txt | 2 ++ 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/app/routers/contracts.py b/app/routers/contracts.py index 5b5fb66..44726a3 100644 --- a/app/routers/contracts.py +++ b/app/routers/contracts.py @@ -1,7 +1,9 @@ from typing import Annotated -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Depends, HTTPException, Query +from hexbytes import HexBytes +from safe_eth.eth.utils import fast_is_checksum_address from sqlmodel.ext.asyncio.session import AsyncSession from ..datasources.db.database import get_database_session @@ -24,5 +26,8 @@ async def list_contracts( offset: int = Query(None), session: AsyncSession = Depends(get_database_session), ) -> PaginatedResponse[Contract]: + if not fast_is_checksum_address(address): + raise HTTPException(status_code=400, detail="Address is not checksumed") + contracts_service = ContractService(limit, offset) - return await contracts_service.get_contract(session, address, chain_ids) + return await contracts_service.get_contract(session, HexBytes(address), chain_ids) diff --git a/app/services/contract.py b/app/services/contract.py index 8191e9d..9f91712 100644 --- a/app/services/contract.py +++ b/app/services/contract.py @@ -1,6 +1,5 @@ from typing import Any, Sequence -from hexbytes import HexBytes from sqlmodel.ext.asyncio.session import AsyncSession from app.datasources.db.models import Contract @@ -25,7 +24,7 @@ async def get_all(session: AsyncSession) -> Sequence[Contract]: return await Contract.get_all(session) async def get_contract( - self, session: AsyncSession, address: str, chain_ids: list[int] | None + self, session: AsyncSession, address: bytes, chain_ids: list[int] | None ) -> PaginatedResponse[Any]: """ Get the contract by address and/or chain_ids @@ -37,5 +36,5 @@ async def get_contract( """ return await self.pagination.paginate( - session, Contract.get_contract(HexBytes(address), chain_ids) + session, Contract.get_contract(address, chain_ids) ) diff --git a/requirements/prod.txt b/requirements/prod.txt index d576647..ab63ec0 100644 --- a/requirements/prod.txt +++ b/requirements/prod.txt @@ -3,8 +3,10 @@ alembic==1.14.0 asyncpg==0.30.0 dramatiq[redis, watch]==1.17.1 fastapi[all]==0.115.6 +hexbytes==1.2.1 periodiq==0.13.0 pydantic-settings==2.7.0 redis[hiredis]==5.2.1 +safe-eth-py==6.1.0 sqladmin[full]==0.20.1 sqlmodel==0.0.22