-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ntt for babybear and goldilocks #144
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
use std::{ | ||
collections::{HashMap, HashSet}, | ||
sync::Arc, | ||
}; | ||
|
||
use num_traits::{pow, Zero}; | ||
use rand::{distributions, thread_rng}; | ||
|
||
use crate::{transformation::prime32::ConcreteTable, Field, NTTField}; | ||
|
||
use super::BabyBear; | ||
|
||
impl From<usize> for BabyBear { | ||
#[inline] | ||
fn from(value: usize) -> Self { | ||
Self::new(value as u32) | ||
} | ||
} | ||
|
||
static mut NTT_TABLE: once_cell::sync::OnceCell<HashMap<u32, Arc<<BabyBear as NTTField>::Table>>> = | ||
once_cell::sync::OnceCell::new(); | ||
|
||
static NTT_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(()); | ||
|
||
impl NTTField for BabyBear { | ||
type Table = ConcreteTable<Self>; | ||
|
||
type Root = Self; | ||
|
||
type Degree = u32; | ||
|
||
#[inline] | ||
fn from_root(root: Self::Root) -> Self { | ||
root | ||
} | ||
|
||
#[inline] | ||
fn to_root(self) -> Self::Root { | ||
self | ||
} | ||
|
||
#[inline] | ||
fn mul_root(self, root: Self::Root) -> Self { | ||
self * root | ||
} | ||
|
||
#[inline] | ||
fn mul_root_assign(&mut self, root: Self::Root) { | ||
*self *= root; | ||
} | ||
|
||
#[inline] | ||
fn is_primitive_root(root: Self, degree: Self::Degree) -> bool { | ||
debug_assert!( | ||
degree > 1 && degree.is_power_of_two(), | ||
"degree must be a power of two and bigger than 1" | ||
); | ||
|
||
if root == Self::zero() { | ||
return false; | ||
} | ||
|
||
pow(root, (degree >> 1) as usize) == Self::neg_one() | ||
} | ||
|
||
fn try_primitive_root(degree: Self::Degree) -> Result<Self, crate::AlgebraError> { | ||
let modulus_sub_one = BabyBear::MODULUS_VALUE - 1; | ||
let quotient = modulus_sub_one / degree; | ||
if modulus_sub_one != quotient * degree { | ||
return Err(crate::AlgebraError::NoPrimitiveRoot { | ||
degree: degree.to_string(), | ||
modulus: BabyBear::MODULUS_VALUE.to_string(), | ||
}); | ||
} | ||
|
||
let mut rng = thread_rng(); | ||
let distr = distributions::Uniform::new_inclusive(2, modulus_sub_one); | ||
|
||
let mut w = Self::zero(); | ||
|
||
if (0..100).any(|_| { | ||
w = pow( | ||
Self::new(rand::Rng::sample(&mut rng, distr)), | ||
quotient as usize, | ||
); | ||
Self::is_primitive_root(w, degree) | ||
}) { | ||
Ok(w) | ||
} else { | ||
Err(crate::AlgebraError::NoPrimitiveRoot { | ||
degree: degree.to_string(), | ||
modulus: BabyBear::MODULUS_VALUE.to_string(), | ||
}) | ||
} | ||
} | ||
Comment on lines
+66
to
+95
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is not necessary now. We just use |
||
|
||
fn try_minimal_primitive_root(degree: Self::Degree) -> Result<Self, crate::AlgebraError> { | ||
let mut root = Self::try_primitive_root(degree)?; | ||
|
||
let generator_sq = (root * root).to_root(); | ||
let mut current_generator = root; | ||
|
||
for _ in 0..degree { | ||
if current_generator < root { | ||
root = current_generator; | ||
} | ||
current_generator.mul_root_assign(generator_sq); | ||
} | ||
|
||
Ok(root) | ||
} | ||
Comment on lines
+97
to
+111
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is not necessary now. We just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is used in zk, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, if zk use it, the wrong result will be returned. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not the right way to get the root for user. Concrete generate root with a different way. |
||
|
||
#[inline] | ||
fn generate_ntt_table(log_n: u32) -> Result<Self::Table, crate::AlgebraError> { | ||
Self::Table::new(log_n) | ||
} | ||
|
||
fn init_ntt_table(log_ns: &[u32]) -> Result<(), crate::AlgebraError> { | ||
let _g = NTT_MUTEX.lock().unwrap(); | ||
match unsafe { NTT_TABLE.get_mut() } { | ||
Some(tables) => { | ||
let new_log_ns: HashSet<u32> = log_ns.iter().copied().collect(); | ||
let old_log_ns: HashSet<u32> = tables.keys().copied().collect(); | ||
|
||
let difference = new_log_ns.difference(&old_log_ns); | ||
|
||
for &log_n in difference { | ||
let temp_table = Self::generate_ntt_table(log_n)?; | ||
tables.insert(log_n, Arc::new(temp_table)); | ||
} | ||
Ok(()) | ||
} | ||
None => { | ||
let log_ns: HashSet<u32> = log_ns.iter().copied().collect(); | ||
let mut map = HashMap::with_capacity(log_ns.len()); | ||
|
||
for log_n in log_ns { | ||
let temp_table = Self::generate_ntt_table(log_n)?; | ||
map.insert(log_n, Arc::new(temp_table)); | ||
} | ||
|
||
if unsafe { NTT_TABLE.set(map).is_err() } { | ||
Err(crate::AlgebraError::NTTTableError) | ||
} else { | ||
Ok(()) | ||
} | ||
} | ||
} | ||
} | ||
|
||
fn get_ntt_table(log_n: u32) -> Result<Arc<Self::Table>, crate::AlgebraError> { | ||
if let Some(tables) = unsafe { NTT_TABLE.get() } { | ||
if let Some(t) = tables.get(&log_n) { | ||
return Ok(Arc::clone(t)); | ||
} | ||
} | ||
|
||
Self::init_ntt_table(&[log_n])?; | ||
Ok(Arc::clone(unsafe { | ||
NTT_TABLE.get().unwrap().get(&log_n).unwrap() | ||
})) | ||
} | ||
} | ||
|
||
#[test] | ||
fn ntt_test() { | ||
use crate::{NTTPolynomial, Polynomial}; | ||
let n = 1 << 10; | ||
let mut rng = thread_rng(); | ||
let poly = Polynomial::<BabyBear>::random(n, &mut rng); | ||
|
||
let ntt_poly: NTTPolynomial<BabyBear> = poly.clone().into(); | ||
|
||
let expect_poly: Polynomial<BabyBear> = ntt_poly.into(); | ||
assert_eq!(poly, expect_poly); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
use std::{ | ||
collections::{HashMap, HashSet}, | ||
sync::Arc, | ||
}; | ||
|
||
use num_traits::{pow, Zero}; | ||
use rand::{distributions, thread_rng}; | ||
|
||
use crate::{transformation::prime64::ConcreteTable, Field, NTTField}; | ||
|
||
Comment on lines
+8
to
+10
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need a way to perform when |
||
use super::Goldilocks; | ||
|
||
impl From<usize> for Goldilocks { | ||
#[inline] | ||
fn from(value: usize) -> Self { | ||
let modulus = Goldilocks::MODULUS_VALUE as usize; | ||
if value < modulus { | ||
Self(value as u64) | ||
} else { | ||
Self((value - modulus) as u64) | ||
} | ||
} | ||
} | ||
|
||
static mut NTT_TABLE: once_cell::sync::OnceCell< | ||
HashMap<u32, Arc<<Goldilocks as NTTField>::Table>>, | ||
> = once_cell::sync::OnceCell::new(); | ||
|
||
static NTT_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(()); | ||
|
||
impl NTTField for Goldilocks { | ||
type Table = ConcreteTable<Self>; | ||
|
||
type Root = Self; | ||
|
||
type Degree = u64; | ||
|
||
#[inline] | ||
fn from_root(root: Self::Root) -> Self { | ||
root | ||
} | ||
|
||
#[inline] | ||
fn to_root(self) -> Self::Root { | ||
self | ||
} | ||
|
||
#[inline] | ||
fn mul_root(self, root: Self::Root) -> Self { | ||
self * root | ||
} | ||
|
||
#[inline] | ||
fn mul_root_assign(&mut self, root: Self::Root) { | ||
*self *= root | ||
} | ||
|
||
#[inline] | ||
fn is_primitive_root(root: Self, degree: Self::Degree) -> bool { | ||
debug_assert!( | ||
degree > 1 && degree.is_power_of_two(), | ||
"degree must be a power of two and bigger than 1" | ||
); | ||
|
||
if root == Self::zero() { | ||
return false; | ||
} | ||
|
||
pow(root, (degree >> 1) as usize) == Self::neg_one() | ||
} | ||
|
||
fn try_primitive_root(degree: Self::Degree) -> Result<Self, crate::AlgebraError> { | ||
let modulus_sub_one = Goldilocks::MODULUS_VALUE - 1; | ||
let quotient = modulus_sub_one / degree; | ||
if modulus_sub_one != quotient * degree { | ||
return Err(crate::AlgebraError::NoPrimitiveRoot { | ||
degree: degree.to_string(), | ||
modulus: Goldilocks::MODULUS_VALUE.to_string(), | ||
}); | ||
} | ||
|
||
let mut rng = thread_rng(); | ||
let distr = distributions::Uniform::new_inclusive(2, modulus_sub_one); | ||
|
||
let mut w = Self::zero(); | ||
|
||
if (0..100).any(|_| { | ||
w = pow( | ||
Self::new(rand::Rng::sample(&mut rng, distr)), | ||
quotient as usize, | ||
); | ||
Self::is_primitive_root(w, degree) | ||
}) { | ||
Ok(w) | ||
} else { | ||
Err(crate::AlgebraError::NoPrimitiveRoot { | ||
degree: degree.to_string(), | ||
modulus: Goldilocks::MODULUS_VALUE.to_string(), | ||
}) | ||
} | ||
} | ||
Comment on lines
+72
to
+101
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is not necessary now. We just use |
||
|
||
fn try_minimal_primitive_root(degree: Self::Degree) -> Result<Self, crate::AlgebraError> { | ||
let mut root = Self::try_primitive_root(degree)?; | ||
|
||
let generator_sq = (root * root).to_root(); | ||
let mut current_generator = root; | ||
|
||
for _ in 0..degree { | ||
if current_generator < root { | ||
root = current_generator; | ||
} | ||
current_generator.mul_root_assign(generator_sq); | ||
} | ||
|
||
Ok(root) | ||
} | ||
Comment on lines
+103
to
+117
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is not necessary now. We just use |
||
|
||
#[inline] | ||
fn generate_ntt_table(log_n: u32) -> Result<Self::Table, crate::AlgebraError> { | ||
Self::Table::new(log_n) | ||
} | ||
|
||
fn init_ntt_table(log_ns: &[u32]) -> Result<(), crate::AlgebraError> { | ||
let _g = NTT_MUTEX.lock().unwrap(); | ||
match unsafe { NTT_TABLE.get_mut() } { | ||
Some(tables) => { | ||
let new_log_ns: HashSet<u32> = log_ns.iter().copied().collect(); | ||
let old_log_ns: HashSet<u32> = tables.keys().copied().collect(); | ||
|
||
let difference = new_log_ns.difference(&old_log_ns); | ||
|
||
for &log_n in difference { | ||
let temp_table = Self::generate_ntt_table(log_n)?; | ||
tables.insert(log_n, Arc::new(temp_table)); | ||
} | ||
Ok(()) | ||
} | ||
None => { | ||
let log_ns: HashSet<u32> = log_ns.iter().copied().collect(); | ||
let mut map = HashMap::with_capacity(log_ns.len()); | ||
|
||
for log_n in log_ns { | ||
let temp_table = Self::generate_ntt_table(log_n)?; | ||
map.insert(log_n, Arc::new(temp_table)); | ||
} | ||
|
||
if unsafe { NTT_TABLE.set(map).is_err() } { | ||
Err(crate::AlgebraError::NTTTableError) | ||
} else { | ||
Ok(()) | ||
} | ||
} | ||
} | ||
} | ||
|
||
fn get_ntt_table(log_n: u32) -> Result<Arc<Self::Table>, crate::AlgebraError> { | ||
if let Some(tables) = unsafe { NTT_TABLE.get() } { | ||
if let Some(t) = tables.get(&log_n) { | ||
return Ok(Arc::clone(t)); | ||
} | ||
} | ||
|
||
Self::init_ntt_table(&[log_n])?; | ||
Ok(Arc::clone(unsafe { | ||
NTT_TABLE.get().unwrap().get(&log_n).unwrap() | ||
})) | ||
} | ||
} | ||
|
||
#[test] | ||
fn ntt_test() { | ||
use crate::{NTTPolynomial, Polynomial}; | ||
let n = 1 << 10; | ||
let mut rng = thread_rng(); | ||
let poly = Polynomial::<Goldilocks>::random(n, &mut rng); | ||
|
||
let ntt_poly: NTTPolynomial<Goldilocks> = poly.clone().into(); | ||
|
||
let expect_poly: Polynomial<Goldilocks> = ntt_poly.into(); | ||
assert_eq!(poly, expect_poly); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
mod extension; | ||
mod goldilocks_ntt; | ||
|
||
pub use extension::GoldilocksExtension; | ||
use serde::{Deserialize, Serialize}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use
ConcreteTable
make this implementation is not correctly when featureconcrete-ntt
is disabled.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do not have a non-concrete version yet. That is why I do not use a feature here, which means we have to always use concrete for these two fields.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we have.

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can disable the default feature to disable
concrete-ntt
.