Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Batch proposal spend limits #3471

Draft
wants to merge 9 commits into
base: staging
Choose a base branch
from
25 changes: 25 additions & 0 deletions node/bft/ledger-service/src/ledger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,4 +404,29 @@ impl<N: Network, C: ConsensusStorage<N>> LedgerService<N> for CoreLedgerService<
tracing::info!("\n\nAdvanced to block {} at round {} - {}\n", block.height(), block.round(), block.hash());
Ok(())
}

fn compute_cost(&self, _transaction_id: N::TransactionID, transaction: Data<Transaction<N>>) -> Result<u64> {
// TODO: move to VM or ledger?
let process = self.ledger.vm().process();

// Deserialize the transaction. If the transaction exceeds the maximum size, then return an error.
let transaction = match transaction {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deserialization is extremely expensive. I would consider moving this calculation into ledger.rs:check_transaction_basic, where we already deserialize. Perhaps that function can return the compute cost? Or some other design?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did consider it and should revisit the idea but intuitively we might want to avoid coupling the cost calculation with check_transaction_basic as it's only called in propose_batch and not in process_batch_propose_from_peer, at least not directly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good point... That is unfortunate, because in the case of process_batch_propose_from_peer, we might be retrieving the transmission from disk, in which case we'll most certainly have to incur the deserialization cost.

Maybe if we let fn compute_cost cost take in a Transaction instead of Data<Transaction>, it can at least be made explicit. For our own proposal we call it from within check_transaction_basic, for incoming proposals we'll need to deserialize before calling it.

And if it turns out to be a bottleneck, we can always refactor the locations where we deserialize more comprehensively, and potentially create a cache for the compute_cache if needed.

Data::Object(transaction) => transaction,
Data::Buffer(bytes) => Transaction::<N>::read_le(&mut bytes.take(N::MAX_TRANSACTION_SIZE as u64))?,
};

// Collect the Optional Stack corresponding to the transaction if its an Execution.
let stack = if let Transaction::Execute(_, ref execution, _) = transaction {
// Get the root transition from the execution.
let root_transition = execution.peek()?;
// Get the stack from the process.
Some(process.read().get_stack(root_transition.program_id())?.clone())
} else {
None
};

use snarkvm::prelude::compute_cost;

compute_cost(&transaction, stack)
}
}
6 changes: 6 additions & 0 deletions node/bft/ledger-service/src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,4 +235,10 @@ impl<N: Network> LedgerService<N> for MockLedgerService<N> {
self.height_to_round_and_hash.lock().insert(block.height(), (block.round(), block.hash()));
Ok(())
}

/// Computes the execution cost in microcredits for a transaction.
fn compute_cost(&self, _transaction_id: N::TransactionID, _transaction: Data<Transaction<N>>) -> Result<u64> {
// Return 1 credit so this function can be used to test spend limits.
Ok(10_000_000)
}
}
5 changes: 5 additions & 0 deletions node/bft/ledger-service/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,4 +186,9 @@ impl<N: Network> LedgerService<N> for ProverLedgerService<N> {
fn advance_to_next_block(&self, block: &Block<N>) -> Result<()> {
bail!("Cannot advance to next block in prover - {block}")
}

/// Computes the execution cost in microcredits for a transaction.
fn compute_cost(&self, transaction_id: N::TransactionID, _transaction: Data<Transaction<N>>) -> Result<u64> {
bail!("Transaction '{transaction_id}' doesn't exist in prover")
vicsn marked this conversation as resolved.
Show resolved Hide resolved
}
}
3 changes: 3 additions & 0 deletions node/bft/ledger-service/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,7 @@ pub trait LedgerService<N: Network>: Debug + Send + Sync {
/// Adds the given block as the next block in the ledger.
#[cfg(feature = "ledger-write")]
fn advance_to_next_block(&self, block: &Block<N>) -> Result<()>;

/// Computes the execution cost in microcredits for a transaction.
fn compute_cost(&self, transaction_id: N::TransactionID, transaction: Data<Transaction<N>>) -> Result<u64>;
}
5 changes: 5 additions & 0 deletions node/bft/ledger-service/src/translucent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,9 @@ impl<N: Network, C: ConsensusStorage<N>> LedgerService<N> for TranslucentLedgerS
fn advance_to_next_block(&self, block: &Block<N>) -> Result<()> {
self.inner.advance_to_next_block(block)
}

/// Computes the execution cost in microcredits for a transaction.
fn compute_cost(&self, transaction_id: N::TransactionID, transaction: Data<Transaction<N>>) -> Result<u64> {
self.inner.compute_cost(transaction_id, transaction)
}
}
10 changes: 10 additions & 0 deletions node/bft/src/helpers/ready.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ impl<N: Network> Ready<N> {
transmissions.drain(range).collect::<IndexMap<_, _>>()
}

/// Inserts the transmission at the front of the queue.
pub fn shift_insert_front(&self, key: TransmissionID<N>, value: Transmission<N>) {
self.transmissions.write().shift_insert(0, key, value);
}

/// Removes and returns the first transmission from the queue.
pub fn shift_remove_front(&self) -> Option<(TransmissionID<N>, Transmission<N>)> {
self.transmissions.write().shift_remove_index(0)
}

/// Clears all solutions from the ready queue.
pub(crate) fn clear_solutions(&self) {
// Acquire the write lock.
Expand Down
212 changes: 138 additions & 74 deletions node/bft/src/primary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,90 +488,96 @@ impl<N: Network> Primary<N> {
return Ok(());
}

// Determined the required number of transmissions per worker.
let num_transmissions_per_worker = BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / self.num_workers() as usize;
// Initialize the map of transmissions.
let mut transmissions: IndexMap<_, _> = Default::default();
// Take the transmissions from the workers.
for worker in self.workers.iter() {
// Initialize a tracker for included transmissions for the current worker.
let mut num_transmissions_included_for_worker = 0;
// Keep draining the worker until the desired number of transmissions is reached or the worker is empty.
'outer: while num_transmissions_included_for_worker < num_transmissions_per_worker {
// Determine the number of remaining transmissions for the worker.
let num_remaining_transmissions =
num_transmissions_per_worker.saturating_sub(num_transmissions_included_for_worker);
// Drain the worker.
let mut worker_transmissions = worker.drain(num_remaining_transmissions).peekable();
// If the worker is empty, break early.
if worker_transmissions.peek().is_none() {
// Track the total execution costs of the batch proposal as it is being constructed.
let mut proposal_cost = 0u64;
// Note: worker draining and transaction inclusion needs to be thought
// through carefully when there is more than one worker. The fairness
// provided by one worker (FIFO) is no longer guaranteed with multiple workers.
debug_assert_eq!(MAX_WORKERS, 1);

'outer: for worker in self.workers().iter() {
let mut num_worker_transmissions = 0usize;

// TODO(nkls): this is O(n), consider improving the underlying data structures.
while let Some((id, transmission)) = worker.shift_remove_front() {
if transmissions.len() == BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH {
break 'outer;
}
// Iterate through the worker transmissions.
'inner: for (id, transmission) in worker_transmissions {
// Check if the ledger already contains the transmission.
if self.ledger.contains_transmission(&id).unwrap_or(true) {
trace!("Proposing - Skipping transmission '{}' - Already in ledger", fmt_id(id));
continue 'inner;
}
// Check if the storage already contain the transmission.
// Note: We do not skip if this is the first transmission in the proposal, to ensure that
// the primary does not propose a batch with no transmissions.
if !transmissions.is_empty() && self.storage.contains_transmission(id) {
trace!("Proposing - Skipping transmission '{}' - Already in storage", fmt_id(id));
continue 'inner;

if num_worker_transmissions == Worker::<N>::MAX_TRANSMISSIONS_PER_WORKER {
continue 'outer;
}

// Check if the ledger already contains the transmission.
if self.ledger.contains_transmission(&id).unwrap_or(true) {
trace!("Proposing - Skipping transmission '{}' - Already in ledger", fmt_id(id));
continue;
}

// Check if the storage already contain the transmission.
// Note: We do not skip if this is the first transmission in the proposal, to ensure that
// the primary does not propose a batch with no transmissions.
if !transmissions.is_empty() && self.storage.contains_transmission(id) {
trace!("Proposing - Skipping transmission '{}' - Already in storage", fmt_id(id));
continue;
}

// Check the transmission is still valid.
match (id, transmission.clone()) {
(TransmissionID::Solution(solution_id, checksum), Transmission::Solution(solution)) => {
// Ensure the checksum matches. If not, skip the solution.
if !matches!(solution.to_checksum::<N>(), Ok(solution_checksum) if solution_checksum == checksum)
{
trace!("Proposing - Skipping solution '{}' - Checksum mismatch", fmt_id(solution_id));
continue;
}
// Check if the solution is still valid.
if let Err(e) = self.ledger.check_solution_basic(solution_id, solution).await {
trace!("Proposing - Skipping solution '{}' - {e}", fmt_id(solution_id));
continue;
}
}
// Check the transmission is still valid.
match (id, transmission.clone()) {
(TransmissionID::Solution(solution_id, checksum), Transmission::Solution(solution)) => {
// Ensure the checksum matches.
match solution.to_checksum::<N>() {
Ok(solution_checksum) if solution_checksum == checksum => (),
_ => {
trace!(
"Proposing - Skipping solution '{}' - Checksum mismatch",
fmt_id(solution_id)
);
continue 'inner;
}
}
// Check if the solution is still valid.
if let Err(e) = self.ledger.check_solution_basic(solution_id, solution).await {
trace!("Proposing - Skipping solution '{}' - {e}", fmt_id(solution_id));
continue 'inner;
}
(TransmissionID::Transaction(transaction_id, checksum), Transmission::Transaction(transaction)) => {
// Ensure the checksum matches. If not, skip the transaction.
if !matches!(transaction.to_checksum::<N>(), Ok(transaction_checksum) if transaction_checksum == checksum )
{
trace!("Proposing - Skipping transaction '{}' - Checksum mismatch", fmt_id(transaction_id));
continue;
}
(
TransmissionID::Transaction(transaction_id, checksum),
Transmission::Transaction(transaction),
) => {
// Ensure the checksum matches.
match transaction.to_checksum::<N>() {
Ok(transaction_checksum) if transaction_checksum == checksum => (),
_ => {
trace!(
"Proposing - Skipping transaction '{}' - Checksum mismatch",
fmt_id(transaction_id)
);
continue 'inner;
}
}
// Check if the transaction is still valid.
if let Err(e) = self.ledger.check_transaction_basic(transaction_id, transaction).await {
trace!("Proposing - Skipping transaction '{}' - {e}", fmt_id(transaction_id));
continue 'inner;
// Check if the transaction is still valid.
// TODO: check if clone is cheap, otherwise fix.
if let Err(e) = self.ledger.check_transaction_basic(transaction_id, transaction.clone()).await {
trace!("Proposing - Skipping transaction '{}' - {e}", fmt_id(transaction_id));
continue;
}

// Ensure the transaction doesn't bring the proposal above the spend limit.
match self.ledger.compute_cost(transaction_id, transaction) {
Ok(cost) if proposal_cost + cost <= N::BATCH_SPEND_LIMIT => proposal_cost += cost,
_ => {
trace!(
"Proposing - Skipping transaction '{}' - Batch spend limit surpassed",
fmt_id(transaction_id)
);

// Reinsert the transmission into the worker, O(n).
worker.shift_insert_front(id, transmission);
break 'outer;
}
}
// Note: We explicitly forbid including ratifications,
// as the protocol currently does not support ratifications.
(TransmissionID::Ratification, Transmission::Ratification) => continue,
// All other combinations are clearly invalid.
_ => continue 'inner,
}
// Insert the transmission into the map.
transmissions.insert(id, transmission);
num_transmissions_included_for_worker += 1;
// Note: We explicitly forbid including ratifications,
// as the protocol currently does not support ratifications.
(TransmissionID::Ratification, Transmission::Ratification) => continue,
// All other combinations are clearly invalid.
_ => continue,
}

// If the transmission is valid, insert it into the proposal's transmission list.
transmissions.insert(id, transmission);
num_worker_transmissions += 1;
}
}

Expand Down Expand Up @@ -755,6 +761,35 @@ impl<N: Network> Primary<N> {
// Inserts the missing transmissions into the workers.
self.insert_missing_transmissions_into_workers(peer_ip, missing_transmissions.into_iter())?;

// Ensure the transaction doesn't bring the proposal above the spend limit.
let mut proposal_cost = 0u64;
for transmission_id in batch_header.transmission_ids() {
let worker_id = assign_to_worker(*transmission_id, self.num_workers())?;
let Some(worker) = self.workers.get(worker_id as usize) else {
debug!("Unable to find worker {worker_id}");
return Ok(());
};

let Some(transmission) = worker.get_transmission(*transmission_id) else {
debug!("Unable to find transmission '{}' in worker '{worker_id}", fmt_id(transmission_id));
return Ok(());
};

// If the transmission is a transaction, compute its execution cost.
if let (TransmissionID::Transaction(transaction_id, _), Transmission::Transaction(transaction)) =
(transmission_id, transmission)
{
proposal_cost += self.ledger.compute_cost(*transaction_id, transaction)?
}
}

if proposal_cost > N::BATCH_SPEND_LIMIT {
debug!(
"Batch propose from peer '{peer_ip}' exceeds the batch spend limit — cost in microcredits: '{proposal_cost}'"
);
return Ok(());
}

/* Proceeding to sign the batch. */

// Retrieve the batch ID.
Expand Down Expand Up @@ -1995,6 +2030,35 @@ mod tests {
}
}

#[tokio::test]
async fn test_propose_batch_over_spend_limit() {
let mut rng = TestRng::default();
let (primary, _) = primary_without_handlers(&mut rng).await;

// Check there is no batch currently proposed.
assert!(primary.proposed_batch.read().is_none());
// Check the workers are empty.
primary.workers().iter().for_each(|worker| assert!(worker.transmissions().is_empty()));

// Generate a solution and a transaction.
let (solution_id, solution) = sample_unconfirmed_solution(&mut rng);
primary.workers[0].process_unconfirmed_solution(solution_id, solution).await.unwrap();

// At 10 credits per execution, 10 transactions should max out a batch, add a few more.
for _i in 0..15 {
let (transaction_id, transaction) = sample_unconfirmed_transaction(&mut rng);
// Store it on one of the workers.
primary.workers[0].process_unconfirmed_transaction(transaction_id, transaction).await.unwrap();
}

// Try to propose a batch again. This time, it should succeed.
assert!(primary.propose_batch().await.is_ok());
// Expect 10/15 transactions to be included in the proposal, along with the solution.
assert_eq!(primary.proposed_batch.read().as_ref().unwrap().transmissions().len(), 11);
// Check the transactions were correctly drained from the workers (15 + 1 - 11).
assert_eq!(primary.workers().iter().map(|worker| worker.transmissions().len()).sum::<usize>(), 5);
}

#[tokio::test]
async fn test_propose_batch() {
let mut rng = TestRng::default();
Expand Down
15 changes: 10 additions & 5 deletions node/bft/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,14 @@ impl<N: Network> Worker<N> {
Ok((transmission_id, transmission))
}

/// Removes up to the specified number of transmissions from the ready queue, and returns them.
pub(crate) fn drain(&self, num_transmissions: usize) -> impl Iterator<Item = (TransmissionID<N>, Transmission<N>)> {
self.ready.drain(num_transmissions).into_iter()
/// Inserts the transmission at the front of the ready queue.
pub(crate) fn shift_insert_front(&self, key: TransmissionID<N>, value: Transmission<N>) {
self.ready.shift_insert_front(key, value)
}

/// Removes and returns the transmission at the front of the ready queue.
pub(crate) fn shift_remove_front(&self) -> Option<(TransmissionID<N>, Transmission<N>)> {
self.ready.shift_remove_front()
}

/// Reinserts the specified transmission into the ready queue.
Expand Down Expand Up @@ -602,6 +607,7 @@ mod tests {
transmissions: IndexMap<TransmissionID<N>, Transmission<N>>,
) -> Result<Block<N>>;
fn advance_to_next_block(&self, block: &Block<N>) -> Result<()>;
fn compute_cost(&self, transaction_id: N::TransactionID, transaction: Data<Transaction<N>>) -> Result<u64>;
}
}

Expand Down Expand Up @@ -659,8 +665,7 @@ mod tests {
assert!(worker.ready.contains(transmission_id));
assert_eq!(worker.get_transmission(transmission_id), Some(transmission));
// Take the transmission from the ready set.
let transmission: Vec<_> = worker.drain(1).collect();
assert_eq!(transmission.len(), 1);
assert!(worker.ready.shift_remove_front().is_some());
assert!(!worker.ready.contains(transmission_id));
}

Expand Down