diff --git a/algebra/src/field/mod.rs b/algebra/src/field/mod.rs index 94124bd0..4b93293f 100644 --- a/algebra/src/field/mod.rs +++ b/algebra/src/field/mod.rs @@ -5,6 +5,7 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssi use num_traits::{Inv, One, Pow, PrimInt, Zero}; use rand::{CryptoRng, Rng}; +use serde::Serialize; use crate::random::UniformBase; use crate::{AsFrom, AsInto, Basis, Widening, WrappingOps}; @@ -67,6 +68,7 @@ pub trait Field: + Neg + Inv + Pow + + Serialize { /// The inner type of this field. type Value: Debug @@ -78,7 +80,8 @@ pub trait Field: + Into + AsInto + AsFrom - + UniformBase; + + UniformBase + + Serialize; /// The type of the field's order. type Order: Copy; diff --git a/algebra/src/lib.rs b/algebra/src/lib.rs index 5b24db59..a3bcd78b 100644 --- a/algebra/src/lib.rs +++ b/algebra/src/lib.rs @@ -26,8 +26,8 @@ pub use extension::*; pub use field::{DecomposableField, FheField, Field, NTTField, PrimeField}; pub use goldilocks::{Goldilocks, GoldilocksExtension}; pub use polynomial::multivariate::{ - DenseMultilinearExtension, ListOfProductsOfPolynomials, MultilinearExtension, PolynomialInfo, - SparsePolynomial, + DenseMultilinearExtension, DenseMultilinearExtensionBase, ListOfProductsOfPolynomials, + MultilinearExtension, MultilinearExtensionBase, PolynomialInfo, SparsePolynomial, UF, }; pub use polynomial::univariate::{ ntt_add_mul_assign, ntt_add_mul_assign_fast, ntt_add_mul_inplace, ntt_mul_assign, diff --git a/algebra/src/polynomial/multivariate/data_structures.rs b/algebra/src/polynomial/multivariate/data_structures.rs index 76f373ed..b5df51d3 100644 --- a/algebra/src/polynomial/multivariate/data_structures.rs +++ b/algebra/src/polynomial/multivariate/data_structures.rs @@ -1,10 +1,12 @@ // It is derived from https://github.com/arkworks-rs/sumcheck/blob/master/src/ml_sumcheck/data_structures.rs . -use std::{collections::HashMap, rc::Rc}; +use std::{collections::HashMap, marker::PhantomData, rc::Rc}; -use crate::Field; +use serde::Serialize; -use super::{DenseMultilinearExtension, MultilinearExtension}; +use crate::{AbstractExtensionField, Field}; + +use super::{multilinear::UF, DenseMultilinearExtension, MultilinearExtension}; /// Stores a list of products of `DenseMultilinearExtension` that is meant to be added together. /// @@ -21,23 +23,25 @@ use super::{DenseMultilinearExtension, MultilinearExtension}; /// /// The resulting polynomial is used as the prover key. #[derive(Clone)] -pub struct ListOfProductsOfPolynomials { +pub struct ListOfProductsOfPolynomials> { /// max number of multiplicands in each product pub max_multiplicands: usize, /// number of variables of the polynomial pub num_variables: usize, /// list of reference to products (as usize) of multilinear extension - pub products: Vec<(F, Vec)>, + pub products: Vec<(EF, Vec)>, /// Stores the linear operations, each of which is successively (in the same order) perfomed over the each MLE of each product stored in the above `products` /// so each (a: F, b: F) can used to wrap a linear operation over the original MLE f, i.e. a \cdot f + b - pub linear_ops: Vec>, + #[allow(clippy::type_complexity)] + pub linear_ops: Vec, UF)>>, /// Stores multilinear extensions in which product multiplicand can refer to. - pub flattened_ml_extensions: Vec>>, - raw_pointers_lookup_table: HashMap<*const DenseMultilinearExtension, usize>, + pub flattened_ml_extensions: Vec>>, + raw_pointers_lookup_table: HashMap<*const DenseMultilinearExtension, usize>, + _marker: PhantomData, } /// Extract the max number of multiplicands and number of variables of the list of products. -impl ListOfProductsOfPolynomials { +impl> ListOfProductsOfPolynomials { /// Extract the max number of multiplicands and number of variables of the list of products. #[inline] pub fn info(&self) -> PolynomialInfo { @@ -48,7 +52,7 @@ impl ListOfProductsOfPolynomials { } } -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Serialize)] /// Stores the number of variables and max number of multiplicands of the added polynomial used by the prover. /// This data structures will be used as the verifier key. pub struct PolynomialInfo { @@ -58,7 +62,7 @@ pub struct PolynomialInfo { pub num_variables: usize, } -impl ListOfProductsOfPolynomials { +impl> ListOfProductsOfPolynomials { /// Returns an empty polynomial #[inline] pub fn new(num_variables: usize) -> Self { @@ -69,6 +73,7 @@ impl ListOfProductsOfPolynomials { linear_ops: Vec::new(), flattened_ml_extensions: Vec::new(), raw_pointers_lookup_table: HashMap::new(), + _marker: PhantomData, } } @@ -76,16 +81,16 @@ impl ListOfProductsOfPolynomials { /// Here we wrap a linear operation on the same MLE so that we can add the /// product like f(x) \cdot (2f(x) + 3) \cdot (4f(x) + 4) with only one Rc. /// The resulting polynomial will be multiplied by the scalar `coefficient`. + #[allow(clippy::type_complexity)] pub fn add_product_with_linear_op( &mut self, - product: impl IntoIterator>>, - op_coefficient: &[(F, F)], - coefficient: F, + product: impl IntoIterator>>, + op_coefficient: &[(UF, UF)], + coefficient: EF, ) { - let product: Vec>> = product.into_iter().collect(); + let product: Vec>> = product.into_iter().collect(); self.max_multiplicands = self.max_multiplicands.max(product.len()); assert_eq!(product.len(), op_coefficient.len()); - assert_eq!(product.len(), op_coefficient.len()); assert!(!product.is_empty()); let mut indexed_product: Vec = Vec::with_capacity(op_coefficient.len()); let mut linear_ops = Vec::with_capacity(op_coefficient.len()); @@ -96,7 +101,7 @@ impl ListOfProductsOfPolynomials { m.num_vars, self.num_variables, "product has a multiplicand with wrong number of variables" ); - let m_ptr: *const DenseMultilinearExtension = Rc::as_ptr(m); + let m_ptr: *const DenseMultilinearExtension = Rc::as_ptr(m); if let Some(index) = self.raw_pointers_lookup_table.get(&m_ptr) { indexed_product.push(*index); linear_ops.push((*a, *b)); @@ -118,27 +123,27 @@ impl ListOfProductsOfPolynomials { #[inline] pub fn add_product( &mut self, - product: impl IntoIterator>>, - coefficient: F, + product: impl IntoIterator>>, + coefficient: EF, ) { - let product: Vec>> = product.into_iter().collect(); - let mut linear_ops: Vec<(F, F)> = Vec::with_capacity(product.len()); + let product: Vec>> = product.into_iter().collect(); + let mut linear_ops: Vec<(UF, UF)> = Vec::with_capacity(product.len()); for _ in 0..product.len() { - linear_ops.push((F::one(), F::zero())); + linear_ops.push((UF::BaseField(F::one()), UF::BaseField(F::zero()))); } self.add_product_with_linear_op(product, &linear_ops, coefficient); } /// Evaluate the polynomial at point `point` - pub fn evaluate(&self, point: &[F]) -> F { - self.products - .iter() - .zip(self.linear_ops.iter()) - .fold(F::zero(), |result, ((c, p), ops)| { + pub fn evaluate(&self, point: &[EF]) -> EF { + self.products.iter().zip(self.linear_ops.iter()).fold( + EF::zero(), + |result, ((c, p), ops)| { result - + p.iter().zip(ops.iter()).fold(*c, |acc, (&i, &(a, b))| { - acc * (self.flattened_ml_extensions[i].evaluate(point) * a + b) + + p.iter().zip(ops.iter()).fold(*c, |acc, (i, &(a, b))| { + acc * (b + a * self.flattened_ml_extensions[*i].evaluate(point)) }) - }) + }, + ) } } diff --git a/algebra/src/polynomial/multivariate/mod.rs b/algebra/src/polynomial/multivariate/mod.rs index 60caa360..b0d9ad22 100644 --- a/algebra/src/polynomial/multivariate/mod.rs +++ b/algebra/src/polynomial/multivariate/mod.rs @@ -2,4 +2,8 @@ mod data_structures; mod multilinear; pub use data_structures::{ListOfProductsOfPolynomials, PolynomialInfo}; -pub use multilinear::{DenseMultilinearExtension, MultilinearExtension, SparsePolynomial}; +pub use multilinear::UF; +pub use multilinear::{ + DenseMultilinearExtension, DenseMultilinearExtensionBase, MultilinearExtension, + MultilinearExtensionBase, SparsePolynomial, +}; diff --git a/algebra/src/polynomial/multivariate/multilinear/dense.rs b/algebra/src/polynomial/multivariate/multilinear/dense.rs index 079cad8c..acb3c7e7 100644 --- a/algebra/src/polynomial/multivariate/multilinear/dense.rs +++ b/algebra/src/polynomial/multivariate/multilinear/dense.rs @@ -1,32 +1,34 @@ // It is derived from https://github.com/arkworks-rs/sumcheck. use std::fmt::Debug; +use std::marker::PhantomData; use std::ops::{Add, AddAssign, Index, Neg, Sub, SubAssign}; use std::slice::{Iter, IterMut}; use num_traits::Zero; use rand_distr::Distribution; -use crate::{AbstractExtensionField, DecomposableField, Field, FieldUniformSampler}; +use crate::{AbstractExtensionField, Field, FieldUniformSampler}; +use super::dense_base::DenseMultilinearExtensionBase; use super::MultilinearExtension; -use std::rc::Rc; /// Stores a multilinear polynomial in dense evaluation form. #[derive(Clone, Default, PartialEq, Eq)] -pub struct DenseMultilinearExtension { +pub struct DenseMultilinearExtension> { /// The evaluation over {0,1}^`num_vars` - pub evaluations: Vec, + pub evaluations: Vec, /// Number of variables pub num_vars: usize, + _marker: PhantomData, } -impl DenseMultilinearExtension { +impl> DenseMultilinearExtension { /// Construct a new polynomial from a list of evaluations where the index /// represents a point in {0,1}^`num_vars` in little endian form. For /// example, `0b1011` represents `P(1,1,0,1)` #[inline] - pub fn from_evaluations_slice(num_vars: usize, evaluations: &[F]) -> Self { + pub fn from_evaluations_slice(num_vars: usize, evaluations: &[EF]) -> Self { assert_eq!( evaluations.len(), 1 << num_vars, @@ -39,7 +41,7 @@ impl DenseMultilinearExtension { /// represents a point in {0,1}^`num_vars` in little endian form. For /// example, `0b1011` represents `P(1,1,0,1)` #[inline] - pub fn from_evaluations_vec(num_vars: usize, evaluations: Vec) -> Self { + pub fn from_evaluations_vec(num_vars: usize, evaluations: Vec) -> Self { assert_eq!( evaluations.len(), 1 << num_vars, @@ -49,18 +51,33 @@ impl DenseMultilinearExtension { Self { num_vars, evaluations, + _marker: PhantomData, + } + } + + /// Construct a new polynomial from DenseMultilinearExtensionBase where the evaluations are in Field + #[inline] + pub fn from_base(mle_base: &DenseMultilinearExtensionBase) -> Self { + Self { + num_vars: mle_base.num_vars, + evaluations: mle_base + .evaluations + .iter() + .map(|x| EF::from_base(*x)) + .collect(), + _marker: PhantomData, } } /// Returns an iterator that iterates over the evaluations over {0,1}^`num_vars` #[inline] - pub fn iter(&self) -> Iter<'_, F> { + pub fn iter(&self) -> Iter<'_, EF> { self.evaluations.iter() } /// Returns a mutable iterator that iterates over the evaluations over {0,1}^`num_vars` #[inline] - pub fn iter_mut(&mut self) -> IterMut<'_, F> { + pub fn iter_mut(&mut self) -> IterMut<'_, EF> { self.evaluations.iter_mut() } @@ -78,74 +95,12 @@ impl DenseMultilinearExtension { ); (left, right) } - - /// Evaluate a point in the extension field. - #[inline] - pub fn evaluate_ext(&self, ext_point: &[EF]) -> EF - where - EF: AbstractExtensionField, - { - assert_eq!(ext_point.len(), self.num_vars, "The point size is invalid."); - let mut poly: Vec<_> = self - .evaluations - .iter() - .map(|&eval| EF::from_base(eval)) - .collect(); - let nv = self.num_vars; - let dim = ext_point.len(); - // evaluate nv variable of partial point from left to right - // with dim rounds and \sum_{i=1}^{dim} 2^(nv - i) - // (If dim = nv, then the complexity is 2^{nv}.) - for i in 1..dim + 1 { - // fix a single variable to evaluate (1 << (nv - i)) evaluations from the last round - // with complexity of 2^(1 << (nv - i)) field multiplications - let r = ext_point[i - 1]; - for b in 0..(1 << (nv - i)) { - let left = poly[b << 1]; - let right = poly[(b << 1) + 1]; - poly[b] = r * (right - left) + left; - } - } - poly.truncate(1 << (nv - dim)); - poly[0] - } } -impl DenseMultilinearExtension { - /// Decompose bits of each evaluation of the origianl MLE. - /// The bit deomposition is only applied for power-of-two base. - /// * base_len: the length of base, i.e. log_2(base) - /// * bits_len: the lenth of decomposed bits - /// - /// The resulting decomposition bits are respectively wrapped into `Rc` struct, which can be more easilier added into the ListsOfProducts. - #[inline] - pub fn get_decomposed_mles( - &self, - base_len: u32, - bits_len: u32, - ) -> Vec>> { - let mut val = self.evaluations.clone(); - let mask = F::mask(base_len); - - let mut bits = Vec::with_capacity(bits_len as usize); - - // extract `base_len` bits as one "bit" at a time - for _ in 0..bits_len { - let mut bit = vec![F::zero(); self.evaluations.len()]; - bit.iter_mut().zip(val.iter_mut()).for_each(|(b_i, v_i)| { - v_i.decompose_lsb_bits_at(b_i, mask, base_len); - }); - bits.push(Rc::new(DenseMultilinearExtension::from_evaluations_vec( - self.num_vars, - bit, - ))); - } - bits - } -} - -impl MultilinearExtension for DenseMultilinearExtension { - type Point = [F]; +impl> MultilinearExtension + for DenseMultilinearExtension +{ + type Point = [EF]; #[inline] fn num_vars(&self) -> usize { @@ -153,7 +108,7 @@ impl MultilinearExtension for DenseMultilinearExtension { } #[inline] - fn evaluate(&self, point: &Self::Point) -> F { + fn evaluate(&self, point: &Self::Point) -> EF { assert_eq!(point.len(), self.num_vars, "The point size is invalid."); self.fix_variables(point)[0] } @@ -169,10 +124,11 @@ impl MultilinearExtension for DenseMultilinearExtension { .sample_iter(rng) .take(1 << num_vars) .collect(), + _marker: PhantomData, } } - fn fix_variables(&self, partial_point: &[F]) -> Self { + fn fix_variables(&self, partial_point: &[EF]) -> Self { assert!( partial_point.len() <= self.num_vars, "invalid size of partial point" @@ -198,13 +154,13 @@ impl MultilinearExtension for DenseMultilinearExtension { } #[inline] - fn to_evaluations(&self) -> Vec { + fn to_evaluations(&self) -> Vec { self.evaluations.to_vec() } } -impl Index for DenseMultilinearExtension { - type Output = F; +impl> Index for DenseMultilinearExtension { + type Output = EF; /// Returns the evaluation of the polynomial at a point represented by index. /// @@ -218,7 +174,7 @@ impl Index for DenseMultilinearExtension { } } -impl Debug for DenseMultilinearExtension { +impl> Debug for DenseMultilinearExtension { #[inline] fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> { write!(f, "DenseML(nv = {}, evaluations = [", self.num_vars)?; @@ -234,12 +190,13 @@ impl Debug for DenseMultilinearExtension { } } -impl Zero for DenseMultilinearExtension { +impl> Zero for DenseMultilinearExtension { #[inline] fn zero() -> Self { Self { num_vars: 0, - evaluations: vec![F::zero()], + evaluations: vec![EF::zero()], + _marker: PhantomData, } } @@ -249,29 +206,33 @@ impl Zero for DenseMultilinearExtension { } } -impl Add for DenseMultilinearExtension { - type Output = DenseMultilinearExtension; +impl> Add for DenseMultilinearExtension { + type Output = DenseMultilinearExtension; #[inline] - fn add(mut self, rhs: DenseMultilinearExtension) -> Self { + fn add(mut self, rhs: DenseMultilinearExtension) -> Self { self.add_assign(rhs); self } } -impl<'a, F: Field> Add<&'a DenseMultilinearExtension> for DenseMultilinearExtension { - type Output = DenseMultilinearExtension; +impl<'a, F: Field, EF: AbstractExtensionField> Add<&'a DenseMultilinearExtension> + for DenseMultilinearExtension +{ + type Output = DenseMultilinearExtension; #[inline] - fn add(mut self, rhs: &'a DenseMultilinearExtension) -> Self::Output { + fn add(mut self, rhs: &'a DenseMultilinearExtension) -> Self::Output { self.add_assign(rhs); self } } -impl<'a, 'b, F: Field> Add<&'a DenseMultilinearExtension> for &'b DenseMultilinearExtension { - type Output = DenseMultilinearExtension; +impl<'a, 'b, F: Field, EF: AbstractExtensionField> Add<&'a DenseMultilinearExtension> + for &'b DenseMultilinearExtension +{ + type Output = DenseMultilinearExtension; #[inline] - fn add(self, rhs: &'a DenseMultilinearExtension) -> Self::Output { + fn add(self, rhs: &'a DenseMultilinearExtension) -> Self::Output { // handle constant zero case if rhs.is_zero() { return self.clone(); @@ -280,49 +241,51 @@ impl<'a, 'b, F: Field> Add<&'a DenseMultilinearExtension> for &'b DenseMultil return rhs.clone(); } assert_eq!(self.num_vars, rhs.num_vars); - let result: Vec = self.iter().zip(rhs.iter()).map(|(&a, b)| a + b).collect(); + let result: Vec = self.iter().zip(rhs.iter()).map(|(&a, b)| a + b).collect(); Self::Output::from_evaluations_vec(self.num_vars, result) } } -impl AddAssign for DenseMultilinearExtension { +impl> AddAssign for DenseMultilinearExtension { #[inline] fn add_assign(&mut self, rhs: Self) { self.iter_mut().zip(rhs.iter()).for_each(|(x, y)| *x += y); } } -impl<'a, F: Field> AddAssign<&'a DenseMultilinearExtension> for DenseMultilinearExtension { +impl<'a, F: Field, EF: AbstractExtensionField> AddAssign<&'a DenseMultilinearExtension> + for DenseMultilinearExtension +{ #[inline] - fn add_assign(&mut self, rhs: &'a DenseMultilinearExtension) { + fn add_assign(&mut self, rhs: &'a DenseMultilinearExtension) { self.iter_mut().zip(rhs.iter()).for_each(|(x, y)| *x += y); } } -impl<'a, F: Field> AddAssign<(F, &'a DenseMultilinearExtension)> - for DenseMultilinearExtension +impl<'a, F: Field, EF: AbstractExtensionField> + AddAssign<(EF, &'a DenseMultilinearExtension)> for DenseMultilinearExtension { #[inline] - fn add_assign(&mut self, (f, rhs): (F, &'a DenseMultilinearExtension)) { + fn add_assign(&mut self, (f, rhs): (EF, &'a DenseMultilinearExtension)) { self.iter_mut() .zip(rhs.iter()) .for_each(|(x, y)| *x += f.mul(y)); } } -impl<'a, F: Field> AddAssign<(F, &'a Rc>)> - for DenseMultilinearExtension +impl<'a, F: Field, EF: AbstractExtensionField> + AddAssign<(EF, &'a DenseMultilinearExtensionBase)> for DenseMultilinearExtension { #[inline] - fn add_assign(&mut self, (f, rhs): (F, &'a Rc>)) { + fn add_assign(&mut self, (f, rhs): (EF, &'a DenseMultilinearExtensionBase)) { self.iter_mut() .zip(rhs.iter()) - .for_each(|(x, y)| *x += f.mul(y)); + .for_each(|(x, y)| *x += f * *y); } } -impl Neg for DenseMultilinearExtension { - type Output = DenseMultilinearExtension; +impl> Neg for DenseMultilinearExtension { + type Output = DenseMultilinearExtension; #[inline] fn neg(mut self) -> Self::Output { @@ -331,8 +294,8 @@ impl Neg for DenseMultilinearExtension { } } -impl Sub for DenseMultilinearExtension { - type Output = DenseMultilinearExtension; +impl> Sub for DenseMultilinearExtension { + type Output = DenseMultilinearExtension; #[inline] fn sub(mut self, rhs: Self) -> Self { @@ -341,21 +304,25 @@ impl Sub for DenseMultilinearExtension { } } -impl<'a, F: Field> Sub<&'a DenseMultilinearExtension> for DenseMultilinearExtension { - type Output = DenseMultilinearExtension; +impl<'a, F: Field, EF: AbstractExtensionField> Sub<&'a DenseMultilinearExtension> + for DenseMultilinearExtension +{ + type Output = DenseMultilinearExtension; #[inline] - fn sub(mut self, rhs: &'a DenseMultilinearExtension) -> Self::Output { + fn sub(mut self, rhs: &'a DenseMultilinearExtension) -> Self::Output { self.sub_assign(rhs); self } } -impl<'a, 'b, F: Field> Sub<&'a DenseMultilinearExtension> for &'b DenseMultilinearExtension { - type Output = DenseMultilinearExtension; +impl<'a, 'b, F: Field, EF: AbstractExtensionField> Sub<&'a DenseMultilinearExtension> + for &'b DenseMultilinearExtension +{ + type Output = DenseMultilinearExtension; #[inline] - fn sub(self, rhs: &'a DenseMultilinearExtension) -> Self::Output { + fn sub(self, rhs: &'a DenseMultilinearExtension) -> Self::Output { // handle constant zero case if rhs.is_zero() { return self.clone(); @@ -364,21 +331,23 @@ impl<'a, 'b, F: Field> Sub<&'a DenseMultilinearExtension> for &'b DenseMultil return rhs.clone(); } assert_eq!(self.num_vars, rhs.num_vars); - let result: Vec = self.iter().zip(rhs.iter()).map(|(&a, b)| a - b).collect(); + let result: Vec = self.iter().zip(rhs.iter()).map(|(&a, b)| a - b).collect(); Self::Output::from_evaluations_vec(self.num_vars, result) } } -impl SubAssign for DenseMultilinearExtension { +impl> SubAssign for DenseMultilinearExtension { #[inline] fn sub_assign(&mut self, rhs: Self) { self.iter_mut().zip(rhs.iter()).for_each(|(x, y)| *x -= y); } } -impl<'a, F: Field> SubAssign<&'a DenseMultilinearExtension> for DenseMultilinearExtension { +impl<'a, F: Field, EF: AbstractExtensionField> SubAssign<&'a DenseMultilinearExtension> + for DenseMultilinearExtension +{ #[inline] - fn sub_assign(&mut self, rhs: &'a DenseMultilinearExtension) { + fn sub_assign(&mut self, rhs: &'a DenseMultilinearExtension) { self.iter_mut().zip(rhs.iter()).for_each(|(x, y)| *x -= y); } } diff --git a/algebra/src/polynomial/multivariate/multilinear/dense_base.rs b/algebra/src/polynomial/multivariate/multilinear/dense_base.rs new file mode 100644 index 00000000..30a150ad --- /dev/null +++ b/algebra/src/polynomial/multivariate/multilinear/dense_base.rs @@ -0,0 +1,391 @@ +// It is derived from https://github.com/arkworks-rs/sumcheck. + +use std::fmt::Debug; +use std::ops::{Add, AddAssign, Index, Neg, Sub, SubAssign}; +use std::slice::{Iter, IterMut}; + +use num_traits::Zero; +use rand_distr::Distribution; + +use crate::{AbstractExtensionField, DecomposableField, Field, FieldUniformSampler}; + +use super::MultilinearExtensionBase; +use std::rc::Rc; + +/// Stores a multilinear polynomial in dense evaluation form. +#[derive(Clone, Default, PartialEq, Eq)] +pub struct DenseMultilinearExtensionBase { + /// The evaluation over {0,1}^`num_vars` + pub evaluations: Vec, + /// Number of variables + pub num_vars: usize, +} + +impl DenseMultilinearExtensionBase { + /// Construct a new polynomial from a list of evaluations where the index + /// represents a point in {0,1}^`num_vars` in little endian form. For + /// example, `0b1011` represents `P(1,1,0,1)` + #[inline] + pub fn from_evaluations_slice(num_vars: usize, evaluations: &[F]) -> Self { + assert_eq!( + evaluations.len(), + 1 << num_vars, + "The size of evaluations should be 2^num_vars." + ); + Self::from_evaluations_vec(num_vars, evaluations.to_vec()) + } + + /// Construct a new polynomial from a list of evaluations where the index + /// represents a point in {0,1}^`num_vars` in little endian form. For + /// example, `0b1011` represents `P(1,1,0,1)` + #[inline] + pub fn from_evaluations_vec(num_vars: usize, evaluations: Vec) -> Self { + assert_eq!( + evaluations.len(), + 1 << num_vars, + "The size of evaluations should be 2^num_vars." + ); + + Self { + num_vars, + evaluations, + } + } + + /// Returns an iterator that iterates over the evaluations over {0,1}^`num_vars` + #[inline] + pub fn iter(&self) -> Iter<'_, F> { + self.evaluations.iter() + } + + /// Returns a mutable iterator that iterates over the evaluations over {0,1}^`num_vars` + #[inline] + pub fn iter_mut(&mut self) -> IterMut<'_, F> { + self.evaluations.iter_mut() + } + + /// Split the mle into two mles with one less variable, eliminating the far right variable + /// original evaluations: f(x, b) for x \in \{0, 1\}^{k-1} and b\{0, 1\} + /// resulting two mles: f0(x) = f(x, 0) for x \in \{0, 1\}^{k-1} and f1(x) = f(x, 1) for x \in \{0, 1\}^{k-1} + pub fn split_halves(&self) -> (Self, Self) { + let left = Self::from_evaluations_slice( + self.num_vars - 1, + &self.evaluations[0..1 << (self.num_vars - 1)], + ); + let right = Self::from_evaluations_slice( + self.num_vars - 1, + &self.evaluations[1 << (self.num_vars - 1)..], + ); + (left, right) + } + + /// Evaluate a point in the extension field. + #[inline] + pub fn evaluate_ext(&self, ext_point: &[EF]) -> EF + where + EF: AbstractExtensionField, + { + assert_eq!(ext_point.len(), self.num_vars, "The point size is invalid."); + let mut poly: Vec<_> = self + .evaluations + .iter() + .map(|&eval| EF::from_base(eval)) + .collect(); + let nv = self.num_vars; + let dim = ext_point.len(); + // evaluate nv variable of partial point from left to right + // with dim rounds and \sum_{i=1}^{dim} 2^(nv - i) + // (If dim = nv, then the complexity is 2^{nv}.) + for i in 1..dim + 1 { + // fix a single variable to evaluate (1 << (nv - i)) evaluations from the last round + // with complexity of 2^(1 << (nv - i)) field multiplications + let r = ext_point[i - 1]; + for b in 0..(1 << (nv - i)) { + let left = poly[b << 1]; + let right = poly[(b << 1) + 1]; + poly[b] = r * (right - left) + left; + } + } + poly.truncate(1 << (nv - dim)); + poly[0] + } +} + +impl DenseMultilinearExtensionBase { + /// Decompose bits of each evaluation of the origianl MLE. + /// The bit deomposition is only applied for power-of-two base. + /// * base_len: the length of base, i.e. log_2(base) + /// * bits_len: the lenth of decomposed bits + /// + /// The resulting decomposition bits are respectively wrapped into `Rc` struct, which can be more easilier added into the ListsOfProducts. + #[inline] + pub fn get_decomposed_mles( + &self, + base_len: u32, + bits_len: u32, + ) -> Vec>> { + let mut val = self.evaluations.clone(); + let mask = F::mask(base_len); + + let mut bits = Vec::with_capacity(bits_len as usize); + + // extract `base_len` bits as one "bit" at a time + for _ in 0..bits_len { + let mut bit = vec![F::zero(); self.evaluations.len()]; + bit.iter_mut().zip(val.iter_mut()).for_each(|(b_i, v_i)| { + v_i.decompose_lsb_bits_at(b_i, mask, base_len); + }); + bits.push(Rc::new( + DenseMultilinearExtensionBase::from_evaluations_vec(self.num_vars, bit), + )); + } + bits + } +} + +impl MultilinearExtensionBase for DenseMultilinearExtensionBase { + type Point = [F]; + + #[inline] + fn num_vars(&self) -> usize { + self.num_vars + } + + #[inline] + fn evaluate(&self, point: &Self::Point) -> F { + assert_eq!(point.len(), self.num_vars, "The point size is invalid."); + self.fix_variables(point)[0] + } + + #[inline] + fn random(num_vars: usize, rng: &mut R) -> Self + where + R: rand::Rng + rand::CryptoRng, + { + Self { + num_vars, + evaluations: FieldUniformSampler::new() + .sample_iter(rng) + .take(1 << num_vars) + .collect(), + } + } + + fn fix_variables(&self, partial_point: &[F]) -> Self { + assert!( + partial_point.len() <= self.num_vars, + "invalid size of partial point" + ); + let mut poly = self.evaluations.to_vec(); + let nv = self.num_vars; + let dim = partial_point.len(); + // evaluate nv variable of partial point from left to right + // with dim rounds and \sum_{i=1}^{dim} 2^(nv - i) + // (If dim = nv, then the complexity is 2^{nv}.) + for i in 1..dim + 1 { + // fix a single variable to evaluate (1 << (nv - i)) evaluations from the last round + // with complexity of 2^(1 << (nv - i)) field multiplications + let r = partial_point[i - 1]; + for b in 0..(1 << (nv - i)) { + let left = poly[b << 1]; + let right = poly[(b << 1) + 1]; + poly[b] = left + r * (right - left); + } + } + poly.truncate(1 << (nv - dim)); + Self::from_evaluations_vec(nv - dim, poly) + } + + #[inline] + fn to_evaluations(&self) -> Vec { + self.evaluations.to_vec() + } +} + +impl Index for DenseMultilinearExtensionBase { + type Output = F; + + /// Returns the evaluation of the polynomial at a point represented by index. + /// + /// Index represents a vector in {0,1}^`num_vars` in little endian form. For + /// example, `0b1011` represents `P(1,1,0,1)` + /// + /// For dense multilinear polynomial, `index` takes constant time. + #[inline] + fn index(&self, index: usize) -> &Self::Output { + &self.evaluations[index] + } +} + +impl Debug for DenseMultilinearExtensionBase { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> { + write!(f, "DenseML(nv = {}, evaluations = [", self.num_vars)?; + for i in 0..4.min(self.evaluations.len()) { + write!(f, "{:?}", self.evaluations[i])?; + } + if self.evaluations.len() < 4 { + write!(f, "])")?; + } else { + write!(f, "...])")?; + } + Ok(()) + } +} + +impl Zero for DenseMultilinearExtensionBase { + #[inline] + fn zero() -> Self { + Self { + num_vars: 0, + evaluations: vec![F::zero()], + } + } + + #[inline] + fn is_zero(&self) -> bool { + self.num_vars == 0 && self.evaluations[0].is_zero() + } +} + +impl Add for DenseMultilinearExtensionBase { + type Output = DenseMultilinearExtensionBase; + #[inline] + fn add(mut self, rhs: DenseMultilinearExtensionBase) -> Self { + self.add_assign(rhs); + self + } +} + +impl<'a, F: Field> Add<&'a DenseMultilinearExtensionBase> for DenseMultilinearExtensionBase { + type Output = DenseMultilinearExtensionBase; + #[inline] + fn add(mut self, rhs: &'a DenseMultilinearExtensionBase) -> Self::Output { + self.add_assign(rhs); + self + } +} + +impl<'a, 'b, F: Field> Add<&'a DenseMultilinearExtensionBase> + for &'b DenseMultilinearExtensionBase +{ + type Output = DenseMultilinearExtensionBase; + + #[inline] + fn add(self, rhs: &'a DenseMultilinearExtensionBase) -> Self::Output { + // handle constant zero case + if rhs.is_zero() { + return self.clone(); + } + if self.is_zero() { + return rhs.clone(); + } + assert_eq!(self.num_vars, rhs.num_vars); + let result: Vec = self.iter().zip(rhs.iter()).map(|(&a, b)| a + b).collect(); + Self::Output::from_evaluations_vec(self.num_vars, result) + } +} + +impl AddAssign for DenseMultilinearExtensionBase { + #[inline] + fn add_assign(&mut self, rhs: Self) { + self.iter_mut().zip(rhs.iter()).for_each(|(x, y)| *x += y); + } +} + +impl<'a, F: Field> AddAssign<&'a DenseMultilinearExtensionBase> + for DenseMultilinearExtensionBase +{ + #[inline] + fn add_assign(&mut self, rhs: &'a DenseMultilinearExtensionBase) { + self.iter_mut().zip(rhs.iter()).for_each(|(x, y)| *x += y); + } +} + +impl<'a, F: Field> AddAssign<(F, &'a DenseMultilinearExtensionBase)> + for DenseMultilinearExtensionBase +{ + #[inline] + fn add_assign(&mut self, (f, rhs): (F, &'a DenseMultilinearExtensionBase)) { + self.iter_mut() + .zip(rhs.iter()) + .for_each(|(x, y)| *x += f.mul(y)); + } +} + +impl<'a, F: Field> AddAssign<(F, &'a Rc>)> + for DenseMultilinearExtensionBase +{ + #[inline] + fn add_assign(&mut self, (f, rhs): (F, &'a Rc>)) { + self.iter_mut() + .zip(rhs.iter()) + .for_each(|(x, y)| *x += f.mul(y)); + } +} + +impl Neg for DenseMultilinearExtensionBase { + type Output = DenseMultilinearExtensionBase; + + #[inline] + fn neg(mut self) -> Self::Output { + self.evaluations.iter_mut().for_each(|x| *x = -(*x)); + self + } +} + +impl Sub for DenseMultilinearExtensionBase { + type Output = DenseMultilinearExtensionBase; + + #[inline] + fn sub(mut self, rhs: Self) -> Self { + self.sub_assign(rhs); + self + } +} + +impl<'a, F: Field> Sub<&'a DenseMultilinearExtensionBase> for DenseMultilinearExtensionBase { + type Output = DenseMultilinearExtensionBase; + + #[inline] + fn sub(mut self, rhs: &'a DenseMultilinearExtensionBase) -> Self::Output { + self.sub_assign(rhs); + self + } +} + +impl<'a, 'b, F: Field> Sub<&'a DenseMultilinearExtensionBase> + for &'b DenseMultilinearExtensionBase +{ + type Output = DenseMultilinearExtensionBase; + + #[inline] + fn sub(self, rhs: &'a DenseMultilinearExtensionBase) -> Self::Output { + // handle constant zero case + if rhs.is_zero() { + return self.clone(); + } + if self.is_zero() { + return rhs.clone(); + } + assert_eq!(self.num_vars, rhs.num_vars); + let result: Vec = self.iter().zip(rhs.iter()).map(|(&a, b)| a - b).collect(); + Self::Output::from_evaluations_vec(self.num_vars, result) + } +} + +impl SubAssign for DenseMultilinearExtensionBase { + #[inline] + fn sub_assign(&mut self, rhs: Self) { + self.iter_mut().zip(rhs.iter()).for_each(|(x, y)| *x -= y); + } +} + +impl<'a, F: Field> SubAssign<&'a DenseMultilinearExtensionBase> + for DenseMultilinearExtensionBase +{ + #[inline] + fn sub_assign(&mut self, rhs: &'a DenseMultilinearExtensionBase) { + self.iter_mut().zip(rhs.iter()).for_each(|(x, y)| *x -= y); + } +} diff --git a/algebra/src/polynomial/multivariate/multilinear/mod.rs b/algebra/src/polynomial/multivariate/multilinear/mod.rs index 4d4375d7..a19516fc 100644 --- a/algebra/src/polynomial/multivariate/multilinear/mod.rs +++ b/algebra/src/polynomial/multivariate/multilinear/mod.rs @@ -4,13 +4,17 @@ use std::ops::{Add, AddAssign, Index, Neg, Sub, SubAssign}; use num_traits::Zero; -use crate::Field; +use crate::{AbstractExtensionField, Field}; mod dense; +mod dense_base; mod sparse; +mod unified_field; pub use dense::DenseMultilinearExtension; +pub use dense_base::DenseMultilinearExtensionBase; pub use sparse::SparsePolynomial; +pub use unified_field::UF; /// This trait describes an interface for the multilinear extension /// of an array. @@ -19,7 +23,7 @@ pub use sparse::SparsePolynomial; /// /// Index represents a point, which is a vector in {0,1}^`num_vars` in little /// endian form. For example, `0b1011` represents `P(1,1,0,1)` -pub trait MultilinearExtension: +pub trait MultilinearExtensionBase: Sized + Clone + Debug @@ -55,3 +59,47 @@ pub trait MultilinearExtension: /// hypercube. The evaluations are in little-endian order. fn to_evaluations(&self) -> Vec; } + +/// This trait describes an interface for the multilinear extension +/// of an array. +/// The latter is a multilinear polynomial represented in terms of its +/// evaluations over the domain {0,1}^`num_vars` (i.e. the Boolean hypercube). +/// +/// Index represents a point, which is a vector in {0,1}^`num_vars` in little +/// endian form. For example, `0b1011` represents `P(1,1,0,1)` +pub trait MultilinearExtension>: + Sized + + Clone + + Debug + + Zero + + Index + + Add + + Neg + + Sub + + AddAssign + + SubAssign + + for<'a> AddAssign<&'a Self> + + for<'a> AddAssign<(EF, &'a Self)> + + for<'a> SubAssign<&'a Self> +{ + /// The type of evaluation points for this polynomial. + type Point: ?Sized + Debug; + + /// Return the number of variables in `self` + fn num_vars(&self) -> usize; + + /// Evaluates `self` at the given `point` in `Self::Point`. + fn evaluate(&self, point: &Self::Point) -> EF; + + /// Outputs an `l`-variate multilinear extension where value of evaluations + /// are sampled at random. + fn random(num_vars: usize, rng: &mut R) -> Self; + + /// Reduce the number of variables of `self` by fixing the + /// `partial_point.len()` variables at `partial_point`. + fn fix_variables(&self, partial_point: &[EF]) -> Self; + + /// Return a list of evaluations over the domain, which is the boolean + /// hypercube. The evaluations are in little-endian order. + fn to_evaluations(&self) -> Vec; +} diff --git a/algebra/src/polynomial/multivariate/multilinear/sparse.rs b/algebra/src/polynomial/multivariate/multilinear/sparse.rs index fd4c5d96..3b42ae3a 100644 --- a/algebra/src/polynomial/multivariate/multilinear/sparse.rs +++ b/algebra/src/polynomial/multivariate/multilinear/sparse.rs @@ -1,6 +1,8 @@ use std::slice::{Iter, IterMut}; -use crate::{DenseMultilinearExtension, Field}; +use crate::Field; + +use super::DenseMultilinearExtensionBase; /// Sparse polynomial #[derive(Clone, Default, PartialEq, Eq)] @@ -53,11 +55,11 @@ impl SparsePolynomial { /// Transform sparse representation into dense representation #[inline] - pub fn to_dense(&self) -> DenseMultilinearExtension { + pub fn to_dense(&self) -> DenseMultilinearExtensionBase { let mut evaluations = vec![F::zero(); 1 << self.num_vars]; self.evaluations.iter().for_each(|(idx, item)| { evaluations[*idx] = *item; }); - DenseMultilinearExtension::from_evaluations_vec(self.num_vars, evaluations) + DenseMultilinearExtensionBase::from_evaluations_vec(self.num_vars, evaluations) } } diff --git a/algebra/src/polynomial/multivariate/multilinear/unified_field.rs b/algebra/src/polynomial/multivariate/multilinear/unified_field.rs new file mode 100644 index 00000000..fdc55355 --- /dev/null +++ b/algebra/src/polynomial/multivariate/multilinear/unified_field.rs @@ -0,0 +1,78 @@ +use crate::{AbstractExtensionField, Field}; +use std::ops::{Add, Mul, Sub}; + +/// Unified Field +#[derive(Debug, Clone, PartialEq, Eq, Copy)] +pub enum UF> { + /// Base Field Element + BaseField(F), + /// Extension Field Element + ExtensionField(EF), +} + +impl> UF { + /// Return one + pub fn one() -> UF { + UF::BaseField(F::one()) + } + /// Return zero + pub fn zero() -> UF { + UF::BaseField(F::zero()) + } +} + +impl> Mul for UF { + type Output = EF; + fn mul(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (UF::BaseField(l), UF::BaseField(r)) => EF::from_base(l * r), + (UF::BaseField(l), UF::ExtensionField(r)) => r * l, + (UF::ExtensionField(l), UF::BaseField(r)) => l * r, + (UF::ExtensionField(l), UF::ExtensionField(r)) => l * r, + } + } +} + +impl> Add for UF { + type Output = EF; + fn add(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (UF::BaseField(l), UF::BaseField(r)) => EF::from_base(l + r), + (UF::BaseField(l), UF::ExtensionField(r)) => r + l, + (UF::ExtensionField(l), UF::BaseField(r)) => l + r, + (UF::ExtensionField(l), UF::ExtensionField(r)) => l + r, + } + } +} + +impl> Sub for UF { + type Output = EF; + fn sub(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (UF::BaseField(l), UF::BaseField(r)) => EF::from_base(l - r), + (UF::BaseField(l), UF::ExtensionField(r)) => EF::from_base(l) - r, + (UF::ExtensionField(l), UF::BaseField(r)) => l - r, + (UF::ExtensionField(l), UF::ExtensionField(r)) => l - r, + } + } +} + +impl> Mul for UF { + type Output = EF; + fn mul(self, r: EF) -> Self::Output { + match self { + UF::BaseField(l) => r * l, + UF::ExtensionField(l) => l * r, + } + } +} + +impl> Add for UF { + type Output = EF; + fn add(self, r: EF) -> Self::Output { + match self { + UF::BaseField(l) => r + l, + UF::ExtensionField(l) => l + r, + } + } +} diff --git a/algebra/src/utils/transcript.rs b/algebra/src/utils/transcript.rs index e1c7157e..ea884134 100644 --- a/algebra/src/utils/transcript.rs +++ b/algebra/src/utils/transcript.rs @@ -27,80 +27,54 @@ impl Transcript { } } } + impl Transcript { /// Append the message to the transcript. - #[inline] - pub fn append_message(&mut self, msg: &[u8]) { - self.transcript.append_message(b"", msg); - } - - /// Append elements to the transcript. - #[inline] - pub fn append_elements(&mut self, elems: &[F]) { - self.append_message(&bincode::serialize(elems).unwrap()); - } - - /// Append extension field elements to the transcript. - #[inline] - pub fn append_ext_field_elements>(&mut self, elems: &[EF]) { - let elems: Vec = elems - .iter() - .flat_map(|x| x.as_base_slice()) - .cloned() - .collect(); - self.append_message(&bincode::serialize(&elems).unwrap()); + pub fn append_message(&mut self, label: &'static [u8], msg: &M) { + self.transcript + .append_message(label, &bincode::serialize(msg).unwrap()); } /// Generate the challenge bytes from the current transcript #[inline] - pub fn get_challenge_bytes(&mut self, bytes: &mut [u8]) { - self.transcript.challenge_bytes(b"", bytes); + pub fn get_challenge_bytes(&mut self, label: &'static [u8], bytes: &mut [u8]) { + self.transcript.challenge_bytes(label, bytes); } /// Generate the challenge from the current transcript - /// and append it to the transcript. - pub fn get_and_append_challenge(&mut self) -> F { + pub fn get_challenge(&mut self, label: &'static [u8]) -> F { let mut seed = [0u8; 16]; - self.transcript.challenge_bytes(b"", &mut seed); + self.transcript.challenge_bytes(label, &mut seed); let mut prg = Prg::from_seed(Block::from(seed)); - let challenge: F = self.sampler.sample(&mut prg); - self.append_message(&bincode::serialize(&challenge).unwrap()); - - challenge + self.sampler.sample(&mut prg) } /// Generate the challenge vector from the current transcript - /// and append it to the transcript. - pub fn get_vec_and_append_challenge(&mut self, num: usize) -> Vec { + pub fn get_vec_challenge(&mut self, label: &'static [u8], num: usize) -> Vec { let mut seed = [0u8; 16]; - self.transcript.challenge_bytes(b"", &mut seed); + self.transcript.challenge_bytes(label, &mut seed); let mut prg = Prg::from_seed(Block::from(seed)); - let challenge = self.sampler.sample_iter(&mut prg).take(num).collect(); - self.append_message(&bincode::serialize(&challenge).unwrap()); - - challenge + self.sampler.sample_iter(&mut prg).take(num).collect() } /// Generate the challenge for extension field from the current transcript - /// and append it to the transcript. #[inline] - pub fn get_ext_field_and_append_challenge(&mut self) -> EF + pub fn get_ext_field_challenge(&mut self, label: &'static [u8]) -> EF where EF: AbstractExtensionField, { - let value = self.get_vec_and_append_challenge(EF::D); + let value = self.get_vec_challenge(label, EF::D); EF::from_base_slice(&value) } /// Generate the challenge vector for extension field from the current transcript - /// and append it to the transcript. #[inline] - pub fn get_vec_ext_field_and_append_challenge(&mut self, num: usize) -> Vec + pub fn get_vec_ext_field_challenge(&mut self, label: &'static [u8], num: usize) -> Vec where EF: AbstractExtensionField, { - let challenges = self.get_vec_and_append_challenge(num * EF::D); + let challenges = self.get_vec_challenge(label, num * EF::D); challenges .chunks_exact(EF::D) .map(|ext| EF::from_base_slice(ext)) diff --git a/algebra/tests/multivariate_test.rs b/algebra/tests/multivariate_test.rs index a3f514f1..2a8ab9cd 100644 --- a/algebra/tests/multivariate_test.rs +++ b/algebra/tests/multivariate_test.rs @@ -2,10 +2,11 @@ use std::vec; use algebra::{ derive::{DecomposableField, Field, Prime}, - Basis, DenseMultilinearExtension, Field, FieldUniformSampler, ListOfProductsOfPolynomials, - MultilinearExtension, + BabyBear, BabyBearExetension, Basis, DenseMultilinearExtension, DenseMultilinearExtensionBase, + Field, FieldUniformSampler, ListOfProductsOfPolynomials, MultilinearExtension, + MultilinearExtensionBase, UF, }; -use num_traits::{Pow, Zero}; +use num_traits::{One, Zero}; use rand::thread_rng; use rand_distr::Distribution; use std::rc::Rc; @@ -19,15 +20,21 @@ macro_rules! field_vec { } } +fn uf_new(val: u32) -> UF { + UF::BaseField(FF::new(val)) +} + #[derive(Field, DecomposableField, Prime)] #[modulus = 132120577] pub struct Fp32(u32); // field type -type FF = Fp32; -type PolyFf = DenseMultilinearExtension; +type FF = BabyBear; +type EF = BabyBearExetension; +type PolyFF = DenseMultilinearExtensionBase; +type PolyEF = DenseMultilinearExtension; -fn evaluate_mle_data_array(data: &[F], point: &[F]) -> F { +fn evaluate_mle_data_array(data: &[EF], point: &[EF]) -> EF { if data.len() != (1 << point.len()) { panic!("Data size mismatch with number of variables.") } @@ -37,7 +44,7 @@ fn evaluate_mle_data_array(data: &[F], point: &[F]) -> F { for i in 1..nv + 1 { let r = point[i - 1]; for b in 0..(1 << (nv - i)) { - a[b] = a[b << 1] * (F::one() - r) + a[(b << 1) + 1] * r; + a[b] = a[b << 1] * (EF::one() - r) + a[(b << 1) + 1] * r; } } @@ -46,17 +53,17 @@ fn evaluate_mle_data_array(data: &[F], point: &[F]) -> F { #[test] fn evaluate_mle_at_a_point() { - let poly = PolyFf::from_evaluations_vec(2, field_vec! {FF; 1, 2, 3, 4}); + let poly = PolyEF::from_evaluations_vec(2, field_vec! {EF; 1, 2, 3, 4}); - let point = vec![FF::new(0), FF::new(1)]; - assert_eq!(poly.evaluate(&point), FF::new(3)); + let point = vec![EF::new(0), EF::new(1)]; + assert_eq!(poly.evaluate(&point), EF::new(3)); } #[test] fn evaluate_mle_at_a_random_point() { let mut rng = thread_rng(); - let poly = PolyFf::random(2, &mut rng); - let uniform = >::new(); + let poly = PolyEF::random(2, &mut rng); + let uniform = >::new(); let point: Vec<_> = (0..2).map(|_| uniform.sample(&mut rng)).collect(); assert_eq!( poly.evaluate(&point), @@ -68,11 +75,12 @@ fn evaluate_mle_at_a_random_point() { fn mle_arithmetic() { const NV: usize = 10; let mut rng = thread_rng(); - let uniform = >::new(); + let uniform = >::new(); + let uniform_ff = >::new(); for _ in 0..20 { let point: Vec<_> = (0..NV).map(|_| uniform.sample(&mut rng)).collect(); - let poly1 = PolyFf::random(NV, &mut rng); - let poly2 = PolyFf::random(NV, &mut rng); + let poly1 = PolyEF::random(NV, &mut rng); + let poly2 = PolyEF::random(NV, &mut rng); let v1 = poly1.evaluate(&point); let v2 = poly2.evaluate(&point); // test add @@ -102,25 +110,33 @@ fn mle_arithmetic() { } // test additive identity { - assert_eq!(&poly1 + &PolyFf::zero(), poly1); - assert_eq!((&PolyFf::zero() + &poly1), poly1); + assert_eq!(&poly1 + &PolyEF::zero(), poly1); + assert_eq!((&PolyEF::zero() + &poly1), poly1); } // test decomposition of mle { + let poly_decomposed = PolyFF::random(NV, &mut rng); let base_len = 3; let base = FF::new(1 << base_len); let basis = >::new(base_len); let bits_len = basis.decompose_len(); - let decomposed_polys = poly1.get_decomposed_mles(base_len, bits_len as u32); - let point: Vec<_> = (0..NV).map(|_| uniform.sample(&mut rng)).collect(); + let decomposed_polys = poly_decomposed.get_decomposed_mles(base_len, bits_len as u32); + let point: Vec<_> = (0..NV).map(|_| uniform_ff.sample(&mut rng)).collect(); + + // base_pow = [1, B, ..., B^{l-1}] + let mut base_pow = vec![FF::one(); bits_len]; + base_pow.iter_mut().fold(FF::one(), |acc, pow| { + *pow *= acc; + acc * base + }); let evaluation = decomposed_polys .iter() - .enumerate() - .fold(FF::zero(), |acc, (i, bit)| { - acc + bit.evaluate(&point) * base.pow(i as u32) + .zip(base_pow.into_iter()) + .fold(FF::zero(), |acc, (bit, base_pow)| { + acc + bit.evaluate(&point) * base_pow }); - assert_eq!(poly1.evaluate(&point), evaluation); + assert_eq!(poly_decomposed.evaluate(&point), evaluation); } } } @@ -133,18 +149,25 @@ fn trivial_decomposed_mles() { let num_vars = 2; let val = field_vec!(FF; 0b001101, 0b100011, 0b101100, 0b111110); - let poly = DenseMultilinearExtension::from_evaluations_vec(num_vars, val); + let poly = DenseMultilinearExtensionBase::from_evaluations_vec(num_vars, val); let decomposed_polys = poly.get_decomposed_mles(base_len, bits_len); let uniform = >::new(); - let point: Vec<_> = (0..num_vars) + let point: Vec = (0..num_vars) .map(|_| uniform.sample(&mut thread_rng())) .collect(); + + // base_pow = [1, B, ..., B^{l-1}] + let mut base_pow = vec![FF::one(); bits_len as usize]; + base_pow.iter_mut().fold(FF::one(), |acc, pow| { + *pow *= acc; + acc * base + }); let evaluation = decomposed_polys .iter() - .enumerate() - .fold(FF::zero(), |acc, (i, bit)| { - acc + bit.evaluate(&point) * base.pow(i as u32) + .zip(base_pow) + .fold(FF::zero(), |acc, (bit, base_pow)| { + acc + bit.evaluate(&point) * base_pow }); assert_eq!(poly.evaluate(&point), evaluation); @@ -154,32 +177,33 @@ fn trivial_decomposed_mles() { fn evaluate_lists_of_products_at_a_point() { let nv = 2; let mut poly = ListOfProductsOfPolynomials::new(nv); - let products = vec![field_vec!(FF; 1, 2, 3, 4), field_vec!(FF; 5, 4, 2, 9)]; - let products: Vec>> = products + let products = vec![field_vec!(EF; 1, 2, 3, 4), field_vec!(EF; 5, 4, 2, 9)]; + let products: Vec>> = products .into_iter() .map(|x| Rc::new(DenseMultilinearExtension::from_evaluations_vec(nv, x))) .collect(); - let coeff = FF::new(4); + let coeff = EF::new(4); poly.add_product(products, coeff); - let point = field_vec!(FF; 0, 1); - assert_eq!(poly.evaluate(&point), FF::new(24)); + let point = field_vec!(EF; 0, 1); + assert_eq!(poly.evaluate(&point), EF::new(24)); } #[test] fn evaluate_lists_of_products_with_op_at_a_point() { let nv = 2; let mut poly = ListOfProductsOfPolynomials::new(nv); - let products = vec![field_vec!(FF; 1, 2, 3, 4), field_vec!(FF; 1, 2, 3, 4)]; - let products: Vec>> = products + let products = vec![field_vec!(EF; 1, 2, 3, 4), field_vec!(EF; 1, 2, 3, 4)]; + let products: Vec>> = products .into_iter() .map(|x| Rc::new(DenseMultilinearExtension::from_evaluations_vec(nv, x))) .collect(); - let coeff = FF::new(4); - let op_coefficient = vec![(FF::new(2), FF::new(0)), (FF::new(1), FF::new(3))]; + let coeff = EF::new(4); + + let op_coefficient = vec![(uf_new(2), uf_new(0)), (uf_new(1), uf_new(3))]; // coeff \cdot [2f \cdot (g + 3)] poly.add_product_with_linear_op(products, &op_coefficient, coeff); // 4 * [2*2 * (2+3)] = 80 - let point = field_vec!(FF; 1, 0); - assert_eq!(poly.evaluate(&point), FF::new(80)); + let point = field_vec!(EF; 1, 0); + assert_eq!(poly.evaluate(&point), EF::new(80)); } diff --git a/pcs/benches/brakedown_pcs.rs b/pcs/benches/brakedown_pcs.rs index 75e9915b..5e41636b 100644 --- a/pcs/benches/brakedown_pcs.rs +++ b/pcs/benches/brakedown_pcs.rs @@ -1,7 +1,8 @@ use std::time::Duration; use algebra::{ - utils::Transcript, BabyBear, BabyBearExetension, DenseMultilinearExtension, FieldUniformSampler, + utils::Transcript, BabyBear, BabyBearExetension, DenseMultilinearExtensionBase, + FieldUniformSampler, }; use criterion::{criterion_group, criterion_main, Criterion}; use pcs::{ @@ -26,7 +27,7 @@ pub fn criterion_benchmark(c: &mut Criterion) { .take(1 << num_vars) .collect(); - let poly = DenseMultilinearExtension::from_evaluations_vec(num_vars, evaluations); + let poly = DenseMultilinearExtensionBase::from_evaluations_vec(num_vars, evaluations); let code_spec = ExpanderCodeSpec::new(0.1195, 0.0284, 1.9, BASE_FIELD_BITS, 10); diff --git a/pcs/examples/brakedown_pcs.rs b/pcs/examples/brakedown_pcs.rs index c3db6bda..999a360f 100644 --- a/pcs/examples/brakedown_pcs.rs +++ b/pcs/examples/brakedown_pcs.rs @@ -1,7 +1,8 @@ use std::time::Instant; use algebra::{ - utils::Transcript, BabyBear, BabyBearExetension, DenseMultilinearExtension, FieldUniformSampler, + utils::Transcript, BabyBear, BabyBearExetension, DenseMultilinearExtensionBase, + FieldUniformSampler, }; use pcs::{ multilinear::brakedown::BrakedownPCS, @@ -23,7 +24,7 @@ fn main() { .take(1 << num_vars) .collect(); - let poly = DenseMultilinearExtension::from_evaluations_vec(num_vars, evaluations); + let poly = DenseMultilinearExtensionBase::from_evaluations_vec(num_vars, evaluations); let code_spec = ExpanderCodeSpec::new(0.1195, 0.0284, 1.9, BASE_FIELD_BITS, 10); diff --git a/pcs/src/lib.rs b/pcs/src/lib.rs index c05c307f..c8f4a924 100644 --- a/pcs/src/lib.rs +++ b/pcs/src/lib.rs @@ -8,7 +8,7 @@ pub mod multilinear; /// utils, mainly used to implement linear time encodable code now pub mod utils; -use algebra::{utils::Transcript, Field, MultilinearExtension}; +use algebra::{utils::Transcript, Field, MultilinearExtensionBase}; // type Point =

>::Point; @@ -17,7 +17,7 @@ pub trait PolynomialCommitmentScheme { /// System parameters type Parameters; /// polynomial to commit - type Polynomial: MultilinearExtension; + type Polynomial: MultilinearExtensionBase; /// commitment type Commitment; /// Auxiliary state of the commitment, output by the `commit` phase. diff --git a/pcs/src/multilinear/brakedown/mod.rs b/pcs/src/multilinear/brakedown/mod.rs index 2498ec3c..b2e2310b 100644 --- a/pcs/src/multilinear/brakedown/mod.rs +++ b/pcs/src/multilinear/brakedown/mod.rs @@ -8,7 +8,7 @@ pub use data_structure::{ use algebra::{ utils::{Block, Prg, Transcript}, - AbstractExtensionField, DenseMultilinearExtension, Field, + AbstractExtensionField, DenseMultilinearExtensionBase, Field, }; use itertools::Itertools; use rand::SeedableRng; @@ -232,7 +232,7 @@ where let codeword_len = pp.code().codeword_len(); let mut seed = [0u8; 16]; - trans.get_challenge_bytes(&mut seed); + trans.get_challenge_bytes(b"Generate random queries", &mut seed); let mut prg = Prg::from_seed(Block::from(seed)); // Generate a random set of queries. @@ -253,7 +253,7 @@ where EF: AbstractExtensionField + Serialize + for<'de> Deserialize<'de>, { type Parameters = BrakedownParams; - type Polynomial = DenseMultilinearExtension; + type Polynomial = DenseMultilinearExtensionBase; type Commitment = BrakedownPolyCommitment; type CommitmentState = BrakedownCommitmentState; type Proof = BrakedownOpenProof; @@ -326,7 +326,8 @@ where ) -> Self::Proof { assert_eq!(points.len(), pp.num_vars()); // Hash the commitment to transcript. - trans.append_message(&commitment.to_bytes().unwrap()); + trans.append_message(b"commitment", &commitment); + // trans.append_message(&commitment.to_bytes().unwrap()); // Compute the tensor from the random point, see [DP23](https://eprint.iacr.org/2023/630.pdf). let tensor = Self::tensor_from_points(pp, points); @@ -334,7 +335,7 @@ where let rlc_msgs = Self::answer_challenge(pp, &tensor, state); // Hash rlc to transcript. - trans.append_ext_field_elements(&rlc_msgs); + trans.append_message(b"rlc", &rlc_msgs); // Sample random queries. let queries = Self::random_queries(pp, trans); @@ -360,7 +361,7 @@ where assert_eq!(points.len(), pp.num_vars()); // Hash the commitment to transcript. - trans.append_message(&commitment.to_bytes().unwrap()); + trans.append_message(b"commitment", &commitment); let (tensor, residual) = Self::tensor_decompose(pp, points); @@ -371,7 +372,7 @@ where pp.code().encode_ext(&mut encoded_msg); // Hash rlc to transcript. - trans.append_ext_field_elements(&proof.rlc_msgs); + trans.append_message(b"rlc", &proof.rlc_msgs); // Sample random queries. let queries = Self::random_queries(pp, trans); diff --git a/pcs/tests/test_pcs.rs b/pcs/tests/test_pcs.rs index 1f26605b..3555c77b 100644 --- a/pcs/tests/test_pcs.rs +++ b/pcs/tests/test_pcs.rs @@ -1,5 +1,6 @@ use algebra::{ - utils::Transcript, BabyBear, BabyBearExetension, DenseMultilinearExtension, FieldUniformSampler, + utils::Transcript, BabyBear, BabyBearExetension, DenseMultilinearExtensionBase, + FieldUniformSampler, }; use pcs::{ multilinear::{brakedown::BrakedownPCS, BrakedownOpenProof}, @@ -22,7 +23,7 @@ fn pcs_test() { .take(1 << num_vars) .collect(); - let poly = DenseMultilinearExtension::from_evaluations_vec(num_vars, evaluations); + let poly = DenseMultilinearExtensionBase::from_evaluations_vec(num_vars, evaluations); // let code_spec = ExpanderCodeSpec::new(128, 0.1195, 0.0284, 1.9, 60, 10); let code_spec = ExpanderCodeSpec::new(0.1195, 0.0284, 1.9, BASE_FIELD_BITS, 10); diff --git a/zkp/Cargo.toml b/zkp/Cargo.toml index b2e6491f..98a192d0 100644 --- a/zkp/Cargo.toml +++ b/zkp/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] algebra = { path = "../algebra" } +fhe_core = { path = "../fhe_core" } rand = { workspace = true } thiserror = { workspace = true } diff --git a/zkp/src/piop/accumulator.rs b/zkp/src/piop/accumulator.rs index 376f63e9..d113128e 100644 --- a/zkp/src/piop/accumulator.rs +++ b/zkp/src/piop/accumulator.rs @@ -8,56 +8,61 @@ use crate::utils::eval_identity_function; use std::marker::PhantomData; use std::rc::Rc; +use algebra::utils::Transcript; +use algebra::AbstractExtensionField; +use algebra::DenseMultilinearExtensionBase; use algebra::{ DenseMultilinearExtension, Field, ListOfProductsOfPolynomials, MultilinearExtension, PolynomialInfo, }; use itertools::izip; -use rand::{RngCore, SeedableRng}; -use rand_chacha::ChaCha12Rng; +use serde::Serialize; use super::bit_decomposition::{BitDecomposition, BitDecompositionProof, BitDecompositionSubClaim}; +use super::ntt::NTTInstanceExt; use super::ntt::{NTTProof, NTTSubclaim}; -use super::{DecomposedBits, DecomposedBitsInfo, NTTInstance, NTTInstanceInfo, NTTIOP}; +use super::rlwe_mul_rgsw::RlweCiphertextExt; +use super::rlwe_mul_rgsw::RlweCiphertextsExt; +use super::{DecomposedBits, DecomposedBitsInfo, NTTInstanceInfo, NTTIOP}; use super::{RlweCiphertext, RlweCiphertexts}; /// SNARKs for Mutliplication between RLWE ciphertext and RGSW ciphertext -pub struct AccumulatorIOP(PhantomData); +pub struct AccumulatorIOP>(PhantomData, PhantomData); /// proof generated by prover -pub struct AccumulatorProof { +pub struct AccumulatorProof> { /// proof for bit decompostion - pub bit_decomposition_proof: BitDecompositionProof, + pub bit_decomposition_proof: BitDecompositionProof, /// proof for ntt - pub ntt_proof: NTTProof, + pub ntt_proof: NTTProof, /// proof for sumcheck - pub sumcheck_msg: Proof, + pub sumcheck_msg: Proof, } /// subclaim reutrned to verifier -pub struct AccumulatorSubclaim { +pub struct AccumulatorSubclaim> { /// subclaim returned from the Bit Decomposition IOP - pub bit_decomposition_subclaim: BitDecompositionSubClaim, + pub bit_decomposition_subclaim: BitDecompositionSubClaim, /// subclaim returned from the NTT IOP - pub ntt_subclaim: NTTSubclaim, + pub ntt_subclaim: NTTSubclaim, /// subclaim returned from the sumcheck protocol - pub sumcheck_subclaim: SubClaim, + pub sumcheck_subclaim: SubClaim, } /// accumulator witness when performing ACC = ACC + (X^{-a_u} + 1) * ACC * RGSW(Z_u) pub struct AccumulatorWitness { /// * Witness when performing input_rlwe_ntt := (X^{-a_u} + 1) * ACC - /// accumulator of ntt form + /// accumulator of ntt form pub accumulator_ntt: RlweCiphertext, /// scalar d = (X^{-a_u} + 1) of coefficient form - pub d: Rc>, + pub d: Rc>, /// scalar d = (X^{-a_u} + 1) of ntt form - pub d_ntt: Rc>, + pub d_ntt: Rc>, /// result d * ACC of ntt form pub input_rlwe_ntt: RlweCiphertext, /// * Witness when performing output_rlwe_ntt := input_rlwe * RGSW(Z_u) where input_rlwe = (X^{-a_u} + 1) * ACC - /// result d * ACC of coefficient form - /// rlwe = (a, b): store the input ciphertext (a, b) where a and b are two polynomials represented by N coefficients. + /// result d * ACC of coefficient form + /// rlwe = (a, b): store the input ciphertext (a, b) where a and b are two polynomials represented by N coefficients. pub input_rlwe: RlweCiphertext, /// bits_rlwe = (a_bits, b_bits): a_bits (b_bits) corresponds to the bit decomposition result of a (b) in the input rlwe ciphertext pub bits_rlwe: RlweCiphertexts, @@ -71,21 +76,46 @@ pub struct AccumulatorWitness { pub output_rlwe_ntt: RlweCiphertext, } +/// Store the corresponding MLE of AccumulatorWitness where the evaluations are over the extension field. +pub struct AccumulatorWitnessExt> { + /// bits_rlwe_ntt: ntt form of the above bit decomposition result + pub bits_rlwe_ntt: RlweCiphertextsExt, + /// bits_rgsw_c_ntt: the ntt form of the first part (c) in the RGSW ciphertext + pub bits_rgsw_c_ntt: RlweCiphertextsExt, + /// bits_rgsw_c_ntt: the ntt form of the second part (f) in the RGSW ciphertext + pub bits_rgsw_f_ntt: RlweCiphertextsExt, + /// output_rlwe_ntt: store the output ciphertext (g', h') in the NTT-form + pub output_rlwe_ntt: RlweCiphertextExt, +} + +impl> AccumulatorWitnessExt { + /// Construct an instance over the extension field from the original instance defined over the basic field + pub fn from_base(input_base: &AccumulatorWitness) -> Self { + Self { + bits_rlwe_ntt: >::from_base(&input_base.bits_rlwe_ntt), + bits_rgsw_c_ntt: >::from_base(&input_base.bits_rgsw_c_ntt), + bits_rgsw_f_ntt: >::from_base(&input_base.bits_rgsw_f_ntt), + output_rlwe_ntt: >::from_base(&input_base.output_rlwe_ntt), + } + } +} + /// Store the ntt instance, bit decomposition instance, and the sumcheck instance for an Accumulator upating t times -pub struct AccumulatorInstance { +pub struct AccumulatorInstance> { /// number of updations in Accumulator denoted by t pub num_updations: usize, /// number of ntt transformation in Accumulator pub num_ntt: usize, /// the (virtually) randomized ntt instance to be proved - pub ntt_instance: NTTInstance, + pub ntt_instance: NTTInstanceExt, /// all decomposed bits pub decomposed_bits: DecomposedBits, /// poly in the sumcheck instance - pub poly: ListOfProductsOfPolynomials, + pub poly: ListOfProductsOfPolynomials, } /// Store the Accumulator info used to verify +#[derive(Serialize)] pub struct AccumulatorInstanceInfo { /// number of updations in Accumulator denoted by t pub num_updations: usize, @@ -97,7 +127,7 @@ pub struct AccumulatorInstanceInfo { pub poly_info: PolynomialInfo, } -impl AccumulatorInstance { +impl> AccumulatorInstance { /// construct an accumulator instance based on ntt info and bit-decomposition info #[inline] pub fn new( @@ -108,9 +138,9 @@ impl AccumulatorInstance { Self { num_updations: 0, num_ntt: 0, - ntt_instance: >::from_info(ntt_info), + ntt_instance: >::from_info(ntt_info), decomposed_bits: >::from_info(decom_info), - poly: >::new(num_vars), + poly: >::new(num_vars), } } @@ -134,9 +164,9 @@ impl AccumulatorInstance { /// * witness: all intermediate witness when updating the accumulator once pub fn add_witness( &mut self, - randomness_ntt: &[F], - randomness_sumcheck: &[F], - identity_func_at_u: &Rc>, + randomness_ntt: &[EF], + randomness_sumcheck: &[EF], + identity_func_at_u: &Rc>, witness: &AccumulatorWitness, ) { self.num_updations += 1; @@ -187,6 +217,9 @@ impl AccumulatorInstance { .add_ntt(*r.next().unwrap(), coeffs, points); } + // Convert the original instance to a new instance over the extension field + let witness = >::from_base(witness); + // Integrate the Sumcheck Part let r_1 = randomness_sumcheck[0]; let r_2 = randomness_sumcheck[1]; @@ -235,28 +268,17 @@ impl AccumulatorInstance { } } -impl AccumulatorIOP { - /// prove the accumulator updation - pub fn prove(instance: &AccumulatorInstance, u: &[F]) -> AccumulatorProof { - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::prove_as_subprotocol(&mut fs_rng, instance, u) - } - +impl> AccumulatorIOP { /// prove the accumulator updation - pub fn prove_as_subprotocol( - fs_rng: &mut impl RngCore, - instance: &AccumulatorInstance, - u: &[F], - ) -> AccumulatorProof { + pub fn prove( + trans: &mut Transcript, + instance: &AccumulatorInstance, + u: &[EF], + ) -> AccumulatorProof { AccumulatorProof { - bit_decomposition_proof: BitDecomposition::prove_as_subprotocol( - fs_rng, - &instance.decomposed_bits, - u, - ), - ntt_proof: NTTIOP::prove_as_subprotocol(fs_rng, &instance.ntt_instance, u), - sumcheck_msg: MLSumcheck::prove_as_subprotocol(fs_rng, &instance.poly) + bit_decomposition_proof: BitDecomposition::prove(trans, &instance.decomposed_bits, u), + ntt_proof: NTTIOP::prove(trans, &instance.ntt_instance, u), + sumcheck_msg: MLSumcheck::prove(trans, &instance.poly) .expect("sumcheck fail in accumulator updation") .0, } @@ -264,38 +286,22 @@ impl AccumulatorIOP { /// verify the proof pub fn verify( - proof: &AccumulatorProof, - u: &[F], + trans: &mut Transcript, + proof: &AccumulatorProof, + u: &[EF], info: &AccumulatorInstanceInfo, - ) -> AccumulatorSubclaim { - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::verify_as_subprotocol(&mut fs_rng, proof, u, info) - } - - /// verify the proof with provided RNG - pub fn verify_as_subprotocol( - fs_rng: &mut impl RngCore, - proof: &AccumulatorProof, - u: &[F], - info: &AccumulatorInstanceInfo, - ) -> AccumulatorSubclaim { + ) -> AccumulatorSubclaim { AccumulatorSubclaim { - bit_decomposition_subclaim: BitDecomposition::verifier_as_subprotocol( - fs_rng, + bit_decomposition_subclaim: BitDecomposition::verify( + trans, &proof.bit_decomposition_proof, &info.decomposed_bits_info, ), - ntt_subclaim: NTTIOP::verify_as_subprotocol( - fs_rng, - &proof.ntt_proof, - &info.ntt_info, - u, - ), - sumcheck_subclaim: MLSumcheck::verify_as_subprotocol( - fs_rng, + ntt_subclaim: NTTIOP::verify(trans, &proof.ntt_proof, &info.ntt_info, u), + sumcheck_subclaim: MLSumcheck::verify( + trans, &info.poly_info, - F::zero(), + EF::zero(), &proof.sumcheck_msg, ) .expect("sumcheck protocol in rlwe mult rgsw failed"), @@ -303,7 +309,7 @@ impl AccumulatorIOP { } } -impl AccumulatorSubclaim { +impl> AccumulatorSubclaim { /// verify the subclaim /// /// # Arguments @@ -317,11 +323,11 @@ impl AccumulatorSubclaim { #[allow(clippy::too_many_arguments)] pub fn verify_subclaim( &self, - u: &[F], - randomness_ntt: &[F], - randomness_sumcheck: &[F], - ntt_coeffs: &DenseMultilinearExtension, - ntt_points: &DenseMultilinearExtension, + u: &[EF], + randomness_ntt: &[EF], + randomness_sumcheck: &[EF], + ntt_coeffs: &DenseMultilinearExtension, + ntt_points: &DenseMultilinearExtension, witnesses: &Vec>, info: &AccumulatorInstanceInfo, ) -> bool { @@ -332,33 +338,33 @@ impl AccumulatorSubclaim { assert_eq!(randomness_sumcheck.len(), 2 * info.num_updations); // check 1: check the consistency of the randomized ntt instance and the original ntt instances - let mut coeffs_eval = F::zero(); - let mut points_eval = F::zero(); + let mut coeffs_eval = EF::zero(); + let mut points_eval = EF::zero(); let mut r_iter = randomness_ntt.iter(); for witness in witnesses { let r = r_iter.next().unwrap(); - coeffs_eval += *r * witness.d.evaluate(u); - points_eval += *r * witness.d_ntt.evaluate(u); + coeffs_eval += *r * witness.d.evaluate_ext(u); + points_eval += *r * witness.d_ntt.evaluate_ext(u); let r = r_iter.next().unwrap(); - coeffs_eval += *r * witness.input_rlwe.a.evaluate(u); - points_eval += *r * witness.input_rlwe_ntt.a.evaluate(u); + coeffs_eval += *r * witness.input_rlwe.a.evaluate_ext(u); + points_eval += *r * witness.input_rlwe_ntt.a.evaluate_ext(u); let r = r_iter.next().unwrap(); - coeffs_eval += *r * witness.input_rlwe.b.evaluate(u); - points_eval += *r * witness.input_rlwe_ntt.b.evaluate(u); + coeffs_eval += *r * witness.input_rlwe.b.evaluate_ext(u); + points_eval += *r * witness.input_rlwe_ntt.b.evaluate_ext(u); for (coeffs, points) in izip!(&witness.bits_rlwe.a_bits, &witness.bits_rlwe_ntt.a_bits) { let r = r_iter.next().unwrap(); - coeffs_eval += *r * coeffs.evaluate(u); - points_eval += *r * points.evaluate(u); + coeffs_eval += *r * coeffs.evaluate_ext(u); + points_eval += *r * points.evaluate_ext(u); } for (coeffs, points) in izip!(&witness.bits_rlwe.b_bits, &witness.bits_rlwe_ntt.b_bits) { let r = r_iter.next().unwrap(); - coeffs_eval += *r * coeffs.evaluate(u); - points_eval += *r * points.evaluate(u); + coeffs_eval += *r * coeffs.evaluate_ext(u); + points_eval += *r * points.evaluate_ext(u); } } if coeffs_eval != ntt_coeffs.evaluate(u) || points_eval != ntt_points.evaluate(u) { @@ -395,10 +401,10 @@ impl AccumulatorSubclaim { let mut r = randomness_sumcheck.iter(); // 4. check 4: check the subclaim returned from the sumcheck protocol consisting of two sub-sumcheck protocols - let mut sum_eval = F::zero(); + let mut sum_eval = EF::zero(); for witness in witnesses { - let mut sum1_eval = F::zero(); - let mut sum2_eval = F::zero(); + let mut sum1_eval = EF::zero(); + let mut sum2_eval = EF::zero(); // The first part is to evaluate at a random point g' = \sum_{i = 0}^{k-1} a_i' \cdot c_i + b_i' \cdot f_i // It is the reduction claim of prover asserting the sum \sum_{x} eq(u, x) (\sum_{i = 0}^{k-1} a_i'(x) \cdot c_i(x) + b_i'(x) \cdot f_i(x) - g'(x)) = 0 // where u is randomly sampled by the verifier. @@ -408,10 +414,10 @@ impl AccumulatorSubclaim { &witness.bits_rgsw_c_ntt.a_bits, &witness.bits_rgsw_f_ntt.a_bits ) { - sum1_eval += (a.evaluate(&self.sumcheck_subclaim.point) - * c.evaluate(&self.sumcheck_subclaim.point)) - + (b.evaluate(&self.sumcheck_subclaim.point) - * f.evaluate(&self.sumcheck_subclaim.point)); + sum1_eval += (a.evaluate_ext(&self.sumcheck_subclaim.point) + * c.evaluate_ext(&self.sumcheck_subclaim.point)) + + (b.evaluate_ext(&self.sumcheck_subclaim.point) + * f.evaluate_ext(&self.sumcheck_subclaim.point)); } // The second part is to evaluate at a random point h' = \sum_{i = 0}^{k-1} a_i' \cdot c_i' + b_i' \cdot f_i' @@ -423,10 +429,10 @@ impl AccumulatorSubclaim { &witness.bits_rgsw_c_ntt.b_bits, &witness.bits_rgsw_f_ntt.b_bits ) { - sum2_eval += (a.evaluate(&self.sumcheck_subclaim.point) - * c.evaluate(&self.sumcheck_subclaim.point)) - + (b.evaluate(&self.sumcheck_subclaim.point) - * f.evaluate(&self.sumcheck_subclaim.point)); + sum2_eval += (a.evaluate_ext(&self.sumcheck_subclaim.point) + * c.evaluate_ext(&self.sumcheck_subclaim.point)) + + (b.evaluate_ext(&self.sumcheck_subclaim.point) + * f.evaluate_ext(&self.sumcheck_subclaim.point)); } let r_1 = r.next().unwrap(); @@ -437,13 +443,13 @@ impl AccumulatorSubclaim { - witness .output_rlwe_ntt .a - .evaluate(&self.sumcheck_subclaim.point)) + .evaluate_ext(&self.sumcheck_subclaim.point)) + *r_2 * (sum2_eval - witness .output_rlwe_ntt .b - .evaluate(&self.sumcheck_subclaim.point))) + .evaluate_ext(&self.sumcheck_subclaim.point))) } sum_eval == self.sumcheck_subclaim.expected_evaluations } diff --git a/zkp/src/piop/addition_in_zq.rs b/zkp/src/piop/addition_in_zq.rs index 80107796..db50c107 100644 --- a/zkp/src/piop/addition_in_zq.rs +++ b/zkp/src/piop/addition_in_zq.rs @@ -24,31 +24,34 @@ use crate::utils::eval_identity_function; use crate::sumcheck::MLSumcheck; use crate::utils::gen_identity_evaluations; use algebra::{ - DecomposableField, DenseMultilinearExtension, Field, ListOfProductsOfPolynomials, - MultilinearExtension, PolynomialInfo, + utils::Transcript, AbstractExtensionField, DecomposableField, DenseMultilinearExtension, + DenseMultilinearExtensionBase, ExtensionField, Field, ListOfProductsOfPolynomials, Packable, + PolynomialInfo, UF, }; -use rand::{RngCore, SeedableRng}; -use rand_chacha::ChaCha12Rng; +use serde::Serialize; /// SNARKs for addition in Zq, i.e. a + b = c (mod q) -pub struct AdditionInZq(PhantomData); +pub struct AdditionInZq> { + _marker: PhantomData, + _stone: PhantomData, +} /// proof generated by prover -pub struct AdditionInZqProof { +pub struct AdditionInZqProof> { /// batched rangecheck proof for a, b, c \in Zq - pub rangecheck_msg: BitDecompositionProof, + pub rangecheck_msg: BitDecompositionProof, /// sumcheck proof for \sum_{x} eq(u, x) * k(x) * (1-k(x)) = 0, i.e. k(x)\in\{0,1\}^l - pub sumcheck_msg: Vec>, + pub sumcheck_msg: Vec>, } /// subclaim returned to verifier -pub struct AdditionInZqSubclaim { +pub struct AdditionInZqSubclaim> { /// rangecheck subclaim for a, b, c \in Zq - pub(crate) rangecheck_subclaim: BitDecompositionSubClaim, + pub(crate) rangecheck_subclaim: BitDecompositionSubClaim, /// subcliam for \sum_{x} eq(u, x) * k(x) * (1-k(x)) = 0 - pub sumcheck_point: Vec, + pub sumcheck_point: Vec, /// expected value returned in the last round of the sumcheck - pub sumcheck_expected_evaluations: F, + pub sumcheck_expected_evaluations: EF, } /// Stores the parameters used for addition in Zq and the inputs and witness for prover. @@ -58,14 +61,32 @@ pub struct AdditionInZqInstance { /// number of variables pub num_vars: usize, /// inputs a, b, and c - pub abc: Vec>>, + pub abc: Vec>>, /// introduced witness k - pub k: Rc>, + pub k: Rc>, /// introduced witness to check the range of a, b, c pub abc_bits: DecomposedBits, } +/// Store the corresponding MLE of AdditionInZqInstance where the evaluations are over the extension field. +pub struct AdditionInZqInstanceExt> { + /// introduced witness k + pub k: Rc>, +} + +impl> AdditionInZqInstanceExt { + /// Construct an instance over the extension field from the original instance defined over the basic field + pub fn from_base(input_base: &AdditionInZqInstance) -> Self { + Self { + k: Rc::new(>::from_base( + input_base.k.as_ref(), + )), + } + } +} + /// Stores the parameters used for addition in Zq and the public info for verifier. +#[derive(Serialize)] pub struct AdditionInZqInstanceInfo { /// modulus in addition pub q: F, @@ -88,8 +109,8 @@ impl AdditionInZqInstance { /// Construct a new instance from vector #[inline] pub fn from_vec( - abc: Vec>>, - k: &Rc>, + abc: Vec>>, + k: &Rc>, q: F, base: F, base_len: u32, @@ -123,8 +144,8 @@ impl AdditionInZqInstance { /// Construct a new instance from slice #[inline] pub fn from_slice( - abc: &[Rc>], - k: &Rc>, + abc: &[Rc>], + k: &Rc>, q: F, base: F, base_len: u32, @@ -156,7 +177,7 @@ impl AdditionInZqInstance { } } -impl AdditionInZqSubclaim { +impl> AdditionInZqSubclaim { /// verify the sumcliam /// * abc stores the inputs and the output to be added in Zq /// * k stores the introduced witness s.t. a + b = c + k\cdot q @@ -166,10 +187,10 @@ impl AdditionInZqSubclaim { pub fn verify_subclaim( &self, q: F, - abc: &[Rc>], - k: &DenseMultilinearExtension, - abc_bits: &[&Vec>>], - u: &[F], + abc: &[Rc>], + k: &DenseMultilinearExtensionBase, + abc_bits: &[&Vec>>], + u: &[EF], info: &AdditionInZqInstanceInfo, ) -> bool { assert_eq!(abc.len(), 3); @@ -184,56 +205,54 @@ impl AdditionInZqSubclaim { } // check 2: subclaim for sumcheck, i.e. eq(u, point) * k(point) * (1 - k(point)) = 0 - let eval_k = k.evaluate(&self.sumcheck_point); - if eval_identity_function(u, &self.sumcheck_point) * eval_k * (F::one() - eval_k) + let eval_k = k.evaluate_ext(&self.sumcheck_point); + if eval_identity_function(u, &self.sumcheck_point) * eval_k * (EF::one() - eval_k) != self.sumcheck_expected_evaluations { return false; } // check 3: a(u) + b(u) = c(u) + k(u) * q - abc[0].evaluate(u) + abc[1].evaluate(u) == abc[2].evaluate(u) + k.evaluate(u) * q + abc[0].evaluate_ext(u) + abc[1].evaluate_ext(u) + == abc[2].evaluate_ext(u) + k.evaluate_ext(u) * q } } -impl AdditionInZq { - /// Prove addition in Zq given a, b, c, k, and the decomposed bits for a, b, and c. - pub fn prove(addition_instance: &AdditionInZqInstance, u: &[F]) -> AdditionInZqProof { - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::prove_as_subprotocol(&mut fs_rng, addition_instance, u) - } - +impl> AdditionInZq { /// Prove addition in Zq given a, b, c, k, and the decomposed bits for a, b, and c. /// This function does the same thing as `prove`, but it uses a `Fiat-Shamir RNG` as the transcript/to generate the /// verifier challenges. - pub fn prove_as_subprotocol( - fs_rng: &mut impl RngCore, + pub fn prove( + trans: &mut Transcript, addition_instance: &AdditionInZqInstance, - u: &[F], - ) -> AdditionInZqProof { + u: &[EF], + ) -> AdditionInZqProof { // 1. rangecheck - let rangecheck_msg = - BitDecomposition::prove_as_subprotocol(fs_rng, &addition_instance.abc_bits, u); + let rangecheck_msg = BitDecomposition::prove(trans, &addition_instance.abc_bits, u); let dim = u.len(); assert_eq!(dim, addition_instance.num_vars); - let mut poly = >::new(dim); + let mut poly = >::new(dim); - // 2. execute sumcheck for \sum_{x} eq(u, x) * k(x) * (1-k(x)) = 0, i.e. k(x)\in\{0,1\}^l - let mut product = Vec::with_capacity(3); - let mut op_coefficient = Vec::with_capacity(3); - product.push(Rc::new(gen_identity_evaluations(u))); - op_coefficient.push((F::one(), F::zero())); - - product.push(Rc::clone(&addition_instance.k)); - op_coefficient.push((F::one(), F::zero())); - product.push(Rc::clone(&addition_instance.k)); - op_coefficient.push((-F::one(), F::one())); + // Convert the MLE over Field in the original instance to a new instance containing the corresponding MLE over Extension Field + let instance_ext = >::from_base(addition_instance); - poly.add_product_with_linear_op(product, &op_coefficient, F::one()); - let sumcheck_proof = MLSumcheck::prove_as_subprotocol(fs_rng, &poly) - .expect("sumcheck for addition in Zq failed"); + // 2. execute sumcheck for \sum_{x} eq(u, x) * k(x) * (1-k(x)) = 0, i.e. k(x)\in\{0,1\}^l + poly.add_product_with_linear_op( + [ + Rc::new(gen_identity_evaluations(u)), + Rc::clone(&instance_ext.k), + Rc::clone(&instance_ext.k), + ], + &[ + (>::one(), >::zero()), + (>::one(), >::zero()), + (UF::BaseField(-F::one()), >::one()), + ], + EF::one(), + ); + let sumcheck_proof = + MLSumcheck::prove(trans, &poly).expect("sumcheck for addition in Zq failed"); AdditionInZqProof { rangecheck_msg, @@ -243,28 +262,12 @@ impl AdditionInZq { /// Verify addition in Zq given the proof and the verification key for bit decomposistion pub fn verify( - proof: &AdditionInZqProof, + trans: &mut Transcript, + proof: &AdditionInZqProof, decomposed_bits_info: &DecomposedBitsInfo, - ) -> AdditionInZqSubclaim { - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::verifier_as_subprotocol(&mut fs_rng, proof, decomposed_bits_info) - } - - /// Verify addition in Zq given the proof and the verification key for bit decomposistion - /// This function does the same thing as `prove`, but it uses a `Fiat-Shamir RNG` as the transcript/to generate the - /// verifier challenges. - pub fn verifier_as_subprotocol( - fs_rng: &mut impl RngCore, - proof: &AdditionInZqProof, - decomposed_bits_info: &DecomposedBitsInfo, - ) -> AdditionInZqSubclaim { - // TODO sample randomness via Fiat-Shamir RNG - let rangecheck_subclaim = BitDecomposition::verifier_as_subprotocol( - fs_rng, - &proof.rangecheck_msg, - decomposed_bits_info, - ); + ) -> AdditionInZqSubclaim { + let rangecheck_subclaim = + BitDecomposition::verify(trans, &proof.rangecheck_msg, decomposed_bits_info); // execute sumcheck for \sum_{x} eq(u, x) * k(x) * (1-k(x)) = 0, i.e. k(x)\in\{0,1\}^l let poly_info = PolynomialInfo { @@ -272,9 +275,8 @@ impl AdditionInZq { num_variables: decomposed_bits_info.num_vars, }; - let subclaim = - MLSumcheck::verify_as_subprotocol(fs_rng, &poly_info, F::zero(), &proof.sumcheck_msg) - .expect("sumcheck protocol in addition in Zq failed"); + let subclaim = MLSumcheck::verify(trans, &poly_info, EF::zero(), &proof.sumcheck_msg) + .expect("sumcheck protocol in addition in Zq failed"); AdditionInZqSubclaim { rangecheck_subclaim, sumcheck_point: subclaim.point, diff --git a/zkp/src/piop/bit_decomposition.rs b/zkp/src/piop/bit_decomposition.rs index 439c3239..2f1012b7 100644 --- a/zkp/src/piop/bit_decomposition.rs +++ b/zkp/src/piop/bit_decomposition.rs @@ -19,7 +19,12 @@ //! then the resulting purported sum is: //! $\sum_{x \in \{0, 1\}^\log M} \sum_{i = 0}^{l-1} r_i \cdot eq(u, x) \cdot [\prod_{k=0}^B (d_i(x) - k)] = 0$ //! where r_i (for i = 0..l) are sampled from the verifier. -use algebra::{DecomposableField, DenseMultilinearExtension, Field, MultilinearExtension}; +use algebra::utils::Transcript; +use algebra::{ + AbstractExtensionField, DecomposableField, DenseMultilinearExtension, + DenseMultilinearExtensionBase, Field, UF, +}; +use serde::Serialize; use std::marker::PhantomData; use std::rc::Rc; @@ -27,27 +32,29 @@ use crate::sumcheck::prover::ProverMsg; use crate::sumcheck::MLSumcheck; use crate::utils::{eval_identity_function, gen_identity_evaluations}; -use algebra::{FieldUniformSampler, ListOfProductsOfPolynomials, PolynomialInfo}; -use rand::{RngCore, SeedableRng}; -use rand_chacha::ChaCha12Rng; -use rand_distr::Distribution; +use algebra::{ListOfProductsOfPolynomials, PolynomialInfo}; /// SNARKs for bit decomposition -pub struct BitDecomposition(PhantomData); +pub struct BitDecomposition> { + _marker: PhantomData, + _stone: PhantomData, +} /// proof generated by prover -pub struct BitDecompositionProof { - pub(crate) sumcheck_msg: Vec>, +pub struct BitDecompositionProof> { + pub(crate) sumcheck_msg: Vec>, } /// subclaim returned to verifier -pub struct BitDecompositionSubClaim { +pub struct BitDecompositionSubClaim> { /// r - pub randomness: Vec, + pub randomness: Vec, /// reduced point from the sumcheck protocol - pub point: Vec, + pub point: Vec, /// expected value returned in sumcheck - pub expected_evaluation: F, + pub expected_evaluation: EF, + /// marker for F + _marker: PhantomData, } /// Stores the parameters used for bit decomposation and every instance of decomposed bits, @@ -65,7 +72,74 @@ pub struct DecomposedBits { /// number of variables of every polynomial pub num_vars: usize, /// batched plain deomposed bits, each of which corresponds to one bit decomposisiton instance - pub instances: Vec>>>, + pub instances: Vec>>>, +} + +/// Store the corresponding MLE of DecomposedBits where the evaluations are over the extension field. +pub struct DecomposedBitsExt> { + /// batched plain deomposed bits, each of which corresponds to one bit decomposisiton instance + pub instances: Vec>>>, +} + +impl> DecomposedBitsExt { + /// Construct an instance over the extension field from the original instance defined over the basic field + pub fn from_base(input_base: &DecomposedBits) -> Self { + DecomposedBitsExt { + instances: input_base + .instances + .iter() + .map(|instance| { + instance + .iter() + .map(|bit| Rc::new(DenseMultilinearExtension::from_base(bit.as_ref()))) + .collect::>() + }) + .collect::>(), + } + } + + #[inline] + /// Batch all the sumcheck protocol, each corresponding to range-check one single bit. + /// * randomness: randomness used to linearly combine bits_len * num_instances sumcheck protocols + /// * u is the common random challenge from the verifier, used to instantiate every sum. + pub fn randomized_sumcheck( + &self, + instance_base: &DecomposedBits, + randomness: &[EF], + u: &[EF], + ) -> ListOfProductsOfPolynomials { + assert_eq!( + randomness.len(), + self.instances.len() * instance_base.bits_len as usize + ); + assert_eq!(u.len(), instance_base.num_vars); + + let mut poly = >::new(instance_base.num_vars); + let identity_func_at_u = Rc::new(gen_identity_evaluations(u)); + let base = 1 << instance_base.base_len; + + let mut r_iter = randomness.iter(); + + for instance in &self.instances { + // For every bit, the reduced sum is $\sum_{x \in \{0, 1\}^\log M} eq(u, x) \cdot [\prod_{k=0}^B (d_i(x) - k)] = 0$ + // and the added product is r_i \cdot eq(u, x) \cdot [\prod_{k=0}^B (d_i(x) - k)] with the corresponding randomness + for bit in instance { + let mut product: Vec<_> = Vec::with_capacity(base + 1); + let mut op_coefficient = Vec::with_capacity(base + 1); + product.push(Rc::clone(&identity_func_at_u)); + op_coefficient.push((UF::one(), UF::zero())); + + let mut minus_k = F::zero(); + for _ in 0..base { + product.push(Rc::clone(bit)); + op_coefficient.push((UF::one(), UF::BaseField(minus_k))); + minus_k -= F::one(); + } + poly.add_product_with_linear_op(product, &op_coefficient, *r_iter.next().unwrap()); + } + } + poly + } } /// Stores the parameters used for bit decomposation. @@ -73,7 +147,7 @@ pub struct DecomposedBits { /// * It is required to decompose over a power-of-2 base. /// /// These parameters are used as the verifier key. -#[derive(Clone)] +#[derive(Clone, Serialize)] pub struct DecomposedBitsInfo { /// base pub base: F, @@ -129,7 +203,7 @@ impl DecomposedBits { /// * decomposed_bits: store each bit pub fn add_decomposed_bits_instance( &mut self, - decomposed_bits: &[Rc>], + decomposed_bits: &[Rc>], ) { assert_eq!(decomposed_bits.len(), self.bits_len as usize); for bit in decomposed_bits { @@ -137,43 +211,6 @@ impl DecomposedBits { } self.instances.push(decomposed_bits.to_vec()); } - - #[inline] - /// Batch all the sumcheck protocol, each corresponding to range-check one single bit. - /// * randomness: randomness used to linearly combine bits_len * num_instances sumcheck protocols - /// * u is the common random challenge from the verifier, used to instantiate every sum. - pub fn randomized_sumcheck(&self, randomness: &[F], u: &[F]) -> ListOfProductsOfPolynomials { - assert_eq!( - randomness.len(), - self.instances.len() * self.bits_len as usize - ); - assert_eq!(u.len(), self.num_vars); - - let mut poly = >::new(self.num_vars); - let identity_func_at_u = Rc::new(gen_identity_evaluations(u)); - let base = 1 << self.base_len; - - let mut r_iter = randomness.iter(); - for instance in &self.instances { - // For every bit, the reduced sum is $\sum_{x \in \{0, 1\}^\log M} eq(u, x) \cdot [\prod_{k=0}^B (d_i(x) - k)] = 0$ - // and the added product is r_i \cdot eq(u, x) \cdot [\prod_{k=0}^B (d_i(x) - k)] with the corresponding randomness - for bit in instance { - let mut product: Vec<_> = Vec::with_capacity(base + 1); - let mut op_coefficient: Vec<_> = Vec::with_capacity(base + 1); - product.push(Rc::clone(&identity_func_at_u)); - op_coefficient.push((F::one(), F::zero())); - - let mut minus_k = F::zero(); - for _ in 0..base { - product.push(Rc::clone(bit)); - op_coefficient.push((F::one(), minus_k)); - minus_k -= F::one(); - } - poly.add_product_with_linear_op(product, &op_coefficient, *r_iter.next().unwrap()); - } - } - poly - } } impl DecomposedBits { @@ -181,14 +218,14 @@ impl DecomposedBits { /// Then add the result into this instance, meaning to add l sumcheck protocols. /// * decomposed_bits: store each bit #[inline] - pub fn add_value_instance(&mut self, value: &DenseMultilinearExtension) { + pub fn add_value_instance(&mut self, value: &DenseMultilinearExtensionBase) { assert_eq!(self.num_vars, value.num_vars); self.instances .push(value.get_decomposed_mles(self.base_len, self.bits_len)); } } -impl BitDecompositionSubClaim { +impl> BitDecompositionSubClaim { /// verify the subclaim /// /// # Argument @@ -198,19 +235,26 @@ impl BitDecompositionSubClaim { /// * `u` is the common random challenge from the verifier, used to instantiate every sum. pub fn verify_subclaim( &self, - d_val: &[Rc>], - d_bits: &[&Vec>>], - u: &[F], + d_val: &[Rc>], + d_bits: &[&Vec>>], + u: &[EF], decomposed_bits_info: &DecomposedBitsInfo, ) -> bool { assert_eq!(d_val.len(), decomposed_bits_info.num_instances); assert_eq!(d_bits.len(), decomposed_bits_info.num_instances); assert_eq!(u.len(), decomposed_bits_info.num_vars); - let d_val_at_point: Vec<_> = d_val.iter().map(|val| val.evaluate(&self.point)).collect(); + let d_val_at_point: Vec<_> = d_val + .iter() + .map(|val| val.evaluate_ext(&self.point)) + .collect(); let d_bits_at_point: Vec> = d_bits .iter() - .map(|bits| bits.iter().map(|bit| bit.evaluate(&self.point)).collect()) + .map(|bits| { + bits.iter() + .map(|bit| bit.evaluate_ext(&self.point)) + .collect() + }) .collect(); // base_pow = [1, B, ..., B^{l-1}] @@ -228,7 +272,7 @@ impl BitDecompositionSubClaim { *val == bits .iter() .zip(base_pow.iter()) - .fold(F::zero(), |acc, (bit, pow)| acc + *pow * *bit) + .fold(EF::zero(), |acc, (bit, pow)| acc + *bit * *pow) }) { return false; @@ -236,7 +280,7 @@ impl BitDecompositionSubClaim { // check 2: expected value returned in sumcheck // each instance contributes value: eq(u, x) \cdot \sum_{i = 0}^{l-1} r_i \cdot [\prod_{k=0}^B (d_i(x) - k)] =? expected_evaluation - let mut evaluation = F::zero(); + let mut evaluation = EF::zero(); let mut r = self.randomness.iter(); d_bits_at_point.iter().for_each(|bits| { bits.iter().for_each(|bit| { @@ -253,69 +297,56 @@ impl BitDecompositionSubClaim { } } -impl BitDecomposition { +impl> BitDecomposition { /// Prove bit decomposition given the decomposed bits as prover key. - pub fn prove(decomposed_bits: &DecomposedBits, u: &[F]) -> BitDecompositionProof { - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::prove_as_subprotocol(&mut fs_rng, decomposed_bits, u) - } + pub fn prove( + trans: &mut Transcript, + decomposed_bits_base: &DecomposedBits, + u: &[EF], + ) -> BitDecompositionProof { + let num_bits = + decomposed_bits_base.instances.len() * decomposed_bits_base.bits_len as usize; + trans.append_message(b"decomposed bits", &decomposed_bits_base.info()); - /// Prove bit decomposition given the decomposed bits as prover key. - /// This function does the same thing as `prove`, but it uses a `Fiat-Shamir RNG` as the transcript/to generate the - /// verifier challenges. - pub fn prove_as_subprotocol( - fs_rng: &mut impl RngCore, - decomposed_bits: &DecomposedBits, - u: &[F], - ) -> BitDecompositionProof { - let num_bits = decomposed_bits.instances.len() * decomposed_bits.bits_len as usize; - // TODO sample randomness via Fiat-Shamir RNG // batch `len_bits` sumcheck protocols into one with random linear combination - let sampler = >::new(); - let randomness: Vec<_> = (0..num_bits).map(|_| sampler.sample(fs_rng)).collect(); - let poly = decomposed_bits.randomized_sumcheck(&randomness, u); + let randomness = trans + .get_vec_ext_field_challenge(b"randomness to combine sumcheck protocols", num_bits); + + // Convert to a new instance defined over extension field + let decomposed_bits = >::from_base(decomposed_bits_base); + let poly = decomposed_bits.randomized_sumcheck(decomposed_bits_base, &randomness, u); + + trans.append_message(b"sumcheck protocol", &poly.info()); BitDecompositionProof { - sumcheck_msg: MLSumcheck::prove_as_subprotocol(fs_rng, &poly) + sumcheck_msg: MLSumcheck::prove(trans, &poly) .expect("bit decomposition failed") .0, } } /// Verify bit decomposition given the basic information of decomposed bits as verifier key. - pub fn verifier( - proof: &BitDecompositionProof, - decomposed_bits_info: &DecomposedBitsInfo, - ) -> BitDecompositionSubClaim { - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::verifier_as_subprotocol(&mut fs_rng, proof, decomposed_bits_info) - } - - /// Verify bit decomposition given the basic information of decomposed bits as verifier key. - /// This function does the same thing as `prove`, but it uses a `Fiat-Shamir RNG` as the transcript/to generate the - /// verifier challenges. - pub fn verifier_as_subprotocol( - fs_rng: &mut impl RngCore, - proof: &BitDecompositionProof, + pub fn verify( + trans: &mut Transcript, + proof: &BitDecompositionProof, decomposed_bits_info: &DecomposedBitsInfo, - ) -> BitDecompositionSubClaim { + ) -> BitDecompositionSubClaim { let num_bits = decomposed_bits_info.num_instances * decomposed_bits_info.bits_len as usize; - // TODO sample randomness via Fiat-Shamir RNG + trans.append_message(b"decomposed bits", decomposed_bits_info); // batch `len_bits` sumcheck protocols into one with random linear combination - let sampler = >::new(); - let randomness: Vec<_> = (0..num_bits).map(|_| sampler.sample(fs_rng)).collect(); + let randomness = trans + .get_vec_ext_field_challenge(b"randomness to combine sumcheck protocols", num_bits); let poly_info = PolynomialInfo { max_multiplicands: 1 + (1 << decomposed_bits_info.base_len), num_variables: decomposed_bits_info.num_vars, }; - let subclaim = - MLSumcheck::verify_as_subprotocol(fs_rng, &poly_info, F::zero(), &proof.sumcheck_msg) - .expect("bit decomposition verification failed"); + trans.append_message(b"sumcheck protocol", &poly_info); + let subclaim = MLSumcheck::verify(trans, &poly_info, EF::zero(), &proof.sumcheck_msg) + .expect("bit decomposition verification failed"); BitDecompositionSubClaim { randomness, point: subclaim.point, expected_evaluation: subclaim.expected_evaluations, + _marker: PhantomData, } } } diff --git a/zkp/src/piop/ntt/mod.rs b/zkp/src/piop/ntt/mod.rs index 7e98e4e1..fe504a89 100644 --- a/zkp/src/piop/ntt/mod.rs +++ b/zkp/src/piop/ntt/mod.rs @@ -38,42 +38,48 @@ use crate::utils::{eval_identity_function, gen_identity_evaluations}; use std::marker::PhantomData; use std::rc::Rc; +use algebra::utils::Transcript; +use algebra::AbstractExtensionField; +use algebra::DenseMultilinearExtensionBase; use algebra::{ DenseMultilinearExtension, Field, ListOfProductsOfPolynomials, MultilinearExtension, - PolynomialInfo, + PolynomialInfo, UF, }; -use rand::{RngCore, SeedableRng}; -use rand_chacha::ChaCha12Rng; use ntt_bare::{NTTBareIOP, NTTBareProof, NTTBareSubclaim}; +use serde::ser::SerializeSeq; +use serde::Serialize; pub mod ntt_bare; /// SNARKs for NTT, i.e. $$a(u) = \sum_{x\in \{0, 1\}^{\log N} c(x)\cdot F(u, x) }$$ -pub struct NTTIOP(PhantomData); +pub struct NTTIOP> { + _marker: PhantomData, + _stone: PhantomData, +} /// proof generated by prover -pub struct NTTProof { +pub struct NTTProof> { /// bare ntt proof for proving $$a(u) = \sum_{x\in \{0, 1\}^{\log N} c(x)\cdot F(u, x) }$$ - pub ntt_bare_proof: NTTBareProof, + pub ntt_bare_proof: NTTBareProof, /// sumcheck proof for $$a(u) = \sum_{x\in \{0, 1\}^{\log N} c(x)\cdot F(u, x) }$$ /// collective sumcheck proofs for delegation - pub delegation_sumcheck_msgs: Vec>, + pub delegation_sumcheck_msgs: Vec>, /// collective claimed sums for delegation - pub delegation_claimed_sums: Vec, + pub delegation_claimed_sums: Vec, /// final claim - pub final_claim: F, + pub final_claim: EF, } /// subclaim returned to verifier -pub struct NTTSubclaim { +pub struct NTTSubclaim> { /// subclaim returned in ntt bare for proving $$a(u) = \sum_{x\in \{0, 1\}^{\log N} c(x)\cdot F(u, x) }$$ - pub ntt_bare_subclaim: NTTBareSubclaim, + pub ntt_bare_subclaim: NTTBareSubclaim, /// the first claim in the delegation process, i.e. F(u, v) - pub delegation_first_claim: F, + pub delegation_first_claim: EF, /// the final claim in the delegation process - pub delegation_final_claim: F, + pub delegation_final_claim: EF, /// the requested point in the final claim - pub final_point: Vec, + pub final_point: Vec, } /// Stores the NTT instance with the corresponding NTT table @@ -84,9 +90,73 @@ pub struct NTTInstance { /// stores {ω^0, ω^1, ..., ω^{2N-1}} pub ntt_table: Rc>, /// coefficient representation of the polynomial - pub coeffs: DenseMultilinearExtension, + pub coeffs: DenseMultilinearExtensionBase, + /// point-evaluation representation of the polynomial + pub points: DenseMultilinearExtensionBase, +} + +/// Stores the NTT instance with the corresponding NTT table +/// where the NTT instance is generated by random linear combination and the randomness is over EF +pub struct NTTInstanceExt> { + /// log_n is the number of the variables + /// the degree of the polynomial is N - 1 + pub log_n: usize, + /// stores {ω^0, ω^1, ..., ω^{2N-1}} + pub ntt_table: Rc>, + /// coefficient representation of the polynomial + pub coeffs: DenseMultilinearExtension, /// point-evaluation representation of the polynomial - pub points: DenseMultilinearExtension, + pub points: DenseMultilinearExtension, +} + +impl> NTTInstanceExt { + /// Extract the information of the NTT Instance for verification + #[inline] + pub fn info(&self) -> NTTInstanceInfo { + NTTInstanceInfo { + log_n: self.log_n, + ntt_table: Rc::clone(&self.ntt_table), + } + } + + /// Construct a new instance from the original instance defined over the basic field + pub fn from_base(input_base: &NTTInstance) -> Self { + Self { + log_n: input_base.log_n, + ntt_table: input_base.ntt_table.clone(), + coeffs: >::from_base(&input_base.coeffs), + points: >::from_base(&input_base.points), + } + } + + /// Constuct a new instance from given info + #[inline] + pub fn from_info(info: &NTTInstanceInfo) -> Self { + Self { + log_n: info.log_n, + ntt_table: info.ntt_table.to_owned(), + coeffs: >::from_evaluations_vec( + info.log_n, + vec![EF::zero(); 1 << info.log_n], + ), + points: >::from_evaluations_vec( + info.log_n, + vec![EF::zero(); 1 << info.log_n], + ), + } + } + + /// add ntt_instance + #[inline] + pub fn add_ntt( + &mut self, + r: EF, + coeffs: &Rc>, + points: &Rc>, + ) { + self.coeffs += (r, coeffs.as_ref()); + self.points += (r, points.as_ref()); + } } /// Stores the corresponding NTT table for the verifier @@ -100,12 +170,26 @@ pub struct NTTInstanceInfo { } /// store the intermediate mles generated in each iteration in the `init_fourier_table_overall` algorithm -pub struct IntermediateMLEs { - f_mles: Vec>>, - w_mles: Vec>>, +pub struct IntermediateMLEs> { + f_mles: Vec>>, + w_mles: Vec>>, +} + +impl Serialize for NTTInstanceInfo { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut seq = serializer.serialize_seq(Some(self.ntt_table.len() + 1))?; + seq.serialize_element(&self.log_n)?; + for e in self.ntt_table.iter() { + seq.serialize_element(e)?; + } + seq.end() + } } -impl IntermediateMLEs { +impl> IntermediateMLEs { /// Initiate the vector pub fn new(n_rounds: u32) -> Self { IntermediateMLEs { @@ -115,7 +199,7 @@ impl IntermediateMLEs { } /// Add the intermediate mles generated in each round - pub fn add_round_mles(&mut self, num_vars: usize, f_mle: &[F], w_mle: Vec) { + pub fn add_round_mles(&mut self, num_vars: usize, f_mle: &[EF], w_mle: Vec) { self.f_mles .push(Rc::new(DenseMultilinearExtension::from_evaluations_slice( num_vars, f_mle, @@ -140,25 +224,28 @@ impl IntermediateMLEs { /// Hence, the final equation is F(u, x) = \prod_{i=0}^{\log{N-1}} ((1 - u_i) + u_i * ω^{2^{i + 1} * X}) * ω^{2^i * x_i} /// /// * In order to comprehend this implementation, it is strongly recommended to read the pure version `naive_init_fourier_table` and `init_fourier_table` in the `ntt_bare.rs`. -/// `naive_init_fourier_table` shows the original formula of this algorithm. -/// `init_fourier_table` shows the dynamic programming version of this algorithm. -/// `init_fourier_table_overall` (this function) stores many intermediate evaluations for the ease of the delegation of F(u, v) +/// `naive_init_fourier_table` shows the original formula of this algorithm. +/// `init_fourier_table` shows the dynamic programming version of this algorithm. +/// `init_fourier_table_overall` (this function) stores many intermediate evaluations for the ease of the delegation of F(u, v) /// /// # Arguments /// * u: the random point /// * ntt_table: It stores the NTT table: ω^0, ω^1, ..., ω^{2N - 1} -pub fn init_fourier_table_overall(u: &[F], ntt_table: &[F]) -> IntermediateMLEs { +pub fn init_fourier_table_overall>( + u: &[EF], + ntt_table: &[F], +) -> IntermediateMLEs { let log_n = u.len(); // N = 1 << dim let m = ntt_table.len(); // M = 2N = 2 * (1 << dim) // It store the evaluations of all F(u, x) for x \in \{0, 1\}^dim. // Note that in our implementation, we use little endian form, so the index `0b1011` // represents the point `P(1,1,0,1)` in {0,1}^`dim` - let mut evaluations: Vec<_> = vec![F::zero(); 1 << log_n]; - evaluations[0] = F::one(); + let mut evaluations: Vec<_> = vec![EF::zero(); 1 << log_n]; + evaluations[0] = EF::one(); // stores all the intermediate evaluations of the table (i.e. F(u, x)) and the term ω^{2^{i + 1} * X} in each iteration - let mut intermediate_mles = >::new(log_n as u32); + let mut intermediate_mles = >::new(log_n as u32); // * Compute \prod_{i=0}^{\log{N-1}} ((1 - u_i) + u_i * ω^{2^{i + 1} * X}) * ω^{2^i * x_i} // The reason why we update the table with u_i in reverse order is that @@ -174,7 +261,7 @@ pub fn init_fourier_table_overall(u: &[F], ntt_table: &[F]) -> Interme let this_round_table_size = 1 << this_round_dim; let last_round_table_size = 1 << last_round_dim; - let mut evaluations_w_term = vec![F::zero(); this_round_table_size]; + let mut evaluations_w_term = vec![EF::zero(); this_round_table_size]; for x in (0..this_round_table_size).rev() { // idx is to indicate the power ω^{2^{i + 1} * X} in ntt_table let idx = (1 << (i + 1)) * x % m; @@ -182,16 +269,16 @@ pub fn init_fourier_table_overall(u: &[F], ntt_table: &[F]) -> Interme // If x >= last_round_table_size, meaning the bit = 1, we need to multiply by ω^{2^last_round_dim * 1} if x >= last_round_table_size { evaluations[x] = evaluations[x % last_round_table_size] - * (F::one() - u[i] + u[i] * ntt_table[idx]) + * (EF::one() - u[i] + u[i] * ntt_table[idx]) * ntt_table[1 << last_round_dim]; } // the bit index in this iteration is last_round_dim = this_round_dim - 1 // If x < last_round_table_size, meaning the bit = 0, we do not need to multiply because ω^{2^last_round_dim * 0} = 1 else { evaluations[x] = evaluations[x % last_round_table_size] - * (F::one() - u[i] + u[i] * ntt_table[idx]); + * (EF::one() - u[i] + u[i] * ntt_table[idx]); } - evaluations_w_term[x] = ntt_table[idx]; + evaluations_w_term[x] = EF::from_base(ntt_table[idx]); } intermediate_mles.add_round_mles( this_round_dim, @@ -211,18 +298,18 @@ pub fn init_fourier_table_overall(u: &[F], ntt_table: &[F]) -> Interme /// * log_m: log of M /// * x_dim: dimension of x or the num of variables of the outputted mle /// * exp: the exponent of the function defined above -pub fn naive_w_power_times_x_table( +pub fn naive_w_power_times_x_table>( ntt_table: &[F], log_m: usize, x_dim: usize, exp: usize, -) -> DenseMultilinearExtension { +) -> DenseMultilinearExtension { let m = 1 << log_m; // M = 2N = 2 * (1 << dim) assert_eq!(ntt_table.len(), m); - let mut evaluations: Vec<_> = (0..(1 << x_dim)).map(|_| F::one()).collect(); + let mut evaluations: Vec<_> = (0..(1 << x_dim)).map(|_| EF::one()).collect(); for x in 0..(1 << x_dim) { - evaluations[x] = ntt_table[(1 << exp) * x % m]; + evaluations[x] = EF::from_base(ntt_table[(1 << exp) * x % m]); } DenseMultilinearExtension::from_evaluations_vec(x_dim, evaluations) } @@ -233,7 +320,7 @@ pub fn naive_w_power_times_x_table( /// = \prod_i (1 - r_i + r_i * w^{2^ {(exp + i) % log_m}) /// * Note that the above equation only holds for exp <= logM - x_dim; /// * otherwise, the exponent 2^exp * x involves a modular addition, disabling the decomposition. -/// (Although I am not clearly making it out, the experiement result shows the above argument.) +/// (Although I am not clearly making it out, the experiement result shows the above argument.) /// /// # Arguments: /// @@ -242,21 +329,21 @@ pub fn naive_w_power_times_x_table( /// * x_dim: dimension of x or the num of variables of the outputted mle /// * exp: the exponent of the function defined above /// * r: random point in F^{x_dim} -pub fn eval_w_power_times_x( +pub fn eval_w_power_times_x>( ntt_table: &[F], log_m: usize, x_dim: usize, exp: usize, - r: &[F], -) -> F { + r: &[EF], +) -> EF { assert_eq!(ntt_table.len(), 1 << log_m); assert_eq!(x_dim, r.len()); assert!(exp + x_dim <= log_m); - let mut prod = F::one(); + let mut prod = EF::one(); for (i, &r_i) in r.iter().enumerate() { let log_exp = (exp + i) % log_m; - prod *= F::one() - r_i + r_i * ntt_table[1 << log_exp]; + prod *= EF::one() - r_i + r_i * ntt_table[1 << log_exp]; } prod @@ -277,8 +364,8 @@ impl NTTInstance { pub fn from_vec( log_n: usize, ntt_table: &Rc>, - coeffs: DenseMultilinearExtension, - points: DenseMultilinearExtension, + coeffs: DenseMultilinearExtensionBase, + points: DenseMultilinearExtensionBase, ) -> Self { Self { log_n, @@ -293,8 +380,8 @@ impl NTTInstance { pub fn from_slice( log_n: usize, ntt_table: &Rc>, - coeffs: &Rc>, - points: &Rc>, + coeffs: &Rc>, + points: &Rc>, ) -> Self { Self { log_n, @@ -310,11 +397,11 @@ impl NTTInstance { Self { log_n: info.log_n, ntt_table: info.ntt_table.to_owned(), - coeffs: >::from_evaluations_vec( + coeffs: >::from_evaluations_vec( info.log_n, vec![F::zero(); 1 << info.log_n], ), - points: >::from_evaluations_vec( + points: >::from_evaluations_vec( info.log_n, vec![F::zero(); 1 << info.log_n], ), @@ -326,22 +413,22 @@ impl NTTInstance { pub fn add_ntt( &mut self, r: F, - coeffs: &Rc>, - points: &Rc>, + coeffs: &Rc>, + points: &Rc>, ) { self.coeffs += (r, coeffs); self.points += (r, points); } } -impl NTTSubclaim { +impl> NTTSubclaim { /// verify the subcliam #[inline] pub fn verify_subcliam( &self, - points: &DenseMultilinearExtension, - coeffs: &DenseMultilinearExtension, - u: &[F], + points: &DenseMultilinearExtension, + coeffs: &DenseMultilinearExtension, + u: &[EF], info: &NTTInstanceInfo, ) -> bool { assert_eq!(u.len(), info.log_n); @@ -359,23 +446,16 @@ impl NTTSubclaim { // check2: check the final claim returned from the last round of delegation let idx = 1 << (info.log_n); - let eval = eval_identity_function(&self.final_point, &[F::zero()]) - + eval_identity_function(&self.final_point, &[F::one()]) - * (F::one() - u[info.log_n - 1] + u[info.log_n - 1] * info.ntt_table[idx]) + let eval = eval_identity_function(&self.final_point, &[EF::zero()]) + + eval_identity_function(&self.final_point, &[EF::one()]) + * (EF::one() - u[info.log_n - 1] + u[info.log_n - 1] * info.ntt_table[idx]) * info.ntt_table[1]; self.delegation_final_claim == eval } } -impl NTTIOP { - /// prove - pub fn prove(ntt_instance: &NTTInstance, u: &[F]) -> NTTProof { - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::prove_as_subprotocol(&mut fs_rng, ntt_instance, u) - } - +impl> NTTIOP { /// The delegation of F(u, v) consists of logN - 1 rounds, each of which is a sumcheck protocol. /// /// We define $A_{F}^{(k)}:\{0,1\}^{k+1} -> \mathbb{F}$ and $ω^{(k)}_{i+1}:\{0,1\}^{k+1} -> \mathbb{F}$. @@ -396,22 +476,18 @@ impl NTTIOP { /// * f: MLE \tilde{A}_{F}^{(k-1)}(z) for z\in \{0,1\}^k /// * w: MLE \tilde{ω}^{(k)}_{i+1}(z, b) for z\in \{0,1\}^k and b\in \{0, 1\}, which will be divided into two smaller MLEs \tilde{ω}^{(k)}_{i+1}(z, 0) and \tilde{ω}^{(k)}_{i+1}(z, 1) pub fn delegation_prover_round( - fs_rng: &mut impl RngCore, + trans: &mut Transcript, round: usize, - point: &[F], - u_i: F, - w_coeff: F, - f: &Rc>, - w: &Rc>, - ) -> (Proof, ProverState) { + point: &[EF], + u_i: EF, + w_coeff: EF, + f: &Rc>, + w: &Rc>, + ) -> (Proof, ProverState) { assert_eq!(f.num_vars, round); assert_eq!(w.num_vars, round + 1); - let mut poly = >::new(round); - let mut product_left = Vec::with_capacity(3); - let mut product_right = Vec::with_capacity(3); - let mut ops_left = Vec::with_capacity(3); - let mut ops_right = Vec::with_capacity(3); + let mut poly = >::new(round); // the equality function defined by the random point $(x, b)\in \mathbb{F}^{k+1}$ // it is divided into two MLEs \tilde{\beta}((x, b),(z,0)) and \tilde{\beta}((x, b),(z,1)) @@ -424,32 +500,38 @@ impl NTTIOP { // construct the polynomial to be sumed // left product is \tilde{\beta}((x, b),(z,0)) * \tilde{A}_{F}^{(k-1)}(z) ( (1-u_{i})+u_{i} * \tilde{ω}^{(k)}_{i+1}(z, 0) // right product is \tilde{\beta}((x, b),(z,1)) * \tilde{A}_{F}^{(k-1)}(z) ( (1-u_{i})+u_{i} * \tilde{ω}^{(k)}_{i+1}(z, 1) * ω^{2^k} - product_left.push(Rc::new(eq_func_left)); - ops_left.push((F::one(), F::zero())); - product_left.push(Rc::clone(f)); - ops_left.push((F::one(), F::zero())); - product_left.push(Rc::new(w_left)); - ops_left.push((u_i, F::one() - u_i)); - poly.add_product_with_linear_op(product_left, &ops_left, F::one()); - - product_right.push(Rc::new(eq_func_right)); - ops_right.push((F::one(), F::zero())); - product_right.push(Rc::clone(f)); - ops_right.push((F::one(), F::zero())); - product_right.push(Rc::new(w_right)); - ops_right.push((u_i, F::one() - u_i)); - poly.add_product_with_linear_op(product_right, &ops_right, w_coeff); - - MLSumcheck::prove_as_subprotocol(fs_rng, &poly) - .expect("ntt proof of delegation failed in round {round}") + poly.add_product_with_linear_op( + [Rc::new(eq_func_left), Rc::clone(f), Rc::new(w_left)], + &[ + (UF::one(), UF::zero()), + (UF::one(), UF::zero()), + (UF::ExtensionField(u_i), UF::ExtensionField(EF::one() - u_i)), + ], + EF::one(), + ); + + poly.add_product_with_linear_op( + [Rc::new(eq_func_right), Rc::clone(f), Rc::new(w_right)], + &[ + (UF::one(), UF::zero()), + (UF::one(), UF::zero()), + (UF::ExtensionField(u_i), UF::ExtensionField(EF::one() - u_i)), + ], + w_coeff, + ); + + MLSumcheck::prove(trans, &poly).expect("ntt proof of delegation failed in round {round}") } /// prove NTT with delegation - pub fn prove_as_subprotocol( - fs_rng: &mut impl RngCore, - ntt_instance: &NTTInstance, - u: &[F], - ) -> NTTProof { + /// Note that this is the only interface that requires prover to provide an instance over Extension Field. + /// We deliberately design in this way since the final NTT instance proved in our scheme is a randomized ntt instance defined over EF. + pub fn prove( + trans: &mut Transcript, + ntt_instance: &NTTInstanceExt, + u: &[EF], + ) -> NTTProof { + trans.append_message(b"ntt", &ntt_instance.info()); let log_n = ntt_instance.log_n; let intermediate_mles = init_fourier_table_overall(u, &ntt_instance.ntt_table); @@ -457,8 +539,7 @@ impl NTTIOP { // 1. prove a(u) = \sum_{x\in \{0, 1\}^{\log N} c(x)\cdot F(u, x) } for a random point u let f_u = &f_mles[log_n - 1]; - let (ntt_bare_proof, state) = - NTTBareIOP::prove_as_subprotocol(fs_rng, f_u, ntt_instance, u); + let (ntt_bare_proof, state) = NTTBareIOP::prove(trans, f_u, ntt_instance, u); // the above sumcheck is reduced to prove F(u, v) where v is the requested point let mut requested_point = state.randomness; @@ -478,11 +559,11 @@ impl NTTIOP { let w_coeff = ntt_instance.ntt_table[1 << k]; let f = &f_mles[k - 1]; let (proof_round, state_round) = Self::delegation_prover_round( - fs_rng, + trans, k, &requested_point, u[i], - w_coeff, + EF::from_base(w_coeff), f, &w_mles[k], ); @@ -501,17 +582,6 @@ impl NTTIOP { } } - /// prove NTT with delegation - pub fn verify( - proof: &NTTProof, - ntt_instance_info: &NTTInstanceInfo, - u: &[F], - ) -> NTTSubclaim { - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::verify_as_subprotocol(&mut fs_rng, proof, ntt_instance_info, u) - } - /// The delegation of F(u, v) consists of logN - 1 rounds, each of which is a sumcheck protocol. /// /// We define $A_{F}^{(k)}:\{0,1\}^{k+1} -> \mathbb{F}$ and $ω^{(k)}_{i+1}:\{0,1\}^{k+1} -> \mathbb{F}$. @@ -533,22 +603,22 @@ impl NTTIOP { /// * reduced_claim: the given evaluation of \tilde{A}_{F}^{(k-1)}(z = r) so verify does not need to compute on his own pub fn delegation_verify_round( round: usize, - x_b_point: &[F], - u_i: F, - subclaim: &SubClaim, - reduced_claim: F, + x_b_point: &[EF], + u_i: EF, + subclaim: &SubClaim, + reduced_claim: EF, ntt_instance_info: &NTTInstanceInfo, ) -> bool { let log_n = ntt_instance_info.log_n; let ntt_table = &ntt_instance_info.ntt_table; // r_left = (r, 0) and r_right = (r, 0) - let mut r_left: Vec<_> = Vec::with_capacity(round + 1); - let mut r_right: Vec<_> = Vec::with_capacity(round + 1); + let mut r_left = Vec::with_capacity(round + 1); + let mut r_right = Vec::with_capacity(round + 1); r_left.extend(&subclaim.point); r_right.extend(&subclaim.point); - r_left.push(F::zero()); - r_right.push(F::one()); + r_left.push(EF::zero()); + r_right.push(EF::one()); // compute $\ω^{(k)}_{i+1}(x,b ) = \ω^{2^{i+1}\cdot j}$ for $j = X+2^{i+1}\cdot b$ at point (r, 0) and (r, 1) // exp: i + 1 = n - k @@ -559,30 +629,29 @@ impl NTTIOP { let eval = eval_identity_function(x_b_point, &r_left) * reduced_claim - * (F::one() - u_i + u_i * w_left) + * (EF::one() - u_i + u_i * w_left) + eval_identity_function(x_b_point, &r_right) * reduced_claim - * (F::one() - u_i + u_i * w_right) + * (EF::one() - u_i + u_i * w_right) * ntt_table[1 << round]; eval == subclaim.expected_evaluations } /// verify NTT with delegation - pub fn verify_as_subprotocol( - fs_rng: &mut impl RngCore, - proof: &NTTProof, + pub fn verify( + trans: &mut Transcript, + proof: &NTTProof, ntt_instance_info: &NTTInstanceInfo, - u: &[F], - ) -> NTTSubclaim { + u: &[EF], + ) -> NTTSubclaim { + trans.append_message(b"ntt", ntt_instance_info); let log_n = ntt_instance_info.log_n; assert_eq!(proof.delegation_sumcheck_msgs.len(), log_n - 1); assert_eq!(proof.delegation_claimed_sums.len(), log_n - 1); - // TODO sample randomness via Fiat-Shamir RNG // 1. verify a(u) = \sum_{x\in \{0, 1\}^{\log N} c(x)\cdot F(u, x) } for a random point u - let ntt_bare_subclaim = - NTTBareIOP::verify_as_subprotocol(fs_rng, &proof.ntt_bare_proof, ntt_instance_info); + let ntt_bare_subclaim = NTTBareIOP::verify(trans, &proof.ntt_bare_proof, ntt_instance_info); // 2. verify the computation of F(u, v) in log_n - 1 rounds let mut requested_point = ntt_bare_subclaim.point.clone(); @@ -594,8 +663,8 @@ impl NTTIOP { max_multiplicands: 3, num_variables: k, }; - let subclaim = MLSumcheck::verify_as_subprotocol( - fs_rng, + let subclaim = MLSumcheck::verify( + trans, &poly_info, proof.delegation_claimed_sums[cnt], &proof.delegation_sumcheck_msgs[cnt], @@ -624,7 +693,6 @@ impl NTTIOP { requested_point = subclaim.point; } - // TODO: handle the case that log = 1 // TODO: handle the case that log = 1 assert_eq!(requested_point.len(), 1); NTTSubclaim { @@ -641,7 +709,8 @@ mod test { use crate::piop::ntt::{eval_w_power_times_x, naive_w_power_times_x_table}; use algebra::{ derive::{DecomposableField, FheField, Field, Prime, NTT}, - DenseMultilinearExtension, FieldUniformSampler, MultilinearExtension, NTTField, + BabyBear, BabyBearExetension, DecomposableField, DenseMultilinearExtensionBase, Field, + FieldUniformSampler, MultilinearExtension, NTTField, }; use num_traits::{One, Zero}; use rand::thread_rng; @@ -650,14 +719,15 @@ mod test { use super::init_fourier_table_overall; #[derive(Field, DecomposableField, FheField, Prime, NTT)] - #[modulus = 132120577] - pub struct Fp32(u32); + #[modulus = 2013265921] + pub struct Fp32(u64); // field type - type FF = Fp32; + type FF = BabyBear; + type EF = BabyBearExetension; #[test] fn test_init_fourier_table_overall() { - let sampler = >::new(); + let sampler = >::new(); let mut rng = thread_rng(); let dim = 10; @@ -665,12 +735,13 @@ mod test { let u: Vec<_> = (0..dim).map(|_| sampler.sample(&mut rng)).collect(); let v: Vec<_> = (0..dim).map(|_| sampler.sample(&mut rng)).collect(); - let mut u_v: Vec<_> = Vec::with_capacity(dim << 1); + let mut u_v: Vec = Vec::with_capacity(dim << 1); u_v.extend(&u); u_v.extend(&v); // root is the M-th root of unity - let root = FF::try_minimal_primitive_root(m).unwrap(); + let root = Fp32::try_minimal_primitive_root(m).unwrap(); + let root = FF::new(root.value() as u32); let mut fourier_matrix: Vec<_> = (0..(1 << dim) * (1 << dim)).map(|_| FF::zero()).collect(); let mut ntt_table = Vec::with_capacity(m as usize); @@ -690,10 +761,14 @@ mod test { } } - let fourier_mle = DenseMultilinearExtension::from_evaluations_vec(dim << 1, fourier_matrix); + let fourier_mle = + DenseMultilinearExtensionBase::from_evaluations_vec(dim << 1, fourier_matrix); let partial_fourier_mle = &init_fourier_table_overall(&u, &ntt_table).f_mles[dim - 1]; - assert_eq!(fourier_mle.evaluate(&u_v), partial_fourier_mle.evaluate(&v)); + assert_eq!( + fourier_mle.evaluate_ext(&u_v), + partial_fourier_mle.evaluate(&v) + ); } #[test] @@ -703,7 +778,8 @@ mod test { let m = 1 << log_m; // M = 2N // root is the M-th root of unity - let root = FF::try_minimal_primitive_root(m).unwrap(); + let root = Fp32::try_minimal_primitive_root(m).unwrap(); + let root = FF::new(root.value() as u32); let mut ntt_table = Vec::with_capacity(m as usize); @@ -713,7 +789,7 @@ mod test { power *= root; } - let sampler = >::new(); + let sampler = >::new(); let mut rng = thread_rng(); for x_dim in 0..=dim { diff --git a/zkp/src/piop/ntt/ntt_bare.rs b/zkp/src/piop/ntt/ntt_bare.rs index 80faf6ac..e2d75b40 100644 --- a/zkp/src/piop/ntt/ntt_bare.rs +++ b/zkp/src/piop/ntt/ntt_bare.rs @@ -23,40 +23,45 @@ use crate::sumcheck::MLSumcheck; use std::marker::PhantomData; use std::rc::Rc; +use super::NTTInstanceExt; +use super::NTTInstanceInfo; +use algebra::utils::Transcript; +use algebra::AbstractExtensionField; +use algebra::DenseMultilinearExtensionBase; use algebra::{ DenseMultilinearExtension, Field, ListOfProductsOfPolynomials, MultilinearExtension, PolynomialInfo, }; -use rand::{RngCore, SeedableRng}; -use rand_chacha::ChaCha12Rng; - -use super::{NTTInstance, NTTInstanceInfo}; /// IOP for NTT, i.e. $$a(u) = \sum_{x\in \{0, 1\}^{\log N} c(x)\cdot F(u, x) }$$ -pub struct NTTBareIOP(PhantomData); +pub struct NTTBareIOP> { + _marker: PhantomData, + _stone: PhantomData, +} /// proof generated by prover in bare ntt, which only consists of the sumcheck without delegation for F(u, v) /// Without delegation, prover only needs to prove this sum /// $$a(u) = \sum_{x\in \{0, 1\}^{\log N} c(x)\cdot F(u, x) }$$ /// where u is a random point, given by verifier -pub struct NTTBareProof { +pub struct NTTBareProof> { /// sumcheck_msg when proving - pub sumcheck_msg: Vec>, + pub sumcheck_msg: Vec>, /// the claimed sum is a(u) - pub claimed_sum: F, + pub claimed_sum: EF, } /// subclaim returned in bare ntt, which only consists of the sumcheck without delegation for F(u, v) /// Without delegation, prover only needs to prove this sum /// $$a(u) = \sum_{x\in \{0, 1\}^{\log N} c(x)\cdot F(u, x) }$$ /// where u is a random point, given by verifier -pub struct NTTBareSubclaim { +pub struct NTTBareSubclaim> { /// the claimed sum is a(u) - pub claimed_sum: F, + pub claimed_sum: EF, /// the proof is reduced to the evaluation of this point (denoted by v) - pub point: Vec, + pub point: Vec, /// the proof is reduced to the evaluation equals to c(v) \cdot F(u, v) - pub expected_evaluation: F, + pub expected_evaluation: EF, + _marker: PhantomData, } /// Naive implementation for initializing F(u, x) in NTT, which helps readers to understand the following dynamic programming version (`init_fourier_table``). @@ -69,16 +74,16 @@ pub struct NTTBareSubclaim { /// # Arguments /// * u: the random point /// * ntt_table: It stores the NTT table: ω^0, ω^1, ..., ω^{2N - 1} -/// In order to delegate the computation F(u, v) to prover, we decompose the ω^X term into the grand product. -/// Hence, the final equation is = \prod_{i=0}^{\log{N-1}} ((1 - u_i) + u_i * ω^{2^{i + 1} * X}) * ω^{2^i * x_i} -pub fn naive_init_fourier_table( - u: &[F], +/// In order to delegate the computation F(u, v) to prover, we decompose the ω^X term into the grand product. +/// Hence, the final equation is = \prod_{i=0}^{\log{N-1}} ((1 - u_i) + u_i * ω^{2^{i + 1} * X}) * ω^{2^i * x_i} +pub fn naive_init_fourier_table>( + u: &[EF], ntt_table: &[F], -) -> DenseMultilinearExtension { +) -> DenseMultilinearExtension { let log_n = u.len(); let m = ntt_table.len(); // m = 2n = 2 * (1 << dim) - let mut evaluations = vec![F::one(); 1 << log_n]; + let mut evaluations = vec![EF::one(); 1 << log_n]; for (x, eval_at_x) in evaluations.iter_mut().enumerate() { for (i, &u_i) in u.iter().enumerate().take(log_n) { @@ -86,7 +91,7 @@ pub fn naive_init_fourier_table( let x_i = (x >> i) & 1; let x_i_idx = (1 << i) * x_i; - *eval_at_x *= ((F::one() - u_i) + u_i * ntt_table[idx]) * ntt_table[x_i_idx]; + *eval_at_x *= ((EF::one() - u_i) + u_i * ntt_table[idx]) * ntt_table[x_i_idx]; } } @@ -109,15 +114,18 @@ pub fn naive_init_fourier_table( /// # Arguments /// * u: the random point /// * ω: It stores the NTT table: ω^0, ω^1, ..., ω^{2N - 1} -pub fn init_fourier_table(u: &[F], ntt_table: &[F]) -> DenseMultilinearExtension { +pub fn init_fourier_table>( + u: &[EF], + ntt_table: &[F], +) -> DenseMultilinearExtension { let log_n = u.len(); // n = 1 << dim let m = ntt_table.len(); // m = 2n = 2 * (1 << dim) // It store the evaluations of all F(u, x) for x \in \{0, 1\}^dim. // Note that in our implementation, we use little endian form, so the index `0b1011` // represents the point `P(1,1,0,1)` in {0,1}^`dim` - let mut evaluations: Vec<_> = vec![F::zero(); 1 << log_n]; - evaluations[0] = F::one(); + let mut evaluations = vec![EF::zero(); 1 << log_n]; + evaluations[0] = EF::one(); // * Compute \prod_{i=0}^{\log{N-1}} ((1 - u_i) + u_i * ω^{2^{i + 1} * X}) * ω^{2^i * x_i} // The reason why we update the table with u_i in reverse order is that @@ -138,20 +146,20 @@ pub fn init_fourier_table(u: &[F], ntt_table: &[F]) -> DenseMultilinea let bit = j >> k; if bit == 1 { evaluations[j] = evaluations[j % last_table_size] - * (F::one() - u[i] + u[i] * ntt_table[idx]) + * (EF::one() - u[i] + u[i] * ntt_table[idx]) * ntt_table[last_table_size]; } // If bit = 0, we do not need to multiply because ω^{2^k * 0} = 1 else { evaluations[j] = - evaluations[j % last_table_size] * (F::one() - u[i] + u[i] * ntt_table[idx]); + evaluations[j % last_table_size] * (EF::one() - u[i] + u[i] * ntt_table[idx]); } } } DenseMultilinearExtension::from_evaluations_vec(log_n, evaluations) } -impl NTTBareSubclaim { +impl> NTTBareSubclaim { /// verify the subcliam for sumcheck /// $$a(u) = \sum_{x\in \{0, 1\}^{\log N} c(x)\cdot F(u, x) }$$ /// @@ -166,10 +174,10 @@ impl NTTBareSubclaim { #[inline] pub fn verify_subclaim( &self, - fourier_matrix: &DenseMultilinearExtension, - points: &DenseMultilinearExtension, - coeffs: &DenseMultilinearExtension, - u: &[F], + fourier_matrix: &DenseMultilinearExtensionBase, + points: &DenseMultilinearExtension, + coeffs: &DenseMultilinearExtension, + u: &[EF], info: &NTTInstanceInfo, ) -> bool { assert_eq!(u.len(), info.log_n); @@ -180,10 +188,10 @@ impl NTTBareSubclaim { } // check 2: c(v) * F(u, v) = expected_evaluation - let mut u_v: Vec<_> = Vec::with_capacity(info.log_n << 1); + let mut u_v: Vec = Vec::with_capacity(info.log_n << 1); u_v.extend(u); u_v.extend(&self.point); - self.expected_evaluation == coeffs.evaluate(&self.point) * fourier_matrix.evaluate(&u_v) + self.expected_evaluation == coeffs.evaluate(&self.point) * fourier_matrix.evaluate_ext(&u_v) } /// verify the subcliam for sumcheck @@ -202,10 +210,10 @@ impl NTTBareSubclaim { /// * u: the random point sampled by verifier before executing the sumcheck protocol pub fn verify_subclaim_with_delegation( &self, - f_delegation: F, - points: &DenseMultilinearExtension, - coeffs: &DenseMultilinearExtension, - u: &[F], + f_delegation: EF, + points: &DenseMultilinearExtension, + coeffs: &DenseMultilinearExtension, + u: &[EF], ) -> bool { // check 1: a(u) = claimed_sum if self.claimed_sum != points.evaluate(u) { @@ -217,34 +225,30 @@ impl NTTBareSubclaim { } } -impl NTTBareIOP { +impl> NTTBareIOP { /// prove pub fn prove( - ntt_instance: &NTTInstance, - f_u: &Rc>, - u: &[F], - ) -> NTTBareProof { - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::prove_as_subprotocol(&mut fs_rng, f_u, ntt_instance, u).0 - } - - /// prove - pub fn prove_as_subprotocol( - fs_rng: &mut impl RngCore, - f_u: &Rc>, - ntt_instance: &NTTInstance, - u: &[F], - ) -> (NTTBareProof, ProverState) { + trans: &mut Transcript, + f_u: &Rc>, + ntt_instance: &NTTInstanceExt, + u: &[EF], + ) -> (NTTBareProof, ProverState) { + trans.append_message(b"ntt bare", &ntt_instance.info()); let log_n = ntt_instance.log_n; - let mut poly = >::new(log_n); + let mut poly = >::new(log_n); - let product = vec![Rc::clone(f_u), Rc::new(ntt_instance.coeffs.clone())]; - poly.add_product(product, F::one()); + poly.add_product( + [ + Rc::clone(f_u), + // Convert the original MLE over Field to a new MLE over Extension Field + Rc::new(ntt_instance.coeffs.clone()), + ], + EF::one(), + ); let (prover_msg, prover_state) = - MLSumcheck::prove_as_subprotocol(fs_rng, &poly).expect("ntt bare proof failed"); + MLSumcheck::prove(trans, &poly).expect("ntt bare proof failed"); ( NTTBareProof { @@ -257,27 +261,17 @@ impl NTTBareIOP { /// verify pub fn verify( - ntt_bare_proof: &NTTBareProof, - ntt_instance_info: &NTTInstanceInfo, - ) -> NTTBareSubclaim { - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::verify_as_subprotocol(&mut fs_rng, ntt_bare_proof, ntt_instance_info) - } - - /// verify - pub fn verify_as_subprotocol( - fs_rng: &mut impl RngCore, - ntt_bare_proof: &NTTBareProof, + trans: &mut Transcript, + ntt_bare_proof: &NTTBareProof, ntt_instance_info: &NTTInstanceInfo, - ) -> NTTBareSubclaim { - // TODO sample randomness via Fiat-Shamir RNG + ) -> NTTBareSubclaim { + trans.append_message(b"ntt bare", &ntt_instance_info); let poly_info = PolynomialInfo { max_multiplicands: 2, num_variables: ntt_instance_info.log_n, }; - let subclaim = MLSumcheck::verify_as_subprotocol( - fs_rng, + let subclaim = MLSumcheck::verify( + trans, &poly_info, ntt_bare_proof.claimed_sum, &ntt_bare_proof.sumcheck_msg, @@ -288,6 +282,7 @@ impl NTTBareIOP { claimed_sum: ntt_bare_proof.claimed_sum, point: subclaim.point, expected_evaluation: subclaim.expected_evaluations, + _marker: PhantomData, } } } @@ -296,7 +291,8 @@ impl NTTBareIOP { mod test { use algebra::{ derive::{DecomposableField, FheField, Field, Prime, NTT}, - DenseMultilinearExtension, Field, FieldUniformSampler, MultilinearExtension, NTTField, + BabyBear, BabyBearExetension, DecomposableField, DenseMultilinearExtensionBase, Field, + FieldUniformSampler, MultilinearExtension, NTTField, }; use num_traits::{One, Zero}; use rand::thread_rng; @@ -314,24 +310,27 @@ mod test { } #[derive(Field, DecomposableField, FheField, Prime, NTT)] - #[modulus = 132120577] - pub struct Fp32(u32); + #[modulus = 2013265921] + pub struct Fp32(u64); + // field type - type FF = Fp32; + type FF = BabyBear; + type EF = BabyBearExetension; #[test] fn test_naive_init_fourier_matrix() { let dim = 2; let m = 1 << (dim + 1); // M = 2N = 2 * (1 << dim) - let u = field_vec!(FF; 1, 1); - let v = field_vec!(FF; 0, 1); + let u = field_vec!(EF; 1, 1); + let v = field_vec!(EF; 0, 1); - let mut u_v = Vec::with_capacity(dim << 1); + let mut u_v: Vec = Vec::with_capacity(dim << 1); u_v.extend(&u); u_v.extend(&v); // root is the M-th root of unity - let root = FF::try_minimal_primitive_root(m).unwrap(); + let root = Fp32::try_minimal_primitive_root(m).unwrap(); + let root = FF::new(root.value() as u32); let mut fourier_matrix: Vec<_> = (0..(1 << dim) * (1 << dim)).map(|_| FF::zero()).collect(); let mut ntt_table = Vec::with_capacity(m as usize); @@ -351,16 +350,20 @@ mod test { } } - let fourier_mle = DenseMultilinearExtension::from_evaluations_vec(dim << 1, fourier_matrix); + let fourier_mle = + DenseMultilinearExtensionBase::from_evaluations_vec(dim << 1, fourier_matrix); // It includes the evaluations of f(u, x) for x \in \{0, 1\}^N let partial_fourier_mle = naive_init_fourier_table(&u, &ntt_table); - assert_eq!(fourier_mle.evaluate(&u_v), partial_fourier_mle.evaluate(&v)); + assert_eq!( + fourier_mle.evaluate_ext(&u_v), + partial_fourier_mle.evaluate(&v) + ); } #[test] fn test_init_fourier_matrix() { - let sampler = >::new(); + let sampler = >::new(); let mut rng = thread_rng(); let dim = 10; @@ -368,12 +371,13 @@ mod test { let u: Vec<_> = (0..dim).map(|_| sampler.sample(&mut rng)).collect(); let v: Vec<_> = (0..dim).map(|_| sampler.sample(&mut rng)).collect(); - let mut u_v: Vec<_> = Vec::with_capacity(dim << 1); + let mut u_v: Vec = Vec::with_capacity(dim << 1); u_v.extend(&u); u_v.extend(&v); // root is the M-th root of unity - let root = FF::try_minimal_primitive_root(m).unwrap(); + let root = Fp32::try_minimal_primitive_root(m).unwrap(); + let root = FF::new(root.value() as u32); let mut fourier_matrix: Vec<_> = (0..(1 << dim) * (1 << dim)).map(|_| FF::zero()).collect(); let mut ntt_table = Vec::with_capacity(m as usize); @@ -393,10 +397,14 @@ mod test { } } - let fourier_mle = DenseMultilinearExtension::from_evaluations_vec(dim << 1, fourier_matrix); + let fourier_mle = + DenseMultilinearExtensionBase::from_evaluations_vec(dim << 1, fourier_matrix); // It includes the evaluations of f(u, x) for x \in \{0, 1\}^N let partial_fourier_mle = &init_fourier_table(&u, &ntt_table); - assert_eq!(fourier_mle.evaluate(&u_v), partial_fourier_mle.evaluate(&v)); + assert_eq!( + fourier_mle.evaluate_ext(&u_v), + partial_fourier_mle.evaluate(&v) + ); } } diff --git a/zkp/src/piop/rlwe_mul_rgsw.rs b/zkp/src/piop/rlwe_mul_rgsw.rs index 1bff224b..60591532 100644 --- a/zkp/src/piop/rlwe_mul_rgsw.rs +++ b/zkp/src/piop/rlwe_mul_rgsw.rs @@ -32,43 +32,48 @@ use std::marker::PhantomData; use std::rc::Rc; use std::vec; +use algebra::utils::Transcript; +use algebra::AbstractExtensionField; +use algebra::DenseMultilinearExtensionBase; use algebra::{ - DenseMultilinearExtension, Field, FieldUniformSampler, ListOfProductsOfPolynomials, - MultilinearExtension, PolynomialInfo, + DenseMultilinearExtension, Field, ListOfProductsOfPolynomials, MultilinearExtension, + PolynomialInfo, }; use itertools::izip; -use rand::{RngCore, SeedableRng}; -use rand_chacha::ChaCha12Rng; -use rand_distr::Distribution; +use serde::Serialize; use super::bit_decomposition::{BitDecomposition, BitDecompositionProof, BitDecompositionSubClaim}; +use super::ntt::NTTInstanceExt; use super::ntt::{NTTProof, NTTSubclaim}; -use super::{DecomposedBits, DecomposedBitsInfo, NTTInstance, NTTInstanceInfo, NTTIOP}; +use super::{DecomposedBits, DecomposedBitsInfo, NTTInstanceInfo, NTTIOP}; /// SNARKs for Mutliplication between RLWE ciphertext and RGSW ciphertext -pub struct RlweMultRgswIOP(PhantomData); +pub struct RlweMultRgswIOP>( + PhantomData, + PhantomData, +); /// proof generated by prover -pub struct RlweMultRgswProof { +pub struct RlweMultRgswProof> { /// proof for bit decompostion - pub bit_decomposition_proof: BitDecompositionProof, + pub bit_decomposition_proof: BitDecompositionProof, /// proof for ntt - pub ntt_proof: NTTProof, + pub ntt_proof: NTTProof, /// proof for sumcheck - pub sumcheck_msg: Proof, + pub sumcheck_msg: Proof, } /// subclaim reutrned to verifier -pub struct RlweMultRgswSubclaim { +pub struct RlweMultRgswSubclaim> { /// subclaim returned from the Bit Decomposition IOP - pub bit_decomposition_subclaim: BitDecompositionSubClaim, + pub bit_decomposition_subclaim: BitDecompositionSubClaim, /// randomness used to randomize 2k + 2 ntt instances - pub randomness_ntt: Vec, + pub randomness_ntt: Vec, /// subclaim returned from the NTT IOP - pub ntt_subclaim: NTTSubclaim, + pub ntt_subclaim: NTTSubclaim, /// randomness used in combination of the two sumcheck protocol - pub randomness_sumcheck: Vec, + pub randomness_sumcheck: Vec, /// subclaim returned from the sumcheck protocol - pub sumcheck_subclaim: SubClaim, + pub sumcheck_subclaim: SubClaim, } /// RLWE ciphertext (a, b) where a and b represents two polynomials in some defined polynomial ring. @@ -76,18 +81,62 @@ pub struct RlweMultRgswSubclaim { #[derive(Clone)] pub struct RlweCiphertext { /// the first part of the ciphertext, chosen at random in the FHE scheme. - pub a: Rc>, + pub a: Rc>, /// the second part of the ciphertext, computed with the plaintext in the FHE scheme. - pub b: Rc>, + pub b: Rc>, } +/// Store the corresponding MLE of RlweCiphertext where the evaluations are over the extension field. +pub struct RlweCiphertextExt> { + /// the first part of the ciphertext, chosen at random in the FHE scheme. + pub a: Rc>, + /// the second part of the ciphertext, computed with the plaintext in the FHE scheme. + pub b: Rc>, +} + +impl> RlweCiphertextExt { + /// Construct an instance over the extension field from the original instance defined over the basic field + pub fn from_base(input_base: &RlweCiphertext) -> Self { + Self { + a: Rc::new(>::from_base(&input_base.a)), + b: Rc::new(>::from_base(&input_base.b)), + } + } +} /// RLWE' ciphertexts represented by two vectors, containing k RLWE ciphertext. #[derive(Clone)] pub struct RlweCiphertexts { /// store the first part of each RLWE ciphertext. - pub a_bits: Vec>>, + pub a_bits: Vec>>, /// store the second part of each RLWE ciphertext. - pub b_bits: Vec>>, + pub b_bits: Vec>>, +} + +/// RLWE' ciphertexts represented by two vectors, containing k RLWE ciphertext. +#[derive(Clone)] +pub struct RlweCiphertextsExt> { + /// store the first part of each RLWE ciphertext. + pub a_bits: Vec>>, + /// store the second part of each RLWE ciphertext. + pub b_bits: Vec>>, +} + +impl> RlweCiphertextsExt { + /// Construct an instance over the extension field from the original instance defined over the basic field + pub fn from_base(input_base: &RlweCiphertexts) -> Self { + Self { + a_bits: input_base + .a_bits + .iter() + .map(|bit| Rc::new(>::from_base(bit.as_ref()))) + .collect(), + b_bits: input_base + .b_bits + .iter() + .map(|bit| Rc::new(>::from_base(bit.as_ref()))) + .collect(), + } + } } impl RlweCiphertexts { @@ -100,7 +149,11 @@ impl RlweCiphertexts { } /// Add a RLWE ciphertext - pub fn add_rlwe(&mut self, a: DenseMultilinearExtension, b: DenseMultilinearExtension) { + pub fn add_rlwe( + &mut self, + a: DenseMultilinearExtensionBase, + b: DenseMultilinearExtensionBase, + ) { self.a_bits.push(Rc::new(a)); self.b_bits.push(Rc::new(b)); } @@ -109,9 +162,9 @@ impl RlweCiphertexts { /// Stores the multiplicaton instance between RLWE ciphertext and RGSW ciphertext with the corresponding NTT table /// Given (a, b) \in RLWE where a and b are two polynomials represented by N coefficients, /// and (c, f) \in RGSW = RLWE' \times RLWE' = (RLWE, ..., RLWE) \times (RLWE, ..., RLWE) where c = ((c0, c0'), ..., (ck-1, ck-1')) and f = ((f0, f0'), ..., (fk-1, fk-1')) -pub struct RlweMultRgswInstance { +pub struct RlweMultRgswInstance> { /// randomized ntt instance to be proved - pub ntt_instance: NTTInstance, + pub ntt_instance: NTTInstanceExt, /// info of decomposed bits pub decomposed_bits_info: DecomposedBitsInfo, /// rlwe = (a, b): store the input ciphertext (a, b) where a and b are two polynomials represented by N coefficients. @@ -130,8 +183,39 @@ pub struct RlweMultRgswInstance { pub output_rlwe: RlweCiphertext, } +/// Store the corresponding MLE of RlweMultRgswInstance where the evaluations are over the extension field. +pub struct RlweMultRgswInstanceExt> { + /// rlwe = (a, b): store the input ciphertext (a, b) where a and b are two polynomials represented by N coefficients. + pub input_rlwe: RlweCiphertextExt, + /// bits_rlwe_ntt: ntt form of the above bit decomposition result + pub bits_rlwe_ntt: RlweCiphertextsExt, + /// bits_rgsw_c_ntt: the ntt form of the first part (c) in the RGSW ciphertext + pub bits_rgsw_c_ntt: RlweCiphertextsExt, + /// bits_rgsw_c_ntt: the ntt form of the second part (f) in the RGSW ciphertext + pub bits_rgsw_f_ntt: RlweCiphertextsExt, + /// output_rlwe_ntt: store the output ciphertext (g', h') in the NTT-form + pub output_rlwe_ntt: RlweCiphertextExt, + /// output_rlwe: store the output ciphertext (g, h) in the coefficient-form + pub output_rlwe: RlweCiphertextExt, +} + +/// Construct an instance over the extension field from the original instance defined over the basic field +impl> RlweMultRgswInstanceExt { + /// Construct an instance over the extension field from the original instance defined over the basic field + pub fn from_base(input_base: &RlweMultRgswInstance) -> Self { + Self { + input_rlwe: >::from_base(&input_base.input_rlwe), + bits_rlwe_ntt: >::from_base(&input_base.bits_rlwe_ntt), + bits_rgsw_c_ntt: >::from_base(&input_base.bits_rgsw_c_ntt), + bits_rgsw_f_ntt: >::from_base(&input_base.bits_rgsw_f_ntt), + output_rlwe_ntt: >::from_base(&input_base.output_rlwe_ntt), + output_rlwe: >::from_base(&input_base.output_rlwe), + } + } +} + /// store the information used to verify -#[derive(Clone)] +#[derive(Clone, Serialize)] pub struct RlweMultRgswInfo { /// information of ntt instance pub ntt_info: NTTInstanceInfo, @@ -139,7 +223,7 @@ pub struct RlweMultRgswInfo { pub decomposed_bits_info: DecomposedBitsInfo, } -impl RlweMultRgswInstance { +impl> RlweMultRgswInstance { /// Extract the information #[inline] pub fn info(&self) -> RlweMultRgswInfo { @@ -155,7 +239,7 @@ impl RlweMultRgswInstance { pub fn from( decomposed_bits_info: &DecomposedBitsInfo, ntt_info: &NTTInstanceInfo, - randomness_ntt: &[F], + randomness_ntt: &[EF], input_rlwe: &RlweCiphertext, bits_rlwe: &RlweCiphertexts, bits_rlwe_ntt: &RlweCiphertexts, @@ -168,7 +252,7 @@ impl RlweMultRgswInstance { assert_eq!(randomness_ntt.len(), num_ntt_instance as usize); // randomize 2k + 1 ntt instances into a single one ntt instance - let mut ntt_instance = NTTInstance::from_info(ntt_info); + let mut ntt_instance = >::from_info(ntt_info); let mut r_iter = randomness_ntt.iter(); // k ntt instances for a_i =NTT equal= a_i' @@ -203,7 +287,7 @@ impl RlweMultRgswInstance { } } -impl RlweMultRgswSubclaim { +impl> RlweMultRgswSubclaim { /// verify the subclaim /// /// # Arguments @@ -222,10 +306,10 @@ impl RlweMultRgswSubclaim { #[allow(clippy::too_many_arguments)] pub fn verify_subclaim( &self, - u: &[F], - randomness_ntt: &[F], - ntt_coeffs: &DenseMultilinearExtension, - ntt_points: &DenseMultilinearExtension, + u: &[EF], + randomness_ntt: &[EF], + ntt_coeffs: &DenseMultilinearExtension, + ntt_points: &DenseMultilinearExtension, input_rlwe: &RlweCiphertext, bits_rlwe: &RlweCiphertexts, bits_rlwe_ntt: &RlweCiphertexts, @@ -241,26 +325,26 @@ impl RlweMultRgswSubclaim { assert_eq!(self.randomness_sumcheck.len(), 2); // check 1: check the consistency of the randomized ntt oracle and the original oracles - let mut coeffs_eval = F::zero(); - let mut points_eval = F::zero(); + let mut coeffs_eval = EF::zero(); + let mut points_eval = EF::zero(); let mut r_iter = randomness_ntt.iter(); for (coeffs, points) in izip!(&bits_rlwe.a_bits, &bits_rlwe_ntt.a_bits) { let r = r_iter.next().unwrap(); - coeffs_eval += *r * coeffs.evaluate(u); - points_eval += *r * points.evaluate(u); + coeffs_eval += *r * coeffs.evaluate_ext(u); + points_eval += *r * points.evaluate_ext(u); } for (coeffs, points) in izip!(&bits_rlwe.b_bits, &bits_rlwe_ntt.b_bits) { let r = r_iter.next().unwrap(); - coeffs_eval += *r * coeffs.evaluate(u); - points_eval += *r * points.evaluate(u); + coeffs_eval += *r * coeffs.evaluate_ext(u); + points_eval += *r * points.evaluate_ext(u); } let r = r_iter.next().unwrap(); - coeffs_eval += *r * output_rlwe.a.evaluate(u); - points_eval += *r * output_rlwe_ntt.a.evaluate(u); + coeffs_eval += *r * output_rlwe.a.evaluate_ext(u); + points_eval += *r * output_rlwe_ntt.a.evaluate_ext(u); let r = r_iter.next().unwrap(); - coeffs_eval += *r * output_rlwe.b.evaluate(u); - points_eval += *r * output_rlwe_ntt.b.evaluate(u); + coeffs_eval += *r * output_rlwe.b.evaluate_ext(u); + points_eval += *r * output_rlwe_ntt.b.evaluate_ext(u); if coeffs_eval != ntt_coeffs.evaluate(u) || points_eval != ntt_points.evaluate(u) { return false; @@ -288,8 +372,8 @@ impl RlweMultRgswSubclaim { } // 4. check 4: check the subclaim returned from the sumcheck protocol consisting of two sub-sumcheck protocols - let mut sum1_eval = F::zero(); - let mut sum2_eval = F::zero(); + let mut sum1_eval = EF::zero(); + let mut sum2_eval = EF::zero(); // The first part is to evaluate at a random point g' = \sum_{i = 0}^{k-1} a_i' \cdot c_i + b_i' \cdot f_i // It is the reduction claim of prover asserting the sum \sum_{x} eq(u, x) (\sum_{i = 0}^{k-1} a_i'(x) \cdot c_i(x) + b_i'(x) \cdot f_i(x) - g'(x)) = 0 @@ -300,10 +384,10 @@ impl RlweMultRgswSubclaim { &bits_rgsw_c_ntt.a_bits, &bits_rgsw_f_ntt.a_bits ) { - sum1_eval += (a.evaluate(&self.sumcheck_subclaim.point) - * c.evaluate(&self.sumcheck_subclaim.point)) - + (b.evaluate(&self.sumcheck_subclaim.point) - * f.evaluate(&self.sumcheck_subclaim.point)); + sum1_eval += (a.evaluate_ext(&self.sumcheck_subclaim.point) + * c.evaluate_ext(&self.sumcheck_subclaim.point)) + + (b.evaluate_ext(&self.sumcheck_subclaim.point) + * f.evaluate_ext(&self.sumcheck_subclaim.point)); } // The second part is to evaluate at a random point h' = \sum_{i = 0}^{k-1} a_i' \cdot c_i' + b_i' \cdot f_i' @@ -315,51 +399,52 @@ impl RlweMultRgswSubclaim { &bits_rgsw_c_ntt.b_bits, &bits_rgsw_f_ntt.b_bits ) { - sum2_eval += (a.evaluate(&self.sumcheck_subclaim.point) - * c.evaluate(&self.sumcheck_subclaim.point)) - + (b.evaluate(&self.sumcheck_subclaim.point) - * f.evaluate(&self.sumcheck_subclaim.point)); + sum2_eval += (a.evaluate_ext(&self.sumcheck_subclaim.point) + * c.evaluate_ext(&self.sumcheck_subclaim.point)) + + (b.evaluate_ext(&self.sumcheck_subclaim.point) + * f.evaluate_ext(&self.sumcheck_subclaim.point)); } self.sumcheck_subclaim.expected_evaluations == eval_identity_function(u, &self.sumcheck_subclaim.point) * (self.randomness_sumcheck[0] - * (sum1_eval - output_rlwe_ntt.a.evaluate(&self.sumcheck_subclaim.point)) + * (sum1_eval + - output_rlwe_ntt + .a + .evaluate_ext(&self.sumcheck_subclaim.point)) + self.randomness_sumcheck[1] - * (sum2_eval - output_rlwe_ntt.b.evaluate(&self.sumcheck_subclaim.point))) + * (sum2_eval + - output_rlwe_ntt + .b + .evaluate_ext(&self.sumcheck_subclaim.point))) } } -impl RlweMultRgswIOP { - /// prove the multiplication between RLWE ciphertext and RGSW ciphertext - pub fn prove(instance: &RlweMultRgswInstance, u: &[F]) -> RlweMultRgswProof { - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::prove_as_subprotocol(&mut fs_rng, instance, u) - } - +impl> RlweMultRgswIOP { /// prove the multiplication between RLWE ciphertext and RGSW ciphertext /// This function does the same thing as `prove`, but it uses a `Fiat-Shamir RNG` as the transcript/to generate the /// verifier challenges. - pub fn prove_as_subprotocol( - fs_rng: &mut impl RngCore, - instance: &RlweMultRgswInstance, - u: &[F], - ) -> RlweMultRgswProof { + pub fn prove( + trans: &mut Transcript, + instance_base: &RlweMultRgswInstance, + u: &[EF], + ) -> RlweMultRgswProof { // construct the instance of bit decomposition - let mut decomposed_bits = DecomposedBits::from_info(&instance.decomposed_bits_info); - decomposed_bits.add_decomposed_bits_instance(&instance.bits_rlwe.a_bits); - decomposed_bits.add_decomposed_bits_instance(&instance.bits_rlwe.b_bits); - - let uniform = >::new(); + let mut decomposed_bits = DecomposedBits::from_info(&instance_base.decomposed_bits_info); + decomposed_bits.add_decomposed_bits_instance(&instance_base.bits_rlwe.a_bits); + decomposed_bits.add_decomposed_bits_instance(&instance_base.bits_rlwe.b_bits); - let mut poly = >::new(instance.ntt_instance.log_n); + let mut poly = >::new(instance_base.ntt_instance.log_n); let identity_func_at_u = Rc::new(gen_identity_evaluations(u)); // randomly combine two sumcheck protocols - // TODO sample randomness via Fiat-Shamir RNG - let r_1 = uniform.sample(fs_rng); - let r_2 = uniform.sample(fs_rng); + let r = trans.get_vec_ext_field_challenge(b"randomness to combine 2 sumchecks", 2); + let r_1 = r[0]; + let r_2 = r[1]; + + // Convert the corresponding MLEs over Field in the original instance to the MLE over Extension Field + let instance = >::from_base(instance_base); + // Sumcheck protocol for proving: g' = \sum_{i = 0}^{k-1} a_i' \cdot c_i + b_i' \cdot f_i // When proving g'(x) = \sum_{i = 0}^{k-1} a_i'(x) \cdot c_i(x) + b_i'(x) \cdot f_i(x) for x \in \{0, 1\}^\log n, // prover claims the sum \sum_{x} eq(u, x) (\sum_{i = 0}^{k-1} a_i'(x) \cdot c_i(x) + b_i'(x) \cdot f_i(x) - g'(x)) = 0 @@ -405,65 +490,43 @@ impl RlweMultRgswIOP { ); RlweMultRgswProof { - bit_decomposition_proof: BitDecomposition::prove_as_subprotocol( - fs_rng, - &decomposed_bits, - u, - ), - ntt_proof: NTTIOP::prove_as_subprotocol(fs_rng, &instance.ntt_instance, u), - sumcheck_msg: MLSumcheck::prove_as_subprotocol(fs_rng, &poly) + bit_decomposition_proof: BitDecomposition::prove(trans, &decomposed_bits, u), + ntt_proof: NTTIOP::prove(trans, &instance_base.ntt_instance, u), + sumcheck_msg: MLSumcheck::prove(trans, &poly) .expect("sumcheck fail in rlwe * rgsw") .0, } } - /// verify the proof - pub fn verify( - proof: &RlweMultRgswProof, - randomness_ntt: &[F], - u: &[F], - info: &RlweMultRgswInfo, - ) -> RlweMultRgswSubclaim { - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::verify_as_subprotocol(&mut fs_rng, proof, randomness_ntt, u, info) - } - /// verify the proof with provided RNG - pub fn verify_as_subprotocol( - fs_rng: &mut impl RngCore, - proof: &RlweMultRgswProof, - randomness_ntt: &[F], - u: &[F], + pub fn verify( + trans: &mut Transcript, + proof: &RlweMultRgswProof, + randomness_ntt: &[EF], + u: &[EF], info: &RlweMultRgswInfo, - ) -> RlweMultRgswSubclaim { - let uniform = >::new(); - // TODO sample randomness via Fiat-Shamir RNG - let r_1 = uniform.sample(fs_rng); - let r_2 = uniform.sample(fs_rng); + ) -> RlweMultRgswSubclaim { + let r = trans.get_vec_ext_field_challenge(b"randomness to combine 2 sumchecks", 2); + let r_1 = r[0]; + let r_2 = r[1]; let poly_info = PolynomialInfo { max_multiplicands: 3, num_variables: info.ntt_info.log_n, }; RlweMultRgswSubclaim { - bit_decomposition_subclaim: BitDecomposition::verifier_as_subprotocol( - fs_rng, + bit_decomposition_subclaim: BitDecomposition::verify( + trans, &proof.bit_decomposition_proof, &info.decomposed_bits_info, ), - ntt_subclaim: NTTIOP::verify_as_subprotocol( - fs_rng, - &proof.ntt_proof, - &info.ntt_info, - u, - ), + ntt_subclaim: NTTIOP::verify(trans, &proof.ntt_proof, &info.ntt_info, u), randomness_ntt: randomness_ntt.to_owned(), randomness_sumcheck: vec![r_1, r_2], - sumcheck_subclaim: MLSumcheck::verify_as_subprotocol( - fs_rng, + sumcheck_subclaim: MLSumcheck::verify( + trans, &poly_info, - F::zero(), + EF::zero(), &proof.sumcheck_msg, ) .expect("sumcheck protocol in rlwe mult rgsw failed"), diff --git a/zkp/src/piop/round.rs b/zkp/src/piop/round.rs index d90fdd9f..bb538297 100644 --- a/zkp/src/piop/round.rs +++ b/zkp/src/piop/round.rs @@ -21,34 +21,31 @@ use crate::sumcheck::MLSumcheck; use crate::sumcheck::Proof; use crate::utils::eval_identity_function; use crate::utils::gen_identity_evaluations; +use algebra::utils::Transcript; +use algebra::AbstractExtensionField; use algebra::DecomposableField; +use algebra::DenseMultilinearExtensionBase; use itertools::izip; use std::marker::PhantomData; use std::rc::Rc; use std::vec; -use algebra::{ - DenseMultilinearExtension, Field, FieldUniformSampler, ListOfProductsOfPolynomials, - MultilinearExtension, PolynomialInfo, -}; -use rand::{RngCore, SeedableRng}; -use rand_chacha::ChaCha12Rng; -use rand_distr::Distribution; +use algebra::{DenseMultilinearExtension, Field, ListOfProductsOfPolynomials, PolynomialInfo, UF}; use super::bit_decomposition::{BitDecomposition, BitDecompositionProof, BitDecompositionSubClaim}; use super::{DecomposedBits, DecomposedBitsInfo}; /// Round IOP -pub struct RoundIOP(PhantomData); +pub struct RoundIOP>(PhantomData, PhantomData); /// proof generated by prover -pub struct RoundIOPProof { +pub struct RoundIOPProof> { /// range check proof for output b - pub bit_decomp_proof_output: BitDecompositionProof, + pub bit_decomp_proof_output: BitDecompositionProof, /// range check proof for offset c - 1 - pub bit_decomp_proof_offset: BitDecompositionProof, + pub bit_decomp_proof_offset: BitDecompositionProof, /// sumcheck msg - pub sumcheck_msg: Proof, + pub sumcheck_msg: Proof, } /// Round Instance used as prover keys pub struct RoundInstance { @@ -59,21 +56,54 @@ pub struct RoundInstance { /// delta = 2^{k_bit_len} - k pub delta: F, /// input denoted by a \in F_Q - pub input: Rc>, + pub input: Rc>, /// output denoted by b \in F_q - pub output: Rc>, + pub output: Rc>, /// decomposed bits of ouput used for range check pub output_bits: DecomposedBits, /// offset denoted by c = a - b * k \in [1, k] such that c - 1 \in [0, k) - pub offset: Rc>, + pub offset: Rc>, /// offset_aux_bits contains two instances of bit decomposition /// decomposed bits of c - 1 \in [0, 2^k_bit_len) used for range check /// decomposed bits of c - 1 + delta \in [0, 2^k_bit_len) used for range check pub offset_aux_bits: DecomposedBits, /// option denoted by w \in {0, 1} - pub option: Rc>, + pub option: Rc>, +} + +/// Store the corresponding MLE of RoundInstance where the evaluations are over the extension field. +pub struct RoundInstanceExt> { + /// input denoted by a \in F_Q + pub input: Rc>, + /// output denoted by b \in F_q + pub output: Rc>, + + /// offset denoted by c = a - b * k \in [1, k] such that c - 1 \in [0, k) + pub offset: Rc>, + + /// option denoted by w \in {0, 1} + pub option: Rc>, +} + +impl> RoundInstanceExt { + fn from_base(input_base: &RoundInstance) -> Self { + Self { + input: Rc::new(>::from_base( + input_base.input.as_ref(), + )), + output: Rc::new(>::from_base( + input_base.output.as_ref(), + )), + offset: Rc::new(>::from_base( + input_base.offset.as_ref(), + )), + option: Rc::new(>::from_base( + input_base.option.as_ref(), + )), + } + } } /// Information about Round Instance used as verifier keys @@ -89,15 +119,15 @@ pub struct RoundInstanceInfo { } /// subclaim returned to verifier -pub struct RoundIOPSubclaim { +pub struct RoundIOPSubclaim> { /// subclaim returned from the range check for output b \in [q] - pub bit_decomp_output_subclaim: BitDecompositionSubClaim, + pub bit_decomp_output_subclaim: BitDecompositionSubClaim, /// subclaim returned from the range check for offset c - 1 \in [k] - pub bit_decomp_offset_subclaim: BitDecompositionSubClaim, + pub bit_decomp_offset_subclaim: BitDecompositionSubClaim, /// subclaim returned from the sumcheck protocol - pub sumcheck_subclaim: SubClaim, + pub sumcheck_subclaim: SubClaim, /// randomness used in the sumcheck - pub randomness: (F, F), + pub randomness: (EF, EF), } impl RoundInstance { @@ -119,8 +149,8 @@ impl RoundInstance { pub fn new( k: F, delta: F, - input: Rc>, - output: Rc>, + input: Rc>, + output: Rc>, output_bits_info: &DecomposedBitsInfo, offset_bits_info: &DecomposedBitsInfo, ) -> Self { @@ -135,7 +165,7 @@ impl RoundInstance { output_bits.add_value_instance(&output); // set w = 1 iff a = 0 & b = 0 - let option = Rc::new(DenseMultilinearExtension::::from_evaluations_vec( + let option = Rc::new(DenseMultilinearExtensionBase::::from_evaluations_vec( num_vars, input .iter() @@ -150,7 +180,7 @@ impl RoundInstance { // Note that we must set c \in [1, k] when w = 1 to ensure that c(x) \in [1, k] for all x \in {0,1}^logn // if w = 0: c = a - b * k // if w = 1: c = 1 defaultly - let offset = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let offset = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, izip!(option.iter(), input.iter(), output.iter()) .map(|(w, a, b)| match w.is_zero() { @@ -162,12 +192,12 @@ impl RoundInstance { let mut offset_aux_bits = DecomposedBits::from_info(offset_bits_info); // c - 1 - let c_minus_one = DenseMultilinearExtension::from_evaluations_vec( + let c_minus_one = DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, offset.iter().map(|x| *x - F::one()).collect(), ); // c - 1 + delta - let c_minus_one_delta = DenseMultilinearExtension::from_evaluations_vec( + let c_minus_one_delta = DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, c_minus_one.iter().map(|x| *x + delta).collect(), ); @@ -190,20 +220,20 @@ impl RoundInstance { } } -impl RoundIOPSubclaim { +impl> RoundIOPSubclaim { /// verify the subclaim #[allow(clippy::too_many_arguments)] pub fn verify_subclaim( &self, - u: &[F], - (lambda_1, lambda_2): (F, F), - input: &Rc>, - output: &Rc>, - output_bits: &Vec>>, - offset: &Rc>, - offset_aux_bits_1: &Vec>>, - offset_aux_bits_2: &Vec>>, - option: &Rc>, + u: &[EF], + (lambda_1, lambda_2): (EF, EF), + input: &Rc>, + output: &Rc>, + output_bits: &Vec>>, + offset: &Rc>, + offset_aux_bits_1: &Vec>>, + offset_aux_bits_2: &Vec>>, + option: &Rc>, info: &RoundInstanceInfo, ) -> bool { // check 1: check the subclaim returned from the range check for output b \in [q] @@ -220,11 +250,11 @@ impl RoundIOPSubclaim { // check 2: check the subclaim returned from the range check for offset c - 1 \in [k] let d_bits = vec![offset_aux_bits_1, offset_aux_bits_2]; - let c_minus_one = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let c_minus_one = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( offset.num_vars, offset.iter().map(|x| *x - F::one()).collect(), )); - let c_minus_one_delta = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let c_minus_one_delta = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( offset.num_vars, c_minus_one.iter().map(|x| *x + info.delta).collect(), )); @@ -241,61 +271,51 @@ impl RoundIOPSubclaim { // check 3: check the subclaim returned from the sumcheck protocol let (r_1, r_2) = self.randomness; let eq_val = eval_identity_function(u, &self.sumcheck_subclaim.point); - let option_eval = option.evaluate(&self.sumcheck_subclaim.point); - let input_eval = input.evaluate(&self.sumcheck_subclaim.point); - let output_eval = output.evaluate(&self.sumcheck_subclaim.point); - let offset_eval = offset.evaluate(&self.sumcheck_subclaim.point); + let option_eval = option.evaluate_ext(&self.sumcheck_subclaim.point); + let input_eval = input.evaluate_ext(&self.sumcheck_subclaim.point); + let output_eval = output.evaluate_ext(&self.sumcheck_subclaim.point); + let offset_eval = offset.evaluate_ext(&self.sumcheck_subclaim.point); self.sumcheck_subclaim.expected_evaluations - == r_1 * eq_val * option_eval * (F::one() - option_eval) + == r_1 * eq_val * option_eval * (EF::one() - option_eval) + r_2 * eq_val * (option_eval * (input_eval * lambda_1 + output_eval * lambda_2) - + (F::one() - option_eval) + + (EF::one() - option_eval) * (input_eval - output_eval * info.k - offset_eval)) } } -impl RoundIOP { +impl> RoundIOP { /// Prove round operation pub fn prove( + trans: &mut Transcript, instance: &RoundInstance, - u: &[F], - (lambda_1, lambda_2): (F, F), - ) -> RoundIOPProof { - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::prove_as_subprotocol(&mut fs_rng, instance, u, (lambda_1, lambda_2)) - } - - /// Prove round operation - pub fn prove_as_subprotocol( - fs_rng: &mut impl RngCore, - instance: &RoundInstance, - u: &[F], - (lambda_1, lambda_2): (F, F), - ) -> RoundIOPProof { - let uniform = >::new(); + u: &[EF], + (lambda_1, lambda_2): (EF, EF), + ) -> RoundIOPProof { + // Convert the original instance over Field to a new instance over Extension Field + let instance_ext = >::from_base(instance); - let mut poly = >::new(instance.num_vars); + let mut poly = >::new(instance.num_vars); let identity_func_at_u = Rc::new(gen_identity_evaluations(u)); // randomly combine two sumcheck protocols - // TODO sample randomness via Fiat-Shamir RNG - let r_1 = uniform.sample(fs_rng); - let r_2 = uniform.sample(fs_rng); + let r = trans.get_vec_ext_field_challenge(b"randomness for sumcheck", 2); + let r_1 = r[0]; + let r_2 = r[1]; // sumcheck1 for \sum_{x} eq(u, x) * w(x) * (1-w(x)) = 0, i.e. w(x)\in\{0,1\}^l with random coefficient r_1 poly.add_product_with_linear_op( [ Rc::clone(&identity_func_at_u), - Rc::clone(&instance.option), - Rc::clone(&instance.option), + Rc::clone(&instance_ext.option), + Rc::clone(&instance_ext.option), ], &[ - (F::one(), F::zero()), - (F::one(), F::zero()), - (-F::one(), F::one()), + (UF::one(), UF::zero()), + (UF::one(), UF::zero()), + (UF::BaseField(-F::one()), UF::one()), ], r_1, ); @@ -308,13 +328,13 @@ impl RoundIOP { poly.add_product_with_linear_op( [ Rc::clone(&identity_func_at_u), - Rc::clone(&instance.option), - Rc::clone(&instance.input), + Rc::clone(&instance_ext.option), + Rc::clone(&instance_ext.input), ], &[ - (F::one(), F::zero()), - (F::one(), F::zero()), - (lambda_1, F::zero()), + (UF::one(), UF::zero()), + (UF::one(), UF::zero()), + (UF::ExtensionField(lambda_1), UF::zero()), ], r_2, ); @@ -322,13 +342,13 @@ impl RoundIOP { poly.add_product_with_linear_op( [ Rc::clone(&identity_func_at_u), - Rc::clone(&instance.option), - Rc::clone(&instance.output), + Rc::clone(&instance_ext.option), + Rc::clone(&instance_ext.output), ], &[ - (F::one(), F::zero()), - (F::one(), F::zero()), - (lambda_2, F::zero()), + (UF::one(), UF::zero()), + (UF::one(), UF::zero()), + (UF::ExtensionField(lambda_2), UF::zero()), ], r_2, ); @@ -336,13 +356,13 @@ impl RoundIOP { poly.add_product_with_linear_op( [ Rc::clone(&identity_func_at_u), - Rc::clone(&instance.option), - Rc::clone(&instance.input), + Rc::clone(&instance_ext.option), + Rc::clone(&instance_ext.input), ], &[ - (F::one(), F::zero()), - (-F::one(), F::one()), - (F::one(), F::zero()), + (UF::one(), UF::zero()), + (UF::BaseField(-F::one()), UF::one()), + (UF::one(), UF::zero()), ], r_2, ); @@ -350,13 +370,13 @@ impl RoundIOP { poly.add_product_with_linear_op( [ Rc::clone(&identity_func_at_u), - Rc::clone(&instance.option), - Rc::clone(&instance.output), + Rc::clone(&instance_ext.option), + Rc::clone(&instance_ext.output), ], &[ - (F::one(), F::zero()), - (-F::one(), F::one()), - (-instance.k, F::zero()), + (UF::one(), UF::zero()), + (UF::BaseField(-F::one()), UF::one()), + (UF::BaseField(-instance.k), UF::zero()), ], r_2, ); @@ -364,74 +384,59 @@ impl RoundIOP { poly.add_product_with_linear_op( [ Rc::clone(&identity_func_at_u), - Rc::clone(&instance.option), - Rc::clone(&instance.offset), + Rc::clone(&instance_ext.option), + Rc::clone(&instance_ext.offset), ], &[ - (F::one(), F::zero()), - (-F::one(), F::one()), - (-F::one(), F::zero()), + (UF::one(), UF::zero()), + (UF::BaseField(-F::one()), UF::one()), + (UF::BaseField(-F::one()), UF::zero()), ], r_2, ); RoundIOPProof { - bit_decomp_proof_output: BitDecomposition::prove_as_subprotocol( - fs_rng, - &instance.output_bits, - u, - ), - bit_decomp_proof_offset: BitDecomposition::prove_as_subprotocol( - fs_rng, - &instance.offset_aux_bits, - u, - ), - sumcheck_msg: MLSumcheck::prove_as_subprotocol(fs_rng, &poly) + bit_decomp_proof_output: BitDecomposition::prove(trans, &instance.output_bits, u), + bit_decomp_proof_offset: BitDecomposition::prove(trans, &instance.offset_aux_bits, u), + sumcheck_msg: MLSumcheck::prove(trans, &poly) .expect("sumcheck for round operation failed") .0, } } - /// verify - pub fn verify(proof: &RoundIOPProof, info: &RoundInstanceInfo) -> RoundIOPSubclaim { - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::verify_as_subprotocol(&mut fs_rng, proof, info) - } - /// verify with given rng - pub fn verify_as_subprotocol( - fs_rng: &mut impl RngCore, - proof: &RoundIOPProof, + pub fn verify( + trans: &mut Transcript, + proof: &RoundIOPProof, info: &RoundInstanceInfo, - ) -> RoundIOPSubclaim { + ) -> RoundIOPSubclaim { let num_vars = info.output_bits_info.num_vars; assert_eq!(num_vars, info.offset_bits_info.num_vars); - let uniform = >::new(); + // randomly combine two sumcheck protocols - // TODO sample randomness via Fiat-Shamir RNG - let r_1 = uniform.sample(fs_rng); - let r_2 = uniform.sample(fs_rng); + let r = trans.get_vec_ext_field_challenge(b"randomness for sumcheck", 2); + let r_1 = r[0]; + let r_2 = r[1]; let poly_info = PolynomialInfo { max_multiplicands: 3, num_variables: num_vars, }; RoundIOPSubclaim { - bit_decomp_output_subclaim: BitDecomposition::verifier_as_subprotocol( - fs_rng, + bit_decomp_output_subclaim: BitDecomposition::verify( + trans, &proof.bit_decomp_proof_output, &info.output_bits_info, ), - bit_decomp_offset_subclaim: BitDecomposition::verifier_as_subprotocol( - fs_rng, + bit_decomp_offset_subclaim: BitDecomposition::verify( + trans, &proof.bit_decomp_proof_offset, &info.offset_bits_info, ), - sumcheck_subclaim: MLSumcheck::verify_as_subprotocol( - fs_rng, + sumcheck_subclaim: MLSumcheck::verify( + trans, &poly_info, - F::zero(), + EF::zero(), &proof.sumcheck_msg, ) .expect("sumcheck protocol for round operation failed"), diff --git a/zkp/src/piop/zq_to_rq.rs b/zkp/src/piop/zq_to_rq.rs index 98818545..a6e6cb0e 100644 --- a/zkp/src/piop/zq_to_rq.rs +++ b/zkp/src/piop/zq_to_rq.rs @@ -39,33 +39,32 @@ use crate::utils::eval_identity_function; use crate::sumcheck::MLSumcheck; use crate::utils::gen_identity_evaluations; use algebra::{ - AsFrom, DecomposableField, DenseMultilinearExtension, Field, ListOfProductsOfPolynomials, - MultilinearExtension, PolynomialInfo, SparsePolynomial, + utils::Transcript, AbstractExtensionField, AsFrom, DecomposableField, + DenseMultilinearExtension, DenseMultilinearExtensionBase, Field, ListOfProductsOfPolynomials, + MultilinearExtension, PolynomialInfo, SparsePolynomial, UF, }; -use rand::{RngCore, SeedableRng}; -use rand_chacha::ChaCha12Rng; /// proof generated by prover -pub struct TransformZqtoRQProof { +pub struct TransformZqtoRQProof> { /// singe rangecheck proof for r - pub rangecheck_msg: BitDecompositionProof, + pub rangecheck_msg: BitDecompositionProof, /// sumcheck proofs for /// \sum_{x} eq(u,x) * k(x) * (1 - k(x)) = 0; /// \sum_{x} eq(u,x) * ((r(x) + 1) * (1 - 2k(x)) - s(x)) = 0; /// \sum_{y} c_u(y) * t(y) = s(u) - pub sumcheck_msgs: Vec>>, + pub sumcheck_msgs: Vec>>, /// the claimed sum of the third sumcheck i.e. s(u) - pub s_u: F, + pub s_u: EF, } /// subclaim returned to verifier -pub struct TransformZqtoRQSubclaim { +pub struct TransformZqtoRQSubclaim> { /// rangecheck subclaim for a, b, c \in Zq - pub(crate) rangecheck_subclaim: BitDecompositionSubClaim, + pub(crate) rangecheck_subclaim: BitDecompositionSubClaim, /// subcliam - pub sumcheck_points: Vec>, + pub sumcheck_points: Vec>, /// expected value returned in the last round of the sumcheck - pub sumcheck_expected_evaluations: Vec, + pub sumcheck_expected_evaluations: Vec, } /// Stores the parameters used for transformation from Zq to RQ and the inputs and witness for prover. @@ -82,17 +81,41 @@ pub struct TransformZqtoRQInstance { /// inputs c pub c: Vec>>, /// inputs a - pub a: Rc>, + pub a: Rc>, /// introduced witness k - pub k: Rc>, + pub k: Rc>, /// introduced witness r - pub r: Rc>, + pub r: Rc>, /// introduced witness s - pub s: Rc>, + pub s: Rc>, /// introduced witness to check the range of a, b, c pub r_bits: DecomposedBits, } +/// Instance of extension vesion +pub struct TransformZqtoRQInstanceExt> { + /// inputs a + pub a: Rc>, + /// introduced witness k + pub k: Rc>, + /// introduced witness r + pub r: Rc>, + /// introduced witness s + pub s: Rc>, +} + +impl> TransformZqtoRQInstanceExt { + /// Construct an instance over the extension field from the original instance defined over the basic field + pub fn from_base(input_base: &TransformZqtoRQInstance) -> Self { + Self { + a: Rc::new(DenseMultilinearExtension::from_base(&input_base.a)), + k: Rc::new(DenseMultilinearExtension::from_base(&input_base.k)), + r: Rc::new(DenseMultilinearExtension::from_base(&input_base.r)), + s: Rc::new(DenseMultilinearExtension::from_base(&input_base.s)), + } + } +} + /// Stores the parameters used for addition in Zq and the public info for verifier. pub struct TransformZqtoRQInstanceInfo { /// number of variables @@ -126,10 +149,10 @@ impl TransformZqtoRQInstance { pub fn from_vec( q: usize, c: Vec>>, - a: Rc>, - k: &Rc>, - r: &Rc>, - s: &Rc>, + a: Rc>, + k: &Rc>, + r: &Rc>, + s: &Rc>, base: F, base_len: u32, bits_len: u32, @@ -174,87 +197,80 @@ impl TransformZqtoRQInstance { } /// SNARKs for transformation from Zq to RQ i.e. R/QR -pub struct TransformZqtoRQ(PhantomData); +pub struct TransformZqtoRQ>( + PhantomData, + PhantomData, +); -impl TransformZqtoRQ { +impl> TransformZqtoRQ { /// Prove transformation from a \in Zq to c \in R/QR pub fn prove( - transform_instance: &TransformZqtoRQInstance, - u: &[F], - ) -> TransformZqtoRQProof { - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::prove_as_subprotocol(&mut fs_rng, transform_instance, u) - } - - /// Prove transformation from Zq to R/QR given input a, c, witness k, r, s and the decomposed bits for r. - /// This function does the same thing as `prove`, but it uses a `Fiat-Shamir RNG` as the transcript/to generate the - /// verifier challenges. - pub fn prove_as_subprotocol( - fs_rng: &mut impl RngCore, - transform_instance: &TransformZqtoRQInstance, - u: &[F], - ) -> TransformZqtoRQProof { + trans: &mut Transcript, + instance_base: &TransformZqtoRQInstance, + u: &[EF], + ) -> TransformZqtoRQProof { let dim = u.len(); - assert_eq!(dim, transform_instance.num_vars); + assert_eq!(dim, instance_base.num_vars); // 1. rangecheck for r - let rangecheck_msg = - BitDecomposition::prove_as_subprotocol(fs_rng, &transform_instance.r_bits, u); + let rangecheck_msg = BitDecomposition::prove(trans, &instance_base.r_bits, u); + + // Convert the original instance defined over F to a new instance defined over EF + let instance = >::from_base(instance_base); // 2. execute sumcheck for \sum_{x \in {0,1}^logM} eq(u, x) * k(x) * (1-k(x)) = 0 i.e. k(x) \in \{0,1\}^dim - let mut poly = >::new(dim); + let mut poly = >::new(dim); let mut product = Vec::with_capacity(3); let mut op_coefficient = Vec::with_capacity(3); product.push(Rc::new(gen_identity_evaluations(u))); - op_coefficient.push((F::one(), F::zero())); - product.push(Rc::clone(&transform_instance.k)); - op_coefficient.push((F::one(), F::zero())); - product.push(Rc::clone(&transform_instance.k)); - op_coefficient.push((-F::one(), F::one())); - poly.add_product_with_linear_op(product, &op_coefficient, F::one()); - - let first_sumcheck_proof = MLSumcheck::prove_as_subprotocol(fs_rng, &poly) + op_coefficient.push((UF::one(), UF::zero())); + product.push(Rc::clone(&instance.k)); + op_coefficient.push((UF::one(), UF::zero())); + product.push(Rc::clone(&instance.k)); + op_coefficient.push((UF::BaseField(-F::one()), UF::one())); + poly.add_product_with_linear_op(product, &op_coefficient, EF::one()); + + let first_sumcheck_proof = MLSumcheck::prove(trans, &poly) .expect("sumcheck for transformation from Zq to RQ failed"); // 3. execute sumcheck for \sum_{x \in {0,1}^logM} eq(u,x)((r(x) + 1) * (1 - 2k(x)) - s(x)) = 0 i.e. (r(x) + 1)(1 - 2k(x)) = s(x) for x in \{0,1\}^dim - let mut poly = >::new(dim); + let mut poly = >::new(dim); let mut product = Vec::with_capacity(3); let mut op_coefficient = Vec::with_capacity(3); product.push(Rc::new(gen_identity_evaluations(u))); - op_coefficient.push((F::one(), F::zero())); - product.push(Rc::clone(&transform_instance.r)); - op_coefficient.push((F::one(), F::one())); - product.push(Rc::clone(&transform_instance.k)); - op_coefficient.push((-(F::one() + F::one()), F::one())); - poly.add_product_with_linear_op(product, &op_coefficient, F::one()); + op_coefficient.push((UF::one(), UF::zero())); + product.push(Rc::clone(&instance.r)); + op_coefficient.push((UF::one(), UF::one())); + product.push(Rc::clone(&instance.k)); + op_coefficient.push((UF::BaseField(-(F::one() + F::one())), UF::one())); + poly.add_product_with_linear_op(product, &op_coefficient, EF::one()); let mut product = Vec::with_capacity(2); let mut op_coefficient = Vec::with_capacity(2); product.push(Rc::new(gen_identity_evaluations(u))); - op_coefficient.push((F::one(), F::zero())); - product.push(Rc::clone(&transform_instance.s)); - op_coefficient.push((-F::one(), F::zero())); - poly.add_product_with_linear_op(product, &op_coefficient, F::one()); + op_coefficient.push((UF::one(), UF::zero())); + product.push(Rc::clone(&instance.s)); + op_coefficient.push((UF::BaseField(-F::one()), UF::zero())); + poly.add_product_with_linear_op(product, &op_coefficient, EF::one()); - let second_sumcheck_proof = MLSumcheck::prove_as_subprotocol(fs_rng, &poly) + let second_sumcheck_proof = MLSumcheck::prove(trans, &poly) .expect("sumcheck for transformation from Zq to RQ failed"); // 4. sumcheck for \sum_{y \in {0,1}^logN} c(u,y)t(y) = s(u) - let c_num_vars = transform_instance.n.ilog(2) as usize; + let c_num_vars = instance_base.n.ilog(2) as usize; // construct c_u let eq_u = gen_identity_evaluations(u).evaluations; - let mut c_u_evaluations = vec![F::zero(); transform_instance.n]; - transform_instance + let mut c_u_evaluations = vec![EF::zero(); instance_base.n]; + instance_base .c .iter() .enumerate() .for_each(|(x_idx, sparse_p)| { sparse_p.iter().for_each(|(y_idx, value)| { - c_u_evaluations[*y_idx] += eq_u[x_idx] * value; + c_u_evaluations[*y_idx] += eq_u[x_idx] * *value; }); }); let c_u = Rc::new(DenseMultilinearExtension::from_evaluations_vec( @@ -263,24 +279,27 @@ impl TransformZqtoRQ { )); // construct t - let t_evaluations = (1..=transform_instance.n) + let t_evaluations = (1..=instance_base.n) .map(|i| F::new(F::Value::as_from(i as f64))) .collect(); - let t = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let t = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( c_num_vars, t_evaluations, )); - let mut poly = >::new(c_num_vars); + // Convert to the MLE defined over the Extension Field + let t = Rc::new(>::from_base(&t)); + + let mut poly = >::new(c_num_vars); let mut product = Vec::with_capacity(2); let mut op_coefficient = Vec::with_capacity(2); product.push(Rc::clone(&c_u)); - op_coefficient.push((F::one(), F::zero())); + op_coefficient.push((UF::one(), UF::zero())); product.push(Rc::clone(&t)); - op_coefficient.push((F::one(), F::zero())); - poly.add_product_with_linear_op(product, &op_coefficient, F::one()); + op_coefficient.push((UF::one(), UF::zero())); + poly.add_product_with_linear_op(product, &op_coefficient, EF::one()); - let third_sumcheck_proof = MLSumcheck::prove_as_subprotocol(fs_rng, &poly) + let third_sumcheck_proof = MLSumcheck::prove(trans, &poly) .expect("sumcheck for transformation from Zq to RQ failed"); TransformZqtoRQProof { @@ -290,38 +309,20 @@ impl TransformZqtoRQ { second_sumcheck_proof.0, third_sumcheck_proof.0, ], - s_u: transform_instance.s.evaluate(u), + s_u: instance.s.evaluate(u), } } /// Verify transformation from Zq to RQ given the proof and the verification key for bit decomposistion pub fn verify( - proof: &TransformZqtoRQProof, + trans: &mut Transcript, + proof: &TransformZqtoRQProof, decomposed_bits_info: &DecomposedBitsInfo, c_num_vars: usize, - ) -> TransformZqtoRQSubclaim { - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::verifier_as_subprotocol(&mut fs_rng, proof, decomposed_bits_info, c_num_vars) - } - - /// Verify transformation from Zq to RQ given the proof and the verification key for bit decomposistion - /// This function does the same thing as `prove`, but it uses a `Fiat-Shamir RNG` as the transcript/to generate the - /// verifier challenges. - pub fn verifier_as_subprotocol( - fs_rng: &mut impl RngCore, - proof: &TransformZqtoRQProof, - decomposed_bits_info: &DecomposedBitsInfo, - c_num_vars: usize, - ) -> TransformZqtoRQSubclaim { - //TODO sample randomness via Fiat-Shamir RNG - + ) -> TransformZqtoRQSubclaim { // 1. rangecheck - let rangecheck_subclaim = BitDecomposition::verifier_as_subprotocol( - fs_rng, - &proof.rangecheck_msg, - decomposed_bits_info, - ); + let rangecheck_subclaim = + BitDecomposition::verify(trans, &proof.rangecheck_msg, decomposed_bits_info); // 2. sumcheck let poly_info = PolynomialInfo { @@ -329,33 +330,21 @@ impl TransformZqtoRQ { num_variables: decomposed_bits_info.num_vars, }; - let first_subclaim = MLSumcheck::verify_as_subprotocol( - fs_rng, - &poly_info, - F::zero(), - &proof.sumcheck_msgs[0], - ) - .expect("sumcheck protocol for transformation from Zq to RQ failed"); - - let second_subclaim = MLSumcheck::verify_as_subprotocol( - fs_rng, - &poly_info, - F::zero(), - &proof.sumcheck_msgs[1], - ) - .expect("sumcheck protocol for transformation from Zq to RQ failed"); + let first_subclaim = + MLSumcheck::verify(trans, &poly_info, EF::zero(), &proof.sumcheck_msgs[0]) + .expect("sumcheck protocol for transformation from Zq to RQ failed"); + + let second_subclaim = + MLSumcheck::verify(trans, &poly_info, EF::zero(), &proof.sumcheck_msgs[1]) + .expect("sumcheck protocol for transformation from Zq to RQ failed"); let poly_info = PolynomialInfo { max_multiplicands: 2, num_variables: c_num_vars, }; - let third_subclaim = MLSumcheck::verify_as_subprotocol( - fs_rng, - &poly_info, - proof.s_u, - &proof.sumcheck_msgs[2], - ) - .expect("sumcheck protocol for transformation from Zq to RQ failed"); + let third_subclaim = + MLSumcheck::verify(trans, &poly_info, proof.s_u, &proof.sumcheck_msgs[2]) + .expect("sumcheck protocol for transformation from Zq to RQ failed"); TransformZqtoRQSubclaim { rangecheck_subclaim, @@ -373,7 +362,7 @@ impl TransformZqtoRQ { } } -impl TransformZqtoRQSubclaim { +impl> TransformZqtoRQSubclaim { /// verify the sumcliam /// * a stores the input and c stores the output of transformation from Zq to RQ /// * k, r, s stores the introduced witness @@ -384,13 +373,13 @@ impl TransformZqtoRQSubclaim { pub fn verify_subclaim( &self, q: usize, - a: Rc>, - c_dense: &DenseMultilinearExtension, - k: &DenseMultilinearExtension, - r: &[Rc>], - s: &DenseMultilinearExtension, - r_bits: &[&Vec>>], - u: &[F], + a: Rc>, + c_dense: &DenseMultilinearExtensionBase, + k: &DenseMultilinearExtensionBase, + r: &[Rc>], + s: &DenseMultilinearExtensionBase, + r_bits: &[&Vec>>], + u: &[EF], info: &TransformZqtoRQInstanceInfo, ) -> bool { assert_eq!(r_bits.len(), 1); @@ -405,8 +394,8 @@ impl TransformZqtoRQSubclaim { } // check 2: subclaim for sumcheck, i.e. eq(u, point) * k(point) * (1 - k(point)) = 0 - let eval_k = k.evaluate(&self.sumcheck_points[0]); - if eval_identity_function(u, &self.sumcheck_points[0]) * eval_k * (F::one() - eval_k) + let eval_k = k.evaluate_ext(&self.sumcheck_points[0]); + if eval_identity_function(u, &self.sumcheck_points[0]) * eval_k * (EF::one() - eval_k) != self.sumcheck_expected_evaluations[0] { return false; @@ -414,24 +403,25 @@ impl TransformZqtoRQSubclaim { // check 3: subclaim for sumcheck, i.e. eq(u, point) * ((r(point) + 1) * (1 - 2 * k(point)) - s(point)) = 0 if eval_identity_function(u, &self.sumcheck_points[1]) - * ((r[0].evaluate(&self.sumcheck_points[1]) + F::one()) - * (F::one() - (F::one() + F::one()) * k.evaluate(&self.sumcheck_points[1])) - - s.evaluate(&self.sumcheck_points[1])) + * ((r[0].evaluate_ext(&self.sumcheck_points[1]) + F::one()) + * (EF::one() - k.evaluate_ext(&self.sumcheck_points[1]) * (F::one() + F::one())) + - s.evaluate_ext(&self.sumcheck_points[1])) != self.sumcheck_expected_evaluations[1] { return false; } // check 4: subclaim for sumcheck, i.e. c(u, point) * t(point) = s(u) - let eval_c_u = c_dense.evaluate(&[&self.sumcheck_points[2], u].concat()); + let eval_c_u = c_dense.evaluate_ext(&[&self.sumcheck_points[2], u].concat()); let t_evaluations = (1..=info.n) .map(|i| F::new(F::Value::as_from(i as f64))) .collect(); - let t = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let t = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( info.n.ilog(2) as usize, t_evaluations, )); - if eval_c_u * t.evaluate(&self.sumcheck_points[2]) != self.sumcheck_expected_evaluations[2] + if eval_c_u * t.evaluate_ext(&self.sumcheck_points[2]) + != self.sumcheck_expected_evaluations[2] { return false; } @@ -440,8 +430,8 @@ impl TransformZqtoRQSubclaim { let n = F::new(F::Value::as_from(info.n as f64)); let n_divied_by_q = F::new(F::Value::as_from((info.n / q) as f64)); - (F::one() + F::one()) * n_divied_by_q * a.evaluate(u) - == n * k.evaluate(u) + r[0].evaluate(u) + a.evaluate_ext(u) * (F::one() + F::one()) * n_divied_by_q + == k.evaluate_ext(u) * n + r[0].evaluate_ext(u) } /// verify the sumcliam @@ -454,13 +444,13 @@ impl TransformZqtoRQSubclaim { pub fn verify_subclaim_without_oracle( &self, q: usize, - a: Rc>, + a: Rc>, c_sparse: &[Rc>], - k: &DenseMultilinearExtension, - r: &[Rc>], - s: &DenseMultilinearExtension, - r_bits: &[&Vec>>], - u: &[F], + k: &DenseMultilinearExtensionBase, + r: &[Rc>], + s: &DenseMultilinearExtensionBase, + r_bits: &[&Vec>>], + u: &[EF], info: &TransformZqtoRQInstanceInfo, ) -> bool { assert_eq!(r_bits.len(), 1); @@ -475,8 +465,8 @@ impl TransformZqtoRQSubclaim { } // check 2: subclaim for sumcheck, i.e. eq(u, point) * k(point) * (1 - k(point)) = 0 - let eval_k = k.evaluate(&self.sumcheck_points[0]); - if eval_identity_function(u, &self.sumcheck_points[0]) * eval_k * (F::one() - eval_k) + let eval_k = k.evaluate_ext(&self.sumcheck_points[0]); + if eval_identity_function(u, &self.sumcheck_points[0]) * eval_k * (EF::one() - eval_k) != self.sumcheck_expected_evaluations[0] { return false; @@ -484,9 +474,9 @@ impl TransformZqtoRQSubclaim { // check 3: subclaim for sumcheck, i.e. eq(u, point) * ((r(point) + 1) * (1 - 2 * k(point)) - s(point)) = 0 if eval_identity_function(u, &self.sumcheck_points[1]) - * ((r[0].evaluate(&self.sumcheck_points[1]) + F::one()) - * (F::one() - (F::one() + F::one()) * k.evaluate(&self.sumcheck_points[1])) - - s.evaluate(&self.sumcheck_points[1])) + * ((r[0].evaluate_ext(&self.sumcheck_points[1]) + F::one()) + * (EF::one() - k.evaluate_ext(&self.sumcheck_points[1]) * (F::one() + F::one())) + - s.evaluate_ext(&self.sumcheck_points[1])) != self.sumcheck_expected_evaluations[1] { return false; @@ -495,7 +485,7 @@ impl TransformZqtoRQSubclaim { // check 4: subclaim for sumcheck, i.e. c(u, point) * t(point) = s(u) let eq_u = gen_identity_evaluations(u); let eq_v = gen_identity_evaluations(&self.sumcheck_points[2]); - let mut eval_c_u = F::zero(); + let mut eval_c_u = EF::zero(); c_sparse.iter().enumerate().for_each(|(x_idx, c)| { assert_eq!(c.evaluations.len(), 1); let (y_idx, c_val) = c.evaluations[0]; @@ -505,11 +495,12 @@ impl TransformZqtoRQSubclaim { let t_evaluations = (1..=info.n) .map(|i| F::new(F::Value::as_from(i as f64))) .collect(); - let t = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let t = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( info.n.ilog(2) as usize, t_evaluations, )); - if eval_c_u * t.evaluate(&self.sumcheck_points[2]) != self.sumcheck_expected_evaluations[2] + if eval_c_u * t.evaluate_ext(&self.sumcheck_points[2]) + != self.sumcheck_expected_evaluations[2] { return false; } @@ -518,7 +509,7 @@ impl TransformZqtoRQSubclaim { let n = F::new(F::Value::as_from(info.n as f64)); let n_divied_by_q = F::new(F::Value::as_from((info.n / q) as f64)); - (F::one() + F::one()) * n_divied_by_q * a.evaluate(u) - == n * k.evaluate(u) + r[0].evaluate(u) + a.evaluate_ext(u) * (F::one() + F::one()) * n_divied_by_q + == k.evaluate_ext(u) * n + r[0].evaluate_ext(u) } } diff --git a/zkp/src/sumcheck/mod.rs b/zkp/src/sumcheck/mod.rs index bc9da551..aeb23959 100644 --- a/zkp/src/sumcheck/mod.rs +++ b/zkp/src/sumcheck/mod.rs @@ -1,29 +1,32 @@ //! Interactive Proof Protocol used for Multilinear Sumcheck // It is derived from https://github.com/arkworks-rs/sumcheck/blob/master/src/ml_sumcheck/protocol/mod.rs. -use algebra::{Field, ListOfProductsOfPolynomials, PolynomialInfo}; +use algebra::{AbstractExtensionField, Field, ListOfProductsOfPolynomials, PolynomialInfo}; use prover::{ProverMsg, ProverState}; -use rand::{RngCore, SeedableRng}; -use rand_chacha::ChaCha12Rng; use std::marker::PhantomData; use verifier::SubClaim; pub mod prover; pub mod verifier; +use algebra::utils::Transcript; /// IP for MLSumcheck -pub struct IPForMLSumcheck { +pub struct IPForMLSumcheck> { _marker: PhantomData, + _stone: PhantomData, } /// Sumcheck for products of multilinear polynomial -pub struct MLSumcheck(PhantomData); +pub struct MLSumcheck> { + _marker: PhantomData, + _stone: PhantomData, +} /// proof generated by prover -pub type Proof = Vec>; +pub type Proof = Vec>; -impl MLSumcheck { +impl> MLSumcheck { /// Extract sum from the proof - pub fn extract_sum(proof: &Proof) -> F { + pub fn extract_sum(proof: &Proof) -> EF { proof[0].evaluations[0] + proof[0].evaluations[1] } @@ -39,32 +42,22 @@ impl MLSumcheck { /// The resulting polynomial is /// /// $$\sum_{i=0}^{n}C_i\cdot\prod_{j=0}^{m_i}P_{ij}$$ + #[allow(clippy::type_complexity)] pub fn prove( - polynomial: &ListOfProductsOfPolynomials, - ) -> Result, crate::error::Error> { - // TODO switch to the Fiat-Shamir RNG - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::prove_as_subprotocol(&mut fs_rng, polynomial).map(|r| r.0) - } - - /// This function does the same thing as `prove`, but it uses a `Fiat-Shamir RNG` as the transcript/to generate the - /// verifier challenges. Additionally, it returns the prover's state in addition to the proof. - /// Both of these allow this sumcheck to be better used as a part of a larger protocol. - pub fn prove_as_subprotocol( - fs_rng: &mut impl RngCore, - polynomial: &ListOfProductsOfPolynomials, - ) -> Result<(Proof, ProverState), crate::error::Error> { - // TODO update Fiat-Shamir RNG using polynomial.info() + trans: &mut Transcript, + polynomial: &ListOfProductsOfPolynomials, + ) -> Result<(Proof, ProverState), crate::error::Error> { + trans.append_message(b"polynomial info", &polynomial.info()); let mut prover_state = IPForMLSumcheck::prover_init(polynomial); let mut verifier_msg = None; let mut prover_msgs = Vec::with_capacity(polynomial.num_variables); for _ in 0..polynomial.num_variables { let prover_msg = IPForMLSumcheck::prove_round(&mut prover_state, &verifier_msg); - // TODO update Fiat-Shamir RNG using prover's message + trans.append_message(b"sumcheck msg", &prover_msg); prover_msgs.push(prover_msg); - verifier_msg = Some(IPForMLSumcheck::sample_round(fs_rng)); + + verifier_msg = Some(IPForMLSumcheck::sample_round(trans)); } prover_state .randomness @@ -74,31 +67,20 @@ impl MLSumcheck { /// verify the proof using `polynomial_info` as the verifier key pub fn verify( + trans: &mut Transcript, polynomial_info: &PolynomialInfo, - claimed_sum: F, - proof: &Proof, - ) -> Result, crate::Error> { - // TODO switch to the Fiat-Shamir RNG - let seed: ::Seed = Default::default(); - let mut fs_rng = ChaCha12Rng::from_seed(seed); - Self::verify_as_subprotocol(&mut fs_rng, polynomial_info, claimed_sum, proof) - } + claimed_sum: EF, + proof: &Proof, + ) -> Result, crate::Error> { + // let mut trans = Transcript::::new(); + trans.append_message(b"polynomial info", polynomial_info); - /// This function does the same thing as `verify`, but it uses a `Fiat-Shamir RNG`` as the transcript to generate the - /// verifier challenges. This allows this sumcheck to be used as a part of a larger protocol. - pub fn verify_as_subprotocol( - fs_rng: &mut impl RngCore, - polynomial_info: &PolynomialInfo, - claimed_sum: F, - proof: &Proof, - ) -> Result, crate::Error> { - // TODO update Fiat-Shamir RNG using polynomial.info() let mut verifier_state = IPForMLSumcheck::verifier_init(polynomial_info); for i in 0..polynomial_info.num_variables { let prover_msg = proof.get(i).expect("proof is incomplete"); - // TODO update Fiat-Shamir RNG using prover's message + trans.append_message(b"sumcheck msg", prover_msg); - IPForMLSumcheck::verify_round((*prover_msg).clone(), &mut verifier_state, fs_rng); + IPForMLSumcheck::verify_round((*prover_msg).clone(), &mut verifier_state, trans); } IPForMLSumcheck::check_and_generate_subclaim(verifier_state, claimed_sum) diff --git a/zkp/src/sumcheck/prover.rs b/zkp/src/sumcheck/prover.rs index 66d91889..972cb910 100644 --- a/zkp/src/sumcheck/prover.rs +++ b/zkp/src/sumcheck/prover.rs @@ -2,10 +2,13 @@ // It is derived from https://github.com/arkworks-rs/sumcheck/blob/master/src/ml_sumcheck/protocol/prover.rs. use core::panic; +use std::marker::PhantomData; use std::vec; -use algebra::Field; +use algebra::{AbstractExtensionField, Field, UF}; use algebra::{DenseMultilinearExtension, ListOfProductsOfPolynomials, MultilinearExtension}; +use serde::ser::SerializeSeq; +use serde::Serialize; use super::verifier::VerifierMsg; use super::IPForMLSumcheck; @@ -13,23 +16,38 @@ use std::rc::Rc; /// Prover Message #[derive(Clone)] -pub struct ProverMsg { +pub struct ProverMsg> { /// evaluations on P(0), P(1), P(2), ... - pub(crate) evaluations: Rc>, + pub(crate) evaluations: Rc>, + _marker: PhantomData, +} + +impl> Serialize for ProverMsg { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut seq = serializer.serialize_seq(Some(self.evaluations.len()))?; + for e in self.evaluations.iter() { + seq.serialize_element(e)?; + } + seq.end() + } } /// Prover State -pub struct ProverState { +pub struct ProverState> { /// sampled randomness given by the verifier - pub randomness: Vec, + pub randomness: Vec, /// Stores the list of products that is meant to be added together. /// Each multiplicand is represented by the index in flattened_ml_extensions - pub list_of_products: Vec<(F, Vec)>, + pub list_of_products: Vec<(EF, Vec)>, /// Stores the linear operations, each of which is successively (in the same order) perfomed over the each MLE of each product stored in the above `products` /// so each (a: F, b: F) can used to wrap a linear operation over the original MLE f, i.e. a \cdot f + b - pub linear_ops: Vec>, + #[allow(clippy::type_complexity)] + pub linear_ops: Vec, UF)>>, /// Stores a list of multilinear extensions in which `self.list_of_products` point to - pub flattened_ml_extensions: Vec>, + pub flattened_ml_extensions: Vec>, /// Number of variables pub num_vars: usize, /// Max number of multiplicands in a product @@ -38,7 +56,7 @@ pub struct ProverState { pub round: usize, } -impl IPForMLSumcheck { +impl> IPForMLSumcheck { /// Initilaize the prover to argue for the sum of polynomial over {0, 1}^`num_vars` /// /// The polynomial is represented by a list of products of polynomials along with its coefficient that is meant to be added together. @@ -51,7 +69,7 @@ impl IPForMLSumcheck { /// The resulting polynomial is /// /// $$\sum_{i=0}^{n} c_i \cdot \prod_{j=0}^{m_i}P_{ij}$$ - pub fn prover_init(polynomial: &ListOfProductsOfPolynomials) -> ProverState { + pub fn prover_init(polynomial: &ListOfProductsOfPolynomials) -> ProverState { if polynomial.num_variables == 0 { panic!("Attempt to prove a constant.") } @@ -78,9 +96,9 @@ impl IPForMLSumcheck { /// /// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2). pub fn prove_round( - prover_state: &mut ProverState, - v_msg: &Option>, - ) -> ProverMsg { + prover_state: &mut ProverState, + v_msg: &Option>, + ) -> ProverMsg { if let Some(msg) = v_msg { if prover_state.round == 0 { panic!("First round should be prover first.") @@ -111,7 +129,7 @@ impl IPForMLSumcheck { // the degree of univariate polynomial sent by prover at this round let degree = prover_state.max_multiplicands; - let zeros = (vec![F::zero(); degree + 1], vec![F::zero(); degree + 1]); + let zeros = (vec![EF::zero(); degree + 1], vec![EF::zero(); degree + 1]); // In effect, this fold is essentially doing simply: // for b in 0..1 << (nv - i) // The goal is to evaluate degree + 1 points for each b, all of which has been fixed with the same (i-1) variables. @@ -128,8 +146,9 @@ impl IPForMLSumcheck { for (&jth_product, &(op_a, op_b)) in products.iter().zip(linear_ops.iter()) { // (a, b) is a wrapped linear operation over original MLE let table = &prover_state.flattened_ml_extensions[jth_product]; - let mut start = (table[b << 1] * op_a) + op_b; - let step = (table[(b << 1) + 1] * op_a) + op_b - start; + // Reordering the operation is to handle the operation between UF and EF + let mut start = op_b + (op_a * table[b << 1]); + let step = op_b + (op_a * table[(b << 1) + 1]) - start; // Evaluate each point P(t) for t = 0..degree + 1 via the accumulated addition instead of the multiplication by t. // [t|b] = [0|b] + t * ([1|b] - [0|b]) represented by little-endian for p in product.iter_mut() { @@ -147,6 +166,7 @@ impl IPForMLSumcheck { ProverMsg { evaluations: Rc::new(products_sum), + _marker: PhantomData, } } } diff --git a/zkp/src/sumcheck/verifier.rs b/zkp/src/sumcheck/verifier.rs index c1a3e1b8..5223d162 100644 --- a/zkp/src/sumcheck/verifier.rs +++ b/zkp/src/sumcheck/verifier.rs @@ -1,46 +1,52 @@ //! Verifier for the multilinear sumcheck protocol // It is derived from https://github.com/arkworks-rs/sumcheck/blob/master/src/ml_sumcheck/protocol/verifier.rs. -use std::vec; +use std::{marker::PhantomData, vec}; -use algebra::{Field, FieldUniformSampler, PolynomialInfo}; -use rand::distributions::Distribution; +use algebra::{utils::Transcript, AbstractExtensionField, Field, PolynomialInfo}; +use serde::Serialize; use crate::error::Error; use std::rc::Rc; use super::{prover::ProverMsg, IPForMLSumcheck}; -#[derive(Clone)] +#[derive(Clone, Serialize)] /// verifier message -pub struct VerifierMsg { +pub struct VerifierMsg> { /// randomness sampled by verifier - pub randomness: F, + pub randomness: EF, + /// marker for F + _marker: PhantomData, } /// Verifier State -pub struct VerifierState { +pub struct VerifierState> { round: usize, nv: usize, max_multiplicands: usize, finished: bool, /// a list storing the univariate polynomial in evaluations sent by the prover at each round - polynomials_received: Vec>>, + polynomials_received: Vec>>, /// a list storing the randomness sampled by the verifier at each round - randomness: Vec, + randomness: Vec, + /// marker for F + _marker: PhantomData, } /// Subclaim when verifier is convinced -pub struct SubClaim { +pub struct SubClaim> { /// the multi-dimensional point that this multilinear extension is evaluated to - pub point: Vec, + pub point: Vec, /// the expected evaluation - pub expected_evaluations: F, + pub expected_evaluations: EF, + /// marker for F + _marker: PhantomData, } -impl IPForMLSumcheck { +impl> IPForMLSumcheck { /// initialize the verifier - pub fn verifier_init(index_info: &PolynomialInfo) -> VerifierState { + pub fn verifier_init(index_info: &PolynomialInfo) -> VerifierState { VerifierState { round: 1, nv: index_info.num_variables, @@ -48,6 +54,7 @@ impl IPForMLSumcheck { finished: false, polynomials_received: Vec::with_capacity(index_info.num_variables), randomness: Vec::with_capacity(index_info.num_variables), + _marker: PhantomData, } } @@ -56,18 +63,18 @@ impl IPForMLSumcheck { /// Normally, this function should perform actual verification. Instead, `verify_round` only samples /// and stores randomness and perform verifications altogether in `check_and_generate_subclaim` at /// the last step. - pub fn verify_round( - prover_msg: ProverMsg, - verifier_state: &mut VerifierState, - rng: &mut R, - ) -> Option> { + pub fn verify_round( + prover_msg: ProverMsg, + verifier_state: &mut VerifierState, + trans: &mut Transcript, + ) -> Option> { if verifier_state.finished { panic!("incorrect verifier state: Verifier is already finished.") } // Now, verifier should check if the received P(0) + P(1) = expected. The check is moved to // `check_and_generate_subclaim`, and will be done after the last round. - let msg = Self::sample_round(rng); + let msg = Self::sample_round(trans); verifier_state.randomness.push(msg.randomness); verifier_state .polynomials_received @@ -87,9 +94,9 @@ impl IPForMLSumcheck { /// check the proof and generate the reduced subclaim pub fn check_and_generate_subclaim( - verifier_state: VerifierState, - asserted_sum: F, - ) -> Result, Error> { + verifier_state: VerifierState, + asserted_sum: EF, + ) -> Result, Error> { if !verifier_state.finished { panic!("Verifier has not finished."); } @@ -116,14 +123,16 @@ impl IPForMLSumcheck { Ok(SubClaim { point: verifier_state.randomness, expected_evaluations: expected, + _marker: PhantomData, }) } /// Simulate a verifier message without doing verification #[inline] - pub fn sample_round(rng: &mut R) -> VerifierMsg { + pub fn sample_round(trans: &mut Transcript) -> VerifierMsg { VerifierMsg { - randomness: FieldUniformSampler::new().sample(rng), + randomness: trans.get_ext_field_challenge::(b"random point in each round"), + _marker: PhantomData, } } } @@ -133,7 +142,10 @@ impl IPForMLSumcheck { /// and evaluate this polynomial at `eval_at`. /// In other words, efficiently compute /// \sum_{i=0}^{len p_i - 1} p_i\[i\] * (\prod_{j!=i}(eval_at - j)/(i - j)) -pub(crate) fn interpolate_uni_poly(p_i: &[F], eval_at: F) -> F { +pub(crate) fn interpolate_uni_poly>( + p_i: &[EF], + eval_at: EF, +) -> EF { let len = p_i.len(); let mut evals = vec![]; @@ -142,13 +154,13 @@ pub(crate) fn interpolate_uni_poly(p_i: &[F], eval_at: F) -> F { let mut prod = eval_at; evals.push(eval_at); - let mut check = F::zero(); + let mut check = EF::zero(); // We return early if 0 <= eval_at < len, i.e. if the desired value has been passed for i in 1..len { if eval_at == check { return p_i[i - 1]; } - check += F::one(); + check += EF::one(); let tmp = eval_at - check; evals.push(tmp); @@ -159,7 +171,7 @@ pub(crate) fn interpolate_uni_poly(p_i: &[F], eval_at: F) -> F { } // Now check = len - 1 - let mut res = F::zero(); + let mut res = EF::zero(); // We want to compute the denominator \prod (j!=i) (i-j) for a given i in 0..len // // we start from the last step for i = len - 1, which is @@ -192,11 +204,11 @@ pub(crate) fn interpolate_uni_poly(p_i: &[F], eval_at: F) -> F { // = \sum_{i=0}^{len p_i - 1} * (prod / evals[i]) / denom[i] // = \sum_{i=0}^{len p_i - 1} * prod / (evals[i] * denom[i]) where denom[i-1] = - denom[i] * (len-i) / i. // So we use denom_up / denom_down to update denom[i] in reverse. - let mut denom_up = field_factorial::(len - 1); - let mut denom_down = F::one(); + let mut denom_up = EF::from_base(field_factorial::(len - 1)); + let mut denom_down = EF::one(); let mut i_as_field = check; // len-1 - let mut len_minust_i_as_field = F::one(); + let mut len_minust_i_as_field = EF::one(); for i in (0..len).rev() { res += p_i[i] * prod * denom_down / (denom_up * evals[i]); @@ -229,7 +241,7 @@ mod test { use crate::sumcheck::verifier::interpolate_uni_poly; use algebra::{ derive::{Field, Prime}, - Field, FieldUniformSampler, Polynomial, + BabyBear, BabyBearExetension, Field, FieldUniformSampler, Polynomial, }; use num_traits::{One, Zero}; use rand::SeedableRng; @@ -241,8 +253,9 @@ mod test { pub struct Fp32(u32); // field type - type FF = Fp32; - type UniPolyFf = Polynomial; + type FF = BabyBear; + type EF = BabyBearExetension; + type UniPolyFf = Polynomial; macro_rules! field_vec { ($t:ty; $elem:expr; $n:expr)=>{ @@ -260,20 +273,20 @@ mod test { // Test a polynomial with 20 known points, i.e., with degree 19 let poly = UniPolyFf::random(20 - 1, &mut prng); - let mut evals: Vec = Vec::with_capacity(20); - let mut point = FF::zero(); + let mut evals: Vec = Vec::with_capacity(20); + let mut point = EF::zero(); evals.push(poly.evaluate(point)); for _i in 1..20 { point += FF::one(); evals.push(poly.evaluate(point)); } - let query = >::new().sample(&mut prng); + let query = >::new().sample(&mut prng); assert_eq!(poly.evaluate(query), interpolate_uni_poly(&evals, query)); // Test interpolation when we ask for the value at an x-coordinate we are already passing, // i.e., in the range 0 <= x < len(values) - 1 - let evals = field_vec!(FF; 0, 1, 4, 9); - assert_eq!(interpolate_uni_poly(&evals, FF::new(3)), FF::new(9)); + let evals = field_vec!(EF; 0, 1, 4, 9); + assert_eq!(interpolate_uni_poly(&evals, EF::new(3)), EF::new(9)); } } diff --git a/zkp/src/utils.rs b/zkp/src/utils.rs index 6c3950a5..fb74b467 100644 --- a/zkp/src/utils.rs +++ b/zkp/src/utils.rs @@ -1,29 +1,31 @@ //! This module defines some useful utils that may invoked by piop. -use algebra::{DenseMultilinearExtension, Field}; +use algebra::{AbstractExtensionField, DenseMultilinearExtension, Field}; /// Generate MLE of the ideneity function eq(u,x) for x \in \{0, 1\}^dim -pub fn gen_identity_evaluations(u: &[F]) -> DenseMultilinearExtension { +pub fn gen_identity_evaluations>( + u: &[EF], +) -> DenseMultilinearExtension { let dim = u.len(); - let mut evaluations: Vec<_> = vec![F::zero(); 1 << dim]; - evaluations[0] = F::one(); + let mut evaluations: Vec<_> = vec![EF::zero(); 1 << dim]; + evaluations[0] = EF::one(); for i in 0..dim { // The index represents a point in {0,1}^`num_vars` in little endian form. // For example, `0b1011` represents `P(1,1,0,1)` let u_i_rev = u[dim - i - 1]; for b in (0..(1 << i)).rev() { evaluations[(b << 1) + 1] = evaluations[b] * u_i_rev; - evaluations[b << 1] = evaluations[b] * (F::one() - u_i_rev); + evaluations[b << 1] = evaluations[b] * (EF::one() - u_i_rev); } } DenseMultilinearExtension::from_evaluations_vec(dim, evaluations) } /// Evaluate eq(u, v) = \prod_i (u_i * v_i + (1 - u_i) * (1 - v_i)) -pub fn eval_identity_function(u: &[F], v: &[F]) -> F { +pub fn eval_identity_function>(u: &[EF], v: &[EF]) -> EF { assert_eq!(u.len(), v.len()); - let mut evaluation = F::one(); + let mut evaluation = EF::one(); for (u_i, v_i) in u.iter().zip(v) { - evaluation *= *u_i * *v_i + (F::one() - *u_i) * (F::one() - *v_i); + evaluation *= *u_i * *v_i + (EF::one() - *u_i) * (EF::one() - *v_i); } evaluation } @@ -33,7 +35,7 @@ mod test { use crate::utils::{eval_identity_function, gen_identity_evaluations}; use algebra::{ derive::{Field, Prime}, - FieldUniformSampler, MultilinearExtension, + BabyBearExetension, FieldUniformSampler, MultilinearExtension, }; use rand::thread_rng; use rand_distr::Distribution; @@ -42,11 +44,11 @@ mod test { #[modulus = 132120577] pub struct Fp32(u32); // field type - type FF = Fp32; + type EF = BabyBearExetension; #[test] fn test_gen_identity_evaluations() { - let sampler = >::new(); + let sampler = >::new(); let mut rng = thread_rng(); let dim = 10; let u: Vec<_> = (0..dim).map(|_| sampler.sample(&mut rng)).collect(); diff --git a/zkp/tests/test_accumulator.rs b/zkp/tests/test_accumulator.rs index d801f8bc..f35da042 100644 --- a/zkp/tests/test_accumulator.rs +++ b/zkp/tests/test_accumulator.rs @@ -1,8 +1,10 @@ use algebra::{ derive::{DecomposableField, FheField, Field, Prime, NTT}, - Basis, DenseMultilinearExtension, Field, FieldUniformSampler, MultilinearExtension, + utils::Transcript, + Basis, DenseMultilinearExtensionBase, Field, FieldUniformSampler, MultilinearExtensionBase, }; use algebra::{transformation::AbstractNTT, NTTField, NTTPolynomial, Polynomial}; +use fhe_core::{DefaultExtendsionFieldU32x4, DefaultFieldU32}; use itertools::izip; use num_traits::One; use rand_distr::Distribution; @@ -20,7 +22,8 @@ use zkp::{ #[modulus = 132120577] pub struct Fp32(u32); // field type -type FF = Fp32; +type FF = DefaultFieldU32; +type EF = DefaultExtendsionFieldU32x4; #[derive(Field, DecomposableField, Prime)] #[modulus = 59] @@ -31,8 +34,8 @@ where R: rand::Rng + rand::CryptoRng, { RlweCiphertext { - a: Rc::new(>::random(num_vars, rng)), - b: Rc::new(>::random(num_vars, rng)), + a: Rc::new(>::random(num_vars, rng)), + b: Rc::new(>::random(num_vars, rng)), } } @@ -46,10 +49,10 @@ where { RlweCiphertexts { a_bits: (0..bits_len) - .map(|_| Rc::new(>::random(num_vars, rng))) + .map(|_| Rc::new(>::random(num_vars, rng))) .collect(), b_bits: (0..bits_len) - .map(|_| Rc::new(>::random(num_vars, rng))) + .map(|_| Rc::new(>::random(num_vars, rng))) .collect(), } } @@ -111,20 +114,20 @@ fn ntt_inverse_transform_normal_order(log_n: u32, points: & // * input_rgsw_ntt: RGSW(Zu) of the ntt form fn update_accumulator( input_accumulator_ntt: &RlweCiphertext, - input_d: Rc>, + input_d: Rc>, input_rgsw_ntt: (RlweCiphertexts, RlweCiphertexts), basis_info: &DecomposedBitsInfo, ntt_info: &NTTInstanceInfo, ) -> AccumulatorWitness { // 1. Perform ntt transform on (x^{-a_u} - 1) - let input_d_ntt = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let input_d_ntt = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( ntt_info.log_n, ntt_transform_normal_order(ntt_info.log_n as u32, &input_d.evaluations), )); // 2. Perform point-multiplication to compute (x^{-a_u} - 1) * ACC let input_rlwe_ntt = RlweCiphertext { - a: Rc::new(DenseMultilinearExtension::from_evaluations_vec( + a: Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( ntt_info.log_n, izip!( &input_d_ntt.evaluations, @@ -133,7 +136,7 @@ fn update_accumulator( .map(|(d_i, a_i)| *d_i * *a_i) .collect(), )), - b: Rc::new(DenseMultilinearExtension::from_evaluations_vec( + b: Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( ntt_info.log_n, izip!( &input_d_ntt.evaluations, @@ -146,14 +149,14 @@ fn update_accumulator( // 3. Compute the RLWE of coefficient form as the input of the multiplication between RLWE and RGSW let input_rlwe = RlweCiphertext { - a: Rc::new(DenseMultilinearExtension::from_evaluations_vec( + a: Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( ntt_info.log_n, ntt_inverse_transform_normal_order( ntt_info.log_n as u32, &input_rlwe_ntt.a.evaluations, ), )), - b: Rc::new(DenseMultilinearExtension::from_evaluations_vec( + b: Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( ntt_info.log_n, ntt_inverse_transform_normal_order( ntt_info.log_n as u32, @@ -179,7 +182,7 @@ fn update_accumulator( .a_bits .iter() .map(|bit| { - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( ntt_info.log_n, ntt_transform_normal_order(ntt_info.log_n as u32, &bit.evaluations), )) @@ -189,7 +192,7 @@ fn update_accumulator( .b_bits .iter() .map(|bit| { - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( ntt_info.log_n, ntt_transform_normal_order(ntt_info.log_n as u32, &bit.evaluations), )) @@ -227,11 +230,11 @@ fn update_accumulator( } let output_rlwe_ntt = RlweCiphertext { - a: Rc::new(DenseMultilinearExtension::from_evaluations_vec( + a: Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( ntt_info.log_n, output_g_ntt, )), - b: Rc::new(DenseMultilinearExtension::from_evaluations_vec( + b: Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( ntt_info.log_n, output_h_ntt, )), @@ -283,7 +286,8 @@ fn test_trivial_accumulator() { let ntt_info = NTTInstanceInfo { log_n, ntt_table }; let mut accumulator = random_rlwe_ciphertext(&mut rng, num_vars); - let mut accumulator_instance = >::new(num_vars, &ntt_info, &basis_info); + let mut accumulator_instance = + >::new(num_vars, &ntt_info, &basis_info); // number of updations in ACC let num_updations = 10; @@ -295,23 +299,36 @@ fn test_trivial_accumulator() { // number of ntt in each updation let num_ntt_iter = ((bits_len << 1) + 3) as usize; let num_ntt = num_updations * num_ntt_iter; + + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); // randomness used to combine all ntt instances - let randomness_ntt = (0..num_ntt) - .map(|_| uniform.sample(&mut rng)) - .collect::>(); + let prover_randomness_ntt = + prover_trans.get_vec_ext_field_challenge(b"randomize ntt instances", num_ntt); + let verify_randomness_ntt = + verifier_trans.get_vec_ext_field_challenge(b"randomize ntt instances", num_ntt); + let num_sumcheck = num_updations * 2; // randomness used to combine all sumcheck protocols - let randomness_sumcheck = (0..num_sumcheck) - .map(|_| uniform.sample(&mut rng)) - .collect::>(); - let u = (0..num_vars) - .map(|_| uniform.sample(&mut rng)) - .collect::>(); - let identity_func_at_u = Rc::new(gen_identity_evaluations(&u)); + let prover_randomness_sumcheck = prover_trans.get_vec_ext_field_challenge( + b"randomness used to combine all sumcheck protocols", + num_sumcheck, + ); + let verify_randomness_sumcheck = verifier_trans.get_vec_ext_field_challenge( + b"randomness used to combine all sumcheck protocols", + num_sumcheck, + ); + + let prover_u = prover_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + let verify_u = verifier_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + + let identity_func_at_u = Rc::new(gen_identity_evaluations(&prover_u)); // update accumulator for `num_updations` times for idx in 0..num_updations { - let input_d = Rc::new(DenseMultilinearExtension::from_evaluations_slice( + let input_d = Rc::new(DenseMultilinearExtensionBase::from_evaluations_slice( num_vars, &random_d, )); let rgsw_ntt = ( @@ -321,19 +338,19 @@ fn test_trivial_accumulator() { let witness = update_accumulator(&accumulator, input_d, rgsw_ntt, &basis_info, &ntt_info); accumulator_instance.add_witness( - &randomness_ntt[idx * num_ntt_iter..(idx + 1) * num_ntt_iter], - &randomness_sumcheck[idx * 2..(idx + 1) * 2], + &prover_randomness_ntt[idx * num_ntt_iter..(idx + 1) * num_ntt_iter], + &prover_randomness_sumcheck[idx * 2..(idx + 1) * 2], &identity_func_at_u, &witness, ); accumulator = RlweCiphertext { - a: Rc::new(DenseMultilinearExtension::from_evaluations_vec( + a: Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, izip!(accumulator.a.iter(), witness.output_rlwe_ntt.a.iter()) .map(|(acc, x)| *acc + *x) .collect(), )), - b: Rc::new(DenseMultilinearExtension::from_evaluations_vec( + b: Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, izip!(accumulator.b.iter(), witness.output_rlwe_ntt.b.iter()) .map(|(acc, x)| *acc + *x) @@ -344,12 +361,13 @@ fn test_trivial_accumulator() { } let info = accumulator_instance.info(); - let proof = >::prove(&accumulator_instance, &u); - let subclaim = >::verify(&proof, &u, &info); + let proof = + >::prove(&mut prover_trans, &accumulator_instance, &prover_u); + let subclaim = >::verify(&mut verifier_trans, &proof, &verify_u, &info); assert!(subclaim.verify_subclaim( - &u, - &randomness_ntt, - &randomness_sumcheck, + &verify_u, + &verify_randomness_ntt, + &verify_randomness_sumcheck, &accumulator_instance.ntt_instance.coeffs, &accumulator_instance.ntt_instance.points, &witnesses, diff --git a/zkp/tests/test_addition_in_zq.rs b/zkp/tests/test_addition_in_zq.rs index 09092744..63a51ebc 100644 --- a/zkp/tests/test_addition_in_zq.rs +++ b/zkp/tests/test_addition_in_zq.rs @@ -1,6 +1,8 @@ use algebra::{ derive::{DecomposableField, Field, Prime}, - Basis, DecomposableField, DenseMultilinearExtension, Field, FieldUniformSampler, + utils::Transcript, + BabyBear, BabyBearExetension, Basis, DecomposableField, DenseMultilinearExtensionBase, Field, + FieldUniformSampler, }; use num_traits::{One, Zero}; use rand::prelude::*; @@ -10,15 +12,16 @@ use std::vec; use zkp::piop::{AdditionInZq, AdditionInZqInstance}; #[derive(Field, DecomposableField, Prime)] -#[modulus = 132120577] -pub struct Fp32(u32); +#[modulus = 2013265921] +pub struct Fp32(u64); #[derive(Field, DecomposableField, Prime)] #[modulus = 59] pub struct Fq(u32); // field type -type FF = Fp32; +type FF = BabyBear; +type EF = BabyBearExetension; macro_rules! field_vec { ($t:ty; $elem:expr; $n:expr)=>{ @@ -31,29 +34,26 @@ macro_rules! field_vec { #[test] fn test_trivial_addition_in_zq() { - let mut rng = thread_rng(); - let sampler = >::new(); - let q = FF::new(9); let base_len: u32 = 1; let base: FF = FF::new(2); let num_vars = 2; let bits_len: u32 = 4; let abc = vec![ - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 4, 6, 8, 2), )), - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 7, 3, 0, 1), )), - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 2, 0, 8, 3), )), ]; - let k = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let k = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 1, 1, 0, 0), )); @@ -68,18 +68,33 @@ fn test_trivial_addition_in_zq() { let abc_instance = AdditionInZqInstance::from_slice(&abc, &k, q, base, base_len, bits_len); let addition_info = abc_instance.info(); - let u: Vec<_> = (0..num_vars).map(|_| sampler.sample(&mut rng)).collect(); - - let proof = AdditionInZq::prove(&abc_instance, &u); - let subclaim = AdditionInZq::verify(&proof, &addition_info.decomposed_bits_info); - assert!(subclaim.verify_subclaim(q, &abc, k.as_ref(), &abd_bits_ref, &u, &addition_info)); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let prover_u = prover_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + let verifier_u = verifier_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + + let proof = >::prove(&mut prover_trans, &abc_instance, &prover_u); + let subclaim = >::verify( + &mut verifier_trans, + &proof, + &addition_info.decomposed_bits_info, + ); + assert!(subclaim.verify_subclaim( + q, + &abc, + k.as_ref(), + &abd_bits_ref, + &verifier_u, + &addition_info + )); } #[test] fn test_random_addition_in_zq() { let mut rng = thread_rng(); let uniform_fq = >::new(); - let uniform_fp = >::new(); let num_vars = 10; let q = FF::new(Fq::MODULUS_VALUE); let base_len: u32 = 3; @@ -108,22 +123,22 @@ fn test_random_addition_in_zq() { let (c, k): (Vec<_>, Vec<_>) = c_k.iter().cloned().unzip(); let abc = vec![ - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, // Convert to Fp a.iter().map(|x: &Fq| FF::new(x.value())).collect(), )), - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, b.iter().map(|x: &Fq| FF::new(x.value())).collect(), )), - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, c.iter().map(|x: &Fq| FF::new(x.value())).collect(), )), ]; - let k = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let k = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, k.iter().map(|x: &Fq| FF::new(x.value())).collect(), )); @@ -138,9 +153,25 @@ fn test_random_addition_in_zq() { let abc_instance = AdditionInZqInstance::from_slice(&abc, &k, q, base, base_len, bits_len); let addition_info = abc_instance.info(); - let u: Vec<_> = (0..num_vars).map(|_| uniform_fp.sample(&mut rng)).collect(); - - let proof = AdditionInZq::prove(&abc_instance, &u); - let subclaim = AdditionInZq::verify(&proof, &addition_info.decomposed_bits_info); - assert!(subclaim.verify_subclaim(q, &abc, k.as_ref(), &abc_bits_ref, &u, &addition_info)); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let prover_u = prover_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + let verifier_u = verifier_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + + let proof = >::prove(&mut prover_trans, &abc_instance, &prover_u); + let subclaim = AdditionInZq::verify( + &mut verifier_trans, + &proof, + &addition_info.decomposed_bits_info, + ); + assert!(subclaim.verify_subclaim( + q, + &abc, + k.as_ref(), + &abc_bits_ref, + &verifier_u, + &addition_info + )); } diff --git a/zkp/tests/test_bit_decomposition.rs b/zkp/tests/test_bit_decomposition.rs index 3068a2e3..4ba38804 100644 --- a/zkp/tests/test_bit_decomposition.rs +++ b/zkp/tests/test_bit_decomposition.rs @@ -1,8 +1,9 @@ -use algebra::Basis; +use algebra::utils::Transcript; use algebra::{ derive::{DecomposableField, FheField, Field, Prime, NTT}, - DenseMultilinearExtension, Field, FieldUniformSampler, + Field, FieldUniformSampler, }; +use algebra::{BabyBear, BabyBearExetension, Basis, DenseMultilinearExtensionBase}; // use protocol::bit_decomposition::{BitDecomposition, DecomposedBits}; use rand::prelude::*; use rand_distr::Distribution; @@ -11,11 +12,12 @@ use std::vec; use zkp::piop::{BitDecomposition, DecomposedBits}; #[derive(Field, Prime, DecomposableField, FheField, NTT)] -#[modulus = 132120577] -pub struct Fp32(u32); +#[modulus = 2013265921] +pub struct Fp32(u64); // field type -type FF = Fp32; +type FF = BabyBear; +type EF = BabyBearExetension; macro_rules! field_vec { ($t:ty; $elem:expr; $n:expr)=>{ @@ -33,18 +35,18 @@ fn test_single_trivial_bit_decomposition_base_2() { let bits_len: u32 = 2; let num_vars = 2; - let d = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let d = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, 1, 2, 3), )); let d_bits = vec![ // 0th bit - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, 1, 0, 1), )), // 1st bit - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, 0, 1, 1), )), @@ -57,9 +59,12 @@ fn test_single_trivial_bit_decomposition_base_2() { let d_bits_verifier = vec![&d_bits]; let decomposed_bits_info = prover_key.info(); - let u = field_vec!(FF; 0, 0); - let proof = BitDecomposition::prove(&prover_key, &u); - let subclaim = BitDecomposition::verifier(&proof, &decomposed_bits_info); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let u = field_vec!(EF; 0, 0); + let proof = >::prove(&mut prover_trans, &prover_key, &u); + let subclaim = + >::verify(&mut verifier_trans, &proof, &decomposed_bits_info); assert!(subclaim.verify_subclaim(&d_verifier, &d_bits_verifier, &u, &decomposed_bits_info)); } @@ -70,14 +75,12 @@ fn test_batch_trivial_bit_decomposition_base_2() { let bits_len: u32 = 2; let num_vars = 2; - let mut rng = thread_rng(); - let uniform = >::new(); let d = vec![ - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, 1, 2, 3), )), - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, 1, 2, 3), )), @@ -85,24 +88,24 @@ fn test_batch_trivial_bit_decomposition_base_2() { let d_bits = vec![ vec![ // 0th bit - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, 1, 0, 1), )), // 1st bit - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, 0, 1, 1), )), ], vec![ // 0th bit - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, 1, 0, 1), )), // 1st bit - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, 0, 1, 1), )), @@ -116,11 +119,17 @@ fn test_batch_trivial_bit_decomposition_base_2() { } let decomposed_bits_info = decomposed_bits.info(); - - let u: Vec<_> = (0..num_vars).map(|_| uniform.sample(&mut rng)).collect(); - let proof = BitDecomposition::prove(&decomposed_bits, &u); - let subclaim = BitDecomposition::verifier(&proof, &decomposed_bits_info); - assert!(subclaim.verify_subclaim(&d, &d_bits_ref, &u, &decomposed_bits_info)); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let prover_u = prover_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + let verifier_u = verifier_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + + let proof = >::prove(&mut prover_trans, &decomposed_bits, &prover_u); + let subclaim = + >::verify(&mut verifier_trans, &proof, &decomposed_bits_info); + assert!(subclaim.verify_subclaim(&d, &d_bits_ref, &verifier_u, &decomposed_bits_info)); } #[test] @@ -132,7 +141,7 @@ fn test_single_bit_decomposition() { let mut rng = thread_rng(); let uniform = >::new(); - let d = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let d = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, (0..(1 << num_vars)) .map(|_| uniform.sample(&mut rng)) @@ -148,10 +157,21 @@ fn test_single_bit_decomposition() { let decomposed_bits_info = decomposed_bits.info(); - let u: Vec<_> = (0..num_vars).map(|_| uniform.sample(&mut rng)).collect(); - let proof = BitDecomposition::prove(&decomposed_bits, &u); - let subclaim = BitDecomposition::verifier(&proof, &decomposed_bits_info); - assert!(subclaim.verify_subclaim(&d_verifier, &d_bits_verifier, &u, &decomposed_bits_info)); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let prover_u = prover_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + let verifier_u = verifier_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + let proof = >::prove(&mut prover_trans, &decomposed_bits, &prover_u); + let subclaim = + >::verify(&mut verifier_trans, &proof, &decomposed_bits_info); + assert!(subclaim.verify_subclaim( + &d_verifier, + &d_bits_verifier, + &verifier_u, + &decomposed_bits_info + )); } #[test] @@ -164,25 +184,25 @@ fn test_batch_bit_decomposition() { let mut rng = thread_rng(); let uniform = >::new(); let d = vec![ - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, (0..(1 << num_vars)) .map(|_| uniform.sample(&mut rng)) .collect(), )), - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, (0..(1 << num_vars)) .map(|_| uniform.sample(&mut rng)) .collect(), )), - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, (0..(1 << num_vars)) .map(|_| uniform.sample(&mut rng)) .collect(), )), - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, (0..(1 << num_vars)) .map(|_| uniform.sample(&mut rng)) @@ -203,8 +223,14 @@ fn test_batch_bit_decomposition() { let decomposed_bits_info = decomposed_bits.info(); - let u: Vec<_> = (0..num_vars).map(|_| uniform.sample(&mut rng)).collect(); - let proof = BitDecomposition::prove(&decomposed_bits, &u); - let subclaim = BitDecomposition::verifier(&proof, &decomposed_bits_info); - assert!(subclaim.verify_subclaim(&d, &d_bits_ref, &u, &decomposed_bits_info)); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let prover_u = prover_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + let verifier_u = verifier_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + let proof = >::prove(&mut prover_trans, &decomposed_bits, &prover_u); + let subclaim = + >::verify(&mut verifier_trans, &proof, &decomposed_bits_info); + assert!(subclaim.verify_subclaim(&d, &d_bits_ref, &verifier_u, &decomposed_bits_info)); } diff --git a/zkp/tests/test_ntt.rs b/zkp/tests/test_ntt.rs index 7c625cbb..6472e9f3 100644 --- a/zkp/tests/test_ntt.rs +++ b/zkp/tests/test_ntt.rs @@ -1,32 +1,33 @@ use algebra::{ - derive::{DecomposableField, FheField, Field, Prime, NTT}, - DenseMultilinearExtension, Field, FieldUniformSampler, NTTPolynomial, + derive::{Field, Prime}, + utils::Transcript, + DecomposableField, DenseMultilinearExtension, DenseMultilinearExtensionBase, Field, + FieldUniformSampler, NTTPolynomial, }; use algebra::{transformation::AbstractNTT, NTTField, Polynomial}; +use fhe_core::{DefaultExtendsionFieldU32x4, DefaultFieldU32}; use num_traits::{One, Zero}; use rand::prelude::*; use rand_distr::Distribution; use std::rc::Rc; use std::vec; -use zkp::piop::ntt::ntt_bare::init_fourier_table; +use zkp::piop::ntt::{ntt_bare::init_fourier_table, NTTInstanceExt}; use zkp::piop::{NTTBareIOP, NTTInstance, NTTIOP}; -#[derive(Field, Prime, DecomposableField, FheField, NTT)] -#[modulus = 132120577] -pub struct Fp32(u32); - #[derive(Field, Prime)] #[modulus = 59] pub struct Fq(u32); -// field type -type FF = Fp32; +type FF = DefaultFieldU32; +type EF = DefaultExtendsionFieldU32x4; + type PolyFF = Polynomial; -fn obtain_fourier_matrix_oracle(log_n: u32) -> DenseMultilinearExtension { +fn obtain_fourier_matrix_oracle(log_n: u32) -> DenseMultilinearExtensionBase { let m = 1 << (log_n + 1); let mut ntt_table = Vec::with_capacity(m as usize); let root = FF::get_ntt_table(log_n).unwrap().root(); + let mut power = FF::one(); for _ in 0..m { ntt_table.push(power); @@ -42,7 +43,7 @@ fn obtain_fourier_matrix_oracle(log_n: u32) -> DenseMultilinearExtension { fourier_matrix[idx_fourier as usize] = ntt_table[idx_power as usize]; } } - DenseMultilinearExtension::from_evaluations_vec((log_n << 1) as usize, fourier_matrix) + DenseMultilinearExtensionBase::from_evaluations_vec((log_n << 1) as usize, fourier_matrix) } /// Given an `index` of `len` bits, output a new index where the bits are reversed. @@ -105,13 +106,13 @@ fn naive_ntt_transform_normal_order(log_n: u32, coeff: &[FF]) -> Vec { let m = 1 << (log_n + 1); let mut ntt_table = Vec::with_capacity(m as usize); let root = FF::get_ntt_table(log_n).unwrap().root(); - let mut power = FF::one(); + let mut power = DefaultFieldU32::one(); for _ in 0..m { ntt_table.push(power); power *= root; } - let mut fourier_matrix = vec![FF::zero(); (1 << log_n) * (1 << log_n)]; + let mut fourier_matrix = vec![DefaultFieldU32::zero(); (1 << log_n) * (1 << log_n)]; // In little endian, the index for F[i, j] is i + (j << dim) for i in 0..1 << log_n { for j in 0..1 << log_n { @@ -121,7 +122,7 @@ fn naive_ntt_transform_normal_order(log_n: u32, coeff: &[FF]) -> Vec { } } - let mut ntt_form = vec![FF::zero(); 1 << log_n]; + let mut ntt_form = vec![DefaultFieldU32::zero(); 1 << log_n]; for i in 0..1 << log_n { for j in 0..1 << log_n { ntt_form[i] += coeff[j] * fourier_matrix[i + (j << log_n)]; @@ -150,7 +151,7 @@ fn test_ntt_transform_normal_order() { let log_n = 10; let coeff = PolyFF::random(1 << log_n, &mut thread_rng()).data(); let points_naive = naive_ntt_transform_normal_order(log_n, &coeff); - let points = ntt_transform_normal_order(log_n, &coeff); + let points = ntt_transform_normal_order::(log_n, &coeff); assert_eq!(points, points_naive); } @@ -182,27 +183,46 @@ fn test_ntt_bare_without_delegation() { let ntt_table = Rc::new(ntt_table); let mut rng = thread_rng(); - let uniform = >::new(); let coeff = PolyFF::random(1 << log_n, &mut rng).data(); - let points = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let points = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( log_n, - ntt_transform_normal_order(log_n as u32, &coeff), + ntt_transform_normal_order(log_n as u32, &coeff) + .iter() + .map(|x| FF::new(x.value())) + .collect(), )); - let coeff = Rc::new(DenseMultilinearExtension::from_evaluations_vec( - log_n, coeff, + let coeff = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( + log_n, + coeff.iter().map(|x| FF::new(x.value())).collect(), )); let ntt_instance = NTTInstance::from_slice(log_n, &ntt_table, &coeff, &points); let ntt_instance_info = ntt_instance.info(); - let u: Vec<_> = (0..log_n).map(|_| uniform.sample(&mut rng)).collect(); - let f_u = Rc::new(init_fourier_table(&u, &ntt_instance.ntt_table)); - let proof = NTTBareIOP::prove(&ntt_instance, &f_u, &u); - let subclaim = NTTBareIOP::verify(&proof, &ntt_instance_info); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let prover_u = prover_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", log_n); + let verifier_u = verifier_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", log_n); + + let f_u = Rc::new(init_fourier_table(&prover_u, &ntt_instance.ntt_table)); + + let ntt_instance = >::from_base(&ntt_instance); + let proof = >::prove(&mut prover_trans, &f_u, &ntt_instance, &prover_u); + let subclaim = >::verify(&mut verifier_trans, &proof.0, &ntt_instance_info); // Without delegation, the verifier needs to compute F(u, v) on its own. let fourier_matrix = Rc::new(obtain_fourier_matrix_oracle(log_n as u32)); - assert!(subclaim.verify_subclaim(&fourier_matrix, &points, &coeff, &u, &ntt_instance_info)); + let points = >::from_base(points.as_ref()); + let coeff = >::from_base(coeff.as_ref()); + assert!(subclaim.verify_subclaim( + &fourier_matrix, + &points, + &coeff, + &verifier_u, + &ntt_instance_info + )); } #[test] @@ -211,6 +231,7 @@ fn test_ntt_with_delegation() { let m = 1 << (log_n + 1); let mut ntt_table = Vec::with_capacity(m as usize); let root = FF::get_ntt_table(log_n as u32).unwrap().root(); + let mut power = FF::one(); for _ in 0..m { ntt_table.push(power); @@ -219,24 +240,37 @@ fn test_ntt_with_delegation() { let ntt_table = Rc::new(ntt_table); let mut rng = thread_rng(); - let uniform = >::new(); let coeff = PolyFF::random(1 << log_n, &mut rng).data(); - let points = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let points = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( log_n, - ntt_transform_normal_order(log_n as u32, &coeff), + ntt_transform_normal_order(log_n as u32, &coeff) + .iter() + .map(|x| FF::new(x.value())) + .collect(), )); - let coeff = Rc::new(DenseMultilinearExtension::from_evaluations_vec( - log_n, coeff, + let coeff = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( + log_n, + coeff.iter().map(|x| FF::new(x.value())).collect(), )); let ntt_instance = NTTInstance::from_slice(log_n, &ntt_table, &coeff, &points); let ntt_instance_info = ntt_instance.info(); - let u: Vec<_> = (0..log_n).map(|_| uniform.sample(&mut rng)).collect(); - let proof = NTTIOP::prove(&ntt_instance, &u); - let subclaim = NTTIOP::verify(&proof, &ntt_instance_info, &u); - - assert!(subclaim.verify_subcliam(&points, &coeff, &u, &ntt_instance_info)); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let prover_u = prover_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", log_n); + let verifier_u = verifier_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", log_n); + + let ntt_instance = >::from_base(&ntt_instance); + let proof = >::prove(&mut prover_trans, &ntt_instance, &prover_u); + let subclaim = + >::verify(&mut verifier_trans, &proof, &ntt_instance_info, &verifier_u); + + let points = >::from_base(points.as_ref()); + let coeff = >::from_base(coeff.as_ref()); + assert!(subclaim.verify_subcliam(&points, &coeff, &verifier_u, &ntt_instance_info)); } #[test] @@ -255,15 +289,21 @@ fn test_ntt_combined_with_delegation() { let mut rng = thread_rng(); let uniform = >::new(); let mut coeff1 = PolyFF::random(1 << log_n, &mut rng); - let points1 = DenseMultilinearExtension::from_evaluations_vec( + let points1 = DenseMultilinearExtensionBase::from_evaluations_vec( log_n, - ntt_transform_normal_order(log_n as u32, coeff1.as_ref()), + ntt_transform_normal_order(log_n as u32, coeff1.as_ref()) + .iter() + .map(|x| FF::new(x.value())) + .collect(), ); let mut coeff2 = PolyFF::random(1 << log_n, &mut rng); - let points2 = DenseMultilinearExtension::from_evaluations_vec( + let points2 = DenseMultilinearExtensionBase::from_evaluations_vec( log_n, - ntt_transform_normal_order(log_n as u32, coeff2.as_ref()), + ntt_transform_normal_order(log_n as u32, coeff2.as_ref()) + .iter() + .map(|x| FF::new(x.value())) + .collect(), ); let r_1 = uniform.sample(&mut rng); @@ -272,22 +312,34 @@ fn test_ntt_combined_with_delegation() { coeff2.mul_scalar_assign(r_2); let coeff = coeff1 + coeff2; - let coeff = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let coeff = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( log_n, - coeff.data(), + coeff.data().iter().map(|x| FF::new(x.value())).collect(), )); - let mut points = - >::from_evaluations_vec(log_n, vec![FF::zero(); 1 << log_n]); - points += (r_1, &points1); - points += (r_2, &points2); + let mut points = >::from_evaluations_vec( + log_n, + vec![FF::zero(); 1 << log_n], + ); + points += (FF::new(r_1.value() as u32), &points1); + points += (FF::new(r_2.value() as u32), &points2); let points = Rc::new(points); let ntt_instance = NTTInstance::from_slice(log_n, &ntt_table, &coeff, &points); let ntt_instance_info = ntt_instance.info(); - let u: Vec<_> = (0..log_n).map(|_| uniform.sample(&mut rng)).collect(); - let proof = NTTIOP::prove(&ntt_instance, &u); - let subclaim = NTTIOP::verify(&proof, &ntt_instance_info, &u); - - assert!(subclaim.verify_subcliam(&points, &coeff, &u, &ntt_instance_info)); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let prover_u = prover_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", log_n); + let verifier_u = verifier_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", log_n); + + let ntt_instance = >::from_base(&ntt_instance); + let proof = >::prove(&mut prover_trans, &ntt_instance, &prover_u); + let subclaim = + >::verify(&mut verifier_trans, &proof, &ntt_instance_info, &verifier_u); + + let points = >::from_base(points.as_ref()); + let coeff = >::from_base(coeff.as_ref()); + assert!(subclaim.verify_subcliam(&points, &coeff, &verifier_u, &ntt_instance_info)); } diff --git a/zkp/tests/test_rlwe_mult_rgsw.rs b/zkp/tests/test_rlwe_mult_rgsw.rs index 0e2c14a3..6af31080 100644 --- a/zkp/tests/test_rlwe_mult_rgsw.rs +++ b/zkp/tests/test_rlwe_mult_rgsw.rs @@ -1,8 +1,9 @@ +use algebra::{transformation::AbstractNTT, NTTField, NTTPolynomial, Polynomial}; use algebra::{ - derive::{DecomposableField, FheField, Field, Prime, NTT}, - Basis, DenseMultilinearExtension, Field, FieldUniformSampler, + utils::Transcript, AbstractExtensionField, Basis, DenseMultilinearExtensionBase, Field, + FieldUniformSampler, }; -use algebra::{transformation::AbstractNTT, NTTField, NTTPolynomial, Polynomial}; +use fhe_core::{DefaultExtendsionFieldU32x4, DefaultFieldU32}; use itertools::izip; use num_traits::One; use rand_distr::Distribution; @@ -13,12 +14,10 @@ use zkp::piop::{ RlweMultRgswInstance, }; -#[derive(Field, Prime, DecomposableField, FheField, NTT)] -#[modulus = 132120577] -pub struct Fp32(u32); - // field type -type FF = Fp32; + +type FF = DefaultFieldU32; +type EF = DefaultExtendsionFieldU32x4; /// Given an `index` of `len` bits, output a new index where the bits are reversed. fn reverse_bits(index: usize, len: u32) -> usize { @@ -79,13 +78,13 @@ fn ntt_inverse_transform_normal_order(log_n: u32, points: & /// * basis_info: information used to decompose bits /// * ntt_info: information used to perform NTT /// * randomness_ntt: randomness used to generate a single randomized NTT instance -fn gen_rlwe_mult_rgsw_instance( +fn gen_rlwe_mult_rgsw_instance>( input_rlwe: RlweCiphertext, input_rgsw: (RlweCiphertexts, RlweCiphertexts), basis_info: &DecomposedBitsInfo, ntt_info: &NTTInstanceInfo, - randomness_ntt: &[F], -) -> RlweMultRgswInstance { + randomness_ntt: &[EF], +) -> RlweMultRgswInstance { // 1. Decompose the input of RLWE ciphertex let bits_rlwe = RlweCiphertexts { a_bits: input_rlwe @@ -103,7 +102,7 @@ fn gen_rlwe_mult_rgsw_instance( .a_bits .iter() .map(|bit| { - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( ntt_info.log_n, ntt_transform_normal_order(ntt_info.log_n as u32, &bit.evaluations), )) @@ -113,7 +112,7 @@ fn gen_rlwe_mult_rgsw_instance( .b_bits .iter() .map(|bit| { - Rc::new(DenseMultilinearExtension::from_evaluations_vec( + Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( ntt_info.log_n, ntt_transform_normal_order(ntt_info.log_n as u32, &bit.evaluations), )) @@ -152,22 +151,22 @@ fn gen_rlwe_mult_rgsw_instance( // 4. Compute the output of coefficient form let output_rlwe = RlweCiphertext { - a: Rc::new(DenseMultilinearExtension::from_evaluations_vec( + a: Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( ntt_info.log_n, ntt_inverse_transform_normal_order(ntt_info.log_n as u32, &output_g_ntt), )), - b: Rc::new(DenseMultilinearExtension::from_evaluations_vec( + b: Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( ntt_info.log_n, ntt_inverse_transform_normal_order(ntt_info.log_n as u32, &output_h_ntt), )), }; let output_rlwe_ntt = RlweCiphertext { - a: Rc::new(DenseMultilinearExtension::from_evaluations_vec( + a: Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( ntt_info.log_n, output_g_ntt, )), - b: Rc::new(DenseMultilinearExtension::from_evaluations_vec( + b: Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( ntt_info.log_n, output_h_ntt, )), @@ -228,33 +227,37 @@ fn test_trivial_rlwe_mult_rgsw() { .collect(); for _ in 0..bits_len { bits_rgsw_c_ntt.add_rlwe( - DenseMultilinearExtension::from_evaluations_slice(log_n, &points), - DenseMultilinearExtension::from_evaluations_slice(log_n, &points), + DenseMultilinearExtensionBase::from_evaluations_slice(log_n, &points), + DenseMultilinearExtensionBase::from_evaluations_slice(log_n, &points), ); } let mut bits_rgsw_f_ntt = >::new(bits_len as usize); for _ in 0..bits_len { bits_rgsw_f_ntt.add_rlwe( - DenseMultilinearExtension::from_evaluations_slice(log_n, &points), - DenseMultilinearExtension::from_evaluations_slice(log_n, &points), + DenseMultilinearExtensionBase::from_evaluations_slice(log_n, &points), + DenseMultilinearExtensionBase::from_evaluations_slice(log_n, &points), ); } // generate the random RLWE ciphertext let input_rlwe = RlweCiphertext { - a: Rc::new(DenseMultilinearExtension::from_evaluations_slice( + a: Rc::new(DenseMultilinearExtensionBase::from_evaluations_slice( log_n, &coeffs, )), - b: Rc::new(DenseMultilinearExtension::from_evaluations_slice( + b: Rc::new(DenseMultilinearExtensionBase::from_evaluations_slice( log_n, &coeffs, )), }; let num_ntt_instance = (basis_info.bits_len << 1) + 2; - let randomness_ntt = (0..num_ntt_instance) - .map(|_| uniform.sample(&mut rng)) - .collect::>(); + + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let prover_randomness_ntt = prover_trans + .get_vec_ext_field_challenge(b"randomize ntt instances", num_ntt_instance as usize); + let verify_randomness_ntt = verifier_trans + .get_vec_ext_field_challenge(b"randomize ntt instances", num_ntt_instance as usize); // generate all the witness required let instance = gen_rlwe_mult_rgsw_instance( @@ -262,23 +265,33 @@ fn test_trivial_rlwe_mult_rgsw() { (bits_rgsw_c_ntt, bits_rgsw_f_ntt), &basis_info, &ntt_info, - &randomness_ntt, + &prover_randomness_ntt, ); // check the consistency of the randomized NTT instance - let ntt_points = - ntt_transform_normal_order(log_n as u32, &instance.ntt_instance.coeffs.evaluations); - assert_eq!(ntt_points, instance.ntt_instance.points.evaluations); + // let ntt_points = + // ntt_transform_normal_order(log_n as u32, &instance.ntt_instance.coeffs.evaluations); + // assert_eq!(ntt_points, instance.ntt_instance.points.evaluations); let instance_info = instance.info(); - let u: Vec<_> = (0..num_vars).map(|_| uniform.sample(&mut rng)).collect(); - let proof = RlweMultRgswIOP::prove(&instance, &u); + let prover_u = prover_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + let verify_u = verifier_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); - let subclaim = RlweMultRgswIOP::verify(&proof, &randomness_ntt, &u, &instance_info); + let proof = >::prove(&mut prover_trans, &instance, &prover_u); + + let subclaim = >::verify( + &mut verifier_trans, + &proof, + &verify_randomness_ntt, + &verify_u, + &instance_info, + ); assert!(subclaim.verify_subclaim( - &u, - &randomness_ntt, + &verify_u, + &verify_randomness_ntt, &instance.ntt_instance.coeffs, &instance.ntt_instance.points, &instance.input_rlwe, diff --git a/zkp/tests/test_round.rs b/zkp/tests/test_round.rs index 4ec8258e..9d6c121b 100644 --- a/zkp/tests/test_round.rs +++ b/zkp/tests/test_round.rs @@ -1,6 +1,8 @@ use algebra::{ derive::{DecomposableField, FheField, Field, Prime, NTT}, - DecomposableField, DenseMultilinearExtension, Field, FieldUniformSampler, + utils::Transcript, + BabyBear, BabyBearExetension, DecomposableField, DenseMultilinearExtensionBase, Field, + FieldUniformSampler, }; use rand_distr::Distribution; use std::rc::Rc; @@ -8,11 +10,13 @@ use std::vec; use zkp::piop::{DecomposedBitsInfo, RoundIOP, RoundInstance}; #[derive(Field, Prime, DecomposableField, FheField, NTT)] -#[modulus = 132120577] -pub struct Fp32(u32); +#[modulus = 2013265921] +pub struct Fp32(u64); -type FF = Fp32; // field type -const FP: u32 = 132120577; // ciphertext space +// field type +type FF = BabyBear; +type EF = BabyBearExetension; +const FP: u32 = 2013265921; // ciphertext space const FT: u32 = 4; // message space const FK: u32 = (FP - 1) / FT; @@ -41,20 +45,20 @@ fn test_round() { #[test] fn test_round_naive_iop() { - // k = (132120577 - 1) / FT = 33030144 = 2^25 - 2^19 + // delta = (1 << k_bits_len) - FK let k = FF::new(FK); - let k_bits_len: u32 = 25; - let delta: FF = FF::new(1 << 19); + let k_bits_len: u32 = 31; + let delta: FF = FF::new((1 << k_bits_len) - FK); let base_len: u32 = 1; let base: FF = FF::new(1 << base_len); let num_vars = 2; - let input = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let input = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, FP/4, FP/4 + 1, FP/2 + 1), )); - let output = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let output = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, 0, 1, 2), )); @@ -85,18 +89,30 @@ fn test_round_naive_iop() { let info = instance.info(); - let mut rng = rand::thread_rng(); - let uniform = >::new(); - let u: Vec<_> = (0..num_vars).map(|_| uniform.sample(&mut rng)).collect(); - let lambda_1 = uniform.sample(&mut rng); - let lambda_2 = uniform.sample(&mut rng); - - let proof = RoundIOP::prove(&instance, &u, (lambda_1, lambda_2)); - let subclaim = RoundIOP::verify(&proof, &instance.info()); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let prover_u = prover_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + let verifier_u = verifier_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + + let p_lambda: Vec = + prover_trans.get_vec_ext_field_challenge(b"random point to randomize sumcheck protocol", 2); + let v_lambda: Vec = verifier_trans + .get_vec_ext_field_challenge(b"random point to randomize sumcheck protocol", 2); + assert_eq!(p_lambda, v_lambda); + + let proof = >::prove( + &mut prover_trans, + &instance, + &prover_u, + (p_lambda[0], p_lambda[1]), + ); + let subclaim = >::verify(&mut verifier_trans, &proof, &instance.info()); assert!(subclaim.verify_subclaim( - &u, - (lambda_1, lambda_2), + &verifier_u, + (v_lambda[0], v_lambda[1]), &instance.input, &instance.output, &instance.output_bits.instances[0], @@ -113,22 +129,22 @@ fn test_round_random_iop() { let mut rng = rand::thread_rng(); let uniform = >::new(); - // k = (132120577 - 1) / FT = 33030144 = 2^25 - 2^19 + // delta = (1 << k_bits_len) - FK let k = FF::new(FK); - let k_bits_len: u32 = 25; - let delta: FF = FF::new(1 << 19); + let k_bits_len: u32 = 31; + let delta: FF = FF::new((1 << k_bits_len) - FK); let base_len: u32 = 1; let base: FF = FF::new(1 << base_len); let num_vars = 10; - let input = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let input = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, (0..1 << num_vars) .map(|_| uniform.sample(&mut rng)) .collect(), )); - let output = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let output = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, input.iter().map(|x| FF::new(decode(*x))).collect(), )); @@ -159,16 +175,29 @@ fn test_round_random_iop() { let info = instance.info(); - let u: Vec<_> = (0..num_vars).map(|_| uniform.sample(&mut rng)).collect(); - let lambda_1 = uniform.sample(&mut rng); - let lambda_2 = uniform.sample(&mut rng); - - let proof = RoundIOP::prove(&instance, &u, (lambda_1, lambda_2)); - let subclaim = RoundIOP::verify(&proof, &instance.info()); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let prover_u = prover_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + let verifier_u = verifier_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + + let p_lambda: Vec = + prover_trans.get_vec_ext_field_challenge(b"random point to randomize sumcheck protocol", 2); + let v_lambda: Vec = verifier_trans + .get_vec_ext_field_challenge(b"random point to randomize sumcheck protocol", 2); + + let proof = >::prove( + &mut prover_trans, + &instance, + &prover_u, + (p_lambda[0], p_lambda[1]), + ); + let subclaim = >::verify(&mut verifier_trans, &proof, &instance.info()); assert!(subclaim.verify_subclaim( - &u, - (lambda_1, lambda_2), + &verifier_u, + (v_lambda[0], v_lambda[1]), &instance.input, &instance.output, &instance.output_bits.instances[0], diff --git a/zkp/tests/test_sumcheck.rs b/zkp/tests/test_sumcheck.rs index ed50c1a8..1e525c67 100644 --- a/zkp/tests/test_sumcheck.rs +++ b/zkp/tests/test_sumcheck.rs @@ -1,6 +1,8 @@ use algebra::{ derive::{Field, Prime}, - DenseMultilinearExtension, Field, FieldUniformSampler, ListOfProductsOfPolynomials, + utils::Transcript, + AbstractExtensionField, BabyBear, BabyBearExetension, DenseMultilinearExtension, + DenseMultilinearExtensionBase, Field, FieldUniformSampler, ListOfProductsOfPolynomials, MultilinearExtension, }; use rand::prelude::*; @@ -11,24 +13,25 @@ use zkp::sumcheck::IPForMLSumcheck; use zkp::sumcheck::MLSumcheck; #[derive(Field, Prime)] -#[modulus = 132120577] -pub struct Fp32(u32); +#[modulus = 2013265921] +pub struct Fp32(u64); // field type -type FF = Fp32; +type FF = BabyBear; +type EF = BabyBearExetension; fn random_product( nv: usize, num_multiplicands: usize, rng: &mut R, -) -> (Vec>>, F) { +) -> (Vec>>, F) { let mut multiplicands = Vec::with_capacity(num_multiplicands); for _ in 0..num_multiplicands { multiplicands.push(Vec::with_capacity(1 << nv)); } let mut sum = F::zero(); - let uniform_sampler = FieldUniformSampler::new(); + let uniform_sampler = >::new(); for _ in 0..(1 << nv) { let mut product = F::one(); for multiplicand in &mut multiplicands { @@ -42,45 +45,51 @@ fn random_product( ( multiplicands .into_iter() - .map(|x| Rc::new(DenseMultilinearExtension::from_evaluations_vec(nv, x))) + .map(|x| Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec(nv, x))) .collect(), sum, ) } -fn random_list_of_products( +fn random_list_of_products, R: RngCore>( nv: usize, num_multiplicands_range: (usize, usize), num_products: usize, rng: &mut R, -) -> (ListOfProductsOfPolynomials, F) { +) -> (ListOfProductsOfPolynomials, E) { let mut sum = F::zero(); let mut poly = ListOfProductsOfPolynomials::new(nv); - let uniform_sampler = FieldUniformSampler::new(); + let uniform_sampler = >::new(); for _ in 0..num_products { let num_multiplicands: usize = rng.gen_range(num_multiplicands_range.0..num_multiplicands_range.1); let (product, product_sum) = random_product(nv, num_multiplicands, rng); let coefficient = uniform_sampler.sample(rng); - poly.add_product(product.into_iter(), coefficient); - sum += product_sum * coefficient; + let product: Vec>> = product + .iter() + .map(|x| Rc::new(>::from_base(x.as_ref()))) + .collect(); + poly.add_product(product.into_iter(), E::from_base(coefficient)); + sum += coefficient * product_sum; } - (poly, sum) + (poly, E::from_base(sum)) } fn test_protocol(nv: usize, num_multiplicands_range: (usize, usize), num_products: usize) { let mut rng = thread_rng(); let (poly, asserted_sum) = - random_list_of_products::(nv, num_multiplicands_range, num_products, &mut rng); + random_list_of_products::(nv, num_multiplicands_range, num_products, &mut rng); let poly_info = poly.info(); let mut prover_state = IPForMLSumcheck::prover_init(&poly); let mut verifier_state = IPForMLSumcheck::verifier_init(&poly_info); let mut verifier_msg = None; + let mut trans = Transcript::::new(); for _ in 0..poly.num_variables { let prover_message = IPForMLSumcheck::prove_round(&mut prover_state, &verifier_msg); - verifier_msg = IPForMLSumcheck::verify_round(prover_message, &mut verifier_state, &mut rng); + verifier_msg = + IPForMLSumcheck::verify_round(prover_message, &mut verifier_state, &mut trans); } let subclaim = IPForMLSumcheck::check_and_generate_subclaim(verifier_state, asserted_sum) .expect("fail to generate subclaim"); @@ -93,10 +102,13 @@ fn test_protocol(nv: usize, num_multiplicands_range: (usize, usize), num_product fn test_polynomial(nv: usize, num_multiplicands_range: (usize, usize), num_products: usize) { let mut rng = thread_rng(); let (poly, asserted_sum) = - random_list_of_products::(nv, num_multiplicands_range, num_products, &mut rng); + random_list_of_products::(nv, num_multiplicands_range, num_products, &mut rng); let poly_info = poly.info(); - let proof = MLSumcheck::prove(&poly).expect("fail to prove"); - let subclaim = MLSumcheck::verify(&poly_info, asserted_sum, &proof).expect("fail to verify"); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let proof = MLSumcheck::prove(&mut prover_trans, &poly).expect("fail to prove"); + let subclaim = MLSumcheck::verify(&mut verifier_trans, &poly_info, asserted_sum, &proof.0) + .expect("fail to verify"); assert!( poly.evaluate(&subclaim.point) == subclaim.expected_evaluations, "wrong subclaim" @@ -107,18 +119,16 @@ fn test_polynomial_as_subprotocol( nv: usize, num_multiplicands_range: (usize, usize), num_products: usize, - prover_rng: &mut impl RngCore, - verifier_rng: &mut impl RngCore, + prover_trans: &mut Transcript, + verifier_rng: &mut Transcript, ) { let mut rng = thread_rng(); let (poly, asserted_sum) = - random_list_of_products::(nv, num_multiplicands_range, num_products, &mut rng); + random_list_of_products::(nv, num_multiplicands_range, num_products, &mut rng); let poly_info = poly.info(); - let (proof, prover_state) = - MLSumcheck::prove_as_subprotocol(prover_rng, &poly).expect("fail to prove"); + let (proof, prover_state) = MLSumcheck::prove(prover_trans, &poly).expect("fail to prove"); let subclaim = - MLSumcheck::verify_as_subprotocol(verifier_rng, &poly_info, asserted_sum, &proof) - .expect("fail to verify"); + MLSumcheck::verify(verifier_rng, &poly_info, asserted_sum, &proof).expect("fail to verify"); assert!( poly.evaluate(&subclaim.point) == subclaim.expected_evaluations, "wrong subclaim" @@ -136,17 +146,15 @@ fn test_trivial_polynomial() { test_protocol(nv, num_multiplicands_range, num_products); test_polynomial(nv, num_multiplicands_range, num_products); - let mut seed: ::Seed = Default::default(); - thread_rng().fill(&mut seed); - let mut prover_rng = ChaCha12Rng::from_seed(seed); - let mut verifier_rng = ChaCha12Rng::from_seed(seed); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); test_polynomial_as_subprotocol( nv, num_multiplicands_range, num_products, - &mut prover_rng, - &mut verifier_rng, + &mut prover_trans, + &mut verifier_trans, ) } } @@ -163,15 +171,15 @@ fn test_normal_polynomial() { let mut seed: ::Seed = Default::default(); thread_rng().fill(&mut seed); - let mut prover_rng = ChaCha12Rng::from_seed(seed); - let mut verifier_rng = ChaCha12Rng::from_seed(seed); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); test_polynomial_as_subprotocol( nv, num_multiplicands_range, num_products, - &mut prover_rng, - &mut verifier_rng, + &mut prover_trans, + &mut verifier_trans, ) } } @@ -183,18 +191,16 @@ fn test_normal_polynomial_different_transcript_fails() { let num_multiplicands_range = (4, 9); let num_products = 5; - let mut prover_seed: ::Seed = Default::default(); - let mut verifier_seed: ::Seed = Default::default(); - thread_rng().fill(&mut prover_seed); - thread_rng().fill(&mut verifier_seed); - let mut prover_rng = ChaCha12Rng::from_seed(prover_seed); - let mut verifier_rng = ChaCha12Rng::from_seed(verifier_seed); + let mut prover_trans = Transcript::::new(); + prover_trans.append_message(b"msg", &"prover"); + let mut verifier_trans = Transcript::::new(); + verifier_trans.append_message(b"msg", &"verifier"); test_polynomial_as_subprotocol( nv, num_multiplicands_range, num_products, - &mut prover_rng, - &mut verifier_rng, + &mut prover_trans, + &mut verifier_trans, ) } @@ -211,10 +217,10 @@ fn zero_polynomial_should_error() { #[test] fn test_extract_sum() { let mut rng = thread_rng(); - let (poly, asserted_sum) = random_list_of_products::(8, (3, 4), 3, &mut rng); - - let proof = MLSumcheck::prove(&poly).expect("fail to prove"); - assert_eq!(MLSumcheck::extract_sum(&proof), asserted_sum); + let (poly, asserted_sum) = random_list_of_products::(8, (3, 4), 3, &mut rng); + let mut prover_trans = Transcript::::new(); + let proof = MLSumcheck::prove(&mut prover_trans, &poly).expect("fail to prove"); + assert_eq!(MLSumcheck::extract_sum(&proof.0), asserted_sum); } #[test] @@ -223,11 +229,11 @@ fn test_extract_sum() { fn test_shared_reference() { let mut rng = thread_rng(); let ml_extensions: Vec<_> = (0..5) - .map(|_| Rc::new(DenseMultilinearExtension::::random(8, &mut rng))) + .map(|_| Rc::new(DenseMultilinearExtension::::random(8, &mut rng))) .collect(); let mut poly = ListOfProductsOfPolynomials::new(8); - let uniform_sampler = >::new(); + let uniform_sampler = >::new(); poly.add_product( vec![ ml_extensions[0].clone(), @@ -269,9 +275,13 @@ fn test_shared_reference() { drop(prover); let poly_info = poly.info(); - let proof = MLSumcheck::prove(&poly).expect("fail to prove"); - let asserted_sum = MLSumcheck::extract_sum(&proof); - let subclaim = MLSumcheck::verify(&poly_info, asserted_sum, &proof).expect("fail to verify"); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + + let proof = MLSumcheck::prove(&mut prover_trans, &poly).expect("fail to prove"); + let asserted_sum = MLSumcheck::extract_sum(&proof.0); + let subclaim = MLSumcheck::verify(&mut verifier_trans, &poly_info, asserted_sum, &proof.0) + .expect("fail to verify"); assert!( poly.evaluate(&subclaim.point) == subclaim.expected_evaluations, "wrong subclaim" diff --git a/zkp/tests/test_zq_to_rq.rs b/zkp/tests/test_zq_to_rq.rs index 5522baba..29bfb1ba 100644 --- a/zkp/tests/test_zq_to_rq.rs +++ b/zkp/tests/test_zq_to_rq.rs @@ -1,7 +1,8 @@ use algebra::{ - derive::*, Basis, DecomposableField, DenseMultilinearExtension, Field, FieldUniformSampler, - SparsePolynomial, + derive::*, utils::Transcript, Basis, DecomposableField, DenseMultilinearExtensionBase, Field, + FieldUniformSampler, SparsePolynomial, }; +use fhe_core::{DefaultExtendsionFieldU32x4, DefaultFieldU32}; use num_traits::{One, Zero}; use rand::prelude::*; use rand_distr::Distribution; @@ -9,16 +10,13 @@ use std::rc::Rc; use std::vec; use zkp::piop::zq_to_rq::{TransformZqtoRQ, TransformZqtoRQInstance}; -#[derive(Field, Prime, DecomposableField)] -#[modulus = 132120577] -pub struct Fp32(u32); - #[derive(Field, DecomposableField)] #[modulus = 512] pub struct Fq(u32); // field type -type FF = Fp32; +type FF = DefaultFieldU32; +type EF = DefaultExtendsionFieldU32x4; macro_rules! field_vec { ($t:ty; $elem:expr; $n:expr)=>{ @@ -48,9 +46,7 @@ macro_rules! field_vec { #[test] fn test_trivial_zq_to_rq() { - let mut rng = thread_rng(); - let sampler = >::new(); - let p = 132120577; + let p = DefaultFieldU32::MODULUS_VALUE; let q = 8; let c_num_vars = 3; let base_len: u32 = 1; @@ -58,22 +54,22 @@ fn test_trivial_zq_to_rq() { let num_vars = 2; let bits_len: u32 = 3; - let a = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let a = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, 3, 5, 7), )); - let k = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let k = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, 0, 1, 1), )); - let r = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let r = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, 6, 2, 6), )); - let s = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let s = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 1, 7, p-3, p-7), )); @@ -97,7 +93,7 @@ fn test_trivial_zq_to_rq() { )), ]; - let c_dense = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let c_dense = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( c_num_vars + num_vars, field_vec!(FF; 1, 0, 0, 0, 0, 0, 0, 0, @@ -120,9 +116,17 @@ fn test_trivial_zq_to_rq() { bits_len, ); let info = instance.info(); - let u: Vec<_> = (0..num_vars).map(|_| sampler.sample(&mut rng)).collect(); - let proof = TransformZqtoRQ::prove(&instance, &u); - let subclaim = TransformZqtoRQ::verify(&proof, &info.decomposed_bits_info, 3); + + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let prover_u: Vec = prover_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + let verify_u: Vec = verifier_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + + let proof = TransformZqtoRQ::prove(&mut prover_trans, &instance, &prover_u); + let subclaim = + TransformZqtoRQ::verify(&mut verifier_trans, &proof, &info.decomposed_bits_info, 3); assert!(subclaim.verify_subclaim( q, @@ -132,7 +136,7 @@ fn test_trivial_zq_to_rq() { vec![r].as_ref(), s.as_ref(), &r_bits, - &u, + &verify_u, &info )); } @@ -141,7 +145,6 @@ fn test_trivial_zq_to_rq() { fn test_random_zq_to_rq() { let mut rng = thread_rng(); let uniform_fq = >::new(); - let uniform_fp = >::new(); let num_vars = 10; let q = FF::new(Fq::MODULUS_VALUE); let c_num_vars = (q.value() as usize).ilog2() as usize; @@ -189,15 +192,19 @@ fn test_random_zq_to_rq() { } }); - let a: Rc> = - Rc::new(DenseMultilinearExtension::from_evaluations_vec(num_vars, a)); - let k: Rc> = - Rc::new(DenseMultilinearExtension::from_evaluations_vec(num_vars, k)); - let r: Rc> = - Rc::new(DenseMultilinearExtension::from_evaluations_vec(num_vars, r)); - let s: Rc> = - Rc::new(DenseMultilinearExtension::from_evaluations_vec(num_vars, s)); - let c_dense = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let a: Rc> = Rc::new( + DenseMultilinearExtensionBase::from_evaluations_vec(num_vars, a), + ); + let k: Rc> = Rc::new( + DenseMultilinearExtensionBase::from_evaluations_vec(num_vars, k), + ); + let r: Rc> = Rc::new( + DenseMultilinearExtensionBase::from_evaluations_vec(num_vars, r), + ); + let s: Rc> = Rc::new( + DenseMultilinearExtensionBase::from_evaluations_vec(num_vars, s), + ); + let c_dense = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars + c_num_vars, c_dense_matrix, )); @@ -216,9 +223,20 @@ fn test_random_zq_to_rq() { bits_len, ); let info = instance.info(); - let u: Vec<_> = (0..num_vars).map(|_| uniform_fp.sample(&mut rng)).collect(); - let proof = TransformZqtoRQ::prove(&instance, &u); - let subclaim = TransformZqtoRQ::verify(&proof, &info.decomposed_bits_info, c_num_vars); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let prover_u: Vec = prover_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + let verify_u: Vec = verifier_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + + let proof = TransformZqtoRQ::prove(&mut prover_trans, &instance, &prover_u); + let subclaim = TransformZqtoRQ::verify( + &mut verifier_trans, + &proof, + &info.decomposed_bits_info, + c_num_vars, + ); assert!(subclaim.verify_subclaim( q.value() as usize, @@ -228,16 +246,14 @@ fn test_random_zq_to_rq() { vec![r].as_ref(), s.as_ref(), &r_bits, - &u, + &verify_u, &info )); } #[test] fn test_trivial_zq_to_rq_without_oracle() { - let mut rng = thread_rng(); - let sampler = >::new(); - let p = 132120577; + let p = DefaultFieldU32::MODULUS_VALUE; let q = 8; let c_num_vars = 3; let base_len: u32 = 1; @@ -245,22 +261,22 @@ fn test_trivial_zq_to_rq_without_oracle() { let num_vars = 2; let bits_len: u32 = 3; - let a = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let a = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, 3, 5, 7), )); - let k = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let k = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, 0, 1, 1), )); - let r = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let r = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 0, 6, 2, 6), )); - let s = Rc::new(DenseMultilinearExtension::from_evaluations_vec( + let s = Rc::new(DenseMultilinearExtensionBase::from_evaluations_vec( num_vars, field_vec!(FF; 1, 7, p-3, p-7), )); @@ -298,9 +314,15 @@ fn test_trivial_zq_to_rq_without_oracle() { bits_len, ); let info = instance.info(); - let u: Vec<_> = (0..num_vars).map(|_| sampler.sample(&mut rng)).collect(); - let proof = TransformZqtoRQ::prove(&instance, &u); - let subclaim = TransformZqtoRQ::verify(&proof, &info.decomposed_bits_info, 3); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let prover_u: Vec = prover_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + let verify_u: Vec = verifier_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + let proof = TransformZqtoRQ::prove(&mut prover_trans, &instance, &prover_u); + let subclaim = + TransformZqtoRQ::verify(&mut verifier_trans, &proof, &info.decomposed_bits_info, 3); assert!(subclaim.verify_subclaim_without_oracle( q, @@ -310,7 +332,7 @@ fn test_trivial_zq_to_rq_without_oracle() { vec![r].as_ref(), s.as_ref(), &r_bits, - &u, + &verify_u, &info )); } @@ -319,7 +341,6 @@ fn test_trivial_zq_to_rq_without_oracle() { fn test_random_zq_to_rq_without_oracle() { let mut rng = thread_rng(); let uniform_fq = >::new(); - let uniform_fp = >::new(); let num_vars = 10; let q = FF::new(Fq::MODULUS_VALUE); let c_num_vars = (q.value() as usize).ilog2() as usize; @@ -360,14 +381,18 @@ fn test_random_zq_to_rq_without_oracle() { } }); - let a: Rc> = - Rc::new(DenseMultilinearExtension::from_evaluations_vec(num_vars, a)); - let k: Rc> = - Rc::new(DenseMultilinearExtension::from_evaluations_vec(num_vars, k)); - let r: Rc> = - Rc::new(DenseMultilinearExtension::from_evaluations_vec(num_vars, r)); - let s: Rc> = - Rc::new(DenseMultilinearExtension::from_evaluations_vec(num_vars, s)); + let a: Rc> = Rc::new( + DenseMultilinearExtensionBase::from_evaluations_vec(num_vars, a), + ); + let k: Rc> = Rc::new( + DenseMultilinearExtensionBase::from_evaluations_vec(num_vars, k), + ); + let r: Rc> = Rc::new( + DenseMultilinearExtensionBase::from_evaluations_vec(num_vars, r), + ); + let s: Rc> = Rc::new( + DenseMultilinearExtensionBase::from_evaluations_vec(num_vars, s), + ); let tmp = r.get_decomposed_mles(base_len, bits_len); let r_bits: Vec<_> = vec![&tmp]; @@ -383,9 +408,20 @@ fn test_random_zq_to_rq_without_oracle() { bits_len, ); let info = instance.info(); - let u: Vec<_> = (0..num_vars).map(|_| uniform_fp.sample(&mut rng)).collect(); - let proof = TransformZqtoRQ::prove(&instance, &u); - let subclaim = TransformZqtoRQ::verify(&proof, &info.decomposed_bits_info, c_num_vars); + let mut prover_trans = Transcript::::new(); + let mut verifier_trans = Transcript::::new(); + let prover_u: Vec = prover_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + let verify_u: Vec = verifier_trans + .get_vec_ext_field_challenge(b"random point to instantiate sumcheck protocol", num_vars); + + let proof = TransformZqtoRQ::prove(&mut prover_trans, &instance, &prover_u); + let subclaim = TransformZqtoRQ::verify( + &mut verifier_trans, + &proof, + &info.decomposed_bits_info, + c_num_vars, + ); assert!(subclaim.verify_subclaim_without_oracle( q.value() as usize, @@ -395,7 +431,7 @@ fn test_random_zq_to_rq_without_oracle() { vec![r].as_ref(), s.as_ref(), &r_bits, - &u, + &verify_u, &info )); }