Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shuffle verification #1255

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ipa-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ pub enum Error {
record_id: RecordId,
total_records: usize,
},
#[error("The verification of the shuffle failed: {0}")]
ShuffleValidationFailed(String),
}

impl Default for Error {
Expand Down
33 changes: 33 additions & 0 deletions ipa-core/src/protocol/ipa_prf/shuffle/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,39 @@ pub struct IntermediateShuffleMessages<S: SharedValue> {
x2_or_y2: Option<Vec<S>>,
}

#[allow(dead_code)]
impl<S: SharedValue> IntermediateShuffleMessages<S> {
/// When `IntermediateShuffleMessages` is initialized correctly,
/// this function returns `x1` when `Role = H1`
/// and `y1` when `Role = H3`.
///
/// ## Panics
/// Panics when `Role = H2`, i.e. `x1_or_y1` is `None`.
pub fn get_x1_or_y1(self) -> Vec<S> {
self.x1_or_y1.unwrap()
}

/// When `IntermediateShuffleMessages` is initialized correctly,
/// this function returns `x2` when `Role = H2`
/// and `y2` when `Role = H3`.
///
/// ## Panics
/// Panics when `Role = H1`, i.e. `x2_or_y2` is `None`.
pub fn get_x2_or_y2(self) -> Vec<S> {
self.x2_or_y2.unwrap()
}

/// When `IntermediateShuffleMessages` is initialized correctly,
/// this function returns `y1` and `y2` when `Role = H3`.
///
/// ## Panics
/// Panics when `Role = H1`, i.e. `x2_or_y2` is `None` or
/// when `Role = H2`, i.e. `x1_or_y1` is `None`.
pub fn get_both_x_or_ys(self) -> (Vec<S>, Vec<S>) {
(self.x1_or_y1.unwrap(), self.x2_or_y2.unwrap())
}
}

async fn run_h1<C, I, S, Zl, Zr>(
ctx: &C,
batch_size: NonZeroUsize,
Expand Down
310 changes: 310 additions & 0 deletions ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
use std::iter;

use futures_util::future::{try_join, try_join3};

use crate::{
error::Error,
ff::{boolean_array::BooleanArray, Field, Gf32Bit},
helpers::{
hashing::{compute_hash, Hash},
Direction, Role, TotalRecords,
},
protocol::{
basics::malicious_reveal,
context::Context,
ipa_prf::shuffle::{base::IntermediateShuffleMessages, step::OPRFShuffleStep},
RecordId,
},
secret_sharing::{
replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing},
SharedValue, SharedValueArray, StdArray,
},
};

/// This function verifies the `shuffled_shares` and the `IntermediateShuffleMessages`.
///
/// ## Errors
/// Propagates network errors.
/// Further, returns an error when messages are inconsistent with the MAC tags.
async fn verify_shuffle<C: Context, S: BooleanArray>(
ctx: C,
key_shares: &[AdditiveShare<Gf32Bit>],
shuffled_shares: &[AdditiveShare<S>],
messages: IntermediateShuffleMessages<S>,
) -> Result<(), Error> {
// reveal keys
let k_ctx = ctx
.narrow(&OPRFShuffleStep::RevealMACKey)
.set_total_records(TotalRecords::specified(key_shares.len())?);
let keys = reveal_keys(&k_ctx, key_shares).await?;

// verify messages and shares
match ctx.role() {
Role::H1 => h1_verify(ctx, &keys, shuffled_shares, messages.get_x1_or_y1()).await,
Role::H2 => h2_verify(ctx, &keys, shuffled_shares, messages.get_x2_or_y2()).await,
Role::H3 => {
let (y1, y2) = messages.get_both_x_or_ys();
h3_verify(ctx, &keys, shuffled_shares, y1, y2).await
}
}
}

/// This is the verification function run by `H1`.
/// `H1` computes the hash for `x1` and `a_xor_b`.
/// Further, he receives `hash_y1` and `hash_c_h3` from `H3`
/// and `hash_c_h2` from `H2`.
///
/// ## Errors
/// Propagates network errors. Further it returns an error when
/// `hash_x1 != hash_y1` or `hash_c_h2 != hash_a_xor_b`
/// or `hash_c_h3 != hash_a_xor_b`.
async fn h1_verify<C: Context, S: BooleanArray>(
ctx: C,
keys: &[StdArray<Gf32Bit, 1>],
share_a_and_b: &[AdditiveShare<S>],
x1: Vec<S>,
) -> Result<(), Error> {
// compute hashes
// compute hash for x1
let hash_x1 = compute_row_hash(keys, x1);
// compute hash for A xor B
let hash_a_xor_b = compute_row_hash(
keys,
share_a_and_b
.iter()
.map(|share| share.left() + share.right()),
);

// setup channels
let h3_ctx = ctx
.narrow(&OPRFShuffleStep::HashesH3toH1)
.set_total_records(TotalRecords::specified(2)?);
let h2_ctx = ctx
.narrow(&OPRFShuffleStep::HashH2toH1)
.set_total_records(TotalRecords::ONE);
let channel_h3 = &h3_ctx.recv_channel::<Hash>(ctx.role().peer(Direction::Left));
let channel_h2 = &h2_ctx.recv_channel::<Hash>(ctx.role().peer(Direction::Right));

// receive hashes
let (hash_y1, hash_h3, hash_h2) = try_join3(
channel_h3.receive(RecordId::FIRST),
channel_h3.receive(RecordId::from(1usize)),
channel_h2.receive(RecordId::FIRST),
)
.await?;

Check warning on line 94 in ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs#L94

Added line #L94 was not covered by tests

// check y1
if hash_x1 != hash_y1 {
return Err(Error::ShuffleValidationFailed(format!(
"Y1 is inconsistent: hash of x1: {hash_x1:?}, hash of y1: {hash_y1:?}"
)));

Check warning on line 100 in ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs#L98-L100

Added lines #L98 - L100 were not covered by tests
}

// check c from h3
if hash_a_xor_b != hash_h3 {
return Err(Error::ShuffleValidationFailed(format!(
"C from H3 is inconsistent: hash of a_xor_b: {hash_a_xor_b:?}, hash of C: {hash_h3:?}"
)));

Check warning on line 107 in ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs#L105-L107

Added lines #L105 - L107 were not covered by tests
}

// check h2
if hash_a_xor_b != hash_h2 {
return Err(Error::ShuffleValidationFailed(format!(
"C from H2 is inconsistent: hash of a_xor_b: {hash_a_xor_b:?}, hash of C: {hash_h2:?}"
)));

Check warning on line 114 in ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs#L112-L114

Added lines #L112 - L114 were not covered by tests
}

Ok(())
}

/// This is the verification function run by `H2`.
/// `H2` computes the hash for `x2` and `c`
/// and sends the latter to `H1`.
/// Further, he receives `hash_y2` from `H3`
///
/// ## Errors
/// Propagates network errors. Further it returns an error when
/// `hash_x2 != hash_y2`.
async fn h2_verify<C: Context, S: BooleanArray>(
ctx: C,
keys: &[StdArray<Gf32Bit, 1>],
share_b_and_c: &[AdditiveShare<S>],
x2: Vec<S>,
) -> Result<(), Error> {
// compute hashes
// compute hash for x2
let hash_x2 = compute_row_hash(keys, x2);
// compute hash for C
let hash_c = compute_row_hash(
keys,
share_b_and_c.iter().map(ReplicatedSecretSharing::right),
);

// setup channels
let h1_ctx = ctx
.narrow(&OPRFShuffleStep::HashH2toH1)
.set_total_records(TotalRecords::specified(1)?);
let h3_ctx = ctx
.narrow(&OPRFShuffleStep::HashH3toH2)
.set_total_records(TotalRecords::specified(1)?);
let channel_h1 = &h1_ctx.send_channel::<Hash>(ctx.role().peer(Direction::Left));
let channel_h3 = &h3_ctx.recv_channel::<Hash>(ctx.role().peer(Direction::Right));

// send and receive hash
let ((), hash_h3) = try_join(
channel_h1.send(RecordId::FIRST, hash_c),
channel_h3.receive(RecordId::FIRST),
)
.await?;

// check x2
if hash_x2 != hash_h3 {
return Err(Error::ShuffleValidationFailed(format!(
"X2 is inconsistent: hash of x2: {hash_x2:?}, hash of y2: {hash_h3:?}"
)));

Check warning on line 164 in ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs#L162-L164

Added lines #L162 - L164 were not covered by tests
}

Ok(())
}

/// This is the verification function run by `H3`.
/// `H3` computes the hash for `y1`, `y2` and `c`
/// and sends `y1`, `c` to `H1` and `y2` to `H2`.
///
/// ## Errors
/// Propagates network errors.
async fn h3_verify<C: Context, S: BooleanArray>(
ctx: C,
keys: &[StdArray<Gf32Bit, 1>],
share_c_and_a: &[AdditiveShare<S>],
y1: Vec<S>,
y2: Vec<S>,
) -> Result<(), Error> {
// compute hashes
// compute hash for y1
let hash_y1 = compute_row_hash(keys, y1);
// compute hash for y2
let hash_y2 = compute_row_hash(keys, y2);
// compute hash for C
let hash_c = compute_row_hash(
keys,
share_c_and_a.iter().map(ReplicatedSecretSharing::left),
);

// setup channels
let h1_ctx = ctx
.narrow(&OPRFShuffleStep::HashesH3toH1)
.set_total_records(TotalRecords::specified(2)?);
let h2_ctx = ctx
.narrow(&OPRFShuffleStep::HashH3toH2)
.set_total_records(TotalRecords::specified(1)?);
let channel_h1 = &h1_ctx.send_channel::<Hash>(ctx.role().peer(Direction::Right));
let channel_h2 = &h2_ctx.send_channel::<Hash>(ctx.role().peer(Direction::Left));

// send and receive hash
let _ = try_join3(
channel_h1.send(RecordId::FIRST, hash_y1),
channel_h1.send(RecordId::from(1usize), hash_c),
channel_h2.send(RecordId::FIRST, hash_y2),
)
.await?;

Check warning on line 210 in ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs#L210

Added line #L210 was not covered by tests

Ok(())
}

/// This function computes for each item in the iterator the inner product with `keys`.
/// It concatenates all inner products and hashes them.
///
/// ## Panics
/// Panics when conversion from `BooleanArray` to `Vec<Gf32Bit` fails.
fn compute_row_hash<S, I>(keys: &[StdArray<Gf32Bit, 1>], row_iterator: I) -> Hash
where
S: BooleanArray,
I: IntoIterator<Item = S>,
{
let iterator = row_iterator
.into_iter()
.map(|row| <S as TryInto<Vec<Gf32Bit>>>::try_into(row).unwrap());
compute_hash(iterator.map(|row| {
row.into_iter()
.zip(keys)
.fold(Gf32Bit::ZERO, |acc, (row_entry, key)| {
acc + row_entry * *key.first()
})
}))
}

/// This function reveals the MAC keys,
/// stores them in a vector
/// and appends a `Gf32Bit::ONE`
///
/// It uses `parallel_join` and therefore vector elements are a `StdArray` of length `1`.
///
/// ## Errors
/// Propagates errors from `parallel_join` and `malicious_reveal`.
async fn reveal_keys<C: Context>(
ctx: &C,
key_shares: &[AdditiveShare<Gf32Bit>],
) -> Result<Vec<StdArray<Gf32Bit, 1>>, Error> {
// reveal MAC keys
let keys = ctx
.parallel_join(key_shares.iter().enumerate().map(|(i, key)| async move {
malicious_reveal(ctx.clone(), RecordId::from(i), None, key).await
}))
.await?
.into_iter()
.flatten()
// add a one, since last row element is tag which is not multiplied with a key
.chain(iter::once(StdArray::from_fn(|_| Gf32Bit::ONE)))
.collect::<Vec<_>>();

Ok(keys)
}

#[cfg(all(test, unit_test))]
mod tests {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there anything that prevents us from adding a test that ensures cheaters are caught?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use rand::{thread_rng, Rng};

use super::*;
use crate::{
ff::{boolean_array::BA64, Serializable},
protocol::ipa_prf::shuffle::base::shuffle,
test_executor::run,
test_fixture::{Runner, TestWorld},
};

/// This test checks the correctness of the malicious shuffle
/// when all parties behave honestly
/// and all the MAC keys are `Gf32Bit::ONE`.
/// Further, each row consists of a `BA32` and a `BA32` tag.
#[test]
fn check_shuffle_with_simple_mac() {
const RECORD_AMOUNT: usize = 10;
run(|| async {
let world = TestWorld::default();
let mut rng = thread_rng();
let records = (0..RECORD_AMOUNT)
.map(|_| {
let entry = rng.gen::<[u8; 4]>();
let mut entry_and_tag = [0u8; 8];
entry_and_tag[0..4].copy_from_slice(&entry);
entry_and_tag[4..8].copy_from_slice(&entry);
BA64::deserialize_from_slice(&entry_and_tag)
})
.collect::<Vec<BA64>>();

let _ = world
.semi_honest(records.into_iter(), |ctx, rows| async move {
// trivial shares of Gf32Bit::ONE
let key_shares = vec![AdditiveShare::new(Gf32Bit::ONE, Gf32Bit::ONE); 1];
// run shuffle
let (shares, messages) = shuffle(ctx.narrow("shuffle"), rows).await.unwrap();
// verify it
verify_shuffle(ctx.narrow("verify"), &key_shares, &shares, messages)
.await
.unwrap();
})
.await;
});
}
}
2 changes: 2 additions & 0 deletions ipa-core/src/protocol/ipa_prf/shuffle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ use crate::{
};

pub mod base;
#[allow(dead_code)]
pub mod malicious;
#[cfg(descriptive_gate)]
mod sharded;
pub(crate) mod step;
Expand Down
4 changes: 4 additions & 0 deletions ipa-core/src/protocol/ipa_prf/shuffle/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@ pub(crate) enum OPRFShuffleStep {
TransferCHat,
TransferX2,
TransferY1,
RevealMACKey,
HashesH3toH1,
HashH2toH1,
HashH3toH2,
}