diff --git a/tasm-lib/benchmarks/tasmlib_arithmetic_u128_add.json b/tasm-lib/benchmarks/tasmlib_arithmetic_u128_safe_add.json similarity index 82% rename from tasm-lib/benchmarks/tasmlib_arithmetic_u128_add.json rename to tasm-lib/benchmarks/tasmlib_arithmetic_u128_safe_add.json index adaf21c9..ace06413 100644 --- a/tasm-lib/benchmarks/tasmlib_arithmetic_u128_add.json +++ b/tasm-lib/benchmarks/tasmlib_arithmetic_u128_safe_add.json @@ -1,6 +1,6 @@ [ { - "name": "tasmlib_arithmetic_u128_add", + "name": "tasmlib_arithmetic_u128_safe_add", "benchmark_result": { "clock_cycle_count": 29, "hash_table_height": 30, @@ -11,7 +11,7 @@ "case": "CommonCase" }, { - "name": "tasmlib_arithmetic_u128_add", + "name": "tasmlib_arithmetic_u128_safe_add", "benchmark_result": { "clock_cycle_count": 29, "hash_table_height": 30, diff --git a/tasm-lib/src/arithmetic/u128.rs b/tasm-lib/src/arithmetic/u128.rs index b7e52921..16d9c565 100644 --- a/tasm-lib/src/arithmetic/u128.rs +++ b/tasm-lib/src/arithmetic/u128.rs @@ -1,4 +1,4 @@ -pub mod add_u128; +pub mod safe_add; pub mod safe_mul_u128; pub mod shift_left_static_u128; pub mod shift_left_u128; diff --git a/tasm-lib/src/arithmetic/u128/add_u128.rs b/tasm-lib/src/arithmetic/u128/safe_add.rs similarity index 67% rename from tasm-lib/src/arithmetic/u128/add_u128.rs rename to tasm-lib/src/arithmetic/u128/safe_add.rs index ab302682..14a08853 100644 --- a/tasm-lib/src/arithmetic/u128/add_u128.rs +++ b/tasm-lib/src/arithmetic/u128/safe_add.rs @@ -4,12 +4,12 @@ use crate::data_type::DataType; use crate::library::Library; use crate::traits::basic_snippet::BasicSnippet; -#[derive(Clone, Debug)] -pub struct AddU128; +#[derive(Clone, Debug, Copy)] +pub struct SafeAddU128; -impl BasicSnippet for AddU128 { +impl BasicSnippet for SafeAddU128 { fn entrypoint(&self) -> String { - "tasmlib_arithmetic_u128_add".to_string() + "tasmlib_arithmetic_u128_safe_add".to_string() } fn inputs(&self) -> Vec<(DataType, String)> { @@ -104,26 +104,29 @@ impl BasicSnippet for AddU128 { #[cfg(test)] mod tests { use itertools::Itertools; + use num::Zero; use rand::rngs::StdRng; use rand::Rng; use rand::SeedableRng; use crate::snippet_bencher::BenchmarkCase; + use crate::test_helpers::negative_test; use crate::test_helpers::test_rust_equivalence_given_complete_state; use crate::traits::closure::Closure; use crate::traits::closure::ShadowedClosure; use crate::traits::rust_shadow::RustShadow; + use crate::InitVmState; use super::*; #[test] fn add_u128_test() { - ShadowedClosure::new(AddU128).test() + ShadowedClosure::new(SafeAddU128).test() } #[test] fn add_u128_unit_test() { - let snippet = AddU128; + let snippet = SafeAddU128; let mut expected = snippet.init_stack_for_isolated_run(); expected.push(BFieldElement::new(0)); expected.push(BFieldElement::new(1 << 4)); @@ -132,18 +135,63 @@ mod tests { snippet.prop_add(1u128 << 67, 1u128 << 67, Some(&expected)) } - impl AddU128 { + #[test] + fn add_u128_overflow_test() { + let snippet = SafeAddU128; + + for (a, b) in [ + (1u128 << 127, 1u128 << 127), + (u128::MAX, u128::MAX), + (u128::MAX, 1), + (u128::MAX, 1 << 31), + (u128::MAX, 1 << 32), + (u128::MAX, 1 << 33), + (u128::MAX, 1 << 63), + (u128::MAX, 1 << 64), + (u128::MAX, 1 << 65), + (u128::MAX, 1 << 95), + (u128::MAX, 1 << 96), + (u128::MAX, 1 << 97), + (u128::MAX - 1, 2), + ] { + negative_test( + &ShadowedClosure::new(snippet), + InitVmState::with_stack(snippet.setup_init_stack(a, b)), + &[InstructionError::AssertionFailed], + ); + negative_test( + &ShadowedClosure::new(snippet), + InitVmState::with_stack(snippet.setup_init_stack(b, a)), + &[InstructionError::AssertionFailed], + ); + } + + for i in 0..128 { + let a = u128::MAX - ((1u128 << i) - 1); + let b = 1u128 << i; + + // sanity check of test input values + assert!(a.wrapping_add(b).is_zero()); + + negative_test( + &ShadowedClosure::new(snippet), + InitVmState::with_stack(snippet.setup_init_stack(a, b)), + &[InstructionError::AssertionFailed], + ); + negative_test( + &ShadowedClosure::new(snippet), + InitVmState::with_stack(snippet.setup_init_stack(b, a)), + &[InstructionError::AssertionFailed], + ); + } + } + + impl SafeAddU128 { fn prop_add(&self, lhs: u128, rhs: u128, expected: Option<&[BFieldElement]>) { - let mut init_stack = self.init_stack_for_isolated_run(); - for elem in rhs.encode().into_iter().rev() { - init_stack.push(elem); - } - for elem in lhs.encode().into_iter().rev() { - init_stack.push(elem); - } + let init_stack = self.setup_init_stack(lhs, rhs); test_rust_equivalence_given_complete_state( - &ShadowedClosure::new(AddU128), + &ShadowedClosure::new(SafeAddU128), &init_stack, &[], &NonDeterminism::default(), @@ -162,7 +210,7 @@ mod tests { } } - impl Closure for AddU128 { + impl Closure for SafeAddU128 { fn rust_shadow(&self, stack: &mut Vec) { fn to_u128(a: u32, b: u32, c: u32, d: u32) -> u128 { a as u128 @@ -206,6 +254,21 @@ mod tests { self.setup_init_stack(lhs, rhs) } + + fn corner_case_initial_states(&self) -> Vec> { + let some_none_zero_value = (1u128 << 97) + (1u128 << 65) + u64::MAX as u128 - 456456; + let rhs_is_zero = self.setup_init_stack(some_none_zero_value, 0); + let lhs_is_zero = self.setup_init_stack(0, some_none_zero_value); + let rhs_is_zero_lhs_max = self.setup_init_stack(u128::MAX, 0); + let lhs_is_zero_rhs_max = self.setup_init_stack(0, u128::MAX); + + vec![ + rhs_is_zero, + lhs_is_zero, + rhs_is_zero_lhs_max, + lhs_is_zero_rhs_max, + ] + } } } @@ -218,6 +281,6 @@ mod benches { #[test] fn add_u128_benchmark() { - ShadowedClosure::new(AddU128).bench() + ShadowedClosure::new(SafeAddU128).bench() } } diff --git a/tasm-lib/src/exported_snippets.rs b/tasm-lib/src/exported_snippets.rs index 98114511..e919880a 100644 --- a/tasm-lib/src/exported_snippets.rs +++ b/tasm-lib/src/exported_snippets.rs @@ -5,7 +5,7 @@ use triton_vm::table::extension_table::Quotientable; use triton_vm::table::master_table::MasterExtTable; use triton_vm::table::NUM_QUOTIENT_SEGMENTS; -use crate::arithmetic::u128::add_u128::AddU128; +use crate::arithmetic::u128::safe_add::SafeAddU128; use crate::arithmetic::u128::safe_mul_u128::SafeMulU128; use crate::arithmetic::u128::shift_left_static_u128::ShiftLeftStaticU128; use crate::arithmetic::u128::shift_left_u128::ShiftLeftU128; @@ -161,7 +161,7 @@ pub fn name_to_snippet(fn_name: &str) -> Box { "tasmlib_arithmetic_u64_overflowing_sub" => Box::new(OverflowingSub), // u128 - "tasmlib_arithmetic_u128_add" => Box::new(AddU128), + "tasmlib_arithmetic_u128_safe_add" => Box::new(SafeAddU128), "tasmlib_arithmetic_u128_shift_left" => Box::new(ShiftLeftU128), "tasmlib_arithmetic_u128_shift_right" => Box::new(ShiftRightU128), "tasmlib_arithmetic_u128_sub" => Box::new(SubU128),