Skip to content

Commit

Permalink
Support sharding in MAC and ZKP validations
Browse files Browse the repository at this point in the history
The next step to be able to run sharded protocols is to change the
validators API and support both sharded and non-sharded contexts.

This is a fairly straightforward change that just propagates `ShardBinding`
trait bound through MAC and ZKP contexts and validators
  • Loading branch information
akoshelev committed Oct 3, 2024
1 parent 9c834cc commit 0ba1ec4
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 64 deletions.
14 changes: 10 additions & 4 deletions ipa-core/src/protocol/basics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ impl<'a, B: ShardBinding> BooleanProtocols<DZKPUpgradedSemiHonestContext<'a, B>>
{
}

impl<'a> BooleanProtocols<DZKPUpgradedMaliciousContext<'a>> for AdditiveShare<Boolean> {}
impl<'a, B: ShardBinding> BooleanProtocols<DZKPUpgradedMaliciousContext<'a, B>>
for AdditiveShare<Boolean>
{
}

// Used for aggregation tests
impl<'a, B: ShardBinding> BooleanProtocols<UpgradedSemiHonestContext<'a, B, Boolean>, 8>
Expand All @@ -107,7 +110,7 @@ impl<'a, B: ShardBinding> BooleanProtocols<DZKPUpgradedSemiHonestContext<'a, B>,
{
}

impl<'a> BooleanProtocols<DZKPUpgradedMaliciousContext<'a>, PRF_CHUNK>
impl<'a, B: ShardBinding> BooleanProtocols<DZKPUpgradedMaliciousContext<'a, B>, PRF_CHUNK>
for AdditiveShare<Boolean, PRF_CHUNK>
{
}
Expand All @@ -124,7 +127,7 @@ impl<'a, B: ShardBinding> BooleanProtocols<DZKPUpgradedSemiHonestContext<'a, B>,
{
}

impl<'a> BooleanProtocols<DZKPUpgradedMaliciousContext<'a>, AGG_CHUNK>
impl<'a, B: ShardBinding> BooleanProtocols<DZKPUpgradedMaliciousContext<'a, B>, AGG_CHUNK>
for AdditiveShare<Boolean, AGG_CHUNK>
{
}
Expand Down Expand Up @@ -159,7 +162,10 @@ impl<'a, B: ShardBinding> BooleanProtocols<DZKPUpgradedSemiHonestContext<'a, B>,
{
}

impl<'a> BooleanProtocols<DZKPUpgradedMaliciousContext<'a>, 32> for AdditiveShare<Boolean, 32> {}
impl<'a, B: ShardBinding> BooleanProtocols<DZKPUpgradedMaliciousContext<'a, B>, 32>
for AdditiveShare<Boolean, 32>
{
}

const_assert_eq!(
AGG_CHUNK,
Expand Down
14 changes: 8 additions & 6 deletions ipa-core/src/protocol/basics/mul/dzkp_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::{
RecordId,
},
secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, Vectorizable},
sharding::{NotSharded, ShardBinding},
};

/// This function implements an MPC multiply using the standard strategy, i.e. via computing the
Expand All @@ -27,13 +28,14 @@ use crate::{
/// back via the error response
/// ## Panics
/// Panics if the mutex is found to be poisoned
pub async fn zkp_multiply<'a, F, const N: usize>(
ctx: DZKPUpgradedMaliciousContext<'a>,
pub async fn zkp_multiply<'a, B, F, const N: usize>(
ctx: DZKPUpgradedMaliciousContext<'a, B>,
record_id: RecordId,
a: &Replicated<F, N>,
b: &Replicated<F, N>,
) -> Result<Replicated<F, N>, Error>
where
B: ShardBinding,
F: Field + DZKPCompatibleField<N>,
{
// Shared randomness used to mask the values that are sent.
Expand Down Expand Up @@ -62,17 +64,17 @@ where

/// Implement secure multiplication for malicious contexts with replicated secret sharing.
#[async_trait]
impl<'a, F: Field + DZKPCompatibleField<N>, const N: usize>
SecureMul<DZKPUpgradedMaliciousContext<'a>> for Replicated<F, N>
impl<'a, B: ShardBinding, F: Field + DZKPCompatibleField<N>, const N: usize>
SecureMul<DZKPUpgradedMaliciousContext<'a, B>> for Replicated<F, N>
{
async fn multiply<'fut>(
&self,
rhs: &Self,
ctx: DZKPUpgradedMaliciousContext<'a>,
ctx: DZKPUpgradedMaliciousContext<'a, B>,
record_id: RecordId,
) -> Result<Self, Error>
where
DZKPUpgradedMaliciousContext<'a>: 'fut,
DZKPUpgradedMaliciousContext<'a, NotSharded>: 'fut,
{
zkp_multiply(ctx, record_id, self, rhs).await
}
Expand Down
8 changes: 5 additions & 3 deletions ipa-core/src/protocol/basics/mul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,19 @@ macro_rules! boolean_array_mul {
}
}

impl<'a> BooleanArrayMul<DZKPUpgradedMaliciousContext<'a>> for Replicated<$vec> {
impl<'a, B: sharding::ShardBinding> BooleanArrayMul<DZKPUpgradedMaliciousContext<'a, B>>
for Replicated<$vec>
{
type Vectorized = Replicated<Boolean, $dim>;

fn multiply<'fut>(
ctx: DZKPUpgradedMaliciousContext<'a>,
ctx: DZKPUpgradedMaliciousContext<'a, B>,
record_id: RecordId,
a: &'fut Self::Vectorized,
b: &'fut Self::Vectorized,
) -> impl Future<Output = Result<Self::Vectorized, Error>> + Send + 'fut
where
DZKPUpgradedMaliciousContext<'a>: 'fut,
DZKPUpgradedMaliciousContext<'a, B>: 'fut,
{
use crate::protocol::basics::mul::dzkp_malicious::zkp_multiply;
zkp_multiply(ctx, record_id, a, b)
Expand Down
7 changes: 4 additions & 3 deletions ipa-core/src/protocol/basics/reveal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,20 +321,21 @@ where
}
}

impl<'a, const N: usize> Reveal<DZKPUpgradedMaliciousContext<'a>> for Replicated<Boolean, N>
impl<'a, B, const N: usize> Reveal<DZKPUpgradedMaliciousContext<'a, B>> for Replicated<Boolean, N>
where
B: ShardBinding,
Boolean: Vectorizable<N>,
{
type Output = <Boolean as Vectorizable<N>>::Array;

async fn generic_reveal<'fut>(
&'fut self,
ctx: DZKPUpgradedMaliciousContext<'a>,
ctx: DZKPUpgradedMaliciousContext<'a, B>,
record_id: RecordId,
excluded: Option<Role>,
) -> Result<Option<Self::Output>, Error>
where
DZKPUpgradedMaliciousContext<'a>: 'fut,
DZKPUpgradedMaliciousContext<'a, B>: 'fut,
{
malicious_reveal(ctx, record_id, excluded, self).await
}
Expand Down
21 changes: 11 additions & 10 deletions ipa-core/src/protocol/context/dzkp_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,22 @@ use crate::{
Gate, RecordId,
},
seq_join::SeqJoin,
sharding::ShardBinding,
sync::{Arc, Weak},
};

/// Represents protocol context in malicious setting when using zero-knowledge proofs,
/// i.e. secure against one active adversary in 3 party MPC ring.
#[derive(Clone)]
pub struct DZKPUpgraded<'a> {
validator_inner: Weak<MaliciousDZKPValidatorInner<'a>>,
base_ctx: MaliciousContext<'a>,
pub struct DZKPUpgraded<'a, B: ShardBinding> {
validator_inner: Weak<MaliciousDZKPValidatorInner<'a, B>>,
base_ctx: MaliciousContext<'a, B>,
}

impl<'a> DZKPUpgraded<'a> {
impl<'a, B: ShardBinding> DZKPUpgraded<'a, B> {
pub(super) fn new(
validator_inner: &Arc<MaliciousDZKPValidatorInner<'a>>,
base_ctx: MaliciousContext<'a>,
validator_inner: &Arc<MaliciousDZKPValidatorInner<'a, B>>,
base_ctx: MaliciousContext<'a, B>,
) -> Self {
let records_per_batch = validator_inner.batcher.lock().unwrap().records_per_batch();
let active_work = if records_per_batch == 1 {
Expand Down Expand Up @@ -82,7 +83,7 @@ impl<'a> DZKPUpgraded<'a> {
}

#[async_trait]
impl<'a> DZKPContext for DZKPUpgraded<'a> {
impl<'a, B: ShardBinding> DZKPContext for DZKPUpgraded<'a, B> {
async fn validate_record(&self, record_id: RecordId) -> Result<(), Error> {
let validator_inner = self.validator_inner.upgrade().expect("validator is active");

Expand All @@ -100,7 +101,7 @@ impl<'a> DZKPContext for DZKPUpgraded<'a> {
}
}

impl<'a> super::Context for DZKPUpgraded<'a> {
impl<'a, B: ShardBinding> super::Context for DZKPUpgraded<'a, B> {
fn role(&self) -> Role {
self.base_ctx.role()
}
Expand Down Expand Up @@ -152,13 +153,13 @@ impl<'a> super::Context for DZKPUpgraded<'a> {
}
}

impl<'a> SeqJoin for DZKPUpgraded<'a> {
impl<'a, B: ShardBinding> SeqJoin for DZKPUpgraded<'a, B> {
fn active_work(&self) -> NonZeroUsize {
self.base_ctx.active_work()
}
}

impl Debug for DZKPUpgraded<'_> {
impl<B: ShardBinding> Debug for DZKPUpgraded<'_, B> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "DZKPMaliciousContext")
}
Expand Down
28 changes: 14 additions & 14 deletions ipa-core/src/protocol/context/dzkp_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ impl Batch {

/// ## Panics
/// If `usize` to `u128` conversion fails.
pub(super) async fn validate(self, ctx: Base<'_>) -> Result<(), Error> {
pub(super) async fn validate<B: ShardBinding>(self, ctx: Base<'_, B>) -> Result<(), Error> {
let proof_ctx = ctx.narrow(&Step::GenerateProof);

if self.is_empty() {
Expand Down Expand Up @@ -701,26 +701,26 @@ type DzkpBatcher<'a> = Batcher<'a, Batch>;

/// The DZKP validator, and all associated contexts, each hold a reference to a single
/// instance of `MaliciousDZKPValidatorInner`.
pub(super) struct MaliciousDZKPValidatorInner<'a> {
pub(super) struct MaliciousDZKPValidatorInner<'a, B: ShardBinding> {
pub(super) batcher: Mutex<DzkpBatcher<'a>>,
pub(super) validate_ctx: Base<'a>,
pub(super) validate_ctx: Base<'a, B>,
}

/// `MaliciousDZKPValidator` corresponds to pub struct `Malicious` and implements the trait `DZKPValidator`
/// The implementation of `validate` of the `DZKPValidator` trait depends on generic `DF`
pub struct MaliciousDZKPValidator<'a> {
pub struct MaliciousDZKPValidator<'a, B: ShardBinding> {
// This is an `Option` because we want to consume it in `DZKPValidator::validate`,
// but we also want to implement `Drop`. Note that the `is_verified` check in `Drop`
// does nothing when `batcher_ref` is already `None`.
inner_ref: Option<Arc<MaliciousDZKPValidatorInner<'a>>>,
protocol_ctx: MaliciousDZKPUpgraded<'a>,
inner_ref: Option<Arc<MaliciousDZKPValidatorInner<'a, B>>>,
protocol_ctx: MaliciousDZKPUpgraded<'a, B>,
}

#[async_trait]
impl<'a> DZKPValidator for MaliciousDZKPValidator<'a> {
type Context = MaliciousDZKPUpgraded<'a>;
impl<'a, B: ShardBinding> DZKPValidator for MaliciousDZKPValidator<'a, B> {
type Context = MaliciousDZKPUpgraded<'a, B>;

fn context(&self) -> MaliciousDZKPUpgraded<'a> {
fn context(&self) -> MaliciousDZKPUpgraded<'a, B> {
self.protocol_ctx.clone()
}

Expand Down Expand Up @@ -774,11 +774,11 @@ impl<'a> DZKPValidator for MaliciousDZKPValidator<'a> {
}
}

impl<'a> MaliciousDZKPValidator<'a> {
impl<'a, B: ShardBinding> MaliciousDZKPValidator<'a, B> {
#[must_use]
#[allow(clippy::needless_pass_by_value)]
pub fn new<S>(
ctx: MaliciousContext<'a>,
ctx: MaliciousContext<'a, B>,
steps: MaliciousProtocolSteps<S>,
max_multiplications_per_gate: usize,
) -> Self
Expand Down Expand Up @@ -808,7 +808,7 @@ impl<'a> MaliciousDZKPValidator<'a> {
}
}

impl<'a> Drop for MaliciousDZKPValidator<'a> {
impl<'a, B: ShardBinding> Drop for MaliciousDZKPValidator<'a, B> {
fn drop(&mut self) {
if self.inner_ref.is_some() {
self.is_verified().unwrap();
Expand Down Expand Up @@ -922,7 +922,7 @@ mod tests {
async fn test_select_malicious<V>()
where
V: BooleanArray,
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedMaliciousContext<'a>>,
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedMaliciousContext<'a, NotSharded>>,
Standard: Distribution<V>,
{
let world = TestWorld::default();
Expand Down Expand Up @@ -1040,7 +1040,7 @@ mod tests {
async fn multi_select_malicious<V>(count: usize, max_multiplications_per_gate: usize)
where
V: BooleanArray,
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedMaliciousContext<'a>>,
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedMaliciousContext<'a, NotSharded>>,
Standard: Distribution<V>,
{
let mut rng = thread_rng();
Expand Down
14 changes: 7 additions & 7 deletions ipa-core/src/protocol/context/malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,14 @@ impl<'a, B: ShardBinding> super::Context for Context<'a, B> {
}
}

impl<'a> UpgradableContext for Context<'a, NotSharded> {
type Validator<F: ExtendableField> = BatchValidator<'a, F>;
impl<'a, B: ShardBinding> UpgradableContext for Context<'a, B> {
type Validator<F: ExtendableField> = BatchValidator<'a, F, B>;

fn validator<F: ExtendableField>(self) -> Self::Validator<F> {
BatchValidator::new(self)
}

type DZKPValidator = MaliciousDZKPValidator<'a>;
type DZKPValidator = MaliciousDZKPValidator<'a, B>;

fn dzkp_validator<S>(
self,
Expand Down Expand Up @@ -174,18 +174,18 @@ impl<B: ShardBinding> Debug for Context<'_, B> {

use crate::sync::{Mutex, Weak};

pub(super) type MacBatcher<'a, F> = Mutex<Batcher<'a, validator::Malicious<'a, F>>>;
pub(super) type MacBatcher<'a, F, B> = Mutex<Batcher<'a, validator::Malicious<'a, F, B>>>;

/// Represents protocol context in malicious setting, i.e. secure against one active adversary
/// in 3 party MPC ring.
#[derive(Clone)]
pub struct Upgraded<'a, F: ExtendableField, B: ShardBinding> {
batch: Weak<MacBatcher<'a, F>>,
batch: Weak<MacBatcher<'a, F, B>>,
base_ctx: Context<'a, B>,
}

impl<'a, F: ExtendableField, B: ShardBinding> Upgraded<'a, F, B> {
pub(super) fn new(batch: &Arc<MacBatcher<'a, F>>, ctx: Context<'a, B>) -> Self {
pub(super) fn new(batch: &Arc<MacBatcher<'a, F, B>>, ctx: Context<'a, B>) -> Self {
// The DZKP malicious context adjusts active_work to match records_per_batch.
// The MAC validator currently configures the batcher with records_per_batch =
// active_work. If the latter behavior changes, this code may need to be
Expand Down Expand Up @@ -231,7 +231,7 @@ impl<'a, F: ExtendableField, B: ShardBinding> Upgraded<'a, F, B> {
self.with_batch(record_id, |v| v.r_share().clone())
}

fn with_batch<C: FnOnce(&mut validator::Malicious<'a, F>) -> T, T>(
fn with_batch<C: FnOnce(&mut validator::Malicious<'a, F, B>) -> T, T>(
&self,
record_id: RecordId,
action: C,
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/protocol/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub type SemiHonestContext<'a, B = NotSharded> = semi_honest::Context<'a, B>;
pub type ShardedSemiHonestContext<'a> = semi_honest::Context<'a, Sharded>;

pub type MaliciousContext<'a, B = NotSharded> = malicious::Context<'a, B>;
pub type UpgradedMaliciousContext<'a, F> = malicious::Upgraded<'a, F, NotSharded>;
pub type UpgradedMaliciousContext<'a, F, B = NotSharded> = malicious::Upgraded<'a, F, B>;

#[cfg(all(feature = "in-memory-infra", any(test, feature = "test-fixture")))]
pub(crate) use malicious::TEST_DZKP_STEPS;
Expand Down
Loading

0 comments on commit 0ba1ec4

Please sign in to comment.