From 4f34d7cbc83c1e40f4408bef04d313a55718066c Mon Sep 17 00:00:00 2001 From: Jay White Date: Tue, 12 Nov 2024 12:45:16 -0500 Subject: [PATCH] perf: avoid cloning sumcheck terms --- .../base/polynomial/multilinear_extension.rs | 17 ++++++++--------- .../src/sql/proof/make_sumcheck_state.rs | 4 ++-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs b/crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs index 99720707c..ef9a52e3c 100644 --- a/crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs +++ b/crates/proof-of-sql/src/base/polynomial/multilinear_extension.rs @@ -1,5 +1,5 @@ use crate::base::{database::Column, if_rayon, scalar::Scalar, slice_ops}; -use alloc::{rc::Rc, vec::Vec}; +use alloc::vec::Vec; use core::ffi::c_void; use num_traits::Zero; #[cfg(feature = "rayon")] @@ -15,7 +15,7 @@ pub trait MultilinearExtension { fn mul_add(&self, res: &mut [S], multiplier: &S); /// convert the MLE to a form that can be used in sumcheck - fn to_sumcheck_term(&self, num_vars: usize) -> Rc>; + fn to_sumcheck_term(&self, num_vars: usize) -> Vec; /// pointer to identify the slice forming the MLE fn id(&self) -> *const c_void; @@ -42,18 +42,17 @@ where slice_ops::mul_add_assign(res, *multiplier, &slice_ops::slice_cast(self)); } - fn to_sumcheck_term(&self, num_vars: usize) -> Rc> { + fn to_sumcheck_term(&self, num_vars: usize) -> Vec { let values = self; let n = 1 << num_vars; assert!(n >= values.len()); - let scalars = if_rayon!(values.par_iter(), values.iter()) + if_rayon!(values.par_iter(), values.iter()) .map(Into::into) .chain(if_rayon!( rayon::iter::repeatn(Zero::zero(), n - values.len()), itertools::repeat_n(Zero::zero(), n - values.len()) )) - .collect(); - Rc::new(scalars) + .collect() } fn id(&self) -> *const c_void { @@ -72,7 +71,7 @@ macro_rules! slice_like_mle_impl { (&self[..]).mul_add(res, multiplier) } - fn to_sumcheck_term(&self, num_vars: usize) -> Rc> { + fn to_sumcheck_term(&self, num_vars: usize) -> Vec { (&self[..]).to_sumcheck_term(num_vars) } @@ -125,7 +124,7 @@ impl MultilinearExtension for &Column<'_, S> { } } - fn to_sumcheck_term(&self, num_vars: usize) -> Rc> { + fn to_sumcheck_term(&self, num_vars: usize) -> Vec { match self { Column::Boolean(c) => c.to_sumcheck_term(num_vars), Column::Scalar(c) | Column::VarChar((_, c)) | Column::Decimal75(_, _, c) => { @@ -163,7 +162,7 @@ impl MultilinearExtension for Column<'_, S> { (&self).mul_add(res, multiplier); } - fn to_sumcheck_term(&self, num_vars: usize) -> Rc> { + fn to_sumcheck_term(&self, num_vars: usize) -> Vec { (&self).to_sumcheck_term(num_vars) } diff --git a/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs b/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs index 895a8e737..6b04d3637 100644 --- a/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs +++ b/crates/proof-of-sql/src/sql/proof/make_sumcheck_state.rs @@ -42,11 +42,11 @@ impl<'a, S: Scalar> FlattenedMLEBuilder<'a, S> { fn flattened_ml_extensions(self) -> Vec> { self.entrywise_multipliers .into_iter() - .map(|mle| (&mle).to_sumcheck_term(self.num_vars).as_ref().clone()) + .map(|mle| (&mle).to_sumcheck_term(self.num_vars)) .chain( self.all_ml_extensions .iter() - .map(|mle| mle.to_sumcheck_term(self.num_vars).as_ref().clone()), + .map(|mle| mle.to_sumcheck_term(self.num_vars)), ) .collect() }