From 3e05f5d10072cf37feab7a9df74b923d1d09bde6 Mon Sep 17 00:00:00 2001 From: ashWhiteHat Date: Tue, 21 Nov 2023 05:35:34 +0900 Subject: [PATCH] Reduce duplicate code across different curve cycle providers (#255) * refactor: impl folding macro * refactor: generalize curve test * chore: rename impl_folding to impl_engine --- src/provider/bn256_grumpkin.rs | 36 +-------- src/provider/mod.rs | 136 +++++++++++++++++++++++++-------- src/provider/pasta.rs | 99 ++++-------------------- src/provider/secp_secq.rs | 36 +-------- 4 files changed, 121 insertions(+), 186 deletions(-) diff --git a/src/provider/bn256_grumpkin.rs b/src/provider/bn256_grumpkin.rs index d1831bb8..f19649b3 100644 --- a/src/provider/bn256_grumpkin.rs +++ b/src/provider/bn256_grumpkin.rs @@ -1,6 +1,6 @@ //! This module implements the Nova traits for `bn256::Point`, `bn256::Scalar`, `grumpkin::Point`, `grumpkin::Scalar`. use crate::{ - impl_traits, + impl_engine, impl_traits, provider::{ cpu_best_multiexp, keccak::Keccak256Transcript, @@ -69,37 +69,3 @@ impl_traits!( "30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", "30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001" ); - -#[cfg(test)] -mod tests { - use super::*; - type G = bn256::Point; - - fn from_label_serial(label: &'static [u8], n: usize) -> Vec { - let mut shake = Shake256::default(); - shake.update(label); - let mut reader = shake.finalize_xof(); - let mut ck = Vec::new(); - for _ in 0..n { - let mut uniform_bytes = [0u8; 32]; - reader.read_exact(&mut uniform_bytes).unwrap(); - let hash = bn256::Point::hash_to_curve("from_uniform_bytes"); - ck.push(hash(&uniform_bytes).to_affine()); - } - ck - } - - #[test] - fn test_from_label() { - let label = b"test_from_label"; - for n in [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 1021, - ] { - let ck_par = ::from_label(label, n); - let ck_ser = from_label_serial(label, n); - assert_eq!(ck_par.len(), n); - assert_eq!(ck_ser.len(), n); - assert_eq!(ck_par, ck_ser); - } - } -} diff --git a/src/provider/mod.rs b/src/provider/mod.rs index fa4b9e07..29273743 100644 --- a/src/provider/mod.rs +++ b/src/provider/mod.rs @@ -229,29 +229,14 @@ macro_rules! impl_traits { $order_str:literal, $base_str:literal ) => { - impl Engine for $engine { - type Base = $name::Base; - type Scalar = $name::Scalar; - type GE = $name::Point; - type RO = PoseidonRO; - type ROCircuit = PoseidonROCircuit; - type TE = Keccak256Transcript; - type CE = CommitmentEngine; - } - - impl Group for $name::Point { - type Base = $name::Base; - type Scalar = $name::Scalar; - - fn group_params() -> (Self::Base, Self::Base, BigInt, BigInt) { - let A = $name::Point::a(); - let B = $name::Point::b(); - let order = BigInt::from_str_radix($order_str, 16).unwrap(); - let base = BigInt::from_str_radix($base_str, 16).unwrap(); - - (A, B, order, base) - } - } + impl_engine!( + $engine, + $name, + $name_compressed, + $name_curve, + $order_str, + $base_str + ); impl DlogGroup for $name::Point { type CompressedGroupElement = $name_compressed; @@ -335,10 +320,11 @@ macro_rules! impl_traits { } } - impl PrimeFieldExt for $name::Scalar { - fn from_uniform(bytes: &[u8]) -> Self { - let bytes_arr: [u8; 64] = bytes.try_into().unwrap(); - $name::Scalar::from_uniform_bytes(&bytes_arr) + impl CompressedGroup for $name_compressed { + type GroupElement = $name::Point; + + fn decompress(&self) -> Option<$name::Point> { + Some($name_curve::from_bytes(&self).unwrap()) } } @@ -347,12 +333,48 @@ macro_rules! impl_traits { self.as_ref().to_vec() } } + }; +} - impl CompressedGroup for $name_compressed { - type GroupElement = $name::Point; +/// Nova folding circuit engine and curve group ops +#[macro_export] +macro_rules! impl_engine { + ( + $engine:ident, + $name:ident, + $name_compressed:ident, + $name_curve:ident, + $order_str:literal, + $base_str:literal + ) => { + impl Engine for $engine { + type Base = $name::Base; + type Scalar = $name::Scalar; + type GE = $name::Point; + type RO = PoseidonRO; + type ROCircuit = PoseidonROCircuit; + type TE = Keccak256Transcript; + type CE = CommitmentEngine; + } - fn decompress(&self) -> Option<$name::Point> { - Some($name_curve::from_bytes(&self).unwrap()) + impl Group for $name::Point { + type Base = $name::Base; + type Scalar = $name::Scalar; + + fn group_params() -> (Self::Base, Self::Base, BigInt, BigInt) { + let A = $name::Point::a(); + let B = $name::Point::b(); + let order = BigInt::from_str_radix($order_str, 16).unwrap(); + let base = BigInt::from_str_radix($base_str, 16).unwrap(); + + (A, B, order, base) + } + } + + impl PrimeFieldExt for $name::Scalar { + fn from_uniform(bytes: &[u8]) -> Self { + let bytes_arr: [u8; 64] = bytes.try_into().unwrap(); + $name::Scalar::from_uniform_bytes(&bytes_arr) } } @@ -371,11 +393,44 @@ mod tests { use crate::provider::{ bn256_grumpkin::{bn256, grumpkin}, secp_secq::{secp256k1, secq256k1}, + DlogGroup, }; - use group::{ff::Field, Group}; - use halo2curves::CurveAffine; + use digest::{ExtendableOutput, Update}; + use group::{ff::Field, Curve, Group}; + use halo2curves::{CurveAffine, CurveExt}; use pasta_curves::{pallas, vesta}; use rand_core::OsRng; + use sha3::Shake256; + use std::io::Read; + + macro_rules! impl_cycle_pair_test { + ($curve:ident) => { + fn from_label_serial(label: &'static [u8], n: usize) -> Vec<$curve::Affine> { + let mut shake = Shake256::default(); + shake.update(label); + let mut reader = shake.finalize_xof(); + (0..n) + .map(|_| { + let mut uniform_bytes = [0u8; 32]; + reader.read_exact(&mut uniform_bytes).unwrap(); + let hash = $curve::Point::hash_to_curve("from_uniform_bytes"); + hash(&uniform_bytes).to_affine() + }) + .collect() + } + + let label = b"test_from_label"; + for n in [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 1021, + ] { + let ck_par = <$curve::Point as DlogGroup>::from_label(label, n); + let ck_ser = from_label_serial(label, n); + assert_eq!(ck_par.len(), n); + assert_eq!(ck_ser.len(), n); + assert_eq!(ck_par, ck_ser); + } + }; + } fn test_msm_with>() { let n = 8; @@ -403,4 +458,19 @@ mod tests { test_msm_with::(); test_msm_with::(); } + + #[test] + fn test_bn256_from_label() { + impl_cycle_pair_test!(bn256); + } + + #[test] + fn test_pallas_from_label() { + impl_cycle_pair_test!(pallas); + } + + #[test] + fn test_secp256k1_from_label() { + impl_cycle_pair_test!(secp256k1); + } } diff --git a/src/provider/pasta.rs b/src/provider/pasta.rs index 50e97c4a..902cf8d7 100644 --- a/src/provider/pasta.rs +++ b/src/provider/pasta.rs @@ -1,5 +1,6 @@ //! This module implements the Nova traits for `pallas::Point`, `pallas::Scalar`, `vesta::Point`, `vesta::Scalar`. use crate::{ + impl_engine, provider::{ cpu_best_multiexp, keccak::Keccak256Transcript, @@ -68,62 +69,41 @@ macro_rules! impl_traits { $order_str:literal, $base_str:literal ) => { - impl Engine for $engine { - type Base = $name::Base; - type Scalar = $name::Scalar; - type GE = $name::Point; - type RO = PoseidonRO; - type ROCircuit = PoseidonROCircuit; - type TE = Keccak256Transcript; - type CE = CommitmentEngine; - } - - impl Group for $name::Point { - type Base = $name::Base; - type Scalar = $name::Scalar; - - fn group_params() -> (Self::Base, Self::Base, BigInt, BigInt) { - let A = $name::Point::a(); - let B = $name::Point::b(); - let order = BigInt::from_str_radix($order_str, 16).unwrap(); - let base = BigInt::from_str_radix($base_str, 16).unwrap(); - - (A, B, order, base) - } - } + impl_engine!( + $engine, + $name, + $name_compressed, + $name_curve, + $order_str, + $base_str + ); impl DlogGroup for $name::Point { type CompressedGroupElement = $name_compressed; type PreprocessedGroupElement = $name::Affine; - #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] fn vartime_multiscalar_mul( scalars: &[Self::Scalar], bases: &[Self::PreprocessedGroupElement], ) -> Self { + #[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))] if scalars.len() >= 128 { pasta_msm::$name(bases, scalars) } else { cpu_best_multiexp(scalars, bases) } + #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] + cpu_best_multiexp(scalars, bases) } - #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))] - fn vartime_multiscalar_mul( - scalars: &[Self::Scalar], - bases: &[Self::PreprocessedGroupElement], - ) -> Self { - cpu_best_multiexp(scalars, bases) + fn preprocessed(&self) -> Self::PreprocessedGroupElement { + self.to_affine() } fn compress(&self) -> Self::CompressedGroupElement { $name_compressed::new(self.to_bytes()) } - fn preprocessed(&self) -> Self::PreprocessedGroupElement { - self.to_affine() - } - fn from_label(label: &'static [u8], n: usize) -> Vec { let mut shake = Shake256::default(); shake.update(label); @@ -184,19 +164,6 @@ macro_rules! impl_traits { } } - impl PrimeFieldExt for $name::Scalar { - fn from_uniform(bytes: &[u8]) -> Self { - let bytes_arr: [u8; 64] = bytes.try_into().unwrap(); - $name::Scalar::from_uniform_bytes(&bytes_arr) - } - } - - impl TranscriptReprTrait for $name_compressed { - fn to_transcript_bytes(&self) -> Vec { - self.repr.to_vec() - } - } - impl CompressedGroup for $name_compressed { type GroupElement = $name::Point; @@ -205,9 +172,9 @@ macro_rules! impl_traits { } } - impl TranscriptReprTrait for $name::Scalar { + impl TranscriptReprTrait for $name_compressed { fn to_transcript_bytes(&self) -> Vec { - self.to_repr().to_vec() + self.repr.to_vec() } } }; @@ -232,37 +199,3 @@ impl_traits!( "40000000000000000000000000000000224698fc094cf91b992d30ed00000001", "40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001" ); - -#[cfg(test)] -mod tests { - use super::*; - type G = ::GE; - - fn from_label_serial(label: &'static [u8], n: usize) -> Vec { - let mut shake = Shake256::default(); - shake.update(label); - let mut reader = shake.finalize_xof(); - let mut ck = Vec::new(); - for _ in 0..n { - let mut uniform_bytes = [0u8; 32]; - reader.read_exact(&mut uniform_bytes).unwrap(); - let hash = Ep::hash_to_curve("from_uniform_bytes"); - ck.push(hash(&uniform_bytes).to_affine()); - } - ck - } - - #[test] - fn test_from_label() { - let label = b"test_from_label"; - for n in [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 1021, - ] { - let ck_par = ::from_label(label, n); - let ck_ser = from_label_serial(label, n); - assert_eq!(ck_par.len(), n); - assert_eq!(ck_ser.len(), n); - assert_eq!(ck_par, ck_ser); - } - } -} diff --git a/src/provider/secp_secq.rs b/src/provider/secp_secq.rs index ef625185..4f2aa71e 100644 --- a/src/provider/secp_secq.rs +++ b/src/provider/secp_secq.rs @@ -1,6 +1,6 @@ //! This module implements the Nova traits for `secp::Point`, `secp::Scalar`, `secq::Point`, `secq::Scalar`. use crate::{ - impl_traits, + impl_engine, impl_traits, provider::{ cpu_best_multiexp, keccak::Keccak256Transcript, @@ -66,37 +66,3 @@ impl_traits!( "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f", "fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141" ); - -#[cfg(test)] -mod tests { - use super::*; - type G = secp256k1::Point; - - fn from_label_serial(label: &'static [u8], n: usize) -> Vec { - let mut shake = Shake256::default(); - shake.update(label); - let mut reader = shake.finalize_xof(); - let mut ck = Vec::new(); - for _ in 0..n { - let mut uniform_bytes = [0u8; 32]; - reader.read_exact(&mut uniform_bytes).unwrap(); - let hash = secp256k1::Point::hash_to_curve("from_uniform_bytes"); - ck.push(hash(&uniform_bytes).to_affine()); - } - ck - } - - #[test] - fn test_from_label() { - let label = b"test_from_label"; - for n in [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 1021, - ] { - let ck_par = ::from_label(label, n); - let ck_ser = from_label_serial(label, n); - assert_eq!(ck_par.len(), n); - assert_eq!(ck_ser.len(), n); - assert_eq!(ck_par, ck_ser); - } - } -}