Skip to content

Commit

Permalink
Fix batch size to be a power of two in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Oct 2, 2024
1 parent a670145 commit e1479ed
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 14 deletions.
4 changes: 2 additions & 2 deletions ipa-core/src/bin/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use ipa_core::{
error::BoxError,
helpers::HelperIdentity,
net::{ClientIdentity, HttpShardTransport, HttpTransport, MpcHelperClient},
AppConfig, AppSetup,
AppConfig, AppSetup, NonZeroU32PowerOfTwo,
};
use tracing::{error, info};

Expand Down Expand Up @@ -93,7 +93,7 @@ struct ServerArgs {

/// Override the amount of active work processed in parallel
#[arg(long)]
active_work: Option<NonZeroUsize>,
active_work: Option<NonZeroU32PowerOfTwo>,
}

#[derive(Debug, Subcommand)]
Expand Down
1 change: 1 addition & 0 deletions ipa-core/src/helpers/gateway/send.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ mod test {
read_size: usize,
record_size: usize,
) {
#[allow(clippy::needless_update)] // stall detection feature wants default value
let gateway_config = GatewayConfig {
active: active.next_power_of_two().try_into().unwrap(),
read_size: read_size.try_into().unwrap(),
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ mod seq_join;
mod serde;
pub mod sharding;
mod utils;

pub use app::{AppConfig, HelperApp, Setup as AppSetup};
pub use utils::NonZeroU32PowerOfTwo;

extern crate core;
#[cfg(all(feature = "shuttle", test))]
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/protocol/basics/mul/dzkp_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ mod test {

let res = world
.malicious((a, b), |ctx, (a, b)| async move {
let validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 10);
let validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 8);
let mctx = validator.context();
let result = a
.multiply(&b, mctx.set_total_records(1), RecordId::from(0))
Expand Down
7 changes: 2 additions & 5 deletions ipa-core/src/protocol/context/dzkp_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1197,10 +1197,7 @@ mod tests {

fn max_multiplications_per_gate_strategy(record_count: usize) -> impl Strategy<Value = usize> {
let max_max_mults = record_count.min(128);
prop_oneof![
1usize..=max_max_mults,
(0u32..=max_max_mults.ilog2()).prop_map(|i| 1usize << i)
]
(0u32..=max_max_mults.ilog2()).prop_map(|i| 1usize << i)
}

prop_compose! {
Expand Down Expand Up @@ -1546,7 +1543,7 @@ mod tests {

let [h1_batch, h2_batch, h3_batch] = world
.malicious((a, b), |ctx, (a, b)| async move {
let mut validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 10);
let mut validator = ctx.dzkp_validator(TEST_DZKP_STEPS, 8);
let mctx = validator.context();
let _ = a
.multiply(&b, mctx.set_total_records(1), RecordId::from(0))
Expand Down
6 changes: 4 additions & 2 deletions ipa-core/src/protocol/ipa_prf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,10 @@ where
// We expect 2*256 = 512 gates in total for two additions per conversion. The vectorization factor
// is CONV_CHUNK. Let `len` equal the number of converted shares. The total amount of
// multiplications is CONV_CHUNK*512*len. We want CONV_CHUNK*512*len ≈ 50M, or len ≈ 381, for a
// reasonably-sized proof.
const CONV_PROOF_CHUNK: usize = 400;
// reasonably-sized proof. There is also a constraint on proof chunks to be powers of two, so
// we pick the closest power of two close to 381 but less than that value. 256 gives us around 33M
// multiplications per batch
const CONV_PROOF_CHUNK: usize = 256;

#[tracing::instrument(name = "compute_prf_for_inputs", skip_all)]
async fn compute_prf_for_inputs<C, BK, TV, TS>(
Expand Down
4 changes: 2 additions & 2 deletions ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ where
},
// TODO: this should not be necessary, but probably can't be removed
// until we align read_size with the batch size.
std::cmp::min(sh_ctx.active_work().get(), chunk_size),
std::cmp::min(sh_ctx.active_work().get(), chunk_size.next_power_of_two()),
);
dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap());
let ctx_for_row_number = set_up_contexts(&dzkp_validator.context(), histogram)?;
Expand Down Expand Up @@ -541,7 +541,7 @@ where
protocol: &Step::Aggregate,
validate: &Step::AggregateValidate,
},
aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()),
aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()).next_power_of_two(),
);
let user_contributions = flattened_user_results.try_collect::<Vec<_>>().await?;
let result =
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/test_fixture/world.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ impl Runner<NotSharded> for TestWorld<NotSharded> {
R: Future<Output = O> + Send,
{
self.malicious(input, |ctx, share| async {
let v = ctx.dzkp_validator(TEST_DZKP_STEPS, 10);
let v = ctx.dzkp_validator(TEST_DZKP_STEPS, 8);
let m_ctx = v.context();
let m_result = helper_fn(m_ctx, share).await;
v.validate().await.unwrap();
Expand Down

0 comments on commit e1479ed

Please sign in to comment.