diff --git a/chain/chain-primitives/src/error.rs b/chain/chain-primitives/src/error.rs index 085145cdeda..b7897f4c48c 100644 --- a/chain/chain-primitives/src/error.rs +++ b/chain/chain-primitives/src/error.rs @@ -193,6 +193,8 @@ pub enum Error { InvalidShardId(ShardId), #[error("Shard index {0} does not exist")] InvalidShardIndex(ShardIndex), + #[error("Shard id {0} does not have a parent")] + NoParentShardId(ShardId), /// Invalid shard id #[error("Invalid state request: {0}")] InvalidStateRequest(String), @@ -326,6 +328,7 @@ impl Error { | Error::InvalidBandwidthRequests(_) | Error::InvalidShardId(_) | Error::InvalidShardIndex(_) + | Error::NoParentShardId(_) | Error::InvalidStateRequest(_) | Error::InvalidRandomnessBeaconOutput | Error::InvalidBlockMerkleRoot @@ -407,6 +410,7 @@ impl Error { Error::InvalidBandwidthRequests(_) => "invalid_bandwidth_requests", Error::InvalidShardId(_) => "invalid_shard_id", Error::InvalidShardIndex(_) => "invalid_shard_index", + Error::NoParentShardId(_) => "no_parent_shard_id", Error::InvalidStateRequest(_) => "invalid_state_request", Error::InvalidRandomnessBeaconOutput => "invalid_randomness_beacon_output", Error::InvalidBlockMerkleRoot => "invalid_block_merkele_root", @@ -450,6 +454,7 @@ impl From for Error { ShardLayoutError::InvalidShardIndexError { shard_index } => { Error::InvalidShardIndex(shard_index) } + ShardLayoutError::NoParentError { shard_id } => Error::NoParentShardId(shard_id), } } } diff --git a/chain/chain/src/chain.rs b/chain/chain/src/chain.rs index 03dc1bc0c0f..948972e0f00 100644 --- a/chain/chain/src/chain.rs +++ b/chain/chain/src/chain.rs @@ -798,27 +798,6 @@ impl Chain { me, &prev_hash, )?; - let prev_block = self.get_block(&prev_hash)?; - - if prev_block.chunks().len() != epoch_first_block.chunks().len() - && !shards_to_state_sync.is_empty() - { - // Currently, the state sync algorithm assumes that the number of chunks do not change - // between the epoch being synced to and the last epoch. - // For example, if shard layout changes at the beginning of epoch T, validators - // will not be able to sync states at epoch T for epoch T+1 - // Fortunately, since all validators track all shards for now, this error will not be - // triggered in live yet - // Instead of propagating the error, we simply log the error here because the error - // do not affect processing blocks for this epoch. However, when the next epoch comes, - // the validator will not have the states ready so it will halt. - error!( - "Cannot download states for epoch {:?} because sharding just changed. I'm {:?}", - epoch_first_block.header().epoch_id(), - me - ); - debug_assert!(false); - } if shards_to_state_sync.is_empty() { Ok(None) } else { diff --git a/core/primitives/src/shard_layout.rs b/core/primitives/src/shard_layout.rs index ff422c1ce95..7f611833b58 100644 --- a/core/primitives/src/shard_layout.rs +++ b/core/primitives/src/shard_layout.rs @@ -327,6 +327,7 @@ impl ShardLayoutV2 { pub enum ShardLayoutError { InvalidShardIdError { shard_id: ShardId }, InvalidShardIndexError { shard_index: ShardIndex }, + NoParentError { shard_id: ShardId }, } impl fmt::Display for ShardLayoutError { @@ -624,7 +625,7 @@ impl ShardLayout { return Err(ShardLayoutError::InvalidShardIdError { shard_id }); } let parent_shard_id = match self { - Self::V0(_) => panic!("shard layout has no parent shard"), + Self::V0(_) => return Err(ShardLayoutError::NoParentError { shard_id }), Self::V1(v1) => match &v1.to_parent_shard_map { // we can safely unwrap here because the construction of to_parent_shard_map guarantees // that every shard has a parent shard @@ -632,7 +633,7 @@ impl ShardLayout { let shard_index = self.get_shard_index(shard_id)?; *to_parent_shard_map.get(shard_index).unwrap() } - None => panic!("shard_layout has no parent shard"), + None => return Err(ShardLayoutError::NoParentError { shard_id }), }, Self::V2(v2) => match &v2.shards_parent_map { Some(to_parent_shard_map) => { @@ -641,7 +642,7 @@ impl ShardLayout { .ok_or(ShardLayoutError::InvalidShardIdError { shard_id })?; *parent_shard_id } - None => panic!("shard_layout has no parent shard"), + None => return Err(ShardLayoutError::NoParentError { shard_id }), }, }; Ok(parent_shard_id) diff --git a/integration-tests/src/test_loop/tests/resharding_v3.rs b/integration-tests/src/test_loop/tests/resharding_v3.rs index ee60c306bab..c2006f388d5 100644 --- a/integration-tests/src/test_loop/tests/resharding_v3.rs +++ b/integration-tests/src/test_loop/tests/resharding_v3.rs @@ -12,7 +12,7 @@ use near_primitives::epoch_manager::EpochConfigStore; use near_primitives::hash::CryptoHash; use near_primitives::shard_layout::{account_id_to_shard_uid, ShardLayout}; use near_primitives::state_record::StateRecord; -use near_primitives::types::{AccountId, BlockHeightDelta, Gas, ShardId}; +use near_primitives::types::{AccountId, BlockHeightDelta, EpochId, Gas, ShardId}; use near_primitives::version::{ProtocolFeature, PROTOCOL_VERSION}; use near_store::adapter::StoreAdapter; use near_store::db::refcount::decode_value_with_rc; @@ -40,21 +40,24 @@ use near_primitives::views::FinalExecutionStatus; use std::cell::Cell; use std::u64; -fn client_tracking_shard<'a>(clients: &'a [&Client], tip: &Tip, shard_id: ShardId) -> &'a Client { +fn client_tracking_shard(client: &Client, shard_id: ShardId, parent_hash: &CryptoHash) -> bool { + let signer = client.validator_signer.get(); + let account_id = signer.as_ref().map(|s| s.validator_id()); + client.shard_tracker.care_about_shard(account_id, parent_hash, shard_id, true) +} + +fn get_client_tracking_shard<'a>( + clients: &'a [&Client], + tip: &Tip, + shard_id: ShardId, +) -> &'a Client { for client in clients { - let signer = client.validator_signer.get(); - let cares_about_shard = client.shard_tracker.care_about_shard( - signer.as_ref().map(|s| s.validator_id()), - &tip.prev_block_hash, - shard_id, - true, - ); - if cares_about_shard { + if client_tracking_shard(client, shard_id, &tip.prev_block_hash) { return client; } } panic!( - "client_tracking_shard() could not find client tracking shard {} at {} #{}", + "get_client_tracking_shard() could not find client tracking shard {} at {} #{}", shard_id, &tip.last_block_hash, tip.height ); } @@ -62,7 +65,7 @@ fn client_tracking_shard<'a>(clients: &'a [&Client], tip: &Tip, shard_id: ShardI fn print_and_assert_shard_accounts(clients: &[&Client], tip: &Tip) { let epoch_config = clients[0].epoch_manager.get_epoch_config(&tip.epoch_id).unwrap(); for shard_uid in epoch_config.shard_layout.shard_uids() { - let client = client_tracking_shard(clients, tip, shard_uid.shard_id()); + let client = get_client_tracking_shard(clients, tip, shard_uid.shard_id()); let chunk_extra = client.chain.get_chunk_extra(&tip.prev_block_hash, &shard_uid).unwrap(); let trie = client .runtime_adapter @@ -549,25 +552,57 @@ fn get_memtrie_for_shard( memtrie } +fn assert_state_equal( + values1: &HashSet<(Vec, Vec)>, + values2: &HashSet<(Vec, Vec)>, + shard_uid: ShardUId, + cmp_msg: &str, +) { + let diff = values1.symmetric_difference(values2); + let mut has_diff = false; + for (key, value) in diff { + has_diff = true; + tracing::error!(target: "test", ?shard_uid, key=?key, ?value, "Difference in state between {}!", cmp_msg); + } + assert!(!has_diff, "{} state mismatch!", cmp_msg); +} + +fn shard_was_split(shard_layout: &ShardLayout, shard_id: ShardId) -> bool { + let Ok(parent) = shard_layout.get_parent_shard_id(shard_id) else { + return false; + }; + parent != shard_id +} + /// Asserts that for each child shard: /// MemTrie, FlatState and DiskTrie all contain the same key-value pairs. -fn assert_state_sanity_for_children_shard(parent_shard_uid: ShardUId, client: &Client) { - let final_head = client.chain.final_head().unwrap(); +/// If `load_mem_tries_for_tracked_shards` is false, we only enforce memtries for split shards +/// Returns the ShardUIds that this client tracks and has sane memtries and flat storage for +fn assert_state_sanity( + client: &Client, + final_head: &Tip, + load_mem_tries_for_tracked_shards: bool, +) -> Vec { + let shard_layout = client.epoch_manager.get_shard_layout(&final_head.epoch_id).unwrap(); + let mut checked_shards = Vec::new(); + + for shard_uid in shard_layout.shard_uids() { + if !load_mem_tries_for_tracked_shards + && !shard_was_split(&shard_layout, shard_uid.shard_id()) + { + continue; + } + if !client_tracking_shard(client, shard_uid.shard_id(), &final_head.prev_block_hash) { + continue; + } - for child_shard_uid in client - .epoch_manager - .get_shard_layout(&final_head.epoch_id) - .unwrap() - .get_children_shards_uids(parent_shard_uid.shard_id()) - .unwrap() - { - let memtrie = get_memtrie_for_shard(client, &child_shard_uid, &final_head.prev_block_hash); + let memtrie = get_memtrie_for_shard(client, &shard_uid, &final_head.prev_block_hash); let memtrie_state = memtrie.lock_for_iter().iter().unwrap().collect::, _>>().unwrap(); let state_root = *client .chain - .get_chunk_extra(&final_head.prev_block_hash, &child_shard_uid) + .get_chunk_extra(&final_head.prev_block_hash, &shard_uid) .unwrap() .state_root(); @@ -575,22 +610,21 @@ fn assert_state_sanity_for_children_shard(parent_shard_uid: ShardUId, client: &C // uses memtries. let trie = client .runtime_adapter - .get_view_trie_for_shard( - child_shard_uid.shard_id(), - &final_head.prev_block_hash, - state_root, - ) + .get_view_trie_for_shard(shard_uid.shard_id(), &final_head.prev_block_hash, state_root) .unwrap(); assert!(!trie.has_memtries()); let trie_state = trie.lock_for_iter().iter().unwrap().collect::, _>>().unwrap(); + assert_state_equal(&memtrie_state, &trie_state, shard_uid, "memtrie and trie"); - let flat_store_chunk_view = client + let Some(flat_store_chunk_view) = client .chain .runtime_adapter .get_flat_storage_manager() - .chunk_view(child_shard_uid, final_head.last_block_hash) - .unwrap(); + .chunk_view(shard_uid, final_head.last_block_hash) + else { + continue; + }; let flat_store_state = flat_store_chunk_view .iter_range(None, None) .map_ok(|(key, value)| { @@ -600,7 +634,7 @@ fn assert_state_sanity_for_children_shard(parent_shard_uid: ShardUId, client: &C .chain_store() .store() .trie_store() - .get(child_shard_uid, &value.hash) + .get(shard_uid, &value.hash) .unwrap() .to_vec(), FlatStateValue::Inlined(data) => data, @@ -610,16 +644,140 @@ fn assert_state_sanity_for_children_shard(parent_shard_uid: ShardUId, client: &C .collect::, _>>() .unwrap(); - let diff_memtrie_flat_store = memtrie_state.symmetric_difference(&flat_store_state); - let diff_memtrie_trie = memtrie_state.symmetric_difference(&trie_state); - let diff = diff_memtrie_flat_store.chain(diff_memtrie_trie); - if diff.clone().count() == 0 { - continue; + assert_state_equal(&memtrie_state, &flat_store_state, shard_uid, "memtrie and flat store"); + checked_shards.push(shard_uid); + } + checked_shards +} + +// For each epoch, keep a map from AccountId to a map with keys equal to +// the set of shards that account tracks in that epoch, and bool values indicating +// whether the equality of flat storage and memtries has been checked for that shard +type EpochTrieCheck = HashMap>; + +/// Keeps track of the needed trie comparisons for each epoch. After we successfully call +/// assert_state_sanity() for an account ID, we mark those shards as checked for that epoch, +/// and then at the end of the test we check whether all expected shards for each account +/// were checked at least once in that epoch. We do this because assert_state_sanity() isn't +/// always able to perform the check if child shard flat storages are still being created, but +/// we want to make sure that it's always eventually checked by the end of the epoch +struct TrieSanityCheck { + accounts: Vec, + load_mem_tries_for_tracked_shards: bool, + checks: HashMap, +} + +impl TrieSanityCheck { + fn new(clients: &[&Client], load_mem_tries_for_tracked_shards: bool) -> Self { + let accounts = clients + .iter() + .filter_map(|c| { + let signer = c.validator_signer.get(); + signer.map(|s| s.validator_id().clone()) + }) + .collect(); + Self { accounts, load_mem_tries_for_tracked_shards, checks: HashMap::new() } + } + + // If it's not already stored, initialize it with the expected ShardUIds for each account + fn get_epoch_check(&mut self, client: &Client, tip: &Tip) -> &mut EpochTrieCheck { + match self.checks.entry(tip.epoch_id) { + std::collections::hash_map::Entry::Occupied(e) => e.into_mut(), + std::collections::hash_map::Entry::Vacant(e) => { + let shard_layout = client.epoch_manager.get_shard_layout(&tip.epoch_id).unwrap(); + let shard_uids = shard_layout.shard_uids().collect_vec(); + let mut check = HashMap::new(); + for account_id in self.accounts.iter() { + let tracked = shard_uids + .iter() + .filter_map(|uid| { + if !self.load_mem_tries_for_tracked_shards + && !shard_was_split(&shard_layout, uid.shard_id()) + { + return None; + } + let cares = client.shard_tracker.care_about_shard( + Some(account_id), + &tip.prev_block_hash, + uid.shard_id(), + false, + ); + if cares { + Some((*uid, false)) + } else { + None + } + }) + .collect(); + check.insert(account_id.clone(), tracked); + } + e.insert(check) + } } - for (key, value) in diff { - tracing::error!(target: "test", shard=?child_shard_uid, key=?key, ?value, "Difference in state between trie, memtrie and flat store!"); + } + + // Check trie sanity and keep track of which shards were succesfully fully checked + fn assert_state_sanity(&mut self, clients: &[&Client]) { + for client in clients { + let signer = client.validator_signer.get(); + let Some(account_id) = signer.as_ref().map(|s| s.validator_id()) else { + // For now this is never relevant, since all of them have account IDs, but + // if this changes in the future, here we'll just skip those. + continue; + }; + let head = client.chain.head().unwrap(); + if head.epoch_id == EpochId::default() { + continue; + } + let final_head = client.chain.final_head().unwrap(); + // At the end of an epoch, we unload memtries for shards we'll no longer track. Also, + // the key/value equality comparison in assert_state_equal() is only guaranteed for + // final blocks. So these two together mean that we should only check this when the head + // and final head are in the same epoch. + if head.epoch_id != final_head.epoch_id { + continue; + } + let checked_shards = + assert_state_sanity(client, &final_head, self.load_mem_tries_for_tracked_shards); + let check = self.get_epoch_check(client, &head); + let check = check.get_mut(account_id).unwrap(); + for shard_uid in checked_shards { + check.insert(shard_uid, true); + } + } + } + + /// Look through all the epochs before the current one (because the current one will be early into the epoch, + /// and we won't have checked it yet) and make sure that for all accounts, all expected shards were checked at least once + fn check_epochs(&self, client: &Client) { + let tip = client.chain.head().unwrap(); + let mut block_info = client.epoch_manager.get_block_info(&tip.last_block_hash).unwrap(); + + loop { + let epoch_id = client + .epoch_manager + .get_prev_epoch_id_from_prev_block(block_info.prev_hash()) + .unwrap(); + if epoch_id == EpochId::default() { + break; + } + let check = self.checks.get(&epoch_id).unwrap_or_else(|| { + panic!("No trie comparison checks made for epoch {}", &epoch_id.0) + }); + for (account_id, checked_shards) in check.iter() { + for (shard_uid, checked) in checked_shards.iter() { + assert!( + checked, + "No trie comparison checks made for account {} epoch {} shard {}", + account_id, &epoch_id.0, shard_uid + ); + } + } + + block_info = + client.epoch_manager.get_block_info(block_info.epoch_first_block()).unwrap(); + block_info = client.epoch_manager.get_block_info(block_info.prev_hash()).unwrap(); } - assert!(false, "trie, memtrie and flat store state mismatch!"); } } @@ -735,6 +893,11 @@ fn test_resharding_v3_base(params: TestReshardingParameters) { let client_handles = node_datas.iter().map(|data| data.client_sender.actor_handle()).collect_vec(); + let clients = + client_handles.iter().map(|handle| &test_loop.data.get(handle).client).collect_vec(); + let mut trie_sanity_check = + TrieSanityCheck::new(&clients, params.load_mem_tries_for_tracked_shards); + let latest_block_height = std::cell::Cell::new(0u64); let success_condition = |test_loop_data: &mut TestLoopData| -> bool { params @@ -755,6 +918,7 @@ fn test_resharding_v3_base(params: TestReshardingParameters) { println!("State before resharding:"); print_and_assert_shard_accounts(&clients, &tip); } + trie_sanity_check.assert_state_sanity(&clients); latest_block_height.set(tip.height); println!("block: {} chunks: {:?}", tip.height, block_header.chunk_mask()); if params.all_chunks_expected && params.chunk_ranges_to_drop.is_empty() { @@ -784,16 +948,12 @@ fn test_resharding_v3_base(params: TestReshardingParameters) { // Give enough time to produce ~7 epochs. Duration::seconds((7 * params.epoch_length) as i64), ); + let client = &test_loop.data.get(&client_handles[0]).client; + trie_sanity_check.check_epochs(client); // Wait for garbage collection to kick in, so that it is tested as well. test_loop .run_for(Duration::seconds((DEFAULT_GC_NUM_EPOCHS_TO_KEEP * params.epoch_length) as i64)); - // At the end of the test we know for sure resharding has been completed. - // Verify that state is equal across tries and flat storage for all children shards. - let clients = - client_handles.iter().map(|handle| &test_loop.data.get(handle).client).collect_vec(); - assert_state_sanity_for_children_shard(parent_shard_uid, &clients[0]); - TestLoopEnv { test_loop, datas: node_datas, tempdir } .shutdown_and_drain_remaining_events(Duration::seconds(20)); } @@ -864,9 +1024,7 @@ fn test_resharding_v3_double_sign_resharding_block() { ); } -// TODO(resharding): fix nearcore and un-ignore this test #[test] -#[ignore] fn test_resharding_v3_shard_shuffling() { let params = TestReshardingParameters::new() .shuffle_shard_assignment()