diff --git a/tasm-lib/benchmarks/tasmlib_list_set_element___digest.json b/tasm-lib/benchmarks/tasmlib_list_set_element___digest.json index 04049508..0e1a4455 100644 --- a/tasm-lib/benchmarks/tasmlib_list_set_element___digest.json +++ b/tasm-lib/benchmarks/tasmlib_list_set_element___digest.json @@ -2,22 +2,22 @@ { "name": "tasmlib_list_set_element___digest", "benchmark_result": { - "clock_cycle_count": 10, - "hash_table_height": 12, - "u32_table_height": 0, - "op_stack_table_height": 11, - "ram_table_height": 5 + "clock_cycle_count": 24, + "hash_table_height": 24, + "u32_table_height": 15, + "op_stack_table_height": 21, + "ram_table_height": 6 }, "case": "CommonCase" }, { "name": "tasmlib_list_set_element___digest", "benchmark_result": { - "clock_cycle_count": 10, - "hash_table_height": 12, - "u32_table_height": 0, - "op_stack_table_height": 11, - "ram_table_height": 5 + "clock_cycle_count": 24, + "hash_table_height": 24, + "u32_table_height": 18, + "op_stack_table_height": 21, + "ram_table_height": 6 }, "case": "WorstCase" } diff --git a/tasm-lib/benchmarks/tasmlib_mmr_calculate_new_peaks_from_leaf_mutation.json b/tasm-lib/benchmarks/tasmlib_mmr_calculate_new_peaks_from_leaf_mutation.json index 655c34de..1462c4c2 100644 --- a/tasm-lib/benchmarks/tasmlib_mmr_calculate_new_peaks_from_leaf_mutation.json +++ b/tasm-lib/benchmarks/tasmlib_mmr_calculate_new_peaks_from_leaf_mutation.json @@ -2,22 +2,22 @@ { "name": "tasmlib_mmr_calculate_new_peaks_from_leaf_mutation", "benchmark_result": { - "clock_cycle_count": 2422, - "hash_table_height": 414, - "u32_table_height": 1263, - "op_stack_table_height": 1574, - "ram_table_height": 191 + "clock_cycle_count": 2436, + "hash_table_height": 426, + "u32_table_height": 1265, + "op_stack_table_height": 1584, + "ram_table_height": 192 }, "case": "CommonCase" }, { "name": "tasmlib_mmr_calculate_new_peaks_from_leaf_mutation", "benchmark_result": { - "clock_cycle_count": 4636, - "hash_table_height": 600, - "u32_table_height": 2158, - "op_stack_table_height": 3072, - "ram_table_height": 377 + "clock_cycle_count": 4650, + "hash_table_height": 612, + "u32_table_height": 2160, + "op_stack_table_height": 3082, + "ram_table_height": 378 }, "case": "WorstCase" } diff --git a/tasm-lib/src/assertion_error_ids.md b/tasm-lib/src/assertion_error_ids.md index 98386a7d..dfffd217 100644 --- a/tasm-lib/src/assertion_error_ids.md +++ b/tasm-lib/src/assertion_error_ids.md @@ -48,3 +48,4 @@ often. | 360..370 | [`u64::Pow2`](arithmetic/u64/pow2.rs) | | 370..380 | [`u64::ShiftLeft`](arithmetic/u64/shift_left.rs) | | 380..390 | [`list::get`](list/get.rs) | +| 390..400 | [`list::set`](list/set.rs) | diff --git a/tasm-lib/src/list/get.rs b/tasm-lib/src/list/get.rs index 2e6c18e4..f48867f3 100644 --- a/tasm-lib/src/list/get.rs +++ b/tasm-lib/src/list/get.rs @@ -21,13 +21,24 @@ impl Get { /// # Panics /// - /// Panics if the element has [dynamic length][BFieldCodec::static_length]. + /// Panics if the element has [dynamic length][BFieldCodec::static_length], or + /// if the static length is 0. pub fn new(element_type: DataType) -> Self { - let has_static_len = element_type.static_length().is_some(); - assert!(has_static_len, "element should have static length"); + Self::assert_element_type_is_supported(&element_type); Self { element_type } } + + /// # Panics + /// + /// Panics if the element has [dynamic length][BFieldCodec::static_length], or + /// if the static length is 0. + pub(crate) fn assert_element_type_is_supported(element_type: &DataType) { + let Some(static_len) = element_type.static_length() else { + panic!("element should have static length"); + }; + assert_ne!(0, static_len, "element must not be zero-sized"); + } } impl BasicSnippet for Get { @@ -50,13 +61,12 @@ impl BasicSnippet for Get { } fn code(&self, library: &mut Library) -> Vec { + let list_length = library.import(Box::new(Length)); let mul_with_element_size = match self.element_type.stack_size() { - 1 => triton_asm!(/* no-op */), + 1 => triton_asm!(), // no-op n => triton_asm!(push {n} mul), }; - let list_length = library.import(Box::new(Length)); - triton_asm!( // BEFORE: _ *list index // AFTER: _ [element: self.element_type] @@ -89,7 +99,7 @@ impl BasicSnippet for Get { } #[cfg(test)] -mod tests { +pub(crate) mod tests { use triton_vm::error::OpStackError::FailedU32Conversion; use super::*; @@ -115,6 +125,23 @@ mod tests { AccessorInitialState { stack, memory } } + + pub fn random_len_idx_ptr( + bench_case: Option, + rng: &mut impl Rng, + ) -> (usize, usize, BFieldElement) { + let (index, list_length) = match bench_case { + Some(BenchmarkCase::CommonCase) => (16, 32), + Some(BenchmarkCase::WorstCase) => (63, 64), + None => { + let list_length = rng.gen_range(1..=100); + (rng.gen_range(0..list_length), list_length) + } + }; + let list_pointer = rng.gen(); + + (list_length, index, list_pointer) + } } impl Accessor for Get { @@ -138,18 +165,8 @@ mod tests { seed: [u8; 32], bench_case: Option, ) -> AccessorInitialState { - let mut rng = StdRng::from_seed(seed); - let list_length = match bench_case { - Some(BenchmarkCase::CommonCase) => 1 << 5, - Some(BenchmarkCase::WorstCase) => 1 << 6, - None => rng.gen_range(1..=100), - }; - let index = match bench_case { - Some(BenchmarkCase::CommonCase) => list_length / 2, - Some(BenchmarkCase::WorstCase) => list_length - 1, - None => rng.gen_range(0..list_length), - }; - let list_pointer = rng.gen(); + let (list_length, index, list_pointer) = + Self::random_len_idx_ptr(bench_case, &mut StdRng::from_seed(seed)); self.set_up_initial_state(list_length, index, list_pointer) } @@ -237,7 +254,7 @@ mod benches { use crate::test_prelude::*; #[test] - fn get_benchmark() { + fn benchmark() { ShadowedAccessor::new(Get::new(DataType::Digest)).bench(); } } diff --git a/tasm-lib/src/list/set.rs b/tasm-lib/src/list/set.rs index 74b045da..c51ab88b 100644 --- a/tasm-lib/src/list/set.rs +++ b/tasm-lib/src/list/set.rs @@ -1,313 +1,264 @@ -use std::collections::HashMap; - -use itertools::Itertools; -use rand::prelude::*; use triton_vm::prelude::*; -use twenty_first::math::other::random_elements; -use crate::empty_stack; +use crate::list::get::Get; +use crate::list::length::Length; use crate::prelude::*; -use crate::rust_shadowing_helper_functions::list::list_set; -use crate::rust_shadowing_helper_functions::list::untyped_insert_random_list; -use crate::traits::deprecated_snippet::DeprecatedSnippet; -use crate::InitVmState; +/// Write an element to a list. Performs bounds check. +/// +/// Only supports lists with [statically sized](BFieldCodec::static_length) +/// elements. +/// +/// ### Behavior +/// +/// ```text +/// BEFORE: _ [element: ElementType] *list [index: u32] +/// AFTER: _ +/// ``` +/// +/// ### Preconditions +/// +/// - the argument `*list` points to a properly [`BFieldCodec`]-encoded list +/// - all input arguments are properly [`BFieldCodec`] encoded +/// +/// ### Postconditions +/// +/// None. #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct Set { - pub element_type: DataType, + element_type: DataType, } impl Set { - pub fn new(data_type: DataType) -> Self { - Self { - element_type: data_type, - } - } -} + pub const INDEX_OUT_OF_BOUNDS_ERROR_ID: i128 = 390; -impl DeprecatedSnippet for Set { - fn entrypoint_name(&self) -> String { - format!( - "tasmlib_list_set_element___{}", - self.element_type.label_friendly_name() - ) - } + /// Any part of the list is outside the allocated memory page. + /// See the [memory convention][crate::memory] for more details. + pub const MEM_PAGE_ACCESS_VIOLATION_ERROR_ID: i128 = 391; - fn input_field_names(&self) -> Vec { - // _ elem{{N - 1}}, elem{{N - 2}}, ..., elem{{0}} *list index - [ - vec!["element".to_string(); self.element_type.stack_size()], - vec!["*list".to_string(), "index".to_string()], - ] - .concat() + /// # Panics + /// + /// Panics if the element has [dynamic length][BFieldCodec::static_length], or + /// if the static length is 0. + pub fn new(element_type: DataType) -> Self { + Get::assert_element_type_is_supported(&element_type); + + Self { element_type } } +} + +impl BasicSnippet for Set { + fn inputs(&self) -> Vec<(DataType, String)> { + let element_type = self.element_type.clone(); + let list_type = DataType::List(Box::new(element_type.clone())); + let index_type = DataType::U32; - fn input_types(&self) -> Vec { vec![ - self.element_type.clone(), - DataType::List(Box::new(self.element_type.clone())), - DataType::U32, + (element_type, "element".to_string()), + (list_type, "*list".to_string()), + (index_type, "index".to_string()), ] } - fn output_field_names(&self) -> Vec { + fn outputs(&self) -> Vec<(DataType, String)> { vec![] } - fn output_types(&self) -> Vec { - vec![] + fn entrypoint(&self) -> String { + let element_type = self.element_type.label_friendly_name(); + format!("tasmlib_list_set_element___{element_type}") } - fn stack_diff(&self) -> isize { - -2 - self.element_type.stack_size() as isize - } - - fn function_code(&self, _library: &mut Library) -> String { - let entrypoint = self.entrypoint_name(); - let element_size = self.element_type.stack_size(); - - let write_elements_to_memory_code = self.element_type.write_value_to_memory_leave_pointer(); - - let mul_with_size = if element_size != 1 { - triton_asm!(push {element_size} mul) - } else { - triton_asm!() + fn code(&self, library: &mut Library) -> Vec { + let list_length = library.import(Box::new(Length)); + let mul_with_element_size = match self.element_type.stack_size() { + 1 => triton_asm!(), // no-op + n => triton_asm!(push {n} mul), + }; + let add_element_size_minus_1 = match self.element_type.stack_size() { + 1 => triton_asm!(), // no-op + n => triton_asm!(addi {n - 1}), }; - triton_asm!( - // BEFORE: _ elem{{N - 1}}, elem{{N - 2}}, ..., elem{{0}} *list index - // AFTER: _ - {entrypoint}: - {&mul_with_size} - // _ [value] *list offset_for_previous_elements - - push 1 - add - // _ [value] *list offset_including_length_indicator - - add - // _ [value] *element - - {&write_elements_to_memory_code} - - // stack: _ *next_element - pop 1 - return + triton_asm!( + // BEFORE: _ [element: self.element_type] *list index + // AFTER: _ + {self.entrypoint()}: + /* assert access is in bounds */ + dup 1 + call {list_length} // _ [element] *list index len + dup 1 + lt // _ [element] *list index (index < len) + assert error_id {Self::INDEX_OUT_OF_BOUNDS_ERROR_ID} + // _ [element] *list index + + {&mul_with_element_size} + // _ [element] *list offset_for_previous_elements + addi 1 // _ [element] *list offset + + /* assert access is within one memory page */ + dup 0 + {&add_element_size_minus_1} + // _ [element] *list offset highest_word_idx + split + pop 1 + push 0 + eq + assert error_id {Self::MEM_PAGE_ACCESS_VIOLATION_ERROR_ID} + // _ [element] *list offset_including_list_metadata + + add // _ [element] *element + {&self.element_type.write_value_to_memory_pop_pointer()} + // _ + return ) - .iter() - .join("\n") - } - - fn crash_conditions(&self) -> Vec { - vec![] - } - - fn gen_input_states(&self) -> Vec { - vec![ - prepare_state(&self.element_type), - prepare_state(&self.element_type), - prepare_state(&self.element_type), - ] } - - fn common_case_input_state(&self) -> InitVmState { - prepare_state(&self.element_type) - } - - fn worst_case_input_state(&self) -> InitVmState { - prepare_state(&self.element_type) - } - - fn rust_shadowing( - &self, - stack: &mut Vec, - _std_in: Vec, - _secret_in: Vec, - memory: &mut HashMap, - ) { - let index: u32 = stack.pop().unwrap().try_into().unwrap(); - let list_pointer = stack.pop().unwrap(); - let mut element: Vec = - vec![BFieldElement::new(0); self.element_type.stack_size()]; - for ee in element.iter_mut() { - *ee = stack.pop().unwrap(); - } - list_set(list_pointer, index as usize, element, memory); - } -} - -fn prepare_state(data_type: &DataType) -> InitVmState { - let list_length: usize = thread_rng().gen_range(1..100); - let index: usize = thread_rng().gen_range(0..list_length); - let mut stack = empty_stack(); - let mut push_value: Vec = random_elements(data_type.stack_size()); - while let Some(element) = push_value.pop() { - stack.push(element); - } - - let list_pointer: BFieldElement = random(); - stack.push(list_pointer); - stack.push(BFieldElement::new(index as u64)); - - let mut memory = HashMap::default(); - untyped_insert_random_list( - list_pointer, - list_length, - &mut memory, - data_type.stack_size(), - ); - InitVmState::with_stack_and_memory(stack, memory) } #[cfg(test)] mod tests { - use super::*; - use crate::test_helpers::test_rust_equivalence_given_input_values_deprecated; - use crate::test_helpers::test_rust_equivalence_multiple_deprecated; + use proptest::collection::vec; + use triton_vm::error::OpStackError::FailedU32Conversion; - #[test] - fn new_snippet_test() { - test_rust_equivalence_multiple_deprecated( - &Set { - element_type: DataType::Bool, - }, - true, - ); - test_rust_equivalence_multiple_deprecated( - &Set { - element_type: DataType::Bfe, - }, - true, - ); - test_rust_equivalence_multiple_deprecated( - &Set { - element_type: DataType::U32, - }, - true, - ); - test_rust_equivalence_multiple_deprecated( - &Set { - element_type: DataType::U64, - }, - true, - ); - test_rust_equivalence_multiple_deprecated( - &Set { - element_type: DataType::Xfe, - }, - true, - ); - test_rust_equivalence_multiple_deprecated( - &Set { - element_type: DataType::Digest, - }, - true, - ); + use super::*; + use crate::rust_shadowing_helper_functions::list::insert_random_list; + use crate::rust_shadowing_helper_functions::list::list_set; + use crate::test_helpers::negative_test; + use crate::test_prelude::*; + use crate::U32_TO_USIZE_ERR; + + impl Set { + fn set_up_initial_state( + &self, + list_length: usize, + index: usize, + list_pointer: BFieldElement, + element: Vec, + ) -> FunctionInitialState { + let mut memory = HashMap::default(); + insert_random_list(&self.element_type, list_pointer, list_length, &mut memory); + + let mut stack = self.init_stack_for_isolated_run(); + stack.extend(element.into_iter().rev()); + stack.push(list_pointer); + stack.push(bfe!(index)); + + FunctionInitialState { stack, memory } + } } - #[test] - fn list_u32_n_is_one_set() { - let list_address = BFieldElement::new(48); - let insert_value = vec![BFieldElement::new(1337)]; - prop_set(DataType::Bfe, list_address, 20, insert_value, 2); - } + impl Function for Set { + fn rust_shadow( + &self, + stack: &mut Vec, + memory: &mut HashMap, + ) { + let index = pop_encodable::(stack); + let list_pointer = stack.pop().unwrap(); + let element = (0..self.element_type.stack_size()) + .map(|_| stack.pop().unwrap()) + .collect_vec(); + + let index = index.try_into().expect(U32_TO_USIZE_ERR); + list_set(list_pointer, index, element, memory); + } - #[test] - fn list_u32_n_is_three_set() { - let list_address = BFieldElement::new(48); - let insert_value = vec![ - BFieldElement::new(1337), - BFieldElement::new(1337), - BFieldElement::new(1337), - ]; - prop_set(DataType::Xfe, list_address, 20, insert_value, 2); - } + fn pseudorandom_initial_state( + &self, + seed: [u8; 32], + bench_case: Option, + ) -> FunctionInitialState { + let mut rng = StdRng::from_seed(seed); + let (list_length, index, list_pointer) = Get::random_len_idx_ptr(bench_case, &mut rng); + let element = self.element_type.seeded_random_element(&mut rng); - #[test] - fn list_u32_n_is_two_set() { - let list_address = BFieldElement::new(1841); - let push_value = vec![BFieldElement::new(133700), BFieldElement::new(32)]; - prop_set(DataType::U64, list_address, 20, push_value, 0); + self.set_up_initial_state(list_length, index, list_pointer, element) + } } #[test] - fn list_u32_n_is_five_set() { - let list_address = BFieldElement::new(558); - let push_value = vec![ - BFieldElement::new(133700), - BFieldElement::new(32), - BFieldElement::new(133700), - BFieldElement::new(19990), - BFieldElement::new(88888888), - ]; - prop_set(DataType::Digest, list_address, 2313, push_value, 589); + fn rust_shadow() { + for ty in [ + DataType::Bool, + DataType::Bfe, + DataType::U32, + DataType::U64, + DataType::Xfe, + DataType::Digest, + ] { + ShadowedFunction::new(Set::new(ty)).test(); + } } - fn prop_set( - data_type: DataType, - list_address: BFieldElement, - init_list_length: u32, - push_value: Vec, - index: u32, + #[proptest] + fn out_of_bounds_access_crashes_vm( + #[strategy(0_usize..=1_000)] list_length: usize, + #[strategy(#list_length..1 << 32)] index: usize, + #[strategy(arb())] list_pointer: BFieldElement, + #[strategy(vec(arb(), 1))] element: Vec, ) { - let expected_end_stack = [empty_stack()].concat(); - let mut init_stack = empty_stack(); - - for i in 0..data_type.stack_size() { - init_stack.push(push_value[data_type.stack_size() - 1 - i]); - } - init_stack.push(list_address); - init_stack.push(BFieldElement::new(index as u64)); - - let mut vm_memory = HashMap::default(); - - // Insert length indicator of list, lives on offset = 0 from `list_address` - untyped_insert_random_list( - list_address, - init_list_length as usize, - &mut vm_memory, - data_type.stack_size(), + let set = Set::new(DataType::Bfe); + let initial_state = set.set_up_initial_state(list_length, index, list_pointer, element); + test_assertion_failure( + &ShadowedFunction::new(set), + initial_state.into(), + &[Set::INDEX_OUT_OF_BOUNDS_ERROR_ID], ); + } - let memory = test_rust_equivalence_given_input_values_deprecated( - &Set { - element_type: data_type.clone(), - }, - &init_stack, - &[], - vm_memory, - Some(&expected_end_stack), - ) - .ram; - - // Verify that length indicator is unchanged - assert_eq!( - BFieldElement::new(init_list_length as u64), - memory[&list_address] + #[proptest] + fn too_large_indices_crash_vm( + #[strategy(1_usize << 32..)] index: usize, + #[strategy(arb())] list_pointer: BFieldElement, + #[strategy(vec(arb(), 1))] element: Vec, + ) { + let list_length = 0; + let set = Set::new(DataType::Bfe); + let initial_state = set.set_up_initial_state(list_length, index, list_pointer, element); + let expected_error = InstructionError::OpStackError(FailedU32Conversion(bfe!(index))); + negative_test( + &ShadowedFunction::new(set), + initial_state.into(), + &[expected_error], ); + } - // verify that value was inserted at expected place - for i in 0..data_type.stack_size() { - assert_eq!( - push_value[i], - memory[&BFieldElement::new( - list_address.value() - + 1 - + data_type.stack_size() as u64 * index as u64 - + i as u64 - )] - ); - } + /// See mirroring test for [`Get`] for an explanation. + #[proptest(cases = 100)] + fn too_large_lists_crash_vm( + #[strategy(1_u64 << 22..1 << 32)] list_length: u64, + #[strategy((1 << 22) - 1..#list_length)] index: u64, + #[strategy(arb())] list_pointer: BFieldElement, + ) { + // spare host machine RAM: pretend every element is all-zeros + let mut memory = HashMap::default(); + memory.insert(list_pointer, bfe!(list_length)); + + // type with a large stack size in Triton VM without breaking the host machine + let tuple_ty = DataType::Tuple(vec![DataType::Bfe; 1 << 10]); + let set = Set::new(tuple_ty); + + // no element on stack: stack underflow implies things have gone wrong already + let mut stack = set.init_stack_for_isolated_run(); + stack.push(list_pointer); + stack.push(bfe!(index)); + let initial_state = AccessorInitialState { stack, memory }; + + test_assertion_failure( + &ShadowedFunction::new(set), + initial_state.into(), + &[Set::MEM_PAGE_ACCESS_VIOLATION_ERROR_ID], + ); } } #[cfg(test)] mod benches { use super::*; - use crate::snippet_bencher::bench_and_write; + use crate::test_prelude::*; #[test] - fn set_benchmark() { - bench_and_write(Set::new(DataType::Digest)); + fn benchmark() { + ShadowedFunction::new(Set::new(DataType::Digest)).bench(); } } diff --git a/tasm-lib/src/rust_shadowing_helper_functions/list.rs b/tasm-lib/src/rust_shadowing_helper_functions/list.rs index 6d090a88..9e4092ea 100644 --- a/tasm-lib/src/rust_shadowing_helper_functions/list.rs +++ b/tasm-lib/src/rust_shadowing_helper_functions/list.rs @@ -200,21 +200,21 @@ pub fn list_get( list_pointer: BFieldElement, index: usize, memory: &HashMap, - element_length: usize, + element_size: usize, ) -> Vec { let list_len = list_get_length(list_pointer, memory); assert!(index < list_len, "out of bounds: {index} >= {list_len}"); - let highest_access_index = LIST_METADATA_SIZE + element_length * (index + 1); + let highest_access_index = LIST_METADATA_SIZE + element_size * (index + 1); assert!(u64::try_from(highest_access_index).expect(USIZE_TO_U64_ERR) < DYN_MALLOC_PAGE_SIZE); let read_word = |i| { - let word_offset = LIST_METADATA_SIZE + element_length * index + i; + let word_offset = LIST_METADATA_SIZE + element_size * index + i; let word_index = list_pointer + bfe!(word_offset); memory[&word_index] }; - (0..element_length).map(read_word).collect() + (0..element_size).map(read_word).collect() } /// Write an element to a list. @@ -223,7 +223,11 @@ pub fn list_get( /// /// # Panics /// -/// Panics if the `index` is out of bounds. +/// Panics if +/// - the `index` is out of bounds, or +/// - the element that is to be read resides outside the list`s +/// [memory page][crate::memory], or +/// - the pointed-to-list is incorrectly encoded into `memory`. pub fn list_set( list_pointer: BFieldElement, index: usize, @@ -234,6 +238,9 @@ pub fn list_set( assert!(index < list_len, "out of bounds: {index} >= {list_len}"); let element_size = value.len(); + let highest_access_index = LIST_METADATA_SIZE + element_size * (index + 1); + assert!(u64::try_from(highest_access_index).expect(USIZE_TO_U64_ERR) < DYN_MALLOC_PAGE_SIZE); + for (i, word) in value.into_iter().enumerate() { let word_offset = LIST_METADATA_SIZE + element_size * index + i; let word_index = list_pointer + bfe!(word_offset);