From 0ba1ec47be5df4621c72c19095984c132d1d8c71 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 2 Oct 2024 21:03:17 -0700 Subject: [PATCH] Support sharding in MAC and ZKP validations The next step to be able to run sharded protocols is to change the validators API and support both sharded and non-sharded contexts. This is a fairly straightforward change that just propagates `ShardBinding` trait bound through MAC and ZKP contexts and validators --- ipa-core/src/protocol/basics/mod.rs | 14 +++++++--- .../src/protocol/basics/mul/dzkp_malicious.rs | 14 ++++++---- ipa-core/src/protocol/basics/mul/mod.rs | 8 ++++-- ipa-core/src/protocol/basics/reveal.rs | 7 +++-- .../src/protocol/context/dzkp_malicious.rs | 21 +++++++------- .../src/protocol/context/dzkp_validator.rs | 28 +++++++++---------- ipa-core/src/protocol/context/malicious.rs | 14 +++++----- ipa-core/src/protocol/context/mod.rs | 2 +- ipa-core/src/protocol/context/validator.rs | 26 ++++++++--------- ipa-core/src/test_fixture/world.rs | 6 ++-- 10 files changed, 76 insertions(+), 64 deletions(-) diff --git a/ipa-core/src/protocol/basics/mod.rs b/ipa-core/src/protocol/basics/mod.rs index 278aded76..cd2e92be6 100644 --- a/ipa-core/src/protocol/basics/mod.rs +++ b/ipa-core/src/protocol/basics/mod.rs @@ -89,7 +89,10 @@ impl<'a, B: ShardBinding> BooleanProtocols> { } -impl<'a> BooleanProtocols> for AdditiveShare {} +impl<'a, B: ShardBinding> BooleanProtocols> + for AdditiveShare +{ +} // Used for aggregation tests impl<'a, B: ShardBinding> BooleanProtocols, 8> @@ -107,7 +110,7 @@ impl<'a, B: ShardBinding> BooleanProtocols, { } -impl<'a> BooleanProtocols, PRF_CHUNK> +impl<'a, B: ShardBinding> BooleanProtocols, PRF_CHUNK> for AdditiveShare { } @@ -124,7 +127,7 @@ impl<'a, B: ShardBinding> BooleanProtocols, { } -impl<'a> BooleanProtocols, AGG_CHUNK> +impl<'a, B: ShardBinding> BooleanProtocols, AGG_CHUNK> for AdditiveShare { } @@ -159,7 +162,10 @@ impl<'a, B: ShardBinding> BooleanProtocols, { } -impl<'a> BooleanProtocols, 32> for AdditiveShare {} +impl<'a, B: ShardBinding> BooleanProtocols, 32> + for AdditiveShare +{ +} const_assert_eq!( AGG_CHUNK, diff --git a/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs b/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs index 23a96c982..4a618d507 100644 --- a/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs +++ b/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs @@ -13,6 +13,7 @@ use crate::{ RecordId, }, secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, Vectorizable}, + sharding::{NotSharded, ShardBinding}, }; /// This function implements an MPC multiply using the standard strategy, i.e. via computing the @@ -27,13 +28,14 @@ use crate::{ /// back via the error response /// ## Panics /// Panics if the mutex is found to be poisoned -pub async fn zkp_multiply<'a, F, const N: usize>( - ctx: DZKPUpgradedMaliciousContext<'a>, +pub async fn zkp_multiply<'a, B, F, const N: usize>( + ctx: DZKPUpgradedMaliciousContext<'a, B>, record_id: RecordId, a: &Replicated, b: &Replicated, ) -> Result, Error> where + B: ShardBinding, F: Field + DZKPCompatibleField, { // Shared randomness used to mask the values that are sent. @@ -62,17 +64,17 @@ where /// Implement secure multiplication for malicious contexts with replicated secret sharing. #[async_trait] -impl<'a, F: Field + DZKPCompatibleField, const N: usize> - SecureMul> for Replicated +impl<'a, B: ShardBinding, F: Field + DZKPCompatibleField, const N: usize> + SecureMul> for Replicated { async fn multiply<'fut>( &self, rhs: &Self, - ctx: DZKPUpgradedMaliciousContext<'a>, + ctx: DZKPUpgradedMaliciousContext<'a, B>, record_id: RecordId, ) -> Result where - DZKPUpgradedMaliciousContext<'a>: 'fut, + DZKPUpgradedMaliciousContext<'a, NotSharded>: 'fut, { zkp_multiply(ctx, record_id, self, rhs).await } diff --git a/ipa-core/src/protocol/basics/mul/mod.rs b/ipa-core/src/protocol/basics/mul/mod.rs index 260d7a4b8..89f5e107a 100644 --- a/ipa-core/src/protocol/basics/mul/mod.rs +++ b/ipa-core/src/protocol/basics/mul/mod.rs @@ -123,17 +123,19 @@ macro_rules! boolean_array_mul { } } - impl<'a> BooleanArrayMul> for Replicated<$vec> { + impl<'a, B: sharding::ShardBinding> BooleanArrayMul> + for Replicated<$vec> + { type Vectorized = Replicated; fn multiply<'fut>( - ctx: DZKPUpgradedMaliciousContext<'a>, + ctx: DZKPUpgradedMaliciousContext<'a, B>, record_id: RecordId, a: &'fut Self::Vectorized, b: &'fut Self::Vectorized, ) -> impl Future> + Send + 'fut where - DZKPUpgradedMaliciousContext<'a>: 'fut, + DZKPUpgradedMaliciousContext<'a, B>: 'fut, { use crate::protocol::basics::mul::dzkp_malicious::zkp_multiply; zkp_multiply(ctx, record_id, a, b) diff --git a/ipa-core/src/protocol/basics/reveal.rs b/ipa-core/src/protocol/basics/reveal.rs index 75867046f..19363e1af 100644 --- a/ipa-core/src/protocol/basics/reveal.rs +++ b/ipa-core/src/protocol/basics/reveal.rs @@ -321,20 +321,21 @@ where } } -impl<'a, const N: usize> Reveal> for Replicated +impl<'a, B, const N: usize> Reveal> for Replicated where + B: ShardBinding, Boolean: Vectorizable, { type Output = >::Array; async fn generic_reveal<'fut>( &'fut self, - ctx: DZKPUpgradedMaliciousContext<'a>, + ctx: DZKPUpgradedMaliciousContext<'a, B>, record_id: RecordId, excluded: Option, ) -> Result, Error> where - DZKPUpgradedMaliciousContext<'a>: 'fut, + DZKPUpgradedMaliciousContext<'a, B>: 'fut, { malicious_reveal(ctx, record_id, excluded, self).await } diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index 9f28239ba..2023f427a 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -20,21 +20,22 @@ use crate::{ Gate, RecordId, }, seq_join::SeqJoin, + sharding::ShardBinding, sync::{Arc, Weak}, }; /// Represents protocol context in malicious setting when using zero-knowledge proofs, /// i.e. secure against one active adversary in 3 party MPC ring. #[derive(Clone)] -pub struct DZKPUpgraded<'a> { - validator_inner: Weak>, - base_ctx: MaliciousContext<'a>, +pub struct DZKPUpgraded<'a, B: ShardBinding> { + validator_inner: Weak>, + base_ctx: MaliciousContext<'a, B>, } -impl<'a> DZKPUpgraded<'a> { +impl<'a, B: ShardBinding> DZKPUpgraded<'a, B> { pub(super) fn new( - validator_inner: &Arc>, - base_ctx: MaliciousContext<'a>, + validator_inner: &Arc>, + base_ctx: MaliciousContext<'a, B>, ) -> Self { let records_per_batch = validator_inner.batcher.lock().unwrap().records_per_batch(); let active_work = if records_per_batch == 1 { @@ -82,7 +83,7 @@ impl<'a> DZKPUpgraded<'a> { } #[async_trait] -impl<'a> DZKPContext for DZKPUpgraded<'a> { +impl<'a, B: ShardBinding> DZKPContext for DZKPUpgraded<'a, B> { async fn validate_record(&self, record_id: RecordId) -> Result<(), Error> { let validator_inner = self.validator_inner.upgrade().expect("validator is active"); @@ -100,7 +101,7 @@ impl<'a> DZKPContext for DZKPUpgraded<'a> { } } -impl<'a> super::Context for DZKPUpgraded<'a> { +impl<'a, B: ShardBinding> super::Context for DZKPUpgraded<'a, B> { fn role(&self) -> Role { self.base_ctx.role() } @@ -152,13 +153,13 @@ impl<'a> super::Context for DZKPUpgraded<'a> { } } -impl<'a> SeqJoin for DZKPUpgraded<'a> { +impl<'a, B: ShardBinding> SeqJoin for DZKPUpgraded<'a, B> { fn active_work(&self) -> NonZeroUsize { self.base_ctx.active_work() } } -impl Debug for DZKPUpgraded<'_> { +impl Debug for DZKPUpgraded<'_, B> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "DZKPMaliciousContext") } diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 835a32e9d..ec9393486 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -520,7 +520,7 @@ impl Batch { /// ## Panics /// If `usize` to `u128` conversion fails. - pub(super) async fn validate(self, ctx: Base<'_>) -> Result<(), Error> { + pub(super) async fn validate(self, ctx: Base<'_, B>) -> Result<(), Error> { let proof_ctx = ctx.narrow(&Step::GenerateProof); if self.is_empty() { @@ -701,26 +701,26 @@ type DzkpBatcher<'a> = Batcher<'a, Batch>; /// The DZKP validator, and all associated contexts, each hold a reference to a single /// instance of `MaliciousDZKPValidatorInner`. -pub(super) struct MaliciousDZKPValidatorInner<'a> { +pub(super) struct MaliciousDZKPValidatorInner<'a, B: ShardBinding> { pub(super) batcher: Mutex>, - pub(super) validate_ctx: Base<'a>, + pub(super) validate_ctx: Base<'a, B>, } /// `MaliciousDZKPValidator` corresponds to pub struct `Malicious` and implements the trait `DZKPValidator` /// The implementation of `validate` of the `DZKPValidator` trait depends on generic `DF` -pub struct MaliciousDZKPValidator<'a> { +pub struct MaliciousDZKPValidator<'a, B: ShardBinding> { // This is an `Option` because we want to consume it in `DZKPValidator::validate`, // but we also want to implement `Drop`. Note that the `is_verified` check in `Drop` // does nothing when `batcher_ref` is already `None`. - inner_ref: Option>>, - protocol_ctx: MaliciousDZKPUpgraded<'a>, + inner_ref: Option>>, + protocol_ctx: MaliciousDZKPUpgraded<'a, B>, } #[async_trait] -impl<'a> DZKPValidator for MaliciousDZKPValidator<'a> { - type Context = MaliciousDZKPUpgraded<'a>; +impl<'a, B: ShardBinding> DZKPValidator for MaliciousDZKPValidator<'a, B> { + type Context = MaliciousDZKPUpgraded<'a, B>; - fn context(&self) -> MaliciousDZKPUpgraded<'a> { + fn context(&self) -> MaliciousDZKPUpgraded<'a, B> { self.protocol_ctx.clone() } @@ -774,11 +774,11 @@ impl<'a> DZKPValidator for MaliciousDZKPValidator<'a> { } } -impl<'a> MaliciousDZKPValidator<'a> { +impl<'a, B: ShardBinding> MaliciousDZKPValidator<'a, B> { #[must_use] #[allow(clippy::needless_pass_by_value)] pub fn new( - ctx: MaliciousContext<'a>, + ctx: MaliciousContext<'a, B>, steps: MaliciousProtocolSteps, max_multiplications_per_gate: usize, ) -> Self @@ -808,7 +808,7 @@ impl<'a> MaliciousDZKPValidator<'a> { } } -impl<'a> Drop for MaliciousDZKPValidator<'a> { +impl<'a, B: ShardBinding> Drop for MaliciousDZKPValidator<'a, B> { fn drop(&mut self) { if self.inner_ref.is_some() { self.is_verified().unwrap(); @@ -922,7 +922,7 @@ mod tests { async fn test_select_malicious() where V: BooleanArray, - for<'a> Replicated: BooleanArrayMul>, + for<'a> Replicated: BooleanArrayMul>, Standard: Distribution, { let world = TestWorld::default(); @@ -1040,7 +1040,7 @@ mod tests { async fn multi_select_malicious(count: usize, max_multiplications_per_gate: usize) where V: BooleanArray, - for<'a> Replicated: BooleanArrayMul>, + for<'a> Replicated: BooleanArrayMul>, Standard: Distribution, { let mut rng = thread_rng(); diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index 401a6cb0e..def33e950 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -138,14 +138,14 @@ impl<'a, B: ShardBinding> super::Context for Context<'a, B> { } } -impl<'a> UpgradableContext for Context<'a, NotSharded> { - type Validator = BatchValidator<'a, F>; +impl<'a, B: ShardBinding> UpgradableContext for Context<'a, B> { + type Validator = BatchValidator<'a, F, B>; fn validator(self) -> Self::Validator { BatchValidator::new(self) } - type DZKPValidator = MaliciousDZKPValidator<'a>; + type DZKPValidator = MaliciousDZKPValidator<'a, B>; fn dzkp_validator( self, @@ -174,18 +174,18 @@ impl Debug for Context<'_, B> { use crate::sync::{Mutex, Weak}; -pub(super) type MacBatcher<'a, F> = Mutex>>; +pub(super) type MacBatcher<'a, F, B> = Mutex>>; /// Represents protocol context in malicious setting, i.e. secure against one active adversary /// in 3 party MPC ring. #[derive(Clone)] pub struct Upgraded<'a, F: ExtendableField, B: ShardBinding> { - batch: Weak>, + batch: Weak>, base_ctx: Context<'a, B>, } impl<'a, F: ExtendableField, B: ShardBinding> Upgraded<'a, F, B> { - pub(super) fn new(batch: &Arc>, ctx: Context<'a, B>) -> Self { + pub(super) fn new(batch: &Arc>, ctx: Context<'a, B>) -> Self { // The DZKP malicious context adjusts active_work to match records_per_batch. // The MAC validator currently configures the batcher with records_per_batch = // active_work. If the latter behavior changes, this code may need to be @@ -231,7 +231,7 @@ impl<'a, F: ExtendableField, B: ShardBinding> Upgraded<'a, F, B> { self.with_batch(record_id, |v| v.r_share().clone()) } - fn with_batch) -> T, T>( + fn with_batch) -> T, T>( &self, record_id: RecordId, action: C, diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index abd53b6ee..0651b74a4 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -26,7 +26,7 @@ pub type SemiHonestContext<'a, B = NotSharded> = semi_honest::Context<'a, B>; pub type ShardedSemiHonestContext<'a> = semi_honest::Context<'a, Sharded>; pub type MaliciousContext<'a, B = NotSharded> = malicious::Context<'a, B>; -pub type UpgradedMaliciousContext<'a, F> = malicious::Upgraded<'a, F, NotSharded>; +pub type UpgradedMaliciousContext<'a, F, B = NotSharded> = malicious::Upgraded<'a, F, B>; #[cfg(all(feature = "in-memory-infra", any(test, feature = "test-fixture")))] pub(crate) use malicious::TEST_DZKP_STEPS; diff --git a/ipa-core/src/protocol/context/validator.rs b/ipa-core/src/protocol/context/validator.rs index e57ae3c6a..a71b395c3 100644 --- a/ipa-core/src/protocol/context/validator.rs +++ b/ipa-core/src/protocol/context/validator.rs @@ -199,18 +199,18 @@ impl MaliciousAccumulator { /// When batch is validated, `r` is revealed and can never be /// used again. In fact, it gets out of scope after successful validation /// so no code can get access to it. -pub struct BatchValidator<'a, F: ExtendableField> { - batches_ref: Arc>, - protocol_ctx: MaliciousContext<'a>, +pub struct BatchValidator<'a, F: ExtendableField, B: ShardBinding> { + batches_ref: Arc>, + protocol_ctx: MaliciousContext<'a, B>, } -impl<'a, F: ExtendableField> BatchValidator<'a, F> { +impl<'a, F: ExtendableField, B: ShardBinding> BatchValidator<'a, F, B> { /// Create a new validator for malicious context. /// /// ## Panics /// If total records is not set. #[must_use] - pub fn new(ctx: MaliciousContext<'a>) -> Self { + pub fn new(ctx: MaliciousContext<'a, B>) -> Self { let TotalRecords::Specified(total_records) = ctx.total_records() else { panic!("Total records must be specified before creating the validator"); }; @@ -230,14 +230,14 @@ impl<'a, F: ExtendableField> BatchValidator<'a, F> { } } -pub struct Malicious<'a, F: ExtendableField> { +pub struct Malicious<'a, F: ExtendableField, B: ShardBinding> { r_share: Replicated, pub(super) accumulator: MaliciousAccumulator, - validate_ctx: Base<'a>, + validate_ctx: Base<'a, B>, offset: usize, } -impl Malicious<'_, F> { +impl Malicious<'_, F, B> { /// ## Errors /// If the two information theoretic MACs are not equal (after multiplying by `r`), this indicates that one of the parties /// must have launched an additive attack. At this point the honest parties should abort the protocol. This method throws an @@ -294,21 +294,21 @@ impl Malicious<'_, F> { } } -impl<'a, F> Validator for BatchValidator<'a, F> +impl<'a, F, B: ShardBinding> Validator for BatchValidator<'a, F, B> where F: ExtendableField, { - type Context = UpgradedMaliciousContext<'a, F>; + type Context = UpgradedMaliciousContext<'a, F, B>; fn context(&self) -> Self::Context { UpgradedMaliciousContext::new(&self.batches_ref, self.protocol_ctx.clone()) } } -impl<'a, F: ExtendableField> Malicious<'a, F> { +impl<'a, F: ExtendableField, B: ShardBinding> Malicious<'a, F, B> { #[must_use] #[allow(clippy::needless_pass_by_value)] - pub fn new(ctx: MaliciousContext<'a>, offset: usize) -> Self { + pub fn new(ctx: MaliciousContext<'a, B>, offset: usize) -> Self { // Each invocation requires 3 calls to PRSS to generate the state. // Validation occurs in batches and `offset` indicates which batch // we're in right now. @@ -386,7 +386,7 @@ impl<'a, F: ExtendableField> Malicious<'a, F> { } } -impl Debug for Malicious<'_, F> { +impl Debug for Malicious<'_, F, B> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "MaliciousValidator<{:?}>", type_name::()) } diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index cd6919580..1c337b10e 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -428,7 +428,7 @@ pub trait Runner { I: IntoShares + Send + 'static, A: Send + 'static, O: Send + Debug, - H: Fn(DZKPUpgradedMaliciousContext<'a>, A) -> R + Send + Sync, + H: Fn(DZKPUpgradedMaliciousContext<'a, NotSharded>, A) -> R + Send + Sync, R: Future + Send; } @@ -531,7 +531,7 @@ impl Runner> I: IntoShares + Send + 'static, A: Send + 'static, O: Send + Debug, - H: Fn(DZKPUpgradedMaliciousContext<'a>, A) -> R + Send + Sync, + H: Fn(DZKPUpgradedMaliciousContext<'a, NotSharded>, A) -> R + Send + Sync, R: Future + Send, { unimplemented!() @@ -672,7 +672,7 @@ impl Runner for TestWorld { I: IntoShares + Send + 'static, A: Send + 'static, O: Send + Debug, - H: (Fn(DZKPUpgradedMaliciousContext<'a>, A) -> R) + Send + Sync, + H: (Fn(DZKPUpgradedMaliciousContext<'a, NotSharded>, A) -> R) + Send + Sync, R: Future + Send, { self.malicious(input, |ctx, share| async {