From af6559a5cde25b20a8eff5745a14919e0edbc83a Mon Sep 17 00:00:00 2001 From: Aleksandr Logunov Date: Sat, 23 Nov 2024 12:47:53 +0400 Subject: [PATCH] resharding: verify proof (#12485) Completes the resharding validation flow on the chunk validator side. The code replaces `MainTransition::ShardLayoutChange` with actual logic. The whole preparation is needed to call `let new_root = trie.retain_split_shard(&boundary_account, retain_mode)?;`, to verify it against state root provided in transition. The main complexity is to introduce implicit transition for resharding to the right place - after applying last chunk in old shard layout, potentially missing, and before applying first chunk in the new layout. To do that, I needed to modify `get_state_witness_block_range` once again, and to make it readable, I rewrote it using `TraversalPosition`. The point of that is to change all parameters of the loop at once, instead of changing them one-by-one in the loop. Then, I generate `ImplicitTransitionParams` right in the loop. Finally, this allows to insert resharding transition before jumping to the previous block hash and child shard id, if needed. Note that this flow happens in reverse to chain flow. Also unfortunately this is not triggered on production tests, because when nodes track all shards, the condition `if let Ok(prev_chunk_extra) = chain.get_chunk_extra(prev_block_hash, &shard_uid) { ... }` is triggered - if latest chunk extra is already computed, we skip unnecessary validations. So for now I checked the correctness by commenting this path and running tests locally. --- chain/chain/src/chain.rs | 8 +- .../stateless_validation/chunk_validation.rs | 348 ++++++++++++------ chain/chain/src/update_shard.rs | 4 - .../chunk_validator/mod.rs | 26 +- tools/replay-archive/src/cli.rs | 17 +- 5 files changed, 243 insertions(+), 160 deletions(-) diff --git a/chain/chain/src/chain.rs b/chain/chain/src/chain.rs index 3f377d2803a..50d8030a4ef 100644 --- a/chain/chain/src/chain.rs +++ b/chain/chain/src/chain.rs @@ -3740,19 +3740,13 @@ impl Chain { self.shard_tracker.care_about_shard(me.as_ref(), prev_hash, shard_id, true); let cares_about_shard_next_epoch = self.shard_tracker.will_care_about_shard(me.as_ref(), prev_hash, shard_id, true); - let will_shard_layout_change = self.epoch_manager.will_shard_layout_change(prev_hash)?; let should_apply_chunk = get_should_apply_chunk( mode, cares_about_shard_this_epoch, cares_about_shard_next_epoch, ); let shard_uid = self.epoch_manager.shard_id_to_uid(shard_id, epoch_id)?; - Ok(ShardContext { - shard_uid, - cares_about_shard_this_epoch, - will_shard_layout_change, - should_apply_chunk, - }) + Ok(ShardContext { shard_uid, should_apply_chunk }) } /// This method returns the closure that is responsible for updating a shard. diff --git a/chain/chain/src/stateless_validation/chunk_validation.rs b/chain/chain/src/stateless_validation/chunk_validation.rs index 10a0e0430ce..4dfca1c3cb5 100644 --- a/chain/chain/src/stateless_validation/chunk_validation.rs +++ b/chain/chain/src/stateless_validation/chunk_validation.rs @@ -3,6 +3,7 @@ use crate::chain::{ ShardContext, StorageContext, }; use crate::rayon_spawner::RayonAsyncComputationSpawner; +use crate::resharding::event_type::ReshardingEventType; use crate::sharding::shuffle_receipt_proofs; use crate::stateless_validation::processing_tracker::ProcessingDoneTracker; use crate::types::{ @@ -29,9 +30,10 @@ use near_primitives::stateless_validation::state_witness::{ }; use near_primitives::transaction::SignedTransaction; use near_primitives::types::chunk_extra::ChunkExtra; -use near_primitives::types::{ProtocolVersion, ShardId, ShardIndex}; +use near_primitives::types::{AccountId, ProtocolVersion, ShardId, ShardIndex}; use near_primitives::utils::compression::CompressedData; -use near_store::PartialStorage; +use near_store::trie::ops::resharding::RetainMode; +use near_store::{PartialStorage, Trie}; use std::collections::HashMap; use std::num::NonZeroUsize; use std::sync::{Arc, Mutex}; @@ -41,12 +43,6 @@ use std::time::Instant; pub enum MainTransition { Genesis { chunk_extra: ChunkExtra, block_hash: CryptoHash, shard_id: ShardId }, NewChunk(NewChunkData), - // TODO(#11881): this is temporary indicator that resharding happened in the - // state transition covered by state witness. Won't happen in production - // until resharding release. - // Instead, we can store a separate field `resharding_transition` in - // `ChunkStateWitness` and use it for proper validation of this case. - ShardLayoutChange, } impl MainTransition { @@ -54,7 +50,6 @@ impl MainTransition { match self { Self::Genesis { block_hash, .. } => *block_hash, Self::NewChunk(data) => data.block.block_hash, - Self::ShardLayoutChange => panic!("block_hash called on ShardLayoutChange"), } } @@ -64,14 +59,13 @@ impl MainTransition { // It is ok to use the shard id from the header because it is a new // chunk. An old chunk may have the shard id from the parent shard. Self::NewChunk(data) => data.chunk_header.shard_id(), - Self::ShardLayoutChange => panic!("shard_id called on ShardLayoutChange"), } } } pub struct PreValidationOutput { pub main_transition_params: MainTransition, - pub implicit_transition_params: Vec, + pub implicit_transition_params: Vec, } #[derive(Clone)] @@ -113,74 +107,208 @@ pub fn validate_prepared_transactions( ) } +/// Parameters of implicit state transition, which is not resulted by +/// application of new chunk. +pub enum ImplicitTransitionParams { + /// Transition resulted from application of an old chunk. Defined by block + /// of that chunk and its shard. + ApplyOldChunk(ApplyChunkBlockContext, ShardUId), + /// Transition resulted from resharding. Defined by boundary account, mode + /// saying which of child shards to retain, and parent shard uid. + Resharding(AccountId, RetainMode, ShardUId), +} + struct StateWitnessBlockRange { - /// Blocks from the last new chunk (exclusive) to the parent block - /// (inclusive). - blocks_after_last_chunk: Vec, + /// Transition parameters **after** the last chunk, corresponding to all + /// state transitions from the last new chunk (exclusive) to the parent + /// block (inclusive). + implicit_transition_params: Vec, /// Blocks from the last last new chunk (exclusive) to the last new chunk - /// (inclusive). + /// (inclusive). Note they are in **reverse** order, from the newest to the + /// oldest. They are needed to validate the chunk's source receipt proofs. blocks_after_last_last_chunk: Vec, + /// Shard id and index of last chunk before the chunk being validated. last_chunk_shard_id: ShardId, last_chunk_shard_index: ShardIndex, } +/// Checks if a block has a new chunk with `shard_index`. +fn block_has_new_chunk(block: &Block, shard_index: ShardIndex) -> Result { + let chunks = block.chunks(); + let chunk = chunks.get(shard_index).ok_or_else(|| { + Error::InvalidChunkStateWitness(format!( + "Shard {} does not exist in block {}", + shard_index, + block.hash() + )) + })?; + + Ok(chunk.is_new_chunk(block.header().height())) +} + +/// Gets ranges of blocks that are needed to validate a chunk state witness. +/// Iterates backwards through the chain, from the chunk being validated to +/// the second last chunk, if it exists. fn get_state_witness_block_range( store: &ChainStore, epoch_manager: &dyn EpochManagerAdapter, state_witness: &ChunkStateWitness, ) -> Result { - let mut blocks_after_last_chunk = Vec::new(); + let mut implicit_transition_params = Vec::new(); let mut blocks_after_last_last_chunk = Vec::new(); - let mut block_hash = *state_witness.chunk_header.prev_block_hash(); - let mut prev_chunks_seen = 0; + /// Position in the chain while traversing the blocks backwards. + struct TraversalPosition { + /// Shard ID of chunk, needed to validate state transitions, in the + /// currently observed block. + shard_id: ShardId, + /// Previous block. + prev_block: Block, + /// Number of new chunks seen during traversal. + num_new_chunks_seen: u32, + /// Current candidate shard id of last chunk before the chunk being + /// validated. + last_chunk_shard_id: ShardId, + /// Current candidate shard index of last chunk before the chunk being + /// validated. + last_chunk_shard_index: usize, + } - // It is ok to use the shard id from the header because it is a new - // chunk. An old chunk may have the shard id from the parent shard. - let (mut current_shard_id, mut current_shard_index) = - epoch_manager.get_prev_shard_id(&block_hash, state_witness.chunk_header.shard_id())?; + let initial_prev_hash = *state_witness.chunk_header.prev_block_hash(); + let initial_prev_block = store.get_block(&initial_prev_hash)?; + let initial_shard_layout = + epoch_manager.get_shard_layout_from_prev_block(&initial_prev_hash)?; + let initial_shard_id = state_witness.chunk_header.shard_id(); + let initial_shard_index = initial_shard_layout.get_shard_index(initial_shard_id)?; + + let mut position = TraversalPosition { + shard_id: initial_shard_id, + prev_block: initial_prev_block, + num_new_chunks_seen: 0, + last_chunk_shard_id: initial_shard_id, + last_chunk_shard_index: initial_shard_index, + }; - let mut last_chunk_shard_id = current_shard_id; - let mut last_chunk_shard_index = current_shard_index; loop { - let block = store.get_block(&block_hash)?; - let prev_hash = *block.header().prev_hash(); - let chunks = block.chunks(); - let Some(chunk) = chunks.get(current_shard_index) else { - return Err(Error::InvalidChunkStateWitness(format!( - "Shard {} does not exist in block {:?}", - current_shard_id, block_hash - ))); - }; - let is_new_chunk = chunk.is_new_chunk(block.header().height()); - let is_genesis = block.header().is_genesis(); - if is_new_chunk { - prev_chunks_seen += 1; + let prev_hash = position.prev_block.hash(); + let prev_prev_hash = position.prev_block.header().prev_hash(); + let epoch_id = epoch_manager.get_epoch_id_from_prev_block(prev_hash)?; + let shard_uid = epoch_manager.shard_id_to_uid(position.shard_id, &epoch_id)?; + + if let Some(transition) = get_resharding_transition( + epoch_manager, + prev_hash, + shard_uid, + position.num_new_chunks_seen, + )? { + implicit_transition_params.push(transition); } - if prev_chunks_seen == 0 { - last_chunk_shard_id = current_shard_id; - last_chunk_shard_index = current_shard_index; - blocks_after_last_chunk.push(block); - } else if prev_chunks_seen == 1 { - blocks_after_last_last_chunk.push(block); + let (prev_shard_id, prev_shard_index) = + epoch_manager.get_prev_shard_id(prev_hash, position.shard_id)?; + + let new_chunk_seen = block_has_new_chunk(&position.prev_block, prev_shard_index)?; + let new_chunks_seen_update = + position.num_new_chunks_seen + if new_chunk_seen { 1 } else { 0 }; + + match new_chunks_seen_update { + // If we have seen 0 chunks, the block contributes to implicit + // state transition. + 0 => { + let block_context = Chain::get_apply_chunk_block_context( + epoch_manager, + &position.prev_block, + &store.get_block_header(&prev_prev_hash)?, + false, + )?; + + implicit_transition_params + .push(ImplicitTransitionParams::ApplyOldChunk(block_context, shard_uid)); + } + // If we have seen 1 chunk, the block contributes to source receipt + // proofs. + 1 => blocks_after_last_last_chunk.push(position.prev_block.clone()), + // If we have seen the 2nd chunk, we are done. + 2 => break, + _ => unreachable!("chunks_seen should never exceed 2"), } - if prev_chunks_seen == 2 || is_genesis { + + if position.prev_block.header().is_genesis() { break; } - block_hash = prev_hash; - (current_shard_id, current_shard_index) = - epoch_manager.get_prev_shard_id(&prev_hash, current_shard_id)?; + let prev_prev_block = store.get_block(&prev_prev_hash)?; + // If we have not seen chunks, switch to previous shard id, but + // once we just saw the first chunk, start keeping its shard id. + let (last_chunk_shard_id, last_chunk_shard_index) = if position.num_new_chunks_seen == 0 { + (prev_shard_id, prev_shard_index) + } else { + (position.last_chunk_shard_id, position.last_chunk_shard_index) + }; + position = TraversalPosition { + shard_id: prev_shard_id, + prev_block: prev_prev_block, + num_new_chunks_seen: new_chunks_seen_update, + last_chunk_shard_id, + last_chunk_shard_index, + }; } + implicit_transition_params.reverse(); Ok(StateWitnessBlockRange { - blocks_after_last_chunk, + implicit_transition_params, blocks_after_last_last_chunk, - last_chunk_shard_id, - last_chunk_shard_index, + last_chunk_shard_id: position.last_chunk_shard_id, + last_chunk_shard_index: position.last_chunk_shard_index, }) } +/// Checks if chunk validation requires a transition to new shard layout in the +/// block with `prev_hash`, with a split resulting in the `shard_uid`, and if +/// so, returns the corresponding resharding transition parameters. +fn get_resharding_transition( + epoch_manager: &dyn EpochManagerAdapter, + prev_hash: &CryptoHash, + shard_uid: ShardUId, + num_new_chunks_seen: u32, +) -> Result, Error> { + // If we have already seen a new chunk, we don't need to validate + // resharding transition. + if num_new_chunks_seen > 0 { + return Ok(None); + } + + let shard_layout = epoch_manager.get_shard_layout_from_prev_block(prev_hash)?; + let prev_epoch_id = epoch_manager.get_prev_epoch_id_from_prev_block(prev_hash)?; + let prev_shard_layout = epoch_manager.get_shard_layout(&prev_epoch_id)?; + let block_has_new_shard_layout = + epoch_manager.is_next_block_epoch_start(prev_hash)? && shard_layout != prev_shard_layout; + + if !block_has_new_shard_layout { + return Ok(None); + } + + let params = match ReshardingEventType::from_shard_layout(&shard_layout, *prev_hash)? { + Some(ReshardingEventType::SplitShard(params)) => params, + None => return Ok(None), + }; + + if params.left_child_shard == shard_uid { + Ok(Some(ImplicitTransitionParams::Resharding( + params.boundary_account, + RetainMode::Left, + shard_uid, + ))) + } else if params.right_child_shard == shard_uid { + Ok(Some(ImplicitTransitionParams::Resharding( + params.boundary_account, + RetainMode::Right, + shard_uid, + ))) + } else { + Ok(None) + } +} + /// Pre-validates the chunk's receipts and transactions against the chain. /// We do this before handing off the computationally intensive part to a /// validation thread. @@ -200,26 +328,12 @@ pub fn pre_validate_chunk_state_witness( // First, go back through the blockchain history to locate the last new chunk // and last last new chunk for the shard. let StateWitnessBlockRange { - blocks_after_last_chunk, + implicit_transition_params, blocks_after_last_last_chunk, last_chunk_shard_id, last_chunk_shard_index, } = get_state_witness_block_range(store, epoch_manager, state_witness)?; - let last_chunk_block = blocks_after_last_last_chunk.first().ok_or_else(|| { - Error::Other("blocks_after_last_last_chunk is empty, this should be impossible!".into()) - })?; - let last_chunk_shard_layout = - epoch_manager.get_shard_layout(&last_chunk_block.header().epoch_id())?; - let chunk_shard_layout = epoch_manager - .get_shard_layout_from_prev_block(state_witness.chunk_header.prev_block_hash())?; - if last_chunk_shard_layout != chunk_shard_layout { - return Ok(PreValidationOutput { - main_transition_params: MainTransition::ShardLayoutChange, - implicit_transition_params: Vec::new(), - }); - } - let receipts_to_apply = validate_source_receipt_proofs( epoch_manager, &state_witness.source_receipt_proofs, @@ -341,21 +455,7 @@ pub fn pre_validate_chunk_state_witness( }) }; - Ok(PreValidationOutput { - main_transition_params, - implicit_transition_params: blocks_after_last_chunk - .into_iter() - .rev() - .map(|block| -> Result<_, Error> { - Ok(Chain::get_apply_chunk_block_context( - epoch_manager, - &block, - &store.get_block_header(block.header().prev_hash())?, - false, - )?) - }) - .collect::>()?, - }) + Ok(PreValidationOutput { main_transition_params, implicit_transition_params }) } /// Validate that receipt proofs contain the receipts that should be applied during the @@ -498,12 +598,7 @@ pub fn validate_chunk_state_witness( ApplyChunkReason::ValidateChunkStateWitness, &span, new_chunk_data, - ShardContext { - shard_uid, - cares_about_shard_this_epoch: true, - will_shard_layout_change: false, - should_apply_chunk: true, - }, + ShardContext { shard_uid, should_apply_chunk: true }, runtime_adapter, )?; let outgoing_receipts = std::mem::take(&mut main_apply_result.outgoing_receipts); @@ -512,9 +607,6 @@ pub fn validate_chunk_state_witness( (chunk_extra, outgoing_receipts) } - (MainTransition::ShardLayoutChange, _) => { - panic!("shard layout change should not be validated") - } (_, Some(result)) => (result.chunk_extra, result.outgoing_receipts), }; if chunk_extra.state_root() != &state_witness.main_state_transition.post_state_root { @@ -549,43 +641,67 @@ pub fn validate_chunk_state_witness( ); } - for (block, transition) in pre_validation_output + if pre_validation_output.implicit_transition_params.len() + != state_witness.implicit_transitions.len() + { + return Err(Error::InvalidChunkStateWitness(format!( + "Implicit transitions count mismatch. Expected {}, found {}", + pre_validation_output.implicit_transition_params.len(), + state_witness.implicit_transitions.len(), + ))); + } + + for (implicit_transition_params, transition) in pre_validation_output .implicit_transition_params .into_iter() .zip(state_witness.implicit_transitions.into_iter()) { - let block_hash = block.block_hash; - let old_chunk_data = OldChunkData { - prev_chunk_extra: chunk_extra.clone(), - block, - storage_context: StorageContext { - storage_data_source: StorageDataSource::Recorded(PartialStorage { - nodes: transition.base_state, - }), - state_patch: Default::default(), - }, + let (shard_uid, new_state_root) = match implicit_transition_params { + ImplicitTransitionParams::ApplyOldChunk(block, shard_uid) => { + let shard_context = ShardContext { shard_uid, should_apply_chunk: false }; + let old_chunk_data = OldChunkData { + prev_chunk_extra: chunk_extra.clone(), + block, + storage_context: StorageContext { + storage_data_source: StorageDataSource::Recorded(PartialStorage { + nodes: transition.base_state, + }), + state_patch: Default::default(), + }, + }; + let OldChunkResult { apply_result, .. } = apply_old_chunk( + ApplyChunkReason::ValidateChunkStateWitness, + &span, + old_chunk_data, + shard_context, + runtime_adapter, + )?; + (shard_uid, apply_result.new_root) + } + ImplicitTransitionParams::Resharding( + boundary_account, + retain_mode, + child_shard_uid, + ) => { + let old_root = *chunk_extra.state_root(); + let trie = Trie::from_recorded_storage( + PartialStorage { nodes: transition.base_state }, + old_root, + true, + ); + let new_root = trie.retain_split_shard(&boundary_account, retain_mode)?; + (child_shard_uid, new_root) + } }; - let OldChunkResult { apply_result, .. } = apply_old_chunk( - ApplyChunkReason::ValidateChunkStateWitness, - &span, - old_chunk_data, - ShardContext { - // Consider other shard uid in case of resharding. - shard_uid, - cares_about_shard_this_epoch: true, - will_shard_layout_change: false, - should_apply_chunk: false, - }, - runtime_adapter, - )?; - *chunk_extra.state_root_mut() = apply_result.new_root; + + *chunk_extra.state_root_mut() = new_state_root; if chunk_extra.state_root() != &transition.post_state_root { // This is an early check, it's not for correctness, only for better // error reporting in case of an invalid state witness due to a bug. // Only the final state root check against the chunk header is required. return Err(Error::InvalidChunkStateWitness(format!( - "Post state root {:?} for implicit transition at block {:?}, does not match expected state root {:?}", - chunk_extra.state_root(), block_hash, transition.post_state_root + "Post state root {:?} for implicit transition at block {:?} to shard {:?}, does not match expected state root {:?}", + chunk_extra.state_root(), transition.block_hash, shard_uid, transition.post_state_root ))); } } diff --git a/chain/chain/src/update_shard.rs b/chain/chain/src/update_shard.rs index 5f171bcc7db..9f1c9dad937 100644 --- a/chain/chain/src/update_shard.rs +++ b/chain/chain/src/update_shard.rs @@ -70,10 +70,6 @@ pub enum ShardUpdateReason { /// Information about shard to update. pub struct ShardContext { pub shard_uid: ShardUId, - /// Whether node cares about shard in this epoch. - pub cares_about_shard_this_epoch: bool, - /// Whether shard layout changes in the next epoch. - pub will_shard_layout_change: bool, /// Whether transactions should be applied. pub should_apply_chunk: bool, } diff --git a/chain/client/src/stateless_validation/chunk_validator/mod.rs b/chain/client/src/stateless_validation/chunk_validator/mod.rs index 28a84577806..4cad6d7ef8a 100644 --- a/chain/client/src/stateless_validation/chunk_validator/mod.rs +++ b/chain/client/src/stateless_validation/chunk_validator/mod.rs @@ -95,30 +95,18 @@ impl ChunkValidator { let chunk_header = state_witness.chunk_header.clone(); let network_sender = self.network_sender.clone(); let epoch_manager = self.epoch_manager.clone(); - if matches!( - pre_validation_result.main_transition_params, - chunk_validation::MainTransition::ShardLayoutChange - ) { - send_chunk_endorsement_to_block_producers( - &chunk_header, - epoch_manager.as_ref(), - signer, - &network_sender, - ); - return Ok(()); - } - // If we have the chunk extra for the previous block, we can validate the chunk without state witness. + // If we have the chunk extra for the previous block, we can validate + // the chunk without state witness. // This usually happens because we are a chunk producer and // therefore have the chunk extra for the previous block saved on disk. // We can also skip validating the chunk state witness in this case. + // We don't need to switch to parent shard uid, because resharding + // creates chunk extra for new shard uid. + let shard_uid = epoch_manager.shard_id_to_uid(shard_id, &epoch_id)?; let prev_block = chain.get_block(prev_block_hash)?; - let last_header = Chain::get_prev_chunk_header( - epoch_manager.as_ref(), - &prev_block, - chunk_header.shard_id(), - )?; - let shard_uid = epoch_manager.shard_id_to_uid(last_header.shard_id(), &epoch_id)?; + let last_header = + Chain::get_prev_chunk_header(epoch_manager.as_ref(), &prev_block, shard_id)?; if let Ok(prev_chunk_extra) = chain.get_chunk_extra(prev_block_hash, &shard_uid) { match validate_chunk_with_chunk_extra( diff --git a/tools/replay-archive/src/cli.rs b/tools/replay-archive/src/cli.rs index 9307a658606..a04a13b0c23 100644 --- a/tools/replay-archive/src/cli.rs +++ b/tools/replay-archive/src/cli.rs @@ -305,7 +305,7 @@ impl ReplayController { )?; let shard_id = shard_uid.shard_id(); - let shard_context = self.get_shard_context(block_header, shard_uid)?; + let shard_context = self.get_shard_context(shard_uid)?; let storage_context = StorageContext { storage_data_source: StorageDataSource::DbTrieOnly, @@ -488,19 +488,8 @@ impl ReplayController { /// Generates a ShardContext specific to replaying the blocks, which indicates that /// we care about all the shards and should always apply chunk. - fn get_shard_context( - &self, - block_header: &BlockHeader, - shard_uid: ShardUId, - ) -> Result { - let prev_hash = block_header.prev_hash(); - let will_shard_layout_change = self.epoch_manager.will_shard_layout_change(prev_hash)?; - let shard_context = ShardContext { - shard_uid, - cares_about_shard_this_epoch: true, - will_shard_layout_change: will_shard_layout_change, - should_apply_chunk: true, - }; + fn get_shard_context(&self, shard_uid: ShardUId) -> Result { + let shard_context = ShardContext { shard_uid, should_apply_chunk: true }; Ok(shard_context) }