Skip to content

Commit

Permalink
Merge pull request #1255 from danielmasny/shuffle-verification
Browse files Browse the repository at this point in the history
Shuffle verification
  • Loading branch information
danielmasny authored Sep 6, 2024
2 parents adcd8dd + b7194fe commit c857eb2
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ipa-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ pub enum Error {
record_id: RecordId,
total_records: usize,
},
#[error("The verification of the shuffle failed: {0}")]
ShuffleValidationFailed(String),
}

impl Default for Error {
Expand Down
33 changes: 33 additions & 0 deletions ipa-core/src/protocol/ipa_prf/shuffle/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,39 @@ pub struct IntermediateShuffleMessages<S: SharedValue> {
x2_or_y2: Option<Vec<S>>,
}

#[allow(dead_code)]
impl<S: SharedValue> IntermediateShuffleMessages<S> {
/// When `IntermediateShuffleMessages` is initialized correctly,
/// this function returns `x1` when `Role = H1`
/// and `y1` when `Role = H3`.
///
/// ## Panics
/// Panics when `Role = H2`, i.e. `x1_or_y1` is `None`.
pub fn get_x1_or_y1(self) -> Vec<S> {
self.x1_or_y1.unwrap()
}

/// When `IntermediateShuffleMessages` is initialized correctly,
/// this function returns `x2` when `Role = H2`
/// and `y2` when `Role = H3`.
///
/// ## Panics
/// Panics when `Role = H1`, i.e. `x2_or_y2` is `None`.
pub fn get_x2_or_y2(self) -> Vec<S> {
self.x2_or_y2.unwrap()
}

/// When `IntermediateShuffleMessages` is initialized correctly,
/// this function returns `y1` and `y2` when `Role = H3`.
///
/// ## Panics
/// Panics when `Role = H1`, i.e. `x2_or_y2` is `None` or
/// when `Role = H2`, i.e. `x1_or_y1` is `None`.
pub fn get_both_x_or_ys(self) -> (Vec<S>, Vec<S>) {
(self.x1_or_y1.unwrap(), self.x2_or_y2.unwrap())
}
}

async fn run_h1<C, I, S, Zl, Zr>(
ctx: &C,
batch_size: NonZeroUsize,
Expand Down
310 changes: 310 additions & 0 deletions ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
use std::iter;

use futures_util::future::{try_join, try_join3};

use crate::{
error::Error,
ff::{boolean_array::BooleanArray, Field, Gf32Bit},
helpers::{
hashing::{compute_hash, Hash},
Direction, Role, TotalRecords,
},
protocol::{
basics::malicious_reveal,
context::Context,
ipa_prf::shuffle::{base::IntermediateShuffleMessages, step::OPRFShuffleStep},
RecordId,
},
secret_sharing::{
replicated::{semi_honest::AdditiveShare, ReplicatedSecretSharing},
SharedValue, SharedValueArray, StdArray,
},
};

/// This function verifies the `shuffled_shares` and the `IntermediateShuffleMessages`.
///
/// ## Errors
/// Propagates network errors.
/// Further, returns an error when messages are inconsistent with the MAC tags.
async fn verify_shuffle<C: Context, S: BooleanArray>(
ctx: C,
key_shares: &[AdditiveShare<Gf32Bit>],
shuffled_shares: &[AdditiveShare<S>],
messages: IntermediateShuffleMessages<S>,
) -> Result<(), Error> {
// reveal keys
let k_ctx = ctx
.narrow(&OPRFShuffleStep::RevealMACKey)
.set_total_records(TotalRecords::specified(key_shares.len())?);
let keys = reveal_keys(&k_ctx, key_shares).await?;

// verify messages and shares
match ctx.role() {
Role::H1 => h1_verify(ctx, &keys, shuffled_shares, messages.get_x1_or_y1()).await,
Role::H2 => h2_verify(ctx, &keys, shuffled_shares, messages.get_x2_or_y2()).await,
Role::H3 => {
let (y1, y2) = messages.get_both_x_or_ys();
h3_verify(ctx, &keys, shuffled_shares, y1, y2).await
}
}
}

/// This is the verification function run by `H1`.
/// `H1` computes the hash for `x1` and `a_xor_b`.
/// Further, he receives `hash_y1` and `hash_c_h3` from `H3`
/// and `hash_c_h2` from `H2`.
///
/// ## Errors
/// Propagates network errors. Further it returns an error when
/// `hash_x1 != hash_y1` or `hash_c_h2 != hash_a_xor_b`
/// or `hash_c_h3 != hash_a_xor_b`.
async fn h1_verify<C: Context, S: BooleanArray>(
ctx: C,
keys: &[StdArray<Gf32Bit, 1>],
share_a_and_b: &[AdditiveShare<S>],
x1: Vec<S>,
) -> Result<(), Error> {
// compute hashes
// compute hash for x1
let hash_x1 = compute_row_hash(keys, x1);
// compute hash for A xor B
let hash_a_xor_b = compute_row_hash(
keys,
share_a_and_b
.iter()
.map(|share| share.left() + share.right()),
);

// setup channels
let h3_ctx = ctx
.narrow(&OPRFShuffleStep::HashesH3toH1)
.set_total_records(TotalRecords::specified(2)?);
let h2_ctx = ctx
.narrow(&OPRFShuffleStep::HashH2toH1)
.set_total_records(TotalRecords::ONE);
let channel_h3 = &h3_ctx.recv_channel::<Hash>(ctx.role().peer(Direction::Left));
let channel_h2 = &h2_ctx.recv_channel::<Hash>(ctx.role().peer(Direction::Right));

// receive hashes
let (hash_y1, hash_h3, hash_h2) = try_join3(
channel_h3.receive(RecordId::FIRST),
channel_h3.receive(RecordId::from(1usize)),
channel_h2.receive(RecordId::FIRST),
)
.await?;

// check y1
if hash_x1 != hash_y1 {
return Err(Error::ShuffleValidationFailed(format!(
"Y1 is inconsistent: hash of x1: {hash_x1:?}, hash of y1: {hash_y1:?}"
)));
}

// check c from h3
if hash_a_xor_b != hash_h3 {
return Err(Error::ShuffleValidationFailed(format!(
"C from H3 is inconsistent: hash of a_xor_b: {hash_a_xor_b:?}, hash of C: {hash_h3:?}"
)));
}

// check h2
if hash_a_xor_b != hash_h2 {
return Err(Error::ShuffleValidationFailed(format!(
"C from H2 is inconsistent: hash of a_xor_b: {hash_a_xor_b:?}, hash of C: {hash_h2:?}"
)));
}

Ok(())
}

/// This is the verification function run by `H2`.
/// `H2` computes the hash for `x2` and `c`
/// and sends the latter to `H1`.
/// Further, he receives `hash_y2` from `H3`
///
/// ## Errors
/// Propagates network errors. Further it returns an error when
/// `hash_x2 != hash_y2`.
async fn h2_verify<C: Context, S: BooleanArray>(
ctx: C,
keys: &[StdArray<Gf32Bit, 1>],
share_b_and_c: &[AdditiveShare<S>],
x2: Vec<S>,
) -> Result<(), Error> {
// compute hashes
// compute hash for x2
let hash_x2 = compute_row_hash(keys, x2);
// compute hash for C
let hash_c = compute_row_hash(
keys,
share_b_and_c.iter().map(ReplicatedSecretSharing::right),
);

// setup channels
let h1_ctx = ctx
.narrow(&OPRFShuffleStep::HashH2toH1)
.set_total_records(TotalRecords::specified(1)?);
let h3_ctx = ctx
.narrow(&OPRFShuffleStep::HashH3toH2)
.set_total_records(TotalRecords::specified(1)?);
let channel_h1 = &h1_ctx.send_channel::<Hash>(ctx.role().peer(Direction::Left));
let channel_h3 = &h3_ctx.recv_channel::<Hash>(ctx.role().peer(Direction::Right));

// send and receive hash
let ((), hash_h3) = try_join(
channel_h1.send(RecordId::FIRST, hash_c),
channel_h3.receive(RecordId::FIRST),
)
.await?;

// check x2
if hash_x2 != hash_h3 {
return Err(Error::ShuffleValidationFailed(format!(
"X2 is inconsistent: hash of x2: {hash_x2:?}, hash of y2: {hash_h3:?}"
)));
}

Ok(())
}

/// This is the verification function run by `H3`.
/// `H3` computes the hash for `y1`, `y2` and `c`
/// and sends `y1`, `c` to `H1` and `y2` to `H2`.
///
/// ## Errors
/// Propagates network errors.
async fn h3_verify<C: Context, S: BooleanArray>(
ctx: C,
keys: &[StdArray<Gf32Bit, 1>],
share_c_and_a: &[AdditiveShare<S>],
y1: Vec<S>,
y2: Vec<S>,
) -> Result<(), Error> {
// compute hashes
// compute hash for y1
let hash_y1 = compute_row_hash(keys, y1);
// compute hash for y2
let hash_y2 = compute_row_hash(keys, y2);
// compute hash for C
let hash_c = compute_row_hash(
keys,
share_c_and_a.iter().map(ReplicatedSecretSharing::left),
);

// setup channels
let h1_ctx = ctx
.narrow(&OPRFShuffleStep::HashesH3toH1)
.set_total_records(TotalRecords::specified(2)?);
let h2_ctx = ctx
.narrow(&OPRFShuffleStep::HashH3toH2)
.set_total_records(TotalRecords::specified(1)?);
let channel_h1 = &h1_ctx.send_channel::<Hash>(ctx.role().peer(Direction::Right));
let channel_h2 = &h2_ctx.send_channel::<Hash>(ctx.role().peer(Direction::Left));

// send and receive hash
let _ = try_join3(
channel_h1.send(RecordId::FIRST, hash_y1),
channel_h1.send(RecordId::from(1usize), hash_c),
channel_h2.send(RecordId::FIRST, hash_y2),
)
.await?;

Ok(())
}

/// This function computes for each item in the iterator the inner product with `keys`.
/// It concatenates all inner products and hashes them.
///
/// ## Panics
/// Panics when conversion from `BooleanArray` to `Vec<Gf32Bit` fails.
fn compute_row_hash<S, I>(keys: &[StdArray<Gf32Bit, 1>], row_iterator: I) -> Hash
where
S: BooleanArray,
I: IntoIterator<Item = S>,
{
let iterator = row_iterator
.into_iter()
.map(|row| <S as TryInto<Vec<Gf32Bit>>>::try_into(row).unwrap());
compute_hash(iterator.map(|row| {
row.into_iter()
.zip(keys)
.fold(Gf32Bit::ZERO, |acc, (row_entry, key)| {
acc + row_entry * *key.first()
})
}))
}

/// This function reveals the MAC keys,
/// stores them in a vector
/// and appends a `Gf32Bit::ONE`
///
/// It uses `parallel_join` and therefore vector elements are a `StdArray` of length `1`.
///
/// ## Errors
/// Propagates errors from `parallel_join` and `malicious_reveal`.
async fn reveal_keys<C: Context>(
ctx: &C,
key_shares: &[AdditiveShare<Gf32Bit>],
) -> Result<Vec<StdArray<Gf32Bit, 1>>, Error> {
// reveal MAC keys
let keys = ctx
.parallel_join(key_shares.iter().enumerate().map(|(i, key)| async move {
malicious_reveal(ctx.clone(), RecordId::from(i), None, key).await
}))
.await?
.into_iter()
.flatten()
// add a one, since last row element is tag which is not multiplied with a key
.chain(iter::once(StdArray::from_fn(|_| Gf32Bit::ONE)))
.collect::<Vec<_>>();

Ok(keys)
}

#[cfg(all(test, unit_test))]
mod tests {
use rand::{thread_rng, Rng};

use super::*;
use crate::{
ff::{boolean_array::BA64, Serializable},
protocol::ipa_prf::shuffle::base::shuffle,
test_executor::run,
test_fixture::{Runner, TestWorld},
};

/// This test checks the correctness of the malicious shuffle
/// when all parties behave honestly
/// and all the MAC keys are `Gf32Bit::ONE`.
/// Further, each row consists of a `BA32` and a `BA32` tag.
#[test]
fn check_shuffle_with_simple_mac() {
const RECORD_AMOUNT: usize = 10;
run(|| async {
let world = TestWorld::default();
let mut rng = thread_rng();
let records = (0..RECORD_AMOUNT)
.map(|_| {
let entry = rng.gen::<[u8; 4]>();
let mut entry_and_tag = [0u8; 8];
entry_and_tag[0..4].copy_from_slice(&entry);
entry_and_tag[4..8].copy_from_slice(&entry);
BA64::deserialize_from_slice(&entry_and_tag)
})
.collect::<Vec<BA64>>();

let _ = world
.semi_honest(records.into_iter(), |ctx, rows| async move {
// trivial shares of Gf32Bit::ONE
let key_shares = vec![AdditiveShare::new(Gf32Bit::ONE, Gf32Bit::ONE); 1];
// run shuffle
let (shares, messages) = shuffle(ctx.narrow("shuffle"), rows).await.unwrap();
// verify it
verify_shuffle(ctx.narrow("verify"), &key_shares, &shares, messages)
.await
.unwrap();
})
.await;
});
}
}
2 changes: 2 additions & 0 deletions ipa-core/src/protocol/ipa_prf/shuffle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ use crate::{
};

pub mod base;
#[allow(dead_code)]
pub mod malicious;
#[cfg(descriptive_gate)]
mod sharded;
pub(crate) mod step;
Expand Down
4 changes: 4 additions & 0 deletions ipa-core/src/protocol/ipa_prf/shuffle/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@ pub(crate) enum OPRFShuffleStep {
TransferCHat,
TransferX2,
TransferY1,
RevealMACKey,
HashesH3toH1,
HashH2toH1,
HashH3toH2,
}

0 comments on commit c857eb2

Please sign in to comment.