From 4b309ed0192e2cd8484eb857f2f75306a90b550e Mon Sep 17 00:00:00 2001 From: Marcelo Diop-Gonzalez Date: Tue, 3 Dec 2024 10:44:45 -0500 Subject: [PATCH] test(resharding): enable resharding state sync and modify the resharding state sanity checks accordingly (#12546) Here we enable state sync after a resharding by removing the debug assert that prevented it, and we unignore the corresponding resharding TestLoop test. To make this work, we have to modify `assert_state_sanity_for_children_shard()`, because it currently assumes that the client tracks all shards. If we make the simple change of just checking whether the shard is tracked before proceeding, we'll run into a problem because memtries are unloaded at the end of an epoch, and this check was performed at the end of the test when the head of the chain is in a different epoch than the final head, so clients tracking a single shard won't have those memtries loaded anymore. So instead of making the check at the end, we can make it on every new block. The check is pretty quick anyway (<2 ms) and there's no harm in getting a bit more coverage. While we're at it, we also make the check for every shard rather than just the child shards, since it's a bit simpler and there's no downside to a bit more coverage. Performing the check on every block introduces a bit of extra work that needs to be done to make sure we're actually checking what we think we are, though. When checking flat storage and memtrie equality on arbitrary blocks, it might be the case that flat storage isn't ready because it's for a child shard that's still being created. This is not a big deal, because we just skip checking flat storage and memtrie equality in that case (but still check memtrie and trie equality). We could just leave it at that, but there's a risk that the test will silently not check what we think it was supposed to if there's a bug that prevents the flat storage comparison from happening for an entire epoch. So, we introduce a `struct TrieSanityCheck` to keep track of the shards for which we expect there to be flat storage and memtries (for each account for each epoch). Then at the end of the test we iterate over epoch IDs and make sure all expected checks were performed. --- chain/chain-primitives/src/error.rs | 5 + chain/chain/src/chain.rs | 21 -- core/primitives/src/shard_layout.rs | 7 +- .../src/test_loop/tests/resharding_v3.rs | 254 ++++++++++++++---- 4 files changed, 215 insertions(+), 72 deletions(-) diff --git a/chain/chain-primitives/src/error.rs b/chain/chain-primitives/src/error.rs index 085145cdeda..b7897f4c48c 100644 --- a/chain/chain-primitives/src/error.rs +++ b/chain/chain-primitives/src/error.rs @@ -193,6 +193,8 @@ pub enum Error { InvalidShardId(ShardId), #[error("Shard index {0} does not exist")] InvalidShardIndex(ShardIndex), + #[error("Shard id {0} does not have a parent")] + NoParentShardId(ShardId), /// Invalid shard id #[error("Invalid state request: {0}")] InvalidStateRequest(String), @@ -326,6 +328,7 @@ impl Error { | Error::InvalidBandwidthRequests(_) | Error::InvalidShardId(_) | Error::InvalidShardIndex(_) + | Error::NoParentShardId(_) | Error::InvalidStateRequest(_) | Error::InvalidRandomnessBeaconOutput | Error::InvalidBlockMerkleRoot @@ -407,6 +410,7 @@ impl Error { Error::InvalidBandwidthRequests(_) => "invalid_bandwidth_requests", Error::InvalidShardId(_) => "invalid_shard_id", Error::InvalidShardIndex(_) => "invalid_shard_index", + Error::NoParentShardId(_) => "no_parent_shard_id", Error::InvalidStateRequest(_) => "invalid_state_request", Error::InvalidRandomnessBeaconOutput => "invalid_randomness_beacon_output", Error::InvalidBlockMerkleRoot => "invalid_block_merkele_root", @@ -450,6 +454,7 @@ impl From for Error { ShardLayoutError::InvalidShardIndexError { shard_index } => { Error::InvalidShardIndex(shard_index) } + ShardLayoutError::NoParentError { shard_id } => Error::NoParentShardId(shard_id), } } } diff --git a/chain/chain/src/chain.rs b/chain/chain/src/chain.rs index 03dc1bc0c0f..948972e0f00 100644 --- a/chain/chain/src/chain.rs +++ b/chain/chain/src/chain.rs @@ -798,27 +798,6 @@ impl Chain { me, &prev_hash, )?; - let prev_block = self.get_block(&prev_hash)?; - - if prev_block.chunks().len() != epoch_first_block.chunks().len() - && !shards_to_state_sync.is_empty() - { - // Currently, the state sync algorithm assumes that the number of chunks do not change - // between the epoch being synced to and the last epoch. - // For example, if shard layout changes at the beginning of epoch T, validators - // will not be able to sync states at epoch T for epoch T+1 - // Fortunately, since all validators track all shards for now, this error will not be - // triggered in live yet - // Instead of propagating the error, we simply log the error here because the error - // do not affect processing blocks for this epoch. However, when the next epoch comes, - // the validator will not have the states ready so it will halt. - error!( - "Cannot download states for epoch {:?} because sharding just changed. I'm {:?}", - epoch_first_block.header().epoch_id(), - me - ); - debug_assert!(false); - } if shards_to_state_sync.is_empty() { Ok(None) } else { diff --git a/core/primitives/src/shard_layout.rs b/core/primitives/src/shard_layout.rs index ff422c1ce95..7f611833b58 100644 --- a/core/primitives/src/shard_layout.rs +++ b/core/primitives/src/shard_layout.rs @@ -327,6 +327,7 @@ impl ShardLayoutV2 { pub enum ShardLayoutError { InvalidShardIdError { shard_id: ShardId }, InvalidShardIndexError { shard_index: ShardIndex }, + NoParentError { shard_id: ShardId }, } impl fmt::Display for ShardLayoutError { @@ -624,7 +625,7 @@ impl ShardLayout { return Err(ShardLayoutError::InvalidShardIdError { shard_id }); } let parent_shard_id = match self { - Self::V0(_) => panic!("shard layout has no parent shard"), + Self::V0(_) => return Err(ShardLayoutError::NoParentError { shard_id }), Self::V1(v1) => match &v1.to_parent_shard_map { // we can safely unwrap here because the construction of to_parent_shard_map guarantees // that every shard has a parent shard @@ -632,7 +633,7 @@ impl ShardLayout { let shard_index = self.get_shard_index(shard_id)?; *to_parent_shard_map.get(shard_index).unwrap() } - None => panic!("shard_layout has no parent shard"), + None => return Err(ShardLayoutError::NoParentError { shard_id }), }, Self::V2(v2) => match &v2.shards_parent_map { Some(to_parent_shard_map) => { @@ -641,7 +642,7 @@ impl ShardLayout { .ok_or(ShardLayoutError::InvalidShardIdError { shard_id })?; *parent_shard_id } - None => panic!("shard_layout has no parent shard"), + None => return Err(ShardLayoutError::NoParentError { shard_id }), }, }; Ok(parent_shard_id) diff --git a/integration-tests/src/test_loop/tests/resharding_v3.rs b/integration-tests/src/test_loop/tests/resharding_v3.rs index ee60c306bab..c2006f388d5 100644 --- a/integration-tests/src/test_loop/tests/resharding_v3.rs +++ b/integration-tests/src/test_loop/tests/resharding_v3.rs @@ -12,7 +12,7 @@ use near_primitives::epoch_manager::EpochConfigStore; use near_primitives::hash::CryptoHash; use near_primitives::shard_layout::{account_id_to_shard_uid, ShardLayout}; use near_primitives::state_record::StateRecord; -use near_primitives::types::{AccountId, BlockHeightDelta, Gas, ShardId}; +use near_primitives::types::{AccountId, BlockHeightDelta, EpochId, Gas, ShardId}; use near_primitives::version::{ProtocolFeature, PROTOCOL_VERSION}; use near_store::adapter::StoreAdapter; use near_store::db::refcount::decode_value_with_rc; @@ -40,21 +40,24 @@ use near_primitives::views::FinalExecutionStatus; use std::cell::Cell; use std::u64; -fn client_tracking_shard<'a>(clients: &'a [&Client], tip: &Tip, shard_id: ShardId) -> &'a Client { +fn client_tracking_shard(client: &Client, shard_id: ShardId, parent_hash: &CryptoHash) -> bool { + let signer = client.validator_signer.get(); + let account_id = signer.as_ref().map(|s| s.validator_id()); + client.shard_tracker.care_about_shard(account_id, parent_hash, shard_id, true) +} + +fn get_client_tracking_shard<'a>( + clients: &'a [&Client], + tip: &Tip, + shard_id: ShardId, +) -> &'a Client { for client in clients { - let signer = client.validator_signer.get(); - let cares_about_shard = client.shard_tracker.care_about_shard( - signer.as_ref().map(|s| s.validator_id()), - &tip.prev_block_hash, - shard_id, - true, - ); - if cares_about_shard { + if client_tracking_shard(client, shard_id, &tip.prev_block_hash) { return client; } } panic!( - "client_tracking_shard() could not find client tracking shard {} at {} #{}", + "get_client_tracking_shard() could not find client tracking shard {} at {} #{}", shard_id, &tip.last_block_hash, tip.height ); } @@ -62,7 +65,7 @@ fn client_tracking_shard<'a>(clients: &'a [&Client], tip: &Tip, shard_id: ShardI fn print_and_assert_shard_accounts(clients: &[&Client], tip: &Tip) { let epoch_config = clients[0].epoch_manager.get_epoch_config(&tip.epoch_id).unwrap(); for shard_uid in epoch_config.shard_layout.shard_uids() { - let client = client_tracking_shard(clients, tip, shard_uid.shard_id()); + let client = get_client_tracking_shard(clients, tip, shard_uid.shard_id()); let chunk_extra = client.chain.get_chunk_extra(&tip.prev_block_hash, &shard_uid).unwrap(); let trie = client .runtime_adapter @@ -549,25 +552,57 @@ fn get_memtrie_for_shard( memtrie } +fn assert_state_equal( + values1: &HashSet<(Vec, Vec)>, + values2: &HashSet<(Vec, Vec)>, + shard_uid: ShardUId, + cmp_msg: &str, +) { + let diff = values1.symmetric_difference(values2); + let mut has_diff = false; + for (key, value) in diff { + has_diff = true; + tracing::error!(target: "test", ?shard_uid, key=?key, ?value, "Difference in state between {}!", cmp_msg); + } + assert!(!has_diff, "{} state mismatch!", cmp_msg); +} + +fn shard_was_split(shard_layout: &ShardLayout, shard_id: ShardId) -> bool { + let Ok(parent) = shard_layout.get_parent_shard_id(shard_id) else { + return false; + }; + parent != shard_id +} + /// Asserts that for each child shard: /// MemTrie, FlatState and DiskTrie all contain the same key-value pairs. -fn assert_state_sanity_for_children_shard(parent_shard_uid: ShardUId, client: &Client) { - let final_head = client.chain.final_head().unwrap(); +/// If `load_mem_tries_for_tracked_shards` is false, we only enforce memtries for split shards +/// Returns the ShardUIds that this client tracks and has sane memtries and flat storage for +fn assert_state_sanity( + client: &Client, + final_head: &Tip, + load_mem_tries_for_tracked_shards: bool, +) -> Vec { + let shard_layout = client.epoch_manager.get_shard_layout(&final_head.epoch_id).unwrap(); + let mut checked_shards = Vec::new(); + + for shard_uid in shard_layout.shard_uids() { + if !load_mem_tries_for_tracked_shards + && !shard_was_split(&shard_layout, shard_uid.shard_id()) + { + continue; + } + if !client_tracking_shard(client, shard_uid.shard_id(), &final_head.prev_block_hash) { + continue; + } - for child_shard_uid in client - .epoch_manager - .get_shard_layout(&final_head.epoch_id) - .unwrap() - .get_children_shards_uids(parent_shard_uid.shard_id()) - .unwrap() - { - let memtrie = get_memtrie_for_shard(client, &child_shard_uid, &final_head.prev_block_hash); + let memtrie = get_memtrie_for_shard(client, &shard_uid, &final_head.prev_block_hash); let memtrie_state = memtrie.lock_for_iter().iter().unwrap().collect::, _>>().unwrap(); let state_root = *client .chain - .get_chunk_extra(&final_head.prev_block_hash, &child_shard_uid) + .get_chunk_extra(&final_head.prev_block_hash, &shard_uid) .unwrap() .state_root(); @@ -575,22 +610,21 @@ fn assert_state_sanity_for_children_shard(parent_shard_uid: ShardUId, client: &C // uses memtries. let trie = client .runtime_adapter - .get_view_trie_for_shard( - child_shard_uid.shard_id(), - &final_head.prev_block_hash, - state_root, - ) + .get_view_trie_for_shard(shard_uid.shard_id(), &final_head.prev_block_hash, state_root) .unwrap(); assert!(!trie.has_memtries()); let trie_state = trie.lock_for_iter().iter().unwrap().collect::, _>>().unwrap(); + assert_state_equal(&memtrie_state, &trie_state, shard_uid, "memtrie and trie"); - let flat_store_chunk_view = client + let Some(flat_store_chunk_view) = client .chain .runtime_adapter .get_flat_storage_manager() - .chunk_view(child_shard_uid, final_head.last_block_hash) - .unwrap(); + .chunk_view(shard_uid, final_head.last_block_hash) + else { + continue; + }; let flat_store_state = flat_store_chunk_view .iter_range(None, None) .map_ok(|(key, value)| { @@ -600,7 +634,7 @@ fn assert_state_sanity_for_children_shard(parent_shard_uid: ShardUId, client: &C .chain_store() .store() .trie_store() - .get(child_shard_uid, &value.hash) + .get(shard_uid, &value.hash) .unwrap() .to_vec(), FlatStateValue::Inlined(data) => data, @@ -610,16 +644,140 @@ fn assert_state_sanity_for_children_shard(parent_shard_uid: ShardUId, client: &C .collect::, _>>() .unwrap(); - let diff_memtrie_flat_store = memtrie_state.symmetric_difference(&flat_store_state); - let diff_memtrie_trie = memtrie_state.symmetric_difference(&trie_state); - let diff = diff_memtrie_flat_store.chain(diff_memtrie_trie); - if diff.clone().count() == 0 { - continue; + assert_state_equal(&memtrie_state, &flat_store_state, shard_uid, "memtrie and flat store"); + checked_shards.push(shard_uid); + } + checked_shards +} + +// For each epoch, keep a map from AccountId to a map with keys equal to +// the set of shards that account tracks in that epoch, and bool values indicating +// whether the equality of flat storage and memtries has been checked for that shard +type EpochTrieCheck = HashMap>; + +/// Keeps track of the needed trie comparisons for each epoch. After we successfully call +/// assert_state_sanity() for an account ID, we mark those shards as checked for that epoch, +/// and then at the end of the test we check whether all expected shards for each account +/// were checked at least once in that epoch. We do this because assert_state_sanity() isn't +/// always able to perform the check if child shard flat storages are still being created, but +/// we want to make sure that it's always eventually checked by the end of the epoch +struct TrieSanityCheck { + accounts: Vec, + load_mem_tries_for_tracked_shards: bool, + checks: HashMap, +} + +impl TrieSanityCheck { + fn new(clients: &[&Client], load_mem_tries_for_tracked_shards: bool) -> Self { + let accounts = clients + .iter() + .filter_map(|c| { + let signer = c.validator_signer.get(); + signer.map(|s| s.validator_id().clone()) + }) + .collect(); + Self { accounts, load_mem_tries_for_tracked_shards, checks: HashMap::new() } + } + + // If it's not already stored, initialize it with the expected ShardUIds for each account + fn get_epoch_check(&mut self, client: &Client, tip: &Tip) -> &mut EpochTrieCheck { + match self.checks.entry(tip.epoch_id) { + std::collections::hash_map::Entry::Occupied(e) => e.into_mut(), + std::collections::hash_map::Entry::Vacant(e) => { + let shard_layout = client.epoch_manager.get_shard_layout(&tip.epoch_id).unwrap(); + let shard_uids = shard_layout.shard_uids().collect_vec(); + let mut check = HashMap::new(); + for account_id in self.accounts.iter() { + let tracked = shard_uids + .iter() + .filter_map(|uid| { + if !self.load_mem_tries_for_tracked_shards + && !shard_was_split(&shard_layout, uid.shard_id()) + { + return None; + } + let cares = client.shard_tracker.care_about_shard( + Some(account_id), + &tip.prev_block_hash, + uid.shard_id(), + false, + ); + if cares { + Some((*uid, false)) + } else { + None + } + }) + .collect(); + check.insert(account_id.clone(), tracked); + } + e.insert(check) + } } - for (key, value) in diff { - tracing::error!(target: "test", shard=?child_shard_uid, key=?key, ?value, "Difference in state between trie, memtrie and flat store!"); + } + + // Check trie sanity and keep track of which shards were succesfully fully checked + fn assert_state_sanity(&mut self, clients: &[&Client]) { + for client in clients { + let signer = client.validator_signer.get(); + let Some(account_id) = signer.as_ref().map(|s| s.validator_id()) else { + // For now this is never relevant, since all of them have account IDs, but + // if this changes in the future, here we'll just skip those. + continue; + }; + let head = client.chain.head().unwrap(); + if head.epoch_id == EpochId::default() { + continue; + } + let final_head = client.chain.final_head().unwrap(); + // At the end of an epoch, we unload memtries for shards we'll no longer track. Also, + // the key/value equality comparison in assert_state_equal() is only guaranteed for + // final blocks. So these two together mean that we should only check this when the head + // and final head are in the same epoch. + if head.epoch_id != final_head.epoch_id { + continue; + } + let checked_shards = + assert_state_sanity(client, &final_head, self.load_mem_tries_for_tracked_shards); + let check = self.get_epoch_check(client, &head); + let check = check.get_mut(account_id).unwrap(); + for shard_uid in checked_shards { + check.insert(shard_uid, true); + } + } + } + + /// Look through all the epochs before the current one (because the current one will be early into the epoch, + /// and we won't have checked it yet) and make sure that for all accounts, all expected shards were checked at least once + fn check_epochs(&self, client: &Client) { + let tip = client.chain.head().unwrap(); + let mut block_info = client.epoch_manager.get_block_info(&tip.last_block_hash).unwrap(); + + loop { + let epoch_id = client + .epoch_manager + .get_prev_epoch_id_from_prev_block(block_info.prev_hash()) + .unwrap(); + if epoch_id == EpochId::default() { + break; + } + let check = self.checks.get(&epoch_id).unwrap_or_else(|| { + panic!("No trie comparison checks made for epoch {}", &epoch_id.0) + }); + for (account_id, checked_shards) in check.iter() { + for (shard_uid, checked) in checked_shards.iter() { + assert!( + checked, + "No trie comparison checks made for account {} epoch {} shard {}", + account_id, &epoch_id.0, shard_uid + ); + } + } + + block_info = + client.epoch_manager.get_block_info(block_info.epoch_first_block()).unwrap(); + block_info = client.epoch_manager.get_block_info(block_info.prev_hash()).unwrap(); } - assert!(false, "trie, memtrie and flat store state mismatch!"); } } @@ -735,6 +893,11 @@ fn test_resharding_v3_base(params: TestReshardingParameters) { let client_handles = node_datas.iter().map(|data| data.client_sender.actor_handle()).collect_vec(); + let clients = + client_handles.iter().map(|handle| &test_loop.data.get(handle).client).collect_vec(); + let mut trie_sanity_check = + TrieSanityCheck::new(&clients, params.load_mem_tries_for_tracked_shards); + let latest_block_height = std::cell::Cell::new(0u64); let success_condition = |test_loop_data: &mut TestLoopData| -> bool { params @@ -755,6 +918,7 @@ fn test_resharding_v3_base(params: TestReshardingParameters) { println!("State before resharding:"); print_and_assert_shard_accounts(&clients, &tip); } + trie_sanity_check.assert_state_sanity(&clients); latest_block_height.set(tip.height); println!("block: {} chunks: {:?}", tip.height, block_header.chunk_mask()); if params.all_chunks_expected && params.chunk_ranges_to_drop.is_empty() { @@ -784,16 +948,12 @@ fn test_resharding_v3_base(params: TestReshardingParameters) { // Give enough time to produce ~7 epochs. Duration::seconds((7 * params.epoch_length) as i64), ); + let client = &test_loop.data.get(&client_handles[0]).client; + trie_sanity_check.check_epochs(client); // Wait for garbage collection to kick in, so that it is tested as well. test_loop .run_for(Duration::seconds((DEFAULT_GC_NUM_EPOCHS_TO_KEEP * params.epoch_length) as i64)); - // At the end of the test we know for sure resharding has been completed. - // Verify that state is equal across tries and flat storage for all children shards. - let clients = - client_handles.iter().map(|handle| &test_loop.data.get(handle).client).collect_vec(); - assert_state_sanity_for_children_shard(parent_shard_uid, &clients[0]); - TestLoopEnv { test_loop, datas: node_datas, tempdir } .shutdown_and_drain_remaining_events(Duration::seconds(20)); } @@ -864,9 +1024,7 @@ fn test_resharding_v3_double_sign_resharding_block() { ); } -// TODO(resharding): fix nearcore and un-ignore this test #[test] -#[ignore] fn test_resharding_v3_shard_shuffling() { let params = TestReshardingParameters::new() .shuffle_shard_assignment()