From f30bfe34aa2beecfe04a0f66fe217259a0b29f4a Mon Sep 17 00:00:00 2001 From: Ben Savage Date: Thu, 13 Jun 2024 15:50:05 +0800 Subject: [PATCH] Cleaning up test code --- .../ipa_prf/malicious_security/lagrange.rs | 1 + .../ipa_prf/malicious_security/prover.rs | 6 +- .../ipa_prf/malicious_security/verifier.rs | 176 +++++++----------- 3 files changed, 74 insertions(+), 109 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 466b915a4..2477f0867 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -62,6 +62,7 @@ where /// The "x coordinates" of the output points `x_N` to `x_(N+M-1)` are `N*F::ONE` to `(N+M-1)*F::ONE` /// when generated using `from(denominator)` /// unless generated using `new(denominator, x_output)` for a specific output "x coordinate" `x_output`. +#[derive(Debug)] pub struct LagrangeTable { table: [[F; N]; M], } diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs index 49a015070..d0480183a 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs @@ -302,7 +302,7 @@ mod test { const U_3: [u128; 2] = [3, 3]; // will later be padded with zeroes const V_3: [u128; 2] = [5, 24]; // will later be padded with zeroes - const PROOF_3: [u128; 7] = [12, 15, 10, 0, 18, 6, 5]; + const PROOF_3: [u128; 7] = [12, 10, 0, 15, 16, 19, 5]; const P_RANDOM_WEIGHT: u128 = 12; const Q_RANDOM_WEIGHT: u128 = 1; @@ -361,8 +361,8 @@ mod test { assert_eq!(uv_3, zip_chunks(U_3, V_3)); let masked_uv_3 = zip_chunks( - [P_RANDOM_WEIGHT, U_3[0], U_3[1]], - [Q_RANDOM_WEIGHT, V_3[0], V_3[1]], + [P_RANDOM_WEIGHT, U_3[1], 0, U_3[0]], + [Q_RANDOM_WEIGHT, V_3[1], 0, V_3[0]], ); // final iteration diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs index d7423a585..772f91159 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs @@ -52,19 +52,19 @@ where }) } -pub async fn compute_p_or_q( +pub async fn recursively_compute_final_check( u_or_v_iterator: J, - r: Vec, + challenges: Vec, p_or_q_0: F, ) -> F where J: Stream + Unpin, { - let recursions = r.len(); + let recursions = challenges.len(); // compute Lagrange tables let denominator_p_or_q = CanonicalLagrangeDenominator::::new(); - let tables = r + let tables = challenges .iter() .map(|r| LagrangeTable::::new(&denominator_p_or_q, r)) .collect::>(); @@ -79,7 +79,16 @@ where // make sure stream is empty assert!(stream.next().await.is_none()); + // In the protocol, the prover is expected to continue recursively compressing the + // u and v vectors until it has length strictly less than λ. + // For this reason, we can safely assume that in the final proof, the last value + // in the ZKP is zero. + // A debug assert will help catch any errors while in development. // set mask + debug_assert!( + last_u_or_v_array[λ - 1] == F::ZERO, + "Should not be overwriting non-zero values" + ); last_u_or_v_array[λ - 1] = last_u_or_v_array[0]; last_u_or_v_array[0] = p_or_q_0; @@ -97,8 +106,12 @@ mod test { ff::{Fp31, PrimeField, U128Conversions}, protocol::ipa_prf::malicious_security::{ lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, - verifier::{compute_p_or_q, compute_sum_share, interpolate_at_r, recurse_u_or_v}, + verifier::{ + compute_sum_share, interpolate_at_r, recurse_u_or_v, + recursively_compute_final_check, + }, }, + secret_sharing::SharedValue, }; fn make_chunks(a: &[u128]) -> Vec<[F; N]> { @@ -175,55 +188,63 @@ mod test { const U_2: [u128; 8] = [0, 0, 26, 0, 7, 18, 24, 13]; - const U_3: [u128; 2] = [3, 3]; + const U_3: [u128; 4] = [3, 3, 0, 0]; - const R: [u128; 3] = [22, 17, 30]; + const CHALLENGES: [u128; 3] = [22, 17, 30]; const P_RANDOM_WEIGHT: u128 = 12; - const EXPECTED_P_FINAL: u128 = 30; + const EXPECTED_P_FINAL: u128 = 27; // compute Lagrange tables let denominator_p_or_q = CanonicalLagrangeDenominator::::new(); - let table_1 = - LagrangeTable::::new(&denominator_p_or_q, &Fp31::try_from(R[0]).unwrap()); - let table_2 = - LagrangeTable::::new(&denominator_p_or_q, &Fp31::try_from(R[1]).unwrap()); - let denominator_p_or_q_final = CanonicalLagrangeDenominator::::new(); - let table_3 = LagrangeTable::::new( - &denominator_p_or_q_final, - &Fp31::try_from(R[2]).unwrap(), - ); + let tables: [LagrangeTable; 3] = CHALLENGES + .map(|r| LagrangeTable::new(&denominator_p_or_q, &Fp31::try_from(r).unwrap())); // uv values in input format let u_1 = make_chunks::<_, 4>(&U_1); - let u_or_v_2 = recurse_u_or_v(stream::iter(u_1), &table_1) + let u_or_v_2 = recurse_u_or_v(stream::iter(u_1), &tables[0]) .await .collect::>() .await; assert_eq!(u_or_v_2, make_chunks::(&U_2)); - let u_or_v_3_temp = recurse_u_or_v(stream::iter(u_or_v_2.into_iter()), &table_2) + let u_or_v_3 = recurse_u_or_v(stream::iter(u_or_v_2.into_iter()), &tables[1]) .await .collect::>() .await; - // final proof trim from U4 to U2 - let u_or_v_3 = [ - Fp31::try_from(P_RANDOM_WEIGHT).unwrap(), - u_or_v_3_temp[0][0], - u_or_v_3_temp[0][1], - ]; + assert_eq!(u_or_v_3, make_chunks::(&U_3)); - assert_eq!([u_or_v_3[1], u_or_v_3[2]], make_chunks::(&U_3)[0]); + let u_or_v_3_masked = [ + Fp31::try_from(P_RANDOM_WEIGHT).unwrap(), // set mask at index 0 + u_or_v_3[0][1], + Fp31::ZERO, + u_or_v_3[0][0], // move first element to the end + ]; - let p_final = recurse_u_or_v(stream::iter(iter::once(u_or_v_3)), &table_3) + let p_final = recurse_u_or_v(stream::iter(iter::once(u_or_v_3_masked)), &tables[2]) .await .collect::>() .await; assert_eq!(p_final[0][0].as_u128(), EXPECTED_P_FINAL); + + // uv values in input format + let u_1 = make_chunks::<_, 4>(&U_1); + + let p_final_another_way = recursively_compute_final_check::( + stream::iter(u_1), + CHALLENGES + .map(|x| Fp31::try_from(x).unwrap()) + .into_iter() + .collect::>(), + Fp31::try_from(P_RANDOM_WEIGHT).unwrap(), + ) + .await; + + assert_eq!(p_final_another_way.as_u128(), EXPECTED_P_FINAL); } #[test] @@ -290,120 +311,63 @@ mod test { 1, 1, 0, 0, 1, 1, ]; const V_2: [u128; 8] = [10, 21, 30, 28, 15, 21, 3, 3]; - const V_3: [u128; 2] = [5, 24]; + const V_3: [u128; 4] = [5, 24, 0, 0]; - const R: [u128; 3] = [22, 17, 30]; + const CHALLENGES: [u128; 3] = [22, 17, 30]; const Q_RANDOM_WEIGHT: u128 = 1; - const EXPECTED_Q_FINAL: u128 = 12; + const EXPECTED_Q_FINAL: u128 = 10; // compute Lagrange tables let denominator_p_or_q = CanonicalLagrangeDenominator::::new(); - let table_1 = - LagrangeTable::::new(&denominator_p_or_q, &Fp31::try_from(R[0]).unwrap()); - let table_2 = - LagrangeTable::::new(&denominator_p_or_q, &Fp31::try_from(R[1]).unwrap()); - let denominator_p_or_q_final = CanonicalLagrangeDenominator::::new(); - let table_3 = LagrangeTable::::new( - &denominator_p_or_q_final, - &Fp31::try_from(R[2]).unwrap(), - ); + let tables: [LagrangeTable; 3] = CHALLENGES + .map(|r| LagrangeTable::new(&denominator_p_or_q, &Fp31::try_from(r).unwrap())); // uv values in input format let v_1 = make_chunks::<_, 4>(&V_1); - let u_or_v_2 = recurse_u_or_v(stream::iter(v_1), &table_1) + let u_or_v_2 = recurse_u_or_v(stream::iter(v_1), &tables[0]) .await .collect::>() .await; assert_eq!(u_or_v_2, make_chunks::(&V_2)); - assert_eq!(u_or_v_2, make_chunks::(&V_2)); - let u_or_v_3_temp = recurse_u_or_v(stream::iter(u_or_v_2.into_iter()), &table_2) + let u_or_v_3 = recurse_u_or_v(stream::iter(u_or_v_2.into_iter()), &tables[1]) .await .collect::>() .await; - // final proof trim from U4 to U2 - let u_or_v_3 = [ - Fp31::try_from(Q_RANDOM_WEIGHT).unwrap(), - u_or_v_3_temp[0][0], - u_or_v_3_temp[0][1], - ]; + assert_eq!(u_or_v_3, make_chunks::(&V_3)); - assert_eq!([u_or_v_3[1], u_or_v_3[2]], make_chunks::(&V_3)[0]); + let u_or_v_3_masked = [ + Fp31::try_from(Q_RANDOM_WEIGHT).unwrap(), // set mask at index 0 + u_or_v_3[0][1], + Fp31::ZERO, + u_or_v_3[0][0], // move first element to the end + ]; // final iteration - - let p_final = recurse_u_or_v(stream::iter(iter::once(u_or_v_3)), &table_3) + let p_final = recurse_u_or_v(stream::iter(iter::once(u_or_v_3_masked)), &tables[2]) .await .collect::>() .await; assert_eq!(p_final[0][0].as_u128(), EXPECTED_Q_FINAL); - } - - #[tokio::test] - async fn recursive_compute_p() { - const U_1: [u128; 32] = [ - 0, 30, 0, 16, 0, 1, 0, 15, 0, 0, 0, 16, 0, 30, 0, 16, 29, 1, 1, 15, 0, 0, 1, 15, 2, 30, - 30, 16, 0, 0, 30, 16, - ]; - - const R: [u128; 3] = [22, 17, 30]; - - const P_RANDOM_WEIGHT: u128 = 12; // uv values in input format - let u_1 = make_chunks::<_, 4>(&U_1); + let v_1 = make_chunks::<_, 4>(&V_1); - // this approach of computing p is tested in test sample_u_recursion - let tested_p = { - // compute Lagrange tables - let denominator_p_or_q = CanonicalLagrangeDenominator::::new(); - let table_1 = LagrangeTable::::new( - &denominator_p_or_q, - &Fp31::try_from(R[0]).unwrap(), - ); - let table_2 = LagrangeTable::::new( - &denominator_p_or_q, - &Fp31::try_from(R[1]).unwrap(), - ); - let table_3 = LagrangeTable::::new( - &denominator_p_or_q, - &Fp31::try_from(R[2]).unwrap(), - ); - - let u_or_v_2 = recurse_u_or_v(stream::iter(u_1.clone()), &table_1).await; - - let mut u_or_v_3 = recurse_u_or_v(u_or_v_2, &table_2) - .await - .collect::>() - .await[0]; - - // set mask - u_or_v_3[3] = u_or_v_3[0]; - u_or_v_3[0] = Fp31::try_from(P_RANDOM_WEIGHT).unwrap(); - - let p_final = recurse_u_or_v(stream::iter(iter::once(u_or_v_3)), &table_3) - .await - .collect::>() - .await; - - // return tested p - p_final[0][0] - }; - - let p = compute_p_or_q::( - stream::iter(u_1), - R.map(|x| Fp31::try_from(x).unwrap()) + let q_final_another_way = recursively_compute_final_check::( + stream::iter(v_1), + CHALLENGES + .map(|x| Fp31::try_from(x).unwrap()) .into_iter() .collect::>(), - Fp31::try_from(P_RANDOM_WEIGHT).unwrap(), + Fp31::try_from(Q_RANDOM_WEIGHT).unwrap(), ) .await; - assert_eq!(p, tested_p); + assert_eq!(q_final_another_way.as_u128(), EXPECTED_Q_FINAL); } }