Skip to content

Commit

Permalink
Optimized modulo operator for Fp61BitPrime
Browse files Browse the repository at this point in the history
Zero-knowledge proofs use multiplication on this field extensively. For every bit of multiplication output, we perform [~10 multiplications](https://github.com/private-attribution/ipa/blob/cf25c69f7c641ed70de62c6536c0293e1e6b2db5/ipa-core/src/protocol/context/dzkp_field.rs#L122) in `Fp61BitPrime`. That's around $$86 \ times 10^9$$ modulo reduction operations per 35 million MPC multiplications.

My preliminary testing shows ~30% improvement in local runs, will do some Draft runs to confirm
  • Loading branch information
akoshelev committed Nov 4, 2024
1 parent 91aa71d commit 20e62c1
Showing 1 changed file with 100 additions and 24 deletions.
124 changes: 100 additions & 24 deletions ipa-core/src/ff/prime_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,44 @@ pub trait PrimeField: Field + U128Conversions {
#[error("Field value {0} provided is greater than prime: {1}")]
pub struct GreaterThanPrimeError<V: Display>(V, u128);

/// Default modulo reduction implementation using the remainder operator.
/// Works for any prime field, but can be inefficient for large primes
/// and if it needs to be done often. Fields operating over Mersenne
/// primes provide an optimized implementation.
macro_rules! rem_modulo_impl {
( $field:ident, $op_store:ty ) => {
impl $field {
#[must_use]
fn modulo_prime_base(input: $op_store) -> Self {
#[allow(clippy::cast_possible_truncation)]
Self((input % <$op_store>::from(Self::PRIME)) as <Self as SharedValue>::Storage)
}

#[must_use]
fn modulo_prime_u128(input: u128) -> Self
where
Self: U128Conversions,
{
const PRIME: u128 = $field::PRIME as u128;
#[allow(clippy::cast_possible_truncation)]
Self((input % PRIME) as <Self as SharedValue>::Storage)
}
}
};
}

macro_rules! field_impl {
( $field:ident, $store:ty, $store_multiply:ty, $bits:expr, $prime:expr ) => {
( $field:ident, $backend_store:ty, $op_store:ty, $bits:expr, $prime:expr ) => {
use super::*;

// check container for multiply is large enough
const_assert!((<$store_multiply>::MAX >> $bits) as u128 >= (<$store>::MAX) as u128);
const_assert!((<$op_store>::MAX >> $bits) as u128 >= (<$backend_store>::MAX) as u128);

#[derive(Clone, Copy, PartialEq, Eq)]
pub struct $field(<Self as SharedValue>::Storage);

impl SharedValue for $field {
type Storage = $store;
type Storage = $backend_store;
const BITS: u32 = $bits;
const ZERO: Self = $field(0);

Expand Down Expand Up @@ -101,8 +127,7 @@ macro_rules! field_impl {
///
/// This method is simpler than rejection sampling for these small prime fields.
fn truncate_from<T: Into<u128>>(v: T) -> Self {
#[allow(clippy::cast_possible_truncation)]
Self((v.into() % u128::from(Self::PRIME)) as <Self as SharedValue>::Storage)
Self::modulo_prime_u128(v.into())
}
}

Expand All @@ -121,10 +146,10 @@ macro_rules! field_impl {
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
let c = u64::from;
debug_assert!(c(Self::PRIME) < (u64::MAX >> 1));
let c = <$op_store>::from;
debug_assert!(c(Self::PRIME) < (<$op_store>::MAX >> 1));
#[allow(clippy::cast_possible_truncation)]
Self(((c(self.0) + c(rhs.0)) % c(Self::PRIME)) as <Self as SharedValue>::Storage)
Self::modulo_prime_base(c(self.0) + c(rhs.0))
}
}

Expand All @@ -139,22 +164,22 @@ macro_rules! field_impl {
type Output = Self;

fn neg(self) -> Self::Output {
Self((Self::PRIME - self.0) % Self::PRIME)
// Invariant uphold by the construction
// self >= 0
// self < Prime
// therefore it is safe to avoid remainder
Self(Self::PRIME - self.0)
}
}

impl std::ops::Sub for $field {
type Output = Self;

fn sub(self, rhs: Self) -> Self::Output {
let c = u64::from;
debug_assert!(c(Self::PRIME) < (u64::MAX >> 1));
let c = <$op_store>::from;
debug_assert!(c(Self::PRIME) < (<$op_store>::MAX >> 1));
// TODO(mt) - constant time?
#[allow(clippy::cast_possible_truncation)]
Self(
((c(Self::PRIME) + c(self.0) - c(rhs.0)) % c(Self::PRIME))
as <Self as SharedValue>::Storage,
)
Self::modulo_prime_base(c(Self::PRIME) + c(self.0) - c(rhs.0))
}
}

Expand All @@ -169,12 +194,10 @@ macro_rules! field_impl {
type Output = Self;

fn mul(self, rhs: Self) -> Self::Output {
debug_assert!(<$store>::try_from(Self::PRIME).is_ok());
let c = <$store_multiply>::from;
debug_assert!(<$backend_store>::try_from(Self::PRIME).is_ok());
let c = <$op_store>::from;
// TODO(mt) - constant time?
// TODO(dm) - optimize arithmetics?
#[allow(clippy::cast_possible_truncation)]
Self(((c(self.0) * c(rhs.0)) % c(Self::PRIME)) as <Self as SharedValue>::Storage)
Self::modulo_prime_base(c(self.0) * c(rhs.0))
}
}

Expand Down Expand Up @@ -213,7 +236,7 @@ macro_rules! field_impl {
}
}

impl From<$field> for $store {
impl From<$field> for $backend_store {
fn from(v: $field) -> Self {
v.0
}
Expand All @@ -239,7 +262,7 @@ macro_rules! field_impl {

impl Serializable for $field {
type Size = <<Self as SharedValue>::Storage as Block>::Size;
type DeserializationError = GreaterThanPrimeError<$store>;
type DeserializationError = GreaterThanPrimeError<$backend_store>;

fn serialize(&self, buf: &mut GenericArray<u8, Self::Size>) {
buf.copy_from_slice(&self.0.to_le_bytes());
Expand All @@ -248,7 +271,7 @@ macro_rules! field_impl {
fn deserialize(
buf: &GenericArray<u8, Self::Size>,
) -> Result<Self, Self::DeserializationError> {
let v = <$store>::from_le_bytes((*buf).into());
let v = <$backend_store>::from_le_bytes((*buf).into());
if v < Self::PRIME {
Ok(Self(v))
} else {
Expand Down Expand Up @@ -361,6 +384,7 @@ macro_rules! field_impl {
#[cfg(any(test, feature = "weak-field"))]
mod fp31 {
field_impl! { Fp31, u8, u16, 8, 31 }
rem_modulo_impl! { Fp31, u16 }

#[cfg(all(test, unit_test))]
mod specialized_tests {
Expand All @@ -384,6 +408,7 @@ mod fp31 {

mod fp32bit {
field_impl! { Fp32BitPrime, u32, u64, 32, 4_294_967_291 }
rem_modulo_impl! { Fp32BitPrime, u64 }

impl Vectorizable<32> for Fp32BitPrime {
type Array = StdArray<Fp32BitPrime, 32>;
Expand Down Expand Up @@ -454,10 +479,40 @@ mod fp61bit {
pub fn from_bit(input: bool) -> Self {
Self(input.into())
}

#[must_use]
fn modulo_prime_base(val: u128) -> Self {
Self::modulo_prime_u128(val)
}

/// Implements optimized modulus division operation for Mersenne fields.
/// Implementation taken from [`bit_twiddling`].
///
/// [`bit_twiddling`]: https://graphics.stanford.edu/~seander/bithacks.html#ModulusDivision
#[must_use]
#[allow(clippy::cast_possible_truncation)]
fn modulo_prime_u128(val: u128) -> Self
where
Self: U128Conversions,
{
const PRIME: u128 = Fp61BitPrime::PRIME as u128;
debug_assert_eq!(0, PRIME & (PRIME + 1), "{PRIME} is not a Mersenne prime");

let val = (val & PRIME) + (val >> Self::BITS);
// another round if val ended up being greater than PRIME
let val = (val & PRIME) + (val >> Self::BITS);
if val == PRIME {
Self::ZERO
} else {
Self(val as <Self as SharedValue>::Storage)
}
}
}

#[cfg(all(test, unit_test))]
mod specialized_tests {
use proptest::proptest;

use super::*;

// copied from 32 bit prime field, adjusted wrap arounds, computed using wolframalpha.com
Expand Down Expand Up @@ -506,6 +561,27 @@ mod fp61bit {
let y = Fp61BitPrime::truncate_from((u64::MAX >> 3) - 1); // PRIME - 1
assert_eq!(x + y, Fp61BitPrime::truncate_from((u64::MAX >> 3) - 2));
}

proptest! {
#[test]
fn add(a: Fp61BitPrime, b: Fp61BitPrime) {
let c = a + b;
assert!(c.0 < Fp61BitPrime::PRIME);
assert_eq!(c.0, (a.0 + b.0) % Fp61BitPrime::PRIME);
}

#[test]
fn mul(a: Fp61BitPrime, b: Fp61BitPrime) {
let c = a * b;
assert!(c.0 < Fp61BitPrime::PRIME);
assert_eq!(c.0, u64::try_from((u128::from(a.0) * u128::from(b.0)) % u128::from(Fp61BitPrime::PRIME)).unwrap());
}

#[test]
fn truncate(a: u64) {
assert_eq!(Fp61BitPrime::truncate_from(a).0, a % Fp61BitPrime::PRIME);
}
}
}
}

Expand Down

0 comments on commit 20e62c1

Please sign in to comment.