Skip to content

Commit

Permalink
perf: avoid cloning sumcheck terms
Browse files Browse the repository at this point in the history
  • Loading branch information
JayWhite2357 committed Dec 8, 2024
1 parent 0b0927f commit 4f34d7c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
17 changes: 8 additions & 9 deletions crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::base::{database::Column, if_rayon, scalar::Scalar, slice_ops};
use alloc::{rc::Rc, vec::Vec};
use alloc::vec::Vec;
use core::ffi::c_void;
use num_traits::Zero;
#[cfg(feature = "rayon")]
Expand All @@ -15,7 +15,7 @@ pub trait MultilinearExtension<S: Scalar> {
fn mul_add(&self, res: &mut [S], multiplier: &S);

/// convert the MLE to a form that can be used in sumcheck
fn to_sumcheck_term(&self, num_vars: usize) -> Rc<Vec<S>>;
fn to_sumcheck_term(&self, num_vars: usize) -> Vec<S>;

/// pointer to identify the slice forming the MLE
fn id(&self) -> *const c_void;
Expand All @@ -42,18 +42,17 @@ where
slice_ops::mul_add_assign(res, *multiplier, &slice_ops::slice_cast(self));
}

fn to_sumcheck_term(&self, num_vars: usize) -> Rc<Vec<S>> {
fn to_sumcheck_term(&self, num_vars: usize) -> Vec<S> {
let values = self;
let n = 1 << num_vars;
assert!(n >= values.len());
let scalars = if_rayon!(values.par_iter(), values.iter())
if_rayon!(values.par_iter(), values.iter())
.map(Into::into)
.chain(if_rayon!(
rayon::iter::repeatn(Zero::zero(), n - values.len()),
itertools::repeat_n(Zero::zero(), n - values.len())
))
.collect();
Rc::new(scalars)
.collect()
}

fn id(&self) -> *const c_void {
Expand All @@ -72,7 +71,7 @@ macro_rules! slice_like_mle_impl {
(&self[..]).mul_add(res, multiplier)
}

fn to_sumcheck_term(&self, num_vars: usize) -> Rc<Vec<S>> {
fn to_sumcheck_term(&self, num_vars: usize) -> Vec<S> {
(&self[..]).to_sumcheck_term(num_vars)
}

Expand Down Expand Up @@ -125,7 +124,7 @@ impl<S: Scalar> MultilinearExtension<S> for &Column<'_, S> {
}
}

fn to_sumcheck_term(&self, num_vars: usize) -> Rc<Vec<S>> {
fn to_sumcheck_term(&self, num_vars: usize) -> Vec<S> {
match self {
Column::Boolean(c) => c.to_sumcheck_term(num_vars),
Column::Scalar(c) | Column::VarChar((_, c)) | Column::Decimal75(_, _, c) => {
Expand Down Expand Up @@ -163,7 +162,7 @@ impl<S: Scalar> MultilinearExtension<S> for Column<'_, S> {
(&self).mul_add(res, multiplier);
}

fn to_sumcheck_term(&self, num_vars: usize) -> Rc<Vec<S>> {
fn to_sumcheck_term(&self, num_vars: usize) -> Vec<S> {
(&self).to_sumcheck_term(num_vars)
}

Expand Down
4 changes: 2 additions & 2 deletions crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ impl<'a, S: Scalar> FlattenedMLEBuilder<'a, S> {
fn flattened_ml_extensions(self) -> Vec<Vec<S>> {
self.entrywise_multipliers
.into_iter()
.map(|mle| (&mle).to_sumcheck_term(self.num_vars).as_ref().clone())
.map(|mle| (&mle).to_sumcheck_term(self.num_vars))
.chain(
self.all_ml_extensions
.iter()
.map(|mle| mle.to_sumcheck_term(self.num_vars).as_ref().clone()),
.map(|mle| mle.to_sumcheck_term(self.num_vars)),
)
.collect()
}
Expand Down

0 comments on commit 4f34d7c

Please sign in to comment.