diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 9a7735e9e..552fd8b49 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -201,7 +201,7 @@ jobs: - name: Add Rust sources run: rustup component add rust-src - name: Run tests with sanitizer - run: RUSTFLAGS="-Z sanitizer=${{ matrix.sanitizer }} -Z sanitizer-memory-track-origins" cargo test -Z build-std -p ipa-core --target $TARGET --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate ${{ matrix.features }}" + run: RUSTFLAGS="-Z sanitizer=${{ matrix.sanitizer }} -Z sanitizer-memory-track-origins" cargo test -Z build-std -p ipa-core --all-targets --target $TARGET --no-default-features --features "cli web-app real-world-infra test-fixture compact-gate ${{ matrix.features }}" miri: runs-on: ubuntu-latest diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 060090172..3a8c866c0 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -143,6 +143,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" sha2 = "0.10" shuttle-crate = { package = "shuttle", version = "0.6.1", optional = true } +subtle = "2.6" thiserror = "1.0" time = { version = "0.3", optional = true } tokio = { version = "1.35", features = ["fs", "rt", "rt-multi-thread", "macros"] } diff --git a/ipa-core/benches/dzkp_convert_prover.rs b/ipa-core/benches/dzkp_convert_prover.rs index 57b557735..c8f820bab 100644 --- a/ipa-core/benches/dzkp_convert_prover.rs +++ b/ipa-core/benches/dzkp_convert_prover.rs @@ -1,41 +1,15 @@ -//! Benchmark for the convert_prover function in dzkp_field.rs. +//! Benchmark for the table_indices_prover function in dzkp_field.rs. use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; -use ipa_core::{ - ff::Fp61BitPrime, - protocol::context::{dzkp_field::DZKPBaseField, dzkp_validator::MultiplicationInputsBlock}, -}; +use ipa_core::protocol::context::dzkp_validator::MultiplicationInputsBlock; use rand::{thread_rng, Rng}; fn convert_prover_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("dzkp_convert_prover"); group.bench_function("convert", |b| { b.iter_batched_ref( - || { - // Generate input - let mut rng = thread_rng(); - - MultiplicationInputsBlock { - x_left: rng.gen::<[u8; 32]>().into(), - x_right: rng.gen::<[u8; 32]>().into(), - y_left: rng.gen::<[u8; 32]>().into(), - y_right: rng.gen::<[u8; 32]>().into(), - prss_left: rng.gen::<[u8; 32]>().into(), - prss_right: rng.gen::<[u8; 32]>().into(), - z_right: rng.gen::<[u8; 32]>().into(), - } - }, - |input| { - let MultiplicationInputsBlock { - x_left, - x_right, - y_left, - y_right, - prss_right, - .. - } = input; - Fp61BitPrime::convert_prover(x_left, x_right, y_left, y_right, prss_right); - }, + || thread_rng().gen(), + |input: &mut MultiplicationInputsBlock| input.table_indices_prover(), BatchSize::SmallInput, ) }); diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index fb6f9fdb7..6b1032f72 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -5,7 +5,7 @@ use async_trait::async_trait; use crate::{ executor::IpaRuntime, helpers::{ - query::{PrepareQuery, QueryConfig, QueryInput}, + query::{CompareStatusRequest, PrepareQuery, QueryConfig, QueryInput}, routing::{Addr, RouteId}, ApiError, BodyStream, HandlerBox, HandlerRef, HelperIdentity, HelperResponse, MpcTransportImpl, RequestHandler, ShardTransportImpl, Transport, TransportIdentity, @@ -149,8 +149,13 @@ impl HelperApp { /// /// ## Errors /// Propagates errors from the helper. - pub fn query_status(&self, query_id: QueryId) -> Result { - Ok(self.inner.query_processor.query_status(query_id)?) + pub async fn query_status(&self, query_id: QueryId) -> Result { + let shard_transport = self.inner.shard_transport.clone_ref(); + Ok(self + .inner + .query_processor + .query_status(shard_transport, query_id) + .await?) } /// Waits for a query to complete and returns the result. @@ -161,7 +166,7 @@ impl HelperApp { Ok(self .inner .query_processor - .complete(query_id) + .complete(query_id, self.inner.shard_transport.clone_ref()) .await? .to_bytes()) } @@ -177,7 +182,7 @@ impl RequestHandler for Inner { async fn handle( &self, req: Addr, - _data: BodyStream, + data: BodyStream, ) -> Result { let qp = &self.query_processor; @@ -186,6 +191,17 @@ impl RequestHandler for Inner { let req = req.into::()?; HelperResponse::from(qp.prepare_shard(&self.shard_transport, req)?) } + RouteId::QueryStatus => { + let req = req.into::()?; + HelperResponse::from(qp.shard_status(&self.shard_transport, &req)?) + } + RouteId::CompleteQuery => { + // The processing flow for this API is exactly the same, regardless + // whether it was received from a peer shard or from report collector. + // Authentication is handled on the layer above, so we erase the identity + // and pass it down to the MPC handler. + RequestHandler::::handle(self, req.erase_origin(), data).await? + } r => { return Err(ApiError::BadRequest( format!("{r:?} request must not be handled by shard query processing flow") @@ -247,11 +263,16 @@ impl RequestHandler for Inner { } RouteId::QueryStatus => { let query_id = ext_query_id(&req)?; - HelperResponse::from(qp.query_status(query_id)?) + let shard_transport = Transport::clone_ref(&self.shard_transport); + let query_status = qp.query_status(shard_transport, query_id).await?; + HelperResponse::from(query_status) } RouteId::CompleteQuery => { let query_id = ext_query_id(&req)?; - HelperResponse::from(qp.complete(query_id).await?) + HelperResponse::from( + qp.complete(query_id, self.shard_transport.clone_ref()) + .await?, + ) } RouteId::KillQuery => { let query_id = ext_query_id(&req)?; diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 7de9de90c..8641f3547 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -183,6 +183,7 @@ async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), B let shard_index = ShardIndex::from(args.shard_index.expect("enforced by clap")); let shard_count = ShardIndex::from(args.shard_count.expect("enforced by clap")); assert!(shard_index < shard_count); + assert_eq!(args.tls_cert.is_some(), !args.disable_https); let (identity, server_tls) = create_client_identity(my_identity, args.tls_cert.clone(), args.tls_key.clone())?; @@ -223,8 +224,13 @@ async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), B let network_config_path = args.network.as_deref().unwrap(); let network_config_string = &fs::read_to_string(network_config_path)?; - let (mut mpc_network, mut shard_network) = - sharded_server_from_toml_str(network_config_string, my_identity, shard_index, shard_count)?; + let (mut mpc_network, mut shard_network) = sharded_server_from_toml_str( + network_config_string, + my_identity, + shard_index, + shard_count, + args.shard_port, + )?; mpc_network = mpc_network.override_scheme(&scheme); shard_network = shard_network.override_scheme(&scheme); diff --git a/ipa-core/src/config.rs b/ipa-core/src/config.rs index 6b5df3c84..403f38b24 100644 --- a/ipa-core/src/config.rs +++ b/ipa-core/src/config.rs @@ -3,6 +3,7 @@ use std::{ fmt::{Debug, Formatter}, iter::zip, path::PathBuf, + str::FromStr, time::Duration, }; @@ -280,20 +281,21 @@ fn parse_sharded_network_toml(input: &str) -> Result /// /// # Errors /// if `input` is in an invalid format +/// +/// # Panics +/// If you somehow provide an invalid non-sharded network toml pub fn sharded_server_from_toml_str( input: &str, id: HelperIdentity, shard_index: ShardIndex, shard_count: ShardIndex, + shard_port: Option, ) -> Result<(NetworkConfig, NetworkConfig), Error> { let all_network = parse_sharded_network_toml(input)?; - let missing_urls = all_network.missing_shard_urls(); - if !missing_urls.is_empty() { - return Err(Error::MissingShardUrls(missing_urls)); - } let ix: usize = shard_index.as_index(); let ix_count: usize = shard_count.as_index(); + // assert ix < count let mpc_id: usize = id.as_index(); let mpc_network = NetworkConfig { @@ -307,21 +309,45 @@ pub fn sharded_server_from_toml_str( client: all_network.client.clone(), identities: HelperIdentity::make_three().to_vec(), }; - - let shard_network = NetworkConfig { - peers: all_network - .peers - .iter() - .map(ShardedPeerConfigToml::to_shard_peer) - .skip(mpc_id) - .step_by(3) - .take(ix_count) - .collect(), - client: all_network.client, - identities: shard_count.iter().collect(), - }; - - Ok((mpc_network, shard_network)) + let missing_urls = all_network.missing_shard_urls(); + if missing_urls.is_empty() { + let shard_network = NetworkConfig { + peers: all_network + .peers + .iter() + .map(ShardedPeerConfigToml::to_shard_peer) + .skip(mpc_id) + .step_by(3) + .take(ix_count) + .collect(), + client: all_network.client, + identities: shard_count.iter().collect(), + }; + Ok((mpc_network, shard_network)) + } else if missing_urls == [0, 1, 2] && shard_count == ShardIndex(1) { + // This is the special case we're dealing with a non-sharded, single ring MPC. + // Since the shard network will be of size 1, it can't really communicate with anyone else. + // Hence we just create a config where I'm the only shard. We take the MPC configuration + // and modify the port. + let mut myself = ShardedPeerConfigToml::to_mpc_peer(all_network.peers.get(mpc_id).unwrap()); + let url = myself.url.to_string(); + let pos = url.rfind(':'); + let port = shard_port.expect("Shard port should be set"); + let new_url = if pos.is_some() { + format!("{}{port}", &url[..=pos.unwrap()]) + } else { + format!("{}:{port}", &url) + }; + myself.url = Uri::from_str(&new_url).expect("Problem creating uri with sharded port"); + let shard_network = NetworkConfig { + peers: vec![myself], + client: all_network.client, + identities: shard_count.iter().collect(), + }; + Ok((mpc_network, shard_network)) + } else { + return Err(Error::MissingShardUrls(missing_urls)); + } } #[derive(Clone, Debug, Deserialize)] @@ -788,6 +814,7 @@ mod tests { HelperIdentity::TWO, ShardIndex::from(1), ShardIndex::from(3), + None, ) .unwrap(); assert_eq!( @@ -856,19 +883,43 @@ mod tests { )); } - /// Check that shard urls are given for [`sharded_server_from_toml_str`] or error is returned. + /// Check that [`sharded_server_from_toml_str`] can work in the previous format. #[test] fn parse_sharded_without_shard_urls() { - // Second, I test the networkconfig parsing - assert!(matches!( - sharded_server_from_toml_str( - &NON_SHARDED_COMPAT, - HelperIdentity::TWO, - ShardIndex::FIRST, - ShardIndex::from(1) - ), - Err(Error::MissingShardUrls(_)) - )); + let (mpc, mut shard) = sharded_server_from_toml_str( + &NON_SHARDED_COMPAT, + HelperIdentity::ONE, + ShardIndex::FIRST, + ShardIndex::from(1), + Some(666), + ) + .unwrap(); + assert_eq!(1, shard.peers.len()); + assert_eq!(3, mpc.peers.len()); + assert_eq!( + "helper1.org:666", + shard.peers.pop().unwrap().url.to_string() + ); + } + + /// Check that [`sharded_server_from_toml_str`] can work in the previous format, even when the + /// given MPC URL doesn't have a port (NOTE: helper 2 doesn't specify it). + #[test] + fn parse_sharded_without_shard_urls_no_port() { + let (mpc, mut shard) = sharded_server_from_toml_str( + &NON_SHARDED_COMPAT, + HelperIdentity::TWO, + ShardIndex::FIRST, + ShardIndex::from(1), + Some(666), + ) + .unwrap(); + assert_eq!(1, shard.peers.len()); + assert_eq!(3, mpc.peers.len()); + assert_eq!( + "helper2.org:666", + shard.peers.pop().unwrap().url.to_string() + ); } /// Testing happy case of a sharded network config @@ -959,6 +1010,7 @@ a6L0t1Ug8i2RcequSo21x319Tvs5nUbGwzMFSS5wKA== url = "helper1.org:443""#; /// The rest of a configuration + /// Note the second helper doesn't provide a port as part of its url const REST: &str = r#" [peers.hpke] public_key = "f458d5e1989b2b8f5dacd4143276aa81eaacf7449744ab1251ff667c43550756" @@ -977,7 +1029,7 @@ zj0EAwIDSAAwRQIhAI4G5ICVm+v5KK5Y8WVetThtNCXGykUBAM1eE973FBOUAiAS XXgJe9q9hAfHf0puZbv0j0tGY3BiqCkJJaLvK7ba+g== -----END CERTIFICATE----- """ -url = "helper2.org:443" +url = "helper2.org" [peers.hpke] public_key = "62357179868e5594372b801ddf282c8523806a868a2bff2685f66aa05ffd6c22" diff --git a/ipa-core/src/ff/accumulator.rs b/ipa-core/src/ff/accumulator.rs new file mode 100644 index 000000000..6edf5390a --- /dev/null +++ b/ipa-core/src/ff/accumulator.rs @@ -0,0 +1,442 @@ +//! Optimized multiply-accumulate for field elements. +//! +//! An add or multiply operation in a prime field can be implemented as the +//! corresponding operation over the integers, followed by a reduction modulo the prime. +//! +//! In the case of several arithmetic operations performed in sequence, it is not +//! necessary to perform the reduction after every operation. As long as an exact +//! integer result is maintained up until reducing, the reduction can be performed once +//! at the end of the sequence. +//! +//! The reduction is usually the most expensive part of a multiplication operation, even +//! when using special primes like Mersenne primes that have an efficient reduction +//! operation. +//! +//! This module implements an optimized multiply-accumulate operation for field elements +//! that defers reduction. +//! +//! To enable this optimized implementation for a field, it is necessary to (1) select +//! an accumulator type, (2) calculate the number of multiply-accumulate operations +//! that can be performed without overflowing the accumulator. Record these values +//! in an implementation of the `MultiplyAccumulate` trait. +//! +//! This module also provides a generic implementation of `MultiplyAccumulate` that +//! reduces after every operation. To use the generic implementation for `MyField`: +//! +//! ```ignore +//! impl MultiplyAccumulate for MyField { +//! type Accumulator = MyField; +//! type AccumulatorArray = [MyField; N]; +//! } +//! ``` +//! +//! Currently, an optimized implementation is supplied only for `Fp61BitPrime`, which is +//! used _extensively_ in DZKP-based malicious security. All other fields use a naive +//! implementation. (The implementation of DZKPs is also the only place that the traits +//! in this module are used; there is no reason to adopt them if not trying to access +//! the optimized implementation, or at least trying to be easily portable to an +//! optimized implementation.) +//! +//! To perform multiply-accumulate operations using this API: +//! ``` +//! use ipa_core::ff::{PrimeField, MultiplyAccumulate, MultiplyAccumulator}; +//! fn dot_product(a: &[F; N], b: &[F; N]) -> F { +//! let mut acc = ::Accumulator::new(); +//! for i in 0..N { +//! acc.multiply_accumulate(a[i], b[i]); +//! } +//! acc.take() +//! } +//! ``` + +use std::{ + array, + marker::PhantomData, + ops::{AddAssign, Mul}, +}; + +use crate::{ + ff::{Field, U128Conversions}, + secret_sharing::SharedValue, +}; + +/// Trait for multiply-accumulate operations on. +/// +/// See the module-level documentation for usage. +pub trait MultiplyAccumulator: Clone { + /// Create a new accumulator with a value of zero. + fn new() -> Self + where + Self: Sized; + + /// Performs _accumulator <- accumulator + lhs * rhs_. + fn multiply_accumulate(&mut self, lhs: F, rhs: F); + + /// Consume the accumulator and return its value. + fn take(self) -> F; + + #[cfg(test)] + fn reduce_interval() -> usize; +} + +/// Trait for multiply-accumulate operations on vectors. +/// +/// See the module-level documentation for usage. +pub trait MultiplyAccumulatorArray: Clone { + /// Create a new accumulator with a value of zero. + fn new() -> Self + where + Self: Sized; + + /// Performs _accumulator <- accumulator + lhs * rhs_. + /// + /// Each of _accumulator_, _lhs_, and _rhs_ is a vector of length `N`, and _lhs * + /// rhs_ is an element-wise product. + fn multiply_accumulate(&mut self, lhs: &[F; N], rhs: &[F; N]); + + /// Consume the accumulator and return its value. + fn take(self) -> [F; N]; + + #[cfg(test)] + fn reduce_interval() -> usize; +} + +/// Trait for values (e.g. `Field`s) that support multiply-accumulate operations. +/// +/// See the module-level documentation for usage. +pub trait MultiplyAccumulate: Sized { + type Accumulator: MultiplyAccumulator; + type AccumulatorArray: MultiplyAccumulatorArray; +} + +#[derive(Clone, Default)] +pub struct Accumulator { + value: A, + count: usize, + phantom_data: PhantomData, +} + +#[cfg(all(test, unit_test))] +impl Accumulator { + /// Return the raw accumulator value, which may not directly correspond to + /// a valid value for `F`. This is intended for tests. + fn into_raw(self) -> A { + self.value + } +} + +/// Create a new accumulator containing the specified value. +/// +/// This is currently used only by tests, and for that purpose it is sufficient to +/// access it via the concrete `Accumulator` type. To use it more generally, it would +/// need to be made part of the trait (either by adding a `From` supertrait bound, or by +/// adding a method in the trait to perform the operation). +impl From for Accumulator +where + A: Default + From, + F: U128Conversions, +{ + fn from(value: F) -> Self { + Self { + value: value.as_u128().into(), + count: 0, + phantom_data: PhantomData, + } + } +} +/// Optimized multiply-accumulate implementation that adds `REDUCE_INTERVAL` products +/// into an accumulator before reducing. +/// +/// Note that the accumulator must be large enough to hold `REDUCE_INTERVAL` products, +/// plus one additional field element. The additional field element represents the +/// output of the previous reduction, or an initial value of the accumulator. +impl MultiplyAccumulator + for Accumulator +where + A: AddAssign + Copy + Default + Mul + From + Into, + F: SharedValue + U128Conversions + MultiplyAccumulate, +{ + #[inline] + fn new() -> Self + where + Self: Sized, + { + Self::default() + } + + #[inline] + fn multiply_accumulate(&mut self, lhs: F, rhs: F) { + self.value += A::from(lhs.as_u128()) * A::from(rhs.as_u128()); + self.count += 1; + if self.count == REDUCE_INTERVAL { + // Modulo, not really a truncation. + self.value = A::from(F::truncate_from(self.value).as_u128()); + self.count = 0; + } + } + + #[inline] + fn take(self) -> F { + // Modulo, not really a truncation. + F::truncate_from(self.value) + } + + #[cfg(test)] + fn reduce_interval() -> usize { + REDUCE_INTERVAL + } +} + +/// Optimized multiply-accumulate implementation that adds `REDUCE_INTERVAL` products +/// into an accumulator before reducing. This version operates on arrays. +impl MultiplyAccumulatorArray + for Accumulator +where + A: AddAssign + Copy + Default + Mul + From + Into, + F: SharedValue + U128Conversions + MultiplyAccumulate = Self>, +{ + #[inline] + fn new() -> Self + where + Self: Sized, + { + Accumulator { + value: [A::default(); N], + count: 0, + phantom_data: PhantomData, + } + } + + #[inline] + fn multiply_accumulate(&mut self, lhs: &[F; N], rhs: &[F; N]) { + for i in 0..N { + self.value[i] += A::from(lhs[i].as_u128()) * A::from(rhs[i].as_u128()); + } + self.count += 1; + if self.count == REDUCE_INTERVAL { + // Modulo, not really a truncation. + self.value = array::from_fn(|i| A::from(F::truncate_from(self.value[i]).as_u128())); + self.count = 0; + } + } + + #[inline] + fn take(self) -> [F; N] { + // Modulo, not really a truncation. + array::from_fn(|i| F::truncate_from(self.value[i])) + } + + #[cfg(test)] + fn reduce_interval() -> usize { + REDUCE_INTERVAL + } +} + +// Unoptimized implementation usable for any field. +impl MultiplyAccumulator for F { + #[inline] + fn new() -> Self + where + Self: Sized, + { + F::ZERO + } + + #[inline] + fn multiply_accumulate(&mut self, lhs: F, rhs: F) { + *self += lhs * rhs; + } + + #[inline] + fn take(self) -> F { + self + } + + #[cfg(test)] + fn reduce_interval() -> usize { + 1 + } +} + +// Unoptimized implementation usable for any field. This version operates on arrays. +impl MultiplyAccumulatorArray for [F; N] { + #[inline] + fn new() -> Self + where + Self: Sized, + { + [F::ZERO; N] + } + + #[inline] + fn multiply_accumulate(&mut self, lhs: &[F; N], rhs: &[F; N]) { + for i in 0..N { + self[i] += lhs[i] * rhs[i]; + } + } + + #[inline] + fn take(self) -> [F; N] { + self + } + + #[cfg(test)] + fn reduce_interval() -> usize { + 1 + } +} + +#[cfg(all(test, unit_test))] +mod test { + use crate::ff::{Fp61BitPrime, MultiplyAccumulate, MultiplyAccumulator, U128Conversions}; + + // If adding optimized multiply-accumulate for an additional field, it would make + // sense to convert the freestanding `fp61bit_*` tests here to a macro and replicate + // them for the new field. + + #[test] + fn fp61bit_accum_size() { + // Test that the accumulator does not overflow before reaching REDUCE_INTERVAL. + type Accumulator = ::Accumulator; + let max = -Fp61BitPrime::from_bit(true); + let mut acc = Accumulator::from(max); + for _ in 0..Accumulator::reduce_interval() - 1 { + acc.multiply_accumulate(max, max); + } + + let expected = max.as_u128() + + u128::try_from(Accumulator::reduce_interval() - 1).unwrap() + * (max.as_u128() * max.as_u128()); + assert_eq!(acc.clone().into_raw(), expected); + + assert_eq!(acc.take(), Fp61BitPrime::truncate_from(expected)); + + // Test that the largest value the accumulator should ever hold (which is not + // visible through the API, because it will be reduced immediately) does not + // overflow. + let _ = max.as_u128() + + u128::try_from(Accumulator::reduce_interval()).unwrap() + * (max.as_u128() * max.as_u128()); + } + + #[test] + fn fp61bit_accum_reduction() { + // Test that the accumulator reduces after reaching the specified interval. + // (This test assumes that the implementation (1) sets REDUCE_INTERVAL as large + // as possible, and (2) fully reduces upon reaching REDUCE_INTERVAL. It is + // possible for a correct implementation not to have these properties. If + // adding such an implementation, this test will need to be adjusted.) + type Accumulator = ::Accumulator; + let max = -Fp61BitPrime::from_bit(true); + let mut acc = Accumulator::new(); + for _ in 0..Accumulator::reduce_interval() { + acc.multiply_accumulate(max, max); + } + + let expected = Fp61BitPrime::truncate_from( + u128::try_from(Accumulator::reduce_interval()).unwrap() + * (max.as_u128() * max.as_u128()), + ); + assert_eq!(acc.clone().into_raw(), expected.as_u128()); + assert_eq!(acc.take(), expected); + } + + #[macro_export] + macro_rules! accum_tests { + ($field:ty) => { + mod accum_tests { + use std::iter::zip; + + use proptest::prelude::*; + + use $crate::{ + ff::{MultiplyAccumulate, MultiplyAccumulator, MultiplyAccumulatorArray}, + test_executor::run_random, + }; + use super::*; + + const SZ: usize = 2; + + #[test] + fn accum_simple() { + run_random(|mut rng| async move { + let a = rng.gen(); + let b = rng.gen(); + let c = rng.gen(); + let d = rng.gen(); + + let mut acc = <$field as MultiplyAccumulate>::Accumulator::new(); + acc.multiply_accumulate(a, b); + acc.multiply_accumulate(c, d); + + assert_eq!(acc.take(), a * b + c * d); + }); + } + + prop_compose! { + fn arb_inputs(max_len: usize) + (len in 0..max_len) + ( + lhs in prop::collection::vec(any::<$field>(), len), + rhs in prop::collection::vec(any::<$field>(), len), + ) + -> (Vec<$field>, Vec<$field>) { + (lhs, rhs) + } + } + + proptest::proptest! { + #[test] + fn accum_proptest((lhs, rhs) in arb_inputs( + 10 * <$field as MultiplyAccumulate>::Accumulator::reduce_interval() + )) { + type Accumulator = <$field as MultiplyAccumulate>::Accumulator; + let mut acc = Accumulator::new(); + let mut expected = <$field>::ZERO; + for (lhs_term, rhs_term) in zip(lhs, rhs) { + acc.multiply_accumulate(lhs_term, rhs_term); + expected += lhs_term * rhs_term; + } + assert_eq!(acc.take(), expected); + } + } + + prop_compose! { + fn arb_array_inputs(max_len: usize) + (len in 0..max_len) + ( + lhs in prop::collection::vec( + prop::array::uniform(any::<$field>()), + len, + ), + rhs in prop::collection::vec( + prop::array::uniform(any::<$field>()), + len, + ), + ) + -> (Vec<[$field; SZ]>, Vec<[$field; SZ]>) { + (lhs, rhs) + } + } + + proptest::proptest! { + #[test] + fn accum_array_proptest((lhs, rhs) in arb_array_inputs( + 10 * <$field as MultiplyAccumulate>::AccumulatorArray::::reduce_interval() + )) { + type Accumulator = <$field as MultiplyAccumulate>::AccumulatorArray::; + let mut acc = Accumulator::new(); + let mut expected = [<$field>::ZERO; SZ]; + for (lhs_arr, rhs_arr) in zip(lhs, rhs) { + acc.multiply_accumulate(&lhs_arr, &rhs_arr); + for i in 0..SZ { + expected[i] += lhs_arr[i] * rhs_arr[i]; + } + } + assert_eq!(acc.take(), expected); + } + } + } + } + } +} diff --git a/ipa-core/src/ff/boolean.rs b/ipa-core/src/ff/boolean.rs index 3296e3a4e..b76699587 100644 --- a/ipa-core/src/ff/boolean.rs +++ b/ipa-core/src/ff/boolean.rs @@ -5,7 +5,7 @@ use generic_array::GenericArray; use typenum::U1; use crate::{ - ff::{ArrayAccess, Field, PrimeField, Serializable, U128Conversions}, + ff::{ArrayAccess, Field, MultiplyAccumulate, PrimeField, Serializable, U128Conversions}, impl_shared_value_common, protocol::{ context::{dzkp_field::DZKPCompatibleField, dzkp_validator::SegmentEntry}, @@ -58,6 +58,13 @@ impl PrimeField for Boolean { const PRIME: Self::PrimeInteger = 2; } +// Note: The multiply-accumulate tests are not currently instantiated for `Boolean`. +impl MultiplyAccumulate for Boolean { + type Accumulator = Boolean; + // This could be specialized with a bit vector type if it ever mattered. + type AccumulatorArray = [Boolean; N]; +} + impl SharedValue for Boolean { type Storage = bool; const BITS: u32 = 1; diff --git a/ipa-core/src/ff/ec_prime_field.rs b/ipa-core/src/ff/ec_prime_field.rs index 64de59f9e..893e56007 100644 --- a/ipa-core/src/ff/ec_prime_field.rs +++ b/ipa-core/src/ff/ec_prime_field.rs @@ -2,10 +2,11 @@ use std::convert::Infallible; use curve25519_dalek::scalar::Scalar; use generic_array::GenericArray; +use subtle::{Choice, ConstantTimeEq}; use typenum::{U2, U32}; use crate::{ - ff::{boolean_array::BA256, Field, Serializable}, + ff::{boolean_array::BA256, Field, MultiplyAccumulate, Serializable}, impl_shared_value_common, protocol::{ ipa_prf::PRF_CHUNK, @@ -75,6 +76,12 @@ impl Serializable for Fp25519 { } } +impl ConstantTimeEq for Fp25519 { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + ///generate random elements in Fp25519 impl rand::distributions::Distribution for rand::distributions::Standard { fn sample(&self, rng: &mut R) -> Fp25519 { @@ -219,6 +226,12 @@ impl ExtendableField for Fp25519 { } } +// Note: The multiply-accumulate tests are not currently instantiated for `Boolean`. +impl MultiplyAccumulate for Fp25519 { + type Accumulator = Fp25519; + type AccumulatorArray = [Fp25519; N]; +} + impl FromRandom for Fp25519 { type SourceLength = U2; diff --git a/ipa-core/src/ff/field.rs b/ipa-core/src/ff/field.rs index 186e81c8a..f34667161 100644 --- a/ipa-core/src/ff/field.rs +++ b/ipa-core/src/ff/field.rs @@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize}; use typenum::{U1, U4, U8}; use crate::{ + ff::MultiplyAccumulate, protocol::prss::FromRandom, secret_sharing::{Block, FieldVectorizable, SharedValue, Vectorizable}, }; @@ -30,6 +31,7 @@ pub trait Field: SharedValue + Mul + MulAssign + + MultiplyAccumulate + FromRandom + Into + Vectorizable<1> diff --git a/ipa-core/src/ff/galois_field.rs b/ipa-core/src/ff/galois_field.rs index 3a84f2b2a..51181ae3e 100644 --- a/ipa-core/src/ff/galois_field.rs +++ b/ipa-core/src/ff/galois_field.rs @@ -8,11 +8,12 @@ use bitvec::{ prelude::{bitarr, BitArr, Lsb0}, }; use generic_array::GenericArray; +use subtle::{Choice, ConstantTimeEq}; use typenum::{Unsigned, U1, U2, U3, U4, U5}; use crate::{ error::LengthError, - ff::{boolean_array::NonZeroPadding, Field, Serializable, U128Conversions}, + ff::{boolean_array::NonZeroPadding, Field, MultiplyAccumulate, Serializable, U128Conversions}, impl_serializable_trait, impl_shared_value_common, protocol::prss::FromRandomU128, secret_sharing::{Block, FieldVectorizable, SharedValue, StdArray, Vectorizable}, @@ -179,6 +180,12 @@ macro_rules! bit_array_impl { const ONE: Self = Self($one); } + // Note: The multiply-accumulate tests are not currently instantiated for Galois fields. + impl MultiplyAccumulate for $name { + type Accumulator = $name; + type AccumulatorArray = [$name; N]; + } + impl U128Conversions for $name { fn as_u128(&self) -> u128 { (*self).into() @@ -227,6 +234,18 @@ macro_rules! bit_array_impl { const POLYNOMIAL: u128 = $polynomial; } + // If the field value fits in a machine word, a naive comparison should be fine. + // But this impl is important for `[T]`, and useful to document where a + // constant-time compare is intended. + impl ConstantTimeEq for $name { + fn ct_eq(&self, other: &Self) -> Choice { + // Note that this will compare the padding bits. That should not be + // a problem, because we should not allow the padding bits to become + // non-zero. + self.0.as_raw_slice().ct_eq(&other.0.as_raw_slice()) + } + } + impl rand::distributions::Distribution<$name> for rand::distributions::Standard { fn sample(&self, rng: &mut R) -> $name { <$name>::truncate_from(rng.gen::()) diff --git a/ipa-core/src/ff/mod.rs b/ipa-core/src/ff/mod.rs index c53476db1..3d2ffb187 100644 --- a/ipa-core/src/ff/mod.rs +++ b/ipa-core/src/ff/mod.rs @@ -2,6 +2,7 @@ // // This is where we store arithmetic shared secret data models. +mod accumulator; pub mod boolean; pub mod boolean_array; pub mod curve_points; @@ -15,6 +16,7 @@ use std::{ ops::{Add, AddAssign, Sub, SubAssign}, }; +pub use accumulator::{MultiplyAccumulate, MultiplyAccumulator, MultiplyAccumulatorArray}; pub use field::{Field, FieldType}; pub use galois_field::{GaloisField, Gf2, Gf20Bit, Gf32Bit, Gf3Bit, Gf40Bit, Gf8Bit, Gf9Bit}; use generic_array::{ArrayLength, GenericArray}; diff --git a/ipa-core/src/ff/prime_field.rs b/ipa-core/src/ff/prime_field.rs index ecc8f5466..de07710b7 100644 --- a/ipa-core/src/ff/prime_field.rs +++ b/ipa-core/src/ff/prime_field.rs @@ -1,11 +1,15 @@ use std::{fmt::Display, mem}; use generic_array::GenericArray; +use subtle::{Choice, ConstantTimeEq}; use super::Field; use crate::{ const_assert, - ff::{Serializable, U128Conversions}, + ff::{ + accumulator::{Accumulator, MultiplyAccumulate}, + Serializable, U128Conversions, + }, impl_shared_value_common, protocol::prss::FromRandomU128, secret_sharing::{Block, FieldVectorizable, SharedValue, StdArray, Vectorizable}, @@ -265,6 +269,15 @@ macro_rules! field_impl { } } + // If the field value fits in a machine word, a naive comparison should be fine. + // But this impl is important for `[T]`, and useful to document where a + // constant-time compare is intended. + impl ConstantTimeEq for $field { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } + } + impl rand::distributions::Distribution<$field> for rand::distributions::Standard { fn sample(&self, rng: &mut R) -> $field { <$field>::truncate_from(rng.gen::()) @@ -427,9 +440,17 @@ mod fp31 { field_impl! { Fp31, u8, u16, 8, 31 } rem_modulo_impl! { Fp31, u16 } + impl MultiplyAccumulate for Fp31 { + type Accumulator = Fp31; + type AccumulatorArray = [Fp31; N]; + } + #[cfg(all(test, unit_test))] mod specialized_tests { use super::*; + use crate::accum_tests; + + accum_tests!(Fp31); #[test] fn fp31() { @@ -459,9 +480,17 @@ mod fp32bit { type ArrayAlias = StdArray; } + impl MultiplyAccumulate for Fp32BitPrime { + type Accumulator = Fp32BitPrime; + type AccumulatorArray = [Fp32BitPrime; N]; + } + #[cfg(all(test, unit_test))] mod specialized_tests { use super::*; + use crate::accum_tests; + + accum_tests!(Fp32BitPrime); #[test] fn thirty_two_bit_prime() { @@ -508,6 +537,14 @@ mod fp32bit { mod fp61bit { field_impl! { Fp61BitPrime, u64, u128, 61, 2_305_843_009_213_693_951 } + // For multiply-accumulate of `Fp61BitPrime` using `u128` as the accumulator, we can add 64 + // products onto an original field element before reduction is necessary. i.e., 64 * (2^61 - 2)^2 + + // (2^61 - 2) < 2^128. + impl MultiplyAccumulate for Fp61BitPrime { + type Accumulator = Accumulator; + type AccumulatorArray = Accumulator; + } + impl Fp61BitPrime { #[must_use] pub const fn const_truncate(input: u64) -> Self { @@ -555,6 +592,12 @@ mod fp61bit { use proptest::proptest; use super::*; + use crate::accum_tests; + + // Note: besides the tests generated with this macro, there are some additional + // tests for the optimized accumulator for `Fp61BitPrime` in the `accumulator` + // module. + accum_tests!(Fp61BitPrime); // copied from 32 bit prime field, adjusted wrap arounds, computed using wolframalpha.com #[test] diff --git a/ipa-core/src/helpers/cross_shard_prss.rs b/ipa-core/src/helpers/cross_shard_prss.rs index 9328d626c..53ebf5f45 100644 --- a/ipa-core/src/helpers/cross_shard_prss.rs +++ b/ipa-core/src/helpers/cross_shard_prss.rs @@ -18,7 +18,7 @@ use crate::{ /// ## Errors /// If shard communication channels fail #[allow(dead_code)] // until this is used in real sharded protocol -async fn gen_and_distribute( +pub async fn gen_and_distribute( gateway: &Gateway, gate: &Gate, prss: R, diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index 3135bd5fe..c14bfb0fd 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -28,7 +28,7 @@ use crate::{ ShardChannelId, TotalRecords, Transport, }, protocol::QueryId, - sharding::ShardIndex, + sharding::{ShardConfiguration, ShardIndex}, sync::{Arc, Mutex}, utils::NonZeroU32PowerOfTwo, }; @@ -106,6 +106,27 @@ pub struct GatewayConfig { pub progress_check_interval: std::time::Duration, } +impl ShardConfiguration for Gateway { + fn shard_id(&self) -> ShardIndex { + ShardConfiguration::shard_id(&self) + } + + fn shard_count(&self) -> ShardIndex { + ShardConfiguration::shard_count(&self) + } +} + +impl ShardConfiguration for &Gateway { + fn shard_id(&self) -> ShardIndex { + self.transports.shard.identity() + } + + fn shard_count(&self) -> ShardIndex { + // total number of shards include this instance and all its peers, so we add 1. + ShardIndex::from(self.transports.shard.peer_count() + 1) + } +} + impl Gateway { #[must_use] pub fn new( diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index d1b88af9f..bbd913031 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -25,6 +25,7 @@ use crate::{ labels::{ROLE, STEP}, metrics::{BYTES_SENT, RECORDS_SENT}, }, + utils::non_zero_prev_power_of_two, }; /// Sending end of the gateway channel. @@ -256,14 +257,6 @@ impl SendChannelConfig { total_records: TotalRecords, record_size: usize, ) -> Self { - // this computes the greatest positive power of 2 that is - // less than or equal to target. - fn non_zero_prev_power_of_two(target: usize) -> usize { - let bits = usize::BITS - target.leading_zeros(); - - 1 << (std::cmp::max(1, bits) - 1) - } - assert!(record_size > 0, "Message size cannot be 0"); let total_capacity = gateway_config.active.get() * record_size; diff --git a/ipa-core/src/helpers/gateway/stall_detection.rs b/ipa-core/src/helpers/gateway/stall_detection.rs index 4a844386f..30d641a4b 100644 --- a/ipa-core/src/helpers/gateway/stall_detection.rs +++ b/ipa-core/src/helpers/gateway/stall_detection.rs @@ -78,7 +78,7 @@ mod gateway { Role, RoleAssignment, SendingEnd, ShardChannelId, ShardReceivingEnd, TotalRecords, }, protocol::QueryId, - sharding::ShardIndex, + sharding::{ShardConfiguration, ShardIndex}, sync::Arc, utils::NonZeroU32PowerOfTwo, }; @@ -207,6 +207,16 @@ mod gateway { } } + impl ShardConfiguration for &Observed { + fn shard_id(&self) -> ShardIndex { + self.inner().gateway.shard_id() + } + + fn shard_count(&self) -> ShardIndex { + self.inner().gateway.shard_count() + } + } + pub struct GatewayWaitingTasks { mpc_send: Option, mpc_recv: Option, diff --git a/ipa-core/src/helpers/gateway/transport.rs b/ipa-core/src/helpers/gateway/transport.rs index dfbc9d328..09a1053cb 100644 --- a/ipa-core/src/helpers/gateway/transport.rs +++ b/ipa-core/src/helpers/gateway/transport.rs @@ -46,6 +46,10 @@ impl Transport for RoleResolvingTransport { Role::all().iter().filter(move |&v| v != &this).copied() } + fn peer_count(&self) -> u32 { + self.inner.peer_count() + } + async fn send< D: Stream> + Send + 'static, Q: QueryIdBinding, diff --git a/ipa-core/src/helpers/hashing.rs b/ipa-core/src/helpers/hashing.rs index 741958c65..eaf1ac3ff 100644 --- a/ipa-core/src/helpers/hashing.rs +++ b/ipa-core/src/helpers/hashing.rs @@ -5,6 +5,7 @@ use sha2::{ digest::{Output, OutputSizeUser}, Digest, Sha256, }; +use subtle::{Choice, ConstantTimeEq}; use crate::{ ff::{PrimeField, Serializable}, @@ -15,6 +16,12 @@ use crate::{ #[derive(Clone, Debug, Default, PartialEq)] pub struct Hash(Output); +impl ConstantTimeEq for Hash { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.as_slice().ct_eq(other.0.as_slice()) + } +} + impl Serializable for Hash { type Size = ::OutputSize; diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index 370c42b05..2b8e27868 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -56,6 +56,7 @@ mod gateway_exports { pub type ShardReceivingEnd = gateway::ShardReceivingEnd; } +pub use cross_shard_prss::gen_and_distribute as setup_cross_shard_prss; pub use gateway::GatewayConfig; // TODO: this type should only be available within infra. Right now several infra modules // are exposed at the root level. That makes it impossible to have a proper hierarchy here. @@ -71,6 +72,7 @@ pub use transport::WrappedAxumBodyStream; #[cfg(feature = "in-memory-infra")] pub use transport::{ config as in_memory_config, InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport, + InMemoryTransportError, }; pub use transport::{ make_owned_handler, query, routing, ApiError, BodyStream, BroadcastError, BytesStream, diff --git a/ipa-core/src/helpers/transport/in_memory/mod.rs b/ipa-core/src/helpers/transport/in_memory/mod.rs index a5c34bba5..1f23eb278 100644 --- a/ipa-core/src/helpers/transport/in_memory/mod.rs +++ b/ipa-core/src/helpers/transport/in_memory/mod.rs @@ -3,8 +3,8 @@ mod sharding; mod transport; pub use sharding::InMemoryShardNetwork; -pub use transport::Setup; use transport::TransportConfigBuilder; +pub use transport::{Error as InMemoryTransportError, Setup}; use crate::{ helpers::{ diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index 93ac7a523..a2a1abea6 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -373,10 +373,11 @@ mod tests { }, routing::RouteId, }, - HandlerBox, HelperIdentity, HelperResponse, OrderingSender, Role, RoleAssignment, - Transport, TransportIdentity, + HandlerBox, HelperIdentity, HelperResponse, InMemoryShardNetwork, OrderingSender, Role, + RoleAssignment, Transport, TransportIdentity, }, protocol::{Gate, QueryId}, + sharding::ShardIndex, sync::Arc, }; @@ -625,6 +626,27 @@ mod tests { // must be received by now assert_eq!(vec![vec![0, 1]], recv.collect::>().await); } + + #[tokio::test] + async fn peer_count() { + let mpc_network = InMemoryMpcNetwork::default(); + assert_eq!(2, mpc_network.transport(HelperIdentity::ONE).peer_count()); + assert_eq!(2, mpc_network.transport(HelperIdentity::TWO).peer_count()); + + let shard_network = InMemoryShardNetwork::with_shards(5); + assert_eq!( + 4, + shard_network + .transport(HelperIdentity::ONE, ShardIndex::FIRST) + .peer_count() + ); + assert_eq!( + 4, + shard_network + .transport(HelperIdentity::TWO, ShardIndex::from(4)) + .peer_count() + ); + } } pub struct TransportConfig { diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index b3cfb862f..6b8341966 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -27,7 +27,9 @@ pub use handler::{ make_owned_handler, Error as ApiError, HandlerBox, HandlerRef, HelperResponse, RequestHandler, }; #[cfg(feature = "in-memory-infra")] -pub use in_memory::{config, InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport}; +pub use in_memory::{ + config, InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport, InMemoryTransportError, +}; use ipa_metrics::LabelValue; pub use receive::{LogErrors, ReceiveRecords}; #[cfg(feature = "web-app")] @@ -304,7 +306,7 @@ impl From> for BroadcastError pub trait Transport: Clone + Send + Sync + 'static { type Identity: TransportIdentity; type RecordsStream: BytesStream; - type Error: std::fmt::Debug + Send; + type Error: Debug + Send; /// Return my identity in the network (MPC or Sharded) fn identity(&self) -> Self::Identity; @@ -312,6 +314,13 @@ pub trait Transport: Clone + Send + Sync + 'static { /// Returns all the other identities, besides me, in this network. fn peers(&self) -> impl Iterator; + /// The number of peers on the network. Default implementation may not be efficient, + /// because it uses [`Self::peers`] to count, so implementations are encouraged to + /// override it + fn peer_count(&self) -> u32 { + u32::try_from(self.peers().count()).expect("Number of peers is less than 4B") + } + /// Sends a new request to the given destination helper party. /// Depending on the specific request, it may or may not require acknowledgment by the remote /// party diff --git a/ipa-core/src/helpers/transport/query/mod.rs b/ipa-core/src/helpers/transport/query/mod.rs index ac70209b3..e491fcb2a 100644 --- a/ipa-core/src/helpers/transport/query/mod.rs +++ b/ipa-core/src/helpers/transport/query/mod.rs @@ -15,6 +15,7 @@ use crate::{ RoleAssignment, RouteParams, }, protocol::QueryId, + query::QueryStatus, }; #[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Serialize)] @@ -194,6 +195,33 @@ impl Debug for QueryInput { } } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[cfg_attr(test, derive(PartialEq, Eq))] +pub struct CompareStatusRequest { + pub query_id: QueryId, + pub status: QueryStatus, +} + +impl RouteParams for CompareStatusRequest { + type Params = String; + + fn resource_identifier(&self) -> RouteId { + RouteId::QueryStatus + } + + fn query_id(&self) -> QueryId { + self.query_id + } + + fn gate(&self) -> NoStep { + NoStep + } + + fn extra(&self) -> Self::Params { + serde_json::to_string(self).unwrap() + } +} + #[derive(Copy, Clone, Debug, Serialize, Deserialize)] #[cfg_attr(test, derive(PartialEq, Eq))] pub enum QueryType { diff --git a/ipa-core/src/helpers/transport/routing.rs b/ipa-core/src/helpers/transport/routing.rs index 3d9c2bb5f..e851865ce 100644 --- a/ipa-core/src/helpers/transport/routing.rs +++ b/ipa-core/src/helpers/transport/routing.rs @@ -8,7 +8,7 @@ use crate::{ }; // The type of request made to an MPC helper. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum RouteId { Records, ReceiveQuery, @@ -49,6 +49,18 @@ impl Addr { } } + /// Drop the origin value and convert this into a request for a different identity type. + /// Useful when we need to handle this request in both shard and MPC handlers. + pub fn erase_origin(self) -> Addr { + Addr { + route: self.route, + origin: None, + query_id: self.query_id, + gate: self.gate, + params: self.params, + } + } + /// Deserializes JSON-encoded request parameters into a client-supplied type `T`. /// /// ## Errors diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index 0bcead345..ce719b9ed 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -372,6 +372,18 @@ impl IpaHttpClient { let resp = self.request(req).await?; resp_ok(resp).await } + + /// Complete query API can be called on the leader shard by the report collector or + /// by the leader shard to other shards. + /// + /// # Errors + /// If the request has illegal arguments, or fails to be delivered + pub async fn complete_query(&self, query_id: QueryId) -> Result<(), Error> { + let req = http_serde::query::results::Request::new(query_id); + let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?; + let resp = self.request(req).await?; + resp_ok(resp).await + } } impl IpaHttpClient { diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index 1965c15ce..85e63a144 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -507,12 +507,10 @@ pub mod query { } impl Request { - #[cfg(any(all(test, not(feature = "shuttle")), feature = "cli"))] // needed because client is blocking; remove when non-blocking pub fn new(query_id: QueryId) -> Self { Self { query_id } } - #[cfg(any(all(test, not(feature = "shuttle")), feature = "cli"))] // needed because client is blocking; remove when non-blocking pub fn try_into_http_request( self, scheme: axum::http::uri::Scheme, diff --git a/ipa-core/src/net/mod.rs b/ipa-core/src/net/mod.rs index 05c3e69be..621365077 100644 --- a/ipa-core/src/net/mod.rs +++ b/ipa-core/src/net/mod.rs @@ -24,7 +24,7 @@ pub mod test; mod transport; pub use client::{ClientIdentity, IpaHttpClient}; -pub use error::Error; +pub use error::{Error, ShardError}; pub use server::{IpaHttpServer, TracingSpanMaker}; pub use transport::{HttpTransport, MpcHttpTransport, ShardHttpTransport}; diff --git a/ipa-core/src/net/server/handlers/query/mod.rs b/ipa-core/src/net/server/handlers/query/mod.rs index fdd8935ae..6e95c4cd4 100644 --- a/ipa-core/src/net/server/handlers/query/mod.rs +++ b/ipa-core/src/net/server/handlers/query/mod.rs @@ -19,9 +19,12 @@ use futures_util::{ use hyper::{Request, StatusCode}; use tower::{layer::layer_fn, Service}; -use crate::net::{ - server::ClientIdentity, transport::MpcHttpTransport, ConnectionFlavor, Helper, Shard, - ShardHttpTransport, +use crate::{ + net::{ + server::ClientIdentity, transport::MpcHttpTransport, ConnectionFlavor, Helper, Shard, + ShardHttpTransport, + }, + sync::Arc, }; /// Construct router for IPA query web service @@ -35,7 +38,7 @@ pub fn query_router(transport: MpcHttpTransport) -> Router { .merge(input::router(transport.clone())) .merge(status::router(transport.clone())) .merge(kill::router(transport.clone())) - .merge(results::router(transport)) + .merge(results::router(transport.inner_transport)) } /// Construct router for helper-to-helper communications @@ -55,7 +58,8 @@ pub fn h2h_router(transport: MpcHttpTransport) -> Router { /// Construct router for shard-to-shard communications similar to [`h2h_router`]. pub fn s2s_router(transport: ShardHttpTransport) -> Router { Router::new() - .merge(prepare::router(transport.inner_transport)) + .merge(prepare::router(Arc::clone(&transport.inner_transport))) + .merge(results::router(transport.inner_transport)) .layer(layer_fn(HelperAuthentication::<_, Shard>::new)) } diff --git a/ipa-core/src/net/server/handlers/query/results.rs b/ipa-core/src/net/server/handlers/query/results.rs index 1c359b659..2f2a59f5e 100644 --- a/ipa-core/src/net/server/handlers/query/results.rs +++ b/ipa-core/src/net/server/handlers/query/results.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use axum::{extract::Path, routing::get, Extension, Router}; use hyper::StatusCode; @@ -6,27 +8,30 @@ use crate::{ net::{ http_serde::{self, query::results::Request}, server::Error, - transport::MpcHttpTransport, + ConnectionFlavor, HttpTransport, }, protocol::QueryId, }; /// Handles the completion of the query by blocking the sender until query is completed. -async fn handler( - transport: Extension, +async fn handler( + transport: Extension>>, Path(query_id): Path, ) -> Result, Error> { let req = Request { query_id }; // TODO: we may be able to stream the response - match transport.dispatch(req, BodyStream::empty()).await { + match Arc::clone(&transport) + .dispatch(req, BodyStream::empty()) + .await + { Ok(resp) => Ok(resp.into_body()), Err(e) => Err(Error::application(StatusCode::INTERNAL_SERVER_ERROR, e)), } } -pub fn router(transport: MpcHttpTransport) -> Router { +pub fn router(transport: Arc>) -> Router { Router::new() - .route(http_serde::query::results::AXUM_PATH, get(handler)) + .route(http_serde::query::results::AXUM_PATH, get(handler::)) .layer(Extension(transport)) } diff --git a/ipa-core/src/net/test.rs b/ipa-core/src/net/test.rs index f4e8a087c..c74c25610 100644 --- a/ipa-core/src/net/test.rs +++ b/ipa-core/src/net/test.rs @@ -279,6 +279,10 @@ impl TestConfig { &self.rings[0] } + pub fn rings(&self) -> impl Iterator> { + self.rings.iter() + } + /// Gets a ref to the entire shard network for a specific helper. #[must_use] pub fn get_shards_for_helper(&self, id: HelperIdentity) -> &TestNetwork { @@ -304,7 +308,6 @@ impl TestConfig { shards, } } - /// Transforms this easy to modify configuration into an easy to run [`TestApp`]. #[must_use] pub fn into_apps(self) -> Vec { diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 6ed523093..de632c48e 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -106,10 +106,14 @@ impl HttpTransport { let req = serde_json::from_str(route.extra().borrow()).unwrap(); self.clients[client_ix].prepare_query(req).await } + RouteId::CompleteQuery => { + let query_id = >::from(route.query_id()) + .expect("query_id is required to call complete query API"); + self.clients[client_ix].complete_query(query_id).await + } evt @ (RouteId::QueryInput | RouteId::ReceiveQuery | RouteId::QueryStatus - | RouteId::CompleteQuery | RouteId::KillQuery) => { unimplemented!( "attempting to send client-specific request {evt:?} to another helper" @@ -266,6 +270,10 @@ impl Transport for MpcHttpTransport { .filter(move |&id| id != this) } + fn peer_count(&self) -> u32 { + 2 + } + async fn send< D: Stream> + Send + 'static, Q: QueryIdBinding, @@ -336,6 +344,10 @@ impl Transport for ShardHttpTransport { self.shard_count.iter().filter(move |&v| v != this) } + fn peer_count(&self) -> u32 { + u32::from(self.shard_count).saturating_sub(1) + } + async fn send( &self, dest: Self::Identity, @@ -478,34 +490,54 @@ mod tests { } async fn test_make_helpers(conf: TestConfig) { - let clients = IpaHttpClient::from_conf( - &IpaRuntime::current(), - &conf.leaders_ring().network, - &ClientIdentity::None, - ); + let clients = conf + .rings() + .map(|test_network| { + IpaHttpClient::from_conf( + &IpaRuntime::current(), + &test_network.network, + &ClientIdentity::None, + ) + }) + .collect::>(); + let _helpers = make_helpers(conf).await; - test_multiply(&clients).await; + test_multiply_single_shard(&clients).await; } #[tokio::test(flavor = "multi_thread")] async fn happy_case_twice() { let conf = TestConfigBuilder::default().build(); - let clients = IpaHttpClient::from_conf( - &IpaRuntime::current(), - &conf.leaders_ring().network, - &ClientIdentity::None, - ); + let clients = conf + .rings() + .map(|test_network| { + IpaHttpClient::from_conf( + &IpaRuntime::current(), + &test_network.network, + &ClientIdentity::None, + ) + }) + .collect::>(); let _helpers = make_helpers(conf).await; - test_multiply(&clients).await; - test_multiply(&clients).await; + test_multiply_single_shard(&clients).await; + test_multiply_single_shard(&clients).await; } - async fn test_multiply(clients: &[IpaHttpClient; 3]) { + /// This executes test multiplication protocol by running it exclusively on the leader shards. + /// If there is more than one shard in the system, they receive no inputs but still participate + /// by doing the full query cycle. It is backward compatible with traditional 3-party MPC with + /// no shards. + /// + /// The sharding requires some amendments to the test multiplication protocol that are + /// currently in progress. Once completed, this test can be fixed by fully utilizing all + /// shards in the system. + async fn test_multiply_single_shard(clients: &[[IpaHttpClient; 3]]) { const SZ: usize = as Serializable>::Size::USIZE; + let leader_ring_clients = &clients[0]; // send a create query command - let leader_client = &clients[0]; + let leader_client = &leader_ring_clients[0]; let create_data = QueryConfig::new(TestMultiply, FieldType::Fp31, 1).unwrap(); // create query @@ -528,11 +560,24 @@ mod tests { query_id, input_stream, }; - handle_resps.push(clients[i].query_input(data)); + handle_resps.push(leader_ring_clients[i].query_input(data)); } try_join_all(handle_resps).await.unwrap(); - let result: [_; 3] = join_all(clients.clone().map(|client| async move { + // shards receive their own input - in this case empty. + // convention - first client is shard leader, and we submitted the inputs to it. + try_join_all(clients.iter().skip(1).map(|ring| { + try_join_all(ring.each_ref().map(|shard_client| { + shard_client.query_input(QueryInput { + query_id, + input_stream: BodyStream::empty(), + }) + })) + })) + .await + .unwrap(); + + let result: [_; 3] = join_all(leader_ring_clients.each_ref().map(|client| async move { let r = client.query_results(query_id).await.unwrap(); AdditiveShare::::from_byte_slice_unchecked(&r).collect::>() })) @@ -565,4 +610,33 @@ mod tests { .build(); test_make_helpers(conf).await; } + + #[tokio::test] + async fn peer_count() { + fn new_transport(identity: F::Identity) -> Arc> { + Arc::new(HttpTransport { + http_runtime: IpaRuntime::current(), + identity, + clients: Vec::new(), + handler: None, + record_streams: StreamCollection::default(), + }) + } + + assert_eq!( + 2, + MpcHttpTransport { + inner_transport: new_transport(HelperIdentity::ONE) + } + .peer_count() + ); + assert_eq!( + 9, + ShardHttpTransport { + inner_transport: new_transport(ShardIndex::FIRST), + shard_count: 10.into() + } + .peer_count() + ); + } } diff --git a/ipa-core/src/protocol/basics/check_zero.rs b/ipa-core/src/protocol/basics/check_zero.rs index a7aa29916..7cbe8b8b3 100644 --- a/ipa-core/src/protocol/basics/check_zero.rs +++ b/ipa-core/src/protocol/basics/check_zero.rs @@ -1,3 +1,5 @@ +use subtle::ConstantTimeEq; + use crate::{ error::Error, ff::Field, @@ -49,7 +51,7 @@ pub async fn malicious_check_zero( ) -> Result where C: Context, - F: Field + FromRandom, + F: Field + FromRandom + ConstantTimeEq, { let r_sharing: Replicated = ctx.prss().generate(record_id); @@ -61,7 +63,7 @@ where .expect("full reveal should always return a value"), ); - Ok(rv == F::ZERO) + Ok(rv.ct_eq(&F::ZERO).into()) } #[cfg(all(test, unit_test))] diff --git a/ipa-core/src/protocol/basics/mod.rs b/ipa-core/src/protocol/basics/mod.rs index 5239821af..49630f4a9 100644 --- a/ipa-core/src/protocol/basics/mod.rs +++ b/ipa-core/src/protocol/basics/mod.rs @@ -21,11 +21,11 @@ pub use share_known_value::ShareKnownValue; use crate::{ const_assert_eq, - ff::{boolean::Boolean, ec_prime_field::Fp25519, PrimeField}, + ff::{boolean::Boolean, ec_prime_field::Fp25519}, protocol::{ context::{ Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, - UpgradedMaliciousContext, UpgradedSemiHonestContext, + ShardedUpgradedMaliciousContext, UpgradedMaliciousContext, UpgradedSemiHonestContext, }, ipa_prf::{AGG_CHUNK, PRF_CHUNK}, prss::FromPrss, @@ -68,6 +68,14 @@ where { } +impl<'a, const N: usize> BasicProtocols, Fp25519, N> + for malicious::AdditiveShare +where + Fp25519: FieldSimd, + AdditiveShare: FromPrss, +{ +} + /// Basic suite of MPC protocols for (possibly vectorized) boolean shares. /// /// Adds the requirement that the type implements `Not`. diff --git a/ipa-core/src/protocol/basics/mul/mod.rs b/ipa-core/src/protocol/basics/mul/mod.rs index 958bafbcd..a69b7faf4 100644 --- a/ipa-core/src/protocol/basics/mul/mod.rs +++ b/ipa-core/src/protocol/basics/mul/mod.rs @@ -13,10 +13,9 @@ use crate::{ Expand, }, protocol::{ - basics::PrimeField, context::{ - dzkp_semi_honest::DZKPUpgraded as SemiHonestDZKPUpgraded, - semi_honest::Upgraded as SemiHonestUpgraded, Context, DZKPUpgradedMaliciousContext, + dzkp_semi_honest::DZKPUpgraded as SemiHonestDZKPUpgraded, Context, + DZKPUpgradedMaliciousContext, }, RecordId, }, @@ -84,26 +83,6 @@ where macro_rules! boolean_array_mul { ($dim:expr, $vec:ty) => { - impl<'a, B, F> BooleanArrayMul> for Replicated<$vec> - where - B: sharding::ShardBinding, - F: PrimeField, - { - type Vectorized = Replicated; - - fn multiply<'fut>( - ctx: SemiHonestUpgraded<'a, B, F>, - record_id: RecordId, - a: &'fut Self::Vectorized, - b: &'fut Self::Vectorized, - ) -> impl Future> + Send + 'fut - where - SemiHonestUpgraded<'a, B, F>: 'fut, - { - semi_honest_multiply(ctx, record_id, a, b) - } - } - impl<'a, B> BooleanArrayMul> for Replicated<$vec> where B: sharding::ShardBinding, diff --git a/ipa-core/src/protocol/basics/reveal.rs b/ipa-core/src/protocol/basics/reveal.rs index 0e4f08378..e38d31500 100644 --- a/ipa-core/src/protocol/basics/reveal.rs +++ b/ipa-core/src/protocol/basics/reveal.rs @@ -37,7 +37,7 @@ use crate::{ boolean::step::TwoHundredFiftySixBitOpStep, context::{ Context, DZKPContext, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, - UpgradedMaliciousContext, UpgradedSemiHonestContext, + ShardedUpgradedMaliciousContext, UpgradedMaliciousContext, UpgradedSemiHonestContext, }, RecordId, }, @@ -333,6 +333,50 @@ where } } +impl<'a, V, const N: usize, CtxF> Reveal> + for Replicated +where + CtxF: ExtendableField, + V: SharedValue + Vectorizable, +{ + type Output = >::Array; + + async fn generic_reveal<'fut>( + &'fut self, + ctx: ShardedUpgradedMaliciousContext<'a, CtxF>, + record_id: RecordId, + excluded: Option, + ) -> Result>::Array>, Error> + where + ShardedUpgradedMaliciousContext<'a, CtxF>: 'fut, + { + malicious_reveal(ctx, record_id, excluded, self).await + } +} + +impl<'a, F, const N: usize> Reveal> + for MaliciousReplicated +where + F: ExtendableFieldSimd, +{ + type Output = >::Array; + + async fn generic_reveal<'fut>( + &'fut self, + ctx: ShardedUpgradedMaliciousContext<'a, F>, + record_id: RecordId, + excluded: Option, + ) -> Result>::Array>, Error> + where + ShardedUpgradedMaliciousContext<'a, F>: 'fut, + { + use crate::secret_sharing::replicated::malicious::ThisCodeIsAuthorizedToDowngradeFromMalicious; + + let x_share = self.x().access_without_downgrade(); + malicious_reveal(ctx, record_id, excluded, x_share).await + } +} + impl<'a, V, B, const N: usize> Reveal> for Replicated where B: ShardBinding, diff --git a/ipa-core/src/protocol/basics/share_validation.rs b/ipa-core/src/protocol/basics/share_validation.rs index 43d1f34e6..492b99973 100644 --- a/ipa-core/src/protocol/basics/share_validation.rs +++ b/ipa-core/src/protocol/basics/share_validation.rs @@ -1,4 +1,5 @@ use futures_util::future::try_join; +use subtle::ConstantTimeEq; use crate::{ error::Error, @@ -50,7 +51,7 @@ where ) .await?; - if hash_left == hash_received { + if hash_left.ct_eq(&hash_received).into() { Ok(()) } else { Err(Error::InconsistentShares) diff --git a/ipa-core/src/protocol/context/dzkp_field.rs b/ipa-core/src/protocol/context/dzkp_field.rs index 368ac36cc..758ad008b 100644 --- a/ipa-core/src/protocol/context/dzkp_field.rs +++ b/ipa-core/src/protocol/context/dzkp_field.rs @@ -1,18 +1,13 @@ -use std::{iter::zip, sync::LazyLock}; +use std::{ops::Index, sync::LazyLock}; use bitvec::field::BitField; use crate::{ ff::{Field, Fp61BitPrime, PrimeField}, - protocol::context::dzkp_validator::{Array256Bit, SegmentEntry}, + protocol::context::dzkp_validator::{Array256Bit, MultiplicationInputsBlock, SegmentEntry}, secret_sharing::{FieldSimd, SharedValue, Vectorizable}, }; -// BlockSize is fixed to 32 -pub const BLOCK_SIZE: usize = 32; -// UVTupleBlock is a block of interleaved U and V values -pub type UVTupleBlock = ([F; BLOCK_SIZE], [F; BLOCK_SIZE]); - /// Trait for fields compatible with DZKPs /// Field needs to support conversion to `SegmentEntry`, i.e. `to_segment_entry` which is required by DZKPs pub trait DZKPCompatibleField: FieldSimd { @@ -25,35 +20,12 @@ pub trait DZKPBaseField: PrimeField { const INVERSE_OF_TWO: Self; const MINUS_ONE_HALF: Self; const MINUS_TWO: Self; +} - /// Convert allows to convert individual bits from multiplication gates into dzkp compatible field elements. - /// This function is called by the prover. - fn convert_prover<'a>( - x_left: &'a Array256Bit, - x_right: &'a Array256Bit, - y_left: &'a Array256Bit, - y_right: &'a Array256Bit, - prss_right: &'a Array256Bit, - ) -> Vec>; - - /// This is similar to `convert_prover` except that it is called by the verifier to the left of the prover. - /// The verifier on the left uses its right shares, since they are consistent with the prover's left shares. - /// This produces the 'u' values. - fn convert_value_from_right_prover<'a>( - x_right: &'a Array256Bit, - y_right: &'a Array256Bit, - prss_right: &'a Array256Bit, - z_right: &'a Array256Bit, - ) -> Vec; - - /// This is similar to `convert_prover` except that it is called by the verifier to the right of the prover. - /// The verifier on the right uses its left shares, since they are consistent with the prover's right shares. - /// This produces the 'v' values - fn convert_value_from_left_prover<'a>( - x_left: &'a Array256Bit, - y_left: &'a Array256Bit, - prss_left: &'a Array256Bit, - ) -> Vec; +impl DZKPBaseField for Fp61BitPrime { + const INVERSE_OF_TWO: Self = Fp61BitPrime::const_truncate(1_152_921_504_606_846_976u64); + const MINUS_ONE_HALF: Self = Fp61BitPrime::const_truncate(1_152_921_504_606_846_975u64); + const MINUS_TWO: Self = Fp61BitPrime::const_truncate(2_305_843_009_213_693_949u64); } impl FromIterator for [Fp61BitPrime; P] { @@ -90,7 +62,7 @@ impl FromIterator for Vec<[Fp61BitPrime; P]> { } } -/// Construct indices for the `convert_values` lookup tables. +/// Construct indices for the `TABLE_U` and `TABLE_V` lookup tables. /// /// `b0` has the least significant bit of each index, and `b1` and `b2` the subsequent /// bits. This routine rearranges the bits so that there is one table index in each @@ -102,7 +74,7 @@ impl FromIterator for Vec<[Fp61BitPrime; P]> { /// (i%4) == j. The "0s", "1s", "2s", "3s" comments trace the movement from /// input to output. #[must_use] -fn convert_values_table_indices(b0: u128, b1: u128, b2: u128) -> [u128; 4] { +fn bits_to_table_indices(b0: u128, b1: u128, b2: u128) -> [u128; 4] { // 0x55 is 0b0101_0101. This mask selects bits having (i%2) == 0. const CONST_55: u128 = u128::from_le_bytes([0x55; 16]); // 0xaa is 0b1010_1010. This mask selects bits having (i%2) == 1. @@ -148,10 +120,31 @@ fn convert_values_table_indices(b0: u128, b1: u128, b2: u128) -> [u128; 4] { [y0, y1, y2, y3] } +pub struct UVTable(pub [[F; 4]; 8]); + +impl Index for UVTable { + type Output = [F; 4]; + + fn index(&self, index: u8) -> &Self::Output { + self.0.index(usize::from(index)) + } +} + +impl Index for UVTable { + type Output = [F; 4]; + + fn index(&self, index: usize) -> &Self::Output { + self.0.index(index) + } +} + // Table used for `convert_prover` and `convert_value_from_right_prover`. // -// The conversion to "g" and "h" values is from https://eprint.iacr.org/2023/909.pdf. -static TABLE_RIGHT: LazyLock<[[Fp61BitPrime; 4]; 8]> = LazyLock::new(|| { +// This table is for "g" or "u" values. This table is "right" on a verifier when it is +// processing values for the prover on its right. On a prover, this table is "left". +// +// The conversion logic is from https://eprint.iacr.org/2023/909.pdf. +pub static TABLE_U: LazyLock> = LazyLock::new(|| { let mut result = Vec::with_capacity(8); for e in [false, true] { for c in [false, true] { @@ -172,13 +165,16 @@ static TABLE_RIGHT: LazyLock<[[Fp61BitPrime; 4]; 8]> = LazyLock::new(|| { } } } - result.try_into().unwrap() + UVTable(result.try_into().unwrap()) }); // Table used for `convert_prover` and `convert_value_from_left_prover`. // -// The conversion to "g" and "h" values is from https://eprint.iacr.org/2023/909.pdf. -static TABLE_LEFT: LazyLock<[[Fp61BitPrime; 4]; 8]> = LazyLock::new(|| { +// This table is for "h" or "v" values. This table is "left" on a verifier when it is +// processing values for the prover on its left. On a prover, this table is "right". +// +// The conversion logic is from https://eprint.iacr.org/2023/909.pdf. +pub static TABLE_V: LazyLock> = LazyLock::new(|| { let mut result = Vec::with_capacity(8); for f in [false, true] { for d in [false, true] { @@ -199,31 +195,31 @@ static TABLE_LEFT: LazyLock<[[Fp61BitPrime; 4]; 8]> = LazyLock::new(|| { } } } - result.try_into().unwrap() + UVTable(result.try_into().unwrap()) }); -/// Lookup-table-based conversion logic used by `convert_prover`, -/// `convert_value_from_left_prover`, and `convert_value_from_left_prover`. +/// Lookup-table-based conversion logic used by `table_indices_prover`, +/// `table_indices_from_right_prover`, and `table_indices_from_left_prover`. /// /// Inputs `i0`, `i1`, and `i2` each contain the value of one of the "a" through "f" /// intermediates for each of 256 multiplies. `table` is the lookup table to use, -/// which should be either `TABLE_LEFT` or `TABLE_RIGHT`. +/// which should be either `TABLE_U` or `TABLE_V`. /// /// We want to interpret the 3-tuple of intermediates at each bit position in `i0`, `i1` /// and `i2` as an integer index in the range 0..8 into the table. The -/// `convert_values_table_indices` helper does this in bulk more efficiently than using +/// `bits_to_table_indices` helper does this in bulk more efficiently than using /// bit-manipulation to handle them one-by-one. /// /// Preserving the order from inputs to outputs is not necessary for correctness as long /// as the same order is used on all three helpers. We preserve the order anyways /// to simplify the end-to-end dataflow, even though it makes this routine slightly /// more complicated. -fn convert_values( +fn intermediates_to_table_indices<'a>( i0: &Array256Bit, i1: &Array256Bit, i2: &Array256Bit, - table: &[[Fp61BitPrime; 4]; 8], -) -> Vec { + mut out: impl Iterator, +) { // Split inputs to two `u128`s. We do this because `u128` is the largest integer // type rust supports. It is possible that using SIMD types here would improve // code generation for AVX-256/512. @@ -238,45 +234,49 @@ fn convert_values( // Output word `j` in each set contains the table indices for input positions `i` // having (i%4) == j. - let [mut z00, mut z01, mut z02, mut z03] = convert_values_table_indices(i00, i10, i20); - let [mut z10, mut z11, mut z12, mut z13] = convert_values_table_indices(i01, i11, i21); + let [mut z00, mut z01, mut z02, mut z03] = bits_to_table_indices(i00, i10, i20); + let [mut z10, mut z11, mut z12, mut z13] = bits_to_table_indices(i01, i11, i21); - let mut result = Vec::with_capacity(1024); + #[allow(clippy::cast_possible_truncation)] for _ in 0..32 { // Take one index in turn from each `z` to preserve the output order. - for z in [&mut z00, &mut z01, &mut z02, &mut z03] { - result.extend(table[(*z as usize) & 0x7]); - *z >>= 4; - } + *out.next().unwrap() = (z00 as u8) & 0x7; + z00 >>= 4; + *out.next().unwrap() = (z01 as u8) & 0x7; + z01 >>= 4; + *out.next().unwrap() = (z02 as u8) & 0x7; + z02 >>= 4; + *out.next().unwrap() = (z03 as u8) & 0x7; + z03 >>= 4; } + #[allow(clippy::cast_possible_truncation)] for _ in 0..32 { - for z in [&mut z10, &mut z11, &mut z12, &mut z13] { - result.extend(table[(*z as usize) & 0x7]); - *z >>= 4; - } + *out.next().unwrap() = (z10 as u8) & 0x7; + z10 >>= 4; + *out.next().unwrap() = (z11 as u8) & 0x7; + z11 >>= 4; + *out.next().unwrap() = (z12 as u8) & 0x7; + z12 >>= 4; + *out.next().unwrap() = (z13 as u8) & 0x7; + z13 >>= 4; } - debug_assert!(result.len() == 1024); - - result + debug_assert!(out.next().is_none()); } -impl DZKPBaseField for Fp61BitPrime { - const INVERSE_OF_TWO: Self = Fp61BitPrime::const_truncate(1_152_921_504_606_846_976u64); - const MINUS_ONE_HALF: Self = Fp61BitPrime::const_truncate(1_152_921_504_606_846_975u64); - const MINUS_TWO: Self = Fp61BitPrime::const_truncate(2_305_843_009_213_693_949u64); - - // Convert allows to convert individual bits from multiplication gates into dzkp compatible field elements - // We use the conversion logic from https://eprint.iacr.org/2023/909.pdf +impl MultiplicationInputsBlock { + /// Repack the intermediates in this block into lookup indices for `TABLE_U` and `TABLE_V`. + /// + /// This is the convert function called by the prover. // - // This function does not use any optimization. + // We use the conversion logic from https://eprint.iacr.org/2023/909.pdf // - // Prover and left can compute: + // Prover and the verifier on its left compute: // g1=-2ac(1-2e), // g2=c(1-2e), // g3=a(1-2e), // g4=-1/2(1-2e), // - // Prover and right can compute: + // Prover and the verifier on its right compute: // h1=bd(1-2f), // h2=d(1-2f), // h3=b(1-2f), @@ -292,33 +292,30 @@ impl DZKPBaseField for Fp61BitPrime { // therefore e = ab⊕cd⊕ f must hold. (alternatively, you can also see this by substituting z_left, // i.e. z_left = x_left · y_left ⊕ x_left · y_right ⊕ x_right · y_left ⊕ prss_left ⊕ prss_right #[allow(clippy::many_single_char_names)] - fn convert_prover<'a>( - x_left: &'a Array256Bit, - x_right: &'a Array256Bit, - y_left: &'a Array256Bit, - y_right: &'a Array256Bit, - prss_right: &'a Array256Bit, - ) -> Vec> { - let a = x_left; - let b = y_right; - let c = y_left; - let d = x_right; + #[must_use] + pub fn table_indices_prover(&self) -> Vec<(u8, u8)> { + let a = &self.x_left; + let b = &self.y_right; + let c = &self.y_left; + let d = &self.x_right; // e = ab ⊕ cd ⊕ f = x_left * y_right ⊕ y_left * x_right ⊕ prss_right - let e = (*x_left & y_right) ^ (*y_left & x_right) ^ prss_right; - let f = prss_right; - - let g = convert_values(a, c, &e, &TABLE_RIGHT); - let h = convert_values(b, d, f, &TABLE_LEFT); + let e = (self.x_left & self.y_right) ^ (self.y_left & self.x_right) ^ self.prss_right; + let f = &self.prss_right; - zip(g.chunks_exact(BLOCK_SIZE), h.chunks_exact(BLOCK_SIZE)) - .map(|(g_chunk, h_chunk)| (g_chunk.try_into().unwrap(), h_chunk.try_into().unwrap())) - .collect() + let mut output = vec![(0u8, 0u8); 256]; + intermediates_to_table_indices(a, c, &e, output.iter_mut().map(|tup| &mut tup.0)); + intermediates_to_table_indices(b, d, f, output.iter_mut().map(|tup| &mut tup.1)); + output } - // Convert allows to convert individual bits from multiplication gates into dzkp compatible field elements + /// Repack the intermediates in this block into lookup indices for `TABLE_U`. + /// + /// This is the convert function called by the verifier when processing for the + /// prover on its right. + // // We use the conversion logic from https://eprint.iacr.org/2023/909.pdf // - // Prover and left can compute: + // Prover and the verifier on its left compute: // g1=-2ac(1-2e), // g2=c(1-2e), // g3=a(1-2e), @@ -328,25 +325,27 @@ impl DZKPBaseField for Fp61BitPrime { // (a,c,e) = (x_right, y_right, x_right * y_right ⊕ z_right ⊕ prss_right) // here e is defined as in the paper (since the the verifier does not have access to b,d,f, // he cannot use the simplified formula for e) - fn convert_value_from_right_prover<'a>( - x_right: &'a Array256Bit, - y_right: &'a Array256Bit, - prss_right: &'a Array256Bit, - z_right: &'a Array256Bit, - ) -> Vec { - let a = x_right; - let c = y_right; + #[must_use] + pub fn table_indices_from_right_prover(&self) -> Vec { + let a = &self.x_right; + let c = &self.y_right; // e = ac ⊕ zright ⊕ prssright // as defined in the paper - let e = (*a & *c) ^ prss_right ^ z_right; + let e = (self.x_right & self.y_right) ^ self.prss_right ^ self.z_right; - convert_values(a, c, &e, &TABLE_RIGHT) + let mut output = vec![0u8; 256]; + intermediates_to_table_indices(a, c, &e, output.iter_mut()); + output } - // Convert allows to convert individual bits from multiplication gates into dzkp compatible field elements + /// Repack the intermediates in this block into lookup indices for `TABLE_V`. + /// + /// This is the convert function called by the verifier when processing for the + /// prover on its left. + // // We use the conversion logic from https://eprint.iacr.org/2023/909.pdf // - // Prover and right can compute: + // The prover and the verifier on its right compute: // h1=bd(1-2f), // h2=d(1-2f), // h3=b(1-2f), @@ -354,31 +353,33 @@ impl DZKPBaseField for Fp61BitPrime { // // where // (b,d,f) = (y_left, x_left, prss_left) - fn convert_value_from_left_prover<'a>( - x_left: &'a Array256Bit, - y_left: &'a Array256Bit, - prss_left: &'a Array256Bit, - ) -> Vec { - let b = y_left; - let d = x_left; - let f = prss_left; - - convert_values(b, d, f, &TABLE_LEFT) + #[must_use] + pub fn table_indices_from_left_prover(&self) -> Vec { + let b = &self.y_left; + let d = &self.x_left; + let f = &self.prss_left; + + let mut output = vec![0u8; 256]; + intermediates_to_table_indices(b, d, f, output.iter_mut()); + output } } #[cfg(all(test, unit_test))] -mod tests { - use bitvec::{array::BitArray, macros::internal::funty::Fundamental, slice::BitSlice}; +pub mod tests { + + use bitvec::{array::BitArray, macros::internal::funty::Fundamental}; use proptest::proptest; use rand::{thread_rng, Rng}; use crate::{ ff::{Field, Fp61BitPrime, U128Conversions}, - protocol::context::dzkp_field::{ - convert_values_table_indices, DZKPBaseField, UVTupleBlock, BLOCK_SIZE, + protocol::context::{ + dzkp_field::{bits_to_table_indices, DZKPBaseField, TABLE_U, TABLE_V}, + dzkp_validator::MultiplicationInputsBlock, }, secret_sharing::SharedValue, + test_executor::run_random, }; #[test] @@ -386,7 +387,7 @@ mod tests { let b0 = 0xaa; let b1 = 0xcc; let b2 = 0xf0; - let [z0, z1, z2, z3] = convert_values_table_indices(b0, b1, b2); + let [z0, z1, z2, z3] = bits_to_table_indices(b0, b1, b2); assert_eq!(z0, 0x40_u128); assert_eq!(z1, 0x51_u128); assert_eq!(z2, 0x62_u128); @@ -396,7 +397,7 @@ mod tests { let b0 = rng.gen(); let b1 = rng.gen(); let b2 = rng.gen(); - let [z0, z1, z2, z3] = convert_values_table_indices(b0, b1, b2); + let [z0, z1, z2, z3] = bits_to_table_indices(b0, b1, b2); for i in (0..128).step_by(4) { fn check(i: u32, j: u32, b0: u128, b1: u128, b2: u128, z: u128) { @@ -417,85 +418,73 @@ mod tests { } } - #[test] - fn batch_convert() { - let mut rng = thread_rng(); + impl MultiplicationInputsBlock { + /// Rotate the "right" values into the "left" values, setting the right values + /// to zero. If the input represents a prover's block of intermediates, the + /// output represents the intermediates that the verifier on the prover's right + /// shares with it. + #[must_use] + pub fn rotate_left(&self) -> Self { + Self { + x_left: self.x_right, + y_left: self.y_right, + prss_left: self.prss_right, + x_right: [0u8; 32].into(), + y_right: [0u8; 32].into(), + prss_right: [0u8; 32].into(), + z_right: [0u8; 32].into(), + } + } - // bitvecs - let mut vec_x_left = Vec::::new(); - let mut vec_x_right = Vec::::new(); - let mut vec_y_left = Vec::::new(); - let mut vec_y_right = Vec::::new(); - let mut vec_prss_left = Vec::::new(); - let mut vec_prss_right = Vec::::new(); - let mut vec_z_right = Vec::::new(); - - // gen 32 random values - for _i in 0..32 { - let x_left: u8 = rng.gen(); - let x_right: u8 = rng.gen(); - let y_left: u8 = rng.gen(); - let y_right: u8 = rng.gen(); - let prss_left: u8 = rng.gen(); - let prss_right: u8 = rng.gen(); - // we set this up to be equal to z_right for this local test - // local here means that only a single party is involved - // and we just test this against this single party - let z_right: u8 = (x_left & y_left) - ^ (x_left & y_right) - ^ (x_right & y_left) - ^ prss_left - ^ prss_right; - - // fill vec - vec_x_left.push(x_left); - vec_x_right.push(x_right); - vec_y_left.push(y_left); - vec_y_right.push(y_right); - vec_prss_left.push(prss_left); - vec_prss_right.push(prss_right); - vec_z_right.push(z_right); + /// Rotate the "left" values into the "right" values, setting the left values to + /// zero. `z_right` is calculated to be consistent with the other values. If the + /// input represents a prover's block of intermediates, the output represents + /// the intermediates that the verifier on the prover's left shares with it. + #[must_use] + pub fn rotate_right(&self) -> Self { + let z_right = (self.x_left & self.y_left) + ^ (self.x_left & self.y_right) + ^ (self.x_right & self.y_left) + ^ self.prss_left + ^ self.prss_right; + + Self { + x_right: self.x_left, + y_right: self.y_left, + prss_right: self.prss_left, + x_left: [0u8; 32].into(), + y_left: [0u8; 32].into(), + prss_left: [0u8; 32].into(), + z_right, + } } + } - // conv to BitVec - let x_left = BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_x_left)).unwrap(); - let x_right = BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_x_right)).unwrap(); - let y_left = BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_y_left)).unwrap(); - let y_right = BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_y_right)).unwrap(); - let prss_left = - BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_prss_left)).unwrap(); - let prss_right = - BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_prss_right)).unwrap(); - let z_right = BitArray::<[u8; 32]>::try_from(BitSlice::from_slice(&vec_z_right)).unwrap(); - - // check consistency of the polynomials - assert_convert( - Fp61BitPrime::convert_prover(&x_left, &x_right, &y_left, &y_right, &prss_right), - // flip intputs right to left since it is checked against itself and not party on the left - // z_right is set to match z_left - Fp61BitPrime::convert_value_from_right_prover(&x_left, &y_left, &prss_left, &z_right), - // flip intputs right to left since it is checked against itself and not party on the left - Fp61BitPrime::convert_value_from_left_prover(&x_right, &y_right, &prss_right), - ); + #[test] + fn batch_convert() { + run_random(|mut rng| async move { + let block = rng.gen::(); + + // When verifying, we rotate the intermediates to match what each prover + // would have. `rotate_right` also calculates z_right from the others. + assert_convert( + block.table_indices_prover(), + block.rotate_right().table_indices_from_right_prover(), + block.rotate_left().table_indices_from_left_prover(), + ); + }); } + fn assert_convert(prover: P, verifier_left: L, verifier_right: R) where - P: IntoIterator>, - L: IntoIterator, - R: IntoIterator, + P: IntoIterator, + L: IntoIterator, + R: IntoIterator, { prover .into_iter() - .zip( - verifier_left - .into_iter() - .collect::>(), - ) - .zip( - verifier_right - .into_iter() - .collect::>(), - ) + .zip(verifier_left.into_iter().collect::>()) + .zip(verifier_right.into_iter().collect::>()) .for_each(|((prover, verifier_left), verifier_right)| { assert_eq!(prover.0, verifier_left); assert_eq!(prover.1, verifier_right); @@ -534,37 +523,15 @@ mod tests { } #[allow(clippy::fn_params_excessive_bools)] - fn correctness_prover_values( + #[must_use] + pub fn reference_convert( x_left: bool, x_right: bool, y_left: bool, y_right: bool, prss_left: bool, prss_right: bool, - ) { - let mut array_x_left = BitArray::<[u8; 32]>::ZERO; - let mut array_x_right = BitArray::<[u8; 32]>::ZERO; - let mut array_y_left = BitArray::<[u8; 32]>::ZERO; - let mut array_y_right = BitArray::<[u8; 32]>::ZERO; - let mut array_prss_left = BitArray::<[u8; 32]>::ZERO; - let mut array_prss_right = BitArray::<[u8; 32]>::ZERO; - - // initialize bits - array_x_left.set(0, x_left); - array_x_right.set(0, x_right); - array_y_left.set(0, y_left); - array_y_right.set(0, y_right); - array_prss_left.set(0, prss_left); - array_prss_right.set(0, prss_right); - - let prover = Fp61BitPrime::convert_prover( - &array_x_left, - &array_x_right, - &array_y_left, - &array_y_right, - &array_prss_right, - )[0]; - + ) -> ([Fp61BitPrime; 4], [Fp61BitPrime; 4]) { // compute expected // (a,b,c,d,f) = (x_left, y_right, y_left, x_right, prss_right) // e = x_left · y_left ⊕ z_left ⊕ prss_left @@ -601,17 +568,59 @@ mod tests { // h4=1-2f, let h4 = one_minus_two_f; + ([g1, g2, g3, g4], [h1, h2, h3, h4]) + } + + #[allow(clippy::fn_params_excessive_bools)] + fn correctness_prover_values( + x_left: bool, + x_right: bool, + y_left: bool, + y_right: bool, + prss_left: bool, + prss_right: bool, + ) { + let mut array_x_left = BitArray::<[u8; 32]>::ZERO; + let mut array_x_right = BitArray::<[u8; 32]>::ZERO; + let mut array_y_left = BitArray::<[u8; 32]>::ZERO; + let mut array_y_right = BitArray::<[u8; 32]>::ZERO; + let mut array_prss_left = BitArray::<[u8; 32]>::ZERO; + let mut array_prss_right = BitArray::<[u8; 32]>::ZERO; + + // initialize bits + array_x_left.set(0, x_left); + array_x_right.set(0, x_right); + array_y_left.set(0, y_left); + array_y_right.set(0, y_right); + array_prss_left.set(0, prss_left); + array_prss_right.set(0, prss_right); + + let block = MultiplicationInputsBlock { + x_left: array_x_left, + x_right: array_x_right, + y_left: array_y_left, + y_right: array_y_right, + prss_left: array_prss_left, + prss_right: array_prss_right, + z_right: BitArray::ZERO, + }; + + let prover = block.table_indices_prover()[0]; + + let ([g1, g2, g3, g4], [h1, h2, h3, h4]) = + reference_convert(x_left, x_right, y_left, y_right, prss_left, prss_right); + // check expected == computed // g polynomial - assert_eq!(g1, prover.0[0]); - assert_eq!(g2, prover.0[1]); - assert_eq!(g3, prover.0[2]); - assert_eq!(g4, prover.0[3]); + assert_eq!(g1, TABLE_U[prover.0][0]); + assert_eq!(g2, TABLE_U[prover.0][1]); + assert_eq!(g3, TABLE_U[prover.0][2]); + assert_eq!(g4, TABLE_U[prover.0][3]); // h polynomial - assert_eq!(h1, prover.1[0]); - assert_eq!(h2, prover.1[1]); - assert_eq!(h3, prover.1[2]); - assert_eq!(h4, prover.1[3]); + assert_eq!(h1, TABLE_V[prover.1][0]); + assert_eq!(h2, TABLE_V[prover.1][1]); + assert_eq!(h3, TABLE_V[prover.1][2]); + assert_eq!(h4, TABLE_V[prover.1][3]); } } diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 63a412265..5d8d58b2e 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -12,7 +12,7 @@ use crate::{ protocol::{ context::{ batcher::Batcher, - dzkp_field::{DZKPBaseField, UVTupleBlock}, + dzkp_field::{DZKPBaseField, TABLE_U, TABLE_V}, dzkp_malicious::DZKPUpgraded as MaliciousDZKPUpgraded, dzkp_semi_honest::DZKPUpgraded as SemiHonestDZKPUpgraded, step::DzkpValidationProtocolStep as Step, @@ -20,7 +20,8 @@ use crate::{ }, ipa_prf::{ validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, - LargeProofGenerator, SmallProofGenerator, + CompressedProofGenerator, FirstProofGenerator, ProverTableIndices, + VerifierTableIndices, }, Gate, RecordId, RecordIdRange, }, @@ -33,7 +34,7 @@ pub type Array256Bit = BitArray<[u8; 32], Lsb0>; type BitSliceType = BitSlice; -const BIT_ARRAY_LEN: usize = 256; +pub const BIT_ARRAY_LEN: usize = 256; const BIT_ARRAY_MASK: usize = BIT_ARRAY_LEN - 1; const BIT_ARRAY_SHIFT: usize = BIT_ARRAY_LEN.ilog2() as usize; @@ -50,11 +51,20 @@ const BIT_ARRAY_SHIFT: usize = BIT_ARRAY_LEN.ilog2() as usize; // A smaller value is used for tests, to enable covering some corner cases with a // reasonable runtime. Some of these tests use TARGET_PROOF_SIZE directly, so for tests // it does need to be a power of two. +// +// TARGET_PROOF_SIZE is closely related to MAX_PROOF_RECURSION; see the assertion that +// `uv_values.len() <= max_uv_values` in `ProofBatch` for more detail. #[cfg(test)] pub const TARGET_PROOF_SIZE: usize = 8192; #[cfg(not(test))] pub const TARGET_PROOF_SIZE: usize = 50_000_000; +/// Minimum proof recursion depth. +/// +/// This minimum avoids special cases in the implementation that would be otherwise +/// required when the initial and final recursion steps overlap. +pub const MIN_PROOF_RECURSION: usize = 2; + /// Maximum proof recursion depth. // // This is a hard limit. Each GF(2) multiply generates four G values and four H values, @@ -71,9 +81,7 @@ pub const TARGET_PROOF_SIZE: usize = 50_000_000; // Because the number of records in a proof batch is often rounded up to a power of two // (and less significantly, because multiplication intermediate storage gets rounded up // to blocks of 256), leaving some margin is advised. -// -// The implementation requires that MAX_PROOF_RECURSION is at least 2. -pub const MAX_PROOF_RECURSION: usize = 9; +pub const MAX_PROOF_RECURSION: usize = 14; /// `MultiplicationInputsBlock` is a block of fixed size of intermediate values /// that occur duringa multiplication. @@ -156,34 +164,21 @@ impl MultiplicationInputsBlock { Ok(()) } +} - /// `Convert` allows to convert `MultiplicationInputs` into a format compatible with DZKPs - /// This is the convert function called by the prover. - fn convert_prover(&self) -> Vec> { - DF::convert_prover( - &self.x_left, - &self.x_right, - &self.y_left, - &self.y_right, - &self.prss_right, - ) - } - - /// `convert_values_from_right_prover` allows to convert `MultiplicationInputs` into a format compatible with DZKPs - /// This is the convert function called by the verifier on the left. - fn convert_values_from_right_prover(&self) -> Vec { - DF::convert_value_from_right_prover( - &self.x_right, - &self.y_right, - &self.prss_right, - &self.z_right, - ) - } - - /// `convert_values_from_left_prover` allows to convert `MultiplicationInputs` into a format compatible with DZKPs - /// This is the convert function called by the verifier on the right. - fn convert_values_from_left_prover(&self) -> Vec { - DF::convert_value_from_left_prover(&self.x_left, &self.y_left, &self.prss_left) +#[cfg(any(test, feature = "enable-benches"))] +impl rand::prelude::Distribution for rand::distributions::Standard { + fn sample(&self, rng: &mut R) -> MultiplicationInputsBlock { + let sample = >::sample; + MultiplicationInputsBlock { + x_left: sample(self, rng).into(), + x_right: sample(self, rng).into(), + y_left: sample(self, rng).into(), + y_right: sample(self, rng).into(), + prss_left: sample(self, rng).into(), + prss_right: sample(self, rng).into(), + z_right: sample(self, rng).into(), + } } } @@ -469,34 +464,30 @@ impl MultiplicationInputsBatch { } } - /// `get_field_values_prover` converts a `MultiplicationInputsBatch` into an iterator over `field` - /// values used by the prover of the DZKPs - fn get_field_values_prover( - &self, - ) -> impl Iterator> + Clone + '_ { + /// `get_field_values_prover` converts a `MultiplicationInputsBatch` into an + /// iterator over pairs of indices for `TABLE_U` and `TABLE_V`. + fn get_field_values_prover(&self) -> impl Iterator + Clone + '_ { self.vec .iter() - .flat_map(MultiplicationInputsBlock::convert_prover::) + .flat_map(MultiplicationInputsBlock::table_indices_prover) } - /// `get_field_values_from_right_prover` converts a `MultiplicationInputsBatch` into an iterator over `field` - /// values used by the verifier of the DZKPs on the left side of the prover, i.e. the `u` values. - fn get_field_values_from_right_prover( - &self, - ) -> impl Iterator + '_ { + /// `get_field_values_from_right_prover` converts a `MultiplicationInputsBatch` into + /// an iterator over table indices for `TABLE_U`, which is used by the verifier of + /// the DZKPs on the left side of the prover. + fn get_field_values_from_right_prover(&self) -> impl Iterator + '_ { self.vec .iter() - .flat_map(MultiplicationInputsBlock::convert_values_from_right_prover::) + .flat_map(MultiplicationInputsBlock::table_indices_from_right_prover) } - /// `get_field_values_from_left_prover` converts a `MultiplicationInputsBatch` into an iterator over `field` - /// values used by the verifier of the DZKPs on the right side of the prover, i.e. the `v` values. - fn get_field_values_from_left_prover( - &self, - ) -> impl Iterator + '_ { + /// `get_field_values_from_left_prover` converts a `MultiplicationInputsBatch` into + /// an iterator over table indices for `TABLE_V`, which is used by the verifier of + /// the DZKPs on the right side of the prover. + fn get_field_values_from_left_prover(&self) -> impl Iterator + '_ { self.vec .iter() - .flat_map(MultiplicationInputsBlock::convert_values_from_left_prover::) + .flat_map(MultiplicationInputsBlock::table_indices_from_left_prover) } } @@ -562,36 +553,30 @@ impl Batch { .sum() } - /// `get_field_values_prover` converts a `Batch` into an iterator over field values - /// which is used by the prover of the DZKP - fn get_field_values_prover( - &self, - ) -> impl Iterator> + Clone + '_ { + /// `get_field_values_prover` converts a `Batch` into an iterator over pairs of + /// indices for `TABLE_U` and `TABLE_V`. + fn get_field_values_prover(&self) -> impl Iterator + Clone + '_ { self.inner .values() - .flat_map(MultiplicationInputsBatch::get_field_values_prover::) + .flat_map(MultiplicationInputsBatch::get_field_values_prover) } - /// `get_field_values_from_right_prover` converts a `Batch` into an iterator over field values - /// which is used by the verifier of the DZKP on the left side of the prover. - /// This produces the `u` values. - fn get_field_values_from_right_prover( - &self, - ) -> impl Iterator + '_ { + /// `get_field_values_from_right_prover` converts a `Batch` into an iterator over + /// table indices for `TABLE_U`, which is used by the verifier of the DZKP on the + /// left side of the prover. + fn get_field_values_from_right_prover(&self) -> impl Iterator + '_ { self.inner .values() - .flat_map(MultiplicationInputsBatch::get_field_values_from_right_prover::) + .flat_map(MultiplicationInputsBatch::get_field_values_from_right_prover) } - /// `get_field_values_from_left_prover` converts a `Batch` into an iterator over field values - /// which is used by the verifier of the DZKP on the right side of the prover. - /// This produces the `v` values. - fn get_field_values_from_left_prover( - &self, - ) -> impl Iterator + '_ { + /// `get_field_values_from_left_prover` converts a `Batch` into an iterator over + /// table indices for `TABLE_V`, which is used by the verifier of the DZKP on the + /// right side of the prover. + fn get_field_values_from_left_prover(&self) -> impl Iterator + '_ { self.inner .values() - .flat_map(MultiplicationInputsBatch::get_field_values_from_left_prover::) + .flat_map(MultiplicationInputsBatch::get_field_values_from_left_prover) } /// ## Panics @@ -601,8 +586,8 @@ impl Batch { ctx: Base<'_, B>, batch_index: usize, ) -> Result<(), Error> { - const PRSS_RECORDS_PER_BATCH: usize = LargeProofGenerator::PROOF_LENGTH - + (MAX_PROOF_RECURSION - 1) * SmallProofGenerator::PROOF_LENGTH + const PRSS_RECORDS_PER_BATCH: usize = FirstProofGenerator::PROOF_LENGTH + + (MAX_PROOF_RECURSION - 1) * CompressedProofGenerator::PROOF_LENGTH + 2; // P and Q masks let proof_ctx = ctx.narrow(&Step::GenerateProof); @@ -623,7 +608,11 @@ impl Batch { q_mask_from_left_prover, ) = { // generate BatchToVerify - ProofBatch::generate(&proof_ctx, prss_record_ids, self.get_field_values_prover()) + ProofBatch::generate( + &proof_ctx, + prss_record_ids, + ProverTableIndices(self.get_field_values_prover()), + ) }; let chunk_batch = BatchToVerify::generate_batch_to_verify( @@ -647,12 +636,7 @@ impl Batch { tracing::info!("validating {m} multiplications"); debug_assert_eq!( m, - self.get_field_values_prover::() - .flat_map(|(u_array, v_array)| { - u_array.into_iter().zip(v_array).map(|(u, v)| u * v) - }) - .count() - / 4, + self.get_field_values_prover().count(), "Number of multiplications is counted incorrectly" ); let sum_of_uv = Fp61BitPrime::truncate_from(u128::try_from(m).unwrap()) @@ -661,8 +645,14 @@ impl Batch { let (p_r_right_prover, q_r_left_prover) = chunk_batch.compute_p_and_q_r( &challenges_for_left_prover, &challenges_for_right_prover, - self.get_field_values_from_right_prover(), - self.get_field_values_from_left_prover(), + VerifierTableIndices { + input: self.get_field_values_from_right_prover(), + table: &TABLE_U, + }, + VerifierTableIndices { + input: self.get_field_values_from_left_prover(), + table: &TABLE_V, + }, ); (sum_of_uv, p_r_right_prover, q_r_left_prover) @@ -962,12 +952,11 @@ mod tests { ff::{ boolean::Boolean, boolean_array::{BooleanArray, BA16, BA20, BA256, BA3, BA32, BA64, BA8}, - Fp61BitPrime, }, protocol::{ basics::{select, BooleanArrayMul, SecureMul}, context::{ - dzkp_field::{DZKPCompatibleField, BLOCK_SIZE}, + dzkp_field::DZKPCompatibleField, dzkp_validator::{ Batch, DZKPValidator, Segment, SegmentEntry, BIT_ARRAY_LEN, TARGET_PROOF_SIZE, }, @@ -1811,16 +1800,16 @@ mod tests { fn assert_batch_convert(batch_prover: &Batch, batch_left: &Batch, batch_right: &Batch) { batch_prover - .get_field_values_prover::() + .get_field_values_prover() .zip( batch_left - .get_field_values_from_right_prover::() - .collect::>(), + .get_field_values_from_right_prover() + .collect::>(), ) .zip( batch_right - .get_field_values_from_left_prover::() - .collect::>(), + .get_field_values_from_left_prover() + .collect::>(), ) .for_each(|((prover, verifier_left), verifier_right)| { assert_eq!(prover.0, verifier_left); diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index 205e022d6..d21896190 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -28,6 +28,7 @@ pub type ShardedSemiHonestContext<'a> = semi_honest::Context<'a, Sharded>; pub type MaliciousContext<'a, B = NotSharded> = malicious::Context<'a, B>; pub type ShardedMaliciousContext<'a> = malicious::Context<'a, Sharded>; pub type UpgradedMaliciousContext<'a, F, B = NotSharded> = malicious::Upgraded<'a, F, B>; +pub type ShardedUpgradedMaliciousContext<'a, F, B = Sharded> = malicious::Upgraded<'a, F, B>; #[cfg(all(feature = "in-memory-infra", any(test, feature = "test-fixture")))] pub(crate) use malicious::TEST_DZKP_STEPS; diff --git a/ipa-core/src/protocol/hybrid/agg.rs b/ipa-core/src/protocol/hybrid/agg.rs new file mode 100644 index 000000000..85f59f58a --- /dev/null +++ b/ipa-core/src/protocol/hybrid/agg.rs @@ -0,0 +1,537 @@ +use std::collections::BTreeMap; + +use futures::{stream, StreamExt, TryStreamExt}; + +use crate::{ + error::Error, + ff::{boolean::Boolean, boolean_array::BooleanArray, ArrayAccess}, + helpers::TotalRecords, + protocol::{ + boolean::step::EightBitStep, + context::{ + dzkp_validator::{validated_seq_join, DZKPValidator, TARGET_PROOF_SIZE}, + Context, DZKPUpgraded, MaliciousProtocolSteps, ShardedContext, UpgradableContext, + }, + hybrid::step::{AggregateReportsStep, HybridStep}, + ipa_prf::boolean_ops::addition_sequential::integer_add, + BooleanProtocols, + }, + report::hybrid::{AggregateableHybridReport, PrfHybridReport}, + secret_sharing::replicated::semi_honest::AdditiveShare as Replicated, +}; + +enum MatchEntry +where + BK: BooleanArray, + V: BooleanArray, +{ + Single(AggregateableHybridReport), + Pair( + AggregateableHybridReport, + AggregateableHybridReport, + ), + MoreThanTwo, +} + +impl MatchEntry +where + BK: BooleanArray, + V: BooleanArray, +{ + pub fn add_report(&mut self, new_report: AggregateableHybridReport) { + match self { + Self::Single(old_report) => { + *self = Self::Pair(old_report.clone(), new_report); + } + Self::Pair { .. } | Self::MoreThanTwo => *self = Self::MoreThanTwo, + } + } + + pub fn into_pair(self) -> Option<[AggregateableHybridReport; 2]> { + match self { + Self::Pair(r1, r2) => Some([r1, r2]), + _ => None, + } + } +} + +/// This function takes in a vector of `PrfHybridReports`, groups them by the oprf of the `match_key`, +/// and collects all pairs of reports with the same `match_key` into a vector of paris (as an array.) +/// +/// *Note*: Any `match_key` which appears once or more than twice is removed. +/// An honest report collector will only provide a single impression report per `match_key` and +/// an honest client will only provide a single conversion report per `match_key`. +/// Also note that a malicious client (intenional or bug) could provide exactly two conversions. +/// This would put the sum of conversion values into `breakdown_key` 0. As this is undetectable, +/// this makes `breakdown_key = 0` *unreliable*. +/// +/// Note: Possible Perf opportunity by removing the `collect()`. +/// See [#1443](https://github.com/private-attribution/ipa/issues/1443). +/// +/// *Note*: In order to add the pairs, the vector of pairs must be in the same order across all +/// three helpers. A standard `HashMap` uses system randomness for insertion placement, so we +/// use a `BTreeMap` to maintain consistent ordering across the helpers. +/// +fn group_report_pairs_ordered( + reports: Vec>, +) -> Vec<[AggregateableHybridReport; 2]> +where + BK: BooleanArray, + V: BooleanArray, +{ + let mut reports_by_matchkey: BTreeMap> = BTreeMap::new(); + + for report in reports { + reports_by_matchkey + .entry(report.match_key) + .and_modify(|e| e.add_report(report.clone().into())) + .or_insert(MatchEntry::Single(report.into())); + } + + // we only keep the reports from match_keys that provided exactly 2 reports + reports_by_matchkey + .into_values() + .filter_map(MatchEntry::into_pair) + .collect::>() +} + +/// This protocol is used to aggregate `PRFHybridReports` and returns `AggregateableHybridReports`. +/// It groups all the reports by the PRF of the `match_key`, finds all reports from `match_keys` +/// with that provided exactly 2 reports, then adds those 2 reports. +/// TODO (Performance opportunity): These additions are not currently vectorized. +/// We are currently deferring that work until the protocol is complete. +pub async fn aggregate_reports( + ctx: C, + reports: Vec>, +) -> Result>, Error> +where + C: UpgradableContext + ShardedContext, + BK: BooleanArray, + V: BooleanArray, + Replicated: BooleanProtocols>, +{ + let report_pairs = group_report_pairs_ordered(reports); + + let chunk_size: usize = TARGET_PROOF_SIZE / (BK::BITS as usize + V::BITS as usize); + + let ctx = ctx.set_total_records(TotalRecords::specified(report_pairs.len())?); + + let dzkp_validator = ctx.dzkp_validator( + MaliciousProtocolSteps { + protocol: &HybridStep::GroupBySum, + validate: &HybridStep::GroupBySumValidate, + }, + chunk_size.next_power_of_two(), + ); + + let agg_ctx = dzkp_validator.context(); + + let agg_work = stream::iter(report_pairs) + .enumerate() + .map(|(idx, reports)| { + let agg_ctx = agg_ctx.clone(); + async move { + let (breakdown_key, _) = integer_add::<_, EightBitStep, 1>( + agg_ctx.narrow(&AggregateReportsStep::AddBK), + idx.into(), + &reports[0].breakdown_key.to_bits(), + &reports[1].breakdown_key.to_bits(), + ) + .await?; + let (value, _) = integer_add::<_, EightBitStep, 1>( + agg_ctx.narrow(&AggregateReportsStep::AddV), + idx.into(), + &reports[0].value.to_bits(), + &reports[1].value.to_bits(), + ) + .await?; + Ok::<_, Error>(AggregateableHybridReport:: { + match_key: (), + breakdown_key: breakdown_key.collect_bits(), + value: value.collect_bits(), + }) + } + }); + + validated_seq_join(dzkp_validator, agg_work) + .try_collect() + .await +} + +#[cfg(all(test, unit_test))] +pub mod test { + use rand::Rng; + + use super::{aggregate_reports, group_report_pairs_ordered}; + use crate::{ + ff::{ + boolean_array::{BA3, BA8}, + U128Conversions, + }, + helpers::Role, + protocol::hybrid::step::AggregateReportsStep, + report::hybrid::{ + AggregateableHybridReport, IndistinguishableHybridReport, PrfHybridReport, + }, + secret_sharing::replicated::{ + semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing, + }, + sharding::{ShardConfiguration, ShardIndex}, + test_executor::{run, run_random}, + test_fixture::{ + hybrid::{TestAggregateableHybridReport, TestHybridRecord}, + Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards, + }, + }; + + // the inputs are laid out to work with exactly 2 shards + // as if it we're resharded by match_key/prf + const SHARDS: usize = 2; + + // we re-use these as the "prf" of the match_key + // to avoid needing to actually do the prf here + const SHARD1_MKS: [u64; 7] = [12345, 12345, 34567, 34567, 78901, 78901, 78901]; + const SHARD2_MKS: [u64; 7] = [23456, 23456, 45678, 56789, 67890, 67890, 67890]; + + fn get_records() -> Vec { + let shard1_records = [ + TestHybridRecord::TestImpression { + match_key: SHARD1_MKS[0], + breakdown_key: 45, + }, + TestHybridRecord::TestConversion { + match_key: SHARD1_MKS[1], + value: 1, + }, // attributed + TestHybridRecord::TestConversion { + match_key: SHARD1_MKS[2], + value: 3, + }, + TestHybridRecord::TestConversion { + match_key: SHARD1_MKS[3], + value: 4, + }, // not attibuted, but duplicated conversion. will land in breakdown_key 0 + TestHybridRecord::TestImpression { + match_key: SHARD1_MKS[4], + breakdown_key: 1, + }, // duplicated impression with same match_key + TestHybridRecord::TestImpression { + match_key: SHARD1_MKS[4], + breakdown_key: 2, + }, // duplicated impression with same match_key + TestHybridRecord::TestConversion { + match_key: SHARD1_MKS[5], + value: 7, + }, // removed + ]; + let shard2_records = [ + TestHybridRecord::TestImpression { + match_key: SHARD2_MKS[0], + breakdown_key: 56, + }, + TestHybridRecord::TestConversion { + match_key: SHARD2_MKS[1], + value: 2, + }, // attributed + TestHybridRecord::TestImpression { + match_key: SHARD2_MKS[2], + breakdown_key: 78, + }, // NOT attributed + TestHybridRecord::TestConversion { + match_key: SHARD2_MKS[3], + value: 5, + }, // NOT attributed + TestHybridRecord::TestImpression { + match_key: SHARD2_MKS[4], + breakdown_key: 90, + }, // attributed twice, removed + TestHybridRecord::TestConversion { + match_key: SHARD2_MKS[5], + value: 6, + }, // attributed twice, removed + TestHybridRecord::TestConversion { + match_key: SHARD2_MKS[6], + value: 7, + }, // attributed twice, removed + ]; + + shard1_records + .chunks(1) + .zip(shard2_records.chunks(1)) + .flat_map(|(a, b)| a.iter().chain(b)) + .cloned() + .collect() + } + + #[test] + fn group_reports_mpc() { + run(|| async { + let records = get_records(); + let expected = vec![ + [ + TestAggregateableHybridReport { + match_key: (), + value: 0, + breakdown_key: 45, + }, + TestAggregateableHybridReport { + match_key: (), + value: 1, + breakdown_key: 0, + }, + ], + [ + TestAggregateableHybridReport { + match_key: (), + value: 3, + breakdown_key: 0, + }, + TestAggregateableHybridReport { + match_key: (), + value: 4, + breakdown_key: 0, + }, + ], + [ + TestAggregateableHybridReport { + match_key: (), + value: 0, + breakdown_key: 56, + }, + TestAggregateableHybridReport { + match_key: (), + value: 2, + breakdown_key: 0, + }, + ], + ]; + + let world = TestWorld::>::with_shards(TestWorldConfig::default()); + #[allow(clippy::type_complexity)] + let results: Vec<[Vec<[AggregateableHybridReport; 2]>; 3]> = world + .malicious(records.clone().into_iter(), |ctx, input| { + let match_keys = match ctx.shard_id() { + ShardIndex(0) => SHARD1_MKS, + ShardIndex(1) => SHARD2_MKS, + _ => panic!("invalid shard_id"), + }; + async move { + let indistinguishable_reports: Vec< + IndistinguishableHybridReport, + > = input.iter().map(|r| r.clone().into()).collect::>(); + + let prf_reports: Vec> = indistinguishable_reports + .iter() + .zip(match_keys) + .map(|(indist_report, match_key)| PrfHybridReport { + match_key, + value: indist_report.value.clone(), + breakdown_key: indist_report.breakdown_key.clone(), + }) + .collect::>(); + group_report_pairs_ordered(prf_reports) + } + }) + .await; + + let results: Vec<[TestAggregateableHybridReport; 2]> = results + .into_iter() + .flat_map(|shard_result| { + shard_result[0] + .clone() + .into_iter() + .zip(shard_result[1].clone()) + .zip(shard_result[2].clone()) + .map(|((r1, r2), r3)| { + [ + [&r1[0], &r2[0], &r3[0]].reconstruct(), + [&r1[1], &r2[1], &r3[1]].reconstruct(), + ] + }) + .collect::>() + }) + .collect::>(); + + assert_eq!(results, expected); + }); + } + + #[test] + fn aggregate_reports_test() { + run(|| async { + let records = get_records(); + let expected = vec![ + TestAggregateableHybridReport { + match_key: (), + value: 1, + breakdown_key: 45, + }, + TestAggregateableHybridReport { + match_key: (), + value: 7, + breakdown_key: 0, + }, + TestAggregateableHybridReport { + match_key: (), + value: 2, + breakdown_key: 56, + }, + ]; + + let world = TestWorld::>::with_shards(TestWorldConfig::default()); + + let results: Vec<[Vec>; 3]> = world + .malicious(records.clone().into_iter(), |ctx, input| { + let match_keys = match ctx.shard_id() { + ShardIndex(0) => SHARD1_MKS, + ShardIndex(1) => SHARD2_MKS, + _ => panic!("invalid shard_id"), + }; + async move { + let indistinguishable_reports: Vec< + IndistinguishableHybridReport, + > = input.iter().map(|r| r.clone().into()).collect::>(); + + let prf_reports: Vec> = indistinguishable_reports + .iter() + .zip(match_keys) + .map(|(indist_report, match_key)| PrfHybridReport { + match_key, + value: indist_report.value.clone(), + breakdown_key: indist_report.breakdown_key.clone(), + }) + .collect::>(); + + aggregate_reports(ctx.clone(), prf_reports).await.unwrap() + } + }) + .await; + + let results: Vec = results + .into_iter() + .flat_map(|shard_result| { + shard_result[0] + .clone() + .into_iter() + .zip(shard_result[1].clone()) + .zip(shard_result[2].clone()) + .map(|((r1, r2), r3)| [&r1, &r2, &r3].reconstruct()) + .collect::>() + }) + .collect::>(); + + assert_eq!(results, expected); + }); + } + + fn build_prf_hybrid_report( + match_key: u64, + value: u8, + breakdown_key: u8, + ) -> PrfHybridReport { + PrfHybridReport:: { + match_key, + value: Replicated::new(BA3::truncate_from(value), BA3::truncate_from(0_u128)), + breakdown_key: Replicated::new( + BA8::truncate_from(breakdown_key), + BA8::truncate_from(0_u128), + ), + } + } + + fn build_aggregateable_report( + value: u8, + breakdown_key: u8, + ) -> AggregateableHybridReport { + AggregateableHybridReport:: { + match_key: (), + value: Replicated::new(BA3::truncate_from(value), BA3::truncate_from(0_u128)), + breakdown_key: Replicated::new( + BA8::truncate_from(breakdown_key), + BA8::truncate_from(0_u128), + ), + } + } + + #[test] + fn group_reports() { + let reports = vec![ + build_prf_hybrid_report(42, 2, 0), // pair: index (1,0) + build_prf_hybrid_report(42, 0, 3), // pair: index (1,1) + build_prf_hybrid_report(17, 4, 0), // pair: index (0,0) + build_prf_hybrid_report(17, 0, 13), // pair: index (0,1) + build_prf_hybrid_report(13, 0, 5), // single + build_prf_hybrid_report(11, 2, 0), // single + build_prf_hybrid_report(31, 1, 2), // triple + build_prf_hybrid_report(31, 3, 4), // triple + build_prf_hybrid_report(31, 5, 6), // triple + ]; + + let expected = vec![ + [ + build_aggregateable_report(4, 0), + build_aggregateable_report(0, 13), + ], + [ + build_aggregateable_report(2, 0), + build_aggregateable_report(0, 3), + ], + ]; + + let results = group_report_pairs_ordered(reports); + assert_eq!(results, expected); + } + + /// This test checks that the sharded malicious `aggregate_reports` fails + /// under a simple bit flip attack by H1. + #[test] + #[should_panic(expected = "DZKPValidationFailed")] + fn sharded_fail_under_bit_flip_attack_on_breakdown_key() { + use crate::helpers::in_memory_config::MaliciousHelper; + run_random(|mut rng| async move { + let target_shard = ShardIndex::from(rng.gen_range(0..u32::try_from(SHARDS).unwrap())); + let mut config = TestWorldConfig::default(); + + let step = format!("{}/{}", AggregateReportsStep::AddBK.as_ref(), "bit0",); + config.stream_interceptor = + MaliciousHelper::new(Role::H2, config.role_assignment(), move |ctx, data| { + // flip a bit of the match_key on the target shard, H1 + if ctx.gate.as_ref().contains(&step) + && ctx.dest == Role::H1 + && ctx.shard == Some(target_shard) + { + data[0] ^= 1u8; + } + }); + + let world = TestWorld::>::with_shards(config); + let records = get_records(); + let _results: Vec<[Vec>; 3]> = world + .malicious(records.clone().into_iter(), |ctx, input| { + let match_keys = match ctx.shard_id() { + ShardIndex(0) => SHARD1_MKS, + ShardIndex(1) => SHARD2_MKS, + _ => panic!("invalid shard_id"), + }; + async move { + let indistinguishable_reports: Vec< + IndistinguishableHybridReport, + > = input.iter().map(|r| r.clone().into()).collect::>(); + + let prf_reports: Vec> = indistinguishable_reports + .iter() + .zip(match_keys) + .map(|(indist_report, match_key)| PrfHybridReport { + match_key, + value: indist_report.value.clone(), + breakdown_key: indist_report.breakdown_key.clone(), + }) + .collect::>(); + + aggregate_reports(ctx.clone(), prf_reports).await.unwrap() + } + }) + .await; + }); + } +} diff --git a/ipa-core/src/protocol/hybrid/breakdown_reveal.rs b/ipa-core/src/protocol/hybrid/breakdown_reveal.rs new file mode 100644 index 000000000..cb9e4caf6 --- /dev/null +++ b/ipa-core/src/protocol/hybrid/breakdown_reveal.rs @@ -0,0 +1,391 @@ +use std::{convert::Infallible, pin::pin}; + +use futures::stream; +use futures_util::{StreamExt, TryStreamExt}; +use tracing::{info_span, Instrument}; + +use crate::{ + error::{Error, UnwrapInfallible}, + ff::{boolean::Boolean, boolean_array::BooleanArray, U128Conversions}, + helpers::TotalRecords, + protocol::{ + basics::{reveal, Reveal}, + context::{ + dzkp_validator::DZKPValidator, Context, DZKPUpgraded, MaliciousProtocolSteps, + ShardedContext, UpgradableContext, + }, + ipa_prf::{ + aggregation::{ + aggregate_values, aggregate_values_proof_chunk, step::AggregationStep as Step, + AGGREGATE_DEPTH, + }, + oprf_padding::{apply_dp_padding, PaddingParameters}, + shuffle::Shuffle, + }, + BooleanProtocols, RecordId, + }, + report::hybrid::AggregateableHybridReport, + secret_sharing::{ + replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, FieldSimd, + TransposeFrom, Vectorizable, + }, + seq_join::seq_join, +}; + +/// Improved Aggregation a.k.a Aggregation revealing breakdown. +/// +/// Aggregation steps happen after attribution. the input for Aggregation is a +/// list of tuples containing Trigger Values (TV) and their corresponding +/// Breakdown Keys (BK), which were attributed in the previous step of IPA. The +/// output of Aggregation is a histogram, where each “bin” or "bucket" is a BK +/// and the value is the addition of all the TVs for it, hence the name +/// Aggregation. This can be thought as a SQL GROUP BY operation. +/// +/// The protocol involves four main steps: +/// 1. Shuffle the data to protect privacy (see [`shuffle_attributions`]). +/// 2. Reveal breakdown keys. This is the key difference to the previous +/// aggregation (see [`reveal_breakdowns`]). +/// 3. Add all values for each breakdown. +/// +/// This protocol explicitly manages proof batches for DZKP-based malicious security by +/// processing chunks of values from `intermediate_results.chunks()`. Procession +/// through record IDs is not uniform for all of the gates in the protocol. The first +/// layer of the reduction adds N pairs of records, the second layer adds N/2 pairs of +/// records, etc. This has a few consequences: +/// * We must specify a batch size of `usize::MAX` when calling `dzkp_validator`. +/// * We must track record IDs across chunks, so that subsequent chunks can +/// start from the last record ID that was used in the previous chunk. +/// * Because the first record ID in the proof batch is set implicitly, we must +/// guarantee that it submits multiplication intermediates before any other +/// record. This is currently ensured by the serial operation of the aggregation +/// protocol (i.e. by not using `seq_join`). +#[tracing::instrument(name = "breakdown_reveal_aggregation", skip_all, fields(total = attributed_values.len()))] +pub async fn breakdown_reveal_aggregation( + ctx: C, + attributed_values: Vec>, + padding_params: &PaddingParameters, +) -> Result>, Error> +where + C: UpgradableContext + Shuffle + ShardedContext, + Boolean: FieldSimd, + Replicated: BooleanProtocols, B>, + BK: BooleanArray + U128Conversions, + Replicated: Reveal, Output = >::Array>, + V: BooleanArray + U128Conversions, + HV: BooleanArray + U128Conversions, + BitDecomposed>: + for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, +{ + // Apply DP padding for Breakdown Reveal Aggregation + let attributed_values_padded = apply_dp_padding::<_, AggregateableHybridReport, B>( + ctx.narrow(&Step::PaddingDp), + attributed_values, + padding_params, + ) + .await?; + + let attributions = ctx + .narrow(&Step::Shuffle) + .shuffle(attributed_values_padded) + .instrument(info_span!("shuffle_attribution_outputs")) + .await?; + + // Revealing the breakdowns doesn't do any multiplies, so won't make it as far as + // doing a proof, but we need the validator to obtain an upgraded malicious context. + let validator = ctx.clone().dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::Reveal, + validate: &Step::RevealValidate, + }, + usize::MAX, + ); + let grouped_tvs = reveal_breakdowns(&validator.context(), attributions).await?; + validator.validate().await?; + let mut intermediate_results: Vec>> = grouped_tvs.into(); + + // Any real-world aggregation should be able to complete in two layers (two + // iterations of the `while` loop below). Tests with small `TARGET_PROOF_SIZE` + // may exceed that. + let mut chunk_counter = 0; + let mut depth = 0; + let agg_proof_chunk = aggregate_values_proof_chunk(B, usize::try_from(V::BITS).unwrap()); + + while intermediate_results.len() > 1 { + let mut record_ids = [RecordId::FIRST; AGGREGATE_DEPTH]; + let mut next_intermediate_results = Vec::new(); + for chunk in intermediate_results.chunks(agg_proof_chunk) { + let chunk_len = chunk.len(); + let validator = ctx.clone().dzkp_validator( + MaliciousProtocolSteps { + protocol: &Step::aggregate(depth), + validate: &Step::AggregateValidate, + }, + usize::MAX, // See note about batching above. + ); + let result = aggregate_values::<_, HV, B>( + validator.context(), + stream::iter(chunk).map(|v| Ok(v.clone())).boxed(), + chunk_len, + Some(&mut record_ids), + ) + .await?; + validator.validate_indexed(chunk_counter).await?; + chunk_counter += 1; + next_intermediate_results.push(result); + } + depth += 1; + intermediate_results = next_intermediate_results; + } + + Ok(intermediate_results + .into_iter() + .next() + .expect("aggregation input must not be empty")) +} + +/// Transforms the Breakdown key from a secret share into a revealed `usize`. +/// The input are the Atrributions and the output is a list of lists of secret +/// shared Trigger Values. Since Breakdown Keys are assumed to be dense the +/// first list contains all the possible Breakdowns, the index in the list +/// representing the Breakdown value. The second list groups all the Trigger +/// Values for that particular Breakdown. +#[tracing::instrument(name = "reveal_breakdowns", skip_all, fields( + total = attributions.len(), +))] +async fn reveal_breakdowns( + parent_ctx: &C, + attributions: Vec>, +) -> Result, Error> +where + C: Context, + Replicated: BooleanProtocols, + Boolean: FieldSimd, + BK: BooleanArray + U128Conversions, + Replicated: Reveal>::Array>, + V: BooleanArray + U128Conversions, +{ + let reveal_ctx = parent_ctx.set_total_records(TotalRecords::specified(attributions.len())?); + + let reveal_work = stream::iter(attributions).enumerate().map(|(i, report)| { + let record_id = RecordId::from(i); + let reveal_ctx = reveal_ctx.clone(); + async move { + let revealed_bk = reveal(reveal_ctx, record_id, &report.breakdown_key).await?; + let revealed_bk = BK::from_array(&revealed_bk); + let Ok(bk) = usize::try_from(revealed_bk.as_u128()) else { + return Err(Error::Internal); + }; + Ok::<_, Error>((bk, report.value)) + } + }); + let mut grouped_tvs = ValueHistogram::::new(); + let mut stream = pin!(seq_join(reveal_ctx.active_work(), reveal_work)); + while let Some((bk, tv)) = stream.try_next().await? { + grouped_tvs.push(bk, tv); + } + + Ok(grouped_tvs) +} + +/// Helper type that hold all the Trigger Values, grouped by their Breakdown +/// Key. The main functionality is to turn into a stream that can be given to +/// [`aggregate_values`]. +struct ValueHistogram { + tvs: [Vec>; B], + max_len: usize, +} + +impl ValueHistogram { + fn new() -> Self { + Self { + tvs: std::array::from_fn(|_| vec![]), + max_len: 0, + } + } + + fn push(&mut self, bk: usize, value: Replicated) { + self.tvs[bk].push(value); + if self.tvs[bk].len() > self.max_len { + self.max_len = self.tvs[bk].len(); + } + } +} + +impl From> + for Vec>> +where + Boolean: FieldSimd, + BitDecomposed>: + for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, +{ + fn from(mut grouped_tvs: ValueHistogram) -> Vec>> { + let iter = (0..grouped_tvs.max_len).map(move |_| { + let slice: [Replicated; B] = grouped_tvs + .tvs + .each_mut() + .map(|tv| tv.pop().unwrap_or(Replicated::ZERO)); + + BitDecomposed::transposed_from(&slice).unwrap_infallible() + }); + iter.collect() + } +} + +#[cfg(all(test, any(unit_test, feature = "shuttle")))] +pub mod tests { + use futures::TryFutureExt; + use rand::seq::SliceRandom; + + #[cfg(not(feature = "shuttle"))] + use crate::{ff::boolean_array::BA16, test_executor::run}; + use crate::{ + ff::{ + boolean::Boolean, + boolean_array::{BA3, BA5, BA8}, + U128Conversions, + }, + protocol::{ + hybrid::breakdown_reveal_aggregation, ipa_prf::oprf_padding::PaddingParameters, + }, + rand::Rng, + secret_sharing::{ + replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, TransposeFrom, + }, + test_executor::run_with, + test_fixture::{ + hybrid::TestAggregateableHybridReport, Reconstruct, Runner, TestWorld, TestWorldConfig, + WithShards, + }, + }; + + fn input_row(breakdown_key: usize, value: u128) -> TestAggregateableHybridReport { + TestAggregateableHybridReport { + match_key: (), + value: value.try_into().unwrap(), + breakdown_key: breakdown_key.try_into().unwrap(), + } + } + + #[test] + fn breakdown_reveal_semi_honest_happy_path() { + // if shuttle executor is enabled, run this test only once. + // it is a very expensive test to explore all possible states, + // sometimes github bails after 40 minutes of running it + // (workers there are really slow). + type HV = BA8; + const SHARDS: usize = 2; + run_with::<_, _, 3>(|| async { + let world = TestWorld::>::with_shards(TestWorldConfig::default()); + let mut rng = world.rng(); + let mut expectation = Vec::new(); + for _ in 0..32 { + expectation.push(rng.gen_range(0u128..256)); + } + let expectation = expectation; // no more mutability for safety + let mut inputs = Vec::new(); + for (bk, expected_hv) in expectation.iter().enumerate() { + let mut remainder = *expected_hv; + while remainder > 7 { + let tv = rng.gen_range(0u128..8); + remainder -= tv; + inputs.push(input_row(bk, tv)); + } + inputs.push(input_row(bk, remainder)); + } + inputs.shuffle(&mut rng); + let result: Vec<_> = world + .semi_honest(inputs.into_iter(), |ctx, reports| async move { + breakdown_reveal_aggregation::<_, BA5, BA3, HV, 32>( + ctx, + reports, + &PaddingParameters::relaxed(), + ) + .map_ok(|d: BitDecomposed>| { + Vec::transposed_from(&d).unwrap() + }) + .await + .unwrap() + }) + .await + .reconstruct(); + let initial = vec![0_u128; 32]; + let result = result + .iter() + .fold(initial, |mut acc, vec: &Vec| { + acc.iter_mut() + .zip(vec) + .for_each(|(a, &b)| *a += b.as_u128()); + acc + }) + .into_iter() + .collect::>(); + + assert_eq!(32, result.len()); + assert_eq!(result, expectation); + }); + } + + #[test] + #[cfg(not(feature = "shuttle"))] // too slow + fn breakdown_reveal_malicious_happy_path() { + type HV = BA16; + const SHARDS: usize = 2; + run(|| async { + let world = TestWorld::>::with_shards(TestWorldConfig::default()); + let mut rng = world.rng(); + let mut expectation = Vec::new(); + for _ in 0..32 { + expectation.push(rng.gen_range(0u128..512)); + } + // The size of input needed here to get complete coverage (more precisely, + // the size of input to the final aggregation using `aggregate_values`) + // depends on `TARGET_PROOF_SIZE`. + let expectation = expectation; // no more mutability for safety + let mut inputs = Vec::new(); + // Builds out inputs with values for each breakdown_key that add up to + // the expectation. Expectation is ranomg (0..512). Each iteration + // generates a value (0..8) and subtracts from the expectation until a final + // remaninder in (0..8) remains to be added to the vec. + for (breakdown_key, expected_value) in expectation.iter().enumerate() { + let mut remainder = *expected_value; + while remainder > 7 { + let value = rng.gen_range(0u128..8); + remainder -= value; + inputs.push(input_row(breakdown_key, value)); + } + inputs.push(input_row(breakdown_key, remainder)); + } + inputs.shuffle(&mut rng); + + let result: Vec<_> = world + .malicious(inputs.into_iter(), |ctx, reports| async move { + breakdown_reveal_aggregation::<_, BA5, BA3, HV, 32>( + ctx, + reports, + &PaddingParameters::relaxed(), + ) + .map_ok(|d: BitDecomposed>| { + Vec::transposed_from(&d).unwrap() + }) + .await + .unwrap() + }) + .await + .reconstruct(); + + let initial = vec![0_u128; 32]; + let result = result + .iter() + .fold(initial, |mut acc, vec: &Vec| { + acc.iter_mut() + .zip(vec) + .for_each(|(a, &b)| *a += b.as_u128()); + acc + }) + .into_iter() + .collect::>(); + assert_eq!(32, result.len()); + assert_eq!(result, expectation); + }); + } +} diff --git a/ipa-core/src/protocol/hybrid/mod.rs b/ipa-core/src/protocol/hybrid/mod.rs index 5aa14ed1c..d9f43f7dd 100644 --- a/ipa-core/src/protocol/hybrid/mod.rs +++ b/ipa-core/src/protocol/hybrid/mod.rs @@ -1,6 +1,10 @@ +pub(crate) mod agg; +pub(crate) mod breakdown_reveal; pub(crate) mod oprf; pub(crate) mod step; +use std::convert::Infallible; + use tracing::{info_span, Instrument}; use crate::{ @@ -11,9 +15,11 @@ use crate::{ }, helpers::query::DpMechanism, protocol::{ - basics::{BooleanProtocols, Reveal}, + basics::{BooleanArrayMul, Reveal}, context::{DZKPUpgraded, MacUpgraded, ShardedContext, UpgradableContext}, hybrid::{ + agg::aggregate_reports, + breakdown_reveal::breakdown_reveal_aggregation, oprf::{compute_prf_and_reshard, BreakdownKey, CONV_CHUNK, PRF_CHUNK}, step::HybridStep as Step, }, @@ -23,9 +29,13 @@ use crate::{ shuffle::Shuffle, }, prss::FromPrss, + BooleanProtocols, }, report::hybrid::{IndistinguishableHybridReport, PrfHybridReport}, - secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, Vectorizable}, + secret_sharing::{ + replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, FieldSimd, + TransposeFrom, Vectorizable, + }, }; /// The Hybrid Protocol @@ -64,12 +74,19 @@ where BK: BreakdownKey, V: BooleanArray + U128Conversions, HV: BooleanArray + U128Conversions, + Boolean: FieldSimd, Replicated: BooleanProtocols, CONV_CHUNK>, Replicated: PrfSharing, PRF_CHUNK, Field = Fp25519> + FromPrss, Replicated: Reveal, Output = >::Array>, PrfHybridReport: Serializable, + Replicated: BooleanProtocols>, + Replicated: BooleanProtocols, B>, + Replicated: BooleanArrayMul> + + Reveal, Output = >::Array>, + BitDecomposed>: + for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, { if input_rows.is_empty() { return Ok(vec![Replicated::ZERO; B]); @@ -89,7 +106,15 @@ where .instrument(info_span!("shuffle_inputs")) .await?; - let _sharded_reports = compute_prf_and_reshard(ctx.clone(), shuffled_input_rows).await?; + let sharded_reports = compute_prf_and_reshard(ctx.clone(), shuffled_input_rows).await?; + + let aggregated_reports = aggregate_reports::(ctx.clone(), sharded_reports).await?; + + let _historgram = breakdown_reveal_aggregation::( + ctx.clone(), + aggregated_reports, + &dp_padding_params, + ); unimplemented!("protocol::hybrid::hybrid_protocol is not fully implemented") } diff --git a/ipa-core/src/protocol/hybrid/oprf.rs b/ipa-core/src/protocol/hybrid/oprf.rs index adc1da9ff..da2bf903f 100644 --- a/ipa-core/src/protocol/hybrid/oprf.rs +++ b/ipa-core/src/protocol/hybrid/oprf.rs @@ -1,3 +1,5 @@ +use std::cmp::max; + use futures::{stream, StreamExt, TryStreamExt}; use typenum::Const; @@ -17,8 +19,10 @@ use crate::{ protocol::{ basics::{BooleanProtocols, Reveal}, context::{ - dzkp_validator::DZKPValidator, reshard_try_stream, DZKPUpgraded, MacUpgraded, - MaliciousProtocolSteps, ShardedContext, UpgradableContext, Validator, + dzkp_validator::{DZKPValidator, TARGET_PROOF_SIZE}, + reshard_try_stream, DZKPUpgraded, MacUpgraded, MaliciousProtocolSteps, ShardedContext, + ShardedUpgradedMaliciousContext, UpgradableContext, UpgradedMaliciousContext, + Validator, }, hybrid::step::HybridStep, ipa_prf::{ @@ -26,15 +30,17 @@ use crate::{ prf_eval::{eval_dy_prf, PrfSharing}, }, prss::{FromPrss, SharedRandomness}, - RecordId, + BasicProtocols, RecordId, }, report::hybrid::{IndistinguishableHybridReport, PrfHybridReport}, secret_sharing::{ - replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, TransposeFrom, - Vectorizable, + replicated::{malicious, semi_honest::AdditiveShare as Replicated}, + BitDecomposed, FieldSimd, TransposeFrom, Vectorizable, }, seq_join::seq_join, + utils::non_zero_prev_power_of_two, }; + // In theory, we could support (runtime-configured breakdown count) ≤ (compile-time breakdown count) // ≤ 2^|bk|, with all three values distinct, but at present, there is no runtime configuration and // the latter two must be equal. The implementation of `move_single_value_to_bucket` does support a @@ -64,13 +70,31 @@ pub const CONV_CHUNK: usize = 256; /// Vectorization dimension for PRF pub const PRF_CHUNK: usize = 16; -// We expect 2*256 = 512 gates in total for two additions per conversion. The vectorization factor -// is CONV_CHUNK. Let `len` equal the number of converted shares. The total amount of -// multiplications is CONV_CHUNK*512*len. We want CONV_CHUNK*512*len ≈ 50M, or len ≈ 381, for a -// reasonably-sized proof. There is also a constraint on proof chunks to be powers of two, so -// we pick the closest power of two close to 381 but less than that value. 256 gives us around 33M -// multiplications per batch -const CONV_PROOF_CHUNK: usize = 256; +/// Returns a suitable proof chunk size (in records) for use with `convert_to_fp25519`. +/// +/// We expect 2*256 = 512 gates in total for two additions per conversion. The +/// vectorization factor is `CONV_CHUNK`. Let `len` equal the number of converted +/// shares. The total amount of multiplications is `CONV_CHUNK`*512*len. We want +/// `CONV_CHUNK`*512*len ≈ 50M for a reasonably-sized proof. There is also a constraint +/// on proof chunks to be powers of two, and we don't want to compute a proof chunk +/// of zero when `TARGET_PROOF_SIZE` is smaller for tests. +fn conv_proof_chunk() -> usize { + non_zero_prev_power_of_two(max(2, TARGET_PROOF_SIZE / CONV_CHUNK / 512)) +} + +/// Allow MAC-malicious shares to be used for PRF generation with shards +impl<'a, const N: usize> PrfSharing, N> + for Replicated +where + Fp25519: FieldSimd, + RP25519: Vectorizable, + malicious::AdditiveShare: + BasicProtocols, Fp25519, N>, + Replicated: FromPrss, +{ + type Field = Fp25519; + type UpgradedSharing = malicious::AdditiveShare; +} /// This computes the Dodis-Yampolsky PRF value on every match key from input, /// and reshards the reports according to the computed PRF. At the end, reports with the @@ -101,7 +125,7 @@ where protocol: &HybridStep::ConvertFp25519, validate: &HybridStep::ConvertFp25519Validate, }, - CONV_PROOF_CHUNK, + conv_proof_chunk(), ); let m_ctx = validator.context(); @@ -224,9 +248,8 @@ mod test { }, ]; - // TODO: we need to use malicious circuits here let reports_per_shard = world - .semi_honest(records.clone().into_iter(), |ctx, reports| async move { + .malicious(records.clone().into_iter(), |ctx, reports| async move { let ind_reports = reports .into_iter() .map(IndistinguishableHybridReport::from) diff --git a/ipa-core/src/protocol/hybrid/step.rs b/ipa-core/src/protocol/hybrid/step.rs index b2b516416..b4022551d 100644 --- a/ipa-core/src/protocol/hybrid/step.rs +++ b/ipa-core/src/protocol/hybrid/step.rs @@ -15,6 +15,16 @@ pub(crate) enum HybridStep { #[step(child = crate::protocol::context::step::MaliciousProtocolStep)] EvalPrf, ReshardByPrf, - Finalize, - FinalizeValidate, + #[step(child = AggregateReportsStep)] + GroupBySum, + #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] + GroupBySumValidate, +} + +#[derive(CompactStep)] +pub(crate) enum AggregateReportsStep { + #[step(child = crate::protocol::boolean::step::EightBitStep)] + AddBK, + #[step(child = crate::protocol::boolean::step::EightBitStep)] + AddV, } diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs index d3e78fd8c..6a9adb345 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs @@ -25,6 +25,7 @@ use crate::{ replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, FieldSimd, SharedValue, TransposeFrom, Vectorizable, }, + utils::non_zero_prev_power_of_two, }; pub(crate) mod breakdown_reveal; @@ -96,8 +97,14 @@ pub type AggResult = Result /// saturating the output) is: /// /// $\sum_{i = 1}^k 2^{k - i} (b + i - 1) \approx 2^k (b + 1) = N (b + 1)$ +/// +/// We set a floor of 2 to avoid computing a chunk of zero when `TARGET_PROOF_SIZE` is +/// smaller for tests. pub fn aggregate_values_proof_chunk(input_width: usize, input_item_bits: usize) -> usize { - max(2, TARGET_PROOF_SIZE / input_width / (input_item_bits + 1)).next_power_of_two() + non_zero_prev_power_of_two(max( + 2, + TARGET_PROOF_SIZE / input_width / (input_item_bits + 1), + )) } // This is the step count for AggregateChunkStep. We need it to size RecordId arrays. diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs index 2dabdc3f4..b98fe9612 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs @@ -381,7 +381,7 @@ mod tests { helpers::stream::process_slice_by_chunks, protocol::{ context::{dzkp_validator::DZKPValidator, UpgradableContext, TEST_DZKP_STEPS}, - ipa_prf::{CONV_CHUNK, CONV_PROOF_CHUNK, PRF_CHUNK}, + ipa_prf::{conv_proof_chunk, CONV_CHUNK, PRF_CHUNK}, }, rand::thread_rng, secret_sharing::SharedValue, @@ -415,7 +415,7 @@ mod tests { let [res0, res1, res2] = world .semi_honest(records.into_iter(), |ctx, records| async move { let c_ctx = ctx.set_total_records((COUNT + CONV_CHUNK - 1) / CONV_CHUNK); - let validator = &c_ctx.dzkp_validator(TEST_DZKP_STEPS, CONV_PROOF_CHUNK); + let validator = &c_ctx.dzkp_validator(TEST_DZKP_STEPS, conv_proof_chunk()); let m_ctx = validator.context(); seq_join( m_ctx.active_work(), diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs index 7f805040f..1f598cb31 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/lagrange.rs @@ -2,7 +2,9 @@ use std::{array::from_fn, fmt::Debug}; use typenum::Unsigned; -use crate::ff::{batch_invert, Field, PrimeField, Serializable}; +use crate::ff::{ + batch_invert, Field, MultiplyAccumulate, MultiplyAccumulator, PrimeField, Serializable, +}; /// The Canonical Lagrange denominator is defined as the denominator of the Lagrange base polynomials /// `https://en.wikipedia.org/wiki/Lagrange_polynomial` @@ -84,7 +86,7 @@ where /// The "x coordinate" of the output point is `x_output`. pub fn new(denominator: &CanonicalLagrangeDenominator, x_output: &F) -> Self { // assertion that table is not too large for the stack - assert!(::Size::USIZE * N < 2024); + debug_assert!(::Size::USIZE * N < 2024); let table = Self::compute_table_row(x_output, denominator); LagrangeTable:: { table: [table; 1] } @@ -162,23 +164,14 @@ where /// Computes the dot product of two arrays of the same size. /// It is isolated from Lagrange because there could be potential SIMD optimizations used fn dot_product(a: &[F; N], b: &[F; N]) -> F { - // Staying in integers allows rustc to optimize this code properly, but puts a restriction - // on how large the prime field can be - debug_assert!( - 2 * F::BITS + N.next_power_of_two().ilog2() <= 128, - "The prime field {} is too large for this dot product implementation", - F::PRIME.into() - ); - - let mut sum = 0; - // I am cautious about using zip in hot code // https://github.com/rust-lang/rust/issues/103555 + + let mut acc = ::Accumulator::new(); for i in 0..N { - sum += a[i].as_u128() * b[i].as_u128(); + acc.multiply_accumulate(a[i], b[i]); } - - F::truncate_from(sum) + acc.take() } #[cfg(all(test, unit_test))] diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs index 607827b1a..9d366735b 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/mod.rs @@ -1,3 +1,11 @@ pub mod lagrange; pub mod prover; pub mod verifier; + +pub type FirstProofGenerator = prover::SmallProofGenerator; +pub type CompressedProofGenerator = prover::SmallProofGenerator; +pub use lagrange::{CanonicalLagrangeDenominator, LagrangeTable}; +pub use prover::ProverTableIndices; +pub use verifier::VerifierTableIndices; + +pub const FIRST_RECURSION_FACTOR: usize = FirstProofGenerator::RECURSION_FACTOR; diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs index af451b458..aa1b25e9b 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/prover.rs @@ -1,17 +1,25 @@ -use std::{borrow::Borrow, iter::zip, marker::PhantomData}; +use std::{array, borrow::Borrow, marker::PhantomData}; -#[cfg(all(test, unit_test))] -use crate::ff::Fp31; use crate::{ error::Error::{self, DZKPMasks}, - ff::{Fp61BitPrime, PrimeField}, + ff::{Fp61BitPrime, MultiplyAccumulate, MultiplyAccumulatorArray, PrimeField}, helpers::hashing::{compute_hash, hash_to_field}, protocol::{ - context::Context, - ipa_prf::malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + context::{ + dzkp_field::{TABLE_U, TABLE_V}, + Context, + }, + ipa_prf::{ + malicious_security::{ + lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + FIRST_RECURSION_FACTOR as FRF, + }, + CompressedProofGenerator, + }, prss::SharedRandomness, RecordId, RecordIdRange, }, + secret_sharing::SharedValue, }; /// This struct stores intermediate `uv` values. @@ -84,8 +92,8 @@ where // compute final uv values let (u_values, v_values) = &mut self.uv_chunks[0]; // shift first element to last position - u_values[SmallProofGenerator::RECURSION_FACTOR - 1] = u_values[0]; - v_values[SmallProofGenerator::RECURSION_FACTOR - 1] = v_values[0]; + u_values[CompressedProofGenerator::RECURSION_FACTOR - 1] = u_values[0]; + v_values[CompressedProofGenerator::RECURSION_FACTOR - 1] = v_values[0]; // set masks in first position u_values[0] = my_p_mask; v_values[0] = my_q_mask; @@ -94,6 +102,213 @@ where } } +/// Trait for inputs to Lagrange interpolation on the prover. +/// +/// Lagrange interpolation is used in the prover in two ways: +/// +/// 1. To extrapolate an additional λ - 1 y-values of degree-(λ-1) polynomials so that the +/// total 2λ - 1 y-values can be multiplied to obtain a representation of the product of +/// the polynomials. +/// 2. To evaluate polynomials at the randomly chosen challenge point _r_. +/// +/// The two methods in this trait correspond to those two uses. +/// +/// There are two implementations of this trait: `ProverTableIndices`, and +/// `ProverValues`. `ProverTableIndices` is used for the input to the first proof. Each +/// set of 4 _u_ or _v_ values input to the first proof has one of eight possible +/// values, determined by the values of the 3 associated multiplication intermediates. +/// The `ProverTableIndices` implementation uses a lookup table containing the output of +/// the Lagrange interpolation for each of these eight possible values. The +/// `ProverValues` implementation, which represents actual _u_ and _v_ values, is used +/// by the remaining recursive proofs. +/// +/// There is a similar trait `VerifierLagrangeInput` in `verifier.rs`. The difference is +/// that the prover operates on _u_ and _v_ values simultaneously (i.e. iterators of +/// tuples). The verifier operates on only one of _u_ or _v_ at a time. +pub trait ProverLagrangeInput { + fn extrapolate_y_values<'a, const P: usize, const M: usize>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a; + + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a; +} + +/// Implementation of `ProverLagrangeInput` for table indices, used for the first proof. +#[derive(Clone)] +pub struct ProverTableIndices>(pub I); + +/// Iterator returned by `ProverTableIndices::extrapolate_y_values` and +/// `ProverTableIndices::eval_at_r`. +struct TableIndicesIterator> { + input: I, + u_table: [T; 8], + v_table: [T; 8], +} + +impl> ProverLagrangeInput + for ProverTableIndices +{ + fn extrapolate_y_values<'a, const P: usize, const M: usize>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + TableIndicesIterator { + input: self.0, + u_table: array::from_fn(|i| { + let mut result = [Fp61BitPrime::ZERO; P]; + let u = &TABLE_U[i]; + result[0..FRF].copy_from_slice(u); + result[FRF..].copy_from_slice(&lagrange_table.eval(u)); + result + }), + v_table: array::from_fn(|i| { + let mut result = [Fp61BitPrime::ZERO; P]; + let v = &TABLE_V[i]; + result[0..FRF].copy_from_slice(v); + result[FRF..].copy_from_slice(&lagrange_table.eval(v)); + result + }), + } + } + + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + TableIndicesIterator { + input: self.0, + u_table: array::from_fn(|i| lagrange_table.eval(&TABLE_U[i])[0]), + v_table: array::from_fn(|i| lagrange_table.eval(&TABLE_V[i])[0]), + } + } +} + +impl> Iterator for TableIndicesIterator { + type Item = (T, T); + + fn next(&mut self) -> Option { + self.input.next().map(|(u_index, v_index)| { + ( + self.u_table[usize::from(u_index)].clone(), + self.v_table[usize::from(v_index)].clone(), + ) + }) + } +} + +/// Implementation of `ProverLagrangeInput` for _u_ and _v_ values, used for subsequent +/// recursive proofs. +#[derive(Clone)] +pub struct ProverValues>(pub I); + +/// Iterator returned by `ProverValues::extrapolate_y_values`. +struct ValuesExtrapolateIterator< + 'a, + F: PrimeField, + const L: usize, + const P: usize, + const M: usize, + I: Iterator, +> { + input: I, + lagrange_table: &'a LagrangeTable, +} + +impl< + 'a, + F: PrimeField, + const L: usize, + const P: usize, + const M: usize, + I: Iterator, + > Iterator for ValuesExtrapolateIterator<'a, F, L, P, M, I> +{ + type Item = ([F; P], [F; P]); + + fn next(&mut self) -> Option { + self.input.next().map(|(u_values, v_values)| { + let mut u = [F::ZERO; P]; + u[0..L].copy_from_slice(&u_values); + u[L..].copy_from_slice(&self.lagrange_table.eval(&u_values)); + let mut v = [F::ZERO; P]; + v[0..L].copy_from_slice(&v_values); + v[L..].copy_from_slice(&self.lagrange_table.eval(&v_values)); + (u, v) + }) + } +} + +/// Iterator returned by `ProverValues::eval_at_r`. +struct ValuesEvalAtRIterator< + 'a, + F: PrimeField, + const L: usize, + I: Iterator, +> { + input: I, + lagrange_table: &'a LagrangeTable, +} + +impl<'a, F: PrimeField, const L: usize, I: Iterator> Iterator + for ValuesEvalAtRIterator<'a, F, L, I> +{ + type Item = (F, F); + + fn next(&mut self) -> Option { + self.input.next().map(|(u_values, v_values)| { + ( + self.lagrange_table.eval(&u_values)[0], + self.lagrange_table.eval(&v_values)[0], + ) + }) + } +} + +impl> ProverLagrangeInput + for ProverValues +{ + fn extrapolate_y_values<'a, const P: usize, const M: usize>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + debug_assert_eq!(L + M, P); + ValuesExtrapolateIterator { + input: self.0, + lagrange_table, + } + } + + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + ValuesEvalAtRIterator { + input: self.0, + lagrange_table, + } + } +} + /// This struct sets up the parameter for the proof generation /// and provides several functions to generate zero knowledge proofs. /// @@ -105,15 +320,11 @@ pub struct ProofGenerator, } -#[cfg(all(test, unit_test))] -pub type TestProofGenerator = ProofGenerator; - // Compression Factor is L // P, Proof size is 2*L - 1 // M, the number of interpolated points is L - 1 // The reason we need these is that Rust doesn't support basic math operations on const generics -pub type SmallProofGenerator = ProofGenerator; -pub type LargeProofGenerator = ProofGenerator; +pub type SmallProofGenerator = ProofGenerator; impl ProofGenerator { // define constants such that they can be used externally @@ -122,40 +333,47 @@ impl ProofGenerat pub const PROOF_LENGTH: usize = P; pub const LAGRANGE_LENGTH: usize = M; + pub fn compute_proof_from_uv(uv: J, lagrange_table: &LagrangeTable) -> [F; P] + where + J: Iterator, + J::Item: Borrow<([F; L], [F; L])>, + { + Self::compute_proof(uv.map(|uv| { + let (u, v) = uv.borrow(); + let mut u_ex = [F::ZERO; P]; + let mut v_ex = [F::ZERO; P]; + u_ex[0..L].copy_from_slice(u); + v_ex[0..L].copy_from_slice(v); + u_ex[L..].copy_from_slice(&lagrange_table.eval(u)); + v_ex[L..].copy_from_slice(&lagrange_table.eval(v)); + (u_ex, v_ex) + })) + } + /// /// Distributed Zero Knowledge Proofs algorithm drawn from /// `https://eprint.iacr.org/2023/909.pdf` - pub fn compute_proof(uv_iterator: J, lagrange_table: &LagrangeTable) -> [F; P] + pub fn compute_proof(pq_iterator: J) -> [F; P] where J: Iterator, - J::Item: Borrow<([F; L], [F; L])>, + J::Item: Borrow<([F; P], [F; P])>, { - let mut proof = [F::ZERO; P]; - for uv_polynomial in uv_iterator { - for (i, proof_part) in proof.iter_mut().enumerate().take(L) { - *proof_part += uv_polynomial.borrow().0[i] * uv_polynomial.borrow().1[i]; - } - let p_extrapolated = lagrange_table.eval(&uv_polynomial.borrow().0); - let q_extrapolated = lagrange_table.eval(&uv_polynomial.borrow().1); - - for (i, (x, y)) in - zip(p_extrapolated.into_iter(), q_extrapolated.into_iter()).enumerate() - { - proof[L + i] += x * y; - } - } - proof + pq_iterator + .fold( + ::AccumulatorArray::

::new(), + |mut proof, pq| { + proof.multiply_accumulate(&pq.borrow().0, &pq.borrow().1); + proof + }, + ) + .take() } - fn gen_challenge_and_recurse( + fn gen_challenge_and_recurse, const N: usize>( proof_left: &[F; P], proof_right: &[F; P], - uv_iterator: J, - ) -> UVValues - where - J: Iterator, - J::Item: Borrow<([F; L], [F; L])>, - { + uv_iterator: I, + ) -> UVValues { let r: F = hash_to_field( &compute_hash(proof_left), &compute_hash(proof_right), @@ -165,17 +383,8 @@ impl ProofGenerat let denominator = CanonicalLagrangeDenominator::::new(); let lagrange_table_r = LagrangeTable::::new(&denominator, &r); - // iter and interpolate at x coordinate r uv_iterator - .map(|polynomial| { - let (u_chunk, v_chunk) = polynomial.borrow(); - ( - // new u value - lagrange_table_r.eval(u_chunk)[0], - // new v value - lagrange_table_r.eval(v_chunk)[0], - ) - }) + .eval_at_r(&lagrange_table_r) .collect::>() } @@ -205,29 +414,23 @@ impl ProofGenerat proof_other_share } - /// This function is a helper function that computes the next proof - /// from an iterator over uv values - /// It also computes the next uv values + /// This function is a helper function that, given the computed proof, computes the shares of + /// the proof, the challenge, and the next uv values. /// /// It output `(uv values, share_of_proof_from_prover_left, my_proof_left_share)` /// where /// `share_of_proof_from_prover_left` from left has type `Vec<[F; P]>`, /// `my_proof_left_share` has type `Vec<[F; P]>`, - pub fn gen_artefacts_from_recursive_step( + pub fn gen_artefacts_from_recursive_step( ctx: &C, record_ids: &mut RecordIdRange, - lagrange_table: &LagrangeTable, - uv_iterator: J, + my_proof: [F; P], + uv_iterator: I, ) -> (UVValues, [F; P], [F; P]) where C: Context, - J: Iterator + Clone, - J::Item: Borrow<([F; L], [F; L])>, + I: ProverLagrangeInput, { - // generate next proof - // from iterator - let my_proof = Self::compute_proof(uv_iterator.clone(), lagrange_table); - // generate proof shares from prss let (share_of_proof_from_prover_left, my_proof_right_share) = Self::gen_proof_shares_from_prss(ctx, record_ids); @@ -257,23 +460,32 @@ mod test { use std::iter::zip; use futures::future::try_join; + use rand::Rng; + use super::*; use crate::{ ff::{Fp31, Fp61BitPrime, PrimeField, U128Conversions}, helpers::{Direction, Role}, protocol::{ - context::Context, - ipa_prf::malicious_security::{ - lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, - prover::{LargeProofGenerator, SmallProofGenerator, TestProofGenerator, UVValues}, + context::{ + dzkp_field::tests::reference_convert, + dzkp_validator::{MultiplicationInputsBlock, BIT_ARRAY_LEN}, + Context, + }, + ipa_prf::{ + malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + FirstProofGenerator, }, RecordId, RecordIdRange, }, seq_join::SeqJoin, - test_executor::run, + test_executor::{run, run_random}, test_fixture::{Runner, TestWorld}, }; + type TestProofGenerator = ProofGenerator; + type LargeProofGenerator = ProofGenerator; + fn zip_chunks(a: I, b: J) -> UVValues where I: IntoIterator, @@ -316,7 +528,7 @@ mod test { let uv_1 = zip_chunks(U_1, V_1); // first iteration - let proof_1 = TestProofGenerator::compute_proof(uv_1.iter(), &lagrange_table); + let proof_1 = TestProofGenerator::compute_proof_from_uv(uv_1.iter(), &lagrange_table); assert_eq!( proof_1.iter().map(Fp31::as_u128).collect::>(), PROOF_1, @@ -335,12 +547,12 @@ mod test { let uv_2 = TestProofGenerator::gen_challenge_and_recurse( &proof_left_1, &proof_right_1, - uv_1.iter(), + ProverValues(uv_1.iter().copied()), ); assert_eq!(uv_2, zip_chunks(U_2, V_2)); // next iteration - let proof_2 = TestProofGenerator::compute_proof(uv_2.iter(), &lagrange_table); + let proof_2 = TestProofGenerator::compute_proof_from_uv(uv_2.iter(), &lagrange_table); assert_eq!( proof_2.iter().map(Fp31::as_u128).collect::>(), PROOF_2, @@ -359,7 +571,7 @@ mod test { let uv_3 = TestProofGenerator::gen_challenge_and_recurse::<_, 4>( &proof_left_2, &proof_right_2, - uv_2.iter(), + ProverValues(uv_2.iter().copied()), ); assert_eq!(uv_3, zip_chunks(U_3, V_3)); @@ -369,7 +581,8 @@ mod test { ); // final iteration - let proof_3 = TestProofGenerator::compute_proof(masked_uv_3.iter(), &lagrange_table); + let proof_3 = + TestProofGenerator::compute_proof_from_uv(masked_uv_3.iter(), &lagrange_table); assert_eq!( proof_3.iter().map(Fp31::as_u128).collect::>(), PROOF_3, @@ -397,11 +610,12 @@ mod test { // first iteration let world = TestWorld::default(); let mut record_ids = RecordIdRange::ALL; + let proof = TestProofGenerator::compute_proof_from_uv(uv_1.iter(), &lagrange_table); let (uv_values, _, _) = TestProofGenerator::gen_artefacts_from_recursive_step::<_, _, 4>( &world.contexts()[0], &mut record_ids, - &lagrange_table, - uv_1.iter(), + proof, + ProverValues(uv_1.iter().copied()), ); assert_eq!(7, uv_values.len()); @@ -436,14 +650,14 @@ mod test { >::from(denominator); // compute proof - let proof = SmallProofGenerator::compute_proof(uv_before.iter(), &lagrange_table); + let proof = SmallProofGenerator::compute_proof_from_uv(uv_before.iter(), &lagrange_table); assert_eq!(proof.len(), SmallProofGenerator::PROOF_LENGTH); let uv_after = SmallProofGenerator::gen_challenge_and_recurse::<_, 8>( &proof, &proof, - uv_before.iter(), + ProverValues(uv_before.iter().copied()), ); assert_eq!( @@ -472,14 +686,14 @@ mod test { >::from(denominator); // compute proof - let proof = LargeProofGenerator::compute_proof(uv_before.iter(), &lagrange_table); + let proof = LargeProofGenerator::compute_proof_from_uv(uv_before.iter(), &lagrange_table); assert_eq!(proof.len(), LargeProofGenerator::PROOF_LENGTH); let uv_after = LargeProofGenerator::gen_challenge_and_recurse::<_, 8>( &proof, &proof, - uv_before.iter(), + ProverValues(uv_before.iter().copied()), ); assert_eq!( @@ -594,4 +808,51 @@ mod test { assert_two_part_secret_sharing(PROOF_2, h1_proof_right, h3_proof_left); assert_two_part_secret_sharing(PROOF_3, h2_proof_right, h1_proof_left); } + + #[test] + fn prover_table_indices_equivalence() { + run_random(|mut rng| async move { + const FPL: usize = FirstProofGenerator::PROOF_LENGTH; + const FLL: usize = FirstProofGenerator::LAGRANGE_LENGTH; + + let block = rng.gen::(); + + // Test equivalence for extrapolate_y_values + let denominator = CanonicalLagrangeDenominator::new(); + let lagrange_table = LagrangeTable::from(denominator); + + assert!(ProverTableIndices(block.table_indices_prover().into_iter()) + .extrapolate_y_values::(&lagrange_table) + .eq(ProverValues((0..BIT_ARRAY_LEN).map(|i| { + reference_convert( + block.x_left[i], + block.x_right[i], + block.y_left[i], + block.y_right[i], + block.prss_left[i], + block.prss_right[i], + ) + })) + .extrapolate_y_values::(&lagrange_table))); + + // Test equivalence for eval_at_r + let denominator = CanonicalLagrangeDenominator::new(); + let r = rng.gen(); + let lagrange_table_r = LagrangeTable::new(&denominator, &r); + + assert!(ProverTableIndices(block.table_indices_prover().into_iter()) + .eval_at_r(&lagrange_table_r) + .eq(ProverValues((0..BIT_ARRAY_LEN).map(|i| { + reference_convert( + block.x_left[i], + block.x_right[i], + block.y_left[i], + block.y_right[i], + block.prss_left[i], + block.prss_right[i], + ) + })) + .eval_at_r(&lagrange_table_r))); + }); + } } diff --git a/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs b/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs index e62465b73..0038d4b85 100644 --- a/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs +++ b/ipa-core/src/protocol/ipa_prf/malicious_security/verifier.rs @@ -1,9 +1,18 @@ -use std::iter::{self}; +use std::{ + array, + iter::{self}, +}; use crate::{ ff::PrimeField, - protocol::ipa_prf::malicious_security::lagrange::{ - CanonicalLagrangeDenominator, LagrangeTable, + protocol::{ + context::{ + dzkp_field::UVTable, + dzkp_validator::{MAX_PROOF_RECURSION, MIN_PROOF_RECURSION}, + }, + ipa_prf::malicious_security::{ + CanonicalLagrangeDenominator, LagrangeTable, FIRST_RECURSION_FACTOR as FRF, + }, }, utils::arraychunks::ArrayChunkIterator, }; @@ -96,35 +105,30 @@ pub fn compute_final_sum_share(zk } /// This function compresses the `u_or_v` values and returns the next `u_or_v` values. -fn recurse_u_or_v<'a, F: PrimeField, J, const L: usize>( - u_or_v: J, +fn recurse_u_or_v<'a, F: PrimeField, const L: usize>( + u_or_v: impl Iterator + 'a, lagrange_table: &'a LagrangeTable, -) -> impl Iterator + 'a -where - J: Iterator + 'a, -{ - u_or_v - .chunk_array::() - .map(|x| lagrange_table.eval(&x)[0]) +) -> impl Iterator + 'a { + VerifierValues(u_or_v.chunk_array::()).eval_at_r(lagrange_table) } /// This function recursively compresses the `u_or_v` values. -/// The recursion factor (or compression) of the first recursion is `L_FIRST` -/// The recursion factor of all following recursions is `L`. -pub fn recursively_compute_final_check( - u_or_v: J, +/// +/// The recursion factor (or compression) of the first recursion is fixed at +/// `FIRST_RECURSION_FACTOR` (`FRF`). The recursion factor of all following recursions +/// is `L`. +pub fn recursively_compute_final_check( + input: impl VerifierLagrangeInput, challenges: &[F], p_or_q_0: F, -) -> F -where - J: Iterator, -{ +) -> F { + // This function requires MIN_PROOF_RECURSION be at least 2. + assert!(challenges.len() >= MIN_PROOF_RECURSION && challenges.len() <= MAX_PROOF_RECURSION); let recursions_after_first = challenges.len() - 1; // compute Lagrange tables - let denominator_p_or_q_first = CanonicalLagrangeDenominator::::new(); - let table_first = - LagrangeTable::::new(&denominator_p_or_q_first, &challenges[0]); + let denominator_p_or_q_first = CanonicalLagrangeDenominator::::new(); + let table_first = LagrangeTable::::new(&denominator_p_or_q_first, &challenges[0]); let denominator_p_or_q = CanonicalLagrangeDenominator::::new(); let tables = challenges[1..] .iter() @@ -133,11 +137,10 @@ where // generate & evaluate recursive streams // to compute last array - let mut iterator: Box> = - Box::new(recurse_u_or_v::<_, _, L_FIRST>(u_or_v, &table_first)); + let mut iterator: Box> = Box::new(input.eval_at_r(&table_first)); // all following recursion except last one for lagrange_table in tables.iter().take(recursions_after_first - 1) { - iterator = Box::new(recurse_u_or_v::<_, _, L>(iterator, lagrange_table)); + iterator = Box::new(recurse_u_or_v(iterator, lagrange_table)); } let last_u_or_v_values = iterator.collect::>(); // Make sure there are less than L last u or v values. The prover is expected to continue @@ -162,18 +165,127 @@ where tables.last().unwrap().eval(&last_array)[0] } +/// Trait for inputs to Lagrange interpolation on the verifier. +/// +/// Lagrange interpolation is used in the verifier to evaluate polynomials at the +/// randomly chosen challenge point _r_. +/// +/// There are two implementations of this trait: `VerifierTableIndices`, and +/// `VerifierValues`. `VerifierTableIndices` is used for the input to the first proof. +/// Each set of 4 _u_ or _v_ values input to the first proof has one of eight possible +/// values, determined by the values of the 3 associated multiplication intermediates. +/// The `VerifierTableIndices` implementation uses a lookup table containing the output +/// of the Lagrange interpolation for each of these eight possible values. The +/// `VerifierValues` implementation, which represents actual _u_ and _v_ values, is used +/// by the remaining recursive proofs. +/// +/// There is a similar trait `ProverLagrangeInput` in `prover.rs`. The difference is +/// that the prover operates on _u_ and _v_ values simultaneously (i.e. iterators of +/// tuples). The verifier operates on only one of _u_ or _v_ at a time. +pub trait VerifierLagrangeInput { + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a; +} + +/// Implementation of `VerifierLagrangeInput` for table indices, used for the first proof. +pub struct VerifierTableIndices<'a, F: PrimeField, I: Iterator> { + pub input: I, + pub table: &'a UVTable, +} + +/// Iterator returned by `VerifierTableIndices::eval_at_r`. +struct TableIndicesIterator> { + input: I, + table: [F; 8], +} + +impl> VerifierLagrangeInput + for VerifierTableIndices<'_, F, I> +{ + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + TableIndicesIterator { + input: self.input, + table: array::from_fn(|i| lagrange_table.eval(&self.table[i])[0]), + } + } +} + +impl> Iterator for TableIndicesIterator { + type Item = F; + + fn next(&mut self) -> Option { + self.input + .next() + .map(|index| self.table[usize::from(index)]) + } +} + +/// Implementation of `VerifierLagrangeInput` for _u_ and _v_ values, used for +/// subsequent recursive proofs. +pub struct VerifierValues>(pub I); + +/// Iterator returned by `ProverValues::eval_at_r`. +struct ValuesEvalAtRIterator<'a, F: PrimeField, const L: usize, I: Iterator> { + input: I, + lagrange_table: &'a LagrangeTable, +} + +impl<'a, F: PrimeField, const L: usize, I: Iterator> Iterator + for ValuesEvalAtRIterator<'a, F, L, I> +{ + type Item = F; + + fn next(&mut self) -> Option { + self.input + .next() + .map(|values| self.lagrange_table.eval(&values)[0]) + } +} + +impl> VerifierLagrangeInput + for VerifierValues +{ + fn eval_at_r<'a>( + self, + lagrange_table: &'a LagrangeTable, + ) -> impl Iterator + 'a + where + Self: 'a, + { + ValuesEvalAtRIterator { + input: self.0, + lagrange_table, + } + } +} + #[cfg(all(test, unit_test))] mod test { + use rand::Rng; + + use super::*; use crate::{ ff::{Fp31, U128Conversions}, - protocol::ipa_prf::malicious_security::{ - lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, - verifier::{ - compute_g_differences, compute_sum_share, interpolate_at_r, recurse_u_or_v, - recursively_compute_final_check, + protocol::{ + context::{ + dzkp_field::{tests::reference_convert, TABLE_U, TABLE_V}, + dzkp_validator::{MultiplicationInputsBlock, BIT_ARRAY_LEN}, }, + ipa_prf::malicious_security::lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, }, secret_sharing::SharedValue, + test_executor::run_random, + utils::arraychunks::ArrayChunkIterator, }; fn to_field(a: &[u128]) -> Vec { @@ -292,12 +404,11 @@ mod test { let tables: [LagrangeTable; 3] = CHALLENGES .map(|r| LagrangeTable::new(&denominator_p_or_q, &Fp31::try_from(r).unwrap())); - let u_or_v_2 = recurse_u_or_v::<_, _, 4>(to_field(&U_1).into_iter(), &tables[0]) - .collect::>(); + let u_or_v_2 = + recurse_u_or_v::<_, 4>(to_field(&U_1).into_iter(), &tables[0]).collect::>(); assert_eq!(u_or_v_2, to_field(&U_2)); - let u_or_v_3 = - recurse_u_or_v::<_, _, 4>(u_or_v_2.into_iter(), &tables[1]).collect::>(); + let u_or_v_3 = recurse_u_or_v::<_, 4>(u_or_v_2.into_iter(), &tables[1]).collect::>(); assert_eq!(u_or_v_3, to_field(&U_3[..2])); @@ -309,12 +420,12 @@ mod test { ]; let p_final = - recurse_u_or_v::<_, _, 4>(u_or_v_3_masked.into_iter(), &tables[2]).collect::>(); + recurse_u_or_v::<_, 4>(u_or_v_3_masked.into_iter(), &tables[2]).collect::>(); assert_eq!(p_final[0].as_u128(), EXPECTED_P_FINAL); - let p_final_another_way = recursively_compute_final_check::( - to_field(&U_1).into_iter(), + let p_final_another_way = recursively_compute_final_check::<_, 4>( + VerifierValues(to_field(&U_1).into_iter().chunk_array::<4>()), &CHALLENGES .map(|x| Fp31::try_from(x).unwrap()) .into_iter() @@ -396,15 +507,15 @@ mod test { // final iteration let p_final = - recurse_u_or_v::<_, _, 4>(u_or_v_3_masked.into_iter(), &tables[2]).collect::>(); + recurse_u_or_v::<_, 4>(u_or_v_3_masked.into_iter(), &tables[2]).collect::>(); assert_eq!(p_final[0].as_u128(), EXPECTED_Q_FINAL); // uv values in input format let v_1 = to_field(&V_1); - let q_final_another_way = recursively_compute_final_check::( - v_1.into_iter(), + let q_final_another_way = recursively_compute_final_check::( + VerifierValues(v_1.into_iter().chunk_array::<4>()), &CHALLENGES .map(|x| Fp31::try_from(x).unwrap()) .into_iter() @@ -447,4 +558,59 @@ mod test { assert_eq!(Fp31::ZERO, g_differences[0]); } + + #[test] + fn verifier_table_indices_equivalence() { + run_random(|mut rng| async move { + let block = rng.gen::(); + + let denominator = CanonicalLagrangeDenominator::new(); + let r = rng.gen(); + let lagrange_table_r = LagrangeTable::new(&denominator, &r); + + // Test equivalence for _u_ values + assert!(VerifierTableIndices { + input: block + .rotate_right() + .table_indices_from_right_prover() + .into_iter(), + table: &TABLE_U, + } + .eval_at_r(&lagrange_table_r) + .eq(VerifierValues((0..BIT_ARRAY_LEN).map(|i| { + reference_convert( + block.x_left[i], + block.x_right[i], + block.y_left[i], + block.y_right[i], + block.prss_left[i], + block.prss_right[i], + ) + .0 + })) + .eval_at_r(&lagrange_table_r))); + + // Test equivalence for _v_ values + assert!(VerifierTableIndices { + input: block + .rotate_left() + .table_indices_from_left_prover() + .into_iter(), + table: &TABLE_V, + } + .eval_at_r(&lagrange_table_r) + .eq(VerifierValues((0..BIT_ARRAY_LEN).map(|i| { + reference_convert( + block.x_left[i], + block.x_right[i], + block.y_left[i], + block.y_right[i], + block.prss_left[i], + block.prss_right[i], + ) + .1 + })) + .eval_at_r(&lagrange_table_r))); + }); + } } diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index f4225a0d8..63dc34a6f 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -1,4 +1,4 @@ -use std::{convert::Infallible, iter::zip, num::NonZeroU32, ops::Add}; +use std::{cmp::max, convert::Infallible, iter::zip, num::NonZeroU32, ops::Add}; use futures::{stream, StreamExt, TryStreamExt}; use generic_array::{ArrayLength, GenericArray}; @@ -24,8 +24,8 @@ use crate::{ protocol::{ basics::{BooleanArrayMul, BooleanProtocols, Reveal}, context::{ - dzkp_validator::DZKPValidator, DZKPUpgraded, MacUpgraded, MaliciousProtocolSteps, - UpgradableContext, + dzkp_validator::{DZKPValidator, TARGET_PROOF_SIZE}, + DZKPUpgraded, MacUpgraded, MaliciousProtocolSteps, UpgradableContext, }, ipa_prf::{ boolean_ops::convert_to_fp25519, @@ -44,6 +44,7 @@ use crate::{ BitDecomposed, FieldSimd, SharedValue, TransposeFrom, Vectorizable, }, seq_join::seq_join, + utils::non_zero_prev_power_of_two, }; pub(crate) mod aggregation; @@ -58,7 +59,10 @@ pub(crate) mod shuffle; pub(crate) mod step; pub mod validation_protocol; -pub use malicious_security::prover::{LargeProofGenerator, SmallProofGenerator}; +pub use malicious_security::{ + CompressedProofGenerator, FirstProofGenerator, LagrangeTable, ProverTableIndices, + VerifierTableIndices, +}; pub use shuffle::Shuffle; /// Match key type @@ -409,13 +413,17 @@ where Ok(noisy_output_histogram) } -// We expect 2*256 = 512 gates in total for two additions per conversion. The vectorization factor -// is CONV_CHUNK. Let `len` equal the number of converted shares. The total amount of -// multiplications is CONV_CHUNK*512*len. We want CONV_CHUNK*512*len ≈ 50M, or len ≈ 381, for a -// reasonably-sized proof. There is also a constraint on proof chunks to be powers of two, so -// we pick the closest power of two close to 381 but less than that value. 256 gives us around 33M -// multiplications per batch -const CONV_PROOF_CHUNK: usize = 256; +/// Returns a suitable proof chunk size (in records) for use with `convert_to_fp25519`. +/// +/// We expect 2*256 = 512 gates in total for two additions per conversion. The +/// vectorization factor is `CONV_CHUNK`. Let `len` equal the number of converted +/// shares. The total amount of multiplications is `CONV_CHUNK`*512*len. We want +/// `CONV_CHUNK`*512*len ≈ 50M for a reasonably-sized proof. There is also a constraint +/// on proof chunks to be powers of two, and we don't want to compute a proof chunk +/// of zero when `TARGET_PROOF_SIZE` is smaller for tests. +fn conv_proof_chunk() -> usize { + non_zero_prev_power_of_two(max(2, TARGET_PROOF_SIZE / CONV_CHUNK / 512)) +} #[tracing::instrument(name = "compute_prf_for_inputs", skip_all)] async fn compute_prf_for_inputs( @@ -443,7 +451,7 @@ where protocol: &Step::ConvertFp25519, validate: &Step::ConvertFp25519Validate, }, - CONV_PROOF_CHUNK, + conv_proof_chunk(), ); let m_ctx = validator.context(); diff --git a/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs b/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs index 0d74ad6a7..099fed80d 100644 --- a/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/oprf_padding/mod.rs @@ -194,7 +194,82 @@ where total_number_of_fake_rows: u32, ) { padding_input_rows.extend( - repeat(IndistinguishableHybridReport::ZERO).take(total_number_of_fake_rows as usize), + repeat(IndistinguishableHybridReport::::ZERO) + .take(total_number_of_fake_rows as usize), + ); + } +} + +impl Paddable for IndistinguishableHybridReport +where + BK: BooleanArray + U128Conversions, + V: BooleanArray, +{ + /// Given an extendable collection of `AggregateableReports`s, + /// this function will pad the collection with dummy reports. The reports + /// cover every `breakdown_key` and have a secret sharing of zero for the `value`. + /// Each `breakdown_key` receives a random number of a rows. + fn add_padding_items, const B: usize>( + direction_to_excluded_helper: Direction, + padding_input_rows: &mut VC, + padding_params: &PaddingParameters, + rng: &mut InstrumentedSequentialSharedRandomness, + ) -> Result { + let mut total_number_of_fake_rows = 0; + match padding_params.aggregation_padding { + AggregationPadding::NoAggPadding => {} + AggregationPadding::Parameters { + aggregation_epsilon, + aggregation_delta, + aggregation_padding_sensitivity, + } => { + let aggregation_padding = OPRFPaddingDp::new( + aggregation_epsilon, + aggregation_delta, + aggregation_padding_sensitivity, + )?; + let num_breakdowns: u32 = u32::try_from(B).unwrap(); + // for every breakdown, sample how many dummies will be added + for breakdownkey in 0..num_breakdowns { + let sample = aggregation_padding.sample(rng); + total_number_of_fake_rows += sample; + + // now add `sample` many fake rows with this `breakdownkey` + for _ in 0..sample { + let breakdownkey_shares = match direction_to_excluded_helper { + Direction::Left => AdditiveShare::new( + BK::ZERO, + BK::truncate_from(u128::from(breakdownkey)), + ), + Direction::Right => AdditiveShare::new( + BK::truncate_from(u128::from(breakdownkey)), + BK::ZERO, + ), + }; + + let row = IndistinguishableHybridReport:: { + match_key: (), + value: AdditiveShare::new(V::ZERO, V::ZERO), + breakdown_key: breakdownkey_shares, + }; + + padding_input_rows.extend(std::iter::once(row)); + } + } + } + } + Ok(total_number_of_fake_rows) + } + + /// Given an extendable collection of `IndistinguishableHybridReport`s, + /// this function ads `total_number_of_fake_rows` of Reports with zeros in all fields. + fn add_zero_shares>( + padding_input_rows: &mut VC, + total_number_of_fake_rows: u32, + ) { + padding_input_rows.extend( + repeat(IndistinguishableHybridReport::::ZERO) + .take(total_number_of_fake_rows as usize), ); } } diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index f6bf2e339..3617ab767 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -52,6 +52,7 @@ use crate::{ replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing}, BitDecomposed, FieldSimd, SharedValue, TransposeFrom, Vectorizable, }, + utils::non_zero_prev_power_of_two, }; pub mod feature_label_dot_product; @@ -515,7 +516,10 @@ where // TODO: this override was originally added to work around problems with // read_size vs. batch size alignment. Those are now fixed (in #1332), but this // is still observed to help performance (see #1376), so has been retained. - std::cmp::min(sh_ctx.active_work().get(), chunk_size.next_power_of_two()), + std::cmp::min( + sh_ctx.active_work().get(), + non_zero_prev_power_of_two(chunk_size), + ), ); dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap()); let ctx_for_row_number = set_up_contexts(&dzkp_validator.context(), histogram)?; diff --git a/ipa-core/src/protocol/ipa_prf/quicksort.rs b/ipa-core/src/protocol/ipa_prf/quicksort.rs index 943dfb1ec..26131224d 100644 --- a/ipa-core/src/protocol/ipa_prf/quicksort.rs +++ b/ipa-core/src/protocol/ipa_prf/quicksort.rs @@ -30,6 +30,7 @@ use crate::{ Vectorizable, }, seq_join::seq_join, + utils::non_zero_prev_power_of_two, }; impl ChunkBuffer for (Vec>, Vec>) @@ -98,7 +99,7 @@ where } fn quicksort_proof_chunk(key_bits: usize) -> usize { - (TARGET_PROOF_SIZE / key_bits / SORT_CHUNK).next_power_of_two() + non_zero_prev_power_of_two(TARGET_PROOF_SIZE / key_bits / SORT_CHUNK) } /// Insecure quicksort using MPC comparisons and a key extraction function `get_key`. diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs index a8b6d63c0..b0745f568 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs @@ -9,6 +9,7 @@ use futures::{ }; use futures_util::future::{try_join, try_join3}; use generic_array::GenericArray; +use subtle::ConstantTimeEq; use typenum::Const; use crate::{ @@ -43,6 +44,44 @@ use crate::{ sharding::ShardIndex, }; +/// Container for left and right shares with tags attached to them. +/// Looks like an additive share, but it is not because it does not need +/// many traits that additive shares require to implement +#[derive(Clone, Debug, Default)] +struct Pair { + left: S, + right: S, +} + +impl Shuffleable for Pair { + type Share = S; + + fn left(&self) -> Self::Share { + self.left.clone() + } + + fn right(&self) -> Self::Share { + self.right.clone() + } + + fn new(l: Self::Share, r: Self::Share) -> Self { + Self { left: l, right: r } + } +} + +impl From> for Pair { + fn from(value: AdditiveShare) -> Self { + let (l, r) = value.as_tuple(); + Shuffleable::new(l, r) + } +} + +impl From> for AdditiveShare { + fn from(value: Pair) -> Self { + ReplicatedSecretSharing::new(value.left, value.right) + } +} + /// This function executes the maliciously secure shuffle protocol on the input: `shares`. /// /// ## Errors @@ -65,7 +104,7 @@ where .collect::>>(); // compute and append tags to rows - let shares_and_tags: Vec> = + let shares_and_tags: Vec> = compute_and_add_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags), &keys, shares).await?; // shuffle @@ -75,7 +114,7 @@ where verify_shuffle::<_, S>( ctx.narrow(&OPRFShuffleStep::VerifyShuffle), &keys, - &shuffled_shares, + shuffled_shares.as_slice(), messages, ) .await?; @@ -143,7 +182,7 @@ where let keys = setup_keys(ctx.narrow(&OPRFShuffleStep::SetupKeys), amount_of_keys).await?; // compute and append tags to rows - let shares_and_tags: Vec> = + let shares_and_tags: Vec> = compute_and_add_tags(ctx.narrow(&OPRFShuffleStep::GenerateTags), &keys, shares).await?; let (shuffled_shares, messages) = match ctx.role() { @@ -170,7 +209,7 @@ where /// /// ## Panics /// Panics when `S::Bits > B::Bits`. -fn truncate_tags(shares_and_tags: &[AdditiveShare]) -> Vec +fn truncate_tags(shares_and_tags: &[Pair]) -> Vec where S: MaliciousShuffleable, { @@ -178,8 +217,8 @@ where .iter() .map(|row_with_tag| { Shuffleable::new( - split_row_and_tag::(ReplicatedSecretSharing::left(row_with_tag)).0, - split_row_and_tag::(ReplicatedSecretSharing::right(row_with_tag)).0, + split_row_and_tag::(&row_with_tag.left).0, + split_row_and_tag::(&row_with_tag.right).0, ) }) .collect() @@ -191,7 +230,9 @@ where /// When `row_with_tag` does not have the correct format, /// i.e. deserialization returns an error, /// the output row and tag will be zero. -fn split_row_and_tag(row_with_tag: S::ShareAndTag) -> (S::Share, Gf32Bit) { +fn split_row_and_tag( + row_with_tag: &S::ShareAndTag, +) -> (S::Share, Gf32Bit) { let mut buf = GenericArray::default(); row_with_tag.serialize(&mut buf); ( @@ -210,7 +251,7 @@ fn split_row_and_tag(row_with_tag: S::ShareAndTag) -> ( async fn verify_shuffle( ctx: C, key_shares: &[AdditiveShare], - shuffled_shares: &[AdditiveShare], + shuffled_shares: &[Pair], messages: IntermediateShuffleMessages, ) -> Result<(), Error> { // reveal keys @@ -247,7 +288,7 @@ async fn verify_shuffle( async fn h1_verify( ctx: C, keys: &[Gf32Bit], - share_a_and_b: &[AdditiveShare], + share_a_and_b: &[Pair], x1: Vec, ) -> Result<(), Error> { // compute hashes @@ -256,9 +297,9 @@ async fn h1_verify( // compute hash for A xor B let hash_a_xor_b = compute_and_hash_tags::( keys, - share_a_and_b.iter().map(|share| { - ReplicatedSecretSharing::left(share) + ReplicatedSecretSharing::right(share) - }), + share_a_and_b + .iter() + .map(|share| Shuffleable::left(share) + Shuffleable::right(share)), ); // setup channels @@ -280,21 +321,21 @@ async fn h1_verify( .await?; // check y1 - if hash_x1 != hash_y1 { + if hash_x1.ct_ne(&hash_y1).into() { return Err(Error::ShuffleValidationFailed(format!( "Y1 is inconsistent: hash of x1: {hash_x1:?}, hash of y1: {hash_y1:?}" ))); } // check c from h3 - if hash_a_xor_b != hash_h3 { + if hash_a_xor_b.ct_ne(&hash_h3).into() { return Err(Error::ShuffleValidationFailed(format!( "C from H3 is inconsistent: hash of a_xor_b: {hash_a_xor_b:?}, hash of C: {hash_h3:?}" ))); } // check h2 - if hash_a_xor_b != hash_h2 { + if hash_a_xor_b.ct_ne(&hash_h2).into() { return Err(Error::ShuffleValidationFailed(format!( "C from H2 is inconsistent: hash of a_xor_b: {hash_a_xor_b:?}, hash of C: {hash_h2:?}" ))); @@ -314,7 +355,7 @@ async fn h1_verify( async fn h2_verify( ctx: C, keys: &[Gf32Bit], - share_b_and_c: &[AdditiveShare], + share_b_and_c: &[Pair], x2: Vec, ) -> Result<(), Error> { // compute hashes @@ -341,7 +382,7 @@ async fn h2_verify( .await?; // check x2 - if hash_x2 != hash_h3 { + if hash_x2.ct_ne(&hash_h3).into() { return Err(Error::ShuffleValidationFailed(format!( "X2 is inconsistent: hash of x2: {hash_x2:?}, hash of y2: {hash_h3:?}" ))); @@ -359,7 +400,7 @@ async fn h2_verify( async fn h3_verify( ctx: C, keys: &[Gf32Bit], - share_c_and_a: &[AdditiveShare], + share_c_and_a: &[Pair], y1: Vec, y2: Vec, ) -> Result<(), Error> { @@ -405,7 +446,7 @@ where let iterator = row_iterator.into_iter().map(|row_with_tag| { // when split_row_and_tags returns the default value, the verification will fail // except 2^-security_parameter, i.e. 2^-32 - let (row, tag) = split_row_and_tag::(row_with_tag); + let (row, tag) = split_row_and_tag::(&row_with_tag); >>::try_into(row) .unwrap() .into_iter() @@ -470,7 +511,7 @@ async fn compute_and_add_tags( ctx: C, keys: &[AdditiveShare], rows: Vec, -) -> Result>, Error> +) -> Result>, Error> where C: Context, S: MaliciousShuffleable, @@ -537,7 +578,7 @@ where fn concatenate_row_and_tag( row: &S, tag: &AdditiveShare, -) -> AdditiveShare { +) -> Pair { let mut row_left = GenericArray::default(); let mut row_right = GenericArray::default(); let mut tag_left = GenericArray::default(); @@ -600,7 +641,10 @@ mod tests { vec![record], ) .await - .unwrap(); + .unwrap() + .into_iter() + .map(AdditiveShare::from) + .collect(); (keys, shares_and_tags) }) @@ -701,7 +745,11 @@ mod tests { verify_shuffle::<_, AdditiveShare>( ctx.narrow("verify"), &key_shares, - &shares, + shares + .into_iter() + .map(Pair::from) + .collect::>() + .as_slice(), messages, ) .await @@ -726,21 +774,21 @@ mod tests { { let row = ::new(rng.gen(), rng.gen()); let tag = AdditiveShare::::new(rng.gen::(), rng.gen::()); - let row_and_tag: AdditiveShare = concatenate_row_and_tag(&row, &tag); + let row_and_tag: Pair = concatenate_row_and_tag(&row, &tag); let mut buf = GenericArray::default(); let mut buf_row = GenericArray::default(); let mut buf_tag = GenericArray::default(); // check left shares - ReplicatedSecretSharing::left(&row_and_tag).serialize(&mut buf); + Shuffleable::left(&row_and_tag).serialize(&mut buf); Shuffleable::left(&row).serialize(&mut buf_row); assert_eq!(buf[0..S::TAG_OFFSET], buf_row[..]); ReplicatedSecretSharing::left(&tag).serialize(&mut buf_tag); assert_eq!(buf[S::TAG_OFFSET..], buf_tag[..]); // check right shares - ReplicatedSecretSharing::right(&row_and_tag).serialize(&mut buf); + Shuffleable::right(&row_and_tag).serialize(&mut buf); Shuffleable::right(&row).serialize(&mut buf_row); assert_eq!(buf[0..S::TAG_OFFSET], buf_row[..]); ReplicatedSecretSharing::right(&tag).serialize(&mut buf_tag); @@ -765,6 +813,7 @@ mod tests { where S: MaliciousShuffleable, S::Share: IntoShares, + S::ShareAndTag: SharedValue, Standard: Distribution, { const RECORD_AMOUNT: usize = 10; @@ -813,6 +862,9 @@ mod tests { ) .await .unwrap() + .into_iter() + .map(AdditiveShare::from) + .collect::>() }, ) .await diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs b/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs index e69fc0093..165366826 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs @@ -279,7 +279,7 @@ pub trait MaliciousShuffleable: /// /// Having an alias here makes it easier to reference in the code, because the /// shuffle routines have an `S: MaliciousShuffleable` type parameter. - type ShareAndTag: ShuffleShare + SharedValue; + type ShareAndTag: ShuffleShare; /// Same as `Self::MaliciousShare::TAG_OFFSET`. /// @@ -316,11 +316,7 @@ where /// automatically. pub trait MaliciousShuffleShare: TryInto, Error = LengthError> { /// A type that can hold `::Share` along with a 32-bit MAC. - /// - /// The `SharedValue` bound is required because some of the malicious shuffle - /// routines use `AdditiveShare`. It might be possible to refactor - /// those routines to avoid the `SharedValue` bound. - type ShareAndTag: ShuffleShare + SharedValue; + type ShareAndTag: ShuffleShare; /// The offset to the MAC in `ShareAndTag`. const TAG_OFFSET: usize; diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs index cb2754e5f..e496f97e4 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/proof_generation.rs @@ -1,21 +1,20 @@ -use std::{array, iter::zip}; +use std::{array, iter::zip, ops::Mul}; -use typenum::{UInt, UTerm, Unsigned, B0, B1}; +use typenum::{Unsigned, U, U8}; use crate::{ - const_assert_eq, error::Error, ff::{Fp61BitPrime, Serializable}, helpers::{Direction, MpcMessage, TotalRecords}, protocol::{ - context::{ - dzkp_field::{UVTupleBlock, BLOCK_SIZE}, - dzkp_validator::MAX_PROOF_RECURSION, - Context, - }, - ipa_prf::malicious_security::{ - lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, - prover::{LargeProofGenerator, SmallProofGenerator}, + context::{dzkp_validator::MAX_PROOF_RECURSION, Context}, + ipa_prf::{ + malicious_security::{ + lagrange::{CanonicalLagrangeDenominator, LagrangeTable}, + prover::{ProverLagrangeInput, ProverValues}, + FIRST_RECURSION_FACTOR as FRF, + }, + CompressedProofGenerator, FirstProofGenerator, }, prss::SharedRandomness, RecordId, RecordIdRange, @@ -25,8 +24,8 @@ use crate::{ /// This a `ProofBatch` generated by a prover. pub struct ProofBatch { - pub first_proof: [Fp61BitPrime; LargeProofGenerator::PROOF_LENGTH], - pub proofs: Vec<[Fp61BitPrime; SmallProofGenerator::PROOF_LENGTH]>, + pub first_proof: [Fp61BitPrime; FirstProofGenerator::PROOF_LENGTH], + pub proofs: Vec<[Fp61BitPrime; CompressedProofGenerator::PROOF_LENGTH]>, } impl FromIterator for ProofBatch { @@ -35,10 +34,11 @@ impl FromIterator for ProofBatch { // consume the first P elements let first_proof = iterator .by_ref() - .take(LargeProofGenerator::PROOF_LENGTH) - .collect::<[Fp61BitPrime; LargeProofGenerator::PROOF_LENGTH]>(); + .take(FirstProofGenerator::PROOF_LENGTH) + .collect::<[Fp61BitPrime; FirstProofGenerator::PROOF_LENGTH]>(); // consume the rest - let proofs = iterator.collect::>(); + let proofs = + iterator.collect::>(); ProofBatch { first_proof, proofs, @@ -51,7 +51,8 @@ impl ProofBatch { #[allow(clippy::len_without_is_empty)] #[must_use] pub fn len(&self) -> usize { - self.proofs.len() * SmallProofGenerator::PROOF_LENGTH + LargeProofGenerator::PROOF_LENGTH + FirstProofGenerator::PROOF_LENGTH + + self.proofs.len() * CompressedProofGenerator::PROOF_LENGTH } #[allow(clippy::unnecessary_box_returns)] // clippy bug? `Array` exceeds unnecessary-box-size @@ -80,39 +81,43 @@ impl ProofBatch { /// ## Panics /// Panics when the function fails to set the masks without overwritting `u` and `v` values. /// This only happens when there is an issue in the recursion. - pub fn generate( + pub fn generate( ctx: &C, mut prss_record_ids: RecordIdRange, - uv_tuple_inputs: I, + uv_inputs: impl ProverLagrangeInput + Clone, ) -> (Self, Self, Fp61BitPrime, Fp61BitPrime) where C: Context, - I: Iterator> + Clone, { - const LRF: usize = LargeProofGenerator::RECURSION_FACTOR; - const LLL: usize = LargeProofGenerator::LAGRANGE_LENGTH; - const SRF: usize = SmallProofGenerator::RECURSION_FACTOR; - const SLL: usize = SmallProofGenerator::LAGRANGE_LENGTH; - const SPL: usize = SmallProofGenerator::PROOF_LENGTH; + const FLL: usize = FirstProofGenerator::LAGRANGE_LENGTH; + const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; + const CLL: usize = CompressedProofGenerator::LAGRANGE_LENGTH; + const CPL: usize = CompressedProofGenerator::PROOF_LENGTH; // precomputation for first proof - let first_denominator = CanonicalLagrangeDenominator::::new(); - let first_lagrange_table = LagrangeTable::::from(first_denominator); + let first_denominator = CanonicalLagrangeDenominator::::new(); + let first_lagrange_table = LagrangeTable::::from(first_denominator); + + let first_proof = FirstProofGenerator::compute_proof( + uv_inputs + .clone() + .extrapolate_y_values(&first_lagrange_table), + ); // generate first proof from input iterator let (mut uv_values, first_proof_from_left, my_first_proof_left_share) = - LargeProofGenerator::gen_artefacts_from_recursive_step( + FirstProofGenerator::gen_artefacts_from_recursive_step( ctx, &mut prss_record_ids, - &first_lagrange_table, - ProofBatch::polynomials_from_inputs(uv_tuple_inputs), + first_proof, + uv_inputs, ); // `MAX_PROOF_RECURSION - 2` because: // * The first level of recursion has already happened. - // * We need (SRF - 1) at the last level to have room for the masks. + // * We need (CRF - 1) at the last level to have room for the masks. let max_uv_values: usize = - (SRF - 1) * SRF.pow(u32::try_from(MAX_PROOF_RECURSION - 2).unwrap()); + (CRF - 1) * CRF.pow(u32::try_from(MAX_PROOF_RECURSION - 2).unwrap()); assert!( uv_values.len() <= max_uv_values, "Proof batch is too large: have {} uv_values, max is {}", @@ -122,9 +127,9 @@ impl ProofBatch { // storage for other proofs let mut my_proofs_left_shares = - Vec::<[Fp61BitPrime; SPL]>::with_capacity(MAX_PROOF_RECURSION - 1); + Vec::<[Fp61BitPrime; CPL]>::with_capacity(MAX_PROOF_RECURSION - 1); let mut shares_of_proofs_from_prover_left = - Vec::<[Fp61BitPrime; SPL]>::with_capacity(MAX_PROOF_RECURSION - 1); + Vec::<[Fp61BitPrime; CPL]>::with_capacity(MAX_PROOF_RECURSION - 1); // generate masks // Prover `P_i` and verifier `P_{i-1}` both compute p(x) @@ -138,29 +143,31 @@ impl ProofBatch { let (q_mask_from_left_prover, my_q_mask) = ctx.prss().generate_fields(prss_record_ids.expect_next()); - let denominator = CanonicalLagrangeDenominator::::new(); - let lagrange_table = LagrangeTable::::from(denominator); + let denominator = CanonicalLagrangeDenominator::::new(); + let lagrange_table = LagrangeTable::::from(denominator); // The last recursion can only include (λ - 1) u/v value pairs, because it needs to put the - // masks in the constant term. If we compress to `uv_values.len() == SRF`, then we need to - // do two more iterations: compressing SRF u/v values to 1 pair of (unmasked) u/v values, + // masks in the constant term. If we compress to `uv_values.len() == CRF`, then we need to + // do two more iterations: compressing CRF u/v values to 1 pair of (unmasked) u/v values, // and then compressing that pair and the masks to the final u/v value. // // There is a test for this corner case in validation.rs. let mut did_set_masks = false; - // recursively generate proofs via SmallProofGenerator + // recursively generate proofs via CompressedProofGenerator while !did_set_masks { - if uv_values.len() < SRF { + if uv_values.len() < CRF { did_set_masks = true; uv_values.set_masks(my_p_mask, my_q_mask).unwrap(); } + let my_proof = + CompressedProofGenerator::compute_proof_from_uv(uv_values.iter(), &lagrange_table); let (uv_values_new, share_of_proof_from_prover_left, my_proof_left_share) = - SmallProofGenerator::gen_artefacts_from_recursive_step( + CompressedProofGenerator::gen_artefacts_from_recursive_step( ctx, &mut prss_record_ids, - &lagrange_table, - uv_values.iter(), + my_proof, + ProverValues(uv_values.iter().copied()), ); shares_of_proofs_from_prover_left.push(share_of_proof_from_prover_left); my_proofs_left_shares.push(my_proof_left_share); @@ -224,59 +231,14 @@ impl ProofBatch { .take(length) .collect()) } - - /// This is a helper function that allows to split a `UVTupleInputs` - /// which consists of arrays of size `BLOCK_SIZE` - /// into an iterator over arrays of size `LargeProofGenerator::RECURSION_FACTOR`. - /// - /// ## Panics - /// Panics when `unwrap` panics, i.e. `try_from` fails to convert a slice to an array. - pub fn polynomials_from_inputs( - inputs: I, - ) -> impl Iterator< - Item = ( - [Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR], - [Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR], - ), - > + Clone - where - I: Iterator> + Clone, - { - assert_eq!(BLOCK_SIZE % LargeProofGenerator::RECURSION_FACTOR, 0); - inputs.flat_map(|(u_block, v_block)| { - (0usize..(BLOCK_SIZE / LargeProofGenerator::RECURSION_FACTOR)).map(move |i| { - ( - <[Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR]>::try_from( - &u_block[i * LargeProofGenerator::RECURSION_FACTOR - ..(i + 1) * LargeProofGenerator::RECURSION_FACTOR], - ) - .unwrap(), - <[Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR]>::try_from( - &v_block[i * LargeProofGenerator::RECURSION_FACTOR - ..(i + 1) * LargeProofGenerator::RECURSION_FACTOR], - ) - .unwrap(), - ) - }) - }) - } } -const_assert_eq!( - MAX_PROOF_RECURSION, - 9, - "following impl valid only for MAX_PROOF_RECURSION = 9" -); - -#[rustfmt::skip] -type U1464 = UInt, B0>, B1>, B1>, B0>, B1>, B1>, B1>, B0>, B0>, B0>; - -const ARRAY_LEN: usize = 183; +const ARRAY_LEN: usize = FirstProofGenerator::PROOF_LENGTH + + (MAX_PROOF_RECURSION - 1) * CompressedProofGenerator::PROOF_LENGTH; type Array = [Fp61BitPrime; ARRAY_LEN]; impl Serializable for Box { - type Size = U1464; + type Size = as Mul>::Output; type DeserializationError = ::DeserializationError; @@ -305,19 +267,25 @@ impl MpcMessage for Box {} #[cfg(all(test, unit_test))] mod test { - use rand::{thread_rng, Rng}; + use std::iter::repeat_with; + + use rand::Rng; use crate::{ - ff::{Fp61BitPrime, U128Conversions}, protocol::{ - context::{dzkp_field::BLOCK_SIZE, Context}, - ipa_prf::validation_protocol::{ - proof_generation::ProofBatch, - validation::{test::simple_proof_check, BatchToVerify}, + context::Context, + ipa_prf::{ + malicious_security::{ + prover::{ProverValues, UVValues}, + FIRST_RECURSION_FACTOR, + }, + validation_protocol::{ + proof_generation::ProofBatch, + validation::{test::simple_proof_check, BatchToVerify}, + }, }, RecordId, RecordIdRange, }, - secret_sharing::replicated::ReplicatedSecretSharing, test_executor::run, test_fixture::{Runner, TestWorld}, }; @@ -327,37 +295,16 @@ mod test { run(|| async move { let world = TestWorld::default(); - let mut rng = thread_rng(); + let mut rng = world.rng(); - // each helper samples a random value h - // which is later used to generate distinct values across helpers - let h = Fp61BitPrime::truncate_from(rng.gen_range(0u128..100)); + let uv_values = repeat_with(|| (rng.gen(), rng.gen())) + .take(100) + .collect::>(); + let uv_values_iter = uv_values.iter().copied(); + let uv_values_iter_ref = &uv_values_iter; let result = world - .semi_honest(h, |ctx, h| async move { - let h = Fp61BitPrime::truncate_from(h.left().as_u128() % 100); - // generate blocks of UV values - // generate u values as (1h,2h,3h,....,10h*BlockSize) split into Blocksize chunks - // where BlockSize = 32 - // v values are identical to u - let uv_tuple_vec = (0usize..100) - .map(|i| { - ( - (BLOCK_SIZE * i..BLOCK_SIZE * (i + 1)) - .map(|j| { - Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) * h - }) - .collect::<[Fp61BitPrime; BLOCK_SIZE]>(), - (BLOCK_SIZE * i..BLOCK_SIZE * (i + 1)) - .map(|j| { - Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) * h - }) - .collect::<[Fp61BitPrime; BLOCK_SIZE]>(), - ) - }) - .collect::>(); - - // generate and output VerifierBatch together with h value + .semi_honest((), |ctx, ()| async move { let ( my_batch_left_shares, shares_of_batch_from_left_prover, @@ -366,10 +313,10 @@ mod test { ) = ProofBatch::generate( &ctx.narrow("generate_batch"), RecordIdRange::ALL, - uv_tuple_vec.into_iter(), + ProverValues(uv_values_iter_ref.clone()), ); - let batch_to_verify = BatchToVerify::generate_batch_to_verify( + BatchToVerify::generate_batch_to_verify( ctx.narrow("generate_batch"), RecordId::FIRST, my_batch_left_shares, @@ -377,21 +324,18 @@ mod test { p_mask_from_right_prover, q_mask_from_left_prover, ) - .await; - - // generate and output VerifierBatch together with h value - (h, batch_to_verify) + .await }) .await; // proof from first party - simple_proof_check(result[0].0, &result[2].1, &result[1].1); + simple_proof_check(uv_values.iter(), &result[2], &result[1]); // proof from second party - simple_proof_check(result[1].0, &result[0].1, &result[2].1); + simple_proof_check(uv_values.iter(), &result[0], &result[2]); // proof from third party - simple_proof_check(result[2].0, &result[1].1, &result[0].1); + simple_proof_check(uv_values.iter(), &result[1], &result[0]); }); } } diff --git a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs index f0430e996..e4022a60c 100644 --- a/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs +++ b/ipa-core/src/protocol/ipa_prf/validation_protocol/validation.rs @@ -4,7 +4,8 @@ use std::{ }; use futures_util::future::{try_join, try_join4}; -use typenum::{Unsigned, U288, U80}; +use subtle::ConstantTimeEq; +use typenum::{Unsigned, U120, U448}; use crate::{ const_assert_eq, @@ -20,10 +21,13 @@ use crate::{ }, ipa_prf::{ malicious_security::{ - prover::{LargeProofGenerator, SmallProofGenerator}, - verifier::{compute_g_differences, recursively_compute_final_check}, + verifier::{ + compute_g_differences, recursively_compute_final_check, VerifierLagrangeInput, + }, + FIRST_RECURSION_FACTOR as FRF, }, validation_protocol::proof_generation::ProofBatch, + CompressedProofGenerator, FirstProofGenerator, }, RecordId, }, @@ -44,10 +48,10 @@ use crate::{ #[derive(Debug)] #[allow(clippy::struct_field_names)] pub struct BatchToVerify { - first_proof_from_left_prover: [Fp61BitPrime; LargeProofGenerator::PROOF_LENGTH], - first_proof_from_right_prover: [Fp61BitPrime; LargeProofGenerator::PROOF_LENGTH], - proofs_from_left_prover: Vec<[Fp61BitPrime; SmallProofGenerator::PROOF_LENGTH]>, - proofs_from_right_prover: Vec<[Fp61BitPrime; SmallProofGenerator::PROOF_LENGTH]>, + first_proof_from_left_prover: [Fp61BitPrime; FirstProofGenerator::PROOF_LENGTH], + first_proof_from_right_prover: [Fp61BitPrime; FirstProofGenerator::PROOF_LENGTH], + proofs_from_left_prover: Vec<[Fp61BitPrime; CompressedProofGenerator::PROOF_LENGTH]>, + proofs_from_right_prover: Vec<[Fp61BitPrime; CompressedProofGenerator::PROOF_LENGTH]>, p_mask_from_right_prover: Fp61BitPrime, q_mask_from_left_prover: Fp61BitPrime, } @@ -104,13 +108,12 @@ impl BatchToVerify { where C: Context, { - const LRF: usize = LargeProofGenerator::RECURSION_FACTOR; - const SRF: usize = SmallProofGenerator::RECURSION_FACTOR; + const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; // exclude for first proof - let exclude_large = u128::try_from(LRF).unwrap(); + let exclude_large = u128::try_from(FRF).unwrap(); // exclude for other proofs - let exclude_small = u128::try_from(SRF).unwrap(); + let exclude_small = u128::try_from(CRF).unwrap(); // generate hashes let my_hashes_prover_left = ProofHashes::generate_hashes(self, Direction::Left); @@ -171,21 +174,20 @@ impl BatchToVerify { v_from_left_prover: V, // Prover P_i and verifier P_{i+1} both compute `v` and `q(x)` ) -> (Fp61BitPrime, Fp61BitPrime) where - U: Iterator + Send, - V: Iterator + Send, + U: VerifierLagrangeInput + Send, + V: VerifierLagrangeInput + Send, { - const LRF: usize = LargeProofGenerator::RECURSION_FACTOR; - const SRF: usize = SmallProofGenerator::RECURSION_FACTOR; + const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; // compute p_r - let p_r_right_prover = recursively_compute_final_check::<_, _, LRF, SRF>( - u_from_right_prover.into_iter(), + let p_r_right_prover = recursively_compute_final_check::<_, CRF>( + u_from_right_prover, challenges_for_right_prover, self.p_mask_from_right_prover, ); // compute q_r - let q_r_left_prover = recursively_compute_final_check::<_, _, LRF, SRF>( - v_from_left_prover.into_iter(), + let q_r_left_prover = recursively_compute_final_check::<_, CRF>( + v_from_left_prover, challenges_for_left_prover, self.q_mask_from_left_prover, ); @@ -241,11 +243,10 @@ impl BatchToVerify { where C: Context, { - const LRF: usize = LargeProofGenerator::RECURSION_FACTOR; - const SRF: usize = SmallProofGenerator::RECURSION_FACTOR; + const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; - const LPL: usize = LargeProofGenerator::PROOF_LENGTH; - const SPL: usize = SmallProofGenerator::PROOF_LENGTH; + const FPL: usize = FirstProofGenerator::PROOF_LENGTH; + const CPL: usize = CompressedProofGenerator::PROOF_LENGTH; let p_times_q_right = Self::compute_p_times_q( ctx.narrow(&Step::PTimesQ), @@ -256,7 +257,7 @@ impl BatchToVerify { .await?; // add Zero for p_times_q and sum since they are not secret shared - let diff_left = compute_g_differences::<_, SPL, SRF, LPL, LRF>( + let diff_left = compute_g_differences::<_, CPL, CRF, FPL, FRF>( &self.first_proof_from_left_prover, &self.proofs_from_left_prover, challenges_for_left_prover, @@ -264,7 +265,7 @@ impl BatchToVerify { Fp61BitPrime::ZERO, ); - let diff_right = compute_g_differences::<_, SPL, SRF, LPL, LRF>( + let diff_right = compute_g_differences::<_, CPL, CRF, FPL, FRF>( &self.first_proof_from_right_prover, &self.proofs_from_right_prover, challenges_for_right_prover, @@ -293,11 +294,13 @@ impl BatchToVerify { .await?; let diff_right_from_other_verifier = receive_data[0..length].to_vec(); - // compare recombined dif to zero - for i in 0..length { - if diff_right[i] + diff_right_from_other_verifier[i] != Fp61BitPrime::ZERO { - return Err(Error::DZKPValidationFailed); - } + // compare recombined diff to zero + let diff = zip(diff_right, diff_right_from_other_verifier) + .map(|(a, b)| a + b) + .collect::>(); + + if diff.ct_ne(&vec![Fp61BitPrime::ZERO; length]).into() { + return Err(Error::DZKPValidationFailed); } Ok(()) @@ -372,12 +375,12 @@ impl ProofHashes { const_assert_eq!( MAX_PROOF_RECURSION, - 9, - "following impl valid only for MAX_PROOF_RECURSION = 9" + 14, + "following impl valid only for MAX_PROOF_RECURSION = 14" ); impl Serializable for [Hash; MAX_PROOF_RECURSION] { - type Size = U288; + type Size = U448; type DeserializationError = ::DeserializationError; @@ -406,14 +409,14 @@ impl MpcMessage for [Hash; MAX_PROOF_RECURSION] {} const_assert_eq!( MAX_PROOF_RECURSION, - 9, - "following impl valid only for MAX_PROOF_RECURSION = 9" + 14, + "following impl valid only for MAX_PROOF_RECURSION = 14" ); type ProofDiff = [Fp61BitPrime; MAX_PROOF_RECURSION + 1]; impl Serializable for ProofDiff { - type Size = U80; + type Size = U120; type DeserializationError = ::DeserializationError; @@ -442,38 +445,43 @@ impl MpcMessage for ProofDiff {} #[cfg(all(test, unit_test))] pub mod test { + use std::iter::repeat_with; + use futures_util::future::try_join; - use rand::{thread_rng, Rng}; + use rand::Rng; use crate::{ - ff::{Fp61BitPrime, U128Conversions}, + ff::Fp61BitPrime, helpers::Direction, protocol::{ - context::{ - dzkp_field::{UVTupleBlock, BLOCK_SIZE}, - Context, - }, + context::Context, ipa_prf::{ malicious_security::{ lagrange::CanonicalLagrangeDenominator, - prover::{LargeProofGenerator, SmallProofGenerator}, - verifier::{compute_sum_share, interpolate_at_r}, + prover::{ProverValues, UVValues}, + verifier::{compute_sum_share, interpolate_at_r, VerifierValues}, + FIRST_RECURSION_FACTOR as FRF, }, validation_protocol::{proof_generation::ProofBatch, validation::BatchToVerify}, + CompressedProofGenerator, FirstProofGenerator, }, prss::SharedRandomness, RecordId, RecordIdRange, }, - secret_sharing::{replicated::ReplicatedSecretSharing, SharedValue}, + secret_sharing::SharedValue, test_executor::run, test_fixture::{Runner, TestWorld}, + utils::arraychunks::ArrayChunkIterator, }; - /// ## Panics - /// When proof is not generated correctly. - // todo: deprecate once validation protocol is implemented - pub fn simple_proof_check( - h: Fp61BitPrime, + // This is a helper for a test in `proof_generation.rs`, but is located here so it + // can access the internals of `BatchToVerify`. + // + // Possibly this (and the associated test) can be removed now that the validation + // protocol is implemented? (There was an old todo to that effect.) But it seems + // useful to keep around as a unit test. + pub(in crate::protocol::ipa_prf::validation_protocol) fn simple_proof_check<'a>( + uv_values: impl Iterator, left_verifier: &BatchToVerify, right_verifier: &BatchToVerify, ) { @@ -481,7 +489,7 @@ pub mod test { // first proof has correct length assert_eq!( left_verifier.first_proof_from_left_prover.len(), - LargeProofGenerator::PROOF_LENGTH + FirstProofGenerator::PROOF_LENGTH ); assert_eq!( left_verifier.first_proof_from_left_prover.len(), @@ -491,7 +499,7 @@ pub mod test { for i in 0..left_verifier.proofs_from_left_prover.len() { assert_eq!( (i, left_verifier.proofs_from_left_prover[i].len()), - (i, SmallProofGenerator::PROOF_LENGTH) + (i, CompressedProofGenerator::PROOF_LENGTH) ); assert_eq!( (i, left_verifier.proofs_from_left_prover[i].len()), @@ -509,36 +517,12 @@ pub mod test { // check first proof, // compute simple proof without lagrange interpolated points - let simple_proof = { - let block_to_polynomial = BLOCK_SIZE / LargeProofGenerator::RECURSION_FACTOR; - let simple_proof_uv = (0usize..100 * block_to_polynomial) - .map(|i| { - ( - (LargeProofGenerator::RECURSION_FACTOR * i - ..LargeProofGenerator::RECURSION_FACTOR * (i + 1)) - .map(|j| Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) * h) - .collect::<[Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR]>(), - (LargeProofGenerator::RECURSION_FACTOR * i - ..LargeProofGenerator::RECURSION_FACTOR * (i + 1)) - .map(|j| Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) * h) - .collect::<[Fp61BitPrime; LargeProofGenerator::RECURSION_FACTOR]>(), - ) - }) - .collect::>(); - - simple_proof_uv.iter().fold( - [Fp61BitPrime::ZERO; LargeProofGenerator::RECURSION_FACTOR], - |mut acc, (left, right)| { - for i in 0..LargeProofGenerator::RECURSION_FACTOR { - acc[i] += left[i] * right[i]; - } - acc - }, - ) - }; + let simple_proof = uv_values.fold([Fp61BitPrime::ZERO; FRF], |mut acc, (left, right)| { + for i in 0..FRF { + acc[i] += left[i] * right[i]; + } + acc + }); // reconstruct computed proof // by adding shares left and right @@ -551,13 +535,7 @@ pub mod test { // check for consistency // only check first R::USIZE field elements - assert_eq!( - (h.as_u128(), simple_proof.to_vec()), - ( - h.as_u128(), - proof_computed[0..LargeProofGenerator::RECURSION_FACTOR].to_vec() - ) - ); + assert_eq!(simple_proof.to_vec(), proof_computed[0..FRF].to_vec()); } #[test] @@ -565,39 +543,17 @@ pub mod test { run(|| async move { let world = TestWorld::default(); - let mut rng = thread_rng(); + let mut rng = world.rng(); - // each helper samples a random value h - // which is later used to generate distinct values across helpers - let h = Fp61BitPrime::truncate_from(rng.gen_range(0u128..100)); + let uv_values = repeat_with(|| (rng.gen(), rng.gen())) + .take(100) + .collect::>(); + let uv_values_iter = uv_values.iter().copied(); + let uv_values_iter_ref = &uv_values_iter; let [(helper_1_left, helper_1_right), (helper_2_left, helper_2_right), (helper_3_left, helper_3_right)] = world - .semi_honest(h, |ctx, h| async move { - let h = Fp61BitPrime::truncate_from(h.left().as_u128() % 100); - // generate blocks of UV values - // generate u values as (1h,2h,3h,....,10h*BlockSize) split into Blocksize chunks - // where BlockSize = 32 - // v values are identical to u - let uv_tuple_vec = (0usize..100) - .map(|i| { - ( - (BLOCK_SIZE * i..BLOCK_SIZE * (i + 1)) - .map(|j| { - Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) - * h - }) - .collect::<[Fp61BitPrime; BLOCK_SIZE]>(), - (BLOCK_SIZE * i..BLOCK_SIZE * (i + 1)) - .map(|j| { - Fp61BitPrime::truncate_from(u128::try_from(j).unwrap()) - * h - }) - .collect::<[Fp61BitPrime; BLOCK_SIZE]>(), - ) - }) - .collect::>(); - + .semi_honest((), |ctx, ()| async move { // generate and output VerifierBatch together with h value let ( my_batch_left_shares, @@ -607,7 +563,7 @@ pub mod test { ) = ProofBatch::generate( &ctx.narrow("generate_batch"), RecordIdRange::ALL, - uv_tuple_vec.into_iter(), + ProverValues(uv_values_iter_ref.clone()), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( @@ -641,31 +597,32 @@ pub mod test { /// Prover `P_i` and verifier `P_{i+1}` both generate `u` /// Prover `P_i` and verifier `P_{i-1}` both generate `v` /// - /// outputs `(my_u_and_v, u_from_right_prover, v_from_left_prover)` + /// outputs `(my_u_and_v, sum_of_uv, u_from_right_prover, v_from_left_prover)` + #[allow(clippy::type_complexity)] fn generate_u_v( ctx: &C, len: usize, ) -> ( - Vec>, + Vec<([Fp61BitPrime; FRF], [Fp61BitPrime; FRF])>, Fp61BitPrime, Vec, Vec, ) { // outputs - let mut vec_u_from_right_prover = Vec::::with_capacity(BLOCK_SIZE * len); - let mut vec_v_from_left_prover = Vec::::with_capacity(BLOCK_SIZE * len); + let mut vec_u_from_right_prover = Vec::::with_capacity(FRF * len); + let mut vec_v_from_left_prover = Vec::::with_capacity(FRF * len); let mut vec_my_u_and_v = - Vec::<([Fp61BitPrime; BLOCK_SIZE], [Fp61BitPrime; BLOCK_SIZE])>::with_capacity(len); + Vec::<([Fp61BitPrime; FRF], [Fp61BitPrime; FRF])>::with_capacity(len); let mut sum_of_uv = Fp61BitPrime::ZERO; // generate random u, v values using PRSS let mut counter = RecordId::FIRST; for _ in 0..len { - let mut my_u_array = [Fp61BitPrime::ZERO; BLOCK_SIZE]; - let mut my_v_array = [Fp61BitPrime::ZERO; BLOCK_SIZE]; - for i in 0..BLOCK_SIZE { + let mut my_u_array = [Fp61BitPrime::ZERO; FRF]; + let mut my_v_array = [Fp61BitPrime::ZERO; FRF]; + for i in 0..FRF { let (my_u, u_from_right_prover) = ctx.prss().generate_fields(counter); counter += 1; let (v_from_left_prover, my_v) = ctx.prss().generate_fields(counter); @@ -722,7 +679,7 @@ pub mod test { ) = ProofBatch::generate( &ctx.narrow("generate_batch"), RecordIdRange::ALL, - vec_my_u_and_v.into_iter(), + ProverValues(vec_my_u_and_v.into_iter()), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( @@ -774,9 +731,9 @@ pub mod test { } fn assert_batch(left: &BatchToVerify, right: &BatchToVerify, challenges: &[Fp61BitPrime]) { - const SRF: usize = SmallProofGenerator::RECURSION_FACTOR; - const SPL: usize = SmallProofGenerator::PROOF_LENGTH; - const LPL: usize = LargeProofGenerator::PROOF_LENGTH; + const CRF: usize = CompressedProofGenerator::RECURSION_FACTOR; + const CPL: usize = CompressedProofGenerator::PROOF_LENGTH; + const FPL: usize = FirstProofGenerator::PROOF_LENGTH; let first = recombine( &left.first_proof_from_left_prover, @@ -788,19 +745,19 @@ pub mod test { .zip(right.proofs_from_right_prover.iter()) .map(|(left, right)| recombine(left, right)) .collect::>(); - let denominator_first = CanonicalLagrangeDenominator::<_, LPL>::new(); - let denominator = CanonicalLagrangeDenominator::<_, SPL>::new(); + let denominator_first = CanonicalLagrangeDenominator::<_, FPL>::new(); + let denominator = CanonicalLagrangeDenominator::<_, CPL>::new(); let length = others.len(); let mut out = interpolate_at_r(&first, &challenges[0], &denominator_first); for (i, proof) in others.iter().take(length - 1).enumerate() { - assert_eq!((i, out), (i, compute_sum_share::<_, SRF, SPL>(proof))); + assert_eq!((i, out), (i, compute_sum_share::<_, CRF, CPL>(proof))); out = interpolate_at_r(proof, &challenges[i + 1], &denominator); } // last sum without masks let masks = others[length - 1][0]; - let last_sum = compute_sum_share::<_, SRF, SPL>(&others[length - 1]); + let last_sum = compute_sum_share::<_, CRF, CPL>(&others[length - 1]); assert_eq!(out, last_sum - masks); } @@ -828,7 +785,7 @@ pub mod test { ) = ProofBatch::generate( &ctx.narrow("generate_batch"), RecordIdRange::ALL, - vec_my_u_and_v.into_iter(), + ProverValues(vec_my_u_and_v.into_iter()), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( @@ -855,8 +812,10 @@ pub mod test { let (p, q) = batch_to_verify.compute_p_and_q_r( &challenges_for_left_prover, &challenges_for_right_prover, - vec_u_from_right_prover.into_iter(), - vec_v_from_left_prover.into_iter(), + VerifierValues( + vec_u_from_right_prover.into_iter().chunk_array::(), + ), + VerifierValues(vec_v_from_left_prover.into_iter().chunk_array::()), ); let p_times_q = @@ -866,7 +825,7 @@ pub mod test { let denominator = CanonicalLagrangeDenominator::< Fp61BitPrime, - { SmallProofGenerator::PROOF_LENGTH }, + { CompressedProofGenerator::PROOF_LENGTH }, >::new(); let g_r_left = interpolate_at_r( @@ -918,7 +877,7 @@ pub mod test { ) = ProofBatch::generate( &ctx.narrow("generate_batch"), RecordIdRange::ALL, - vec_my_u_and_v.into_iter(), + ProverValues(vec_my_u_and_v.into_iter()), ); let batch_to_verify = BatchToVerify::generate_batch_to_verify( @@ -958,8 +917,8 @@ pub mod test { let (p, q) = batch_to_verify.compute_p_and_q_r( &challenges_for_left_prover, &challenges_for_right_prover, - vec_u_from_right_prover.into_iter(), - vec_v_from_left_prover.into_iter(), + VerifierValues(vec_u_from_right_prover.into_iter().chunk_array::()), + VerifierValues(vec_v_from_left_prover.into_iter().chunk_array::()), ); batch_to_verify @@ -985,9 +944,14 @@ pub mod test { // Test a batch that exercises the case where `uv_values.len() == 1` but `did_set_masks = // false` in `ProofBatch::generate`. - verify_batch( - LargeProofGenerator::RECURSION_FACTOR * SmallProofGenerator::RECURSION_FACTOR - / BLOCK_SIZE, - ); + // + // We divide by `FRF` here because `generate_u_v`, which is used by + // `verify_batch` to generate test data, generates `len` chunks of u/v values of + // length `FRF`. We want the input u/v values to compress to exactly one + // u/v pair after some number of proof steps. + let num_inputs = + FirstProofGenerator::RECURSION_FACTOR * CompressedProofGenerator::RECURSION_FACTOR; + assert!(num_inputs % FRF == 0); + verify_batch(num_inputs / FRF); } } diff --git a/ipa-core/src/protocol/prss/mod.rs b/ipa-core/src/protocol/prss/mod.rs index 9c7e75ea7..dd3d2fbe6 100644 --- a/ipa-core/src/protocol/prss/mod.rs +++ b/ipa-core/src/protocol/prss/mod.rs @@ -209,6 +209,26 @@ impl SharedRandomness for IndexedSharedRandomness { } } +impl SharedRandomness for Arc { + type ChunkIter<'a, Z: ArrayLength> = + ::ChunkIter<'a, Z>; + + fn generate_chunks_one_side, Z: ArrayLength>( + &self, + index: I, + direction: Direction, + ) -> Self::ChunkIter<'_, Z> { + IndexedSharedRandomness::generate_chunks_one_side(self, index, direction) + } + + fn generate_chunks_iter, Z: ArrayLength>( + &self, + index: I, + ) -> impl Iterator, GenericArray)> { + IndexedSharedRandomness::generate_chunks_iter(self, index) + } +} + /// Specialized implementation for chunks that are generated using both left and right /// randomness. The functionality is the same as [`std::iter::zip`], but it does not use /// `Iterator` trait to call `left` and `right` next. It uses inlined method calls to diff --git a/ipa-core/src/protocol/step.rs b/ipa-core/src/protocol/step.rs index ee2ef726a..351eb23c0 100644 --- a/ipa-core/src/protocol/step.rs +++ b/ipa-core/src/protocol/step.rs @@ -5,18 +5,20 @@ use ipa_step_derive::{CompactGate, CompactStep}; #[derive(CompactStep, CompactGate)] pub enum ProtocolStep { Prss, + CrossShardPrss, #[step(child = crate::protocol::ipa_prf::step::IpaPrfStep)] IpaPrf, #[step(child = crate::protocol::hybrid::step::HybridStep)] Hybrid, Multiply, PrimeFieldAddition, + #[step(child = crate::protocol::ipa_prf::shuffle::step::ShardedShuffleStep)] + ShardedShuffle, /// Steps used in unit tests are grouped under this one. Ideally it should be /// gated behind test configuration, but it does not work with build.rs that /// does not enable any features when creating protocol gate file #[step(child = TestExecutionStep)] Test, - /// This step includes all the steps that are currently not linked into a top-level protocol. /// /// This allows those steps to be compiled. However, any use of them will fail at run time. @@ -39,8 +41,6 @@ pub enum DeadCodeStep { FeatureLabelDotProduct, #[step(child = crate::protocol::ipa_prf::boolean_ops::step::MultiplicationStep)] Multiplication, - #[step(child = crate::protocol::ipa_prf::shuffle::step::ShardedShuffleStep)] - ShardedShuffle, } /// Provides a unique per-iteration context in tests. diff --git a/ipa-core/src/query/executor.rs b/ipa-core/src/query/executor.rs index edd1662e4..bb11e66db 100644 --- a/ipa-core/src/query/executor.rs +++ b/ipa-core/src/query/executor.rs @@ -46,7 +46,8 @@ use crate::{ }; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] use crate::{ - ff::Fp32BitPrime, query::runner::execute_test_multiply, query::runner::test_add_in_prime_field, + ff::Fp32BitPrime, query::runner::execute_sharded_shuffle, query::runner::execute_test_multiply, + query::runner::test_add_in_prime_field, }; pub trait Result: Send + Debug { @@ -108,7 +109,7 @@ pub fn execute( config, gateway, input, - |_prss, _gateway, _config, _input| unimplemented!(), + |prss, gateway, _config, input| Box::pin(execute_sharded_shuffle(prss, gateway, input)), ), #[cfg(any(test, feature = "weak-field"))] (QueryType::TestAddInPrimeField, FieldType::Fp31) => do_query( diff --git a/ipa-core/src/query/mod.rs b/ipa-core/src/query/mod.rs index 6e6650862..1f2550987 100644 --- a/ipa-core/src/query/mod.rs +++ b/ipa-core/src/query/mod.rs @@ -11,4 +11,4 @@ pub use processor::{ QueryInputError, QueryKillStatus, QueryKilled, QueryStatusError, }; pub use runner::OprfIpaQuery; -pub use state::QueryStatus; +pub use state::{min_status, QueryStatus}; diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 120e1c5ca..1735db141 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -6,11 +6,13 @@ use std::{ use futures::{future::try_join, stream}; use serde::Serialize; +use super::min_status; use crate::{ - error::Error as ProtocolError, + error::{BoxError, Error as ProtocolError}, executor::IpaRuntime, helpers::{ - query::{PrepareQuery, QueryConfig, QueryInput}, + query::{CompareStatusRequest, PrepareQuery, QueryConfig, QueryInput}, + routing::RouteId, BroadcastError, Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, Role, RoleAssignment, ShardTransportError, ShardTransportImpl, Transport, }, @@ -110,6 +112,18 @@ pub enum QueryInputError { pub enum QueryStatusError { #[error("The query with id {0:?} does not exist")] NoSuchQuery(QueryId), + #[error(transparent)] + ShardBroadcastError(#[from] BroadcastError), + #[error("This shard {0:?} isn't the leader (shard 0)")] + NotLeader(ShardIndex), + #[error("This is the leader shard")] + Leader, + #[error("My status {my_status:?} for query {query_id:?} differs from {other_status:?}")] + DifferentStatus { + query_id: QueryId, + my_status: QueryStatus, + other_status: QueryStatus, + }, } #[derive(thiserror::Error, Debug)] @@ -123,6 +137,8 @@ pub enum QueryCompletionError { }, #[error("query execution failed: {0}")] ExecutionError(#[from] ProtocolError), + #[error("one or more shards rejected this request: {0}")] + ShardError(#[from] BroadcastError), } impl Debug for Processor { @@ -349,17 +365,109 @@ impl Processor { Some(status) } - /// Returns the query status. + /// This helper function is used to transform a [`BoxError`] into a + /// [`QueryStatusError::DifferentStatus`] and retrieve it's internal state. Returns [`None`] + /// if not possible. + fn downcast_state_error(box_error: BoxError) -> Option { + use crate::helpers::ApiError; + let api_error = box_error.downcast::().ok()?; + if let ApiError::QueryStatus(QueryStatusError::DifferentStatus { my_status, .. }) = + *api_error + { + return Some(my_status); + } + None + } + + /// This helper is used by the in-memory stack to obtain the state of other shards via a + /// [`QueryStatusError::DifferentStatus`] error. + /// TODO: Ideally broadcast should return a value, that we could use to parse the state instead + /// of relying on errors. + #[cfg(feature = "in-memory-infra")] + fn get_state_from_error( + error: crate::helpers::InMemoryTransportError, + ) -> Option { + if let crate::helpers::InMemoryTransportError::Rejected { inner, .. } = error { + return Self::downcast_state_error(inner); + } + None + } + + /// This helper is used by the HTTP stack to obtain the state of other shards via a + /// [`QueryStatusError::DifferentStatus`] error. + /// TODO: Ideally broadcast should return a value, that we could use to parse the state instead + /// of relying on errors. + #[cfg(feature = "real-world-infra")] + fn get_state_from_error(shard_error: crate::net::ShardError) -> Option { + if let crate::net::Error::Application { error, .. } = shard_error.source { + return Self::downcast_state_error(error); + } + None + } + + /// Returns the query status in this helper, by querying all shards. /// /// ## Errors /// If query is not registered on this helper. /// /// ## Panics /// If the query collection mutex is poisoned. - pub fn query_status(&self, query_id: QueryId) -> Result { - let status = self + pub async fn query_status( + &self, + shard_transport: ShardTransportImpl, + query_id: QueryId, + ) -> Result { + let shard_index = shard_transport.identity(); + if shard_index != ShardIndex::FIRST { + return Err(QueryStatusError::NotLeader(shard_index)); + } + + let mut status = self .get_status(query_id) .ok_or(QueryStatusError::NoSuchQuery(query_id))?; + + let shard_query_status_req = CompareStatusRequest { query_id, status }; + + let shard_responses = shard_transport.broadcast(shard_query_status_req).await; + if let Err(e) = shard_responses { + // The following silently ignores the cases where the query isn't found. + let states: Vec<_> = e + .failures + .into_iter() + .filter_map(|(_si, e)| Self::get_state_from_error(e)) + .collect(); + status = states.into_iter().fold(status, min_status); + } + + Ok(status) + } + + /// Compares this shard status against the given type. Returns an error if different. + /// + /// ## Errors + /// If query is not registered on this helper or + /// + /// ## Panics + /// If the query collection mutex is poisoned. + pub fn shard_status( + &self, + shard_transport: &ShardTransportImpl, + req: &CompareStatusRequest, + ) -> Result { + let shard_index = shard_transport.identity(); + if shard_index == ShardIndex::FIRST { + return Err(QueryStatusError::Leader); + } + let status = self + .get_status(req.query_id) + .ok_or(QueryStatusError::NoSuchQuery(req.query_id))?; + if req.status != status { + return Err(QueryStatusError::DifferentStatus { + query_id: req.query_id, + my_status: status, + other_status: req.status, + }); + } Ok(status) } @@ -373,6 +481,7 @@ impl Processor { pub async fn complete( &self, query_id: QueryId, + shard_transport: ShardTransportImpl, ) -> Result, QueryCompletionError> { let handle = { let mut queries = self.queries.inner.lock().unwrap(); @@ -397,6 +506,18 @@ impl Processor { } }; // release mutex before await + // Inform other shards about our intent to complete the query. + // If any of them rejects it, report the error back. We expect all shards + // to be in the same state. In normal cycle, this API is called only after + // query status reports completion. + if shard_transport.identity() == ShardIndex::FIRST { + // See shard finalizer protocol to see how shards merge their results together. + // At the end, only leader holds the value + shard_transport + .broadcast((RouteId::CompleteQuery, query_id)) + .await?; + } + Ok(handle.await?) } @@ -440,39 +561,47 @@ mod tests { use tokio::sync::Barrier; use crate::{ - ff::FieldType, + executor::IpaRuntime, + ff::{boolean_array::BA64, FieldType}, helpers::{ make_owned_handler, query::{PrepareQuery, QueryConfig, QueryType::TestMultiply}, + routing::Addr, ApiError, HandlerBox, HelperIdentity, HelperResponse, InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport, RequestHandler, RoleAssignment, Transport, TransportIdentity, }, protocol::QueryId, query::{ - processor::Processor, state::StateError, NewQueryError, PrepareQueryError, QueryStatus, - QueryStatusError, + processor::Processor, + state::{QueryState, RunningQuery, StateError}, + NewQueryError, PrepareQueryError, QueryStatus, QueryStatusError, }, sharding::ShardIndex, }; - fn prepare_query_handler(cb: F) -> Arc> + fn prepare_query() -> PrepareQuery { + PrepareQuery { + query_id: QueryId, + config: test_multiply_config(), + roles: RoleAssignment::new(HelperIdentity::make_three()), + } + } + + fn create_handler(cb: F) -> Arc> where - F: Fn(PrepareQuery) -> Fut + Send + Sync + 'static, + F: Fn(Addr) -> Fut + Send + Sync + 'static, Fut: Future> + Send + Sync + 'static, { - make_owned_handler(move |req, _| { - let prepare_query = req.into().unwrap(); - cb(prepare_query) - }) + make_owned_handler(move |req, _| cb(req)) } fn helper_respond_ok() -> Arc> { - prepare_query_handler(|_| async { Ok(HelperResponse::ok()) }) + create_handler(|_| async { Ok(HelperResponse::ok()) }) } fn shard_respond_ok(_si: ShardIndex) -> Arc> { - prepare_query_handler(|_| async { Ok(HelperResponse::ok()) }) + create_handler(|_| async { Ok(HelperResponse::ok()) }) } fn test_multiply_config() -> QueryConfig { @@ -559,7 +688,15 @@ mod tests { shard_transport: InMemoryTransport, } + impl Default for TestComponents { + fn default() -> Self { + Self::new(TestComponentsArgs::default()) + } + } + impl TestComponents { + const COMPLETE_QUERY_RESULT: Vec = Vec::new(); + fn new(mut args: TestComponentsArgs) -> Self { let mpc_network = InMemoryMpcNetwork::new( args.mpc_handlers @@ -584,6 +721,31 @@ mod tests { shard_transport, } } + + /// This initiates a new query on all shards and puts them all on running state. + /// It also makes up a fake query result + async fn new_running_query(&self) -> QueryId { + self.processor + .new_query( + self.first_transport.clone_ref(), + self.shard_transport.clone_ref(), + self.query_config, + ) + .await + .unwrap(); + let (tx, rx) = tokio::sync::oneshot::channel(); + self.processor + .queries + .handle(QueryId) + .set_state(QueryState::Running(RunningQuery { + result: rx, + join_handle: IpaRuntime::current().spawn(async {}), + })) + .unwrap(); + tx.send(Ok(Box::new(Self::COMPLETE_QUERY_RESULT))).unwrap(); + + QueryId + } } #[tokio::test] @@ -592,14 +754,14 @@ mod tests { let barrier = Arc::new(Barrier::new(3)); let h2_barrier = Arc::clone(&barrier); let h3_barrier = Arc::clone(&barrier); - let h2 = prepare_query_handler(move |_| { + let h2 = create_handler(move |_| { let barrier = Arc::clone(&h2_barrier); async move { barrier.wait().await; Ok(HelperResponse::ok()) } }); - let h3 = prepare_query_handler(move |_| { + let h3 = create_handler(move |_| { let barrier = Arc::clone(&h3_barrier); async move { barrier.wait().await; @@ -608,9 +770,11 @@ mod tests { }); args.mpc_handlers = [None, Some(h2), Some(h3)]; let t = TestComponents::new(args); - let qc_future = t - .processor - .new_query(t.first_transport, t.shard_transport, t.query_config); + let qc_future = t.processor.new_query( + t.first_transport, + t.shard_transport.clone_ref(), + t.query_config, + ); pin_mut!(qc_future); // poll future once to trigger query status change @@ -618,7 +782,10 @@ mod tests { assert_eq!( QueryStatus::Preparing, - t.processor.query_status(QueryId).unwrap() + t.processor + .query_status(t.shard_transport.clone_ref(), QueryId) + .await + .unwrap() ); // unblock sends barrier.wait().await; @@ -636,7 +803,10 @@ mod tests { ); assert_eq!( QueryStatus::AwaitingInputs, - t.processor.query_status(QueryId).unwrap() + t.processor + .query_status(t.shard_transport.clone_ref(), QueryId) + .await + .unwrap() ); } @@ -665,7 +835,7 @@ mod tests { async fn prepare_error() { let mut args = TestComponentsArgs::default(); let h2 = helper_respond_ok(); - let h3 = prepare_query_handler(|_| async move { + let h3 = create_handler(|_| async move { Err(ApiError::QueryPrepare(PrepareQueryError::WrongTarget)) }); args.mpc_handlers = [None, Some(h2), Some(h3)]; @@ -688,7 +858,7 @@ mod tests { #[tokio::test] async fn shard_prepare_error() { fn shard_handle(si: ShardIndex) -> Arc> { - prepare_query_handler(move |_| async move { + create_handler(move |_| async move { if si == ShardIndex(2) { Err(ApiError::QueryPrepare(PrepareQueryError::AlreadyRunning)) } else { @@ -704,7 +874,11 @@ mod tests { let t = TestComponents::new(args); let r = t .processor - .new_query(t.first_transport, t.shard_transport, t.query_config) + .new_query( + t.first_transport, + t.shard_transport.clone_ref(), + t.query_config, + ) .await; // The following makes sure the error is a broadcast error from shard 2 assert!(r.is_err()); @@ -716,7 +890,10 @@ mod tests { } } assert!(matches!( - t.processor.query_status(QueryId).unwrap_err(), + t.processor + .query_status(t.shard_transport, QueryId) + .await + .unwrap_err(), QueryStatusError::NoSuchQuery(_) )); } @@ -732,7 +909,7 @@ mod tests { // First we setup MPC handlers that will return some error let mut args = TestComponentsArgs::default(); let h2 = helper_respond_ok(); - let h3 = prepare_query_handler(|_| async move { + let h3 = create_handler(|_| async move { Err(ApiError::QueryPrepare(PrepareQueryError::WrongTarget)) }); args.mpc_handlers = [None, Some(h2), Some(h3)]; @@ -755,33 +932,109 @@ mod tests { assert!(t.processor.get_status(QueryId).is_none()); } + mod complete { + + use crate::{ + helpers::{make_owned_handler, routing::RouteId, Transport}, + query::{ + processor::{ + tests::{HelperResponse, TestComponents, TestComponentsArgs}, + QueryId, + }, + ProtocolResult, QueryCompletionError, + }, + sharding::ShardIndex, + }; + + #[tokio::test] + async fn complete_basic() { + let t = TestComponents::default(); + let query_id = t.new_running_query().await; + + assert_eq!( + TestComponents::COMPLETE_QUERY_RESULT.to_bytes(), + t.processor + .complete(query_id, t.shard_transport.clone_ref()) + .await + .unwrap() + .to_bytes() + ); + } + + #[tokio::test] + #[should_panic(expected = "QueryCompletion(NoSuchQuery(QueryId))")] + async fn complete_one_shard_fails() { + let mut args = TestComponentsArgs::default(); + + args.set_shard_handler(|shard_id| { + make_owned_handler(move |req, _| { + if shard_id != ShardIndex::from(1) || req.route != RouteId::CompleteQuery { + futures::future::ok(HelperResponse::ok()) + } else { + futures::future::err(QueryCompletionError::NoSuchQuery(QueryId).into()) + } + }) + }); + + let t = TestComponents::new(args); + let query_id = t.new_running_query().await; + + let _ = t + .processor + .complete(query_id, t.shard_transport.clone_ref()) + .await + .unwrap(); + } + + #[tokio::test] + async fn only_leader_broadcasts() { + let mut args = TestComponentsArgs::default(); + + args.set_shard_handler(|shard_id| { + make_owned_handler(move |_req, _| { + if shard_id == ShardIndex::FIRST { + panic!("Leader shard must not receive requests through shard channels"); + } else { + futures::future::ok(HelperResponse::ok()) + } + }) + }); + + let t = TestComponents::new(args); + let query_id = t.new_running_query().await; + + t.processor + .complete(query_id, t.shard_transport.clone_ref()) + .await + .unwrap(); + } + } + mod prepare { use super::*; use crate::query::QueryStatusError; - fn prepare_query() -> PrepareQuery { - PrepareQuery { - query_id: QueryId, - config: test_multiply_config(), - roles: RoleAssignment::new(HelperIdentity::make_three()), - } - } - #[tokio::test] async fn happy_case() { let req = prepare_query(); let t = TestComponents::new(TestComponentsArgs::default()); assert!(matches!( - t.processor.query_status(QueryId).unwrap_err(), + t.processor + .query_status(t.shard_transport.clone_ref(), QueryId) + .await + .unwrap_err(), QueryStatusError::NoSuchQuery(_) )); t.processor - .prepare_helper(t.second_transport, t.shard_transport, req) + .prepare_helper(t.second_transport, t.shard_transport.clone_ref(), req) .await .unwrap(); assert_eq!( QueryStatus::AwaitingInputs, - t.processor.query_status(QueryId).unwrap() + t.processor + .query_status(t.shard_transport, QueryId) + .await + .unwrap() ); } @@ -879,6 +1132,159 @@ mod tests { } } + mod query_status { + use super::*; + use crate::{helpers::query::CompareStatusRequest, protocol::QueryId}; + + /// * From the standpoint of leader shard in Helper 1 + /// * On query_status + /// + /// The min state should be returned. In this case, if I, as leader, am in AwaitingInputs + /// state and shards report that they are further ahead (Completed and Running), then my + /// state is returned. + #[tokio::test] + async fn combined_status_response() { + fn shard_handle(si: ShardIndex) -> Arc> { + create_handler(move |_| async move { + match si { + ShardIndex(3) => { + Err(ApiError::QueryStatus(QueryStatusError::DifferentStatus { + query_id: QueryId, + my_status: QueryStatus::Completed, + other_status: QueryStatus::Preparing, + })) + } + ShardIndex(2) => { + Err(ApiError::QueryStatus(QueryStatusError::DifferentStatus { + query_id: QueryId, + my_status: QueryStatus::Running, + other_status: QueryStatus::Preparing, + })) + } + _ => Ok(HelperResponse::ok()), + } + }) + } + let mut args = TestComponentsArgs { + shard_count: 4, + ..Default::default() + }; + args.set_shard_handler(shard_handle); + let t = TestComponents::new(args); + let req = prepare_query(); + // Using prepare shard to set the inner state, but in reality we should be using prepare_helper + // Prepare helper will use the shard_handle defined above though and will fail. The following + // achieves the same state. + t.processor + .prepare_shard( + &t.shard_network + .transport(HelperIdentity::ONE, ShardIndex::from(1)), + req, + ) + .unwrap(); + let r = t + .processor + .query_status(t.shard_transport.clone_ref(), QueryId) + .await; + if let Err(e) = r { + panic!("Unexpected error {e}"); + } + if let Ok(st) = r { + assert_eq!(QueryStatus::AwaitingInputs, st); + } + } + + /// * From the standpoint of leader shard in Helper 1 + /// * On query_status + /// + /// If one of my shards hasn't received the query yet (NoSuchQuery) the leader shouldn't + /// return an error but instead with the min state. + #[tokio::test] + async fn status_query_doesnt_exist() { + fn shard_handle(si: ShardIndex) -> Arc> { + create_handler(move |_| async move { + match si { + ShardIndex(3) => Err(ApiError::QueryStatus(QueryStatusError::NoSuchQuery( + QueryId, + ))), + _ => Ok(HelperResponse::ok()), + } + }) + } + let mut args = TestComponentsArgs { + shard_count: 4, + ..Default::default() + }; + args.set_shard_handler(shard_handle); + let t = TestComponents::new(args); + let req = prepare_query(); + // Using prepare shard to set the inner state, but in reality we should be using prepare_helper + // Prepare_helper will use the shard_handle defined above though and will fail. The following + // achieves the same state. + t.processor + .prepare_shard( + &t.shard_network + .transport(HelperIdentity::ONE, ShardIndex::from(1)), + req, + ) + .unwrap(); + let r = t + .processor + .query_status(t.shard_transport.clone_ref(), QueryId) + .await; + if let Err(e) = r { + panic!("Unexpected error {e}"); + } + if let Ok(st) = r { + assert_eq!(QueryStatus::AwaitingInputs, st); + } + } + + /// Context: + /// * From the standpoint of the second shard in Helper 2 + /// + /// This test makes sure that an error is returned if I get a [`Processor::query_status`] + /// call. Only the shard leader (shard 0) should handle those calls. + #[tokio::test] + async fn rejects_if_not_shard_leader() { + let t = TestComponents::new(TestComponentsArgs::default()); + assert!(matches!( + t.processor + .query_status( + t.shard_network + .transport(HelperIdentity::TWO, ShardIndex::from(1)), + QueryId + ) + .await, + Err(QueryStatusError::NotLeader(_)) + )); + } + + /// Context: + /// * From the standpoint of the leader shard in Helper 2 + /// + /// This test makes sure that an error is returned if I get a [`Processor::shard_status`] + /// call. Only non-leaders (1,2,3...) should handle those calls. + #[tokio::test] + async fn shard_not_leader() { + let req = CompareStatusRequest { + query_id: QueryId, + status: QueryStatus::Running, + }; + let t = TestComponents::new(TestComponentsArgs::default()); + assert!(matches!( + t.processor + .shard_status( + &t.shard_network + .transport(HelperIdentity::TWO, ShardIndex::FIRST), + &req + ) + .unwrap_err(), + QueryStatusError::Leader + )); + } + } + mod kill { use std::sync::Arc; @@ -1011,11 +1417,7 @@ mod tests { .start_query(vec![a, b].into_iter(), test_multiply_config()) .await?; - while !app - .query_status(query_id)? - .into_iter() - .all(|s| s == QueryStatus::Completed) - { + while !(app.query_status(query_id).await? == QueryStatus::Completed) { sleep(Duration::from_millis(1)).await; } diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index f16d3fac2..09bef945c 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -17,7 +17,7 @@ use crate::{ }, hpke::PrivateKeyRegistry, protocol::{ - basics::{BooleanProtocols, Reveal}, + basics::{BooleanArrayMul, BooleanProtocols, Reveal}, context::{DZKPUpgraded, MacUpgraded, ShardedContext, UpgradableContext}, hybrid::{ hybrid_protocol, @@ -72,6 +72,9 @@ where PrfSharing, PRF_CHUNK, Field = Fp25519> + FromPrss, Replicated: Reveal, Output = >::Array>, + Replicated: BooleanProtocols>, + Replicated: BooleanArrayMul> + + Reveal, Output = >::Array>, { #[tracing::instrument("hybrid_query", skip_all, fields(sz=%query_size))] pub async fn execute( @@ -273,7 +276,7 @@ mod tests { #[should_panic( expected = "not implemented: protocol::hybrid::hybrid_protocol is not fully implemented" )] - fn encrypted_hybrid_reports() { + fn encrypted_hybrid_reports_happy() { // While this test currently checks for an unimplemented panic it is // designed to test for a correct result for a complete implementation. run(|| async { @@ -290,7 +293,7 @@ mod tests { } = build_buffers_from_records(&records, SHARDS, &hybrid_info); let world = TestWorld::>::with_shards(TestWorldConfig::default()); - let contexts = world.contexts(); + let contexts = world.malicious_contexts(); #[allow(clippy::large_futures)] let results = flatten3v(buffers.into_iter().zip(contexts).map( @@ -381,7 +384,7 @@ mod tests { let world: TestWorld> = TestWorld::with_shards(TestWorldConfig::default()); - let contexts = world.contexts(); + let contexts = world.malicious_contexts(); #[allow(clippy::large_futures)] let results = flatten3v(buffers.into_iter().zip(contexts).map( @@ -434,7 +437,7 @@ mod tests { let world: TestWorld> = TestWorld::with_shards(TestWorldConfig::default()); - let contexts = world.contexts(); + let contexts = world.malicious_contexts(); #[allow(clippy::large_futures)] let results = flatten3v(buffers.into_iter().zip(contexts).map( diff --git a/ipa-core/src/query/runner/mod.rs b/ipa-core/src/query/runner/mod.rs index 3f1b59f55..83f033fe4 100644 --- a/ipa-core/src/query/runner/mod.rs +++ b/ipa-core/src/query/runner/mod.rs @@ -4,11 +4,15 @@ mod hybrid; mod oprf_ipa; mod reshard_tag; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] +mod sharded_shuffle; +#[cfg(any(test, feature = "cli", feature = "test-fixture"))] mod test_multiply; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] pub(super) use add_in_prime_field::execute as test_add_in_prime_field; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] +pub(super) use sharded_shuffle::execute_sharded_shuffle; +#[cfg(any(test, feature = "cli", feature = "test-fixture"))] pub(super) use test_multiply::execute_test_multiply; pub use self::oprf_ipa::OprfIpaQuery; diff --git a/ipa-core/src/query/runner/reshard_tag.rs b/ipa-core/src/query/runner/reshard_tag.rs index 5d1c3b8f5..9b535fc11 100644 --- a/ipa-core/src/query/runner/reshard_tag.rs +++ b/ipa-core/src/query/runner/reshard_tag.rs @@ -100,7 +100,7 @@ mod tests { let world: TestWorld> = TestWorld::with_shards(TestWorldConfig::default()); world - .semi_honest( + .malicious( vec![BA8::truncate_from(1u128), BA8::truncate_from(2u128)].into_iter(), |ctx, input| async move { let shard_id = ctx.shard_id(); @@ -130,7 +130,7 @@ mod tests { let world: TestWorld> = TestWorld::with_shards(TestWorldConfig::default()); world - .semi_honest( + .malicious( vec![BA8::truncate_from(1u128), BA8::truncate_from(2u128)].into_iter(), |ctx, input| async move { reshard_aad( diff --git a/ipa-core/src/query/runner/sharded_shuffle.rs b/ipa-core/src/query/runner/sharded_shuffle.rs new file mode 100644 index 000000000..a90161b25 --- /dev/null +++ b/ipa-core/src/query/runner/sharded_shuffle.rs @@ -0,0 +1,122 @@ +use futures_util::TryStreamExt; +use ipa_step::StepNarrow; + +use crate::{ + error::Error, + ff::boolean_array::BA64, + helpers::{setup_cross_shard_prss, BodyStream, Gateway, SingleRecordStream}, + protocol::{ + context::{Context, ShardedContext, ShardedSemiHonestContext}, + ipa_prf::Shuffle, + prss::Endpoint as PrssEndpoint, + step::ProtocolStep, + Gate, + }, + query::runner::QueryResult, + secret_sharing::replicated::semi_honest::AdditiveShare, + sharding::{ShardConfiguration, Sharded}, + sync::Arc, +}; + +/// This executes the sharded shuffle protocol that consists of only one step: +/// permute the private inputs using a permutation that is not known to any helper +/// and client. +pub async fn execute_sharded_shuffle<'a>( + prss: &'a PrssEndpoint, + gateway: &'a Gateway, + input: BodyStream, +) -> QueryResult { + let gate = Gate::default().narrow(&ProtocolStep::CrossShardPrss); + let cross_shard_prss = + setup_cross_shard_prss(gateway, &gate, prss.indexed(&gate), gateway).await?; + let ctx = ShardedSemiHonestContext::new_sharded( + prss, + gateway, + Sharded { + shard_id: gateway.shard_id(), + shard_count: gateway.shard_count(), + prss: Arc::new(cross_shard_prss), + }, + ) + .narrow(&ProtocolStep::ShardedShuffle); + + Ok(Box::new(execute(ctx, input).await?)) +} + +#[tracing::instrument("sharded_shuffle", skip_all)] +pub async fn execute(ctx: C, input_stream: BodyStream) -> Result>, Error> +where + C: ShardedContext + Shuffle, +{ + let input = SingleRecordStream::, _>::new(input_stream) + .try_collect::>() + .await?; + ctx.shuffle(input).await +} + +#[cfg(all(test, unit_test))] +mod tests { + use futures_util::future::try_join_all; + use generic_array::GenericArray; + use typenum::Unsigned; + + use crate::{ + ff::{boolean_array::BA64, Serializable, U128Conversions}, + query::runner::sharded_shuffle::execute, + secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares}, + test_executor::run, + test_fixture::{try_join3_array, Reconstruct, TestWorld, TestWorldConfig, WithShards}, + utils::array::zip3, + }; + + #[test] + fn basic() { + run(|| async { + const SHARDS: usize = 20; + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let contexts = world.contexts(); + let input = (0..20_u128).map(BA64::truncate_from).collect::>(); + + #[allow(clippy::redundant_closure_for_method_calls)] + let shard_shares: [Vec>>; 3] = + input.clone().into_iter().share().map(|helper_shares| { + helper_shares + .chunks(SHARDS / 3) + .map(|v| v.to_vec()) + .collect() + }); + + let result = + try_join3_array(zip3(contexts, shard_shares).map(|(h_contexts, h_shares)| { + try_join_all( + h_contexts + .into_iter() + .zip(h_shares) + .map(|(ctx, shard_shares)| { + let shard_stream = shard_shares + .into_iter() + .flat_map(|share| { + const SIZE: usize = + as Serializable>::Size::USIZE; + let mut slice = [0_u8; SIZE]; + share.serialize(GenericArray::from_mut_slice(&mut slice)); + slice + }) + .collect::>() + .into(); + + execute(ctx, shard_stream) + }), + ) + })) + .await + .unwrap() + .map(|v| v.into_iter().flatten().collect::>()) + .reconstruct(); + + // 1/20! probability of this permutation to be the same + assert_ne!(input, result); + }); + } +} diff --git a/ipa-core/src/query/state.rs b/ipa-core/src/query/state.rs index 460296022..28f981222 100644 --- a/ipa-core/src/query/state.rs +++ b/ipa-core/src/query/state.rs @@ -48,6 +48,26 @@ impl From<&QueryState> for QueryStatus { } } +/// This function is used, among others, by the [`Processor`] to return a unified response when +/// queried about the state of a sharded helper. In such scenarios, there will be many different +/// [`QueryStatus`] and the [`Processor`] needs to return a single one that describes the entire +/// helper. With this function we're saying that the minimum state across all shards is the one +/// that describes the helper. +#[must_use] +pub fn min_status(a: QueryStatus, b: QueryStatus) -> QueryStatus { + match (a, b) { + (QueryStatus::Preparing, _) | (_, QueryStatus::Preparing) => QueryStatus::Preparing, + (QueryStatus::AwaitingInputs, _) | (_, QueryStatus::AwaitingInputs) => { + QueryStatus::AwaitingInputs + } + (QueryStatus::Running, _) | (_, QueryStatus::Running) => QueryStatus::Running, + (QueryStatus::AwaitingCompletion, _) | (_, QueryStatus::AwaitingCompletion) => { + QueryStatus::AwaitingCompletion + } + (QueryStatus::Completed, _) => QueryStatus::Completed, + } +} + /// TODO: a macro would be very useful here to keep it in sync with `QueryStatus` pub enum QueryState { Empty, @@ -60,13 +80,14 @@ pub enum QueryState { impl QueryState { pub fn transition(cur_state: &Self, new_state: Self) -> Result { - use QueryState::{AwaitingInputs, Empty, Preparing}; + use QueryState::{AwaitingInputs, Empty, Preparing, Running}; match (cur_state, &new_state) { // If query is not running, coordinator initial state is preparing // and followers initial state is awaiting inputs (Empty, Preparing(_) | AwaitingInputs(_, _, _)) - | (Preparing(_), AwaitingInputs(_, _, _)) => Ok(new_state), + | (Preparing(_), AwaitingInputs(_, _, _)) + | (AwaitingInputs(_, _, _), Running(_)) => Ok(new_state), (_, Preparing(_)) => Err(StateError::AlreadyRunning), (_, _) => Err(StateError::InvalidState { from: cur_state.into(), @@ -226,3 +247,29 @@ impl Drop for RemoveQuery<'_> { } } } + +#[cfg(all(test, unit_test))] +mod tests { + use crate::query::{state::min_status, QueryStatus}; + + #[test] + fn test_order() { + // this list sorted in priority order. Preparing is the lowest possible value, + // while Completed is the highest. + let all = [ + QueryStatus::Preparing, + QueryStatus::AwaitingInputs, + QueryStatus::Running, + QueryStatus::AwaitingCompletion, + QueryStatus::Completed, + ]; + + for i in 0..all.len() { + let this = all[i]; + for other in all.into_iter().skip(i) { + assert_eq!(this, min_status(this, other)); + assert_eq!(this, min_status(other, this)); + } + } + } +} diff --git a/ipa-core/src/report/hybrid.rs b/ipa-core/src/report/hybrid.rs index b3842b6b5..511a0418b 100644 --- a/ipa-core/src/report/hybrid.rs +++ b/ipa-core/src/report/hybrid.rs @@ -40,7 +40,7 @@ use crate::{ error::{BoxError, Error}, ff::{ boolean_array::{ - BooleanArray, BooleanArrayReader, BooleanArrayWriter, BA112, BA3, BA64, BA8, + BooleanArray, BooleanArrayReader, BooleanArrayWriter, BA112, BA3, BA32, BA64, BA8, }, Serializable, }, @@ -48,7 +48,7 @@ use crate::{ open_in_place, seal_in_place, CryptError, EncapsulationSize, PrivateKeyRegistry, PublicKeyRegistry, TagSize, }, - protocol::ipa_prf::shuffle::Shuffleable, + protocol::ipa_prf::{boolean_ops::expand_shared_array_in_place, shuffle::Shuffleable}, report::hybrid_info::{HybridConversionInfo, HybridImpressionInfo, HybridInfo}, secret_sharing::{ replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing}, @@ -681,6 +681,57 @@ where /// Converted report where shares of match key are replaced with OPRF value pub type PrfHybridReport = IndistinguishableHybridReport; +/// After grouping `IndistinguishableHybridReport`s by the OPRF of thier `match_key`, +/// that OPRF value is no longer required. +pub type AggregateableHybridReport = IndistinguishableHybridReport; + +impl IndistinguishableHybridReport +where + BK: BooleanArray, + V: BooleanArray, +{ + pub const ZERO: Self = Self { + match_key: (), + value: Replicated::::ZERO, + breakdown_key: Replicated::::ZERO, + }; + + fn join_fields(value: V, breakdown_key: BK) -> ::Share { + let mut share = ::Share::ZERO; + + BooleanArrayWriter::new(&mut share) + .write(&value) + .write(&breakdown_key); + + share + } + + fn split_fields(share: &::Share) -> (V, BK) { + let bits = BooleanArrayReader::new(share); + let (value, bits) = bits.read(); + let (breakdown_key, _) = bits.read(); + (value, breakdown_key) + } +} + +/// When aggregating reports, we need to lift the value from `V` to `HV`. +impl From> for AggregateableHybridReport +where + BK: SharedValue + BooleanArray, + V: SharedValue + BooleanArray, + HV: SharedValue + BooleanArray, +{ + fn from(report: PrfHybridReport) -> Self { + let mut value = Replicated::::ZERO; + expand_shared_array_in_place(&mut value, &report.value, 0); + Self { + match_key: (), + breakdown_key: report.breakdown_key, + value, + } + } +} + /// This struct is designed to fit both `HybridConversionReport`s /// and `HybridImpressionReport`s so that they can be made indistingushable. /// Note: these need to be shuffled (and secret shares need to be rerandomized) @@ -830,6 +881,41 @@ where } } +impl Shuffleable for IndistinguishableHybridReport +where + BK: BooleanArray, + V: BooleanArray, +{ + // this requires BK:BAXX + V:BAYY such that XX + YY <= 32 + // this is checked in a debud_assert call in ::new below + type Share = BA32; + + fn left(&self) -> Self::Share { + Self::join_fields(self.value.left(), self.breakdown_key.left()) + } + + fn right(&self) -> Self::Share { + Self::join_fields(self.value.right(), self.breakdown_key.right()) + } + + fn new(l: Self::Share, r: Self::Share) -> Self { + debug_assert!( + BK::BITS + V::BITS <= Self::Share::BITS, + "share type {} is too small", + std::any::type_name::(), + ); + + let left = Self::split_fields(&l); + let right = Self::split_fields(&r); + + Self { + match_key: (), + value: ReplicatedSecretSharing::new(left.0, right.0), + breakdown_key: ReplicatedSecretSharing::new(left.1, right.1), + } + } +} + impl PrfHybridReport { const PRF_MK_SZ: usize = 8; const V_SZ: usize = as Serializable>::Size::USIZE; diff --git a/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs b/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs index ece21a2b6..a782d9a27 100644 --- a/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs +++ b/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs @@ -10,6 +10,7 @@ use futures::{ stream::{iter as stream_iter, StreamExt}, }; use generic_array::{ArrayLength, GenericArray}; +use subtle::ConstantTimeEq; use typenum::Unsigned; use crate::{ @@ -41,7 +42,7 @@ pub struct AdditiveShare, const N: usize } pub trait ExtendableField: Field { - type ExtendedField: Field + FromRandom; + type ExtendedField: Field + FromRandom + ConstantTimeEq; fn to_extended(&self) -> Self::ExtendedField; } @@ -57,7 +58,7 @@ impl> + FieldSimd, const N: us { } -impl ExtendableField for F { +impl ExtendableField for F { type ExtendedField = F; fn to_extended(&self) -> Self::ExtendedField { diff --git a/ipa-core/src/secret_sharing/vector/transpose.rs b/ipa-core/src/secret_sharing/vector/transpose.rs index 3a8357f02..b020ddb04 100644 --- a/ipa-core/src/secret_sharing/vector/transpose.rs +++ b/ipa-core/src/secret_sharing/vector/transpose.rs @@ -100,7 +100,7 @@ pub trait TransposeFrom { // // From Hacker's Delight (2nd edition), Figure 7-6. // -// There are comments on `dzkp_field::convert_values_table_indices`, which implements a +// There are comments on `dzkp_field::bits_to_table_indices`, which implements a // similar transformation, that may help to understand how this works. #[inline] pub fn transpose_8x8>(x: B) -> [u8; 8] { @@ -125,7 +125,7 @@ pub fn transpose_8x8>(x: B) -> [u8; 8] { // // Loosely based on Hacker's Delight (2nd edition), Figure 7-6. // -// There are comments on `dzkp_field::convert_values_table_indices`, which implements a +// There are comments on `dzkp_field::bits_to_table_indices`, which implements a // similar transformation, that may help to understand how this works. #[inline] pub fn transpose_16x16(src: &[u8; 32]) -> [u8; 32] { diff --git a/ipa-core/src/test_fixture/app.rs b/ipa-core/src/test_fixture/app.rs index 4eb51cc9c..1cf01cb38 100644 --- a/ipa-core/src/test_fixture/app.rs +++ b/ipa-core/src/test_fixture/app.rs @@ -1,5 +1,6 @@ use std::{array, iter::zip}; +use futures::future::join_all; use generic_array::GenericArray; use typenum::Unsigned; @@ -11,7 +12,7 @@ use crate::{ ApiError, InMemoryMpcNetwork, InMemoryShardNetwork, Transport, }, protocol::QueryId, - query::QueryStatus, + query::{min_status, QueryStatus}, secret_sharing::IntoShares, test_fixture::try_join3_array, utils::array::zip3, @@ -118,12 +119,13 @@ impl TestApp { /// Propagates errors retrieving the query status. /// ## Panics /// Never. - pub fn query_status(&self, query_id: QueryId) -> Result<[QueryStatus; 3], ApiError> { - Ok((0..3) - .map(|i| self.drivers[i].query_status(query_id)) - .collect::, _>>()? - .try_into() - .unwrap()) + #[allow(clippy::disallowed_methods)] + pub async fn query_status(&self, query_id: QueryId) -> Result { + join_all((0..3).map(|i| self.drivers[i].query_status(query_id))) + .await + .into_iter() + .reduce(|s1, s2| Ok(min_status(s1?, s2?))) + .unwrap() } /// ## Errors diff --git a/ipa-core/src/test_fixture/hybrid.rs b/ipa-core/src/test_fixture/hybrid.rs index a28fa7232..d522089c3 100644 --- a/ipa-core/src/test_fixture/hybrid.rs +++ b/ipa-core/src/test_fixture/hybrid.rs @@ -10,7 +10,8 @@ use crate::{ }, rand::Rng, report::hybrid::{ - HybridConversionReport, HybridImpressionReport, HybridReport, IndistinguishableHybridReport, + AggregateableHybridReport, HybridConversionReport, HybridImpressionReport, HybridReport, + IndistinguishableHybridReport, }, secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, IntoShares}, test_fixture::sharing::Reconstruct, @@ -22,13 +23,15 @@ pub enum TestHybridRecord { TestConversion { match_key: u64, value: u32 }, } -#[derive(PartialEq, Eq)] -pub struct TestIndistinguishableHybridReport { - pub match_key: u64, +#[derive(PartialEq, Eq, Debug)] +pub struct TestIndistinguishableHybridReport { + pub match_key: MK, pub value: u32, pub breakdown_key: u32, } +pub type TestAggregateableHybridReport = TestIndistinguishableHybridReport<()>; + impl Reconstruct for [&IndistinguishableHybridReport; 3] where @@ -60,6 +63,54 @@ where } } +impl Reconstruct + for [&IndistinguishableHybridReport; 3] +where + BK: BooleanArray + U128Conversions + IntoShares>, + V: BooleanArray + U128Conversions + IntoShares>, +{ + fn reconstruct(&self) -> TestAggregateableHybridReport { + let breakdown_key = self + .each_ref() + .map(|v| v.breakdown_key.clone()) + .reconstruct() + .as_u128(); + let value = self + .each_ref() + .map(|v| v.value.clone()) + .reconstruct() + .as_u128(); + + TestAggregateableHybridReport { + match_key: (), + breakdown_key: breakdown_key.try_into().unwrap(), + value: value.try_into().unwrap(), + } + } +} + +impl IntoShares> for TestAggregateableHybridReport +where + BK: BooleanArray + U128Conversions + IntoShares>, + V: BooleanArray + U128Conversions + IntoShares>, +{ + fn share_with(self, rng: &mut R) -> [AggregateableHybridReport; 3] { + let ba_breakdown_key = BK::try_from(u128::from(self.breakdown_key)) + .unwrap() + .share_with(rng); + let ba_value = V::try_from(u128::from(self.value)).unwrap().share_with(rng); + zip(ba_breakdown_key, ba_value) + .map(|(breakdown_key, value)| AggregateableHybridReport { + match_key: (), + breakdown_key, + value, + }) + .collect::>() + .try_into() + .unwrap() + } +} + impl IntoShares> for TestHybridRecord where BK: BooleanArray + U128Conversions + IntoShares>, diff --git a/ipa-core/src/test_fixture/sharing.rs b/ipa-core/src/test_fixture/sharing.rs index e30e8fc2e..424e73e17 100644 --- a/ipa-core/src/test_fixture/sharing.rs +++ b/ipa-core/src/test_fixture/sharing.rs @@ -155,6 +155,15 @@ where } } +impl Reconstruct>> for Vec<[Vec; 3]> +where + for<'i> [&'i [I]; 3]: Reconstruct>, +{ + fn reconstruct(&self) -> Vec> { + self.iter().map(Reconstruct::reconstruct).collect() + } +} + impl Reconstruct> for [BitDecomposed; 3] where for<'i> [&'i [I]; 3]: Reconstruct>, diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index 8c2db9a4e..91cfc87ea 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -240,9 +240,10 @@ impl TestWorld> { /// Panics if world has more or less than 3 gateways/participants #[must_use] pub fn malicious_contexts(&self) -> [Vec>; 3] { + let gate = &self.next_gate(); self.shards() .iter() - .map(|shard| shard.malicious_contexts(&self.next_gate())) + .map(|shard| shard.malicious_contexts(gate)) .fold([Vec::new(), Vec::new(), Vec::new()], |mut acc, contexts| { // Distribute contexts into the respective vectors. for (vec, context) in acc.iter_mut().zip(contexts.iter()) { diff --git a/ipa-core/src/utils/mod.rs b/ipa-core/src/utils/mod.rs index e8dfd95ae..6829f57fa 100644 --- a/ipa-core/src/utils/mod.rs +++ b/ipa-core/src/utils/mod.rs @@ -4,4 +4,4 @@ pub mod arraychunks; mod power_of_two; #[cfg(target_pointer_width = "64")] -pub use power_of_two::NonZeroU32PowerOfTwo; +pub use power_of_two::{non_zero_prev_power_of_two, NonZeroU32PowerOfTwo}; diff --git a/ipa-core/src/utils/power_of_two.rs b/ipa-core/src/utils/power_of_two.rs index a84455c92..b34ac0423 100644 --- a/ipa-core/src/utils/power_of_two.rs +++ b/ipa-core/src/utils/power_of_two.rs @@ -68,9 +68,19 @@ impl NonZeroU32PowerOfTwo { } } +/// Returns the largest power of two less than or equal to `target`. +/// +/// Returns 1 if `target` is zero. +pub fn non_zero_prev_power_of_two(target: usize) -> usize { + let bits = usize::BITS - target.leading_zeros(); + + 1 << (std::cmp::max(1, bits) - 1) +} + #[cfg(all(test, unit_test))] mod tests { use super::{ConvertError, NonZeroU32PowerOfTwo}; + use crate::utils::power_of_two::non_zero_prev_power_of_two; #[test] fn rejects_invalid_values() { @@ -107,4 +117,19 @@ mod tests { "3".parse::().unwrap_err() ); } + + #[test] + fn test_prev_power_of_two() { + const TWO_EXP_62: usize = 1usize << (usize::BITS - 2); + const TWO_EXP_63: usize = 1usize << (usize::BITS - 1); + assert_eq!(non_zero_prev_power_of_two(0), 1usize); + assert_eq!(non_zero_prev_power_of_two(1), 1usize); + assert_eq!(non_zero_prev_power_of_two(2), 2usize); + assert_eq!(non_zero_prev_power_of_two(3), 2usize); + assert_eq!(non_zero_prev_power_of_two(4), 4usize); + assert_eq!(non_zero_prev_power_of_two(TWO_EXP_63 - 1), TWO_EXP_62); + assert_eq!(non_zero_prev_power_of_two(TWO_EXP_63), TWO_EXP_63); + assert_eq!(non_zero_prev_power_of_two(TWO_EXP_63 + 1), TWO_EXP_63); + assert_eq!(non_zero_prev_power_of_two(usize::MAX), TWO_EXP_63); + } }