Skip to content

Commit

Permalink
[sumcheck] Use Karatsuba interpolation (IrreducibleOSS#20)
Browse files Browse the repository at this point in the history
Support for the Karatsuba infinity point when interpolating, where we take a binary subspace evaluation domain and replace the third point (after zero and one) with an infinity.
  • Loading branch information
onesk authored Feb 24, 2025
1 parent e9991ce commit 4f56aca
Show file tree
Hide file tree
Showing 17 changed files with 332 additions and 112 deletions.
51 changes: 36 additions & 15 deletions crates/core/src/protocols/gkr_gpa/gpa_sumcheck/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

use std::ops::Range;

use binius_field::{util::eq, Field, PackedExtension, PackedField, PackedFieldIndexable};
use binius_field::{
util::eq, ExtensionField, Field, PackedExtension, PackedField, PackedFieldIndexable, TowerField,
};
use binius_hal::{ComputationBackend, SumcheckEvaluator};
use binius_math::{CompositionPoly, EvaluationDomainFactory, InterpolationDomain, MultilinearPoly};
use binius_maybe_rayon::prelude::*;
Expand All @@ -13,10 +15,10 @@ use tracing::{debug_span, instrument};

use super::error::Error;
use crate::{
polynomial::Error as PolynomialError,
polynomial::{ArithCircuitPoly, Error as PolynomialError},
protocols::{
sumcheck::{
immediate_switchover_heuristic,
get_nontrivial_evaluation_points, immediate_switchover_heuristic,
prove::{common, prover_state::ProverState, SumcheckInterpolator, SumcheckProver},
CompositeSumClaim, Error as SumcheckError, RoundCoeffs,
},
Expand Down Expand Up @@ -44,7 +46,7 @@ where

impl<'a, F, FDomain, P, Composition, M, Backend> GPAProver<'a, FDomain, P, Composition, M, Backend>
where
F: Field,
F: TowerField + ExtensionField<FDomain>,
FDomain: Field,
P: PackedFieldIndexable<Scalar = F> + PackedExtension<FDomain>,
Composition: CompositionPoly<P>,
Expand Down Expand Up @@ -85,7 +87,8 @@ where
.par_iter()
.map(|composite_claim| {
let degree = composite_claim.composition.degree();
let domain = evaluation_domain_factory.create(degree + 1)?;
let domain =
evaluation_domain_factory.create_with_infinity(degree + 1, degree >= 2)?;
Ok(domain.into())
})
.collect::<Result<Vec<InterpolationDomain<FDomain>>, _>>()
Expand All @@ -96,15 +99,12 @@ where
.map(|claim| claim.composition)
.collect();

let evaluation_points = domains
.iter()
.max_by_key(|domain| domain.size())
.map_or_else(|| Vec::new(), |domain| domain.finite_points().to_vec());
let nontrivial_evaluation_points = get_nontrivial_evaluation_points(&domains)?;

let state = ProverState::new(
multilinears,
claimed_sums,
evaluation_points,
nontrivial_evaluation_points,
// We use GPA protocol only for big fields, which is why switchover is trivial
immediate_switchover_heuristic,
backend,
Expand Down Expand Up @@ -187,7 +187,7 @@ where
impl<F, FDomain, P, Composition, M, Backend> SumcheckProver<F>
for GPAProver<'_, FDomain, P, Composition, M, Backend>
where
F: Field,
F: TowerField + ExtensionField<FDomain>,
FDomain: Field,
P: PackedFieldIndexable<Scalar = F>
+ PackedExtension<F, PackedSubfield = P>
Expand All @@ -213,8 +213,12 @@ where
.map(|first_round_eval_1s| first_round_eval_1s[index])
.filter(|_| round == 0);

let composition_at_infinity =
ArithCircuitPoly::new(composition.expression().leading_term());

GPAEvaluator {
composition,
composition_at_infinity,
interpolation_domain,
first_round_eval_1,
partial_eq_ind_evals: &self.partial_eq_ind_evals,
Expand Down Expand Up @@ -275,6 +279,7 @@ where
FDomain: Field,
{
composition: &'a Composition,
composition_at_infinity: ArithCircuitPoly<P::Scalar>,
interpolation_domain: &'a InterpolationDomain<FDomain>,
partial_eq_ind_evals: &'a [P],
first_round_eval_1: Option<P::Scalar>,
Expand All @@ -284,7 +289,7 @@ where
impl<F, P, FDomain, Composition> SumcheckEvaluator<P, Composition>
for GPAEvaluator<'_, P, FDomain, Composition>
where
F: Field,
F: TowerField + ExtensionField<FDomain>,
P: PackedField<Scalar = F> + PackedExtension<F, PackedSubfield = P> + PackedExtension<FDomain>,
FDomain: Field,
Composition: CompositionPoly<P>,
Expand All @@ -308,14 +313,21 @@ where
&self,
subcube_vars: usize,
subcube_index: usize,
is_infinity_point: bool,
batch_query: &[&[P]],
) -> P {
let row_len = batch_query.first().map_or(0, |row| row.len());

stackalloc_with_default(row_len, |evals| {
self.composition
.batch_evaluate(batch_query, evals)
.expect("correct by query construction invariant");
if is_infinity_point {
self.composition_at_infinity
.batch_evaluate(batch_query, evals)
.expect("correct by query construction invariant");
} else {
self.composition
.batch_evaluate(batch_query, evals)
.expect("correct by query construction invariant");
};

let subcube_start = subcube_index << subcube_vars.saturating_sub(P::LOG_WIDTH);
for (i, eval) in evals.iter_mut().enumerate() {
Expand Down Expand Up @@ -366,6 +378,15 @@ where

round_evals.insert(0, zero_evaluation);

if round_evals.len() > 3 {
// SumcheckRoundCalculator orders interpolation points as 0, 1, "infinity", then subspace points.
// InterpolationDomain expects "infinity" at the last position, thus reordering is needed.
// Putting "special" evaluation points at the beginning of domain allows benefitting from
// faster/skipped interpolation even in case of mixed degree compositions .
let infinity_round_eval = round_evals.remove(2);
round_evals.push(infinity_round_eval);
}

let coeffs = self.interpolation_domain.interpolate(&round_evals)?;
Ok(coeffs)
}
Expand Down
5 changes: 2 additions & 3 deletions crates/core/src/protocols/gkr_gpa/gpa_sumcheck/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ use binius_field::{
arch::OptimalUnderlier512b,
as_packed_field::{PackScalar, PackedType},
underlier::UnderlierType,
BinaryField, BinaryField128b, BinaryField8b, ExtensionField, PackedField, PackedFieldIndexable,
TowerField,
BinaryField128b, BinaryField8b, ExtensionField, PackedField, PackedFieldIndexable, TowerField,
};
use binius_hal::make_portable_backend;
use binius_math::{
Expand All @@ -34,7 +33,7 @@ fn test_prove_verify_bivariate_product_helper<U, F, FDomain>(
) where
U: UnderlierType + PackScalar<F> + PackScalar<FDomain>,
F: TowerField + ExtensionField<FDomain>,
FDomain: BinaryField,
FDomain: TowerField,
PackedType<U, F>: PackedFieldIndexable,
{
let mut rng = StdRng::seed_from_u64(0);
Expand Down
6 changes: 3 additions & 3 deletions crates/core/src/protocols/gkr_gpa/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ fn process_finished_provers<F, P, Backend>(
reverse_sorted_final_layer_claims: &mut Vec<LayerClaim<F>>,
) -> Result<(), Error>
where
F: Field,
F: TowerField,
P: PackedFieldIndexable<Scalar = F> + PackedExtension<F, PackedSubfield = P>,
Backend: ComputationBackend,
{
Expand Down Expand Up @@ -176,7 +176,7 @@ where

impl<'a, F, P, Backend> GrandProductProverState<'a, F, P, Backend>
where
F: Field + From<P::Scalar>,
F: TowerField + From<P::Scalar>,
P: PackedFieldIndexable<Scalar = F> + PackedExtension<F, PackedSubfield = P>,
Backend: ComputationBackend,
{
Expand Down Expand Up @@ -314,7 +314,7 @@ where
if self.current_layer_no() >= self.input_vars() {
bail!(Error::TooManyRounds);
}
let new_eval = extrapolate_line_scalar(zero_eval, one_eval, gpa_challenge);
let new_eval = extrapolate_line_scalar::<F, F>(zero_eval, one_eval, gpa_challenge);
let mut layer_challenge = sumcheck_challenge;
layer_challenge.push(gpa_challenge);

Expand Down
35 changes: 34 additions & 1 deletion crates/core/src/protocols/sumcheck/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use binius_field::{
util::{inner_product_unchecked, powers},
ExtensionField, Field, PackedField,
};
use binius_math::{CompositionPoly, MultilinearPoly};
use binius_math::{CompositionPoly, InterpolationDomain, MultilinearPoly};
use binius_utils::bail;
use getset::{CopyGetters, Getters};
use tracing::instrument;
Expand Down Expand Up @@ -303,6 +303,39 @@ pub fn batch_weighted_value<F: Field>(batch_coeff: F, values: impl Iterator<Item
batch_coeff * inner_product_unchecked(powers(batch_coeff), values)
}

/// Validate the sumcheck evaluation domains to conform to the shape expected by the
/// `SumcheckRoundCalculator`:
/// 1) First three points are zero, one, and Karatsuba infinity (for degrees above 1)
/// 2) All finite evaluation point slices are proper prefixes of the largest evaluation domain
pub fn get_nontrivial_evaluation_points<F: Field>(
domains: &[InterpolationDomain<F>],
) -> Result<Vec<F>, Error> {
let Some(largest_domain) = domains.iter().max_by_key(|domain| domain.size()) else {
return Ok(Vec::new());
};

#[allow(clippy::get_first)]
if !domains.iter().all(|domain| {
(domain.size() <= 2 || domain.with_infinity())
&& domain.finite_points().get(0).unwrap_or(&F::ZERO) == &F::ZERO
&& domain.finite_points().get(1).unwrap_or(&F::ONE) == &F::ONE
}) {
bail!(Error::IncorrectSumcheckEvaluationDomain);
}

let finite_points = largest_domain.finite_points();

if domains
.iter()
.any(|domain| !finite_points.starts_with(domain.finite_points()))
{
bail!(Error::NonProperPrefixEvaluationDomain);
}

let nontrivial_evaluation_points = finite_points[2.min(finite_points.len())..].to_vec();
Ok(nontrivial_evaluation_points)
}

#[cfg(test)]
mod tests {
use binius_field::BinaryField64b;
Expand Down
4 changes: 4 additions & 0 deletions crates/core/src/protocols/sumcheck/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ pub enum Error {
oracle: String,
hypercube_index: usize,
},
#[error("evaluation domain should start with zero and one, and contain Karatsuba infinity for degrees above 1")]
IncorrectSumcheckEvaluationDomain,
#[error("evaluation domains are not proper prefixes of each other")]
NonProperPrefixEvaluationDomain,
#[error("constraint set contains multilinears of different heights")]
ConstraintSetNumberOfVariablesMismatch,
#[error("batching sumchecks and zerochecks is not supported yet")]
Expand Down
12 changes: 6 additions & 6 deletions crates/core/src/protocols/sumcheck/prove/prover_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ where
#[getset(get_copy = "pub")]
n_vars: usize,
multilinears: Vec<SumcheckMultilinear<P, M>>,
evaluation_points: Vec<FDomain>,
nontrivial_evaluation_points: Vec<FDomain>,
tensor_query: Option<MultilinearQuery<P>>,
last_coeffs_or_sums: ProverStateCoeffsOrSums<P::Scalar>,
backend: &'a Backend,
Expand All @@ -78,7 +78,7 @@ where
pub fn new(
multilinears: Vec<M>,
claimed_sums: Vec<F>,
evaluation_points: Vec<FDomain>,
nontrivial_evaluation_points: Vec<FDomain>,
switchover_fn: impl Fn(usize) -> usize,
backend: &'a Backend,
) -> Result<Self, Error> {
Expand All @@ -87,7 +87,7 @@ where
multilinears,
&switchover_rounds,
claimed_sums,
evaluation_points,
nontrivial_evaluation_points,
backend,
)
}
Expand All @@ -101,7 +101,7 @@ where
multilinears: Vec<M>,
switchover_rounds: &[usize],
claimed_sums: Vec<F>,
evaluation_points: Vec<FDomain>,
nontrivial_evaluation_points: Vec<FDomain>,
backend: &'a Backend,
) -> Result<Self, Error> {
let n_vars = equal_n_vars_check(&multilinears)?;
Expand All @@ -123,7 +123,7 @@ where
Ok(Self {
n_vars,
multilinears,
evaluation_points,
nontrivial_evaluation_points,
tensor_query: Some(tensor_query),
last_coeffs_or_sums: ProverStateCoeffsOrSums::Sums(claimed_sums),
backend,
Expand Down Expand Up @@ -260,7 +260,7 @@ where
self.tensor_query.as_ref().map(Into::into),
&self.multilinears,
evaluators,
&self.evaluation_points,
&self.nontrivial_evaluation_points,
)?)
}

Expand Down
Loading

0 comments on commit 4f56aca

Please sign in to comment.