Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ntt for babybear and goldilocks #144

Merged
merged 4 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions algebra/src/baby_bear/babybear_ntt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};

use num_traits::{pow, Zero};
use rand::{distributions, thread_rng};

use crate::{transformation::prime32::ConcreteTable, Field, NTTField};

Comment on lines +8 to +10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use ConcreteTable make this implementation is not correctly when feature concrete-ntt is disabled.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not have a non-concrete version yet. That is why I do not use a feature here, which means we have to always use concrete for these two fields.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we have.
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can disable the default feature to disable concrete-ntt.

use super::BabyBear;

impl From<usize> for BabyBear {
#[inline]
fn from(value: usize) -> Self {
Self::new(value as u32)
}
}

static mut NTT_TABLE: once_cell::sync::OnceCell<HashMap<u32, Arc<<BabyBear as NTTField>::Table>>> =
once_cell::sync::OnceCell::new();

static NTT_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());

impl NTTField for BabyBear {
type Table = ConcreteTable<Self>;

type Root = Self;

type Degree = u32;

#[inline]
fn from_root(root: Self::Root) -> Self {
root
}

#[inline]
fn to_root(self) -> Self::Root {
self
}

#[inline]
fn mul_root(self, root: Self::Root) -> Self {
self * root
}

#[inline]
fn mul_root_assign(&mut self, root: Self::Root) {
*self *= root;
}

#[inline]
fn is_primitive_root(root: Self, degree: Self::Degree) -> bool {
debug_assert!(
degree > 1 && degree.is_power_of_two(),
"degree must be a power of two and bigger than 1"
);

if root == Self::zero() {
return false;
}

pow(root, (degree >> 1) as usize) == Self::neg_one()
}

fn try_primitive_root(degree: Self::Degree) -> Result<Self, crate::AlgebraError> {
let modulus_sub_one = BabyBear::MODULUS_VALUE - 1;
let quotient = modulus_sub_one / degree;
if modulus_sub_one != quotient * degree {
return Err(crate::AlgebraError::NoPrimitiveRoot {
degree: degree.to_string(),
modulus: BabyBear::MODULUS_VALUE.to_string(),
});
}

let mut rng = thread_rng();
let distr = distributions::Uniform::new_inclusive(2, modulus_sub_one);

let mut w = Self::zero();

if (0..100).any(|_| {
w = pow(
Self::new(rand::Rng::sample(&mut rng, distr)),
quotient as usize,
);
Self::is_primitive_root(w, degree)
}) {
Ok(w)
} else {
Err(crate::AlgebraError::NoPrimitiveRoot {
degree: degree.to_string(),
modulus: BabyBear::MODULUS_VALUE.to_string(),
})
}
}
Comment on lines +66 to +95
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is not necessary now. We just use unreachable!() in it.


fn try_minimal_primitive_root(degree: Self::Degree) -> Result<Self, crate::AlgebraError> {
let mut root = Self::try_primitive_root(degree)?;

let generator_sq = (root * root).to_root();
let mut current_generator = root;

for _ in 0..degree {
if current_generator < root {
root = current_generator;
}
current_generator.mul_root_assign(generator_sq);
}

Ok(root)
}
Comment on lines +97 to +111
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is not necessary now. We just use unreachable!() in it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is used in zk, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, if zk use it, the wrong result will be returned.

Copy link
Collaborator

@serendipity-crypto serendipity-crypto Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not the right way to get the root for user. Concrete generate root with a different way.


#[inline]
fn generate_ntt_table(log_n: u32) -> Result<Self::Table, crate::AlgebraError> {
Self::Table::new(log_n)
}

fn init_ntt_table(log_ns: &[u32]) -> Result<(), crate::AlgebraError> {
let _g = NTT_MUTEX.lock().unwrap();
match unsafe { NTT_TABLE.get_mut() } {
Some(tables) => {
let new_log_ns: HashSet<u32> = log_ns.iter().copied().collect();
let old_log_ns: HashSet<u32> = tables.keys().copied().collect();

let difference = new_log_ns.difference(&old_log_ns);

for &log_n in difference {
let temp_table = Self::generate_ntt_table(log_n)?;
tables.insert(log_n, Arc::new(temp_table));
}
Ok(())
}
None => {
let log_ns: HashSet<u32> = log_ns.iter().copied().collect();
let mut map = HashMap::with_capacity(log_ns.len());

for log_n in log_ns {
let temp_table = Self::generate_ntt_table(log_n)?;
map.insert(log_n, Arc::new(temp_table));
}

if unsafe { NTT_TABLE.set(map).is_err() } {
Err(crate::AlgebraError::NTTTableError)
} else {
Ok(())
}
}
}
}

fn get_ntt_table(log_n: u32) -> Result<Arc<Self::Table>, crate::AlgebraError> {
if let Some(tables) = unsafe { NTT_TABLE.get() } {
if let Some(t) = tables.get(&log_n) {
return Ok(Arc::clone(t));
}
}

Self::init_ntt_table(&[log_n])?;
Ok(Arc::clone(unsafe {
NTT_TABLE.get().unwrap().get(&log_n).unwrap()
}))
}
}

#[test]
fn ntt_test() {
use crate::{NTTPolynomial, Polynomial};
let n = 1 << 10;
let mut rng = thread_rng();
let poly = Polynomial::<BabyBear>::random(n, &mut rng);

let ntt_poly: NTTPolynomial<BabyBear> = poly.clone().into();

let expect_poly: Polynomial<BabyBear> = ntt_poly.into();
assert_eq!(poly, expect_poly);
}
2 changes: 2 additions & 0 deletions algebra/src/baby_bear/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod babybear_ntt;
mod extension;

pub use extension::BabyBearExetension;

use serde::{Deserialize, Serialize};

use std::{
Expand Down
182 changes: 182 additions & 0 deletions algebra/src/goldilocks/goldilocks_ntt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};

use num_traits::{pow, Zero};
use rand::{distributions, thread_rng};

use crate::{transformation::prime64::ConcreteTable, Field, NTTField};

Comment on lines +8 to +10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a way to perform when concrete-ntt is not enable

use super::Goldilocks;

impl From<usize> for Goldilocks {
#[inline]
fn from(value: usize) -> Self {
let modulus = Goldilocks::MODULUS_VALUE as usize;
if value < modulus {
Self(value as u64)
} else {
Self((value - modulus) as u64)
}
}
}

static mut NTT_TABLE: once_cell::sync::OnceCell<
HashMap<u32, Arc<<Goldilocks as NTTField>::Table>>,
> = once_cell::sync::OnceCell::new();

static NTT_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());

impl NTTField for Goldilocks {
type Table = ConcreteTable<Self>;

type Root = Self;

type Degree = u64;

#[inline]
fn from_root(root: Self::Root) -> Self {
root
}

#[inline]
fn to_root(self) -> Self::Root {
self
}

#[inline]
fn mul_root(self, root: Self::Root) -> Self {
self * root
}

#[inline]
fn mul_root_assign(&mut self, root: Self::Root) {
*self *= root
}

#[inline]
fn is_primitive_root(root: Self, degree: Self::Degree) -> bool {
debug_assert!(
degree > 1 && degree.is_power_of_two(),
"degree must be a power of two and bigger than 1"
);

if root == Self::zero() {
return false;
}

pow(root, (degree >> 1) as usize) == Self::neg_one()
}

fn try_primitive_root(degree: Self::Degree) -> Result<Self, crate::AlgebraError> {
let modulus_sub_one = Goldilocks::MODULUS_VALUE - 1;
let quotient = modulus_sub_one / degree;
if modulus_sub_one != quotient * degree {
return Err(crate::AlgebraError::NoPrimitiveRoot {
degree: degree.to_string(),
modulus: Goldilocks::MODULUS_VALUE.to_string(),
});
}

let mut rng = thread_rng();
let distr = distributions::Uniform::new_inclusive(2, modulus_sub_one);

let mut w = Self::zero();

if (0..100).any(|_| {
w = pow(
Self::new(rand::Rng::sample(&mut rng, distr)),
quotient as usize,
);
Self::is_primitive_root(w, degree)
}) {
Ok(w)
} else {
Err(crate::AlgebraError::NoPrimitiveRoot {
degree: degree.to_string(),
modulus: Goldilocks::MODULUS_VALUE.to_string(),
})
}
}
Comment on lines +72 to +101
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is not necessary now. We just use unreachable!() in it.


fn try_minimal_primitive_root(degree: Self::Degree) -> Result<Self, crate::AlgebraError> {
let mut root = Self::try_primitive_root(degree)?;

let generator_sq = (root * root).to_root();
let mut current_generator = root;

for _ in 0..degree {
if current_generator < root {
root = current_generator;
}
current_generator.mul_root_assign(generator_sq);
}

Ok(root)
}
Comment on lines +103 to +117
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is not necessary now. We just use unreachable!() in it.


#[inline]
fn generate_ntt_table(log_n: u32) -> Result<Self::Table, crate::AlgebraError> {
Self::Table::new(log_n)
}

fn init_ntt_table(log_ns: &[u32]) -> Result<(), crate::AlgebraError> {
let _g = NTT_MUTEX.lock().unwrap();
match unsafe { NTT_TABLE.get_mut() } {
Some(tables) => {
let new_log_ns: HashSet<u32> = log_ns.iter().copied().collect();
let old_log_ns: HashSet<u32> = tables.keys().copied().collect();

let difference = new_log_ns.difference(&old_log_ns);

for &log_n in difference {
let temp_table = Self::generate_ntt_table(log_n)?;
tables.insert(log_n, Arc::new(temp_table));
}
Ok(())
}
None => {
let log_ns: HashSet<u32> = log_ns.iter().copied().collect();
let mut map = HashMap::with_capacity(log_ns.len());

for log_n in log_ns {
let temp_table = Self::generate_ntt_table(log_n)?;
map.insert(log_n, Arc::new(temp_table));
}

if unsafe { NTT_TABLE.set(map).is_err() } {
Err(crate::AlgebraError::NTTTableError)
} else {
Ok(())
}
}
}
}

fn get_ntt_table(log_n: u32) -> Result<Arc<Self::Table>, crate::AlgebraError> {
if let Some(tables) = unsafe { NTT_TABLE.get() } {
if let Some(t) = tables.get(&log_n) {
return Ok(Arc::clone(t));
}
}

Self::init_ntt_table(&[log_n])?;
Ok(Arc::clone(unsafe {
NTT_TABLE.get().unwrap().get(&log_n).unwrap()
}))
}
}

#[test]
fn ntt_test() {
use crate::{NTTPolynomial, Polynomial};
let n = 1 << 10;
let mut rng = thread_rng();
let poly = Polynomial::<Goldilocks>::random(n, &mut rng);

let ntt_poly: NTTPolynomial<Goldilocks> = poly.clone().into();

let expect_poly: Polynomial<Goldilocks> = ntt_poly.into();
assert_eq!(poly, expect_poly);
}
1 change: 1 addition & 0 deletions algebra/src/goldilocks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod extension;
mod goldilocks_ntt;

pub use extension::GoldilocksExtension;
use serde::{Deserialize, Serialize};
Expand Down
Loading