Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/perf reduce memory usage #368

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
17 changes: 8 additions & 9 deletions crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::base::{database::Column, if_rayon, scalar::Scalar, slice_ops};
use alloc::{rc::Rc, vec::Vec};
use alloc::vec::Vec;
use core::ffi::c_void;
use num_traits::Zero;
#[cfg(feature = "rayon")]
Expand All @@ -15,7 +15,7 @@ pub trait MultilinearExtension<S: Scalar> {
fn mul_add(&self, res: &mut [S], multiplier: &S);

/// convert the MLE to a form that can be used in sumcheck
fn to_sumcheck_term(&self, num_vars: usize) -> Rc<Vec<S>>;
fn to_sumcheck_term(&self, num_vars: usize) -> Vec<S>;

/// pointer to identify the slice forming the MLE
fn id(&self) -> *const c_void;
Expand All @@ -42,18 +42,17 @@ where
slice_ops::mul_add_assign(res, *multiplier, &slice_ops::slice_cast(self));
}

fn to_sumcheck_term(&self, num_vars: usize) -> Rc<Vec<S>> {
fn to_sumcheck_term(&self, num_vars: usize) -> Vec<S> {
let values = self;
let n = 1 << num_vars;
assert!(n >= values.len());
let scalars = if_rayon!(values.par_iter(), values.iter())
if_rayon!(values.par_iter(), values.iter())
.map(Into::into)
.chain(if_rayon!(
rayon::iter::repeatn(Zero::zero(), n - values.len()),
itertools::repeat_n(Zero::zero(), n - values.len())
))
.collect();
Rc::new(scalars)
.collect()
}

fn id(&self) -> *const c_void {
Expand All @@ -72,7 +71,7 @@ macro_rules! slice_like_mle_impl {
(&self[..]).mul_add(res, multiplier)
}

fn to_sumcheck_term(&self, num_vars: usize) -> Rc<Vec<S>> {
fn to_sumcheck_term(&self, num_vars: usize) -> Vec<S> {
(&self[..]).to_sumcheck_term(num_vars)
}

Expand Down Expand Up @@ -125,7 +124,7 @@ impl<S: Scalar> MultilinearExtension<S> for &Column<'_, S> {
}
}

fn to_sumcheck_term(&self, num_vars: usize) -> Rc<Vec<S>> {
fn to_sumcheck_term(&self, num_vars: usize) -> Vec<S> {
match self {
Column::Boolean(c) => c.to_sumcheck_term(num_vars),
Column::Scalar(c) | Column::VarChar((_, c)) | Column::Decimal75(_, _, c) => {
Expand Down Expand Up @@ -163,7 +162,7 @@ impl<S: Scalar> MultilinearExtension<S> for Column<'_, S> {
(&self).mul_add(res, multiplier);
}

fn to_sumcheck_term(&self, num_vars: usize) -> Rc<Vec<S>> {
fn to_sumcheck_term(&self, num_vars: usize) -> Vec<S> {
(&self).to_sumcheck_term(num_vars)
}

Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/proof_primitive/sumcheck/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod proof_test;
pub use proof::SumcheckProof;

mod prover_state;
use prover_state::ProverState;
pub(crate) use prover_state::ProverState;

mod prover_round;
use prover_round::prove_round;
Expand Down
20 changes: 7 additions & 13 deletions crates/proof-of-sql/src/proof_primitive/sumcheck/proof.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use crate::{
base::{
polynomial::{
interpolate_evaluations_to_reverse_coefficients, CompositePolynomial,
CompositePolynomialInfo,
},
polynomial::{interpolate_evaluations_to_reverse_coefficients, CompositePolynomialInfo},
proof::{ProofError, Transcript},
scalar::Scalar,
},
Expand Down Expand Up @@ -31,19 +28,16 @@ impl<S: Scalar> SumcheckProof<S> {
pub fn create(
transcript: &mut impl Transcript,
evaluation_point: &mut [S],
polynomial: &CompositePolynomial<S>,
mut state: ProverState<S>,
) -> Self {
assert_eq!(evaluation_point.len(), polynomial.num_variables);
transcript.extend_as_be([
polynomial.max_multiplicands as u64,
polynomial.num_variables as u64,
]);
let num_vars = state.num_vars;
assert_eq!(evaluation_point.len(), num_vars);
transcript.extend_as_be([state.max_multiplicands as u64, num_vars as u64]);
// This challenge is in order to keep transcript messages grouped. (This simplifies the Solidity implementation.)
transcript.scalar_challenge_as_be::<S>();
let mut r = None;
let mut state = ProverState::create(polynomial);
let mut coefficients = Vec::with_capacity(polynomial.num_variables);
for scalar in evaluation_point.iter_mut().take(polynomial.num_variables) {
let mut coefficients = Vec::with_capacity(num_vars);
for scalar in evaluation_point.iter_mut().take(num_vars) {
let round_evaluations = prove_round(&mut state, &r);
let round_coefficients =
interpolate_evaluations_to_reverse_coefficients(&round_evaluations);
Expand Down
27 changes: 19 additions & 8 deletions crates/proof-of-sql/src/proof_primitive/sumcheck/proof_test.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
use super::test_cases::sumcheck_test_cases;
use crate::base::{
polynomial::{CompositePolynomial, CompositePolynomialInfo},
proof::Transcript as _,
scalar::{test_scalar::TestScalar, Curve25519Scalar, MontScalar, Scalar},
};
/*
* Adopted from arkworks
*
* See third_party/license/arkworks.LICENSE
*/
use crate::proof_primitive::sumcheck::proof::*;
use crate::{
base::{
polynomial::{CompositePolynomial, CompositePolynomialInfo},
proof::Transcript as _,
scalar::{test_scalar::TestScalar, Curve25519Scalar, MontScalar, Scalar},
},
proof_primitive::sumcheck::prover_state::ProverState,
};
use alloc::rc::Rc;
use ark_std::UniformRand;
use merlin::Transcript;
Expand All @@ -29,7 +32,11 @@ fn test_create_verify_proof() {
let fa = Rc::new(a_vec.to_vec());
poly.add_product([fa], Curve25519Scalar::from(1u64));
let mut transcript = Transcript::new(b"sumchecktest");
let mut proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, &poly);
let mut proof = SumcheckProof::create(
&mut transcript,
&mut evaluation_point,
ProverState::create(&poly),
);

// verify proof
let mut transcript = Transcript::new(b"sumchecktest");
Expand Down Expand Up @@ -130,7 +137,11 @@ fn test_polynomial(nv: usize, num_multiplicands_range: (usize, usize), num_produ
// create a proof
let mut transcript = Transcript::new(b"sumchecktest");
let mut evaluation_point = vec![Curve25519Scalar::zero(); poly_info.num_variables];
let proof = SumcheckProof::create(&mut transcript, &mut evaluation_point, &poly);
let proof = SumcheckProof::create(
&mut transcript,
&mut evaluation_point,
ProverState::create(&poly),
);

// verify proof
let mut transcript = Transcript::new(b"sumchecktest");
Expand Down Expand Up @@ -172,7 +183,7 @@ fn we_can_verify_many_random_test_cases() {
let proof = SumcheckProof::create(
&mut transcript,
&mut evaluation_point,
&test_case.polynomial,
ProverState::create(&test_case.polynomial),
);

let mut transcript = Transcript::new(b"sumchecktest");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ pub fn prove_round<S: Scalar>(prover_state: &mut ProverState<S>, r_maybe: &Optio
"first round should be prover first."
);

prover_state.randomness.push(*r);

// fix argument
let r_as_field = prover_state.randomness[prover_state.round - 1];
let r_as_field = *r;
if_rayon!(
prover_state.flattened_ml_extensions.par_iter_mut(),
prover_state.flattened_ml_extensions.iter_mut()
Expand Down
21 changes: 18 additions & 3 deletions crates/proof-of-sql/src/proof_primitive/sumcheck/prover_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ use crate::base::scalar::Scalar;
use alloc::vec::Vec;

pub struct ProverState<S: Scalar> {
/// sampled randomness given by the verifier
pub randomness: Vec<S>,
/// Stores the list of products that is meant to be added together. Each multiplicand is represented by
/// the index in `flattened_ml_extensions`
pub list_of_products: Vec<(S, Vec<usize>)>,
Expand All @@ -36,12 +34,29 @@ impl<S: Scalar> ProverState<S> {
.collect();

ProverState {
randomness: Vec::with_capacity(polynomial.num_variables),
list_of_products: polynomial.products.clone(),
flattened_ml_extensions,
num_vars: polynomial.num_variables,
max_multiplicands: polynomial.max_multiplicands,
round: 0,
}
}
pub fn new(
list_of_products: Vec<(S, Vec<usize>)>,
flattened_ml_extensions: Vec<Vec<S>>,
num_vars: usize,
) -> Self {
let max_multiplicands = list_of_products
.iter()
.map(|(_, product)| product.len())
.max()
.unwrap_or(0);
ProverState {
list_of_products,
flattened_ml_extensions,
num_vars,
max_multiplicands,
round: 0,
}
}
}
124 changes: 0 additions & 124 deletions crates/proof-of-sql/src/sql/proof/composite_polynomial_builder.rs

This file was deleted.

Loading
Loading