diff --git a/CHANGELOG.md b/CHANGELOG.md index ee102fe8..4fcd128a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ ### Improvements +- [\#72](https://github.com/arkworks-rs/r1cs-std/pull/72) Implement `pow_by_constant` with NAF for `FpVar`. + ### Bug Fixes - [\#70](https://github.com/arkworks-rs/r1cs-std/pull/70) Fix soundness issues of `mul_by_inverse` for field gadgets. diff --git a/src/fields/fp/mod.rs b/src/fields/fp/mod.rs index c85bcbda..a3b6e6d8 100644 --- a/src/fields/fp/mod.rs +++ b/src/fields/fp/mod.rs @@ -1,4 +1,4 @@ -use ark_ff::{BigInteger, FpParameters, PrimeField}; +use ark_ff::{BigInteger, BitIteratorBE, FpParameters, PrimeField}; use ark_relations::r1cs::{ ConstraintSystemRef, LinearCombination, Namespace, SynthesisError, Variable, }; @@ -764,6 +764,128 @@ impl FieldVar for FpVar { *self = self.frobenius_map(power)?; Ok(self) } + + /// Computes `self^S`, where S is interpreted as an little-endian + /// u64-decomposition of an integer. + #[tracing::instrument(target = "r1cs", skip(exp))] + fn pow_by_constant>(&self, exp: S) -> Result { + use ark_ff::biginteger::arithmetic::find_wnaf; + + // first check if exp = 0 + let mut is_nonzero = false; + for limb in exp.as_ref() { + if *limb != 0u64 { + is_nonzero = true; + } + } + + // handle the case when exp = 0 + if !is_nonzero { + return Ok(FpVar::Constant(F::one())); + } + + // now we consider the case when exp != 0 + + // if `self` is constant, we compute it directly. + if self.is_constant() { + return Ok(FpVar::Constant(self.value()?.pow(exp))); + } + + // now we consider the case when exp != 0 and `self` is not a constant + + // obtain the NAF representation + let naf_be = find_wnaf(exp.as_ref()); + let found_minus_one_in_naf = naf_be.contains(&-1i64); + + // now discuss whether or not we should use NAF + let mut use_naf = true; + let mut standard_be = None; + + // if the NAF does not contain `-1`, it cannot be faster than the square-and-multiply + if !found_minus_one_in_naf { + use_naf = false; + standard_be = Some(BitIteratorBE::without_leading_zeros(&exp).collect::>()); + } + + // since NAF needs to compute the inverse, which incurs additional overhead, + // it might not be better than the standard square-and-multiply + + if use_naf { + // obtain the standard representation + let standard_be_bits = + BitIteratorBE::without_leading_zeros(&exp).collect::>(); + + // compute the cost of the NAF representation + let mut naf_cost = naf_be.len() + naf_be.iter().filter(|x| **x != 0i64).count(); + if found_minus_one_in_naf { + // computing the inverse_or_any incurs additional overhead + // two for computing the inverse-or-any, one for ensuring 0 ^ exp = 0 + naf_cost += 3; + } + + // compute the cost of the standard representation + let standard_cost = + standard_be_bits.len() + standard_be_bits.iter().filter(|x| **x == true).count(); + + if standard_cost <= naf_cost { + use_naf = false; + } + + standard_be = Some(standard_be_bits); + } + + if !use_naf { + // use simple square-and-multiple + let mut res = Self::one(); + for i in standard_be.unwrap() { + res.square_in_place()?; + if i { + res *= self; + } + } + Ok(res) + } else { + // use NAF + + // first compute `inverse_or_any` + // if `self` != 0, it implies that `self` * `inverse_or_any` = 1 + // if `self` == 0, `inverse_or_any` can be any value + let self_inverse_or_any = { + let inverse_or_any = Self::new_witness(self.cs().clone(), || { + Ok(self.value()?.inverse().unwrap_or_else(F::zero)) + })?; + + // self * self = tmp + let tmp = self.square()?; + + // tmp * inverse_or_any = self + tmp.mul_equals(&inverse_or_any, &self)?; + + inverse_or_any + }; + + // the initial `res` = 1 if `self` != 0, or `res` = 0 if `self` == 0 + let mut res = self * &self_inverse_or_any; + + let mut found_non_zero = false; + for &value in naf_be.iter().rev() { + if found_non_zero { + res = res.square()?; + } + + if value != 0 { + found_non_zero = true; + if value > 0 { + res *= self; + } else { + res *= &self_inverse_or_any; + } + } + } + + Ok(res) + } + } } impl_ops!( @@ -1091,3 +1213,111 @@ mod test { assert_eq!(sum.value().unwrap(), sum_expected); } } + +#[cfg(test)] +mod test_pow_by_constant { + use crate::alloc::AllocVar; + use crate::fields::fp::FpVar; + use crate::fields::FieldVar; + use crate::R1CSVar; + use ark_ff::Field; + use ark_relations::r1cs::ConstraintSystem; + use ark_std::{One, UniformRand, Zero}; + use ark_test_curves::bls12_381::Fr; + + #[test] + fn test_rand() { + let mut rng = ark_std::test_rng(); + let cs = ConstraintSystem::new_ref(); + + let mut rand_base = Fr::rand(&mut rng); + // ensure that rand_base is not zero. + if rand_base == Fr::zero() { + rand_base = rand_base + &Fr::one(); + } + + let rand_exp = [ + u64::rand(&mut rng), + u64::rand(&mut rng), + u64::rand(&mut rng), + u64::rand(&mut rng), + ]; + + { + let rand_base_g = FpVar::::new_witness(cs.clone(), || Ok(rand_base)).unwrap(); + let res_expected = rand_base.pow(rand_exp); + let res = rand_base_g + .pow_by_constant(&rand_exp) + .unwrap() + .value() + .unwrap(); + assert_eq!(res, res_expected); + } + + { + let rand_base_g = FpVar::::new_constant(cs.clone(), rand_base).unwrap(); + let res_expected = rand_base.pow(rand_exp); + let res = rand_base_g + .pow_by_constant(&rand_exp) + .unwrap() + .value() + .unwrap(); + assert_eq!(res, res_expected); + } + + assert!(cs.is_satisfied().unwrap()); + } + + #[test] + fn test_zero_base() { + let cs = ConstraintSystem::new_ref(); + let exp = [1u64, 2u64, 3u64, 4u64]; + + { + let base_g = FpVar::::new_witness(cs.clone(), || Ok(Fr::zero())).unwrap(); + let res_expected = Fr::zero(); + let res = base_g.pow_by_constant(exp).unwrap().value().unwrap(); + assert_eq!(res, res_expected); + } + + { + let base_g = FpVar::::new_constant(cs.clone(), Fr::zero()).unwrap(); + let res_expected = Fr::zero(); + let res = base_g.pow_by_constant(exp).unwrap().value().unwrap(); + assert_eq!(res, res_expected); + } + + assert!(cs.is_satisfied().unwrap()); + } + + #[test] + fn test_zero_exp() { + let mut rng = ark_std::test_rng(); + let cs = ConstraintSystem::new_ref(); + let exp = [0u64, 0u64, 0u64, 0u64]; + + let mut rand_base = Fr::rand(&mut rng); + + // ensure that rand_base is not zero. + if rand_base == Fr::zero() { + rand_base = rand_base + &Fr::one(); + } + + { + let rand_base_g = + FpVar::::new_witness(cs.clone(), || Ok(rand_base.clone())).unwrap(); + let res_expected = Fr::one(); + let res = rand_base_g.pow_by_constant(&exp).unwrap().value().unwrap(); + assert_eq!(res, res_expected); + } + + { + let rand_base_g = FpVar::::new_constant(cs.clone(), rand_base).unwrap(); + let res_expected = Fr::one(); + let res = rand_base_g.pow_by_constant(&exp).unwrap().value().unwrap(); + assert_eq!(res, res_expected); + } + + assert!(cs.is_satisfied().unwrap()); + } +}