Skip to content

Commit

Permalink
reintroduced ntt tests against risc0 and lambdaworks
Browse files Browse the repository at this point in the history
  • Loading branch information
yshekel committed Jul 2, 2024
1 parent 1d6850c commit 4ce39bc
Show file tree
Hide file tree
Showing 15 changed files with 154 additions and 32 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/v3_rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,8 @@ jobs:
- name: Run tests
working-directory: ./wrappers/rust_v3
if: needs.check-changed-files.outputs.rust == 'true' || needs.check-changed-files.outputs.cpp_cuda == 'true'
run: cargo build --release --verbose --features=g2,ec_ntt && cargo test --workspace --release --verbose --features=g2,ec_ntt
# tests are split to phases since NTT domain is global but tests have conflicting requirements
run: |
cargo build --release --verbose --features=g2,ec_ntt
cargo test --workspace --release --verbose --features=g2,ec_ntt -- --skip phase
cargo test phase2 --workspace --release --verbose --features=g2,ec_ntt
4 changes: 2 additions & 2 deletions icicle_v3/include/icicle/ntt.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,12 @@ namespace icicle {
NTTConfig<scalar_t>& config,
extension_t* output)>;

void register_ntt_ext_field(const std::string& deviceType, NttExtFieldImpl impl);
void register_extension_ntt(const std::string& deviceType, NttExtFieldImpl impl);

#define REGISTER_NTT_EXT_FIELD_BACKEND(DEVICE_TYPE, FUNC) \
namespace { \
static bool UNIQUE(_reg_ntt_ext_field) = []() -> bool { \
register_ntt_ext_field(DEVICE_TYPE, FUNC); \
register_extension_ntt(DEVICE_TYPE, FUNC); \
return true; \
}(); \
}
Expand Down
6 changes: 3 additions & 3 deletions icicle_v3/src/ntt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ namespace icicle {
}

#ifdef EXT_FIELD
ICICLE_DISPATCHER_INST(NttExtFieldDispatcher, ntt_ext_field, NttExtFieldImpl);
ICICLE_DISPATCHER_INST(NttExtFieldDispatcher, extension_ntt, NttExtFieldImpl);

extern "C" eIcicleError CONCAT_EXPAND(FIELD, ntt_ext_field)(
extern "C" eIcicleError CONCAT_EXPAND(FIELD, extension_ntt)(
const extension_t* input, int size, NTTDir dir, NTTConfig<scalar_t>& config, extension_t* output)
{
return NttExtFieldDispatcher::execute(input, size, dir, config, output);
Expand All @@ -30,7 +30,7 @@ namespace icicle {
template <>
eIcicleError ntt(const extension_t* input, int size, NTTDir dir, NTTConfig<scalar_t>& config, extension_t* output)
{
return CONCAT_EXPAND(FIELD, ntt_ext_field)(input, size, dir, config, output);
return CONCAT_EXPAND(FIELD, extension_ntt)(input, size, dir, config, output);
}
#endif // EXT_FIELD

Expand Down
11 changes: 6 additions & 5 deletions wrappers/rust_v3/icicle-core/src/ntt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,14 @@ macro_rules! impl_ntt_tests {
) => {
use icicle_core::test_utilities;
use icicle_runtime::{device::Device, runtime};
use std::sync::{Once};

const MAX_SIZE: u64 = 1 << 17;
static INIT: OnceLock<()> = OnceLock::new();
static INIT: Once = Once::new();
const FAST_TWIDDLES_MODE: bool = false;

pub fn initialize() {
INIT.get_or_init(move || {
pub fn initialize() {
INIT.call_once(move || {
test_utilities::test_load_and_init_devices();
// init domain for both devices
test_utilities::test_set_ref_device();
Expand Down Expand Up @@ -427,8 +428,8 @@ macro_rules! impl_ntt_tests {
#[test]
#[serial]
fn test_ntt_release_domain() {
initialize();
// check_release_domain::<$field>()
// initialize();
// check_release_domain::<$field>()
}
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ pub(crate) mod tests {
use crate::curve::{CurveCfg, ScalarField};

use icicle_core::ecntt::tests::*;
use icicle_core::impl_ecntt_tests;
use std::sync::OnceLock;
use icicle_core::impl_ecntt_tests;

impl_ecntt_tests!(ScalarField, CurveCfg);
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ pub(crate) mod tests {
use icicle_core::impl_ntt_tests;
use icicle_core::ntt::tests::*;
use serial_test::{parallel, serial};
use std::sync::OnceLock;


impl_ntt_tests!(ScalarField);
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ pub(crate) mod tests {
use crate::curve::{CurveCfg, ScalarField};

use icicle_core::ecntt::tests::*;
use icicle_core::impl_ecntt_tests;
use std::sync::OnceLock;
use icicle_core::impl_ecntt_tests;

impl_ecntt_tests!(ScalarField, CurveCfg);
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ pub(crate) mod tests {
use icicle_core::impl_ntt_tests;
use icicle_core::ntt::tests::*;
use serial_test::{parallel, serial};
use std::sync::OnceLock;

impl_ntt_tests!(ScalarField);
}
3 changes: 1 addition & 2 deletions wrappers/rust_v3/icicle-curves/icicle-bn254/src/ecntt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ pub(crate) mod tests {
use crate::curve::{CurveCfg, ScalarField};

use icicle_core::ecntt::tests::*;
use icicle_core::impl_ecntt_tests;
use std::sync::OnceLock;
use icicle_core::impl_ecntt_tests;

impl_ecntt_tests!(ScalarField, CurveCfg);
}
1 change: 0 additions & 1 deletion wrappers/rust_v3/icicle-curves/icicle-bn254/src/ntt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ pub(crate) mod tests {
use icicle_core::impl_ntt_tests;
use icicle_core::ntt::tests::*;
use serial_test::{parallel, serial};
use std::sync::OnceLock;

impl_ntt_tests!(ScalarField);
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ pub(crate) mod tests {
use icicle_core::impl_ntt_tests;
use icicle_core::ntt::tests::*;
use serial_test::{parallel, serial};
use std::sync::OnceLock;

impl_ntt_tests!(ScalarField);
}
2 changes: 2 additions & 0 deletions wrappers/rust_v3/icicle-fields/icicle-babybear/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ icicle-runtime = { workspace = true }
[dev-dependencies]
criterion = "0.3"
serial_test = "3.0.0"
risc0-core = "0.21.0"
risc0-zkp = "0.21.0"

[build-dependencies]
cmake = "0.1.50"
85 changes: 79 additions & 6 deletions wrappers/rust_v3/icicle-fields/icicle-babybear/src/ntt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,84 @@ pub(crate) mod tests {
use icicle_core::impl_ntt_tests;
use icicle_core::ntt::tests::*;
use serial_test::{parallel, serial};
use std::sync::OnceLock;

impl_ntt_tests!(ScalarField);
}

// TODO Yuval : V2 has tests against plonky3, do we still need it?
// Note that the NTT tests could not work for babybear since they rely on arkworks which is not implementing babybear
// UPDATE: team decided to keep it
// Tests against risc0 and plonky3
use super::{ExtensionField};
use icicle_core::{
ntt::{initialize_domain, ntt_inplace, release_domain, NTTConfig, NTTInitDomainConfig, NTTDir},
traits::{FieldImpl, GenerateRandom}
};
use icicle_runtime::{memory::HostSlice};
use risc0_core::field::{
baby_bear::{Elem, ExtElem},
Elem as FieldElem, RootsOfUnity,
};

// Note that risc0 and plonky3 tests shouldn't be ran simultaneously in parallel to other ntt tests as they use different roots of unity.
#[test]
#[serial]
fn phase2_test_ntt_against_risc0() {
test_utilities::test_load_and_init_devices();
test_utilities::test_set_main_device();

release_domain::<ScalarField>().unwrap(); // release domain from previous tests, if exists

let log_sizes = [15, 20];
let risc0_rou = Elem::ROU_FWD[log_sizes[1]];
initialize_domain(ScalarField::from([risc0_rou.as_u32()]), &NTTInitDomainConfig::default()).unwrap();
for log_size in log_sizes {
let ntt_size = 1 << log_size;

let mut scalars: Vec<ScalarField> = <ScalarField as FieldImpl>::Config::generate_random(ntt_size);
let mut scalars_risc0: Vec<Elem> = scalars
.iter()
.map(|x| Elem::new(Into::<[u32; 1]>::into(*x)[0]))
.collect();

let ntt_cfg: NTTConfig<ScalarField> = NTTConfig::default();
ntt_inplace(HostSlice::from_mut_slice(&mut scalars[..]), NTTDir::kForward, &ntt_cfg).unwrap();

risc0_zkp::core::ntt::bit_reverse(&mut scalars_risc0[..]);
risc0_zkp::core::ntt::evaluate_ntt::<Elem, Elem>(&mut scalars_risc0[..], ntt_size);

for (s1, s2) in scalars
.iter()
.zip(scalars_risc0)
{
assert_eq!(Into::<[u32; 1]>::into(*s1)[0], s2.as_u32());
}

let mut ext_scalars: Vec<ExtensionField> = <ExtensionField as FieldImpl>::Config::generate_random(ntt_size);
let mut ext_scalars_risc0: Vec<ExtElem> = ext_scalars
.iter()
.map(|x| ExtElem::from_u32_words(&Into::<[u32; 4]>::into(*x)[..]))
.collect();

ntt_inplace(
HostSlice::from_mut_slice(&mut ext_scalars[..]),
NTTDir::kForward,
&ntt_cfg,
)
.unwrap();

risc0_zkp::core::ntt::bit_reverse(&mut ext_scalars_risc0[..]);
risc0_zkp::core::ntt::evaluate_ntt::<Elem, ExtElem>(&mut ext_scalars_risc0[..], ntt_size);

for (s1, s2) in ext_scalars
.iter()
.zip(ext_scalars_risc0)
{
assert_eq!(Into::<[u32; 4]>::into(*s1)[..], s2.to_u32_words()[..]);
}
}

release_domain::<ScalarField>().unwrap();
}

// TODO test from V2. For some reason importing plonky3 Babybear cause an error
// #[test]
// #[serial]
// fn test_against_plonky3() {
// }
}
1 change: 1 addition & 0 deletions wrappers/rust_v3/icicle-fields/icicle-stark252/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ icicle-runtime = { workspace = true }
[dev-dependencies]
criterion = "0.3"
serial_test = "3.0.0"
lambdaworks-math = "0.6.0"

[build-dependencies]
cmake = "0.1.50"
56 changes: 52 additions & 4 deletions wrappers/rust_v3/icicle-fields/icicle-stark252/src/ntt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,58 @@ pub(crate) mod tests {
use icicle_core::impl_ntt_tests;
use icicle_core::ntt::tests::*;
use serial_test::{parallel, serial};
use std::sync::OnceLock;

impl_ntt_tests!(ScalarField);
}

// TODO Yuval : V2 has tests against lambdaworks, do we still need it?
// UPDATE: team decided to keep it

use icicle_core::{
ntt::{initialize_domain, ntt_inplace, release_domain, NTTConfig, NTTInitDomainConfig, NTTDir},
traits::{FieldImpl, GenerateRandom},
};
use icicle_runtime::memory::HostSlice;
use lambdaworks_math::{
field::{
element::FieldElement, fields::fft_friendly::stark_252_prime_field::Stark252PrimeField, traits::IsFFTField,
},
polynomial::Polynomial,
traits::ByteConversion,
};

pub type FE = FieldElement<Stark252PrimeField>;

#[test]
#[serial]
fn phase2_test_ntt_against_lambdaworks() {
test_utilities::test_load_and_init_devices();
test_utilities::test_set_main_device();

release_domain::<ScalarField>().unwrap(); // release domain from previous tests, if exists

let log_sizes = [15, 20];
let lw_root_of_unity = Stark252PrimeField::get_primitive_root_of_unity(log_sizes[log_sizes.len() - 1]).unwrap();
initialize_domain(ScalarField::from_bytes_le(&lw_root_of_unity.to_bytes_le()), &NTTInitDomainConfig::default()).unwrap();
for log_size in log_sizes {
let ntt_size = 1 << log_size;

let mut scalars: Vec<ScalarField> = <ScalarField as FieldImpl>::Config::generate_random(ntt_size);
let scalars_lw: Vec<FE> = scalars
.iter()
.map(|x| FieldElement::from_bytes_le(&x.to_bytes_le()).unwrap())
.collect();

let ntt_cfg: NTTConfig<ScalarField> = NTTConfig::default();
ntt_inplace(HostSlice::from_mut_slice(&mut scalars[..]), NTTDir::kForward, &ntt_cfg).unwrap();

let poly = Polynomial::new(&scalars_lw[..]);
let evaluations = Polynomial::evaluate_fft::<Stark252PrimeField>(&poly, 1, None).unwrap();

for (s1, s2) in scalars
.iter()
.zip(evaluations.iter())
{
assert_eq!(s1.to_bytes_le(), s2.to_bytes_le());
}
}
release_domain::<ScalarField>().unwrap();
}
}

0 comments on commit 4ce39bc

Please sign in to comment.