From b7194fe55f9e5a0ddc63e7a0becc3a92448899a3 Mon Sep 17 00:00:00 2001 From: danielmasny Date: Fri, 6 Sep 2024 15:28:38 -0700 Subject: [PATCH] improving shuffle verification using Alex's suggestions --- ipa-core/src/protocol/ipa_prf/shuffle/base.rs | 18 ++++- .../src/protocol/ipa_prf/shuffle/malicious.rs | 73 ++++++++----------- 2 files changed, 46 insertions(+), 45 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs index 8020bfebf..a34477d69 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/base.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/base.rs @@ -72,8 +72,8 @@ impl IntermediateShuffleMessages { /// /// ## Panics /// Panics when `Role = H2`, i.e. `x1_or_y1` is `None`. - pub fn get_x1_or_y1(&self) -> &Vec { - self.x1_or_y1.as_ref().unwrap() + pub fn get_x1_or_y1(self) -> Vec { + self.x1_or_y1.unwrap() } /// When `IntermediateShuffleMessages` is initialized correctly, @@ -82,8 +82,18 @@ impl IntermediateShuffleMessages { /// /// ## Panics /// Panics when `Role = H1`, i.e. `x2_or_y2` is `None`. - pub fn get_x2_or_y2(&self) -> &Vec { - self.x2_or_y2.as_ref().unwrap() + pub fn get_x2_or_y2(self) -> Vec { + self.x2_or_y2.unwrap() + } + + /// When `IntermediateShuffleMessages` is initialized correctly, + /// this function returns `y1` and `y2` when `Role = H3`. + /// + /// ## Panics + /// Panics when `Role = H1`, i.e. `x2_or_y2` is `None` or + /// when `Role = H2`, i.e. `x1_or_y1` is `None`. + pub fn get_both_x_or_ys(self) -> (Vec, Vec) { + (self.x1_or_y1.unwrap(), self.x2_or_y2.unwrap()) } } diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index a8f368b35..a58fcae58 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -1,4 +1,4 @@ -use std::{borrow::Borrow, iter}; +use std::iter; use futures_util::future::{try_join, try_join3}; @@ -17,7 +17,7 @@ use crate::{ }, secret_sharing::{ replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing}, - SharedValue, StdArray, + SharedValue, SharedValueArray, StdArray, }, }; @@ -43,14 +43,8 @@ async fn verify_shuffle( Role::H1 => h1_verify(ctx, &keys, shuffled_shares, messages.get_x1_or_y1()).await, Role::H2 => h2_verify(ctx, &keys, shuffled_shares, messages.get_x2_or_y2()).await, Role::H3 => { - h3_verify( - ctx, - &keys, - shuffled_shares, - messages.get_x1_or_y1(), - messages.get_x2_or_y2(), - ) - .await + let (y1, y2) = messages.get_both_x_or_ys(); + h3_verify(ctx, &keys, shuffled_shares, y1, y2).await } } } @@ -68,13 +62,13 @@ async fn h1_verify( ctx: C, keys: &[StdArray], share_a_and_b: &[AdditiveShare], - x1: &[S], + x1: Vec, ) -> Result<(), Error> { // compute hashes // compute hash for x1 - let hash_x1 = compute_row_hash::(keys, x1); + let hash_x1 = compute_row_hash(keys, x1); // compute hash for A xor B - let hash_a_xor_b = compute_row_hash::( + let hash_a_xor_b = compute_row_hash( keys, share_a_and_b .iter() @@ -87,32 +81,29 @@ async fn h1_verify( .set_total_records(TotalRecords::specified(2)?); let h2_ctx = ctx .narrow(&OPRFShuffleStep::HashH2toH1) - .set_total_records(TotalRecords::specified(1)?); + .set_total_records(TotalRecords::ONE); let channel_h3 = &h3_ctx.recv_channel::(ctx.role().peer(Direction::Left)); let channel_h2 = &h2_ctx.recv_channel::(ctx.role().peer(Direction::Right)); // receive hashes - let (hashes_h3, hash_h2) = try_join( - h3_ctx.parallel_join( - (0usize..=1).map(|i| async move { channel_h3.receive(RecordId::from(i)).await }), - ), + let (hash_y1, hash_h3, hash_h2) = try_join3( + channel_h3.receive(RecordId::FIRST), + channel_h3.receive(RecordId::from(1usize)), channel_h2.receive(RecordId::FIRST), ) .await?; // check y1 - if hash_x1 != hashes_h3[0] { + if hash_x1 != hash_y1 { return Err(Error::ShuffleValidationFailed(format!( - "Y1 is inconsistent: hash of x1: {hash_x1:?}, hash of y1: {:?}", - hashes_h3[0] + "Y1 is inconsistent: hash of x1: {hash_x1:?}, hash of y1: {hash_y1:?}" ))); } // check c from h3 - if hash_a_xor_b != hashes_h3[1] { + if hash_a_xor_b != hash_h3 { return Err(Error::ShuffleValidationFailed(format!( - "C from H3 is inconsistent: hash of a_xor_b: {hash_a_xor_b:?}, hash of C: {:?}", - hashes_h3[1] + "C from H3 is inconsistent: hash of a_xor_b: {hash_a_xor_b:?}, hash of C: {hash_h3:?}" ))); } @@ -138,13 +129,13 @@ async fn h2_verify( ctx: C, keys: &[StdArray], share_b_and_c: &[AdditiveShare], - x2: &[S], + x2: Vec, ) -> Result<(), Error> { // compute hashes // compute hash for x2 - let hash_x2 = compute_row_hash::(keys, x2); + let hash_x2 = compute_row_hash(keys, x2); // compute hash for C - let hash_c = compute_row_hash::( + let hash_c = compute_row_hash( keys, share_b_and_c.iter().map(ReplicatedSecretSharing::right), ); @@ -186,16 +177,16 @@ async fn h3_verify( ctx: C, keys: &[StdArray], share_c_and_a: &[AdditiveShare], - y1: &[S], - y2: &[S], + y1: Vec, + y2: Vec, ) -> Result<(), Error> { // compute hashes // compute hash for y1 - let hash_y1 = compute_row_hash::(keys, y1); + let hash_y1 = compute_row_hash(keys, y1); // compute hash for y2 - let hash_y2 = compute_row_hash::(keys, y2); + let hash_y2 = compute_row_hash(keys, y2); // compute hash for C - let hash_c = compute_row_hash::( + let hash_c = compute_row_hash( keys, share_c_and_a.iter().map(ReplicatedSecretSharing::left), ); @@ -226,20 +217,19 @@ async fn h3_verify( /// /// ## Panics /// Panics when conversion from `BooleanArray` to `Vec(keys: &[StdArray], row_iterator: I) -> Hash +fn compute_row_hash(keys: &[StdArray], row_iterator: I) -> Hash where S: BooleanArray, - B: Borrow, - I: IntoIterator, + I: IntoIterator, { let iterator = row_iterator .into_iter() - .map(|s| (*(s.borrow())).try_into().unwrap()); + .map(|row| >>::try_into(row).unwrap()); compute_hash(iterator.map(|row| { - row.iter() + row.into_iter() .zip(keys) .fold(Gf32Bit::ZERO, |acc, (row_entry, key)| { - acc + *row_entry * *key.first() + acc + row_entry * *key.first() }) })) } @@ -257,16 +247,17 @@ async fn reveal_keys( key_shares: &[AdditiveShare], ) -> Result>, Error> { // reveal MAC keys - let mut keys = ctx + let keys = ctx .parallel_join(key_shares.iter().enumerate().map(|(i, key)| async move { malicious_reveal(ctx.clone(), RecordId::from(i), None, key).await })) .await? .into_iter() .flatten() + // add a one, since last row element is tag which is not multiplied with a key + .chain(iter::once(StdArray::from_fn(|_| Gf32Bit::ONE))) .collect::>(); - // add a one, since last row element is tag which is not multiplied with a key - keys.push(iter::once(Gf32Bit::ONE).collect()); + Ok(keys) }