Skip to content

Commit

Permalink
Lagrage evaluation improvements
Browse files Browse the repository at this point in the history
rustc didn't optimize the `eval` function and ended up doing
131072*32*992 loop iteration per Lagrange compute.

For some unknown reason to me, optimizer works only if we operate
on integers, not on fields. Maybe modulo reduction is to blame
  • Loading branch information
akoshelev committed Nov 7, 2024
1 parent 59a6fb3 commit de59194
Showing 1 changed file with 29 additions and 16 deletions.
45 changes: 29 additions & 16 deletions ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{borrow::Borrow, fmt::Debug};
use std::fmt::Debug;

use typenum::Unsigned;

Expand Down Expand Up @@ -79,8 +79,7 @@ pub struct LagrangeTable<F: Field, const N: usize, const M: usize> {

impl<F, const N: usize> LagrangeTable<F, N, 1>
where
F: Field + TryFrom<u128>,
<F as TryFrom<u128>>::Error: Debug,
F: PrimeField,
{
/// generates a `CanonicalLagrangeTable` from `CanoncialLagrangeDenominators` for a single output point
/// The "x coordinate" of the output point is `x_output`.
Expand All @@ -95,25 +94,16 @@ where

impl<F, const N: usize, const M: usize> LagrangeTable<F, N, M>
where
F: Field,
F: PrimeField,
{
/// This function uses the `LagrangeTable` to evaluate `polynomial` on the _output_ "x coordinates"
/// that were used to generate this table.
/// It is assumed that the `y_coordinates` provided to this function correspond the values of the _input_ "x coordinates"
/// that were used to generate this table.
pub fn eval(&self, y_coordinates: &[F; N]) -> [F; M]
{
pub fn eval(&self, y_coordinates: &[F; N]) -> [F; M] {
self.table
.iter()
.map(|table_row| {
table_row
.iter()
.zip(y_coordinates)
.fold(F::ZERO, |acc, (&base, y)| acc + base * (*y.borrow()))
})
.collect::<Vec<F>>()
.try_into()
.unwrap()
.each_ref()
.map(|row| dot_product(row, y_coordinates))
}

/// helper function to compute a single row of `LagrangeTable`
Expand Down Expand Up @@ -170,6 +160,29 @@ where
}
}

/// Computes the dot product of two arrays of the same size.
/// It is isolated from Lagrange because there could be potential SIMD optimizations used
fn dot_product<F: PrimeField, const N: usize>(a: &[F; N], b: &[F; N]) -> F {
// Staying in integers allows rustc to optimize this code properly
// with any reasonable N, we won't run into overflow with dot product.
// (N can be as large as 2^32 and still no chance of overflow for 61 bit prime fields)
debug_assert!(
F::PRIME.into() < (1 << 64),
"The prime {} is too large for this dot product implementation",
F::PRIME.into()
);

let mut sum = 0;

// I am cautious about using zip in hot code
// https://github.com/rust-lang/rust/issues/103555
for i in 0..N {
sum += a[i].as_u128() * b[i].as_u128();
}

F::truncate_from(sum)
}

#[cfg(all(test, unit_test))]
mod test {
use std::{borrow::Borrow, fmt::Debug};
Expand Down

0 comments on commit de59194

Please sign in to comment.