diff --git a/ipa-core/src/ff/prime_field.rs b/ipa-core/src/ff/prime_field.rs index cfe90bd28..750dbda89 100644 --- a/ipa-core/src/ff/prime_field.rs +++ b/ipa-core/src/ff/prime_field.rs @@ -57,18 +57,44 @@ pub trait PrimeField: Field + U128Conversions { #[error("Field value {0} provided is greater than prime: {1}")] pub struct GreaterThanPrimeError(V, u128); +/// Default modulo reduction implementation using the remainder operator. +/// Works for any prime field, but can be inefficient for large primes +/// and if it needs to be done often. Fields operating over Mersenne +/// primes provide an optimized implementation. +macro_rules! rem_modulo_impl { + ( $field:ident, $op_store:ty ) => { + impl $field { + #[must_use] + fn modulo_prime_base(input: $op_store) -> Self { + #[allow(clippy::cast_possible_truncation)] + Self((input % <$op_store>::from(Self::PRIME)) as ::Storage) + } + + #[must_use] + fn modulo_prime_u128(input: u128) -> Self + where + Self: U128Conversions, + { + const PRIME: u128 = $field::PRIME as u128; + #[allow(clippy::cast_possible_truncation)] + Self((input % PRIME) as ::Storage) + } + } + }; +} + macro_rules! field_impl { - ( $field:ident, $store:ty, $store_multiply:ty, $bits:expr, $prime:expr ) => { + ( $field:ident, $backend_store:ty, $op_store:ty, $bits:expr, $prime:expr ) => { use super::*; // check container for multiply is large enough - const_assert!((<$store_multiply>::MAX >> $bits) as u128 >= (<$store>::MAX) as u128); + const_assert!((<$op_store>::MAX >> $bits) as u128 >= (<$backend_store>::MAX) as u128); #[derive(Clone, Copy, PartialEq, Eq)] pub struct $field(::Storage); impl SharedValue for $field { - type Storage = $store; + type Storage = $backend_store; const BITS: u32 = $bits; const ZERO: Self = $field(0); @@ -101,8 +127,7 @@ macro_rules! field_impl { /// /// This method is simpler than rejection sampling for these small prime fields. fn truncate_from>(v: T) -> Self { - #[allow(clippy::cast_possible_truncation)] - Self((v.into() % u128::from(Self::PRIME)) as ::Storage) + Self::modulo_prime_u128(v.into()) } } @@ -121,10 +146,10 @@ macro_rules! field_impl { type Output = Self; fn add(self, rhs: Self) -> Self::Output { - let c = u64::from; - debug_assert!(c(Self::PRIME) < (u64::MAX >> 1)); + let c = <$op_store>::from; + debug_assert!(c(Self::PRIME) < (<$op_store>::MAX >> 1)); #[allow(clippy::cast_possible_truncation)] - Self(((c(self.0) + c(rhs.0)) % c(Self::PRIME)) as ::Storage) + Self::modulo_prime_base(c(self.0) + c(rhs.0)) } } @@ -139,7 +164,11 @@ macro_rules! field_impl { type Output = Self; fn neg(self) -> Self::Output { - Self((Self::PRIME - self.0) % Self::PRIME) + // Invariant uphold by the construction + // self >= 0 + // self < Prime + // therefore it is safe to avoid remainder + Self(Self::PRIME - self.0) } } @@ -147,14 +176,10 @@ macro_rules! field_impl { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { - let c = u64::from; - debug_assert!(c(Self::PRIME) < (u64::MAX >> 1)); + let c = <$op_store>::from; + debug_assert!(c(Self::PRIME) < (<$op_store>::MAX >> 1)); // TODO(mt) - constant time? - #[allow(clippy::cast_possible_truncation)] - Self( - ((c(Self::PRIME) + c(self.0) - c(rhs.0)) % c(Self::PRIME)) - as ::Storage, - ) + Self::modulo_prime_base(c(Self::PRIME) + c(self.0) - c(rhs.0)) } } @@ -169,12 +194,10 @@ macro_rules! field_impl { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { - debug_assert!(<$store>::try_from(Self::PRIME).is_ok()); - let c = <$store_multiply>::from; + debug_assert!(<$backend_store>::try_from(Self::PRIME).is_ok()); + let c = <$op_store>::from; // TODO(mt) - constant time? - // TODO(dm) - optimize arithmetics? - #[allow(clippy::cast_possible_truncation)] - Self(((c(self.0) * c(rhs.0)) % c(Self::PRIME)) as ::Storage) + Self::modulo_prime_base(c(self.0) * c(rhs.0)) } } @@ -213,7 +236,7 @@ macro_rules! field_impl { } } - impl From<$field> for $store { + impl From<$field> for $backend_store { fn from(v: $field) -> Self { v.0 } @@ -239,7 +262,7 @@ macro_rules! field_impl { impl Serializable for $field { type Size = <::Storage as Block>::Size; - type DeserializationError = GreaterThanPrimeError<$store>; + type DeserializationError = GreaterThanPrimeError<$backend_store>; fn serialize(&self, buf: &mut GenericArray) { buf.copy_from_slice(&self.0.to_le_bytes()); @@ -248,7 +271,7 @@ macro_rules! field_impl { fn deserialize( buf: &GenericArray, ) -> Result { - let v = <$store>::from_le_bytes((*buf).into()); + let v = <$backend_store>::from_le_bytes((*buf).into()); if v < Self::PRIME { Ok(Self(v)) } else { @@ -361,6 +384,7 @@ macro_rules! field_impl { #[cfg(any(test, feature = "weak-field"))] mod fp31 { field_impl! { Fp31, u8, u16, 8, 31 } + rem_modulo_impl! { Fp31, u16 } #[cfg(all(test, unit_test))] mod specialized_tests { @@ -384,6 +408,7 @@ mod fp31 { mod fp32bit { field_impl! { Fp32BitPrime, u32, u64, 32, 4_294_967_291 } + rem_modulo_impl! { Fp32BitPrime, u64 } impl Vectorizable<32> for Fp32BitPrime { type Array = StdArray; @@ -454,10 +479,40 @@ mod fp61bit { pub fn from_bit(input: bool) -> Self { Self(input.into()) } + + #[must_use] + fn modulo_prime_base(val: u128) -> Self { + Self::modulo_prime_u128(val) + } + + /// Implements optimized modulus division operation for Mersenne fields. + /// Implementation taken from [`bit_twiddling`]. + /// + /// [`bit_twiddling`]: https://graphics.stanford.edu/~seander/bithacks.html#ModulusDivision + #[must_use] + #[allow(clippy::cast_possible_truncation)] + fn modulo_prime_u128(val: u128) -> Self + where + Self: U128Conversions, + { + const PRIME: u128 = Fp61BitPrime::PRIME as u128; + debug_assert_eq!(0, PRIME & (PRIME + 1), "{PRIME} is not a Mersenne prime"); + + let val = (val & PRIME) + (val >> Self::BITS); + // another round if val ended up being greater than PRIME + let val = (val & PRIME) + (val >> Self::BITS); + if val == PRIME { + Self::ZERO + } else { + Self(val as ::Storage) + } + } } #[cfg(all(test, unit_test))] mod specialized_tests { + use proptest::proptest; + use super::*; // copied from 32 bit prime field, adjusted wrap arounds, computed using wolframalpha.com @@ -506,6 +561,27 @@ mod fp61bit { let y = Fp61BitPrime::truncate_from((u64::MAX >> 3) - 1); // PRIME - 1 assert_eq!(x + y, Fp61BitPrime::truncate_from((u64::MAX >> 3) - 2)); } + + proptest! { + #[test] + fn add(a: Fp61BitPrime, b: Fp61BitPrime) { + let c = a + b; + assert!(c.0 < Fp61BitPrime::PRIME); + assert_eq!(c.0, (a.0 + b.0) % Fp61BitPrime::PRIME); + } + + #[test] + fn mul(a: Fp61BitPrime, b: Fp61BitPrime) { + let c = a * b; + assert!(c.0 < Fp61BitPrime::PRIME); + assert_eq!(c.0, u64::try_from((u128::from(a.0) * u128::from(b.0)) % u128::from(Fp61BitPrime::PRIME)).unwrap()); + } + + #[test] + fn truncate(a: u64) { + assert_eq!(Fp61BitPrime::truncate_from(a).0, a % Fp61BitPrime::PRIME); + } + } } }