diff --git a/src/tests_and_benchmarks/ozk/programs/recufier.rs b/src/tests_and_benchmarks/ozk/programs/recufier.rs index f9dbaa0c..a90e019a 100644 --- a/src/tests_and_benchmarks/ozk/programs/recufier.rs +++ b/src/tests_and_benchmarks/ozk/programs/recufier.rs @@ -5,7 +5,6 @@ mod eval_arg; mod fast_ntt; mod fast_ntt_to_basic_snippet; mod fri_verify; -mod host_machine_vm_proof_iter; mod merkle_root; mod merkle_root_autogen; pub(crate) mod stark_parameters; diff --git a/src/tests_and_benchmarks/ozk/programs/recufier/eval_arg.rs b/src/tests_and_benchmarks/ozk/programs/recufier/eval_arg.rs index 17dffd13..f57a9ac3 100644 --- a/src/tests_and_benchmarks/ozk/programs/recufier/eval_arg.rs +++ b/src/tests_and_benchmarks/ozk/programs/recufier/eval_arg.rs @@ -3,7 +3,7 @@ use num::One; use tasm_lib::triton_vm::prelude::*; #[derive(Debug, Clone, Copy, Eq, PartialEq)] -pub struct EvalArg; +struct EvalArg; impl EvalArg { fn _default_initial() -> XFieldElement { @@ -14,6 +14,8 @@ impl EvalArg { /// and `symbols`. This amounts to evaluating polynomial /// `f(x) = initial·x^n + Σ_i symbols[n-i]·x^i` /// at point `challenge`, _i.e._, returns `f(challenge)`. + /// Consider using `tasm-lib` snippets directly instead of this code. The `tasm-lib` snippets + /// produce a much shorter execution trace. fn compute_terminal( symbols: Vec, initial: XFieldElement, diff --git a/src/tests_and_benchmarks/ozk/programs/recufier/fast_ntt.rs b/src/tests_and_benchmarks/ozk/programs/recufier/fast_ntt.rs index 38d2855d..78b0d619 100644 --- a/src/tests_and_benchmarks/ozk/programs/recufier/fast_ntt.rs +++ b/src/tests_and_benchmarks/ozk/programs/recufier/fast_ntt.rs @@ -1,149 +1,146 @@ -#![allow(clippy::manual_swap)] - -use num::One; -use tasm_lib::triton_vm::prelude::*; -use tasm_lib::twenty_first::prelude::ModPowU32; - -use crate::tests_and_benchmarks::ozk::rust_shadows as tasm; - -#[allow(clippy::ptr_arg)] -#[allow(clippy::vec_init_then_push)] -fn main() { - fn xfe_ntt(x: &mut Vec, omega: BFieldElement, log_2_of_n: u32) { - fn bitreverse(mut n: u32, l: u32) -> u32 { - let mut r: u32 = 0; - let mut i: u32 = 0; - while i < l { - r = (r << 1) | (n & 1); - n >>= 1; - i += 1; - } - - return r; - } +#[cfg(test)] +mod test { + use crate::tests_and_benchmarks::ozk::ozk_parsing::EntrypointLocation; + use crate::tests_and_benchmarks::ozk::rust_shadows; + use crate::tests_and_benchmarks::test_helpers::shared_test::*; + use itertools::Itertools; + use num::One; + use tasm_lib::triton_vm::prelude::*; + use tasm_lib::twenty_first::prelude::ModPowU32; + use tasm_lib::twenty_first::shared_math::ntt; + use tasm_lib::twenty_first::shared_math::other::log_2_floor; + use tasm_lib::twenty_first::shared_math::other::random_elements; + use tasm_lib::twenty_first::shared_math::traits::PrimitiveRootOfUnity; - let n: u32 = x.len() as u32; - - { - let mut k: u32 = 0; - while k != n { - let rk: u32 = bitreverse(k, log_2_of_n); - if k < rk { - // TODO: Use `swap` here instead, once it's implemented in `tasm-lib` - // That will give us a shorter cycle count - // x.swap(rk as usize, k as usize); - let rk_val: XFieldElement = x[rk as usize]; - x[rk as usize] = x[k as usize]; - x[k as usize] = rk_val; + use crate::tests_and_benchmarks::ozk::rust_shadows as tasm; + + #[allow(clippy::ptr_arg)] + #[allow(clippy::manual_swap)] + #[allow(clippy::vec_init_then_push)] + fn main() { + fn xfe_ntt(x: &mut Vec, omega: BFieldElement, log_2_of_n: u32) { + fn bitreverse(mut n: u32, l: u32) -> u32 { + let mut r: u32 = 0; + let mut i: u32 = 0; + while i < l { + r = (r << 1) | (n & 1); + n >>= 1; + i += 1; } - k += 1; + return r; } - } - let mut m: u32 = 1; - - let mut outer_count: u32 = 0; - while outer_count != log_2_of_n { - // for _ in 0..log_2_of_n { - let w_m: BFieldElement = omega.mod_pow_u32(n / (2 * m)); - let mut k: u32 = 0; - while k < n { - let mut w: BFieldElement = BFieldElement::one(); - let mut j: u32 = 0; - while j != m { - // for j in 0..m { - let u: XFieldElement = x[(k + j) as usize]; - let mut v: XFieldElement = x[(k + j + m) as usize]; - v *= w; - x[(k + j) as usize] = u + v; - x[(k + j + m) as usize] = u - v; - w *= w_m; - - j += 1; + let n: u32 = x.len() as u32; + + { + let mut k: u32 = 0; + while k != n { + let rk: u32 = bitreverse(k, log_2_of_n); + if k < rk { + // TODO: Use `swap` here instead, once it's implemented in `tasm-lib` + // That will give us a shorter cycle count + // x.swap(rk as usize, k as usize); + let rk_val: XFieldElement = x[rk as usize]; + x[rk as usize] = x[k as usize]; + x[k as usize] = rk_val; + } + + k += 1; } - - k += 2 * m; } - m *= 2; + let mut m: u32 = 1; + + let mut outer_count: u32 = 0; + while outer_count != log_2_of_n { + // for _ in 0..log_2_of_n { + let w_m: BFieldElement = omega.mod_pow_u32(n / (2 * m)); + let mut k: u32 = 0; + while k < n { + let mut w: BFieldElement = BFieldElement::one(); + let mut j: u32 = 0; + while j != m { + // for j in 0..m { + let u: XFieldElement = x[(k + j) as usize]; + let mut v: XFieldElement = x[(k + j + m) as usize]; + v *= w; + x[(k + j) as usize] = u + v; + x[(k + j + m) as usize] = u - v; + w *= w_m; + + j += 1; + } + + k += 2 * m; + } + + m *= 2; - outer_count += 1; + outer_count += 1; + } + + return; } - return; - } + fn xfe_intt(x: &mut Vec, omega: BFieldElement, log_2_of_n: u32) { + { + let omega_inv: BFieldElement = BFieldElement::one() / omega; + xfe_ntt(x, omega_inv, log_2_of_n); + } + + let n_inv: BFieldElement = { + let n: BFieldElement = BFieldElement::new(x.len() as u64); + BFieldElement::one() / n + }; - fn xfe_intt(x: &mut Vec, omega: BFieldElement, log_2_of_n: u32) { - { - let omega_inv: BFieldElement = BFieldElement::one() / omega; - xfe_ntt(x, omega_inv, log_2_of_n); + let mut i: usize = 0; + let len: usize = x.len(); + while i < len { + x[i] *= n_inv; + i += 1; + } + + return; } - let n_inv: BFieldElement = { - let n: BFieldElement = BFieldElement::new(x.len() as u64); - BFieldElement::one() / n - }; + // NTT is equivalent to polynomial evaluation over the field generated by the `omega` generator + // where the input values are interpreted as coefficients. So an input of `[C, 0]` must output + // `[C, C]`, as the output is $P(x) = C$. + let omega: BFieldElement = tasm::tasm_io_read_stdin___bfe(); + let input_output_boxed: Box> = Vec::::decode( + &tasm::load_from_memory(BFieldElement::new(0x1000_0000_0000_0000u64)), + ) + .unwrap(); + let mut input_output: Vec = *input_output_boxed; + let size: usize = input_output.len(); + let log_2_size: u32 = u32::BITS - (size as u32).leading_zeros() - 1; + xfe_ntt(&mut input_output, omega, log_2_size); + assert!(BFieldElement::one() == omega.mod_pow_u32(size as u32)); let mut i: usize = 0; - let len: usize = x.len(); - while i < len { - x[i] *= n_inv; + + while i < size { + tasm::tasm_io_write_to_stdout___xfe(input_output[i]); i += 1; } - return; - } - - // NTT is equivalent to polynomial evaluation over the field generated by the `omega` generator - // where the input values are interpreted as coefficients. So an input of `[C, 0]` must output - // `[C, C]`, as the output is $P(x) = C$. - let omega: BFieldElement = tasm::tasm_io_read_stdin___bfe(); - let input_output_boxed: Box> = Vec::::decode( - &tasm::load_from_memory(BFieldElement::new(0x1000_0000_0000_0000u64)), - ) - .unwrap(); - let mut input_output: Vec = *input_output_boxed; - let size: usize = input_output.len(); - let log_2_size: u32 = u32::BITS - (size as u32).leading_zeros() - 1; - xfe_ntt(&mut input_output, omega, log_2_size); - assert!(BFieldElement::one() == omega.mod_pow_u32(size as u32)); - - let mut i: usize = 0; - - while i < size { - tasm::tasm_io_write_to_stdout___xfe(input_output[i]); - i += 1; - } + // We only output the NTT for the test, but we test that `xfe_intt` produces the + // inverse of `xfe_ntt`. + xfe_intt(&mut input_output, omega, log_2_size); + let input_copied: Box> = Vec::::decode( + &tasm::load_from_memory(BFieldElement::new(0x1000_0000_0000_0000u64)), + ) + .unwrap(); + i = 0; + while i < size { + assert!(input_copied[i] == input_output[i]); + i += 1; + } - // We only output the NTT for the test, but we test that `xfe_intt` produces the - // inverse of `xfe_ntt`. - xfe_intt(&mut input_output, omega, log_2_size); - let input_copied: Box> = Vec::::decode( - &tasm::load_from_memory(BFieldElement::new(0x1000_0000_0000_0000u64)), - ) - .unwrap(); - i = 0; - while i < size { - assert!(input_copied[i] == input_output[i]); - i += 1; + return; } - return; -} - -#[cfg(test)] -mod test { - use super::*; - use crate::tests_and_benchmarks::ozk::ozk_parsing::EntrypointLocation; - use crate::tests_and_benchmarks::ozk::rust_shadows; - use crate::tests_and_benchmarks::test_helpers::shared_test::*; - use itertools::Itertools; - use tasm_lib::twenty_first::shared_math::ntt; - use tasm_lib::twenty_first::shared_math::other::log_2_floor; - use tasm_lib::twenty_first::shared_math::other::random_elements; - use tasm_lib::twenty_first::shared_math::traits::PrimitiveRootOfUnity; - #[test] fn fast_xfe_ntt_test() { for input_length in [2, 4, 8, 16, 32, 64, 128] { @@ -175,7 +172,8 @@ mod test { } // Test function in Triton VM - let entrypoint_location = EntrypointLocation::disk("recufier", "fast_ntt", "main"); + let entrypoint_location = + EntrypointLocation::disk("recufier", "fast_ntt", "test::main"); let rust_ast = entrypoint_location.extract_entrypoint(); let expected_stack_diff = 0; let (code, _fn_name) = compile_for_run_test(&rust_ast); @@ -219,7 +217,7 @@ mod benches { } } - let entrypoint_location = EntrypointLocation::disk("recufier", "fast_ntt", "main"); + let entrypoint_location = EntrypointLocation::disk("recufier", "fast_ntt", "test::main"); let code = ozk_parsing::compile_for_test(&entrypoint_location); let common_case_input = get_input(32); diff --git a/src/tests_and_benchmarks/ozk/programs/recufier/fast_ntt_to_basic_snippet.rs b/src/tests_and_benchmarks/ozk/programs/recufier/fast_ntt_to_basic_snippet.rs index fa15ed6f..bd2f3d75 100644 --- a/src/tests_and_benchmarks/ozk/programs/recufier/fast_ntt_to_basic_snippet.rs +++ b/src/tests_and_benchmarks/ozk/programs/recufier/fast_ntt_to_basic_snippet.rs @@ -1,94 +1,88 @@ -#![allow(clippy::manual_swap)] - -use num::One; -use tasm_lib::triton_vm::prelude::*; -use tasm_lib::twenty_first::prelude::ModPowU32; - -#[allow(clippy::ptr_arg)] -#[allow(clippy::vec_init_then_push)] -#[allow(dead_code)] -fn xfe_ntt(x: &mut Vec, omega: BFieldElement) { - fn bitreverse(mut n: u32, l: u32) -> u32 { - let mut r: u32 = 0; - let mut i: u32 = 0; - while i < l { - r = (r << 1) | (n & 1); - n >>= 1; - i += 1; - } - - return r; - } +#[cfg(test)] +mod test { + use crate::tests_and_benchmarks::ozk::ozk_parsing::*; + use crate::tests_and_benchmarks::test_helpers::shared_test::bfe_lit; + use crate::tests_and_benchmarks::test_helpers::shared_test::compare_compiled_prop_with_stack_and_ins; + use num::One; + use std::collections::HashMap; + use tasm_lib::triton_vm::prelude::*; + use tasm_lib::twenty_first::prelude::ModPowU32; + use tasm_lib::twenty_first::prelude::XFieldElement; - let size: u32 = x.len() as u32; - let log_2_size: u32 = u32::BITS - size.leading_zeros() - 1; - - { - let mut k: u32 = 0; - while k != size { - let rk: u32 = bitreverse(k, log_2_size); - if k < rk { - // TODO: Use `swap` here instead, once it's implemented in `tasm-lib` - // That will give us a shorter cycle count - // x.swap(rk as usize, k as usize); - let rk_val: XFieldElement = x[rk as usize]; - x[rk as usize] = x[k as usize]; - x[k as usize] = rk_val; + #[allow(clippy::manual_swap)] + #[allow(clippy::ptr_arg)] + #[allow(clippy::vec_init_then_push)] + #[allow(dead_code)] + fn xfe_ntt(x: &mut Vec, omega: BFieldElement) { + fn bitreverse(mut n: u32, l: u32) -> u32 { + let mut r: u32 = 0; + let mut i: u32 = 0; + while i < l { + r = (r << 1) | (n & 1); + n >>= 1; + i += 1; } - k += 1; + return r; } - } - let mut m: u32 = 1; - - let mut outer_count: u32 = 0; - while outer_count != log_2_size { - // for _ in 0..log_2_of_n { - let w_m: BFieldElement = omega.mod_pow_u32(size / (2 * m)); - let mut k: u32 = 0; - while k < size { - let mut w: BFieldElement = BFieldElement::one(); - let mut j: u32 = 0; - while j != m { - // for j in 0..m { - let u: XFieldElement = x[(k + j) as usize]; - let mut v: XFieldElement = x[(k + j + m) as usize]; - v *= w; - x[(k + j) as usize] = u + v; - x[(k + j + m) as usize] = u - v; - w *= w_m; - - j += 1; + let size: u32 = x.len() as u32; + let log_2_size: u32 = u32::BITS - size.leading_zeros() - 1; + + { + let mut k: u32 = 0; + while k != size { + let rk: u32 = bitreverse(k, log_2_size); + if k < rk { + // TODO: Use `swap` here instead, once it's implemented in `tasm-lib` + // That will give us a shorter cycle count + // x.swap(rk as usize, k as usize); + let rk_val: XFieldElement = x[rk as usize]; + x[rk as usize] = x[k as usize]; + x[k as usize] = rk_val; + } + + k += 1; } - - k += 2 * m; } - m *= 2; - - outer_count += 1; - } - - return; -} - -#[cfg(test)] -mod test { - use std::collections::HashMap; + let mut m: u32 = 1; + + let mut outer_count: u32 = 0; + while outer_count != log_2_size { + // for _ in 0..log_2_of_n { + let w_m: BFieldElement = omega.mod_pow_u32(size / (2 * m)); + let mut k: u32 = 0; + while k < size { + let mut w: BFieldElement = BFieldElement::one(); + let mut j: u32 = 0; + while j != m { + // for j in 0..m { + let u: XFieldElement = x[(k + j) as usize]; + let mut v: XFieldElement = x[(k + j + m) as usize]; + v *= w; + x[(k + j) as usize] = u + v; + x[(k + j + m) as usize] = u - v; + w *= w_m; + + j += 1; + } + + k += 2 * m; + } - use tasm_lib::triton_vm::prelude::*; + m *= 2; - use tasm_lib::twenty_first::prelude::XFieldElement; + outer_count += 1; + } - use crate::tests_and_benchmarks::ozk::ozk_parsing::*; - use crate::tests_and_benchmarks::test_helpers::shared_test::bfe_lit; - use crate::tests_and_benchmarks::test_helpers::shared_test::compare_compiled_prop_with_stack_and_ins; + return; + } #[test] fn fast_xfe_ntt_to_basic_snippet_test() { let entrypoint_location = - EntrypointLocation::disk("recufier", "fast_ntt_to_basic_snippet", "xfe_ntt"); + EntrypointLocation::disk("recufier", "fast_ntt_to_basic_snippet", "test::xfe_ntt"); let compiled = compile_for_test(&entrypoint_location); let list_pointer = BFieldElement::new(100); let list_length = BFieldElement::new(1); @@ -115,7 +109,7 @@ mod test { // Output what we came for: A `BasicSnippet` implementation constructed by the compiler let entrypoint_location = - EntrypointLocation::disk("recufier", "fast_ntt_to_basic_snippet", "xfe_ntt"); + EntrypointLocation::disk("recufier", "fast_ntt_to_basic_snippet", "test::xfe_ntt"); let rust_ast = entrypoint_location.extract_entrypoint(); let as_bs = compile_to_basic_snippet(rust_ast, HashMap::default()); println!("{}", as_bs); diff --git a/src/tests_and_benchmarks/ozk/programs/recufier/fri_verify.rs b/src/tests_and_benchmarks/ozk/programs/recufier/fri_verify.rs index 44e671e0..49eeed0d 100644 --- a/src/tests_and_benchmarks/ozk/programs/recufier/fri_verify.rs +++ b/src/tests_and_benchmarks/ozk/programs/recufier/fri_verify.rs @@ -40,12 +40,6 @@ impl FriVerify { } } -fn main() { - let _a: FriVerify = FriVerify::new(BFieldElement::new(7), 32, 4, 3); - - return; -} - #[cfg(test)] pub(crate) mod test { use rand::random; @@ -58,6 +52,12 @@ pub(crate) mod test { use super::*; + fn main() { + let _a: FriVerify = FriVerify::new(BFieldElement::new(7), 32, 4, 3); + + return; + } + #[test] fn fri_verify_test() { // Rust program on host machine @@ -67,7 +67,7 @@ pub(crate) mod test { rust_shadows::wrap_main_with_io(&main)(stdin.clone(), non_determinism.clone()); // Run test on Triton-VM - let entrypoint_location = EntrypointLocation::disk("recufier", "fri_verify", "main"); + let entrypoint_location = EntrypointLocation::disk("recufier", "fri_verify", "test::main"); let test_program = ozk_parsing::compile_for_test(&entrypoint_location); let expected_stack_diff = 0; diff --git a/src/tests_and_benchmarks/ozk/programs/recufier/host_machine_vm_proof_iter.rs b/src/tests_and_benchmarks/ozk/programs/recufier/host_machine_vm_proof_iter.rs deleted file mode 100644 index 8b137891..00000000 --- a/src/tests_and_benchmarks/ozk/programs/recufier/host_machine_vm_proof_iter.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/tests_and_benchmarks/ozk/programs/recufier/merkle_root.rs b/src/tests_and_benchmarks/ozk/programs/recufier/merkle_root.rs index 693bd187..ba7ec9e4 100644 --- a/src/tests_and_benchmarks/ozk/programs/recufier/merkle_root.rs +++ b/src/tests_and_benchmarks/ozk/programs/recufier/merkle_root.rs @@ -1,49 +1,20 @@ -#![allow(clippy::explicit_auto_deref)] - -use tasm_lib::triton_vm::prelude::*; -use tasm_lib::twenty_first::prelude::AlgebraicHasher; - -use crate::tests_and_benchmarks::ozk::rust_shadows as tasm; - -fn main() { - fn merkle_root(leafs: &Vec, start: usize, stop: usize) -> Digest { - let result: Digest = if stop == start + 1usize { - leafs[start] - } else { - let half: usize = (stop - start) / 2; - let left: Digest = merkle_root(leafs, start, stop - half); - let right: Digest = merkle_root(leafs, start + half, stop); - Tip5::hash_pair(left, right) - }; - - return result; - } - - let elements: Box> = - Vec::::decode(&tasm::load_from_memory(BFieldElement::new(2000))).unwrap(); - let length: usize = elements.len(); - - let root: Digest = merkle_root(&(*elements), 0usize, length); - tasm::tasm_io_write_to_stdout___digest(root); - - return; -} - +#[allow(clippy::explicit_auto_deref)] #[cfg(test)] mod test { + use tasm_lib::triton_vm::prelude::*; + use tasm_lib::twenty_first::prelude::AlgebraicHasher; use tasm_lib::twenty_first::shared_math::other::random_elements; use tasm_lib::twenty_first::util_types::merkle_tree::CpuParallel; use tasm_lib::twenty_first::util_types::merkle_tree::MerkleTree; use tasm_lib::twenty_first::util_types::merkle_tree_maker::MerkleTreeMaker; use crate::tests_and_benchmarks::ozk::ozk_parsing::EntrypointLocation; + use crate::tests_and_benchmarks::ozk::rust_shadows as tasm; use crate::tests_and_benchmarks::ozk::rust_shadows; use crate::tests_and_benchmarks::test_helpers::shared_test::execute_compiled_with_stack_and_ins_for_test; use crate::tests_and_benchmarks::test_helpers::shared_test::init_memory_from; use crate::tests_and_benchmarks::test_helpers::shared_test::*; - use super::*; - #[test] fn merkle_root_test() { // Test function on host machine @@ -58,7 +29,8 @@ mod test { assert_eq!(native_output, expected_output); // Test function in Triton VM - let entrypoint_location = EntrypointLocation::disk("recufier", "merkle_root", "main"); + let entrypoint_location = + EntrypointLocation::disk("recufier", "merkle_root", "test::main"); let rust_ast = entrypoint_location.extract_entrypoint(); let expected_stack_diff = 0; let (code, _fn_name) = compile_for_run_test(&rust_ast); @@ -73,9 +45,34 @@ mod test { assert_eq!(expected_output, vm_output.public_output); } } + + fn main() { + fn merkle_root(leafs: &Vec, start: usize, stop: usize) -> Digest { + let result: Digest = if stop == start + 1usize { + leafs[start] + } else { + let half: usize = (stop - start) / 2; + let left: Digest = merkle_root(leafs, start, stop - half); + let right: Digest = merkle_root(leafs, start + half, stop); + Tip5::hash_pair(left, right) + }; + + return result; + } + + let elements: Box> = + Vec::::decode(&tasm::load_from_memory(BFieldElement::new(2000))).unwrap(); + let length: usize = elements.len(); + + let root: Digest = merkle_root(&(*elements), 0usize, length); + tasm::tasm_io_write_to_stdout___digest(root); + + return; + } } mod benches { + use tasm_lib::triton_vm::prelude::*; use tasm_lib::twenty_first::shared_math::other::random_elements; use crate::tests_and_benchmarks::benchmarks::execute_and_write_benchmark; @@ -85,8 +82,6 @@ mod benches { use crate::tests_and_benchmarks::ozk::ozk_parsing::EntrypointLocation; use crate::tests_and_benchmarks::test_helpers::shared_test::*; - use super::*; - #[test] fn merkle_root_bench() { fn get_input(length: usize) -> BenchmarkInput { @@ -99,7 +94,7 @@ mod benches { } } - let entrypoint_location = EntrypointLocation::disk("recufier", "merkle_root", "main"); + let entrypoint_location = EntrypointLocation::disk("recufier", "merkle_root", "test::main"); let code = ozk_parsing::compile_for_test(&entrypoint_location); let common_case_input = get_input(16); diff --git a/src/tests_and_benchmarks/ozk/programs/recufier/merkle_root_autogen.rs b/src/tests_and_benchmarks/ozk/programs/recufier/merkle_root_autogen.rs index 408044d6..c3e3493a 100644 --- a/src/tests_and_benchmarks/ozk/programs/recufier/merkle_root_autogen.rs +++ b/src/tests_and_benchmarks/ozk/programs/recufier/merkle_root_autogen.rs @@ -1,39 +1,35 @@ -#![allow(clippy::manual_swap)] - -use tasm_lib::twenty_first::prelude::AlgebraicHasher; -use tasm_lib::Digest; - -use crate::twenty_first::prelude::*; - -#[allow(clippy::ptr_arg)] -#[allow(clippy::vec_init_then_push)] -#[allow(dead_code)] -fn merkle_root(leafs: &Vec, start: usize, stop: usize) -> Digest { - // #[allow(unused_assignments)] - // let mut result: Digest = Digest::default(); - let result: Digest = if stop == start + 1usize { - leafs[start] - } else { - let half: usize = (stop - start) / 2; - let left: Digest = merkle_root(leafs, start, stop - half); - let right: Digest = merkle_root(leafs, start + half, stop); - Tip5::hash_pair(left, right) - }; - - return result; -} - #[cfg(test)] mod test { use crate::tests_and_benchmarks::ozk::ozk_parsing::compile_to_basic_snippet; use crate::tests_and_benchmarks::ozk::ozk_parsing::EntrypointLocation; + use crate::twenty_first::prelude::*; + use tasm_lib::twenty_first::prelude::AlgebraicHasher; + use tasm_lib::Digest; + /// Output the `Merkle root` implementation as a `BasicSnippet` implementation. #[test] - fn merkle_root_to_basic_snippet_test() { + fn merkle_root_to_basic_snippet() { let entrypoint_location = - EntrypointLocation::disk("recufier", "merkle_root_autogen", "merkle_root"); + EntrypointLocation::disk("recufier", "merkle_root_autogen", "test::merkle_root"); let rust_ast = entrypoint_location.extract_entrypoint(); let as_bs = compile_to_basic_snippet(rust_ast, std::collections::HashMap::default()); println!("{}", as_bs); } + + #[allow(clippy::ptr_arg)] + #[allow(dead_code)] + fn merkle_root(leafs: &Vec, start: usize, stop: usize) -> Digest { + // #[allow(unused_assignments)] + // let mut result: Digest = Digest::default(); + let result: Digest = if stop == start + 1usize { + leafs[start] + } else { + let half: usize = (stop - start) / 2; + let left: Digest = merkle_root(leafs, start, stop - half); + let right: Digest = merkle_root(leafs, start + half, stop); + Tip5::hash_pair(left, right) + }; + + return result; + } } diff --git a/src/tests_and_benchmarks/ozk/programs/recufier/vm_proof_iter_next_as.rs b/src/tests_and_benchmarks/ozk/programs/recufier/vm_proof_iter_next_as.rs index 7d84f89a..0219e889 100644 --- a/src/tests_and_benchmarks/ozk/programs/recufier/vm_proof_iter_next_as.rs +++ b/src/tests_and_benchmarks/ozk/programs/recufier/vm_proof_iter_next_as.rs @@ -1,114 +1,3 @@ -use tasm_lib::triton_vm::prelude::*; -use tasm_lib::triton_vm::proof_item::FriResponse; -use tasm_lib::triton_vm::proof_item::ProofItem; -use tasm_lib::triton_vm::proof_stream::ProofStream; - -use crate::tests_and_benchmarks::ozk::rust_shadows as tasm; -use crate::tests_and_benchmarks::ozk::rust_shadows::Tip5WithState; -use crate::tests_and_benchmarks::ozk::rust_shadows::VmProofIter; - -fn call_all_next_methods() { - Tip5WithState::init(); - let vm_proof_iter_stack: VmProofIter = VmProofIter { - current_item_pointer: BFieldElement::new(2), - }; - let mut vm_proof_iter: Box = Box::::new(vm_proof_iter_stack); - let a_merkle_root: Box = vm_proof_iter.next_as_merkleroot(); - tasm::tasm_io_write_to_stdout___digest(*a_merkle_root); - - let out_of_domain_base_row: Box> = - vm_proof_iter.next_as_outofdomainbaserow(); - tasm::tasm_io_write_to_stdout___xfe(out_of_domain_base_row[0]); - tasm::tasm_io_write_to_stdout___xfe(out_of_domain_base_row[1]); - - let out_of_domain_ext_row: Box> = - vm_proof_iter.next_as_outofdomainextrow(); - tasm::tasm_io_write_to_stdout___xfe(out_of_domain_ext_row[0]); - tasm::tasm_io_write_to_stdout___xfe(out_of_domain_ext_row[1]); - - let out_of_domain_quotient_segments: Box<[XFieldElement; 4]> = - vm_proof_iter.next_as_outofdomainquotientsegments(); - tasm::tasm_io_write_to_stdout___xfe(out_of_domain_quotient_segments[0]); - tasm::tasm_io_write_to_stdout___xfe(out_of_domain_quotient_segments[1]); - tasm::tasm_io_write_to_stdout___xfe(out_of_domain_quotient_segments[2]); - tasm::tasm_io_write_to_stdout___xfe(out_of_domain_quotient_segments[3]); - - let authentication_structure: Box> = - vm_proof_iter.next_as_authenticationstructure(); - tasm::tasm_io_write_to_stdout___digest(authentication_structure[0]); - tasm::tasm_io_write_to_stdout___digest(authentication_structure[1]); - tasm::tasm_io_write_to_stdout___digest(authentication_structure[2]); - - let mbtw: Box> = vm_proof_iter.next_as_masterbasetablerows(); - { - let mut j: usize = 0; - while j < mbtw.len() { - let mut i: usize = 0; - while i < mbtw[j].len() { - tasm::tasm_io_write_to_stdout___bfe(mbtw[j][i]); - i += 1; - } - j += 1; - } - } - - let metr: Box> = vm_proof_iter.next_as_masterexttablerows(); - { - let mut j: usize = 0; - while j < metr.len() { - let mut i: usize = 0; - while i < metr[j].len() { - tasm::tasm_io_write_to_stdout___xfe(metr[j][i]); - i += 1; - } - j += 1; - } - } - - let log2paddedheight: Box = vm_proof_iter.next_as_log2paddedheight(); - tasm::tasm_io_write_to_stdout___u32(*log2paddedheight); - - let quotient_segments_elements: Box> = - vm_proof_iter.next_as_quotientsegmentselements(); - { - let mut j: usize = 0; - while j < quotient_segments_elements.len() { - tasm::tasm_io_write_to_stdout___xfe(quotient_segments_elements[j][0]); - tasm::tasm_io_write_to_stdout___xfe(quotient_segments_elements[j][1]); - tasm::tasm_io_write_to_stdout___xfe(quotient_segments_elements[j][2]); - tasm::tasm_io_write_to_stdout___xfe(quotient_segments_elements[j][3]); - j += 1; - } - } - - let fri_codeword: Box> = vm_proof_iter.next_as_fricodeword(); - { - let mut j: usize = 0; - while j < fri_codeword.len() { - tasm::tasm_io_write_to_stdout___xfe(fri_codeword[j]); - j += 1; - } - } - - let fri_response: Box = vm_proof_iter.next_as_friresponse(); - { - let mut j: usize = 0; - while j < fri_response.auth_structure.len() { - tasm::tasm_io_write_to_stdout___digest(fri_response.auth_structure[j]); - j += 1; - } - } - { - let mut j: usize = 0; - while j < fri_response.revealed_leaves.len() { - tasm::tasm_io_write_to_stdout___xfe(fri_response.revealed_leaves[j]); - j += 1; - } - } - - return; -} - #[cfg(test)] mod test { use itertools::Itertools; @@ -124,13 +13,124 @@ mod test { use crate::triton_vm::table::QuotientSegments; use crate::triton_vm::table::NUM_BASE_COLUMNS; use crate::triton_vm::table::NUM_EXT_COLUMNS; + use tasm_lib::triton_vm::prelude::*; + use tasm_lib::triton_vm::proof_item::ProofItem; + use tasm_lib::triton_vm::proof_stream::ProofStream; + + use crate::tests_and_benchmarks::ozk::rust_shadows as tasm; + use crate::tests_and_benchmarks::ozk::rust_shadows::Tip5WithState; + use crate::tests_and_benchmarks::ozk::rust_shadows::VmProofIter; + + /// The function being tested here. Dual-compiled by `rustc` and `tasm-lang`. + fn call_all_next_methods() { + Tip5WithState::init(); + let vm_proof_iter_stack: VmProofIter = VmProofIter { + current_item_pointer: BFieldElement::new(2), + }; + let mut vm_proof_iter: Box = Box::::new(vm_proof_iter_stack); + let a_merkle_root: Box = vm_proof_iter.next_as_merkleroot(); + tasm::tasm_io_write_to_stdout___digest(*a_merkle_root); + + let out_of_domain_base_row: Box> = + vm_proof_iter.next_as_outofdomainbaserow(); + tasm::tasm_io_write_to_stdout___xfe(out_of_domain_base_row[0]); + tasm::tasm_io_write_to_stdout___xfe(out_of_domain_base_row[1]); + + let out_of_domain_ext_row: Box> = + vm_proof_iter.next_as_outofdomainextrow(); + tasm::tasm_io_write_to_stdout___xfe(out_of_domain_ext_row[0]); + tasm::tasm_io_write_to_stdout___xfe(out_of_domain_ext_row[1]); + + let out_of_domain_quotient_segments: Box<[XFieldElement; 4]> = + vm_proof_iter.next_as_outofdomainquotientsegments(); + tasm::tasm_io_write_to_stdout___xfe(out_of_domain_quotient_segments[0]); + tasm::tasm_io_write_to_stdout___xfe(out_of_domain_quotient_segments[1]); + tasm::tasm_io_write_to_stdout___xfe(out_of_domain_quotient_segments[2]); + tasm::tasm_io_write_to_stdout___xfe(out_of_domain_quotient_segments[3]); + + let authentication_structure: Box> = + vm_proof_iter.next_as_authenticationstructure(); + tasm::tasm_io_write_to_stdout___digest(authentication_structure[0]); + tasm::tasm_io_write_to_stdout___digest(authentication_structure[1]); + tasm::tasm_io_write_to_stdout___digest(authentication_structure[2]); + + let mbtw: Box> = vm_proof_iter.next_as_masterbasetablerows(); + { + let mut j: usize = 0; + while j < mbtw.len() { + let mut i: usize = 0; + while i < mbtw[j].len() { + tasm::tasm_io_write_to_stdout___bfe(mbtw[j][i]); + i += 1; + } + j += 1; + } + } + + let metr: Box> = vm_proof_iter.next_as_masterexttablerows(); + { + let mut j: usize = 0; + while j < metr.len() { + let mut i: usize = 0; + while i < metr[j].len() { + tasm::tasm_io_write_to_stdout___xfe(metr[j][i]); + i += 1; + } + j += 1; + } + } + + let log2paddedheight: Box = vm_proof_iter.next_as_log2paddedheight(); + tasm::tasm_io_write_to_stdout___u32(*log2paddedheight); - use super::*; + let quotient_segments_elements: Box> = + vm_proof_iter.next_as_quotientsegmentselements(); + { + let mut j: usize = 0; + while j < quotient_segments_elements.len() { + tasm::tasm_io_write_to_stdout___xfe(quotient_segments_elements[j][0]); + tasm::tasm_io_write_to_stdout___xfe(quotient_segments_elements[j][1]); + tasm::tasm_io_write_to_stdout___xfe(quotient_segments_elements[j][2]); + tasm::tasm_io_write_to_stdout___xfe(quotient_segments_elements[j][3]); + j += 1; + } + } + + let fri_codeword: Box> = vm_proof_iter.next_as_fricodeword(); + { + let mut j: usize = 0; + while j < fri_codeword.len() { + tasm::tasm_io_write_to_stdout___xfe(fri_codeword[j]); + j += 1; + } + } + + let fri_response: Box = vm_proof_iter.next_as_friresponse(); + { + let mut j: usize = 0; + while j < fri_response.auth_structure.len() { + tasm::tasm_io_write_to_stdout___digest(fri_response.auth_structure[j]); + j += 1; + } + } + { + let mut j: usize = 0; + while j < fri_response.revealed_leaves.len() { + tasm::tasm_io_write_to_stdout___xfe(fri_response.revealed_leaves[j]); + j += 1; + } + } + + return; + } #[test] fn test_all_next_as_methods() { - let entrypoint_location = - EntrypointLocation::disk("recufier", "vm_proof_iter_next_as", "call_all_next_methods"); + let entrypoint_location = EntrypointLocation::disk( + "recufier", + "vm_proof_iter_next_as", + "test::call_all_next_methods", + ); let test_case = TritonVMTestCase::new(entrypoint_location); let non_determinism = non_determinism(); { diff --git a/src/tests_and_benchmarks/ozk/programs/recufier/xfe_ntt_recursive.rs b/src/tests_and_benchmarks/ozk/programs/recufier/xfe_ntt_recursive.rs index 242b5ce2..1b22b5c6 100644 --- a/src/tests_and_benchmarks/ozk/programs/recufier/xfe_ntt_recursive.rs +++ b/src/tests_and_benchmarks/ozk/programs/recufier/xfe_ntt_recursive.rs @@ -1,150 +1,20 @@ -use num::One; -use tasm_lib::triton_vm::prelude::*; -use tasm_lib::twenty_first::prelude::ModPowU32; - -use crate::tests_and_benchmarks::ozk::rust_shadows as tasm; - -#[allow(clippy::ptr_arg)] -#[allow(clippy::vec_init_then_push)] -fn main() { - fn base_case(x: &Vec, omega: BFieldElement) -> Vec { - let mut ret: Vec = Vec::::default(); - ret.push(x[0] + x[1]); - ret.push(x[0] + x[1] * omega); - - return ret; - } - - fn xfe_ntt(x: Vec, omega: BFieldElement) -> Vec { - let size: usize = x.len(); - let res: Vec = if size == 2 { - base_case(&x, omega) - } else { - // Split by parity - let mut x_even: Vec = Vec::::default(); - let mut x_odd: Vec = Vec::::default(); - let mut i: usize = 0; - while i < size { - if i % 2 == 0 { - x_even.push(x[i]); - } else { - x_odd.push(x[i]); - } - i += 1; - } - - // Recursive call - let omega_squared: BFieldElement = omega * omega; - let even: Vec = xfe_ntt(x_even, omega_squared); - let odd: Vec = xfe_ntt(x_odd, omega_squared); - - // Calculate all values omega^j, for j=0..size - let mut factor_values: Vec = Vec::::default(); - i = 0; - let mut pow: BFieldElement = BFieldElement::one(); - while i < size { - factor_values.push(pow); - pow *= omega; - i += 1; - } - - // Split by middle - let mut fst_half_factors: Vec = Vec::::default(); - let mut snd_half_factors: Vec = Vec::::default(); - i = 0; - while i != size / 2 { - fst_half_factors.push(factor_values[i]); - i += 1; - } - while i != size { - snd_half_factors.push(factor_values[i]); - i += 1; - } - - // hadamard products - let mut res: Vec = Vec::::default(); - i = 0; - while i != size / 2 { - res.push(even[i] + odd[i] * fst_half_factors[i]); - i += 1; - } - i = 0; - while i != size / 2 { - res.push(even[i] + odd[i] * snd_half_factors[i]); - i += 1; - } - - res - }; - - return res; - } - - fn xfe_intt(x: Vec, omega: BFieldElement) -> Vec { - let length: usize = x.len(); - let xfe_length: XFieldElement = BFieldElement::new(length as u64).lift(); - let omega_inv: BFieldElement = BFieldElement::one() / omega; - let mut res: Vec = xfe_ntt(x, omega_inv); - - let mut i: usize = 0; - while i < length { - res[i] = res[i] / xfe_length; - i += 1; - } - - return res; - } - - // NTT is equivalent to polynomial evaluation over the field generated by the `omega` generator - // where the input values are interpreted as coefficients. So an input of `[C, 0]` must output - // `[C, C]`, as the output is $P(x) = C$. - let omega: BFieldElement = tasm::tasm_io_read_stdin___bfe(); - let input: Box> = Vec::::decode(&tasm::load_from_memory( - BFieldElement::new(0x1000_0000_0000_0000u64), - )) - .unwrap(); - let output: Vec = xfe_ntt(*input, omega); - let size: usize = output.len(); - assert!(BFieldElement::one() == omega.mod_pow_u32(size as u32)); - - let mut i: usize = 0; - - while i < size { - tasm::tasm_io_write_to_stdout___xfe(output[i]); - i += 1; - } - - // We only output the NTT for the test, but we test that `xfe_intt` produces the - // inverse of `xfe_ntt`. - let input_again: Vec = xfe_intt(output, omega); - let input_copied: Box> = Vec::::decode( - &tasm::load_from_memory(BFieldElement::new(0x1000_0000_0000_0000u64)), - ) - .unwrap(); - i = 0; - while i < size { - assert!(input_copied[i] == input_again[i]); - i += 1; - } - - return; -} - #[cfg(test)] mod test { use itertools::Itertools; + use num::One; use tasm_lib::triton_vm::prelude::*; + use tasm_lib::twenty_first::prelude::ModPowU32; use tasm_lib::twenty_first::shared_math::ntt; use tasm_lib::twenty_first::shared_math::other::log_2_floor; use tasm_lib::twenty_first::shared_math::other::random_elements; use tasm_lib::twenty_first::shared_math::traits::PrimitiveRootOfUnity; + use crate::tests_and_benchmarks::ozk::rust_shadows as tasm; + use crate::tests_and_benchmarks::ozk::ozk_parsing::EntrypointLocation; use crate::tests_and_benchmarks::ozk::rust_shadows; use crate::tests_and_benchmarks::test_helpers::shared_test::*; - use super::*; - #[test] fn recursive_xfe_ntt_test() { // Test function on host machine @@ -177,7 +47,7 @@ mod test { // Test function in Triton VM let entrypoint_location = - EntrypointLocation::disk("recufier", "xfe_ntt_recursive", "main"); + EntrypointLocation::disk("recufier", "xfe_ntt_recursive", "test::main"); let rust_ast = entrypoint_location.extract_entrypoint(); let expected_stack_diff = 0; let (code, _fn_name) = compile_for_run_test(&rust_ast); @@ -192,4 +62,131 @@ mod test { assert_eq!(native_output, vm_output.public_output); } } + + /// Dual-compiled function that implements NTT efficiently + #[allow(clippy::ptr_arg)] + #[allow(clippy::vec_init_then_push)] + fn main() { + fn base_case(x: &Vec, omega: BFieldElement) -> Vec { + let mut ret: Vec = Vec::::default(); + ret.push(x[0] + x[1]); + ret.push(x[0] + x[1] * omega); + + return ret; + } + + fn xfe_ntt(x: Vec, omega: BFieldElement) -> Vec { + let size: usize = x.len(); + let res: Vec = if size == 2 { + base_case(&x, omega) + } else { + // Split by parity + let mut x_even: Vec = Vec::::default(); + let mut x_odd: Vec = Vec::::default(); + let mut i: usize = 0; + while i < size { + if i % 2 == 0 { + x_even.push(x[i]); + } else { + x_odd.push(x[i]); + } + i += 1; + } + + // Recursive call + let omega_squared: BFieldElement = omega * omega; + let even: Vec = xfe_ntt(x_even, omega_squared); + let odd: Vec = xfe_ntt(x_odd, omega_squared); + + // Calculate all values omega^j, for j=0..size + let mut factor_values: Vec = Vec::::default(); + i = 0; + let mut pow: BFieldElement = BFieldElement::one(); + while i < size { + factor_values.push(pow); + pow *= omega; + i += 1; + } + + // Split by middle + let mut fst_half_factors: Vec = Vec::::default(); + let mut snd_half_factors: Vec = Vec::::default(); + i = 0; + while i != size / 2 { + fst_half_factors.push(factor_values[i]); + i += 1; + } + while i != size { + snd_half_factors.push(factor_values[i]); + i += 1; + } + + // hadamard products + let mut res: Vec = Vec::::default(); + i = 0; + while i != size / 2 { + res.push(even[i] + odd[i] * fst_half_factors[i]); + i += 1; + } + i = 0; + while i != size / 2 { + res.push(even[i] + odd[i] * snd_half_factors[i]); + i += 1; + } + + res + }; + + return res; + } + + fn xfe_intt(x: Vec, omega: BFieldElement) -> Vec { + let length: usize = x.len(); + let xfe_length: XFieldElement = BFieldElement::new(length as u64).lift(); + let omega_inv: BFieldElement = BFieldElement::one() / omega; + let mut res: Vec = xfe_ntt(x, omega_inv); + + let mut i: usize = 0; + while i < length { + res[i] = res[i] / xfe_length; + i += 1; + } + + return res; + } + + // NTT is equivalent to polynomial evaluation over the field generated by the `omega` generator + // where the input values are interpreted as coefficients. So an input of `[C, 0]` must output + // `[C, C]`, as the output is $P(x) = C$. + let omega: BFieldElement = tasm::tasm_io_read_stdin___bfe(); + let input: Box> = Vec::::decode(&tasm::load_from_memory( + BFieldElement::new(0x1000_0000_0000_0000u64), + )) + .unwrap(); + let output: Vec = xfe_ntt(*input, omega); + let size: usize = output.len(); + assert!(BFieldElement::one() == omega.mod_pow_u32(size as u32)); + + let mut i: usize = 0; + + while i < size { + tasm::tasm_io_write_to_stdout___xfe(output[i]); + i += 1; + } + + // We only output the NTT for the test, but we test that `xfe_intt` produces the + // inverse of `xfe_ntt`. + let input_again: Vec = xfe_intt(output, omega); + let input_copied: Box> = Vec::::decode( + &tasm::load_from_memory(BFieldElement::new(0x1000_0000_0000_0000u64)), + ) + .unwrap(); + i = 0; + while i < size { + assert!(input_copied[i] == input_again[i]); + i += 1; + } + + return; + } }