Skip to content
This repository has been archived by the owner on Jan 8, 2025. It is now read-only.

Commit

Permalink
provider: embed StarknetProvider in EthDataProvider (#1369)
Browse files Browse the repository at this point in the history
* provider: embed StarknetProvider in EthDataProvider

* fix comments

* clean up

* fix test

* clean up make file
  • Loading branch information
tcoratger authored Sep 10, 2024
1 parent 281eba0 commit d64d633
Show file tree
Hide file tree
Showing 16 changed files with 91 additions and 61 deletions.
3 changes: 0 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ test: katana-genesis load-env
test-target: load-env
cargo test --tests --all-features $(TARGET) -- --nocapture

test-target1: load-env
cargo test --package kakarot-rpc --test entry --all-features -- tests::eth_provider::test_send_raw_transaction --exact --show-output

benchmark:
cd benchmarks && bun i && bun run benchmark

Expand Down
30 changes: 19 additions & 11 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@ use crate::{
mempool::{KakarotPool, TransactionOrdering},
validate::KakarotTransactionValidatorBuilder,
},
providers::eth_provider::{
chain::ChainProvider,
database::{
ethereum::{EthereumBlockStore, EthereumTransactionStore},
state::EthDatabase,
Database,
providers::{
eth_provider::{
chain::ChainProvider,
database::{
ethereum::{EthereumBlockStore, EthereumTransactionStore},
state::EthDatabase,
Database,
},
error::{EthApiError, EthereumDataFormatError, KakarotError, SignatureError, TransactionError},
provider::{EthApiResult, EthDataProvider},
starknet::kakarot_core::to_starknet_transaction,
},
error::{EthApiError, EthereumDataFormatError, KakarotError, SignatureError, TransactionError},
provider::{EthApiResult, EthDataProvider},
starknet::kakarot_core::to_starknet_transaction,
sn_provider::StarknetProvider,
},
};
use alloy_rlp::Decodable;
Expand Down Expand Up @@ -47,6 +50,11 @@ impl<SP> EthClient<SP>
where
SP: Provider + Clone + Sync + Send,
{
/// Get the Starknet provider from the Ethereum provider.
pub const fn starknet_provider(&self) -> &StarknetProvider<SP> {
self.eth_provider.starknet_provider()
}

/// Tries to start a [`EthClient`] by fetching the current chain id, initializing a [`EthDataProvider`] and a [`Pool`].
pub async fn try_new(starknet_provider: SP, database: Database) -> eyre::Result<Self> {
let chain = (starknet_provider.chain_id().await.map_err(KakarotError::from)?.to_bigint()
Expand All @@ -55,7 +63,8 @@ where
.unwrap();

// Create a new EthDataProvider instance with the initialized database and Starknet provider.
let eth_provider = EthDataProvider::try_new(database, starknet_provider).await?;
let eth_provider =
EthDataProvider::try_new(database, StarknetProvider::new(Arc::new(starknet_provider))).await?;

let validator =
KakarotTransactionValidatorBuilder::new(Arc::new(ChainSpec { chain: chain.into(), ..Default::default() }))
Expand Down Expand Up @@ -130,7 +139,6 @@ where
// Add the transaction to the Starknet provider
let span = tracing::span!(tracing::Level::INFO, "sn::add_invoke_transaction");
let res = self
.eth_provider
.starknet_provider()
.add_invoke_transaction(starknet_transaction)
.instrument(span)
Expand Down
2 changes: 1 addition & 1 deletion src/eth_rpc/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ where
{
pub fn new(eth_client: EthClient<SP>) -> Self {
let eth_provider = eth_client.eth_provider().clone();
let starknet_provider = eth_provider.starknet_provider().clone();
let starknet_provider = eth_provider.starknet_provider_inner().clone();

let alchemy_provider = Arc::new(AlchemyDataProvider::new(eth_provider.clone()));
let pool_provider = Arc::new(PoolDataProvider::new(eth_provider.clone()));
Expand Down
7 changes: 2 additions & 5 deletions src/pool/mempool.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::validate::KakarotTransactionValidator;
use crate::{client::EthClient, providers::sn_provider::StarknetProvider};
use crate::client::EthClient;
use reth_primitives::{BlockId, U256};
use reth_transaction_pool::{
blobstore::NoopBlobStore, CoinbaseTipOrdering, EthPooledTransaction, Pool, TransactionPool,
Expand Down Expand Up @@ -62,7 +62,6 @@ impl<SP: starknet::providers::Provider + Send + Sync + Clone + 'static> AccountM
accounts.insert(
felt_address,
eth_client
.eth_provider()
.starknet_provider()
.get_nonce(starknet_block_id, felt_address)
.await
Expand Down Expand Up @@ -125,10 +124,8 @@ impl<SP: starknet::providers::Provider + Send + Sync + Clone + 'static> AccountM
async fn get_balance(&self, account_address: Felt) -> eyre::Result<U256> {
// Convert the optional Ethereum block ID to a Starknet block ID.
let starknet_block_id = self.eth_client.eth_provider().to_starknet_block_id(Some(BlockId::default())).await?;
// Create a new Starknet provider wrapper.
let starknet_provider = StarknetProvider::new(Arc::new(self.eth_client.eth_provider().starknet_provider()));
// Get the balance of the address at the given block ID.
starknet_provider.balance_at(account_address, starknet_block_id).await.map_err(Into::into)
self.eth_client.starknet_provider().balance_at(account_address, starknet_block_id).await.map_err(Into::into)
}

/// Processes a transaction for the given account if the balance is sufficient.
Expand Down
4 changes: 3 additions & 1 deletion src/providers/eth_provider/blocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ where
// In case the database is empty, use the starknet provider
None => {
let span = tracing::span!(tracing::Level::INFO, "sn::block_number");
U64::from(self.starknet_provider().block_number().instrument(span).await.map_err(KakarotError::from)?)
U64::from(
self.starknet_provider_inner().block_number().instrument(span).await.map_err(KakarotError::from)?,
)
}
Some(header) => {
let number = header.number.ok_or(EthApiError::UnknownBlockNumber(None))?;
Expand Down
2 changes: 1 addition & 1 deletion src/providers/eth_provider/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ where
{
async fn syncing(&self) -> EthApiResult<SyncStatus> {
let span = tracing::span!(tracing::Level::INFO, "sn::syncing");
Ok(match self.starknet_provider().syncing().instrument(span).await.map_err(KakarotError::from)? {
Ok(match self.starknet_provider_inner().syncing().instrument(span).await.map_err(KakarotError::from)? {
SyncStatusType::NotSyncing => SyncStatus::None,
SyncStatusType::Syncing(data) => SyncStatus::Info(SyncInfo {
starting_block: U256::from(data.starting_block_num),
Expand Down
2 changes: 1 addition & 1 deletion src/providers/eth_provider/gas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ where
}

async fn gas_price(&self) -> EthApiResult<U256> {
let kakarot_contract = KakarotCoreReader::new(*KAKAROT_ADDRESS, self.starknet_provider());
let kakarot_contract = KakarotCoreReader::new(*KAKAROT_ADDRESS, self.starknet_provider_inner());
let span = tracing::span!(tracing::Level::INFO, "sn::base_fee");
let gas_price =
kakarot_contract.get_base_fee().call().instrument(span).await.map_err(ExecutionError::from)?.base_fee;
Expand Down
25 changes: 17 additions & 8 deletions src/providers/eth_provider/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@ use crate::{
block::{EthBlockId, EthBlockNumberOrTag},
felt::Felt252Wrapper,
},
providers::eth_provider::{
BlockProvider, GasProvider, LogProvider, ReceiptProvider, StateProvider, TransactionProvider, TxPoolProvider,
providers::{
eth_provider::{
BlockProvider, GasProvider, LogProvider, ReceiptProvider, StateProvider, TransactionProvider,
TxPoolProvider,
},
sn_provider::StarknetProvider,
},
};
use cainome::cairo_serde::CairoArrayLegacy;
Expand Down Expand Up @@ -61,7 +65,7 @@ impl<T> EthereumProvider for T where
#[derive(Debug, Clone)]
pub struct EthDataProvider<SP: starknet::providers::Provider + Send + Sync> {
database: Database,
starknet_provider: SP,
starknet_provider: StarknetProvider<SP>,
pub(crate) chain_id: u64,
}

Expand All @@ -75,7 +79,12 @@ where
}

/// Returns a reference to the Starknet provider.
pub const fn starknet_provider(&self) -> &SP {
pub const fn starknet_provider(&self) -> &StarknetProvider<SP> {
&self.starknet_provider
}

/// Returns a reference to the underlying SP provider.
pub fn starknet_provider_inner(&self) -> &SP {
&self.starknet_provider
}
}
Expand All @@ -84,7 +93,7 @@ impl<SP> EthDataProvider<SP>
where
SP: starknet::providers::Provider + Send + Sync,
{
pub async fn try_new(database: Database, starknet_provider: SP) -> Result<Self> {
pub async fn try_new(database: Database, starknet_provider: StarknetProvider<SP>) -> Result<Self> {
// We take the chain_id modulo u32::MAX to ensure compatibility with tooling
// see: https://github.com/ethereum/EIPs/issues/2294
// Note: Metamask is breaking for a chain_id = u64::MAX - 1
Expand Down Expand Up @@ -157,7 +166,7 @@ where
let starknet_block_id = self.to_starknet_block_id(block_id).await?;
let call_input = self.prepare_call_input(request, block_id).await?;

let kakarot_contract = KakarotCoreReader::new(*KAKAROT_ADDRESS, &self.starknet_provider);
let kakarot_contract = KakarotCoreReader::new(*KAKAROT_ADDRESS, self.starknet_provider_inner());
let span = tracing::span!(tracing::Level::INFO, "sn::eth_call");
let call_output = kakarot_contract
.eth_call(
Expand Down Expand Up @@ -194,7 +203,7 @@ where
let starknet_block_id = self.to_starknet_block_id(block_id).await?;
let call_input = self.prepare_call_input(request, block_id).await?;

let kakarot_contract = KakarotCoreReader::new(*KAKAROT_ADDRESS, &self.starknet_provider);
let kakarot_contract = KakarotCoreReader::new(*KAKAROT_ADDRESS, self.starknet_provider_inner());
let span = tracing::span!(tracing::Level::INFO, "sn::eth_estimate_gas");
let estimate_gas_output = kakarot_contract
.eth_estimate_gas(
Expand Down Expand Up @@ -305,7 +314,7 @@ where
};

let signer_starknet_address = starknet_address(signer);
let account_contract = AccountContractReader::new(signer_starknet_address, &self.starknet_provider);
let account_contract = AccountContractReader::new(signer_starknet_address, self.starknet_provider_inner());
let maybe_is_initialized = account_contract
.is_initialized()
.block_id(starknet::core::types::BlockId::Tag(BlockTag::Latest))
Expand Down
19 changes: 6 additions & 13 deletions src/providers/eth_provider/state.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::sync::Arc;

use super::{
database::state::{EthCacheDatabase, EthDatabase},
error::{EthApiError, ExecutionError, TransactionError},
Expand All @@ -9,12 +7,9 @@ use super::{
use crate::{
into_via_wrapper,
models::felt::Felt252Wrapper,
providers::{
eth_provider::{
provider::{EthApiResult, EthDataProvider},
BlockProvider, ChainProvider,
},
sn_provider::StarknetProvider,
providers::eth_provider::{
provider::{EthApiResult, EthDataProvider},
BlockProvider, ChainProvider,
},
};
use async_trait::async_trait;
Expand Down Expand Up @@ -72,10 +67,8 @@ where
async fn balance(&self, address: Address, block_id: Option<BlockId>) -> EthApiResult<U256> {
// Convert the optional Ethereum block ID to a Starknet block ID.
let starknet_block_id = self.to_starknet_block_id(block_id).await?;
// Create a new Starknet provider wrapper.
let starknet_provider = StarknetProvider::new(Arc::new(self.starknet_provider()));
// Get the balance of the address at the given block ID.
starknet_provider.balance_at(starknet_address(address), starknet_block_id).await.map_err(Into::into)
self.starknet_provider().balance_at(starknet_address(address), starknet_block_id).await.map_err(Into::into)
}

async fn storage_at(
Expand All @@ -87,7 +80,7 @@ where
let starknet_block_id = self.to_starknet_block_id(block_id).await?;

let address = starknet_address(address);
let contract = AccountContractReader::new(address, self.starknet_provider());
let contract = AccountContractReader::new(address, self.starknet_provider_inner());

let keys = split_u256(index.0);
let storage_address = get_storage_var_address("Account_storage", &keys).expect("Storage var name is not ASCII");
Expand All @@ -112,7 +105,7 @@ where
let starknet_block_id = self.to_starknet_block_id(block_id).await?;

let address = starknet_address(address);
let account_contract = AccountContractReader::new(address, self.starknet_provider());
let account_contract = AccountContractReader::new(address, self.starknet_provider_inner());
let span = tracing::span!(tracing::Level::INFO, "sn::code");
let bytecode = account_contract.bytecode().block_id(starknet_block_id).call().instrument(span).await;

Expand Down
10 changes: 7 additions & 3 deletions src/providers/eth_provider/transactions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ where
let starknet_block_id = self.to_starknet_block_id(block_id).await?;

let address = starknet_address(address);
let account_contract = AccountContractReader::new(address, self.starknet_provider());
let account_contract = AccountContractReader::new(address, self.starknet_provider_inner());
let span = tracing::span!(tracing::Level::INFO, "sn::kkrt_nonce");
let maybe_nonce = account_contract.get_nonce().block_id(starknet_block_id).call().instrument(span).await;

Expand All @@ -130,8 +130,12 @@ where
// This can happen when an underlying Starknet transaction reverts => Account storage changes are reverted,
// but the protocol nonce is still incremented.
let span = tracing::span!(tracing::Level::INFO, "sn::protocol_nonce");
let protocol_nonce =
self.starknet_provider().get_nonce(starknet_block_id, address).instrument(span).await.unwrap_or_default();
let protocol_nonce = self
.starknet_provider_inner()
.get_nonce(starknet_block_id, address)
.instrument(span)
.await
.unwrap_or_default();
let nonce = nonce.max(protocol_nonce);

Ok(into_via_wrapper!(nonce))
Expand Down
19 changes: 15 additions & 4 deletions src/providers/sn_provider/starknet_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,31 @@ use crate::{
},
};
use reth_primitives::U256;
use starknet::core::types::{BlockId, Felt};
use std::sync::Arc;
use starknet::{
core::types::{BlockId, Felt},
providers::Provider,
};
use std::{ops::Deref, sync::Arc};
use tracing::Instrument;

/// A provider wrapper around the Starknet provider to expose utility methods.
#[derive(Debug, Clone)]
pub struct StarknetProvider<SP: starknet::providers::Provider + Send + Sync> {
pub struct StarknetProvider<SP: Provider + Send + Sync> {
/// The underlying Starknet provider wrapped in an [`Arc`] for shared ownership across threads.
provider: Arc<SP>,
}

impl<SP: Provider + Send + Sync> Deref for StarknetProvider<SP> {
type Target = SP;

fn deref(&self) -> &Self::Target {
&self.provider
}
}

impl<SP> StarknetProvider<SP>
where
SP: starknet::providers::Provider + Send + Sync,
SP: Provider + Send + Sync,
{
/// Creates a new [`StarknetProvider`] instance from an [`Arc`]-wrapped Starknet provider.
pub const fn new(provider: Arc<SP>) -> Self {
Expand Down
6 changes: 3 additions & 3 deletions src/test_utils/eoa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl<P: Provider + Send + Sync + Clone> Eoa<P> for KakarotEOA<P> {

impl<P: Provider + Send + Sync + Clone> KakarotEOA<P> {
fn starknet_provider(&self) -> &P {
self.eth_client.eth_provider().starknet_provider()
self.eth_client.starknet_provider()
}

/// Deploys an EVM contract given a contract name and constructor arguments
Expand Down Expand Up @@ -140,7 +140,7 @@ impl<P: Provider + Send + Sync + Clone> KakarotEOA<P> {
let tx_hash: Felt252Wrapper = tx_hash.into();

watch_tx(
self.eth_client.eth_provider().starknet_provider(),
self.eth_client.eth_provider().starknet_provider_inner(),
tx_hash.clone().into(),
std::time::Duration::from_millis(300),
60,
Expand Down Expand Up @@ -197,7 +197,7 @@ impl<P: Provider + Send + Sync + Clone> KakarotEOA<P> {
let starknet_tx_hash = Felt::from_bytes_be(&bytes);

watch_tx(
self.eth_client.eth_provider().starknet_provider(),
self.eth_client.eth_provider().starknet_provider_inner(),
starknet_tx_hash,
std::time::Duration::from_millis(300),
60,
Expand Down
2 changes: 1 addition & 1 deletion src/test_utils/katana/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ impl<'a> Katana {
}

pub fn starknet_provider(&self) -> Arc<JsonRpcClient<HttpTransport>> {
self.eoa.eth_client.eth_provider().starknet_provider().clone()
self.eoa.eth_client.eth_provider().starknet_provider_inner().clone()
}

pub fn eoa(&self) -> KakarotEOA<Arc<JsonRpcClient<HttpTransport>>> {
Expand Down
8 changes: 6 additions & 2 deletions src/tracing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,10 @@ fn env_with_tx(env: &EnvWithHandlerCfg, tx: reth_rpc_types::Transaction) -> Trac
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::eth_provider::{database::Database, provider::EthDataProvider};
use crate::providers::{
eth_provider::{database::Database, provider::EthDataProvider},
sn_provider::StarknetProvider,
};
use builder::TracerBuilder;
use mongodb::options::{DatabaseOptions, ReadConcern, WriteConcern};
use starknet::providers::{jsonrpc::HttpTransport, JsonRpcClient};
Expand Down Expand Up @@ -422,7 +425,8 @@ mod tests {
),
);

let eth_provider = Arc::new(EthDataProvider::try_new(db, starknet_provider).await.unwrap());
let eth_provider =
Arc::new(EthDataProvider::try_new(db, StarknetProvider::new(Arc::new(starknet_provider))).await.unwrap());
let tracer = TracerBuilder::new(eth_provider)
.await
.unwrap()
Expand Down
2 changes: 1 addition & 1 deletion tests/tests/eth_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ async fn test_send_raw_transaction_pre_eip_155(#[future] katana: Katana, _setup:
assert_eq!(mempool_size_after_send.pending, 1);
assert_eq!(mempool_size_after_send.total, 1);

watch_tx(eth_provider.starknet_provider(), starknet_tx_hash, std::time::Duration::from_millis(300), 60)
watch_tx(eth_provider.starknet_provider_inner(), starknet_tx_hash, std::time::Duration::from_millis(300), 60)
.await
.expect("Tx polling failed");

Expand Down
Loading

0 comments on commit d64d633

Please sign in to comment.