From 2f5ec67b965cc751e617f507b3fe9e271e221e7d Mon Sep 17 00:00:00 2001 From: Jay White Date: Sat, 9 Nov 2024 11:16:40 -0500 Subject: [PATCH 1/9] refactor: remove unused `randomness` field from `ProverState --- .../proof-of-sql/src/proof_primitive/sumcheck/prover_round.rs | 4 +--- .../proof-of-sql/src/proof_primitive/sumcheck/prover_state.rs | 3 --- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/crates/proof-of-sql/src/proof_primitive/sumcheck/prover_round.rs b/crates/proof-of-sql/src/proof_primitive/sumcheck/prover_round.rs index 9ceb42b00..6af59c08b 100644 --- a/crates/proof-of-sql/src/proof_primitive/sumcheck/prover_round.rs +++ b/crates/proof-of-sql/src/proof_primitive/sumcheck/prover_round.rs @@ -19,10 +19,8 @@ pub fn prove_round(prover_state: &mut ProverState, r_maybe: &Optio "first round should be prover first." ); - prover_state.randomness.push(*r); - // fix argument - let r_as_field = prover_state.randomness[prover_state.round - 1]; + let r_as_field = *r; if_rayon!( prover_state.flattened_ml_extensions.par_iter_mut(), prover_state.flattened_ml_extensions.iter_mut() diff --git a/crates/proof-of-sql/src/proof_primitive/sumcheck/prover_state.rs b/crates/proof-of-sql/src/proof_primitive/sumcheck/prover_state.rs index 44138378c..0a506fbd6 100644 --- a/crates/proof-of-sql/src/proof_primitive/sumcheck/prover_state.rs +++ b/crates/proof-of-sql/src/proof_primitive/sumcheck/prover_state.rs @@ -8,8 +8,6 @@ use crate::base::scalar::Scalar; use alloc::vec::Vec; pub struct ProverState { - /// sampled randomness given by the verifier - pub randomness: Vec, /// Stores the list of products that is meant to be added together. Each multiplicand is represented by /// the index in `flattened_ml_extensions` pub list_of_products: Vec<(S, Vec)>, @@ -36,7 +34,6 @@ impl ProverState { .collect(); ProverState { - randomness: Vec::with_capacity(polynomial.num_variables), list_of_products: polynomial.products.clone(), flattened_ml_extensions, num_vars: polynomial.num_variables, From 38495814c124956b203c45ee5aa156e128218ac0 Mon Sep 17 00:00:00 2001 From: Jay White Date: Sat, 9 Nov 2024 11:33:16 -0500 Subject: [PATCH 2/9] refactor: make sumcheck proof be created from state rather than composite polynomial refactor: change make_sumcheck_polynomial to make_sumcheck_prover_state refactor: make make_sumcheck_prover_state standalone simplify refactor --- .../src/proof_primitive/sumcheck/mod.rs | 2 +- .../src/proof_primitive/sumcheck/proof.rs | 20 ++--- .../proof_primitive/sumcheck/proof_test.rs | 27 +++++-- .../src/sql/proof/final_round_builder.rs | 34 ++------- .../src/sql/proof/final_round_builder_test.rs | 73 ++----------------- .../src/sql/proof/make_sumcheck_state.rs | 26 +++++++ crates/proof-of-sql/src/sql/proof/mod.rs | 2 + .../proof-of-sql/src/sql/proof/query_proof.rs | 18 +++-- 8 files changed, 76 insertions(+), 126 deletions(-) create mode 100644 crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs diff --git a/crates/proof-of-sql/src/proof_primitive/sumcheck/mod.rs b/crates/proof-of-sql/src/proof_primitive/sumcheck/mod.rs index 36de38a5f..8a171629b 100644 --- a/crates/proof-of-sql/src/proof_primitive/sumcheck/mod.rs +++ b/crates/proof-of-sql/src/proof_primitive/sumcheck/mod.rs @@ -4,7 +4,7 @@ mod proof_test; pub use proof::SumcheckProof; mod prover_state; -use prover_state::ProverState; +pub(crate) use prover_state::ProverState; mod prover_round; use prover_round::prove_round; diff --git a/crates/proof-of-sql/src/proof_primitive/sumcheck/proof.rs b/crates/proof-of-sql/src/proof_primitive/sumcheck/proof.rs index b5cc82416..7a28f17b1 100644 --- a/crates/proof-of-sql/src/proof_primitive/sumcheck/proof.rs +++ b/crates/proof-of-sql/src/proof_primitive/sumcheck/proof.rs @@ -1,9 +1,6 @@ use crate::{ base::{ - polynomial::{ - interpolate_evaluations_to_reverse_coefficients, CompositePolynomial, - CompositePolynomialInfo, - }, + polynomial::{interpolate_evaluations_to_reverse_coefficients, CompositePolynomialInfo}, proof::{ProofError, Transcript}, scalar::Scalar, }, @@ -31,19 +28,16 @@ impl SumcheckProof { pub fn create( transcript: &mut impl Transcript, evaluation_point: &mut [S], - polynomial: &CompositePolynomial, + mut state: ProverState, ) -> Self { - assert_eq!(evaluation_point.len(), polynomial.num_variables); - transcript.extend_as_be([ - polynomial.max_multiplicands as u64, - polynomial.num_variables as u64, - ]); + let num_vars = state.num_vars; + assert_eq!(evaluation_point.len(), num_vars); + transcript.extend_as_be([state.max_multiplicands as u64, num_vars as u64]); // This challenge is in order to keep transcript messages grouped. (This simplifies the Solidity implementation.) transcript.scalar_challenge_as_be::(); let mut r = None; - let mut state = ProverState::create(polynomial); - let mut coefficients = Vec::with_capacity(polynomial.num_variables); - for scalar in evaluation_point.iter_mut().take(polynomial.num_variables) { + let mut coefficients = Vec::with_capacity(num_vars); + for scalar in evaluation_point.iter_mut().take(num_vars) { let round_evaluations = prove_round(&mut state, &r); let round_coefficients = interpolate_evaluations_to_reverse_coefficients(&round_evaluations); diff --git a/crates/proof-of-sql/src/proof_primitive/sumcheck/proof_test.rs b/crates/proof-of-sql/src/proof_primitive/sumcheck/proof_test.rs index 97f9ea5ab..f6410de14 100644 --- a/crates/proof-of-sql/src/proof_primitive/sumcheck/proof_test.rs +++ b/crates/proof-of-sql/src/proof_primitive/sumcheck/proof_test.rs @@ -1,15 +1,18 @@ use super::test_cases::sumcheck_test_cases; -use crate::base::{ - polynomial::{CompositePolynomial, CompositePolynomialInfo}, - proof::Transcript as _, - scalar::{test_scalar::TestScalar, Curve25519Scalar, MontScalar, Scalar}, -}; /* * Adopted from arkworks * * See third_party/license/arkworks.LICENSE */ use crate::proof_primitive::sumcheck::proof::*; +use crate::{ + base::{ + polynomial::{CompositePolynomial, CompositePolynomialInfo}, + proof::Transcript as _, + scalar::{test_scalar::TestScalar, Curve25519Scalar, MontScalar, Scalar}, + }, + proof_primitive::sumcheck::prover_state::ProverState, +}; use alloc::rc::Rc; use ark_std::UniformRand; use merlin::Transcript; @@ -29,7 +32,11 @@ fn test_create_verify_proof() { let fa = Rc::new(a_vec.to_vec()); poly.add_product([fa], Curve25519Scalar::from(1u64)); let mut transcript = Transcript::new(b"sumchecktest"); - let mut proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, &poly); + let mut proof = SumcheckProof::create( + &mut transcript, + &mut evaluation_point, + ProverState::create(&poly), + ); // verify proof let mut transcript = Transcript::new(b"sumchecktest"); @@ -130,7 +137,11 @@ fn test_polynomial(nv: usize, num_multiplicands_range: (usize, usize), num_produ // create a proof let mut transcript = Transcript::new(b"sumchecktest"); let mut evaluation_point = vec![Curve25519Scalar::zero(); poly_info.num_variables]; - let proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, &poly); + let proof = SumcheckProof::create( + &mut transcript, + &mut evaluation_point, + ProverState::create(&poly), + ); // verify proof let mut transcript = Transcript::new(b"sumchecktest"); @@ -172,7 +183,7 @@ fn we_can_verify_many_random_test_cases() { let proof = SumcheckProof::create( &mut transcript, &mut evaluation_point, - &test_case.polynomial, + ProverState::create(&test_case.polynomial), ); let mut transcript = Transcript::new(b"sumchecktest"); diff --git a/crates/proof-of-sql/src/sql/proof/final_round_builder.rs b/crates/proof-of-sql/src/sql/proof/final_round_builder.rs index 85f973256..1ce16dbd6 100644 --- a/crates/proof-of-sql/src/sql/proof/final_round_builder.rs +++ b/crates/proof-of-sql/src/sql/proof/final_round_builder.rs @@ -1,11 +1,8 @@ -use super::{ - CompositePolynomialBuilder, SumcheckRandomScalars, SumcheckSubpolynomial, - SumcheckSubpolynomialTerm, SumcheckSubpolynomialType, -}; +use super::{SumcheckSubpolynomial, SumcheckSubpolynomialTerm, SumcheckSubpolynomialType}; use crate::base::{ bit::BitDistribution, commitment::{Commitment, CommittableColumn, VecCommitmentExt}, - polynomial::{CompositePolynomial, MultilinearExtension}, + polynomial::MultilinearExtension, scalar::Scalar, }; use alloc::{boxed::Box, vec::Vec}; @@ -105,29 +102,10 @@ impl<'a, S: Scalar> FinalRoundBuilder<'a, S> { ) } - /// Given random multipliers, construct an aggregatated sumcheck polynomial from all - /// the individual subpolynomials. - #[tracing::instrument( - name = "FinalRoundBuilder::make_sumcheck_polynomial", - level = "debug", - skip_all - )] - pub fn make_sumcheck_polynomial( - &self, - scalars: &SumcheckRandomScalars, - ) -> CompositePolynomial { - let mut builder = CompositePolynomialBuilder::new( - self.num_sumcheck_variables, - &scalars.compute_entrywise_multipliers(), - ); - for (multiplier, subpoly) in scalars - .subpolynomial_multipliers - .iter() - .zip(self.sumcheck_subpolynomials.iter()) - { - subpoly.compose(&mut builder, *multiplier); - } - builder.make_composite_polynomial() + /// Produce a subpolynomial to be aggegated into sumcheck where the sum across binary + /// values of the variables is zero. + pub fn sumcheck_subpolynomials(&self) -> &[SumcheckSubpolynomial<'a, S>] { + &self.sumcheck_subpolynomials } /// Given the evaluation vector, compute evaluations of all the MLEs used in sumcheck except diff --git a/crates/proof-of-sql/src/sql/proof/final_round_builder_test.rs b/crates/proof-of-sql/src/sql/proof/final_round_builder_test.rs index 81c24fade..b1e8b9d2e 100644 --- a/crates/proof-of-sql/src/sql/proof/final_round_builder_test.rs +++ b/crates/proof-of-sql/src/sql/proof/final_round_builder_test.rs @@ -1,12 +1,8 @@ -use super::{FinalRoundBuilder, ProvableQueryResult, SumcheckRandomScalars}; -use crate::{ - base::{ - commitment::{Commitment, CommittableColumn}, - database::{Column, ColumnField, ColumnType}, - polynomial::{compute_evaluation_vector, CompositePolynomial, MultilinearExtension}, - scalar::Curve25519Scalar, - }, - sql::proof::SumcheckSubpolynomialType, +use super::{FinalRoundBuilder, ProvableQueryResult}; +use crate::base::{ + commitment::{Commitment, CommittableColumn}, + database::{Column, ColumnField, ColumnType}, + scalar::Curve25519Scalar, }; use alloc::sync::Arc; #[cfg(feature = "arrow")] @@ -16,7 +12,6 @@ use arrow::{ record_batch::RecordBatch, }; use curve25519_dalek::RistrettoPoint; -use num_traits::{One, Zero}; #[test] fn we_can_compute_commitments_for_intermediate_mles_using_a_zero_offset() { @@ -75,64 +70,6 @@ fn we_can_evaluate_pcs_proof_mles() { assert_eq!(evals, expected_evals); } -#[test] -fn we_can_form_an_aggregated_sumcheck_polynomial() { - let mle1 = [1, 2, -1]; - let mle2 = [10i64, 20, 100, 30]; - let mle3 = [2000i64, 3000, 5000, 7000]; - let mut builder = FinalRoundBuilder::new(2, Vec::new()); - builder.produce_anchored_mle(&mle1); - builder.produce_intermediate_mle(&mle2[..]); - builder.produce_intermediate_mle(&mle3[..]); - - builder.produce_sumcheck_subpolynomial( - SumcheckSubpolynomialType::Identity, - vec![(-Curve25519Scalar::one(), vec![Box::new(&mle1)])], - ); - builder.produce_sumcheck_subpolynomial( - SumcheckSubpolynomialType::Identity, - vec![(-Curve25519Scalar::from(10u64), vec![Box::new(&mle2)])], - ); - builder.produce_sumcheck_subpolynomial( - SumcheckSubpolynomialType::ZeroSum, - vec![(Curve25519Scalar::from(9876u64), vec![Box::new(&mle3)])], - ); - - let multipliers = [ - Curve25519Scalar::from(5u64), - Curve25519Scalar::from(2u64), - Curve25519Scalar::from(50u64), - Curve25519Scalar::from(25u64), - Curve25519Scalar::from(11u64), - ]; - - let mut evaluation_vector = vec![Zero::zero(); 4]; - compute_evaluation_vector(&mut evaluation_vector, &multipliers[..2]); - - let poly = builder.make_sumcheck_polynomial(&SumcheckRandomScalars::new(&multipliers, 4, 2)); - let mut expected_poly = CompositePolynomial::new(2); - let fr = (&evaluation_vector).to_sumcheck_term(2); - expected_poly.add_product( - [fr.clone(), (&mle1).to_sumcheck_term(2)], - -Curve25519Scalar::from(1u64) * multipliers[2], - ); - expected_poly.add_product( - [fr, (&mle2).to_sumcheck_term(2)], - -Curve25519Scalar::from(10u64) * multipliers[3], - ); - expected_poly.add_product( - [(&mle3).to_sumcheck_term(2)], - Curve25519Scalar::from(9876u64) * multipliers[4], - ); - let random_point = [ - Curve25519Scalar::from(123u64), - Curve25519Scalar::from(101_112_u64), - ]; - let eval = poly.evaluate(&random_point); - let expected_eval = expected_poly.evaluate(&random_point); - assert_eq!(eval, expected_eval); -} - #[cfg(feature = "arrow")] #[test] fn we_can_form_the_provable_query_result() { diff --git a/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs b/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs new file mode 100644 index 000000000..d7399c76c --- /dev/null +++ b/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs @@ -0,0 +1,26 @@ +use super::{CompositePolynomialBuilder, SumcheckRandomScalars, SumcheckSubpolynomial}; +use crate::{base::scalar::Scalar, proof_primitive::sumcheck::ProverState}; + +/// Given random multipliers, construct an aggregatated sumcheck polynomial from all +/// the individual subpolynomials. +#[tracing::instrument( + name = "query_proof::make_sumcheck_polynomial", + level = "debug", + skip_all +)] +pub fn make_sumcheck_prover_state( + subpolynomials: &[SumcheckSubpolynomial<'_, S>], + num_vars: usize, + scalars: &SumcheckRandomScalars, +) -> ProverState { + let mut builder = + CompositePolynomialBuilder::new(num_vars, &scalars.compute_entrywise_multipliers()); + for (multiplier, subpoly) in scalars + .subpolynomial_multipliers + .iter() + .zip(subpolynomials.iter()) + { + subpoly.compose(&mut builder, *multiplier); + } + ProverState::create(&builder.make_composite_polynomial()) +} diff --git a/crates/proof-of-sql/src/sql/proof/mod.rs b/crates/proof-of-sql/src/sql/proof/mod.rs index 8e191e1f7..f31e456ee 100644 --- a/crates/proof-of-sql/src/sql/proof/mod.rs +++ b/crates/proof-of-sql/src/sql/proof/mod.rs @@ -71,3 +71,5 @@ pub(crate) use first_round_builder::FirstRoundBuilder; #[cfg(all(test, feature = "arrow"))] mod provable_query_result_test; + +mod make_sumcheck_state; diff --git a/crates/proof-of-sql/src/sql/proof/query_proof.rs b/crates/proof-of-sql/src/sql/proof/query_proof.rs index 661df1e8d..e5dfba0cc 100644 --- a/crates/proof-of-sql/src/sql/proof/query_proof.rs +++ b/crates/proof-of-sql/src/sql/proof/query_proof.rs @@ -1,6 +1,7 @@ use super::{ - CountBuilder, FinalRoundBuilder, ProofCounts, ProofPlan, ProvableQueryResult, QueryResult, - SumcheckMleEvaluations, SumcheckRandomScalars, VerificationBuilder, + make_sumcheck_state::make_sumcheck_prover_state, CountBuilder, FinalRoundBuilder, ProofCounts, + ProofPlan, ProvableQueryResult, QueryResult, SumcheckMleEvaluations, SumcheckRandomScalars, + VerificationBuilder, }; use crate::{ base::{ @@ -149,15 +150,16 @@ impl QueryProof { core::iter::repeat_with(|| transcript.scalar_challenge_as_be()) .take(num_random_scalars) .collect(); - let poly = builder.make_sumcheck_polynomial(&SumcheckRandomScalars::new( - &random_scalars, - range_length, + let sumcheck_state = make_sumcheck_prover_state( + builder.sumcheck_subpolynomials(), num_sumcheck_variables, - )); + &SumcheckRandomScalars::new(&random_scalars, range_length, num_sumcheck_variables), + ); // create the sumcheck proof -- this is the main part of proving a query - let mut evaluation_point = vec![Zero::zero(); poly.num_variables]; - let sumcheck_proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, &poly); + let mut evaluation_point = vec![Zero::zero(); num_sumcheck_variables]; + let sumcheck_proof = + SumcheckProof::create(&mut transcript, &mut evaluation_point, sumcheck_state); // evaluate the MLEs used in sumcheck except for the result columns let mut evaluation_vec = vec![Zero::zero(); range_length]; From ada20dfec325e049ef2c2f36fd8e20c55ae67da0 Mon Sep 17 00:00:00 2001 From: Jay White Date: Sat, 9 Nov 2024 15:41:25 -0500 Subject: [PATCH 3/9] refactor: add helper implementations to `SumcheckSubpolynomial` --- .../src/sql/proof/query_proof_test.rs | 10 ++++----- .../src/sql/proof/sumcheck_subpolynomial.rs | 22 +++++++++++++++++++ .../src/sql/proof/verification_builder.rs | 2 +- .../sql/proof/verification_builder_test.rs | 4 ++-- .../src/sql/proof_exprs/and_expr.rs | 2 +- .../src/sql/proof_exprs/equals_expr.rs | 4 ++-- .../src/sql/proof_exprs/multiply_expr.rs | 2 +- .../src/sql/proof_exprs/or_expr.rs | 2 +- .../src/sql/proof_exprs/sign_expr.rs | 4 ++-- .../src/sql/proof_plans/filter_exec.rs | 6 ++--- .../src/sql/proof_plans/group_by_exec.rs | 6 ++--- 11 files changed, 43 insertions(+), 21 deletions(-) diff --git a/crates/proof-of-sql/src/sql/proof/query_proof_test.rs b/crates/proof-of-sql/src/sql/proof/query_proof_test.rs index 6ee0eccac..72d38e22d 100644 --- a/crates/proof-of-sql/src/sql/proof/query_proof_test.rs +++ b/crates/proof-of-sql/src/sql/proof/query_proof_test.rs @@ -89,7 +89,7 @@ impl ProofPlan for TrivialTestProofPlan { ) -> Result, ProofError> { assert_eq!(builder.consume_intermediate_mle(), S::ZERO); builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::ZeroSum, + SumcheckSubpolynomialType::ZeroSum, S::from(self.evaluation), ); Ok(TableEvaluation::new( @@ -280,7 +280,7 @@ impl ProofPlan for SquareTestProofPlan { .unwrap(); let res_eval = builder.consume_intermediate_mle(); builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::Identity, + SumcheckSubpolynomialType::Identity, res_eval - x_eval * x_eval, ); Ok(TableEvaluation::new( @@ -479,13 +479,13 @@ impl ProofPlan for DoubleSquareTestProofPlan { // poly1 builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::Identity, + SumcheckSubpolynomialType::Identity, z_eval - x_eval * x_eval, ); // poly2 builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::Identity, + SumcheckSubpolynomialType::Identity, res_eval - z_eval * z_eval, ); Ok(TableEvaluation::new( @@ -683,7 +683,7 @@ impl ProofPlan for ChallengeTestProofPlan { .unwrap(); let res_eval = builder.consume_intermediate_mle(); builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::Identity, + SumcheckSubpolynomialType::Identity, alpha * res_eval - alpha * x_eval * x_eval, ); Ok(TableEvaluation::new( diff --git a/crates/proof-of-sql/src/sql/proof/sumcheck_subpolynomial.rs b/crates/proof-of-sql/src/sql/proof/sumcheck_subpolynomial.rs index 110389c5e..3b607f725 100644 --- a/crates/proof-of-sql/src/sql/proof/sumcheck_subpolynomial.rs +++ b/crates/proof-of-sql/src/sql/proof/sumcheck_subpolynomial.rs @@ -3,6 +3,7 @@ use crate::base::{polynomial::MultilinearExtension, scalar::Scalar}; use alloc::{boxed::Box, vec::Vec}; /// The type of a sumcheck subpolynomial +#[derive(Copy, Clone, Hash, Eq, PartialEq)] pub enum SumcheckSubpolynomialType { /// The subpolynomial should be zero at every entry/row Identity, @@ -39,6 +40,27 @@ impl<'a, S: Scalar> SumcheckSubpolynomial<'a, S> { } } + pub fn subpolynomial_type(&self) -> SumcheckSubpolynomialType { + self.subpolynomial_type + } + + /// Return an iterator over the Subpolynomialterms returning a tuple with the type, coefficient, and multiplicands. + /// The multiplier parameters is multiplied by every coefficient. + pub fn iter_mul_by( + &self, + multiplier: S, + ) -> impl Iterator< + Item = ( + SumcheckSubpolynomialType, + S, + &Vec + 'a>>, + ), + > { + self.terms.iter().map(move |(coeff, multiplicands)| { + (self.subpolynomial_type, multiplier * *coeff, multiplicands) + }) + } + /// Combine the subpolynomial into a combined composite polynomial pub fn compose( &self, diff --git a/crates/proof-of-sql/src/sql/proof/verification_builder.rs b/crates/proof-of-sql/src/sql/proof/verification_builder.rs index 7385bf353..bc0743167 100644 --- a/crates/proof-of-sql/src/sql/proof/verification_builder.rs +++ b/crates/proof-of-sql/src/sql/proof/verification_builder.rs @@ -111,7 +111,7 @@ impl<'a, S: Scalar> VerificationBuilder<'a, S> { /// Produce the evaluation of a subpolynomial used in sumcheck pub fn produce_sumcheck_subpolynomial_evaluation( &mut self, - subpolynomial_type: &SumcheckSubpolynomialType, + subpolynomial_type: SumcheckSubpolynomialType, eval: S, ) { self.sumcheck_evaluation += self.subpolynomial_multipliers[self.produced_subpolynomials] diff --git a/crates/proof-of-sql/src/sql/proof/verification_builder_test.rs b/crates/proof-of-sql/src/sql/proof/verification_builder_test.rs index ba71c2087..647294312 100644 --- a/crates/proof-of-sql/src/sql/proof/verification_builder_test.rs +++ b/crates/proof-of-sql/src/sql/proof/verification_builder_test.rs @@ -41,11 +41,11 @@ fn we_build_up_a_sumcheck_polynomial_evaluation_from_subpolynomial_evaluations() Vec::new(), ); builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::ZeroSum, + SumcheckSubpolynomialType::ZeroSum, Curve25519Scalar::from(2u64), ); builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::ZeroSum, + SumcheckSubpolynomialType::ZeroSum, Curve25519Scalar::from(3u64), ); let expected_sumcheck_evaluation = subpolynomial_multipliers[0] * Curve25519Scalar::from(2u64) diff --git a/crates/proof-of-sql/src/sql/proof_exprs/and_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/and_expr.rs index 6cd6dcdcb..8f496a216 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/and_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/and_expr.rs @@ -96,7 +96,7 @@ impl ProofExpr for AndExpr { // subpolynomial: lhs_and_rhs - lhs * rhs builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::Identity, + SumcheckSubpolynomialType::Identity, lhs_and_rhs - lhs * rhs, ); diff --git a/crates/proof-of-sql/src/sql/proof_exprs/equals_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/equals_expr.rs index c3ee789e9..aded010c8 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/equals_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/equals_expr.rs @@ -160,13 +160,13 @@ pub fn verifier_evaluate_equals_zero( // subpolynomial: selection * lhs builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::Identity, + SumcheckSubpolynomialType::Identity, selection_eval * lhs_eval, ); // subpolynomial: selection_not - lhs * lhs_pseudo_inv builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::Identity, + SumcheckSubpolynomialType::Identity, selection_not_eval - lhs_eval * lhs_pseudo_inv_eval, ); diff --git a/crates/proof-of-sql/src/sql/proof_exprs/multiply_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/multiply_expr.rs index f4af9209e..1b1c203db 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/multiply_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/multiply_expr.rs @@ -98,7 +98,7 @@ impl ProofExpr for MultiplyExpr { // subpolynomial: lhs_times_rhs - lhs * rhs builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::Identity, + SumcheckSubpolynomialType::Identity, lhs_times_rhs - lhs * rhs, ); diff --git a/crates/proof-of-sql/src/sql/proof_exprs/or_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/or_expr.rs index 5e25866ce..fb6d27a3b 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/or_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/or_expr.rs @@ -138,7 +138,7 @@ pub fn verifier_evaluate_or( // subpolynomial: lhs_and_rhs - lhs * rhs builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::Identity, + SumcheckSubpolynomialType::Identity, lhs_and_rhs - *lhs * *rhs, ); diff --git a/crates/proof-of-sql/src/sql/proof_exprs/sign_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/sign_expr.rs index 0ffccaa84..a8322652b 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/sign_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/sign_expr.rs @@ -191,7 +191,7 @@ fn prove_bits_are_binary<'a, S: Scalar>( fn verify_bits_are_binary(builder: &mut VerificationBuilder, bit_evals: &[S]) { for bit_eval in bit_evals { builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::Identity, + SumcheckSubpolynomialType::Identity, *bit_eval - *bit_eval * *bit_eval, ); } @@ -257,5 +257,5 @@ fn verify_bit_decomposition( eval -= S::from(mult) * sign_eval * bit_eval; vary_index += 1; }); - builder.produce_sumcheck_subpolynomial_evaluation(&SumcheckSubpolynomialType::Identity, eval); + builder.produce_sumcheck_subpolynomial_evaluation(SumcheckSubpolynomialType::Identity, eval); } diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs index 5872564fe..5d2ea8332 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs @@ -267,19 +267,19 @@ pub(super) fn verify_filter( // sum c_star * s - d_star = 0 builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::ZeroSum, + SumcheckSubpolynomialType::ZeroSum, c_star_eval * s_eval - d_star_eval, ); // c_fold * c_star - input_ones = 0 builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::Identity, + SumcheckSubpolynomialType::Identity, c_fold_eval * c_star_eval - one_eval, ); // d_bar_fold * d_star - chi = 0 builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::Identity, + SumcheckSubpolynomialType::Identity, d_bar_fold_eval * d_star_eval - chi_eval, ); diff --git a/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs index 2ac025ff6..2d8fb8d90 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs @@ -361,19 +361,19 @@ fn verify_group_by( // sum g_in_star * sel_in * sum_in_fold - g_out_star * sum_out_bar_fold = 0 builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::ZeroSum, + SumcheckSubpolynomialType::ZeroSum, g_in_star_eval * sel_in_eval * sum_in_fold_eval - g_out_star_eval * sum_out_bar_fold_eval, ); // g_in_star * g_in_fold - input_ones = 0 builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::Identity, + SumcheckSubpolynomialType::Identity, g_in_star_eval * g_in_fold_eval - one_eval, ); // g_out_star * g_out_bar_fold - input_ones = 0 builder.produce_sumcheck_subpolynomial_evaluation( - &SumcheckSubpolynomialType::Identity, + SumcheckSubpolynomialType::Identity, g_out_star_eval * g_out_bar_fold_eval - one_eval, ); From ab5d009302b655de7f00106f6c994a749d006e55 Mon Sep 17 00:00:00 2001 From: Jay White Date: Mon, 11 Nov 2024 10:14:00 -0500 Subject: [PATCH 4/9] add ProverState::new --- .../proof_primitive/sumcheck/prover_state.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/crates/proof-of-sql/src/proof_primitive/sumcheck/prover_state.rs b/crates/proof-of-sql/src/proof_primitive/sumcheck/prover_state.rs index 0a506fbd6..27f24a82a 100644 --- a/crates/proof-of-sql/src/proof_primitive/sumcheck/prover_state.rs +++ b/crates/proof-of-sql/src/proof_primitive/sumcheck/prover_state.rs @@ -41,4 +41,22 @@ impl ProverState { round: 0, } } + pub fn new( + list_of_products: Vec<(S, Vec)>, + flattened_ml_extensions: Vec>, + num_vars: usize, + ) -> Self { + let max_multiplicands = list_of_products + .iter() + .map(|(_, product)| product.len()) + .max() + .unwrap_or(0); + ProverState { + list_of_products, + flattened_ml_extensions, + num_vars, + max_multiplicands, + round: 0, + } + } } From c5f9896f9f87ea1b6cc5bca8342662c383714867 Mon Sep 17 00:00:00 2001 From: Jay White Date: Tue, 12 Nov 2024 09:40:42 -0500 Subject: [PATCH 5/9] refactor: make make_sumcheck_prover_state standalone --- .../src/sql/proof/make_sumcheck_state.rs | 80 ++++++++++++++++--- .../proof-of-sql/src/sql/proof/query_proof.rs | 4 +- 2 files changed, 71 insertions(+), 13 deletions(-) diff --git a/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs b/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs index d7399c76c..cbeebeded 100644 --- a/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs +++ b/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs @@ -1,5 +1,47 @@ -use super::{CompositePolynomialBuilder, SumcheckRandomScalars, SumcheckSubpolynomial}; -use crate::{base::scalar::Scalar, proof_primitive::sumcheck::ProverState}; +use super::{SumcheckRandomScalars, SumcheckSubpolynomial, SumcheckSubpolynomialType}; +use crate::{ + base::{polynomial::MultilinearExtension, scalar::Scalar}, + proof_primitive::sumcheck::ProverState, +}; +use alloc::vec::Vec; + +struct FlattenedMLEBuilder<'a, S: Scalar> { + multiplicand_count: usize, + all_ml_extensions: Vec<&'a dyn MultilinearExtension>, + entrywise_multipliers: Option>, + num_vars: usize, +} +impl<'a, S: Scalar> FlattenedMLEBuilder<'a, S> { + fn new(entrywise_multipliers: Option>, num_vars: usize) -> Self { + Self { + multiplicand_count: entrywise_multipliers.is_some().into(), + all_ml_extensions: Vec::new(), + entrywise_multipliers, + num_vars, + } + } + fn position_or_insert(&mut self, multiplicand: &'a dyn MultilinearExtension) -> usize { + self.all_ml_extensions.push(multiplicand); + self.multiplicand_count += 1; + self.multiplicand_count - 1 + } + #[tracing::instrument( + name = "FlattenedMLEBuilder::flattened_ml_extensions", + level = "debug", + skip_all + )] + fn flattened_ml_extensions(self) -> Vec> { + self.entrywise_multipliers + .into_iter() + .map(|mle| (&mle).to_sumcheck_term(self.num_vars).as_ref().clone()) + .chain( + self.all_ml_extensions + .iter() + .map(|mle| mle.to_sumcheck_term(self.num_vars).as_ref().clone()), + ) + .collect() + } +} /// Given random multipliers, construct an aggregatated sumcheck polynomial from all /// the individual subpolynomials. @@ -13,14 +55,32 @@ pub fn make_sumcheck_prover_state( num_vars: usize, scalars: &SumcheckRandomScalars, ) -> ProverState { - let mut builder = - CompositePolynomialBuilder::new(num_vars, &scalars.compute_entrywise_multipliers()); - for (multiplier, subpoly) in scalars + let needs_entrywise_multipliers = subpolynomials + .iter() + .any(|s| matches!(s.subpolynomial_type(), SumcheckSubpolynomialType::Identity)); + let all_terms = scalars .subpolynomial_multipliers .iter() - .zip(subpolynomials.iter()) - { - subpoly.compose(&mut builder, *multiplier); - } - ProverState::create(&builder.make_composite_polynomial()) + .zip(subpolynomials) + .flat_map(|(multiplier, terms)| terms.iter_mul_by(*multiplier)); + let mut builder = FlattenedMLEBuilder::new( + needs_entrywise_multipliers.then(|| scalars.compute_entrywise_multipliers()), + num_vars, + ); + let list_of_products = all_terms + .map(|(ty, coeff, term)| { + ( + coeff, + term.iter() + .map(|multiplicand| builder.position_or_insert(multiplicand.as_ref())) + .chain(matches!(ty, SumcheckSubpolynomialType::Identity).then_some(0)) + .collect(), + ) + }) + .collect(); + ProverState::new( + list_of_products, + builder.flattened_ml_extensions(), + num_vars, + ) } diff --git a/crates/proof-of-sql/src/sql/proof/query_proof.rs b/crates/proof-of-sql/src/sql/proof/query_proof.rs index e5dfba0cc..d605e0a09 100644 --- a/crates/proof-of-sql/src/sql/proof/query_proof.rs +++ b/crates/proof-of-sql/src/sql/proof/query_proof.rs @@ -275,9 +275,7 @@ impl QueryProof { // verify sumcheck up to the evaluation check let poly_info = CompositePolynomialInfo { - // This needs to be at least 2 since `CompositePolynomialBuilder::make_composite_polynomial` - // always adds a degree 2 term. - max_multiplicands: core::cmp::max(counts.sumcheck_max_multiplicands, 2), + max_multiplicands: counts.sumcheck_max_multiplicands, num_variables: num_sumcheck_variables, }; let subclaim = self.sumcheck_proof.verify_without_evaluation( From fbcc240ee3c5e724cfdbe1b1c158b9d178390d9c Mon Sep 17 00:00:00 2001 From: Jay White Date: Tue, 12 Nov 2024 09:41:20 -0500 Subject: [PATCH 6/9] refactor: drop dead code --- .../sql/proof/composite_polynomial_builder.rs | 124 ------------------ .../composite_polynomial_builder_test.rs | 119 ----------------- crates/proof-of-sql/src/sql/proof/mod.rs | 5 - .../src/sql/proof/sumcheck_subpolynomial.rs | 18 --- 4 files changed, 266 deletions(-) delete mode 100644 crates/proof-of-sql/src/sql/proof/composite_polynomial_builder.rs delete mode 100644 crates/proof-of-sql/src/sql/proof/composite_polynomial_builder_test.rs diff --git a/crates/proof-of-sql/src/sql/proof/composite_polynomial_builder.rs b/crates/proof-of-sql/src/sql/proof/composite_polynomial_builder.rs deleted file mode 100644 index c5e6878df..000000000 --- a/crates/proof-of-sql/src/sql/proof/composite_polynomial_builder.rs +++ /dev/null @@ -1,124 +0,0 @@ -use crate::base::{ - if_rayon, - map::IndexMap, - polynomial::{CompositePolynomial, MultilinearExtension}, - scalar::Scalar, -}; -use alloc::{boxed::Box, rc::Rc, vec, vec::Vec}; -use core::{ffi::c_void, iter}; -use num_traits::{One, Zero}; -#[cfg(feature = "rayon")] -use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; - -// Build up a composite polynomial from individual MLE expressions -pub struct CompositePolynomialBuilder { - num_sumcheck_variables: usize, - fr_multiplicands_degree1: Vec, - fr_multiplicands_rest: Vec<(S, Vec>>)>, - zerosum_multiplicands: Vec<(S, Vec>>)>, - fr: Rc>, - mles: IndexMap<*const c_void, Rc>>, -} - -impl CompositePolynomialBuilder { - #[allow( - clippy::missing_panics_doc, - reason = "The assertion ensures that the length of 'fr' does not exceed the allowable range based on 'num_sumcheck_variables', making the panic clear from context." - )] - pub fn new(num_sumcheck_variables: usize, fr: &[S]) -> Self { - assert!(1 << num_sumcheck_variables >= fr.len()); - Self { - num_sumcheck_variables, - fr_multiplicands_degree1: vec![Zero::zero(); fr.len()], - fr_multiplicands_rest: vec![], - zerosum_multiplicands: vec![], - fr: fr.to_sumcheck_term(num_sumcheck_variables), - mles: IndexMap::default(), - } - } - - /// Produce a polynomial term of the form - /// `mult * f_r(X1, .., Xr) * term1(X1, ..., Xr) * ... * termK(X1, ..., Xr)` - /// where `f_r` is an MLE of random scalars - pub fn produce_fr_multiplicand( - &mut self, - mult: &S, - terms: &[Box + '_>], - ) { - if terms.is_empty() { - if_rayon!( - self.fr_multiplicands_degree1 - .par_iter_mut() - .with_min_len(crate::base::slice_ops::MIN_RAYON_LEN), - self.fr_multiplicands_degree1.iter_mut() - ) - .for_each(|val| *val += *mult); - } else if terms.len() == 1 { - terms[0].mul_add(&mut self.fr_multiplicands_degree1, mult); - } else { - let multiplicand = self.create_multiplicand_with_deduplicated_mles(terms); - self.fr_multiplicands_rest.push((*mult, multiplicand)); - } - } - /// Produce a polynomial term of the form - /// mult * term1(X1, ..., Xr) * ... * termK(X1, ..., Xr) - #[allow( - clippy::missing_panics_doc, - reason = "The assertion guarantees that terms are not empty, which is inherently clear from the context of this function." - )] - pub fn produce_zerosum_multiplicand( - &mut self, - mult: &S, - terms: &[Box + '_>], - ) { - // There is a more efficient way of handling constant zerosum terms, - // since we know the sum will be constant * length, so this assertion should be here. - assert!(!terms.is_empty()); - let multiplicand = self.create_multiplicand_with_deduplicated_mles(terms); - self.zerosum_multiplicands.push((*mult, multiplicand)); - } - - fn create_multiplicand_with_deduplicated_mles( - &mut self, - terms: &[Box + '_>], - ) -> Vec>> { - let mut deduplicated_terms = Vec::with_capacity(terms.len()); - for term in terms { - let id = term.id(); - if let Some(cached_term) = self.mles.get(&id) { - deduplicated_terms.push(cached_term.clone()); - } else { - let new_term = term.to_sumcheck_term(self.num_sumcheck_variables); - self.mles.insert(id, new_term.clone()); - deduplicated_terms.push(new_term); - } - } - deduplicated_terms - } - - /// Create a composite polynomial that is the sum of all of the - /// produced MLE expressions - pub fn make_composite_polynomial(&self) -> CompositePolynomial { - let mut res = CompositePolynomial::new(self.num_sumcheck_variables); - res.add_product( - [ - self.fr.clone(), - (&self.fr_multiplicands_degree1).to_sumcheck_term(self.num_sumcheck_variables), - ], - One::one(), - ); - for (mult, terms) in &self.fr_multiplicands_rest { - let fr_iter = iter::once(self.fr.clone()); - let terms_iter = terms.iter().cloned(); - res.add_product(fr_iter.chain(terms_iter), *mult); - } - for (mult, terms) in &self.zerosum_multiplicands { - let terms_iter = terms.iter().cloned(); - res.add_product(terms_iter, *mult); - } - - res.annotate_trace(); - - res - } -} diff --git a/crates/proof-of-sql/src/sql/proof/composite_polynomial_builder_test.rs b/crates/proof-of-sql/src/sql/proof/composite_polynomial_builder_test.rs deleted file mode 100644 index 5f0b25b54..000000000 --- a/crates/proof-of-sql/src/sql/proof/composite_polynomial_builder_test.rs +++ /dev/null @@ -1,119 +0,0 @@ -use super::CompositePolynomialBuilder; -use crate::base::scalar::Curve25519Scalar; -use num_traits::One; - -#[test] -fn we_combine_single_degree_fr_multiplicands() { - let fr = [Curve25519Scalar::from(1u64), Curve25519Scalar::from(2u64)]; - let mle1 = [10, 20]; - let mle2 = [11, 21]; - let mut builder = CompositePolynomialBuilder::new(1, &fr); - builder.produce_fr_multiplicand(&One::one(), &[Box::new(&mle1)]); - builder.produce_fr_multiplicand(&-Curve25519Scalar::one(), &[Box::new(&mle2)]); - let p = builder.make_composite_polynomial(); - assert_eq!(p.products.len(), 1); - assert_eq!(p.flattened_ml_extensions.len(), 2); - let pt = [Curve25519Scalar::from(9_268_764_u64)]; - let m0 = Curve25519Scalar::one() - pt[0]; - let m1 = pt[0]; - let eval1 = Curve25519Scalar::from(mle1[0]) * m0 + Curve25519Scalar::from(mle1[1]) * m1; - let eval2 = Curve25519Scalar::from(mle2[0]) * m0 + Curve25519Scalar::from(mle2[1]) * m1; - let eval_fr = fr[0] * m0 + fr[1] * m1; - let expected = eval_fr * (eval1 - eval2); - assert_eq!(p.evaluate(&pt), expected); -} - -#[test] -fn we_dont_duplicate_repeated_mles() { - let fr = [Curve25519Scalar::from(1u64), Curve25519Scalar::from(2u64)]; - let mle1 = [10, 20]; - let mle2 = [11, 21]; - let mut builder = CompositePolynomialBuilder::new(1, &fr); - builder.produce_fr_multiplicand(&One::one(), &[Box::new(&mle1), Box::new(&mle1)]); - builder.produce_fr_multiplicand(&One::one(), &[Box::new(&mle1), Box::new(&mle2)]); - let p = builder.make_composite_polynomial(); - assert_eq!(p.products.len(), 3); - assert_eq!(p.flattened_ml_extensions.len(), 4); - let pt = [Curve25519Scalar::from(9_268_764_u64)]; - let m0 = Curve25519Scalar::one() - pt[0]; - let m1 = pt[0]; - let eval1 = Curve25519Scalar::from(mle1[0]) * m0 + Curve25519Scalar::from(mle1[1]) * m1; - let eval2 = Curve25519Scalar::from(mle2[0]) * m0 + Curve25519Scalar::from(mle2[1]) * m1; - let eval_fr = fr[0] * m0 + fr[1] * m1; - let expected = eval_fr * (eval1 * eval1 + eval1 * eval2); - assert_eq!(p.evaluate(&pt), expected); -} - -#[test] -fn we_can_combine_identity_with_zero_sum_polynomials() { - let fr = [Curve25519Scalar::from(1u64), Curve25519Scalar::from(2u64)]; - let mle1 = [10, 20]; - let mle2 = [11, 21]; - let mle3 = [12, 22]; - let mle4 = [13, 23]; - let mut builder = CompositePolynomialBuilder::new(1, &fr); - builder.produce_fr_multiplicand(&One::one(), &[Box::new(&mle1), Box::new(&mle2)]); - builder.produce_zerosum_multiplicand( - &-Curve25519Scalar::one(), - &[Box::new(&mle3), Box::new(&mle4)], - ); - let p = builder.make_composite_polynomial(); - assert_eq!(p.products.len(), 3); //1 for the linear term, 1 for the fr multiplicand, 1 for the zerosum multiplicand - assert_eq!(p.flattened_ml_extensions.len(), 6); //1 for fr, 1 for the linear term, and 4 for mle1-4 - let pt = [Curve25519Scalar::from(9_268_764_u64)]; - let m0 = Curve25519Scalar::one() - pt[0]; - let m1 = pt[0]; - let eval1 = Curve25519Scalar::from(mle1[0]) * m0 + Curve25519Scalar::from(mle1[1]) * m1; - let eval2 = Curve25519Scalar::from(mle2[0]) * m0 + Curve25519Scalar::from(mle2[1]) * m1; - let eval3 = Curve25519Scalar::from(mle3[0]) * m0 + Curve25519Scalar::from(mle3[1]) * m1; - let eval4 = Curve25519Scalar::from(mle4[0]) * m0 + Curve25519Scalar::from(mle4[1]) * m1; - let eval_fr = fr[0] * m0 + fr[1] * m1; - let expected = eval_fr * eval1 * eval2 - eval3 * eval4; - assert_eq!(p.evaluate(&pt), expected); -} - -#[test] -fn we_can_handle_only_an_empty_fr_multiplicand() { - let fr = [Curve25519Scalar::from(1u64), Curve25519Scalar::from(2u64)]; - let mut builder = CompositePolynomialBuilder::new(1, &fr); - builder.produce_fr_multiplicand(&Curve25519Scalar::from(17), &[]); - let p = builder.make_composite_polynomial(); - assert_eq!(p.products.len(), 1); //1 for the fr multiplicand - assert_eq!(p.flattened_ml_extensions.len(), 2); //1 for fr, 1 for the linear term - let pt = [Curve25519Scalar::from(9_268_764_u64)]; - let m0 = Curve25519Scalar::one() - pt[0]; - let m1 = pt[0]; - let eval1 = (m0 + m1) * Curve25519Scalar::from(17); - let eval_fr = fr[0] * m0 + fr[1] * m1; - let expected = eval_fr * eval1; - assert_eq!(p.evaluate(&pt), expected); -} - -#[test] -fn we_can_handle_empty_terms_with_other_terms() { - let fr = [Curve25519Scalar::from(1u64), Curve25519Scalar::from(2u64)]; - let mle1 = [10, 20]; - let mle2 = [11, 21]; - let mle3 = [12, 22]; - let mle4 = [13, 23]; - let mut builder = CompositePolynomialBuilder::new(1, &fr); - builder.produce_fr_multiplicand(&One::one(), &[Box::new(&mle1), Box::new(&mle2)]); - builder.produce_fr_multiplicand(&Curve25519Scalar::from(17), &[]); - builder.produce_zerosum_multiplicand( - &-Curve25519Scalar::one(), - &[Box::new(&mle3), Box::new(&mle4)], - ); - let p = builder.make_composite_polynomial(); - assert_eq!(p.products.len(), 3); //1 for the linear term, 1 for the fr multiplicand, 1 for the zerosum multiplicand - assert_eq!(p.flattened_ml_extensions.len(), 6); //1 for fr, 1 for the linear term, and 4 for mle1-4 - let pt = [Curve25519Scalar::from(9_268_764_u64)]; - let m0 = Curve25519Scalar::one() - pt[0]; - let m1 = pt[0]; - let eval1 = Curve25519Scalar::from(mle1[0]) * m0 + Curve25519Scalar::from(mle1[1]) * m1; - let eval2 = Curve25519Scalar::from(mle2[0]) * m0 + Curve25519Scalar::from(mle2[1]) * m1; - let eval3 = Curve25519Scalar::from(mle3[0]) * m0 + Curve25519Scalar::from(mle3[1]) * m1; - let eval4 = Curve25519Scalar::from(mle4[0]) * m0 + Curve25519Scalar::from(mle4[1]) * m1; - let eval_fr = fr[0] * m0 + fr[1] * m1; - let expected = eval_fr * (eval1 * eval2 + Curve25519Scalar::from(17)) - eval3 * eval4; - assert_eq!(p.evaluate(&pt), expected); -} diff --git a/crates/proof-of-sql/src/sql/proof/mod.rs b/crates/proof-of-sql/src/sql/proof/mod.rs index f31e456ee..c9c796f4d 100644 --- a/crates/proof-of-sql/src/sql/proof/mod.rs +++ b/crates/proof-of-sql/src/sql/proof/mod.rs @@ -7,11 +7,6 @@ pub(crate) use final_round_builder::FinalRoundBuilder; #[cfg(all(test, feature = "blitzar"))] mod final_round_builder_test; -mod composite_polynomial_builder; -pub(crate) use composite_polynomial_builder::CompositePolynomialBuilder; -#[cfg(test)] -mod composite_polynomial_builder_test; - mod proof_counts; pub(crate) use proof_counts::ProofCounts; diff --git a/crates/proof-of-sql/src/sql/proof/sumcheck_subpolynomial.rs b/crates/proof-of-sql/src/sql/proof/sumcheck_subpolynomial.rs index 3b607f725..1d50286e8 100644 --- a/crates/proof-of-sql/src/sql/proof/sumcheck_subpolynomial.rs +++ b/crates/proof-of-sql/src/sql/proof/sumcheck_subpolynomial.rs @@ -1,4 +1,3 @@ -use super::CompositePolynomialBuilder; use crate::base::{polynomial::MultilinearExtension, scalar::Scalar}; use alloc::{boxed::Box, vec::Vec}; @@ -60,21 +59,4 @@ impl<'a, S: Scalar> SumcheckSubpolynomial<'a, S> { (self.subpolynomial_type, multiplier * *coeff, multiplicands) }) } - - /// Combine the subpolynomial into a combined composite polynomial - pub fn compose( - &self, - composite_polynomial: &mut CompositePolynomialBuilder, - group_multiplier: S, - ) { - for (mult, term) in &self.terms { - match self.subpolynomial_type { - SumcheckSubpolynomialType::Identity => { - composite_polynomial.produce_fr_multiplicand(&(*mult * group_multiplier), term); - } - SumcheckSubpolynomialType::ZeroSum => composite_polynomial - .produce_zerosum_multiplicand(&(*mult * group_multiplier), term), - } - } - } } From fef1cf1e4ad4ea3aae3273817f8d7cf0cd5ccd81 Mon Sep 17 00:00:00 2001 From: Jay White Date: Mon, 11 Nov 2024 15:29:55 -0500 Subject: [PATCH 7/9] deduplicate mles --- .../src/sql/proof/make_sumcheck_state.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs b/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs index cbeebeded..9d96d5fc0 100644 --- a/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs +++ b/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs @@ -1,15 +1,17 @@ use super::{SumcheckRandomScalars, SumcheckSubpolynomial, SumcheckSubpolynomialType}; use crate::{ - base::{polynomial::MultilinearExtension, scalar::Scalar}, + base::{map::IndexMap, polynomial::MultilinearExtension, scalar::Scalar}, proof_primitive::sumcheck::ProverState, }; use alloc::vec::Vec; +use core::ffi::c_void; struct FlattenedMLEBuilder<'a, S: Scalar> { multiplicand_count: usize, all_ml_extensions: Vec<&'a dyn MultilinearExtension>, entrywise_multipliers: Option>, num_vars: usize, + lookup: IndexMap<*const c_void, usize>, } impl<'a, S: Scalar> FlattenedMLEBuilder<'a, S> { fn new(entrywise_multipliers: Option>, num_vars: usize) -> Self { @@ -18,12 +20,15 @@ impl<'a, S: Scalar> FlattenedMLEBuilder<'a, S> { all_ml_extensions: Vec::new(), entrywise_multipliers, num_vars, + lookup: IndexMap::default(), } } fn position_or_insert(&mut self, multiplicand: &'a dyn MultilinearExtension) -> usize { - self.all_ml_extensions.push(multiplicand); - self.multiplicand_count += 1; - self.multiplicand_count - 1 + *self.lookup.entry(multiplicand.id()).or_insert_with(|| { + self.all_ml_extensions.push(multiplicand); + self.multiplicand_count += 1; + self.multiplicand_count - 1 + }) } #[tracing::instrument( name = "FlattenedMLEBuilder::flattened_ml_extensions", From 0b0927f4d8e7f49ffc856dd6ae9929fe6105f5e3 Mon Sep 17 00:00:00 2001 From: Jay White Date: Tue, 12 Nov 2024 10:56:13 -0500 Subject: [PATCH 8/9] perf: optimize sumcheck terms --- .../src/sql/proof/make_sumcheck_state.rs | 16 +- crates/proof-of-sql/src/sql/proof/mod.rs | 2 + .../src/sql/proof/sumcheck_term_optimizer.rs | 143 ++++++++++++++++++ 3 files changed, 159 insertions(+), 2 deletions(-) create mode 100644 crates/proof-of-sql/src/sql/proof/sumcheck_term_optimizer.rs diff --git a/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs b/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs index 9d96d5fc0..895a8e737 100644 --- a/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs +++ b/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs @@ -1,10 +1,14 @@ -use super::{SumcheckRandomScalars, SumcheckSubpolynomial, SumcheckSubpolynomialType}; +use super::{ + sumcheck_term_optimizer::SumcheckTermOptimizer, SumcheckRandomScalars, SumcheckSubpolynomial, + SumcheckSubpolynomialType, +}; use crate::{ base::{map::IndexMap, polynomial::MultilinearExtension, scalar::Scalar}, proof_primitive::sumcheck::ProverState, }; use alloc::vec::Vec; use core::ffi::c_void; +use tracing::Level; struct FlattenedMLEBuilder<'a, S: Scalar> { multiplicand_count: usize, @@ -68,11 +72,19 @@ pub fn make_sumcheck_prover_state( .iter() .zip(subpolynomials) .flat_map(|(multiplier, terms)| terms.iter_mul_by(*multiplier)); + + // Optimization should be very fast. We put this span to double check this. There is almost no copying being done. + let span = tracing::span!(Level::DEBUG, "optimize sumcheck terms").entered(); + let optimizer = SumcheckTermOptimizer::new(all_terms, scalars.table_length); + let optimized_terms = optimizer.terms(); + let optimized_term_iter = optimized_terms.into_iter(); + span.exit(); + let mut builder = FlattenedMLEBuilder::new( needs_entrywise_multipliers.then(|| scalars.compute_entrywise_multipliers()), num_vars, ); - let list_of_products = all_terms + let list_of_products = optimized_term_iter .map(|(ty, coeff, term)| { ( coeff, diff --git a/crates/proof-of-sql/src/sql/proof/mod.rs b/crates/proof-of-sql/src/sql/proof/mod.rs index c9c796f4d..cfb50f4e5 100644 --- a/crates/proof-of-sql/src/sql/proof/mod.rs +++ b/crates/proof-of-sql/src/sql/proof/mod.rs @@ -68,3 +68,5 @@ pub(crate) use first_round_builder::FirstRoundBuilder; mod provable_query_result_test; mod make_sumcheck_state; + +mod sumcheck_term_optimizer; diff --git a/crates/proof-of-sql/src/sql/proof/sumcheck_term_optimizer.rs b/crates/proof-of-sql/src/sql/proof/sumcheck_term_optimizer.rs new file mode 100644 index 000000000..2d5bf1639 --- /dev/null +++ b/crates/proof-of-sql/src/sql/proof/sumcheck_term_optimizer.rs @@ -0,0 +1,143 @@ +use super::SumcheckSubpolynomialType; +use crate::base::{map::IndexMap, polynomial::MultilinearExtension, scalar::Scalar}; +use alloc::{boxed::Box, vec, vec::Vec}; +use core::{ + iter::{Chain, Copied, Flatten, Map}, + slice, +}; + +type SumcheckTerm<'a, S> = Vec + 'a>>; + +pub struct SumcheckTermOptimizer<'a, S: Scalar> { + merged_terms: Vec<(SumcheckSubpolynomialType, S, Vec>)>, + old_grouped_terms: Vec)>>, +} +pub struct OptimizedSumcheckTerms<'a, S: Scalar> { + old_grouped_terms: &'a Vec)>>, + new_mle_terms: Vec<(SumcheckSubpolynomialType, S, SumcheckTerm<'a, S>)>, +} + +fn merge_subquadratic_terms<'a, S: Scalar + 'a>( + maybe_constant_terms: Option)>>, + maybe_linear_terms: Option)>>, + merged_terms: &mut Vec<(SumcheckSubpolynomialType, S, Vec>)>, + term_length: usize, + ty: SumcheckSubpolynomialType, +) -> Option)>> { + let maybe_constant_sum = + maybe_constant_terms.map(|terms| terms.into_iter().map(|(_, coeff, _)| coeff).sum()); + + match (maybe_constant_sum, maybe_linear_terms) { + (Some(constant_sum), None) => { + merged_terms.push((ty, constant_sum, vec![])); + None + } + (maybe_constant_sum, Some(linear_terms)) + if maybe_constant_sum.is_some() || linear_terms.len() >= 2 => + { + let mut combined_term = vec![maybe_constant_sum.unwrap_or(S::ZERO); term_length]; + for (_, coeff, linear_term) in linear_terms { + linear_term[0].mul_add(&mut combined_term, &coeff); + } + merged_terms.push((ty, S::ONE, vec![combined_term])); + None + } + (_, maybe_linear_terms) => maybe_linear_terms, + } +} + +impl<'a, S: Scalar + 'a> SumcheckTermOptimizer<'a, S> { + pub fn new( + all_terms: impl Iterator)>, + term_length: usize, + ) -> Self { + let mut groups = all_terms.fold( + IndexMap::<_, Vec<_>>::default(), + |mut lookup, (ty, coeff, multiplicands)| { + lookup + .entry((ty, multiplicands.len().min(2))) + .or_default() + .push((ty, coeff, multiplicands)); + lookup + }, + ); + let mut merged_terms = Vec::with_capacity(2); + let old_grouped_terms = [ + SumcheckSubpolynomialType::ZeroSum, + SumcheckSubpolynomialType::Identity, + ] + .into_iter() + .flat_map(|ty| { + let maybe_constant_terms = groups.swap_remove(&(ty, 0)); + let maybe_linear_terms = groups.swap_remove(&(ty, 1)); + let maybe_superlinear_terms = groups.swap_remove(&(ty, 2)); + + let maybe_combined_terms = merge_subquadratic_terms( + maybe_constant_terms, + maybe_linear_terms, + &mut merged_terms, + term_length, + ty, + ); + + [maybe_combined_terms, maybe_superlinear_terms] + .into_iter() + .flatten() + }) + .collect(); + + Self { + merged_terms, + old_grouped_terms, + } + } +} + +impl<'a, S: Scalar + 'a> SumcheckTermOptimizer<'a, S> { + pub fn terms(&'a self) -> OptimizedSumcheckTerms<'a, S> { + OptimizedSumcheckTerms { + old_grouped_terms: &self.old_grouped_terms, + new_mle_terms: self + .merged_terms + .iter() + .map(|(ty, coeff, terms)| { + ( + *ty, + *coeff, + terms + .iter() + .map(|mle| -> Box> { Box::new(mle) }) + .collect::>(), + ) + }) + .collect(), + } + } +} + +impl<'a, S: Scalar + 'a> IntoIterator for &'a OptimizedSumcheckTerms<'a, S> { + type Item = (SumcheckSubpolynomialType, S, &'a SumcheckTerm<'a, S>); + + // Currently, `impl Trait` in associated types is unstable. We can change this to the following when it stabilizes: + // type IntoIter = impl Iterator)>; + type IntoIter = Chain< + Copied< + Flatten)>>>, + >, + Map< + slice::Iter<'a, (SumcheckSubpolynomialType, S, SumcheckTerm<'a, S>)>, + fn( + &'a (SumcheckSubpolynomialType, S, SumcheckTerm<'a, S>), + ) -> (SumcheckSubpolynomialType, S, &'a SumcheckTerm<'a, S>), + >, + >; + + fn into_iter(self) -> Self::IntoIter { + let result = self.old_grouped_terms.iter().flatten().copied().chain( + self.new_mle_terms + .iter() + .map((|(ty, coeff, terms)| (*ty, *coeff, terms)) as fn(&'a _) -> _), + ); + result + } +} From 4f34d7cbc83c1e40f4408bef04d313a55718066c Mon Sep 17 00:00:00 2001 From: Jay White Date: Tue, 12 Nov 2024 12:45:16 -0500 Subject: [PATCH 9/9] perf: avoid cloning sumcheck terms --- .../base/polynomial/multilinear_extension.rs | 17 ++++++++--------- .../src/sql/proof/make_sumcheck_state.rs | 4 ++-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs b/crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs index 99720707c..ef9a52e3c 100644 --- a/crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs +++ b/crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs @@ -1,5 +1,5 @@ use crate::base::{database::Column, if_rayon, scalar::Scalar, slice_ops}; -use alloc::{rc::Rc, vec::Vec}; +use alloc::vec::Vec; use core::ffi::c_void; use num_traits::Zero; #[cfg(feature = "rayon")] @@ -15,7 +15,7 @@ pub trait MultilinearExtension { fn mul_add(&self, res: &mut [S], multiplier: &S); /// convert the MLE to a form that can be used in sumcheck - fn to_sumcheck_term(&self, num_vars: usize) -> Rc>; + fn to_sumcheck_term(&self, num_vars: usize) -> Vec; /// pointer to identify the slice forming the MLE fn id(&self) -> *const c_void; @@ -42,18 +42,17 @@ where slice_ops::mul_add_assign(res, *multiplier, &slice_ops::slice_cast(self)); } - fn to_sumcheck_term(&self, num_vars: usize) -> Rc> { + fn to_sumcheck_term(&self, num_vars: usize) -> Vec { let values = self; let n = 1 << num_vars; assert!(n >= values.len()); - let scalars = if_rayon!(values.par_iter(), values.iter()) + if_rayon!(values.par_iter(), values.iter()) .map(Into::into) .chain(if_rayon!( rayon::iter::repeatn(Zero::zero(), n - values.len()), itertools::repeat_n(Zero::zero(), n - values.len()) )) - .collect(); - Rc::new(scalars) + .collect() } fn id(&self) -> *const c_void { @@ -72,7 +71,7 @@ macro_rules! slice_like_mle_impl { (&self[..]).mul_add(res, multiplier) } - fn to_sumcheck_term(&self, num_vars: usize) -> Rc> { + fn to_sumcheck_term(&self, num_vars: usize) -> Vec { (&self[..]).to_sumcheck_term(num_vars) } @@ -125,7 +124,7 @@ impl MultilinearExtension for &Column<'_, S> { } } - fn to_sumcheck_term(&self, num_vars: usize) -> Rc> { + fn to_sumcheck_term(&self, num_vars: usize) -> Vec { match self { Column::Boolean(c) => c.to_sumcheck_term(num_vars), Column::Scalar(c) | Column::VarChar((_, c)) | Column::Decimal75(_, _, c) => { @@ -163,7 +162,7 @@ impl MultilinearExtension for Column<'_, S> { (&self).mul_add(res, multiplier); } - fn to_sumcheck_term(&self, num_vars: usize) -> Rc> { + fn to_sumcheck_term(&self, num_vars: usize) -> Vec { (&self).to_sumcheck_term(num_vars) } diff --git a/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs b/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs index 895a8e737..6b04d3637 100644 --- a/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs +++ b/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs @@ -42,11 +42,11 @@ impl<'a, S: Scalar> FlattenedMLEBuilder<'a, S> { fn flattened_ml_extensions(self) -> Vec> { self.entrywise_multipliers .into_iter() - .map(|mle| (&mle).to_sumcheck_term(self.num_vars).as_ref().clone()) + .map(|mle| (&mle).to_sumcheck_term(self.num_vars)) .chain( self.all_ml_extensions .iter() - .map(|mle| mle.to_sumcheck_term(self.num_vars).as_ref().clone()), + .map(|mle| mle.to_sumcheck_term(self.num_vars)), ) .collect() }