From de591944e91e6ad92e9a08b498eb0ec1d363a10c Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 6 Nov 2024 17:51:48 -0800 Subject: [PATCH] Lagrage evaluation improvements rustc didn't optimize the `eval` function and ended up doing 131072*32*992 loop iteration per Lagrange compute. For some unknown reason to me, optimizer works only if we operate on integers, not on fields. Maybe modulo reduction is to blame --- .../ipa_prf/malicious_security/lagrange.rs | 45 ++++++++++++------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index c31649a41..e578f0a51 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -1,4 +1,4 @@ -use std::{borrow::Borrow, fmt::Debug}; +use std::fmt::Debug; use typenum::Unsigned; @@ -79,8 +79,7 @@ pub struct LagrangeTable { impl LagrangeTable where - F: Field + TryFrom, - >::Error: Debug, + F: PrimeField, { /// generates a `CanonicalLagrangeTable` from `CanoncialLagrangeDenominators` for a single output point /// The "x coordinate" of the output point is `x_output`. @@ -95,25 +94,16 @@ where impl LagrangeTable where - F: Field, + F: PrimeField, { /// This function uses the `LagrangeTable` to evaluate `polynomial` on the _output_ "x coordinates" /// that were used to generate this table. /// It is assumed that the `y_coordinates` provided to this function correspond the values of the _input_ "x coordinates" /// that were used to generate this table. - pub fn eval(&self, y_coordinates: &[F; N]) -> [F; M] - { + pub fn eval(&self, y_coordinates: &[F; N]) -> [F; M] { self.table - .iter() - .map(|table_row| { - table_row - .iter() - .zip(y_coordinates) - .fold(F::ZERO, |acc, (&base, y)| acc + base * (*y.borrow())) - }) - .collect::>() - .try_into() - .unwrap() + .each_ref() + .map(|row| dot_product(row, y_coordinates)) } /// helper function to compute a single row of `LagrangeTable` @@ -170,6 +160,29 @@ where } } +/// Computes the dot product of two arrays of the same size. +/// It is isolated from Lagrange because there could be potential SIMD optimizations used +fn dot_product(a: &[F; N], b: &[F; N]) -> F { + // Staying in integers allows rustc to optimize this code properly + // with any reasonable N, we won't run into overflow with dot product. + // (N can be as large as 2^32 and still no chance of overflow for 61 bit prime fields) + debug_assert!( + F::PRIME.into() < (1 << 64), + "The prime {} is too large for this dot product implementation", + F::PRIME.into() + ); + + let mut sum = 0; + + // I am cautious about using zip in hot code + // https://github.com/rust-lang/rust/issues/103555 + for i in 0..N { + sum += a[i].as_u128() * b[i].as_u128(); + } + + F::truncate_from(sum) +} + #[cfg(all(test, unit_test))] mod test { use std::{borrow::Borrow, fmt::Debug};