Skip to content

Commit

Permalink
update from omr branch (#159)
Browse files Browse the repository at this point in the history
* Simplify fhe api (#158)

* remove modulus from some api

* improve polynomial

* update

* update

* update

* add back `concrete-ntt`

* update

* update

* update dependency

* simplify lwe public key generation

* remove some redundant codes  and file

* fix `test_lwe_pk`

* improve ntt

* improve ntt
  • Loading branch information
serendipity-crypto authored Feb 17, 2025
1 parent 8c43803 commit 13583cc
Show file tree
Hide file tree
Showing 64 changed files with 1,866 additions and 753 deletions.
6 changes: 4 additions & 2 deletions algebra/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ rand = { workspace = true }
rand_distr = { workspace = true }
itertools = { workspace = true }
bytemuck = { workspace = true }
concrete-ntt = { git = "https://github.com/primus-labs/concrete-ntt", branch = "dev", default-features = false, optional = true }

[dev-dependencies]
criterion = { workspace = true }

[features]
default = []
nightly = []
default = ["concrete-ntt"]
concrete-ntt = ["dep:concrete-ntt", "concrete-ntt/std"]
nightly = ["concrete-ntt?/nightly"]

[[bench]]
name = "gcd_bench"
Expand Down
57 changes: 43 additions & 14 deletions algebra/benches/field_ntt.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use algebra::ntt::{NttTable, NumberTheoryTransform};
use algebra::{ntt::FieldTableWithShoupRoot, polynomial::FieldPolynomial};
use algebra::{Field, U32FieldEval};
use algebra::ntt::NumberTheoryTransform;
use algebra::polynomial::FieldPolynomial;
use algebra::{Field, NttField, U32FieldEval, U64FieldEval};
use criterion::{criterion_group, criterion_main, Criterion};
use rand::{distributions::Uniform, prelude::*};

Expand All @@ -9,36 +9,65 @@ type ValueT = u32;
const LOG_N: u32 = 10;
const N: usize = 1 << LOG_N;

type Fp = U32FieldEval<132120577>;
type F32 = U32FieldEval<132120577>;
type F64 = U64FieldEval<1125899906826241>;

pub fn criterion_benchmark(c: &mut Criterion) {
let table = <FieldTableWithShoupRoot<Fp>>::new(Fp::MODULUS, LOG_N).unwrap();

let mut rng = thread_rng();

let distr = Uniform::new_inclusive(0, Fp::MINUS_ONE);
let table32 = F32::generate_ntt_table(LOG_N).unwrap();

let distr = Uniform::new_inclusive(0, F32::MINUS_ONE);

let poly: Vec<ValueT> = distr.sample_iter(&mut rng).take(N).collect();
let mut poly = <FieldPolynomial<Fp>>::new(poly);
let mut poly = <FieldPolynomial<F32>>::new(poly);

let degree: usize = rng.gen_range(0..N);
let coeff = distr.sample(&mut rng);

c.bench_function(&format!("field ntt {}", N), |b| {
c.bench_function(&format!("field 32 ntt {}", N), |b| {
b.iter(|| {
table32.transform_slice(poly.as_mut_slice());
})
});

c.bench_function(&format!("field 32 intt {}", N), |b| {
b.iter(|| {
table32.inverse_transform_slice(poly.as_mut_slice());
})
});

c.bench_function(&format!("field 32 monomial ntt {}", N), |b| {
b.iter(|| {
table32.transform_monomial(coeff, degree, poly.as_mut_slice());
})
});

let table64 = F64::generate_ntt_table(LOG_N + 1).unwrap();

let distr = Uniform::new_inclusive(0, F64::MINUS_ONE);

let poly: Vec<_> = distr.sample_iter(&mut rng).take(N << 1).collect();
let mut poly = <FieldPolynomial<F64>>::new(poly);

let degree: usize = rng.gen_range(0..(N << 1));
let coeff = distr.sample(&mut rng);

c.bench_function(&format!("field 64 ntt {}", N << 1), |b| {
b.iter(|| {
table.transform_slice(poly.as_mut_slice());
table64.transform_slice(poly.as_mut_slice());
})
});

c.bench_function(&format!("field intt {}", N), |b| {
c.bench_function(&format!("field 64 intt {}", N << 1), |b| {
b.iter(|| {
table.inverse_transform_slice(poly.as_mut_slice());
table64.inverse_transform_slice(poly.as_mut_slice());
})
});

c.bench_function(&format!("field monomial ntt {}", N), |b| {
c.bench_function(&format!("field 64 monomial ntt {}", N << 1), |b| {
b.iter(|| {
table.transform_monomial(coeff, degree, poly.as_mut_slice());
table64.transform_monomial(coeff, degree, poly.as_mut_slice());
})
});
}
Expand Down
6 changes: 3 additions & 3 deletions algebra/benches/ntt_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ use algebra::ntt::{NttTable, NumberTheoryTransform, TableWithShoupRoot};
use criterion::{criterion_group, criterion_main, Criterion};
use rand::{distributions::Uniform, prelude::*};

type ValueT = u32;
type ValueT = u64;

const LOG_N: u32 = 10;
const LOG_N: u32 = 11;
const N: usize = 1 << LOG_N;
const MODULUS: ValueT = 132120577;
const MODULUS: ValueT = 1125899906826241;

pub fn criterion_benchmark(c: &mut Criterion) {
let modulus = <BarrettModulus<ValueT>>::new(MODULUS);
Expand Down
155 changes: 155 additions & 0 deletions algebra/src/field/impls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
pub mod f32 {
use crate::reduce::*;

#[doc = r" This define a field based the barrett reduction."]
#[derive(Clone, Copy)]
pub struct U32FieldEval<const P: u32>;

impl<const P: u32> crate::Field for U32FieldEval<P> {
type ValueT = u32;
type Modulus = crate::modulus::BarrettModulus<u32>;
const MODULUS_VALUE: Self::ValueT = P;
const MODULUS: Self::Modulus = Self::Modulus::new(P);
const ZERO: Self::ValueT = 0;
const ONE: Self::ValueT = 1;
const MINUS_ONE: Self::ValueT = P - 1;
#[doc = r" Calculates `a + b`."]
#[inline]
fn add(a: Self::ValueT, b: Self::ValueT) -> Self::ValueT {
Self::MODULUS_VALUE.reduce_add(a, b)
}
#[doc = r" Calculates `a += b`."]
#[inline]
fn add_assign(a: &mut Self::ValueT, b: Self::ValueT) {
Self::MODULUS_VALUE.reduce_add_assign(a, b);
}
#[doc = r" Calculates `2*value`."]
#[inline]
fn double(value: Self::ValueT) -> Self::ValueT {
Self::MODULUS_VALUE.reduce_double(value)
}
#[doc = r" Calculates `value = 2*value`."]
#[inline]
fn double_assign(value: &mut Self::ValueT) {
Self::MODULUS_VALUE.reduce_double_assign(value);
}
#[doc = r" Calculates `a - b`."]
#[inline]
fn sub(a: Self::ValueT, b: Self::ValueT) -> Self::ValueT {
Self::MODULUS_VALUE.reduce_sub(a, b)
}
#[doc = r" Calculates `a -= b`."]
#[inline]
fn sub_assign(a: &mut Self::ValueT, b: Self::ValueT) {
Self::MODULUS_VALUE.reduce_sub_assign(a, b);
}
#[doc = r" Calculates `-value`."]
#[inline]
fn neg(value: Self::ValueT) -> Self::ValueT {
Self::MODULUS_VALUE.reduce_neg(value)
}
#[doc = r" Calculates `-value`."]
#[inline]
fn neg_assign(value: &mut Self::ValueT) {
Self::MODULUS_VALUE.reduce_neg_assign(value);
}
#[doc = r" Calculate the multiplicative inverse of `value`."]
#[inline]
fn inv(value: Self::ValueT) -> Self::ValueT {
Self::MODULUS_VALUE.reduce_inv(value)
}
#[doc = r" Calculates `value^(-1)`."]
#[inline]
fn inv_assign(value: &mut Self::ValueT) {
Self::MODULUS_VALUE.reduce_inv_assign(value);
}
}
impl<const P: u32> crate::NttField for U32FieldEval<P> {
#[cfg(not(feature = "concrete-ntt"))]
type Table = crate::ntt::FieldTableWithShoupRoot<Self>;
#[cfg(feature = "concrete-ntt")]
type Table = crate::ntt::Concrete32Table<Self>;
#[inline]
fn generate_ntt_table(log_n: u32) -> Result<Self::Table, crate::AlgebraError> {
crate::ntt::NttTable::new(<Self as crate::Field>::MODULUS, log_n)
}
}
}

pub mod f64 {
use crate::reduce::*;

#[doc = r" This define a field based the barrett reduction."]
#[derive(Clone, Copy)]
pub struct U64FieldEval<const P: u64>;

impl<const P: u64> crate::Field for U64FieldEval<P> {
type ValueT = u64;
type Modulus = crate::modulus::BarrettModulus<u64>;
const MODULUS_VALUE: Self::ValueT = P;
const MODULUS: Self::Modulus = Self::Modulus::new(P);
const ZERO: Self::ValueT = 0;
const ONE: Self::ValueT = 1;
const MINUS_ONE: Self::ValueT = P - 1;
#[doc = r" Calculates `a + b`."]
#[inline]
fn add(a: Self::ValueT, b: Self::ValueT) -> Self::ValueT {
Self::MODULUS_VALUE.reduce_add(a, b)
}
#[doc = r" Calculates `a += b`."]
#[inline]
fn add_assign(a: &mut Self::ValueT, b: Self::ValueT) {
Self::MODULUS_VALUE.reduce_add_assign(a, b);
}
#[doc = r" Calculates `2*value`."]
#[inline]
fn double(value: Self::ValueT) -> Self::ValueT {
Self::MODULUS_VALUE.reduce_double(value)
}
#[doc = r" Calculates `value = 2*value`."]
#[inline]
fn double_assign(value: &mut Self::ValueT) {
Self::MODULUS_VALUE.reduce_double_assign(value);
}
#[doc = r" Calculates `a - b`."]
#[inline]
fn sub(a: Self::ValueT, b: Self::ValueT) -> Self::ValueT {
Self::MODULUS_VALUE.reduce_sub(a, b)
}
#[doc = r" Calculates `a -= b`."]
#[inline]
fn sub_assign(a: &mut Self::ValueT, b: Self::ValueT) {
Self::MODULUS_VALUE.reduce_sub_assign(a, b);
}
#[doc = r" Calculates `-value`."]
#[inline]
fn neg(value: Self::ValueT) -> Self::ValueT {
Self::MODULUS_VALUE.reduce_neg(value)
}
#[doc = r" Calculates `-value`."]
#[inline]
fn neg_assign(value: &mut Self::ValueT) {
Self::MODULUS_VALUE.reduce_neg_assign(value);
}
#[doc = r" Calculate the multiplicative inverse of `value`."]
#[inline]
fn inv(value: Self::ValueT) -> Self::ValueT {
Self::MODULUS_VALUE.reduce_inv(value)
}
#[doc = r" Calculates `value^(-1)`."]
#[inline]
fn inv_assign(value: &mut Self::ValueT) {
Self::MODULUS_VALUE.reduce_inv_assign(value);
}
}
impl<const P: u64> crate::NttField for U64FieldEval<P> {
#[cfg(not(feature = "concrete-ntt"))]
type Table = crate::ntt::FieldTableWithShoupRoot<Self>;
#[cfg(feature = "concrete-ntt")]
type Table = crate::ntt::Concrete64Table<Self>;
#[inline]
fn generate_ntt_table(log_n: u32) -> Result<Self::Table, crate::AlgebraError> {
crate::ntt::NttTable::new(<Self as crate::Field>::MODULUS, log_n)
}
}
}
5 changes: 3 additions & 2 deletions algebra/src/field/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ use crate::reduce::*;

#[macro_use]
mod macros;
mod impls;
mod ntt;

pub use impls::f32::U32FieldEval;
pub use impls::f64::U64FieldEval;
pub use ntt::NttField;

/// An abstract for field evaluator.
Expand Down Expand Up @@ -152,5 +155,3 @@ pub trait Field: Sized + Clone + Copy {

impl_barrett_field!(#[derive(Clone, Copy)] impl pub U8FieldEval<u8>);
impl_barrett_field!(#[derive(Clone, Copy)] impl pub U16FieldEval<u16>);
impl_barrett_field!(#[derive(Clone, Copy)] impl pub U32FieldEval<u32>);
impl_barrett_field!(#[derive(Clone, Copy)] impl pub U64FieldEval<u64>);
81 changes: 79 additions & 2 deletions algebra/src/modulus/barrett/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::fmt::Display;

use crate::{numeric::Numeric, reduce::Modulus};
use crate::{
integer::{AsFrom, AsInto},
numeric::Numeric,
reduce::{Modulus, ModulusValue},
};

#[macro_use]
mod macros;
Expand Down Expand Up @@ -29,6 +33,22 @@ impl<T: Numeric> Display for BarrettModulus<T> {
}

impl<T: Numeric> BarrettModulus<T> {
/// Creates a new [`BarrettModulus<T>`] with the given value.
pub fn new_generic(value: T) -> Self {
if value <= T::ONE {
panic!("modulus can't be 0 or 1.")
}
let bit_count = T::BITS - value.leading_zeros();
assert!(bit_count < T::BITS - 1);

let (numerator, _) = div_inplace(value);

Self {
value,
ratio: numerator,
}
}

/// Returns the value of this [`BarrettModulus<T>`].
#[inline]
pub const fn value(&self) -> T {
Expand All @@ -44,7 +64,22 @@ impl<T: Numeric> BarrettModulus<T> {

impl<T: Numeric> Modulus<T> for BarrettModulus<T> {
#[inline]
fn modulus_minus_one(self) -> T {
fn from_value(value: ModulusValue<T>) -> Self {
match value {
ModulusValue::Native => panic!("Not match for native"),
ModulusValue::PowerOf2(value)
| ModulusValue::Prime(value)
| ModulusValue::Others(value) => Self::new_generic(value),
}
}

#[inline]
fn modulus_value(&self) -> ModulusValue<T> {
ModulusValue::Others(self.value)
}

#[inline]
fn modulus_minus_one(&self) -> T {
self.value - T::ONE
}
}
Expand All @@ -53,3 +88,45 @@ impl_barrett_modulus!(impl BarrettModulus<u8>; WideType: u16);
impl_barrett_modulus!(impl BarrettModulus<u16>; WideType: u32);
impl_barrett_modulus!(impl BarrettModulus<u32>; WideType: u64);
impl_barrett_modulus!(impl BarrettModulus<u64>; WideType: u128);

#[inline]
fn div_rem<T: Numeric>(numerator: T, divisor: T) -> (T, T) {
(numerator / divisor, numerator % divisor)
}

#[inline]
fn div_wide<T: Numeric>(hi: T, divisor: T) -> (T, T) {
let lhs = T::WideT::as_from(hi) << <T>::BITS;
let rhs = T::WideT::as_from(divisor);
((lhs / rhs).as_into(), (lhs % rhs).as_into())
}

#[inline]
fn div_half<T: Numeric>(rem: T, divisor: T) -> (T, T) {
let half_bits: u32 = T::BITS >> 1;
let (hi, rem) = div_rem(rem << half_bits, divisor);
let (lo, rem) = div_rem(rem << half_bits, divisor);
((hi << half_bits) | lo, rem)
}

fn div_inplace<T: Numeric>(value: T) -> ([T; 2], T) {
let mut numerator = [T::ZERO, T::ZERO];
let rem;

if value <= (T::MAX >> (T::BITS >> 1)) {
let (q, r) = div_half(T::ONE, value);
numerator[1] = q;

let (q, r) = div_half(r, value);
numerator[0] = q;
rem = r;
} else {
let (q, r) = div_wide(T::ONE, value);
numerator[1] = q;

let (q, r) = div_wide(r, value);
numerator[0] = q;
rem = r;
}
(numerator, rem)
}
Loading

0 comments on commit 13583cc

Please sign in to comment.