Skip to content

Commit

Permalink
Enforce active work to be a power of two
Browse files Browse the repository at this point in the history
This is the second attempt to mitigate send buffer misalignment. Previous one (private-attribution#1307) didn't handle all the edge cases and was abandoned in favour of this PR.

What I believe makes this change work is the new requirement for active work to be a power of two. With this constraint, it is much easier to align the read size with it. Given that `total_capacity = active * record_size`, we can represent `read_size` as a multiple of `record_size` too:
`read_size = X * record_size`. If X is a power of two and active_work is a power of two, then they will always be aligned with each other.

For example, if active work is 16, read size is 10 bytes and record size is 3 bytes, then:

```
total_capacity = 16*3
read_size = X*3 (close to 10)
X = 2 (power of two that satisfies the requirement)
```

when picking up the read size, we are rounding down to avoid buffer overflows. In the example above, setting X=3 would make it closer to the desired read size, but it is greater than 10, so we pick 2 instead.
  • Loading branch information
akoshelev committed Oct 2, 2024
1 parent 4429326 commit 7d99a5b
Show file tree
Hide file tree
Showing 2 changed files with 293 additions and 20 deletions.
217 changes: 211 additions & 6 deletions ipa-core/src/helpers/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pub struct State {
pub struct GatewayConfig {
/// The number of items that can be active at the one time.
/// This is used to determine the size of sending and receiving buffers.
/// Any value that is not a power of two will be rejected
pub active: NonZeroUsize,

/// Number of bytes packed and sent together in one batch down to the network layer. This
Expand All @@ -84,6 +85,10 @@ pub struct GatewayConfig {
/// payload may not be exactly this, but it will be the closest multiple of record size to this
/// number. For instance, having 14 bytes records and batch size of 4096 will result in
/// 4088 bytes being sent in a batch.
///
/// The actual size for read chunks may be bigger or smaller, depending on the record size
/// sent through each channel. Read size will be aligned with [`Self::active_work`] value to
/// prevent deadlocks.
pub read_size: NonZeroUsize,

/// Time to wait before checking gateway progress. If no progress has been made between
Expand Down Expand Up @@ -276,29 +281,46 @@ impl GatewayConfig {
// capabilities (see #ipa/1171) to allow that currently.
usize::from(value.size),
),
);
)
.next_power_of_two();
// we set active to be at least 2, so unwrap is fine.
self.active = NonZeroUsize::new(active).unwrap();
}
}

#[cfg(all(test, unit_test))]
mod tests {
use std::iter::{repeat, zip};
use std::{
iter::{repeat, zip},
sync::Arc,
};

use futures::{
future::{join, try_join, try_join_all},
stream,
stream::StreamExt,
};
use proptest::proptest;

use crate::{
ff::{boolean_array::BA3, Fp31, Fp32BitPrime, Gf2, U128Conversions},
helpers::{Direction, GatewayConfig, MpcMessage, Role, SendingEnd},
ff::{
boolean_array::{BA20, BA256, BA3, BA4, BA5, BA6, BA7, BA8},
FieldType, Fp31, Fp32BitPrime, Gf2, U128Conversions,
},
helpers::{
gateway::QueryConfig,
query::{QuerySize, QueryType},
ChannelId, Direction, GatewayConfig, MpcMessage, MpcReceivingEnd, Role, SendingEnd,
TotalRecords,
},
protocol::{
context::{Context, ShardedContext},
RecordId,
Gate, RecordId,
},
secret_sharing::{
replicated::semi_honest::AdditiveShare, SharedValue, SharedValueArray, StdArray,
},
secret_sharing::replicated::semi_honest::AdditiveShare,
seq_join::seq_join,
sharding::ShardConfiguration,
test_executor::run,
test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards},
Expand Down Expand Up @@ -516,6 +538,87 @@ mod tests {
});
}

macro_rules! send_recv_test {
(
message: $message:expr,
read_size: $read_size:expr,
active_work: $active_work:expr,
total_records: $total_records:expr,
$test_fn: ident
) => {
#[test]
fn $test_fn() {
run(|| async {
send_recv($read_size, $active_work, $total_records, $message).await;
});
}
};
}

send_recv_test! {
message: BA20::ZERO,
read_size: 5,
active_work: 8,
total_records: 25,
test_ba20_5_10_25
}

send_recv_test! {
message: StdArray::<BA256, 16>::ZERO_ARRAY,
read_size: 2048,
active_work: 16,
total_records: 43,
test_ba256_by_16_2048_10_43
}

send_recv_test! {
message: StdArray::<BA8, 16>::ZERO_ARRAY,
read_size: 2048,
active_work: 32,
total_records: 50,
test_ba8_by_16_2048_37_50
}

proptest! {
#[test]
fn send_recv_randomized(
total_records in 1_usize..10_000,
active in 1_usize..10_000,
read_size in (1_usize..32768),
record_size in 1_usize..=8,
) {
let active = active.next_power_of_two();
run(move || async move {
match record_size {
1 => send_recv(read_size, active, total_records, StdArray::<BA8, 32>::ZERO_ARRAY).await,
2 => send_recv(read_size, active, total_records, StdArray::<BA8, 64>::ZERO_ARRAY).await,
3 => send_recv(read_size, active, total_records, BA3::ZERO).await,
4 => send_recv(read_size, active, total_records, BA4::ZERO).await,
5 => send_recv(read_size, active, total_records, BA5::ZERO).await,
6 => send_recv(read_size, active, total_records, BA6::ZERO).await,
7 => send_recv(read_size, active, total_records, BA7::ZERO).await,
8 => send_recv(read_size, active, total_records, StdArray::<BA256, 16>::ZERO_ARRAY).await,
_ => unreachable!(),
}
});
}
}

/// ensures when active work is set from query input, it is always a power of two
#[test]
fn gateway_config_active_work_power_of_two() {
let mut config = GatewayConfig {
active: 2.try_into().unwrap(),
..Default::default()
};
config.set_active_work_from_query_config(&QueryConfig {
size: QuerySize::try_from(5).unwrap(),
field_type: FieldType::Fp31,
query_type: QueryType::TestAddInPrimeField,
});
assert_eq!(8, config.active_work().get());
}

async fn shard_comms_test(test_world: &TestWorld<WithShards<2>>) {
let input = vec![BA3::truncate_from(0_u32), BA3::truncate_from(1_u32)];

Expand Down Expand Up @@ -553,4 +656,106 @@ mod tests {
let world_ptr = world as *mut _;
(world, world_ptr)
}

/// This serves the purpose of randomized testing of our send channels by providing
/// variable sizes for read size, active work and record size
async fn send_recv<M>(read_size: usize, active_work: usize, total_records: usize, sample: M)
where
M: MpcMessage + Clone + PartialEq,
{
fn duplex_channel<M: MpcMessage>(
world: &TestWorld,
left: Role,
right: Role,
total_records: usize,
) -> (SendingEnd<Role, M>, MpcReceivingEnd<M>) {
(
world.gateway(left).get_mpc_sender::<M>(
&ChannelId::new(right, Gate::default()),
TotalRecords::specified(total_records).unwrap(),
),
world
.gateway(right)
.get_mpc_receiver::<M>(&ChannelId::new(left, Gate::default())),
)
}

async fn circuit<M>(
send_channel: SendingEnd<Role, M>,
recv_channel: MpcReceivingEnd<M>,
active_work: usize,
total_records: usize,
msg: M,
) where
M: MpcMessage + Clone + PartialEq,
{
let send_notify = Arc::new(tokio::sync::Notify::new());

// perform "multiplication-like" operation (send + subsequent receive)
// and "validate": block the future until we have at least `active_work`
// futures pending and unblock them all at the same time
seq_join(
active_work.try_into().unwrap(),
stream::iter(std::iter::repeat(msg).take(total_records).enumerate()).map(
|(record_id, msg)| {
let send_channel = &send_channel;
let recv_channel = &recv_channel;
let send_notify = Arc::clone(&send_notify);
async move {
send_channel
.send(record_id.into(), msg.clone())
.await
.unwrap();
let r = recv_channel.receive(record_id.into()).await.unwrap();
// this simulates validate_record API by forcing futures to wait
// until the entire batch is validated by the last future in that batch
if record_id % active_work == active_work - 1
|| record_id == total_records - 1
{
send_notify.notify_waiters();
} else {
send_notify.notified().await;
}
assert_eq!(msg, r);
}
},
),
)
.collect::<Vec<_>>()
.await;
}

let config = TestWorldConfig {
gateway_config: GatewayConfig {
active: active_work.try_into().unwrap(),
read_size: read_size.try_into().unwrap(),
..Default::default()
},
..Default::default()
};

let world = TestWorld::new_with(&config);
let (h1_send_channel, h1_recv_channel) =
duplex_channel(&world, Role::H1, Role::H2, total_records);
let (h2_send_channel, h2_recv_channel) =
duplex_channel(&world, Role::H2, Role::H1, total_records);

join(
circuit(
h1_send_channel,
h1_recv_channel,
active_work,
total_records,
sample.clone(),
),
circuit(
h2_send_channel,
h2_recv_channel,
active_work,
total_records,
sample,
),
)
.await;
}
}
Loading

0 comments on commit 7d99a5b

Please sign in to comment.