Skip to content

Commit

Permalink
refactor(test::recufier): Improve separation of tests and helper func…
Browse files Browse the repository at this point in the history
…tions

Move all functions that only serve as entrypoints for tests under a
`#[cfg(test)]` annotation to make it easier for developers to reuse code
from the `recufier` directory. The whole `recufier` directory is
technically all test-code for this compiler, but for now we actually use
these tests to build the recufier that Triton-VM needs.
  • Loading branch information
Sword-Smith committed Mar 21, 2024
1 parent 3a6992c commit c8df1e9
Show file tree
Hide file tree
Showing 9 changed files with 503 additions and 525 deletions.
1 change: 0 additions & 1 deletion src/tests_and_benchmarks/ozk/programs/recufier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ mod eval_arg;
mod fast_ntt;
mod fast_ntt_to_basic_snippet;
mod fri_verify;
mod host_machine_vm_proof_iter;
mod merkle_root;
mod merkle_root_autogen;
pub(crate) mod stark_parameters;
Expand Down
250 changes: 124 additions & 126 deletions src/tests_and_benchmarks/ozk/programs/recufier/fast_ntt.rs
Original file line number Diff line number Diff line change
@@ -1,149 +1,146 @@
#![allow(clippy::manual_swap)]

use num::One;
use tasm_lib::triton_vm::prelude::*;
use tasm_lib::twenty_first::prelude::ModPowU32;

use crate::tests_and_benchmarks::ozk::rust_shadows as tasm;

#[allow(clippy::ptr_arg)]
#[allow(clippy::vec_init_then_push)]
fn main() {
fn xfe_ntt(x: &mut Vec<XFieldElement>, omega: BFieldElement, log_2_of_n: u32) {
fn bitreverse(mut n: u32, l: u32) -> u32 {
let mut r: u32 = 0;
let mut i: u32 = 0;
while i < l {
r = (r << 1) | (n & 1);
n >>= 1;
i += 1;
}

return r;
}
#[cfg(test)]
mod test {
use crate::tests_and_benchmarks::ozk::ozk_parsing::EntrypointLocation;
use crate::tests_and_benchmarks::ozk::rust_shadows;
use crate::tests_and_benchmarks::test_helpers::shared_test::*;
use itertools::Itertools;
use num::One;
use tasm_lib::triton_vm::prelude::*;
use tasm_lib::twenty_first::prelude::ModPowU32;
use tasm_lib::twenty_first::shared_math::ntt;
use tasm_lib::twenty_first::shared_math::other::log_2_floor;
use tasm_lib::twenty_first::shared_math::other::random_elements;
use tasm_lib::twenty_first::shared_math::traits::PrimitiveRootOfUnity;

let n: u32 = x.len() as u32;

{
let mut k: u32 = 0;
while k != n {
let rk: u32 = bitreverse(k, log_2_of_n);
if k < rk {
// TODO: Use `swap` here instead, once it's implemented in `tasm-lib`
// That will give us a shorter cycle count
// x.swap(rk as usize, k as usize);
let rk_val: XFieldElement = x[rk as usize];
x[rk as usize] = x[k as usize];
x[k as usize] = rk_val;
use crate::tests_and_benchmarks::ozk::rust_shadows as tasm;

#[allow(clippy::ptr_arg)]
#[allow(clippy::manual_swap)]
#[allow(clippy::vec_init_then_push)]
fn main() {
fn xfe_ntt(x: &mut Vec<XFieldElement>, omega: BFieldElement, log_2_of_n: u32) {
fn bitreverse(mut n: u32, l: u32) -> u32 {
let mut r: u32 = 0;
let mut i: u32 = 0;
while i < l {
r = (r << 1) | (n & 1);
n >>= 1;
i += 1;
}

k += 1;
return r;
}
}

let mut m: u32 = 1;

let mut outer_count: u32 = 0;
while outer_count != log_2_of_n {
// for _ in 0..log_2_of_n {
let w_m: BFieldElement = omega.mod_pow_u32(n / (2 * m));
let mut k: u32 = 0;
while k < n {
let mut w: BFieldElement = BFieldElement::one();
let mut j: u32 = 0;
while j != m {
// for j in 0..m {
let u: XFieldElement = x[(k + j) as usize];
let mut v: XFieldElement = x[(k + j + m) as usize];
v *= w;
x[(k + j) as usize] = u + v;
x[(k + j + m) as usize] = u - v;
w *= w_m;

j += 1;
let n: u32 = x.len() as u32;

{
let mut k: u32 = 0;
while k != n {
let rk: u32 = bitreverse(k, log_2_of_n);
if k < rk {
// TODO: Use `swap` here instead, once it's implemented in `tasm-lib`
// That will give us a shorter cycle count
// x.swap(rk as usize, k as usize);
let rk_val: XFieldElement = x[rk as usize];
x[rk as usize] = x[k as usize];
x[k as usize] = rk_val;
}

k += 1;
}

k += 2 * m;
}

m *= 2;
let mut m: u32 = 1;

let mut outer_count: u32 = 0;
while outer_count != log_2_of_n {
// for _ in 0..log_2_of_n {
let w_m: BFieldElement = omega.mod_pow_u32(n / (2 * m));
let mut k: u32 = 0;
while k < n {
let mut w: BFieldElement = BFieldElement::one();
let mut j: u32 = 0;
while j != m {
// for j in 0..m {
let u: XFieldElement = x[(k + j) as usize];
let mut v: XFieldElement = x[(k + j + m) as usize];
v *= w;
x[(k + j) as usize] = u + v;
x[(k + j + m) as usize] = u - v;
w *= w_m;

j += 1;
}

k += 2 * m;
}

m *= 2;

outer_count += 1;
outer_count += 1;
}

return;
}

return;
}
fn xfe_intt(x: &mut Vec<XFieldElement>, omega: BFieldElement, log_2_of_n: u32) {
{
let omega_inv: BFieldElement = BFieldElement::one() / omega;
xfe_ntt(x, omega_inv, log_2_of_n);
}

let n_inv: BFieldElement = {
let n: BFieldElement = BFieldElement::new(x.len() as u64);
BFieldElement::one() / n
};

fn xfe_intt(x: &mut Vec<XFieldElement>, omega: BFieldElement, log_2_of_n: u32) {
{
let omega_inv: BFieldElement = BFieldElement::one() / omega;
xfe_ntt(x, omega_inv, log_2_of_n);
let mut i: usize = 0;
let len: usize = x.len();
while i < len {
x[i] *= n_inv;
i += 1;
}

return;
}

let n_inv: BFieldElement = {
let n: BFieldElement = BFieldElement::new(x.len() as u64);
BFieldElement::one() / n
};
// NTT is equivalent to polynomial evaluation over the field generated by the `omega` generator
// where the input values are interpreted as coefficients. So an input of `[C, 0]` must output
// `[C, C]`, as the output is $P(x) = C$.
let omega: BFieldElement = tasm::tasm_io_read_stdin___bfe();
let input_output_boxed: Box<Vec<XFieldElement>> = Vec::<XFieldElement>::decode(
&tasm::load_from_memory(BFieldElement::new(0x1000_0000_0000_0000u64)),
)
.unwrap();
let mut input_output: Vec<XFieldElement> = *input_output_boxed;
let size: usize = input_output.len();
let log_2_size: u32 = u32::BITS - (size as u32).leading_zeros() - 1;
xfe_ntt(&mut input_output, omega, log_2_size);
assert!(BFieldElement::one() == omega.mod_pow_u32(size as u32));

let mut i: usize = 0;
let len: usize = x.len();
while i < len {
x[i] *= n_inv;

while i < size {
tasm::tasm_io_write_to_stdout___xfe(input_output[i]);
i += 1;
}

return;
}

// NTT is equivalent to polynomial evaluation over the field generated by the `omega` generator
// where the input values are interpreted as coefficients. So an input of `[C, 0]` must output
// `[C, C]`, as the output is $P(x) = C$.
let omega: BFieldElement = tasm::tasm_io_read_stdin___bfe();
let input_output_boxed: Box<Vec<XFieldElement>> = Vec::<XFieldElement>::decode(
&tasm::load_from_memory(BFieldElement::new(0x1000_0000_0000_0000u64)),
)
.unwrap();
let mut input_output: Vec<XFieldElement> = *input_output_boxed;
let size: usize = input_output.len();
let log_2_size: u32 = u32::BITS - (size as u32).leading_zeros() - 1;
xfe_ntt(&mut input_output, omega, log_2_size);
assert!(BFieldElement::one() == omega.mod_pow_u32(size as u32));

let mut i: usize = 0;

while i < size {
tasm::tasm_io_write_to_stdout___xfe(input_output[i]);
i += 1;
}
// We only output the NTT for the test, but we test that `xfe_intt` produces the
// inverse of `xfe_ntt`.
xfe_intt(&mut input_output, omega, log_2_size);
let input_copied: Box<Vec<XFieldElement>> = Vec::<XFieldElement>::decode(
&tasm::load_from_memory(BFieldElement::new(0x1000_0000_0000_0000u64)),
)
.unwrap();
i = 0;
while i < size {
assert!(input_copied[i] == input_output[i]);
i += 1;
}

// We only output the NTT for the test, but we test that `xfe_intt` produces the
// inverse of `xfe_ntt`.
xfe_intt(&mut input_output, omega, log_2_size);
let input_copied: Box<Vec<XFieldElement>> = Vec::<XFieldElement>::decode(
&tasm::load_from_memory(BFieldElement::new(0x1000_0000_0000_0000u64)),
)
.unwrap();
i = 0;
while i < size {
assert!(input_copied[i] == input_output[i]);
i += 1;
return;
}

return;
}

#[cfg(test)]
mod test {
use super::*;
use crate::tests_and_benchmarks::ozk::ozk_parsing::EntrypointLocation;
use crate::tests_and_benchmarks::ozk::rust_shadows;
use crate::tests_and_benchmarks::test_helpers::shared_test::*;
use itertools::Itertools;
use tasm_lib::twenty_first::shared_math::ntt;
use tasm_lib::twenty_first::shared_math::other::log_2_floor;
use tasm_lib::twenty_first::shared_math::other::random_elements;
use tasm_lib::twenty_first::shared_math::traits::PrimitiveRootOfUnity;

#[test]
fn fast_xfe_ntt_test() {
for input_length in [2, 4, 8, 16, 32, 64, 128] {
Expand Down Expand Up @@ -175,7 +172,8 @@ mod test {
}

// Test function in Triton VM
let entrypoint_location = EntrypointLocation::disk("recufier", "fast_ntt", "main");
let entrypoint_location =
EntrypointLocation::disk("recufier", "fast_ntt", "test::main");
let rust_ast = entrypoint_location.extract_entrypoint();
let expected_stack_diff = 0;
let (code, _fn_name) = compile_for_run_test(&rust_ast);
Expand Down Expand Up @@ -219,7 +217,7 @@ mod benches {
}
}

let entrypoint_location = EntrypointLocation::disk("recufier", "fast_ntt", "main");
let entrypoint_location = EntrypointLocation::disk("recufier", "fast_ntt", "test::main");
let code = ozk_parsing::compile_for_test(&entrypoint_location);

let common_case_input = get_input(32);
Expand Down
Loading

0 comments on commit c8df1e9

Please sign in to comment.