diff --git a/ipa-core/benches/dzkp_convert_prover.rs b/ipa-core/benches/dzkp_convert_prover.rs index 57b557735..c8f820bab 100644 --- a/ipa-core/benches/dzkp_convert_prover.rs +++ b/ipa-core/benches/dzkp_convert_prover.rs @@ -1,41 +1,15 @@ -//! Benchmark for the convert_prover function in dzkp_field.rs. +//! Benchmark for the table_indices_prover function in dzkp_field.rs. use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; -use ipa_core::{ - ff::Fp61BitPrime, - protocol::context::{dzkp_field::DZKPBaseField, dzkp_validator::MultiplicationInputsBlock}, -}; +use ipa_core::protocol::context::dzkp_validator::MultiplicationInputsBlock; use rand::{thread_rng, Rng}; fn convert_prover_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("dzkp_convert_prover"); group.bench_function("convert", |b| { b.iter_batched_ref( - || { - // Generate input - let mut rng = thread_rng(); - - MultiplicationInputsBlock { - x_left: rng.gen::<[u8; 32]>().into(), - x_right: rng.gen::<[u8; 32]>().into(), - y_left: rng.gen::<[u8; 32]>().into(), - y_right: rng.gen::<[u8; 32]>().into(), - prss_left: rng.gen::<[u8; 32]>().into(), - prss_right: rng.gen::<[u8; 32]>().into(), - z_right: rng.gen::<[u8; 32]>().into(), - } - }, - |input| { - let MultiplicationInputsBlock { - x_left, - x_right, - y_left, - y_right, - prss_right, - .. - } = input; - Fp61BitPrime::convert_prover(x_left, x_right, y_left, y_right, prss_right); - }, + || thread_rng().gen(), + |input: &mut MultiplicationInputsBlock| input.table_indices_prover(), BatchSize::SmallInput, ) }); diff --git a/ipa-core/src/protocol/context/dzkp_field.rs b/ipa-core/src/protocol/context/dzkp_field.rs index 368ac36cc..758ad008b 100644 --- a/ipa-core/src/protocol/context/dzkp_field.rs +++ b/ipa-core/src/protocol/context/dzkp_field.rs @@ -1,18 +1,13 @@ -use std::{iter::zip, sync::LazyLock}; +use std::{ops::Index, sync::LazyLock}; use bitvec::field::BitField; use crate::{ ff::{Field, Fp61BitPrime, PrimeField}, - protocol::context::dzkp_validator::{Array256Bit, SegmentEntry}, + protocol::context::dzkp_validator::{Array256Bit, MultiplicationInputsBlock, SegmentEntry}, secret_sharing::{FieldSimd, SharedValue, Vectorizable}, }; -// BlockSize is fixed to 32 -pub const BLOCK_SIZE: usize = 32; -// UVTupleBlock is a block of interleaved U and V values -pub type UVTupleBlock = ([F; BLOCK_SIZE], [F; BLOCK_SIZE]); - /// Trait for fields compatible with DZKPs /// Field needs to support conversion to `SegmentEntry`, i.e. `to_segment_entry` which is required by DZKPs pub trait DZKPCompatibleField: FieldSimd { @@ -25,35 +20,12 @@ pub trait DZKPBaseField: PrimeField { const INVERSE_OF_TWO: Self; const MINUS_ONE_HALF: Self; const MINUS_TWO: Self; +} - /// Convert allows to convert individual bits from multiplication gates into dzkp compatible field elements. - /// This function is called by the prover. - fn convert_prover<'a>( - x_left: &'a Array256Bit, - x_right: &'a Array256Bit, - y_left: &'a Array256Bit, - y_right: &'a Array256Bit, - prss_right: &'a Array256Bit, - ) -> Vec>; - - /// This is similar to `convert_prover` except that it is called by the verifier to the left of the prover. - /// The verifier on the left uses its right shares, since they are consistent with the prover's left shares. - /// This produces the 'u' values. - fn convert_value_from_right_prover<'a>( - x_right: &'a Array256Bit, - y_right: &'a Array256Bit, - prss_right: &'a Array256Bit, - z_right: &'a Array256Bit, - ) -> Vec; - - /// This is similar to `convert_prover` except that it is called by the verifier to the right of the prover. - /// The verifier on the right uses its left shares, since they are consistent with the prover's right shares. - /// This produces the 'v' values - fn convert_value_from_left_prover<'a>( - x_left: &'a Array256Bit, - y_left: &'a Array256Bit, - prss_left: &'a Array256Bit, - ) -> Vec; +impl DZKPBaseField for Fp61BitPrime { + const INVERSE_OF_TWO: Self = Fp61BitPrime::const_truncate(1_152_921_504_606_846_976u64); + const MINUS_ONE_HALF: Self = Fp61BitPrime::const_truncate(1_152_921_504_606_846_975u64); + const MINUS_TWO: Self = Fp61BitPrime::const_truncate(2_305_843_009_213_693_949u64); } impl FromIterator for [Fp61BitPrime; P] { @@ -90,7 +62,7 @@ impl FromIterator for Vec<[Fp61BitPrime; P]> { } } -/// Construct indices for the `convert_values` lookup tables. +/// Construct indices for the `TABLE_U` and `TABLE_V` lookup tables. /// /// `b0` has the least significant bit of each index, and `b1` and `b2` the subsequent /// bits. This routine rearranges the bits so that there is one table index in each @@ -102,7 +74,7 @@ impl FromIterator for Vec<[Fp61BitPrime; P]> { /// (i%4) == j. The "0s", "1s", "2s", "3s" comments trace the movement from /// input to output. #[must_use] -fn convert_values_table_indices(b0: u128, b1: u128, b2: u128) -> [u128; 4] { +fn bits_to_table_indices(b0: u128, b1: u128, b2: u128) -> [u128; 4] { // 0x55 is 0b0101_0101. This mask selects bits having (i%2) == 0. const CONST_55: u128 = u128::from_le_bytes([0x55; 16]); // 0xaa is 0b1010_1010. This mask selects bits having (i%2) == 1. @@ -148,10 +120,31 @@ fn convert_values_table_indices(b0: u128, b1: u128, b2: u128) -> [u128; 4] { [y0, y1, y2, y3] } +pub struct UVTable(pub [[F; 4]; 8]); + +impl Index for UVTable { + type Output = [F; 4]; + + fn index(&self, index: u8) -> &Self::Output { + self.0.index(usize::from(index)) + } +} + +impl Index for UVTable { + type Output = [F; 4]; + + fn index(&self, index: usize) -> &Self::Output { + self.0.index(index) + } +} + // Table used for `convert_prover` and `convert_value_from_right_prover`. // -// The conversion to "g" and "h" values is from https://eprint.iacr.org/2023/909.pdf. -static TABLE_RIGHT: LazyLock<[[Fp61BitPrime; 4]; 8]> = LazyLock::new(|| { +// This table is for "g" or "u" values. This table is "right" on a verifier when it is +// processing values for the prover on its right. On a prover, this table is "left". +// +// The conversion logic is from https://eprint.iacr.org/2023/909.pdf. +pub static TABLE_U: LazyLock> = LazyLock::new(|| { let mut result = Vec::with_capacity(8); for e in [false, true] { for c in [false, true] { @@ -172,13 +165,16 @@ static TABLE_RIGHT: LazyLock<[[Fp61BitPrime; 4]; 8]> = LazyLock::new(|| { } } } - result.try_into().unwrap() + UVTable(result.try_into().unwrap()) }); // Table used for `convert_prover` and `convert_value_from_left_prover`. // -// The conversion to "g" and "h" values is from https://eprint.iacr.org/2023/909.pdf. -static TABLE_LEFT: LazyLock<[[Fp61BitPrime; 4]; 8]> = LazyLock::new(|| { +// This table is for "h" or "v" values. This table is "left" on a verifier when it is +// processing values for the prover on its left. On a prover, this table is "right". +// +// The conversion logic is from https://eprint.iacr.org/2023/909.pdf. +pub static TABLE_V: LazyLock> = LazyLock::new(|| { let mut result = Vec::with_capacity(8); for f in [false, true] { for d in [false, true] { @@ -199,31 +195,31 @@ static TABLE_LEFT: LazyLock<[[Fp61BitPrime; 4]; 8]> = LazyLock::new(|| { } } } - result.try_into().unwrap() + UVTable(result.try_into().unwrap()) }); -/// Lookup-table-based conversion logic used by `convert_prover`, -/// `convert_value_from_left_prover`, and `convert_value_from_left_prover`. +/// Lookup-table-based conversion logic used by `table_indices_prover`, +/// `table_indices_from_right_prover`, and `table_indices_from_left_prover`. /// /// Inputs `i0`, `i1`, and `i2` each contain the value of one of the "a" through "f" /// intermediates for each of 256 multiplies. `table` is the lookup table to use, -/// which should be either `TABLE_LEFT` or `TABLE_RIGHT`. +/// which should be either `TABLE_U` or `TABLE_V`. /// /// We want to interpret the 3-tuple of intermediates at each bit position in `i0`, `i1` /// and `i2` as an integer index in the range 0..8 into the table. The -/// `convert_values_table_indices` helper does this in bulk more efficiently than using +/// `bits_to_table_indices` helper does this in bulk more efficiently than using /// bit-manipulation to handle them one-by-one. /// /// Preserving the order from inputs to outputs is not necessary for correctness as long /// as the same order is used on all three helpers. We preserve the order anyways /// to simplify the end-to-end dataflow, even though it makes this routine slightly /// more complicated. -fn convert_values( +fn intermediates_to_table_indices<'a>( i0: &Array256Bit, i1: &Array256Bit, i2: &Array256Bit, - table: &[[Fp61BitPrime; 4]; 8], -) -> Vec { + mut out: impl Iterator, +) { // Split inputs to two `u128`s. We do this because `u128` is the largest integer // type rust supports. It is possible that using SIMD types here would improve // code generation for AVX-256/512. @@ -238,45 +234,49 @@ fn convert_values( // Output word `j` in each set contains the table indices for input positions `i` // having (i%4) == j. - let [mut z00, mut z01, mut z02, mut z03] = convert_values_table_indices(i00, i10, i20); - let [mut z10, mut z11, mut z12, mut z13] = convert_values_table_indices(i01, i11, i21); + let [mut z00, mut z01, mut z02, mut z03] = bits_to_table_indices(i00, i10, i20); + let [mut z10, mut z11, mut z12, mut z13] = bits_to_table_indices(i01, i11, i21); - let mut result = Vec::with_capacity(1024); + #[allow(clippy::cast_possible_truncation)] for _ in 0..32 { // Take one index in turn from each `z` to preserve the output order. - for z in [&mut z00, &mut z01, &mut z02, &mut z03] { - result.extend(table[(*z as usize) & 0x7]); - *z >>= 4; - } + *out.next().unwrap() = (z00 as u8) & 0x7; + z00 >>= 4; + *out.next().unwrap() = (z01 as u8) & 0x7; + z01 >>= 4; + *out.next().unwrap() = (z02 as u8) & 0x7; + z02 >>= 4; + *out.next().unwrap() = (z03 as u8) & 0x7; + z03 >>= 4; } + #[allow(clippy::cast_possible_truncation)] for _ in 0..32 { - for z in [&mut z10, &mut z11, &mut z12, &mut z13] { - result.extend(table[(*z as usize) & 0x7]); - *z >>= 4; - } + *out.next().unwrap() = (z10 as u8) & 0x7; + z10 >>= 4; + *out.next().unwrap() = (z11 as u8) & 0x7; + z11 >>= 4; + *out.next().unwrap() = (z12 as u8) & 0x7; + z12 >>= 4; + *out.next().unwrap() = (z13 as u8) & 0x7; + z13 >>= 4; } - debug_assert!(result.len() == 1024); - - result + debug_assert!(out.next().is_none()); } -impl DZKPBaseField for Fp61BitPrime { - const INVERSE_OF_TWO: Self = Fp61BitPrime::const_truncate(1_152_921_504_606_846_976u64); - const MINUS_ONE_HALF: Self = Fp61BitPrime::const_truncate(1_152_921_504_606_846_975u64); - const MINUS_TWO: Self = Fp61BitPrime::const_truncate(2_305_843_009_213_693_949u64); - - // Convert allows to convert individual bits from multiplication gates into dzkp compatible field elements - // We use the conversion logic from https://eprint.iacr.org/2023/909.pdf +impl MultiplicationInputsBlock { + /// Repack the intermediates in this block into lookup indices for `TABLE_U` and `TABLE_V`. + /// + /// This is the convert function called by the prover. // - // This function does not use any optimization. + // We use the conversion logic from https://eprint.iacr.org/2023/909.pdf // - // Prover and left can compute: + // Prover and the verifier on its left compute: // g1=-2ac(1-2e), // g2=c(1-2e), // g3=a(1-2e), // g4=-1/2(1-2e), // - // Prover and right can compute: + // Prover and the verifier on its right compute: // h1=bd(1-2f), // h2=d(1-2f), // h3=b(1-2f), @@ -292,33 +292,30 @@ impl DZKPBaseField for Fp61BitPrime { // therefore e = ab⊕cd⊕ f must hold. (alternatively, you can also see this by substituting z_left, // i.e. z_left = x_left · y_left ⊕ x_left · y_right ⊕ x_right · y_left ⊕ prss_left ⊕ prss_right #[allow(clippy::many_single_char_names)] - fn convert_prover<'a>( - x_left: &'a Array256Bit, - x_right: &'a Array256Bit, - y_left: &'a Array256Bit, - y_right: &'a Array256Bit, - prss_right: &'a Array256Bit, - ) -> Vec> { - let a = x_left; - let b = y_right; - let c = y_left; - let d = x_right; + #[must_use] + pub fn table_indices_prover(&self) -> Vec<(u8, u8)> { + let a = &self.x_left; + let b = &self.y_right; + let c = &self.y_left; + let d = &self.x_right; // e = ab ⊕ cd ⊕ f = x_left * y_right ⊕ y_left * x_right ⊕ prss_right - let e = (*x_left & y_right) ^ (*y_left & x_right) ^ prss_right; - let f = prss_right; - - let g = convert_values(a, c, &e, &TABLE_RIGHT); - let h = convert_values(b, d, f, &TABLE_LEFT); + let e = (self.x_left & self.y_right) ^ (self.y_left & self.x_right) ^ self.prss_right; + let f = &self.prss_right; - zip(g.chunks_exact(BLOCK_SIZE), h.chunks_exact(BLOCK_SIZE)) - .map(|(g_chunk, h_chunk)| (g_chunk.try_into().unwrap(), h_chunk.try_into().unwrap())) - .collect() + let mut output = vec![(0u8, 0u8); 256]; + intermediates_to_table_indices(a, c, &e, output.iter_mut().map(|tup| &mut tup.0)); + intermediates_to_table_indices(b, d, f, output.iter_mut().map(|tup| &mut tup.1)); + output } - // Convert allows to convert individual bits from multiplication gates into dzkp compatible field elements + /// Repack the intermediates in this block into lookup indices for `TABLE_U`. + /// + /// This is the convert function called by the verifier when processing for the + /// prover on its right. + // // We use the conversion logic from https://eprint.iacr.org/2023/909.pdf // - // Prover and left can compute: + // Prover and the verifier on its left compute: // g1=-2ac(1-2e), // g2=c(1-2e), // g3=a(1-2e), @@ -328,25 +325,27 @@ impl DZKPBaseField for Fp61BitPrime { // (a,c,e) = (x_right, y_right, x_right * y_right ⊕ z_right ⊕ prss_right) // here e is defined as in the paper (since the the verifier does not have access to b,d,f, // he cannot use the simplified formula for e) - fn convert_value_from_right_prover<'a>( - x_right: &'a Array256Bit, - y_right: &'a Array256Bit, - prss_right: &'a Array256Bit, - z_right: &'a Array256Bit, - ) -> Vec { - let a = x_right; - let c = y_right; + #[must_use] + pub fn table_indices_from_right_prover(&self) -> Vec { + let a = &self.x_right; + let c = &self.y_right; // e = ac ⊕ zright ⊕ prssright // as defined in the paper - let e = (*a & *c) ^ prss_right ^ z_right; + let e = (self.x_right & self.y_right) ^ self.prss_right ^ self.z_right; - convert_values(a, c, &e, &TABLE_RIGHT) + let mut output = vec![0u8; 256]; + intermediates_to_table_indices(a, c, &e, output.iter_mut()); + output } - // Convert allows to convert individual bits from multiplication gates into dzkp compatible field elements + /// Repack the intermediates in this block into lookup indices for `TABLE_V`. + /// + /// This is the convert function called by the verifier when processing for the + /// prover on its left. + // // We use the conversion logic from https://eprint.iacr.org/2023/909.pdf // - // Prover and right can compute: + // The prover and the verifier on its right compute: // h1=bd(1-2f), // h2=d(1-2f), // h3=b(1-2f), @@ -354,31 +353,33 @@ impl DZKPBaseField for Fp61BitPrime { // // where // (b,d,f) = (y_left, x_left, prss_left) - fn convert_value_from_left_prover<'a>( - x_left: &'a Array256Bit, - y_left: &'a Array256Bit, - prss_left: &'a Array256Bit, - ) -> Vec { - let b = y_left; - let d = x_left; - let f = prss_left; - - convert_values(b, d, f, &TABLE_LEFT) + #[must_use] + pub fn table_indices_from_left_prover(&self) -> Vec { + let b = &self.y_left; + let d = &self.x_left; + let f = &self.prss_left; + + let mut output = vec![0u8; 256]; + intermediates_to_table_indices(b, d, f, output.iter_mut()); + output } } #[cfg(all(test, unit_test))] -mod tests { - use bitvec::{array::BitArray, macros::internal::funty::Fundamental, slice::BitSlice}; +pub mod tests { + + use bitvec::{array::BitArray, macros::internal::funty::Fundamental}; use proptest::proptest; use rand::{thread_rng, Rng}; use crate::{ ff::{Field, Fp61BitPrime, U128Conversions}, - protocol::context::dzkp_field::{ - convert_values_table_indices, DZKPBaseField, UVTupleBlock, BLOCK_SIZE, + protocol::context::{ + dzkp_field::{bits_to_table_indices, DZKPBaseField, TABLE_U, TABLE_V}, + dzkp_validator::MultiplicationInputsBlock, }, secret_sharing::SharedValue, + test_executor::run_random, }; #[test] @@ -386,7 +387,7 @@ mod tests { let b0 = 0xaa; let b1 = 0xcc; let b2 = 0xf0; - let [z0, z1, z2, z3] = convert_values_table_indices(b0, b1, b2); + let [z0, z1, z2, z3] = bits_to_table_indices(b0, b1, b2); assert_eq!(z0, 0x40_u128); assert_eq!(z1, 0x51_u128); assert_eq!(z2, 0x62_u128); @@ -396,7 +397,7 @@ mod tests { let b0 = rng.gen(); let b1 = rng.gen(); let b2 = rng.gen(); - let [z0, z1, z2, z3] = convert_values_table_indices(b0, b1, b2); + let [z0, z1, z2, z3] = bits_to_table_indices(b0, b1, b2); for i in (0..128).step_by(4) { fn check(i: u32, j: u32, b0: u128, b1: u128, b2: u128, z: u128) { @@ -417,85 +418,73 @@ mod tests { } } - #[test] - fn batch_convert() { - let mut rng = thread_rng(); + impl MultiplicationInputsBlock { + /// Rotate the "right" values into the "left" values, setting the right values + /// to zero. If the input represents a prover's block of intermediates, the + /// output represents the intermediates that the verifier on the prover's right + /// shares with it. + #[must_use] + pub fn rotate_left(&self) -> Self { + Self { + x_left: self.x_right, + y_left: self.y_right, + prss_left: self.prss_right, + x_right: [0u8; 32].into(), + y_right: [0u8; 32].into(), + prss_right: [0u8; 32].into(), + z_right: [0u8; 32].into(), + } + } - // bitvecs - let mut vec_x_left = Vec::::new(); - let mut vec_x_right = Vec::::new(); - let mut vec_y_left = Vec::::new(); - let mut vec_y_right = Vec::::new(); - let mut vec_prss_left = Vec::::new(); - let mut vec_prss_right = Vec::::new(); - let mut vec_z_right = Vec::::new(); - - // gen 32 random values - for _i in 0..32 { - let x_left: u8 = rng.gen(); - let x_right: u8 = rng.gen(); - let y_left: u8 = rng.gen(); - let y_right: u8 = rng.gen(); - let prss_left: u8 = rng.gen(); - let prss_right: u8 = rng.gen(); - // we set this up to be equal to z_right for this local test - // local here means that only a single party is involved - // and we just test this against this single party - let z_right: u8 = (x_left & y_left) - ^ (x_left & y_right) - ^ (x_right & y_left) - ^ prss_left - ^ prss_right; - - // fill vec - vec_x_left.push(x_left); - vec_x_right.push(x_right); - vec_y_left.push(y_left); - vec_y_right.push(y_right); - vec_prss_left.push(prss_left); - vec_prss_right.push(prss_right); - vec_z_right.push(z_right); + /// Rotate the "left" values into the "right" values, setting the left values to + /// zero. `z_right` is calculated to be consistent with the other values. If the + /// input represents a prover's block of intermediates, the output represents + /// the intermediates that the verifier on the prover's left shares with it. + #[must_use] + pub fn rotate_right(&self) -> Self { + let z_right = (self.x_left & self.y_left) + ^ (self.x_left & self.y_right) + ^ (self.x_right & self.y_left) + ^ self.prss_left + ^ self.prss_right; + + Self { + x_right: self.x_left, + y_right: self.y_left, + prss_right: self.prss_left, + x_left: [0u8; 32].into(), + y_left: [0u8; 32].into(), + prss_left: [0u8; 32].into(), + z_right, + } } + } - // conv to BitVec - let x_left = BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_x_left)).unwrap(); - let x_right = BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_x_right)).unwrap(); - let y_left = BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_y_left)).unwrap(); - let y_right = BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_y_right)).unwrap(); - let prss_left = - BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_prss_left)).unwrap(); - let prss_right = - BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_prss_right)).unwrap(); - let z_right = BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_z_right)).unwrap(); - - // check consistency of the polynomials - assert_convert( - Fp61BitPrime::convert_prover(&x_left, &x_right, &y_left, &y_right, &prss_right), - // flip intputs right to left since it is checked against itself and not party on the left - // z_right is set to match z_left - Fp61BitPrime::convert_value_from_right_prover(&x_left, &y_left, &prss_left, &z_right), - // flip intputs right to left since it is checked against itself and not party on the left - Fp61BitPrime::convert_value_from_left_prover(&x_right, &y_right, &prss_right), - ); + #[test] + fn batch_convert() { + run_random(|mut rng| async move { + let block = rng.gen::(); + + // When verifying, we rotate the intermediates to match what each prover + // would have. `rotate_right` also calculates z_right from the others. + assert_convert( + block.table_indices_prover(), + block.rotate_right().table_indices_from_right_prover(), + block.rotate_left().table_indices_from_left_prover(), + ); + }); } + fn assert_convert(prover: P, verifier_left: L, verifier_right: R) where - P: IntoIterator>, - L: IntoIterator, - R: IntoIterator, + P: IntoIterator, + L: IntoIterator, + R: IntoIterator, { prover .into_iter() - .zip( - verifier_left - .into_iter() - .collect::>(), - ) - .zip( - verifier_right - .into_iter() - .collect::>(), - ) + .zip(verifier_left.into_iter().collect::>()) + .zip(verifier_right.into_iter().collect::>()) .for_each(|((prover, verifier_left), verifier_right)| { assert_eq!(prover.0, verifier_left); assert_eq!(prover.1, verifier_right); @@ -534,37 +523,15 @@ mod tests { } #[allow(clippy::fn_params_excessive_bools)] - fn correctness_prover_values( + #[must_use] + pub fn reference_convert( x_left: bool, x_right: bool, y_left: bool, y_right: bool, prss_left: bool, prss_right: bool, - ) { - let mut array_x_left = BitArray::<[u8; 32]>::ZERO; - let mut array_x_right = BitArray::<[u8; 32]>::ZERO; - let mut array_y_left = BitArray::<[u8; 32]>::ZERO; - let mut array_y_right = BitArray::<[u8; 32]>::ZERO; - let mut array_prss_left = BitArray::<[u8; 32]>::ZERO; - let mut array_prss_right = BitArray::<[u8; 32]>::ZERO; - - // initialize bits - array_x_left.set(0, x_left); - array_x_right.set(0, x_right); - array_y_left.set(0, y_left); - array_y_right.set(0, y_right); - array_prss_left.set(0, prss_left); - array_prss_right.set(0, prss_right); - - let prover = Fp61BitPrime::convert_prover( - &array_x_left, - &array_x_right, - &array_y_left, - &array_y_right, - &array_prss_right, - )[0]; - + ) -> ([Fp61BitPrime; 4], [Fp61BitPrime; 4]) { // compute expected // (a,b,c,d,f) = (x_left, y_right, y_left, x_right, prss_right) // e = x_left · y_left ⊕ z_left ⊕ prss_left @@ -601,17 +568,59 @@ mod tests { // h4=1-2f, let h4 = one_minus_two_f; + ([g1, g2, g3, g4], [h1, h2, h3, h4]) + } + + #[allow(clippy::fn_params_excessive_bools)] + fn correctness_prover_values( + x_left: bool, + x_right: bool, + y_left: bool, + y_right: bool, + prss_left: bool, + prss_right: bool, + ) { + let mut array_x_left = BitArray::<[u8; 32]>::ZERO; + let mut array_x_right = BitArray::<[u8; 32]>::ZERO; + let mut array_y_left = BitArray::<[u8; 32]>::ZERO; + let mut array_y_right = BitArray::<[u8; 32]>::ZERO; + let mut array_prss_left = BitArray::<[u8; 32]>::ZERO; + let mut array_prss_right = BitArray::<[u8; 32]>::ZERO; + + // initialize bits + array_x_left.set(0, x_left); + array_x_right.set(0, x_right); + array_y_left.set(0, y_left); + array_y_right.set(0, y_right); + array_prss_left.set(0, prss_left); + array_prss_right.set(0, prss_right); + + let block = MultiplicationInputsBlock { + x_left: array_x_left, + x_right: array_x_right, + y_left: array_y_left, + y_right: array_y_right, + prss_left: array_prss_left, + prss_right: array_prss_right, + z_right: BitArray::ZERO, + }; + + let prover = block.table_indices_prover()[0]; + + let ([g1, g2, g3, g4], [h1, h2, h3, h4]) = + reference_convert(x_left, x_right, y_left, y_right, prss_left, prss_right); + // check expected == computed // g polynomial - assert_eq!(g1, prover.0[0]); - assert_eq!(g2, prover.0[1]); - assert_eq!(g3, prover.0[2]); - assert_eq!(g4, prover.0[3]); + assert_eq!(g1, TABLE_U[prover.0][0]); + assert_eq!(g2, TABLE_U[prover.0][1]); + assert_eq!(g3, TABLE_U[prover.0][2]); + assert_eq!(g4, TABLE_U[prover.0][3]); // h polynomial - assert_eq!(h1, prover.1[0]); - assert_eq!(h2, prover.1[1]); - assert_eq!(h3, prover.1[2]); - assert_eq!(h4, prover.1[3]); + assert_eq!(h1, TABLE_V[prover.1][0]); + assert_eq!(h2, TABLE_V[prover.1][1]); + assert_eq!(h3, TABLE_V[prover.1][2]); + assert_eq!(h4, TABLE_V[prover.1][3]); } } diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index fd616fc9c..5d8d58b2e 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -12,7 +12,7 @@ use crate::{ protocol::{ context::{ batcher::Batcher, - dzkp_field::{DZKPBaseField, UVTupleBlock}, + dzkp_field::{DZKPBaseField, TABLE_U, TABLE_V}, dzkp_malicious::DZKPUpgraded as MaliciousDZKPUpgraded, dzkp_semi_honest::DZKPUpgraded as SemiHonestDZKPUpgraded, step::DzkpValidationProtocolStep as Step, @@ -20,7 +20,8 @@ use crate::{ }, ipa_prf::{ validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, - CompressedProofGenerator, FirstProofGenerator, + CompressedProofGenerator, FirstProofGenerator, ProverTableIndices, + VerifierTableIndices, }, Gate, RecordId, RecordIdRange, }, @@ -33,7 +34,7 @@ pub type Array256Bit = BitArray<[u8; 32], Lsb0>; type BitSliceType = BitSlice; -const BIT_ARRAY_LEN: usize = 256; +pub const BIT_ARRAY_LEN: usize = 256; const BIT_ARRAY_MASK: usize = BIT_ARRAY_LEN - 1; const BIT_ARRAY_SHIFT: usize = BIT_ARRAY_LEN.ilog2() as usize; @@ -58,6 +59,12 @@ pub const TARGET_PROOF_SIZE: usize = 8192; #[cfg(not(test))] pub const TARGET_PROOF_SIZE: usize = 50_000_000; +/// Minimum proof recursion depth. +/// +/// This minimum avoids special cases in the implementation that would be otherwise +/// required when the initial and final recursion steps overlap. +pub const MIN_PROOF_RECURSION: usize = 2; + /// Maximum proof recursion depth. // // This is a hard limit. Each GF(2) multiply generates four G values and four H values, @@ -74,8 +81,6 @@ pub const TARGET_PROOF_SIZE: usize = 50_000_000; // Because the number of records in a proof batch is often rounded up to a power of two // (and less significantly, because multiplication intermediate storage gets rounded up // to blocks of 256), leaving some margin is advised. -// -// The implementation requires that MAX_PROOF_RECURSION is at least 2. pub const MAX_PROOF_RECURSION: usize = 14; /// `MultiplicationInputsBlock` is a block of fixed size of intermediate values @@ -159,34 +164,21 @@ impl MultiplicationInputsBlock { Ok(()) } +} - /// `Convert` allows to convert `MultiplicationInputs` into a format compatible with DZKPs - /// This is the convert function called by the prover. - fn convert_prover(&self) -> Vec> { - DF::convert_prover( - &self.x_left, - &self.x_right, - &self.y_left, - &self.y_right, - &self.prss_right, - ) - } - - /// `convert_values_from_right_prover` allows to convert `MultiplicationInputs` into a format compatible with DZKPs - /// This is the convert function called by the verifier on the left. - fn convert_values_from_right_prover(&self) -> Vec { - DF::convert_value_from_right_prover( - &self.x_right, - &self.y_right, - &self.prss_right, - &self.z_right, - ) - } - - /// `convert_values_from_left_prover` allows to convert `MultiplicationInputs` into a format compatible with DZKPs - /// This is the convert function called by the verifier on the right. - fn convert_values_from_left_prover(&self) -> Vec { - DF::convert_value_from_left_prover(&self.x_left, &self.y_left, &self.prss_left) +#[cfg(any(test, feature = "enable-benches"))] +impl rand::prelude::Distribution for rand::distributions::Standard { + fn sample(&self, rng: &mut R) -> MultiplicationInputsBlock { + let sample = >::sample; + MultiplicationInputsBlock { + x_left: sample(self, rng).into(), + x_right: sample(self, rng).into(), + y_left: sample(self, rng).into(), + y_right: sample(self, rng).into(), + prss_left: sample(self, rng).into(), + prss_right: sample(self, rng).into(), + z_right: sample(self, rng).into(), + } } } @@ -472,34 +464,30 @@ impl MultiplicationInputsBatch { } } - /// `get_field_values_prover` converts a `MultiplicationInputsBatch` into an iterator over `field` - /// values used by the prover of the DZKPs - fn get_field_values_prover( - &self, - ) -> impl Iterator> + Clone + '_ { + /// `get_field_values_prover` converts a `MultiplicationInputsBatch` into an + /// iterator over pairs of indices for `TABLE_U` and `TABLE_V`. + fn get_field_values_prover(&self) -> impl Iterator + Clone + '_ { self.vec .iter() - .flat_map(MultiplicationInputsBlock::convert_prover::) + .flat_map(MultiplicationInputsBlock::table_indices_prover) } - /// `get_field_values_from_right_prover` converts a `MultiplicationInputsBatch` into an iterator over `field` - /// values used by the verifier of the DZKPs on the left side of the prover, i.e. the `u` values. - fn get_field_values_from_right_prover( - &self, - ) -> impl Iterator + '_ { + /// `get_field_values_from_right_prover` converts a `MultiplicationInputsBatch` into + /// an iterator over table indices for `TABLE_U`, which is used by the verifier of + /// the DZKPs on the left side of the prover. + fn get_field_values_from_right_prover(&self) -> impl Iterator + '_ { self.vec .iter() - .flat_map(MultiplicationInputsBlock::convert_values_from_right_prover::) + .flat_map(MultiplicationInputsBlock::table_indices_from_right_prover) } - /// `get_field_values_from_left_prover` converts a `MultiplicationInputsBatch` into an iterator over `field` - /// values used by the verifier of the DZKPs on the right side of the prover, i.e. the `v` values. - fn get_field_values_from_left_prover( - &self, - ) -> impl Iterator + '_ { + /// `get_field_values_from_left_prover` converts a `MultiplicationInputsBatch` into + /// an iterator over table indices for `TABLE_V`, which is used by the verifier of + /// the DZKPs on the right side of the prover. + fn get_field_values_from_left_prover(&self) -> impl Iterator + '_ { self.vec .iter() - .flat_map(MultiplicationInputsBlock::convert_values_from_left_prover::) + .flat_map(MultiplicationInputsBlock::table_indices_from_left_prover) } } @@ -565,36 +553,30 @@ impl Batch { .sum() } - /// `get_field_values_prover` converts a `Batch` into an iterator over field values - /// which is used by the prover of the DZKP - fn get_field_values_prover( - &self, - ) -> impl Iterator> + Clone + '_ { + /// `get_field_values_prover` converts a `Batch` into an iterator over pairs of + /// indices for `TABLE_U` and `TABLE_V`. + fn get_field_values_prover(&self) -> impl Iterator + Clone + '_ { self.inner .values() - .flat_map(MultiplicationInputsBatch::get_field_values_prover::) + .flat_map(MultiplicationInputsBatch::get_field_values_prover) } - /// `get_field_values_from_right_prover` converts a `Batch` into an iterator over field values - /// which is used by the verifier of the DZKP on the left side of the prover. - /// This produces the `u` values. - fn get_field_values_from_right_prover( - &self, - ) -> impl Iterator + '_ { + /// `get_field_values_from_right_prover` converts a `Batch` into an iterator over + /// table indices for `TABLE_U`, which is used by the verifier of the DZKP on the + /// left side of the prover. + fn get_field_values_from_right_prover(&self) -> impl Iterator + '_ { self.inner .values() - .flat_map(MultiplicationInputsBatch::get_field_values_from_right_prover::) + .flat_map(MultiplicationInputsBatch::get_field_values_from_right_prover) } - /// `get_field_values_from_left_prover` converts a `Batch` into an iterator over field values - /// which is used by the verifier of the DZKP on the right side of the prover. - /// This produces the `v` values. - fn get_field_values_from_left_prover( - &self, - ) -> impl Iterator + '_ { + /// `get_field_values_from_left_prover` converts a `Batch` into an iterator over + /// table indices for `TABLE_V`, which is used by the verifier of the DZKP on the + /// right side of the prover. + fn get_field_values_from_left_prover(&self) -> impl Iterator + '_ { self.inner .values() - .flat_map(MultiplicationInputsBatch::get_field_values_from_left_prover::) + .flat_map(MultiplicationInputsBatch::get_field_values_from_left_prover) } /// ## Panics @@ -626,7 +608,11 @@ impl Batch { q_mask_from_left_prover, ) = { // generate BatchToVerify - ProofBatch::generate(&proof_ctx, prss_record_ids, self.get_field_values_prover()) + ProofBatch::generate( + &proof_ctx, + prss_record_ids, + ProverTableIndices(self.get_field_values_prover()), + ) }; let chunk_batch = BatchToVerify::generate_batch_to_verify( @@ -650,12 +636,7 @@ impl Batch { tracing::info!("validating {m} multiplications"); debug_assert_eq!( m, - self.get_field_values_prover::() - .flat_map(|(u_array, v_array)| { - u_array.into_iter().zip(v_array).map(|(u, v)| u * v) - }) - .count() - / 4, + self.get_field_values_prover().count(), "Number of multiplications is counted incorrectly" ); let sum_of_uv = Fp61BitPrime::truncate_from(u128::try_from(m).unwrap()) @@ -664,8 +645,14 @@ impl Batch { let (p_r_right_prover, q_r_left_prover) = chunk_batch.compute_p_and_q_r( &challenges_for_left_prover, &challenges_for_right_prover, - self.get_field_values_from_right_prover(), - self.get_field_values_from_left_prover(), + VerifierTableIndices { + input: self.get_field_values_from_right_prover(), + table: &TABLE_U, + }, + VerifierTableIndices { + input: self.get_field_values_from_left_prover(), + table: &TABLE_V, + }, ); (sum_of_uv, p_r_right_prover, q_r_left_prover) @@ -965,12 +952,11 @@ mod tests { ff::{ boolean::Boolean, boolean_array::{BooleanArray, BA16, BA20, BA256, BA3, BA32, BA64, BA8}, - Fp61BitPrime, }, protocol::{ basics::{select, BooleanArrayMul, SecureMul}, context::{ - dzkp_field::{DZKPCompatibleField, BLOCK_SIZE}, + dzkp_field::DZKPCompatibleField, dzkp_validator::{ Batch, DZKPValidator, Segment, SegmentEntry, BIT_ARRAY_LEN, TARGET_PROOF_SIZE, }, @@ -1814,16 +1800,16 @@ mod tests { fn assert_batch_convert(batch_prover: &Batch, batch_left: &Batch, batch_right: &Batch) { batch_prover - .get_field_values_prover::() + .get_field_values_prover() .zip( batch_left - .get_field_values_from_right_prover::() - .collect::>(), + .get_field_values_from_right_prover() + .collect::>(), ) .zip( batch_right - .get_field_values_from_left_prover::() - .collect::>(), + .get_field_values_from_left_prover() + .collect::>(), ) .for_each(|((prover, verifier_left), verifier_right)| { assert_eq!(prover.0, verifier_left); diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 41abab508..1f598cb31 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -86,7 +86,7 @@ where /// The "x coordinate" of the output point is `x_output`. pub fn new(denominator: &CanonicalLagrangeDenominator, x_output: &F) -> Self { // assertion that table is not too large for the stack - assert!(::Size::USIZE * N < 2024); + debug_assert!(::Size::USIZE * N < 2024); let table = Self::compute_table_row(x_output, denominator); LagrangeTable:: { table: [table; 1] } diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs index 607827b1a..9d366735b 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs @@ -1,3 +1,11 @@ pub mod lagrange; pub mod prover; pub mod verifier; + +pub type FirstProofGenerator = prover::SmallProofGenerator; +pub type CompressedProofGenerator = prover::SmallProofGenerator; +pub use lagrange::{CanonicalLagrangeDenominator, LagrangeTable}; +pub use prover::ProverTableIndices; +pub use verifier::VerifierTableIndices; + +pub const FIRST_RECURSION_FACTOR: usize = FirstProofGenerator::RECURSION_FACTOR; diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs index 26ce5c5d0..aa1b25e9b 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs @@ -1,18 +1,25 @@ -use std::{borrow::Borrow, iter::zip, marker::PhantomData}; +use std::{array, borrow::Borrow, marker::PhantomData}; use crate::{ error::Error::{self, DZKPMasks}, - ff::{Fp61BitPrime, PrimeField}, + ff::{Fp61BitPrime, MultiplyAccumulate, MultiplyAccumulatorArray, PrimeField}, helpers::hashing::{compute_hash, hash_to_field}, protocol::{ - context::Context, + context::{ + dzkp_field::{TABLE_U, TABLE_V}, + Context, + }, ipa_prf::{ - malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + malicious_security::{ + lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + FIRST_RECURSION_FACTOR as FRF, + }, CompressedProofGenerator, }, prss::SharedRandomness, RecordId, RecordIdRange, }, + secret_sharing::SharedValue, }; /// This struct stores intermediate `uv` values. @@ -95,6 +102,213 @@ where } } +/// Trait for inputs to Lagrange interpolation on the prover. +/// +/// Lagrange interpolation is used in the prover in two ways: +/// +/// 1. To extrapolate an additional λ - 1 y-values of degree-(λ-1) polynomials so that the +/// total 2λ - 1 y-values can be multiplied to obtain a representation of the product of +/// the polynomials. +/// 2. To evaluate polynomials at the randomly chosen challenge point _r_. +/// +/// The two methods in this trait correspond to those two uses. +/// +/// There are two implementations of this trait: `ProverTableIndices`, and +/// `ProverValues`. `ProverTableIndices` is used for the input to the first proof. Each +/// set of 4 _u_ or _v_ values input to the first proof has one of eight possible +/// values, determined by the values of the 3 associated multiplication intermediates. +/// The `ProverTableIndices` implementation uses a lookup table containing the output of +/// the Lagrange interpolation for each of these eight possible values. The +/// `ProverValues` implementation, which represents actual _u_ and _v_ values, is used +/// by the remaining recursive proofs. +/// +/// There is a similar trait `VerifierLagrangeInput` in `verifier.rs`. The difference is +/// that the prover operates on _u_ and _v_ values simultaneously (i.e. iterators of +/// tuples). The verifier operates on only one of _u_ or _v_ at a time. +pub trait ProverLagrangeInput { + fn extrapolate_y_values<'a, const P: usize, const M: usize>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a; + + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a; +} + +/// Implementation of `ProverLagrangeInput` for table indices, used for the first proof. +#[derive(Clone)] +pub struct ProverTableIndices>(pub I); + +/// Iterator returned by `ProverTableIndices::extrapolate_y_values` and +/// `ProverTableIndices::eval_at_r`. +struct TableIndicesIterator> { + input: I, + u_table: [T; 8], + v_table: [T; 8], +} + +impl> ProverLagrangeInput + for ProverTableIndices +{ + fn extrapolate_y_values<'a, const P: usize, const M: usize>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + TableIndicesIterator { + input: self.0, + u_table: array::from_fn(|i| { + let mut result = [Fp61BitPrime::ZERO; P]; + let u = &TABLE_U[i]; + result[0..FRF].copy_from_slice(u); + result[FRF..].copy_from_slice(&lagrange_table.eval(u)); + result + }), + v_table: array::from_fn(|i| { + let mut result = [Fp61BitPrime::ZERO; P]; + let v = &TABLE_V[i]; + result[0..FRF].copy_from_slice(v); + result[FRF..].copy_from_slice(&lagrange_table.eval(v)); + result + }), + } + } + + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + TableIndicesIterator { + input: self.0, + u_table: array::from_fn(|i| lagrange_table.eval(&TABLE_U[i])[0]), + v_table: array::from_fn(|i| lagrange_table.eval(&TABLE_V[i])[0]), + } + } +} + +impl> Iterator for TableIndicesIterator { + type Item = (T, T); + + fn next(&mut self) -> Option { + self.input.next().map(|(u_index, v_index)| { + ( + self.u_table[usize::from(u_index)].clone(), + self.v_table[usize::from(v_index)].clone(), + ) + }) + } +} + +/// Implementation of `ProverLagrangeInput` for _u_ and _v_ values, used for subsequent +/// recursive proofs. +#[derive(Clone)] +pub struct ProverValues>(pub I); + +/// Iterator returned by `ProverValues::extrapolate_y_values`. +struct ValuesExtrapolateIterator< + 'a, + F: PrimeField, + const L: usize, + const P: usize, + const M: usize, + I: Iterator, +> { + input: I, + lagrange_table: &'a LagrangeTable, +} + +impl< + 'a, + F: PrimeField, + const L: usize, + const P: usize, + const M: usize, + I: Iterator, + > Iterator for ValuesExtrapolateIterator<'a, F, L, P, M, I> +{ + type Item = ([F; P], [F; P]); + + fn next(&mut self) -> Option { + self.input.next().map(|(u_values, v_values)| { + let mut u = [F::ZERO; P]; + u[0..L].copy_from_slice(&u_values); + u[L..].copy_from_slice(&self.lagrange_table.eval(&u_values)); + let mut v = [F::ZERO; P]; + v[0..L].copy_from_slice(&v_values); + v[L..].copy_from_slice(&self.lagrange_table.eval(&v_values)); + (u, v) + }) + } +} + +/// Iterator returned by `ProverValues::eval_at_r`. +struct ValuesEvalAtRIterator< + 'a, + F: PrimeField, + const L: usize, + I: Iterator, +> { + input: I, + lagrange_table: &'a LagrangeTable, +} + +impl<'a, F: PrimeField, const L: usize, I: Iterator> Iterator + for ValuesEvalAtRIterator<'a, F, L, I> +{ + type Item = (F, F); + + fn next(&mut self) -> Option { + self.input.next().map(|(u_values, v_values)| { + ( + self.lagrange_table.eval(&u_values)[0], + self.lagrange_table.eval(&v_values)[0], + ) + }) + } +} + +impl> ProverLagrangeInput + for ProverValues +{ + fn extrapolate_y_values<'a, const P: usize, const M: usize>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + debug_assert_eq!(L + M, P); + ValuesExtrapolateIterator { + input: self.0, + lagrange_table, + } + } + + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + ValuesEvalAtRIterator { + input: self.0, + lagrange_table, + } + } +} + /// This struct sets up the parameter for the proof generation /// and provides several functions to generate zero knowledge proofs. /// @@ -119,40 +333,47 @@ impl ProofGenerat pub const PROOF_LENGTH: usize = P; pub const LAGRANGE_LENGTH: usize = M; + pub fn compute_proof_from_uv(uv: J, lagrange_table: &LagrangeTable) -> [F; P] + where + J: Iterator, + J::Item: Borrow<([F; L], [F; L])>, + { + Self::compute_proof(uv.map(|uv| { + let (u, v) = uv.borrow(); + let mut u_ex = [F::ZERO; P]; + let mut v_ex = [F::ZERO; P]; + u_ex[0..L].copy_from_slice(u); + v_ex[0..L].copy_from_slice(v); + u_ex[L..].copy_from_slice(&lagrange_table.eval(u)); + v_ex[L..].copy_from_slice(&lagrange_table.eval(v)); + (u_ex, v_ex) + })) + } + /// /// Distributed Zero Knowledge Proofs algorithm drawn from /// `https://eprint.iacr.org/2023/909.pdf` - pub fn compute_proof(uv_iterator: J, lagrange_table: &LagrangeTable) -> [F; P] + pub fn compute_proof(pq_iterator: J) -> [F; P] where J: Iterator, - J::Item: Borrow<([F; L], [F; L])>, + J::Item: Borrow<([F; P], [F; P])>, { - let mut proof = [F::ZERO; P]; - for uv_polynomial in uv_iterator { - for (i, proof_part) in proof.iter_mut().enumerate().take(L) { - *proof_part += uv_polynomial.borrow().0[i] * uv_polynomial.borrow().1[i]; - } - let p_extrapolated = lagrange_table.eval(&uv_polynomial.borrow().0); - let q_extrapolated = lagrange_table.eval(&uv_polynomial.borrow().1); - - for (i, (x, y)) in - zip(p_extrapolated.into_iter(), q_extrapolated.into_iter()).enumerate() - { - proof[L + i] += x * y; - } - } - proof + pq_iterator + .fold( + ::AccumulatorArray::

::new(), + |mut proof, pq| { + proof.multiply_accumulate(&pq.borrow().0, &pq.borrow().1); + proof + }, + ) + .take() } - fn gen_challenge_and_recurse( + fn gen_challenge_and_recurse, const N: usize>( proof_left: &[F; P], proof_right: &[F; P], - uv_iterator: J, - ) -> UVValues - where - J: Iterator, - J::Item: Borrow<([F; L], [F; L])>, - { + uv_iterator: I, + ) -> UVValues { let r: F = hash_to_field( &compute_hash(proof_left), &compute_hash(proof_right), @@ -162,17 +383,8 @@ impl ProofGenerat let denominator = CanonicalLagrangeDenominator::::new(); let lagrange_table_r = LagrangeTable::::new(&denominator, &r); - // iter and interpolate at x coordinate r uv_iterator - .map(|polynomial| { - let (u_chunk, v_chunk) = polynomial.borrow(); - ( - // new u value - lagrange_table_r.eval(u_chunk)[0], - // new v value - lagrange_table_r.eval(v_chunk)[0], - ) - }) + .eval_at_r(&lagrange_table_r) .collect::>() } @@ -202,29 +414,23 @@ impl ProofGenerat proof_other_share } - /// This function is a helper function that computes the next proof - /// from an iterator over uv values - /// It also computes the next uv values + /// This function is a helper function that, given the computed proof, computes the shares of + /// the proof, the challenge, and the next uv values. /// /// It output `(uv values, share_of_proof_from_prover_left, my_proof_left_share)` /// where /// `share_of_proof_from_prover_left` from left has type `Vec<[F; P]>`, /// `my_proof_left_share` has type `Vec<[F; P]>`, - pub fn gen_artefacts_from_recursive_step( + pub fn gen_artefacts_from_recursive_step( ctx: &C, record_ids: &mut RecordIdRange, - lagrange_table: &LagrangeTable, - uv_iterator: J, + my_proof: [F; P], + uv_iterator: I, ) -> (UVValues, [F; P], [F; P]) where C: Context, - J: Iterator + Clone, - J::Item: Borrow<([F; L], [F; L])>, + I: ProverLagrangeInput, { - // generate next proof - // from iterator - let my_proof = Self::compute_proof(uv_iterator.clone(), lagrange_table); - // generate proof shares from prss let (share_of_proof_from_prover_left, my_proof_right_share) = Self::gen_proof_shares_from_prss(ctx, record_ids); @@ -254,20 +460,26 @@ mod test { use std::iter::zip; use futures::future::try_join; + use rand::Rng; + use super::*; use crate::{ ff::{Fp31, Fp61BitPrime, PrimeField, U128Conversions}, helpers::{Direction, Role}, protocol::{ - context::Context, - ipa_prf::malicious_security::{ - lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, - prover::{ProofGenerator, SmallProofGenerator, UVValues}, + context::{ + dzkp_field::tests::reference_convert, + dzkp_validator::{MultiplicationInputsBlock, BIT_ARRAY_LEN}, + Context, + }, + ipa_prf::{ + malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + FirstProofGenerator, }, RecordId, RecordIdRange, }, seq_join::SeqJoin, - test_executor::run, + test_executor::{run, run_random}, test_fixture::{Runner, TestWorld}, }; @@ -316,7 +528,7 @@ mod test { let uv_1 = zip_chunks(U_1, V_1); // first iteration - let proof_1 = TestProofGenerator::compute_proof(uv_1.iter(), &lagrange_table); + let proof_1 = TestProofGenerator::compute_proof_from_uv(uv_1.iter(), &lagrange_table); assert_eq!( proof_1.iter().map(Fp31::as_u128).collect::>(), PROOF_1, @@ -335,12 +547,12 @@ mod test { let uv_2 = TestProofGenerator::gen_challenge_and_recurse( &proof_left_1, &proof_right_1, - uv_1.iter(), + ProverValues(uv_1.iter().copied()), ); assert_eq!(uv_2, zip_chunks(U_2, V_2)); // next iteration - let proof_2 = TestProofGenerator::compute_proof(uv_2.iter(), &lagrange_table); + let proof_2 = TestProofGenerator::compute_proof_from_uv(uv_2.iter(), &lagrange_table); assert_eq!( proof_2.iter().map(Fp31::as_u128).collect::>(), PROOF_2, @@ -359,7 +571,7 @@ mod test { let uv_3 = TestProofGenerator::gen_challenge_and_recurse::<_, 4>( &proof_left_2, &proof_right_2, - uv_2.iter(), + ProverValues(uv_2.iter().copied()), ); assert_eq!(uv_3, zip_chunks(U_3, V_3)); @@ -369,7 +581,8 @@ mod test { ); // final iteration - let proof_3 = TestProofGenerator::compute_proof(masked_uv_3.iter(), &lagrange_table); + let proof_3 = + TestProofGenerator::compute_proof_from_uv(masked_uv_3.iter(), &lagrange_table); assert_eq!( proof_3.iter().map(Fp31::as_u128).collect::>(), PROOF_3, @@ -397,11 +610,12 @@ mod test { // first iteration let world = TestWorld::default(); let mut record_ids = RecordIdRange::ALL; + let proof = TestProofGenerator::compute_proof_from_uv(uv_1.iter(), &lagrange_table); let (uv_values, _, _) = TestProofGenerator::gen_artefacts_from_recursive_step::<_, _, 4>( &world.contexts()[0], &mut record_ids, - &lagrange_table, - uv_1.iter(), + proof, + ProverValues(uv_1.iter().copied()), ); assert_eq!(7, uv_values.len()); @@ -436,14 +650,14 @@ mod test { >::from(denominator); // compute proof - let proof = SmallProofGenerator::compute_proof(uv_before.iter(), &lagrange_table); + let proof = SmallProofGenerator::compute_proof_from_uv(uv_before.iter(), &lagrange_table); assert_eq!(proof.len(), SmallProofGenerator::PROOF_LENGTH); let uv_after = SmallProofGenerator::gen_challenge_and_recurse::<_, 8>( &proof, &proof, - uv_before.iter(), + ProverValues(uv_before.iter().copied()), ); assert_eq!( @@ -472,14 +686,14 @@ mod test { >::from(denominator); // compute proof - let proof = LargeProofGenerator::compute_proof(uv_before.iter(), &lagrange_table); + let proof = LargeProofGenerator::compute_proof_from_uv(uv_before.iter(), &lagrange_table); assert_eq!(proof.len(), LargeProofGenerator::PROOF_LENGTH); let uv_after = LargeProofGenerator::gen_challenge_and_recurse::<_, 8>( &proof, &proof, - uv_before.iter(), + ProverValues(uv_before.iter().copied()), ); assert_eq!( @@ -594,4 +808,51 @@ mod test { assert_two_part_secret_sharing(PROOF_2, h1_proof_right, h3_proof_left); assert_two_part_secret_sharing(PROOF_3, h2_proof_right, h1_proof_left); } + + #[test] + fn prover_table_indices_equivalence() { + run_random(|mut rng| async move { + const FPL: usize = FirstProofGenerator::PROOF_LENGTH; + const FLL: usize = FirstProofGenerator::LAGRANGE_LENGTH; + + let block = rng.gen::(); + + // Test equivalence for extrapolate_y_values + let denominator = CanonicalLagrangeDenominator::new(); + let lagrange_table = LagrangeTable::from(denominator); + + assert!(ProverTableIndices(block.table_indices_prover().into_iter()) + .extrapolate_y_values::(&lagrange_table) + .eq(ProverValues((0..BIT_ARRAY_LEN).map(|i| { + reference_convert( + block.x_left[i], + block.x_right[i], + block.y_left[i], + block.y_right[i], + block.prss_left[i], + block.prss_right[i], + ) + })) + .extrapolate_y_values::(&lagrange_table))); + + // Test equivalence for eval_at_r + let denominator = CanonicalLagrangeDenominator::new(); + let r = rng.gen(); + let lagrange_table_r = LagrangeTable::new(&denominator, &r); + + assert!(ProverTableIndices(block.table_indices_prover().into_iter()) + .eval_at_r(&lagrange_table_r) + .eq(ProverValues((0..BIT_ARRAY_LEN).map(|i| { + reference_convert( + block.x_left[i], + block.x_right[i], + block.y_left[i], + block.y_right[i], + block.prss_left[i], + block.prss_right[i], + ) + })) + .eval_at_r(&lagrange_table_r))); + }); + } } diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs index e62465b73..0038d4b85 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs @@ -1,9 +1,18 @@ -use std::iter::{self}; +use std::{ + array, + iter::{self}, +}; use crate::{ ff::PrimeField, - protocol::ipa_prf::malicious_security::lagrange::{ - CanonicalLagrangeDenominator, LagrangeTable, + protocol::{ + context::{ + dzkp_field::UVTable, + dzkp_validator::{MAX_PROOF_RECURSION, MIN_PROOF_RECURSION}, + }, + ipa_prf::malicious_security::{ + CanonicalLagrangeDenominator, LagrangeTable, FIRST_RECURSION_FACTOR as FRF, + }, }, utils::arraychunks::ArrayChunkIterator, }; @@ -96,35 +105,30 @@ pub fn compute_final_sum_share(zk } /// This function compresses the `u_or_v` values and returns the next `u_or_v` values. -fn recurse_u_or_v<'a, F: PrimeField, J, const L: usize>( - u_or_v: J, +fn recurse_u_or_v<'a, F: PrimeField, const L: usize>( + u_or_v: impl Iterator + 'a, lagrange_table: &'a LagrangeTable, -) -> impl Iterator + 'a -where - J: Iterator + 'a, -{ - u_or_v - .chunk_array::() - .map(|x| lagrange_table.eval(&x)[0]) +) -> impl Iterator + 'a { + VerifierValues(u_or_v.chunk_array::()).eval_at_r(lagrange_table) } /// This function recursively compresses the `u_or_v` values. -/// The recursion factor (or compression) of the first recursion is `L_FIRST` -/// The recursion factor of all following recursions is `L`. -pub fn recursively_compute_final_check( - u_or_v: J, +/// +/// The recursion factor (or compression) of the first recursion is fixed at +/// `FIRST_RECURSION_FACTOR` (`FRF`). The recursion factor of all following recursions +/// is `L`. +pub fn recursively_compute_final_check( + input: impl VerifierLagrangeInput, challenges: &[F], p_or_q_0: F, -) -> F -where - J: Iterator, -{ +) -> F { + // This function requires MIN_PROOF_RECURSION be at least 2. + assert!(challenges.len() >= MIN_PROOF_RECURSION && challenges.len() <= MAX_PROOF_RECURSION); let recursions_after_first = challenges.len() - 1; // compute Lagrange tables - let denominator_p_or_q_first = CanonicalLagrangeDenominator::::new(); - let table_first = - LagrangeTable::::new(&denominator_p_or_q_first, &challenges[0]); + let denominator_p_or_q_first = CanonicalLagrangeDenominator::::new(); + let table_first = LagrangeTable::::new(&denominator_p_or_q_first, &challenges[0]); let denominator_p_or_q = CanonicalLagrangeDenominator::::new(); let tables = challenges[1..] .iter() @@ -133,11 +137,10 @@ where // generate & evaluate recursive streams // to compute last array - let mut iterator: Box> = - Box::new(recurse_u_or_v::<_, _, L_FIRST>(u_or_v, &table_first)); + let mut iterator: Box> = Box::new(input.eval_at_r(&table_first)); // all following recursion except last one for lagrange_table in tables.iter().take(recursions_after_first - 1) { - iterator = Box::new(recurse_u_or_v::<_, _, L>(iterator, lagrange_table)); + iterator = Box::new(recurse_u_or_v(iterator, lagrange_table)); } let last_u_or_v_values = iterator.collect::>(); // Make sure there are less than L last u or v values. The prover is expected to continue @@ -162,18 +165,127 @@ where tables.last().unwrap().eval(&last_array)[0] } +/// Trait for inputs to Lagrange interpolation on the verifier. +/// +/// Lagrange interpolation is used in the verifier to evaluate polynomials at the +/// randomly chosen challenge point _r_. +/// +/// There are two implementations of this trait: `VerifierTableIndices`, and +/// `VerifierValues`. `VerifierTableIndices` is used for the input to the first proof. +/// Each set of 4 _u_ or _v_ values input to the first proof has one of eight possible +/// values, determined by the values of the 3 associated multiplication intermediates. +/// The `VerifierTableIndices` implementation uses a lookup table containing the output +/// of the Lagrange interpolation for each of these eight possible values. The +/// `VerifierValues` implementation, which represents actual _u_ and _v_ values, is used +/// by the remaining recursive proofs. +/// +/// There is a similar trait `ProverLagrangeInput` in `prover.rs`. The difference is +/// that the prover operates on _u_ and _v_ values simultaneously (i.e. iterators of +/// tuples). The verifier operates on only one of _u_ or _v_ at a time. +pub trait VerifierLagrangeInput { + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a; +} + +/// Implementation of `VerifierLagrangeInput` for table indices, used for the first proof. +pub struct VerifierTableIndices<'a, F: PrimeField, I: Iterator> { + pub input: I, + pub table: &'a UVTable, +} + +/// Iterator returned by `VerifierTableIndices::eval_at_r`. +struct TableIndicesIterator> { + input: I, + table: [F; 8], +} + +impl> VerifierLagrangeInput + for VerifierTableIndices<'_, F, I> +{ + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + TableIndicesIterator { + input: self.input, + table: array::from_fn(|i| lagrange_table.eval(&self.table[i])[0]), + } + } +} + +impl> Iterator for TableIndicesIterator { + type Item = F; + + fn next(&mut self) -> Option { + self.input + .next() + .map(|index| self.table[usize::from(index)]) + } +} + +/// Implementation of `VerifierLagrangeInput` for _u_ and _v_ values, used for +/// subsequent recursive proofs. +pub struct VerifierValues>(pub I); + +/// Iterator returned by `ProverValues::eval_at_r`. +struct ValuesEvalAtRIterator<'a, F: PrimeField, const L: usize, I: Iterator> { + input: I, + lagrange_table: &'a LagrangeTable, +} + +impl<'a, F: PrimeField, const L: usize, I: Iterator> Iterator + for ValuesEvalAtRIterator<'a, F, L, I> +{ + type Item = F; + + fn next(&mut self) -> Option { + self.input + .next() + .map(|values| self.lagrange_table.eval(&values)[0]) + } +} + +impl> VerifierLagrangeInput + for VerifierValues +{ + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + ValuesEvalAtRIterator { + input: self.0, + lagrange_table, + } + } +} + #[cfg(all(test, unit_test))] mod test { + use rand::Rng; + + use super::*; use crate::{ ff::{Fp31, U128Conversions}, - protocol::ipa_prf::malicious_security::{ - lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, - verifier::{ - compute_g_differences, compute_sum_share, interpolate_at_r, recurse_u_or_v, - recursively_compute_final_check, + protocol::{ + context::{ + dzkp_field::{tests::reference_convert, TABLE_U, TABLE_V}, + dzkp_validator::{MultiplicationInputsBlock, BIT_ARRAY_LEN}, }, + ipa_prf::malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, }, secret_sharing::SharedValue, + test_executor::run_random, + utils::arraychunks::ArrayChunkIterator, }; fn to_field(a: &[u128]) -> Vec { @@ -292,12 +404,11 @@ mod test { let tables: [LagrangeTable; 3] = CHALLENGES .map(|r| LagrangeTable::new(&denominator_p_or_q, &Fp31::try_from(r).unwrap())); - let u_or_v_2 = recurse_u_or_v::<_, _, 4>(to_field(&U_1).into_iter(), &tables[0]) - .collect::>(); + let u_or_v_2 = + recurse_u_or_v::<_, 4>(to_field(&U_1).into_iter(), &tables[0]).collect::>(); assert_eq!(u_or_v_2, to_field(&U_2)); - let u_or_v_3 = - recurse_u_or_v::<_, _, 4>(u_or_v_2.into_iter(), &tables[1]).collect::>(); + let u_or_v_3 = recurse_u_or_v::<_, 4>(u_or_v_2.into_iter(), &tables[1]).collect::>(); assert_eq!(u_or_v_3, to_field(&U_3[..2])); @@ -309,12 +420,12 @@ mod test { ]; let p_final = - recurse_u_or_v::<_, _, 4>(u_or_v_3_masked.into_iter(), &tables[2]).collect::>(); + recurse_u_or_v::<_, 4>(u_or_v_3_masked.into_iter(), &tables[2]).collect::>(); assert_eq!(p_final[0].as_u128(), EXPECTED_P_FINAL); - let p_final_another_way = recursively_compute_final_check::( - to_field(&U_1).into_iter(), + let p_final_another_way = recursively_compute_final_check::<_, 4>( + VerifierValues(to_field(&U_1).into_iter().chunk_array::<4>()), &CHALLENGES .map(|x| Fp31::try_from(x).unwrap()) .into_iter() @@ -396,15 +507,15 @@ mod test { // final iteration let p_final = - recurse_u_or_v::<_, _, 4>(u_or_v_3_masked.into_iter(), &tables[2]).collect::>(); + recurse_u_or_v::<_, 4>(u_or_v_3_masked.into_iter(), &tables[2]).collect::>(); assert_eq!(p_final[0].as_u128(), EXPECTED_Q_FINAL); // uv values in input format let v_1 = to_field(&V_1); - let q_final_another_way = recursively_compute_final_check::( - v_1.into_iter(), + let q_final_another_way = recursively_compute_final_check::( + VerifierValues(v_1.into_iter().chunk_array::<4>()), &CHALLENGES .map(|x| Fp31::try_from(x).unwrap()) .into_iter() @@ -447,4 +558,59 @@ mod test { assert_eq!(Fp31::ZERO, g_differences[0]); } + + #[test] + fn verifier_table_indices_equivalence() { + run_random(|mut rng| async move { + let block = rng.gen::(); + + let denominator = CanonicalLagrangeDenominator::new(); + let r = rng.gen(); + let lagrange_table_r = LagrangeTable::new(&denominator, &r); + + // Test equivalence for _u_ values + assert!(VerifierTableIndices { + input: block + .rotate_right() + .table_indices_from_right_prover() + .into_iter(), + table: &TABLE_U, + } + .eval_at_r(&lagrange_table_r) + .eq(VerifierValues((0..BIT_ARRAY_LEN).map(|i| { + reference_convert( + block.x_left[i], + block.x_right[i], + block.y_left[i], + block.y_right[i], + block.prss_left[i], + block.prss_right[i], + ) + .0 + })) + .eval_at_r(&lagrange_table_r))); + + // Test equivalence for _v_ values + assert!(VerifierTableIndices { + input: block + .rotate_left() + .table_indices_from_left_prover() + .into_iter(), + table: &TABLE_V, + } + .eval_at_r(&lagrange_table_r) + .eq(VerifierValues((0..BIT_ARRAY_LEN).map(|i| { + reference_convert( + block.x_left[i], + block.x_right[i], + block.y_left[i], + block.y_right[i], + block.prss_left[i], + block.prss_right[i], + ) + .1 + })) + .eval_at_r(&lagrange_table_r))); + }); + } } diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 0ffe2adbc..63dc34a6f 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -59,9 +59,10 @@ pub(crate) mod shuffle; pub(crate) mod step; pub mod validation_protocol; -pub type FirstProofGenerator = malicious_security::prover::SmallProofGenerator; -pub type CompressedProofGenerator = malicious_security::prover::SmallProofGenerator; - +pub use malicious_security::{ + CompressedProofGenerator, FirstProofGenerator, LagrangeTable, ProverTableIndices, + VerifierTableIndices, +}; pub use shuffle::Shuffle; /// Match key type diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs index 0a658614e..e496f97e4 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs @@ -7,13 +7,13 @@ use crate::{ ff::{Fp61BitPrime, Serializable}, helpers::{Direction, MpcMessage, TotalRecords}, protocol::{ - context::{ - dzkp_field::{UVTupleBlock, BLOCK_SIZE}, - dzkp_validator::MAX_PROOF_RECURSION, - Context, - }, + context::{dzkp_validator::MAX_PROOF_RECURSION, Context}, ipa_prf::{ - malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + malicious_security::{ + lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + prover::{ProverLagrangeInput, ProverValues}, + FIRST_RECURSION_FACTOR as FRF, + }, CompressedProofGenerator, FirstProofGenerator, }, prss::SharedRandomness, @@ -81,16 +81,14 @@ impl ProofBatch { /// ## Panics /// Panics when the function fails to set the masks without overwritting `u` and `v` values. /// This only happens when there is an issue in the recursion. - pub fn generate( + pub fn generate( ctx: &C, mut prss_record_ids: RecordIdRange, - uv_tuple_inputs: I, + uv_inputs: impl ProverLagrangeInput + Clone, ) -> (Self, Self, Fp61BitPrime, Fp61BitPrime) where C: Context, - I: Iterator> + Clone, { - const FRF: usize = FirstProofGenerator::RECURSION_FACTOR; const FLL: usize = FirstProofGenerator::LAGRANGE_LENGTH; const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; const CLL: usize = CompressedProofGenerator::LAGRANGE_LENGTH; @@ -100,13 +98,19 @@ impl ProofBatch { let first_denominator = CanonicalLagrangeDenominator::::new(); let first_lagrange_table = LagrangeTable::::from(first_denominator); + let first_proof = FirstProofGenerator::compute_proof( + uv_inputs + .clone() + .extrapolate_y_values(&first_lagrange_table), + ); + // generate first proof from input iterator let (mut uv_values, first_proof_from_left, my_first_proof_left_share) = FirstProofGenerator::gen_artefacts_from_recursive_step( ctx, &mut prss_record_ids, - &first_lagrange_table, - ProofBatch::polynomials_from_inputs(uv_tuple_inputs), + first_proof, + uv_inputs, ); // `MAX_PROOF_RECURSION - 2` because: @@ -156,12 +160,14 @@ impl ProofBatch { did_set_masks = true; uv_values.set_masks(my_p_mask, my_q_mask).unwrap(); } + let my_proof = + CompressedProofGenerator::compute_proof_from_uv(uv_values.iter(), &lagrange_table); let (uv_values_new, share_of_proof_from_prover_left, my_proof_left_share) = CompressedProofGenerator::gen_artefacts_from_recursive_step( ctx, &mut prss_record_ids, - &lagrange_table, - uv_values.iter(), + my_proof, + ProverValues(uv_values.iter().copied()), ); shares_of_proofs_from_prover_left.push(share_of_proof_from_prover_left); my_proofs_left_shares.push(my_proof_left_share); @@ -225,42 +231,6 @@ impl ProofBatch { .take(length) .collect()) } - - /// This is a helper function that allows to split a `UVTupleInputs` - /// which consists of arrays of size `BLOCK_SIZE` - /// into an iterator over arrays of size `LargeProofGenerator::RECURSION_FACTOR`. - /// - /// ## Panics - /// Panics when `unwrap` panics, i.e. `try_from` fails to convert a slice to an array. - pub fn polynomials_from_inputs( - inputs: I, - ) -> impl Iterator< - Item = ( - [Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR], - [Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR], - ), - > + Clone - where - I: Iterator> + Clone, - { - assert_eq!(BLOCK_SIZE % FirstProofGenerator::RECURSION_FACTOR, 0); - inputs.flat_map(|(u_block, v_block)| { - (0usize..(BLOCK_SIZE / FirstProofGenerator::RECURSION_FACTOR)).map(move |i| { - ( - <[Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR]>::try_from( - &u_block[i * FirstProofGenerator::RECURSION_FACTOR - ..(i + 1) * FirstProofGenerator::RECURSION_FACTOR], - ) - .unwrap(), - <[Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR]>::try_from( - &v_block[i * FirstProofGenerator::RECURSION_FACTOR - ..(i + 1) * FirstProofGenerator::RECURSION_FACTOR], - ) - .unwrap(), - ) - }) - }) - } } const ARRAY_LEN: usize = FirstProofGenerator::PROOF_LENGTH @@ -297,19 +267,25 @@ impl MpcMessage for Box {} #[cfg(all(test, unit_test))] mod test { - use rand::{thread_rng, Rng}; + use std::iter::repeat_with; + + use rand::Rng; use crate::{ - ff::{Fp61BitPrime, U128Conversions}, protocol::{ - context::{dzkp_field::BLOCK_SIZE, Context}, - ipa_prf::validation_protocol::{ - proof_generation::ProofBatch, - validation::{test::simple_proof_check, BatchToVerify}, + context::Context, + ipa_prf::{ + malicious_security::{ + prover::{ProverValues, UVValues}, + FIRST_RECURSION_FACTOR, + }, + validation_protocol::{ + proof_generation::ProofBatch, + validation::{test::simple_proof_check, BatchToVerify}, + }, }, RecordId, RecordIdRange, }, - secret_sharing::replicated::ReplicatedSecretSharing, test_executor::run, test_fixture::{Runner, TestWorld}, }; @@ -319,37 +295,16 @@ mod test { run(|| async move { let world = TestWorld::default(); - let mut rng = thread_rng(); + let mut rng = world.rng(); - // each helper samples a random value h - // which is later used to generate distinct values across helpers - let h = Fp61BitPrime::truncate_from(rng.gen_range(0u128..100)); + let uv_values = repeat_with(|| (rng.gen(), rng.gen())) + .take(100) + .collect::>(); + let uv_values_iter = uv_values.iter().copied(); + let uv_values_iter_ref = &uv_values_iter; let result = world - .semi_honest(h, |ctx, h| async move { - let h = Fp61BitPrime::truncate_from(h.left().as_u128() % 100); - // generate blocks of UV values - // generate u values as (1h,2h,3h,....,10h*BlockSize) split into Blocksize chunks - // where BlockSize = 32 - // v values are identical to u - let uv_tuple_vec = (0usize..100) - .map(|i| { - ( - (BLOCK_SIZE * i..BLOCK_SIZE * (i + 1)) - .map(|j| { - Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) * h - }) - .collect::<[Fp61BitPrime; BLOCK_SIZE]>(), - (BLOCK_SIZE * i..BLOCK_SIZE * (i + 1)) - .map(|j| { - Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) * h - }) - .collect::<[Fp61BitPrime; BLOCK_SIZE]>(), - ) - }) - .collect::>(); - - // generate and output VerifierBatch together with h value + .semi_honest((), |ctx, ()| async move { let ( my_batch_left_shares, shares_of_batch_from_left_prover, @@ -358,10 +313,10 @@ mod test { ) = ProofBatch::generate( &ctx.narrow("generate_batch"), RecordIdRange::ALL, - uv_tuple_vec.into_iter(), + ProverValues(uv_values_iter_ref.clone()), ); - let batch_to_verify = BatchToVerify::generate_batch_to_verify( + BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), RecordId::FIRST, my_batch_left_shares, @@ -369,21 +324,18 @@ mod test { p_mask_from_right_prover, q_mask_from_left_prover, ) - .await; - - // generate and output VerifierBatch together with h value - (h, batch_to_verify) + .await }) .await; // proof from first party - simple_proof_check(result[0].0, &result[2].1, &result[1].1); + simple_proof_check(uv_values.iter(), &result[2], &result[1]); // proof from second party - simple_proof_check(result[1].0, &result[0].1, &result[2].1); + simple_proof_check(uv_values.iter(), &result[0], &result[2]); // proof from third party - simple_proof_check(result[2].0, &result[1].1, &result[0].1); + simple_proof_check(uv_values.iter(), &result[1], &result[0]); }); } } diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs index 3f656b3eb..e4022a60c 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs @@ -20,8 +20,11 @@ use crate::{ dzkp_validator::MAX_PROOF_RECURSION, step::DzkpProofVerifyStep as Step, Context, }, ipa_prf::{ - malicious_security::verifier::{ - compute_g_differences, recursively_compute_final_check, + malicious_security::{ + verifier::{ + compute_g_differences, recursively_compute_final_check, VerifierLagrangeInput, + }, + FIRST_RECURSION_FACTOR as FRF, }, validation_protocol::proof_generation::ProofBatch, CompressedProofGenerator, FirstProofGenerator, @@ -105,7 +108,6 @@ impl BatchToVerify { where C: Context, { - const FRF: usize = FirstProofGenerator::RECURSION_FACTOR; const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; // exclude for first proof @@ -172,21 +174,20 @@ impl BatchToVerify { v_from_left_prover: V, // Prover P_i and verifier P_{i+1} both compute `v` and `q(x)` ) -> (Fp61BitPrime, Fp61BitPrime) where - U: Iterator + Send, - V: Iterator + Send, + U: VerifierLagrangeInput + Send, + V: VerifierLagrangeInput + Send, { - const FRF: usize = FirstProofGenerator::RECURSION_FACTOR; const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; // compute p_r - let p_r_right_prover = recursively_compute_final_check::<_, _, FRF, CRF>( - u_from_right_prover.into_iter(), + let p_r_right_prover = recursively_compute_final_check::<_, CRF>( + u_from_right_prover, challenges_for_right_prover, self.p_mask_from_right_prover, ); // compute q_r - let q_r_left_prover = recursively_compute_final_check::<_, _, FRF, CRF>( - v_from_left_prover.into_iter(), + let q_r_left_prover = recursively_compute_final_check::<_, CRF>( + v_from_left_prover, challenges_for_left_prover, self.q_mask_from_left_prover, ); @@ -242,7 +243,6 @@ impl BatchToVerify { where C: Context, { - const FRF: usize = FirstProofGenerator::RECURSION_FACTOR; const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; const FPL: usize = FirstProofGenerator::PROOF_LENGTH; @@ -445,21 +445,22 @@ impl MpcMessage for ProofDiff {} #[cfg(all(test, unit_test))] pub mod test { + use std::iter::repeat_with; + use futures_util::future::try_join; - use rand::{thread_rng, Rng}; + use rand::Rng; use crate::{ - ff::{Fp61BitPrime, U128Conversions}, + ff::Fp61BitPrime, helpers::Direction, protocol::{ - context::{ - dzkp_field::{UVTupleBlock, BLOCK_SIZE}, - Context, - }, + context::Context, ipa_prf::{ malicious_security::{ lagrange::CanonicalLagrangeDenominator, - verifier::{compute_sum_share, interpolate_at_r}, + prover::{ProverValues, UVValues}, + verifier::{compute_sum_share, interpolate_at_r, VerifierValues}, + FIRST_RECURSION_FACTOR as FRF, }, validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, CompressedProofGenerator, FirstProofGenerator, @@ -467,16 +468,20 @@ pub mod test { prss::SharedRandomness, RecordId, RecordIdRange, }, - secret_sharing::{replicated::ReplicatedSecretSharing, SharedValue}, + secret_sharing::SharedValue, test_executor::run, test_fixture::{Runner, TestWorld}, + utils::arraychunks::ArrayChunkIterator, }; - /// ## Panics - /// When proof is not generated correctly. - // todo: deprecate once validation protocol is implemented - pub fn simple_proof_check( - h: Fp61BitPrime, + // This is a helper for a test in `proof_generation.rs`, but is located here so it + // can access the internals of `BatchToVerify`. + // + // Possibly this (and the associated test) can be removed now that the validation + // protocol is implemented? (There was an old todo to that effect.) But it seems + // useful to keep around as a unit test. + pub(in crate::protocol::ipa_prf::validation_protocol) fn simple_proof_check<'a>( + uv_values: impl Iterator, left_verifier: &BatchToVerify, right_verifier: &BatchToVerify, ) { @@ -512,36 +517,12 @@ pub mod test { // check first proof, // compute simple proof without lagrange interpolated points - let simple_proof = { - let block_to_polynomial = BLOCK_SIZE / FirstProofGenerator::RECURSION_FACTOR; - let simple_proof_uv = (0usize..100 * block_to_polynomial) - .map(|i| { - ( - (FirstProofGenerator::RECURSION_FACTOR * i - ..FirstProofGenerator::RECURSION_FACTOR * (i + 1)) - .map(|j| Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) * h) - .collect::<[Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR]>(), - (FirstProofGenerator::RECURSION_FACTOR * i - ..FirstProofGenerator::RECURSION_FACTOR * (i + 1)) - .map(|j| Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) * h) - .collect::<[Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR]>(), - ) - }) - .collect::>(); - - simple_proof_uv.iter().fold( - [Fp61BitPrime::ZERO; FirstProofGenerator::RECURSION_FACTOR], - |mut acc, (left, right)| { - for i in 0..FirstProofGenerator::RECURSION_FACTOR { - acc[i] += left[i] * right[i]; - } - acc - }, - ) - }; + let simple_proof = uv_values.fold([Fp61BitPrime::ZERO; FRF], |mut acc, (left, right)| { + for i in 0..FRF { + acc[i] += left[i] * right[i]; + } + acc + }); // reconstruct computed proof // by adding shares left and right @@ -554,13 +535,7 @@ pub mod test { // check for consistency // only check first R::USIZE field elements - assert_eq!( - (h.as_u128(), simple_proof.to_vec()), - ( - h.as_u128(), - proof_computed[0..FirstProofGenerator::RECURSION_FACTOR].to_vec() - ) - ); + assert_eq!(simple_proof.to_vec(), proof_computed[0..FRF].to_vec()); } #[test] @@ -568,39 +543,17 @@ pub mod test { run(|| async move { let world = TestWorld::default(); - let mut rng = thread_rng(); + let mut rng = world.rng(); - // each helper samples a random value h - // which is later used to generate distinct values across helpers - let h = Fp61BitPrime::truncate_from(rng.gen_range(0u128..100)); + let uv_values = repeat_with(|| (rng.gen(), rng.gen())) + .take(100) + .collect::>(); + let uv_values_iter = uv_values.iter().copied(); + let uv_values_iter_ref = &uv_values_iter; let [(helper_1_left, helper_1_right), (helper_2_left, helper_2_right), (helper_3_left, helper_3_right)] = world - .semi_honest(h, |ctx, h| async move { - let h = Fp61BitPrime::truncate_from(h.left().as_u128() % 100); - // generate blocks of UV values - // generate u values as (1h,2h,3h,....,10h*BlockSize) split into Blocksize chunks - // where BlockSize = 32 - // v values are identical to u - let uv_tuple_vec = (0usize..100) - .map(|i| { - ( - (BLOCK_SIZE * i..BLOCK_SIZE * (i + 1)) - .map(|j| { - Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) - * h - }) - .collect::<[Fp61BitPrime; BLOCK_SIZE]>(), - (BLOCK_SIZE * i..BLOCK_SIZE * (i + 1)) - .map(|j| { - Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) - * h - }) - .collect::<[Fp61BitPrime; BLOCK_SIZE]>(), - ) - }) - .collect::>(); - + .semi_honest((), |ctx, ()| async move { // generate and output VerifierBatch together with h value let ( my_batch_left_shares, @@ -610,7 +563,7 @@ pub mod test { ) = ProofBatch::generate( &ctx.narrow("generate_batch"), RecordIdRange::ALL, - uv_tuple_vec.into_iter(), + ProverValues(uv_values_iter_ref.clone()), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( @@ -644,31 +597,32 @@ pub mod test { /// Prover `P_i` and verifier `P_{i+1}` both generate `u` /// Prover `P_i` and verifier `P_{i-1}` both generate `v` /// - /// outputs `(my_u_and_v, u_from_right_prover, v_from_left_prover)` + /// outputs `(my_u_and_v, sum_of_uv, u_from_right_prover, v_from_left_prover)` + #[allow(clippy::type_complexity)] fn generate_u_v( ctx: &C, len: usize, ) -> ( - Vec>, + Vec<([Fp61BitPrime; FRF], [Fp61BitPrime; FRF])>, Fp61BitPrime, Vec, Vec, ) { // outputs - let mut vec_u_from_right_prover = Vec::::with_capacity(BLOCK_SIZE * len); - let mut vec_v_from_left_prover = Vec::::with_capacity(BLOCK_SIZE * len); + let mut vec_u_from_right_prover = Vec::::with_capacity(FRF * len); + let mut vec_v_from_left_prover = Vec::::with_capacity(FRF * len); let mut vec_my_u_and_v = - Vec::<([Fp61BitPrime; BLOCK_SIZE], [Fp61BitPrime; BLOCK_SIZE])>::with_capacity(len); + Vec::<([Fp61BitPrime; FRF], [Fp61BitPrime; FRF])>::with_capacity(len); let mut sum_of_uv = Fp61BitPrime::ZERO; // generate random u, v values using PRSS let mut counter = RecordId::FIRST; for _ in 0..len { - let mut my_u_array = [Fp61BitPrime::ZERO; BLOCK_SIZE]; - let mut my_v_array = [Fp61BitPrime::ZERO; BLOCK_SIZE]; - for i in 0..BLOCK_SIZE { + let mut my_u_array = [Fp61BitPrime::ZERO; FRF]; + let mut my_v_array = [Fp61BitPrime::ZERO; FRF]; + for i in 0..FRF { let (my_u, u_from_right_prover) = ctx.prss().generate_fields(counter); counter += 1; let (v_from_left_prover, my_v) = ctx.prss().generate_fields(counter); @@ -725,7 +679,7 @@ pub mod test { ) = ProofBatch::generate( &ctx.narrow("generate_batch"), RecordIdRange::ALL, - vec_my_u_and_v.into_iter(), + ProverValues(vec_my_u_and_v.into_iter()), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( @@ -831,7 +785,7 @@ pub mod test { ) = ProofBatch::generate( &ctx.narrow("generate_batch"), RecordIdRange::ALL, - vec_my_u_and_v.into_iter(), + ProverValues(vec_my_u_and_v.into_iter()), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( @@ -858,8 +812,10 @@ pub mod test { let (p, q) = batch_to_verify.compute_p_and_q_r( &challenges_for_left_prover, &challenges_for_right_prover, - vec_u_from_right_prover.into_iter(), - vec_v_from_left_prover.into_iter(), + VerifierValues( + vec_u_from_right_prover.into_iter().chunk_array::(), + ), + VerifierValues(vec_v_from_left_prover.into_iter().chunk_array::()), ); let p_times_q = @@ -921,7 +877,7 @@ pub mod test { ) = ProofBatch::generate( &ctx.narrow("generate_batch"), RecordIdRange::ALL, - vec_my_u_and_v.into_iter(), + ProverValues(vec_my_u_and_v.into_iter()), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( @@ -961,8 +917,8 @@ pub mod test { let (p, q) = batch_to_verify.compute_p_and_q_r( &challenges_for_left_prover, &challenges_for_right_prover, - vec_u_from_right_prover.into_iter(), - vec_v_from_left_prover.into_iter(), + VerifierValues(vec_u_from_right_prover.into_iter().chunk_array::()), + VerifierValues(vec_v_from_left_prover.into_iter().chunk_array::()), ); batch_to_verify @@ -989,14 +945,13 @@ pub mod test { // Test a batch that exercises the case where `uv_values.len() == 1` but `did_set_masks = // false` in `ProofBatch::generate`. // - // We divide by `BLOCK_SIZE` here because `generate_u_v`, which is used by + // We divide by `FRF` here because `generate_u_v`, which is used by // `verify_batch` to generate test data, generates `len` chunks of u/v values of - // length `BLOCK_SIZE`. We want the input u/v values to compress to exactly one + // length `FRF`. We want the input u/v values to compress to exactly one // u/v pair after some number of proof steps. - let num_inputs = FirstProofGenerator::RECURSION_FACTOR - * CompressedProofGenerator::RECURSION_FACTOR - * CompressedProofGenerator::RECURSION_FACTOR; - assert!(num_inputs % BLOCK_SIZE == 0); - verify_batch(num_inputs / BLOCK_SIZE); + let num_inputs = + FirstProofGenerator::RECURSION_FACTOR * CompressedProofGenerator::RECURSION_FACTOR; + assert!(num_inputs % FRF == 0); + verify_batch(num_inputs / FRF); } } diff --git a/ipa-core/src/secret_sharing/vector/transpose.rs b/ipa-core/src/secret_sharing/vector/transpose.rs index 3a8357f02..b020ddb04 100644 --- a/ipa-core/src/secret_sharing/vector/transpose.rs +++ b/ipa-core/src/secret_sharing/vector/transpose.rs @@ -100,7 +100,7 @@ pub trait TransposeFrom { // // From Hacker's Delight (2nd edition), Figure 7-6. // -// There are comments on `dzkp_field::convert_values_table_indices`, which implements a +// There are comments on `dzkp_field::bits_to_table_indices`, which implements a // similar transformation, that may help to understand how this works. #[inline] pub fn transpose_8x8>(x: B) -> [u8; 8] { @@ -125,7 +125,7 @@ pub fn transpose_8x8>(x: B) -> [u8; 8] { // // Loosely based on Hacker's Delight (2nd edition), Figure 7-6. // -// There are comments on `dzkp_field::convert_values_table_indices`, which implements a +// There are comments on `dzkp_field::bits_to_table_indices`, which implements a // similar transformation, that may help to understand how this works. #[inline] pub fn transpose_16x16(src: &[u8; 32]) -> [u8; 32] {