Skip to content

Commit

Permalink
improving shuffle verification using Alex's suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
danielmasny committed Sep 6, 2024
1 parent 5a654f6 commit b7194fe
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 45 deletions.
18 changes: 14 additions & 4 deletions ipa-core/src/protocol/ipa_prf/shuffle/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ impl<S: SharedValue> IntermediateShuffleMessages<S> {
///
/// ## Panics
/// Panics when `Role = H2`, i.e. `x1_or_y1` is `None`.
pub fn get_x1_or_y1(&self) -> &Vec<S> {
self.x1_or_y1.as_ref().unwrap()
pub fn get_x1_or_y1(self) -> Vec<S> {
self.x1_or_y1.unwrap()
}

/// When `IntermediateShuffleMessages` is initialized correctly,
Expand All @@ -82,8 +82,18 @@ impl<S: SharedValue> IntermediateShuffleMessages<S> {
///
/// ## Panics
/// Panics when `Role = H1`, i.e. `x2_or_y2` is `None`.
pub fn get_x2_or_y2(&self) -> &Vec<S> {
self.x2_or_y2.as_ref().unwrap()
pub fn get_x2_or_y2(self) -> Vec<S> {
self.x2_or_y2.unwrap()
}

/// When `IntermediateShuffleMessages` is initialized correctly,
/// this function returns `y1` and `y2` when `Role = H3`.
///
/// ## Panics
/// Panics when `Role = H1`, i.e. `x2_or_y2` is `None` or
/// when `Role = H2`, i.e. `x1_or_y1` is `None`.
pub fn get_both_x_or_ys(self) -> (Vec<S>, Vec<S>) {
(self.x1_or_y1.unwrap(), self.x2_or_y2.unwrap())
}
}

Expand Down
73 changes: 32 additions & 41 deletions ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{borrow::Borrow, iter};
use std::iter;

use futures_util::future::{try_join, try_join3};

Expand All @@ -17,7 +17,7 @@ use crate::{
},
secret_sharing::{
replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing},
SharedValue, StdArray,
SharedValue, SharedValueArray, StdArray,
},
};

Expand All @@ -43,14 +43,8 @@ async fn verify_shuffle<C: Context, S: BooleanArray>(
Role::H1 => h1_verify(ctx, &keys, shuffled_shares, messages.get_x1_or_y1()).await,
Role::H2 => h2_verify(ctx, &keys, shuffled_shares, messages.get_x2_or_y2()).await,
Role::H3 => {
h3_verify(
ctx,
&keys,
shuffled_shares,
messages.get_x1_or_y1(),
messages.get_x2_or_y2(),
)
.await
let (y1, y2) = messages.get_both_x_or_ys();
h3_verify(ctx, &keys, shuffled_shares, y1, y2).await
}
}
}
Expand All @@ -68,13 +62,13 @@ async fn h1_verify<C: Context, S: BooleanArray>(
ctx: C,
keys: &[StdArray<Gf32Bit, 1>],
share_a_and_b: &[AdditiveShare<S>],
x1: &[S],
x1: Vec<S>,
) -> Result<(), Error> {
// compute hashes
// compute hash for x1
let hash_x1 = compute_row_hash::<S, _, _>(keys, x1);
let hash_x1 = compute_row_hash(keys, x1);
// compute hash for A xor B
let hash_a_xor_b = compute_row_hash::<S, _, _>(
let hash_a_xor_b = compute_row_hash(
keys,
share_a_and_b
.iter()
Expand All @@ -87,32 +81,29 @@ async fn h1_verify<C: Context, S: BooleanArray>(
.set_total_records(TotalRecords::specified(2)?);
let h2_ctx = ctx
.narrow(&OPRFShuffleStep::HashH2toH1)
.set_total_records(TotalRecords::specified(1)?);
.set_total_records(TotalRecords::ONE);
let channel_h3 = &h3_ctx.recv_channel::<Hash>(ctx.role().peer(Direction::Left));
let channel_h2 = &h2_ctx.recv_channel::<Hash>(ctx.role().peer(Direction::Right));

// receive hashes
let (hashes_h3, hash_h2) = try_join(
h3_ctx.parallel_join(
(0usize..=1).map(|i| async move { channel_h3.receive(RecordId::from(i)).await }),
),
let (hash_y1, hash_h3, hash_h2) = try_join3(
channel_h3.receive(RecordId::FIRST),
channel_h3.receive(RecordId::from(1usize)),
channel_h2.receive(RecordId::FIRST),
)
.await?;

Check warning on line 94 in ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs#L94

Added line #L94 was not covered by tests

// check y1
if hash_x1 != hashes_h3[0] {
if hash_x1 != hash_y1 {
return Err(Error::ShuffleValidationFailed(format!(
"Y1 is inconsistent: hash of x1: {hash_x1:?}, hash of y1: {:?}",
hashes_h3[0]
"Y1 is inconsistent: hash of x1: {hash_x1:?}, hash of y1: {hash_y1:?}"
)));

Check warning on line 100 in ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs#L98-L100

Added lines #L98 - L100 were not covered by tests
}

// check c from h3
if hash_a_xor_b != hashes_h3[1] {
if hash_a_xor_b != hash_h3 {
return Err(Error::ShuffleValidationFailed(format!(
"C from H3 is inconsistent: hash of a_xor_b: {hash_a_xor_b:?}, hash of C: {:?}",
hashes_h3[1]
"C from H3 is inconsistent: hash of a_xor_b: {hash_a_xor_b:?}, hash of C: {hash_h3:?}"
)));

Check warning on line 107 in ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs#L105-L107

Added lines #L105 - L107 were not covered by tests
}

Expand All @@ -138,13 +129,13 @@ async fn h2_verify<C: Context, S: BooleanArray>(
ctx: C,
keys: &[StdArray<Gf32Bit, 1>],
share_b_and_c: &[AdditiveShare<S>],
x2: &[S],
x2: Vec<S>,
) -> Result<(), Error> {
// compute hashes
// compute hash for x2
let hash_x2 = compute_row_hash::<S, _, _>(keys, x2);
let hash_x2 = compute_row_hash(keys, x2);
// compute hash for C
let hash_c = compute_row_hash::<S, _, _>(
let hash_c = compute_row_hash(
keys,
share_b_and_c.iter().map(ReplicatedSecretSharing::right),
);
Expand Down Expand Up @@ -186,16 +177,16 @@ async fn h3_verify<C: Context, S: BooleanArray>(
ctx: C,
keys: &[StdArray<Gf32Bit, 1>],
share_c_and_a: &[AdditiveShare<S>],
y1: &[S],
y2: &[S],
y1: Vec<S>,
y2: Vec<S>,
) -> Result<(), Error> {
// compute hashes
// compute hash for y1
let hash_y1 = compute_row_hash::<S, _, _>(keys, y1);
let hash_y1 = compute_row_hash(keys, y1);
// compute hash for y2
let hash_y2 = compute_row_hash::<S, _, _>(keys, y2);
let hash_y2 = compute_row_hash(keys, y2);
// compute hash for C
let hash_c = compute_row_hash::<S, _, _>(
let hash_c = compute_row_hash(
keys,
share_c_and_a.iter().map(ReplicatedSecretSharing::left),
);
Expand Down Expand Up @@ -226,20 +217,19 @@ async fn h3_verify<C: Context, S: BooleanArray>(
///
/// ## Panics
/// Panics when conversion from `BooleanArray` to `Vec<Gf32Bit` fails.
fn compute_row_hash<S, B, I>(keys: &[StdArray<Gf32Bit, 1>], row_iterator: I) -> Hash
fn compute_row_hash<S, I>(keys: &[StdArray<Gf32Bit, 1>], row_iterator: I) -> Hash
where
S: BooleanArray,
B: Borrow<S>,
I: IntoIterator<Item = B>,
I: IntoIterator<Item = S>,
{
let iterator = row_iterator
.into_iter()
.map(|s| (*(s.borrow())).try_into().unwrap());
.map(|row| <S as TryInto<Vec<Gf32Bit>>>::try_into(row).unwrap());
compute_hash(iterator.map(|row| {
row.iter()
row.into_iter()
.zip(keys)
.fold(Gf32Bit::ZERO, |acc, (row_entry, key)| {
acc + *row_entry * *key.first()
acc + row_entry * *key.first()
})
}))
}
Expand All @@ -257,16 +247,17 @@ async fn reveal_keys<C: Context>(
key_shares: &[AdditiveShare<Gf32Bit>],
) -> Result<Vec<StdArray<Gf32Bit, 1>>, Error> {
// reveal MAC keys
let mut keys = ctx
let keys = ctx
.parallel_join(key_shares.iter().enumerate().map(|(i, key)| async move {
malicious_reveal(ctx.clone(), RecordId::from(i), None, key).await
}))
.await?
.into_iter()
.flatten()
// add a one, since last row element is tag which is not multiplied with a key
.chain(iter::once(StdArray::from_fn(|_| Gf32Bit::ONE)))
.collect::<Vec<_>>();
// add a one, since last row element is tag which is not multiplied with a key
keys.push(iter::once(Gf32Bit::ONE).collect());

Ok(keys)
}

Expand Down

0 comments on commit b7194fe

Please sign in to comment.