From 49100268b86631f69d71fdd1a1092b1feb73e7a1 Mon Sep 17 00:00:00 2001 From: Conghao Shen Date: Fri, 13 May 2022 21:48:25 -0700 Subject: [PATCH 1/3] change version to 0.3.1 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 1bc2f2c..c84a585 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ark-linear-sumcheck" -version = "0.3.0" +version = "0.3.1" authors = [ "Tom Shen ", "arkworks contributors" From 27656fa7dc3ec1cd1cc4483cb121a90c58173072 Mon Sep 17 00:00:00 2001 From: Conghao Shen Date: Tue, 17 May 2022 14:36:18 -0700 Subject: [PATCH 2/3] use sponge for transcript also: refmt code, update dependency --- Cargo.toml | 18 ++- rustfmt.toml | 10 ++ src/gkr_round_sumcheck/data_structures.rs | 10 +- src/gkr_round_sumcheck/mod.rs | 74 ++++----- src/gkr_round_sumcheck/test.rs | 6 +- src/lib.rs | 1 - src/ml_sumcheck/data_structures.rs | 54 ++++--- src/ml_sumcheck/mod.rs | 50 +++--- src/ml_sumcheck/protocol/mod.rs | 4 +- src/ml_sumcheck/protocol/prover.rs | 56 ++++--- src/ml_sumcheck/protocol/verifier.rs | 58 ++++--- src/ml_sumcheck/test.rs | 88 ++++++++--- src/rng.rs | 176 ---------------------- 13 files changed, 272 insertions(+), 333 deletions(-) create mode 100644 rustfmt.toml delete mode 100644 src/rng.rs diff --git a/Cargo.toml b/Cargo.toml index c84a585..52f2ae8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,11 +22,23 @@ ark-poly = { version = "^0.3.0", default-features = false } blake2 = { version = "0.9", default-features = false } hashbrown = { version = "0.11.2" } rayon = { version = "1", optional = true } +ark-sponge = { version = "^0.3.0", default-features = false } [dev-dependencies] -ark-test-curves = { version = "^0.3.0", default-features = false, features = ["bls12_381_scalar_field", "bls12_381_curve"] } - +#ark-test-curves = { version = "^0.3.0", default-features = false, features = ["bls12_381_scalar_field", "bls12_381_curve"] } +ark-bls12-381 = { version = "^0.3.0", default-features = false, features = ["scalar_field"] } [features] default = ["std"] -std = ["ark-ff/std", "ark-serialize/std", "blake2/std", "ark-std/std", "ark-poly/std"] +std = ["ark-ff/std", "ark-serialize/std", "blake2/std", "ark-std/std", "ark-poly/std", "ark-sponge/std", "ark-bls12-381/std"] parallel = ["std", "ark-ff/parallel", "ark-poly/parallel", "ark-std/parallel", "rayon"] + +# To be removed in the new release. +[patch.crates-io] +ark-sponge = { git = "https://github.com/arkworks-rs/sponge" } +ark-std = { git = "https://github.com/arkworks-rs/std" } +ark-ff = { git = "https://github.com/arkworks-rs/algebra" } +ark-ec = { git = "https://github.com/arkworks-rs/algebra" } +ark-serialize = { git = "https://github.com/arkworks-rs/algebra" } +ark-poly = { git = "https://github.com/arkworks-rs/algebra" } +ark-bls12-381 = { git = "https://github.com/arkworks-rs/curves" } +ark-r1cs-std = { git = "https://github.com/arkworks-rs/r1cs-std" } \ No newline at end of file diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..0955465 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,10 @@ +condense_wildcard_suffixes = true +edition = "2018" +imports_granularity = "Crate" +match_block_trailing_comma = true +normalize_comments = true +reorder_imports = true +use_field_init_shorthand = true +use_try_shorthand = true +wrap_comments = true +max_width = 100 \ No newline at end of file diff --git a/src/gkr_round_sumcheck/data_structures.rs b/src/gkr_round_sumcheck/data_structures.rs index e3f8d95..9561cc9 100644 --- a/src/gkr_round_sumcheck/data_structures.rs +++ b/src/gkr_round_sumcheck/data_structures.rs @@ -1,17 +1,17 @@ //! Data structures used by GKR Round Sumcheck use crate::ml_sumcheck::protocol::prover::ProverMsg; -use ark_ff::Field; +use ark_ff::PrimeField; use ark_poly::{DenseMultilinearExtension, MultilinearExtension, SparseMultilinearExtension}; use ark_std::vec::Vec; /// Proof for GKR Round Function -pub struct GKRProof { +pub struct GKRProof { pub(crate) phase1_sumcheck_msgs: Vec>, pub(crate) phase2_sumcheck_msgs: Vec>, } -impl GKRProof { +impl GKRProof { /// Extract the witness (i.e. the sum of GKR) pub fn extract_sum(&self) -> F { self.phase1_sumcheck_msgs[0].evaluations[0] + self.phase1_sumcheck_msgs[0].evaluations[1] @@ -19,7 +19,7 @@ impl GKRProof { } /// Subclaim for GKR Round Function -pub struct GKRRoundSumcheckSubClaim { +pub struct GKRRoundSumcheckSubClaim { /// u pub u: Vec, /// v @@ -28,7 +28,7 @@ pub struct GKRRoundSumcheckSubClaim { pub expected_evaluation: F, } -impl GKRRoundSumcheckSubClaim { +impl GKRRoundSumcheckSubClaim { /// Verify that the subclaim is true by evaluating the GKR Round function. pub fn verify_subclaim( &self, diff --git a/src/gkr_round_sumcheck/mod.rs b/src/gkr_round_sumcheck/mod.rs index 4d4d3da..37acf1b 100644 --- a/src/gkr_round_sumcheck/mod.rs +++ b/src/gkr_round_sumcheck/mod.rs @@ -3,21 +3,23 @@ //! GKR Round Sumcheck will use `ml_sumcheck` as a subroutine. pub mod data_structures; -#[cfg(test)] -mod test; - -use crate::gkr_round_sumcheck::data_structures::{GKRProof, GKRRoundSumcheckSubClaim}; -use crate::ml_sumcheck::protocol::prover::ProverState; -use crate::ml_sumcheck::protocol::{IPForMLSumcheck, ListOfProductsOfPolynomials, PolynomialInfo}; -use crate::rng::{Blake2s512Rng, FeedableRNG}; -use ark_ff::{Field, Zero}; +// #[cfg(test)] +// mod test; + +use crate::{ + gkr_round_sumcheck::data_structures::{GKRProof, GKRRoundSumcheckSubClaim}, + ml_sumcheck::protocol::{ + prover::ProverState, IPForMLSumcheck, ListOfProductsOfPolynomials, PolynomialInfo, + }, +}; +use ark_ff::{PrimeField, Zero}; use ark_poly::{DenseMultilinearExtension, MultilinearExtension, SparseMultilinearExtension}; -use ark_std::marker::PhantomData; -use ark_std::rc::Rc; -use ark_std::vec::Vec; +use ark_sponge::{Absorb, CryptographicSponge}; +use ark_std::{marker::PhantomData, rc::Rc, vec::Vec}; -/// Takes multilinear f1, f3, and input g = g1,...,gl. Returns h_g, and f1 fixed at g. -pub fn initialize_phase_one( +/// Takes multilinear f1, f3, and input g = g1,...,gl. Returns h_g, and f1 fixed +/// at g. +pub fn initialize_phase_one( f1: &SparseMultilinearExtension, f3: &DenseMultilinearExtension, g: &[F], @@ -38,7 +40,7 @@ pub fn initialize_phase_one( } /// Takes h_g and returns a sumcheck state -pub fn start_phase1_sumcheck( +pub fn start_phase1_sumcheck( h_g: &DenseMultilinearExtension, f2: &DenseMultilinearExtension, ) -> ProverState { @@ -49,8 +51,9 @@ pub fn start_phase1_sumcheck( IPForMLSumcheck::prover_init(&poly) } -/// Takes multilinear f1 fixed at g, phase one randomness u. Returns f1 fixed at g||u -pub fn initialize_phase_two( +/// Takes multilinear f1 fixed at g, phase one randomness u. Returns f1 fixed at +/// g||u +pub fn initialize_phase_two( f1_g: &SparseMultilinearExtension, u: &[F], ) -> DenseMultilinearExtension { @@ -59,7 +62,7 @@ pub fn initialize_phase_two( } /// Takes f1 fixed at g||u, f3, and f2 evaluated at u. -pub fn start_phase2_sumcheck( +pub fn start_phase2_sumcheck( f1_gu: &DenseMultilinearExtension, f3: &DenseMultilinearExtension, f2_u: F, @@ -78,15 +81,16 @@ pub fn start_phase2_sumcheck( } /// Sumcheck Argument for GKR Round Function -pub struct GKRRoundSumcheck { +pub struct GKRRoundSumcheck { _marker: PhantomData, } -impl GKRRoundSumcheck { +impl GKRRoundSumcheck { /// Takes a GKR Round Function and input, prove the sum. /// * `f1`,`f2`,`f3`: represents the GKR round function /// * `g`: represents the fixed input. - pub fn prove( + pub fn prove( + sponge: &mut S, f1: &SparseMultilinearExtension, f2: &DenseMultilinearExtension, f3: &DenseMultilinearExtension, @@ -98,8 +102,6 @@ impl GKRRoundSumcheck { let dim = f2.num_vars; let g = g.to_vec(); - let mut rng = Blake2s512Rng::setup(); - let (h_g, f1_g) = initialize_phase_one(f1, f3, &g); let mut phase1_ps = start_phase1_sumcheck(&h_g, f2); let mut phase1_vm = None; @@ -108,9 +110,9 @@ impl GKRRoundSumcheck { for _ in 0..dim { let (pm, ps) = IPForMLSumcheck::prove_round(phase1_ps, &phase1_vm); phase1_ps = ps; - rng.feed(&pm).unwrap(); + sponge.absorb(&pm); phase1_prover_msgs.push(pm); - let vm = IPForMLSumcheck::sample_round(&mut rng); + let vm = IPForMLSumcheck::sample_round(sponge); phase1_vm = Some(vm.clone()); u.push(vm.randomness); } @@ -123,9 +125,9 @@ impl GKRRoundSumcheck { for _ in 0..dim { let (pm, ps) = IPForMLSumcheck::prove_round(phase2_ps, &phase2_vm); phase2_ps = ps; - rng.feed(&pm).unwrap(); + sponge.absorb(&pm); phase2_prover_msgs.push(pm); - let vm = IPForMLSumcheck::sample_round(&mut rng); + let vm = IPForMLSumcheck::sample_round(sponge); phase2_vm = Some(vm.clone()); v.push(vm.randomness); } @@ -138,11 +140,13 @@ impl GKRRoundSumcheck { /// Takes a GKR Round Function, input, and proof, and returns a subclaim. /// - /// If the `claimed_sum` is correct, then it is `subclaim.verify_subclaim` will return true. - /// Otherwise, it is very likely that `subclaim.verify_subclaim` will return false. - /// Larger field size guarantees smaller soundness error. + /// If the `claimed_sum` is correct, then it is `subclaim.verify_subclaim` + /// will return true. Otherwise, it is very likely that + /// `subclaim.verify_subclaim` will return false. Larger field size + /// guarantees smaller soundness error. /// * `f2_num_vars`: represents number of variables of f2 - pub fn verify( + pub fn verify( + sponge: &mut S, f2_num_vars: usize, proof: &GKRProof, claimed_sum: F, @@ -150,8 +154,6 @@ impl GKRRoundSumcheck { // verify first sumcheck let dim = f2_num_vars; - let mut rng = Blake2s512Rng::setup(); - let mut phase1_vs = IPForMLSumcheck::verifier_init(&PolynomialInfo { max_multiplicands: 2, num_variables: dim, @@ -159,8 +161,8 @@ impl GKRRoundSumcheck { for i in 0..dim { let pm = &proof.phase1_sumcheck_msgs[i]; - rng.feed(pm).unwrap(); - let result = IPForMLSumcheck::verify_round((*pm).clone(), phase1_vs, &mut rng); + sponge.absorb(&pm); + let result = IPForMLSumcheck::verify_round((*pm).clone(), phase1_vs, sponge); phase1_vs = result.1; } let phase1_subclaim = IPForMLSumcheck::check_and_generate_subclaim(phase1_vs, claimed_sum)?; @@ -172,8 +174,8 @@ impl GKRRoundSumcheck { }); for i in 0..dim { let pm = &proof.phase2_sumcheck_msgs[i]; - rng.feed(pm).unwrap(); - let result = IPForMLSumcheck::verify_round((*pm).clone(), phase2_vs, &mut rng); + sponge.absorb(&pm); + let result = IPForMLSumcheck::verify_round((*pm).clone(), phase2_vs, sponge); phase2_vs = result.1; } let phase2_subclaim = IPForMLSumcheck::check_and_generate_subclaim( diff --git a/src/gkr_round_sumcheck/test.rs b/src/gkr_round_sumcheck/test.rs index 9d9181d..21cbbb1 100644 --- a/src/gkr_round_sumcheck/test.rs +++ b/src/gkr_round_sumcheck/test.rs @@ -1,11 +1,11 @@ use crate::gkr_round_sumcheck::GKRRoundSumcheck; -use ark_ff::Field; +use ark_ff::{Field, PrimeField}; use ark_poly::{DenseMultilinearExtension, MultilinearExtension, SparseMultilinearExtension}; use ark_std::rand::RngCore; use ark_std::{test_rng, UniformRand}; use ark_test_curves::bls12_381::Fr; -fn random_gkr_instance( +fn random_gkr_instance( dim: usize, rng: &mut R, ) -> ( @@ -20,7 +20,7 @@ fn random_gkr_instance( ) } -fn calculate_sum_naive( +fn calculate_sum_naive( f1: &SparseMultilinearExtension, f2: &DenseMultilinearExtension, f3: &DenseMultilinearExtension, diff --git a/src/lib.rs b/src/lib.rs index ff158f6..89a3804 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,7 +22,6 @@ mod error; pub mod gkr_round_sumcheck; pub mod ml_sumcheck; -pub mod rng; #[cfg(test)] mod tests {} diff --git a/src/ml_sumcheck/data_structures.rs b/src/ml_sumcheck/data_structures.rs index 6abd724..5ef4d8f 100644 --- a/src/ml_sumcheck/data_structures.rs +++ b/src/ml_sumcheck/data_structures.rs @@ -1,17 +1,19 @@ //! Defines the data structures used by the `MLSumcheck` protocol. -use ark_ff::Field; +use ark_ff::PrimeField; use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Read, SerializationError, Write}; -use ark_std::cmp::max; -use ark_std::rc::Rc; -use ark_std::vec::Vec; +use ark_sponge::Absorb; +use ark_std::{cmp::max, rc::Rc, vec::Vec}; use hashbrown::HashMap; -/// Stores a list of products of `DenseMultilinearExtension` that is meant to be added together. +/// Stores a list of products of `DenseMultilinearExtension` that is meant to be +/// added together. /// -/// The polynomial is represented by a list of products of polynomials along with its coefficient that is meant to be added together. +/// The polynomial is represented by a list of products of polynomials along +/// with its coefficient that is meant to be added together. /// -/// This data structure of the polynomial is a list of list of `(coefficient, DenseMultilinearExtension)`. +/// This data structure of the polynomial is a list of list of `(coefficient, +/// DenseMultilinearExtension)`. /// * Number of products n = `self.products.len()`, /// * Number of multiplicands of ith product m_i = `self.products[i].1.len()`, /// * Coefficient of ith product c_i = `self.products[i].0` @@ -22,20 +24,24 @@ use hashbrown::HashMap; /// /// The result polynomial is used as the prover key. #[derive(Clone)] -pub struct ListOfProductsOfPolynomials { +pub struct ListOfProductsOfPolynomials { /// max number of multiplicands in each product pub max_multiplicands: usize, /// number of variables of the polynomial pub num_variables: usize, /// list of reference to products (as usize) of multilinear extension pub products: Vec<(F, Vec)>, - /// Stores multilinear extensions in which product multiplicand can refer to. + // TODO: unnecessary to use pointer. Fix it. + /// Stores multilinear extensions in which product multiplicand can refer + /// to. pub flattened_ml_extensions: Vec>>, + /// Store the index of each `ml_extension` in `flattened_ml_extensions`. raw_pointers_lookup_table: HashMap<*const DenseMultilinearExtension, usize>, } -impl ListOfProductsOfPolynomials { - /// Extract the max number of multiplicands and number of variables of the list of products. +impl ListOfProductsOfPolynomials { + /// Extract the max number of multiplicands and number of variables of the + /// list of products. pub fn info(&self) -> PolynomialInfo { PolynomialInfo { max_multiplicands: self.max_multiplicands, @@ -45,8 +51,9 @@ impl ListOfProductsOfPolynomials { } #[derive(CanonicalSerialize, CanonicalDeserialize, Clone)] -/// Stores the number of variables and max number of multiplicands of the added polynomial used by the prover. -/// This data structures will is used as the verifier key. +/// Stores the number of variables and max number of multiplicands of the added +/// polynomial used by the prover. This data structures will is used as the +/// verifier key. pub struct PolynomialInfo { /// max number of multiplicands in each product pub max_multiplicands: usize, @@ -54,7 +61,18 @@ pub struct PolynomialInfo { pub num_variables: usize, } -impl ListOfProductsOfPolynomials { +impl Absorb for PolynomialInfo { + fn to_sponge_bytes(&self, dest: &mut Vec) { + self.serialize(dest).expect("serialization failed"); + } + + fn to_sponge_field_elements(&self, dest: &mut Vec) { + dest.push(F::from(self.max_multiplicands as u128)); + dest.push(F::from(self.num_variables as u128)); + } +} + +impl ListOfProductsOfPolynomials { /// Returns an empty polynomial pub fn new(num_variables: usize) -> Self { ListOfProductsOfPolynomials { @@ -66,8 +84,9 @@ impl ListOfProductsOfPolynomials { } } - /// Add a list of multilinear extensions that is meant to be multiplied together. - /// The resulting polynomial will be multiplied by the scalar `coefficient`. + /// Add a list of multilinear extensions that is meant to be multiplied + /// together. The resulting polynomial will be multiplied by the scalar + /// `coefficient`. pub fn add_product( &mut self, product: impl IntoIterator>>, @@ -75,7 +94,7 @@ impl ListOfProductsOfPolynomials { ) { let product: Vec>> = product.into_iter().collect(); let mut indexed_product = Vec::with_capacity(product.len()); - assert!(product.len() > 0); + assert!(!product.is_empty(), "product is empty"); self.max_multiplicands = max(self.max_multiplicands, product.len()); for m in product { assert_eq!( @@ -84,6 +103,7 @@ impl ListOfProductsOfPolynomials { ); let m_ptr: *const DenseMultilinearExtension = Rc::as_ptr(&m); if let Some(index) = self.raw_pointers_lookup_table.get(&m_ptr) { + // if same multilinear extension is already added, just add a reference to it indexed_product.push(*index) } else { let curr_index = self.flattened_ml_extensions.len(); diff --git a/src/ml_sumcheck/mod.rs b/src/ml_sumcheck/mod.rs index 143beb7..3002763 100644 --- a/src/ml_sumcheck/mod.rs +++ b/src/ml_sumcheck/mod.rs @@ -1,13 +1,12 @@ //! Sumcheck Protocol for multilinear extension -use crate::ml_sumcheck::data_structures::{ListOfProductsOfPolynomials, PolynomialInfo}; -use crate::ml_sumcheck::protocol::prover::ProverMsg; -use crate::ml_sumcheck::protocol::verifier::SubClaim; -use crate::ml_sumcheck::protocol::IPForMLSumcheck; -use crate::rng::{Blake2s512Rng, FeedableRNG}; -use ark_ff::Field; -use ark_std::marker::PhantomData; -use ark_std::vec::Vec; +use crate::ml_sumcheck::{ + data_structures::{ListOfProductsOfPolynomials, PolynomialInfo}, + protocol::{prover::ProverMsg, verifier::SubClaim, IPForMLSumcheck}, +}; +use ark_ff::PrimeField; +use ark_sponge::{Absorb, CryptographicSponge}; +use ark_std::{marker::PhantomData, vec::Vec}; pub mod protocol; @@ -16,12 +15,12 @@ pub mod data_structures; mod test; /// Sumcheck for products of multilinear polynomial -pub struct MLSumcheck(#[doc(hidden)] PhantomData); +pub struct MLSumcheck(#[doc(hidden)] PhantomData); /// proof generated by prover pub type Proof = Vec>; -impl MLSumcheck { +impl MLSumcheck { /// extract sum from the proof pub fn extract_sum(proof: &Proof) -> F { proof[0].evaluations[0] + proof[0].evaluations[1] @@ -29,19 +28,24 @@ impl MLSumcheck { /// generate proof of the sum of polynomial over {0,1}^`num_vars` /// - /// The polynomial is represented by a list of products of polynomials along with its coefficient that is meant to be added together. + /// The polynomial is represented by a list of products of polynomials along + /// with its coefficient that is meant to be added together. /// - /// This data structure of the polynomial is a list of list of `(coefficient, DenseMultilinearExtension)`. + /// This data structure of the polynomial is a list of list of + /// `(coefficient, DenseMultilinearExtension)`. /// * Number of products n = `polynomial.products.len()`, - /// * Number of multiplicands of ith product m_i = `polynomial.products[i].1.len()`, + /// * Number of multiplicands of ith product m_i = + /// `polynomial.products[i].1.len()`, /// * Coefficient of ith product c_i = `polynomial.products[i].0` /// /// The resulting polynomial is /// /// $$\sum_{i=0}^{n}C_i\cdot\prod_{j=0}^{m_i}P_{ij}$$ - pub fn prove(polynomial: &ListOfProductsOfPolynomials) -> Result, crate::Error> { - let mut fs_rng = Blake2s512Rng::setup(); - fs_rng.feed(&polynomial.info())?; + pub fn prove( + sponge: &mut S, + polynomial: &ListOfProductsOfPolynomials, + ) -> Result, crate::Error> { + sponge.absorb(&polynomial.info()); let mut prover_state = IPForMLSumcheck::prover_init(&polynomial); let mut verifier_msg = None; @@ -50,28 +54,28 @@ impl MLSumcheck { let (prover_msg, prover_state_new) = IPForMLSumcheck::prove_round(prover_state, &verifier_msg); prover_state = prover_state_new; - fs_rng.feed(&prover_msg)?; + sponge.absorb(&prover_msg); prover_msgs.push(prover_msg); - verifier_msg = Some(IPForMLSumcheck::sample_round(&mut fs_rng)); + verifier_msg = Some(IPForMLSumcheck::sample_round(sponge)); } Ok(prover_msgs) } /// verify the claimed sum using the proof - pub fn verify( + pub fn verify( + sponge: &mut S, polynomial_info: &PolynomialInfo, claimed_sum: F, proof: &Proof, ) -> Result, crate::Error> { - let mut fs_rng = Blake2s512Rng::setup(); - fs_rng.feed(polynomial_info)?; + sponge.absorb(&polynomial_info); let mut verifier_state = IPForMLSumcheck::verifier_init(polynomial_info); for i in 0..polynomial_info.num_variables { let prover_msg = proof.get(i).expect("proof is incomplete"); - fs_rng.feed(prover_msg)?; + sponge.absorb(&prover_msg); let result = - IPForMLSumcheck::verify_round((*prover_msg).clone(), verifier_state, &mut fs_rng); + IPForMLSumcheck::verify_round((*prover_msg).clone(), verifier_state, sponge); verifier_state = result.1; } diff --git a/src/ml_sumcheck/protocol/mod.rs b/src/ml_sumcheck/protocol/mod.rs index 05d3c25..ff7240d 100644 --- a/src/ml_sumcheck/protocol/mod.rs +++ b/src/ml_sumcheck/protocol/mod.rs @@ -1,13 +1,13 @@ //! Interactive Proof Protocol used for Multilinear Sumcheck -use ark_ff::Field; +use ark_ff::PrimeField; use ark_std::marker::PhantomData; pub mod prover; pub mod verifier; pub use crate::ml_sumcheck::data_structures::{ListOfProductsOfPolynomials, PolynomialInfo}; /// Interactive Proof for Multilinear Sumcheck -pub struct IPForMLSumcheck { +pub struct IPForMLSumcheck { #[doc(hidden)] _marker: PhantomData, } diff --git a/src/ml_sumcheck/protocol/prover.rs b/src/ml_sumcheck/protocol/prover.rs index 2150d8a..3590e4a 100644 --- a/src/ml_sumcheck/protocol/prover.rs +++ b/src/ml_sumcheck/protocol/prover.rs @@ -1,46 +1,63 @@ //! Prover -use crate::ml_sumcheck::data_structures::ListOfProductsOfPolynomials; -use crate::ml_sumcheck::protocol::verifier::VerifierMsg; -use crate::ml_sumcheck::protocol::IPForMLSumcheck; -use ark_ff::Field; +use crate::ml_sumcheck::{ + data_structures::ListOfProductsOfPolynomials, + protocol::{verifier::VerifierMsg, IPForMLSumcheck}, +}; +use ark_ff::PrimeField; use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Read, SerializationError, Write}; +use ark_sponge::Absorb; use ark_std::vec::Vec; /// Prover Message #[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] -pub struct ProverMsg { +pub struct ProverMsg { /// evaluations on P(0), P(1), P(2), ... pub(crate) evaluations: Vec, } + +impl Absorb for ProverMsg { + fn to_sponge_bytes(&self, dest: &mut Vec) { + self.evaluations.to_sponge_bytes(dest); + } + + fn to_sponge_field_elements(&self, dest: &mut Vec) { + self.evaluations.to_sponge_field_elements(dest); + } +} + /// Prover State -pub struct ProverState { +pub struct ProverState { /// sampled randomness given by the verifier pub randomness: Vec, - /// Stores the list of products that is meant to be added together. Each multiplicand is represented by - /// the index in flattened_ml_extensions + /// Stores the list of products that is meant to be added together. Each + /// multiplicand is represented by the index in flattened_ml_extensions pub list_of_products: Vec<(F, Vec)>, - /// Stores a list of multilinear extensions in which `self.list_of_products` points to + /// Stores a list of multilinear extensions in which `self.list_of_products` + /// points to pub flattened_ml_extensions: Vec>, num_vars: usize, max_multiplicands: usize, round: usize, } -impl IPForMLSumcheck { - /// initialize the prover to argue for the sum of polynomial over {0,1}^`num_vars` +impl IPForMLSumcheck { + /// initialize the prover to argue for the sum of polynomial over + /// {0,1}^`num_vars` /// - /// The polynomial is represented by a list of products of polynomials along with its coefficient that is meant to be added together. + /// The polynomial is represented by a list of products of polynomials along + /// with its coefficient that is meant to be added together. /// - /// This data structure of the polynomial is a list of list of `(coefficient, DenseMultilinearExtension)`. + /// This data structure of the polynomial is a list of list of + /// `(coefficient, DenseMultilinearExtension)`. /// * Number of products n = `polynomial.products.len()`, - /// * Number of multiplicands of ith product m_i = `polynomial.products[i].1.len()`, + /// * Number of multiplicands of ith product m_i = + /// `polynomial.products[i].1.len()`, /// * Coefficient of ith product c_i = `polynomial.products[i].0` /// /// The resulting polynomial is /// /// $$\sum_{i=0}^{n}C_i\cdot\prod_{j=0}^{m_i}P_{ij}$$ - /// pub fn prover_init(polynomial: &ListOfProductsOfPolynomials) -> ProverState { if polynomial.num_variables == 0 { panic!("Attempt to prove a constant.") @@ -63,7 +80,8 @@ impl IPForMLSumcheck { } } - /// receive message from verifier, generate prover message, and proceed to next round + /// receive message from verifier, generate prover message, and proceed to + /// next round /// /// Main algorithm used is from section 3.2 of [XZZPS19](https://eprint.iacr.org/2019/317.pdf#subsection.3.2). pub fn prove_round( @@ -82,10 +100,8 @@ impl IPForMLSumcheck { for multiplicand in prover_state.flattened_ml_extensions.iter_mut() { *multiplicand = multiplicand.fix_variables(&[r]); } - } else { - if prover_state.round > 0 { - panic!("verifier message is empty"); - } + } else if prover_state.round > 0 { + panic!("verifier message is empty"); } prover_state.round += 1; diff --git a/src/ml_sumcheck/protocol/verifier.rs b/src/ml_sumcheck/protocol/verifier.rs index 185edf3..535f3df 100644 --- a/src/ml_sumcheck/protocol/verifier.rs +++ b/src/ml_sumcheck/protocol/verifier.rs @@ -1,39 +1,43 @@ //! Verifier -use crate::ml_sumcheck::data_structures::PolynomialInfo; -use crate::ml_sumcheck::protocol::prover::ProverMsg; -use crate::ml_sumcheck::protocol::IPForMLSumcheck; -use ark_ff::Field; +use crate::ml_sumcheck::{ + data_structures::PolynomialInfo, + protocol::{prover::ProverMsg, IPForMLSumcheck}, +}; +use ark_ff::{Field, PrimeField}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, Read, SerializationError, Write}; -use ark_std::rand::RngCore; +use ark_sponge::CryptographicSponge; use ark_std::vec::Vec; #[derive(Clone, CanonicalSerialize, CanonicalDeserialize)] /// Verifier Message -pub struct VerifierMsg { +pub struct VerifierMsg { /// randomness sampled by verifier pub randomness: F, } /// Verifier State -pub struct VerifierState { +pub struct VerifierState { round: usize, nv: usize, max_multiplicands: usize, finished: bool, - /// a list storing the univariate polynomial in evaluation form sent by the prover at each round + /// a list storing the univariate polynomial in evaluation form sent by the + /// prover at each round polynomials_received: Vec>, /// a list storing the randomness sampled by the verifier at each round randomness: Vec, } + /// Subclaim when verifier is convinced -pub struct SubClaim { - /// the multi-dimensional point that this multilinear extension is evaluated to +pub struct SubClaim { + /// the multi-dimensional point that this multilinear extension is evaluated + /// to pub point: Vec, /// the expected evaluation pub expected_evaluation: F, } -impl IPForMLSumcheck { +impl IPForMLSumcheck { /// initialize the verifier pub fn verifier_init(index_info: &PolynomialInfo) -> VerifierState { VerifierState { @@ -48,22 +52,24 @@ impl IPForMLSumcheck { /// Run verifier at current round, given prover message /// - /// Normally, this function should perform actual verification. Instead, `verify_round` only samples - /// and stores randomness and perform verifications altogether in `check_and_generate_subclaim` at + /// Normally, this function should perform actual verification. Instead, + /// `verify_round` only samples and stores randomness and perform + /// verifications altogether in `check_and_generate_subclaim` at /// the last step. - pub fn verify_round( + pub fn verify_round( prover_msg: ProverMsg, mut verifier_state: VerifierState, - rng: &mut R, + sponge: &mut S, ) -> (Option>, VerifierState) { if verifier_state.finished { panic!("Incorrect verifier state: Verifier is already finished."); } - // Now, verifier should check if the received P(0) + P(1) = expected. The check is moved to - // `check_and_generate_subclaim`, and will be done after the last round. + // Now, verifier should check if the received P(0) + P(1) = expected. The check + // is moved to `check_and_generate_subclaim`, and will be done after the + // last round. - let msg = Self::sample_round(rng); + let msg = Self::sample_round(sponge); verifier_state.randomness.push(msg.randomness); verifier_state .polynomials_received @@ -84,8 +90,9 @@ impl IPForMLSumcheck { /// verify the sumcheck phase, and generate the subclaim /// - /// If the asserted sum is correct, then the multilinear polynomial evaluated at `subclaim.point` - /// is `subclaim.expected_evaluation`. Otherwise, it is highly unlikely that those two will be equal. + /// If the asserted sum is correct, then the multilinear polynomial + /// evaluated at `subclaim.point` is `subclaim.expected_evaluation`. + /// Otherwise, it is highly unlikely that those two will be equal. /// Larger field size guarantees smaller soundness error. pub fn check_and_generate_subclaim( verifier_state: VerifierState, @@ -122,17 +129,18 @@ impl IPForMLSumcheck { /// simulate a verifier message without doing verification /// - /// Given the same calling context, `random_oracle_round` output exactly the same message as - /// `verify_round` + /// Given the same calling context, `random_oracle_round` output exactly the + /// same message as `verify_round` #[inline] - pub fn sample_round(rng: &mut R) -> VerifierMsg { + pub fn sample_round(sponge: &mut S) -> VerifierMsg { VerifierMsg { - randomness: F::rand(rng), + randomness: sponge.squeeze_field_elements(1)[0], } } } -/// interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this polynomial at `eval_at`. +/// interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this +/// polynomial at `eval_at`. pub(crate) fn interpolate_uni_poly(p_i: &[F], eval_at: F) -> F { let mut result = F::zero(); let mut i = F::zero(); diff --git a/src/ml_sumcheck/test.rs b/src/ml_sumcheck/test.rs index 1f71a3d..f2214eb 100644 --- a/src/ml_sumcheck/test.rs +++ b/src/ml_sumcheck/test.rs @@ -1,16 +1,22 @@ -use crate::ml_sumcheck::data_structures::ListOfProductsOfPolynomials; -use crate::ml_sumcheck::protocol::IPForMLSumcheck; -use crate::ml_sumcheck::MLSumcheck; -use ark_ff::Field; +use crate::ml_sumcheck::{ + data_structures::ListOfProductsOfPolynomials, protocol::IPForMLSumcheck, MLSumcheck, +}; +use ark_bls12_381::Fr; +use ark_ff::PrimeField; use ark_poly::{DenseMultilinearExtension, MultilinearExtension}; -use ark_std::rand::Rng; -use ark_std::rand::RngCore; -use ark_std::rc::Rc; -use ark_std::vec::Vec; -use ark_std::{test_rng, UniformRand}; -use ark_test_curves::bls12_381::Fr; - -fn random_product( +use ark_sponge::{ + poseidon::{find_poseidon_ark_and_mds, PoseidonConfig, PoseidonSponge}, + CryptographicSponge, +}; +use ark_std::{ + rand::{Rng, RngCore}, + rc::Rc, + test_rng, + vec::Vec, + UniformRand, +}; + +fn random_product( nv: usize, num_multiplicands: usize, rng: &mut R, @@ -40,7 +46,7 @@ fn random_product( ); } -fn random_list_of_products( +fn random_list_of_products( nv: usize, num_multiplicands_range: (usize, usize), num_products: usize, @@ -64,8 +70,11 @@ fn test_polynomial(nv: usize, num_multiplicands_range: (usize, usize), num_produ let (poly, asserted_sum) = random_list_of_products::(nv, num_multiplicands_range, num_products, &mut rng); let poly_info = poly.info(); - let proof = MLSumcheck::prove(&poly).expect("fail to prove"); - let subclaim = MLSumcheck::verify(&poly_info, asserted_sum, &proof).expect("fail to verify"); + let sponge_param = poseidon_parameters(); + let mut sponge = PoseidonSponge::new(&sponge_param); + let proof = MLSumcheck::prove(&mut sponge.clone(), &poly).expect("fail to prove"); + let subclaim = + MLSumcheck::verify(&mut sponge, &poly_info, asserted_sum, &proof).expect("fail to verify"); assert!( poly.evaluate(&subclaim.point) == subclaim.expected_evaluation, "wrong subclaim" @@ -80,11 +89,12 @@ fn test_protocol(nv: usize, num_multiplicands_range: (usize, usize), num_product let mut prover_state = IPForMLSumcheck::prover_init(&poly); let mut verifier_state = IPForMLSumcheck::verifier_init(&poly_info); let mut verifier_msg = None; + let mut sponge = PoseidonSponge::new(&poseidon_parameters()); for _ in 0..poly.num_variables { let result = IPForMLSumcheck::prove_round(prover_state, &verifier_msg); prover_state = result.1; let (verifier_msg2, verifier_state2) = - IPForMLSumcheck::verify_round(result.0, verifier_state, &mut rng); + IPForMLSumcheck::verify_round(result.0, verifier_state, &mut sponge); verifier_msg = verifier_msg2; verifier_state = verifier_state2; } @@ -129,14 +139,15 @@ fn zero_polynomial_should_error() { fn test_extract_sum() { let mut rng = test_rng(); let (poly, asserted_sum) = random_list_of_products::(8, (3, 4), 3, &mut rng); - - let proof = MLSumcheck::prove(&poly).expect("fail to prove"); + let sponge_param = poseidon_parameters(); + let mut sponge = PoseidonSponge::new(&sponge_param); + let proof = MLSumcheck::prove(&mut sponge, &poly).expect("fail to prove"); assert_eq!(MLSumcheck::extract_sum(&proof), asserted_sum); } #[test] -/// Test that the memory usage of shared-reference is linear to number of unique MLExtensions -/// instead of total number of multiplicands. +/// Test that the memory usage of shared-reference is linear to number of unique +/// MLExtensions instead of total number of multiplicands. fn test_shared_reference() { let mut rng = test_rng(); let ml_extensions: Vec<_> = (0..5) @@ -181,11 +192,44 @@ fn test_shared_reference() { drop(prover); let poly_info = poly.info(); - let proof = MLSumcheck::prove(&poly).expect("fail to prove"); + let sponge_param = poseidon_parameters(); + let proof = + MLSumcheck::prove(&mut PoseidonSponge::new(&sponge_param), &poly).expect("fail to prove"); let asserted_sum = MLSumcheck::extract_sum(&proof); - let subclaim = MLSumcheck::verify(&poly_info, asserted_sum, &proof).expect("fail to verify"); + let subclaim = MLSumcheck::verify( + &mut PoseidonSponge::new(&sponge_param), + &poly_info, + asserted_sum, + &proof, + ) + .expect("fail to verify"); assert!( poly.evaluate(&subclaim.point) == subclaim.expected_evaluation, "wrong subclaim" ); } + +pub(crate) fn poseidon_parameters() -> PoseidonConfig { + let full_rounds = 8; + let partial_rounds = 31; + let alpha = 5; + let rate = 2; + + let (ark, mds) = find_poseidon_ark_and_mds::( + ::MODULUS_BIT_SIZE as u64, + rate, + full_rounds, + partial_rounds, + 0, + ); + + PoseidonConfig::new( + full_rounds as usize, + partial_rounds as usize, + alpha, + mds, + ark, + rate, + 1, + ) +} diff --git a/src/rng.rs b/src/rng.rs deleted file mode 100644 index 074e8ef..0000000 --- a/src/rng.rs +++ /dev/null @@ -1,176 +0,0 @@ -//! Fiat-Shamir Random Generator -use ark_serialize::CanonicalSerialize; -use ark_std::rand::RngCore; -use ark_std::vec::Vec; -use blake2::{Blake2s, Digest}; -/// Random Field Element Generator where randomness `feed` adds entropy for the output. -/// -/// Implementation should support all types of input that has `ToBytes` trait. -/// -/// Same sequence of `feed` and `get` call should yield same result! -pub trait FeedableRNG: RngCore { - /// Error type - type Error: ark_std::error::Error + From; - /// Setup should not have any parameter. - fn setup() -> Self; - - /// Provide randomness for the generator, given the message. - fn feed(&mut self, msg: &M) -> Result<(), Self::Error>; -} - -/// 512-bits digest hash pseudorandom generator -pub struct Blake2s512Rng { - /// current digest instance - current_digest: Blake2s, -} - -impl FeedableRNG for Blake2s512Rng { - type Error = crate::Error; - - fn setup() -> Self { - Self { - current_digest: Blake2s::new(), - } - } - - fn feed(&mut self, msg: &M) -> Result<(), Self::Error> { - let mut buf = Vec::new(); - msg.serialize(&mut buf)?; - self.current_digest.update(&buf); - Ok(()) - } -} - -impl RngCore for Blake2s512Rng { - fn next_u32(&mut self) -> u32 { - let mut temp = [0u8; 4]; - self.fill_bytes(&mut temp); - u32::from_le_bytes(temp) - } - - fn next_u64(&mut self) -> u64 { - let mut temp = [0u8; 8]; - self.fill_bytes(&mut temp); - u64::from_le_bytes(temp) - } - - fn fill_bytes(&mut self, dest: &mut [u8]) { - self.try_fill_bytes(dest).unwrap() - } - - fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), ark_std::rand::Error> { - let mut digest = self.current_digest.clone(); - let mut output = digest.finalize(); - let output_size = Blake2s::output_size(); - let mut ptr = 0; - let mut digest_ptr = 0; - while ptr < dest.len() { - dest[ptr] = output[digest_ptr]; - ptr += 1usize; - digest_ptr += 1; - if digest_ptr == output_size { - self.current_digest.update(output); - digest = self.current_digest.clone(); - output = digest.finalize(); - digest_ptr = 0; - } - } - self.current_digest.update(output); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use ark_ff::Field; - use ark_std::rand::Rng; - use ark_std::rand::RngCore; - - use crate::rng::{Blake2s512Rng, FeedableRNG}; - use ark_serialize::{CanonicalSerialize, SerializationError, Write}; - use ark_std::test_rng; - use ark_std::vec::Vec; - use ark_test_curves::bls12_381::Fr; - - /// Special type of input used for test. - #[derive(CanonicalSerialize)] - struct TestMessage { - data: Vec, - } - - impl TestMessage { - fn rand(rng: &mut R, size: usize) -> TestMessage { - let mut data = Vec::with_capacity(size); - data.resize_with(size, || rng.gen()); - TestMessage { data } - } - } - - /// Test that same sequence of `feed` and `get` call should yield same result. - /// - /// * `rng_test`: the pseudorandom RNG to be tested - /// * `num_iterations`: number of independent tests - fn test_deterministic_pseudorandom_generator(num_iterations: u32) - where - F: Field, - G: FeedableRNG, - { - let mut rng = test_rng(); - for _ in 0..num_iterations { - // generate write messages - let mut msgs = Vec::with_capacity(7); - msgs.resize_with(7, || TestMessage::rand(&mut rng, 128)); - - let rw_sequence = |r: &mut G, o: &mut Vec| { - r.feed(&msgs[0]).unwrap(); - o.push(F::rand(r)); - o.push(F::rand(r)); - r.feed(&msgs[1]).unwrap(); - r.feed(&msgs[2]).unwrap(); - o.push(F::rand(r)); - r.feed(&msgs[3]).unwrap(); - o.push(F::rand(r)); - o.push(F::rand(r)); - r.feed(&msgs[4]).unwrap(); - r.feed(&msgs[5]).unwrap(); - r.feed(&msgs[6]).unwrap(); - let f1 = F::rand(r); - o.push(f1); - let f2 = F::rand(r); - o.push(f2); - assert_ne!(f1, f2, "Producing same element"); - o.push(F::rand(r)); - o.push(F::rand(r)); - // edge case: not aligned bytes - let mut buf1 = [0u8; 127]; - let mut buf2 = [0u8; 128]; - let mut buf3 = [0u8; 777]; - r.fill_bytes(&mut buf1); - r.feed(&buf1.to_vec()).unwrap(); - r.fill_bytes(&mut buf2); - r.fill_bytes(&mut buf3); - assert_ne!(&buf2[..64], &buf3[..64]); - o.push(F::rand(r)); - r.feed(&buf3.to_vec()).unwrap(); - o.push(F::rand(r)); - }; - let mut rng_test = G::setup(); - let mut random_output = Vec::with_capacity(8); - - rw_sequence(&mut rng_test, &mut random_output); - - // test that it is deterministic - for _ in 0..10 { - let mut another_rng_test = G::setup(); - let mut another_random_output = Vec::with_capacity(8); - rw_sequence(&mut another_rng_test, &mut another_random_output); - assert_eq!(random_output, another_random_output); - } - } - } - - #[test] - fn test_blake2s_hashing() { - test_deterministic_pseudorandom_generator::(5) - } -} From 2c12d26629156fdd10a9bdfddc2fba33a109885a Mon Sep 17 00:00:00 2001 From: Conghao Shen Date: Wed, 18 May 2022 08:19:38 -0700 Subject: [PATCH 3/3] nit --- src/ml_sumcheck/mod.rs | 4 ++-- src/ml_sumcheck/protocol/prover.rs | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/ml_sumcheck/mod.rs b/src/ml_sumcheck/mod.rs index 3002763..dc7ba4c 100644 --- a/src/ml_sumcheck/mod.rs +++ b/src/ml_sumcheck/mod.rs @@ -79,9 +79,9 @@ impl MLSumcheck { verifier_state = result.1; } - Ok(IPForMLSumcheck::check_and_generate_subclaim( + IPForMLSumcheck::check_and_generate_subclaim( verifier_state, claimed_sum, - )?) + ) } } diff --git a/src/ml_sumcheck/protocol/prover.rs b/src/ml_sumcheck/protocol/prover.rs index 3590e4a..569462d 100644 --- a/src/ml_sumcheck/protocol/prover.rs +++ b/src/ml_sumcheck/protocol/prover.rs @@ -114,8 +114,7 @@ impl IPForMLSumcheck { let nv = prover_state.num_vars; let degree = prover_state.max_multiplicands; // the degree of univariate polynomial sent by prover at this round - let mut products_sum = Vec::with_capacity(degree + 1); - products_sum.resize(degree + 1, F::zero()); + let mut products_sum = vec![F::zero(); degree + 1]; // generate sum for b in 0..1 << (nv - i) {