From d66aa411119f1c31a01c59e29d66decf1f1a8d69 Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Tue, 27 Aug 2024 19:51:27 +0800 Subject: [PATCH 1/4] add ntt for babybear and goldilocks --- algebra/src/baby_bear/mod.rs | 2 ++ algebra/src/goldilocks/mod.rs | 1 + 2 files changed, 3 insertions(+) diff --git a/algebra/src/baby_bear/mod.rs b/algebra/src/baby_bear/mod.rs index 291e3b0c..41c13578 100644 --- a/algebra/src/baby_bear/mod.rs +++ b/algebra/src/baby_bear/mod.rs @@ -1,6 +1,8 @@ +mod babybear_ntt; mod extension; pub use extension::BabyBearExetension; + use serde::{Deserialize, Serialize}; use std::{ diff --git a/algebra/src/goldilocks/mod.rs b/algebra/src/goldilocks/mod.rs index b9d1a209..155a838f 100644 --- a/algebra/src/goldilocks/mod.rs +++ b/algebra/src/goldilocks/mod.rs @@ -1,4 +1,5 @@ mod extension; +mod goldilocks_ntt; pub use extension::GoldilocksExtension; use serde::{Deserialize, Serialize}; From eadc7ec18474d90cd860d98cecdb5cfef7dff026 Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Tue, 27 Aug 2024 19:51:54 +0800 Subject: [PATCH 2/4] add ntt for babybear and goldilocks --- algebra/src/baby_bear/babybear_ntt.rs | 176 ++++++++++++++++++++++ algebra/src/goldilocks/goldilocks_ntt.rs | 182 +++++++++++++++++++++++ 2 files changed, 358 insertions(+) create mode 100644 algebra/src/baby_bear/babybear_ntt.rs create mode 100644 algebra/src/goldilocks/goldilocks_ntt.rs diff --git a/algebra/src/baby_bear/babybear_ntt.rs b/algebra/src/baby_bear/babybear_ntt.rs new file mode 100644 index 00000000..02d04d39 --- /dev/null +++ b/algebra/src/baby_bear/babybear_ntt.rs @@ -0,0 +1,176 @@ +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use num_traits::{pow, Zero}; +use rand::{distributions, thread_rng}; + +use crate::{transformation::prime32::ConcreteTable, Field, NTTField}; + +use super::BabyBear; + +impl From for BabyBear { + #[inline] + fn from(value: usize) -> Self { + Self::new(value as u32) + } +} + +static mut NTT_TABLE: once_cell::sync::OnceCell::Table>>> = + once_cell::sync::OnceCell::new(); + +static NTT_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(()); + +impl NTTField for BabyBear { + type Table = ConcreteTable; + + type Root = Self; + + type Degree = u32; + + #[inline] + fn from_root(root: Self::Root) -> Self { + root + } + + #[inline] + fn to_root(self) -> Self::Root { + self + } + + #[inline] + fn mul_root(self, root: Self::Root) -> Self { + self * root + } + + #[inline] + fn mul_root_assign(&mut self, root: Self::Root) { + *self *= root; + } + + #[inline] + fn is_primitive_root(root: Self, degree: Self::Degree) -> bool { + debug_assert!( + degree > 1 && degree.is_power_of_two(), + "degree must be a power of two and bigger than 1" + ); + + if root == Self::zero() { + return false; + } + + pow(root, (degree >> 1) as usize) == Self::neg_one() + } + + fn try_primitive_root(degree: Self::Degree) -> Result { + let modulus_sub_one = BabyBear::MODULUS_VALUE - 1; + let quotient = modulus_sub_one / degree; + if modulus_sub_one != quotient * degree { + return Err(crate::AlgebraError::NoPrimitiveRoot { + degree: degree.to_string(), + modulus: BabyBear::MODULUS_VALUE.to_string(), + }); + } + + let mut rng = thread_rng(); + let distr = distributions::Uniform::new_inclusive(2, modulus_sub_one); + + let mut w = Self::zero(); + + if (0..100).any(|_| { + w = pow( + Self::new(rand::Rng::sample(&mut rng, distr)), + quotient as usize, + ); + Self::is_primitive_root(w, degree) + }) { + Ok(w) + } else { + Err(crate::AlgebraError::NoPrimitiveRoot { + degree: degree.to_string(), + modulus: BabyBear::MODULUS_VALUE.to_string(), + }) + } + } + + fn try_minimal_primitive_root(degree: Self::Degree) -> Result { + let mut root = Self::try_primitive_root(degree)?; + + let generator_sq = (root * root).to_root(); + let mut current_generator = root; + + for _ in 0..degree { + if current_generator < root { + root = current_generator; + } + current_generator.mul_root_assign(generator_sq); + } + + Ok(root) + } + + #[inline] + fn generate_ntt_table(log_n: u32) -> Result { + Self::Table::new(log_n) + } + + fn init_ntt_table(log_ns: &[u32]) -> Result<(), crate::AlgebraError> { + let _g = NTT_MUTEX.lock().unwrap(); + match unsafe { NTT_TABLE.get_mut() } { + Some(tables) => { + let new_log_ns: HashSet = log_ns.iter().copied().collect(); + let old_log_ns: HashSet = tables.keys().copied().collect(); + + let difference = new_log_ns.difference(&old_log_ns); + + for &log_n in difference { + let temp_table = Self::generate_ntt_table(log_n)?; + tables.insert(log_n, Arc::new(temp_table)); + } + Ok(()) + } + None => { + let log_ns: HashSet = log_ns.iter().copied().collect(); + let mut map = HashMap::with_capacity(log_ns.len()); + + for log_n in log_ns { + let temp_table = Self::generate_ntt_table(log_n)?; + map.insert(log_n, Arc::new(temp_table)); + } + + if unsafe { NTT_TABLE.set(map).is_err() } { + Err(crate::AlgebraError::NTTTableError) + } else { + Ok(()) + } + } + } + } + + fn get_ntt_table(log_n: u32) -> Result, crate::AlgebraError> { + if let Some(tables) = unsafe { NTT_TABLE.get() } { + if let Some(t) = tables.get(&log_n) { + return Ok(Arc::clone(t)); + } + } + + Self::init_ntt_table(&[log_n])?; + Ok(Arc::clone(unsafe { + NTT_TABLE.get().unwrap().get(&log_n).unwrap() + })) + } +} + +#[test] +fn ntt_test() { + use crate::{NTTPolynomial, Polynomial}; + let n = 1 << 10; + let mut rng = thread_rng(); + let poly = Polynomial::::random(n, &mut rng); + + let ntt_poly: NTTPolynomial = poly.clone().into(); + + let expect_poly: Polynomial = ntt_poly.into(); + assert_eq!(poly, expect_poly); +} diff --git a/algebra/src/goldilocks/goldilocks_ntt.rs b/algebra/src/goldilocks/goldilocks_ntt.rs new file mode 100644 index 00000000..dd912bcb --- /dev/null +++ b/algebra/src/goldilocks/goldilocks_ntt.rs @@ -0,0 +1,182 @@ +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use num_traits::{pow, Zero}; +use rand::{distributions, thread_rng}; + +use crate::{transformation::prime64::ConcreteTable, Field, NTTField}; + +use super::Goldilocks; + +impl From for Goldilocks { + #[inline] + fn from(value: usize) -> Self { + let modulus = Goldilocks::MODULUS_VALUE as usize; + if value < modulus { + Self(value as u64) + } else { + Self((value % modulus) as u64) + } + } +} + +static mut NTT_TABLE: once_cell::sync::OnceCell< + HashMap::Table>>, +> = once_cell::sync::OnceCell::new(); + +static NTT_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(()); + +impl NTTField for Goldilocks { + type Table = ConcreteTable; + + type Root = Self; + + type Degree = u64; + + #[inline] + fn from_root(root: Self::Root) -> Self { + root + } + + #[inline] + fn to_root(self) -> Self::Root { + self + } + + #[inline] + fn mul_root(self, root: Self::Root) -> Self { + self * root + } + + #[inline] + fn mul_root_assign(&mut self, root: Self::Root) { + *self *= root + } + + #[inline] + fn is_primitive_root(root: Self, degree: Self::Degree) -> bool { + debug_assert!( + degree > 1 && degree.is_power_of_two(), + "degree must be a power of two and bigger than 1" + ); + + if root == Self::zero() { + return false; + } + + pow(root, (degree >> 1) as usize) == Self::neg_one() + } + + fn try_primitive_root(degree: Self::Degree) -> Result { + let modulus_sub_one = Goldilocks::MODULUS_VALUE - 1; + let quotient = modulus_sub_one / degree; + if modulus_sub_one != quotient * degree { + return Err(crate::AlgebraError::NoPrimitiveRoot { + degree: degree.to_string(), + modulus: Goldilocks::MODULUS_VALUE.to_string(), + }); + } + + let mut rng = thread_rng(); + let distr = distributions::Uniform::new_inclusive(2, modulus_sub_one); + + let mut w = Self::zero(); + + if (0..100).any(|_| { + w = pow( + Self::new(rand::Rng::sample(&mut rng, distr)), + quotient as usize, + ); + Self::is_primitive_root(w, degree) + }) { + Ok(w) + } else { + Err(crate::AlgebraError::NoPrimitiveRoot { + degree: degree.to_string(), + modulus: Goldilocks::MODULUS_VALUE.to_string(), + }) + } + } + + fn try_minimal_primitive_root(degree: Self::Degree) -> Result { + let mut root = Self::try_primitive_root(degree)?; + + let generator_sq = (root * root).to_root(); + let mut current_generator = root; + + for _ in 0..degree { + if current_generator < root { + root = current_generator; + } + current_generator.mul_root_assign(generator_sq); + } + + Ok(root) + } + + #[inline] + fn generate_ntt_table(log_n: u32) -> Result { + Self::Table::new(log_n) + } + + fn init_ntt_table(log_ns: &[u32]) -> Result<(), crate::AlgebraError> { + let _g = NTT_MUTEX.lock().unwrap(); + match unsafe { NTT_TABLE.get_mut() } { + Some(tables) => { + let new_log_ns: HashSet = log_ns.iter().copied().collect(); + let old_log_ns: HashSet = tables.keys().copied().collect(); + + let difference = new_log_ns.difference(&old_log_ns); + + for &log_n in difference { + let temp_table = Self::generate_ntt_table(log_n)?; + tables.insert(log_n, Arc::new(temp_table)); + } + Ok(()) + } + None => { + let log_ns: HashSet = log_ns.iter().copied().collect(); + let mut map = HashMap::with_capacity(log_ns.len()); + + for log_n in log_ns { + let temp_table = Self::generate_ntt_table(log_n)?; + map.insert(log_n, Arc::new(temp_table)); + } + + if unsafe { NTT_TABLE.set(map).is_err() } { + Err(crate::AlgebraError::NTTTableError) + } else { + Ok(()) + } + } + } + } + + fn get_ntt_table(log_n: u32) -> Result, crate::AlgebraError> { + if let Some(tables) = unsafe { NTT_TABLE.get() } { + if let Some(t) = tables.get(&log_n) { + return Ok(Arc::clone(t)); + } + } + + Self::init_ntt_table(&[log_n])?; + Ok(Arc::clone(unsafe { + NTT_TABLE.get().unwrap().get(&log_n).unwrap() + })) + } +} + +#[test] +fn ntt_test() { + use crate::{NTTPolynomial, Polynomial}; + let n = 1 << 10; + let mut rng = thread_rng(); + let poly = Polynomial::::random(n, &mut rng); + + let ntt_poly: NTTPolynomial = poly.clone().into(); + + let expect_poly: Polynomial = ntt_poly.into(); + assert_eq!(poly, expect_poly); +} From 83eb022101b0620e0f27059edfdcebc0ead58faa Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Wed, 28 Aug 2024 09:30:36 +0800 Subject: [PATCH 3/4] small opt --- algebra/src/goldilocks/goldilocks_ntt.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algebra/src/goldilocks/goldilocks_ntt.rs b/algebra/src/goldilocks/goldilocks_ntt.rs index dd912bcb..6f8de04e 100644 --- a/algebra/src/goldilocks/goldilocks_ntt.rs +++ b/algebra/src/goldilocks/goldilocks_ntt.rs @@ -17,7 +17,7 @@ impl From for Goldilocks { if value < modulus { Self(value as u64) } else { - Self((value % modulus) as u64) + Self((value - modulus) as u64) } } } From 43403febd9863e3248bc28d09f3fc9d11ffc0353 Mon Sep 17 00:00:00 2001 From: Xiang Xie Date: Wed, 28 Aug 2024 20:07:50 +0800 Subject: [PATCH 4/4] add features --- algebra/src/lib.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/algebra/src/lib.rs b/algebra/src/lib.rs index df1e24b3..4a507303 100644 --- a/algebra/src/lib.rs +++ b/algebra/src/lib.rs @@ -2,13 +2,16 @@ #![deny(missing_docs)] //! Define arithmetic operations. - +#[cfg(feature = "concrete-ntt")] mod baby_bear; + +#[cfg(feature = "concrete-ntt")] +mod goldilocks; + mod decompose_basis; mod error; mod extension; mod field; -mod goldilocks; mod polynomial; mod primitive; mod random; @@ -19,12 +22,16 @@ pub mod reduce; pub mod transformation; pub mod utils; +#[cfg(feature = "concrete-ntt")] pub use baby_bear::{BabyBear, BabyBearExetension}; + +#[cfg(feature = "concrete-ntt")] +pub use goldilocks::{Goldilocks, GoldilocksExtension}; + pub use decompose_basis::Basis; pub use error::AlgebraError; pub use extension::*; pub use field::{DecomposableField, FheField, Field, NTTField, PrimeField}; -pub use goldilocks::{Goldilocks, GoldilocksExtension}; pub use polynomial::multivariate::{ DenseMultilinearExtension, ListOfProductsOfPolynomials, MultilinearExtension, PolynomialInfo, SparsePolynomial,