Skip to content

Commit

Permalink
Merge pull request private-attribution#1136 from benjaminsavage/ben_nits
Browse files Browse the repository at this point in the history
Cleaning up test code
  • Loading branch information
danielmasny authored Jun 13, 2024
2 parents c56449c + f30bfe3 commit 3633fbe
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ where
/// The "x coordinates" of the output points `x_N` to `x_(N+M-1)` are `N*F::ONE` to `(N+M-1)*F::ONE`
/// when generated using `from(denominator)`
/// unless generated using `new(denominator, x_output)` for a specific output "x coordinate" `x_output`.
#[derive(Debug)]
pub struct LagrangeTable<F: Field, const N: usize, const M: usize> {
table: [[F; N]; M],
}
Expand Down
6 changes: 3 additions & 3 deletions ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ mod test {
const U_3: [u128; 2] = [3, 3]; // will later be padded with zeroes
const V_3: [u128; 2] = [5, 24]; // will later be padded with zeroes

const PROOF_3: [u128; 7] = [12, 15, 10, 0, 18, 6, 5];
const PROOF_3: [u128; 7] = [12, 10, 0, 15, 16, 19, 5];
const P_RANDOM_WEIGHT: u128 = 12;
const Q_RANDOM_WEIGHT: u128 = 1;

Expand Down Expand Up @@ -361,8 +361,8 @@ mod test {
assert_eq!(uv_3, zip_chunks(U_3, V_3));

let masked_uv_3 = zip_chunks(
[P_RANDOM_WEIGHT, U_3[0], U_3[1]],
[Q_RANDOM_WEIGHT, V_3[0], V_3[1]],
[P_RANDOM_WEIGHT, U_3[1], 0, U_3[0]],
[Q_RANDOM_WEIGHT, V_3[1], 0, V_3[0]],
);

// final iteration
Expand Down
176 changes: 70 additions & 106 deletions ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,19 @@ where
})
}

pub async fn compute_p_or_q<F: PrimeField, J, const λ: usize>(
pub async fn recursively_compute_final_check<F: PrimeField, J, const λ: usize>(
u_or_v_iterator: J,
r: Vec<F>,
challenges: Vec<F>,
p_or_q_0: F,
) -> F
where
J: Stream<Item = [F; λ]> + Unpin,
{
let recursions = r.len();
let recursions = challenges.len();

// compute Lagrange tables
let denominator_p_or_q = CanonicalLagrangeDenominator::<F, λ>::new();
let tables = r
let tables = challenges
.iter()
.map(|r| LagrangeTable::<F, λ, 1>::new(&denominator_p_or_q, r))
.collect::<Vec<_>>();
Expand All @@ -79,7 +79,16 @@ where
// make sure stream is empty
assert!(stream.next().await.is_none());

// In the protocol, the prover is expected to continue recursively compressing the
// u and v vectors until it has length strictly less than λ.
// For this reason, we can safely assume that in the final proof, the last value
// in the ZKP is zero.
// A debug assert will help catch any errors while in development.
// set mask
debug_assert!(
last_u_or_v_array[λ - 1] == F::ZERO,
"Should not be overwriting non-zero values"
);
last_u_or_v_array[λ - 1] = last_u_or_v_array[0];
last_u_or_v_array[0] = p_or_q_0;

Expand All @@ -97,8 +106,12 @@ mod test {
ff::{Fp31, PrimeField, U128Conversions},
protocol::ipa_prf::malicious_security::{
lagrange::{CanonicalLagrangeDenominator, LagrangeTable},
verifier::{compute_p_or_q, compute_sum_share, interpolate_at_r, recurse_u_or_v},
verifier::{
compute_sum_share, interpolate_at_r, recurse_u_or_v,
recursively_compute_final_check,
},
},
secret_sharing::SharedValue,
};

fn make_chunks<F: PrimeField, const N: usize>(a: &[u128]) -> Vec<[F; N]> {
Expand Down Expand Up @@ -175,55 +188,63 @@ mod test {

const U_2: [u128; 8] = [0, 0, 26, 0, 7, 18, 24, 13];

const U_3: [u128; 2] = [3, 3];
const U_3: [u128; 4] = [3, 3, 0, 0];

const R: [u128; 3] = [22, 17, 30];
const CHALLENGES: [u128; 3] = [22, 17, 30];

const P_RANDOM_WEIGHT: u128 = 12;

const EXPECTED_P_FINAL: u128 = 30;
const EXPECTED_P_FINAL: u128 = 27;

// compute Lagrange tables
let denominator_p_or_q = CanonicalLagrangeDenominator::<Fp31, 4>::new();
let table_1 =
LagrangeTable::<Fp31, 4, 1>::new(&denominator_p_or_q, &Fp31::try_from(R[0]).unwrap());
let table_2 =
LagrangeTable::<Fp31, 4, 1>::new(&denominator_p_or_q, &Fp31::try_from(R[1]).unwrap());
let denominator_p_or_q_final = CanonicalLagrangeDenominator::<Fp31, 3>::new();
let table_3 = LagrangeTable::<Fp31, 3, 1>::new(
&denominator_p_or_q_final,
&Fp31::try_from(R[2]).unwrap(),
);
let tables: [LagrangeTable<Fp31, 4, 1>; 3] = CHALLENGES
.map(|r| LagrangeTable::new(&denominator_p_or_q, &Fp31::try_from(r).unwrap()));

// uv values in input format
let u_1 = make_chunks::<_, 4>(&U_1);

let u_or_v_2 = recurse_u_or_v(stream::iter(u_1), &table_1)
let u_or_v_2 = recurse_u_or_v(stream::iter(u_1), &tables[0])
.await
.collect::<Vec<_>>()
.await;
assert_eq!(u_or_v_2, make_chunks::<Fp31, 4>(&U_2));

let u_or_v_3_temp = recurse_u_or_v(stream::iter(u_or_v_2.into_iter()), &table_2)
let u_or_v_3 = recurse_u_or_v(stream::iter(u_or_v_2.into_iter()), &tables[1])
.await
.collect::<Vec<_>>()
.await;

// final proof trim from U4 to U2
let u_or_v_3 = [
Fp31::try_from(P_RANDOM_WEIGHT).unwrap(),
u_or_v_3_temp[0][0],
u_or_v_3_temp[0][1],
];
assert_eq!(u_or_v_3, make_chunks::<Fp31, 4>(&U_3));

assert_eq!([u_or_v_3[1], u_or_v_3[2]], make_chunks::<Fp31, 2>(&U_3)[0]);
let u_or_v_3_masked = [
Fp31::try_from(P_RANDOM_WEIGHT).unwrap(), // set mask at index 0
u_or_v_3[0][1],
Fp31::ZERO,
u_or_v_3[0][0], // move first element to the end
];

let p_final = recurse_u_or_v(stream::iter(iter::once(u_or_v_3)), &table_3)
let p_final = recurse_u_or_v(stream::iter(iter::once(u_or_v_3_masked)), &tables[2])
.await
.collect::<Vec<_>>()
.await;

assert_eq!(p_final[0][0].as_u128(), EXPECTED_P_FINAL);

// uv values in input format
let u_1 = make_chunks::<_, 4>(&U_1);

let p_final_another_way = recursively_compute_final_check::<Fp31, _, 4>(
stream::iter(u_1),
CHALLENGES
.map(|x| Fp31::try_from(x).unwrap())
.into_iter()
.collect::<Vec<_>>(),
Fp31::try_from(P_RANDOM_WEIGHT).unwrap(),
)
.await;

assert_eq!(p_final_another_way.as_u128(), EXPECTED_P_FINAL);
}

#[test]
Expand Down Expand Up @@ -290,120 +311,63 @@ mod test {
1, 1, 0, 0, 1, 1,
];
const V_2: [u128; 8] = [10, 21, 30, 28, 15, 21, 3, 3];
const V_3: [u128; 2] = [5, 24];
const V_3: [u128; 4] = [5, 24, 0, 0];

const R: [u128; 3] = [22, 17, 30];
const CHALLENGES: [u128; 3] = [22, 17, 30];

const Q_RANDOM_WEIGHT: u128 = 1;

const EXPECTED_Q_FINAL: u128 = 12;
const EXPECTED_Q_FINAL: u128 = 10;

// compute Lagrange tables
let denominator_p_or_q = CanonicalLagrangeDenominator::<Fp31, 4>::new();
let table_1 =
LagrangeTable::<Fp31, 4, 1>::new(&denominator_p_or_q, &Fp31::try_from(R[0]).unwrap());
let table_2 =
LagrangeTable::<Fp31, 4, 1>::new(&denominator_p_or_q, &Fp31::try_from(R[1]).unwrap());
let denominator_p_or_q_final = CanonicalLagrangeDenominator::<Fp31, 3>::new();
let table_3 = LagrangeTable::<Fp31, 3, 1>::new(
&denominator_p_or_q_final,
&Fp31::try_from(R[2]).unwrap(),
);
let tables: [LagrangeTable<Fp31, 4, 1>; 3] = CHALLENGES
.map(|r| LagrangeTable::new(&denominator_p_or_q, &Fp31::try_from(r).unwrap()));

// uv values in input format
let v_1 = make_chunks::<_, 4>(&V_1);

let u_or_v_2 = recurse_u_or_v(stream::iter(v_1), &table_1)
let u_or_v_2 = recurse_u_or_v(stream::iter(v_1), &tables[0])
.await
.collect::<Vec<_>>()
.await;
assert_eq!(u_or_v_2, make_chunks::<Fp31, 4>(&V_2));
assert_eq!(u_or_v_2, make_chunks::<Fp31, 4>(&V_2));

let u_or_v_3_temp = recurse_u_or_v(stream::iter(u_or_v_2.into_iter()), &table_2)
let u_or_v_3 = recurse_u_or_v(stream::iter(u_or_v_2.into_iter()), &tables[1])
.await
.collect::<Vec<_>>()
.await;

// final proof trim from U4 to U2
let u_or_v_3 = [
Fp31::try_from(Q_RANDOM_WEIGHT).unwrap(),
u_or_v_3_temp[0][0],
u_or_v_3_temp[0][1],
];
assert_eq!(u_or_v_3, make_chunks::<Fp31, 4>(&V_3));

assert_eq!([u_or_v_3[1], u_or_v_3[2]], make_chunks::<Fp31, 2>(&V_3)[0]);
let u_or_v_3_masked = [
Fp31::try_from(Q_RANDOM_WEIGHT).unwrap(), // set mask at index 0
u_or_v_3[0][1],
Fp31::ZERO,
u_or_v_3[0][0], // move first element to the end
];

// final iteration

let p_final = recurse_u_or_v(stream::iter(iter::once(u_or_v_3)), &table_3)
let p_final = recurse_u_or_v(stream::iter(iter::once(u_or_v_3_masked)), &tables[2])
.await
.collect::<Vec<_>>()
.await;

assert_eq!(p_final[0][0].as_u128(), EXPECTED_Q_FINAL);
}

#[tokio::test]
async fn recursive_compute_p() {
const U_1: [u128; 32] = [
0, 30, 0, 16, 0, 1, 0, 15, 0, 0, 0, 16, 0, 30, 0, 16, 29, 1, 1, 15, 0, 0, 1, 15, 2, 30,
30, 16, 0, 0, 30, 16,
];

const R: [u128; 3] = [22, 17, 30];

const P_RANDOM_WEIGHT: u128 = 12;

// uv values in input format
let u_1 = make_chunks::<_, 4>(&U_1);
let v_1 = make_chunks::<_, 4>(&V_1);

// this approach of computing p is tested in test sample_u_recursion
let tested_p = {
// compute Lagrange tables
let denominator_p_or_q = CanonicalLagrangeDenominator::<Fp31, 4>::new();
let table_1 = LagrangeTable::<Fp31, 4, 1>::new(
&denominator_p_or_q,
&Fp31::try_from(R[0]).unwrap(),
);
let table_2 = LagrangeTable::<Fp31, 4, 1>::new(
&denominator_p_or_q,
&Fp31::try_from(R[1]).unwrap(),
);
let table_3 = LagrangeTable::<Fp31, 4, 1>::new(
&denominator_p_or_q,
&Fp31::try_from(R[2]).unwrap(),
);

let u_or_v_2 = recurse_u_or_v(stream::iter(u_1.clone()), &table_1).await;

let mut u_or_v_3 = recurse_u_or_v(u_or_v_2, &table_2)
.await
.collect::<Vec<_>>()
.await[0];

// set mask
u_or_v_3[3] = u_or_v_3[0];
u_or_v_3[0] = Fp31::try_from(P_RANDOM_WEIGHT).unwrap();

let p_final = recurse_u_or_v(stream::iter(iter::once(u_or_v_3)), &table_3)
.await
.collect::<Vec<_>>()
.await;

// return tested p
p_final[0][0]
};

let p = compute_p_or_q::<Fp31, _, 4>(
stream::iter(u_1),
R.map(|x| Fp31::try_from(x).unwrap())
let q_final_another_way = recursively_compute_final_check::<Fp31, _, 4>(
stream::iter(v_1),
CHALLENGES
.map(|x| Fp31::try_from(x).unwrap())
.into_iter()
.collect::<Vec<_>>(),
Fp31::try_from(P_RANDOM_WEIGHT).unwrap(),
Fp31::try_from(Q_RANDOM_WEIGHT).unwrap(),
)
.await;

assert_eq!(p, tested_p);
assert_eq!(q_final_another_way.as_u128(), EXPECTED_Q_FINAL);
}
}

0 comments on commit 3633fbe

Please sign in to comment.