diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 790245587..db6adf7fd 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -18,7 +18,7 @@ use ipa_core::{ error::BoxError, helpers::HelperIdentity, net::{ClientIdentity, HttpShardTransport, HttpTransport, MpcHelperClient}, - AppConfig, AppSetup, + AppConfig, AppSetup, NonZeroU32PowerOfTwo, }; use tracing::{error, info}; @@ -93,7 +93,7 @@ struct ServerArgs { /// Override the amount of active work processed in parallel #[arg(long)] - active_work: Option, + active_work: Option, } #[derive(Debug, Subcommand)] diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index 70cae3707..71bfc1c6b 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -434,6 +434,7 @@ mod test { read_size: usize, record_size: usize, ) { + #[allow(clippy::needless_update)] // stall detection feature wants default value let gateway_config = GatewayConfig { active: active.next_power_of_two().try_into().unwrap(), read_size: read_size.try_into().unwrap(), diff --git a/ipa-core/src/lib.rs b/ipa-core/src/lib.rs index 59cae0106..f88ea718e 100644 --- a/ipa-core/src/lib.rs +++ b/ipa-core/src/lib.rs @@ -32,8 +32,8 @@ mod seq_join; mod serde; pub mod sharding; mod utils; - pub use app::{AppConfig, HelperApp, Setup as AppSetup}; +pub use utils::NonZeroU32PowerOfTwo; extern crate core; #[cfg(all(feature = "shuttle", test))] diff --git a/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs b/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs index 23a96c982..e024c4483 100644 --- a/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs +++ b/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs @@ -101,7 +101,7 @@ mod test { let res = world .malicious((a, b), |ctx, (a, b)| async move { - let validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 10); + let validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 8); let mctx = validator.context(); let result = a .multiply(&b, mctx.set_total_records(1), RecordId::from(0)) diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 835a32e9d..a16ca32fd 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -1197,10 +1197,7 @@ mod tests { fn max_multiplications_per_gate_strategy(record_count: usize) -> impl Strategy { let max_max_mults = record_count.min(128); - prop_oneof![ - 1usize..=max_max_mults, - (0u32..=max_max_mults.ilog2()).prop_map(|i| 1usize << i) - ] + (0u32..=max_max_mults.ilog2()).prop_map(|i| 1usize << i) } prop_compose! { @@ -1546,7 +1543,7 @@ mod tests { let [h1_batch, h2_batch, h3_batch] = world .malicious((a, b), |ctx, (a, b)| async move { - let mut validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 10); + let mut validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 8); let mctx = validator.context(); let _ = a .multiply(&b, mctx.set_total_records(1), RecordId::from(0)) diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index cc3fa2633..754f179b6 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -308,8 +308,10 @@ where // We expect 2*256 = 512 gates in total for two additions per conversion. The vectorization factor // is CONV_CHUNK. Let `len` equal the number of converted shares. The total amount of // multiplications is CONV_CHUNK*512*len. We want CONV_CHUNK*512*len ≈ 50M, or len ≈ 381, for a -// reasonably-sized proof. -const CONV_PROOF_CHUNK: usize = 400; +// reasonably-sized proof. There is also a constraint on proof chunks to be powers of two, so +// we pick the closest power of two close to 381 but less than that value. 256 gives us around 33M +// multiplications per batch +const CONV_PROOF_CHUNK: usize = 256; #[tracing::instrument(name = "compute_prf_for_inputs", skip_all)] async fn compute_prf_for_inputs( diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 9a1f8f278..e3c8cf49c 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -512,7 +512,7 @@ where }, // TODO: this should not be necessary, but probably can't be removed // until we align read_size with the batch size. - std::cmp::min(sh_ctx.active_work().get(), chunk_size), + std::cmp::min(sh_ctx.active_work().get(), chunk_size.next_power_of_two()), ); dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap()); let ctx_for_row_number = set_up_contexts(&dzkp_validator.context(), histogram)?; @@ -541,7 +541,7 @@ where protocol: &Step::Aggregate, validate: &Step::AggregateValidate, }, - aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()), + aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()).next_power_of_two(), ); let user_contributions = flattened_user_results.try_collect::>().await?; let result = diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index f92326c9b..bdcd7448e 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -676,7 +676,7 @@ impl Runner for TestWorld { R: Future + Send, { self.malicious(input, |ctx, share| async { - let v = ctx.dzkp_validator(TEST_DZKP_STEPS, 10); + let v = ctx.dzkp_validator(TEST_DZKP_STEPS, 8); let m_ctx = v.context(); let m_result = helper_fn(m_ctx, share).await; v.validate().await.unwrap();