From 94b7cbf6685beddc225cb359ed2a217eff79cc56 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Fri, 25 Oct 2024 12:34:15 -0700 Subject: [PATCH 01/47] Use constant-time comparisons --- ipa-core/Cargo.toml | 1 + ipa-core/src/ff/ec_prime_field.rs | 7 ++++++ ipa-core/src/ff/galois_field.rs | 13 ++++++++++ ipa-core/src/ff/prime_field.rs | 10 ++++++++ ipa-core/src/helpers/hashing.rs | 7 ++++++ ipa-core/src/protocol/basics/check_zero.rs | 6 +++-- ipa-core/src/protocol/basics/mod.rs | 2 +- ipa-core/src/protocol/basics/mul/mod.rs | 25 ++----------------- .../src/protocol/basics/share_validation.rs | 3 ++- .../src/protocol/ipa_prf/shuffle/malicious.rs | 9 ++++--- .../ipa_prf/validation_protocol/validation.rs | 13 ++++++---- .../replicated/malicious/additive_share.rs | 5 ++-- 12 files changed, 63 insertions(+), 38 deletions(-) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 017f67ab6..0b2447c9c 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -143,6 +143,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" sha2 = "0.10" shuttle-crate = { package = "shuttle", version = "0.6.1", optional = true } +subtle = "2.6" thiserror = "1.0" time = { version = "0.3", optional = true } tokio = { version = "1.35", features = ["fs", "rt", "rt-multi-thread", "macros"] } diff --git a/ipa-core/src/ff/ec_prime_field.rs b/ipa-core/src/ff/ec_prime_field.rs index 64de59f9e..35ad47a2c 100644 --- a/ipa-core/src/ff/ec_prime_field.rs +++ b/ipa-core/src/ff/ec_prime_field.rs @@ -2,6 +2,7 @@ use std::convert::Infallible; use curve25519_dalek::scalar::Scalar; use generic_array::GenericArray; +use subtle::{Choice, ConstantTimeEq}; use typenum::{U2, U32}; use crate::{ @@ -75,6 +76,12 @@ impl Serializable for Fp25519 { } } +impl ConstantTimeEq for Fp25519 { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + ///generate random elements in Fp25519 impl rand::distributions::Distribution for rand::distributions::Standard { fn sample(&self, rng: &mut R) -> Fp25519 { diff --git a/ipa-core/src/ff/galois_field.rs b/ipa-core/src/ff/galois_field.rs index 3a84f2b2a..2c99ed164 100644 --- a/ipa-core/src/ff/galois_field.rs +++ b/ipa-core/src/ff/galois_field.rs @@ -8,6 +8,7 @@ use bitvec::{ prelude::{bitarr, BitArr, Lsb0}, }; use generic_array::GenericArray; +use subtle::{Choice, ConstantTimeEq}; use typenum::{Unsigned, U1, U2, U3, U4, U5}; use crate::{ @@ -227,6 +228,18 @@ macro_rules! bit_array_impl { const POLYNOMIAL: u128 = $polynomial; } + // If the field value fits in a machine word, a naive comparison should be fine. + // But this impl is important for `[T]`, and useful to document where a + // constant-time compare is intended. + impl ConstantTimeEq for $name { + fn ct_eq(&self, other: &Self) -> Choice { + // Note that this will compare the padding bits. That should not be + // a problem, because we should not allow the padding bits to become + // non-zero. + self.0.as_raw_slice().ct_eq(&other.0.as_raw_slice()) + } + } + impl rand::distributions::Distribution<$name> for rand::distributions::Standard { fn sample(&self, rng: &mut R) -> $name { <$name>::truncate_from(rng.gen::()) diff --git a/ipa-core/src/ff/prime_field.rs b/ipa-core/src/ff/prime_field.rs index 3d398fc6a..dfc353659 100644 --- a/ipa-core/src/ff/prime_field.rs +++ b/ipa-core/src/ff/prime_field.rs @@ -1,6 +1,7 @@ use std::{fmt::Display, mem}; use generic_array::GenericArray; +use subtle::{Choice, ConstantTimeEq}; use super::Field; use crate::{ @@ -241,6 +242,15 @@ macro_rules! field_impl { } } + // If the field value fits in a machine word, a naive comparison should be fine. + // But this impl is important for `[T]`, and useful to document where a + // constant-time compare is intended. + impl ConstantTimeEq for $field { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } + } + impl rand::distributions::Distribution<$field> for rand::distributions::Standard { fn sample(&self, rng: &mut R) -> $field { <$field>::truncate_from(rng.gen::()) diff --git a/ipa-core/src/helpers/hashing.rs b/ipa-core/src/helpers/hashing.rs index 741958c65..eaf1ac3ff 100644 --- a/ipa-core/src/helpers/hashing.rs +++ b/ipa-core/src/helpers/hashing.rs @@ -5,6 +5,7 @@ use sha2::{ digest::{Output, OutputSizeUser}, Digest, Sha256, }; +use subtle::{Choice, ConstantTimeEq}; use crate::{ ff::{PrimeField, Serializable}, @@ -15,6 +16,12 @@ use crate::{ #[derive(Clone, Debug, Default, PartialEq)] pub struct Hash(Output); +impl ConstantTimeEq for Hash { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.as_slice().ct_eq(other.0.as_slice()) + } +} + impl Serializable for Hash { type Size = ::OutputSize; diff --git a/ipa-core/src/protocol/basics/check_zero.rs b/ipa-core/src/protocol/basics/check_zero.rs index a7aa29916..7cbe8b8b3 100644 --- a/ipa-core/src/protocol/basics/check_zero.rs +++ b/ipa-core/src/protocol/basics/check_zero.rs @@ -1,3 +1,5 @@ +use subtle::ConstantTimeEq; + use crate::{ error::Error, ff::Field, @@ -49,7 +51,7 @@ pub async fn malicious_check_zero( ) -> Result where C: Context, - F: Field + FromRandom, + F: Field + FromRandom + ConstantTimeEq, { let r_sharing: Replicated = ctx.prss().generate(record_id); @@ -61,7 +63,7 @@ where .expect("full reveal should always return a value"), ); - Ok(rv == F::ZERO) + Ok(rv.ct_eq(&F::ZERO).into()) } #[cfg(all(test, unit_test))] diff --git a/ipa-core/src/protocol/basics/mod.rs b/ipa-core/src/protocol/basics/mod.rs index 594c246f1..ebe34cb34 100644 --- a/ipa-core/src/protocol/basics/mod.rs +++ b/ipa-core/src/protocol/basics/mod.rs @@ -19,7 +19,7 @@ pub use share_known_value::ShareKnownValue; use crate::{ const_assert_eq, - ff::{boolean::Boolean, ec_prime_field::Fp25519, PrimeField}, + ff::{boolean::Boolean, ec_prime_field::Fp25519}, protocol::{ context::{ Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, diff --git a/ipa-core/src/protocol/basics/mul/mod.rs b/ipa-core/src/protocol/basics/mul/mod.rs index 958bafbcd..a69b7faf4 100644 --- a/ipa-core/src/protocol/basics/mul/mod.rs +++ b/ipa-core/src/protocol/basics/mul/mod.rs @@ -13,10 +13,9 @@ use crate::{ Expand, }, protocol::{ - basics::PrimeField, context::{ - dzkp_semi_honest::DZKPUpgraded as SemiHonestDZKPUpgraded, - semi_honest::Upgraded as SemiHonestUpgraded, Context, DZKPUpgradedMaliciousContext, + dzkp_semi_honest::DZKPUpgraded as SemiHonestDZKPUpgraded, Context, + DZKPUpgradedMaliciousContext, }, RecordId, }, @@ -84,26 +83,6 @@ where macro_rules! boolean_array_mul { ($dim:expr, $vec:ty) => { - impl<'a, B, F> BooleanArrayMul> for Replicated<$vec> - where - B: sharding::ShardBinding, - F: PrimeField, - { - type Vectorized = Replicated; - - fn multiply<'fut>( - ctx: SemiHonestUpgraded<'a, B, F>, - record_id: RecordId, - a: &'fut Self::Vectorized, - b: &'fut Self::Vectorized, - ) -> impl Future> + Send + 'fut - where - SemiHonestUpgraded<'a, B, F>: 'fut, - { - semi_honest_multiply(ctx, record_id, a, b) - } - } - impl<'a, B> BooleanArrayMul> for Replicated<$vec> where B: sharding::ShardBinding, diff --git a/ipa-core/src/protocol/basics/share_validation.rs b/ipa-core/src/protocol/basics/share_validation.rs index 43d1f34e6..492b99973 100644 --- a/ipa-core/src/protocol/basics/share_validation.rs +++ b/ipa-core/src/protocol/basics/share_validation.rs @@ -1,4 +1,5 @@ use futures_util::future::try_join; +use subtle::ConstantTimeEq; use crate::{ error::Error, @@ -50,7 +51,7 @@ where ) .await?; - if hash_left == hash_received { + if hash_left.ct_eq(&hash_received).into() { Ok(()) } else { Err(Error::InconsistentShares) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index a8b6d63c0..210acbdd4 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -9,6 +9,7 @@ use futures::{ }; use futures_util::future::{try_join, try_join3}; use generic_array::GenericArray; +use subtle::ConstantTimeEq; use typenum::Const; use crate::{ @@ -280,21 +281,21 @@ async fn h1_verify( .await?; // check y1 - if hash_x1 != hash_y1 { + if hash_x1.ct_ne(&hash_y1).into() { return Err(Error::ShuffleValidationFailed(format!( "Y1 is inconsistent: hash of x1: {hash_x1:?}, hash of y1: {hash_y1:?}" ))); } // check c from h3 - if hash_a_xor_b != hash_h3 { + if hash_a_xor_b.ct_ne(&hash_h3).into() { return Err(Error::ShuffleValidationFailed(format!( "C from H3 is inconsistent: hash of a_xor_b: {hash_a_xor_b:?}, hash of C: {hash_h3:?}" ))); } // check h2 - if hash_a_xor_b != hash_h2 { + if hash_a_xor_b.ct_ne(&hash_h2).into() { return Err(Error::ShuffleValidationFailed(format!( "C from H2 is inconsistent: hash of a_xor_b: {hash_a_xor_b:?}, hash of C: {hash_h2:?}" ))); @@ -341,7 +342,7 @@ async fn h2_verify( .await?; // check x2 - if hash_x2 != hash_h3 { + if hash_x2.ct_ne(&hash_h3).into() { return Err(Error::ShuffleValidationFailed(format!( "X2 is inconsistent: hash of x2: {hash_x2:?}, hash of y2: {hash_h3:?}" ))); diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs index f0430e996..b197dcfa3 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs @@ -4,6 +4,7 @@ use std::{ }; use futures_util::future::{try_join, try_join4}; +use subtle::ConstantTimeEq; use typenum::{Unsigned, U288, U80}; use crate::{ @@ -293,11 +294,13 @@ impl BatchToVerify { .await?; let diff_right_from_other_verifier = receive_data[0..length].to_vec(); - // compare recombined dif to zero - for i in 0..length { - if diff_right[i] + diff_right_from_other_verifier[i] != Fp61BitPrime::ZERO { - return Err(Error::DZKPValidationFailed); - } + // compare recombined diff to zero + let diff = zip(diff_right, diff_right_from_other_verifier) + .map(|(a, b)| a + b) + .collect::>(); + + if diff.ct_ne(&vec![Fp61BitPrime::ZERO; length]).into() { + return Err(Error::DZKPValidationFailed); } Ok(()) diff --git a/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs b/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs index ece21a2b6..a782d9a27 100644 --- a/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs +++ b/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs @@ -10,6 +10,7 @@ use futures::{ stream::{iter as stream_iter, StreamExt}, }; use generic_array::{ArrayLength, GenericArray}; +use subtle::ConstantTimeEq; use typenum::Unsigned; use crate::{ @@ -41,7 +42,7 @@ pub struct AdditiveShare, const N: usize } pub trait ExtendableField: Field { - type ExtendedField: Field + FromRandom; + type ExtendedField: Field + FromRandom + ConstantTimeEq; fn to_extended(&self) -> Self::ExtendedField; } @@ -57,7 +58,7 @@ impl> + FieldSimd, const N: us { } -impl ExtendableField for F { +impl ExtendableField for F { type ExtendedField = F; fn to_extended(&self) -> Self::ExtendedField { From 1eda2ffe241039464c46858a86fb170d73b6902e Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 12 Nov 2024 10:09:42 -0800 Subject: [PATCH 02/47] Support sharded shuffle in executor This does not get it to the end, but it implements the necessary parts to plumb sharded contexts and streams into the sharded shuffle toy protocol. We still need to test it end-to-end and do some blackbox testing inside transport module, but this work is currently blocked behind changes being made inside the processor --- ipa-core/src/helpers/cross_shard_prss.rs | 2 +- ipa-core/src/helpers/gateway/mod.rs | 12 +- .../src/helpers/gateway/stall_detection.rs | 12 +- ipa-core/src/helpers/gateway/transport.rs | 4 + ipa-core/src/helpers/mod.rs | 1 + ipa-core/src/helpers/transport/mod.rs | 7 ++ ipa-core/src/net/transport.rs | 8 ++ ipa-core/src/protocol/prss/mod.rs | 20 +++ ipa-core/src/protocol/step.rs | 6 +- ipa-core/src/query/executor.rs | 4 +- ipa-core/src/query/runner/mod.rs | 4 + ipa-core/src/query/runner/sharded_shuffle.rs | 119 ++++++++++++++++++ 12 files changed, 191 insertions(+), 8 deletions(-) create mode 100644 ipa-core/src/query/runner/sharded_shuffle.rs diff --git a/ipa-core/src/helpers/cross_shard_prss.rs b/ipa-core/src/helpers/cross_shard_prss.rs index 9328d626c..53ebf5f45 100644 --- a/ipa-core/src/helpers/cross_shard_prss.rs +++ b/ipa-core/src/helpers/cross_shard_prss.rs @@ -18,7 +18,7 @@ use crate::{ /// ## Errors /// If shard communication channels fail #[allow(dead_code)] // until this is used in real sharded protocol -async fn gen_and_distribute( +pub async fn gen_and_distribute( gateway: &Gateway, gate: &Gate, prss: R, diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index 3135bd5fe..a6119081e 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -28,7 +28,7 @@ use crate::{ ShardChannelId, TotalRecords, Transport, }, protocol::QueryId, - sharding::ShardIndex, + sharding::{ShardConfiguration, ShardIndex}, sync::{Arc, Mutex}, utils::NonZeroU32PowerOfTwo, }; @@ -106,6 +106,16 @@ pub struct GatewayConfig { pub progress_check_interval: std::time::Duration, } +impl ShardConfiguration for Gateway { + fn shard_id(&self) -> ShardIndex { + self.transports.shard.identity() + } + + fn shard_count(&self) -> ShardIndex { + ShardIndex::from(self.transports.shard.peer_count() + 1) + } +} + impl Gateway { #[must_use] pub fn new( diff --git a/ipa-core/src/helpers/gateway/stall_detection.rs b/ipa-core/src/helpers/gateway/stall_detection.rs index 4a844386f..30d641a4b 100644 --- a/ipa-core/src/helpers/gateway/stall_detection.rs +++ b/ipa-core/src/helpers/gateway/stall_detection.rs @@ -78,7 +78,7 @@ mod gateway { Role, RoleAssignment, SendingEnd, ShardChannelId, ShardReceivingEnd, TotalRecords, }, protocol::QueryId, - sharding::ShardIndex, + sharding::{ShardConfiguration, ShardIndex}, sync::Arc, utils::NonZeroU32PowerOfTwo, }; @@ -207,6 +207,16 @@ mod gateway { } } + impl ShardConfiguration for &Observed { + fn shard_id(&self) -> ShardIndex { + self.inner().gateway.shard_id() + } + + fn shard_count(&self) -> ShardIndex { + self.inner().gateway.shard_count() + } + } + pub struct GatewayWaitingTasks { mpc_send: Option, mpc_recv: Option, diff --git a/ipa-core/src/helpers/gateway/transport.rs b/ipa-core/src/helpers/gateway/transport.rs index dfbc9d328..09a1053cb 100644 --- a/ipa-core/src/helpers/gateway/transport.rs +++ b/ipa-core/src/helpers/gateway/transport.rs @@ -46,6 +46,10 @@ impl Transport for RoleResolvingTransport { Role::all().iter().filter(move |&v| v != &this).copied() } + fn peer_count(&self) -> u32 { + self.inner.peer_count() + } + async fn send< D: Stream> + Send + 'static, Q: QueryIdBinding, diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index 370c42b05..76fcb4046 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -56,6 +56,7 @@ mod gateway_exports { pub type ShardReceivingEnd = gateway::ShardReceivingEnd; } +pub use cross_shard_prss::gen_and_distribute as setup_cross_shard_prss; pub use gateway::GatewayConfig; // TODO: this type should only be available within infra. Right now several infra modules // are exposed at the root level. That makes it impossible to have a proper hierarchy here. diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index b3cfb862f..b71875966 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -312,6 +312,13 @@ pub trait Transport: Clone + Send + Sync + 'static { /// Returns all the other identities, besides me, in this network. fn peers(&self) -> impl Iterator; + /// The number of peers on the network. Default implementation may not be efficient, + /// because it uses [`Self::peers`] to count, so implementations are encouraged to + /// override it + fn peer_count(&self) -> u32 { + u32::try_from(self.peers().count()).expect("Number of peers is less than 4B") + } + /// Sends a new request to the given destination helper party. /// Depending on the specific request, it may or may not require acknowledgment by the remote /// party diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 6ed523093..b6d726a43 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -266,6 +266,10 @@ impl Transport for MpcHttpTransport { .filter(move |&id| id != this) } + fn peer_count(&self) -> u32 { + 2 + } + async fn send< D: Stream> + Send + 'static, Q: QueryIdBinding, @@ -336,6 +340,10 @@ impl Transport for ShardHttpTransport { self.shard_count.iter().filter(move |&v| v != this) } + fn peer_count(&self) -> u32 { + u32::from(self.shard_count).saturating_sub(1) + } + async fn send( &self, dest: Self::Identity, diff --git a/ipa-core/src/protocol/prss/mod.rs b/ipa-core/src/protocol/prss/mod.rs index 9c7e75ea7..dd3d2fbe6 100644 --- a/ipa-core/src/protocol/prss/mod.rs +++ b/ipa-core/src/protocol/prss/mod.rs @@ -209,6 +209,26 @@ impl SharedRandomness for IndexedSharedRandomness { } } +impl SharedRandomness for Arc { + type ChunkIter<'a, Z: ArrayLength> = + ::ChunkIter<'a, Z>; + + fn generate_chunks_one_side, Z: ArrayLength>( + &self, + index: I, + direction: Direction, + ) -> Self::ChunkIter<'_, Z> { + IndexedSharedRandomness::generate_chunks_one_side(self, index, direction) + } + + fn generate_chunks_iter, Z: ArrayLength>( + &self, + index: I, + ) -> impl Iterator, GenericArray)> { + IndexedSharedRandomness::generate_chunks_iter(self, index) + } +} + /// Specialized implementation for chunks that are generated using both left and right /// randomness. The functionality is the same as [`std::iter::zip`], but it does not use /// `Iterator` trait to call `left` and `right` next. It uses inlined method calls to diff --git a/ipa-core/src/protocol/step.rs b/ipa-core/src/protocol/step.rs index ee2ef726a..351eb23c0 100644 --- a/ipa-core/src/protocol/step.rs +++ b/ipa-core/src/protocol/step.rs @@ -5,18 +5,20 @@ use ipa_step_derive::{CompactGate, CompactStep}; #[derive(CompactStep, CompactGate)] pub enum ProtocolStep { Prss, + CrossShardPrss, #[step(child = crate::protocol::ipa_prf::step::IpaPrfStep)] IpaPrf, #[step(child = crate::protocol::hybrid::step::HybridStep)] Hybrid, Multiply, PrimeFieldAddition, + #[step(child = crate::protocol::ipa_prf::shuffle::step::ShardedShuffleStep)] + ShardedShuffle, /// Steps used in unit tests are grouped under this one. Ideally it should be /// gated behind test configuration, but it does not work with build.rs that /// does not enable any features when creating protocol gate file #[step(child = TestExecutionStep)] Test, - /// This step includes all the steps that are currently not linked into a top-level protocol. /// /// This allows those steps to be compiled. However, any use of them will fail at run time. @@ -39,8 +41,6 @@ pub enum DeadCodeStep { FeatureLabelDotProduct, #[step(child = crate::protocol::ipa_prf::boolean_ops::step::MultiplicationStep)] Multiplication, - #[step(child = crate::protocol::ipa_prf::shuffle::step::ShardedShuffleStep)] - ShardedShuffle, } /// Provides a unique per-iteration context in tests. diff --git a/ipa-core/src/query/executor.rs b/ipa-core/src/query/executor.rs index edd1662e4..c620b9565 100644 --- a/ipa-core/src/query/executor.rs +++ b/ipa-core/src/query/executor.rs @@ -39,7 +39,7 @@ use crate::{ Gate, }, query::{ - runner::{OprfIpaQuery, QueryResult}, + runner::{execute_sharded_shuffle, OprfIpaQuery, QueryResult}, state::RunningQuery, }, sync::Arc, @@ -108,7 +108,7 @@ pub fn execute( config, gateway, input, - |_prss, _gateway, _config, _input| unimplemented!(), + |prss, gateway, _config, input| Box::pin(execute_sharded_shuffle(prss, gateway, input)), ), #[cfg(any(test, feature = "weak-field"))] (QueryType::TestAddInPrimeField, FieldType::Fp31) => do_query( diff --git a/ipa-core/src/query/runner/mod.rs b/ipa-core/src/query/runner/mod.rs index 3f1b59f55..83f033fe4 100644 --- a/ipa-core/src/query/runner/mod.rs +++ b/ipa-core/src/query/runner/mod.rs @@ -4,11 +4,15 @@ mod hybrid; mod oprf_ipa; mod reshard_tag; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] +mod sharded_shuffle; +#[cfg(any(test, feature = "cli", feature = "test-fixture"))] mod test_multiply; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] pub(super) use add_in_prime_field::execute as test_add_in_prime_field; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] +pub(super) use sharded_shuffle::execute_sharded_shuffle; +#[cfg(any(test, feature = "cli", feature = "test-fixture"))] pub(super) use test_multiply::execute_test_multiply; pub use self::oprf_ipa::OprfIpaQuery; diff --git a/ipa-core/src/query/runner/sharded_shuffle.rs b/ipa-core/src/query/runner/sharded_shuffle.rs new file mode 100644 index 000000000..0025f07cb --- /dev/null +++ b/ipa-core/src/query/runner/sharded_shuffle.rs @@ -0,0 +1,119 @@ +use futures_util::TryStreamExt; +use ipa_step::StepNarrow; + +use crate::{ + error::Error, + ff::boolean_array::BA64, + helpers::{setup_cross_shard_prss, BodyStream, Gateway, SingleRecordStream}, + protocol::{ + context::{Context, ShardedContext, ShardedSemiHonestContext}, + ipa_prf::Shuffle, + prss::Endpoint as PrssEndpoint, + step::ProtocolStep, + Gate, + }, + query::runner::QueryResult, + secret_sharing::replicated::semi_honest::AdditiveShare, + sharding::{ShardConfiguration, Sharded}, + sync::Arc, +}; + +pub async fn execute_sharded_shuffle<'a>( + prss: &'a PrssEndpoint, + gateway: &'a Gateway, + input: BodyStream, +) -> QueryResult { + let gate = Gate::default().narrow(&ProtocolStep::CrossShardPrss); + let cross_shard_prss = + setup_cross_shard_prss(gateway, &gate, prss.indexed(&gate), gateway).await?; + let ctx = ShardedSemiHonestContext::new_sharded( + prss, + gateway, + Sharded { + shard_id: gateway.shard_id(), + shard_count: gateway.shard_count(), + prss: Arc::new(cross_shard_prss), + }, + ) + .narrow(&ProtocolStep::ShardedShuffle); + + Ok(Box::new(execute(ctx, input).await?)) +} + +#[tracing::instrument("sharded_shuffle", skip_all)] +pub async fn execute(ctx: C, input_stream: BodyStream) -> Result>, Error> +where + C: ShardedContext + Shuffle, +{ + let input = SingleRecordStream::, _>::new(input_stream) + .try_collect::>() + .await?; + ctx.shuffle(input).await +} + +#[cfg(all(test, unit_test))] +mod tests { + use futures_util::future::try_join_all; + use generic_array::GenericArray; + use typenum::Unsigned; + + use crate::{ + ff::{boolean_array::BA64, Serializable, U128Conversions}, + query::runner::sharded_shuffle::execute, + secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares}, + test_executor::run, + test_fixture::{try_join3_array, Reconstruct, TestWorld, TestWorldConfig, WithShards}, + utils::array::zip3, + }; + + #[test] + fn basic() { + run(|| async { + const SHARDS: usize = 20; + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let contexts = world.contexts(); + let input = (0..20_u128).map(BA64::truncate_from).collect::>(); + + #[allow(clippy::redundant_closure_for_method_calls)] + let shard_shares: [Vec>>; 3] = + input.clone().into_iter().share().map(|helper_shares| { + helper_shares + .chunks(SHARDS / 3) + .map(|v| v.to_vec()) + .collect() + }); + + let result = + try_join3_array(zip3(contexts, shard_shares).map(|(h_contexts, h_shares)| { + try_join_all( + h_contexts + .into_iter() + .zip(h_shares) + .map(|(ctx, shard_shares)| { + let shard_stream = shard_shares + .into_iter() + .flat_map(|share| { + const SIZE: usize = + as Serializable>::Size::USIZE; + let mut slice = [0_u8; SIZE]; + share.serialize(GenericArray::from_mut_slice(&mut slice)); + slice + }) + .collect::>() + .into(); + + execute(ctx, shard_stream) + }), + ) + })) + .await + .unwrap() + .map(|v| v.into_iter().flatten().collect::>()) + .reconstruct(); + + // 1/20! probability of this permutation to be the same + assert_ne!(input, result); + }); + } +} From 42a60a2c73d3439157adc289b47f518e6b11ef98 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 12 Nov 2024 17:20:47 -0800 Subject: [PATCH 03/47] Fix release and web builds --- ipa-core/src/helpers/gateway/mod.rs | 10 ++++++++++ ipa-core/src/query/executor.rs | 5 +++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index a6119081e..9ecde1377 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -107,6 +107,16 @@ pub struct GatewayConfig { } impl ShardConfiguration for Gateway { + fn shard_id(&self) -> ShardIndex { + ShardConfiguration::shard_id(&self) + } + + fn shard_count(&self) -> ShardIndex { + ShardConfiguration::shard_count(&self) + } +} + +impl ShardConfiguration for &Gateway { fn shard_id(&self) -> ShardIndex { self.transports.shard.identity() } diff --git a/ipa-core/src/query/executor.rs b/ipa-core/src/query/executor.rs index c620b9565..bb11e66db 100644 --- a/ipa-core/src/query/executor.rs +++ b/ipa-core/src/query/executor.rs @@ -39,14 +39,15 @@ use crate::{ Gate, }, query::{ - runner::{execute_sharded_shuffle, OprfIpaQuery, QueryResult}, + runner::{OprfIpaQuery, QueryResult}, state::RunningQuery, }, sync::Arc, }; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] use crate::{ - ff::Fp32BitPrime, query::runner::execute_test_multiply, query::runner::test_add_in_prime_field, + ff::Fp32BitPrime, query::runner::execute_sharded_shuffle, query::runner::execute_test_multiply, + query::runner::test_add_in_prime_field, }; pub trait Result: Send + Debug { From 4900f0f035fee3f5df68300e1e2db1a2e3b695e5 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 12 Nov 2024 23:19:51 -0800 Subject: [PATCH 04/47] Remove SharedValue trait bound from MaliciousShuffleable The only reason it was there was to be able to use `AdditiveShare` inside malicious shuffle, but it felt too constraining to me. AdditiveShare has a lot of stuff built in and a lot of these are not required for Shuffle and a simple wrapper could to the same trick. While working on this, I noticed that it gets a bit clunky to use in tests that call `verify_shuffle`. I tried to fix that but it is a much bigger refactoring. The idea was to add more traits supporting maliciuos shuffle. Initially we could consume `MaliciousShuffleable` that could later be upgraded to `MaliciousWithTags` type and bound to the original through `ShuffleWithTag`. That type could expose shares with tags attached. More importantly, it should provide a means to go back to original `MaliciousShuffleable`. I got half way through and abandoned that effort for now --- .../src/protocol/ipa_prf/shuffle/malicious.rs | 96 ++++++++++++++----- .../src/protocol/ipa_prf/shuffle/sharded.rs | 8 +- 2 files changed, 76 insertions(+), 28 deletions(-) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index a8b6d63c0..e31d952f0 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -42,6 +42,45 @@ use crate::{ seq_join::seq_join, sharding::ShardIndex, }; +// use crate::protocol::ipa_prf::shuffle::sharded::MaliciousShuffleShare; + +/// Container for left and right shares with tags attached to them. +/// Looks like an additive share, but it is not because it does not need +/// many traits that additive shares require to implement +#[derive(Clone, Debug, Default)] +struct Pair { + left: S, + right: S, +} + +impl Shuffleable for Pair { + type Share = S; + + fn left(&self) -> Self::Share { + self.left.clone() + } + + fn right(&self) -> Self::Share { + self.right.clone() + } + + fn new(l: Self::Share, r: Self::Share) -> Self { + Self { left: l, right: r } + } +} + +impl From> for Pair { + fn from(value: AdditiveShare) -> Self { + let (l, r) = value.as_tuple(); + Shuffleable::new(l, r) + } +} + +impl From> for AdditiveShare { + fn from(value: Pair) -> Self { + ReplicatedSecretSharing::new(value.left, value.right) + } +} /// This function executes the maliciously secure shuffle protocol on the input: `shares`. /// @@ -65,7 +104,7 @@ where .collect::>>(); // compute and append tags to rows - let shares_and_tags: Vec> = + let shares_and_tags: Vec> = compute_and_add_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags), &keys, shares).await?; // shuffle @@ -75,7 +114,7 @@ where verify_shuffle::<_, S>( ctx.narrow(&OPRFShuffleStep::VerifyShuffle), &keys, - &shuffled_shares, + shuffled_shares.as_slice(), messages, ) .await?; @@ -143,7 +182,7 @@ where let keys = setup_keys(ctx.narrow(&OPRFShuffleStep::SetupKeys), amount_of_keys).await?; // compute and append tags to rows - let shares_and_tags: Vec> = + let shares_and_tags: Vec> = compute_and_add_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags), &keys, shares).await?; let (shuffled_shares, messages) = match ctx.role() { @@ -170,7 +209,7 @@ where /// /// ## Panics /// Panics when `S::Bits > B::Bits`. -fn truncate_tags(shares_and_tags: &[AdditiveShare]) -> Vec +fn truncate_tags(shares_and_tags: &[Pair]) -> Vec where S: MaliciousShuffleable, { @@ -178,8 +217,8 @@ where .iter() .map(|row_with_tag| { Shuffleable::new( - split_row_and_tag::(ReplicatedSecretSharing::left(row_with_tag)).0, - split_row_and_tag::(ReplicatedSecretSharing::right(row_with_tag)).0, + split_row_and_tag::(&row_with_tag.left).0, + split_row_and_tag::(&row_with_tag.right).0, ) }) .collect() @@ -191,7 +230,9 @@ where /// When `row_with_tag` does not have the correct format, /// i.e. deserialization returns an error, /// the output row and tag will be zero. -fn split_row_and_tag(row_with_tag: S::ShareAndTag) -> (S::Share, Gf32Bit) { +fn split_row_and_tag( + row_with_tag: &S::ShareAndTag, +) -> (S::Share, Gf32Bit) { let mut buf = GenericArray::default(); row_with_tag.serialize(&mut buf); ( @@ -210,7 +251,7 @@ fn split_row_and_tag(row_with_tag: S::ShareAndTag) -> ( async fn verify_shuffle( ctx: C, key_shares: &[AdditiveShare], - shuffled_shares: &[AdditiveShare], + shuffled_shares: &[Pair], messages: IntermediateShuffleMessages, ) -> Result<(), Error> { // reveal keys @@ -247,7 +288,7 @@ async fn verify_shuffle( async fn h1_verify( ctx: C, keys: &[Gf32Bit], - share_a_and_b: &[AdditiveShare], + share_a_and_b: &[Pair], x1: Vec, ) -> Result<(), Error> { // compute hashes @@ -256,9 +297,9 @@ async fn h1_verify( // compute hash for A xor B let hash_a_xor_b = compute_and_hash_tags::( keys, - share_a_and_b.iter().map(|share| { - ReplicatedSecretSharing::left(share) + ReplicatedSecretSharing::right(share) - }), + share_a_and_b + .iter() + .map(|share| Shuffleable::left(share) + Shuffleable::right(share)), ); // setup channels @@ -314,7 +355,7 @@ async fn h1_verify( async fn h2_verify( ctx: C, keys: &[Gf32Bit], - share_b_and_c: &[AdditiveShare], + share_b_and_c: &[Pair], x2: Vec, ) -> Result<(), Error> { // compute hashes @@ -359,7 +400,7 @@ async fn h2_verify( async fn h3_verify( ctx: C, keys: &[Gf32Bit], - share_c_and_a: &[AdditiveShare], + share_c_and_a: &[Pair], y1: Vec, y2: Vec, ) -> Result<(), Error> { @@ -405,7 +446,7 @@ where let iterator = row_iterator.into_iter().map(|row_with_tag| { // when split_row_and_tags returns the default value, the verification will fail // except 2^-security_parameter, i.e. 2^-32 - let (row, tag) = split_row_and_tag::(row_with_tag); + let (row, tag) = split_row_and_tag::(&row_with_tag); >>::try_into(row) .unwrap() .into_iter() @@ -470,7 +511,7 @@ async fn compute_and_add_tags( ctx: C, keys: &[AdditiveShare], rows: Vec, -) -> Result>, Error> +) -> Result>, Error> where C: Context, S: MaliciousShuffleable, @@ -537,7 +578,7 @@ where fn concatenate_row_and_tag( row: &S, tag: &AdditiveShare, -) -> AdditiveShare { +) -> Pair { let mut row_left = GenericArray::default(); let mut row_right = GenericArray::default(); let mut tag_left = GenericArray::default(); @@ -600,7 +641,10 @@ mod tests { vec![record], ) .await - .unwrap(); + .unwrap() + .into_iter() + .map(AdditiveShare::from) + .collect(); (keys, shares_and_tags) }) @@ -701,7 +745,11 @@ mod tests { verify_shuffle::<_, AdditiveShare>( ctx.narrow("verify"), &key_shares, - &shares, + shares + .into_iter() + .map(Pair::from) + .collect::>() + .as_slice(), messages, ) .await @@ -726,21 +774,21 @@ mod tests { { let row = ::new(rng.gen(), rng.gen()); let tag = AdditiveShare::::new(rng.gen::(), rng.gen::()); - let row_and_tag: AdditiveShare = concatenate_row_and_tag(&row, &tag); + let row_and_tag: Pair = concatenate_row_and_tag(&row, &tag); let mut buf = GenericArray::default(); let mut buf_row = GenericArray::default(); let mut buf_tag = GenericArray::default(); // check left shares - ReplicatedSecretSharing::left(&row_and_tag).serialize(&mut buf); + Shuffleable::left(&row_and_tag).serialize(&mut buf); Shuffleable::left(&row).serialize(&mut buf_row); assert_eq!(buf[0..S::TAG_OFFSET], buf_row[..]); ReplicatedSecretSharing::left(&tag).serialize(&mut buf_tag); assert_eq!(buf[S::TAG_OFFSET..], buf_tag[..]); // check right shares - ReplicatedSecretSharing::right(&row_and_tag).serialize(&mut buf); + Shuffleable::right(&row_and_tag).serialize(&mut buf); Shuffleable::right(&row).serialize(&mut buf_row); assert_eq!(buf[0..S::TAG_OFFSET], buf_row[..]); ReplicatedSecretSharing::right(&tag).serialize(&mut buf_tag); @@ -765,6 +813,7 @@ mod tests { where S: MaliciousShuffleable, S::Share: IntoShares, + S::ShareAndTag: SharedValue, Standard: Distribution, { const RECORD_AMOUNT: usize = 10; @@ -813,6 +862,9 @@ mod tests { ) .await .unwrap() + .into_iter() + .map(AdditiveShare::from) + .collect::>() }, ) .await diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs b/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs index e69fc0093..165366826 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs @@ -279,7 +279,7 @@ pub trait MaliciousShuffleable: /// /// Having an alias here makes it easier to reference in the code, because the /// shuffle routines have an `S: MaliciousShuffleable` type parameter. - type ShareAndTag: ShuffleShare + SharedValue; + type ShareAndTag: ShuffleShare; /// Same as `Self::MaliciousShare::TAG_OFFSET`. /// @@ -316,11 +316,7 @@ where /// automatically. pub trait MaliciousShuffleShare: TryInto, Error = LengthError> { /// A type that can hold `::Share` along with a 32-bit MAC. - /// - /// The `SharedValue` bound is required because some of the malicious shuffle - /// routines use `AdditiveShare`. It might be possible to refactor - /// those routines to avoid the `SharedValue` bound. - type ShareAndTag: ShuffleShare + SharedValue; + type ShareAndTag: ShuffleShare; /// The offset to the MAC in `ShareAndTag`. const TAG_OFFSET: usize; From dba4f6a9d944b54d01b97b3e67373e039684e462 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 13 Nov 2024 17:10:55 -0800 Subject: [PATCH 05/47] Update ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs Co-authored-by: Andy Leiserson --- ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index e31d952f0..2c33ae5ec 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -42,7 +42,6 @@ use crate::{ seq_join::seq_join, sharding::ShardIndex, }; -// use crate::protocol::ipa_prf::shuffle::sharded::MaliciousShuffleShare; /// Container for left and right shares with tags attached to them. /// Looks like an additive share, but it is not because it does not need From 561950c81af8293d5b9d82257647c2a591e119a8 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Fri, 8 Nov 2024 12:25:06 -0800 Subject: [PATCH 06/47] Query status --- ipa-core/src/app.rs | 26 +++++++-- ipa-core/src/helpers/transport/query/mod.rs | 26 +++++++++ ipa-core/src/query/processor.rs | 64 +++++++++++++++++---- ipa-core/src/test_fixture/app.rs | 14 +++-- 4 files changed, 108 insertions(+), 22 deletions(-) diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index fb6f9fdb7..dda2f1321 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -149,8 +149,13 @@ impl HelperApp { /// /// ## Errors /// Propagates errors from the helper. - pub fn query_status(&self, query_id: QueryId) -> Result { - Ok(self.inner.query_processor.query_status(query_id)?) + pub async fn query_status(&self, query_id: QueryId) -> Result { + let shard_transport = self.inner.shard_transport.clone_ref(); + Ok(self + .inner + .query_processor + .query_status(shard_transport, query_id) + .await?) } /// Waits for a query to complete and returns the result. @@ -186,12 +191,23 @@ impl RequestHandler for Inner { let req = req.into::()?; HelperResponse::from(qp.prepare_shard(&self.shard_transport, req)?) } + RouteId::QueryStatus => { + let query_id = ext_query_id(&req)?; + HelperResponse::from(qp.shard_ready(query_id)?) + } r => { return Err(ApiError::BadRequest( format!("{r:?} request must not be handled by shard query processing flow") .into(), )) - } + } /*RouteId::CompleteQuery => { + let query_id = ext_query_id(&req)?; + HelperResponse::from(qp.complete(query_id).await?) + } + RouteId::KillQuery => { + let query_id = ext_query_id(&req)?; + HelperResponse::from(qp.kill(query_id)?) + }*/ }) } } @@ -247,7 +263,9 @@ impl RequestHandler for Inner { } RouteId::QueryStatus => { let query_id = ext_query_id(&req)?; - HelperResponse::from(qp.query_status(query_id)?) + let shard_transport = Transport::clone_ref(&self.shard_transport); + let query_status = qp.query_status(shard_transport, query_id).await?; + HelperResponse::from(query_status) } RouteId::CompleteQuery => { let query_id = ext_query_id(&req)?; diff --git a/ipa-core/src/helpers/transport/query/mod.rs b/ipa-core/src/helpers/transport/query/mod.rs index ac70209b3..71d3e20db 100644 --- a/ipa-core/src/helpers/transport/query/mod.rs +++ b/ipa-core/src/helpers/transport/query/mod.rs @@ -194,6 +194,32 @@ impl Debug for QueryInput { } } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[cfg_attr(test, derive(PartialEq, Eq))] +pub struct QueryStatusRequest { + pub query_id: QueryId, +} + +impl RouteParams for QueryStatusRequest { + type Params = String; + + fn resource_identifier(&self) -> RouteId { + RouteId::QueryStatus + } + + fn query_id(&self) -> QueryId { + self.query_id + } + + fn gate(&self) -> NoStep { + NoStep + } + + fn extra(&self) -> Self::Params { + serde_json::to_string(self).unwrap() + } +} + #[derive(Copy, Clone, Debug, Serialize, Deserialize)] #[cfg_attr(test, derive(PartialEq, Eq))] pub enum QueryType { diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 120e1c5ca..6a09e98be 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -10,7 +10,7 @@ use crate::{ error::Error as ProtocolError, executor::IpaRuntime, helpers::{ - query::{PrepareQuery, QueryConfig, QueryInput}, + query::{PrepareQuery, QueryConfig, QueryInput, QueryStatusRequest}, BroadcastError, Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, Role, RoleAssignment, ShardTransportError, ShardTransportImpl, Transport, }, @@ -110,6 +110,8 @@ pub enum QueryInputError { pub enum QueryStatusError { #[error("The query with id {0:?} does not exist")] NoSuchQuery(QueryId), + #[error(transparent)] + ShardBroadcastError(#[from] BroadcastError), } #[derive(thiserror::Error, Debug)] @@ -349,14 +351,37 @@ impl Processor { Some(status) } - /// Returns the query status. + /// Returns the query status in this helper, by querying all shards. + /// + /// ## Errors + /// If query is not registered on this helper. + /// + /// ## Panics + /// If the query collection mutex is poisoned. + pub async fn query_status( + &self, + shard_transport: ShardTransportImpl, + query_id: QueryId, + ) -> Result { + let status = self + .get_status(query_id) + .ok_or(QueryStatusError::NoSuchQuery(query_id))?; + + let shard_query_status_req = QueryStatusRequest { query_id }; + + shard_transport.broadcast(shard_query_status_req).await?; + + Ok(status) + } + + /// Returns the staus of this single shard. /// /// ## Errors /// If query is not registered on this helper. /// /// ## Panics /// If the query collection mutex is poisoned. - pub fn query_status(&self, query_id: QueryId) -> Result { + pub fn shard_ready(&self, query_id: QueryId) -> Result { let status = self .get_status(query_id) .ok_or(QueryStatusError::NoSuchQuery(query_id))?; @@ -608,9 +633,11 @@ mod tests { }); args.mpc_handlers = [None, Some(h2), Some(h3)]; let t = TestComponents::new(args); - let qc_future = t - .processor - .new_query(t.first_transport, t.shard_transport, t.query_config); + let qc_future = t.processor.new_query( + t.first_transport, + t.shard_transport.clone_ref(), + t.query_config, + ); pin_mut!(qc_future); // poll future once to trigger query status change @@ -618,7 +645,10 @@ mod tests { assert_eq!( QueryStatus::Preparing, - t.processor.query_status(QueryId).unwrap() + t.processor + .query_status(t.shard_transport.clone_ref(), QueryId) + .await + .unwrap() ); // unblock sends barrier.wait().await; @@ -636,7 +666,10 @@ mod tests { ); assert_eq!( QueryStatus::AwaitingInputs, - t.processor.query_status(QueryId).unwrap() + t.processor + .query_status(t.shard_transport.clone_ref(), QueryId) + .await + .unwrap() ); } @@ -772,16 +805,22 @@ mod tests { let req = prepare_query(); let t = TestComponents::new(TestComponentsArgs::default()); assert!(matches!( - t.processor.query_status(QueryId).unwrap_err(), + t.processor + .query_status(t.shard_transport.clone_ref(), QueryId) + .await + .unwrap_err(), QueryStatusError::NoSuchQuery(_) )); t.processor - .prepare_helper(t.second_transport, t.shard_transport, req) + .prepare_helper(t.second_transport, t.shard_transport.clone_ref(), req) .await .unwrap(); assert_eq!( QueryStatus::AwaitingInputs, - t.processor.query_status(QueryId).unwrap() + t.processor + .query_status(t.shard_transport, QueryId) + .await + .unwrap() ); } @@ -1012,7 +1051,8 @@ mod tests { .await?; while !app - .query_status(query_id)? + .query_status(query_id) + .await? .into_iter() .all(|s| s == QueryStatus::Completed) { diff --git a/ipa-core/src/test_fixture/app.rs b/ipa-core/src/test_fixture/app.rs index 4eb51cc9c..cfe2e9622 100644 --- a/ipa-core/src/test_fixture/app.rs +++ b/ipa-core/src/test_fixture/app.rs @@ -1,5 +1,6 @@ use std::{array, iter::zip}; +use futures::future::join_all; use generic_array::GenericArray; use typenum::Unsigned; @@ -118,12 +119,13 @@ impl TestApp { /// Propagates errors retrieving the query status. /// ## Panics /// Never. - pub fn query_status(&self, query_id: QueryId) -> Result<[QueryStatus; 3], ApiError> { - Ok((0..3) - .map(|i| self.drivers[i].query_status(query_id)) - .collect::, _>>()? - .try_into() - .unwrap()) + #[allow(clippy::disallowed_methods)] + pub async fn query_status(&self, query_id: QueryId) -> Result<[QueryStatus; 3], ApiError> { + join_all((0..3).map(|i| self.drivers[i].query_status(query_id))) + .await + .into_iter() + .collect::, _>>() + .map(|vec| vec.try_into().unwrap()) } /// ## Errors From a6333ce473faee2e3a5d0e1fcb7c1c7297c638b7 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Fri, 8 Nov 2024 12:25:21 -0800 Subject: [PATCH 07/47] Fixing tests --- ipa-core/src/query/processor.rs | 71 +++++++++++++++++++++++++++------ 1 file changed, 59 insertions(+), 12 deletions(-) diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 6a09e98be..9a015f535 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -458,7 +458,7 @@ pub enum QueryKillStatus { #[cfg(all(test, unit_test))] mod tests { - use std::{array, future::Future, sync::Arc}; + use std::{array, future::Future, marker::PhantomData, sync::Arc}; use futures::pin_mut; use futures_util::future::poll_immediate; @@ -467,11 +467,7 @@ mod tests { use crate::{ ff::FieldType, helpers::{ - make_owned_handler, - query::{PrepareQuery, QueryConfig, QueryType::TestMultiply}, - ApiError, HandlerBox, HelperIdentity, HelperResponse, InMemoryMpcNetwork, - InMemoryShardNetwork, InMemoryTransport, RequestHandler, RoleAssignment, Transport, - TransportIdentity, + make_owned_handler, query::{PrepareQuery, QueryConfig, QueryStatusRequest, QueryType::TestMultiply}, routing::RouteId, ApiError, HandlerBox, HelperIdentity, HelperResponse, InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport, RequestHandler, RoleAssignment, Transport, TransportIdentity }, protocol::QueryId, query::{ @@ -481,7 +477,7 @@ mod tests { sharding::ShardIndex, }; - fn prepare_query_handler(cb: F) -> Arc> + /*fn prepare_query_handler(cb: F) -> Arc> where F: Fn(PrepareQuery) -> Fut + Send + Sync + 'static, Fut: Future> + Send + Sync + 'static, @@ -490,14 +486,65 @@ mod tests { let prepare_query = req.into().unwrap(); cb(prepare_query) }) + }*/ + + fn create_handler(prepare_handler: FPQ, status_hanler: FQS) -> Arc> + where + FPQ: Fn(PrepareQuery) -> Fut + Send + Sync + 'static, + FQS: Fn(QueryStatusRequest) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + Sync + 'static, + I: TransportIdentity, + { + make_owned_handler(move |req, _| { + match req.route { + RouteId::PrepareQuery => prepare_handler(req.into().unwrap()), + RouteId::QueryStatus => status_hanler(req.into().unwrap()), + _ => panic!("unexpected route {:?}", req.route) + } + }) + } + + async fn respond_ok(_: T) -> Result { + Ok(HelperResponse::ok()) + } + + struct TestHandler { + phantom: PhantomData, + prepare_handle: Option, + status_handle: Option, + } + + impl TestHandler { + fn create_response(opterror: &mut Option) -> Result { + if let Some(error) = opterror.take() { + Err(error) + } else { + Ok(HelperResponse::ok()) + } + } + } + + #[async_trait::async_trait] + impl RequestHandler for TestHandler { + async fn handle( + &self, + req: crate::helpers::routing::Addr, + _data: crate::helpers::BodyStream, + ) -> Result { + match req.route { + //RouteId::PrepareQuery => Self::create_response(&mut self.prepare_error), + //RouteId::QueryStatus => Self::create_response(&mut self.status_error), + _ => panic!("unexpected route {:?}", req.route) + } + } } fn helper_respond_ok() -> Arc> { - prepare_query_handler(|_| async { Ok(HelperResponse::ok()) }) + create_handler(respond_ok, respond_ok) } fn shard_respond_ok(_si: ShardIndex) -> Arc> { - prepare_query_handler(|_| async { Ok(HelperResponse::ok()) }) + create_handler(respond_ok::, respond_ok::) } fn test_multiply_config() -> QueryConfig { @@ -617,13 +664,13 @@ mod tests { let barrier = Arc::new(Barrier::new(3)); let h2_barrier = Arc::clone(&barrier); let h3_barrier = Arc::clone(&barrier); - let h2 = prepare_query_handler(move |_| { + let h2 = create_handler(move |_| { let barrier = Arc::clone(&h2_barrier); - async move { + async move |_| { barrier.wait().await; Ok(HelperResponse::ok()) } - }); + }, respond_ok); let h3 = prepare_query_handler(move |_| { let barrier = Arc::clone(&h3_barrier); async move { From e557df91d8f4c1b578591a2757c0dbebc10d175d Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Mon, 11 Nov 2024 11:21:13 -0800 Subject: [PATCH 08/47] Undid tests. Need to handle different addrs. --- ipa-core/src/query/processor.rs | 82 +++++++++------------------------ 1 file changed, 21 insertions(+), 61 deletions(-) diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 9a015f535..efbe67f18 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -458,7 +458,7 @@ pub enum QueryKillStatus { #[cfg(all(test, unit_test))] mod tests { - use std::{array, future::Future, marker::PhantomData, sync::Arc}; + use std::{array, future::Future, sync::Arc}; use futures::pin_mut; use futures_util::future::poll_immediate; @@ -467,7 +467,11 @@ mod tests { use crate::{ ff::FieldType, helpers::{ - make_owned_handler, query::{PrepareQuery, QueryConfig, QueryStatusRequest, QueryType::TestMultiply}, routing::RouteId, ApiError, HandlerBox, HelperIdentity, HelperResponse, InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport, RequestHandler, RoleAssignment, Transport, TransportIdentity + make_owned_handler, + query::{PrepareQuery, QueryConfig, QueryType::TestMultiply}, + ApiError, HandlerBox, HelperIdentity, HelperResponse, InMemoryMpcNetwork, + InMemoryShardNetwork, InMemoryTransport, RequestHandler, RoleAssignment, Transport, + TransportIdentity, }, protocol::QueryId, query::{ @@ -477,7 +481,7 @@ mod tests { sharding::ShardIndex, }; - /*fn prepare_query_handler(cb: F) -> Arc> + fn prepare_query_handler(cb: F) -> Arc> where F: Fn(PrepareQuery) -> Fut + Send + Sync + 'static, Fut: Future> + Send + Sync + 'static, @@ -486,65 +490,14 @@ mod tests { let prepare_query = req.into().unwrap(); cb(prepare_query) }) - }*/ - - fn create_handler(prepare_handler: FPQ, status_hanler: FQS) -> Arc> - where - FPQ: Fn(PrepareQuery) -> Fut + Send + Sync + 'static, - FQS: Fn(QueryStatusRequest) -> Fut + Send + Sync + 'static, - Fut: Future> + Send + Sync + 'static, - I: TransportIdentity, - { - make_owned_handler(move |req, _| { - match req.route { - RouteId::PrepareQuery => prepare_handler(req.into().unwrap()), - RouteId::QueryStatus => status_hanler(req.into().unwrap()), - _ => panic!("unexpected route {:?}", req.route) - } - }) - } - - async fn respond_ok(_: T) -> Result { - Ok(HelperResponse::ok()) - } - - struct TestHandler { - phantom: PhantomData, - prepare_handle: Option, - status_handle: Option, - } - - impl TestHandler { - fn create_response(opterror: &mut Option) -> Result { - if let Some(error) = opterror.take() { - Err(error) - } else { - Ok(HelperResponse::ok()) - } - } - } - - #[async_trait::async_trait] - impl RequestHandler for TestHandler { - async fn handle( - &self, - req: crate::helpers::routing::Addr, - _data: crate::helpers::BodyStream, - ) -> Result { - match req.route { - //RouteId::PrepareQuery => Self::create_response(&mut self.prepare_error), - //RouteId::QueryStatus => Self::create_response(&mut self.status_error), - _ => panic!("unexpected route {:?}", req.route) - } - } } fn helper_respond_ok() -> Arc> { - create_handler(respond_ok, respond_ok) + prepare_query_handler(|_| async { Ok(HelperResponse::ok()) }) } fn shard_respond_ok(_si: ShardIndex) -> Arc> { - create_handler(respond_ok::, respond_ok::) + prepare_query_handler(|_| async { Ok(HelperResponse::ok()) }) } fn test_multiply_config() -> QueryConfig { @@ -664,13 +617,13 @@ mod tests { let barrier = Arc::new(Barrier::new(3)); let h2_barrier = Arc::clone(&barrier); let h3_barrier = Arc::clone(&barrier); - let h2 = create_handler(move |_| { + let h2 = prepare_query_handler(move |_| { let barrier = Arc::clone(&h2_barrier); - async move |_| { + async move { barrier.wait().await; Ok(HelperResponse::ok()) } - }, respond_ok); + }); let h3 = prepare_query_handler(move |_| { let barrier = Arc::clone(&h3_barrier); async move { @@ -784,7 +737,11 @@ mod tests { let t = TestComponents::new(args); let r = t .processor - .new_query(t.first_transport, t.shard_transport, t.query_config) + .new_query( + t.first_transport, + t.shard_transport.clone_ref(), + t.query_config, + ) .await; // The following makes sure the error is a broadcast error from shard 2 assert!(r.is_err()); @@ -796,7 +753,10 @@ mod tests { } } assert!(matches!( - t.processor.query_status(QueryId).unwrap_err(), + t.processor + .query_status(t.shard_transport, QueryId) + .await + .unwrap_err(), QueryStatusError::NoSuchQuery(_) )); } From 03dc557cdb4261c54c79d7083edce1b78513a363 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Mon, 11 Nov 2024 11:37:51 -0800 Subject: [PATCH 09/47] Simple fix for tests --- ipa-core/src/query/processor.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index efbe67f18..873b9a327 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -469,6 +469,7 @@ mod tests { helpers::{ make_owned_handler, query::{PrepareQuery, QueryConfig, QueryType::TestMultiply}, + routing::Addr, ApiError, HandlerBox, HelperIdentity, HelperResponse, InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport, RequestHandler, RoleAssignment, Transport, TransportIdentity, @@ -483,13 +484,10 @@ mod tests { fn prepare_query_handler(cb: F) -> Arc> where - F: Fn(PrepareQuery) -> Fut + Send + Sync + 'static, + F: Fn(Addr) -> Fut + Send + Sync + 'static, Fut: Future> + Send + Sync + 'static, { - make_owned_handler(move |req, _| { - let prepare_query = req.into().unwrap(); - cb(prepare_query) - }) + make_owned_handler(move |req, _| cb(req)) } fn helper_respond_ok() -> Arc> { From a2fb5f5724d9e89a96eb2e443c0a9f6e6903cfe2 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Tue, 12 Nov 2024 08:58:24 -0800 Subject: [PATCH 10/47] combined status temp --- ipa-core/src/app.rs | 2 +- ipa-core/src/query/processor.rs | 63 +++++++++++++++++++++++++++------ ipa-core/src/query/state.rs | 3 ++ 3 files changed, 57 insertions(+), 11 deletions(-) diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index dda2f1321..9d2d6fd00 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -193,7 +193,7 @@ impl RequestHandler for Inner { } RouteId::QueryStatus => { let query_id = ext_query_id(&req)?; - HelperResponse::from(qp.shard_ready(query_id)?) + HelperResponse::from(qp.shard_status(query_id)?) } r => { return Err(ApiError::BadRequest( diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 873b9a327..35dbfa9c4 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -374,14 +374,14 @@ impl Processor { Ok(status) } - /// Returns the staus of this single shard. + /// Compares this shard status against the given type. Returns an error if different. /// /// ## Errors /// If query is not registered on this helper. /// /// ## Panics /// If the query collection mutex is poisoned. - pub fn shard_ready(&self, query_id: QueryId) -> Result { + pub fn shard_status(&self, query_id: QueryId) -> Result { let status = self .get_status(query_id) .ok_or(QueryStatusError::NoSuchQuery(query_id))?; @@ -482,7 +482,7 @@ mod tests { sharding::ShardIndex, }; - fn prepare_query_handler(cb: F) -> Arc> + fn create_handler(cb: F) -> Arc> where F: Fn(Addr) -> Fut + Send + Sync + 'static, Fut: Future> + Send + Sync + 'static, @@ -491,11 +491,11 @@ mod tests { } fn helper_respond_ok() -> Arc> { - prepare_query_handler(|_| async { Ok(HelperResponse::ok()) }) + create_handler(|_| async { Ok(HelperResponse::ok()) }) } fn shard_respond_ok(_si: ShardIndex) -> Arc> { - prepare_query_handler(|_| async { Ok(HelperResponse::ok()) }) + create_handler(|_| async { Ok(HelperResponse::ok()) }) } fn test_multiply_config() -> QueryConfig { @@ -615,14 +615,14 @@ mod tests { let barrier = Arc::new(Barrier::new(3)); let h2_barrier = Arc::clone(&barrier); let h3_barrier = Arc::clone(&barrier); - let h2 = prepare_query_handler(move |_| { + let h2 = create_handler(move |_| { let barrier = Arc::clone(&h2_barrier); async move { barrier.wait().await; Ok(HelperResponse::ok()) } }); - let h3 = prepare_query_handler(move |_| { + let h3 = create_handler(move |_| { let barrier = Arc::clone(&h3_barrier); async move { barrier.wait().await; @@ -696,7 +696,7 @@ mod tests { async fn prepare_error() { let mut args = TestComponentsArgs::default(); let h2 = helper_respond_ok(); - let h3 = prepare_query_handler(|_| async move { + let h3 = create_handler(|_| async move { Err(ApiError::QueryPrepare(PrepareQueryError::WrongTarget)) }); args.mpc_handlers = [None, Some(h2), Some(h3)]; @@ -719,7 +719,7 @@ mod tests { #[tokio::test] async fn shard_prepare_error() { fn shard_handle(si: ShardIndex) -> Arc> { - prepare_query_handler(move |_| async move { + create_handler(move |_| async move { if si == ShardIndex(2) { Err(ApiError::QueryPrepare(PrepareQueryError::AlreadyRunning)) } else { @@ -770,7 +770,7 @@ mod tests { // First we setup MPC handlers that will return some error let mut args = TestComponentsArgs::default(); let h2 = helper_respond_ok(); - let h3 = prepare_query_handler(|_| async move { + let h3 = create_handler(|_| async move { Err(ApiError::QueryPrepare(PrepareQueryError::WrongTarget)) }); args.mpc_handlers = [None, Some(h2), Some(h3)]; @@ -923,6 +923,49 @@ mod tests { } } + mod query_status { + use crate::protocol::QueryId; + + use super::*; + + /// * From the standpoint of leader shard in Helper 1 + /// * On query_status + /// + /// If one of my shards isn't ready + #[tokio::test] + async fn combined_status_response() { + fn shard_handle(si: ShardIndex) -> Arc> { + create_handler(move |_| async move { + match si { + ShardIndex(1) => Ok(HelperResponse::from(QueryStatus::AwaitingInputs)), + ShardIndex(2) => Ok(HelperResponse::from(QueryStatus::Running)), + ShardIndex(3) => Ok(HelperResponse::from(QueryStatus::Completed)), + _ => Ok(HelperResponse::from(QueryStatus::Running)) + } + }) + } + let mut args = TestComponentsArgs { + shard_count: 4, + ..Default::default() + }; + args.set_shard_handler(shard_handle); + let t = TestComponents::new(args); + let _ = t.processor + .new_query( + Transport::clone_ref(&t.first_transport), + Transport::clone_ref(&t.shard_transport), + t.query_config, + ) + .await + .unwrap(); + assert_eq!(t.processor + .query_status(t.shard_transport.clone_ref(),QueryId) + .await + .unwrap(), + QueryStatus::MixedStatus); + } + } + mod kill { use std::sync::Arc; diff --git a/ipa-core/src/query/state.rs b/ipa-core/src/query/state.rs index 460296022..74aead2be 100644 --- a/ipa-core/src/query/state.rs +++ b/ipa-core/src/query/state.rs @@ -33,6 +33,9 @@ pub enum QueryStatus { AwaitingCompletion, /// Query has finished and results are available. Completed, + /// This is used when there are different states for multiple queries, for instance when + /// querying a sharded helper, and shards are in different states. + MixedStatus, } impl From<&QueryState> for QueryStatus { From db3f90986173df7de4f41b037a2dc671e4448059 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Thu, 14 Nov 2024 11:27:30 -0800 Subject: [PATCH 11/47] added combined status test --- ipa-core/src/app.rs | 6 +- ipa-core/src/helpers/transport/query/mod.rs | 30 ++++++- ipa-core/src/query/processor.rs | 90 ++++++++++++++------- ipa-core/src/query/state.rs | 3 - 4 files changed, 91 insertions(+), 38 deletions(-) diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index 9d2d6fd00..082bea3a6 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -5,7 +5,7 @@ use async_trait::async_trait; use crate::{ executor::IpaRuntime, helpers::{ - query::{PrepareQuery, QueryConfig, QueryInput}, + query::{CompareStatusRequest, PrepareQuery, QueryConfig, QueryInput}, routing::{Addr, RouteId}, ApiError, BodyStream, HandlerBox, HandlerRef, HelperIdentity, HelperResponse, MpcTransportImpl, RequestHandler, ShardTransportImpl, Transport, TransportIdentity, @@ -192,8 +192,8 @@ impl RequestHandler for Inner { HelperResponse::from(qp.prepare_shard(&self.shard_transport, req)?) } RouteId::QueryStatus => { - let query_id = ext_query_id(&req)?; - HelperResponse::from(qp.shard_status(query_id)?) + let req = req.into::()?; + HelperResponse::from(qp.shard_status(&req)?) } r => { return Err(ApiError::BadRequest( diff --git a/ipa-core/src/helpers/transport/query/mod.rs b/ipa-core/src/helpers/transport/query/mod.rs index 71d3e20db..26f90be2e 100644 --- a/ipa-core/src/helpers/transport/query/mod.rs +++ b/ipa-core/src/helpers/transport/query/mod.rs @@ -15,6 +15,7 @@ use crate::{ RoleAssignment, RouteParams, }, protocol::QueryId, + query::QueryStatus, }; #[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Serialize)] @@ -194,7 +195,7 @@ impl Debug for QueryInput { } } -#[derive(Clone, Debug, Serialize, Deserialize)] +/*#[derive(Clone, Debug, Serialize, Deserialize)] #[cfg_attr(test, derive(PartialEq, Eq))] pub struct QueryStatusRequest { pub query_id: QueryId, @@ -215,6 +216,33 @@ impl RouteParams for QueryStatusRequest { NoStep } + fn extra(&self) -> Self::Params { + serde_json::to_string(self).unwrap() + } +}*/ + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[cfg_attr(test, derive(PartialEq, Eq))] +pub struct CompareStatusRequest { + pub query_id: QueryId, + pub status: QueryStatus, +} + +impl RouteParams for CompareStatusRequest { + type Params = String; + + fn resource_identifier(&self) -> RouteId { + RouteId::QueryStatus + } + + fn query_id(&self) -> QueryId { + self.query_id + } + + fn gate(&self) -> NoStep { + NoStep + } + fn extra(&self) -> Self::Params { serde_json::to_string(self).unwrap() } diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 35dbfa9c4..ea383e4f1 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -10,7 +10,7 @@ use crate::{ error::Error as ProtocolError, executor::IpaRuntime, helpers::{ - query::{PrepareQuery, QueryConfig, QueryInput, QueryStatusRequest}, + query::{CompareStatusRequest, PrepareQuery, QueryConfig, QueryInput}, BroadcastError, Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, Role, RoleAssignment, ShardTransportError, ShardTransportImpl, Transport, }, @@ -112,6 +112,12 @@ pub enum QueryStatusError { NoSuchQuery(QueryId), #[error(transparent)] ShardBroadcastError(#[from] BroadcastError), + #[error("My status {my_status:?} for query {query_id:?} differs from {other_status:?}")] + DifferentStatus { + query_id: QueryId, + my_status: QueryStatus, + other_status: QueryStatus, + }, } #[derive(thiserror::Error, Debug)] @@ -367,7 +373,7 @@ impl Processor { .get_status(query_id) .ok_or(QueryStatusError::NoSuchQuery(query_id))?; - let shard_query_status_req = QueryStatusRequest { query_id }; + let shard_query_status_req = CompareStatusRequest { query_id, status }; shard_transport.broadcast(shard_query_status_req).await?; @@ -381,10 +387,20 @@ impl Processor { /// /// ## Panics /// If the query collection mutex is poisoned. - pub fn shard_status(&self, query_id: QueryId) -> Result { + pub fn shard_status( + &self, + req: &CompareStatusRequest, + ) -> Result { let status = self - .get_status(query_id) - .ok_or(QueryStatusError::NoSuchQuery(query_id))?; + .get_status(req.query_id) + .ok_or(QueryStatusError::NoSuchQuery(req.query_id))?; + if req.status != status { + return Err(QueryStatusError::DifferentStatus { + query_id: req.query_id, + my_status: status, + other_status: req.status, + }); + } Ok(status) } @@ -482,6 +498,14 @@ mod tests { sharding::ShardIndex, }; + fn prepare_query() -> PrepareQuery { + PrepareQuery { + query_id: QueryId, + config: test_multiply_config(), + roles: RoleAssignment::new(HelperIdentity::make_three()), + } + } + fn create_handler(cb: F) -> Arc> where F: Fn(Addr) -> Fut + Send + Sync + 'static, @@ -797,14 +821,6 @@ mod tests { use super::*; use crate::query::QueryStatusError; - fn prepare_query() -> PrepareQuery { - PrepareQuery { - query_id: QueryId, - config: test_multiply_config(), - roles: RoleAssignment::new(HelperIdentity::make_three()), - } - } - #[tokio::test] async fn happy_case() { let req = prepare_query(); @@ -924,23 +940,26 @@ mod tests { } mod query_status { - use crate::protocol::QueryId; - use super::*; + use crate::protocol::QueryId; /// * From the standpoint of leader shard in Helper 1 /// * On query_status - /// + /// /// If one of my shards isn't ready #[tokio::test] async fn combined_status_response() { fn shard_handle(si: ShardIndex) -> Arc> { create_handler(move |_| async move { match si { - ShardIndex(1) => Ok(HelperResponse::from(QueryStatus::AwaitingInputs)), - ShardIndex(2) => Ok(HelperResponse::from(QueryStatus::Running)), - ShardIndex(3) => Ok(HelperResponse::from(QueryStatus::Completed)), - _ => Ok(HelperResponse::from(QueryStatus::Running)) + ShardIndex(3) => { + Err(ApiError::QueryStatus(QueryStatusError::DifferentStatus { + query_id: QueryId, + my_status: QueryStatus::Completed, + other_status: QueryStatus::Preparing, + })) + } + _ => Ok(HelperResponse::ok()), } }) } @@ -950,19 +969,28 @@ mod tests { }; args.set_shard_handler(shard_handle); let t = TestComponents::new(args); - let _ = t.processor - .new_query( - Transport::clone_ref(&t.first_transport), - Transport::clone_ref(&t.shard_transport), - t.query_config, + let req = prepare_query(); + // Using prepare shard to set the inner state, but in reality we should be using prepare_helper + // Prepare helper will use the shard_handle defined above though and will fail. The following + // achieves the same state. + t.processor + .prepare_shard( + &t.shard_network + .transport(HelperIdentity::ONE, ShardIndex::from(1)), + req, ) - .await .unwrap(); - assert_eq!(t.processor - .query_status(t.shard_transport.clone_ref(),QueryId) - .await - .unwrap(), - QueryStatus::MixedStatus); + let r = t + .processor + .query_status(t.shard_transport.clone_ref(), QueryId) + .await; + if let Err(e) = r { + if let QueryStatusError::ShardBroadcastError(be) = e { + assert_eq!(be.failures[0].0, ShardIndex(3)); + } else { + panic!("Unexpected error type"); + } + } } } diff --git a/ipa-core/src/query/state.rs b/ipa-core/src/query/state.rs index 74aead2be..460296022 100644 --- a/ipa-core/src/query/state.rs +++ b/ipa-core/src/query/state.rs @@ -33,9 +33,6 @@ pub enum QueryStatus { AwaitingCompletion, /// Query has finished and results are available. Completed, - /// This is used when there are different states for multiple queries, for instance when - /// querying a sharded helper, and shards are in different states. - MixedStatus, } impl From<&QueryState> for QueryStatus { From 22277cd31d6c850d390e77911383b657006a7150 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Thu, 14 Nov 2024 16:49:21 -0800 Subject: [PATCH 12/47] Cleanup --- ipa-core/src/app.rs | 9 +------ ipa-core/src/helpers/transport/query/mod.rs | 26 --------------------- 2 files changed, 1 insertion(+), 34 deletions(-) diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index 082bea3a6..da0e83409 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -200,14 +200,7 @@ impl RequestHandler for Inner { format!("{r:?} request must not be handled by shard query processing flow") .into(), )) - } /*RouteId::CompleteQuery => { - let query_id = ext_query_id(&req)?; - HelperResponse::from(qp.complete(query_id).await?) - } - RouteId::KillQuery => { - let query_id = ext_query_id(&req)?; - HelperResponse::from(qp.kill(query_id)?) - }*/ + } }) } } diff --git a/ipa-core/src/helpers/transport/query/mod.rs b/ipa-core/src/helpers/transport/query/mod.rs index 26f90be2e..e491fcb2a 100644 --- a/ipa-core/src/helpers/transport/query/mod.rs +++ b/ipa-core/src/helpers/transport/query/mod.rs @@ -195,32 +195,6 @@ impl Debug for QueryInput { } } -/*#[derive(Clone, Debug, Serialize, Deserialize)] -#[cfg_attr(test, derive(PartialEq, Eq))] -pub struct QueryStatusRequest { - pub query_id: QueryId, -} - -impl RouteParams for QueryStatusRequest { - type Params = String; - - fn resource_identifier(&self) -> RouteId { - RouteId::QueryStatus - } - - fn query_id(&self) -> QueryId { - self.query_id - } - - fn gate(&self) -> NoStep { - NoStep - } - - fn extra(&self) -> Self::Params { - serde_json::to_string(self).unwrap() - } -}*/ - #[derive(Clone, Debug, Serialize, Deserialize)] #[cfg_attr(test, derive(PartialEq, Eq))] pub struct CompareStatusRequest { From 997eec45134b5b2a20f0556c9ab9240976a88f6b Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Thu, 14 Nov 2024 18:21:42 -0800 Subject: [PATCH 13/47] Adding a few simple tests --- ipa-core/src/app.rs | 2 +- ipa-core/src/query/processor.rs | 62 +++++++++++++++++++++++++++++++-- 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index da0e83409..1c3371738 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -193,7 +193,7 @@ impl RequestHandler for Inner { } RouteId::QueryStatus => { let req = req.into::()?; - HelperResponse::from(qp.shard_status(&req)?) + HelperResponse::from(qp.shard_status(&self.shard_transport, &req)?) } r => { return Err(ApiError::BadRequest( diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index ea383e4f1..4f860f011 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -112,6 +112,10 @@ pub enum QueryStatusError { NoSuchQuery(QueryId), #[error(transparent)] ShardBroadcastError(#[from] BroadcastError), + #[error("This shard {0:?} isn't the leader (shard 0)")] + NotLeader(ShardIndex), + #[error("This is the leader shard")] + Leader, #[error("My status {my_status:?} for query {query_id:?} differs from {other_status:?}")] DifferentStatus { query_id: QueryId, @@ -369,6 +373,11 @@ impl Processor { shard_transport: ShardTransportImpl, query_id: QueryId, ) -> Result { + let shard_index = shard_transport.identity(); + if shard_index != ShardIndex::FIRST { + return Err(QueryStatusError::NotLeader(shard_index)); + } + let status = self .get_status(query_id) .ok_or(QueryStatusError::NoSuchQuery(query_id))?; @@ -383,14 +392,19 @@ impl Processor { /// Compares this shard status against the given type. Returns an error if different. /// /// ## Errors - /// If query is not registered on this helper. + /// If query is not registered on this helper or /// /// ## Panics /// If the query collection mutex is poisoned. pub fn shard_status( &self, + shard_transport: &ShardTransportImpl, req: &CompareStatusRequest, ) -> Result { + let shard_index = shard_transport.identity(); + if shard_index == ShardIndex::FIRST { + return Err(QueryStatusError::Leader); + } let status = self .get_status(req.query_id) .ok_or(QueryStatusError::NoSuchQuery(req.query_id))?; @@ -941,7 +955,7 @@ mod tests { mod query_status { use super::*; - use crate::protocol::QueryId; + use crate::{helpers::query::CompareStatusRequest, protocol::QueryId}; /// * From the standpoint of leader shard in Helper 1 /// * On query_status @@ -992,6 +1006,50 @@ mod tests { } } } + + /// Context: + /// * From the standpoint of the second shard in Helper 2 + /// + /// This test makes sure that an error is returned if I get a [`Processor::query_status`] + /// call. Only the shard leader (shard 0) should handle those calls. + #[tokio::test] + async fn rejects_if_not_shard_leader() { + let t = TestComponents::new(TestComponentsArgs::default()); + assert!(matches!( + t.processor + .query_status( + t.shard_network + .transport(HelperIdentity::TWO, ShardIndex::from(1)), + QueryId + ) + .await, + Err(QueryStatusError::NotLeader(_)) + )); + } + + /// Context: + /// * From the standpoint of the leader shard in Helper 2 + /// + /// This test makes sure that an error is returned if I get a [`Processor::shard_status`] + /// call. Only non-leaders (1,2,3...) should handle those calls. + #[tokio::test] + async fn shard_not_leader() { + let req = CompareStatusRequest { + query_id: QueryId, + status: QueryStatus::Running, + }; + let t = TestComponents::new(TestComponentsArgs::default()); + assert!(matches!( + t.processor + .shard_status( + &t.shard_network + .transport(HelperIdentity::TWO, ShardIndex::FIRST), + &req + ) + .unwrap_err(), + QueryStatusError::Leader + )); + } } mod kill { From 44f699d87100ce19bedf888eac6b35c73043f51e Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Fri, 15 Nov 2024 16:21:38 -0800 Subject: [PATCH 14/47] Revert "using shard_url instead of port" This reverts commit 8c201ae1b955da2e38f6a0c634e7a23925e834a3. --- ipa-core/src/bin/helper.rs | 1 - ipa-core/src/cli/clientconf.rs | 8 +- ipa-core/src/config.rs | 197 ++++++++++++++++----------------- ipa-core/src/serde.rs | 21 ---- ipa-core/src/utils/mod.rs | 17 +++ 5 files changed, 115 insertions(+), 129 deletions(-) diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 7de9de90c..8eb006e94 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -71,7 +71,6 @@ struct ServerArgs { #[arg(short, long, default_value = "3000")] port: Option, - /// Port to use for shard-to-shard communication, if sharded MPC is used #[arg(default_value = "6000")] shard_port: Option, diff --git a/ipa-core/src/cli/clientconf.rs b/ipa-core/src/cli/clientconf.rs index 222b26420..a57fd59d4 100644 --- a/ipa-core/src/cli/clientconf.rs +++ b/ipa-core/src/cli/clientconf.rs @@ -139,12 +139,8 @@ pub fn gen_client_config<'a>( )), ); peer.insert( - String::from("shard_url"), - Value::String(format!( - "{host}:{port}", - host = client_conf.host, - port = client_conf.shard_port - )), + String::from("shard_port"), + Value::Integer(client_conf.shard_port.into()), ); peer.insert(String::from("certificate"), Value::String(certificate)); peer.insert( diff --git a/ipa-core/src/config.rs b/ipa-core/src/config.rs index 6b5df3c84..7f087de7d 100644 --- a/ipa-core/src/config.rs +++ b/ipa-core/src/config.rs @@ -3,6 +3,7 @@ use std::{ fmt::{Debug, Formatter}, iter::zip, path::PathBuf, + str::FromStr, time::Duration, }; @@ -38,8 +39,8 @@ pub enum Error { InvalidNetworkSize(usize), #[error(transparent)] IOError(#[from] std::io::Error), - #[error("Missing shard URLs for peers {0:?}")] - MissingShardUrls(Vec), + #[error("Missing shard ports for peers {0:?}")] + MissingShardPorts(Vec), } /// Configuration describing either 3 peers in a Ring or N shard peers. In a non-sharded case a @@ -200,12 +201,12 @@ struct ShardedNetworkToml { } impl ShardedNetworkToml { - fn missing_shard_urls(&self) -> Vec { + fn missing_shard_ports(&self) -> Vec { self.peers .iter() .enumerate() .filter_map(|(i, peer)| { - if peer.shard_url.is_some() { + if peer.shard_port.is_some() { None } else { Some(i) @@ -216,14 +217,12 @@ impl ShardedNetworkToml { } /// This struct is only used by [`parse_sharded_network_toml`] to generate [`PeerConfig`]. It -/// contains an optional `shard_url`. +/// contains an optional `shard_port`. #[derive(Clone, Debug, Deserialize)] struct ShardedPeerConfigToml { #[serde(flatten)] pub config: PeerConfig, - - #[serde(default, with = "crate::serde::option::uri")] - pub shard_url: Option, + pub shard_port: Option, } impl ShardedPeerConfigToml { @@ -232,15 +231,21 @@ impl ShardedPeerConfigToml { self.config.clone() } - /// Create a new Peer but its url using [`ShardedPeerConfigToml::shard_url`]. + /// Create a new Peer but its url using [`ShardedPeerConfigToml::shard_port`]. fn to_shard_peer(&self) -> PeerConfig { + let url = self.config.url.to_string(); + let new_url = format!( + "{}{}", + &url[..=url.find(':').unwrap()], + self.shard_port.expect("Shard port should be set") + ); let mut shard_peer = self.config.clone(); - shard_peer.url = self.shard_url.clone().expect("Shard URL should be set"); + shard_peer.url = Uri::from_str(&new_url).expect("Problem creating uri with sharded port"); shard_peer } } -/// Parses a [`ShardedNetworkToml`] from a network.toml file. Validates that sharding urls are set +/// Parses a [`ShardedNetworkToml`] from a network.toml file. Validates that sharding ports are set /// if necessary. The number of peers needs to be a multiple of 3. fn parse_sharded_network_toml(input: &str) -> Result { use config::{Config, File, FileFormat}; @@ -255,11 +260,11 @@ fn parse_sharded_network_toml(input: &str) -> Result } // Validate sharding config is set - let any_shard_url_set = parsed.peers.iter().any(|peer| peer.shard_url.is_some()); - if any_shard_url_set || parsed.peers.len() > 3 { - let missing_urls = parsed.missing_shard_urls(); - if !missing_urls.is_empty() { - return Err(Error::MissingShardUrls(missing_urls)); + let any_shard_port_set = parsed.peers.iter().any(|peer| peer.shard_port.is_some()); + if any_shard_port_set || parsed.peers.len() > 3 { + let missing_ports = parsed.missing_shard_ports(); + if !missing_ports.is_empty() { + return Err(Error::MissingShardPorts(missing_ports)); } } @@ -268,7 +273,7 @@ fn parse_sharded_network_toml(input: &str) -> Result /// Reads a the config for a specific, single, sharded server from string. Expects config to be /// toml format. The server in the network is specified via `id`, `shard_index` and -/// `shard_count`. This function expects shard urls to be set for all peers. +/// `shard_count`. This function expects shard ports to be set for all peers. /// /// The first 3 peers corresponds to the leaders Ring. H1 shard 0, H2 shard 0, and H3 shard 0. /// The next 3 correspond to the next ring with `shard_index` equals 1 and so on. @@ -287,9 +292,9 @@ pub fn sharded_server_from_toml_str( shard_count: ShardIndex, ) -> Result<(NetworkConfig, NetworkConfig), Error> { let all_network = parse_sharded_network_toml(input)?; - let missing_urls = all_network.missing_shard_urls(); - if !missing_urls.is_empty() { - return Err(Error::MissingShardUrls(missing_urls)); + let missing_ports = all_network.missing_shard_ports(); + if !missing_ports.is_empty() { + return Err(Error::MissingShardPorts(missing_ports)); } let ix: usize = shard_index.as_index(); @@ -681,6 +686,7 @@ mod tests { helpers::HelperIdentity, net::test::TestConfigBuilder, sharding::ShardIndex, + utils::replace_all, }; const URI_1: &str = "http://localhost:3000"; @@ -724,10 +730,7 @@ mod tests { let mut rng = StdRng::seed_from_u64(1); let (_, public_key) = X25519HkdfSha256::gen_keypair(&mut rng); let config = HpkeClientConfig { public_key }; - assert_eq!( - format!("{config:?}"), - r#"HpkeClientConfig { public_key: "2bd9da78f01d8bc6948bbcbe44ec1e7163d05083e267d110cdb2e75d847e3b6f" }"# - ); + assert_eq!(format!("{config:?}"), "HpkeClientConfig { public_key: \"2bd9da78f01d8bc6948bbcbe44ec1e7163d05083e267d110cdb2e75d847e3b6f\" }"); } #[test] @@ -792,9 +795,9 @@ mod tests { .unwrap(); assert_eq!( vec![ - "helper1.shard1.org:443", - "helper2.shard1.org:443", - "helper3.shard1.org:443" + "helper1.prod.ipa-helper.shard1.dev:443", + "helper2.prod.ipa-helper.shard1.dev:443", + "helper3.prod.ipa-helper.shard1.dev:443" ], mpc.peers .into_iter() @@ -803,9 +806,9 @@ mod tests { ); assert_eq!( vec![ - "helper2.shard0.org:555", - "helper2.shard1.org:555", - "helper2.shard2.org:555" + "helper2.prod.ipa-helper.shard0.dev:555", + "helper2.prod.ipa-helper.shard1.dev:555", + "helper2.prod.ipa-helper.shard2.dev:555" ], shard .peers @@ -815,16 +818,16 @@ mod tests { ); } - /// Tests that the url of a shard gets updated with the shard url. + /// Tests that the url of a shard gets updated with the shard port. #[test] fn transform_sharded_peers() { let mut n = parse_sharded_network_toml(&SHARDED_OK_REPEAT).unwrap(); assert_eq!( - "helper3.shard2.org:666", + "helper3.prod.ipa-helper.shard2.dev:666", n.peers.pop().unwrap().to_shard_peer().url ); assert_eq!( - "helper2.shard2.org:555", + "helper2.prod.ipa-helper.shard2.dev:555", n.peers.pop().unwrap().to_shard_peer().url ); } @@ -838,27 +841,27 @@ mod tests { )); } - /// If any sharded url is set (indicating this is a sharding config), then ALL urls must be set. + /// If any sharded port is set (indicating this is a sharding config), then ALL ports must be set. #[test] - fn parse_network_toml_shard_urls_some_set() { + fn parse_network_toml_shard_port_some_set() { assert!(matches!( - parse_sharded_network_toml(&SHARDED_COMPAT_ONE_URL), - Err(Error::MissingShardUrls(_)) + parse_sharded_network_toml(&SHARDED_COMPAT_ONE_PORT), + Err(Error::MissingShardPorts(_)) )); } - /// If there are more than 3 peers configured (indicating this is a sharding config), then ALL urls must be set. + /// If there are more than 3 peers configured (indicating this is a sharding config), then ALL ports must be set. #[test] - fn parse_network_toml_shard_urls_set() { + fn parse_network_toml_shard_port_set() { assert!(matches!( - parse_sharded_network_toml(&SHARDED_MISSING_URLS_REPEAT), - Err(Error::MissingShardUrls(_)) + parse_sharded_network_toml(&SHARDED_MISSING_PORTS_REPEAT), + Err(Error::MissingShardPorts(_)) )); } - /// Check that shard urls are given for [`sharded_server_from_toml_str`] or error is returned. + /// Check that shard ports are given for [`sharded_server_from_toml_str`] or error is returned. #[test] - fn parse_sharded_without_shard_urls() { + fn parse_sharded_without_shard_ports() { // Second, I test the networkconfig parsing assert!(matches!( sharded_server_from_toml_str( @@ -867,7 +870,7 @@ mod tests { ShardIndex::FIRST, ShardIndex::from(1) ), - Err(Error::MissingShardUrls(_)) + Err(Error::MissingShardPorts(_)) )); } @@ -882,15 +885,11 @@ mod tests { HttpClientConfigurator::Http2(_) )); assert_eq!(3, entire_network.peers.len()); - assert_eq!("helper3.shard0.org:443", entire_network.peers[2].config.url); assert_eq!( - "helper3.shard0.org:666", - entire_network.peers[2] - .shard_url - .as_ref() - .unwrap() - .to_string() + "helper3.prod.ipa-helper.shard0.dev:443", + entire_network.peers[2].config.url ); + assert_eq!(Some(666), entire_network.peers[2].shard_port); } /// Testing happy case of a longer sharded network config @@ -900,14 +899,7 @@ mod tests { assert!(r_entire_network.is_ok()); let entire_network = r_entire_network.unwrap(); assert_eq!(9, entire_network.peers.len()); - assert_eq!( - "helper3.shard2.org:666", - entire_network.peers[8] - .shard_url - .as_ref() - .unwrap() - .to_string() - ); + assert_eq!(Some(666), entire_network.peers[8].shard_port); } /// This test validates that the new logic that handles sharded configurations can also handle the previous version @@ -921,7 +913,10 @@ mod tests { HttpClientConfigurator::Http2(_) )); assert_eq!(3, entire_network.peers.len()); - assert_eq!("helper3.org:443", entire_network.peers[2].config.url); + assert_eq!( + "helper3.prod.ipa-helper.dev:443", + entire_network.peers[2].config.url + ); } // Following are some large &str const used for tests @@ -930,20 +925,20 @@ mod tests { static NON_SHARDED_COMPAT: Lazy = Lazy::new(|| format!("{CLIENT}{P1}{REST}")); /// Invalid: Same as [`NON_SHARDED_COMPAT`] but with a single `shard_port` set. - static SHARDED_COMPAT_ONE_URL: Lazy = - Lazy::new(|| format!("{CLIENT}{P1}\nshard_url = \"helper1.org:777\"\n{REST}")); + static SHARDED_COMPAT_ONE_PORT: Lazy = + Lazy::new(|| format!("{CLIENT}{P1}\nshard_port = 777\n{REST}")); /// Helper const used to create client configs - const CLIENT: &str = r#"[client.http_config] + const CLIENT: &str = "[client.http_config] ping_interval_secs = 90.0 -version = "http2" -"#; +version = \"http2\" +"; /// Helper const that has the first part of a Peer, just before were `shard_port` should be /// specified. - const P1: &str = r#" + const P1: &str = " [[peers]] -certificate = """ +certificate = \"\"\" -----BEGIN CERTIFICATE----- MIIBmzCCAUGgAwIBAgIIMlnveFys5QUwCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb aGVscGVyMS5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMwM1oXDTI0 @@ -955,16 +950,16 @@ DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI zj0EAwIDSAAwRQIhAKVdDCQeXLRXDYXy4b1N1UxD/JPuD9H7zeRb8/nmIDTfAiBL a6L0t1Ug8i2RcequSo21x319Tvs5nUbGwzMFSS5wKA== -----END CERTIFICATE----- -""" -url = "helper1.org:443""#; +\"\"\" +url = \"helper1.prod.ipa-helper.dev:443\""; /// The rest of a configuration - const REST: &str = r#" + const REST: &str = " [peers.hpke] -public_key = "f458d5e1989b2b8f5dacd4143276aa81eaacf7449744ab1251ff667c43550756" +public_key = \"f458d5e1989b2b8f5dacd4143276aa81eaacf7449744ab1251ff667c43550756\" [[peers]] -certificate = """ +certificate = \"\"\" -----BEGIN CERTIFICATE----- MIIBmzCCAUGgAwIBAgIITOtoca16QckwCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb aGVscGVyMi5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMwOFoXDTI0 @@ -976,14 +971,14 @@ DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI zj0EAwIDSAAwRQIhAI4G5ICVm+v5KK5Y8WVetThtNCXGykUBAM1eE973FBOUAiAS XXgJe9q9hAfHf0puZbv0j0tGY3BiqCkJJaLvK7ba+g== -----END CERTIFICATE----- -""" -url = "helper2.org:443" +\"\"\" +url = \"helper2.prod.ipa-helper.dev:443\" [peers.hpke] -public_key = "62357179868e5594372b801ddf282c8523806a868a2bff2685f66aa05ffd6c22" +public_key = \"62357179868e5594372b801ddf282c8523806a868a2bff2685f66aa05ffd6c22\" [[peers]] -certificate = """ +certificate = \"\"\" -----BEGIN CERTIFICATE----- MIIBmzCCAUGgAwIBAgIIaf7eDCnXh2swCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb aGVscGVyMy5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMxMloXDTI0 @@ -995,17 +990,17 @@ DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI zj0EAwIDSAAwRQIhAOTSQWbN7kfIatNJEwWTBL4xOY88E3+SOnBNExCsTkQuAiBB /cwOQQUEeE4llrDp+EnyGbzmVm5bINz8gePIxkKqog== -----END CERTIFICATE----- -""" -url = "helper3.org:443" +\"\"\" +url = \"helper3.prod.ipa-helper.dev:443\" [peers.hpke] -public_key = "55f87a8794b4de9a60f8ede9ed000f5f10c028e22390922efc4fb63bc6be0a61" -"#; +public_key = \"55f87a8794b4de9a60f8ede9ed000f5f10c028e22390922efc4fb63bc6be0a61\" +"; /// Valid: A sharded configuration - const SHARDED_OK: &str = r#" + const SHARDED_OK: &str = " [[peers]] -certificate = """ +certificate = \"\"\" -----BEGIN CERTIFICATE----- MIIBmzCCAUGgAwIBAgIIMlnveFys5QUwCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb aGVscGVyMS5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMwM1oXDTI0 @@ -1017,15 +1012,15 @@ DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI zj0EAwIDSAAwRQIhAKVdDCQeXLRXDYXy4b1N1UxD/JPuD9H7zeRb8/nmIDTfAiBL a6L0t1Ug8i2RcequSo21x319Tvs5nUbGwzMFSS5wKA== -----END CERTIFICATE----- -""" -url = "helper1.shard0.org:443" -shard_url = "helper1.shard0.org:444" +\"\"\" +url = \"helper1.prod.ipa-helper.shard0.dev:443\" +shard_port = 444 [peers.hpke] -public_key = "f458d5e1989b2b8f5dacd4143276aa81eaacf7449744ab1251ff667c43550756" +public_key = \"f458d5e1989b2b8f5dacd4143276aa81eaacf7449744ab1251ff667c43550756\" [[peers]] -certificate = """ +certificate = \"\"\" -----BEGIN CERTIFICATE----- MIIBmzCCAUGgAwIBAgIITOtoca16QckwCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb aGVscGVyMi5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMwOFoXDTI0 @@ -1037,15 +1032,15 @@ DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI zj0EAwIDSAAwRQIhAI4G5ICVm+v5KK5Y8WVetThtNCXGykUBAM1eE973FBOUAiAS XXgJe9q9hAfHf0puZbv0j0tGY3BiqCkJJaLvK7ba+g== -----END CERTIFICATE----- -""" -url = "helper2.shard0.org:443" -shard_url = "helper2.shard0.org:555" +\"\"\" +url = \"helper2.prod.ipa-helper.shard0.dev:443\" +shard_port = 555 [peers.hpke] -public_key = "62357179868e5594372b801ddf282c8523806a868a2bff2685f66aa05ffd6c22" +public_key = \"62357179868e5594372b801ddf282c8523806a868a2bff2685f66aa05ffd6c22\" [[peers]] -certificate = """ +certificate = \"\"\" -----BEGIN CERTIFICATE----- MIIBmzCCAUGgAwIBAgIIaf7eDCnXh2swCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb aGVscGVyMy5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMxMloXDTI0 @@ -1057,21 +1052,21 @@ DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI zj0EAwIDSAAwRQIhAOTSQWbN7kfIatNJEwWTBL4xOY88E3+SOnBNExCsTkQuAiBB /cwOQQUEeE4llrDp+EnyGbzmVm5bINz8gePIxkKqog== -----END CERTIFICATE----- -""" -url = "helper3.shard0.org:443" -shard_url = "helper3.shard0.org:666" +\"\"\" +url = \"helper3.prod.ipa-helper.shard0.dev:443\" +shard_port = 666 [peers.hpke] -public_key = "55f87a8794b4de9a60f8ede9ed000f5f10c028e22390922efc4fb63bc6be0a61" -"#; +public_key = \"55f87a8794b4de9a60f8ede9ed000f5f10c028e22390922efc4fb63bc6be0a61\" +"; /// Valid: Three sharded configs together for 9 static SHARDED_OK_REPEAT: Lazy = Lazy::new(|| { format!( "{}{}{}", SHARDED_OK, - SHARDED_OK.replace("shard0", "shard1"), - SHARDED_OK.replace("shard0", "shard2") + replace_all(SHARDED_OK, "shard0", "shard1"), + replace_all(SHARDED_OK, "shard0", "shard2") ) }); @@ -1082,11 +1077,11 @@ public_key = "55f87a8794b4de9a60f8ede9ed000f5f10c028e22390922efc4fb63bc6be0a61" }); /// Invalid: Same as [`SHARDED_OK_REPEAT`] but without the expected ports - static SHARDED_MISSING_URLS_REPEAT: Lazy = Lazy::new(|| { + static SHARDED_MISSING_PORTS_REPEAT: Lazy = Lazy::new(|| { let lines: Vec<&str> = SHARDED_OK_REPEAT.lines().collect(); let new_lines: Vec = lines .iter() - .filter(|line| !line.starts_with("shard_url =")) + .filter(|line| !line.starts_with("shard_port =")) .map(std::string::ToString::to_string) .collect(); new_lines.join("\n") diff --git a/ipa-core/src/serde.rs b/ipa-core/src/serde.rs index ed65273d7..0acc2d925 100644 --- a/ipa-core/src/serde.rs +++ b/ipa-core/src/serde.rs @@ -13,27 +13,6 @@ pub mod uri { } } -#[cfg(feature = "web-app")] -pub mod option { - pub mod uri { - use hyper::Uri; - use serde::{de::Error, Deserialize, Deserializer}; - - /// # Errors - /// if deserializing from string fails, or if string is not a [`Uri`] - pub fn deserialize<'de, D: Deserializer<'de>>( - deserializer: D, - ) -> Result, D::Error> { - let opt_s: Option = Deserialize::deserialize(deserializer)?; - if let Some(s) = opt_s { - s.parse().map(Some).map_err(D::Error::custom) - } else { - Ok(None) - } - } - } -} - pub mod duration { use std::time::Duration; diff --git a/ipa-core/src/utils/mod.rs b/ipa-core/src/utils/mod.rs index e8dfd95ae..c19ada348 100644 --- a/ipa-core/src/utils/mod.rs +++ b/ipa-core/src/utils/mod.rs @@ -5,3 +5,20 @@ mod power_of_two; #[cfg(target_pointer_width = "64")] pub use power_of_two::NonZeroU32PowerOfTwo; + +/// Replaces all occurrences of `from` with `to` in `s`. +#[allow(dead_code)] +pub fn replace_all(s: &str, from: &str, to: &str) -> String { + let mut result = String::new(); + let mut i = 0; + while i < s.len() { + if s[i..].starts_with(from) { + result.push_str(to); + i += from.len(); + } else { + result.push(s.chars().nth(i).unwrap()); + i += 1; + } + } + result +} From 5bae2bf5dc189819058882b8ec4a6f1008cfd750 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Fri, 15 Nov 2024 16:22:11 -0800 Subject: [PATCH 15/47] Revert "Fixing HTTP/gen tests" This reverts commit 00fc91f65eeb16de9ba5a5747e098f22cd3d2d65. --- ipa-core/src/cli/clientconf.rs | 13 +---- ipa-core/src/cli/test_setup.rs | 8 +-- ipa-core/tests/common/mod.rs | 87 +++++++++++++++---------------- ipa-core/tests/helper_networks.rs | 4 +- 4 files changed, 46 insertions(+), 66 deletions(-) diff --git a/ipa-core/src/cli/clientconf.rs b/ipa-core/src/cli/clientconf.rs index a57fd59d4..341a4253a 100644 --- a/ipa-core/src/cli/clientconf.rs +++ b/ipa-core/src/cli/clientconf.rs @@ -17,9 +17,6 @@ pub struct ConfGenArgs { #[arg(short, long, num_args = 3, value_name = "PORT", default_values = vec!["3000", "3001", "3002"])] ports: Vec, - #[arg(short, long, num_args = 3, value_name = "SHARD_PORTS", default_values = vec!["6000", "6001", "6002"])] - shard_ports: Vec, - #[arg(long, num_args = 3, default_values = vec!["localhost", "localhost", "localhost"])] hosts: Vec, @@ -57,14 +54,13 @@ pub struct ConfGenArgs { /// [`ConfGenArgs`]: ConfGenArgs /// [`Paths`]: crate::cli::paths::PathExt pub fn setup(args: ConfGenArgs) -> Result<(), BoxError> { - let clients_conf: [_; 3] = zip(args.hosts.iter(), zip(args.ports, args.shard_ports)) + let clients_conf: [_; 3] = zip(args.hosts.iter(), args.ports) .enumerate() - .map(|(id, (host, (port, shard_port)))| { + .map(|(id, (host, port))| { let id: u8 = u8::try_from(id).unwrap() + 1; HelperClientConf { host, port, - shard_port, tls_cert_file: args.keys_dir.helper_tls_cert(id), mk_public_key_file: args.keys_dir.helper_mk_public_key(id), } @@ -100,7 +96,6 @@ pub fn setup(args: ConfGenArgs) -> Result<(), BoxError> { pub struct HelperClientConf<'a> { pub(crate) host: &'a str, pub(crate) port: u16, - pub(crate) shard_port: u16, pub(crate) tls_cert_file: PathBuf, pub(crate) mk_public_key_file: PathBuf, } @@ -138,10 +133,6 @@ pub fn gen_client_config<'a>( port = client_conf.port )), ); - peer.insert( - String::from("shard_port"), - Value::Integer(client_conf.shard_port.into()), - ); peer.insert(String::from("certificate"), Value::String(certificate)); peer.insert( String::from("hpke"), diff --git a/ipa-core/src/cli/test_setup.rs b/ipa-core/src/cli/test_setup.rs index a3aa93cc4..538faf180 100644 --- a/ipa-core/src/cli/test_setup.rs +++ b/ipa-core/src/cli/test_setup.rs @@ -36,9 +36,6 @@ pub struct TestSetupArgs { #[arg(short, long, num_args = 3, value_name = "PORT", default_values = vec!["3000", "3001", "3002"])] ports: Vec, - - #[arg(short, long, num_args = 3, value_name = "SHARD_PORT", default_values = vec!["6000", "6001", "6002"])] - shard_ports: Vec, } /// Prepare a test network of three helpers. @@ -59,8 +56,8 @@ pub fn test_setup(args: TestSetupArgs) -> Result<(), BoxError> { let localhost = String::from("localhost"); - let clients_config: [_; 3] = zip([1, 2, 3], zip(args.ports, args.shard_ports)) - .map(|(id, (port, shard_port))| { + let clients_config: [_; 3] = zip([1, 2, 3], args.ports) + .map(|(id, port)| { let keygen_args = KeygenArgs { name: localhost.clone(), tls_cert: args.output_dir.helper_tls_cert(id), @@ -75,7 +72,6 @@ pub fn test_setup(args: TestSetupArgs) -> Result<(), BoxError> { Ok(HelperClientConf { host: &localhost, port, - shard_port, tls_cert_file: keygen_args.tls_cert, mk_public_key_file: keygen_args.mk_public_key, }) diff --git a/ipa-core/tests/common/mod.rs b/ipa-core/tests/common/mod.rs index dae743987..be582537c 100644 --- a/ipa-core/tests/common/mod.rs +++ b/ipa-core/tests/common/mod.rs @@ -121,9 +121,7 @@ fn test_setup(config_path: &Path) -> [TcpListener; 6] { .arg("test-setup") .args(["--output-dir".as_ref(), config_path.as_os_str()]) .arg("--ports") - .args(ports.iter().take(3).map(|p| p.to_string())) - .arg("--shard-ports") - .args(ports.iter().skip(3).take(3).map(|p| p.to_string())); + .args(ports.chunks(2).map(|p| p[0].to_string())); command.status().unwrap_status(); sockets @@ -134,50 +132,47 @@ pub fn spawn_helpers( sockets: &[TcpListener; 6], https: bool, ) -> Vec { - zip( - [1, 2, 3], - zip(sockets.iter().take(3), sockets.iter().skip(3).take(3)), - ) - .map(|(id, (socket, shard_socket))| { - let mut command = Command::new(HELPER_BIN); - command - .args(["-i", &id.to_string()]) - .args(["--network".into(), config_path.join("network.toml")]) - .silent(); - - if https { + zip([1, 2, 3], sockets.chunks(2)) + .map(|(id, socket)| { + let mut command = Command::new(HELPER_BIN); command - .args(["--tls-cert".into(), config_path.join(format!("h{id}.pem"))]) - .args(["--tls-key".into(), config_path.join(format!("h{id}.key"))]) - .args([ - "--mk-public-key".into(), - config_path.join(format!("h{id}_mk.pub")), - ]) - .args([ - "--mk-private-key".into(), - config_path.join(format!("h{id}_mk.key")), - ]); - } else { - command.arg("--disable-https"); - } - - command.preserved_fds(vec![socket.as_raw_fd()]); - command.args(["--server-socket-fd", &socket.as_raw_fd().to_string()]); - command.preserved_fds(vec![shard_socket.as_raw_fd()]); - command.args([ - "--shard-server-socket-fd", - &shard_socket.as_raw_fd().to_string(), - ]); - - // something went wrong if command is terminated at this point. - let mut child = command.spawn().unwrap(); - if let Ok(Some(status)) = child.try_wait() { - panic!("Helper binary terminated early with status = {status}"); - } - - child.terminate_on_drop() - }) - .collect::>() + .args(["-i", &id.to_string()]) + .args(["--network".into(), config_path.join("network.toml")]) + .silent(); + + if https { + command + .args(["--tls-cert".into(), config_path.join(format!("h{id}.pem"))]) + .args(["--tls-key".into(), config_path.join(format!("h{id}.key"))]) + .args([ + "--mk-public-key".into(), + config_path.join(format!("h{id}_mk.pub")), + ]) + .args([ + "--mk-private-key".into(), + config_path.join(format!("h{id}_mk.key")), + ]); + } else { + command.arg("--disable-https"); + } + + command.preserved_fds(vec![socket[0].as_raw_fd()]); + command.args(["--server-socket-fd", &socket[0].as_raw_fd().to_string()]); + command.preserved_fds(vec![socket[1].as_raw_fd()]); + command.args([ + "--shard-server-socket-fd", + &socket[1].as_raw_fd().to_string(), + ]); + + // something went wrong if command is terminated at this point. + let mut child = command.spawn().unwrap(); + if let Ok(Some(status)) = child.try_wait() { + panic!("Helper binary terminated early with status = {status}"); + } + + child.terminate_on_drop() + }) + .collect::>() } pub fn test_multiply(config_dir: &Path, https: bool) { diff --git a/ipa-core/tests/helper_networks.rs b/ipa-core/tests/helper_networks.rs index 06adb56a7..4eb59a38e 100644 --- a/ipa-core/tests/helper_networks.rs +++ b/ipa-core/tests/helper_networks.rs @@ -85,9 +85,7 @@ fn keygen_confgen() { .args(["--output-dir".as_ref(), path.as_os_str()]) .args(["--keys-dir".as_ref(), path.as_os_str()]) .arg("--ports") - .args(ports.iter().take(3).map(|p| p.to_string())) - .arg("--shard-ports") - .args(ports.iter().skip(3).take(3).map(|p| p.to_string())) + .args(ports.chunks(2).map(|p| p[0].to_string())) .arg("--hosts") .args(["localhost", "localhost", "localhost"]); if overwrite { From 3e533510f704c8ba47881e60151653d196fca9cb Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Fri, 15 Nov 2024 16:22:20 -0800 Subject: [PATCH 16/47] Revert "Network.toml requires shard_port" This reverts commit 205df354f753420bdca24df4855037ddfc4e4da3. --- ipa-core/src/config.rs | 54 ++++++++++---------------------------- ipa-core/src/net/config.rs | 25 ++++++++++++++++++ 2 files changed, 39 insertions(+), 40 deletions(-) create mode 100644 ipa-core/src/net/config.rs diff --git a/ipa-core/src/config.rs b/ipa-core/src/config.rs index 7f087de7d..560f91039 100644 --- a/ipa-core/src/config.rs +++ b/ipa-core/src/config.rs @@ -200,22 +200,6 @@ struct ShardedNetworkToml { pub client: ClientConfig, } -impl ShardedNetworkToml { - fn missing_shard_ports(&self) -> Vec { - self.peers - .iter() - .enumerate() - .filter_map(|(i, peer)| { - if peer.shard_port.is_some() { - None - } else { - Some(i) - } - }) - .collect() - } -} - /// This struct is only used by [`parse_sharded_network_toml`] to generate [`PeerConfig`]. It /// contains an optional `shard_port`. #[derive(Clone, Debug, Deserialize)] @@ -262,7 +246,18 @@ fn parse_sharded_network_toml(input: &str) -> Result // Validate sharding config is set let any_shard_port_set = parsed.peers.iter().any(|peer| peer.shard_port.is_some()); if any_shard_port_set || parsed.peers.len() > 3 { - let missing_ports = parsed.missing_shard_ports(); + let missing_ports: Vec = parsed + .peers + .iter() + .enumerate() + .filter_map(|(i, peer)| { + if peer.shard_port.is_some() { + None + } else { + Some(i) + } + }) + .collect(); if !missing_ports.is_empty() { return Err(Error::MissingShardPorts(missing_ports)); } @@ -273,10 +268,8 @@ fn parse_sharded_network_toml(input: &str) -> Result /// Reads a the config for a specific, single, sharded server from string. Expects config to be /// toml format. The server in the network is specified via `id`, `shard_index` and -/// `shard_count`. This function expects shard ports to be set for all peers. -/// -/// The first 3 peers corresponds to the leaders Ring. H1 shard 0, H2 shard 0, and H3 shard 0. -/// The next 3 correspond to the next ring with `shard_index` equals 1 and so on. +/// `shard_count`. +/// The first 3 entries corresponds to the leaders Ring. H1 shard 0, H2, shard 0, and H3 shard 0. /// /// Other methods to read the network.toml exist depending on the use, for example /// [`NetworkConfig::from_toml_str`] reads a non-sharded config. @@ -292,10 +285,6 @@ pub fn sharded_server_from_toml_str( shard_count: ShardIndex, ) -> Result<(NetworkConfig, NetworkConfig), Error> { let all_network = parse_sharded_network_toml(input)?; - let missing_ports = all_network.missing_shard_ports(); - if !missing_ports.is_empty() { - return Err(Error::MissingShardPorts(missing_ports)); - } let ix: usize = shard_index.as_index(); let ix_count: usize = shard_count.as_index(); @@ -859,21 +848,6 @@ mod tests { )); } - /// Check that shard ports are given for [`sharded_server_from_toml_str`] or error is returned. - #[test] - fn parse_sharded_without_shard_ports() { - // Second, I test the networkconfig parsing - assert!(matches!( - sharded_server_from_toml_str( - &NON_SHARDED_COMPAT, - HelperIdentity::TWO, - ShardIndex::FIRST, - ShardIndex::from(1) - ), - Err(Error::MissingShardPorts(_)) - )); - } - /// Testing happy case of a sharded network config #[test] fn happy_parse_sharded_network_toml() { diff --git a/ipa-core/src/net/config.rs b/ipa-core/src/net/config.rs new file mode 100644 index 000000000..2c5d3e5ed --- /dev/null +++ b/ipa-core/src/net/config.rs @@ -0,0 +1,25 @@ +use std::{ + fmt::Debug, + io::{self, BufRead}, + sync::Arc, +}; + +use config::{Config, File, FileFormat}; +use hyper::{header::HeaderName, Uri}; +use once_cell::sync::Lazy; +use rustls::crypto::CryptoProvider; +use rustls_pki_types::CertificateDer; +use ::serde::Deserialize; + +use crate::{ + config::{ClientConfig, OwnedCertificate, OwnedPrivateKey, PeerConfig}, helpers::{HelperIdentity, TransportIdentity}, serde, sharding::ShardIndex +}; + + + +#[cfg(all(test, unit_test))] +mod tests { + + + +} From 4ac03a4064ac4ca826d7ec64cd693fb0d5035692 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Fri, 15 Nov 2024 16:22:30 -0800 Subject: [PATCH 17/47] Revert "Using 3 peer configs for a ring instead of 6" This reverts commit ccdd1f0d8d5f704616df42d04c0d5fea92e03a69. --- ipa-core/src/bin/helper.rs | 8 +- ipa-core/src/config.rs | 513 ++++++------------------------------- ipa-core/src/net/config.rs | 25 -- ipa-core/src/utils/mod.rs | 17 -- 4 files changed, 88 insertions(+), 475 deletions(-) delete mode 100644 ipa-core/src/net/config.rs diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 8eb006e94..9b558c862 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -159,10 +159,10 @@ fn create_client_identity( } } -/// Creates a [`TcpListener`] from an optional raw file descriptor. Safety notes: -/// 1. The `--server-socket-fd` option is only intended for use in tests, not in production. -/// 2. This must be the only call to from_raw_fd for this file descriptor, to ensure it has -/// only one owner. +// SAFETY: +// 1. The `--server-socket-fd` option is only intended for use in tests, not in production. +// 2. This must be the only call to from_raw_fd for this file descriptor, to ensure it has +// only one owner. fn create_listener(server_socket_fd: Option) -> Result, BoxError> { server_socket_fd .map(|fd| { diff --git a/ipa-core/src/config.rs b/ipa-core/src/config.rs index 560f91039..408213b3e 100644 --- a/ipa-core/src/config.rs +++ b/ipa-core/src/config.rs @@ -3,7 +3,6 @@ use std::{ fmt::{Debug, Formatter}, iter::zip, path::PathBuf, - str::FromStr, time::Duration, }; @@ -39,8 +38,6 @@ pub enum Error { InvalidNetworkSize(usize), #[error(transparent)] IOError(#[from] std::io::Error), - #[error("Missing shard ports for peers {0:?}")] - MissingShardPorts(Vec), } /// Configuration describing either 3 peers in a Ring or N shard peers. In a non-sharded case a @@ -122,6 +119,88 @@ impl NetworkConfig { } } +/// Reads a the config for a specific, single, sharded server from string. Expects config to be +/// toml format. The server in the network is specified via `id`, `shard_index` and +/// `shard_count`. +/// +/// First we read the configuration without assigning any identities. The number of peers in the +/// configuration must be a multiple of 6, or 3 as a special case to support older, non-sharded +/// configurations. +/// +/// If there are 3 entries, we assign helper identities for them. We create a dummy sharded +/// configuration. +/// +/// If there are any multiple of 6 peers, then peer assignment is as follows: +/// By rings (to be reminiscent of the previous config). The first 6 entries corresponds to the +/// leaders Ring. H1 shard 0, H2, shard 0, and H3 shard 0. The next 6 correspond increases the +/// shard index by one. +/// +/// Other methods to read the network.toml exist depending on the use, for example +/// [`NetworkConfig::from_toml_str`] reads a non-sharded config. +/// TODO: There will be one to read the information relevant for the RC (doesn't need shard +/// info) +/// +/// # Errors +/// if `input` is in an invalid format +pub fn sharded_server_from_toml_str( + input: &str, + id: HelperIdentity, + shard_index: ShardIndex, + shard_count: ShardIndex, +) -> Result<(NetworkConfig, NetworkConfig), Error> { + use config::{Config, File, FileFormat}; + + let all_network: NetworkConfig = Config::builder() + .add_source(File::from_str(input, FileFormat::Toml)) + .build()? + .try_deserialize()?; + + let ix: usize = shard_index.as_index(); + let ix_count: usize = shard_count.as_index(); + let mpc_id: usize = id.as_index(); + + let total_peers = all_network.peers.len(); + if total_peers == 3 { + let mpc_network = NetworkConfig { + peers: all_network.peers.clone(), + client: all_network.client.clone(), + identities: HelperIdentity::make_three().to_vec(), + }; + let shard_network = NetworkConfig { + peers: vec![all_network.peers[mpc_id].clone()], + client: all_network.client, + identities: vec![ShardIndex(0)], + }; + Ok((mpc_network, shard_network)) + } else if total_peers > 0 && total_peers % 6 == 0 { + let mpc_network = NetworkConfig { + peers: all_network + .peers + .clone() + .into_iter() + .skip(ix * 6) + .take(3) + .collect(), + client: all_network.client.clone(), + identities: HelperIdentity::make_three().to_vec(), + }; + let shard_network = NetworkConfig { + peers: all_network + .peers + .into_iter() + .skip(3 + mpc_id) + .step_by(6) + .take(ix_count) + .collect(), + client: all_network.client, + identities: shard_count.iter().collect(), + }; + Ok((mpc_network, shard_network)) + } else { + Err(Error::InvalidNetworkSize(total_peers)) + } +} + impl NetworkConfig { /// # Panics /// In the unexpected case there are more than max usize shards. @@ -189,135 +268,6 @@ impl NetworkConfig { } } -/// This struct is only used by [`parse_sharded_network_toml`] to parse the entire network. -/// Unlike [`NetworkConfig`], this one doesn't have identities. -#[derive(Clone, Debug, Deserialize)] -struct ShardedNetworkToml { - pub peers: Vec, - - /// HTTP client configuration. - #[serde(default)] - pub client: ClientConfig, -} - -/// This struct is only used by [`parse_sharded_network_toml`] to generate [`PeerConfig`]. It -/// contains an optional `shard_port`. -#[derive(Clone, Debug, Deserialize)] -struct ShardedPeerConfigToml { - #[serde(flatten)] - pub config: PeerConfig, - pub shard_port: Option, -} - -impl ShardedPeerConfigToml { - /// Clones the inner Peer. - fn to_mpc_peer(&self) -> PeerConfig { - self.config.clone() - } - - /// Create a new Peer but its url using [`ShardedPeerConfigToml::shard_port`]. - fn to_shard_peer(&self) -> PeerConfig { - let url = self.config.url.to_string(); - let new_url = format!( - "{}{}", - &url[..=url.find(':').unwrap()], - self.shard_port.expect("Shard port should be set") - ); - let mut shard_peer = self.config.clone(); - shard_peer.url = Uri::from_str(&new_url).expect("Problem creating uri with sharded port"); - shard_peer - } -} - -/// Parses a [`ShardedNetworkToml`] from a network.toml file. Validates that sharding ports are set -/// if necessary. The number of peers needs to be a multiple of 3. -fn parse_sharded_network_toml(input: &str) -> Result { - use config::{Config, File, FileFormat}; - - let parsed: ShardedNetworkToml = Config::builder() - .add_source(File::from_str(input, FileFormat::Toml)) - .build()? - .try_deserialize()?; - - if parsed.peers.len() % 3 != 0 { - return Err(Error::InvalidNetworkSize(parsed.peers.len())); - } - - // Validate sharding config is set - let any_shard_port_set = parsed.peers.iter().any(|peer| peer.shard_port.is_some()); - if any_shard_port_set || parsed.peers.len() > 3 { - let missing_ports: Vec = parsed - .peers - .iter() - .enumerate() - .filter_map(|(i, peer)| { - if peer.shard_port.is_some() { - None - } else { - Some(i) - } - }) - .collect(); - if !missing_ports.is_empty() { - return Err(Error::MissingShardPorts(missing_ports)); - } - } - - Ok(parsed) -} - -/// Reads a the config for a specific, single, sharded server from string. Expects config to be -/// toml format. The server in the network is specified via `id`, `shard_index` and -/// `shard_count`. -/// The first 3 entries corresponds to the leaders Ring. H1 shard 0, H2, shard 0, and H3 shard 0. -/// -/// Other methods to read the network.toml exist depending on the use, for example -/// [`NetworkConfig::from_toml_str`] reads a non-sharded config. -/// TODO: There will be one to read the information relevant for the RC (doesn't need shard -/// info) -/// -/// # Errors -/// if `input` is in an invalid format -pub fn sharded_server_from_toml_str( - input: &str, - id: HelperIdentity, - shard_index: ShardIndex, - shard_count: ShardIndex, -) -> Result<(NetworkConfig, NetworkConfig), Error> { - let all_network = parse_sharded_network_toml(input)?; - - let ix: usize = shard_index.as_index(); - let ix_count: usize = shard_count.as_index(); - let mpc_id: usize = id.as_index(); - - let mpc_network = NetworkConfig { - peers: all_network - .peers - .iter() - .map(ShardedPeerConfigToml::to_mpc_peer) - .skip(ix * 3) - .take(3) - .collect(), - client: all_network.client.clone(), - identities: HelperIdentity::make_three().to_vec(), - }; - - let shard_network = NetworkConfig { - peers: all_network - .peers - .iter() - .map(ShardedPeerConfigToml::to_shard_peer) - .skip(mpc_id) - .step_by(3) - .take(ix_count) - .collect(), - client: all_network.client, - identities: shard_count.iter().collect(), - }; - - Ok((mpc_network, shard_network)) -} - #[derive(Clone, Debug, Deserialize)] pub struct PeerConfig { /// Peer URL @@ -661,21 +611,15 @@ mod tests { use hpke::{kem::X25519HkdfSha256, Kem}; use hyper::Uri; - use once_cell::sync::Lazy; use rand::rngs::StdRng; use rand_core::SeedableRng; - use super::{ - parse_sharded_network_toml, sharded_server_from_toml_str, NetworkConfig, PeerConfig, - }; + use super::{NetworkConfig, PeerConfig}; use crate::{ - config::{ - ClientConfig, Error, HpkeClientConfig, Http2Configurator, HttpClientConfigurator, - }, + config::{ClientConfig, HpkeClientConfig, Http2Configurator, HttpClientConfigurator}, helpers::HelperIdentity, net::test::TestConfigBuilder, sharding::ShardIndex, - utils::replace_all, }; const URI_1: &str = "http://localhost:3000"; @@ -771,293 +715,4 @@ mod tests { let conf = NetworkConfig::new_shards(vec![pc1.clone()], client); assert_eq!(conf.peers[ShardIndex(0)].url, pc1.url); } - - #[test] - fn parse_sharded_server_happy() { - // Asuming position of the second helper in the second shard (the middle server in the 3 x 3) - let (mpc, shard) = sharded_server_from_toml_str( - &SHARDED_OK_REPEAT, - HelperIdentity::TWO, - ShardIndex::from(1), - ShardIndex::from(3), - ) - .unwrap(); - assert_eq!( - vec![ - "helper1.prod.ipa-helper.shard1.dev:443", - "helper2.prod.ipa-helper.shard1.dev:443", - "helper3.prod.ipa-helper.shard1.dev:443" - ], - mpc.peers - .into_iter() - .map(|p| p.url.to_string()) - .collect::>() - ); - assert_eq!( - vec![ - "helper2.prod.ipa-helper.shard0.dev:555", - "helper2.prod.ipa-helper.shard1.dev:555", - "helper2.prod.ipa-helper.shard2.dev:555" - ], - shard - .peers - .into_iter() - .map(|p| p.url.to_string()) - .collect::>() - ); - } - - /// Tests that the url of a shard gets updated with the shard port. - #[test] - fn transform_sharded_peers() { - let mut n = parse_sharded_network_toml(&SHARDED_OK_REPEAT).unwrap(); - assert_eq!( - "helper3.prod.ipa-helper.shard2.dev:666", - n.peers.pop().unwrap().to_shard_peer().url - ); - assert_eq!( - "helper2.prod.ipa-helper.shard2.dev:555", - n.peers.pop().unwrap().to_shard_peer().url - ); - } - - /// Expects an error if the number of peers isn't a multiple of 3 - #[test] - fn invalid_nr_of_peers() { - assert!(matches!( - parse_sharded_network_toml(&SHARDED_8), - Err(Error::InvalidNetworkSize(_)) - )); - } - - /// If any sharded port is set (indicating this is a sharding config), then ALL ports must be set. - #[test] - fn parse_network_toml_shard_port_some_set() { - assert!(matches!( - parse_sharded_network_toml(&SHARDED_COMPAT_ONE_PORT), - Err(Error::MissingShardPorts(_)) - )); - } - - /// If there are more than 3 peers configured (indicating this is a sharding config), then ALL ports must be set. - #[test] - fn parse_network_toml_shard_port_set() { - assert!(matches!( - parse_sharded_network_toml(&SHARDED_MISSING_PORTS_REPEAT), - Err(Error::MissingShardPorts(_)) - )); - } - - /// Testing happy case of a sharded network config - #[test] - fn happy_parse_sharded_network_toml() { - let r_entire_network = parse_sharded_network_toml(SHARDED_OK); - assert!(r_entire_network.is_ok()); - let entire_network = r_entire_network.unwrap(); - assert!(matches!( - entire_network.client.http_config, - HttpClientConfigurator::Http2(_) - )); - assert_eq!(3, entire_network.peers.len()); - assert_eq!( - "helper3.prod.ipa-helper.shard0.dev:443", - entire_network.peers[2].config.url - ); - assert_eq!(Some(666), entire_network.peers[2].shard_port); - } - - /// Testing happy case of a longer sharded network config - #[test] - fn happy_parse_larger_sharded_network_toml() { - let r_entire_network = parse_sharded_network_toml(&SHARDED_OK_REPEAT); - assert!(r_entire_network.is_ok()); - let entire_network = r_entire_network.unwrap(); - assert_eq!(9, entire_network.peers.len()); - assert_eq!(Some(666), entire_network.peers[8].shard_port); - } - - /// This test validates that the new logic that handles sharded configurations can also handle the previous version - #[test] - fn parse_non_sharded_network_toml() { - let r_entire_network = parse_sharded_network_toml(&NON_SHARDED_COMPAT); - assert!(r_entire_network.is_ok()); - let entire_network = r_entire_network.unwrap(); - assert!(matches!( - entire_network.client.http_config, - HttpClientConfigurator::Http2(_) - )); - assert_eq!(3, entire_network.peers.len()); - assert_eq!( - "helper3.prod.ipa-helper.dev:443", - entire_network.peers[2].config.url - ); - } - - // Following are some large &str const used for tests - - /// Valid: A non-sharded network toml, just how they used to be - static NON_SHARDED_COMPAT: Lazy = Lazy::new(|| format!("{CLIENT}{P1}{REST}")); - - /// Invalid: Same as [`NON_SHARDED_COMPAT`] but with a single `shard_port` set. - static SHARDED_COMPAT_ONE_PORT: Lazy = - Lazy::new(|| format!("{CLIENT}{P1}\nshard_port = 777\n{REST}")); - - /// Helper const used to create client configs - const CLIENT: &str = "[client.http_config] -ping_interval_secs = 90.0 -version = \"http2\" -"; - - /// Helper const that has the first part of a Peer, just before were `shard_port` should be - /// specified. - const P1: &str = " -[[peers]] -certificate = \"\"\" ------BEGIN CERTIFICATE----- -MIIBmzCCAUGgAwIBAgIIMlnveFys5QUwCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb -aGVscGVyMS5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMwM1oXDTI0 -MTIwNDAzMzMwM1owJjEkMCIGA1UEAwwbaGVscGVyMS5wcm9kLmlwYS1oZWxwZXIu -ZGV2MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEWmrrkaKM7HQ0Y3ZGJtHB7vfG -cT/hDCXCoob4pJ/fpPDMrqhiwTTck3bNOuzv9QIx+p5C2Qp8u67rYfK78w86NaNZ -MFcwJgYDVR0RBB8wHYIbaGVscGVyMS5wcm9kLmlwYS1oZWxwZXIuZGV2MA4GA1Ud -DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI -zj0EAwIDSAAwRQIhAKVdDCQeXLRXDYXy4b1N1UxD/JPuD9H7zeRb8/nmIDTfAiBL -a6L0t1Ug8i2RcequSo21x319Tvs5nUbGwzMFSS5wKA== ------END CERTIFICATE----- -\"\"\" -url = \"helper1.prod.ipa-helper.dev:443\""; - - /// The rest of a configuration - const REST: &str = " -[peers.hpke] -public_key = \"f458d5e1989b2b8f5dacd4143276aa81eaacf7449744ab1251ff667c43550756\" - -[[peers]] -certificate = \"\"\" ------BEGIN CERTIFICATE----- -MIIBmzCCAUGgAwIBAgIITOtoca16QckwCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb -aGVscGVyMi5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMwOFoXDTI0 -MTIwNDAzMzMwOFowJjEkMCIGA1UEAwwbaGVscGVyMi5wcm9kLmlwYS1oZWxwZXIu -ZGV2MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAETxOH4ATz6kBxLuRznKDFRugm -XKmH7mzRB9wn5vaVlVpDzf4nDHJ+TTzSS6Lb3YLsA7jrXDx+W7xPLGow1+9FNqNZ -MFcwJgYDVR0RBB8wHYIbaGVscGVyMi5wcm9kLmlwYS1oZWxwZXIuZGV2MA4GA1Ud -DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI -zj0EAwIDSAAwRQIhAI4G5ICVm+v5KK5Y8WVetThtNCXGykUBAM1eE973FBOUAiAS -XXgJe9q9hAfHf0puZbv0j0tGY3BiqCkJJaLvK7ba+g== ------END CERTIFICATE----- -\"\"\" -url = \"helper2.prod.ipa-helper.dev:443\" - -[peers.hpke] -public_key = \"62357179868e5594372b801ddf282c8523806a868a2bff2685f66aa05ffd6c22\" - -[[peers]] -certificate = \"\"\" ------BEGIN CERTIFICATE----- -MIIBmzCCAUGgAwIBAgIIaf7eDCnXh2swCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb -aGVscGVyMy5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMxMloXDTI0 -MTIwNDAzMzMxMlowJjEkMCIGA1UEAwwbaGVscGVyMy5wcm9kLmlwYS1oZWxwZXIu -ZGV2MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEIMqxCCtu4joFr8YtOrEtq230 -NuTtUAaJHIHNtv4CvpUcbtlFMWFYUUum7d22A8YTfUeccG5PsjjCoQG/dhhSbKNZ -MFcwJgYDVR0RBB8wHYIbaGVscGVyMy5wcm9kLmlwYS1oZWxwZXIuZGV2MA4GA1Ud -DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI -zj0EAwIDSAAwRQIhAOTSQWbN7kfIatNJEwWTBL4xOY88E3+SOnBNExCsTkQuAiBB -/cwOQQUEeE4llrDp+EnyGbzmVm5bINz8gePIxkKqog== ------END CERTIFICATE----- -\"\"\" -url = \"helper3.prod.ipa-helper.dev:443\" - -[peers.hpke] -public_key = \"55f87a8794b4de9a60f8ede9ed000f5f10c028e22390922efc4fb63bc6be0a61\" -"; - - /// Valid: A sharded configuration - const SHARDED_OK: &str = " -[[peers]] -certificate = \"\"\" ------BEGIN CERTIFICATE----- -MIIBmzCCAUGgAwIBAgIIMlnveFys5QUwCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb -aGVscGVyMS5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMwM1oXDTI0 -MTIwNDAzMzMwM1owJjEkMCIGA1UEAwwbaGVscGVyMS5wcm9kLmlwYS1oZWxwZXIu -ZGV2MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEWmrrkaKM7HQ0Y3ZGJtHB7vfG -cT/hDCXCoob4pJ/fpPDMrqhiwTTck3bNOuzv9QIx+p5C2Qp8u67rYfK78w86NaNZ -MFcwJgYDVR0RBB8wHYIbaGVscGVyMS5wcm9kLmlwYS1oZWxwZXIuZGV2MA4GA1Ud -DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI -zj0EAwIDSAAwRQIhAKVdDCQeXLRXDYXy4b1N1UxD/JPuD9H7zeRb8/nmIDTfAiBL -a6L0t1Ug8i2RcequSo21x319Tvs5nUbGwzMFSS5wKA== ------END CERTIFICATE----- -\"\"\" -url = \"helper1.prod.ipa-helper.shard0.dev:443\" -shard_port = 444 - -[peers.hpke] -public_key = \"f458d5e1989b2b8f5dacd4143276aa81eaacf7449744ab1251ff667c43550756\" - -[[peers]] -certificate = \"\"\" ------BEGIN CERTIFICATE----- -MIIBmzCCAUGgAwIBAgIITOtoca16QckwCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb -aGVscGVyMi5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMwOFoXDTI0 -MTIwNDAzMzMwOFowJjEkMCIGA1UEAwwbaGVscGVyMi5wcm9kLmlwYS1oZWxwZXIu -ZGV2MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAETxOH4ATz6kBxLuRznKDFRugm -XKmH7mzRB9wn5vaVlVpDzf4nDHJ+TTzSS6Lb3YLsA7jrXDx+W7xPLGow1+9FNqNZ -MFcwJgYDVR0RBB8wHYIbaGVscGVyMi5wcm9kLmlwYS1oZWxwZXIuZGV2MA4GA1Ud -DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI -zj0EAwIDSAAwRQIhAI4G5ICVm+v5KK5Y8WVetThtNCXGykUBAM1eE973FBOUAiAS -XXgJe9q9hAfHf0puZbv0j0tGY3BiqCkJJaLvK7ba+g== ------END CERTIFICATE----- -\"\"\" -url = \"helper2.prod.ipa-helper.shard0.dev:443\" -shard_port = 555 - -[peers.hpke] -public_key = \"62357179868e5594372b801ddf282c8523806a868a2bff2685f66aa05ffd6c22\" - -[[peers]] -certificate = \"\"\" ------BEGIN CERTIFICATE----- -MIIBmzCCAUGgAwIBAgIIaf7eDCnXh2swCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb -aGVscGVyMy5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMxMloXDTI0 -MTIwNDAzMzMxMlowJjEkMCIGA1UEAwwbaGVscGVyMy5wcm9kLmlwYS1oZWxwZXIu -ZGV2MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEIMqxCCtu4joFr8YtOrEtq230 -NuTtUAaJHIHNtv4CvpUcbtlFMWFYUUum7d22A8YTfUeccG5PsjjCoQG/dhhSbKNZ -MFcwJgYDVR0RBB8wHYIbaGVscGVyMy5wcm9kLmlwYS1oZWxwZXIuZGV2MA4GA1Ud -DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI -zj0EAwIDSAAwRQIhAOTSQWbN7kfIatNJEwWTBL4xOY88E3+SOnBNExCsTkQuAiBB -/cwOQQUEeE4llrDp+EnyGbzmVm5bINz8gePIxkKqog== ------END CERTIFICATE----- -\"\"\" -url = \"helper3.prod.ipa-helper.shard0.dev:443\" -shard_port = 666 - -[peers.hpke] -public_key = \"55f87a8794b4de9a60f8ede9ed000f5f10c028e22390922efc4fb63bc6be0a61\" -"; - - /// Valid: Three sharded configs together for 9 - static SHARDED_OK_REPEAT: Lazy = Lazy::new(|| { - format!( - "{}{}{}", - SHARDED_OK, - replace_all(SHARDED_OK, "shard0", "shard1"), - replace_all(SHARDED_OK, "shard0", "shard2") - ) - }); - - /// Invalid: A network toml with 8 entries - static SHARDED_8: Lazy = Lazy::new(|| { - let last_peers_index = SHARDED_OK_REPEAT.rfind("[[peers]]").unwrap(); - SHARDED_OK_REPEAT[..last_peers_index].to_string() - }); - - /// Invalid: Same as [`SHARDED_OK_REPEAT`] but without the expected ports - static SHARDED_MISSING_PORTS_REPEAT: Lazy = Lazy::new(|| { - let lines: Vec<&str> = SHARDED_OK_REPEAT.lines().collect(); - let new_lines: Vec = lines - .iter() - .filter(|line| !line.starts_with("shard_port =")) - .map(std::string::ToString::to_string) - .collect(); - new_lines.join("\n") - }); } diff --git a/ipa-core/src/net/config.rs b/ipa-core/src/net/config.rs deleted file mode 100644 index 2c5d3e5ed..000000000 --- a/ipa-core/src/net/config.rs +++ /dev/null @@ -1,25 +0,0 @@ -use std::{ - fmt::Debug, - io::{self, BufRead}, - sync::Arc, -}; - -use config::{Config, File, FileFormat}; -use hyper::{header::HeaderName, Uri}; -use once_cell::sync::Lazy; -use rustls::crypto::CryptoProvider; -use rustls_pki_types::CertificateDer; -use ::serde::Deserialize; - -use crate::{ - config::{ClientConfig, OwnedCertificate, OwnedPrivateKey, PeerConfig}, helpers::{HelperIdentity, TransportIdentity}, serde, sharding::ShardIndex -}; - - - -#[cfg(all(test, unit_test))] -mod tests { - - - -} diff --git a/ipa-core/src/utils/mod.rs b/ipa-core/src/utils/mod.rs index c19ada348..e8dfd95ae 100644 --- a/ipa-core/src/utils/mod.rs +++ b/ipa-core/src/utils/mod.rs @@ -5,20 +5,3 @@ mod power_of_two; #[cfg(target_pointer_width = "64")] pub use power_of_two::NonZeroU32PowerOfTwo; - -/// Replaces all occurrences of `from` with `to` in `s`. -#[allow(dead_code)] -pub fn replace_all(s: &str, from: &str, to: &str) -> String { - let mut result = String::new(); - let mut i = 0; - while i < s.len() { - if s[i..].starts_with(from) { - result.push_str(to); - i += from.len(); - } else { - result.push(s.chars().nth(i).unwrap()); - i += 1; - } - } - result -} From 5904d98f0ffc6826895263af0dd29c8445cae487 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Fri, 15 Nov 2024 16:22:39 -0800 Subject: [PATCH 18/47] Revert "Starting sharded helpers" This reverts commit a3d4097add03ec88b95af105e1169f1357d30865. --- ipa-core/src/bin/helper.rs | 155 +++++++++--------------------- ipa-core/src/config.rs | 88 +---------------- ipa-core/tests/common/mod.rs | 21 ++-- ipa-core/tests/helper_networks.rs | 6 +- 4 files changed, 61 insertions(+), 209 deletions(-) diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 9b558c862..734db1bc5 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -8,23 +8,17 @@ use std::{ }; use clap::{self, Parser, Subcommand}; -use futures::future::join; use hyper::http::uri::Scheme; use ipa_core::{ cli::{ client_config_setup, keygen, test_setup, ConfGenArgs, KeygenArgs, LoggingHandle, TestSetupArgs, Verbosity, }, - config::{ - hpke_registry, sharded_server_from_toml_str, HpkeServerConfig, ServerConfig, TlsConfig, - }, + config::{hpke_registry, HpkeServerConfig, NetworkConfig, ServerConfig, TlsConfig}, error::BoxError, executor::IpaRuntime, helpers::HelperIdentity, - net::{ - ClientIdentity, ConnectionFlavor, IpaHttpClient, MpcHttpTransport, Shard, - ShardHttpTransport, - }, + net::{ClientIdentity, IpaHttpClient, MpcHttpTransport, ShardHttpTransport}, sharding::ShardIndex, AppConfig, AppSetup, NonZeroU32PowerOfTwo, }; @@ -61,31 +55,16 @@ struct ServerArgs { #[arg(short, long, required = true)] identity: Option, - #[arg(default_value = "0")] - shard_index: Option, - - #[arg(default_value = "1")] - shard_count: Option, - /// Port to listen on #[arg(short, long, default_value = "3000")] port: Option, - #[arg(default_value = "6000")] - shard_port: Option, - - /// Use the supplied prebound socket instead of binding a new socket for mpc + /// Use the supplied prebound socket instead of binding a new socket /// /// This is only intended for avoiding port conflicts in tests. #[arg(hide = true, long)] server_socket_fd: Option, - /// Use the supplied prebound socket instead of binding a new socket for shard server - /// - /// This is only intended for avoiding port conflicts in tests. - #[arg(hide = true, long)] - shard_server_socket_fd: Option, - /// Use insecure HTTP #[arg(short = 'k', long)] disable_https: bool, @@ -94,7 +73,7 @@ struct ServerArgs { #[arg(long, required = true)] network: Option, - /// TLS certificate for helper-to-helper and shard-to-shard communication + /// TLS certificate for helper-to-helper communication #[arg( long, visible_alias("cert"), @@ -103,7 +82,7 @@ struct ServerArgs { )] tls_cert: Option, - /// TLS key for helper-to-helper and shard-to-shard communication + /// TLS key for helper-to-helper communication #[arg(long, visible_alias("key"), requires = "tls_cert")] tls_key: Option, @@ -135,58 +114,24 @@ fn read_file(path: &Path) -> Result, BoxError> { .map_err(|e| format!("failed to open file {}: {e:?}", path.display()))?) } -/// Helper function that creates the client identity; either with certificates if they are provided -/// or just with headers otherwise. This works both for sharded and helper configs. -fn create_client_identity( - id: F::Identity, - tls_cert: Option, - tls_key: Option, -) -> Result<(ClientIdentity, Option), BoxError> { - match (tls_cert, tls_key) { +async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), BoxError> { + let my_identity = HelperIdentity::try_from(args.identity.expect("enforced by clap")).unwrap(); + + let (identity, server_tls) = match (args.tls_cert, args.tls_key) { (Some(cert_file), Some(key_file)) => { let mut key = read_file(&key_file)?; let mut certs = read_file(&cert_file)?; - Ok(( - ClientIdentity::::from_pkcs8(&mut certs, &mut key)?, + ( + ClientIdentity::from_pkcs8(&mut certs, &mut key)?, Some(TlsConfig::File { certificate_file: cert_file, private_key_file: key_file, }), - )) + ) } - (None, None) => Ok((ClientIdentity::Header(id), None)), - _ => Err("should have been rejected by clap".into()), - } -} - -// SAFETY: -// 1. The `--server-socket-fd` option is only intended for use in tests, not in production. -// 2. This must be the only call to from_raw_fd for this file descriptor, to ensure it has -// only one owner. -fn create_listener(server_socket_fd: Option) -> Result, BoxError> { - server_socket_fd - .map(|fd| { - let listener = unsafe { TcpListener::from_raw_fd(fd) }; - if listener.local_addr().is_ok() { - info!("adopting fd {fd} as listening socket"); - Ok(listener) - } else { - Err(BoxError::from(format!("the server was asked to listen on fd {fd}, but it does not appear to be a valid socket"))) - } - }) - .transpose() -} - -async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), BoxError> { - let my_identity = HelperIdentity::try_from(args.identity.expect("enforced by clap")).unwrap(); - let shard_index = ShardIndex::from(args.shard_index.expect("enforced by clap")); - let shard_count = ShardIndex::from(args.shard_count.expect("enforced by clap")); - assert!(shard_index < shard_count); - - let (identity, server_tls) = - create_client_identity(my_identity, args.tls_cert.clone(), args.tls_key.clone())?; - let (shard_identity, shard_server_tls) = - create_client_identity(shard_index, args.tls_cert, args.tls_key)?; + (None, None) => (ClientIdentity::Header(my_identity), None), + _ => panic!("should have been rejected by clap"), + }; let mk_encryption = args.mk_private_key.map(|sk_path| HpkeServerConfig::File { private_key_file: sk_path, @@ -204,13 +149,6 @@ async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), B port: args.port, disable_https: args.disable_https, tls: server_tls, - hpke_config: mk_encryption.clone(), - }; - - let shard_server_config = ServerConfig { - port: args.shard_port, - disable_https: args.disable_https, - tls: shard_server_tls, hpke_config: mk_encryption, }; @@ -219,48 +157,60 @@ async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), B } else { Scheme::HTTPS }; - let network_config_path = args.network.as_deref().unwrap(); - let network_config_string = &fs::read_to_string(network_config_path)?; - let (mut mpc_network, mut shard_network) = - sharded_server_from_toml_str(network_config_string, my_identity, shard_index, shard_count)?; - mpc_network = mpc_network.override_scheme(&scheme); - shard_network = shard_network.override_scheme(&scheme); + let network_config = NetworkConfig::from_toml_str(&fs::read_to_string(network_config_path)?)? + .override_scheme(&scheme); + + // TODO: Following is just temporary until Shard Transport is actually used. + let shard_clients_config = network_config.client.clone(); + let shard_server_config = server_config.clone(); + // --- let http_runtime = new_http_runtime(&logging_handle); let clients = IpaHttpClient::from_conf( &IpaRuntime::from_tokio_runtime(&http_runtime), - &mpc_network, + &network_config, &identity, ); let (transport, server) = MpcHttpTransport::new( IpaRuntime::from_tokio_runtime(&http_runtime), my_identity, server_config, - mpc_network, + network_config, &clients, Some(handler), ); - let shard_clients = IpaHttpClient::::shards_from_conf( - &IpaRuntime::from_tokio_runtime(&http_runtime), - &shard_network, - &shard_identity, - ); - let (shard_transport, shard_server) = ShardHttpTransport::new( + // TODO: Following is just temporary until Shard Transport is actually used. + let shard_network_config = NetworkConfig::new_shards(vec![], shard_clients_config); + let (shard_transport, _shard_server) = ShardHttpTransport::new( IpaRuntime::from_tokio_runtime(&http_runtime), - shard_index, - shard_count, + ShardIndex::FIRST, + ShardIndex::from(1), shard_server_config, - shard_network, - shard_clients, + shard_network_config, + vec![], Some(shard_handler), ); + // --- let _app = setup.connect(transport.clone(), shard_transport.clone()); - let listener = create_listener(args.server_socket_fd)?; - let shard_listener = create_listener(args.shard_server_socket_fd)?; + let listener = args.server_socket_fd + .map(|fd| { + // SAFETY: + // 1. The `--server-socket-fd` option is only intended for use in tests, not in production. + // 2. This must be the only call to from_raw_fd for this file descriptor, to ensure it has + // only one owner. + let listener = unsafe { TcpListener::from_raw_fd(fd) }; + if listener.local_addr().is_ok() { + info!("adopting fd {fd} as listening socket"); + Ok(listener) + } else { + Err(BoxError::from(format!("the server was asked to listen on fd {fd}, but it does not appear to be a valid socket"))) + } + }) + .transpose()?; let (_addr, server_handle) = server .start_on( @@ -270,17 +220,8 @@ async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), B None as Option<()>, ) .await; - let (_saddr, shard_server_handle) = shard_server - .start_on( - &IpaRuntime::from_tokio_runtime(&http_runtime), - shard_listener, - // TODO, trace based on the content of the query. - None as Option<()>, - ) - .await; - - join(server_handle, shard_server_handle).await; + server_handle.await; [query_runtime, http_runtime].map(Runtime::shutdown_background); Ok(()) diff --git a/ipa-core/src/config.rs b/ipa-core/src/config.rs index 408213b3e..49384b814 100644 --- a/ipa-core/src/config.rs +++ b/ipa-core/src/config.rs @@ -16,7 +16,7 @@ use tokio::fs; use crate::{ error::BoxError, - helpers::{HelperIdentity, TransportIdentity}, + helpers::HelperIdentity, hpke::{ Deserializable as _, IpaPrivateKey, IpaPublicKey, KeyRegistry, PrivateKeyOnly, PublicKeyOnly, Serializable as _, @@ -32,10 +32,8 @@ pub type OwnedPrivateKey = PrivateKeyDer<'static>; pub enum Error { #[error(transparent)] ParseError(#[from] config::ConfigError), - #[error("Invalid uri: {0}")] + #[error("invalid uri: {0}")] InvalidUri(#[from] hyper::http::uri::InvalidUri), - #[error("Invalid network size {0}")] - InvalidNetworkSize(usize), #[error(transparent)] IOError(#[from] std::io::Error), } @@ -119,88 +117,6 @@ impl NetworkConfig { } } -/// Reads a the config for a specific, single, sharded server from string. Expects config to be -/// toml format. The server in the network is specified via `id`, `shard_index` and -/// `shard_count`. -/// -/// First we read the configuration without assigning any identities. The number of peers in the -/// configuration must be a multiple of 6, or 3 as a special case to support older, non-sharded -/// configurations. -/// -/// If there are 3 entries, we assign helper identities for them. We create a dummy sharded -/// configuration. -/// -/// If there are any multiple of 6 peers, then peer assignment is as follows: -/// By rings (to be reminiscent of the previous config). The first 6 entries corresponds to the -/// leaders Ring. H1 shard 0, H2, shard 0, and H3 shard 0. The next 6 correspond increases the -/// shard index by one. -/// -/// Other methods to read the network.toml exist depending on the use, for example -/// [`NetworkConfig::from_toml_str`] reads a non-sharded config. -/// TODO: There will be one to read the information relevant for the RC (doesn't need shard -/// info) -/// -/// # Errors -/// if `input` is in an invalid format -pub fn sharded_server_from_toml_str( - input: &str, - id: HelperIdentity, - shard_index: ShardIndex, - shard_count: ShardIndex, -) -> Result<(NetworkConfig, NetworkConfig), Error> { - use config::{Config, File, FileFormat}; - - let all_network: NetworkConfig = Config::builder() - .add_source(File::from_str(input, FileFormat::Toml)) - .build()? - .try_deserialize()?; - - let ix: usize = shard_index.as_index(); - let ix_count: usize = shard_count.as_index(); - let mpc_id: usize = id.as_index(); - - let total_peers = all_network.peers.len(); - if total_peers == 3 { - let mpc_network = NetworkConfig { - peers: all_network.peers.clone(), - client: all_network.client.clone(), - identities: HelperIdentity::make_three().to_vec(), - }; - let shard_network = NetworkConfig { - peers: vec![all_network.peers[mpc_id].clone()], - client: all_network.client, - identities: vec![ShardIndex(0)], - }; - Ok((mpc_network, shard_network)) - } else if total_peers > 0 && total_peers % 6 == 0 { - let mpc_network = NetworkConfig { - peers: all_network - .peers - .clone() - .into_iter() - .skip(ix * 6) - .take(3) - .collect(), - client: all_network.client.clone(), - identities: HelperIdentity::make_three().to_vec(), - }; - let shard_network = NetworkConfig { - peers: all_network - .peers - .into_iter() - .skip(3 + mpc_id) - .step_by(6) - .take(ix_count) - .collect(), - client: all_network.client, - identities: shard_count.iter().collect(), - }; - Ok((mpc_network, shard_network)) - } else { - Err(Error::InvalidNetworkSize(total_peers)) - } -} - impl NetworkConfig { /// # Panics /// In the unexpected case there are more than max usize shards. diff --git a/ipa-core/tests/common/mod.rs b/ipa-core/tests/common/mod.rs index be582537c..ca1d5e08a 100644 --- a/ipa-core/tests/common/mod.rs +++ b/ipa-core/tests/common/mod.rs @@ -109,9 +109,9 @@ impl CommandExt for Command { } } -fn test_setup(config_path: &Path) -> [TcpListener; 6] { - let sockets: [_; 6] = array::from_fn(|_| TcpListener::bind("127.0.0.1:0").unwrap()); - let ports: [u16; 6] = sockets +fn test_setup(config_path: &Path) -> [TcpListener; 3] { + let sockets: [_; 3] = array::from_fn(|_| TcpListener::bind("127.0.0.1:0").unwrap()); + let ports: [u16; 3] = sockets .each_ref() .map(|sock| sock.local_addr().unwrap().port()); @@ -121,7 +121,7 @@ fn test_setup(config_path: &Path) -> [TcpListener; 6] { .arg("test-setup") .args(["--output-dir".as_ref(), config_path.as_os_str()]) .arg("--ports") - .args(ports.chunks(2).map(|p| p[0].to_string())); + .args(ports.map(|p| p.to_string())); command.status().unwrap_status(); sockets @@ -129,10 +129,10 @@ fn test_setup(config_path: &Path) -> [TcpListener; 6] { pub fn spawn_helpers( config_path: &Path, - sockets: &[TcpListener; 6], + sockets: &[TcpListener; 3], https: bool, ) -> Vec { - zip([1, 2, 3], sockets.chunks(2)) + zip([1, 2, 3], sockets) .map(|(id, socket)| { let mut command = Command::new(HELPER_BIN); command @@ -156,13 +156,8 @@ pub fn spawn_helpers( command.arg("--disable-https"); } - command.preserved_fds(vec![socket[0].as_raw_fd()]); - command.args(["--server-socket-fd", &socket[0].as_raw_fd().to_string()]); - command.preserved_fds(vec![socket[1].as_raw_fd()]); - command.args([ - "--shard-server-socket-fd", - &socket[1].as_raw_fd().to_string(), - ]); + command.preserved_fds(vec![socket.as_raw_fd()]); + command.args(["--server-socket-fd", &socket.as_raw_fd().to_string()]); // something went wrong if command is terminated at this point. let mut child = command.spawn().unwrap(); diff --git a/ipa-core/tests/helper_networks.rs b/ipa-core/tests/helper_networks.rs index 4eb59a38e..7775ffba4 100644 --- a/ipa-core/tests/helper_networks.rs +++ b/ipa-core/tests/helper_networks.rs @@ -71,8 +71,8 @@ fn keygen_confgen() { let dir = TempDir::new_delete_on_drop(); let path = dir.path(); - let sockets: [_; 6] = array::from_fn(|_| TcpListener::bind("127.0.0.1:0").unwrap()); - let ports: [u16; 6] = sockets + let sockets: [_; 3] = array::from_fn(|_| TcpListener::bind("127.0.0.1:0").unwrap()); + let ports: [u16; 3] = sockets .each_ref() .map(|sock| sock.local_addr().unwrap().port()); @@ -85,7 +85,7 @@ fn keygen_confgen() { .args(["--output-dir".as_ref(), path.as_os_str()]) .args(["--keys-dir".as_ref(), path.as_os_str()]) .arg("--ports") - .args(ports.chunks(2).map(|p| p[0].to_string())) + .args(ports.map(|p| p.to_string())) .arg("--hosts") .args(["localhost", "localhost", "localhost"]); if overwrite { From 9a1b6e9b1f83849fee5da9ea65f56a5fe22c79cf Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Mon, 18 Nov 2024 15:32:20 -0800 Subject: [PATCH 19/47] compiling version of aggregate_reports (#1434) --- ipa-core/src/protocol/hybrid/agg.rs | 537 ++++++++++++++++++ ipa-core/src/protocol/hybrid/mod.rs | 10 +- ipa-core/src/protocol/hybrid/step.rs | 12 + .../src/protocol/ipa_prf/oprf_padding/mod.rs | 3 +- ipa-core/src/query/runner/hybrid.rs | 1 + ipa-core/src/report/hybrid.rs | 24 +- ipa-core/src/test_fixture/hybrid.rs | 34 +- 7 files changed, 614 insertions(+), 7 deletions(-) create mode 100644 ipa-core/src/protocol/hybrid/agg.rs diff --git a/ipa-core/src/protocol/hybrid/agg.rs b/ipa-core/src/protocol/hybrid/agg.rs new file mode 100644 index 000000000..85f59f58a --- /dev/null +++ b/ipa-core/src/protocol/hybrid/agg.rs @@ -0,0 +1,537 @@ +use std::collections::BTreeMap; + +use futures::{stream, StreamExt, TryStreamExt}; + +use crate::{ + error::Error, + ff::{boolean::Boolean, boolean_array::BooleanArray, ArrayAccess}, + helpers::TotalRecords, + protocol::{ + boolean::step::EightBitStep, + context::{ + dzkp_validator::{validated_seq_join, DZKPValidator, TARGET_PROOF_SIZE}, + Context, DZKPUpgraded, MaliciousProtocolSteps, ShardedContext, UpgradableContext, + }, + hybrid::step::{AggregateReportsStep, HybridStep}, + ipa_prf::boolean_ops::addition_sequential::integer_add, + BooleanProtocols, + }, + report::hybrid::{AggregateableHybridReport, PrfHybridReport}, + secret_sharing::replicated::semi_honest::AdditiveShare as Replicated, +}; + +enum MatchEntry +where + BK: BooleanArray, + V: BooleanArray, +{ + Single(AggregateableHybridReport), + Pair( + AggregateableHybridReport, + AggregateableHybridReport, + ), + MoreThanTwo, +} + +impl MatchEntry +where + BK: BooleanArray, + V: BooleanArray, +{ + pub fn add_report(&mut self, new_report: AggregateableHybridReport) { + match self { + Self::Single(old_report) => { + *self = Self::Pair(old_report.clone(), new_report); + } + Self::Pair { .. } | Self::MoreThanTwo => *self = Self::MoreThanTwo, + } + } + + pub fn into_pair(self) -> Option<[AggregateableHybridReport; 2]> { + match self { + Self::Pair(r1, r2) => Some([r1, r2]), + _ => None, + } + } +} + +/// This function takes in a vector of `PrfHybridReports`, groups them by the oprf of the `match_key`, +/// and collects all pairs of reports with the same `match_key` into a vector of paris (as an array.) +/// +/// *Note*: Any `match_key` which appears once or more than twice is removed. +/// An honest report collector will only provide a single impression report per `match_key` and +/// an honest client will only provide a single conversion report per `match_key`. +/// Also note that a malicious client (intenional or bug) could provide exactly two conversions. +/// This would put the sum of conversion values into `breakdown_key` 0. As this is undetectable, +/// this makes `breakdown_key = 0` *unreliable*. +/// +/// Note: Possible Perf opportunity by removing the `collect()`. +/// See [#1443](https://github.com/private-attribution/ipa/issues/1443). +/// +/// *Note*: In order to add the pairs, the vector of pairs must be in the same order across all +/// three helpers. A standard `HashMap` uses system randomness for insertion placement, so we +/// use a `BTreeMap` to maintain consistent ordering across the helpers. +/// +fn group_report_pairs_ordered( + reports: Vec>, +) -> Vec<[AggregateableHybridReport; 2]> +where + BK: BooleanArray, + V: BooleanArray, +{ + let mut reports_by_matchkey: BTreeMap> = BTreeMap::new(); + + for report in reports { + reports_by_matchkey + .entry(report.match_key) + .and_modify(|e| e.add_report(report.clone().into())) + .or_insert(MatchEntry::Single(report.into())); + } + + // we only keep the reports from match_keys that provided exactly 2 reports + reports_by_matchkey + .into_values() + .filter_map(MatchEntry::into_pair) + .collect::>() +} + +/// This protocol is used to aggregate `PRFHybridReports` and returns `AggregateableHybridReports`. +/// It groups all the reports by the PRF of the `match_key`, finds all reports from `match_keys` +/// with that provided exactly 2 reports, then adds those 2 reports. +/// TODO (Performance opportunity): These additions are not currently vectorized. +/// We are currently deferring that work until the protocol is complete. +pub async fn aggregate_reports( + ctx: C, + reports: Vec>, +) -> Result>, Error> +where + C: UpgradableContext + ShardedContext, + BK: BooleanArray, + V: BooleanArray, + Replicated: BooleanProtocols>, +{ + let report_pairs = group_report_pairs_ordered(reports); + + let chunk_size: usize = TARGET_PROOF_SIZE / (BK::BITS as usize + V::BITS as usize); + + let ctx = ctx.set_total_records(TotalRecords::specified(report_pairs.len())?); + + let dzkp_validator = ctx.dzkp_validator( + MaliciousProtocolSteps { + protocol: &HybridStep::GroupBySum, + validate: &HybridStep::GroupBySumValidate, + }, + chunk_size.next_power_of_two(), + ); + + let agg_ctx = dzkp_validator.context(); + + let agg_work = stream::iter(report_pairs) + .enumerate() + .map(|(idx, reports)| { + let agg_ctx = agg_ctx.clone(); + async move { + let (breakdown_key, _) = integer_add::<_, EightBitStep, 1>( + agg_ctx.narrow(&AggregateReportsStep::AddBK), + idx.into(), + &reports[0].breakdown_key.to_bits(), + &reports[1].breakdown_key.to_bits(), + ) + .await?; + let (value, _) = integer_add::<_, EightBitStep, 1>( + agg_ctx.narrow(&AggregateReportsStep::AddV), + idx.into(), + &reports[0].value.to_bits(), + &reports[1].value.to_bits(), + ) + .await?; + Ok::<_, Error>(AggregateableHybridReport:: { + match_key: (), + breakdown_key: breakdown_key.collect_bits(), + value: value.collect_bits(), + }) + } + }); + + validated_seq_join(dzkp_validator, agg_work) + .try_collect() + .await +} + +#[cfg(all(test, unit_test))] +pub mod test { + use rand::Rng; + + use super::{aggregate_reports, group_report_pairs_ordered}; + use crate::{ + ff::{ + boolean_array::{BA3, BA8}, + U128Conversions, + }, + helpers::Role, + protocol::hybrid::step::AggregateReportsStep, + report::hybrid::{ + AggregateableHybridReport, IndistinguishableHybridReport, PrfHybridReport, + }, + secret_sharing::replicated::{ + semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing, + }, + sharding::{ShardConfiguration, ShardIndex}, + test_executor::{run, run_random}, + test_fixture::{ + hybrid::{TestAggregateableHybridReport, TestHybridRecord}, + Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards, + }, + }; + + // the inputs are laid out to work with exactly 2 shards + // as if it we're resharded by match_key/prf + const SHARDS: usize = 2; + + // we re-use these as the "prf" of the match_key + // to avoid needing to actually do the prf here + const SHARD1_MKS: [u64; 7] = [12345, 12345, 34567, 34567, 78901, 78901, 78901]; + const SHARD2_MKS: [u64; 7] = [23456, 23456, 45678, 56789, 67890, 67890, 67890]; + + fn get_records() -> Vec { + let shard1_records = [ + TestHybridRecord::TestImpression { + match_key: SHARD1_MKS[0], + breakdown_key: 45, + }, + TestHybridRecord::TestConversion { + match_key: SHARD1_MKS[1], + value: 1, + }, // attributed + TestHybridRecord::TestConversion { + match_key: SHARD1_MKS[2], + value: 3, + }, + TestHybridRecord::TestConversion { + match_key: SHARD1_MKS[3], + value: 4, + }, // not attibuted, but duplicated conversion. will land in breakdown_key 0 + TestHybridRecord::TestImpression { + match_key: SHARD1_MKS[4], + breakdown_key: 1, + }, // duplicated impression with same match_key + TestHybridRecord::TestImpression { + match_key: SHARD1_MKS[4], + breakdown_key: 2, + }, // duplicated impression with same match_key + TestHybridRecord::TestConversion { + match_key: SHARD1_MKS[5], + value: 7, + }, // removed + ]; + let shard2_records = [ + TestHybridRecord::TestImpression { + match_key: SHARD2_MKS[0], + breakdown_key: 56, + }, + TestHybridRecord::TestConversion { + match_key: SHARD2_MKS[1], + value: 2, + }, // attributed + TestHybridRecord::TestImpression { + match_key: SHARD2_MKS[2], + breakdown_key: 78, + }, // NOT attributed + TestHybridRecord::TestConversion { + match_key: SHARD2_MKS[3], + value: 5, + }, // NOT attributed + TestHybridRecord::TestImpression { + match_key: SHARD2_MKS[4], + breakdown_key: 90, + }, // attributed twice, removed + TestHybridRecord::TestConversion { + match_key: SHARD2_MKS[5], + value: 6, + }, // attributed twice, removed + TestHybridRecord::TestConversion { + match_key: SHARD2_MKS[6], + value: 7, + }, // attributed twice, removed + ]; + + shard1_records + .chunks(1) + .zip(shard2_records.chunks(1)) + .flat_map(|(a, b)| a.iter().chain(b)) + .cloned() + .collect() + } + + #[test] + fn group_reports_mpc() { + run(|| async { + let records = get_records(); + let expected = vec![ + [ + TestAggregateableHybridReport { + match_key: (), + value: 0, + breakdown_key: 45, + }, + TestAggregateableHybridReport { + match_key: (), + value: 1, + breakdown_key: 0, + }, + ], + [ + TestAggregateableHybridReport { + match_key: (), + value: 3, + breakdown_key: 0, + }, + TestAggregateableHybridReport { + match_key: (), + value: 4, + breakdown_key: 0, + }, + ], + [ + TestAggregateableHybridReport { + match_key: (), + value: 0, + breakdown_key: 56, + }, + TestAggregateableHybridReport { + match_key: (), + value: 2, + breakdown_key: 0, + }, + ], + ]; + + let world = TestWorld::>::with_shards(TestWorldConfig::default()); + #[allow(clippy::type_complexity)] + let results: Vec<[Vec<[AggregateableHybridReport; 2]>; 3]> = world + .malicious(records.clone().into_iter(), |ctx, input| { + let match_keys = match ctx.shard_id() { + ShardIndex(0) => SHARD1_MKS, + ShardIndex(1) => SHARD2_MKS, + _ => panic!("invalid shard_id"), + }; + async move { + let indistinguishable_reports: Vec< + IndistinguishableHybridReport, + > = input.iter().map(|r| r.clone().into()).collect::>(); + + let prf_reports: Vec> = indistinguishable_reports + .iter() + .zip(match_keys) + .map(|(indist_report, match_key)| PrfHybridReport { + match_key, + value: indist_report.value.clone(), + breakdown_key: indist_report.breakdown_key.clone(), + }) + .collect::>(); + group_report_pairs_ordered(prf_reports) + } + }) + .await; + + let results: Vec<[TestAggregateableHybridReport; 2]> = results + .into_iter() + .flat_map(|shard_result| { + shard_result[0] + .clone() + .into_iter() + .zip(shard_result[1].clone()) + .zip(shard_result[2].clone()) + .map(|((r1, r2), r3)| { + [ + [&r1[0], &r2[0], &r3[0]].reconstruct(), + [&r1[1], &r2[1], &r3[1]].reconstruct(), + ] + }) + .collect::>() + }) + .collect::>(); + + assert_eq!(results, expected); + }); + } + + #[test] + fn aggregate_reports_test() { + run(|| async { + let records = get_records(); + let expected = vec![ + TestAggregateableHybridReport { + match_key: (), + value: 1, + breakdown_key: 45, + }, + TestAggregateableHybridReport { + match_key: (), + value: 7, + breakdown_key: 0, + }, + TestAggregateableHybridReport { + match_key: (), + value: 2, + breakdown_key: 56, + }, + ]; + + let world = TestWorld::>::with_shards(TestWorldConfig::default()); + + let results: Vec<[Vec>; 3]> = world + .malicious(records.clone().into_iter(), |ctx, input| { + let match_keys = match ctx.shard_id() { + ShardIndex(0) => SHARD1_MKS, + ShardIndex(1) => SHARD2_MKS, + _ => panic!("invalid shard_id"), + }; + async move { + let indistinguishable_reports: Vec< + IndistinguishableHybridReport, + > = input.iter().map(|r| r.clone().into()).collect::>(); + + let prf_reports: Vec> = indistinguishable_reports + .iter() + .zip(match_keys) + .map(|(indist_report, match_key)| PrfHybridReport { + match_key, + value: indist_report.value.clone(), + breakdown_key: indist_report.breakdown_key.clone(), + }) + .collect::>(); + + aggregate_reports(ctx.clone(), prf_reports).await.unwrap() + } + }) + .await; + + let results: Vec = results + .into_iter() + .flat_map(|shard_result| { + shard_result[0] + .clone() + .into_iter() + .zip(shard_result[1].clone()) + .zip(shard_result[2].clone()) + .map(|((r1, r2), r3)| [&r1, &r2, &r3].reconstruct()) + .collect::>() + }) + .collect::>(); + + assert_eq!(results, expected); + }); + } + + fn build_prf_hybrid_report( + match_key: u64, + value: u8, + breakdown_key: u8, + ) -> PrfHybridReport { + PrfHybridReport:: { + match_key, + value: Replicated::new(BA3::truncate_from(value), BA3::truncate_from(0_u128)), + breakdown_key: Replicated::new( + BA8::truncate_from(breakdown_key), + BA8::truncate_from(0_u128), + ), + } + } + + fn build_aggregateable_report( + value: u8, + breakdown_key: u8, + ) -> AggregateableHybridReport { + AggregateableHybridReport:: { + match_key: (), + value: Replicated::new(BA3::truncate_from(value), BA3::truncate_from(0_u128)), + breakdown_key: Replicated::new( + BA8::truncate_from(breakdown_key), + BA8::truncate_from(0_u128), + ), + } + } + + #[test] + fn group_reports() { + let reports = vec![ + build_prf_hybrid_report(42, 2, 0), // pair: index (1,0) + build_prf_hybrid_report(42, 0, 3), // pair: index (1,1) + build_prf_hybrid_report(17, 4, 0), // pair: index (0,0) + build_prf_hybrid_report(17, 0, 13), // pair: index (0,1) + build_prf_hybrid_report(13, 0, 5), // single + build_prf_hybrid_report(11, 2, 0), // single + build_prf_hybrid_report(31, 1, 2), // triple + build_prf_hybrid_report(31, 3, 4), // triple + build_prf_hybrid_report(31, 5, 6), // triple + ]; + + let expected = vec![ + [ + build_aggregateable_report(4, 0), + build_aggregateable_report(0, 13), + ], + [ + build_aggregateable_report(2, 0), + build_aggregateable_report(0, 3), + ], + ]; + + let results = group_report_pairs_ordered(reports); + assert_eq!(results, expected); + } + + /// This test checks that the sharded malicious `aggregate_reports` fails + /// under a simple bit flip attack by H1. + #[test] + #[should_panic(expected = "DZKPValidationFailed")] + fn sharded_fail_under_bit_flip_attack_on_breakdown_key() { + use crate::helpers::in_memory_config::MaliciousHelper; + run_random(|mut rng| async move { + let target_shard = ShardIndex::from(rng.gen_range(0..u32::try_from(SHARDS).unwrap())); + let mut config = TestWorldConfig::default(); + + let step = format!("{}/{}", AggregateReportsStep::AddBK.as_ref(), "bit0",); + config.stream_interceptor = + MaliciousHelper::new(Role::H2, config.role_assignment(), move |ctx, data| { + // flip a bit of the match_key on the target shard, H1 + if ctx.gate.as_ref().contains(&step) + && ctx.dest == Role::H1 + && ctx.shard == Some(target_shard) + { + data[0] ^= 1u8; + } + }); + + let world = TestWorld::>::with_shards(config); + let records = get_records(); + let _results: Vec<[Vec>; 3]> = world + .malicious(records.clone().into_iter(), |ctx, input| { + let match_keys = match ctx.shard_id() { + ShardIndex(0) => SHARD1_MKS, + ShardIndex(1) => SHARD2_MKS, + _ => panic!("invalid shard_id"), + }; + async move { + let indistinguishable_reports: Vec< + IndistinguishableHybridReport, + > = input.iter().map(|r| r.clone().into()).collect::>(); + + let prf_reports: Vec> = indistinguishable_reports + .iter() + .zip(match_keys) + .map(|(indist_report, match_key)| PrfHybridReport { + match_key, + value: indist_report.value.clone(), + breakdown_key: indist_report.breakdown_key.clone(), + }) + .collect::>(); + + aggregate_reports(ctx.clone(), prf_reports).await.unwrap() + } + }) + .await; + }); + } +} diff --git a/ipa-core/src/protocol/hybrid/mod.rs b/ipa-core/src/protocol/hybrid/mod.rs index 5aa14ed1c..cc63df8a5 100644 --- a/ipa-core/src/protocol/hybrid/mod.rs +++ b/ipa-core/src/protocol/hybrid/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod agg; pub(crate) mod oprf; pub(crate) mod step; @@ -11,9 +12,10 @@ use crate::{ }, helpers::query::DpMechanism, protocol::{ - basics::{BooleanProtocols, Reveal}, + basics::Reveal, context::{DZKPUpgraded, MacUpgraded, ShardedContext, UpgradableContext}, hybrid::{ + agg::aggregate_reports, oprf::{compute_prf_and_reshard, BreakdownKey, CONV_CHUNK, PRF_CHUNK}, step::HybridStep as Step, }, @@ -23,6 +25,7 @@ use crate::{ shuffle::Shuffle, }, prss::FromPrss, + BooleanProtocols, }, report::hybrid::{IndistinguishableHybridReport, PrfHybridReport}, secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, Vectorizable}, @@ -70,6 +73,7 @@ where Replicated: Reveal, Output = >::Array>, PrfHybridReport: Serializable, + Replicated: BooleanProtocols>, { if input_rows.is_empty() { return Ok(vec![Replicated::ZERO; B]); @@ -89,7 +93,9 @@ where .instrument(info_span!("shuffle_inputs")) .await?; - let _sharded_reports = compute_prf_and_reshard(ctx.clone(), shuffled_input_rows).await?; + let sharded_reports = compute_prf_and_reshard(ctx.clone(), shuffled_input_rows).await?; + + let _aggregated_reports = aggregate_reports::(ctx.clone(), sharded_reports); unimplemented!("protocol::hybrid::hybrid_protocol is not fully implemented") } diff --git a/ipa-core/src/protocol/hybrid/step.rs b/ipa-core/src/protocol/hybrid/step.rs index 93dcb0aee..b4022551d 100644 --- a/ipa-core/src/protocol/hybrid/step.rs +++ b/ipa-core/src/protocol/hybrid/step.rs @@ -15,4 +15,16 @@ pub(crate) enum HybridStep { #[step(child = crate::protocol::context::step::MaliciousProtocolStep)] EvalPrf, ReshardByPrf, + #[step(child = AggregateReportsStep)] + GroupBySum, + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] + GroupBySumValidate, +} + +#[derive(CompactStep)] +pub(crate) enum AggregateReportsStep { + #[step(child = crate::protocol::boolean::step::EightBitStep)] + AddBK, + #[step(child = crate::protocol::boolean::step::EightBitStep)] + AddV, } diff --git a/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs b/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs index 0d74ad6a7..6b0ee88f6 100644 --- a/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs @@ -194,7 +194,8 @@ where total_number_of_fake_rows: u32, ) { padding_input_rows.extend( - repeat(IndistinguishableHybridReport::ZERO).take(total_number_of_fake_rows as usize), + repeat(IndistinguishableHybridReport::::ZERO) + .take(total_number_of_fake_rows as usize), ); } } diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index f16d3fac2..0ce953f03 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -72,6 +72,7 @@ where PrfSharing, PRF_CHUNK, Field = Fp25519> + FromPrss, Replicated: Reveal, Output = >::Array>, + Replicated: BooleanProtocols>, { #[tracing::instrument("hybrid_query", skip_all, fields(sz=%query_size))] pub async fn execute( diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index b3842b6b5..be2bd1d3f 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -48,7 +48,7 @@ use crate::{ open_in_place, seal_in_place, CryptError, EncapsulationSize, PrivateKeyRegistry, PublicKeyRegistry, TagSize, }, - protocol::ipa_prf::shuffle::Shuffleable, + protocol::ipa_prf::{boolean_ops::expand_shared_array_in_place, shuffle::Shuffleable}, report::hybrid_info::{HybridConversionInfo, HybridImpressionInfo, HybridInfo}, secret_sharing::{ replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing}, @@ -681,6 +681,28 @@ where /// Converted report where shares of match key are replaced with OPRF value pub type PrfHybridReport = IndistinguishableHybridReport; +/// After grouping `IndistinguishableHybridReport`s by the OPRF of thier `match_key`, +/// that OPRF value is no longer required. +pub type AggregateableHybridReport = IndistinguishableHybridReport; + +/// When aggregating reports, we need to lift the value from `V` to `HV`. +impl From> for AggregateableHybridReport +where + BK: SharedValue + BooleanArray, + V: SharedValue + BooleanArray, + HV: SharedValue + BooleanArray, +{ + fn from(report: PrfHybridReport) -> Self { + let mut value = Replicated::::ZERO; + expand_shared_array_in_place(&mut value, &report.value, 0); + Self { + match_key: (), + breakdown_key: report.breakdown_key, + value, + } + } +} + /// This struct is designed to fit both `HybridConversionReport`s /// and `HybridImpressionReport`s so that they can be made indistingushable. /// Note: these need to be shuffled (and secret shares need to be rerandomized) diff --git a/ipa-core/src/test_fixture/hybrid.rs b/ipa-core/src/test_fixture/hybrid.rs index a28fa7232..7a4d86007 100644 --- a/ipa-core/src/test_fixture/hybrid.rs +++ b/ipa-core/src/test_fixture/hybrid.rs @@ -22,13 +22,15 @@ pub enum TestHybridRecord { TestConversion { match_key: u64, value: u32 }, } -#[derive(PartialEq, Eq)] -pub struct TestIndistinguishableHybridReport { - pub match_key: u64, +#[derive(PartialEq, Eq, Debug)] +pub struct TestIndistinguishableHybridReport { + pub match_key: MK, pub value: u32, pub breakdown_key: u32, } +pub type TestAggregateableHybridReport = TestIndistinguishableHybridReport<()>; + impl Reconstruct for [&IndistinguishableHybridReport; 3] where @@ -60,6 +62,32 @@ where } } +impl Reconstruct + for [&IndistinguishableHybridReport; 3] +where + BK: BooleanArray + U128Conversions + IntoShares>, + V: BooleanArray + U128Conversions + IntoShares>, +{ + fn reconstruct(&self) -> TestAggregateableHybridReport { + let breakdown_key = self + .each_ref() + .map(|v| v.breakdown_key.clone()) + .reconstruct() + .as_u128(); + let value = self + .each_ref() + .map(|v| v.value.clone()) + .reconstruct() + .as_u128(); + + TestAggregateableHybridReport { + match_key: (), + breakdown_key: breakdown_key.try_into().unwrap(), + value: value.try_into().unwrap(), + } + } +} + impl IntoShares> for TestHybridRecord where BK: BooleanArray + U128Conversions + IntoShares>, From e831a219bb549a7e7b15e15625c2f3d4384b9227 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Mon, 18 Nov 2024 17:11:48 -0800 Subject: [PATCH 20/47] copy IPA breakdown_key reveal verbatim (#1444) --- .../src/protocol/hybrid/breakdown_reveal.rs | 447 ++++++++++++++++++ 1 file changed, 447 insertions(+) create mode 100644 ipa-core/src/protocol/hybrid/breakdown_reveal.rs diff --git a/ipa-core/src/protocol/hybrid/breakdown_reveal.rs b/ipa-core/src/protocol/hybrid/breakdown_reveal.rs new file mode 100644 index 000000000..00d4c36af --- /dev/null +++ b/ipa-core/src/protocol/hybrid/breakdown_reveal.rs @@ -0,0 +1,447 @@ +use std::{convert::Infallible, pin::pin}; + +use futures::stream; +use futures_util::{StreamExt, TryStreamExt}; +use tracing::{info_span, Instrument}; + +use super::aggregate_values; +use crate::{ + error::{Error, UnwrapInfallible}, + ff::{ + boolean::Boolean, + boolean_array::{BooleanArray, BooleanArrayReader, BooleanArrayWriter, BA32}, + U128Conversions, + }, + helpers::TotalRecords, + protocol::{ + basics::{reveal, Reveal}, + context::{ + dzkp_validator::DZKPValidator, Context, DZKPUpgraded, MaliciousProtocolSteps, + UpgradableContext, + }, + ipa_prf::{ + aggregation::{ + aggregate_values_proof_chunk, step::AggregationStep as Step, AGGREGATE_DEPTH, + }, + oprf_padding::{apply_dp_padding, PaddingParameters}, + prf_sharding::{AttributionOutputs, SecretSharedAttributionOutputs}, + shuffle::{Shuffle, Shuffleable}, + BreakdownKey, + }, + BooleanProtocols, RecordId, + }, + secret_sharing::{ + replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing}, + BitDecomposed, FieldSimd, SharedValue, TransposeFrom, Vectorizable, + }, + seq_join::seq_join, +}; + +impl AttributionOutputs, Replicated> +where + BK: BooleanArray, + TV: BooleanArray, +{ + fn join_fields(breakdown_key: BK, trigger_value: TV) -> ::Share { + let mut share = ::Share::ZERO; + + BooleanArrayWriter::new(&mut share) + .write(&breakdown_key) + .write(&trigger_value); + + share + } + + fn split_fields(share: &::Share) -> (BK, TV) { + let bits = BooleanArrayReader::new(share); + let (breakdown_key, bits) = bits.read(); + let (trigger_value, _bits) = bits.read(); + (breakdown_key, trigger_value) + } +} + +impl Shuffleable for AttributionOutputs, Replicated> +where + BK: BooleanArray, + TV: BooleanArray, +{ + /// TODO: Use a smaller BA type to contain BK and TV + type Share = BA32; + + fn left(&self) -> Self::Share { + Self::join_fields( + ReplicatedSecretSharing::left(&self.attributed_breakdown_key_bits), + ReplicatedSecretSharing::left(&self.capped_attributed_trigger_value), + ) + } + + fn right(&self) -> Self::Share { + Self::join_fields( + ReplicatedSecretSharing::right(&self.attributed_breakdown_key_bits), + ReplicatedSecretSharing::right(&self.capped_attributed_trigger_value), + ) + } + + fn new(l: Self::Share, r: Self::Share) -> Self { + debug_assert!( + BK::BITS + TV::BITS <= Self::Share::BITS, + "share type {} is too small", + std::any::type_name::(), + ); + + let left = Self::split_fields(&l); + let right = Self::split_fields(&r); + + Self { + attributed_breakdown_key_bits: ReplicatedSecretSharing::new(left.0, right.0), + capped_attributed_trigger_value: ReplicatedSecretSharing::new(left.1, right.1), + } + } +} + +/// Improved Aggregation a.k.a Aggregation revealing breakdown. +/// +/// Aggregation steps happen after attribution. the input for Aggregation is a +/// list of tuples containing Trigger Values (TV) and their corresponding +/// Breakdown Keys (BK), which were attributed in the previous step of IPA. The +/// output of Aggregation is a histogram, where each “bin” or "bucket" is a BK +/// and the value is the addition of all the TVs for it, hence the name +/// Aggregation. This can be thought as a SQL GROUP BY operation. +/// +/// The protocol involves four main steps: +/// 1. Shuffle the data to protect privacy (see [`shuffle_attributions`]). +/// 2. Reveal breakdown keys. This is the key difference to the previous +/// aggregation (see [`reveal_breakdowns`]). +/// 3. Add all values for each breakdown. +/// +/// This protocol explicitly manages proof batches for DZKP-based malicious security by +/// processing chunks of values from `intermediate_results.chunks()`. Procession +/// through record IDs is not uniform for all of the gates in the protocol. The first +/// layer of the reduction adds N pairs of records, the second layer adds N/2 pairs of +/// records, etc. This has a few consequences: +/// * We must specify a batch size of `usize::MAX` when calling `dzkp_validator`. +/// * We must track record IDs across chunks, so that subsequent chunks can +/// start from the last record ID that was used in the previous chunk. +/// * Because the first record ID in the proof batch is set implicitly, we must +/// guarantee that it submits multiplication intermediates before any other +/// record. This is currently ensured by the serial operation of the aggregation +/// protocol (i.e. by not using `seq_join`). +#[tracing::instrument(name = "breakdown_reveal_aggregation", skip_all, fields(total = attributed_values.len()))] +pub async fn breakdown_reveal_aggregation( + ctx: C, + attributed_values: Vec>, + padding_params: &PaddingParameters, +) -> Result>, Error> +where + C: UpgradableContext + Shuffle, + Boolean: FieldSimd, + Replicated: BooleanProtocols, B>, + BK: BreakdownKey, + Replicated: Reveal, Output = >::Array>, + TV: BooleanArray + U128Conversions, + HV: BooleanArray + U128Conversions, + BitDecomposed>: + for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, +{ + // Apply DP padding for Breakdown Reveal Aggregation + let attributed_values_padded = + apply_dp_padding::<_, AttributionOutputs, Replicated>, B>( + ctx.narrow(&Step::PaddingDp), + attributed_values, + padding_params, + ) + .await?; + + let attributions = ctx + .narrow(&Step::Shuffle) + .shuffle(attributed_values_padded) + .instrument(info_span!("shuffle_attribution_outputs")) + .await?; + + // Revealing the breakdowns doesn't do any multiplies, so won't make it as far as + // doing a proof, but we need the validator to obtain an upgraded malicious context. + let validator = ctx.clone().dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::Reveal, + validate: &Step::RevealValidate, + }, + usize::MAX, + ); + let grouped_tvs = reveal_breakdowns(&validator.context(), attributions).await?; + validator.validate().await?; + let mut intermediate_results: Vec>> = grouped_tvs.into(); + + // Any real-world aggregation should be able to complete in two layers (two + // iterations of the `while` loop below). Tests with small `TARGET_PROOF_SIZE` + // may exceed that. + let mut chunk_counter = 0; + let mut depth = 0; + let agg_proof_chunk = aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()); + + while intermediate_results.len() > 1 { + let mut record_ids = [RecordId::FIRST; AGGREGATE_DEPTH]; + let mut next_intermediate_results = Vec::new(); + for chunk in intermediate_results.chunks(agg_proof_chunk) { + let chunk_len = chunk.len(); + let validator = ctx.clone().dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::aggregate(depth), + validate: &Step::AggregateValidate, + }, + usize::MAX, // See note about batching above. + ); + let result = aggregate_values::<_, HV, B>( + validator.context(), + stream::iter(chunk).map(|v| Ok(v.clone())).boxed(), + chunk_len, + Some(&mut record_ids), + ) + .await?; + validator.validate_indexed(chunk_counter).await?; + chunk_counter += 1; + next_intermediate_results.push(result); + } + depth += 1; + intermediate_results = next_intermediate_results; + } + + Ok(intermediate_results + .into_iter() + .next() + .expect("aggregation input must not be empty")) +} + +/// Transforms the Breakdown key from a secret share into a revealed `usize`. +/// The input are the Atrributions and the output is a list of lists of secret +/// shared Trigger Values. Since Breakdown Keys are assumed to be dense the +/// first list contains all the possible Breakdowns, the index in the list +/// representing the Breakdown value. The second list groups all the Trigger +/// Values for that particular Breakdown. +#[tracing::instrument(name = "reveal_breakdowns", skip_all, fields( + total = attributions.len(), +))] +async fn reveal_breakdowns( + parent_ctx: &C, + attributions: Vec>, +) -> Result, Error> +where + C: Context, + Replicated: BooleanProtocols, + Boolean: FieldSimd, + BK: BreakdownKey, + Replicated: Reveal>::Array>, + TV: BooleanArray + U128Conversions, +{ + let reveal_ctx = parent_ctx.set_total_records(TotalRecords::specified(attributions.len())?); + + let reveal_work = stream::iter(attributions).enumerate().map(|(i, ao)| { + let record_id = RecordId::from(i); + let reveal_ctx = reveal_ctx.clone(); + async move { + let revealed_bk = + reveal(reveal_ctx, record_id, &ao.attributed_breakdown_key_bits).await?; + let revealed_bk = BK::from_array(&revealed_bk); + let Ok(bk) = usize::try_from(revealed_bk.as_u128()) else { + return Err(Error::Internal); + }; + Ok::<_, Error>((bk, ao.capped_attributed_trigger_value)) + } + }); + let mut grouped_tvs = GroupedTriggerValues::::new(); + let mut stream = pin!(seq_join(reveal_ctx.active_work(), reveal_work)); + while let Some((bk, tv)) = stream.try_next().await? { + grouped_tvs.push(bk, tv); + } + + Ok(grouped_tvs) +} + +/// Helper type that hold all the Trigger Values, grouped by their Breakdown +/// Key. The main functionality is to turn into a stream that can be given to +/// [`aggregate_values`]. +struct GroupedTriggerValues { + tvs: [Vec>; B], + max_len: usize, +} + +impl GroupedTriggerValues { + fn new() -> Self { + Self { + tvs: std::array::from_fn(|_| vec![]), + max_len: 0, + } + } + + fn push(&mut self, bk: usize, value: Replicated) { + self.tvs[bk].push(value); + if self.tvs[bk].len() > self.max_len { + self.max_len = self.tvs[bk].len(); + } + } +} + +impl From> + for Vec>> +where + Boolean: FieldSimd, + BitDecomposed>: + for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, +{ + fn from( + mut grouped_tvs: GroupedTriggerValues, + ) -> Vec>> { + let iter = (0..grouped_tvs.max_len).map(move |_| { + let slice: [Replicated; B] = grouped_tvs + .tvs + .each_mut() + .map(|tv| tv.pop().unwrap_or(Replicated::ZERO)); + + BitDecomposed::transposed_from(&slice).unwrap_infallible() + }); + iter.collect() + } +} + +#[cfg(all(test, any(unit_test, feature = "shuttle")))] +pub mod tests { + use futures::TryFutureExt; + use rand::seq::SliceRandom; + + #[cfg(not(feature = "shuttle"))] + use crate::{ff::boolean_array::BA16, test_executor::run}; + use crate::{ + ff::{ + boolean::Boolean, + boolean_array::{BA3, BA5, BA8}, + U128Conversions, + }, + protocol::ipa_prf::{ + aggregation::breakdown_reveal::breakdown_reveal_aggregation, + oprf_padding::PaddingParameters, + prf_sharding::{AttributionOutputsTestInput, SecretSharedAttributionOutputs}, + }, + rand::Rng, + secret_sharing::{ + replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, TransposeFrom, + }, + test_executor::run_with, + test_fixture::{Reconstruct, Runner, TestWorld}, + }; + + fn input_row(bk: usize, tv: u128) -> AttributionOutputsTestInput { + let bk: u128 = bk.try_into().unwrap(); + AttributionOutputsTestInput { + bk: BA5::truncate_from(bk), + tv: BA3::truncate_from(tv), + } + } + + #[test] + fn semi_honest_happy_path() { + // if shuttle executor is enabled, run this test only once. + // it is a very expensive test to explore all possible states, + // sometimes github bails after 40 minutes of running it + // (workers there are really slow). + run_with::<_, _, 3>(|| async { + let world = TestWorld::default(); + let mut rng = world.rng(); + let mut expectation = Vec::new(); + for _ in 0..32 { + expectation.push(rng.gen_range(0u128..256)); + } + let expectation = expectation; // no more mutability for safety + let mut inputs = Vec::new(); + for (bk, expected_hv) in expectation.iter().enumerate() { + let mut remainder = *expected_hv; + while remainder > 7 { + let tv = rng.gen_range(0u128..8); + remainder -= tv; + inputs.push(input_row(bk, tv)); + } + inputs.push(input_row(bk, remainder)); + } + inputs.shuffle(&mut rng); + let result: Vec<_> = world + .semi_honest(inputs.into_iter(), |ctx, input_rows| async move { + let aos = input_rows + .into_iter() + .map(|ti| SecretSharedAttributionOutputs { + attributed_breakdown_key_bits: ti.0, + capped_attributed_trigger_value: ti.1, + }) + .collect(); + let r: Vec> = + breakdown_reveal_aggregation::<_, BA5, BA3, BA8, 32>( + ctx, + aos, + &PaddingParameters::relaxed(), + ) + .map_ok(|d: BitDecomposed>| { + Vec::transposed_from(&d).unwrap() + }) + .await + .unwrap(); + r + }) + .await + .reconstruct(); + let result = result.iter().map(|&v| v.as_u128()).collect::>(); + assert_eq!(32, result.len()); + assert_eq!(result, expectation); + }); + } + + #[test] + #[cfg(not(feature = "shuttle"))] // too slow + fn malicious_happy_path() { + type HV = BA16; + run(|| async { + let world = TestWorld::default(); + let mut rng = world.rng(); + let mut expectation = Vec::new(); + for _ in 0..32 { + expectation.push(rng.gen_range(0u128..512)); + } + // The size of input needed here to get complete coverage (more precisely, + // the size of input to the final aggregation using `aggregate_values`) + // depends on `TARGET_PROOF_SIZE`. + let expectation = expectation; // no more mutability for safety + let mut inputs = Vec::new(); + for (bk, expected_hv) in expectation.iter().enumerate() { + let mut remainder = *expected_hv; + while remainder > 7 { + let tv = rng.gen_range(0u128..8); + remainder -= tv; + inputs.push(input_row(bk, tv)); + } + inputs.push(input_row(bk, remainder)); + } + inputs.shuffle(&mut rng); + let result: Vec<_> = world + .malicious(inputs.into_iter(), |ctx, input_rows| async move { + let aos = input_rows + .into_iter() + .map(|ti| SecretSharedAttributionOutputs { + attributed_breakdown_key_bits: ti.0, + capped_attributed_trigger_value: ti.1, + }) + .collect(); + breakdown_reveal_aggregation::<_, BA5, BA3, HV, 32>( + ctx, + aos, + &PaddingParameters::relaxed(), + ) + .map_ok(|d: BitDecomposed>| { + Vec::transposed_from(&d).unwrap() + }) + .await + .unwrap() + }) + .await + .reconstruct(); + let result = result.iter().map(|v: &HV| v.as_u128()).collect::>(); + assert_eq!(32, result.len()); + assert_eq!(result, expectation); + }); + } +} From 9336e5982da75cd9564236064693c41241202381 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Tue, 19 Nov 2024 10:59:20 -0800 Subject: [PATCH 21/47] Use recursion factor of 4 for all proofs (#1440) --- ipa-core/src/helpers/gateway/send.rs | 9 +- .../src/protocol/context/dzkp_validator.rs | 11 +- ipa-core/src/protocol/hybrid/oprf.rs | 29 +++-- .../src/protocol/ipa_prf/aggregation/mod.rs | 9 +- .../boolean_ops/share_conversion_aby.rs | 4 +- .../ipa_prf/malicious_security/prover.rs | 22 ++-- ipa-core/src/protocol/ipa_prf/mod.rs | 31 +++-- .../src/protocol/ipa_prf/prf_sharding/mod.rs | 6 +- ipa-core/src/protocol/ipa_prf/quicksort.rs | 3 +- .../validation_protocol/proof_generation.rs | 98 +++++++-------- .../ipa_prf/validation_protocol/validation.rs | 116 +++++++++--------- ipa-core/src/utils/mod.rs | 2 +- ipa-core/src/utils/power_of_two.rs | 25 ++++ 13 files changed, 206 insertions(+), 159 deletions(-) diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index d1b88af9f..bbd913031 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -25,6 +25,7 @@ use crate::{ labels::{ROLE, STEP}, metrics::{BYTES_SENT, RECORDS_SENT}, }, + utils::non_zero_prev_power_of_two, }; /// Sending end of the gateway channel. @@ -256,14 +257,6 @@ impl SendChannelConfig { total_records: TotalRecords, record_size: usize, ) -> Self { - // this computes the greatest positive power of 2 that is - // less than or equal to target. - fn non_zero_prev_power_of_two(target: usize) -> usize { - let bits = usize::BITS - target.leading_zeros(); - - 1 << (std::cmp::max(1, bits) - 1) - } - assert!(record_size > 0, "Message size cannot be 0"); let total_capacity = gateway_config.active.get() * record_size; diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 63a412265..fd616fc9c 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -20,7 +20,7 @@ use crate::{ }, ipa_prf::{ validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, - LargeProofGenerator, SmallProofGenerator, + CompressedProofGenerator, FirstProofGenerator, }, Gate, RecordId, RecordIdRange, }, @@ -50,6 +50,9 @@ const BIT_ARRAY_SHIFT: usize = BIT_ARRAY_LEN.ilog2() as usize; // A smaller value is used for tests, to enable covering some corner cases with a // reasonable runtime. Some of these tests use TARGET_PROOF_SIZE directly, so for tests // it does need to be a power of two. +// +// TARGET_PROOF_SIZE is closely related to MAX_PROOF_RECURSION; see the assertion that +// `uv_values.len() <= max_uv_values` in `ProofBatch` for more detail. #[cfg(test)] pub const TARGET_PROOF_SIZE: usize = 8192; #[cfg(not(test))] @@ -73,7 +76,7 @@ pub const TARGET_PROOF_SIZE: usize = 50_000_000; // to blocks of 256), leaving some margin is advised. // // The implementation requires that MAX_PROOF_RECURSION is at least 2. -pub const MAX_PROOF_RECURSION: usize = 9; +pub const MAX_PROOF_RECURSION: usize = 14; /// `MultiplicationInputsBlock` is a block of fixed size of intermediate values /// that occur duringa multiplication. @@ -601,8 +604,8 @@ impl Batch { ctx: Base<'_, B>, batch_index: usize, ) -> Result<(), Error> { - const PRSS_RECORDS_PER_BATCH: usize = LargeProofGenerator::PROOF_LENGTH - + (MAX_PROOF_RECURSION - 1) * SmallProofGenerator::PROOF_LENGTH + const PRSS_RECORDS_PER_BATCH: usize = FirstProofGenerator::PROOF_LENGTH + + (MAX_PROOF_RECURSION - 1) * CompressedProofGenerator::PROOF_LENGTH + 2; // P and Q masks let proof_ctx = ctx.narrow(&Step::GenerateProof); diff --git a/ipa-core/src/protocol/hybrid/oprf.rs b/ipa-core/src/protocol/hybrid/oprf.rs index adc1da9ff..bf1a14f3b 100644 --- a/ipa-core/src/protocol/hybrid/oprf.rs +++ b/ipa-core/src/protocol/hybrid/oprf.rs @@ -1,3 +1,5 @@ +use std::cmp::max; + use futures::{stream, StreamExt, TryStreamExt}; use typenum::Const; @@ -17,8 +19,9 @@ use crate::{ protocol::{ basics::{BooleanProtocols, Reveal}, context::{ - dzkp_validator::DZKPValidator, reshard_try_stream, DZKPUpgraded, MacUpgraded, - MaliciousProtocolSteps, ShardedContext, UpgradableContext, Validator, + dzkp_validator::{DZKPValidator, TARGET_PROOF_SIZE}, + reshard_try_stream, DZKPUpgraded, MacUpgraded, MaliciousProtocolSteps, ShardedContext, + UpgradableContext, Validator, }, hybrid::step::HybridStep, ipa_prf::{ @@ -34,7 +37,9 @@ use crate::{ Vectorizable, }, seq_join::seq_join, + utils::non_zero_prev_power_of_two, }; + // In theory, we could support (runtime-configured breakdown count) ≤ (compile-time breakdown count) // ≤ 2^|bk|, with all three values distinct, but at present, there is no runtime configuration and // the latter two must be equal. The implementation of `move_single_value_to_bucket` does support a @@ -64,13 +69,17 @@ pub const CONV_CHUNK: usize = 256; /// Vectorization dimension for PRF pub const PRF_CHUNK: usize = 16; -// We expect 2*256 = 512 gates in total for two additions per conversion. The vectorization factor -// is CONV_CHUNK. Let `len` equal the number of converted shares. The total amount of -// multiplications is CONV_CHUNK*512*len. We want CONV_CHUNK*512*len ≈ 50M, or len ≈ 381, for a -// reasonably-sized proof. There is also a constraint on proof chunks to be powers of two, so -// we pick the closest power of two close to 381 but less than that value. 256 gives us around 33M -// multiplications per batch -const CONV_PROOF_CHUNK: usize = 256; +/// Returns a suitable proof chunk size (in records) for use with `convert_to_fp25519`. +/// +/// We expect 2*256 = 512 gates in total for two additions per conversion. The +/// vectorization factor is `CONV_CHUNK`. Let `len` equal the number of converted +/// shares. The total amount of multiplications is `CONV_CHUNK`*512*len. We want +/// `CONV_CHUNK`*512*len ≈ 50M for a reasonably-sized proof. There is also a constraint +/// on proof chunks to be powers of two, and we don't want to compute a proof chunk +/// of zero when `TARGET_PROOF_SIZE` is smaller for tests. +fn conv_proof_chunk() -> usize { + non_zero_prev_power_of_two(max(2, TARGET_PROOF_SIZE / CONV_CHUNK / 512)) +} /// This computes the Dodis-Yampolsky PRF value on every match key from input, /// and reshards the reports according to the computed PRF. At the end, reports with the @@ -101,7 +110,7 @@ where protocol: &HybridStep::ConvertFp25519, validate: &HybridStep::ConvertFp25519Validate, }, - CONV_PROOF_CHUNK, + conv_proof_chunk(), ); let m_ctx = validator.context(); diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs index d3e78fd8c..6a9adb345 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs @@ -25,6 +25,7 @@ use crate::{ replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, FieldSimd, SharedValue, TransposeFrom, Vectorizable, }, + utils::non_zero_prev_power_of_two, }; pub(crate) mod breakdown_reveal; @@ -96,8 +97,14 @@ pub type AggResult = Result /// saturating the output) is: /// /// $\sum_{i = 1}^k 2^{k - i} (b + i - 1) \approx 2^k (b + 1) = N (b + 1)$ +/// +/// We set a floor of 2 to avoid computing a chunk of zero when `TARGET_PROOF_SIZE` is +/// smaller for tests. pub fn aggregate_values_proof_chunk(input_width: usize, input_item_bits: usize) -> usize { - max(2, TARGET_PROOF_SIZE / input_width / (input_item_bits + 1)).next_power_of_two() + non_zero_prev_power_of_two(max( + 2, + TARGET_PROOF_SIZE / input_width / (input_item_bits + 1), + )) } // This is the step count for AggregateChunkStep. We need it to size RecordId arrays. diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs index 2dabdc3f4..b98fe9612 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs @@ -381,7 +381,7 @@ mod tests { helpers::stream::process_slice_by_chunks, protocol::{ context::{dzkp_validator::DZKPValidator, UpgradableContext, TEST_DZKP_STEPS}, - ipa_prf::{CONV_CHUNK, CONV_PROOF_CHUNK, PRF_CHUNK}, + ipa_prf::{conv_proof_chunk, CONV_CHUNK, PRF_CHUNK}, }, rand::thread_rng, secret_sharing::SharedValue, @@ -415,7 +415,7 @@ mod tests { let [res0, res1, res2] = world .semi_honest(records.into_iter(), |ctx, records| async move { let c_ctx = ctx.set_total_records((COUNT + CONV_CHUNK - 1) / CONV_CHUNK); - let validator = &c_ctx.dzkp_validator(TEST_DZKP_STEPS, CONV_PROOF_CHUNK); + let validator = &c_ctx.dzkp_validator(TEST_DZKP_STEPS, conv_proof_chunk()); let m_ctx = validator.context(); seq_join( m_ctx.active_work(), diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs index af451b458..26ce5c5d0 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs @@ -1,14 +1,15 @@ use std::{borrow::Borrow, iter::zip, marker::PhantomData}; -#[cfg(all(test, unit_test))] -use crate::ff::Fp31; use crate::{ error::Error::{self, DZKPMasks}, ff::{Fp61BitPrime, PrimeField}, helpers::hashing::{compute_hash, hash_to_field}, protocol::{ context::Context, - ipa_prf::malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + ipa_prf::{ + malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + CompressedProofGenerator, + }, prss::SharedRandomness, RecordId, RecordIdRange, }, @@ -84,8 +85,8 @@ where // compute final uv values let (u_values, v_values) = &mut self.uv_chunks[0]; // shift first element to last position - u_values[SmallProofGenerator::RECURSION_FACTOR - 1] = u_values[0]; - v_values[SmallProofGenerator::RECURSION_FACTOR - 1] = v_values[0]; + u_values[CompressedProofGenerator::RECURSION_FACTOR - 1] = u_values[0]; + v_values[CompressedProofGenerator::RECURSION_FACTOR - 1] = v_values[0]; // set masks in first position u_values[0] = my_p_mask; v_values[0] = my_q_mask; @@ -105,15 +106,11 @@ pub struct ProofGenerator, } -#[cfg(all(test, unit_test))] -pub type TestProofGenerator = ProofGenerator; - // Compression Factor is L // P, Proof size is 2*L - 1 // M, the number of interpolated points is L - 1 // The reason we need these is that Rust doesn't support basic math operations on const generics -pub type SmallProofGenerator = ProofGenerator; -pub type LargeProofGenerator = ProofGenerator; +pub type SmallProofGenerator = ProofGenerator; impl ProofGenerator { // define constants such that they can be used externally @@ -265,7 +262,7 @@ mod test { context::Context, ipa_prf::malicious_security::{ lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, - prover::{LargeProofGenerator, SmallProofGenerator, TestProofGenerator, UVValues}, + prover::{ProofGenerator, SmallProofGenerator, UVValues}, }, RecordId, RecordIdRange, }, @@ -274,6 +271,9 @@ mod test { test_fixture::{Runner, TestWorld}, }; + type TestProofGenerator = ProofGenerator; + type LargeProofGenerator = ProofGenerator; + fn zip_chunks(a: I, b: J) -> UVValues where I: IntoIterator, diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index f4225a0d8..0ffe2adbc 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -1,4 +1,4 @@ -use std::{convert::Infallible, iter::zip, num::NonZeroU32, ops::Add}; +use std::{cmp::max, convert::Infallible, iter::zip, num::NonZeroU32, ops::Add}; use futures::{stream, StreamExt, TryStreamExt}; use generic_array::{ArrayLength, GenericArray}; @@ -24,8 +24,8 @@ use crate::{ protocol::{ basics::{BooleanArrayMul, BooleanProtocols, Reveal}, context::{ - dzkp_validator::DZKPValidator, DZKPUpgraded, MacUpgraded, MaliciousProtocolSteps, - UpgradableContext, + dzkp_validator::{DZKPValidator, TARGET_PROOF_SIZE}, + DZKPUpgraded, MacUpgraded, MaliciousProtocolSteps, UpgradableContext, }, ipa_prf::{ boolean_ops::convert_to_fp25519, @@ -44,6 +44,7 @@ use crate::{ BitDecomposed, FieldSimd, SharedValue, TransposeFrom, Vectorizable, }, seq_join::seq_join, + utils::non_zero_prev_power_of_two, }; pub(crate) mod aggregation; @@ -58,7 +59,9 @@ pub(crate) mod shuffle; pub(crate) mod step; pub mod validation_protocol; -pub use malicious_security::prover::{LargeProofGenerator, SmallProofGenerator}; +pub type FirstProofGenerator = malicious_security::prover::SmallProofGenerator; +pub type CompressedProofGenerator = malicious_security::prover::SmallProofGenerator; + pub use shuffle::Shuffle; /// Match key type @@ -409,13 +412,17 @@ where Ok(noisy_output_histogram) } -// We expect 2*256 = 512 gates in total for two additions per conversion. The vectorization factor -// is CONV_CHUNK. Let `len` equal the number of converted shares. The total amount of -// multiplications is CONV_CHUNK*512*len. We want CONV_CHUNK*512*len ≈ 50M, or len ≈ 381, for a -// reasonably-sized proof. There is also a constraint on proof chunks to be powers of two, so -// we pick the closest power of two close to 381 but less than that value. 256 gives us around 33M -// multiplications per batch -const CONV_PROOF_CHUNK: usize = 256; +/// Returns a suitable proof chunk size (in records) for use with `convert_to_fp25519`. +/// +/// We expect 2*256 = 512 gates in total for two additions per conversion. The +/// vectorization factor is `CONV_CHUNK`. Let `len` equal the number of converted +/// shares. The total amount of multiplications is `CONV_CHUNK`*512*len. We want +/// `CONV_CHUNK`*512*len ≈ 50M for a reasonably-sized proof. There is also a constraint +/// on proof chunks to be powers of two, and we don't want to compute a proof chunk +/// of zero when `TARGET_PROOF_SIZE` is smaller for tests. +fn conv_proof_chunk() -> usize { + non_zero_prev_power_of_two(max(2, TARGET_PROOF_SIZE / CONV_CHUNK / 512)) +} #[tracing::instrument(name = "compute_prf_for_inputs", skip_all)] async fn compute_prf_for_inputs( @@ -443,7 +450,7 @@ where protocol: &Step::ConvertFp25519, validate: &Step::ConvertFp25519Validate, }, - CONV_PROOF_CHUNK, + conv_proof_chunk(), ); let m_ctx = validator.context(); diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index f6bf2e339..3617ab767 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -52,6 +52,7 @@ use crate::{ replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing}, BitDecomposed, FieldSimd, SharedValue, TransposeFrom, Vectorizable, }, + utils::non_zero_prev_power_of_two, }; pub mod feature_label_dot_product; @@ -515,7 +516,10 @@ where // TODO: this override was originally added to work around problems with // read_size vs. batch size alignment. Those are now fixed (in #1332), but this // is still observed to help performance (see #1376), so has been retained. - std::cmp::min(sh_ctx.active_work().get(), chunk_size.next_power_of_two()), + std::cmp::min( + sh_ctx.active_work().get(), + non_zero_prev_power_of_two(chunk_size), + ), ); dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap()); let ctx_for_row_number = set_up_contexts(&dzkp_validator.context(), histogram)?; diff --git a/ipa-core/src/protocol/ipa_prf/quicksort.rs b/ipa-core/src/protocol/ipa_prf/quicksort.rs index 943dfb1ec..26131224d 100644 --- a/ipa-core/src/protocol/ipa_prf/quicksort.rs +++ b/ipa-core/src/protocol/ipa_prf/quicksort.rs @@ -30,6 +30,7 @@ use crate::{ Vectorizable, }, seq_join::seq_join, + utils::non_zero_prev_power_of_two, }; impl ChunkBuffer for (Vec>, Vec>) @@ -98,7 +99,7 @@ where } fn quicksort_proof_chunk(key_bits: usize) -> usize { - (TARGET_PROOF_SIZE / key_bits / SORT_CHUNK).next_power_of_two() + non_zero_prev_power_of_two(TARGET_PROOF_SIZE / key_bits / SORT_CHUNK) } /// Insecure quicksort using MPC comparisons and a key extraction function `get_key`. diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs index cb2754e5f..0a658614e 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs @@ -1,9 +1,8 @@ -use std::{array, iter::zip}; +use std::{array, iter::zip, ops::Mul}; -use typenum::{UInt, UTerm, Unsigned, B0, B1}; +use typenum::{Unsigned, U, U8}; use crate::{ - const_assert_eq, error::Error, ff::{Fp61BitPrime, Serializable}, helpers::{Direction, MpcMessage, TotalRecords}, @@ -13,9 +12,9 @@ use crate::{ dzkp_validator::MAX_PROOF_RECURSION, Context, }, - ipa_prf::malicious_security::{ - lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, - prover::{LargeProofGenerator, SmallProofGenerator}, + ipa_prf::{ + malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + CompressedProofGenerator, FirstProofGenerator, }, prss::SharedRandomness, RecordId, RecordIdRange, @@ -25,8 +24,8 @@ use crate::{ /// This a `ProofBatch` generated by a prover. pub struct ProofBatch { - pub first_proof: [Fp61BitPrime; LargeProofGenerator::PROOF_LENGTH], - pub proofs: Vec<[Fp61BitPrime; SmallProofGenerator::PROOF_LENGTH]>, + pub first_proof: [Fp61BitPrime; FirstProofGenerator::PROOF_LENGTH], + pub proofs: Vec<[Fp61BitPrime; CompressedProofGenerator::PROOF_LENGTH]>, } impl FromIterator for ProofBatch { @@ -35,10 +34,11 @@ impl FromIterator for ProofBatch { // consume the first P elements let first_proof = iterator .by_ref() - .take(LargeProofGenerator::PROOF_LENGTH) - .collect::<[Fp61BitPrime; LargeProofGenerator::PROOF_LENGTH]>(); + .take(FirstProofGenerator::PROOF_LENGTH) + .collect::<[Fp61BitPrime; FirstProofGenerator::PROOF_LENGTH]>(); // consume the rest - let proofs = iterator.collect::>(); + let proofs = + iterator.collect::>(); ProofBatch { first_proof, proofs, @@ -51,7 +51,8 @@ impl ProofBatch { #[allow(clippy::len_without_is_empty)] #[must_use] pub fn len(&self) -> usize { - self.proofs.len() * SmallProofGenerator::PROOF_LENGTH + LargeProofGenerator::PROOF_LENGTH + FirstProofGenerator::PROOF_LENGTH + + self.proofs.len() * CompressedProofGenerator::PROOF_LENGTH } #[allow(clippy::unnecessary_box_returns)] // clippy bug? `Array` exceeds unnecessary-box-size @@ -89,19 +90,19 @@ impl ProofBatch { C: Context, I: Iterator> + Clone, { - const LRF: usize = LargeProofGenerator::RECURSION_FACTOR; - const LLL: usize = LargeProofGenerator::LAGRANGE_LENGTH; - const SRF: usize = SmallProofGenerator::RECURSION_FACTOR; - const SLL: usize = SmallProofGenerator::LAGRANGE_LENGTH; - const SPL: usize = SmallProofGenerator::PROOF_LENGTH; + const FRF: usize = FirstProofGenerator::RECURSION_FACTOR; + const FLL: usize = FirstProofGenerator::LAGRANGE_LENGTH; + const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; + const CLL: usize = CompressedProofGenerator::LAGRANGE_LENGTH; + const CPL: usize = CompressedProofGenerator::PROOF_LENGTH; // precomputation for first proof - let first_denominator = CanonicalLagrangeDenominator::::new(); - let first_lagrange_table = LagrangeTable::::from(first_denominator); + let first_denominator = CanonicalLagrangeDenominator::::new(); + let first_lagrange_table = LagrangeTable::::from(first_denominator); // generate first proof from input iterator let (mut uv_values, first_proof_from_left, my_first_proof_left_share) = - LargeProofGenerator::gen_artefacts_from_recursive_step( + FirstProofGenerator::gen_artefacts_from_recursive_step( ctx, &mut prss_record_ids, &first_lagrange_table, @@ -110,9 +111,9 @@ impl ProofBatch { // `MAX_PROOF_RECURSION - 2` because: // * The first level of recursion has already happened. - // * We need (SRF - 1) at the last level to have room for the masks. + // * We need (CRF - 1) at the last level to have room for the masks. let max_uv_values: usize = - (SRF - 1) * SRF.pow(u32::try_from(MAX_PROOF_RECURSION - 2).unwrap()); + (CRF - 1) * CRF.pow(u32::try_from(MAX_PROOF_RECURSION - 2).unwrap()); assert!( uv_values.len() <= max_uv_values, "Proof batch is too large: have {} uv_values, max is {}", @@ -122,9 +123,9 @@ impl ProofBatch { // storage for other proofs let mut my_proofs_left_shares = - Vec::<[Fp61BitPrime; SPL]>::with_capacity(MAX_PROOF_RECURSION - 1); + Vec::<[Fp61BitPrime; CPL]>::with_capacity(MAX_PROOF_RECURSION - 1); let mut shares_of_proofs_from_prover_left = - Vec::<[Fp61BitPrime; SPL]>::with_capacity(MAX_PROOF_RECURSION - 1); + Vec::<[Fp61BitPrime; CPL]>::with_capacity(MAX_PROOF_RECURSION - 1); // generate masks // Prover `P_i` and verifier `P_{i-1}` both compute p(x) @@ -138,25 +139,25 @@ impl ProofBatch { let (q_mask_from_left_prover, my_q_mask) = ctx.prss().generate_fields(prss_record_ids.expect_next()); - let denominator = CanonicalLagrangeDenominator::::new(); - let lagrange_table = LagrangeTable::::from(denominator); + let denominator = CanonicalLagrangeDenominator::::new(); + let lagrange_table = LagrangeTable::::from(denominator); // The last recursion can only include (λ - 1) u/v value pairs, because it needs to put the - // masks in the constant term. If we compress to `uv_values.len() == SRF`, then we need to - // do two more iterations: compressing SRF u/v values to 1 pair of (unmasked) u/v values, + // masks in the constant term. If we compress to `uv_values.len() == CRF`, then we need to + // do two more iterations: compressing CRF u/v values to 1 pair of (unmasked) u/v values, // and then compressing that pair and the masks to the final u/v value. // // There is a test for this corner case in validation.rs. let mut did_set_masks = false; - // recursively generate proofs via SmallProofGenerator + // recursively generate proofs via CompressedProofGenerator while !did_set_masks { - if uv_values.len() < SRF { + if uv_values.len() < CRF { did_set_masks = true; uv_values.set_masks(my_p_mask, my_q_mask).unwrap(); } let (uv_values_new, share_of_proof_from_prover_left, my_proof_left_share) = - SmallProofGenerator::gen_artefacts_from_recursive_step( + CompressedProofGenerator::gen_artefacts_from_recursive_step( ctx, &mut prss_record_ids, &lagrange_table, @@ -235,25 +236,25 @@ impl ProofBatch { inputs: I, ) -> impl Iterator< Item = ( - [Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR], - [Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR], + [Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR], + [Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR], ), > + Clone where I: Iterator> + Clone, { - assert_eq!(BLOCK_SIZE % LargeProofGenerator::RECURSION_FACTOR, 0); + assert_eq!(BLOCK_SIZE % FirstProofGenerator::RECURSION_FACTOR, 0); inputs.flat_map(|(u_block, v_block)| { - (0usize..(BLOCK_SIZE / LargeProofGenerator::RECURSION_FACTOR)).map(move |i| { + (0usize..(BLOCK_SIZE / FirstProofGenerator::RECURSION_FACTOR)).map(move |i| { ( - <[Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR]>::try_from( - &u_block[i * LargeProofGenerator::RECURSION_FACTOR - ..(i + 1) * LargeProofGenerator::RECURSION_FACTOR], + <[Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR]>::try_from( + &u_block[i * FirstProofGenerator::RECURSION_FACTOR + ..(i + 1) * FirstProofGenerator::RECURSION_FACTOR], ) .unwrap(), - <[Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR]>::try_from( - &v_block[i * LargeProofGenerator::RECURSION_FACTOR - ..(i + 1) * LargeProofGenerator::RECURSION_FACTOR], + <[Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR]>::try_from( + &v_block[i * FirstProofGenerator::RECURSION_FACTOR + ..(i + 1) * FirstProofGenerator::RECURSION_FACTOR], ) .unwrap(), ) @@ -262,21 +263,12 @@ impl ProofBatch { } } -const_assert_eq!( - MAX_PROOF_RECURSION, - 9, - "following impl valid only for MAX_PROOF_RECURSION = 9" -); - -#[rustfmt::skip] -type U1464 = UInt, B0>, B1>, B1>, B0>, B1>, B1>, B1>, B0>, B0>, B0>; - -const ARRAY_LEN: usize = 183; +const ARRAY_LEN: usize = FirstProofGenerator::PROOF_LENGTH + + (MAX_PROOF_RECURSION - 1) * CompressedProofGenerator::PROOF_LENGTH; type Array = [Fp61BitPrime; ARRAY_LEN]; impl Serializable for Box { - type Size = U1464; + type Size = as Mul>::Output; type DeserializationError = ::DeserializationError; diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs index b197dcfa3..3f656b3eb 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs @@ -5,7 +5,7 @@ use std::{ use futures_util::future::{try_join, try_join4}; use subtle::ConstantTimeEq; -use typenum::{Unsigned, U288, U80}; +use typenum::{Unsigned, U120, U448}; use crate::{ const_assert_eq, @@ -20,11 +20,11 @@ use crate::{ dzkp_validator::MAX_PROOF_RECURSION, step::DzkpProofVerifyStep as Step, Context, }, ipa_prf::{ - malicious_security::{ - prover::{LargeProofGenerator, SmallProofGenerator}, - verifier::{compute_g_differences, recursively_compute_final_check}, + malicious_security::verifier::{ + compute_g_differences, recursively_compute_final_check, }, validation_protocol::proof_generation::ProofBatch, + CompressedProofGenerator, FirstProofGenerator, }, RecordId, }, @@ -45,10 +45,10 @@ use crate::{ #[derive(Debug)] #[allow(clippy::struct_field_names)] pub struct BatchToVerify { - first_proof_from_left_prover: [Fp61BitPrime; LargeProofGenerator::PROOF_LENGTH], - first_proof_from_right_prover: [Fp61BitPrime; LargeProofGenerator::PROOF_LENGTH], - proofs_from_left_prover: Vec<[Fp61BitPrime; SmallProofGenerator::PROOF_LENGTH]>, - proofs_from_right_prover: Vec<[Fp61BitPrime; SmallProofGenerator::PROOF_LENGTH]>, + first_proof_from_left_prover: [Fp61BitPrime; FirstProofGenerator::PROOF_LENGTH], + first_proof_from_right_prover: [Fp61BitPrime; FirstProofGenerator::PROOF_LENGTH], + proofs_from_left_prover: Vec<[Fp61BitPrime; CompressedProofGenerator::PROOF_LENGTH]>, + proofs_from_right_prover: Vec<[Fp61BitPrime; CompressedProofGenerator::PROOF_LENGTH]>, p_mask_from_right_prover: Fp61BitPrime, q_mask_from_left_prover: Fp61BitPrime, } @@ -105,13 +105,13 @@ impl BatchToVerify { where C: Context, { - const LRF: usize = LargeProofGenerator::RECURSION_FACTOR; - const SRF: usize = SmallProofGenerator::RECURSION_FACTOR; + const FRF: usize = FirstProofGenerator::RECURSION_FACTOR; + const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; // exclude for first proof - let exclude_large = u128::try_from(LRF).unwrap(); + let exclude_large = u128::try_from(FRF).unwrap(); // exclude for other proofs - let exclude_small = u128::try_from(SRF).unwrap(); + let exclude_small = u128::try_from(CRF).unwrap(); // generate hashes let my_hashes_prover_left = ProofHashes::generate_hashes(self, Direction::Left); @@ -175,17 +175,17 @@ impl BatchToVerify { U: Iterator + Send, V: Iterator + Send, { - const LRF: usize = LargeProofGenerator::RECURSION_FACTOR; - const SRF: usize = SmallProofGenerator::RECURSION_FACTOR; + const FRF: usize = FirstProofGenerator::RECURSION_FACTOR; + const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; // compute p_r - let p_r_right_prover = recursively_compute_final_check::<_, _, LRF, SRF>( + let p_r_right_prover = recursively_compute_final_check::<_, _, FRF, CRF>( u_from_right_prover.into_iter(), challenges_for_right_prover, self.p_mask_from_right_prover, ); // compute q_r - let q_r_left_prover = recursively_compute_final_check::<_, _, LRF, SRF>( + let q_r_left_prover = recursively_compute_final_check::<_, _, FRF, CRF>( v_from_left_prover.into_iter(), challenges_for_left_prover, self.q_mask_from_left_prover, @@ -242,11 +242,11 @@ impl BatchToVerify { where C: Context, { - const LRF: usize = LargeProofGenerator::RECURSION_FACTOR; - const SRF: usize = SmallProofGenerator::RECURSION_FACTOR; + const FRF: usize = FirstProofGenerator::RECURSION_FACTOR; + const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; - const LPL: usize = LargeProofGenerator::PROOF_LENGTH; - const SPL: usize = SmallProofGenerator::PROOF_LENGTH; + const FPL: usize = FirstProofGenerator::PROOF_LENGTH; + const CPL: usize = CompressedProofGenerator::PROOF_LENGTH; let p_times_q_right = Self::compute_p_times_q( ctx.narrow(&Step::PTimesQ), @@ -257,7 +257,7 @@ impl BatchToVerify { .await?; // add Zero for p_times_q and sum since they are not secret shared - let diff_left = compute_g_differences::<_, SPL, SRF, LPL, LRF>( + let diff_left = compute_g_differences::<_, CPL, CRF, FPL, FRF>( &self.first_proof_from_left_prover, &self.proofs_from_left_prover, challenges_for_left_prover, @@ -265,7 +265,7 @@ impl BatchToVerify { Fp61BitPrime::ZERO, ); - let diff_right = compute_g_differences::<_, SPL, SRF, LPL, LRF>( + let diff_right = compute_g_differences::<_, CPL, CRF, FPL, FRF>( &self.first_proof_from_right_prover, &self.proofs_from_right_prover, challenges_for_right_prover, @@ -375,12 +375,12 @@ impl ProofHashes { const_assert_eq!( MAX_PROOF_RECURSION, - 9, - "following impl valid only for MAX_PROOF_RECURSION = 9" + 14, + "following impl valid only for MAX_PROOF_RECURSION = 14" ); impl Serializable for [Hash; MAX_PROOF_RECURSION] { - type Size = U288; + type Size = U448; type DeserializationError = ::DeserializationError; @@ -409,14 +409,14 @@ impl MpcMessage for [Hash; MAX_PROOF_RECURSION] {} const_assert_eq!( MAX_PROOF_RECURSION, - 9, - "following impl valid only for MAX_PROOF_RECURSION = 9" + 14, + "following impl valid only for MAX_PROOF_RECURSION = 14" ); type ProofDiff = [Fp61BitPrime; MAX_PROOF_RECURSION + 1]; impl Serializable for ProofDiff { - type Size = U80; + type Size = U120; type DeserializationError = ::DeserializationError; @@ -459,10 +459,10 @@ pub mod test { ipa_prf::{ malicious_security::{ lagrange::CanonicalLagrangeDenominator, - prover::{LargeProofGenerator, SmallProofGenerator}, verifier::{compute_sum_share, interpolate_at_r}, }, validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, + CompressedProofGenerator, FirstProofGenerator, }, prss::SharedRandomness, RecordId, RecordIdRange, @@ -484,7 +484,7 @@ pub mod test { // first proof has correct length assert_eq!( left_verifier.first_proof_from_left_prover.len(), - LargeProofGenerator::PROOF_LENGTH + FirstProofGenerator::PROOF_LENGTH ); assert_eq!( left_verifier.first_proof_from_left_prover.len(), @@ -494,7 +494,7 @@ pub mod test { for i in 0..left_verifier.proofs_from_left_prover.len() { assert_eq!( (i, left_verifier.proofs_from_left_prover[i].len()), - (i, SmallProofGenerator::PROOF_LENGTH) + (i, CompressedProofGenerator::PROOF_LENGTH) ); assert_eq!( (i, left_verifier.proofs_from_left_prover[i].len()), @@ -513,29 +513,29 @@ pub mod test { // check first proof, // compute simple proof without lagrange interpolated points let simple_proof = { - let block_to_polynomial = BLOCK_SIZE / LargeProofGenerator::RECURSION_FACTOR; + let block_to_polynomial = BLOCK_SIZE / FirstProofGenerator::RECURSION_FACTOR; let simple_proof_uv = (0usize..100 * block_to_polynomial) .map(|i| { ( - (LargeProofGenerator::RECURSION_FACTOR * i - ..LargeProofGenerator::RECURSION_FACTOR * (i + 1)) + (FirstProofGenerator::RECURSION_FACTOR * i + ..FirstProofGenerator::RECURSION_FACTOR * (i + 1)) .map(|j| Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) * h) - .collect::<[Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR]>(), - (LargeProofGenerator::RECURSION_FACTOR * i - ..LargeProofGenerator::RECURSION_FACTOR * (i + 1)) + .collect::<[Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR]>(), + (FirstProofGenerator::RECURSION_FACTOR * i + ..FirstProofGenerator::RECURSION_FACTOR * (i + 1)) .map(|j| Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) * h) - .collect::<[Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR]>(), + .collect::<[Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR]>(), ) }) .collect::>(); simple_proof_uv.iter().fold( - [Fp61BitPrime::ZERO; LargeProofGenerator::RECURSION_FACTOR], + [Fp61BitPrime::ZERO; FirstProofGenerator::RECURSION_FACTOR], |mut acc, (left, right)| { - for i in 0..LargeProofGenerator::RECURSION_FACTOR { + for i in 0..FirstProofGenerator::RECURSION_FACTOR { acc[i] += left[i] * right[i]; } acc @@ -558,7 +558,7 @@ pub mod test { (h.as_u128(), simple_proof.to_vec()), ( h.as_u128(), - proof_computed[0..LargeProofGenerator::RECURSION_FACTOR].to_vec() + proof_computed[0..FirstProofGenerator::RECURSION_FACTOR].to_vec() ) ); } @@ -777,9 +777,9 @@ pub mod test { } fn assert_batch(left: &BatchToVerify, right: &BatchToVerify, challenges: &[Fp61BitPrime]) { - const SRF: usize = SmallProofGenerator::RECURSION_FACTOR; - const SPL: usize = SmallProofGenerator::PROOF_LENGTH; - const LPL: usize = LargeProofGenerator::PROOF_LENGTH; + const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; + const CPL: usize = CompressedProofGenerator::PROOF_LENGTH; + const FPL: usize = FirstProofGenerator::PROOF_LENGTH; let first = recombine( &left.first_proof_from_left_prover, @@ -791,19 +791,19 @@ pub mod test { .zip(right.proofs_from_right_prover.iter()) .map(|(left, right)| recombine(left, right)) .collect::>(); - let denominator_first = CanonicalLagrangeDenominator::<_, LPL>::new(); - let denominator = CanonicalLagrangeDenominator::<_, SPL>::new(); + let denominator_first = CanonicalLagrangeDenominator::<_, FPL>::new(); + let denominator = CanonicalLagrangeDenominator::<_, CPL>::new(); let length = others.len(); let mut out = interpolate_at_r(&first, &challenges[0], &denominator_first); for (i, proof) in others.iter().take(length - 1).enumerate() { - assert_eq!((i, out), (i, compute_sum_share::<_, SRF, SPL>(proof))); + assert_eq!((i, out), (i, compute_sum_share::<_, CRF, CPL>(proof))); out = interpolate_at_r(proof, &challenges[i + 1], &denominator); } // last sum without masks let masks = others[length - 1][0]; - let last_sum = compute_sum_share::<_, SRF, SPL>(&others[length - 1]); + let last_sum = compute_sum_share::<_, CRF, CPL>(&others[length - 1]); assert_eq!(out, last_sum - masks); } @@ -869,7 +869,7 @@ pub mod test { let denominator = CanonicalLagrangeDenominator::< Fp61BitPrime, - { SmallProofGenerator::PROOF_LENGTH }, + { CompressedProofGenerator::PROOF_LENGTH }, >::new(); let g_r_left = interpolate_at_r( @@ -988,9 +988,15 @@ pub mod test { // Test a batch that exercises the case where `uv_values.len() == 1` but `did_set_masks = // false` in `ProofBatch::generate`. - verify_batch( - LargeProofGenerator::RECURSION_FACTOR * SmallProofGenerator::RECURSION_FACTOR - / BLOCK_SIZE, - ); + // + // We divide by `BLOCK_SIZE` here because `generate_u_v`, which is used by + // `verify_batch` to generate test data, generates `len` chunks of u/v values of + // length `BLOCK_SIZE`. We want the input u/v values to compress to exactly one + // u/v pair after some number of proof steps. + let num_inputs = FirstProofGenerator::RECURSION_FACTOR + * CompressedProofGenerator::RECURSION_FACTOR + * CompressedProofGenerator::RECURSION_FACTOR; + assert!(num_inputs % BLOCK_SIZE == 0); + verify_batch(num_inputs / BLOCK_SIZE); } } diff --git a/ipa-core/src/utils/mod.rs b/ipa-core/src/utils/mod.rs index e8dfd95ae..6829f57fa 100644 --- a/ipa-core/src/utils/mod.rs +++ b/ipa-core/src/utils/mod.rs @@ -4,4 +4,4 @@ pub mod arraychunks; mod power_of_two; #[cfg(target_pointer_width = "64")] -pub use power_of_two::NonZeroU32PowerOfTwo; +pub use power_of_two::{non_zero_prev_power_of_two, NonZeroU32PowerOfTwo}; diff --git a/ipa-core/src/utils/power_of_two.rs b/ipa-core/src/utils/power_of_two.rs index a84455c92..b34ac0423 100644 --- a/ipa-core/src/utils/power_of_two.rs +++ b/ipa-core/src/utils/power_of_two.rs @@ -68,9 +68,19 @@ impl NonZeroU32PowerOfTwo { } } +/// Returns the largest power of two less than or equal to `target`. +/// +/// Returns 1 if `target` is zero. +pub fn non_zero_prev_power_of_two(target: usize) -> usize { + let bits = usize::BITS - target.leading_zeros(); + + 1 << (std::cmp::max(1, bits) - 1) +} + #[cfg(all(test, unit_test))] mod tests { use super::{ConvertError, NonZeroU32PowerOfTwo}; + use crate::utils::power_of_two::non_zero_prev_power_of_two; #[test] fn rejects_invalid_values() { @@ -107,4 +117,19 @@ mod tests { "3".parse::().unwrap_err() ); } + + #[test] + fn test_prev_power_of_two() { + const TWO_EXP_62: usize = 1usize << (usize::BITS - 2); + const TWO_EXP_63: usize = 1usize << (usize::BITS - 1); + assert_eq!(non_zero_prev_power_of_two(0), 1usize); + assert_eq!(non_zero_prev_power_of_two(1), 1usize); + assert_eq!(non_zero_prev_power_of_two(2), 2usize); + assert_eq!(non_zero_prev_power_of_two(3), 2usize); + assert_eq!(non_zero_prev_power_of_two(4), 4usize); + assert_eq!(non_zero_prev_power_of_two(TWO_EXP_63 - 1), TWO_EXP_62); + assert_eq!(non_zero_prev_power_of_two(TWO_EXP_63), TWO_EXP_63); + assert_eq!(non_zero_prev_power_of_two(TWO_EXP_63 + 1), TWO_EXP_63); + assert_eq!(non_zero_prev_power_of_two(usize::MAX), TWO_EXP_63); + } } From 8452f63ab5090c6673b8f49572979de34352cdd6 Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Tue, 19 Nov 2024 16:20:49 -0800 Subject: [PATCH 22/47] Breakdown key reveal (#1445) * add breakdown_key reveal protocol for hybrid * update test to use shards * add semi-honest test for breakdown_key reveal that works with shuttle --- .../src/protocol/hybrid/breakdown_reveal.rs | 264 +++++++----------- ipa-core/src/protocol/hybrid/mod.rs | 25 +- .../src/protocol/ipa_prf/oprf_padding/mod.rs | 74 +++++ ipa-core/src/query/runner/hybrid.rs | 4 +- ipa-core/src/report/hybrid.rs | 66 ++++- ipa-core/src/test_fixture/hybrid.rs | 25 +- ipa-core/src/test_fixture/sharing.rs | 9 + 7 files changed, 301 insertions(+), 166 deletions(-) diff --git a/ipa-core/src/protocol/hybrid/breakdown_reveal.rs b/ipa-core/src/protocol/hybrid/breakdown_reveal.rs index 00d4c36af..cb9e4caf6 100644 --- a/ipa-core/src/protocol/hybrid/breakdown_reveal.rs +++ b/ipa-core/src/protocol/hybrid/breakdown_reveal.rs @@ -4,101 +4,34 @@ use futures::stream; use futures_util::{StreamExt, TryStreamExt}; use tracing::{info_span, Instrument}; -use super::aggregate_values; use crate::{ error::{Error, UnwrapInfallible}, - ff::{ - boolean::Boolean, - boolean_array::{BooleanArray, BooleanArrayReader, BooleanArrayWriter, BA32}, - U128Conversions, - }, + ff::{boolean::Boolean, boolean_array::BooleanArray, U128Conversions}, helpers::TotalRecords, protocol::{ basics::{reveal, Reveal}, context::{ dzkp_validator::DZKPValidator, Context, DZKPUpgraded, MaliciousProtocolSteps, - UpgradableContext, + ShardedContext, UpgradableContext, }, ipa_prf::{ aggregation::{ - aggregate_values_proof_chunk, step::AggregationStep as Step, AGGREGATE_DEPTH, + aggregate_values, aggregate_values_proof_chunk, step::AggregationStep as Step, + AGGREGATE_DEPTH, }, oprf_padding::{apply_dp_padding, PaddingParameters}, - prf_sharding::{AttributionOutputs, SecretSharedAttributionOutputs}, - shuffle::{Shuffle, Shuffleable}, - BreakdownKey, + shuffle::Shuffle, }, BooleanProtocols, RecordId, }, + report::hybrid::AggregateableHybridReport, secret_sharing::{ - replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing}, - BitDecomposed, FieldSimd, SharedValue, TransposeFrom, Vectorizable, + replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, FieldSimd, + TransposeFrom, Vectorizable, }, seq_join::seq_join, }; -impl AttributionOutputs, Replicated> -where - BK: BooleanArray, - TV: BooleanArray, -{ - fn join_fields(breakdown_key: BK, trigger_value: TV) -> ::Share { - let mut share = ::Share::ZERO; - - BooleanArrayWriter::new(&mut share) - .write(&breakdown_key) - .write(&trigger_value); - - share - } - - fn split_fields(share: &::Share) -> (BK, TV) { - let bits = BooleanArrayReader::new(share); - let (breakdown_key, bits) = bits.read(); - let (trigger_value, _bits) = bits.read(); - (breakdown_key, trigger_value) - } -} - -impl Shuffleable for AttributionOutputs, Replicated> -where - BK: BooleanArray, - TV: BooleanArray, -{ - /// TODO: Use a smaller BA type to contain BK and TV - type Share = BA32; - - fn left(&self) -> Self::Share { - Self::join_fields( - ReplicatedSecretSharing::left(&self.attributed_breakdown_key_bits), - ReplicatedSecretSharing::left(&self.capped_attributed_trigger_value), - ) - } - - fn right(&self) -> Self::Share { - Self::join_fields( - ReplicatedSecretSharing::right(&self.attributed_breakdown_key_bits), - ReplicatedSecretSharing::right(&self.capped_attributed_trigger_value), - ) - } - - fn new(l: Self::Share, r: Self::Share) -> Self { - debug_assert!( - BK::BITS + TV::BITS <= Self::Share::BITS, - "share type {} is too small", - std::any::type_name::(), - ); - - let left = Self::split_fields(&l); - let right = Self::split_fields(&r); - - Self { - attributed_breakdown_key_bits: ReplicatedSecretSharing::new(left.0, right.0), - capped_attributed_trigger_value: ReplicatedSecretSharing::new(left.1, right.1), - } - } -} - /// Improved Aggregation a.k.a Aggregation revealing breakdown. /// /// Aggregation steps happen after attribution. the input for Aggregation is a @@ -127,30 +60,29 @@ where /// record. This is currently ensured by the serial operation of the aggregation /// protocol (i.e. by not using `seq_join`). #[tracing::instrument(name = "breakdown_reveal_aggregation", skip_all, fields(total = attributed_values.len()))] -pub async fn breakdown_reveal_aggregation( +pub async fn breakdown_reveal_aggregation( ctx: C, - attributed_values: Vec>, + attributed_values: Vec>, padding_params: &PaddingParameters, ) -> Result>, Error> where - C: UpgradableContext + Shuffle, + C: UpgradableContext + Shuffle + ShardedContext, Boolean: FieldSimd, Replicated: BooleanProtocols, B>, - BK: BreakdownKey, + BK: BooleanArray + U128Conversions, Replicated: Reveal, Output = >::Array>, - TV: BooleanArray + U128Conversions, + V: BooleanArray + U128Conversions, HV: BooleanArray + U128Conversions, BitDecomposed>: - for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, + for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, { // Apply DP padding for Breakdown Reveal Aggregation - let attributed_values_padded = - apply_dp_padding::<_, AttributionOutputs, Replicated>, B>( - ctx.narrow(&Step::PaddingDp), - attributed_values, - padding_params, - ) - .await?; + let attributed_values_padded = apply_dp_padding::<_, AggregateableHybridReport, B>( + ctx.narrow(&Step::PaddingDp), + attributed_values, + padding_params, + ) + .await?; let attributions = ctx .narrow(&Step::Shuffle) @@ -176,7 +108,7 @@ where // may exceed that. let mut chunk_counter = 0; let mut depth = 0; - let agg_proof_chunk = aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()); + let agg_proof_chunk = aggregate_values_proof_chunk(B, usize::try_from(V::BITS).unwrap()); while intermediate_results.len() > 1 { let mut record_ids = [RecordId::FIRST; AGGREGATE_DEPTH]; @@ -220,34 +152,33 @@ where #[tracing::instrument(name = "reveal_breakdowns", skip_all, fields( total = attributions.len(), ))] -async fn reveal_breakdowns( +async fn reveal_breakdowns( parent_ctx: &C, - attributions: Vec>, -) -> Result, Error> + attributions: Vec>, +) -> Result, Error> where C: Context, Replicated: BooleanProtocols, Boolean: FieldSimd, - BK: BreakdownKey, + BK: BooleanArray + U128Conversions, Replicated: Reveal>::Array>, - TV: BooleanArray + U128Conversions, + V: BooleanArray + U128Conversions, { let reveal_ctx = parent_ctx.set_total_records(TotalRecords::specified(attributions.len())?); - let reveal_work = stream::iter(attributions).enumerate().map(|(i, ao)| { + let reveal_work = stream::iter(attributions).enumerate().map(|(i, report)| { let record_id = RecordId::from(i); let reveal_ctx = reveal_ctx.clone(); async move { - let revealed_bk = - reveal(reveal_ctx, record_id, &ao.attributed_breakdown_key_bits).await?; + let revealed_bk = reveal(reveal_ctx, record_id, &report.breakdown_key).await?; let revealed_bk = BK::from_array(&revealed_bk); let Ok(bk) = usize::try_from(revealed_bk.as_u128()) else { return Err(Error::Internal); }; - Ok::<_, Error>((bk, ao.capped_attributed_trigger_value)) + Ok::<_, Error>((bk, report.value)) } }); - let mut grouped_tvs = GroupedTriggerValues::::new(); + let mut grouped_tvs = ValueHistogram::::new(); let mut stream = pin!(seq_join(reveal_ctx.active_work(), reveal_work)); while let Some((bk, tv)) = stream.try_next().await? { grouped_tvs.push(bk, tv); @@ -259,12 +190,12 @@ where /// Helper type that hold all the Trigger Values, grouped by their Breakdown /// Key. The main functionality is to turn into a stream that can be given to /// [`aggregate_values`]. -struct GroupedTriggerValues { - tvs: [Vec>; B], +struct ValueHistogram { + tvs: [Vec>; B], max_len: usize, } -impl GroupedTriggerValues { +impl ValueHistogram { fn new() -> Self { Self { tvs: std::array::from_fn(|_| vec![]), @@ -272,7 +203,7 @@ impl GroupedTriggerValues { } } - fn push(&mut self, bk: usize, value: Replicated) { + fn push(&mut self, bk: usize, value: Replicated) { self.tvs[bk].push(value); if self.tvs[bk].len() > self.max_len { self.max_len = self.tvs[bk].len(); @@ -280,18 +211,16 @@ impl GroupedTriggerValues { } } -impl From> +impl From> for Vec>> where Boolean: FieldSimd, BitDecomposed>: - for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, + for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, { - fn from( - mut grouped_tvs: GroupedTriggerValues, - ) -> Vec>> { + fn from(mut grouped_tvs: ValueHistogram) -> Vec>> { let iter = (0..grouped_tvs.max_len).map(move |_| { - let slice: [Replicated; B] = grouped_tvs + let slice: [Replicated; B] = grouped_tvs .tvs .each_mut() .map(|tv| tv.pop().unwrap_or(Replicated::ZERO)); @@ -315,35 +244,38 @@ pub mod tests { boolean_array::{BA3, BA5, BA8}, U128Conversions, }, - protocol::ipa_prf::{ - aggregation::breakdown_reveal::breakdown_reveal_aggregation, - oprf_padding::PaddingParameters, - prf_sharding::{AttributionOutputsTestInput, SecretSharedAttributionOutputs}, + protocol::{ + hybrid::breakdown_reveal_aggregation, ipa_prf::oprf_padding::PaddingParameters, }, rand::Rng, secret_sharing::{ replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, TransposeFrom, }, test_executor::run_with, - test_fixture::{Reconstruct, Runner, TestWorld}, + test_fixture::{ + hybrid::TestAggregateableHybridReport, Reconstruct, Runner, TestWorld, TestWorldConfig, + WithShards, + }, }; - fn input_row(bk: usize, tv: u128) -> AttributionOutputsTestInput { - let bk: u128 = bk.try_into().unwrap(); - AttributionOutputsTestInput { - bk: BA5::truncate_from(bk), - tv: BA3::truncate_from(tv), + fn input_row(breakdown_key: usize, value: u128) -> TestAggregateableHybridReport { + TestAggregateableHybridReport { + match_key: (), + value: value.try_into().unwrap(), + breakdown_key: breakdown_key.try_into().unwrap(), } } #[test] - fn semi_honest_happy_path() { + fn breakdown_reveal_semi_honest_happy_path() { // if shuttle executor is enabled, run this test only once. // it is a very expensive test to explore all possible states, // sometimes github bails after 40 minutes of running it // (workers there are really slow). + type HV = BA8; + const SHARDS: usize = 2; run_with::<_, _, 3>(|| async { - let world = TestWorld::default(); + let world = TestWorld::>::with_shards(TestWorldConfig::default()); let mut rng = world.rng(); let mut expectation = Vec::new(); for _ in 0..32 { @@ -362,30 +294,32 @@ pub mod tests { } inputs.shuffle(&mut rng); let result: Vec<_> = world - .semi_honest(inputs.into_iter(), |ctx, input_rows| async move { - let aos = input_rows - .into_iter() - .map(|ti| SecretSharedAttributionOutputs { - attributed_breakdown_key_bits: ti.0, - capped_attributed_trigger_value: ti.1, - }) - .collect(); - let r: Vec> = - breakdown_reveal_aggregation::<_, BA5, BA3, BA8, 32>( - ctx, - aos, - &PaddingParameters::relaxed(), - ) - .map_ok(|d: BitDecomposed>| { - Vec::transposed_from(&d).unwrap() - }) - .await - .unwrap(); - r + .semi_honest(inputs.into_iter(), |ctx, reports| async move { + breakdown_reveal_aggregation::<_, BA5, BA3, HV, 32>( + ctx, + reports, + &PaddingParameters::relaxed(), + ) + .map_ok(|d: BitDecomposed>| { + Vec::transposed_from(&d).unwrap() + }) + .await + .unwrap() }) .await .reconstruct(); - let result = result.iter().map(|&v| v.as_u128()).collect::>(); + let initial = vec![0_u128; 32]; + let result = result + .iter() + .fold(initial, |mut acc, vec: &Vec| { + acc.iter_mut() + .zip(vec) + .for_each(|(a, &b)| *a += b.as_u128()); + acc + }) + .into_iter() + .collect::>(); + assert_eq!(32, result.len()); assert_eq!(result, expectation); }); @@ -393,10 +327,11 @@ pub mod tests { #[test] #[cfg(not(feature = "shuttle"))] // too slow - fn malicious_happy_path() { + fn breakdown_reveal_malicious_happy_path() { type HV = BA16; + const SHARDS: usize = 2; run(|| async { - let world = TestWorld::default(); + let world = TestWorld::>::with_shards(TestWorldConfig::default()); let mut rng = world.rng(); let mut expectation = Vec::new(); for _ in 0..32 { @@ -407,28 +342,26 @@ pub mod tests { // depends on `TARGET_PROOF_SIZE`. let expectation = expectation; // no more mutability for safety let mut inputs = Vec::new(); - for (bk, expected_hv) in expectation.iter().enumerate() { - let mut remainder = *expected_hv; + // Builds out inputs with values for each breakdown_key that add up to + // the expectation. Expectation is ranomg (0..512). Each iteration + // generates a value (0..8) and subtracts from the expectation until a final + // remaninder in (0..8) remains to be added to the vec. + for (breakdown_key, expected_value) in expectation.iter().enumerate() { + let mut remainder = *expected_value; while remainder > 7 { - let tv = rng.gen_range(0u128..8); - remainder -= tv; - inputs.push(input_row(bk, tv)); + let value = rng.gen_range(0u128..8); + remainder -= value; + inputs.push(input_row(breakdown_key, value)); } - inputs.push(input_row(bk, remainder)); + inputs.push(input_row(breakdown_key, remainder)); } inputs.shuffle(&mut rng); + let result: Vec<_> = world - .malicious(inputs.into_iter(), |ctx, input_rows| async move { - let aos = input_rows - .into_iter() - .map(|ti| SecretSharedAttributionOutputs { - attributed_breakdown_key_bits: ti.0, - capped_attributed_trigger_value: ti.1, - }) - .collect(); + .malicious(inputs.into_iter(), |ctx, reports| async move { breakdown_reveal_aggregation::<_, BA5, BA3, HV, 32>( ctx, - aos, + reports, &PaddingParameters::relaxed(), ) .map_ok(|d: BitDecomposed>| { @@ -439,7 +372,18 @@ pub mod tests { }) .await .reconstruct(); - let result = result.iter().map(|v: &HV| v.as_u128()).collect::>(); + + let initial = vec![0_u128; 32]; + let result = result + .iter() + .fold(initial, |mut acc, vec: &Vec| { + acc.iter_mut() + .zip(vec) + .for_each(|(a, &b)| *a += b.as_u128()); + acc + }) + .into_iter() + .collect::>(); assert_eq!(32, result.len()); assert_eq!(result, expectation); }); diff --git a/ipa-core/src/protocol/hybrid/mod.rs b/ipa-core/src/protocol/hybrid/mod.rs index cc63df8a5..d9f43f7dd 100644 --- a/ipa-core/src/protocol/hybrid/mod.rs +++ b/ipa-core/src/protocol/hybrid/mod.rs @@ -1,7 +1,10 @@ pub(crate) mod agg; +pub(crate) mod breakdown_reveal; pub(crate) mod oprf; pub(crate) mod step; +use std::convert::Infallible; + use tracing::{info_span, Instrument}; use crate::{ @@ -12,10 +15,11 @@ use crate::{ }, helpers::query::DpMechanism, protocol::{ - basics::Reveal, + basics::{BooleanArrayMul, Reveal}, context::{DZKPUpgraded, MacUpgraded, ShardedContext, UpgradableContext}, hybrid::{ agg::aggregate_reports, + breakdown_reveal::breakdown_reveal_aggregation, oprf::{compute_prf_and_reshard, BreakdownKey, CONV_CHUNK, PRF_CHUNK}, step::HybridStep as Step, }, @@ -28,7 +32,10 @@ use crate::{ BooleanProtocols, }, report::hybrid::{IndistinguishableHybridReport, PrfHybridReport}, - secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, Vectorizable}, + secret_sharing::{ + replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, FieldSimd, + TransposeFrom, Vectorizable, + }, }; /// The Hybrid Protocol @@ -67,6 +74,7 @@ where BK: BreakdownKey, V: BooleanArray + U128Conversions, HV: BooleanArray + U128Conversions, + Boolean: FieldSimd, Replicated: BooleanProtocols, CONV_CHUNK>, Replicated: PrfSharing, PRF_CHUNK, Field = Fp25519> + FromPrss, @@ -74,6 +82,11 @@ where Reveal, Output = >::Array>, PrfHybridReport: Serializable, Replicated: BooleanProtocols>, + Replicated: BooleanProtocols, B>, + Replicated: BooleanArrayMul> + + Reveal, Output = >::Array>, + BitDecomposed>: + for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, { if input_rows.is_empty() { return Ok(vec![Replicated::ZERO; B]); @@ -95,7 +108,13 @@ where let sharded_reports = compute_prf_and_reshard(ctx.clone(), shuffled_input_rows).await?; - let _aggregated_reports = aggregate_reports::(ctx.clone(), sharded_reports); + let aggregated_reports = aggregate_reports::(ctx.clone(), sharded_reports).await?; + + let _historgram = breakdown_reveal_aggregation::( + ctx.clone(), + aggregated_reports, + &dp_padding_params, + ); unimplemented!("protocol::hybrid::hybrid_protocol is not fully implemented") } diff --git a/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs b/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs index 6b0ee88f6..099fed80d 100644 --- a/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs @@ -200,6 +200,80 @@ where } } +impl Paddable for IndistinguishableHybridReport +where + BK: BooleanArray + U128Conversions, + V: BooleanArray, +{ + /// Given an extendable collection of `AggregateableReports`s, + /// this function will pad the collection with dummy reports. The reports + /// cover every `breakdown_key` and have a secret sharing of zero for the `value`. + /// Each `breakdown_key` receives a random number of a rows. + fn add_padding_items, const B: usize>( + direction_to_excluded_helper: Direction, + padding_input_rows: &mut VC, + padding_params: &PaddingParameters, + rng: &mut InstrumentedSequentialSharedRandomness, + ) -> Result { + let mut total_number_of_fake_rows = 0; + match padding_params.aggregation_padding { + AggregationPadding::NoAggPadding => {} + AggregationPadding::Parameters { + aggregation_epsilon, + aggregation_delta, + aggregation_padding_sensitivity, + } => { + let aggregation_padding = OPRFPaddingDp::new( + aggregation_epsilon, + aggregation_delta, + aggregation_padding_sensitivity, + )?; + let num_breakdowns: u32 = u32::try_from(B).unwrap(); + // for every breakdown, sample how many dummies will be added + for breakdownkey in 0..num_breakdowns { + let sample = aggregation_padding.sample(rng); + total_number_of_fake_rows += sample; + + // now add `sample` many fake rows with this `breakdownkey` + for _ in 0..sample { + let breakdownkey_shares = match direction_to_excluded_helper { + Direction::Left => AdditiveShare::new( + BK::ZERO, + BK::truncate_from(u128::from(breakdownkey)), + ), + Direction::Right => AdditiveShare::new( + BK::truncate_from(u128::from(breakdownkey)), + BK::ZERO, + ), + }; + + let row = IndistinguishableHybridReport:: { + match_key: (), + value: AdditiveShare::new(V::ZERO, V::ZERO), + breakdown_key: breakdownkey_shares, + }; + + padding_input_rows.extend(std::iter::once(row)); + } + } + } + } + Ok(total_number_of_fake_rows) + } + + /// Given an extendable collection of `IndistinguishableHybridReport`s, + /// this function ads `total_number_of_fake_rows` of Reports with zeros in all fields. + fn add_zero_shares>( + padding_input_rows: &mut VC, + total_number_of_fake_rows: u32, + ) { + padding_input_rows.extend( + repeat(IndistinguishableHybridReport::::ZERO) + .take(total_number_of_fake_rows as usize), + ); + } +} + impl Paddable for OPRFIPAInputRow where BK: BooleanArray + U128Conversions, diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index 0ce953f03..8a9f375be 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -17,7 +17,7 @@ use crate::{ }, hpke::PrivateKeyRegistry, protocol::{ - basics::{BooleanProtocols, Reveal}, + basics::{BooleanArrayMul, BooleanProtocols, Reveal}, context::{DZKPUpgraded, MacUpgraded, ShardedContext, UpgradableContext}, hybrid::{ hybrid_protocol, @@ -73,6 +73,8 @@ where Replicated: Reveal, Output = >::Array>, Replicated: BooleanProtocols>, + Replicated: BooleanArrayMul> + + Reveal, Output = >::Array>, { #[tracing::instrument("hybrid_query", skip_all, fields(sz=%query_size))] pub async fn execute( diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index be2bd1d3f..511a0418b 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -40,7 +40,7 @@ use crate::{ error::{BoxError, Error}, ff::{ boolean_array::{ - BooleanArray, BooleanArrayReader, BooleanArrayWriter, BA112, BA3, BA64, BA8, + BooleanArray, BooleanArrayReader, BooleanArrayWriter, BA112, BA3, BA32, BA64, BA8, }, Serializable, }, @@ -685,6 +685,35 @@ pub type PrfHybridReport = IndistinguishableHybridReport; /// that OPRF value is no longer required. pub type AggregateableHybridReport = IndistinguishableHybridReport; +impl IndistinguishableHybridReport +where + BK: BooleanArray, + V: BooleanArray, +{ + pub const ZERO: Self = Self { + match_key: (), + value: Replicated::::ZERO, + breakdown_key: Replicated::::ZERO, + }; + + fn join_fields(value: V, breakdown_key: BK) -> ::Share { + let mut share = ::Share::ZERO; + + BooleanArrayWriter::new(&mut share) + .write(&value) + .write(&breakdown_key); + + share + } + + fn split_fields(share: &::Share) -> (V, BK) { + let bits = BooleanArrayReader::new(share); + let (value, bits) = bits.read(); + let (breakdown_key, _) = bits.read(); + (value, breakdown_key) + } +} + /// When aggregating reports, we need to lift the value from `V` to `HV`. impl From> for AggregateableHybridReport where @@ -852,6 +881,41 @@ where } } +impl Shuffleable for IndistinguishableHybridReport +where + BK: BooleanArray, + V: BooleanArray, +{ + // this requires BK:BAXX + V:BAYY such that XX + YY <= 32 + // this is checked in a debud_assert call in ::new below + type Share = BA32; + + fn left(&self) -> Self::Share { + Self::join_fields(self.value.left(), self.breakdown_key.left()) + } + + fn right(&self) -> Self::Share { + Self::join_fields(self.value.right(), self.breakdown_key.right()) + } + + fn new(l: Self::Share, r: Self::Share) -> Self { + debug_assert!( + BK::BITS + V::BITS <= Self::Share::BITS, + "share type {} is too small", + std::any::type_name::(), + ); + + let left = Self::split_fields(&l); + let right = Self::split_fields(&r); + + Self { + match_key: (), + value: ReplicatedSecretSharing::new(left.0, right.0), + breakdown_key: ReplicatedSecretSharing::new(left.1, right.1), + } + } +} + impl PrfHybridReport { const PRF_MK_SZ: usize = 8; const V_SZ: usize = as Serializable>::Size::USIZE; diff --git a/ipa-core/src/test_fixture/hybrid.rs b/ipa-core/src/test_fixture/hybrid.rs index 7a4d86007..d522089c3 100644 --- a/ipa-core/src/test_fixture/hybrid.rs +++ b/ipa-core/src/test_fixture/hybrid.rs @@ -10,7 +10,8 @@ use crate::{ }, rand::Rng, report::hybrid::{ - HybridConversionReport, HybridImpressionReport, HybridReport, IndistinguishableHybridReport, + AggregateableHybridReport, HybridConversionReport, HybridImpressionReport, HybridReport, + IndistinguishableHybridReport, }, secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, IntoShares}, test_fixture::sharing::Reconstruct, @@ -88,6 +89,28 @@ where } } +impl IntoShares> for TestAggregateableHybridReport +where + BK: BooleanArray + U128Conversions + IntoShares>, + V: BooleanArray + U128Conversions + IntoShares>, +{ + fn share_with(self, rng: &mut R) -> [AggregateableHybridReport; 3] { + let ba_breakdown_key = BK::try_from(u128::from(self.breakdown_key)) + .unwrap() + .share_with(rng); + let ba_value = V::try_from(u128::from(self.value)).unwrap().share_with(rng); + zip(ba_breakdown_key, ba_value) + .map(|(breakdown_key, value)| AggregateableHybridReport { + match_key: (), + breakdown_key, + value, + }) + .collect::>() + .try_into() + .unwrap() + } +} + impl IntoShares> for TestHybridRecord where BK: BooleanArray + U128Conversions + IntoShares>, diff --git a/ipa-core/src/test_fixture/sharing.rs b/ipa-core/src/test_fixture/sharing.rs index e30e8fc2e..424e73e17 100644 --- a/ipa-core/src/test_fixture/sharing.rs +++ b/ipa-core/src/test_fixture/sharing.rs @@ -155,6 +155,15 @@ where } } +impl Reconstruct>> for Vec<[Vec; 3]> +where + for<'i> [&'i [I]; 3]: Reconstruct>, +{ + fn reconstruct(&self) -> Vec> { + self.iter().map(Reconstruct::reconstruct).collect() + } +} + impl Reconstruct> for [BitDecomposed; 3] where for<'i> [&'i [I]; 3]: Reconstruct>, From 7cabd6ff8db76817dd6c72e15bf4f62631781e6e Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 19 Nov 2024 15:30:11 -0800 Subject: [PATCH 23/47] Complete query API for sharded environments As we discussed, we want to make the complete query API simple to implement. That means, we have an assumption that all shards communicate their results to the leader. That leaves complete API to just clean up the state on each shard and return response from the leader. This PR does exactly that. --- ipa-core/src/app.rs | 7 +- ipa-core/src/helpers/transport/routing.rs | 2 +- ipa-core/src/query/processor.rs | 134 +++++++++++++++++++++- ipa-core/src/query/state.rs | 5 +- 4 files changed, 138 insertions(+), 10 deletions(-) diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index fb6f9fdb7..792215b1e 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -161,7 +161,7 @@ impl HelperApp { Ok(self .inner .query_processor - .complete(query_id) + .complete(query_id, self.inner.shard_transport.clone_ref()) .await? .to_bytes()) } @@ -251,7 +251,10 @@ impl RequestHandler for Inner { } RouteId::CompleteQuery => { let query_id = ext_query_id(&req)?; - HelperResponse::from(qp.complete(query_id).await?) + HelperResponse::from( + qp.complete(query_id, self.shard_transport.clone_ref()) + .await?, + ) } RouteId::KillQuery => { let query_id = ext_query_id(&req)?; diff --git a/ipa-core/src/helpers/transport/routing.rs b/ipa-core/src/helpers/transport/routing.rs index 3d9c2bb5f..6cb1006df 100644 --- a/ipa-core/src/helpers/transport/routing.rs +++ b/ipa-core/src/helpers/transport/routing.rs @@ -8,7 +8,7 @@ use crate::{ }; // The type of request made to an MPC helper. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum RouteId { Records, ReceiveQuery, diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 120e1c5ca..fb37fa555 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -11,6 +11,7 @@ use crate::{ executor::IpaRuntime, helpers::{ query::{PrepareQuery, QueryConfig, QueryInput}, + routing::RouteId, BroadcastError, Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, Role, RoleAssignment, ShardTransportError, ShardTransportImpl, Transport, }, @@ -123,6 +124,8 @@ pub enum QueryCompletionError { }, #[error("query execution failed: {0}")] ExecutionError(#[from] ProtocolError), + #[error("one or more shards rejected this request: {0}")] + ShardError(#[from] BroadcastError), } impl Debug for Processor { @@ -373,6 +376,7 @@ impl Processor { pub async fn complete( &self, query_id: QueryId, + shard_transport: ShardTransportImpl, ) -> Result, QueryCompletionError> { let handle = { let mut queries = self.queries.inner.lock().unwrap(); @@ -397,6 +401,18 @@ impl Processor { } }; // release mutex before await + // Inform other shards about our intent to complete the query. + // If any of them rejects it, report the error back. We expect all shards + // to be in the same state. In normal cycle, this API is called only after + // query status reports completion. + if shard_transport.identity() == ShardIndex::FIRST { + // See shard finalizer protocol to see how shards merge their results together. + // At the end, only leader holds the value + shard_transport + .broadcast((RouteId::CompleteQuery, query_id)) + .await?; + } + Ok(handle.await?) } @@ -440,7 +456,8 @@ mod tests { use tokio::sync::Barrier; use crate::{ - ff::FieldType, + executor::IpaRuntime, + ff::{boolean_array::BA64, FieldType}, helpers::{ make_owned_handler, query::{PrepareQuery, QueryConfig, QueryType::TestMultiply}, @@ -450,8 +467,9 @@ mod tests { }, protocol::QueryId, query::{ - processor::Processor, state::StateError, NewQueryError, PrepareQueryError, QueryStatus, - QueryStatusError, + processor::Processor, + state::{QueryState, RunningQuery, StateError}, + NewQueryError, PrepareQueryError, QueryStatus, QueryStatusError, }, sharding::ShardIndex, }; @@ -468,11 +486,11 @@ mod tests { } fn helper_respond_ok() -> Arc> { - prepare_query_handler(|_| async { Ok(HelperResponse::ok()) }) + make_owned_handler(move |_req, _| futures::future::ok(HelperResponse::ok())) } fn shard_respond_ok(_si: ShardIndex) -> Arc> { - prepare_query_handler(|_| async { Ok(HelperResponse::ok()) }) + make_owned_handler(move |_req, _| futures::future::ok(HelperResponse::ok())) } fn test_multiply_config() -> QueryConfig { @@ -559,6 +577,12 @@ mod tests { shard_transport: InMemoryTransport, } + impl Default for TestComponents { + fn default() -> Self { + Self::new(TestComponentsArgs::default()) + } + } + impl TestComponents { fn new(mut args: TestComponentsArgs) -> Self { let mpc_network = InMemoryMpcNetwork::new( @@ -584,6 +608,31 @@ mod tests { shard_transport, } } + + /// This initiates a new query on all shards and puts them all on running state. + /// It also makes up a fake query result + async fn new_running_query(&self) -> QueryId { + self.processor + .new_query( + self.first_transport.clone_ref(), + self.shard_transport.clone_ref(), + self.query_config, + ) + .await + .unwrap(); + let (tx, rx) = tokio::sync::oneshot::channel(); + self.processor + .queries + .handle(QueryId) + .set_state(QueryState::Running(RunningQuery { + result: rx, + join_handle: IpaRuntime::current().spawn(async {}), + })) + .unwrap(); + tx.send(Ok(Box::new(Vec::::new()))).unwrap(); + + QueryId + } } #[tokio::test] @@ -755,6 +804,81 @@ mod tests { assert!(t.processor.get_status(QueryId).is_none()); } + mod complete { + + use crate::{ + helpers::{make_owned_handler, routing::RouteId, ApiError, Transport}, + query::{ + processor::{ + tests::{HelperResponse, TestComponents, TestComponentsArgs}, + QueryId, + }, + QueryCompletionError, + }, + sharding::ShardIndex, + }; + + #[tokio::test] + async fn complete_basic() { + let t = TestComponents::default(); + let query_id = t.new_running_query().await; + + t.processor + .complete(query_id, t.shard_transport.clone_ref()) + .await + .unwrap(); + } + + #[tokio::test] + async fn complete_one_shard_fails() { + let mut args = TestComponentsArgs::default(); + + args.set_shard_handler(|shard_id| { + make_owned_handler(move |req, _| { + if shard_id != ShardIndex::from(1) || req.route != RouteId::CompleteQuery { + futures::future::ok(HelperResponse::ok()) + } else { + futures::future::err(QueryCompletionError::NoSuchQuery(QueryId).into()) + } + }) + }); + + let t = TestComponents::new(args); + let query_id = t.new_running_query().await; + + let _ = t + .processor + .complete(query_id, t.shard_transport.clone_ref()) + .await + .unwrap_err(); + } + + #[tokio::test] + async fn only_leader_broadcasts() { + let mut args = TestComponentsArgs::default(); + + args.set_shard_handler(|shard_id| { + make_owned_handler(move |_req, _| { + if shard_id == ShardIndex::FIRST { + futures::future::err(ApiError::BadRequest( + "Leader shard must not receive requests through shard channels".into(), + )) + } else { + futures::future::ok(HelperResponse::ok()) + } + }) + }); + + let t = TestComponents::new(args); + let query_id = t.new_running_query().await; + + t.processor + .complete(query_id, t.shard_transport.clone_ref()) + .await + .unwrap(); + } + } + mod prepare { use super::*; use crate::query::QueryStatusError; diff --git a/ipa-core/src/query/state.rs b/ipa-core/src/query/state.rs index 460296022..f745ada31 100644 --- a/ipa-core/src/query/state.rs +++ b/ipa-core/src/query/state.rs @@ -60,13 +60,14 @@ pub enum QueryState { impl QueryState { pub fn transition(cur_state: &Self, new_state: Self) -> Result { - use QueryState::{AwaitingInputs, Empty, Preparing}; + use QueryState::{AwaitingInputs, Empty, Preparing, Running}; match (cur_state, &new_state) { // If query is not running, coordinator initial state is preparing // and followers initial state is awaiting inputs (Empty, Preparing(_) | AwaitingInputs(_, _, _)) - | (Preparing(_), AwaitingInputs(_, _, _)) => Ok(new_state), + | (Preparing(_), AwaitingInputs(_, _, _)) + | (AwaitingInputs(_, _, _), Running(_)) => Ok(new_state), (_, Preparing(_)) => Err(StateError::AlreadyRunning), (_, _) => Err(StateError::InvalidState { from: cur_state.into(), From 5179647c8b8c581ee9387dbd44812b7243c94dc8 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Tue, 19 Nov 2024 18:04:13 -0800 Subject: [PATCH 24/47] Sending state via error --- ipa-core/src/helpers/gateway/transport.rs | 13 ++- ipa-core/src/helpers/mod.rs | 10 +-- .../helpers/transport/in_memory/transport.rs | 35 ++++++-- ipa-core/src/helpers/transport/mod.rs | 14 ++- ipa-core/src/net/error.rs | 24 +++++- ipa-core/src/net/transport.rs | 13 ++- ipa-core/src/query/mod.rs | 2 +- ipa-core/src/query/processor.rs | 85 ++++++++++++++++--- ipa-core/src/query/state.rs | 36 ++++++++ 9 files changed, 201 insertions(+), 31 deletions(-) diff --git a/ipa-core/src/helpers/gateway/transport.rs b/ipa-core/src/helpers/gateway/transport.rs index dfbc9d328..d05353af3 100644 --- a/ipa-core/src/helpers/gateway/transport.rs +++ b/ipa-core/src/helpers/gateway/transport.rs @@ -3,10 +3,12 @@ use futures::Stream; use crate::{ helpers::{ - transport::routing::RouteId, MpcTransportImpl, NoResourceIdentifier, QueryIdBinding, Role, - RoleAssignment, RouteParams, StepBinding, Transport, + transport::{routing::RouteId, BroadcasteableError}, + MpcTransportImpl, NoResourceIdentifier, QueryIdBinding, Role, RoleAssignment, RouteParams, + StepBinding, Transport, }, protocol::{Gate, QueryId}, + query::QueryStatus, sharding::ShardIndex, }; @@ -14,6 +16,13 @@ use crate::{ #[error("Failed to send to {0:?}: {1:?}")] pub struct SendToRoleError(Role, ::Error); +impl BroadcasteableError for SendToRoleError { + /// Implementing this as a no-op + fn peer_state(&self) -> Option { + None + } +} + /// Transport adapter that resolves [`Role`] -> [`HelperIdentity`] mapping. As gateways created /// per query, it is not ambiguous. /// diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index 370c42b05..93eeb6dd0 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -73,11 +73,11 @@ pub use transport::{ config as in_memory_config, InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport, }; pub use transport::{ - make_owned_handler, query, routing, ApiError, BodyStream, BroadcastError, BytesStream, - HandlerBox, HandlerRef, HelperResponse, Identity as TransportIdentity, LengthDelimitedStream, - LogErrors, NoQueryId, NoResourceIdentifier, NoStep, QueryIdBinding, ReceiveRecords, - RecordsStream, RequestHandler, RouteParams, SingleRecordStream, StepBinding, StreamCollection, - StreamKey, Transport, WrappedBoxBodyStream, + make_owned_handler, query, routing, ApiError, BodyStream, BroadcastError, BroadcasteableError, + BytesStream, HandlerBox, HandlerRef, HelperResponse, Identity as TransportIdentity, + LengthDelimitedStream, LogErrors, NoQueryId, NoResourceIdentifier, NoStep, QueryIdBinding, + ReceiveRecords, RecordsStream, RequestHandler, RouteParams, SingleRecordStream, StepBinding, + StreamCollection, StreamKey, Transport, WrappedBoxBodyStream, }; use typenum::{Const, ToUInt, Unsigned, U8}; use x25519_dalek::PublicKey; diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index 93ac7a523..6ffd96a08 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -23,11 +23,12 @@ use crate::{ helpers::{ in_memory_config::{self, DynStreamInterceptor}, transport::routing::{Addr, RouteId}, - ApiError, BodyStream, HandlerRef, HelperIdentity, HelperResponse, NoResourceIdentifier, - QueryIdBinding, ReceiveRecords, RequestHandler, RouteParams, StepBinding, StreamCollection, - Transport, TransportIdentity, + ApiError, BodyStream, BroadcasteableError, HandlerRef, HelperIdentity, HelperResponse, + NoResourceIdentifier, QueryIdBinding, ReceiveRecords, RequestHandler, RouteParams, + StepBinding, StreamCollection, Transport, TransportIdentity, }, protocol::{Gate, QueryId}, + query::{QueryStatus, QueryStatusError}, sharding::ShardIndex, sync::{Arc, Weak}, }; @@ -59,6 +60,18 @@ pub enum Error { #[from] inner: serde_json::Error, }, + #[error("Peer is in an invalid state: {peer_state:?}")] + PeerState { peer_state: QueryStatus }, +} + +impl BroadcasteableError for Error { + fn peer_state(&self) -> Option { + let mut status = None; + if let Error::PeerState { peer_state } = self { + status = Some(peer_state); + } + status.copied() + } } /// In-memory implementation of [`Transport`] backed by Tokio mpsc channels. @@ -219,9 +232,19 @@ impl Transport for Weak> { dest, inner: "channel closed".into(), })? - .map_err(|e| Error::Rejected { - dest, - inner: e.into(), + .map_err(|e: ApiError| { + if let ApiError::QueryStatus(QueryStatusError::DifferentStatus { + my_status, .. + }) = e + { + return Error::PeerState { + peer_state: my_status, + }; + } + Error::Rejected { + dest, + inner: e.into(), + } })?; Ok(()) diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index b3cfb862f..e7f2e1ebf 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -12,6 +12,7 @@ use crate::helpers::in_memory_config::InspectContext; use crate::{ helpers::{transport::routing::RouteId, HelperIdentity, Role, TransportIdentity}, protocol::{Gate, QueryId}, + query::QueryStatus, sharding::ShardIndex, }; @@ -287,13 +288,20 @@ impl RouteParams for (RouteId, QueryId) { } } +/// Broadcast errors need to tell in what state their peer is so that the processor that's +/// broadcasting knows how to handle the error. For example, if the peer is in Completed state it +/// might want to handle the error differently than if the query hasn't been started. +pub trait BroadcasteableError: Debug { + fn peer_state(&self) -> Option; +} + #[derive(thiserror::Error, Debug)] #[error("One or more peers rejected the request: {failures:?}")] -pub struct BroadcastError { +pub struct BroadcastError { pub failures: Vec<(I, E)>, } -impl From> for BroadcastError { +impl From> for BroadcastError { fn from(value: Vec<(I, E)>) -> Self { Self { failures: value } } @@ -304,7 +312,7 @@ impl From> for BroadcastError pub trait Transport: Clone + Send + Sync + 'static { type Identity: TransportIdentity; type RecordsStream: BytesStream; - type Error: std::fmt::Debug + Send; + type Error: BroadcasteableError + Send; /// Return my identity in the network (MPC or Sharded) fn identity(&self) -> Self::Identity; diff --git a/ipa-core/src/net/error.rs b/ipa-core/src/net/error.rs index 6a04e8282..60ace9543 100644 --- a/ipa-core/src/net/error.rs +++ b/ipa-core/src/net/error.rs @@ -4,7 +4,8 @@ use axum::{ }; use crate::{ - error::BoxError, net::client::ResponseFromEndpoint, protocol::QueryId, sharding::ShardIndex, + error::BoxError, helpers::BroadcasteableError, net::client::ResponseFromEndpoint, + protocol::QueryId, query::QueryStatus, sharding::ShardIndex, }; #[derive(thiserror::Error, Debug)] @@ -61,6 +62,8 @@ pub enum Error { }, #[error("{error}")] Application { code: StatusCode, error: BoxError }, + #[error("Peer is in an invalid state: {peer_state:?}")] + PeerState { peer_state: QueryStatus }, } impl Error { @@ -139,9 +142,26 @@ impl From for Error { #[error("Error in shard {shard_index}: {source}")] pub struct ShardError { pub shard_index: ShardIndex, + pub status: Option, pub source: Error, } +impl BroadcasteableError for ShardError { + fn peer_state(&self) -> Option { + self.status + } +} + +impl BroadcasteableError for Error { + fn peer_state(&self) -> Option { + let mut status = None; + if let Error::PeerState { peer_state } = self { + status = Some(peer_state); + } + status.copied() + } +} + impl IntoResponse for Error { fn into_response(self) -> Response { let status_code = match self { @@ -164,6 +184,8 @@ impl IntoResponse for Error { | Self::InvalidUri(_) | Self::MissingExtension(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::PeerState { .. } => StatusCode::PRECONDITION_FAILED, + Self::Application { code, .. } => code, }; (status_code, self.to_string()).into_response() diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 6ed523093..542ec73e5 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -352,9 +352,16 @@ impl Transport for ShardHttpTransport { { self.inner_transport .send(dest, route, data) - .map_err(|source| ShardError { - shard_index: self.identity(), - source, + .map_err(|source: Error| { + let mut status = None; + if let Error::PeerState { peer_state } = source { + status = Some(peer_state); + } + ShardError { + shard_index: self.identity(), + status, + source, + } }) .await } diff --git a/ipa-core/src/query/mod.rs b/ipa-core/src/query/mod.rs index 6e6650862..1f2550987 100644 --- a/ipa-core/src/query/mod.rs +++ b/ipa-core/src/query/mod.rs @@ -11,4 +11,4 @@ pub use processor::{ QueryInputError, QueryKillStatus, QueryKilled, QueryStatusError, }; pub use runner::OprfIpaQuery; -pub use state::QueryStatus; +pub use state::{min_status, QueryStatus}; diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 4f860f011..a6df95a74 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -6,13 +6,14 @@ use std::{ use futures::{future::try_join, stream}; use serde::Serialize; +use super::min_status; use crate::{ error::Error as ProtocolError, executor::IpaRuntime, helpers::{ query::{CompareStatusRequest, PrepareQuery, QueryConfig, QueryInput}, - BroadcastError, Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, Role, - RoleAssignment, ShardTransportError, ShardTransportImpl, Transport, + BroadcastError, BroadcasteableError, Gateway, GatewayConfig, MpcTransportError, + MpcTransportImpl, Role, RoleAssignment, ShardTransportError, ShardTransportImpl, Transport, }, hpke::{KeyRegistry, PrivateKeyOnly}, protocol::QueryId, @@ -378,13 +379,23 @@ impl Processor { return Err(QueryStatusError::NotLeader(shard_index)); } - let status = self + let mut status = self .get_status(query_id) .ok_or(QueryStatusError::NoSuchQuery(query_id))?; let shard_query_status_req = CompareStatusRequest { query_id, status }; - shard_transport.broadcast(shard_query_status_req).await?; + let shard_responses = shard_transport.broadcast(shard_query_status_req).await; + if let Err(e) = shard_responses { + // The following silently ignores the cases where the query isn't found because those + // errors return `None` for [`BroadcasteableError::peer_state()`] + let states: Vec<_> = e + .failures + .iter() + .filter_map(|(_si, error)| error.peer_state()) + .collect(); + status = states.into_iter().fold(status, min_status); + } Ok(status) } @@ -960,7 +971,9 @@ mod tests { /// * From the standpoint of leader shard in Helper 1 /// * On query_status /// - /// If one of my shards isn't ready + /// The min state should be returned. In this case, if I, as leader, am in AwaitingInputs + /// state and shards report that they are further ahead (Completed and Running), then my + /// state is returned. #[tokio::test] async fn combined_status_response() { fn shard_handle(si: ShardIndex) -> Arc> { @@ -973,6 +986,13 @@ mod tests { other_status: QueryStatus::Preparing, })) } + ShardIndex(2) => { + Err(ApiError::QueryStatus(QueryStatusError::DifferentStatus { + query_id: QueryId, + my_status: QueryStatus::Running, + other_status: QueryStatus::Preparing, + })) + } _ => Ok(HelperResponse::ok()), } }) @@ -999,11 +1019,56 @@ mod tests { .query_status(t.shard_transport.clone_ref(), QueryId) .await; if let Err(e) = r { - if let QueryStatusError::ShardBroadcastError(be) = e { - assert_eq!(be.failures[0].0, ShardIndex(3)); - } else { - panic!("Unexpected error type"); - } + panic!("Unexpected error {e}"); + } + if let Ok(st) = r { + assert_eq!(QueryStatus::AwaitingInputs, st); + } + } + + /// * From the standpoint of leader shard in Helper 1 + /// * On query_status + /// + /// If one of my shards hasn't received the query yet (NoSuchQuery) the leader shouldn't + /// return an error but instead with the min state. + #[tokio::test] + async fn status_query_doesnt_exist() { + fn shard_handle(si: ShardIndex) -> Arc> { + create_handler(move |_| async move { + match si { + ShardIndex(3) => Err(ApiError::QueryStatus(QueryStatusError::NoSuchQuery( + QueryId, + ))), + _ => Ok(HelperResponse::ok()), + } + }) + } + let mut args = TestComponentsArgs { + shard_count: 4, + ..Default::default() + }; + args.set_shard_handler(shard_handle); + let t = TestComponents::new(args); + let req = prepare_query(); + // Using prepare shard to set the inner state, but in reality we should be using prepare_helper + // Prepare_helper will use the shard_handle defined above though and will fail. The following + // achieves the same state. + t.processor + .prepare_shard( + &t.shard_network + .transport(HelperIdentity::ONE, ShardIndex::from(1)), + req, + ) + .unwrap(); + let r = t + .processor + .query_status(t.shard_transport.clone_ref(), QueryId) + .await; + if let Err(e) = r { + panic!("Unexpected error {e}"); + } + if let Ok(st) = r { + assert_eq!(QueryStatus::AwaitingInputs, st); } } diff --git a/ipa-core/src/query/state.rs b/ipa-core/src/query/state.rs index 460296022..834e491bd 100644 --- a/ipa-core/src/query/state.rs +++ b/ipa-core/src/query/state.rs @@ -48,6 +48,21 @@ impl From<&QueryState> for QueryStatus { } } +#[must_use] +pub fn min_status(a: QueryStatus, b: QueryStatus) -> QueryStatus { + match (a, b) { + (QueryStatus::Preparing, _) | (_, QueryStatus::Preparing) => QueryStatus::Preparing, + (QueryStatus::AwaitingInputs, _) | (_, QueryStatus::AwaitingInputs) => { + QueryStatus::AwaitingInputs + } + (QueryStatus::Running, _) | (_, QueryStatus::Running) => QueryStatus::Running, + (QueryStatus::AwaitingCompletion, _) | (_, QueryStatus::AwaitingCompletion) => { + QueryStatus::AwaitingCompletion + } + (QueryStatus::Completed, _) => QueryStatus::Completed, + } +} + /// TODO: a macro would be very useful here to keep it in sync with `QueryStatus` pub enum QueryState { Empty, @@ -226,3 +241,24 @@ impl Drop for RemoveQuery<'_> { } } } + +#[cfg(all(test, unit_test))] +mod tests { + use crate::query::{state::min_status, QueryStatus}; + + #[test] + fn test_order() { + assert_eq!( + min_status(QueryStatus::Preparing, QueryStatus::Preparing), + QueryStatus::Preparing + ); + assert_eq!( + min_status(QueryStatus::Preparing, QueryStatus::Completed), + QueryStatus::Preparing + ); + assert_eq!( + min_status(QueryStatus::AwaitingCompletion, QueryStatus::AwaitingInputs), + QueryStatus::AwaitingInputs + ); + } +} From 22ba17d9c675ea973949c0dfa41244f3ed0fa7ad Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 19 Nov 2024 21:40:07 -0800 Subject: [PATCH 25/47] Tests for `peer_count` --- .../helpers/transport/in_memory/transport.rs | 26 +++++++++++++++-- ipa-core/src/net/transport.rs | 29 +++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index 93ac7a523..a2a1abea6 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -373,10 +373,11 @@ mod tests { }, routing::RouteId, }, - HandlerBox, HelperIdentity, HelperResponse, OrderingSender, Role, RoleAssignment, - Transport, TransportIdentity, + HandlerBox, HelperIdentity, HelperResponse, InMemoryShardNetwork, OrderingSender, Role, + RoleAssignment, Transport, TransportIdentity, }, protocol::{Gate, QueryId}, + sharding::ShardIndex, sync::Arc, }; @@ -625,6 +626,27 @@ mod tests { // must be received by now assert_eq!(vec![vec![0, 1]], recv.collect::>().await); } + + #[tokio::test] + async fn peer_count() { + let mpc_network = InMemoryMpcNetwork::default(); + assert_eq!(2, mpc_network.transport(HelperIdentity::ONE).peer_count()); + assert_eq!(2, mpc_network.transport(HelperIdentity::TWO).peer_count()); + + let shard_network = InMemoryShardNetwork::with_shards(5); + assert_eq!( + 4, + shard_network + .transport(HelperIdentity::ONE, ShardIndex::FIRST) + .peer_count() + ); + assert_eq!( + 4, + shard_network + .transport(HelperIdentity::TWO, ShardIndex::from(4)) + .peer_count() + ); + } } pub struct TransportConfig { diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index b6d726a43..55d363f8b 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -573,4 +573,33 @@ mod tests { .build(); test_make_helpers(conf).await; } + + #[tokio::test] + async fn peer_count() { + fn new_transport(identity: F::Identity) -> Arc> { + Arc::new(HttpTransport { + http_runtime: IpaRuntime::current(), + identity, + clients: Vec::new(), + handler: None, + record_streams: StreamCollection::default(), + }) + } + + assert_eq!( + 2, + MpcHttpTransport { + inner_transport: new_transport(HelperIdentity::ONE) + } + .peer_count() + ); + assert_eq!( + 9, + ShardHttpTransport { + inner_transport: new_transport(ShardIndex::FIRST), + shard_count: 10.into() + } + .peer_count() + ); + } } From 36d18a239e6d69d67ad9d5bd4a148eca3e8236ab Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 19 Nov 2024 21:44:38 -0800 Subject: [PATCH 26/47] Docs --- ipa-core/src/helpers/gateway/mod.rs | 1 + ipa-core/src/query/runner/sharded_shuffle.rs | 3 +++ 2 files changed, 4 insertions(+) diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index 9ecde1377..c14bfb0fd 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -122,6 +122,7 @@ impl ShardConfiguration for &Gateway { } fn shard_count(&self) -> ShardIndex { + // total number of shards include this instance and all its peers, so we add 1. ShardIndex::from(self.transports.shard.peer_count() + 1) } } diff --git a/ipa-core/src/query/runner/sharded_shuffle.rs b/ipa-core/src/query/runner/sharded_shuffle.rs index 0025f07cb..a90161b25 100644 --- a/ipa-core/src/query/runner/sharded_shuffle.rs +++ b/ipa-core/src/query/runner/sharded_shuffle.rs @@ -18,6 +18,9 @@ use crate::{ sync::Arc, }; +/// This executes the sharded shuffle protocol that consists of only one step: +/// permute the private inputs using a permutation that is not known to any helper +/// and client. pub async fn execute_sharded_shuffle<'a>( prss: &'a PrssEndpoint, gateway: &'a Gateway, From c1859b5bbba205c6a9e9746642b9e7de0eabee07 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Wed, 20 Nov 2024 09:50:29 -0800 Subject: [PATCH 27/47] broadcast_with_errors --- ipa-core/src/helpers/gateway/transport.rs | 13 ++------- ipa-core/src/helpers/mod.rs | 10 +++---- .../helpers/transport/in_memory/transport.rs | 16 ++--------- ipa-core/src/helpers/transport/mod.rs | 28 +++++++++++++------ ipa-core/src/net/error.rs | 20 ++----------- ipa-core/src/query/processor.rs | 4 +-- 6 files changed, 34 insertions(+), 57 deletions(-) diff --git a/ipa-core/src/helpers/gateway/transport.rs b/ipa-core/src/helpers/gateway/transport.rs index d05353af3..dfbc9d328 100644 --- a/ipa-core/src/helpers/gateway/transport.rs +++ b/ipa-core/src/helpers/gateway/transport.rs @@ -3,12 +3,10 @@ use futures::Stream; use crate::{ helpers::{ - transport::{routing::RouteId, BroadcasteableError}, - MpcTransportImpl, NoResourceIdentifier, QueryIdBinding, Role, RoleAssignment, RouteParams, - StepBinding, Transport, + transport::routing::RouteId, MpcTransportImpl, NoResourceIdentifier, QueryIdBinding, Role, + RoleAssignment, RouteParams, StepBinding, Transport, }, protocol::{Gate, QueryId}, - query::QueryStatus, sharding::ShardIndex, }; @@ -16,13 +14,6 @@ use crate::{ #[error("Failed to send to {0:?}: {1:?}")] pub struct SendToRoleError(Role, ::Error); -impl BroadcasteableError for SendToRoleError { - /// Implementing this as a no-op - fn peer_state(&self) -> Option { - None - } -} - /// Transport adapter that resolves [`Role`] -> [`HelperIdentity`] mapping. As gateways created /// per query, it is not ambiguous. /// diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index 93eeb6dd0..370c42b05 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -73,11 +73,11 @@ pub use transport::{ config as in_memory_config, InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport, }; pub use transport::{ - make_owned_handler, query, routing, ApiError, BodyStream, BroadcastError, BroadcasteableError, - BytesStream, HandlerBox, HandlerRef, HelperResponse, Identity as TransportIdentity, - LengthDelimitedStream, LogErrors, NoQueryId, NoResourceIdentifier, NoStep, QueryIdBinding, - ReceiveRecords, RecordsStream, RequestHandler, RouteParams, SingleRecordStream, StepBinding, - StreamCollection, StreamKey, Transport, WrappedBoxBodyStream, + make_owned_handler, query, routing, ApiError, BodyStream, BroadcastError, BytesStream, + HandlerBox, HandlerRef, HelperResponse, Identity as TransportIdentity, LengthDelimitedStream, + LogErrors, NoQueryId, NoResourceIdentifier, NoStep, QueryIdBinding, ReceiveRecords, + RecordsStream, RequestHandler, RouteParams, SingleRecordStream, StepBinding, StreamCollection, + StreamKey, Transport, WrappedBoxBodyStream, }; use typenum::{Const, ToUInt, Unsigned, U8}; use x25519_dalek::PublicKey; diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index 6ffd96a08..4d7661622 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -23,9 +23,9 @@ use crate::{ helpers::{ in_memory_config::{self, DynStreamInterceptor}, transport::routing::{Addr, RouteId}, - ApiError, BodyStream, BroadcasteableError, HandlerRef, HelperIdentity, HelperResponse, - NoResourceIdentifier, QueryIdBinding, ReceiveRecords, RequestHandler, RouteParams, - StepBinding, StreamCollection, Transport, TransportIdentity, + ApiError, BodyStream, HandlerRef, HelperIdentity, HelperResponse, NoResourceIdentifier, + QueryIdBinding, ReceiveRecords, RequestHandler, RouteParams, StepBinding, StreamCollection, + Transport, TransportIdentity, }, protocol::{Gate, QueryId}, query::{QueryStatus, QueryStatusError}, @@ -64,16 +64,6 @@ pub enum Error { PeerState { peer_state: QueryStatus }, } -impl BroadcasteableError for Error { - fn peer_state(&self) -> Option { - let mut status = None; - if let Error::PeerState { peer_state } = self { - status = Some(peer_state); - } - status.copied() - } -} - /// In-memory implementation of [`Transport`] backed by Tokio mpsc channels. /// Use [`Setup`] to initialize it and call [`Setup::start`] to make it actively listen for /// incoming messages. diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index e7f2e1ebf..82f8d85a6 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -297,11 +297,11 @@ pub trait BroadcasteableError: Debug { #[derive(thiserror::Error, Debug)] #[error("One or more peers rejected the request: {failures:?}")] -pub struct BroadcastError { +pub struct BroadcastError { pub failures: Vec<(I, E)>, } -impl From> for BroadcastError { +impl From> for BroadcastError { fn from(value: Vec<(I, E)>) -> Self { Self { failures: value } } @@ -312,7 +312,7 @@ impl From> for Broadca pub trait Transport: Clone + Send + Sync + 'static { type Identity: TransportIdentity; type RecordsStream: BytesStream; - type Error: BroadcasteableError + Send; + type Error: Debug + Send; /// Return my identity in the network (MPC or Sharded) fn identity(&self) -> Self::Identity; @@ -351,6 +351,22 @@ pub trait Transport: Clone + Send + Sync + 'static { &self, route: R, ) -> Result<(), BroadcastError> + where + Option: From, + Option: From, + Q: QueryIdBinding, + S: StepBinding, + R: RouteParams + Clone, + { + let errs = self.broadcast_with_errors(route).await; + if errs.is_empty() { + Ok(()) + } else { + Err(errs.into()) + } + } + + async fn broadcast_with_errors(&self, route: R) -> Vec<(Self::Identity, Self::Error)> where Option: From, Option: From, @@ -373,11 +389,7 @@ pub trait Transport: Clone + Send + Sync + 'static { } } - if errs.is_empty() { - Ok(()) - } else { - Err(errs.into()) - } + errs } /// Alias for `Clone::clone`. diff --git a/ipa-core/src/net/error.rs b/ipa-core/src/net/error.rs index 60ace9543..ed817b6e7 100644 --- a/ipa-core/src/net/error.rs +++ b/ipa-core/src/net/error.rs @@ -4,8 +4,8 @@ use axum::{ }; use crate::{ - error::BoxError, helpers::BroadcasteableError, net::client::ResponseFromEndpoint, - protocol::QueryId, query::QueryStatus, sharding::ShardIndex, + error::BoxError, net::client::ResponseFromEndpoint, protocol::QueryId, query::QueryStatus, + sharding::ShardIndex, }; #[derive(thiserror::Error, Debug)] @@ -146,22 +146,6 @@ pub struct ShardError { pub source: Error, } -impl BroadcasteableError for ShardError { - fn peer_state(&self) -> Option { - self.status - } -} - -impl BroadcasteableError for Error { - fn peer_state(&self) -> Option { - let mut status = None; - if let Error::PeerState { peer_state } = self { - status = Some(peer_state); - } - status.copied() - } -} - impl IntoResponse for Error { fn into_response(self) -> Response { let status_code = match self { diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index a6df95a74..9031f69f3 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -12,8 +12,8 @@ use crate::{ executor::IpaRuntime, helpers::{ query::{CompareStatusRequest, PrepareQuery, QueryConfig, QueryInput}, - BroadcastError, BroadcasteableError, Gateway, GatewayConfig, MpcTransportError, - MpcTransportImpl, Role, RoleAssignment, ShardTransportError, ShardTransportImpl, Transport, + BroadcastError, Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, Role, + RoleAssignment, ShardTransportError, ShardTransportImpl, Transport, }, hpke::{KeyRegistry, PrivateKeyOnly}, protocol::QueryId, From 6b5e162fa44cc81077f904dc075ced0149980480 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Wed, 20 Nov 2024 10:27:15 -0800 Subject: [PATCH 28/47] removing conversions --- .../src/helpers/transport/in_memory/transport.rs | 15 +-------------- ipa-core/src/net/error.rs | 7 +------ ipa-core/src/net/transport.rs | 8 +------- 3 files changed, 3 insertions(+), 27 deletions(-) diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index 4d7661622..61a136968 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -28,7 +28,6 @@ use crate::{ Transport, TransportIdentity, }, protocol::{Gate, QueryId}, - query::{QueryStatus, QueryStatusError}, sharding::ShardIndex, sync::{Arc, Weak}, }; @@ -60,8 +59,6 @@ pub enum Error { #[from] inner: serde_json::Error, }, - #[error("Peer is in an invalid state: {peer_state:?}")] - PeerState { peer_state: QueryStatus }, } /// In-memory implementation of [`Transport`] backed by Tokio mpsc channels. @@ -222,19 +219,9 @@ impl Transport for Weak> { dest, inner: "channel closed".into(), })? - .map_err(|e: ApiError| { - if let ApiError::QueryStatus(QueryStatusError::DifferentStatus { - my_status, .. - }) = e - { - return Error::PeerState { - peer_state: my_status, - }; - } - Error::Rejected { + .map_err(|e| Error::Rejected { dest, inner: e.into(), - } })?; Ok(()) diff --git a/ipa-core/src/net/error.rs b/ipa-core/src/net/error.rs index ed817b6e7..da31b28fb 100644 --- a/ipa-core/src/net/error.rs +++ b/ipa-core/src/net/error.rs @@ -4,7 +4,7 @@ use axum::{ }; use crate::{ - error::BoxError, net::client::ResponseFromEndpoint, protocol::QueryId, query::QueryStatus, + error::BoxError, net::client::ResponseFromEndpoint, protocol::QueryId, sharding::ShardIndex, }; @@ -62,8 +62,6 @@ pub enum Error { }, #[error("{error}")] Application { code: StatusCode, error: BoxError }, - #[error("Peer is in an invalid state: {peer_state:?}")] - PeerState { peer_state: QueryStatus }, } impl Error { @@ -142,7 +140,6 @@ impl From for Error { #[error("Error in shard {shard_index}: {source}")] pub struct ShardError { pub shard_index: ShardIndex, - pub status: Option, pub source: Error, } @@ -168,8 +165,6 @@ impl IntoResponse for Error { | Self::InvalidUri(_) | Self::MissingExtension(_) => StatusCode::INTERNAL_SERVER_ERROR, - Self::PeerState { .. } => StatusCode::PRECONDITION_FAILED, - Self::Application { code, .. } => code, }; (status_code, self.to_string()).into_response() diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 542ec73e5..dbd3fc8d9 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -352,16 +352,10 @@ impl Transport for ShardHttpTransport { { self.inner_transport .send(dest, route, data) - .map_err(|source: Error| { - let mut status = None; - if let Error::PeerState { peer_state } = source { - status = Some(peer_state); - } - ShardError { + .map_err(|source| ShardError { shard_index: self.identity(), status, source, - } }) .await } From cd28ea0190365fd559c62772bf1f7cb11f663ed6 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Wed, 20 Nov 2024 10:55:03 -0800 Subject: [PATCH 29/47] cfg functions + dispatch --- ipa-core/src/helpers/mod.rs | 1 + ipa-core/src/helpers/transport/in_memory/mod.rs | 2 +- .../src/helpers/transport/in_memory/transport.rs | 4 ++-- ipa-core/src/helpers/transport/mod.rs | 12 +++--------- ipa-core/src/net/error.rs | 3 +-- ipa-core/src/net/transport.rs | 5 ++--- ipa-core/src/query/processor.rs | 14 +++++++++++++- 7 files changed, 23 insertions(+), 18 deletions(-) diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index 370c42b05..95eb33b71 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -71,6 +71,7 @@ pub use transport::WrappedAxumBodyStream; #[cfg(feature = "in-memory-infra")] pub use transport::{ config as in_memory_config, InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport, + InMemoryTransportError, }; pub use transport::{ make_owned_handler, query, routing, ApiError, BodyStream, BroadcastError, BytesStream, diff --git a/ipa-core/src/helpers/transport/in_memory/mod.rs b/ipa-core/src/helpers/transport/in_memory/mod.rs index a5c34bba5..1f23eb278 100644 --- a/ipa-core/src/helpers/transport/in_memory/mod.rs +++ b/ipa-core/src/helpers/transport/in_memory/mod.rs @@ -3,8 +3,8 @@ mod sharding; mod transport; pub use sharding::InMemoryShardNetwork; -pub use transport::Setup; use transport::TransportConfigBuilder; +pub use transport::{Error as InMemoryTransportError, Setup}; use crate::{ helpers::{ diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index 61a136968..93ac7a523 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -220,8 +220,8 @@ impl Transport for Weak> { inner: "channel closed".into(), })? .map_err(|e| Error::Rejected { - dest, - inner: e.into(), + dest, + inner: e.into(), })?; Ok(()) diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index 82f8d85a6..61e97912e 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -12,7 +12,6 @@ use crate::helpers::in_memory_config::InspectContext; use crate::{ helpers::{transport::routing::RouteId, HelperIdentity, Role, TransportIdentity}, protocol::{Gate, QueryId}, - query::QueryStatus, sharding::ShardIndex, }; @@ -28,7 +27,9 @@ pub use handler::{ make_owned_handler, Error as ApiError, HandlerBox, HandlerRef, HelperResponse, RequestHandler, }; #[cfg(feature = "in-memory-infra")] -pub use in_memory::{config, InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport}; +pub use in_memory::{ + config, InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport, InMemoryTransportError, +}; use ipa_metrics::LabelValue; pub use receive::{LogErrors, ReceiveRecords}; #[cfg(feature = "web-app")] @@ -288,13 +289,6 @@ impl RouteParams for (RouteId, QueryId) { } } -/// Broadcast errors need to tell in what state their peer is so that the processor that's -/// broadcasting knows how to handle the error. For example, if the peer is in Completed state it -/// might want to handle the error differently than if the query hasn't been started. -pub trait BroadcasteableError: Debug { - fn peer_state(&self) -> Option; -} - #[derive(thiserror::Error, Debug)] #[error("One or more peers rejected the request: {failures:?}")] pub struct BroadcastError { diff --git a/ipa-core/src/net/error.rs b/ipa-core/src/net/error.rs index da31b28fb..6a04e8282 100644 --- a/ipa-core/src/net/error.rs +++ b/ipa-core/src/net/error.rs @@ -4,8 +4,7 @@ use axum::{ }; use crate::{ - error::BoxError, net::client::ResponseFromEndpoint, protocol::QueryId, - sharding::ShardIndex, + error::BoxError, net::client::ResponseFromEndpoint, protocol::QueryId, sharding::ShardIndex, }; #[derive(thiserror::Error, Debug)] diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index dbd3fc8d9..6ed523093 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -353,9 +353,8 @@ impl Transport for ShardHttpTransport { self.inner_transport .send(dest, route, data) .map_err(|source| ShardError { - shard_index: self.identity(), - status, - source, + shard_index: self.identity(), + source, }) .await } diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 9031f69f3..92bfbe0c7 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -362,6 +362,18 @@ impl Processor { Some(status) } + #[cfg(feature = "in-memory-infra")] + fn get_state_from_error( + _be: &crate::helpers::InMemoryTransportError, + ) -> Option { + todo!() + } + + #[cfg(feature = "real-world-infra")] + fn get_state_from_error(be: &ShardError) -> QueryStatus { + todo!() + } + /// Returns the query status in this helper, by querying all shards. /// /// ## Errors @@ -392,7 +404,7 @@ impl Processor { let states: Vec<_> = e .failures .iter() - .filter_map(|(_si, error)| error.peer_state()) + .filter_map(|(_si, e)| Self::get_state_from_error(e)) .collect(); status = states.into_iter().fold(status, min_status); } From e2956a02ac7d5ee80f97baee2ef7945fb8a72973 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 20 Nov 2024 11:24:59 -0800 Subject: [PATCH 30/47] Fix `four_shards_http` test This fix is quite involved because it requires fixing the query workflow cycle. Receive inputs and complete query handlers must work for both shard and MPC and HTTP handlers must be present to process data. --- ipa-core/src/app.rs | 9 +++- ipa-core/src/helpers/transport/routing.rs | 12 +++++ ipa-core/src/net/client/mod.rs | 36 ++++++++++----- ipa-core/src/net/server/handlers/mod.rs | 4 +- .../src/net/server/handlers/query/input.rs | 13 +++--- ipa-core/src/net/server/handlers/query/mod.rs | 23 +++++++--- .../src/net/server/handlers/query/results.rs | 17 ++++--- ipa-core/src/net/test.rs | 13 ++++++ ipa-core/src/net/transport.rs | 44 ++++++++++++++++--- 9 files changed, 133 insertions(+), 38 deletions(-) diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index 792215b1e..a19787c1b 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -177,7 +177,7 @@ impl RequestHandler for Inner { async fn handle( &self, req: Addr, - _data: BodyStream, + data: BodyStream, ) -> Result { let qp = &self.query_processor; @@ -186,6 +186,13 @@ impl RequestHandler for Inner { let req = req.into::()?; HelperResponse::from(qp.prepare_shard(&self.shard_transport, req)?) } + RouteId::QueryInput | RouteId::CompleteQuery => { + // The processing flow for this API is exactly the same, regardless + // whether it was received from a peer shard or from report collector. + // Authentication is handled on the layer above, so we erase the identity + // and pass it down to the MPC handler. + RequestHandler::::handle(self, req.erase_origin(), data).await? + } r => { return Err(ApiError::BadRequest( format!("{r:?} request must not be handled by shard query processing flow") diff --git a/ipa-core/src/helpers/transport/routing.rs b/ipa-core/src/helpers/transport/routing.rs index 6cb1006df..e851865ce 100644 --- a/ipa-core/src/helpers/transport/routing.rs +++ b/ipa-core/src/helpers/transport/routing.rs @@ -49,6 +49,18 @@ impl Addr { } } + /// Drop the origin value and convert this into a request for a different identity type. + /// Useful when we need to handle this request in both shard and MPC handlers. + pub fn erase_origin(self) -> Addr { + Addr { + route: self.route, + origin: None, + query_id: self.query_id, + gate: self.gate, + params: self.params, + } + } + /// Deserializes JSON-encoded request parameters into a client-supplied type `T`. /// /// ## Errors diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index 0bcead345..778aba6f1 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -372,6 +372,30 @@ impl IpaHttpClient { let resp = self.request(req).await?; resp_ok(resp).await } + + /// Intended to be called externally, e.g. by the report collector. After the report collector + /// calls "create query", it must then send the data for the query to each of the clients. This + /// query input contains the data intended for a helper. + /// # Errors + /// If the request has illegal arguments, or fails to deliver to helper + pub async fn query_input(&self, data: QueryInput) -> Result<(), Error> { + let req = http_serde::query::input::Request::new(data); + let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; + let resp = self.request(req).await?; + resp_ok(resp).await + } + + /// Complete query API can be called on the leader shard by the report collector or + /// by the leader shard to other shards. + /// + /// # Errors + /// If the request has illegal arguments, or fails to be delivered + pub async fn complete_query(&self, query_id: QueryId) -> Result<(), Error> { + let req = http_serde::query::results::Request::new(query_id); + let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; + let resp = self.request(req).await?; + resp_ok(resp).await + } } impl IpaHttpClient { @@ -418,18 +442,6 @@ impl IpaHttpClient { } } - /// Intended to be called externally, e.g. by the report collector. After the report collector - /// calls "create query", it must then send the data for the query to each of the clients. This - /// query input contains the data intended for a helper. - /// # Errors - /// If the request has illegal arguments, or fails to deliver to helper - pub async fn query_input(&self, data: QueryInput) -> Result<(), Error> { - let req = http_serde::query::input::Request::new(data); - let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; - let resp = self.request(req).await?; - resp_ok(resp).await - } - /// Retrieve the status of a query. /// /// ## Errors diff --git a/ipa-core/src/net/server/handlers/mod.rs b/ipa-core/src/net/server/handlers/mod.rs index dc99ebff5..54303bdbb 100644 --- a/ipa-core/src/net/server/handlers/mod.rs +++ b/ipa-core/src/net/server/handlers/mod.rs @@ -17,6 +17,8 @@ pub fn mpc_router(transport: MpcHttpTransport) -> Router { pub fn shard_router(transport: ShardHttpTransport) -> Router { echo::router().nest( http_serde::query::BASE_AXUM_PATH, - Router::new().merge(query::s2s_router(transport)), + Router::new() + .merge(query::c2s_router(&transport)) + .merge(query::s2s_router(transport)), ) } diff --git a/ipa-core/src/net/server/handlers/query/input.rs b/ipa-core/src/net/server/handlers/query/input.rs index da47e9386..4e5487e2d 100644 --- a/ipa-core/src/net/server/handlers/query/input.rs +++ b/ipa-core/src/net/server/handlers/query/input.rs @@ -3,12 +3,13 @@ use hyper::StatusCode; use crate::{ helpers::{query::QueryInput, routing::RouteId, BodyStream}, - net::{http_serde, transport::MpcHttpTransport, Error}, + net::{http_serde, ConnectionFlavor, Error, HttpTransport}, protocol::QueryId, + sync::Arc, }; -async fn handler( - transport: Extension, +async fn handler( + transport: Extension>>, Path(query_id): Path, input_stream: BodyStream, ) -> Result<(), Error> { @@ -16,7 +17,7 @@ async fn handler( query_id, input_stream, }; - let _ = transport + let _ = Arc::clone(&transport) .dispatch( (RouteId::QueryInput, query_input.query_id), query_input.input_stream, @@ -27,9 +28,9 @@ async fn handler( Ok(()) } -pub fn router(transport: MpcHttpTransport) -> Router { +pub fn router(transport: Arc>) -> Router { Router::new() - .route(http_serde::query::input::AXUM_PATH, post(handler)) + .route(http_serde::query::input::AXUM_PATH, post(handler::)) .layer(Extension(transport)) } diff --git a/ipa-core/src/net/server/handlers/query/mod.rs b/ipa-core/src/net/server/handlers/query/mod.rs index fdd8935ae..4f19bfa46 100644 --- a/ipa-core/src/net/server/handlers/query/mod.rs +++ b/ipa-core/src/net/server/handlers/query/mod.rs @@ -19,9 +19,12 @@ use futures_util::{ use hyper::{Request, StatusCode}; use tower::{layer::layer_fn, Service}; -use crate::net::{ - server::ClientIdentity, transport::MpcHttpTransport, ConnectionFlavor, Helper, Shard, - ShardHttpTransport, +use crate::{ + net::{ + server::ClientIdentity, transport::MpcHttpTransport, ConnectionFlavor, Helper, Shard, + ShardHttpTransport, + }, + sync::Arc, }; /// Construct router for IPA query web service @@ -32,10 +35,10 @@ use crate::net::{ pub fn query_router(transport: MpcHttpTransport) -> Router { Router::new() .merge(create::router(transport.clone())) - .merge(input::router(transport.clone())) + .merge(input::router(Arc::clone(&transport.inner_transport))) .merge(status::router(transport.clone())) .merge(kill::router(transport.clone())) - .merge(results::router(transport)) + .merge(results::router(transport.inner_transport)) } /// Construct router for helper-to-helper communications @@ -55,10 +58,18 @@ pub fn h2h_router(transport: MpcHttpTransport) -> Router { /// Construct router for shard-to-shard communications similar to [`h2h_router`]. pub fn s2s_router(transport: ShardHttpTransport) -> Router { Router::new() - .merge(prepare::router(transport.inner_transport)) + .merge(prepare::router(Arc::clone(&transport.inner_transport))) + .merge(results::router(transport.inner_transport)) .layer(layer_fn(HelperAuthentication::<_, Shard>::new)) } +/// Client-to-shard routes. There are only a few cases where we expect parties +/// to talk to individual shards. Input submission is one of them. This path does +/// not require cert authentication. +pub fn c2s_router(transport: &ShardHttpTransport) -> Router { + Router::new().merge(input::router(Arc::clone(&transport.inner_transport))) +} + /// Returns HTTP 401 Unauthorized if the request does not have valid authentication. /// /// Authentication information is carried via the `ClientIdentity` request extension. The extension diff --git a/ipa-core/src/net/server/handlers/query/results.rs b/ipa-core/src/net/server/handlers/query/results.rs index 1c359b659..2f2a59f5e 100644 --- a/ipa-core/src/net/server/handlers/query/results.rs +++ b/ipa-core/src/net/server/handlers/query/results.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use axum::{extract::Path, routing::get, Extension, Router}; use hyper::StatusCode; @@ -6,27 +8,30 @@ use crate::{ net::{ http_serde::{self, query::results::Request}, server::Error, - transport::MpcHttpTransport, + ConnectionFlavor, HttpTransport, }, protocol::QueryId, }; /// Handles the completion of the query by blocking the sender until query is completed. -async fn handler( - transport: Extension, +async fn handler( + transport: Extension>>, Path(query_id): Path, ) -> Result, Error> { let req = Request { query_id }; // TODO: we may be able to stream the response - match transport.dispatch(req, BodyStream::empty()).await { + match Arc::clone(&transport) + .dispatch(req, BodyStream::empty()) + .await + { Ok(resp) => Ok(resp.into_body()), Err(e) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, e)), } } -pub fn router(transport: MpcHttpTransport) -> Router { +pub fn router(transport: Arc>) -> Router { Router::new() - .route(http_serde::query::results::AXUM_PATH, get(handler)) + .route(http_serde::query::results::AXUM_PATH, get(handler::)) .layer(Extension(transport)) } diff --git a/ipa-core/src/net/test.rs b/ipa-core/src/net/test.rs index f4e8a087c..fb0164c55 100644 --- a/ipa-core/src/net/test.rs +++ b/ipa-core/src/net/test.rs @@ -305,6 +305,19 @@ impl TestConfig { } } + /// Returns full set of clients to talk to each individual shard in the MPC. + #[must_use] + pub fn shard_clients(&self) -> [Vec>; 3] { + let shard_clients = HelperIdentity::make_three().map(|id| { + IpaHttpClient::shards_from_conf( + &IpaRuntime::current(), + &self.get_shards_for_helper(id).network, + &ClientIdentity::None, + ) + }); + shard_clients + } + /// Transforms this easy to modify configuration into an easy to run [`TestApp`]. #[must_use] pub fn into_apps(self) -> Vec { diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 6ed523093..4a1d17897 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -106,10 +106,14 @@ impl HttpTransport { let req = serde_json::from_str(route.extra().borrow()).unwrap(); self.clients[client_ix].prepare_query(req).await } + RouteId::CompleteQuery => { + let query_id = >::from(route.query_id()) + .expect("query_id is required to call complete query API"); + self.clients[client_ix].complete_query(query_id).await + } evt @ (RouteId::QueryInput | RouteId::ReceiveQuery | RouteId::QueryStatus - | RouteId::CompleteQuery | RouteId::KillQuery) => { unimplemented!( "attempting to send client-specific request {evt:?} to another helper" @@ -478,13 +482,15 @@ mod tests { } async fn test_make_helpers(conf: TestConfig) { - let clients = IpaHttpClient::from_conf( + let mpc_clients = IpaHttpClient::from_conf( &IpaRuntime::current(), &conf.leaders_ring().network, &ClientIdentity::None, ); + let shard_clients = conf.shard_clients(); + let _helpers = make_helpers(conf).await; - test_multiply(&clients).await; + test_multiply_single_shard(&mpc_clients, shard_clients.each_ref().map(AsRef::as_ref)).await; } #[tokio::test(flavor = "multi_thread")] @@ -495,13 +501,26 @@ mod tests { &conf.leaders_ring().network, &ClientIdentity::None, ); + let shard_clients = conf.shard_clients(); + let shard_clients_ref = shard_clients.each_ref().map(AsRef::as_ref); let _helpers = make_helpers(conf).await; - test_multiply(&clients).await; - test_multiply(&clients).await; + test_multiply_single_shard(&clients, shard_clients_ref).await; + test_multiply_single_shard(&clients, shard_clients_ref).await; } - async fn test_multiply(clients: &[IpaHttpClient; 3]) { + /// This executes test multiplication protocol by running it exclusively on the leader shards. + /// If there is more than one shard in the system, they receive no inputs but still participate + /// by doing the full query cycle. It is backward compatible with traditional 3-party MPC with + /// no shards. + /// + /// The sharding requires some amendments to the test multiplication protocol that are + /// currently in progress. Once completed, this test can be fixed by fully utilizing all + /// shards in the system. + async fn test_multiply_single_shard( + clients: &[IpaHttpClient; 3], + shard_clients: [&[IpaHttpClient]; 3], + ) { const SZ: usize = as Serializable>::Size::USIZE; // send a create query command @@ -532,6 +551,19 @@ mod tests { } try_join_all(handle_resps).await.unwrap(); + // shards receive their own input - in this case empty + try_join_all(shard_clients.each_ref().map(|helper_shard_clients| { + // convention - first client is shard leader, and we submitted the inputs to it. + try_join_all(helper_shard_clients.iter().skip(1).map(|shard_client| { + shard_client.query_input(QueryInput { + query_id, + input_stream: BodyStream::empty(), + }) + })) + })) + .await + .unwrap(); + let result: [_; 3] = join_all(clients.clone().map(|client| async move { let r = client.query_results(query_id).await.unwrap(); AdditiveShare::::from_byte_slice_unchecked(&r).collect::>() From 6f4e61dc1b5868ae4ca3f2af66d0aee5aab261bb Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Wed, 20 Nov 2024 11:52:32 -0800 Subject: [PATCH 31/47] http get state --- ipa-core/src/net/mod.rs | 2 +- ipa-core/src/query/processor.rs | 41 ++++++++++++++++++++++++++++----- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/ipa-core/src/net/mod.rs b/ipa-core/src/net/mod.rs index 05c3e69be..621365077 100644 --- a/ipa-core/src/net/mod.rs +++ b/ipa-core/src/net/mod.rs @@ -24,7 +24,7 @@ pub mod test; mod transport; pub use client::{ClientIdentity, IpaHttpClient}; -pub use error::Error; +pub use error::{Error, ShardError}; pub use server::{IpaHttpServer, TracingSpanMaker}; pub use transport::{HttpTransport, MpcHttpTransport, ShardHttpTransport}; diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 92bfbe0c7..073ecf331 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -8,7 +8,7 @@ use serde::Serialize; use super::min_status; use crate::{ - error::Error as ProtocolError, + error::{BoxError, Error as ProtocolError}, executor::IpaRuntime, helpers::{ query::{CompareStatusRequest, PrepareQuery, QueryConfig, QueryInput}, @@ -362,16 +362,45 @@ impl Processor { Some(status) } + /// This helper function is used to transform a [`BoxError`] into a + /// [`QueryStatusError::DifferentStatus`] and retrieve it's internal state. Returns [`None`] + /// if not possible. + fn downcast_state_error(be: BoxError) -> Option { + let ae = be.downcast::().ok()?; + if let crate::helpers::ApiError::QueryStatus(QueryStatusError::DifferentStatus { + my_status, + .. + }) = *ae + { + return Some(my_status); + } + None + } + + /// This helper is used by the in-memory stack to obtain the state of other shards via a + /// [`QueryStatusError::DifferentStatus`] error. + /// TODO: Ideally broadcast should return a value, that we could use to parse the state instead + /// of relying on errors. #[cfg(feature = "in-memory-infra")] fn get_state_from_error( - _be: &crate::helpers::InMemoryTransportError, + be: crate::helpers::InMemoryTransportError, ) -> Option { - todo!() + if let crate::helpers::InMemoryTransportError::Rejected { inner, .. } = be { + return Self::downcast_state_error(inner); + } + None } + /// This helper is used by the HTTP stack to obtain the state of other shards via a + /// [`QueryStatusError::DifferentStatus`] error. + /// TODO: Ideally broadcast should return a value, that we could use to parse the state instead + /// of relying on errors. #[cfg(feature = "real-world-infra")] - fn get_state_from_error(be: &ShardError) -> QueryStatus { - todo!() + fn get_state_from_error(se: crate::net::ShardError) -> Option { + if let crate::net::Error::Application { error, .. } = se.source { + return Self::downcast_state_error(error); + } + None } /// Returns the query status in this helper, by querying all shards. @@ -403,7 +432,7 @@ impl Processor { // errors return `None` for [`BroadcasteableError::peer_state()`] let states: Vec<_> = e .failures - .iter() + .into_iter() .filter_map(|(_si, e)| Self::get_state_from_error(e)) .collect(); status = states.into_iter().fold(status, min_status); From 2de4cf9369ca50ee93c97b31047f8706a8014a57 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Wed, 20 Nov 2024 12:03:41 -0800 Subject: [PATCH 32/47] removed bradcast_with_errors --- ipa-core/src/helpers/transport/mod.rs | 22 +++++----------------- ipa-core/src/query/processor.rs | 4 ++-- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index 61e97912e..bbe16ee6f 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -345,22 +345,6 @@ pub trait Transport: Clone + Send + Sync + 'static { &self, route: R, ) -> Result<(), BroadcastError> - where - Option: From, - Option: From, - Q: QueryIdBinding, - S: StepBinding, - R: RouteParams + Clone, - { - let errs = self.broadcast_with_errors(route).await; - if errs.is_empty() { - Ok(()) - } else { - Err(errs.into()) - } - } - - async fn broadcast_with_errors(&self, route: R) -> Vec<(Self::Identity, Self::Error)> where Option: From, Option: From, @@ -383,7 +367,11 @@ pub trait Transport: Clone + Send + Sync + 'static { } } - errs + if errs.is_empty() { + Ok(()) + } else { + Err(errs.into()) + } } /// Alias for `Clone::clone`. diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 073ecf331..3f23b6260 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -377,7 +377,7 @@ impl Processor { None } - /// This helper is used by the in-memory stack to obtain the state of other shards via a + /// This helper is used by the in-memory stack to obtain the state of other shards via a /// [`QueryStatusError::DifferentStatus`] error. /// TODO: Ideally broadcast should return a value, that we could use to parse the state instead /// of relying on errors. @@ -391,7 +391,7 @@ impl Processor { None } - /// This helper is used by the HTTP stack to obtain the state of other shards via a + /// This helper is used by the HTTP stack to obtain the state of other shards via a /// [`QueryStatusError::DifferentStatus`] error. /// TODO: Ideally broadcast should return a value, that we could use to parse the state instead /// of relying on errors. From 4fab6edf5071294c62677defb229fb04bb8deaad Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 20 Nov 2024 13:35:01 -0800 Subject: [PATCH 33/47] Fix clippy shuttle --- ipa-core/src/net/http_serde.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index 1965c15ce..85e63a144 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -507,12 +507,10 @@ pub mod query { } impl Request { - #[cfg(any(all(test, not(feature = "shuttle")), feature = "cli"))] // needed because client is blocking; remove when non-blocking pub fn new(query_id: QueryId) -> Self { Self { query_id } } - #[cfg(any(all(test, not(feature = "shuttle")), feature = "cli"))] // needed because client is blocking; remove when non-blocking pub fn try_into_http_request( self, scheme: axum::http::uri::Scheme, From fab480bf82bc9f2a4b79dc2ff666810efeccd0ce Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Wed, 20 Nov 2024 14:12:48 -0800 Subject: [PATCH 34/47] Create an abstraction for multiply-accumulate (#1448) --- .github/workflows/check.yml | 2 +- ipa-core/src/ff/accumulator.rs | 442 ++++++++++++++++++ ipa-core/src/ff/boolean.rs | 9 +- ipa-core/src/ff/ec_prime_field.rs | 8 +- ipa-core/src/ff/field.rs | 2 + ipa-core/src/ff/galois_field.rs | 8 +- ipa-core/src/ff/mod.rs | 2 + ipa-core/src/ff/prime_field.rs | 35 +- .../ipa_prf/malicious_security/lagrange.rs | 21 +- 9 files changed, 510 insertions(+), 19 deletions(-) create mode 100644 ipa-core/src/ff/accumulator.rs diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 9a7735e9e..552fd8b49 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -201,7 +201,7 @@ jobs: - name: Add Rust sources run: rustup component add rust-src - name: Run tests with sanitizer - run: RUSTFLAGS="-Z sanitizer=${{ matrix.sanitizer }} -Z sanitizer-memory-track-origins" cargo test -Z build-std -p ipa-core --target $TARGET --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate ${{ matrix.features }}" + run: RUSTFLAGS="-Z sanitizer=${{ matrix.sanitizer }} -Z sanitizer-memory-track-origins" cargo test -Z build-std -p ipa-core --all-targets --target $TARGET --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate ${{ matrix.features }}" miri: runs-on: ubuntu-latest diff --git a/ipa-core/src/ff/accumulator.rs b/ipa-core/src/ff/accumulator.rs new file mode 100644 index 000000000..6edf5390a --- /dev/null +++ b/ipa-core/src/ff/accumulator.rs @@ -0,0 +1,442 @@ +//! Optimized multiply-accumulate for field elements. +//! +//! An add or multiply operation in a prime field can be implemented as the +//! corresponding operation over the integers, followed by a reduction modulo the prime. +//! +//! In the case of several arithmetic operations performed in sequence, it is not +//! necessary to perform the reduction after every operation. As long as an exact +//! integer result is maintained up until reducing, the reduction can be performed once +//! at the end of the sequence. +//! +//! The reduction is usually the most expensive part of a multiplication operation, even +//! when using special primes like Mersenne primes that have an efficient reduction +//! operation. +//! +//! This module implements an optimized multiply-accumulate operation for field elements +//! that defers reduction. +//! +//! To enable this optimized implementation for a field, it is necessary to (1) select +//! an accumulator type, (2) calculate the number of multiply-accumulate operations +//! that can be performed without overflowing the accumulator. Record these values +//! in an implementation of the `MultiplyAccumulate` trait. +//! +//! This module also provides a generic implementation of `MultiplyAccumulate` that +//! reduces after every operation. To use the generic implementation for `MyField`: +//! +//! ```ignore +//! impl MultiplyAccumulate for MyField { +//! type Accumulator = MyField; +//! type AccumulatorArray = [MyField; N]; +//! } +//! ``` +//! +//! Currently, an optimized implementation is supplied only for `Fp61BitPrime`, which is +//! used _extensively_ in DZKP-based malicious security. All other fields use a naive +//! implementation. (The implementation of DZKPs is also the only place that the traits +//! in this module are used; there is no reason to adopt them if not trying to access +//! the optimized implementation, or at least trying to be easily portable to an +//! optimized implementation.) +//! +//! To perform multiply-accumulate operations using this API: +//! ``` +//! use ipa_core::ff::{PrimeField, MultiplyAccumulate, MultiplyAccumulator}; +//! fn dot_product(a: &[F; N], b: &[F; N]) -> F { +//! let mut acc = ::Accumulator::new(); +//! for i in 0..N { +//! acc.multiply_accumulate(a[i], b[i]); +//! } +//! acc.take() +//! } +//! ``` + +use std::{ + array, + marker::PhantomData, + ops::{AddAssign, Mul}, +}; + +use crate::{ + ff::{Field, U128Conversions}, + secret_sharing::SharedValue, +}; + +/// Trait for multiply-accumulate operations on. +/// +/// See the module-level documentation for usage. +pub trait MultiplyAccumulator: Clone { + /// Create a new accumulator with a value of zero. + fn new() -> Self + where + Self: Sized; + + /// Performs _accumulator <- accumulator + lhs * rhs_. + fn multiply_accumulate(&mut self, lhs: F, rhs: F); + + /// Consume the accumulator and return its value. + fn take(self) -> F; + + #[cfg(test)] + fn reduce_interval() -> usize; +} + +/// Trait for multiply-accumulate operations on vectors. +/// +/// See the module-level documentation for usage. +pub trait MultiplyAccumulatorArray: Clone { + /// Create a new accumulator with a value of zero. + fn new() -> Self + where + Self: Sized; + + /// Performs _accumulator <- accumulator + lhs * rhs_. + /// + /// Each of _accumulator_, _lhs_, and _rhs_ is a vector of length `N`, and _lhs * + /// rhs_ is an element-wise product. + fn multiply_accumulate(&mut self, lhs: &[F; N], rhs: &[F; N]); + + /// Consume the accumulator and return its value. + fn take(self) -> [F; N]; + + #[cfg(test)] + fn reduce_interval() -> usize; +} + +/// Trait for values (e.g. `Field`s) that support multiply-accumulate operations. +/// +/// See the module-level documentation for usage. +pub trait MultiplyAccumulate: Sized { + type Accumulator: MultiplyAccumulator; + type AccumulatorArray: MultiplyAccumulatorArray; +} + +#[derive(Clone, Default)] +pub struct Accumulator { + value: A, + count: usize, + phantom_data: PhantomData, +} + +#[cfg(all(test, unit_test))] +impl Accumulator { + /// Return the raw accumulator value, which may not directly correspond to + /// a valid value for `F`. This is intended for tests. + fn into_raw(self) -> A { + self.value + } +} + +/// Create a new accumulator containing the specified value. +/// +/// This is currently used only by tests, and for that purpose it is sufficient to +/// access it via the concrete `Accumulator` type. To use it more generally, it would +/// need to be made part of the trait (either by adding a `From` supertrait bound, or by +/// adding a method in the trait to perform the operation). +impl From for Accumulator +where + A: Default + From, + F: U128Conversions, +{ + fn from(value: F) -> Self { + Self { + value: value.as_u128().into(), + count: 0, + phantom_data: PhantomData, + } + } +} +/// Optimized multiply-accumulate implementation that adds `REDUCE_INTERVAL` products +/// into an accumulator before reducing. +/// +/// Note that the accumulator must be large enough to hold `REDUCE_INTERVAL` products, +/// plus one additional field element. The additional field element represents the +/// output of the previous reduction, or an initial value of the accumulator. +impl MultiplyAccumulator + for Accumulator +where + A: AddAssign + Copy + Default + Mul + From + Into, + F: SharedValue + U128Conversions + MultiplyAccumulate, +{ + #[inline] + fn new() -> Self + where + Self: Sized, + { + Self::default() + } + + #[inline] + fn multiply_accumulate(&mut self, lhs: F, rhs: F) { + self.value += A::from(lhs.as_u128()) * A::from(rhs.as_u128()); + self.count += 1; + if self.count == REDUCE_INTERVAL { + // Modulo, not really a truncation. + self.value = A::from(F::truncate_from(self.value).as_u128()); + self.count = 0; + } + } + + #[inline] + fn take(self) -> F { + // Modulo, not really a truncation. + F::truncate_from(self.value) + } + + #[cfg(test)] + fn reduce_interval() -> usize { + REDUCE_INTERVAL + } +} + +/// Optimized multiply-accumulate implementation that adds `REDUCE_INTERVAL` products +/// into an accumulator before reducing. This version operates on arrays. +impl MultiplyAccumulatorArray + for Accumulator +where + A: AddAssign + Copy + Default + Mul + From + Into, + F: SharedValue + U128Conversions + MultiplyAccumulate = Self>, +{ + #[inline] + fn new() -> Self + where + Self: Sized, + { + Accumulator { + value: [A::default(); N], + count: 0, + phantom_data: PhantomData, + } + } + + #[inline] + fn multiply_accumulate(&mut self, lhs: &[F; N], rhs: &[F; N]) { + for i in 0..N { + self.value[i] += A::from(lhs[i].as_u128()) * A::from(rhs[i].as_u128()); + } + self.count += 1; + if self.count == REDUCE_INTERVAL { + // Modulo, not really a truncation. + self.value = array::from_fn(|i| A::from(F::truncate_from(self.value[i]).as_u128())); + self.count = 0; + } + } + + #[inline] + fn take(self) -> [F; N] { + // Modulo, not really a truncation. + array::from_fn(|i| F::truncate_from(self.value[i])) + } + + #[cfg(test)] + fn reduce_interval() -> usize { + REDUCE_INTERVAL + } +} + +// Unoptimized implementation usable for any field. +impl MultiplyAccumulator for F { + #[inline] + fn new() -> Self + where + Self: Sized, + { + F::ZERO + } + + #[inline] + fn multiply_accumulate(&mut self, lhs: F, rhs: F) { + *self += lhs * rhs; + } + + #[inline] + fn take(self) -> F { + self + } + + #[cfg(test)] + fn reduce_interval() -> usize { + 1 + } +} + +// Unoptimized implementation usable for any field. This version operates on arrays. +impl MultiplyAccumulatorArray for [F; N] { + #[inline] + fn new() -> Self + where + Self: Sized, + { + [F::ZERO; N] + } + + #[inline] + fn multiply_accumulate(&mut self, lhs: &[F; N], rhs: &[F; N]) { + for i in 0..N { + self[i] += lhs[i] * rhs[i]; + } + } + + #[inline] + fn take(self) -> [F; N] { + self + } + + #[cfg(test)] + fn reduce_interval() -> usize { + 1 + } +} + +#[cfg(all(test, unit_test))] +mod test { + use crate::ff::{Fp61BitPrime, MultiplyAccumulate, MultiplyAccumulator, U128Conversions}; + + // If adding optimized multiply-accumulate for an additional field, it would make + // sense to convert the freestanding `fp61bit_*` tests here to a macro and replicate + // them for the new field. + + #[test] + fn fp61bit_accum_size() { + // Test that the accumulator does not overflow before reaching REDUCE_INTERVAL. + type Accumulator = ::Accumulator; + let max = -Fp61BitPrime::from_bit(true); + let mut acc = Accumulator::from(max); + for _ in 0..Accumulator::reduce_interval() - 1 { + acc.multiply_accumulate(max, max); + } + + let expected = max.as_u128() + + u128::try_from(Accumulator::reduce_interval() - 1).unwrap() + * (max.as_u128() * max.as_u128()); + assert_eq!(acc.clone().into_raw(), expected); + + assert_eq!(acc.take(), Fp61BitPrime::truncate_from(expected)); + + // Test that the largest value the accumulator should ever hold (which is not + // visible through the API, because it will be reduced immediately) does not + // overflow. + let _ = max.as_u128() + + u128::try_from(Accumulator::reduce_interval()).unwrap() + * (max.as_u128() * max.as_u128()); + } + + #[test] + fn fp61bit_accum_reduction() { + // Test that the accumulator reduces after reaching the specified interval. + // (This test assumes that the implementation (1) sets REDUCE_INTERVAL as large + // as possible, and (2) fully reduces upon reaching REDUCE_INTERVAL. It is + // possible for a correct implementation not to have these properties. If + // adding such an implementation, this test will need to be adjusted.) + type Accumulator = ::Accumulator; + let max = -Fp61BitPrime::from_bit(true); + let mut acc = Accumulator::new(); + for _ in 0..Accumulator::reduce_interval() { + acc.multiply_accumulate(max, max); + } + + let expected = Fp61BitPrime::truncate_from( + u128::try_from(Accumulator::reduce_interval()).unwrap() + * (max.as_u128() * max.as_u128()), + ); + assert_eq!(acc.clone().into_raw(), expected.as_u128()); + assert_eq!(acc.take(), expected); + } + + #[macro_export] + macro_rules! accum_tests { + ($field:ty) => { + mod accum_tests { + use std::iter::zip; + + use proptest::prelude::*; + + use $crate::{ + ff::{MultiplyAccumulate, MultiplyAccumulator, MultiplyAccumulatorArray}, + test_executor::run_random, + }; + use super::*; + + const SZ: usize = 2; + + #[test] + fn accum_simple() { + run_random(|mut rng| async move { + let a = rng.gen(); + let b = rng.gen(); + let c = rng.gen(); + let d = rng.gen(); + + let mut acc = <$field as MultiplyAccumulate>::Accumulator::new(); + acc.multiply_accumulate(a, b); + acc.multiply_accumulate(c, d); + + assert_eq!(acc.take(), a * b + c * d); + }); + } + + prop_compose! { + fn arb_inputs(max_len: usize) + (len in 0..max_len) + ( + lhs in prop::collection::vec(any::<$field>(), len), + rhs in prop::collection::vec(any::<$field>(), len), + ) + -> (Vec<$field>, Vec<$field>) { + (lhs, rhs) + } + } + + proptest::proptest! { + #[test] + fn accum_proptest((lhs, rhs) in arb_inputs( + 10 * <$field as MultiplyAccumulate>::Accumulator::reduce_interval() + )) { + type Accumulator = <$field as MultiplyAccumulate>::Accumulator; + let mut acc = Accumulator::new(); + let mut expected = <$field>::ZERO; + for (lhs_term, rhs_term) in zip(lhs, rhs) { + acc.multiply_accumulate(lhs_term, rhs_term); + expected += lhs_term * rhs_term; + } + assert_eq!(acc.take(), expected); + } + } + + prop_compose! { + fn arb_array_inputs(max_len: usize) + (len in 0..max_len) + ( + lhs in prop::collection::vec( + prop::array::uniform(any::<$field>()), + len, + ), + rhs in prop::collection::vec( + prop::array::uniform(any::<$field>()), + len, + ), + ) + -> (Vec<[$field; SZ]>, Vec<[$field; SZ]>) { + (lhs, rhs) + } + } + + proptest::proptest! { + #[test] + fn accum_array_proptest((lhs, rhs) in arb_array_inputs( + 10 * <$field as MultiplyAccumulate>::AccumulatorArray::::reduce_interval() + )) { + type Accumulator = <$field as MultiplyAccumulate>::AccumulatorArray::; + let mut acc = Accumulator::new(); + let mut expected = [<$field>::ZERO; SZ]; + for (lhs_arr, rhs_arr) in zip(lhs, rhs) { + acc.multiply_accumulate(&lhs_arr, &rhs_arr); + for i in 0..SZ { + expected[i] += lhs_arr[i] * rhs_arr[i]; + } + } + assert_eq!(acc.take(), expected); + } + } + } + } + } +} diff --git a/ipa-core/src/ff/boolean.rs b/ipa-core/src/ff/boolean.rs index 3296e3a4e..b76699587 100644 --- a/ipa-core/src/ff/boolean.rs +++ b/ipa-core/src/ff/boolean.rs @@ -5,7 +5,7 @@ use generic_array::GenericArray; use typenum::U1; use crate::{ - ff::{ArrayAccess, Field, PrimeField, Serializable, U128Conversions}, + ff::{ArrayAccess, Field, MultiplyAccumulate, PrimeField, Serializable, U128Conversions}, impl_shared_value_common, protocol::{ context::{dzkp_field::DZKPCompatibleField, dzkp_validator::SegmentEntry}, @@ -58,6 +58,13 @@ impl PrimeField for Boolean { const PRIME: Self::PrimeInteger = 2; } +// Note: The multiply-accumulate tests are not currently instantiated for `Boolean`. +impl MultiplyAccumulate for Boolean { + type Accumulator = Boolean; + // This could be specialized with a bit vector type if it ever mattered. + type AccumulatorArray = [Boolean; N]; +} + impl SharedValue for Boolean { type Storage = bool; const BITS: u32 = 1; diff --git a/ipa-core/src/ff/ec_prime_field.rs b/ipa-core/src/ff/ec_prime_field.rs index 35ad47a2c..893e56007 100644 --- a/ipa-core/src/ff/ec_prime_field.rs +++ b/ipa-core/src/ff/ec_prime_field.rs @@ -6,7 +6,7 @@ use subtle::{Choice, ConstantTimeEq}; use typenum::{U2, U32}; use crate::{ - ff::{boolean_array::BA256, Field, Serializable}, + ff::{boolean_array::BA256, Field, MultiplyAccumulate, Serializable}, impl_shared_value_common, protocol::{ ipa_prf::PRF_CHUNK, @@ -226,6 +226,12 @@ impl ExtendableField for Fp25519 { } } +// Note: The multiply-accumulate tests are not currently instantiated for `Boolean`. +impl MultiplyAccumulate for Fp25519 { + type Accumulator = Fp25519; + type AccumulatorArray = [Fp25519; N]; +} + impl FromRandom for Fp25519 { type SourceLength = U2; diff --git a/ipa-core/src/ff/field.rs b/ipa-core/src/ff/field.rs index 186e81c8a..f34667161 100644 --- a/ipa-core/src/ff/field.rs +++ b/ipa-core/src/ff/field.rs @@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize}; use typenum::{U1, U4, U8}; use crate::{ + ff::MultiplyAccumulate, protocol::prss::FromRandom, secret_sharing::{Block, FieldVectorizable, SharedValue, Vectorizable}, }; @@ -30,6 +31,7 @@ pub trait Field: SharedValue + Mul + MulAssign + + MultiplyAccumulate + FromRandom + Into + Vectorizable<1> diff --git a/ipa-core/src/ff/galois_field.rs b/ipa-core/src/ff/galois_field.rs index 2c99ed164..51181ae3e 100644 --- a/ipa-core/src/ff/galois_field.rs +++ b/ipa-core/src/ff/galois_field.rs @@ -13,7 +13,7 @@ use typenum::{Unsigned, U1, U2, U3, U4, U5}; use crate::{ error::LengthError, - ff::{boolean_array::NonZeroPadding, Field, Serializable, U128Conversions}, + ff::{boolean_array::NonZeroPadding, Field, MultiplyAccumulate, Serializable, U128Conversions}, impl_serializable_trait, impl_shared_value_common, protocol::prss::FromRandomU128, secret_sharing::{Block, FieldVectorizable, SharedValue, StdArray, Vectorizable}, @@ -180,6 +180,12 @@ macro_rules! bit_array_impl { const ONE: Self = Self($one); } + // Note: The multiply-accumulate tests are not currently instantiated for Galois fields. + impl MultiplyAccumulate for $name { + type Accumulator = $name; + type AccumulatorArray = [$name; N]; + } + impl U128Conversions for $name { fn as_u128(&self) -> u128 { (*self).into() diff --git a/ipa-core/src/ff/mod.rs b/ipa-core/src/ff/mod.rs index c53476db1..3d2ffb187 100644 --- a/ipa-core/src/ff/mod.rs +++ b/ipa-core/src/ff/mod.rs @@ -2,6 +2,7 @@ // // This is where we store arithmetic shared secret data models. +mod accumulator; pub mod boolean; pub mod boolean_array; pub mod curve_points; @@ -15,6 +16,7 @@ use std::{ ops::{Add, AddAssign, Sub, SubAssign}, }; +pub use accumulator::{MultiplyAccumulate, MultiplyAccumulator, MultiplyAccumulatorArray}; pub use field::{Field, FieldType}; pub use galois_field::{GaloisField, Gf2, Gf20Bit, Gf32Bit, Gf3Bit, Gf40Bit, Gf8Bit, Gf9Bit}; use generic_array::{ArrayLength, GenericArray}; diff --git a/ipa-core/src/ff/prime_field.rs b/ipa-core/src/ff/prime_field.rs index d45be9e3a..de07710b7 100644 --- a/ipa-core/src/ff/prime_field.rs +++ b/ipa-core/src/ff/prime_field.rs @@ -6,7 +6,10 @@ use subtle::{Choice, ConstantTimeEq}; use super::Field; use crate::{ const_assert, - ff::{Serializable, U128Conversions}, + ff::{ + accumulator::{Accumulator, MultiplyAccumulate}, + Serializable, U128Conversions, + }, impl_shared_value_common, protocol::prss::FromRandomU128, secret_sharing::{Block, FieldVectorizable, SharedValue, StdArray, Vectorizable}, @@ -437,9 +440,17 @@ mod fp31 { field_impl! { Fp31, u8, u16, 8, 31 } rem_modulo_impl! { Fp31, u16 } + impl MultiplyAccumulate for Fp31 { + type Accumulator = Fp31; + type AccumulatorArray = [Fp31; N]; + } + #[cfg(all(test, unit_test))] mod specialized_tests { use super::*; + use crate::accum_tests; + + accum_tests!(Fp31); #[test] fn fp31() { @@ -469,9 +480,17 @@ mod fp32bit { type ArrayAlias = StdArray; } + impl MultiplyAccumulate for Fp32BitPrime { + type Accumulator = Fp32BitPrime; + type AccumulatorArray = [Fp32BitPrime; N]; + } + #[cfg(all(test, unit_test))] mod specialized_tests { use super::*; + use crate::accum_tests; + + accum_tests!(Fp32BitPrime); #[test] fn thirty_two_bit_prime() { @@ -518,6 +537,14 @@ mod fp32bit { mod fp61bit { field_impl! { Fp61BitPrime, u64, u128, 61, 2_305_843_009_213_693_951 } + // For multiply-accumulate of `Fp61BitPrime` using `u128` as the accumulator, we can add 64 + // products onto an original field element before reduction is necessary. i.e., 64 * (2^61 - 2)^2 + + // (2^61 - 2) < 2^128. + impl MultiplyAccumulate for Fp61BitPrime { + type Accumulator = Accumulator; + type AccumulatorArray = Accumulator; + } + impl Fp61BitPrime { #[must_use] pub const fn const_truncate(input: u64) -> Self { @@ -565,6 +592,12 @@ mod fp61bit { use proptest::proptest; use super::*; + use crate::accum_tests; + + // Note: besides the tests generated with this macro, there are some additional + // tests for the optimized accumulator for `Fp61BitPrime` in the `accumulator` + // module. + accum_tests!(Fp61BitPrime); // copied from 32 bit prime field, adjusted wrap arounds, computed using wolframalpha.com #[test] diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 7f805040f..41abab508 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -2,7 +2,9 @@ use std::{array::from_fn, fmt::Debug}; use typenum::Unsigned; -use crate::ff::{batch_invert, Field, PrimeField, Serializable}; +use crate::ff::{ + batch_invert, Field, MultiplyAccumulate, MultiplyAccumulator, PrimeField, Serializable, +}; /// The Canonical Lagrange denominator is defined as the denominator of the Lagrange base polynomials /// `https://en.wikipedia.org/wiki/Lagrange_polynomial` @@ -162,23 +164,14 @@ where /// Computes the dot product of two arrays of the same size. /// It is isolated from Lagrange because there could be potential SIMD optimizations used fn dot_product(a: &[F; N], b: &[F; N]) -> F { - // Staying in integers allows rustc to optimize this code properly, but puts a restriction - // on how large the prime field can be - debug_assert!( - 2 * F::BITS + N.next_power_of_two().ilog2() <= 128, - "The prime field {} is too large for this dot product implementation", - F::PRIME.into() - ); - - let mut sum = 0; - // I am cautious about using zip in hot code // https://github.com/rust-lang/rust/issues/103555 + + let mut acc = ::Accumulator::new(); for i in 0..N { - sum += a[i].as_u128() * b[i].as_u128(); + acc.multiply_accumulate(a[i], b[i]); } - - F::truncate_from(sum) + acc.take() } #[cfg(all(test, unit_test))] From 0084ada5be8a8a4a6b8faba79570514003044dc9 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Wed, 20 Nov 2024 15:11:46 -0800 Subject: [PATCH 35/47] Starting sharded helpers This reverts commit 5904d98f0ffc6826895263af0dd29c8445cae487. --- ipa-core/src/bin/helper.rs | 155 +++++++++++++++++++++--------- ipa-core/src/config.rs | 88 ++++++++++++++++- ipa-core/tests/common/mod.rs | 21 ++-- ipa-core/tests/helper_networks.rs | 6 +- 4 files changed, 209 insertions(+), 61 deletions(-) diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 734db1bc5..9b558c862 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -8,17 +8,23 @@ use std::{ }; use clap::{self, Parser, Subcommand}; +use futures::future::join; use hyper::http::uri::Scheme; use ipa_core::{ cli::{ client_config_setup, keygen, test_setup, ConfGenArgs, KeygenArgs, LoggingHandle, TestSetupArgs, Verbosity, }, - config::{hpke_registry, HpkeServerConfig, NetworkConfig, ServerConfig, TlsConfig}, + config::{ + hpke_registry, sharded_server_from_toml_str, HpkeServerConfig, ServerConfig, TlsConfig, + }, error::BoxError, executor::IpaRuntime, helpers::HelperIdentity, - net::{ClientIdentity, IpaHttpClient, MpcHttpTransport, ShardHttpTransport}, + net::{ + ClientIdentity, ConnectionFlavor, IpaHttpClient, MpcHttpTransport, Shard, + ShardHttpTransport, + }, sharding::ShardIndex, AppConfig, AppSetup, NonZeroU32PowerOfTwo, }; @@ -55,16 +61,31 @@ struct ServerArgs { #[arg(short, long, required = true)] identity: Option, + #[arg(default_value = "0")] + shard_index: Option, + + #[arg(default_value = "1")] + shard_count: Option, + /// Port to listen on #[arg(short, long, default_value = "3000")] port: Option, - /// Use the supplied prebound socket instead of binding a new socket + #[arg(default_value = "6000")] + shard_port: Option, + + /// Use the supplied prebound socket instead of binding a new socket for mpc /// /// This is only intended for avoiding port conflicts in tests. #[arg(hide = true, long)] server_socket_fd: Option, + /// Use the supplied prebound socket instead of binding a new socket for shard server + /// + /// This is only intended for avoiding port conflicts in tests. + #[arg(hide = true, long)] + shard_server_socket_fd: Option, + /// Use insecure HTTP #[arg(short = 'k', long)] disable_https: bool, @@ -73,7 +94,7 @@ struct ServerArgs { #[arg(long, required = true)] network: Option, - /// TLS certificate for helper-to-helper communication + /// TLS certificate for helper-to-helper and shard-to-shard communication #[arg( long, visible_alias("cert"), @@ -82,7 +103,7 @@ struct ServerArgs { )] tls_cert: Option, - /// TLS key for helper-to-helper communication + /// TLS key for helper-to-helper and shard-to-shard communication #[arg(long, visible_alias("key"), requires = "tls_cert")] tls_key: Option, @@ -114,24 +135,58 @@ fn read_file(path: &Path) -> Result, BoxError> { .map_err(|e| format!("failed to open file {}: {e:?}", path.display()))?) } -async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), BoxError> { - let my_identity = HelperIdentity::try_from(args.identity.expect("enforced by clap")).unwrap(); - - let (identity, server_tls) = match (args.tls_cert, args.tls_key) { +/// Helper function that creates the client identity; either with certificates if they are provided +/// or just with headers otherwise. This works both for sharded and helper configs. +fn create_client_identity( + id: F::Identity, + tls_cert: Option, + tls_key: Option, +) -> Result<(ClientIdentity, Option), BoxError> { + match (tls_cert, tls_key) { (Some(cert_file), Some(key_file)) => { let mut key = read_file(&key_file)?; let mut certs = read_file(&cert_file)?; - ( - ClientIdentity::from_pkcs8(&mut certs, &mut key)?, + Ok(( + ClientIdentity::::from_pkcs8(&mut certs, &mut key)?, Some(TlsConfig::File { certificate_file: cert_file, private_key_file: key_file, }), - ) + )) } - (None, None) => (ClientIdentity::Header(my_identity), None), - _ => panic!("should have been rejected by clap"), - }; + (None, None) => Ok((ClientIdentity::Header(id), None)), + _ => Err("should have been rejected by clap".into()), + } +} + +// SAFETY: +// 1. The `--server-socket-fd` option is only intended for use in tests, not in production. +// 2. This must be the only call to from_raw_fd for this file descriptor, to ensure it has +// only one owner. +fn create_listener(server_socket_fd: Option) -> Result, BoxError> { + server_socket_fd + .map(|fd| { + let listener = unsafe { TcpListener::from_raw_fd(fd) }; + if listener.local_addr().is_ok() { + info!("adopting fd {fd} as listening socket"); + Ok(listener) + } else { + Err(BoxError::from(format!("the server was asked to listen on fd {fd}, but it does not appear to be a valid socket"))) + } + }) + .transpose() +} + +async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), BoxError> { + let my_identity = HelperIdentity::try_from(args.identity.expect("enforced by clap")).unwrap(); + let shard_index = ShardIndex::from(args.shard_index.expect("enforced by clap")); + let shard_count = ShardIndex::from(args.shard_count.expect("enforced by clap")); + assert!(shard_index < shard_count); + + let (identity, server_tls) = + create_client_identity(my_identity, args.tls_cert.clone(), args.tls_key.clone())?; + let (shard_identity, shard_server_tls) = + create_client_identity(shard_index, args.tls_cert, args.tls_key)?; let mk_encryption = args.mk_private_key.map(|sk_path| HpkeServerConfig::File { private_key_file: sk_path, @@ -149,6 +204,13 @@ async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), B port: args.port, disable_https: args.disable_https, tls: server_tls, + hpke_config: mk_encryption.clone(), + }; + + let shard_server_config = ServerConfig { + port: args.shard_port, + disable_https: args.disable_https, + tls: shard_server_tls, hpke_config: mk_encryption, }; @@ -157,60 +219,48 @@ async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), B } else { Scheme::HTTPS }; - let network_config_path = args.network.as_deref().unwrap(); - let network_config = NetworkConfig::from_toml_str(&fs::read_to_string(network_config_path)?)? - .override_scheme(&scheme); - // TODO: Following is just temporary until Shard Transport is actually used. - let shard_clients_config = network_config.client.clone(); - let shard_server_config = server_config.clone(); - // --- + let network_config_path = args.network.as_deref().unwrap(); + let network_config_string = &fs::read_to_string(network_config_path)?; + let (mut mpc_network, mut shard_network) = + sharded_server_from_toml_str(network_config_string, my_identity, shard_index, shard_count)?; + mpc_network = mpc_network.override_scheme(&scheme); + shard_network = shard_network.override_scheme(&scheme); let http_runtime = new_http_runtime(&logging_handle); let clients = IpaHttpClient::from_conf( &IpaRuntime::from_tokio_runtime(&http_runtime), - &network_config, + &mpc_network, &identity, ); let (transport, server) = MpcHttpTransport::new( IpaRuntime::from_tokio_runtime(&http_runtime), my_identity, server_config, - network_config, + mpc_network, &clients, Some(handler), ); - // TODO: Following is just temporary until Shard Transport is actually used. - let shard_network_config = NetworkConfig::new_shards(vec![], shard_clients_config); - let (shard_transport, _shard_server) = ShardHttpTransport::new( + let shard_clients = IpaHttpClient::::shards_from_conf( + &IpaRuntime::from_tokio_runtime(&http_runtime), + &shard_network, + &shard_identity, + ); + let (shard_transport, shard_server) = ShardHttpTransport::new( IpaRuntime::from_tokio_runtime(&http_runtime), - ShardIndex::FIRST, - ShardIndex::from(1), + shard_index, + shard_count, shard_server_config, - shard_network_config, - vec![], + shard_network, + shard_clients, Some(shard_handler), ); - // --- let _app = setup.connect(transport.clone(), shard_transport.clone()); - let listener = args.server_socket_fd - .map(|fd| { - // SAFETY: - // 1. The `--server-socket-fd` option is only intended for use in tests, not in production. - // 2. This must be the only call to from_raw_fd for this file descriptor, to ensure it has - // only one owner. - let listener = unsafe { TcpListener::from_raw_fd(fd) }; - if listener.local_addr().is_ok() { - info!("adopting fd {fd} as listening socket"); - Ok(listener) - } else { - Err(BoxError::from(format!("the server was asked to listen on fd {fd}, but it does not appear to be a valid socket"))) - } - }) - .transpose()?; + let listener = create_listener(args.server_socket_fd)?; + let shard_listener = create_listener(args.shard_server_socket_fd)?; let (_addr, server_handle) = server .start_on( @@ -220,8 +270,17 @@ async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), B None as Option<()>, ) .await; + let (_saddr, shard_server_handle) = shard_server + .start_on( + &IpaRuntime::from_tokio_runtime(&http_runtime), + shard_listener, + // TODO, trace based on the content of the query. + None as Option<()>, + ) + .await; + + join(server_handle, shard_server_handle).await; - server_handle.await; [query_runtime, http_runtime].map(Runtime::shutdown_background); Ok(()) diff --git a/ipa-core/src/config.rs b/ipa-core/src/config.rs index 49384b814..408213b3e 100644 --- a/ipa-core/src/config.rs +++ b/ipa-core/src/config.rs @@ -16,7 +16,7 @@ use tokio::fs; use crate::{ error::BoxError, - helpers::HelperIdentity, + helpers::{HelperIdentity, TransportIdentity}, hpke::{ Deserializable as _, IpaPrivateKey, IpaPublicKey, KeyRegistry, PrivateKeyOnly, PublicKeyOnly, Serializable as _, @@ -32,8 +32,10 @@ pub type OwnedPrivateKey = PrivateKeyDer<'static>; pub enum Error { #[error(transparent)] ParseError(#[from] config::ConfigError), - #[error("invalid uri: {0}")] + #[error("Invalid uri: {0}")] InvalidUri(#[from] hyper::http::uri::InvalidUri), + #[error("Invalid network size {0}")] + InvalidNetworkSize(usize), #[error(transparent)] IOError(#[from] std::io::Error), } @@ -117,6 +119,88 @@ impl NetworkConfig { } } +/// Reads a the config for a specific, single, sharded server from string. Expects config to be +/// toml format. The server in the network is specified via `id`, `shard_index` and +/// `shard_count`. +/// +/// First we read the configuration without assigning any identities. The number of peers in the +/// configuration must be a multiple of 6, or 3 as a special case to support older, non-sharded +/// configurations. +/// +/// If there are 3 entries, we assign helper identities for them. We create a dummy sharded +/// configuration. +/// +/// If there are any multiple of 6 peers, then peer assignment is as follows: +/// By rings (to be reminiscent of the previous config). The first 6 entries corresponds to the +/// leaders Ring. H1 shard 0, H2, shard 0, and H3 shard 0. The next 6 correspond increases the +/// shard index by one. +/// +/// Other methods to read the network.toml exist depending on the use, for example +/// [`NetworkConfig::from_toml_str`] reads a non-sharded config. +/// TODO: There will be one to read the information relevant for the RC (doesn't need shard +/// info) +/// +/// # Errors +/// if `input` is in an invalid format +pub fn sharded_server_from_toml_str( + input: &str, + id: HelperIdentity, + shard_index: ShardIndex, + shard_count: ShardIndex, +) -> Result<(NetworkConfig, NetworkConfig), Error> { + use config::{Config, File, FileFormat}; + + let all_network: NetworkConfig = Config::builder() + .add_source(File::from_str(input, FileFormat::Toml)) + .build()? + .try_deserialize()?; + + let ix: usize = shard_index.as_index(); + let ix_count: usize = shard_count.as_index(); + let mpc_id: usize = id.as_index(); + + let total_peers = all_network.peers.len(); + if total_peers == 3 { + let mpc_network = NetworkConfig { + peers: all_network.peers.clone(), + client: all_network.client.clone(), + identities: HelperIdentity::make_three().to_vec(), + }; + let shard_network = NetworkConfig { + peers: vec![all_network.peers[mpc_id].clone()], + client: all_network.client, + identities: vec![ShardIndex(0)], + }; + Ok((mpc_network, shard_network)) + } else if total_peers > 0 && total_peers % 6 == 0 { + let mpc_network = NetworkConfig { + peers: all_network + .peers + .clone() + .into_iter() + .skip(ix * 6) + .take(3) + .collect(), + client: all_network.client.clone(), + identities: HelperIdentity::make_three().to_vec(), + }; + let shard_network = NetworkConfig { + peers: all_network + .peers + .into_iter() + .skip(3 + mpc_id) + .step_by(6) + .take(ix_count) + .collect(), + client: all_network.client, + identities: shard_count.iter().collect(), + }; + Ok((mpc_network, shard_network)) + } else { + Err(Error::InvalidNetworkSize(total_peers)) + } +} + impl NetworkConfig { /// # Panics /// In the unexpected case there are more than max usize shards. diff --git a/ipa-core/tests/common/mod.rs b/ipa-core/tests/common/mod.rs index ca1d5e08a..be582537c 100644 --- a/ipa-core/tests/common/mod.rs +++ b/ipa-core/tests/common/mod.rs @@ -109,9 +109,9 @@ impl CommandExt for Command { } } -fn test_setup(config_path: &Path) -> [TcpListener; 3] { - let sockets: [_; 3] = array::from_fn(|_| TcpListener::bind("127.0.0.1:0").unwrap()); - let ports: [u16; 3] = sockets +fn test_setup(config_path: &Path) -> [TcpListener; 6] { + let sockets: [_; 6] = array::from_fn(|_| TcpListener::bind("127.0.0.1:0").unwrap()); + let ports: [u16; 6] = sockets .each_ref() .map(|sock| sock.local_addr().unwrap().port()); @@ -121,7 +121,7 @@ fn test_setup(config_path: &Path) -> [TcpListener; 3] { .arg("test-setup") .args(["--output-dir".as_ref(), config_path.as_os_str()]) .arg("--ports") - .args(ports.map(|p| p.to_string())); + .args(ports.chunks(2).map(|p| p[0].to_string())); command.status().unwrap_status(); sockets @@ -129,10 +129,10 @@ fn test_setup(config_path: &Path) -> [TcpListener; 3] { pub fn spawn_helpers( config_path: &Path, - sockets: &[TcpListener; 3], + sockets: &[TcpListener; 6], https: bool, ) -> Vec { - zip([1, 2, 3], sockets) + zip([1, 2, 3], sockets.chunks(2)) .map(|(id, socket)| { let mut command = Command::new(HELPER_BIN); command @@ -156,8 +156,13 @@ pub fn spawn_helpers( command.arg("--disable-https"); } - command.preserved_fds(vec![socket.as_raw_fd()]); - command.args(["--server-socket-fd", &socket.as_raw_fd().to_string()]); + command.preserved_fds(vec![socket[0].as_raw_fd()]); + command.args(["--server-socket-fd", &socket[0].as_raw_fd().to_string()]); + command.preserved_fds(vec![socket[1].as_raw_fd()]); + command.args([ + "--shard-server-socket-fd", + &socket[1].as_raw_fd().to_string(), + ]); // something went wrong if command is terminated at this point. let mut child = command.spawn().unwrap(); diff --git a/ipa-core/tests/helper_networks.rs b/ipa-core/tests/helper_networks.rs index 7775ffba4..4eb59a38e 100644 --- a/ipa-core/tests/helper_networks.rs +++ b/ipa-core/tests/helper_networks.rs @@ -71,8 +71,8 @@ fn keygen_confgen() { let dir = TempDir::new_delete_on_drop(); let path = dir.path(); - let sockets: [_; 3] = array::from_fn(|_| TcpListener::bind("127.0.0.1:0").unwrap()); - let ports: [u16; 3] = sockets + let sockets: [_; 6] = array::from_fn(|_| TcpListener::bind("127.0.0.1:0").unwrap()); + let ports: [u16; 6] = sockets .each_ref() .map(|sock| sock.local_addr().unwrap().port()); @@ -85,7 +85,7 @@ fn keygen_confgen() { .args(["--output-dir".as_ref(), path.as_os_str()]) .args(["--keys-dir".as_ref(), path.as_os_str()]) .arg("--ports") - .args(ports.map(|p| p.to_string())) + .args(ports.chunks(2).map(|p| p[0].to_string())) .arg("--hosts") .args(["localhost", "localhost", "localhost"]); if overwrite { From cac7a72612cc5f4f67f17b3f7f0da211c36973c8 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Wed, 20 Nov 2024 15:13:11 -0800 Subject: [PATCH 36/47] Using 3 peer configs for a ring instead of 6 This reverts commit 4ac03a4064ac4ca826d7ec64cd693fb0d5035692. --- ipa-core/src/bin/helper.rs | 8 +- ipa-core/src/config.rs | 513 +++++++++++++++++++++++++++++++------ ipa-core/src/net/config.rs | 25 ++ ipa-core/src/utils/mod.rs | 17 ++ 4 files changed, 475 insertions(+), 88 deletions(-) create mode 100644 ipa-core/src/net/config.rs diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 9b558c862..8eb006e94 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -159,10 +159,10 @@ fn create_client_identity( } } -// SAFETY: -// 1. The `--server-socket-fd` option is only intended for use in tests, not in production. -// 2. This must be the only call to from_raw_fd for this file descriptor, to ensure it has -// only one owner. +/// Creates a [`TcpListener`] from an optional raw file descriptor. Safety notes: +/// 1. The `--server-socket-fd` option is only intended for use in tests, not in production. +/// 2. This must be the only call to from_raw_fd for this file descriptor, to ensure it has +/// only one owner. fn create_listener(server_socket_fd: Option) -> Result, BoxError> { server_socket_fd .map(|fd| { diff --git a/ipa-core/src/config.rs b/ipa-core/src/config.rs index 408213b3e..560f91039 100644 --- a/ipa-core/src/config.rs +++ b/ipa-core/src/config.rs @@ -3,6 +3,7 @@ use std::{ fmt::{Debug, Formatter}, iter::zip, path::PathBuf, + str::FromStr, time::Duration, }; @@ -38,6 +39,8 @@ pub enum Error { InvalidNetworkSize(usize), #[error(transparent)] IOError(#[from] std::io::Error), + #[error("Missing shard ports for peers {0:?}")] + MissingShardPorts(Vec), } /// Configuration describing either 3 peers in a Ring or N shard peers. In a non-sharded case a @@ -119,88 +122,6 @@ impl NetworkConfig { } } -/// Reads a the config for a specific, single, sharded server from string. Expects config to be -/// toml format. The server in the network is specified via `id`, `shard_index` and -/// `shard_count`. -/// -/// First we read the configuration without assigning any identities. The number of peers in the -/// configuration must be a multiple of 6, or 3 as a special case to support older, non-sharded -/// configurations. -/// -/// If there are 3 entries, we assign helper identities for them. We create a dummy sharded -/// configuration. -/// -/// If there are any multiple of 6 peers, then peer assignment is as follows: -/// By rings (to be reminiscent of the previous config). The first 6 entries corresponds to the -/// leaders Ring. H1 shard 0, H2, shard 0, and H3 shard 0. The next 6 correspond increases the -/// shard index by one. -/// -/// Other methods to read the network.toml exist depending on the use, for example -/// [`NetworkConfig::from_toml_str`] reads a non-sharded config. -/// TODO: There will be one to read the information relevant for the RC (doesn't need shard -/// info) -/// -/// # Errors -/// if `input` is in an invalid format -pub fn sharded_server_from_toml_str( - input: &str, - id: HelperIdentity, - shard_index: ShardIndex, - shard_count: ShardIndex, -) -> Result<(NetworkConfig, NetworkConfig), Error> { - use config::{Config, File, FileFormat}; - - let all_network: NetworkConfig = Config::builder() - .add_source(File::from_str(input, FileFormat::Toml)) - .build()? - .try_deserialize()?; - - let ix: usize = shard_index.as_index(); - let ix_count: usize = shard_count.as_index(); - let mpc_id: usize = id.as_index(); - - let total_peers = all_network.peers.len(); - if total_peers == 3 { - let mpc_network = NetworkConfig { - peers: all_network.peers.clone(), - client: all_network.client.clone(), - identities: HelperIdentity::make_three().to_vec(), - }; - let shard_network = NetworkConfig { - peers: vec![all_network.peers[mpc_id].clone()], - client: all_network.client, - identities: vec![ShardIndex(0)], - }; - Ok((mpc_network, shard_network)) - } else if total_peers > 0 && total_peers % 6 == 0 { - let mpc_network = NetworkConfig { - peers: all_network - .peers - .clone() - .into_iter() - .skip(ix * 6) - .take(3) - .collect(), - client: all_network.client.clone(), - identities: HelperIdentity::make_three().to_vec(), - }; - let shard_network = NetworkConfig { - peers: all_network - .peers - .into_iter() - .skip(3 + mpc_id) - .step_by(6) - .take(ix_count) - .collect(), - client: all_network.client, - identities: shard_count.iter().collect(), - }; - Ok((mpc_network, shard_network)) - } else { - Err(Error::InvalidNetworkSize(total_peers)) - } -} - impl NetworkConfig { /// # Panics /// In the unexpected case there are more than max usize shards. @@ -268,6 +189,135 @@ impl NetworkConfig { } } +/// This struct is only used by [`parse_sharded_network_toml`] to parse the entire network. +/// Unlike [`NetworkConfig`], this one doesn't have identities. +#[derive(Clone, Debug, Deserialize)] +struct ShardedNetworkToml { + pub peers: Vec, + + /// HTTP client configuration. + #[serde(default)] + pub client: ClientConfig, +} + +/// This struct is only used by [`parse_sharded_network_toml`] to generate [`PeerConfig`]. It +/// contains an optional `shard_port`. +#[derive(Clone, Debug, Deserialize)] +struct ShardedPeerConfigToml { + #[serde(flatten)] + pub config: PeerConfig, + pub shard_port: Option, +} + +impl ShardedPeerConfigToml { + /// Clones the inner Peer. + fn to_mpc_peer(&self) -> PeerConfig { + self.config.clone() + } + + /// Create a new Peer but its url using [`ShardedPeerConfigToml::shard_port`]. + fn to_shard_peer(&self) -> PeerConfig { + let url = self.config.url.to_string(); + let new_url = format!( + "{}{}", + &url[..=url.find(':').unwrap()], + self.shard_port.expect("Shard port should be set") + ); + let mut shard_peer = self.config.clone(); + shard_peer.url = Uri::from_str(&new_url).expect("Problem creating uri with sharded port"); + shard_peer + } +} + +/// Parses a [`ShardedNetworkToml`] from a network.toml file. Validates that sharding ports are set +/// if necessary. The number of peers needs to be a multiple of 3. +fn parse_sharded_network_toml(input: &str) -> Result { + use config::{Config, File, FileFormat}; + + let parsed: ShardedNetworkToml = Config::builder() + .add_source(File::from_str(input, FileFormat::Toml)) + .build()? + .try_deserialize()?; + + if parsed.peers.len() % 3 != 0 { + return Err(Error::InvalidNetworkSize(parsed.peers.len())); + } + + // Validate sharding config is set + let any_shard_port_set = parsed.peers.iter().any(|peer| peer.shard_port.is_some()); + if any_shard_port_set || parsed.peers.len() > 3 { + let missing_ports: Vec = parsed + .peers + .iter() + .enumerate() + .filter_map(|(i, peer)| { + if peer.shard_port.is_some() { + None + } else { + Some(i) + } + }) + .collect(); + if !missing_ports.is_empty() { + return Err(Error::MissingShardPorts(missing_ports)); + } + } + + Ok(parsed) +} + +/// Reads a the config for a specific, single, sharded server from string. Expects config to be +/// toml format. The server in the network is specified via `id`, `shard_index` and +/// `shard_count`. +/// The first 3 entries corresponds to the leaders Ring. H1 shard 0, H2, shard 0, and H3 shard 0. +/// +/// Other methods to read the network.toml exist depending on the use, for example +/// [`NetworkConfig::from_toml_str`] reads a non-sharded config. +/// TODO: There will be one to read the information relevant for the RC (doesn't need shard +/// info) +/// +/// # Errors +/// if `input` is in an invalid format +pub fn sharded_server_from_toml_str( + input: &str, + id: HelperIdentity, + shard_index: ShardIndex, + shard_count: ShardIndex, +) -> Result<(NetworkConfig, NetworkConfig), Error> { + let all_network = parse_sharded_network_toml(input)?; + + let ix: usize = shard_index.as_index(); + let ix_count: usize = shard_count.as_index(); + let mpc_id: usize = id.as_index(); + + let mpc_network = NetworkConfig { + peers: all_network + .peers + .iter() + .map(ShardedPeerConfigToml::to_mpc_peer) + .skip(ix * 3) + .take(3) + .collect(), + client: all_network.client.clone(), + identities: HelperIdentity::make_three().to_vec(), + }; + + let shard_network = NetworkConfig { + peers: all_network + .peers + .iter() + .map(ShardedPeerConfigToml::to_shard_peer) + .skip(mpc_id) + .step_by(3) + .take(ix_count) + .collect(), + client: all_network.client, + identities: shard_count.iter().collect(), + }; + + Ok((mpc_network, shard_network)) +} + #[derive(Clone, Debug, Deserialize)] pub struct PeerConfig { /// Peer URL @@ -611,15 +661,21 @@ mod tests { use hpke::{kem::X25519HkdfSha256, Kem}; use hyper::Uri; + use once_cell::sync::Lazy; use rand::rngs::StdRng; use rand_core::SeedableRng; - use super::{NetworkConfig, PeerConfig}; + use super::{ + parse_sharded_network_toml, sharded_server_from_toml_str, NetworkConfig, PeerConfig, + }; use crate::{ - config::{ClientConfig, HpkeClientConfig, Http2Configurator, HttpClientConfigurator}, + config::{ + ClientConfig, Error, HpkeClientConfig, Http2Configurator, HttpClientConfigurator, + }, helpers::HelperIdentity, net::test::TestConfigBuilder, sharding::ShardIndex, + utils::replace_all, }; const URI_1: &str = "http://localhost:3000"; @@ -715,4 +771,293 @@ mod tests { let conf = NetworkConfig::new_shards(vec![pc1.clone()], client); assert_eq!(conf.peers[ShardIndex(0)].url, pc1.url); } + + #[test] + fn parse_sharded_server_happy() { + // Asuming position of the second helper in the second shard (the middle server in the 3 x 3) + let (mpc, shard) = sharded_server_from_toml_str( + &SHARDED_OK_REPEAT, + HelperIdentity::TWO, + ShardIndex::from(1), + ShardIndex::from(3), + ) + .unwrap(); + assert_eq!( + vec![ + "helper1.prod.ipa-helper.shard1.dev:443", + "helper2.prod.ipa-helper.shard1.dev:443", + "helper3.prod.ipa-helper.shard1.dev:443" + ], + mpc.peers + .into_iter() + .map(|p| p.url.to_string()) + .collect::>() + ); + assert_eq!( + vec![ + "helper2.prod.ipa-helper.shard0.dev:555", + "helper2.prod.ipa-helper.shard1.dev:555", + "helper2.prod.ipa-helper.shard2.dev:555" + ], + shard + .peers + .into_iter() + .map(|p| p.url.to_string()) + .collect::>() + ); + } + + /// Tests that the url of a shard gets updated with the shard port. + #[test] + fn transform_sharded_peers() { + let mut n = parse_sharded_network_toml(&SHARDED_OK_REPEAT).unwrap(); + assert_eq!( + "helper3.prod.ipa-helper.shard2.dev:666", + n.peers.pop().unwrap().to_shard_peer().url + ); + assert_eq!( + "helper2.prod.ipa-helper.shard2.dev:555", + n.peers.pop().unwrap().to_shard_peer().url + ); + } + + /// Expects an error if the number of peers isn't a multiple of 3 + #[test] + fn invalid_nr_of_peers() { + assert!(matches!( + parse_sharded_network_toml(&SHARDED_8), + Err(Error::InvalidNetworkSize(_)) + )); + } + + /// If any sharded port is set (indicating this is a sharding config), then ALL ports must be set. + #[test] + fn parse_network_toml_shard_port_some_set() { + assert!(matches!( + parse_sharded_network_toml(&SHARDED_COMPAT_ONE_PORT), + Err(Error::MissingShardPorts(_)) + )); + } + + /// If there are more than 3 peers configured (indicating this is a sharding config), then ALL ports must be set. + #[test] + fn parse_network_toml_shard_port_set() { + assert!(matches!( + parse_sharded_network_toml(&SHARDED_MISSING_PORTS_REPEAT), + Err(Error::MissingShardPorts(_)) + )); + } + + /// Testing happy case of a sharded network config + #[test] + fn happy_parse_sharded_network_toml() { + let r_entire_network = parse_sharded_network_toml(SHARDED_OK); + assert!(r_entire_network.is_ok()); + let entire_network = r_entire_network.unwrap(); + assert!(matches!( + entire_network.client.http_config, + HttpClientConfigurator::Http2(_) + )); + assert_eq!(3, entire_network.peers.len()); + assert_eq!( + "helper3.prod.ipa-helper.shard0.dev:443", + entire_network.peers[2].config.url + ); + assert_eq!(Some(666), entire_network.peers[2].shard_port); + } + + /// Testing happy case of a longer sharded network config + #[test] + fn happy_parse_larger_sharded_network_toml() { + let r_entire_network = parse_sharded_network_toml(&SHARDED_OK_REPEAT); + assert!(r_entire_network.is_ok()); + let entire_network = r_entire_network.unwrap(); + assert_eq!(9, entire_network.peers.len()); + assert_eq!(Some(666), entire_network.peers[8].shard_port); + } + + /// This test validates that the new logic that handles sharded configurations can also handle the previous version + #[test] + fn parse_non_sharded_network_toml() { + let r_entire_network = parse_sharded_network_toml(&NON_SHARDED_COMPAT); + assert!(r_entire_network.is_ok()); + let entire_network = r_entire_network.unwrap(); + assert!(matches!( + entire_network.client.http_config, + HttpClientConfigurator::Http2(_) + )); + assert_eq!(3, entire_network.peers.len()); + assert_eq!( + "helper3.prod.ipa-helper.dev:443", + entire_network.peers[2].config.url + ); + } + + // Following are some large &str const used for tests + + /// Valid: A non-sharded network toml, just how they used to be + static NON_SHARDED_COMPAT: Lazy = Lazy::new(|| format!("{CLIENT}{P1}{REST}")); + + /// Invalid: Same as [`NON_SHARDED_COMPAT`] but with a single `shard_port` set. + static SHARDED_COMPAT_ONE_PORT: Lazy = + Lazy::new(|| format!("{CLIENT}{P1}\nshard_port = 777\n{REST}")); + + /// Helper const used to create client configs + const CLIENT: &str = "[client.http_config] +ping_interval_secs = 90.0 +version = \"http2\" +"; + + /// Helper const that has the first part of a Peer, just before were `shard_port` should be + /// specified. + const P1: &str = " +[[peers]] +certificate = \"\"\" +-----BEGIN CERTIFICATE----- +MIIBmzCCAUGgAwIBAgIIMlnveFys5QUwCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb +aGVscGVyMS5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMwM1oXDTI0 +MTIwNDAzMzMwM1owJjEkMCIGA1UEAwwbaGVscGVyMS5wcm9kLmlwYS1oZWxwZXIu +ZGV2MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEWmrrkaKM7HQ0Y3ZGJtHB7vfG +cT/hDCXCoob4pJ/fpPDMrqhiwTTck3bNOuzv9QIx+p5C2Qp8u67rYfK78w86NaNZ +MFcwJgYDVR0RBB8wHYIbaGVscGVyMS5wcm9kLmlwYS1oZWxwZXIuZGV2MA4GA1Ud +DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI +zj0EAwIDSAAwRQIhAKVdDCQeXLRXDYXy4b1N1UxD/JPuD9H7zeRb8/nmIDTfAiBL +a6L0t1Ug8i2RcequSo21x319Tvs5nUbGwzMFSS5wKA== +-----END CERTIFICATE----- +\"\"\" +url = \"helper1.prod.ipa-helper.dev:443\""; + + /// The rest of a configuration + const REST: &str = " +[peers.hpke] +public_key = \"f458d5e1989b2b8f5dacd4143276aa81eaacf7449744ab1251ff667c43550756\" + +[[peers]] +certificate = \"\"\" +-----BEGIN CERTIFICATE----- +MIIBmzCCAUGgAwIBAgIITOtoca16QckwCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb +aGVscGVyMi5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMwOFoXDTI0 +MTIwNDAzMzMwOFowJjEkMCIGA1UEAwwbaGVscGVyMi5wcm9kLmlwYS1oZWxwZXIu +ZGV2MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAETxOH4ATz6kBxLuRznKDFRugm +XKmH7mzRB9wn5vaVlVpDzf4nDHJ+TTzSS6Lb3YLsA7jrXDx+W7xPLGow1+9FNqNZ +MFcwJgYDVR0RBB8wHYIbaGVscGVyMi5wcm9kLmlwYS1oZWxwZXIuZGV2MA4GA1Ud +DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI +zj0EAwIDSAAwRQIhAI4G5ICVm+v5KK5Y8WVetThtNCXGykUBAM1eE973FBOUAiAS +XXgJe9q9hAfHf0puZbv0j0tGY3BiqCkJJaLvK7ba+g== +-----END CERTIFICATE----- +\"\"\" +url = \"helper2.prod.ipa-helper.dev:443\" + +[peers.hpke] +public_key = \"62357179868e5594372b801ddf282c8523806a868a2bff2685f66aa05ffd6c22\" + +[[peers]] +certificate = \"\"\" +-----BEGIN CERTIFICATE----- +MIIBmzCCAUGgAwIBAgIIaf7eDCnXh2swCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb +aGVscGVyMy5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMxMloXDTI0 +MTIwNDAzMzMxMlowJjEkMCIGA1UEAwwbaGVscGVyMy5wcm9kLmlwYS1oZWxwZXIu +ZGV2MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEIMqxCCtu4joFr8YtOrEtq230 +NuTtUAaJHIHNtv4CvpUcbtlFMWFYUUum7d22A8YTfUeccG5PsjjCoQG/dhhSbKNZ +MFcwJgYDVR0RBB8wHYIbaGVscGVyMy5wcm9kLmlwYS1oZWxwZXIuZGV2MA4GA1Ud +DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI +zj0EAwIDSAAwRQIhAOTSQWbN7kfIatNJEwWTBL4xOY88E3+SOnBNExCsTkQuAiBB +/cwOQQUEeE4llrDp+EnyGbzmVm5bINz8gePIxkKqog== +-----END CERTIFICATE----- +\"\"\" +url = \"helper3.prod.ipa-helper.dev:443\" + +[peers.hpke] +public_key = \"55f87a8794b4de9a60f8ede9ed000f5f10c028e22390922efc4fb63bc6be0a61\" +"; + + /// Valid: A sharded configuration + const SHARDED_OK: &str = " +[[peers]] +certificate = \"\"\" +-----BEGIN CERTIFICATE----- +MIIBmzCCAUGgAwIBAgIIMlnveFys5QUwCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb +aGVscGVyMS5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMwM1oXDTI0 +MTIwNDAzMzMwM1owJjEkMCIGA1UEAwwbaGVscGVyMS5wcm9kLmlwYS1oZWxwZXIu +ZGV2MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEWmrrkaKM7HQ0Y3ZGJtHB7vfG +cT/hDCXCoob4pJ/fpPDMrqhiwTTck3bNOuzv9QIx+p5C2Qp8u67rYfK78w86NaNZ +MFcwJgYDVR0RBB8wHYIbaGVscGVyMS5wcm9kLmlwYS1oZWxwZXIuZGV2MA4GA1Ud +DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI +zj0EAwIDSAAwRQIhAKVdDCQeXLRXDYXy4b1N1UxD/JPuD9H7zeRb8/nmIDTfAiBL +a6L0t1Ug8i2RcequSo21x319Tvs5nUbGwzMFSS5wKA== +-----END CERTIFICATE----- +\"\"\" +url = \"helper1.prod.ipa-helper.shard0.dev:443\" +shard_port = 444 + +[peers.hpke] +public_key = \"f458d5e1989b2b8f5dacd4143276aa81eaacf7449744ab1251ff667c43550756\" + +[[peers]] +certificate = \"\"\" +-----BEGIN CERTIFICATE----- +MIIBmzCCAUGgAwIBAgIITOtoca16QckwCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb +aGVscGVyMi5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMwOFoXDTI0 +MTIwNDAzMzMwOFowJjEkMCIGA1UEAwwbaGVscGVyMi5wcm9kLmlwYS1oZWxwZXIu +ZGV2MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAETxOH4ATz6kBxLuRznKDFRugm +XKmH7mzRB9wn5vaVlVpDzf4nDHJ+TTzSS6Lb3YLsA7jrXDx+W7xPLGow1+9FNqNZ +MFcwJgYDVR0RBB8wHYIbaGVscGVyMi5wcm9kLmlwYS1oZWxwZXIuZGV2MA4GA1Ud +DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI +zj0EAwIDSAAwRQIhAI4G5ICVm+v5KK5Y8WVetThtNCXGykUBAM1eE973FBOUAiAS +XXgJe9q9hAfHf0puZbv0j0tGY3BiqCkJJaLvK7ba+g== +-----END CERTIFICATE----- +\"\"\" +url = \"helper2.prod.ipa-helper.shard0.dev:443\" +shard_port = 555 + +[peers.hpke] +public_key = \"62357179868e5594372b801ddf282c8523806a868a2bff2685f66aa05ffd6c22\" + +[[peers]] +certificate = \"\"\" +-----BEGIN CERTIFICATE----- +MIIBmzCCAUGgAwIBAgIIaf7eDCnXh2swCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb +aGVscGVyMy5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMxMloXDTI0 +MTIwNDAzMzMxMlowJjEkMCIGA1UEAwwbaGVscGVyMy5wcm9kLmlwYS1oZWxwZXIu +ZGV2MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEIMqxCCtu4joFr8YtOrEtq230 +NuTtUAaJHIHNtv4CvpUcbtlFMWFYUUum7d22A8YTfUeccG5PsjjCoQG/dhhSbKNZ +MFcwJgYDVR0RBB8wHYIbaGVscGVyMy5wcm9kLmlwYS1oZWxwZXIuZGV2MA4GA1Ud +DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI +zj0EAwIDSAAwRQIhAOTSQWbN7kfIatNJEwWTBL4xOY88E3+SOnBNExCsTkQuAiBB +/cwOQQUEeE4llrDp+EnyGbzmVm5bINz8gePIxkKqog== +-----END CERTIFICATE----- +\"\"\" +url = \"helper3.prod.ipa-helper.shard0.dev:443\" +shard_port = 666 + +[peers.hpke] +public_key = \"55f87a8794b4de9a60f8ede9ed000f5f10c028e22390922efc4fb63bc6be0a61\" +"; + + /// Valid: Three sharded configs together for 9 + static SHARDED_OK_REPEAT: Lazy = Lazy::new(|| { + format!( + "{}{}{}", + SHARDED_OK, + replace_all(SHARDED_OK, "shard0", "shard1"), + replace_all(SHARDED_OK, "shard0", "shard2") + ) + }); + + /// Invalid: A network toml with 8 entries + static SHARDED_8: Lazy = Lazy::new(|| { + let last_peers_index = SHARDED_OK_REPEAT.rfind("[[peers]]").unwrap(); + SHARDED_OK_REPEAT[..last_peers_index].to_string() + }); + + /// Invalid: Same as [`SHARDED_OK_REPEAT`] but without the expected ports + static SHARDED_MISSING_PORTS_REPEAT: Lazy = Lazy::new(|| { + let lines: Vec<&str> = SHARDED_OK_REPEAT.lines().collect(); + let new_lines: Vec = lines + .iter() + .filter(|line| !line.starts_with("shard_port =")) + .map(std::string::ToString::to_string) + .collect(); + new_lines.join("\n") + }); } diff --git a/ipa-core/src/net/config.rs b/ipa-core/src/net/config.rs new file mode 100644 index 000000000..2c5d3e5ed --- /dev/null +++ b/ipa-core/src/net/config.rs @@ -0,0 +1,25 @@ +use std::{ + fmt::Debug, + io::{self, BufRead}, + sync::Arc, +}; + +use config::{Config, File, FileFormat}; +use hyper::{header::HeaderName, Uri}; +use once_cell::sync::Lazy; +use rustls::crypto::CryptoProvider; +use rustls_pki_types::CertificateDer; +use ::serde::Deserialize; + +use crate::{ + config::{ClientConfig, OwnedCertificate, OwnedPrivateKey, PeerConfig}, helpers::{HelperIdentity, TransportIdentity}, serde, sharding::ShardIndex +}; + + + +#[cfg(all(test, unit_test))] +mod tests { + + + +} diff --git a/ipa-core/src/utils/mod.rs b/ipa-core/src/utils/mod.rs index 6829f57fa..f8b785fae 100644 --- a/ipa-core/src/utils/mod.rs +++ b/ipa-core/src/utils/mod.rs @@ -5,3 +5,20 @@ mod power_of_two; #[cfg(target_pointer_width = "64")] pub use power_of_two::{non_zero_prev_power_of_two, NonZeroU32PowerOfTwo}; + +/// Replaces all occurrences of `from` with `to` in `s`. +#[allow(dead_code)] +pub fn replace_all(s: &str, from: &str, to: &str) -> String { + let mut result = String::new(); + let mut i = 0; + while i < s.len() { + if s[i..].starts_with(from) { + result.push_str(to); + i += from.len(); + } else { + result.push(s.chars().nth(i).unwrap()); + i += 1; + } + } + result +} From 1e8f787578783a6202e4f77267e2cfe9ff925985 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Wed, 20 Nov 2024 15:13:39 -0800 Subject: [PATCH 37/47] Network.toml requires shard_port This reverts commit 3e533510f704c8ba47881e60151653d196fca9cb. --- ipa-core/src/config.rs | 54 ++++++++++++++++++++++++++++---------- ipa-core/src/net/config.rs | 25 ------------------ 2 files changed, 40 insertions(+), 39 deletions(-) delete mode 100644 ipa-core/src/net/config.rs diff --git a/ipa-core/src/config.rs b/ipa-core/src/config.rs index 560f91039..7f087de7d 100644 --- a/ipa-core/src/config.rs +++ b/ipa-core/src/config.rs @@ -200,6 +200,22 @@ struct ShardedNetworkToml { pub client: ClientConfig, } +impl ShardedNetworkToml { + fn missing_shard_ports(&self) -> Vec { + self.peers + .iter() + .enumerate() + .filter_map(|(i, peer)| { + if peer.shard_port.is_some() { + None + } else { + Some(i) + } + }) + .collect() + } +} + /// This struct is only used by [`parse_sharded_network_toml`] to generate [`PeerConfig`]. It /// contains an optional `shard_port`. #[derive(Clone, Debug, Deserialize)] @@ -246,18 +262,7 @@ fn parse_sharded_network_toml(input: &str) -> Result // Validate sharding config is set let any_shard_port_set = parsed.peers.iter().any(|peer| peer.shard_port.is_some()); if any_shard_port_set || parsed.peers.len() > 3 { - let missing_ports: Vec = parsed - .peers - .iter() - .enumerate() - .filter_map(|(i, peer)| { - if peer.shard_port.is_some() { - None - } else { - Some(i) - } - }) - .collect(); + let missing_ports = parsed.missing_shard_ports(); if !missing_ports.is_empty() { return Err(Error::MissingShardPorts(missing_ports)); } @@ -268,8 +273,10 @@ fn parse_sharded_network_toml(input: &str) -> Result /// Reads a the config for a specific, single, sharded server from string. Expects config to be /// toml format. The server in the network is specified via `id`, `shard_index` and -/// `shard_count`. -/// The first 3 entries corresponds to the leaders Ring. H1 shard 0, H2, shard 0, and H3 shard 0. +/// `shard_count`. This function expects shard ports to be set for all peers. +/// +/// The first 3 peers corresponds to the leaders Ring. H1 shard 0, H2 shard 0, and H3 shard 0. +/// The next 3 correspond to the next ring with `shard_index` equals 1 and so on. /// /// Other methods to read the network.toml exist depending on the use, for example /// [`NetworkConfig::from_toml_str`] reads a non-sharded config. @@ -285,6 +292,10 @@ pub fn sharded_server_from_toml_str( shard_count: ShardIndex, ) -> Result<(NetworkConfig, NetworkConfig), Error> { let all_network = parse_sharded_network_toml(input)?; + let missing_ports = all_network.missing_shard_ports(); + if !missing_ports.is_empty() { + return Err(Error::MissingShardPorts(missing_ports)); + } let ix: usize = shard_index.as_index(); let ix_count: usize = shard_count.as_index(); @@ -848,6 +859,21 @@ mod tests { )); } + /// Check that shard ports are given for [`sharded_server_from_toml_str`] or error is returned. + #[test] + fn parse_sharded_without_shard_ports() { + // Second, I test the networkconfig parsing + assert!(matches!( + sharded_server_from_toml_str( + &NON_SHARDED_COMPAT, + HelperIdentity::TWO, + ShardIndex::FIRST, + ShardIndex::from(1) + ), + Err(Error::MissingShardPorts(_)) + )); + } + /// Testing happy case of a sharded network config #[test] fn happy_parse_sharded_network_toml() { diff --git a/ipa-core/src/net/config.rs b/ipa-core/src/net/config.rs deleted file mode 100644 index 2c5d3e5ed..000000000 --- a/ipa-core/src/net/config.rs +++ /dev/null @@ -1,25 +0,0 @@ -use std::{ - fmt::Debug, - io::{self, BufRead}, - sync::Arc, -}; - -use config::{Config, File, FileFormat}; -use hyper::{header::HeaderName, Uri}; -use once_cell::sync::Lazy; -use rustls::crypto::CryptoProvider; -use rustls_pki_types::CertificateDer; -use ::serde::Deserialize; - -use crate::{ - config::{ClientConfig, OwnedCertificate, OwnedPrivateKey, PeerConfig}, helpers::{HelperIdentity, TransportIdentity}, serde, sharding::ShardIndex -}; - - - -#[cfg(all(test, unit_test))] -mod tests { - - - -} From fd70733e79cf52ae5394637640ea42793d50b9cf Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Wed, 20 Nov 2024 15:14:00 -0800 Subject: [PATCH 38/47] Fixing HTTP/gen tests This reverts commit 5bae2bf5dc189819058882b8ec4a6f1008cfd750. --- ipa-core/src/cli/clientconf.rs | 13 ++++- ipa-core/src/cli/test_setup.rs | 8 ++- ipa-core/tests/common/mod.rs | 87 ++++++++++++++++--------------- ipa-core/tests/helper_networks.rs | 4 +- 4 files changed, 66 insertions(+), 46 deletions(-) diff --git a/ipa-core/src/cli/clientconf.rs b/ipa-core/src/cli/clientconf.rs index 341a4253a..a57fd59d4 100644 --- a/ipa-core/src/cli/clientconf.rs +++ b/ipa-core/src/cli/clientconf.rs @@ -17,6 +17,9 @@ pub struct ConfGenArgs { #[arg(short, long, num_args = 3, value_name = "PORT", default_values = vec!["3000", "3001", "3002"])] ports: Vec, + #[arg(short, long, num_args = 3, value_name = "SHARD_PORTS", default_values = vec!["6000", "6001", "6002"])] + shard_ports: Vec, + #[arg(long, num_args = 3, default_values = vec!["localhost", "localhost", "localhost"])] hosts: Vec, @@ -54,13 +57,14 @@ pub struct ConfGenArgs { /// [`ConfGenArgs`]: ConfGenArgs /// [`Paths`]: crate::cli::paths::PathExt pub fn setup(args: ConfGenArgs) -> Result<(), BoxError> { - let clients_conf: [_; 3] = zip(args.hosts.iter(), args.ports) + let clients_conf: [_; 3] = zip(args.hosts.iter(), zip(args.ports, args.shard_ports)) .enumerate() - .map(|(id, (host, port))| { + .map(|(id, (host, (port, shard_port)))| { let id: u8 = u8::try_from(id).unwrap() + 1; HelperClientConf { host, port, + shard_port, tls_cert_file: args.keys_dir.helper_tls_cert(id), mk_public_key_file: args.keys_dir.helper_mk_public_key(id), } @@ -96,6 +100,7 @@ pub fn setup(args: ConfGenArgs) -> Result<(), BoxError> { pub struct HelperClientConf<'a> { pub(crate) host: &'a str, pub(crate) port: u16, + pub(crate) shard_port: u16, pub(crate) tls_cert_file: PathBuf, pub(crate) mk_public_key_file: PathBuf, } @@ -133,6 +138,10 @@ pub fn gen_client_config<'a>( port = client_conf.port )), ); + peer.insert( + String::from("shard_port"), + Value::Integer(client_conf.shard_port.into()), + ); peer.insert(String::from("certificate"), Value::String(certificate)); peer.insert( String::from("hpke"), diff --git a/ipa-core/src/cli/test_setup.rs b/ipa-core/src/cli/test_setup.rs index 538faf180..a3aa93cc4 100644 --- a/ipa-core/src/cli/test_setup.rs +++ b/ipa-core/src/cli/test_setup.rs @@ -36,6 +36,9 @@ pub struct TestSetupArgs { #[arg(short, long, num_args = 3, value_name = "PORT", default_values = vec!["3000", "3001", "3002"])] ports: Vec, + + #[arg(short, long, num_args = 3, value_name = "SHARD_PORT", default_values = vec!["6000", "6001", "6002"])] + shard_ports: Vec, } /// Prepare a test network of three helpers. @@ -56,8 +59,8 @@ pub fn test_setup(args: TestSetupArgs) -> Result<(), BoxError> { let localhost = String::from("localhost"); - let clients_config: [_; 3] = zip([1, 2, 3], args.ports) - .map(|(id, port)| { + let clients_config: [_; 3] = zip([1, 2, 3], zip(args.ports, args.shard_ports)) + .map(|(id, (port, shard_port))| { let keygen_args = KeygenArgs { name: localhost.clone(), tls_cert: args.output_dir.helper_tls_cert(id), @@ -72,6 +75,7 @@ pub fn test_setup(args: TestSetupArgs) -> Result<(), BoxError> { Ok(HelperClientConf { host: &localhost, port, + shard_port, tls_cert_file: keygen_args.tls_cert, mk_public_key_file: keygen_args.mk_public_key, }) diff --git a/ipa-core/tests/common/mod.rs b/ipa-core/tests/common/mod.rs index be582537c..dae743987 100644 --- a/ipa-core/tests/common/mod.rs +++ b/ipa-core/tests/common/mod.rs @@ -121,7 +121,9 @@ fn test_setup(config_path: &Path) -> [TcpListener; 6] { .arg("test-setup") .args(["--output-dir".as_ref(), config_path.as_os_str()]) .arg("--ports") - .args(ports.chunks(2).map(|p| p[0].to_string())); + .args(ports.iter().take(3).map(|p| p.to_string())) + .arg("--shard-ports") + .args(ports.iter().skip(3).take(3).map(|p| p.to_string())); command.status().unwrap_status(); sockets @@ -132,47 +134,50 @@ pub fn spawn_helpers( sockets: &[TcpListener; 6], https: bool, ) -> Vec { - zip([1, 2, 3], sockets.chunks(2)) - .map(|(id, socket)| { - let mut command = Command::new(HELPER_BIN); + zip( + [1, 2, 3], + zip(sockets.iter().take(3), sockets.iter().skip(3).take(3)), + ) + .map(|(id, (socket, shard_socket))| { + let mut command = Command::new(HELPER_BIN); + command + .args(["-i", &id.to_string()]) + .args(["--network".into(), config_path.join("network.toml")]) + .silent(); + + if https { command - .args(["-i", &id.to_string()]) - .args(["--network".into(), config_path.join("network.toml")]) - .silent(); - - if https { - command - .args(["--tls-cert".into(), config_path.join(format!("h{id}.pem"))]) - .args(["--tls-key".into(), config_path.join(format!("h{id}.key"))]) - .args([ - "--mk-public-key".into(), - config_path.join(format!("h{id}_mk.pub")), - ]) - .args([ - "--mk-private-key".into(), - config_path.join(format!("h{id}_mk.key")), - ]); - } else { - command.arg("--disable-https"); - } - - command.preserved_fds(vec![socket[0].as_raw_fd()]); - command.args(["--server-socket-fd", &socket[0].as_raw_fd().to_string()]); - command.preserved_fds(vec![socket[1].as_raw_fd()]); - command.args([ - "--shard-server-socket-fd", - &socket[1].as_raw_fd().to_string(), - ]); - - // something went wrong if command is terminated at this point. - let mut child = command.spawn().unwrap(); - if let Ok(Some(status)) = child.try_wait() { - panic!("Helper binary terminated early with status = {status}"); - } - - child.terminate_on_drop() - }) - .collect::>() + .args(["--tls-cert".into(), config_path.join(format!("h{id}.pem"))]) + .args(["--tls-key".into(), config_path.join(format!("h{id}.key"))]) + .args([ + "--mk-public-key".into(), + config_path.join(format!("h{id}_mk.pub")), + ]) + .args([ + "--mk-private-key".into(), + config_path.join(format!("h{id}_mk.key")), + ]); + } else { + command.arg("--disable-https"); + } + + command.preserved_fds(vec![socket.as_raw_fd()]); + command.args(["--server-socket-fd", &socket.as_raw_fd().to_string()]); + command.preserved_fds(vec![shard_socket.as_raw_fd()]); + command.args([ + "--shard-server-socket-fd", + &shard_socket.as_raw_fd().to_string(), + ]); + + // something went wrong if command is terminated at this point. + let mut child = command.spawn().unwrap(); + if let Ok(Some(status)) = child.try_wait() { + panic!("Helper binary terminated early with status = {status}"); + } + + child.terminate_on_drop() + }) + .collect::>() } pub fn test_multiply(config_dir: &Path, https: bool) { diff --git a/ipa-core/tests/helper_networks.rs b/ipa-core/tests/helper_networks.rs index 4eb59a38e..06adb56a7 100644 --- a/ipa-core/tests/helper_networks.rs +++ b/ipa-core/tests/helper_networks.rs @@ -85,7 +85,9 @@ fn keygen_confgen() { .args(["--output-dir".as_ref(), path.as_os_str()]) .args(["--keys-dir".as_ref(), path.as_os_str()]) .arg("--ports") - .args(ports.chunks(2).map(|p| p[0].to_string())) + .args(ports.iter().take(3).map(|p| p.to_string())) + .arg("--shard-ports") + .args(ports.iter().skip(3).take(3).map(|p| p.to_string())) .arg("--hosts") .args(["localhost", "localhost", "localhost"]); if overwrite { From 07faaaf753bec0aa2da59413991a7f28ca9b6a66 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Wed, 20 Nov 2024 15:14:50 -0800 Subject: [PATCH 39/47] using shard_url instead of port This reverts commit 44f699d87100ce19bedf888eac6b35c73043f51e. --- ipa-core/src/bin/helper.rs | 1 + ipa-core/src/cli/clientconf.rs | 8 +- ipa-core/src/config.rs | 197 +++++++++++++++++---------------- ipa-core/src/serde.rs | 21 ++++ ipa-core/src/utils/mod.rs | 17 --- 5 files changed, 129 insertions(+), 115 deletions(-) diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 8eb006e94..7de9de90c 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -71,6 +71,7 @@ struct ServerArgs { #[arg(short, long, default_value = "3000")] port: Option, + /// Port to use for shard-to-shard communication, if sharded MPC is used #[arg(default_value = "6000")] shard_port: Option, diff --git a/ipa-core/src/cli/clientconf.rs b/ipa-core/src/cli/clientconf.rs index a57fd59d4..222b26420 100644 --- a/ipa-core/src/cli/clientconf.rs +++ b/ipa-core/src/cli/clientconf.rs @@ -139,8 +139,12 @@ pub fn gen_client_config<'a>( )), ); peer.insert( - String::from("shard_port"), - Value::Integer(client_conf.shard_port.into()), + String::from("shard_url"), + Value::String(format!( + "{host}:{port}", + host = client_conf.host, + port = client_conf.shard_port + )), ); peer.insert(String::from("certificate"), Value::String(certificate)); peer.insert( diff --git a/ipa-core/src/config.rs b/ipa-core/src/config.rs index 7f087de7d..6b5df3c84 100644 --- a/ipa-core/src/config.rs +++ b/ipa-core/src/config.rs @@ -3,7 +3,6 @@ use std::{ fmt::{Debug, Formatter}, iter::zip, path::PathBuf, - str::FromStr, time::Duration, }; @@ -39,8 +38,8 @@ pub enum Error { InvalidNetworkSize(usize), #[error(transparent)] IOError(#[from] std::io::Error), - #[error("Missing shard ports for peers {0:?}")] - MissingShardPorts(Vec), + #[error("Missing shard URLs for peers {0:?}")] + MissingShardUrls(Vec), } /// Configuration describing either 3 peers in a Ring or N shard peers. In a non-sharded case a @@ -201,12 +200,12 @@ struct ShardedNetworkToml { } impl ShardedNetworkToml { - fn missing_shard_ports(&self) -> Vec { + fn missing_shard_urls(&self) -> Vec { self.peers .iter() .enumerate() .filter_map(|(i, peer)| { - if peer.shard_port.is_some() { + if peer.shard_url.is_some() { None } else { Some(i) @@ -217,12 +216,14 @@ impl ShardedNetworkToml { } /// This struct is only used by [`parse_sharded_network_toml`] to generate [`PeerConfig`]. It -/// contains an optional `shard_port`. +/// contains an optional `shard_url`. #[derive(Clone, Debug, Deserialize)] struct ShardedPeerConfigToml { #[serde(flatten)] pub config: PeerConfig, - pub shard_port: Option, + + #[serde(default, with = "crate::serde::option::uri")] + pub shard_url: Option, } impl ShardedPeerConfigToml { @@ -231,21 +232,15 @@ impl ShardedPeerConfigToml { self.config.clone() } - /// Create a new Peer but its url using [`ShardedPeerConfigToml::shard_port`]. + /// Create a new Peer but its url using [`ShardedPeerConfigToml::shard_url`]. fn to_shard_peer(&self) -> PeerConfig { - let url = self.config.url.to_string(); - let new_url = format!( - "{}{}", - &url[..=url.find(':').unwrap()], - self.shard_port.expect("Shard port should be set") - ); let mut shard_peer = self.config.clone(); - shard_peer.url = Uri::from_str(&new_url).expect("Problem creating uri with sharded port"); + shard_peer.url = self.shard_url.clone().expect("Shard URL should be set"); shard_peer } } -/// Parses a [`ShardedNetworkToml`] from a network.toml file. Validates that sharding ports are set +/// Parses a [`ShardedNetworkToml`] from a network.toml file. Validates that sharding urls are set /// if necessary. The number of peers needs to be a multiple of 3. fn parse_sharded_network_toml(input: &str) -> Result { use config::{Config, File, FileFormat}; @@ -260,11 +255,11 @@ fn parse_sharded_network_toml(input: &str) -> Result } // Validate sharding config is set - let any_shard_port_set = parsed.peers.iter().any(|peer| peer.shard_port.is_some()); - if any_shard_port_set || parsed.peers.len() > 3 { - let missing_ports = parsed.missing_shard_ports(); - if !missing_ports.is_empty() { - return Err(Error::MissingShardPorts(missing_ports)); + let any_shard_url_set = parsed.peers.iter().any(|peer| peer.shard_url.is_some()); + if any_shard_url_set || parsed.peers.len() > 3 { + let missing_urls = parsed.missing_shard_urls(); + if !missing_urls.is_empty() { + return Err(Error::MissingShardUrls(missing_urls)); } } @@ -273,7 +268,7 @@ fn parse_sharded_network_toml(input: &str) -> Result /// Reads a the config for a specific, single, sharded server from string. Expects config to be /// toml format. The server in the network is specified via `id`, `shard_index` and -/// `shard_count`. This function expects shard ports to be set for all peers. +/// `shard_count`. This function expects shard urls to be set for all peers. /// /// The first 3 peers corresponds to the leaders Ring. H1 shard 0, H2 shard 0, and H3 shard 0. /// The next 3 correspond to the next ring with `shard_index` equals 1 and so on. @@ -292,9 +287,9 @@ pub fn sharded_server_from_toml_str( shard_count: ShardIndex, ) -> Result<(NetworkConfig, NetworkConfig), Error> { let all_network = parse_sharded_network_toml(input)?; - let missing_ports = all_network.missing_shard_ports(); - if !missing_ports.is_empty() { - return Err(Error::MissingShardPorts(missing_ports)); + let missing_urls = all_network.missing_shard_urls(); + if !missing_urls.is_empty() { + return Err(Error::MissingShardUrls(missing_urls)); } let ix: usize = shard_index.as_index(); @@ -686,7 +681,6 @@ mod tests { helpers::HelperIdentity, net::test::TestConfigBuilder, sharding::ShardIndex, - utils::replace_all, }; const URI_1: &str = "http://localhost:3000"; @@ -730,7 +724,10 @@ mod tests { let mut rng = StdRng::seed_from_u64(1); let (_, public_key) = X25519HkdfSha256::gen_keypair(&mut rng); let config = HpkeClientConfig { public_key }; - assert_eq!(format!("{config:?}"), "HpkeClientConfig { public_key: \"2bd9da78f01d8bc6948bbcbe44ec1e7163d05083e267d110cdb2e75d847e3b6f\" }"); + assert_eq!( + format!("{config:?}"), + r#"HpkeClientConfig { public_key: "2bd9da78f01d8bc6948bbcbe44ec1e7163d05083e267d110cdb2e75d847e3b6f" }"# + ); } #[test] @@ -795,9 +792,9 @@ mod tests { .unwrap(); assert_eq!( vec![ - "helper1.prod.ipa-helper.shard1.dev:443", - "helper2.prod.ipa-helper.shard1.dev:443", - "helper3.prod.ipa-helper.shard1.dev:443" + "helper1.shard1.org:443", + "helper2.shard1.org:443", + "helper3.shard1.org:443" ], mpc.peers .into_iter() @@ -806,9 +803,9 @@ mod tests { ); assert_eq!( vec![ - "helper2.prod.ipa-helper.shard0.dev:555", - "helper2.prod.ipa-helper.shard1.dev:555", - "helper2.prod.ipa-helper.shard2.dev:555" + "helper2.shard0.org:555", + "helper2.shard1.org:555", + "helper2.shard2.org:555" ], shard .peers @@ -818,16 +815,16 @@ mod tests { ); } - /// Tests that the url of a shard gets updated with the shard port. + /// Tests that the url of a shard gets updated with the shard url. #[test] fn transform_sharded_peers() { let mut n = parse_sharded_network_toml(&SHARDED_OK_REPEAT).unwrap(); assert_eq!( - "helper3.prod.ipa-helper.shard2.dev:666", + "helper3.shard2.org:666", n.peers.pop().unwrap().to_shard_peer().url ); assert_eq!( - "helper2.prod.ipa-helper.shard2.dev:555", + "helper2.shard2.org:555", n.peers.pop().unwrap().to_shard_peer().url ); } @@ -841,27 +838,27 @@ mod tests { )); } - /// If any sharded port is set (indicating this is a sharding config), then ALL ports must be set. + /// If any sharded url is set (indicating this is a sharding config), then ALL urls must be set. #[test] - fn parse_network_toml_shard_port_some_set() { + fn parse_network_toml_shard_urls_some_set() { assert!(matches!( - parse_sharded_network_toml(&SHARDED_COMPAT_ONE_PORT), - Err(Error::MissingShardPorts(_)) + parse_sharded_network_toml(&SHARDED_COMPAT_ONE_URL), + Err(Error::MissingShardUrls(_)) )); } - /// If there are more than 3 peers configured (indicating this is a sharding config), then ALL ports must be set. + /// If there are more than 3 peers configured (indicating this is a sharding config), then ALL urls must be set. #[test] - fn parse_network_toml_shard_port_set() { + fn parse_network_toml_shard_urls_set() { assert!(matches!( - parse_sharded_network_toml(&SHARDED_MISSING_PORTS_REPEAT), - Err(Error::MissingShardPorts(_)) + parse_sharded_network_toml(&SHARDED_MISSING_URLS_REPEAT), + Err(Error::MissingShardUrls(_)) )); } - /// Check that shard ports are given for [`sharded_server_from_toml_str`] or error is returned. + /// Check that shard urls are given for [`sharded_server_from_toml_str`] or error is returned. #[test] - fn parse_sharded_without_shard_ports() { + fn parse_sharded_without_shard_urls() { // Second, I test the networkconfig parsing assert!(matches!( sharded_server_from_toml_str( @@ -870,7 +867,7 @@ mod tests { ShardIndex::FIRST, ShardIndex::from(1) ), - Err(Error::MissingShardPorts(_)) + Err(Error::MissingShardUrls(_)) )); } @@ -885,11 +882,15 @@ mod tests { HttpClientConfigurator::Http2(_) )); assert_eq!(3, entire_network.peers.len()); + assert_eq!("helper3.shard0.org:443", entire_network.peers[2].config.url); assert_eq!( - "helper3.prod.ipa-helper.shard0.dev:443", - entire_network.peers[2].config.url + "helper3.shard0.org:666", + entire_network.peers[2] + .shard_url + .as_ref() + .unwrap() + .to_string() ); - assert_eq!(Some(666), entire_network.peers[2].shard_port); } /// Testing happy case of a longer sharded network config @@ -899,7 +900,14 @@ mod tests { assert!(r_entire_network.is_ok()); let entire_network = r_entire_network.unwrap(); assert_eq!(9, entire_network.peers.len()); - assert_eq!(Some(666), entire_network.peers[8].shard_port); + assert_eq!( + "helper3.shard2.org:666", + entire_network.peers[8] + .shard_url + .as_ref() + .unwrap() + .to_string() + ); } /// This test validates that the new logic that handles sharded configurations can also handle the previous version @@ -913,10 +921,7 @@ mod tests { HttpClientConfigurator::Http2(_) )); assert_eq!(3, entire_network.peers.len()); - assert_eq!( - "helper3.prod.ipa-helper.dev:443", - entire_network.peers[2].config.url - ); + assert_eq!("helper3.org:443", entire_network.peers[2].config.url); } // Following are some large &str const used for tests @@ -925,20 +930,20 @@ mod tests { static NON_SHARDED_COMPAT: Lazy = Lazy::new(|| format!("{CLIENT}{P1}{REST}")); /// Invalid: Same as [`NON_SHARDED_COMPAT`] but with a single `shard_port` set. - static SHARDED_COMPAT_ONE_PORT: Lazy = - Lazy::new(|| format!("{CLIENT}{P1}\nshard_port = 777\n{REST}")); + static SHARDED_COMPAT_ONE_URL: Lazy = + Lazy::new(|| format!("{CLIENT}{P1}\nshard_url = \"helper1.org:777\"\n{REST}")); /// Helper const used to create client configs - const CLIENT: &str = "[client.http_config] + const CLIENT: &str = r#"[client.http_config] ping_interval_secs = 90.0 -version = \"http2\" -"; +version = "http2" +"#; /// Helper const that has the first part of a Peer, just before were `shard_port` should be /// specified. - const P1: &str = " + const P1: &str = r#" [[peers]] -certificate = \"\"\" +certificate = """ -----BEGIN CERTIFICATE----- MIIBmzCCAUGgAwIBAgIIMlnveFys5QUwCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb aGVscGVyMS5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMwM1oXDTI0 @@ -950,16 +955,16 @@ DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI zj0EAwIDSAAwRQIhAKVdDCQeXLRXDYXy4b1N1UxD/JPuD9H7zeRb8/nmIDTfAiBL a6L0t1Ug8i2RcequSo21x319Tvs5nUbGwzMFSS5wKA== -----END CERTIFICATE----- -\"\"\" -url = \"helper1.prod.ipa-helper.dev:443\""; +""" +url = "helper1.org:443""#; /// The rest of a configuration - const REST: &str = " + const REST: &str = r#" [peers.hpke] -public_key = \"f458d5e1989b2b8f5dacd4143276aa81eaacf7449744ab1251ff667c43550756\" +public_key = "f458d5e1989b2b8f5dacd4143276aa81eaacf7449744ab1251ff667c43550756" [[peers]] -certificate = \"\"\" +certificate = """ -----BEGIN CERTIFICATE----- MIIBmzCCAUGgAwIBAgIITOtoca16QckwCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb aGVscGVyMi5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMwOFoXDTI0 @@ -971,14 +976,14 @@ DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI zj0EAwIDSAAwRQIhAI4G5ICVm+v5KK5Y8WVetThtNCXGykUBAM1eE973FBOUAiAS XXgJe9q9hAfHf0puZbv0j0tGY3BiqCkJJaLvK7ba+g== -----END CERTIFICATE----- -\"\"\" -url = \"helper2.prod.ipa-helper.dev:443\" +""" +url = "helper2.org:443" [peers.hpke] -public_key = \"62357179868e5594372b801ddf282c8523806a868a2bff2685f66aa05ffd6c22\" +public_key = "62357179868e5594372b801ddf282c8523806a868a2bff2685f66aa05ffd6c22" [[peers]] -certificate = \"\"\" +certificate = """ -----BEGIN CERTIFICATE----- MIIBmzCCAUGgAwIBAgIIaf7eDCnXh2swCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb aGVscGVyMy5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMxMloXDTI0 @@ -990,17 +995,17 @@ DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI zj0EAwIDSAAwRQIhAOTSQWbN7kfIatNJEwWTBL4xOY88E3+SOnBNExCsTkQuAiBB /cwOQQUEeE4llrDp+EnyGbzmVm5bINz8gePIxkKqog== -----END CERTIFICATE----- -\"\"\" -url = \"helper3.prod.ipa-helper.dev:443\" +""" +url = "helper3.org:443" [peers.hpke] -public_key = \"55f87a8794b4de9a60f8ede9ed000f5f10c028e22390922efc4fb63bc6be0a61\" -"; +public_key = "55f87a8794b4de9a60f8ede9ed000f5f10c028e22390922efc4fb63bc6be0a61" +"#; /// Valid: A sharded configuration - const SHARDED_OK: &str = " + const SHARDED_OK: &str = r#" [[peers]] -certificate = \"\"\" +certificate = """ -----BEGIN CERTIFICATE----- MIIBmzCCAUGgAwIBAgIIMlnveFys5QUwCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb aGVscGVyMS5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMwM1oXDTI0 @@ -1012,15 +1017,15 @@ DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI zj0EAwIDSAAwRQIhAKVdDCQeXLRXDYXy4b1N1UxD/JPuD9H7zeRb8/nmIDTfAiBL a6L0t1Ug8i2RcequSo21x319Tvs5nUbGwzMFSS5wKA== -----END CERTIFICATE----- -\"\"\" -url = \"helper1.prod.ipa-helper.shard0.dev:443\" -shard_port = 444 +""" +url = "helper1.shard0.org:443" +shard_url = "helper1.shard0.org:444" [peers.hpke] -public_key = \"f458d5e1989b2b8f5dacd4143276aa81eaacf7449744ab1251ff667c43550756\" +public_key = "f458d5e1989b2b8f5dacd4143276aa81eaacf7449744ab1251ff667c43550756" [[peers]] -certificate = \"\"\" +certificate = """ -----BEGIN CERTIFICATE----- MIIBmzCCAUGgAwIBAgIITOtoca16QckwCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb aGVscGVyMi5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMwOFoXDTI0 @@ -1032,15 +1037,15 @@ DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI zj0EAwIDSAAwRQIhAI4G5ICVm+v5KK5Y8WVetThtNCXGykUBAM1eE973FBOUAiAS XXgJe9q9hAfHf0puZbv0j0tGY3BiqCkJJaLvK7ba+g== -----END CERTIFICATE----- -\"\"\" -url = \"helper2.prod.ipa-helper.shard0.dev:443\" -shard_port = 555 +""" +url = "helper2.shard0.org:443" +shard_url = "helper2.shard0.org:555" [peers.hpke] -public_key = \"62357179868e5594372b801ddf282c8523806a868a2bff2685f66aa05ffd6c22\" +public_key = "62357179868e5594372b801ddf282c8523806a868a2bff2685f66aa05ffd6c22" [[peers]] -certificate = \"\"\" +certificate = """ -----BEGIN CERTIFICATE----- MIIBmzCCAUGgAwIBAgIIaf7eDCnXh2swCgYIKoZIzj0EAwIwJjEkMCIGA1UEAwwb aGVscGVyMy5wcm9kLmlwYS1oZWxwZXIuZGV2MB4XDTI0MDkwNDAzMzMxMloXDTI0 @@ -1052,21 +1057,21 @@ DwEB/wQEAwICpDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwCgYIKoZI zj0EAwIDSAAwRQIhAOTSQWbN7kfIatNJEwWTBL4xOY88E3+SOnBNExCsTkQuAiBB /cwOQQUEeE4llrDp+EnyGbzmVm5bINz8gePIxkKqog== -----END CERTIFICATE----- -\"\"\" -url = \"helper3.prod.ipa-helper.shard0.dev:443\" -shard_port = 666 +""" +url = "helper3.shard0.org:443" +shard_url = "helper3.shard0.org:666" [peers.hpke] -public_key = \"55f87a8794b4de9a60f8ede9ed000f5f10c028e22390922efc4fb63bc6be0a61\" -"; +public_key = "55f87a8794b4de9a60f8ede9ed000f5f10c028e22390922efc4fb63bc6be0a61" +"#; /// Valid: Three sharded configs together for 9 static SHARDED_OK_REPEAT: Lazy = Lazy::new(|| { format!( "{}{}{}", SHARDED_OK, - replace_all(SHARDED_OK, "shard0", "shard1"), - replace_all(SHARDED_OK, "shard0", "shard2") + SHARDED_OK.replace("shard0", "shard1"), + SHARDED_OK.replace("shard0", "shard2") ) }); @@ -1077,11 +1082,11 @@ public_key = \"55f87a8794b4de9a60f8ede9ed000f5f10c028e22390922efc4fb63bc6be0a61\ }); /// Invalid: Same as [`SHARDED_OK_REPEAT`] but without the expected ports - static SHARDED_MISSING_PORTS_REPEAT: Lazy = Lazy::new(|| { + static SHARDED_MISSING_URLS_REPEAT: Lazy = Lazy::new(|| { let lines: Vec<&str> = SHARDED_OK_REPEAT.lines().collect(); let new_lines: Vec = lines .iter() - .filter(|line| !line.starts_with("shard_port =")) + .filter(|line| !line.starts_with("shard_url =")) .map(std::string::ToString::to_string) .collect(); new_lines.join("\n") diff --git a/ipa-core/src/serde.rs b/ipa-core/src/serde.rs index 0acc2d925..ed65273d7 100644 --- a/ipa-core/src/serde.rs +++ b/ipa-core/src/serde.rs @@ -13,6 +13,27 @@ pub mod uri { } } +#[cfg(feature = "web-app")] +pub mod option { + pub mod uri { + use hyper::Uri; + use serde::{de::Error, Deserialize, Deserializer}; + + /// # Errors + /// if deserializing from string fails, or if string is not a [`Uri`] + pub fn deserialize<'de, D: Deserializer<'de>>( + deserializer: D, + ) -> Result, D::Error> { + let opt_s: Option = Deserialize::deserialize(deserializer)?; + if let Some(s) = opt_s { + s.parse().map(Some).map_err(D::Error::custom) + } else { + Ok(None) + } + } + } +} + pub mod duration { use std::time::Duration; diff --git a/ipa-core/src/utils/mod.rs b/ipa-core/src/utils/mod.rs index f8b785fae..6829f57fa 100644 --- a/ipa-core/src/utils/mod.rs +++ b/ipa-core/src/utils/mod.rs @@ -5,20 +5,3 @@ mod power_of_two; #[cfg(target_pointer_width = "64")] pub use power_of_two::{non_zero_prev_power_of_two, NonZeroU32PowerOfTwo}; - -/// Replaces all occurrences of `from` with `to` in `s`. -#[allow(dead_code)] -pub fn replace_all(s: &str, from: &str, to: &str) -> String { - let mut result = String::new(); - let mut i = 0; - while i < s.len() { - if s[i..].starts_with(from) { - result.push_str(to); - i += from.len(); - } else { - result.push(s.chars().nth(i).unwrap()); - i += 1; - } - } - result -} From 34d15cd6efc95cedfed3f87b194ff39f82df0188 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 20 Nov 2024 15:27:25 -0800 Subject: [PATCH 40/47] Feedback --- ipa-core/src/app.rs | 2 +- ipa-core/src/net/server/handlers/mod.rs | 4 +- .../src/net/server/handlers/query/input.rs | 13 ++-- ipa-core/src/net/server/handlers/query/mod.rs | 9 +-- ipa-core/src/net/test.rs | 18 ++--- ipa-core/src/net/transport.rs | 65 +++++++++++-------- ipa-core/src/query/processor.rs | 27 ++++---- 7 files changed, 67 insertions(+), 71 deletions(-) diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index a19787c1b..ac3e714a4 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -186,7 +186,7 @@ impl RequestHandler for Inner { let req = req.into::()?; HelperResponse::from(qp.prepare_shard(&self.shard_transport, req)?) } - RouteId::QueryInput | RouteId::CompleteQuery => { + RouteId::CompleteQuery => { // The processing flow for this API is exactly the same, regardless // whether it was received from a peer shard or from report collector. // Authentication is handled on the layer above, so we erase the identity diff --git a/ipa-core/src/net/server/handlers/mod.rs b/ipa-core/src/net/server/handlers/mod.rs index 54303bdbb..dc99ebff5 100644 --- a/ipa-core/src/net/server/handlers/mod.rs +++ b/ipa-core/src/net/server/handlers/mod.rs @@ -17,8 +17,6 @@ pub fn mpc_router(transport: MpcHttpTransport) -> Router { pub fn shard_router(transport: ShardHttpTransport) -> Router { echo::router().nest( http_serde::query::BASE_AXUM_PATH, - Router::new() - .merge(query::c2s_router(&transport)) - .merge(query::s2s_router(transport)), + Router::new().merge(query::s2s_router(transport)), ) } diff --git a/ipa-core/src/net/server/handlers/query/input.rs b/ipa-core/src/net/server/handlers/query/input.rs index 4e5487e2d..da47e9386 100644 --- a/ipa-core/src/net/server/handlers/query/input.rs +++ b/ipa-core/src/net/server/handlers/query/input.rs @@ -3,13 +3,12 @@ use hyper::StatusCode; use crate::{ helpers::{query::QueryInput, routing::RouteId, BodyStream}, - net::{http_serde, ConnectionFlavor, Error, HttpTransport}, + net::{http_serde, transport::MpcHttpTransport, Error}, protocol::QueryId, - sync::Arc, }; -async fn handler( - transport: Extension>>, +async fn handler( + transport: Extension, Path(query_id): Path, input_stream: BodyStream, ) -> Result<(), Error> { @@ -17,7 +16,7 @@ async fn handler( query_id, input_stream, }; - let _ = Arc::clone(&transport) + let _ = transport .dispatch( (RouteId::QueryInput, query_input.query_id), query_input.input_stream, @@ -28,9 +27,9 @@ async fn handler( Ok(()) } -pub fn router(transport: Arc>) -> Router { +pub fn router(transport: MpcHttpTransport) -> Router { Router::new() - .route(http_serde::query::input::AXUM_PATH, post(handler::)) + .route(http_serde::query::input::AXUM_PATH, post(handler)) .layer(Extension(transport)) } diff --git a/ipa-core/src/net/server/handlers/query/mod.rs b/ipa-core/src/net/server/handlers/query/mod.rs index 4f19bfa46..6e95c4cd4 100644 --- a/ipa-core/src/net/server/handlers/query/mod.rs +++ b/ipa-core/src/net/server/handlers/query/mod.rs @@ -35,7 +35,7 @@ use crate::{ pub fn query_router(transport: MpcHttpTransport) -> Router { Router::new() .merge(create::router(transport.clone())) - .merge(input::router(Arc::clone(&transport.inner_transport))) + .merge(input::router(transport.clone())) .merge(status::router(transport.clone())) .merge(kill::router(transport.clone())) .merge(results::router(transport.inner_transport)) @@ -63,13 +63,6 @@ pub fn s2s_router(transport: ShardHttpTransport) -> Router { .layer(layer_fn(HelperAuthentication::<_, Shard>::new)) } -/// Client-to-shard routes. There are only a few cases where we expect parties -/// to talk to individual shards. Input submission is one of them. This path does -/// not require cert authentication. -pub fn c2s_router(transport: &ShardHttpTransport) -> Router { - Router::new().merge(input::router(Arc::clone(&transport.inner_transport))) -} - /// Returns HTTP 401 Unauthorized if the request does not have valid authentication. /// /// Authentication information is carried via the `ClientIdentity` request extension. The extension diff --git a/ipa-core/src/net/test.rs b/ipa-core/src/net/test.rs index fb0164c55..c74c25610 100644 --- a/ipa-core/src/net/test.rs +++ b/ipa-core/src/net/test.rs @@ -279,6 +279,10 @@ impl TestConfig { &self.rings[0] } + pub fn rings(&self) -> impl Iterator> { + self.rings.iter() + } + /// Gets a ref to the entire shard network for a specific helper. #[must_use] pub fn get_shards_for_helper(&self, id: HelperIdentity) -> &TestNetwork { @@ -304,20 +308,6 @@ impl TestConfig { shards, } } - - /// Returns full set of clients to talk to each individual shard in the MPC. - #[must_use] - pub fn shard_clients(&self) -> [Vec>; 3] { - let shard_clients = HelperIdentity::make_three().map(|id| { - IpaHttpClient::shards_from_conf( - &IpaRuntime::current(), - &self.get_shards_for_helper(id).network, - &ClientIdentity::None, - ) - }); - shard_clients - } - /// Transforms this easy to modify configuration into an easy to run [`TestApp`]. #[must_use] pub fn into_apps(self) -> Vec { diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 4a1d17897..ad5e41395 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -482,31 +482,44 @@ mod tests { } async fn test_make_helpers(conf: TestConfig) { - let mpc_clients = IpaHttpClient::from_conf( - &IpaRuntime::current(), - &conf.leaders_ring().network, - &ClientIdentity::None, - ); - let shard_clients = conf.shard_clients(); + let clients = conf + .rings() + .map(|test_network| { + IpaHttpClient::from_conf( + &IpaRuntime::current(), + &test_network.network, + &ClientIdentity::None, + ) + }) + .collect::>(); + + // let mpc_clients = IpaHttpClient::from_conf( + // &IpaRuntime::current(), + // &conf.leaders_ring().network, + // &ClientIdentity::None, + // ); let _helpers = make_helpers(conf).await; - test_multiply_single_shard(&mpc_clients, shard_clients.each_ref().map(AsRef::as_ref)).await; + test_multiply_single_shard(&clients).await; } #[tokio::test(flavor = "multi_thread")] async fn happy_case_twice() { let conf = TestConfigBuilder::default().build(); - let clients = IpaHttpClient::from_conf( - &IpaRuntime::current(), - &conf.leaders_ring().network, - &ClientIdentity::None, - ); - let shard_clients = conf.shard_clients(); - let shard_clients_ref = shard_clients.each_ref().map(AsRef::as_ref); + let clients = conf + .rings() + .map(|test_network| { + IpaHttpClient::from_conf( + &IpaRuntime::current(), + &test_network.network, + &ClientIdentity::None, + ) + }) + .collect::>(); let _helpers = make_helpers(conf).await; - test_multiply_single_shard(&clients, shard_clients_ref).await; - test_multiply_single_shard(&clients, shard_clients_ref).await; + test_multiply_single_shard(&clients).await; + test_multiply_single_shard(&clients).await; } /// This executes test multiplication protocol by running it exclusively on the leader shards. @@ -517,14 +530,12 @@ mod tests { /// The sharding requires some amendments to the test multiplication protocol that are /// currently in progress. Once completed, this test can be fixed by fully utilizing all /// shards in the system. - async fn test_multiply_single_shard( - clients: &[IpaHttpClient; 3], - shard_clients: [&[IpaHttpClient]; 3], - ) { + async fn test_multiply_single_shard(clients: &[[IpaHttpClient; 3]]) { const SZ: usize = as Serializable>::Size::USIZE; + let leader_ring_clients = &clients[0]; // send a create query command - let leader_client = &clients[0]; + let leader_client = &leader_ring_clients[0]; let create_data = QueryConfig::new(TestMultiply, FieldType::Fp31, 1).unwrap(); // create query @@ -547,14 +558,14 @@ mod tests { query_id, input_stream, }; - handle_resps.push(clients[i].query_input(data)); + handle_resps.push(leader_ring_clients[i].query_input(data)); } try_join_all(handle_resps).await.unwrap(); - // shards receive their own input - in this case empty - try_join_all(shard_clients.each_ref().map(|helper_shard_clients| { - // convention - first client is shard leader, and we submitted the inputs to it. - try_join_all(helper_shard_clients.iter().skip(1).map(|shard_client| { + // shards receive their own input - in this case empty. + // convention - first client is shard leader, and we submitted the inputs to it. + try_join_all(clients.iter().skip(1).map(|ring| { + try_join_all(ring.each_ref().map(|shard_client| { shard_client.query_input(QueryInput { query_id, input_stream: BodyStream::empty(), @@ -564,7 +575,7 @@ mod tests { .await .unwrap(); - let result: [_; 3] = join_all(clients.clone().map(|client| async move { + let result: [_; 3] = join_all(leader_ring_clients.each_ref().map(|client| async move { let r = client.query_results(query_id).await.unwrap(); AdditiveShare::::from_byte_slice_unchecked(&r).collect::>() })) diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index fb37fa555..6d7da358e 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -584,6 +584,8 @@ mod tests { } impl TestComponents { + const COMPLETE_QUERY_RESULT: Vec = Vec::new(); + fn new(mut args: TestComponentsArgs) -> Self { let mpc_network = InMemoryMpcNetwork::new( args.mpc_handlers @@ -629,7 +631,7 @@ mod tests { join_handle: IpaRuntime::current().spawn(async {}), })) .unwrap(); - tx.send(Ok(Box::new(Vec::::new()))).unwrap(); + tx.send(Ok(Box::new(Self::COMPLETE_QUERY_RESULT))).unwrap(); QueryId } @@ -807,13 +809,13 @@ mod tests { mod complete { use crate::{ - helpers::{make_owned_handler, routing::RouteId, ApiError, Transport}, + helpers::{make_owned_handler, routing::RouteId, Transport}, query::{ processor::{ tests::{HelperResponse, TestComponents, TestComponentsArgs}, QueryId, }, - QueryCompletionError, + ProtocolResult, QueryCompletionError, }, sharding::ShardIndex, }; @@ -823,13 +825,18 @@ mod tests { let t = TestComponents::default(); let query_id = t.new_running_query().await; - t.processor - .complete(query_id, t.shard_transport.clone_ref()) - .await - .unwrap(); + assert_eq!( + TestComponents::COMPLETE_QUERY_RESULT.to_bytes(), + t.processor + .complete(query_id, t.shard_transport.clone_ref()) + .await + .unwrap() + .to_bytes() + ); } #[tokio::test] + #[should_panic(expected = "QueryCompletion(NoSuchQuery(QueryId))")] async fn complete_one_shard_fails() { let mut args = TestComponentsArgs::default(); @@ -850,7 +857,7 @@ mod tests { .processor .complete(query_id, t.shard_transport.clone_ref()) .await - .unwrap_err(); + .unwrap(); } #[tokio::test] @@ -860,9 +867,7 @@ mod tests { args.set_shard_handler(|shard_id| { make_owned_handler(move |_req, _| { if shard_id == ShardIndex::FIRST { - futures::future::err(ApiError::BadRequest( - "Leader shard must not receive requests through shard channels".into(), - )) + panic!("Leader shard must not receive requests through shard channels"); } else { futures::future::ok(HelperResponse::ok()) } From 1cddf354eecb61abe7dcdc97d14647c87226ccea Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 20 Nov 2024 15:47:10 -0800 Subject: [PATCH 41/47] Remove commented code --- ipa-core/src/net/transport.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index ad5e41395..73054ec8f 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -493,12 +493,6 @@ mod tests { }) .collect::>(); - // let mpc_clients = IpaHttpClient::from_conf( - // &IpaRuntime::current(), - // &conf.leaders_ring().network, - // &ClientIdentity::None, - // ); - let _helpers = make_helpers(conf).await; test_multiply_single_shard(&clients).await; } From 5547f4db0aff4e192bb497fd84eb3b045b212c65 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Wed, 20 Nov 2024 15:43:37 -0800 Subject: [PATCH 42/47] Allowing non-sharded network.toml --- ipa-core/src/bin/helper.rs | 10 ++++- ipa-core/src/config.rs | 89 +++++++++++++++++++++++++------------- 2 files changed, 67 insertions(+), 32 deletions(-) diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 7de9de90c..8641f3547 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -183,6 +183,7 @@ async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), B let shard_index = ShardIndex::from(args.shard_index.expect("enforced by clap")); let shard_count = ShardIndex::from(args.shard_count.expect("enforced by clap")); assert!(shard_index < shard_count); + assert_eq!(args.tls_cert.is_some(), !args.disable_https); let (identity, server_tls) = create_client_identity(my_identity, args.tls_cert.clone(), args.tls_key.clone())?; @@ -223,8 +224,13 @@ async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), B let network_config_path = args.network.as_deref().unwrap(); let network_config_string = &fs::read_to_string(network_config_path)?; - let (mut mpc_network, mut shard_network) = - sharded_server_from_toml_str(network_config_string, my_identity, shard_index, shard_count)?; + let (mut mpc_network, mut shard_network) = sharded_server_from_toml_str( + network_config_string, + my_identity, + shard_index, + shard_count, + args.shard_port, + )?; mpc_network = mpc_network.override_scheme(&scheme); shard_network = shard_network.override_scheme(&scheme); diff --git a/ipa-core/src/config.rs b/ipa-core/src/config.rs index 6b5df3c84..d53f9dcdf 100644 --- a/ipa-core/src/config.rs +++ b/ipa-core/src/config.rs @@ -3,6 +3,7 @@ use std::{ fmt::{Debug, Formatter}, iter::zip, path::PathBuf, + str::FromStr, time::Duration, }; @@ -280,20 +281,21 @@ fn parse_sharded_network_toml(input: &str) -> Result /// /// # Errors /// if `input` is in an invalid format +/// +/// # Panics +/// If you somehow provide an invalid non-sharded network toml pub fn sharded_server_from_toml_str( input: &str, id: HelperIdentity, shard_index: ShardIndex, shard_count: ShardIndex, + shard_port: Option, ) -> Result<(NetworkConfig, NetworkConfig), Error> { let all_network = parse_sharded_network_toml(input)?; - let missing_urls = all_network.missing_shard_urls(); - if !missing_urls.is_empty() { - return Err(Error::MissingShardUrls(missing_urls)); - } let ix: usize = shard_index.as_index(); let ix_count: usize = shard_count.as_index(); + // assert ix < count let mpc_id: usize = id.as_index(); let mpc_network = NetworkConfig { @@ -307,21 +309,43 @@ pub fn sharded_server_from_toml_str( client: all_network.client.clone(), identities: HelperIdentity::make_three().to_vec(), }; - - let shard_network = NetworkConfig { - peers: all_network - .peers - .iter() - .map(ShardedPeerConfigToml::to_shard_peer) - .skip(mpc_id) - .step_by(3) - .take(ix_count) - .collect(), - client: all_network.client, - identities: shard_count.iter().collect(), - }; - - Ok((mpc_network, shard_network)) + let missing_urls = all_network.missing_shard_urls(); + if missing_urls.is_empty() { + let shard_network = NetworkConfig { + peers: all_network + .peers + .iter() + .map(ShardedPeerConfigToml::to_shard_peer) + .skip(mpc_id) + .step_by(3) + .take(ix_count) + .collect(), + client: all_network.client, + identities: shard_count.iter().collect(), + }; + Ok((mpc_network, shard_network)) + } else if missing_urls == [0, 1, 2] && shard_count == ShardIndex(1) { + // This is the special case we're dealing with a non-sharded, single ring MPC. + // Since the shard network will be of size 1, it can't really communicate with anyone else. + // Hence we just create a config where I'm the only shard. We take the MPC configuration + // and modify the port. + let mut myself = ShardedPeerConfigToml::to_mpc_peer(all_network.peers.get(mpc_id).unwrap()); + let url = myself.url.to_string(); + let new_url = format!( + "{}{}", + &url[..=url.find(':').unwrap()], + shard_port.expect("Shard port should be set") + ); + myself.url = Uri::from_str(&new_url).expect("Problem creating uri with sharded port"); + let shard_network = NetworkConfig { + peers: vec![myself], + client: all_network.client, + identities: shard_count.iter().collect(), + }; + Ok((mpc_network, shard_network)) + } else { + return Err(Error::MissingShardUrls(missing_urls)); + } } #[derive(Clone, Debug, Deserialize)] @@ -788,6 +812,7 @@ mod tests { HelperIdentity::TWO, ShardIndex::from(1), ShardIndex::from(3), + None, ) .unwrap(); assert_eq!( @@ -856,19 +881,23 @@ mod tests { )); } - /// Check that shard urls are given for [`sharded_server_from_toml_str`] or error is returned. + /// Check that [`sharded_server_from_toml_str`] can work in the previous format. #[test] fn parse_sharded_without_shard_urls() { - // Second, I test the networkconfig parsing - assert!(matches!( - sharded_server_from_toml_str( - &NON_SHARDED_COMPAT, - HelperIdentity::TWO, - ShardIndex::FIRST, - ShardIndex::from(1) - ), - Err(Error::MissingShardUrls(_)) - )); + let (mpc, mut shard) = sharded_server_from_toml_str( + &NON_SHARDED_COMPAT, + HelperIdentity::TWO, + ShardIndex::FIRST, + ShardIndex::from(1), + Some(666), + ) + .unwrap(); + assert_eq!(1, shard.peers.len()); + assert_eq!(3, mpc.peers.len()); + assert_eq!( + "helper2.org:666", + shard.peers.pop().unwrap().url.to_string() + ); } /// Testing happy case of a sharded network config From 8ea3cc70e486a5a2e180d8a6b1b567a48cd93763 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Wed, 20 Nov 2024 19:10:11 -0800 Subject: [PATCH 43/47] handling case when url has no port --- ipa-core/src/config.rs | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/ipa-core/src/config.rs b/ipa-core/src/config.rs index d53f9dcdf..403f38b24 100644 --- a/ipa-core/src/config.rs +++ b/ipa-core/src/config.rs @@ -331,11 +331,13 @@ pub fn sharded_server_from_toml_str( // and modify the port. let mut myself = ShardedPeerConfigToml::to_mpc_peer(all_network.peers.get(mpc_id).unwrap()); let url = myself.url.to_string(); - let new_url = format!( - "{}{}", - &url[..=url.find(':').unwrap()], - shard_port.expect("Shard port should be set") - ); + let pos = url.rfind(':'); + let port = shard_port.expect("Shard port should be set"); + let new_url = if pos.is_some() { + format!("{}{port}", &url[..=pos.unwrap()]) + } else { + format!("{}:{port}", &url) + }; myself.url = Uri::from_str(&new_url).expect("Problem creating uri with sharded port"); let shard_network = NetworkConfig { peers: vec![myself], @@ -884,6 +886,26 @@ mod tests { /// Check that [`sharded_server_from_toml_str`] can work in the previous format. #[test] fn parse_sharded_without_shard_urls() { + let (mpc, mut shard) = sharded_server_from_toml_str( + &NON_SHARDED_COMPAT, + HelperIdentity::ONE, + ShardIndex::FIRST, + ShardIndex::from(1), + Some(666), + ) + .unwrap(); + assert_eq!(1, shard.peers.len()); + assert_eq!(3, mpc.peers.len()); + assert_eq!( + "helper1.org:666", + shard.peers.pop().unwrap().url.to_string() + ); + } + + /// Check that [`sharded_server_from_toml_str`] can work in the previous format, even when the + /// given MPC URL doesn't have a port (NOTE: helper 2 doesn't specify it). + #[test] + fn parse_sharded_without_shard_urls_no_port() { let (mpc, mut shard) = sharded_server_from_toml_str( &NON_SHARDED_COMPAT, HelperIdentity::TWO, @@ -988,6 +1010,7 @@ a6L0t1Ug8i2RcequSo21x319Tvs5nUbGwzMFSS5wKA== url = "helper1.org:443""#; /// The rest of a configuration + /// Note the second helper doesn't provide a port as part of its url const REST: &str = r#" [peers.hpke] public_key = "f458d5e1989b2b8f5dacd4143276aa81eaacf7449744ab1251ff667c43550756" @@ -1006,7 +1029,7 @@ zj0EAwIDSAAwRQIhAI4G5ICVm+v5KK5Y8WVetThtNCXGykUBAM1eE973FBOUAiAS XXgJe9q9hAfHf0puZbv0j0tGY3BiqCkJJaLvK7ba+g== -----END CERTIFICATE----- """ -url = "helper2.org:443" +url = "helper2.org" [peers.hpke] public_key = "62357179868e5594372b801ddf282c8523806a868a2bff2685f66aa05ffd6c22" From 5b0673027b8762153908198d7c6bec67f57c4469 Mon Sep 17 00:00:00 2001 From: Christian Berkhoff Date: Wed, 20 Nov 2024 20:00:06 -0800 Subject: [PATCH 44/47] Addressing comments --- ipa-core/src/query/processor.rs | 29 +++++++++++---------------- ipa-core/src/query/state.rs | 34 +++++++++++++++++++++----------- ipa-core/src/test_fixture/app.rs | 8 ++++---- 3 files changed, 37 insertions(+), 34 deletions(-) diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 3f23b6260..e6d39e92d 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -365,12 +365,11 @@ impl Processor { /// This helper function is used to transform a [`BoxError`] into a /// [`QueryStatusError::DifferentStatus`] and retrieve it's internal state. Returns [`None`] /// if not possible. - fn downcast_state_error(be: BoxError) -> Option { - let ae = be.downcast::().ok()?; - if let crate::helpers::ApiError::QueryStatus(QueryStatusError::DifferentStatus { - my_status, - .. - }) = *ae + fn downcast_state_error(box_error: BoxError) -> Option { + use crate::helpers::ApiError; + let api_error = box_error.downcast::().ok()?; + if let ApiError::QueryStatus(QueryStatusError::DifferentStatus { my_status, .. }) = + *api_error { return Some(my_status); } @@ -383,9 +382,9 @@ impl Processor { /// of relying on errors. #[cfg(feature = "in-memory-infra")] fn get_state_from_error( - be: crate::helpers::InMemoryTransportError, + error: crate::helpers::InMemoryTransportError, ) -> Option { - if let crate::helpers::InMemoryTransportError::Rejected { inner, .. } = be { + if let crate::helpers::InMemoryTransportError::Rejected { inner, .. } = error { return Self::downcast_state_error(inner); } None @@ -396,8 +395,8 @@ impl Processor { /// TODO: Ideally broadcast should return a value, that we could use to parse the state instead /// of relying on errors. #[cfg(feature = "real-world-infra")] - fn get_state_from_error(se: crate::net::ShardError) -> Option { - if let crate::net::Error::Application { error, .. } = se.source { + fn get_state_from_error(shard_error: crate::net::ShardError) -> Option { + if let crate::net::Error::Application { error, .. } = shard_error.source { return Self::downcast_state_error(error); } None @@ -428,8 +427,7 @@ impl Processor { let shard_responses = shard_transport.broadcast(shard_query_status_req).await; if let Err(e) = shard_responses { - // The following silently ignores the cases where the query isn't found because those - // errors return `None` for [`BroadcasteableError::peer_state()`] + // The following silently ignores the cases where the query isn't found. let states: Vec<_> = e .failures .into_iter() @@ -1290,12 +1288,7 @@ mod tests { .start_query(vec![a, b].into_iter(), test_multiply_config()) .await?; - while !app - .query_status(query_id) - .await? - .into_iter() - .all(|s| s == QueryStatus::Completed) - { + while !(app.query_status(query_id).await? == QueryStatus::Completed) { sleep(Duration::from_millis(1)).await; } diff --git a/ipa-core/src/query/state.rs b/ipa-core/src/query/state.rs index 834e491bd..8be68f13f 100644 --- a/ipa-core/src/query/state.rs +++ b/ipa-core/src/query/state.rs @@ -48,6 +48,11 @@ impl From<&QueryState> for QueryStatus { } } +/// This function is used, among others, by the [`Processor`] to return a unified response when +/// queried about the state of a sharded helper. In such scenarios, there will be many different +/// [`QueryStatus`] and the [`Processor`] needs to return a single one that describes the entire +/// helper. With this function we're saying that the minimum state across all shards is the one +/// that describes the helper. #[must_use] pub fn min_status(a: QueryStatus, b: QueryStatus) -> QueryStatus { match (a, b) { @@ -248,17 +253,22 @@ mod tests { #[test] fn test_order() { - assert_eq!( - min_status(QueryStatus::Preparing, QueryStatus::Preparing), - QueryStatus::Preparing - ); - assert_eq!( - min_status(QueryStatus::Preparing, QueryStatus::Completed), - QueryStatus::Preparing - ); - assert_eq!( - min_status(QueryStatus::AwaitingCompletion, QueryStatus::AwaitingInputs), - QueryStatus::AwaitingInputs - ); + // this list sorted in priority order. Preparing is the lowest possible value, + // while Completed is the highest. + let all = [ + QueryStatus::Preparing, + QueryStatus::AwaitingInputs, + QueryStatus::Running, + QueryStatus::AwaitingCompletion, + QueryStatus::Completed, + ]; + + for i in 0..all.len() { + let this = all[i]; + for other in all.into_iter().skip(i) { + assert_eq!(this, min_status(this, other)); + assert_eq!(this, min_status(other, this)); + } + } } } diff --git a/ipa-core/src/test_fixture/app.rs b/ipa-core/src/test_fixture/app.rs index cfe2e9622..1cf01cb38 100644 --- a/ipa-core/src/test_fixture/app.rs +++ b/ipa-core/src/test_fixture/app.rs @@ -12,7 +12,7 @@ use crate::{ ApiError, InMemoryMpcNetwork, InMemoryShardNetwork, Transport, }, protocol::QueryId, - query::QueryStatus, + query::{min_status, QueryStatus}, secret_sharing::IntoShares, test_fixture::try_join3_array, utils::array::zip3, @@ -120,12 +120,12 @@ impl TestApp { /// ## Panics /// Never. #[allow(clippy::disallowed_methods)] - pub async fn query_status(&self, query_id: QueryId) -> Result<[QueryStatus; 3], ApiError> { + pub async fn query_status(&self, query_id: QueryId) -> Result { join_all((0..3).map(|i| self.drivers[i].query_status(query_id))) .await .into_iter() - .collect::, _>>() - .map(|vec| vec.try_into().unwrap()) + .reduce(|s1, s2| Ok(min_status(s1?, s2?))) + .unwrap() } /// ## Errors From cc119fede3d7cc7fe856aa512dac799792faaedb Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 20 Nov 2024 20:17:02 -0800 Subject: [PATCH 45/47] Move back query_input --- ipa-core/src/net/client/mod.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index 778aba6f1..ce719b9ed 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -373,18 +373,6 @@ impl IpaHttpClient { resp_ok(resp).await } - /// Intended to be called externally, e.g. by the report collector. After the report collector - /// calls "create query", it must then send the data for the query to each of the clients. This - /// query input contains the data intended for a helper. - /// # Errors - /// If the request has illegal arguments, or fails to deliver to helper - pub async fn query_input(&self, data: QueryInput) -> Result<(), Error> { - let req = http_serde::query::input::Request::new(data); - let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; - let resp = self.request(req).await?; - resp_ok(resp).await - } - /// Complete query API can be called on the leader shard by the report collector or /// by the leader shard to other shards. /// @@ -442,6 +430,18 @@ impl IpaHttpClient { } } + /// Intended to be called externally, e.g. by the report collector. After the report collector + /// calls "create query", it must then send the data for the query to each of the clients. This + /// query input contains the data intended for a helper. + /// # Errors + /// If the request has illegal arguments, or fails to deliver to helper + pub async fn query_input(&self, data: QueryInput) -> Result<(), Error> { + let req = http_serde::query::input::Request::new(data); + let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; + let resp = self.request(req).await?; + resp_ok(resp).await + } + /// Retrieve the status of a query. /// /// ## Errors From 59f15c8e00a596487137a450cae39ca7cc7f49ca Mon Sep 17 00:00:00 2001 From: Erik Taubeneck Date: Thu, 21 Nov 2024 12:05:56 -0800 Subject: [PATCH 46/47] upgrade hybrid tests to use malicious context (#1429) * upgrade hybrid tests to use malicious context * Update ipa-core/src/protocol/hybrid/oprf.rs Co-authored-by: Andy Leiserson --------- Co-authored-by: Andy Leiserson --- ipa-core/src/protocol/basics/mod.rs | 10 +++++- ipa-core/src/protocol/basics/reveal.rs | 46 +++++++++++++++++++++++- ipa-core/src/protocol/context/mod.rs | 1 + ipa-core/src/protocol/hybrid/oprf.rs | 26 ++++++++++---- ipa-core/src/query/runner/hybrid.rs | 8 ++--- ipa-core/src/query/runner/reshard_tag.rs | 4 +-- ipa-core/src/test_fixture/world.rs | 3 +- 7 files changed, 83 insertions(+), 15 deletions(-) diff --git a/ipa-core/src/protocol/basics/mod.rs b/ipa-core/src/protocol/basics/mod.rs index ebe34cb34..ab22efbcf 100644 --- a/ipa-core/src/protocol/basics/mod.rs +++ b/ipa-core/src/protocol/basics/mod.rs @@ -23,7 +23,7 @@ use crate::{ protocol::{ context::{ Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, - UpgradedMaliciousContext, UpgradedSemiHonestContext, + ShardedUpgradedMaliciousContext, UpgradedMaliciousContext, UpgradedSemiHonestContext, }, ipa_prf::{AGG_CHUNK, PRF_CHUNK}, prss::FromPrss, @@ -66,6 +66,14 @@ where { } +impl<'a, const N: usize> BasicProtocols, Fp25519, N> + for malicious::AdditiveShare +where + Fp25519: FieldSimd, + AdditiveShare: FromPrss, +{ +} + /// Basic suite of MPC protocols for (possibly vectorized) boolean shares. /// /// Adds the requirement that the type implements `Not`. diff --git a/ipa-core/src/protocol/basics/reveal.rs b/ipa-core/src/protocol/basics/reveal.rs index 0e4f08378..e38d31500 100644 --- a/ipa-core/src/protocol/basics/reveal.rs +++ b/ipa-core/src/protocol/basics/reveal.rs @@ -37,7 +37,7 @@ use crate::{ boolean::step::TwoHundredFiftySixBitOpStep, context::{ Context, DZKPContext, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, - UpgradedMaliciousContext, UpgradedSemiHonestContext, + ShardedUpgradedMaliciousContext, UpgradedMaliciousContext, UpgradedSemiHonestContext, }, RecordId, }, @@ -333,6 +333,50 @@ where } } +impl<'a, V, const N: usize, CtxF> Reveal> + for Replicated +where + CtxF: ExtendableField, + V: SharedValue + Vectorizable, +{ + type Output = >::Array; + + async fn generic_reveal<'fut>( + &'fut self, + ctx: ShardedUpgradedMaliciousContext<'a, CtxF>, + record_id: RecordId, + excluded: Option, + ) -> Result>::Array>, Error> + where + ShardedUpgradedMaliciousContext<'a, CtxF>: 'fut, + { + malicious_reveal(ctx, record_id, excluded, self).await + } +} + +impl<'a, F, const N: usize> Reveal> + for MaliciousReplicated +where + F: ExtendableFieldSimd, +{ + type Output = >::Array; + + async fn generic_reveal<'fut>( + &'fut self, + ctx: ShardedUpgradedMaliciousContext<'a, F>, + record_id: RecordId, + excluded: Option, + ) -> Result>::Array>, Error> + where + ShardedUpgradedMaliciousContext<'a, F>: 'fut, + { + use crate::secret_sharing::replicated::malicious::ThisCodeIsAuthorizedToDowngradeFromMalicious; + + let x_share = self.x().access_without_downgrade(); + malicious_reveal(ctx, record_id, excluded, x_share).await + } +} + impl<'a, V, B, const N: usize> Reveal> for Replicated where B: ShardBinding, diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index 205e022d6..d21896190 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -28,6 +28,7 @@ pub type ShardedSemiHonestContext<'a> = semi_honest::Context<'a, Sharded>; pub type MaliciousContext<'a, B = NotSharded> = malicious::Context<'a, B>; pub type ShardedMaliciousContext<'a> = malicious::Context<'a, Sharded>; pub type UpgradedMaliciousContext<'a, F, B = NotSharded> = malicious::Upgraded<'a, F, B>; +pub type ShardedUpgradedMaliciousContext<'a, F, B = Sharded> = malicious::Upgraded<'a, F, B>; #[cfg(all(feature = "in-memory-infra", any(test, feature = "test-fixture")))] pub(crate) use malicious::TEST_DZKP_STEPS; diff --git a/ipa-core/src/protocol/hybrid/oprf.rs b/ipa-core/src/protocol/hybrid/oprf.rs index bf1a14f3b..da2bf903f 100644 --- a/ipa-core/src/protocol/hybrid/oprf.rs +++ b/ipa-core/src/protocol/hybrid/oprf.rs @@ -21,7 +21,8 @@ use crate::{ context::{ dzkp_validator::{DZKPValidator, TARGET_PROOF_SIZE}, reshard_try_stream, DZKPUpgraded, MacUpgraded, MaliciousProtocolSteps, ShardedContext, - UpgradableContext, Validator, + ShardedUpgradedMaliciousContext, UpgradableContext, UpgradedMaliciousContext, + Validator, }, hybrid::step::HybridStep, ipa_prf::{ @@ -29,12 +30,12 @@ use crate::{ prf_eval::{eval_dy_prf, PrfSharing}, }, prss::{FromPrss, SharedRandomness}, - RecordId, + BasicProtocols, RecordId, }, report::hybrid::{IndistinguishableHybridReport, PrfHybridReport}, secret_sharing::{ - replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, TransposeFrom, - Vectorizable, + replicated::{malicious, semi_honest::AdditiveShare as Replicated}, + BitDecomposed, FieldSimd, TransposeFrom, Vectorizable, }, seq_join::seq_join, utils::non_zero_prev_power_of_two, @@ -81,6 +82,20 @@ fn conv_proof_chunk() -> usize { non_zero_prev_power_of_two(max(2, TARGET_PROOF_SIZE / CONV_CHUNK / 512)) } +/// Allow MAC-malicious shares to be used for PRF generation with shards +impl<'a, const N: usize> PrfSharing, N> + for Replicated +where + Fp25519: FieldSimd, + RP25519: Vectorizable, + malicious::AdditiveShare: + BasicProtocols, Fp25519, N>, + Replicated: FromPrss, +{ + type Field = Fp25519; + type UpgradedSharing = malicious::AdditiveShare; +} + /// This computes the Dodis-Yampolsky PRF value on every match key from input, /// and reshards the reports according to the computed PRF. At the end, reports with the /// same value end up on the same shard. @@ -233,9 +248,8 @@ mod test { }, ]; - // TODO: we need to use malicious circuits here let reports_per_shard = world - .semi_honest(records.clone().into_iter(), |ctx, reports| async move { + .malicious(records.clone().into_iter(), |ctx, reports| async move { let ind_reports = reports .into_iter() .map(IndistinguishableHybridReport::from) diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index 8a9f375be..09bef945c 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -276,7 +276,7 @@ mod tests { #[should_panic( expected = "not implemented: protocol::hybrid::hybrid_protocol is not fully implemented" )] - fn encrypted_hybrid_reports() { + fn encrypted_hybrid_reports_happy() { // While this test currently checks for an unimplemented panic it is // designed to test for a correct result for a complete implementation. run(|| async { @@ -293,7 +293,7 @@ mod tests { } = build_buffers_from_records(&records, SHARDS, &hybrid_info); let world = TestWorld::>::with_shards(TestWorldConfig::default()); - let contexts = world.contexts(); + let contexts = world.malicious_contexts(); #[allow(clippy::large_futures)] let results = flatten3v(buffers.into_iter().zip(contexts).map( @@ -384,7 +384,7 @@ mod tests { let world: TestWorld> = TestWorld::with_shards(TestWorldConfig::default()); - let contexts = world.contexts(); + let contexts = world.malicious_contexts(); #[allow(clippy::large_futures)] let results = flatten3v(buffers.into_iter().zip(contexts).map( @@ -437,7 +437,7 @@ mod tests { let world: TestWorld> = TestWorld::with_shards(TestWorldConfig::default()); - let contexts = world.contexts(); + let contexts = world.malicious_contexts(); #[allow(clippy::large_futures)] let results = flatten3v(buffers.into_iter().zip(contexts).map( diff --git a/ipa-core/src/query/runner/reshard_tag.rs b/ipa-core/src/query/runner/reshard_tag.rs index 5d1c3b8f5..9b535fc11 100644 --- a/ipa-core/src/query/runner/reshard_tag.rs +++ b/ipa-core/src/query/runner/reshard_tag.rs @@ -100,7 +100,7 @@ mod tests { let world: TestWorld> = TestWorld::with_shards(TestWorldConfig::default()); world - .semi_honest( + .malicious( vec![BA8::truncate_from(1u128), BA8::truncate_from(2u128)].into_iter(), |ctx, input| async move { let shard_id = ctx.shard_id(); @@ -130,7 +130,7 @@ mod tests { let world: TestWorld> = TestWorld::with_shards(TestWorldConfig::default()); world - .semi_honest( + .malicious( vec![BA8::truncate_from(1u128), BA8::truncate_from(2u128)].into_iter(), |ctx, input| async move { reshard_aad( diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index 8c2db9a4e..91cfc87ea 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -240,9 +240,10 @@ impl TestWorld> { /// Panics if world has more or less than 3 gateways/participants #[must_use] pub fn malicious_contexts(&self) -> [Vec>; 3] { + let gate = &self.next_gate(); self.shards() .iter() - .map(|shard| shard.malicious_contexts(&self.next_gate())) + .map(|shard| shard.malicious_contexts(gate)) .fold([Vec::new(), Vec::new(), Vec::new()], |mut acc, contexts| { // Distribute contexts into the respective vectors. for (vec, context) in acc.iter_mut().zip(contexts.iter()) { From 0cbc5de284a5767003d94c97ff56a268879b667a Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Fri, 22 Nov 2024 08:42:27 -0800 Subject: [PATCH 47/47] Use lookup tables for first proof Lagrange output (#1437) --- ipa-core/benches/dzkp_convert_prover.rs | 34 +- ipa-core/src/protocol/context/dzkp_field.rs | 461 +++++++++--------- .../src/protocol/context/dzkp_validator.rs | 158 +++--- .../ipa_prf/malicious_security/lagrange.rs | 2 +- .../ipa_prf/malicious_security/mod.rs | 8 + .../ipa_prf/malicious_security/prover.rs | 393 ++++++++++++--- .../ipa_prf/malicious_security/verifier.rs | 248 ++++++++-- ipa-core/src/protocol/ipa_prf/mod.rs | 7 +- .../validation_protocol/proof_generation.rs | 140 ++---- .../ipa_prf/validation_protocol/validation.rs | 181 +++---- .../src/secret_sharing/vector/transpose.rs | 4 +- 11 files changed, 974 insertions(+), 662 deletions(-) diff --git a/ipa-core/benches/dzkp_convert_prover.rs b/ipa-core/benches/dzkp_convert_prover.rs index 57b557735..c8f820bab 100644 --- a/ipa-core/benches/dzkp_convert_prover.rs +++ b/ipa-core/benches/dzkp_convert_prover.rs @@ -1,41 +1,15 @@ -//! Benchmark for the convert_prover function in dzkp_field.rs. +//! Benchmark for the table_indices_prover function in dzkp_field.rs. use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; -use ipa_core::{ - ff::Fp61BitPrime, - protocol::context::{dzkp_field::DZKPBaseField, dzkp_validator::MultiplicationInputsBlock}, -}; +use ipa_core::protocol::context::dzkp_validator::MultiplicationInputsBlock; use rand::{thread_rng, Rng}; fn convert_prover_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("dzkp_convert_prover"); group.bench_function("convert", |b| { b.iter_batched_ref( - || { - // Generate input - let mut rng = thread_rng(); - - MultiplicationInputsBlock { - x_left: rng.gen::<[u8; 32]>().into(), - x_right: rng.gen::<[u8; 32]>().into(), - y_left: rng.gen::<[u8; 32]>().into(), - y_right: rng.gen::<[u8; 32]>().into(), - prss_left: rng.gen::<[u8; 32]>().into(), - prss_right: rng.gen::<[u8; 32]>().into(), - z_right: rng.gen::<[u8; 32]>().into(), - } - }, - |input| { - let MultiplicationInputsBlock { - x_left, - x_right, - y_left, - y_right, - prss_right, - .. - } = input; - Fp61BitPrime::convert_prover(x_left, x_right, y_left, y_right, prss_right); - }, + || thread_rng().gen(), + |input: &mut MultiplicationInputsBlock| input.table_indices_prover(), BatchSize::SmallInput, ) }); diff --git a/ipa-core/src/protocol/context/dzkp_field.rs b/ipa-core/src/protocol/context/dzkp_field.rs index 368ac36cc..758ad008b 100644 --- a/ipa-core/src/protocol/context/dzkp_field.rs +++ b/ipa-core/src/protocol/context/dzkp_field.rs @@ -1,18 +1,13 @@ -use std::{iter::zip, sync::LazyLock}; +use std::{ops::Index, sync::LazyLock}; use bitvec::field::BitField; use crate::{ ff::{Field, Fp61BitPrime, PrimeField}, - protocol::context::dzkp_validator::{Array256Bit, SegmentEntry}, + protocol::context::dzkp_validator::{Array256Bit, MultiplicationInputsBlock, SegmentEntry}, secret_sharing::{FieldSimd, SharedValue, Vectorizable}, }; -// BlockSize is fixed to 32 -pub const BLOCK_SIZE: usize = 32; -// UVTupleBlock is a block of interleaved U and V values -pub type UVTupleBlock = ([F; BLOCK_SIZE], [F; BLOCK_SIZE]); - /// Trait for fields compatible with DZKPs /// Field needs to support conversion to `SegmentEntry`, i.e. `to_segment_entry` which is required by DZKPs pub trait DZKPCompatibleField: FieldSimd { @@ -25,35 +20,12 @@ pub trait DZKPBaseField: PrimeField { const INVERSE_OF_TWO: Self; const MINUS_ONE_HALF: Self; const MINUS_TWO: Self; +} - /// Convert allows to convert individual bits from multiplication gates into dzkp compatible field elements. - /// This function is called by the prover. - fn convert_prover<'a>( - x_left: &'a Array256Bit, - x_right: &'a Array256Bit, - y_left: &'a Array256Bit, - y_right: &'a Array256Bit, - prss_right: &'a Array256Bit, - ) -> Vec>; - - /// This is similar to `convert_prover` except that it is called by the verifier to the left of the prover. - /// The verifier on the left uses its right shares, since they are consistent with the prover's left shares. - /// This produces the 'u' values. - fn convert_value_from_right_prover<'a>( - x_right: &'a Array256Bit, - y_right: &'a Array256Bit, - prss_right: &'a Array256Bit, - z_right: &'a Array256Bit, - ) -> Vec; - - /// This is similar to `convert_prover` except that it is called by the verifier to the right of the prover. - /// The verifier on the right uses its left shares, since they are consistent with the prover's right shares. - /// This produces the 'v' values - fn convert_value_from_left_prover<'a>( - x_left: &'a Array256Bit, - y_left: &'a Array256Bit, - prss_left: &'a Array256Bit, - ) -> Vec; +impl DZKPBaseField for Fp61BitPrime { + const INVERSE_OF_TWO: Self = Fp61BitPrime::const_truncate(1_152_921_504_606_846_976u64); + const MINUS_ONE_HALF: Self = Fp61BitPrime::const_truncate(1_152_921_504_606_846_975u64); + const MINUS_TWO: Self = Fp61BitPrime::const_truncate(2_305_843_009_213_693_949u64); } impl FromIterator for [Fp61BitPrime; P] { @@ -90,7 +62,7 @@ impl FromIterator for Vec<[Fp61BitPrime; P]> { } } -/// Construct indices for the `convert_values` lookup tables. +/// Construct indices for the `TABLE_U` and `TABLE_V` lookup tables. /// /// `b0` has the least significant bit of each index, and `b1` and `b2` the subsequent /// bits. This routine rearranges the bits so that there is one table index in each @@ -102,7 +74,7 @@ impl FromIterator for Vec<[Fp61BitPrime; P]> { /// (i%4) == j. The "0s", "1s", "2s", "3s" comments trace the movement from /// input to output. #[must_use] -fn convert_values_table_indices(b0: u128, b1: u128, b2: u128) -> [u128; 4] { +fn bits_to_table_indices(b0: u128, b1: u128, b2: u128) -> [u128; 4] { // 0x55 is 0b0101_0101. This mask selects bits having (i%2) == 0. const CONST_55: u128 = u128::from_le_bytes([0x55; 16]); // 0xaa is 0b1010_1010. This mask selects bits having (i%2) == 1. @@ -148,10 +120,31 @@ fn convert_values_table_indices(b0: u128, b1: u128, b2: u128) -> [u128; 4] { [y0, y1, y2, y3] } +pub struct UVTable(pub [[F; 4]; 8]); + +impl Index for UVTable { + type Output = [F; 4]; + + fn index(&self, index: u8) -> &Self::Output { + self.0.index(usize::from(index)) + } +} + +impl Index for UVTable { + type Output = [F; 4]; + + fn index(&self, index: usize) -> &Self::Output { + self.0.index(index) + } +} + // Table used for `convert_prover` and `convert_value_from_right_prover`. // -// The conversion to "g" and "h" values is from https://eprint.iacr.org/2023/909.pdf. -static TABLE_RIGHT: LazyLock<[[Fp61BitPrime; 4]; 8]> = LazyLock::new(|| { +// This table is for "g" or "u" values. This table is "right" on a verifier when it is +// processing values for the prover on its right. On a prover, this table is "left". +// +// The conversion logic is from https://eprint.iacr.org/2023/909.pdf. +pub static TABLE_U: LazyLock> = LazyLock::new(|| { let mut result = Vec::with_capacity(8); for e in [false, true] { for c in [false, true] { @@ -172,13 +165,16 @@ static TABLE_RIGHT: LazyLock<[[Fp61BitPrime; 4]; 8]> = LazyLock::new(|| { } } } - result.try_into().unwrap() + UVTable(result.try_into().unwrap()) }); // Table used for `convert_prover` and `convert_value_from_left_prover`. // -// The conversion to "g" and "h" values is from https://eprint.iacr.org/2023/909.pdf. -static TABLE_LEFT: LazyLock<[[Fp61BitPrime; 4]; 8]> = LazyLock::new(|| { +// This table is for "h" or "v" values. This table is "left" on a verifier when it is +// processing values for the prover on its left. On a prover, this table is "right". +// +// The conversion logic is from https://eprint.iacr.org/2023/909.pdf. +pub static TABLE_V: LazyLock> = LazyLock::new(|| { let mut result = Vec::with_capacity(8); for f in [false, true] { for d in [false, true] { @@ -199,31 +195,31 @@ static TABLE_LEFT: LazyLock<[[Fp61BitPrime; 4]; 8]> = LazyLock::new(|| { } } } - result.try_into().unwrap() + UVTable(result.try_into().unwrap()) }); -/// Lookup-table-based conversion logic used by `convert_prover`, -/// `convert_value_from_left_prover`, and `convert_value_from_left_prover`. +/// Lookup-table-based conversion logic used by `table_indices_prover`, +/// `table_indices_from_right_prover`, and `table_indices_from_left_prover`. /// /// Inputs `i0`, `i1`, and `i2` each contain the value of one of the "a" through "f" /// intermediates for each of 256 multiplies. `table` is the lookup table to use, -/// which should be either `TABLE_LEFT` or `TABLE_RIGHT`. +/// which should be either `TABLE_U` or `TABLE_V`. /// /// We want to interpret the 3-tuple of intermediates at each bit position in `i0`, `i1` /// and `i2` as an integer index in the range 0..8 into the table. The -/// `convert_values_table_indices` helper does this in bulk more efficiently than using +/// `bits_to_table_indices` helper does this in bulk more efficiently than using /// bit-manipulation to handle them one-by-one. /// /// Preserving the order from inputs to outputs is not necessary for correctness as long /// as the same order is used on all three helpers. We preserve the order anyways /// to simplify the end-to-end dataflow, even though it makes this routine slightly /// more complicated. -fn convert_values( +fn intermediates_to_table_indices<'a>( i0: &Array256Bit, i1: &Array256Bit, i2: &Array256Bit, - table: &[[Fp61BitPrime; 4]; 8], -) -> Vec { + mut out: impl Iterator, +) { // Split inputs to two `u128`s. We do this because `u128` is the largest integer // type rust supports. It is possible that using SIMD types here would improve // code generation for AVX-256/512. @@ -238,45 +234,49 @@ fn convert_values( // Output word `j` in each set contains the table indices for input positions `i` // having (i%4) == j. - let [mut z00, mut z01, mut z02, mut z03] = convert_values_table_indices(i00, i10, i20); - let [mut z10, mut z11, mut z12, mut z13] = convert_values_table_indices(i01, i11, i21); + let [mut z00, mut z01, mut z02, mut z03] = bits_to_table_indices(i00, i10, i20); + let [mut z10, mut z11, mut z12, mut z13] = bits_to_table_indices(i01, i11, i21); - let mut result = Vec::with_capacity(1024); + #[allow(clippy::cast_possible_truncation)] for _ in 0..32 { // Take one index in turn from each `z` to preserve the output order. - for z in [&mut z00, &mut z01, &mut z02, &mut z03] { - result.extend(table[(*z as usize) & 0x7]); - *z >>= 4; - } + *out.next().unwrap() = (z00 as u8) & 0x7; + z00 >>= 4; + *out.next().unwrap() = (z01 as u8) & 0x7; + z01 >>= 4; + *out.next().unwrap() = (z02 as u8) & 0x7; + z02 >>= 4; + *out.next().unwrap() = (z03 as u8) & 0x7; + z03 >>= 4; } + #[allow(clippy::cast_possible_truncation)] for _ in 0..32 { - for z in [&mut z10, &mut z11, &mut z12, &mut z13] { - result.extend(table[(*z as usize) & 0x7]); - *z >>= 4; - } + *out.next().unwrap() = (z10 as u8) & 0x7; + z10 >>= 4; + *out.next().unwrap() = (z11 as u8) & 0x7; + z11 >>= 4; + *out.next().unwrap() = (z12 as u8) & 0x7; + z12 >>= 4; + *out.next().unwrap() = (z13 as u8) & 0x7; + z13 >>= 4; } - debug_assert!(result.len() == 1024); - - result + debug_assert!(out.next().is_none()); } -impl DZKPBaseField for Fp61BitPrime { - const INVERSE_OF_TWO: Self = Fp61BitPrime::const_truncate(1_152_921_504_606_846_976u64); - const MINUS_ONE_HALF: Self = Fp61BitPrime::const_truncate(1_152_921_504_606_846_975u64); - const MINUS_TWO: Self = Fp61BitPrime::const_truncate(2_305_843_009_213_693_949u64); - - // Convert allows to convert individual bits from multiplication gates into dzkp compatible field elements - // We use the conversion logic from https://eprint.iacr.org/2023/909.pdf +impl MultiplicationInputsBlock { + /// Repack the intermediates in this block into lookup indices for `TABLE_U` and `TABLE_V`. + /// + /// This is the convert function called by the prover. // - // This function does not use any optimization. + // We use the conversion logic from https://eprint.iacr.org/2023/909.pdf // - // Prover and left can compute: + // Prover and the verifier on its left compute: // g1=-2ac(1-2e), // g2=c(1-2e), // g3=a(1-2e), // g4=-1/2(1-2e), // - // Prover and right can compute: + // Prover and the verifier on its right compute: // h1=bd(1-2f), // h2=d(1-2f), // h3=b(1-2f), @@ -292,33 +292,30 @@ impl DZKPBaseField for Fp61BitPrime { // therefore e = ab⊕cd⊕ f must hold. (alternatively, you can also see this by substituting z_left, // i.e. z_left = x_left · y_left ⊕ x_left · y_right ⊕ x_right · y_left ⊕ prss_left ⊕ prss_right #[allow(clippy::many_single_char_names)] - fn convert_prover<'a>( - x_left: &'a Array256Bit, - x_right: &'a Array256Bit, - y_left: &'a Array256Bit, - y_right: &'a Array256Bit, - prss_right: &'a Array256Bit, - ) -> Vec> { - let a = x_left; - let b = y_right; - let c = y_left; - let d = x_right; + #[must_use] + pub fn table_indices_prover(&self) -> Vec<(u8, u8)> { + let a = &self.x_left; + let b = &self.y_right; + let c = &self.y_left; + let d = &self.x_right; // e = ab ⊕ cd ⊕ f = x_left * y_right ⊕ y_left * x_right ⊕ prss_right - let e = (*x_left & y_right) ^ (*y_left & x_right) ^ prss_right; - let f = prss_right; - - let g = convert_values(a, c, &e, &TABLE_RIGHT); - let h = convert_values(b, d, f, &TABLE_LEFT); + let e = (self.x_left & self.y_right) ^ (self.y_left & self.x_right) ^ self.prss_right; + let f = &self.prss_right; - zip(g.chunks_exact(BLOCK_SIZE), h.chunks_exact(BLOCK_SIZE)) - .map(|(g_chunk, h_chunk)| (g_chunk.try_into().unwrap(), h_chunk.try_into().unwrap())) - .collect() + let mut output = vec![(0u8, 0u8); 256]; + intermediates_to_table_indices(a, c, &e, output.iter_mut().map(|tup| &mut tup.0)); + intermediates_to_table_indices(b, d, f, output.iter_mut().map(|tup| &mut tup.1)); + output } - // Convert allows to convert individual bits from multiplication gates into dzkp compatible field elements + /// Repack the intermediates in this block into lookup indices for `TABLE_U`. + /// + /// This is the convert function called by the verifier when processing for the + /// prover on its right. + // // We use the conversion logic from https://eprint.iacr.org/2023/909.pdf // - // Prover and left can compute: + // Prover and the verifier on its left compute: // g1=-2ac(1-2e), // g2=c(1-2e), // g3=a(1-2e), @@ -328,25 +325,27 @@ impl DZKPBaseField for Fp61BitPrime { // (a,c,e) = (x_right, y_right, x_right * y_right ⊕ z_right ⊕ prss_right) // here e is defined as in the paper (since the the verifier does not have access to b,d,f, // he cannot use the simplified formula for e) - fn convert_value_from_right_prover<'a>( - x_right: &'a Array256Bit, - y_right: &'a Array256Bit, - prss_right: &'a Array256Bit, - z_right: &'a Array256Bit, - ) -> Vec { - let a = x_right; - let c = y_right; + #[must_use] + pub fn table_indices_from_right_prover(&self) -> Vec { + let a = &self.x_right; + let c = &self.y_right; // e = ac ⊕ zright ⊕ prssright // as defined in the paper - let e = (*a & *c) ^ prss_right ^ z_right; + let e = (self.x_right & self.y_right) ^ self.prss_right ^ self.z_right; - convert_values(a, c, &e, &TABLE_RIGHT) + let mut output = vec![0u8; 256]; + intermediates_to_table_indices(a, c, &e, output.iter_mut()); + output } - // Convert allows to convert individual bits from multiplication gates into dzkp compatible field elements + /// Repack the intermediates in this block into lookup indices for `TABLE_V`. + /// + /// This is the convert function called by the verifier when processing for the + /// prover on its left. + // // We use the conversion logic from https://eprint.iacr.org/2023/909.pdf // - // Prover and right can compute: + // The prover and the verifier on its right compute: // h1=bd(1-2f), // h2=d(1-2f), // h3=b(1-2f), @@ -354,31 +353,33 @@ impl DZKPBaseField for Fp61BitPrime { // // where // (b,d,f) = (y_left, x_left, prss_left) - fn convert_value_from_left_prover<'a>( - x_left: &'a Array256Bit, - y_left: &'a Array256Bit, - prss_left: &'a Array256Bit, - ) -> Vec { - let b = y_left; - let d = x_left; - let f = prss_left; - - convert_values(b, d, f, &TABLE_LEFT) + #[must_use] + pub fn table_indices_from_left_prover(&self) -> Vec { + let b = &self.y_left; + let d = &self.x_left; + let f = &self.prss_left; + + let mut output = vec![0u8; 256]; + intermediates_to_table_indices(b, d, f, output.iter_mut()); + output } } #[cfg(all(test, unit_test))] -mod tests { - use bitvec::{array::BitArray, macros::internal::funty::Fundamental, slice::BitSlice}; +pub mod tests { + + use bitvec::{array::BitArray, macros::internal::funty::Fundamental}; use proptest::proptest; use rand::{thread_rng, Rng}; use crate::{ ff::{Field, Fp61BitPrime, U128Conversions}, - protocol::context::dzkp_field::{ - convert_values_table_indices, DZKPBaseField, UVTupleBlock, BLOCK_SIZE, + protocol::context::{ + dzkp_field::{bits_to_table_indices, DZKPBaseField, TABLE_U, TABLE_V}, + dzkp_validator::MultiplicationInputsBlock, }, secret_sharing::SharedValue, + test_executor::run_random, }; #[test] @@ -386,7 +387,7 @@ mod tests { let b0 = 0xaa; let b1 = 0xcc; let b2 = 0xf0; - let [z0, z1, z2, z3] = convert_values_table_indices(b0, b1, b2); + let [z0, z1, z2, z3] = bits_to_table_indices(b0, b1, b2); assert_eq!(z0, 0x40_u128); assert_eq!(z1, 0x51_u128); assert_eq!(z2, 0x62_u128); @@ -396,7 +397,7 @@ mod tests { let b0 = rng.gen(); let b1 = rng.gen(); let b2 = rng.gen(); - let [z0, z1, z2, z3] = convert_values_table_indices(b0, b1, b2); + let [z0, z1, z2, z3] = bits_to_table_indices(b0, b1, b2); for i in (0..128).step_by(4) { fn check(i: u32, j: u32, b0: u128, b1: u128, b2: u128, z: u128) { @@ -417,85 +418,73 @@ mod tests { } } - #[test] - fn batch_convert() { - let mut rng = thread_rng(); + impl MultiplicationInputsBlock { + /// Rotate the "right" values into the "left" values, setting the right values + /// to zero. If the input represents a prover's block of intermediates, the + /// output represents the intermediates that the verifier on the prover's right + /// shares with it. + #[must_use] + pub fn rotate_left(&self) -> Self { + Self { + x_left: self.x_right, + y_left: self.y_right, + prss_left: self.prss_right, + x_right: [0u8; 32].into(), + y_right: [0u8; 32].into(), + prss_right: [0u8; 32].into(), + z_right: [0u8; 32].into(), + } + } - // bitvecs - let mut vec_x_left = Vec::::new(); - let mut vec_x_right = Vec::::new(); - let mut vec_y_left = Vec::::new(); - let mut vec_y_right = Vec::::new(); - let mut vec_prss_left = Vec::::new(); - let mut vec_prss_right = Vec::::new(); - let mut vec_z_right = Vec::::new(); - - // gen 32 random values - for _i in 0..32 { - let x_left: u8 = rng.gen(); - let x_right: u8 = rng.gen(); - let y_left: u8 = rng.gen(); - let y_right: u8 = rng.gen(); - let prss_left: u8 = rng.gen(); - let prss_right: u8 = rng.gen(); - // we set this up to be equal to z_right for this local test - // local here means that only a single party is involved - // and we just test this against this single party - let z_right: u8 = (x_left & y_left) - ^ (x_left & y_right) - ^ (x_right & y_left) - ^ prss_left - ^ prss_right; - - // fill vec - vec_x_left.push(x_left); - vec_x_right.push(x_right); - vec_y_left.push(y_left); - vec_y_right.push(y_right); - vec_prss_left.push(prss_left); - vec_prss_right.push(prss_right); - vec_z_right.push(z_right); + /// Rotate the "left" values into the "right" values, setting the left values to + /// zero. `z_right` is calculated to be consistent with the other values. If the + /// input represents a prover's block of intermediates, the output represents + /// the intermediates that the verifier on the prover's left shares with it. + #[must_use] + pub fn rotate_right(&self) -> Self { + let z_right = (self.x_left & self.y_left) + ^ (self.x_left & self.y_right) + ^ (self.x_right & self.y_left) + ^ self.prss_left + ^ self.prss_right; + + Self { + x_right: self.x_left, + y_right: self.y_left, + prss_right: self.prss_left, + x_left: [0u8; 32].into(), + y_left: [0u8; 32].into(), + prss_left: [0u8; 32].into(), + z_right, + } } + } - // conv to BitVec - let x_left = BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_x_left)).unwrap(); - let x_right = BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_x_right)).unwrap(); - let y_left = BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_y_left)).unwrap(); - let y_right = BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_y_right)).unwrap(); - let prss_left = - BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_prss_left)).unwrap(); - let prss_right = - BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_prss_right)).unwrap(); - let z_right = BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_z_right)).unwrap(); - - // check consistency of the polynomials - assert_convert( - Fp61BitPrime::convert_prover(&x_left, &x_right, &y_left, &y_right, &prss_right), - // flip intputs right to left since it is checked against itself and not party on the left - // z_right is set to match z_left - Fp61BitPrime::convert_value_from_right_prover(&x_left, &y_left, &prss_left, &z_right), - // flip intputs right to left since it is checked against itself and not party on the left - Fp61BitPrime::convert_value_from_left_prover(&x_right, &y_right, &prss_right), - ); + #[test] + fn batch_convert() { + run_random(|mut rng| async move { + let block = rng.gen::(); + + // When verifying, we rotate the intermediates to match what each prover + // would have. `rotate_right` also calculates z_right from the others. + assert_convert( + block.table_indices_prover(), + block.rotate_right().table_indices_from_right_prover(), + block.rotate_left().table_indices_from_left_prover(), + ); + }); } + fn assert_convert(prover: P, verifier_left: L, verifier_right: R) where - P: IntoIterator>, - L: IntoIterator, - R: IntoIterator, + P: IntoIterator, + L: IntoIterator, + R: IntoIterator, { prover .into_iter() - .zip( - verifier_left - .into_iter() - .collect::>(), - ) - .zip( - verifier_right - .into_iter() - .collect::>(), - ) + .zip(verifier_left.into_iter().collect::>()) + .zip(verifier_right.into_iter().collect::>()) .for_each(|((prover, verifier_left), verifier_right)| { assert_eq!(prover.0, verifier_left); assert_eq!(prover.1, verifier_right); @@ -534,37 +523,15 @@ mod tests { } #[allow(clippy::fn_params_excessive_bools)] - fn correctness_prover_values( + #[must_use] + pub fn reference_convert( x_left: bool, x_right: bool, y_left: bool, y_right: bool, prss_left: bool, prss_right: bool, - ) { - let mut array_x_left = BitArray::<[u8; 32]>::ZERO; - let mut array_x_right = BitArray::<[u8; 32]>::ZERO; - let mut array_y_left = BitArray::<[u8; 32]>::ZERO; - let mut array_y_right = BitArray::<[u8; 32]>::ZERO; - let mut array_prss_left = BitArray::<[u8; 32]>::ZERO; - let mut array_prss_right = BitArray::<[u8; 32]>::ZERO; - - // initialize bits - array_x_left.set(0, x_left); - array_x_right.set(0, x_right); - array_y_left.set(0, y_left); - array_y_right.set(0, y_right); - array_prss_left.set(0, prss_left); - array_prss_right.set(0, prss_right); - - let prover = Fp61BitPrime::convert_prover( - &array_x_left, - &array_x_right, - &array_y_left, - &array_y_right, - &array_prss_right, - )[0]; - + ) -> ([Fp61BitPrime; 4], [Fp61BitPrime; 4]) { // compute expected // (a,b,c,d,f) = (x_left, y_right, y_left, x_right, prss_right) // e = x_left · y_left ⊕ z_left ⊕ prss_left @@ -601,17 +568,59 @@ mod tests { // h4=1-2f, let h4 = one_minus_two_f; + ([g1, g2, g3, g4], [h1, h2, h3, h4]) + } + + #[allow(clippy::fn_params_excessive_bools)] + fn correctness_prover_values( + x_left: bool, + x_right: bool, + y_left: bool, + y_right: bool, + prss_left: bool, + prss_right: bool, + ) { + let mut array_x_left = BitArray::<[u8; 32]>::ZERO; + let mut array_x_right = BitArray::<[u8; 32]>::ZERO; + let mut array_y_left = BitArray::<[u8; 32]>::ZERO; + let mut array_y_right = BitArray::<[u8; 32]>::ZERO; + let mut array_prss_left = BitArray::<[u8; 32]>::ZERO; + let mut array_prss_right = BitArray::<[u8; 32]>::ZERO; + + // initialize bits + array_x_left.set(0, x_left); + array_x_right.set(0, x_right); + array_y_left.set(0, y_left); + array_y_right.set(0, y_right); + array_prss_left.set(0, prss_left); + array_prss_right.set(0, prss_right); + + let block = MultiplicationInputsBlock { + x_left: array_x_left, + x_right: array_x_right, + y_left: array_y_left, + y_right: array_y_right, + prss_left: array_prss_left, + prss_right: array_prss_right, + z_right: BitArray::ZERO, + }; + + let prover = block.table_indices_prover()[0]; + + let ([g1, g2, g3, g4], [h1, h2, h3, h4]) = + reference_convert(x_left, x_right, y_left, y_right, prss_left, prss_right); + // check expected == computed // g polynomial - assert_eq!(g1, prover.0[0]); - assert_eq!(g2, prover.0[1]); - assert_eq!(g3, prover.0[2]); - assert_eq!(g4, prover.0[3]); + assert_eq!(g1, TABLE_U[prover.0][0]); + assert_eq!(g2, TABLE_U[prover.0][1]); + assert_eq!(g3, TABLE_U[prover.0][2]); + assert_eq!(g4, TABLE_U[prover.0][3]); // h polynomial - assert_eq!(h1, prover.1[0]); - assert_eq!(h2, prover.1[1]); - assert_eq!(h3, prover.1[2]); - assert_eq!(h4, prover.1[3]); + assert_eq!(h1, TABLE_V[prover.1][0]); + assert_eq!(h2, TABLE_V[prover.1][1]); + assert_eq!(h3, TABLE_V[prover.1][2]); + assert_eq!(h4, TABLE_V[prover.1][3]); } } diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index fd616fc9c..5d8d58b2e 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -12,7 +12,7 @@ use crate::{ protocol::{ context::{ batcher::Batcher, - dzkp_field::{DZKPBaseField, UVTupleBlock}, + dzkp_field::{DZKPBaseField, TABLE_U, TABLE_V}, dzkp_malicious::DZKPUpgraded as MaliciousDZKPUpgraded, dzkp_semi_honest::DZKPUpgraded as SemiHonestDZKPUpgraded, step::DzkpValidationProtocolStep as Step, @@ -20,7 +20,8 @@ use crate::{ }, ipa_prf::{ validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, - CompressedProofGenerator, FirstProofGenerator, + CompressedProofGenerator, FirstProofGenerator, ProverTableIndices, + VerifierTableIndices, }, Gate, RecordId, RecordIdRange, }, @@ -33,7 +34,7 @@ pub type Array256Bit = BitArray<[u8; 32], Lsb0>; type BitSliceType = BitSlice; -const BIT_ARRAY_LEN: usize = 256; +pub const BIT_ARRAY_LEN: usize = 256; const BIT_ARRAY_MASK: usize = BIT_ARRAY_LEN - 1; const BIT_ARRAY_SHIFT: usize = BIT_ARRAY_LEN.ilog2() as usize; @@ -58,6 +59,12 @@ pub const TARGET_PROOF_SIZE: usize = 8192; #[cfg(not(test))] pub const TARGET_PROOF_SIZE: usize = 50_000_000; +/// Minimum proof recursion depth. +/// +/// This minimum avoids special cases in the implementation that would be otherwise +/// required when the initial and final recursion steps overlap. +pub const MIN_PROOF_RECURSION: usize = 2; + /// Maximum proof recursion depth. // // This is a hard limit. Each GF(2) multiply generates four G values and four H values, @@ -74,8 +81,6 @@ pub const TARGET_PROOF_SIZE: usize = 50_000_000; // Because the number of records in a proof batch is often rounded up to a power of two // (and less significantly, because multiplication intermediate storage gets rounded up // to blocks of 256), leaving some margin is advised. -// -// The implementation requires that MAX_PROOF_RECURSION is at least 2. pub const MAX_PROOF_RECURSION: usize = 14; /// `MultiplicationInputsBlock` is a block of fixed size of intermediate values @@ -159,34 +164,21 @@ impl MultiplicationInputsBlock { Ok(()) } +} - /// `Convert` allows to convert `MultiplicationInputs` into a format compatible with DZKPs - /// This is the convert function called by the prover. - fn convert_prover(&self) -> Vec> { - DF::convert_prover( - &self.x_left, - &self.x_right, - &self.y_left, - &self.y_right, - &self.prss_right, - ) - } - - /// `convert_values_from_right_prover` allows to convert `MultiplicationInputs` into a format compatible with DZKPs - /// This is the convert function called by the verifier on the left. - fn convert_values_from_right_prover(&self) -> Vec { - DF::convert_value_from_right_prover( - &self.x_right, - &self.y_right, - &self.prss_right, - &self.z_right, - ) - } - - /// `convert_values_from_left_prover` allows to convert `MultiplicationInputs` into a format compatible with DZKPs - /// This is the convert function called by the verifier on the right. - fn convert_values_from_left_prover(&self) -> Vec { - DF::convert_value_from_left_prover(&self.x_left, &self.y_left, &self.prss_left) +#[cfg(any(test, feature = "enable-benches"))] +impl rand::prelude::Distribution for rand::distributions::Standard { + fn sample(&self, rng: &mut R) -> MultiplicationInputsBlock { + let sample = >::sample; + MultiplicationInputsBlock { + x_left: sample(self, rng).into(), + x_right: sample(self, rng).into(), + y_left: sample(self, rng).into(), + y_right: sample(self, rng).into(), + prss_left: sample(self, rng).into(), + prss_right: sample(self, rng).into(), + z_right: sample(self, rng).into(), + } } } @@ -472,34 +464,30 @@ impl MultiplicationInputsBatch { } } - /// `get_field_values_prover` converts a `MultiplicationInputsBatch` into an iterator over `field` - /// values used by the prover of the DZKPs - fn get_field_values_prover( - &self, - ) -> impl Iterator> + Clone + '_ { + /// `get_field_values_prover` converts a `MultiplicationInputsBatch` into an + /// iterator over pairs of indices for `TABLE_U` and `TABLE_V`. + fn get_field_values_prover(&self) -> impl Iterator + Clone + '_ { self.vec .iter() - .flat_map(MultiplicationInputsBlock::convert_prover::) + .flat_map(MultiplicationInputsBlock::table_indices_prover) } - /// `get_field_values_from_right_prover` converts a `MultiplicationInputsBatch` into an iterator over `field` - /// values used by the verifier of the DZKPs on the left side of the prover, i.e. the `u` values. - fn get_field_values_from_right_prover( - &self, - ) -> impl Iterator + '_ { + /// `get_field_values_from_right_prover` converts a `MultiplicationInputsBatch` into + /// an iterator over table indices for `TABLE_U`, which is used by the verifier of + /// the DZKPs on the left side of the prover. + fn get_field_values_from_right_prover(&self) -> impl Iterator + '_ { self.vec .iter() - .flat_map(MultiplicationInputsBlock::convert_values_from_right_prover::) + .flat_map(MultiplicationInputsBlock::table_indices_from_right_prover) } - /// `get_field_values_from_left_prover` converts a `MultiplicationInputsBatch` into an iterator over `field` - /// values used by the verifier of the DZKPs on the right side of the prover, i.e. the `v` values. - fn get_field_values_from_left_prover( - &self, - ) -> impl Iterator + '_ { + /// `get_field_values_from_left_prover` converts a `MultiplicationInputsBatch` into + /// an iterator over table indices for `TABLE_V`, which is used by the verifier of + /// the DZKPs on the right side of the prover. + fn get_field_values_from_left_prover(&self) -> impl Iterator + '_ { self.vec .iter() - .flat_map(MultiplicationInputsBlock::convert_values_from_left_prover::) + .flat_map(MultiplicationInputsBlock::table_indices_from_left_prover) } } @@ -565,36 +553,30 @@ impl Batch { .sum() } - /// `get_field_values_prover` converts a `Batch` into an iterator over field values - /// which is used by the prover of the DZKP - fn get_field_values_prover( - &self, - ) -> impl Iterator> + Clone + '_ { + /// `get_field_values_prover` converts a `Batch` into an iterator over pairs of + /// indices for `TABLE_U` and `TABLE_V`. + fn get_field_values_prover(&self) -> impl Iterator + Clone + '_ { self.inner .values() - .flat_map(MultiplicationInputsBatch::get_field_values_prover::) + .flat_map(MultiplicationInputsBatch::get_field_values_prover) } - /// `get_field_values_from_right_prover` converts a `Batch` into an iterator over field values - /// which is used by the verifier of the DZKP on the left side of the prover. - /// This produces the `u` values. - fn get_field_values_from_right_prover( - &self, - ) -> impl Iterator + '_ { + /// `get_field_values_from_right_prover` converts a `Batch` into an iterator over + /// table indices for `TABLE_U`, which is used by the verifier of the DZKP on the + /// left side of the prover. + fn get_field_values_from_right_prover(&self) -> impl Iterator + '_ { self.inner .values() - .flat_map(MultiplicationInputsBatch::get_field_values_from_right_prover::) + .flat_map(MultiplicationInputsBatch::get_field_values_from_right_prover) } - /// `get_field_values_from_left_prover` converts a `Batch` into an iterator over field values - /// which is used by the verifier of the DZKP on the right side of the prover. - /// This produces the `v` values. - fn get_field_values_from_left_prover( - &self, - ) -> impl Iterator + '_ { + /// `get_field_values_from_left_prover` converts a `Batch` into an iterator over + /// table indices for `TABLE_V`, which is used by the verifier of the DZKP on the + /// right side of the prover. + fn get_field_values_from_left_prover(&self) -> impl Iterator + '_ { self.inner .values() - .flat_map(MultiplicationInputsBatch::get_field_values_from_left_prover::) + .flat_map(MultiplicationInputsBatch::get_field_values_from_left_prover) } /// ## Panics @@ -626,7 +608,11 @@ impl Batch { q_mask_from_left_prover, ) = { // generate BatchToVerify - ProofBatch::generate(&proof_ctx, prss_record_ids, self.get_field_values_prover()) + ProofBatch::generate( + &proof_ctx, + prss_record_ids, + ProverTableIndices(self.get_field_values_prover()), + ) }; let chunk_batch = BatchToVerify::generate_batch_to_verify( @@ -650,12 +636,7 @@ impl Batch { tracing::info!("validating {m} multiplications"); debug_assert_eq!( m, - self.get_field_values_prover::() - .flat_map(|(u_array, v_array)| { - u_array.into_iter().zip(v_array).map(|(u, v)| u * v) - }) - .count() - / 4, + self.get_field_values_prover().count(), "Number of multiplications is counted incorrectly" ); let sum_of_uv = Fp61BitPrime::truncate_from(u128::try_from(m).unwrap()) @@ -664,8 +645,14 @@ impl Batch { let (p_r_right_prover, q_r_left_prover) = chunk_batch.compute_p_and_q_r( &challenges_for_left_prover, &challenges_for_right_prover, - self.get_field_values_from_right_prover(), - self.get_field_values_from_left_prover(), + VerifierTableIndices { + input: self.get_field_values_from_right_prover(), + table: &TABLE_U, + }, + VerifierTableIndices { + input: self.get_field_values_from_left_prover(), + table: &TABLE_V, + }, ); (sum_of_uv, p_r_right_prover, q_r_left_prover) @@ -965,12 +952,11 @@ mod tests { ff::{ boolean::Boolean, boolean_array::{BooleanArray, BA16, BA20, BA256, BA3, BA32, BA64, BA8}, - Fp61BitPrime, }, protocol::{ basics::{select, BooleanArrayMul, SecureMul}, context::{ - dzkp_field::{DZKPCompatibleField, BLOCK_SIZE}, + dzkp_field::DZKPCompatibleField, dzkp_validator::{ Batch, DZKPValidator, Segment, SegmentEntry, BIT_ARRAY_LEN, TARGET_PROOF_SIZE, }, @@ -1814,16 +1800,16 @@ mod tests { fn assert_batch_convert(batch_prover: &Batch, batch_left: &Batch, batch_right: &Batch) { batch_prover - .get_field_values_prover::() + .get_field_values_prover() .zip( batch_left - .get_field_values_from_right_prover::() - .collect::>(), + .get_field_values_from_right_prover() + .collect::>(), ) .zip( batch_right - .get_field_values_from_left_prover::() - .collect::>(), + .get_field_values_from_left_prover() + .collect::>(), ) .for_each(|((prover, verifier_left), verifier_right)| { assert_eq!(prover.0, verifier_left); diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 41abab508..1f598cb31 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -86,7 +86,7 @@ where /// The "x coordinate" of the output point is `x_output`. pub fn new(denominator: &CanonicalLagrangeDenominator, x_output: &F) -> Self { // assertion that table is not too large for the stack - assert!(::Size::USIZE * N < 2024); + debug_assert!(::Size::USIZE * N < 2024); let table = Self::compute_table_row(x_output, denominator); LagrangeTable:: { table: [table; 1] } diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs index 607827b1a..9d366735b 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs @@ -1,3 +1,11 @@ pub mod lagrange; pub mod prover; pub mod verifier; + +pub type FirstProofGenerator = prover::SmallProofGenerator; +pub type CompressedProofGenerator = prover::SmallProofGenerator; +pub use lagrange::{CanonicalLagrangeDenominator, LagrangeTable}; +pub use prover::ProverTableIndices; +pub use verifier::VerifierTableIndices; + +pub const FIRST_RECURSION_FACTOR: usize = FirstProofGenerator::RECURSION_FACTOR; diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs index 26ce5c5d0..aa1b25e9b 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs @@ -1,18 +1,25 @@ -use std::{borrow::Borrow, iter::zip, marker::PhantomData}; +use std::{array, borrow::Borrow, marker::PhantomData}; use crate::{ error::Error::{self, DZKPMasks}, - ff::{Fp61BitPrime, PrimeField}, + ff::{Fp61BitPrime, MultiplyAccumulate, MultiplyAccumulatorArray, PrimeField}, helpers::hashing::{compute_hash, hash_to_field}, protocol::{ - context::Context, + context::{ + dzkp_field::{TABLE_U, TABLE_V}, + Context, + }, ipa_prf::{ - malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + malicious_security::{ + lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + FIRST_RECURSION_FACTOR as FRF, + }, CompressedProofGenerator, }, prss::SharedRandomness, RecordId, RecordIdRange, }, + secret_sharing::SharedValue, }; /// This struct stores intermediate `uv` values. @@ -95,6 +102,213 @@ where } } +/// Trait for inputs to Lagrange interpolation on the prover. +/// +/// Lagrange interpolation is used in the prover in two ways: +/// +/// 1. To extrapolate an additional λ - 1 y-values of degree-(λ-1) polynomials so that the +/// total 2λ - 1 y-values can be multiplied to obtain a representation of the product of +/// the polynomials. +/// 2. To evaluate polynomials at the randomly chosen challenge point _r_. +/// +/// The two methods in this trait correspond to those two uses. +/// +/// There are two implementations of this trait: `ProverTableIndices`, and +/// `ProverValues`. `ProverTableIndices` is used for the input to the first proof. Each +/// set of 4 _u_ or _v_ values input to the first proof has one of eight possible +/// values, determined by the values of the 3 associated multiplication intermediates. +/// The `ProverTableIndices` implementation uses a lookup table containing the output of +/// the Lagrange interpolation for each of these eight possible values. The +/// `ProverValues` implementation, which represents actual _u_ and _v_ values, is used +/// by the remaining recursive proofs. +/// +/// There is a similar trait `VerifierLagrangeInput` in `verifier.rs`. The difference is +/// that the prover operates on _u_ and _v_ values simultaneously (i.e. iterators of +/// tuples). The verifier operates on only one of _u_ or _v_ at a time. +pub trait ProverLagrangeInput { + fn extrapolate_y_values<'a, const P: usize, const M: usize>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a; + + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a; +} + +/// Implementation of `ProverLagrangeInput` for table indices, used for the first proof. +#[derive(Clone)] +pub struct ProverTableIndices>(pub I); + +/// Iterator returned by `ProverTableIndices::extrapolate_y_values` and +/// `ProverTableIndices::eval_at_r`. +struct TableIndicesIterator> { + input: I, + u_table: [T; 8], + v_table: [T; 8], +} + +impl> ProverLagrangeInput + for ProverTableIndices +{ + fn extrapolate_y_values<'a, const P: usize, const M: usize>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + TableIndicesIterator { + input: self.0, + u_table: array::from_fn(|i| { + let mut result = [Fp61BitPrime::ZERO; P]; + let u = &TABLE_U[i]; + result[0..FRF].copy_from_slice(u); + result[FRF..].copy_from_slice(&lagrange_table.eval(u)); + result + }), + v_table: array::from_fn(|i| { + let mut result = [Fp61BitPrime::ZERO; P]; + let v = &TABLE_V[i]; + result[0..FRF].copy_from_slice(v); + result[FRF..].copy_from_slice(&lagrange_table.eval(v)); + result + }), + } + } + + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + TableIndicesIterator { + input: self.0, + u_table: array::from_fn(|i| lagrange_table.eval(&TABLE_U[i])[0]), + v_table: array::from_fn(|i| lagrange_table.eval(&TABLE_V[i])[0]), + } + } +} + +impl> Iterator for TableIndicesIterator { + type Item = (T, T); + + fn next(&mut self) -> Option { + self.input.next().map(|(u_index, v_index)| { + ( + self.u_table[usize::from(u_index)].clone(), + self.v_table[usize::from(v_index)].clone(), + ) + }) + } +} + +/// Implementation of `ProverLagrangeInput` for _u_ and _v_ values, used for subsequent +/// recursive proofs. +#[derive(Clone)] +pub struct ProverValues>(pub I); + +/// Iterator returned by `ProverValues::extrapolate_y_values`. +struct ValuesExtrapolateIterator< + 'a, + F: PrimeField, + const L: usize, + const P: usize, + const M: usize, + I: Iterator, +> { + input: I, + lagrange_table: &'a LagrangeTable, +} + +impl< + 'a, + F: PrimeField, + const L: usize, + const P: usize, + const M: usize, + I: Iterator, + > Iterator for ValuesExtrapolateIterator<'a, F, L, P, M, I> +{ + type Item = ([F; P], [F; P]); + + fn next(&mut self) -> Option { + self.input.next().map(|(u_values, v_values)| { + let mut u = [F::ZERO; P]; + u[0..L].copy_from_slice(&u_values); + u[L..].copy_from_slice(&self.lagrange_table.eval(&u_values)); + let mut v = [F::ZERO; P]; + v[0..L].copy_from_slice(&v_values); + v[L..].copy_from_slice(&self.lagrange_table.eval(&v_values)); + (u, v) + }) + } +} + +/// Iterator returned by `ProverValues::eval_at_r`. +struct ValuesEvalAtRIterator< + 'a, + F: PrimeField, + const L: usize, + I: Iterator, +> { + input: I, + lagrange_table: &'a LagrangeTable, +} + +impl<'a, F: PrimeField, const L: usize, I: Iterator> Iterator + for ValuesEvalAtRIterator<'a, F, L, I> +{ + type Item = (F, F); + + fn next(&mut self) -> Option { + self.input.next().map(|(u_values, v_values)| { + ( + self.lagrange_table.eval(&u_values)[0], + self.lagrange_table.eval(&v_values)[0], + ) + }) + } +} + +impl> ProverLagrangeInput + for ProverValues +{ + fn extrapolate_y_values<'a, const P: usize, const M: usize>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + debug_assert_eq!(L + M, P); + ValuesExtrapolateIterator { + input: self.0, + lagrange_table, + } + } + + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + ValuesEvalAtRIterator { + input: self.0, + lagrange_table, + } + } +} + /// This struct sets up the parameter for the proof generation /// and provides several functions to generate zero knowledge proofs. /// @@ -119,40 +333,47 @@ impl ProofGenerat pub const PROOF_LENGTH: usize = P; pub const LAGRANGE_LENGTH: usize = M; + pub fn compute_proof_from_uv(uv: J, lagrange_table: &LagrangeTable) -> [F; P] + where + J: Iterator, + J::Item: Borrow<([F; L], [F; L])>, + { + Self::compute_proof(uv.map(|uv| { + let (u, v) = uv.borrow(); + let mut u_ex = [F::ZERO; P]; + let mut v_ex = [F::ZERO; P]; + u_ex[0..L].copy_from_slice(u); + v_ex[0..L].copy_from_slice(v); + u_ex[L..].copy_from_slice(&lagrange_table.eval(u)); + v_ex[L..].copy_from_slice(&lagrange_table.eval(v)); + (u_ex, v_ex) + })) + } + /// /// Distributed Zero Knowledge Proofs algorithm drawn from /// `https://eprint.iacr.org/2023/909.pdf` - pub fn compute_proof(uv_iterator: J, lagrange_table: &LagrangeTable) -> [F; P] + pub fn compute_proof(pq_iterator: J) -> [F; P] where J: Iterator, - J::Item: Borrow<([F; L], [F; L])>, + J::Item: Borrow<([F; P], [F; P])>, { - let mut proof = [F::ZERO; P]; - for uv_polynomial in uv_iterator { - for (i, proof_part) in proof.iter_mut().enumerate().take(L) { - *proof_part += uv_polynomial.borrow().0[i] * uv_polynomial.borrow().1[i]; - } - let p_extrapolated = lagrange_table.eval(&uv_polynomial.borrow().0); - let q_extrapolated = lagrange_table.eval(&uv_polynomial.borrow().1); - - for (i, (x, y)) in - zip(p_extrapolated.into_iter(), q_extrapolated.into_iter()).enumerate() - { - proof[L + i] += x * y; - } - } - proof + pq_iterator + .fold( + ::AccumulatorArray::

::new(), + |mut proof, pq| { + proof.multiply_accumulate(&pq.borrow().0, &pq.borrow().1); + proof + }, + ) + .take() } - fn gen_challenge_and_recurse( + fn gen_challenge_and_recurse, const N: usize>( proof_left: &[F; P], proof_right: &[F; P], - uv_iterator: J, - ) -> UVValues - where - J: Iterator, - J::Item: Borrow<([F; L], [F; L])>, - { + uv_iterator: I, + ) -> UVValues { let r: F = hash_to_field( &compute_hash(proof_left), &compute_hash(proof_right), @@ -162,17 +383,8 @@ impl ProofGenerat let denominator = CanonicalLagrangeDenominator::::new(); let lagrange_table_r = LagrangeTable::::new(&denominator, &r); - // iter and interpolate at x coordinate r uv_iterator - .map(|polynomial| { - let (u_chunk, v_chunk) = polynomial.borrow(); - ( - // new u value - lagrange_table_r.eval(u_chunk)[0], - // new v value - lagrange_table_r.eval(v_chunk)[0], - ) - }) + .eval_at_r(&lagrange_table_r) .collect::>() } @@ -202,29 +414,23 @@ impl ProofGenerat proof_other_share } - /// This function is a helper function that computes the next proof - /// from an iterator over uv values - /// It also computes the next uv values + /// This function is a helper function that, given the computed proof, computes the shares of + /// the proof, the challenge, and the next uv values. /// /// It output `(uv values, share_of_proof_from_prover_left, my_proof_left_share)` /// where /// `share_of_proof_from_prover_left` from left has type `Vec<[F; P]>`, /// `my_proof_left_share` has type `Vec<[F; P]>`, - pub fn gen_artefacts_from_recursive_step( + pub fn gen_artefacts_from_recursive_step( ctx: &C, record_ids: &mut RecordIdRange, - lagrange_table: &LagrangeTable, - uv_iterator: J, + my_proof: [F; P], + uv_iterator: I, ) -> (UVValues, [F; P], [F; P]) where C: Context, - J: Iterator + Clone, - J::Item: Borrow<([F; L], [F; L])>, + I: ProverLagrangeInput, { - // generate next proof - // from iterator - let my_proof = Self::compute_proof(uv_iterator.clone(), lagrange_table); - // generate proof shares from prss let (share_of_proof_from_prover_left, my_proof_right_share) = Self::gen_proof_shares_from_prss(ctx, record_ids); @@ -254,20 +460,26 @@ mod test { use std::iter::zip; use futures::future::try_join; + use rand::Rng; + use super::*; use crate::{ ff::{Fp31, Fp61BitPrime, PrimeField, U128Conversions}, helpers::{Direction, Role}, protocol::{ - context::Context, - ipa_prf::malicious_security::{ - lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, - prover::{ProofGenerator, SmallProofGenerator, UVValues}, + context::{ + dzkp_field::tests::reference_convert, + dzkp_validator::{MultiplicationInputsBlock, BIT_ARRAY_LEN}, + Context, + }, + ipa_prf::{ + malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + FirstProofGenerator, }, RecordId, RecordIdRange, }, seq_join::SeqJoin, - test_executor::run, + test_executor::{run, run_random}, test_fixture::{Runner, TestWorld}, }; @@ -316,7 +528,7 @@ mod test { let uv_1 = zip_chunks(U_1, V_1); // first iteration - let proof_1 = TestProofGenerator::compute_proof(uv_1.iter(), &lagrange_table); + let proof_1 = TestProofGenerator::compute_proof_from_uv(uv_1.iter(), &lagrange_table); assert_eq!( proof_1.iter().map(Fp31::as_u128).collect::>(), PROOF_1, @@ -335,12 +547,12 @@ mod test { let uv_2 = TestProofGenerator::gen_challenge_and_recurse( &proof_left_1, &proof_right_1, - uv_1.iter(), + ProverValues(uv_1.iter().copied()), ); assert_eq!(uv_2, zip_chunks(U_2, V_2)); // next iteration - let proof_2 = TestProofGenerator::compute_proof(uv_2.iter(), &lagrange_table); + let proof_2 = TestProofGenerator::compute_proof_from_uv(uv_2.iter(), &lagrange_table); assert_eq!( proof_2.iter().map(Fp31::as_u128).collect::>(), PROOF_2, @@ -359,7 +571,7 @@ mod test { let uv_3 = TestProofGenerator::gen_challenge_and_recurse::<_, 4>( &proof_left_2, &proof_right_2, - uv_2.iter(), + ProverValues(uv_2.iter().copied()), ); assert_eq!(uv_3, zip_chunks(U_3, V_3)); @@ -369,7 +581,8 @@ mod test { ); // final iteration - let proof_3 = TestProofGenerator::compute_proof(masked_uv_3.iter(), &lagrange_table); + let proof_3 = + TestProofGenerator::compute_proof_from_uv(masked_uv_3.iter(), &lagrange_table); assert_eq!( proof_3.iter().map(Fp31::as_u128).collect::>(), PROOF_3, @@ -397,11 +610,12 @@ mod test { // first iteration let world = TestWorld::default(); let mut record_ids = RecordIdRange::ALL; + let proof = TestProofGenerator::compute_proof_from_uv(uv_1.iter(), &lagrange_table); let (uv_values, _, _) = TestProofGenerator::gen_artefacts_from_recursive_step::<_, _, 4>( &world.contexts()[0], &mut record_ids, - &lagrange_table, - uv_1.iter(), + proof, + ProverValues(uv_1.iter().copied()), ); assert_eq!(7, uv_values.len()); @@ -436,14 +650,14 @@ mod test { >::from(denominator); // compute proof - let proof = SmallProofGenerator::compute_proof(uv_before.iter(), &lagrange_table); + let proof = SmallProofGenerator::compute_proof_from_uv(uv_before.iter(), &lagrange_table); assert_eq!(proof.len(), SmallProofGenerator::PROOF_LENGTH); let uv_after = SmallProofGenerator::gen_challenge_and_recurse::<_, 8>( &proof, &proof, - uv_before.iter(), + ProverValues(uv_before.iter().copied()), ); assert_eq!( @@ -472,14 +686,14 @@ mod test { >::from(denominator); // compute proof - let proof = LargeProofGenerator::compute_proof(uv_before.iter(), &lagrange_table); + let proof = LargeProofGenerator::compute_proof_from_uv(uv_before.iter(), &lagrange_table); assert_eq!(proof.len(), LargeProofGenerator::PROOF_LENGTH); let uv_after = LargeProofGenerator::gen_challenge_and_recurse::<_, 8>( &proof, &proof, - uv_before.iter(), + ProverValues(uv_before.iter().copied()), ); assert_eq!( @@ -594,4 +808,51 @@ mod test { assert_two_part_secret_sharing(PROOF_2, h1_proof_right, h3_proof_left); assert_two_part_secret_sharing(PROOF_3, h2_proof_right, h1_proof_left); } + + #[test] + fn prover_table_indices_equivalence() { + run_random(|mut rng| async move { + const FPL: usize = FirstProofGenerator::PROOF_LENGTH; + const FLL: usize = FirstProofGenerator::LAGRANGE_LENGTH; + + let block = rng.gen::(); + + // Test equivalence for extrapolate_y_values + let denominator = CanonicalLagrangeDenominator::new(); + let lagrange_table = LagrangeTable::from(denominator); + + assert!(ProverTableIndices(block.table_indices_prover().into_iter()) + .extrapolate_y_values::(&lagrange_table) + .eq(ProverValues((0..BIT_ARRAY_LEN).map(|i| { + reference_convert( + block.x_left[i], + block.x_right[i], + block.y_left[i], + block.y_right[i], + block.prss_left[i], + block.prss_right[i], + ) + })) + .extrapolate_y_values::(&lagrange_table))); + + // Test equivalence for eval_at_r + let denominator = CanonicalLagrangeDenominator::new(); + let r = rng.gen(); + let lagrange_table_r = LagrangeTable::new(&denominator, &r); + + assert!(ProverTableIndices(block.table_indices_prover().into_iter()) + .eval_at_r(&lagrange_table_r) + .eq(ProverValues((0..BIT_ARRAY_LEN).map(|i| { + reference_convert( + block.x_left[i], + block.x_right[i], + block.y_left[i], + block.y_right[i], + block.prss_left[i], + block.prss_right[i], + ) + })) + .eval_at_r(&lagrange_table_r))); + }); + } } diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs index e62465b73..0038d4b85 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs @@ -1,9 +1,18 @@ -use std::iter::{self}; +use std::{ + array, + iter::{self}, +}; use crate::{ ff::PrimeField, - protocol::ipa_prf::malicious_security::lagrange::{ - CanonicalLagrangeDenominator, LagrangeTable, + protocol::{ + context::{ + dzkp_field::UVTable, + dzkp_validator::{MAX_PROOF_RECURSION, MIN_PROOF_RECURSION}, + }, + ipa_prf::malicious_security::{ + CanonicalLagrangeDenominator, LagrangeTable, FIRST_RECURSION_FACTOR as FRF, + }, }, utils::arraychunks::ArrayChunkIterator, }; @@ -96,35 +105,30 @@ pub fn compute_final_sum_share(zk } /// This function compresses the `u_or_v` values and returns the next `u_or_v` values. -fn recurse_u_or_v<'a, F: PrimeField, J, const L: usize>( - u_or_v: J, +fn recurse_u_or_v<'a, F: PrimeField, const L: usize>( + u_or_v: impl Iterator + 'a, lagrange_table: &'a LagrangeTable, -) -> impl Iterator + 'a -where - J: Iterator + 'a, -{ - u_or_v - .chunk_array::() - .map(|x| lagrange_table.eval(&x)[0]) +) -> impl Iterator + 'a { + VerifierValues(u_or_v.chunk_array::()).eval_at_r(lagrange_table) } /// This function recursively compresses the `u_or_v` values. -/// The recursion factor (or compression) of the first recursion is `L_FIRST` -/// The recursion factor of all following recursions is `L`. -pub fn recursively_compute_final_check( - u_or_v: J, +/// +/// The recursion factor (or compression) of the first recursion is fixed at +/// `FIRST_RECURSION_FACTOR` (`FRF`). The recursion factor of all following recursions +/// is `L`. +pub fn recursively_compute_final_check( + input: impl VerifierLagrangeInput, challenges: &[F], p_or_q_0: F, -) -> F -where - J: Iterator, -{ +) -> F { + // This function requires MIN_PROOF_RECURSION be at least 2. + assert!(challenges.len() >= MIN_PROOF_RECURSION && challenges.len() <= MAX_PROOF_RECURSION); let recursions_after_first = challenges.len() - 1; // compute Lagrange tables - let denominator_p_or_q_first = CanonicalLagrangeDenominator::::new(); - let table_first = - LagrangeTable::::new(&denominator_p_or_q_first, &challenges[0]); + let denominator_p_or_q_first = CanonicalLagrangeDenominator::::new(); + let table_first = LagrangeTable::::new(&denominator_p_or_q_first, &challenges[0]); let denominator_p_or_q = CanonicalLagrangeDenominator::::new(); let tables = challenges[1..] .iter() @@ -133,11 +137,10 @@ where // generate & evaluate recursive streams // to compute last array - let mut iterator: Box> = - Box::new(recurse_u_or_v::<_, _, L_FIRST>(u_or_v, &table_first)); + let mut iterator: Box> = Box::new(input.eval_at_r(&table_first)); // all following recursion except last one for lagrange_table in tables.iter().take(recursions_after_first - 1) { - iterator = Box::new(recurse_u_or_v::<_, _, L>(iterator, lagrange_table)); + iterator = Box::new(recurse_u_or_v(iterator, lagrange_table)); } let last_u_or_v_values = iterator.collect::>(); // Make sure there are less than L last u or v values. The prover is expected to continue @@ -162,18 +165,127 @@ where tables.last().unwrap().eval(&last_array)[0] } +/// Trait for inputs to Lagrange interpolation on the verifier. +/// +/// Lagrange interpolation is used in the verifier to evaluate polynomials at the +/// randomly chosen challenge point _r_. +/// +/// There are two implementations of this trait: `VerifierTableIndices`, and +/// `VerifierValues`. `VerifierTableIndices` is used for the input to the first proof. +/// Each set of 4 _u_ or _v_ values input to the first proof has one of eight possible +/// values, determined by the values of the 3 associated multiplication intermediates. +/// The `VerifierTableIndices` implementation uses a lookup table containing the output +/// of the Lagrange interpolation for each of these eight possible values. The +/// `VerifierValues` implementation, which represents actual _u_ and _v_ values, is used +/// by the remaining recursive proofs. +/// +/// There is a similar trait `ProverLagrangeInput` in `prover.rs`. The difference is +/// that the prover operates on _u_ and _v_ values simultaneously (i.e. iterators of +/// tuples). The verifier operates on only one of _u_ or _v_ at a time. +pub trait VerifierLagrangeInput { + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a; +} + +/// Implementation of `VerifierLagrangeInput` for table indices, used for the first proof. +pub struct VerifierTableIndices<'a, F: PrimeField, I: Iterator> { + pub input: I, + pub table: &'a UVTable, +} + +/// Iterator returned by `VerifierTableIndices::eval_at_r`. +struct TableIndicesIterator> { + input: I, + table: [F; 8], +} + +impl> VerifierLagrangeInput + for VerifierTableIndices<'_, F, I> +{ + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + TableIndicesIterator { + input: self.input, + table: array::from_fn(|i| lagrange_table.eval(&self.table[i])[0]), + } + } +} + +impl> Iterator for TableIndicesIterator { + type Item = F; + + fn next(&mut self) -> Option { + self.input + .next() + .map(|index| self.table[usize::from(index)]) + } +} + +/// Implementation of `VerifierLagrangeInput` for _u_ and _v_ values, used for +/// subsequent recursive proofs. +pub struct VerifierValues>(pub I); + +/// Iterator returned by `ProverValues::eval_at_r`. +struct ValuesEvalAtRIterator<'a, F: PrimeField, const L: usize, I: Iterator> { + input: I, + lagrange_table: &'a LagrangeTable, +} + +impl<'a, F: PrimeField, const L: usize, I: Iterator> Iterator + for ValuesEvalAtRIterator<'a, F, L, I> +{ + type Item = F; + + fn next(&mut self) -> Option { + self.input + .next() + .map(|values| self.lagrange_table.eval(&values)[0]) + } +} + +impl> VerifierLagrangeInput + for VerifierValues +{ + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + ValuesEvalAtRIterator { + input: self.0, + lagrange_table, + } + } +} + #[cfg(all(test, unit_test))] mod test { + use rand::Rng; + + use super::*; use crate::{ ff::{Fp31, U128Conversions}, - protocol::ipa_prf::malicious_security::{ - lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, - verifier::{ - compute_g_differences, compute_sum_share, interpolate_at_r, recurse_u_or_v, - recursively_compute_final_check, + protocol::{ + context::{ + dzkp_field::{tests::reference_convert, TABLE_U, TABLE_V}, + dzkp_validator::{MultiplicationInputsBlock, BIT_ARRAY_LEN}, }, + ipa_prf::malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, }, secret_sharing::SharedValue, + test_executor::run_random, + utils::arraychunks::ArrayChunkIterator, }; fn to_field(a: &[u128]) -> Vec { @@ -292,12 +404,11 @@ mod test { let tables: [LagrangeTable; 3] = CHALLENGES .map(|r| LagrangeTable::new(&denominator_p_or_q, &Fp31::try_from(r).unwrap())); - let u_or_v_2 = recurse_u_or_v::<_, _, 4>(to_field(&U_1).into_iter(), &tables[0]) - .collect::>(); + let u_or_v_2 = + recurse_u_or_v::<_, 4>(to_field(&U_1).into_iter(), &tables[0]).collect::>(); assert_eq!(u_or_v_2, to_field(&U_2)); - let u_or_v_3 = - recurse_u_or_v::<_, _, 4>(u_or_v_2.into_iter(), &tables[1]).collect::>(); + let u_or_v_3 = recurse_u_or_v::<_, 4>(u_or_v_2.into_iter(), &tables[1]).collect::>(); assert_eq!(u_or_v_3, to_field(&U_3[..2])); @@ -309,12 +420,12 @@ mod test { ]; let p_final = - recurse_u_or_v::<_, _, 4>(u_or_v_3_masked.into_iter(), &tables[2]).collect::>(); + recurse_u_or_v::<_, 4>(u_or_v_3_masked.into_iter(), &tables[2]).collect::>(); assert_eq!(p_final[0].as_u128(), EXPECTED_P_FINAL); - let p_final_another_way = recursively_compute_final_check::( - to_field(&U_1).into_iter(), + let p_final_another_way = recursively_compute_final_check::<_, 4>( + VerifierValues(to_field(&U_1).into_iter().chunk_array::<4>()), &CHALLENGES .map(|x| Fp31::try_from(x).unwrap()) .into_iter() @@ -396,15 +507,15 @@ mod test { // final iteration let p_final = - recurse_u_or_v::<_, _, 4>(u_or_v_3_masked.into_iter(), &tables[2]).collect::>(); + recurse_u_or_v::<_, 4>(u_or_v_3_masked.into_iter(), &tables[2]).collect::>(); assert_eq!(p_final[0].as_u128(), EXPECTED_Q_FINAL); // uv values in input format let v_1 = to_field(&V_1); - let q_final_another_way = recursively_compute_final_check::( - v_1.into_iter(), + let q_final_another_way = recursively_compute_final_check::( + VerifierValues(v_1.into_iter().chunk_array::<4>()), &CHALLENGES .map(|x| Fp31::try_from(x).unwrap()) .into_iter() @@ -447,4 +558,59 @@ mod test { assert_eq!(Fp31::ZERO, g_differences[0]); } + + #[test] + fn verifier_table_indices_equivalence() { + run_random(|mut rng| async move { + let block = rng.gen::(); + + let denominator = CanonicalLagrangeDenominator::new(); + let r = rng.gen(); + let lagrange_table_r = LagrangeTable::new(&denominator, &r); + + // Test equivalence for _u_ values + assert!(VerifierTableIndices { + input: block + .rotate_right() + .table_indices_from_right_prover() + .into_iter(), + table: &TABLE_U, + } + .eval_at_r(&lagrange_table_r) + .eq(VerifierValues((0..BIT_ARRAY_LEN).map(|i| { + reference_convert( + block.x_left[i], + block.x_right[i], + block.y_left[i], + block.y_right[i], + block.prss_left[i], + block.prss_right[i], + ) + .0 + })) + .eval_at_r(&lagrange_table_r))); + + // Test equivalence for _v_ values + assert!(VerifierTableIndices { + input: block + .rotate_left() + .table_indices_from_left_prover() + .into_iter(), + table: &TABLE_V, + } + .eval_at_r(&lagrange_table_r) + .eq(VerifierValues((0..BIT_ARRAY_LEN).map(|i| { + reference_convert( + block.x_left[i], + block.x_right[i], + block.y_left[i], + block.y_right[i], + block.prss_left[i], + block.prss_right[i], + ) + .1 + })) + .eval_at_r(&lagrange_table_r))); + }); + } } diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 0ffe2adbc..63dc34a6f 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -59,9 +59,10 @@ pub(crate) mod shuffle; pub(crate) mod step; pub mod validation_protocol; -pub type FirstProofGenerator = malicious_security::prover::SmallProofGenerator; -pub type CompressedProofGenerator = malicious_security::prover::SmallProofGenerator; - +pub use malicious_security::{ + CompressedProofGenerator, FirstProofGenerator, LagrangeTable, ProverTableIndices, + VerifierTableIndices, +}; pub use shuffle::Shuffle; /// Match key type diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs index 0a658614e..e496f97e4 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs @@ -7,13 +7,13 @@ use crate::{ ff::{Fp61BitPrime, Serializable}, helpers::{Direction, MpcMessage, TotalRecords}, protocol::{ - context::{ - dzkp_field::{UVTupleBlock, BLOCK_SIZE}, - dzkp_validator::MAX_PROOF_RECURSION, - Context, - }, + context::{dzkp_validator::MAX_PROOF_RECURSION, Context}, ipa_prf::{ - malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + malicious_security::{ + lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + prover::{ProverLagrangeInput, ProverValues}, + FIRST_RECURSION_FACTOR as FRF, + }, CompressedProofGenerator, FirstProofGenerator, }, prss::SharedRandomness, @@ -81,16 +81,14 @@ impl ProofBatch { /// ## Panics /// Panics when the function fails to set the masks without overwritting `u` and `v` values. /// This only happens when there is an issue in the recursion. - pub fn generate( + pub fn generate( ctx: &C, mut prss_record_ids: RecordIdRange, - uv_tuple_inputs: I, + uv_inputs: impl ProverLagrangeInput + Clone, ) -> (Self, Self, Fp61BitPrime, Fp61BitPrime) where C: Context, - I: Iterator> + Clone, { - const FRF: usize = FirstProofGenerator::RECURSION_FACTOR; const FLL: usize = FirstProofGenerator::LAGRANGE_LENGTH; const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; const CLL: usize = CompressedProofGenerator::LAGRANGE_LENGTH; @@ -100,13 +98,19 @@ impl ProofBatch { let first_denominator = CanonicalLagrangeDenominator::::new(); let first_lagrange_table = LagrangeTable::::from(first_denominator); + let first_proof = FirstProofGenerator::compute_proof( + uv_inputs + .clone() + .extrapolate_y_values(&first_lagrange_table), + ); + // generate first proof from input iterator let (mut uv_values, first_proof_from_left, my_first_proof_left_share) = FirstProofGenerator::gen_artefacts_from_recursive_step( ctx, &mut prss_record_ids, - &first_lagrange_table, - ProofBatch::polynomials_from_inputs(uv_tuple_inputs), + first_proof, + uv_inputs, ); // `MAX_PROOF_RECURSION - 2` because: @@ -156,12 +160,14 @@ impl ProofBatch { did_set_masks = true; uv_values.set_masks(my_p_mask, my_q_mask).unwrap(); } + let my_proof = + CompressedProofGenerator::compute_proof_from_uv(uv_values.iter(), &lagrange_table); let (uv_values_new, share_of_proof_from_prover_left, my_proof_left_share) = CompressedProofGenerator::gen_artefacts_from_recursive_step( ctx, &mut prss_record_ids, - &lagrange_table, - uv_values.iter(), + my_proof, + ProverValues(uv_values.iter().copied()), ); shares_of_proofs_from_prover_left.push(share_of_proof_from_prover_left); my_proofs_left_shares.push(my_proof_left_share); @@ -225,42 +231,6 @@ impl ProofBatch { .take(length) .collect()) } - - /// This is a helper function that allows to split a `UVTupleInputs` - /// which consists of arrays of size `BLOCK_SIZE` - /// into an iterator over arrays of size `LargeProofGenerator::RECURSION_FACTOR`. - /// - /// ## Panics - /// Panics when `unwrap` panics, i.e. `try_from` fails to convert a slice to an array. - pub fn polynomials_from_inputs( - inputs: I, - ) -> impl Iterator< - Item = ( - [Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR], - [Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR], - ), - > + Clone - where - I: Iterator> + Clone, - { - assert_eq!(BLOCK_SIZE % FirstProofGenerator::RECURSION_FACTOR, 0); - inputs.flat_map(|(u_block, v_block)| { - (0usize..(BLOCK_SIZE / FirstProofGenerator::RECURSION_FACTOR)).map(move |i| { - ( - <[Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR]>::try_from( - &u_block[i * FirstProofGenerator::RECURSION_FACTOR - ..(i + 1) * FirstProofGenerator::RECURSION_FACTOR], - ) - .unwrap(), - <[Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR]>::try_from( - &v_block[i * FirstProofGenerator::RECURSION_FACTOR - ..(i + 1) * FirstProofGenerator::RECURSION_FACTOR], - ) - .unwrap(), - ) - }) - }) - } } const ARRAY_LEN: usize = FirstProofGenerator::PROOF_LENGTH @@ -297,19 +267,25 @@ impl MpcMessage for Box {} #[cfg(all(test, unit_test))] mod test { - use rand::{thread_rng, Rng}; + use std::iter::repeat_with; + + use rand::Rng; use crate::{ - ff::{Fp61BitPrime, U128Conversions}, protocol::{ - context::{dzkp_field::BLOCK_SIZE, Context}, - ipa_prf::validation_protocol::{ - proof_generation::ProofBatch, - validation::{test::simple_proof_check, BatchToVerify}, + context::Context, + ipa_prf::{ + malicious_security::{ + prover::{ProverValues, UVValues}, + FIRST_RECURSION_FACTOR, + }, + validation_protocol::{ + proof_generation::ProofBatch, + validation::{test::simple_proof_check, BatchToVerify}, + }, }, RecordId, RecordIdRange, }, - secret_sharing::replicated::ReplicatedSecretSharing, test_executor::run, test_fixture::{Runner, TestWorld}, }; @@ -319,37 +295,16 @@ mod test { run(|| async move { let world = TestWorld::default(); - let mut rng = thread_rng(); + let mut rng = world.rng(); - // each helper samples a random value h - // which is later used to generate distinct values across helpers - let h = Fp61BitPrime::truncate_from(rng.gen_range(0u128..100)); + let uv_values = repeat_with(|| (rng.gen(), rng.gen())) + .take(100) + .collect::>(); + let uv_values_iter = uv_values.iter().copied(); + let uv_values_iter_ref = &uv_values_iter; let result = world - .semi_honest(h, |ctx, h| async move { - let h = Fp61BitPrime::truncate_from(h.left().as_u128() % 100); - // generate blocks of UV values - // generate u values as (1h,2h,3h,....,10h*BlockSize) split into Blocksize chunks - // where BlockSize = 32 - // v values are identical to u - let uv_tuple_vec = (0usize..100) - .map(|i| { - ( - (BLOCK_SIZE * i..BLOCK_SIZE * (i + 1)) - .map(|j| { - Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) * h - }) - .collect::<[Fp61BitPrime; BLOCK_SIZE]>(), - (BLOCK_SIZE * i..BLOCK_SIZE * (i + 1)) - .map(|j| { - Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) * h - }) - .collect::<[Fp61BitPrime; BLOCK_SIZE]>(), - ) - }) - .collect::>(); - - // generate and output VerifierBatch together with h value + .semi_honest((), |ctx, ()| async move { let ( my_batch_left_shares, shares_of_batch_from_left_prover, @@ -358,10 +313,10 @@ mod test { ) = ProofBatch::generate( &ctx.narrow("generate_batch"), RecordIdRange::ALL, - uv_tuple_vec.into_iter(), + ProverValues(uv_values_iter_ref.clone()), ); - let batch_to_verify = BatchToVerify::generate_batch_to_verify( + BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), RecordId::FIRST, my_batch_left_shares, @@ -369,21 +324,18 @@ mod test { p_mask_from_right_prover, q_mask_from_left_prover, ) - .await; - - // generate and output VerifierBatch together with h value - (h, batch_to_verify) + .await }) .await; // proof from first party - simple_proof_check(result[0].0, &result[2].1, &result[1].1); + simple_proof_check(uv_values.iter(), &result[2], &result[1]); // proof from second party - simple_proof_check(result[1].0, &result[0].1, &result[2].1); + simple_proof_check(uv_values.iter(), &result[0], &result[2]); // proof from third party - simple_proof_check(result[2].0, &result[1].1, &result[0].1); + simple_proof_check(uv_values.iter(), &result[1], &result[0]); }); } } diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs index 3f656b3eb..e4022a60c 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs @@ -20,8 +20,11 @@ use crate::{ dzkp_validator::MAX_PROOF_RECURSION, step::DzkpProofVerifyStep as Step, Context, }, ipa_prf::{ - malicious_security::verifier::{ - compute_g_differences, recursively_compute_final_check, + malicious_security::{ + verifier::{ + compute_g_differences, recursively_compute_final_check, VerifierLagrangeInput, + }, + FIRST_RECURSION_FACTOR as FRF, }, validation_protocol::proof_generation::ProofBatch, CompressedProofGenerator, FirstProofGenerator, @@ -105,7 +108,6 @@ impl BatchToVerify { where C: Context, { - const FRF: usize = FirstProofGenerator::RECURSION_FACTOR; const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; // exclude for first proof @@ -172,21 +174,20 @@ impl BatchToVerify { v_from_left_prover: V, // Prover P_i and verifier P_{i+1} both compute `v` and `q(x)` ) -> (Fp61BitPrime, Fp61BitPrime) where - U: Iterator + Send, - V: Iterator + Send, + U: VerifierLagrangeInput + Send, + V: VerifierLagrangeInput + Send, { - const FRF: usize = FirstProofGenerator::RECURSION_FACTOR; const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; // compute p_r - let p_r_right_prover = recursively_compute_final_check::<_, _, FRF, CRF>( - u_from_right_prover.into_iter(), + let p_r_right_prover = recursively_compute_final_check::<_, CRF>( + u_from_right_prover, challenges_for_right_prover, self.p_mask_from_right_prover, ); // compute q_r - let q_r_left_prover = recursively_compute_final_check::<_, _, FRF, CRF>( - v_from_left_prover.into_iter(), + let q_r_left_prover = recursively_compute_final_check::<_, CRF>( + v_from_left_prover, challenges_for_left_prover, self.q_mask_from_left_prover, ); @@ -242,7 +243,6 @@ impl BatchToVerify { where C: Context, { - const FRF: usize = FirstProofGenerator::RECURSION_FACTOR; const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; const FPL: usize = FirstProofGenerator::PROOF_LENGTH; @@ -445,21 +445,22 @@ impl MpcMessage for ProofDiff {} #[cfg(all(test, unit_test))] pub mod test { + use std::iter::repeat_with; + use futures_util::future::try_join; - use rand::{thread_rng, Rng}; + use rand::Rng; use crate::{ - ff::{Fp61BitPrime, U128Conversions}, + ff::Fp61BitPrime, helpers::Direction, protocol::{ - context::{ - dzkp_field::{UVTupleBlock, BLOCK_SIZE}, - Context, - }, + context::Context, ipa_prf::{ malicious_security::{ lagrange::CanonicalLagrangeDenominator, - verifier::{compute_sum_share, interpolate_at_r}, + prover::{ProverValues, UVValues}, + verifier::{compute_sum_share, interpolate_at_r, VerifierValues}, + FIRST_RECURSION_FACTOR as FRF, }, validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, CompressedProofGenerator, FirstProofGenerator, @@ -467,16 +468,20 @@ pub mod test { prss::SharedRandomness, RecordId, RecordIdRange, }, - secret_sharing::{replicated::ReplicatedSecretSharing, SharedValue}, + secret_sharing::SharedValue, test_executor::run, test_fixture::{Runner, TestWorld}, + utils::arraychunks::ArrayChunkIterator, }; - /// ## Panics - /// When proof is not generated correctly. - // todo: deprecate once validation protocol is implemented - pub fn simple_proof_check( - h: Fp61BitPrime, + // This is a helper for a test in `proof_generation.rs`, but is located here so it + // can access the internals of `BatchToVerify`. + // + // Possibly this (and the associated test) can be removed now that the validation + // protocol is implemented? (There was an old todo to that effect.) But it seems + // useful to keep around as a unit test. + pub(in crate::protocol::ipa_prf::validation_protocol) fn simple_proof_check<'a>( + uv_values: impl Iterator, left_verifier: &BatchToVerify, right_verifier: &BatchToVerify, ) { @@ -512,36 +517,12 @@ pub mod test { // check first proof, // compute simple proof without lagrange interpolated points - let simple_proof = { - let block_to_polynomial = BLOCK_SIZE / FirstProofGenerator::RECURSION_FACTOR; - let simple_proof_uv = (0usize..100 * block_to_polynomial) - .map(|i| { - ( - (FirstProofGenerator::RECURSION_FACTOR * i - ..FirstProofGenerator::RECURSION_FACTOR * (i + 1)) - .map(|j| Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) * h) - .collect::<[Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR]>(), - (FirstProofGenerator::RECURSION_FACTOR * i - ..FirstProofGenerator::RECURSION_FACTOR * (i + 1)) - .map(|j| Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) * h) - .collect::<[Fp61BitPrime; FirstProofGenerator::RECURSION_FACTOR]>(), - ) - }) - .collect::>(); - - simple_proof_uv.iter().fold( - [Fp61BitPrime::ZERO; FirstProofGenerator::RECURSION_FACTOR], - |mut acc, (left, right)| { - for i in 0..FirstProofGenerator::RECURSION_FACTOR { - acc[i] += left[i] * right[i]; - } - acc - }, - ) - }; + let simple_proof = uv_values.fold([Fp61BitPrime::ZERO; FRF], |mut acc, (left, right)| { + for i in 0..FRF { + acc[i] += left[i] * right[i]; + } + acc + }); // reconstruct computed proof // by adding shares left and right @@ -554,13 +535,7 @@ pub mod test { // check for consistency // only check first R::USIZE field elements - assert_eq!( - (h.as_u128(), simple_proof.to_vec()), - ( - h.as_u128(), - proof_computed[0..FirstProofGenerator::RECURSION_FACTOR].to_vec() - ) - ); + assert_eq!(simple_proof.to_vec(), proof_computed[0..FRF].to_vec()); } #[test] @@ -568,39 +543,17 @@ pub mod test { run(|| async move { let world = TestWorld::default(); - let mut rng = thread_rng(); + let mut rng = world.rng(); - // each helper samples a random value h - // which is later used to generate distinct values across helpers - let h = Fp61BitPrime::truncate_from(rng.gen_range(0u128..100)); + let uv_values = repeat_with(|| (rng.gen(), rng.gen())) + .take(100) + .collect::>(); + let uv_values_iter = uv_values.iter().copied(); + let uv_values_iter_ref = &uv_values_iter; let [(helper_1_left, helper_1_right), (helper_2_left, helper_2_right), (helper_3_left, helper_3_right)] = world - .semi_honest(h, |ctx, h| async move { - let h = Fp61BitPrime::truncate_from(h.left().as_u128() % 100); - // generate blocks of UV values - // generate u values as (1h,2h,3h,....,10h*BlockSize) split into Blocksize chunks - // where BlockSize = 32 - // v values are identical to u - let uv_tuple_vec = (0usize..100) - .map(|i| { - ( - (BLOCK_SIZE * i..BLOCK_SIZE * (i + 1)) - .map(|j| { - Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) - * h - }) - .collect::<[Fp61BitPrime; BLOCK_SIZE]>(), - (BLOCK_SIZE * i..BLOCK_SIZE * (i + 1)) - .map(|j| { - Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) - * h - }) - .collect::<[Fp61BitPrime; BLOCK_SIZE]>(), - ) - }) - .collect::>(); - + .semi_honest((), |ctx, ()| async move { // generate and output VerifierBatch together with h value let ( my_batch_left_shares, @@ -610,7 +563,7 @@ pub mod test { ) = ProofBatch::generate( &ctx.narrow("generate_batch"), RecordIdRange::ALL, - uv_tuple_vec.into_iter(), + ProverValues(uv_values_iter_ref.clone()), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( @@ -644,31 +597,32 @@ pub mod test { /// Prover `P_i` and verifier `P_{i+1}` both generate `u` /// Prover `P_i` and verifier `P_{i-1}` both generate `v` /// - /// outputs `(my_u_and_v, u_from_right_prover, v_from_left_prover)` + /// outputs `(my_u_and_v, sum_of_uv, u_from_right_prover, v_from_left_prover)` + #[allow(clippy::type_complexity)] fn generate_u_v( ctx: &C, len: usize, ) -> ( - Vec>, + Vec<([Fp61BitPrime; FRF], [Fp61BitPrime; FRF])>, Fp61BitPrime, Vec, Vec, ) { // outputs - let mut vec_u_from_right_prover = Vec::::with_capacity(BLOCK_SIZE * len); - let mut vec_v_from_left_prover = Vec::::with_capacity(BLOCK_SIZE * len); + let mut vec_u_from_right_prover = Vec::::with_capacity(FRF * len); + let mut vec_v_from_left_prover = Vec::::with_capacity(FRF * len); let mut vec_my_u_and_v = - Vec::<([Fp61BitPrime; BLOCK_SIZE], [Fp61BitPrime; BLOCK_SIZE])>::with_capacity(len); + Vec::<([Fp61BitPrime; FRF], [Fp61BitPrime; FRF])>::with_capacity(len); let mut sum_of_uv = Fp61BitPrime::ZERO; // generate random u, v values using PRSS let mut counter = RecordId::FIRST; for _ in 0..len { - let mut my_u_array = [Fp61BitPrime::ZERO; BLOCK_SIZE]; - let mut my_v_array = [Fp61BitPrime::ZERO; BLOCK_SIZE]; - for i in 0..BLOCK_SIZE { + let mut my_u_array = [Fp61BitPrime::ZERO; FRF]; + let mut my_v_array = [Fp61BitPrime::ZERO; FRF]; + for i in 0..FRF { let (my_u, u_from_right_prover) = ctx.prss().generate_fields(counter); counter += 1; let (v_from_left_prover, my_v) = ctx.prss().generate_fields(counter); @@ -725,7 +679,7 @@ pub mod test { ) = ProofBatch::generate( &ctx.narrow("generate_batch"), RecordIdRange::ALL, - vec_my_u_and_v.into_iter(), + ProverValues(vec_my_u_and_v.into_iter()), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( @@ -831,7 +785,7 @@ pub mod test { ) = ProofBatch::generate( &ctx.narrow("generate_batch"), RecordIdRange::ALL, - vec_my_u_and_v.into_iter(), + ProverValues(vec_my_u_and_v.into_iter()), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( @@ -858,8 +812,10 @@ pub mod test { let (p, q) = batch_to_verify.compute_p_and_q_r( &challenges_for_left_prover, &challenges_for_right_prover, - vec_u_from_right_prover.into_iter(), - vec_v_from_left_prover.into_iter(), + VerifierValues( + vec_u_from_right_prover.into_iter().chunk_array::(), + ), + VerifierValues(vec_v_from_left_prover.into_iter().chunk_array::()), ); let p_times_q = @@ -921,7 +877,7 @@ pub mod test { ) = ProofBatch::generate( &ctx.narrow("generate_batch"), RecordIdRange::ALL, - vec_my_u_and_v.into_iter(), + ProverValues(vec_my_u_and_v.into_iter()), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( @@ -961,8 +917,8 @@ pub mod test { let (p, q) = batch_to_verify.compute_p_and_q_r( &challenges_for_left_prover, &challenges_for_right_prover, - vec_u_from_right_prover.into_iter(), - vec_v_from_left_prover.into_iter(), + VerifierValues(vec_u_from_right_prover.into_iter().chunk_array::()), + VerifierValues(vec_v_from_left_prover.into_iter().chunk_array::()), ); batch_to_verify @@ -989,14 +945,13 @@ pub mod test { // Test a batch that exercises the case where `uv_values.len() == 1` but `did_set_masks = // false` in `ProofBatch::generate`. // - // We divide by `BLOCK_SIZE` here because `generate_u_v`, which is used by + // We divide by `FRF` here because `generate_u_v`, which is used by // `verify_batch` to generate test data, generates `len` chunks of u/v values of - // length `BLOCK_SIZE`. We want the input u/v values to compress to exactly one + // length `FRF`. We want the input u/v values to compress to exactly one // u/v pair after some number of proof steps. - let num_inputs = FirstProofGenerator::RECURSION_FACTOR - * CompressedProofGenerator::RECURSION_FACTOR - * CompressedProofGenerator::RECURSION_FACTOR; - assert!(num_inputs % BLOCK_SIZE == 0); - verify_batch(num_inputs / BLOCK_SIZE); + let num_inputs = + FirstProofGenerator::RECURSION_FACTOR * CompressedProofGenerator::RECURSION_FACTOR; + assert!(num_inputs % FRF == 0); + verify_batch(num_inputs / FRF); } } diff --git a/ipa-core/src/secret_sharing/vector/transpose.rs b/ipa-core/src/secret_sharing/vector/transpose.rs index 3a8357f02..b020ddb04 100644 --- a/ipa-core/src/secret_sharing/vector/transpose.rs +++ b/ipa-core/src/secret_sharing/vector/transpose.rs @@ -100,7 +100,7 @@ pub trait TransposeFrom { // // From Hacker's Delight (2nd edition), Figure 7-6. // -// There are comments on `dzkp_field::convert_values_table_indices`, which implements a +// There are comments on `dzkp_field::bits_to_table_indices`, which implements a // similar transformation, that may help to understand how this works. #[inline] pub fn transpose_8x8>(x: B) -> [u8; 8] { @@ -125,7 +125,7 @@ pub fn transpose_8x8>(x: B) -> [u8; 8] { // // Loosely based on Hacker's Delight (2nd edition), Figure 7-6. // -// There are comments on `dzkp_field::convert_values_table_indices`, which implements a +// There are comments on `dzkp_field::bits_to_table_indices`, which implements a // similar transformation, that may help to understand how this works. #[inline] pub fn transpose_16x16(src: &[u8; 32]) -> [u8; 32] {