Skip to content

Commit

Permalink
Enforce active work to be a power of two
Browse files Browse the repository at this point in the history
This is the second attempt to mitigate send buffer misalignment. Previous one (private-attribution#1307) didn't handle all the edge cases and was abandoned in favour of this PR.

What I believe makes this change work is the new requirement for active work to be a power of two. With this constraint, it is much easier to align the read size with it. Given that `total_capacity = active * record_size`, we can represent `read_size` as a multiple of `record_size` too:
`read_size = X * record_size`. If X is a power of two and active_work is a power of two, then they will always be aligned with each other.

For example, if active work is 16, read size is 10 bytes and record size is 3 bytes, then:

```
total_capacity = 16*3
read_size = X*3 (close to 10)
X = 2 (power of two that satisfies the requirement)
```

when picking up the read size, we are rounding down to avoid buffer overflows. In the example above, setting X=3 would make it closer to the desired read size, but it is greater than 10, so we pick 2 instead.
  • Loading branch information
akoshelev committed Oct 2, 2024
1 parent a55a3b4 commit b0b7223
Show file tree
Hide file tree
Showing 2 changed files with 289 additions and 18 deletions.
211 changes: 207 additions & 4 deletions ipa-core/src/helpers/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pub struct State {
pub struct GatewayConfig {
/// The number of items that can be active at the one time.
/// This is used to determine the size of sending and receiving buffers.
/// Any value that is not a power of two will be rejected
pub active: NonZeroUsize,

/// Number of bytes packed and sent together in one batch down to the network layer. This
Expand All @@ -84,6 +85,10 @@ pub struct GatewayConfig {
/// payload may not be exactly this, but it will be the closest multiple of record size to this
/// number. For instance, having 14 bytes records and batch size of 4096 will result in
/// 4088 bytes being sent in a batch.
///
/// The actual size for read chunks may be bigger or smaller, depending on the record size
/// sent through each channel. Read size will be aligned with [`Self::active_work`] value to
/// prevent deadlocks.
pub read_size: NonZeroUsize,

/// Time to wait before checking gateway progress. If no progress has been made between
Expand Down Expand Up @@ -279,7 +284,8 @@ impl GatewayConfig {
// capabilities (see #ipa/1171) to allow that currently.
usize::from(value.size),
),
);
)
.next_power_of_two();
// we set active to be at least 2, so unwrap is fine.
self.active = NonZeroUsize::new(active).unwrap();
}
Expand All @@ -299,23 +305,35 @@ mod tests {
use std::{
iter::{repeat, zip},
num::NonZeroUsize,
sync::Arc,
};

use futures::{
future::{join, try_join, try_join_all},
stream,
stream::StreamExt,
};
use proptest::proptest;

use crate::{
ff::{boolean_array::BA3, Fp31, Fp32BitPrime, Gf2, U128Conversions},
ff::{
boolean_array::{BA20, BA256, BA3, BA4, BA5, BA6, BA7, BA8},
FieldType, Fp31, Fp32BitPrime, Gf2, U128Conversions,
},
helpers::{
ChannelId, Direction, GatewayConfig, MpcMessage, Role, SendingEnd, TotalRecords,
gateway::QueryConfig,
query::{QuerySize, QueryType},
ChannelId, Direction, GatewayConfig, MpcMessage, MpcReceivingEnd, Role, SendingEnd,
TotalRecords,
},
protocol::{
context::{Context, ShardedContext},
Gate, RecordId,
},
secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue},
secret_sharing::{
replicated::semi_honest::AdditiveShare, SharedValue, SharedValueArray, StdArray,
},
seq_join::seq_join,
sharding::ShardConfiguration,
test_executor::run,
test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards},
Expand Down Expand Up @@ -569,6 +587,87 @@ mod tests {
});
}

macro_rules! send_recv_test {
(
message: $message:expr,
read_size: $read_size:expr,
active_work: $active_work:expr,
total_records: $total_records:expr,
$test_fn: ident
) => {
#[test]
fn $test_fn() {
run(|| async {
send_recv($read_size, $active_work, $total_records, $message).await;
});
}
};
}

send_recv_test! {
message: BA20::ZERO,
read_size: 5,
active_work: 8,
total_records: 25,
test_ba20_5_10_25
}

send_recv_test! {
message: StdArray::<BA256, 16>::ZERO_ARRAY,
read_size: 2048,
active_work: 16,
total_records: 43,
test_ba256_by_16_2048_10_43
}

send_recv_test! {
message: StdArray::<BA8, 16>::ZERO_ARRAY,
read_size: 2048,
active_work: 32,
total_records: 50,
test_ba8_by_16_2048_37_50
}

proptest! {
#[test]
fn send_recv_randomized(
total_records in 1_usize..500,
active in 2_usize..1000,
read_size in (1_usize..32768),
record_size in 1_usize..=8,
) {
let active = active.next_power_of_two();
run(move || async move {
match record_size {
1 => send_recv(read_size, active, total_records, StdArray::<BA8, 32>::ZERO_ARRAY).await,
2 => send_recv(read_size, active, total_records, StdArray::<BA8, 64>::ZERO_ARRAY).await,
3 => send_recv(read_size, active, total_records, BA3::ZERO).await,
4 => send_recv(read_size, active, total_records, BA4::ZERO).await,
5 => send_recv(read_size, active, total_records, BA5::ZERO).await,
6 => send_recv(read_size, active, total_records, BA6::ZERO).await,
7 => send_recv(read_size, active, total_records, BA7::ZERO).await,
8 => send_recv(read_size, active, total_records, StdArray::<BA256, 16>::ZERO_ARRAY).await,
_ => unreachable!(),
}
});
}
}

/// ensures when active work is set from query input, it is always a power of two
#[test]
fn gateway_config_active_work_power_of_two() {
let mut config = GatewayConfig {
active: 2.try_into().unwrap(),
..Default::default()
};
config.set_active_work_from_query_config(&QueryConfig {
size: QuerySize::try_from(5).unwrap(),
field_type: FieldType::Fp31,
query_type: QueryType::TestAddInPrimeField,
});
assert_eq!(8, config.active_work().get());
}

async fn shard_comms_test(test_world: &TestWorld<WithShards<2>>) {
let input = vec![BA3::truncate_from(0_u32), BA3::truncate_from(1_u32)];

Expand Down Expand Up @@ -606,4 +705,108 @@ mod tests {
let world_ptr = world as *mut _;
(world, world_ptr)
}

/// This serves the purpose of randomized testing of our send channels by providing
/// variable sizes for read size, active work and record size
async fn send_recv<M>(read_size: usize, active_work: usize, total_records: usize, sample: M)
where
M: MpcMessage + Clone + PartialEq,
{
fn duplex_channel<M: MpcMessage>(
world: &TestWorld,
left: Role,
right: Role,
total_records: usize,
active_work: usize,
) -> (SendingEnd<Role, M>, MpcReceivingEnd<M>) {
(
world.gateway(left).get_mpc_sender::<M>(
&ChannelId::new(right, Gate::default()),
TotalRecords::specified(total_records).unwrap(),
active_work.try_into().unwrap(),
),
world
.gateway(right)
.get_mpc_receiver::<M>(&ChannelId::new(left, Gate::default())),
)
}

async fn circuit<M>(
send_channel: SendingEnd<Role, M>,
recv_channel: MpcReceivingEnd<M>,
active_work: usize,
total_records: usize,
msg: M,
) where
M: MpcMessage + Clone + PartialEq,
{
let send_notify = Arc::new(tokio::sync::Notify::new());

// perform "multiplication-like" operation (send + subsequent receive)
// and "validate": block the future until we have at least `active_work`
// futures pending and unblock them all at the same time
seq_join(
active_work.try_into().unwrap(),
stream::iter(std::iter::repeat(msg).take(total_records).enumerate()).map(
|(record_id, msg)| {
let send_channel = &send_channel;
let recv_channel = &recv_channel;
let send_notify = Arc::clone(&send_notify);
async move {
send_channel
.send(record_id.into(), msg.clone())
.await
.unwrap();
let r = recv_channel.receive(record_id.into()).await.unwrap();
// this simulates validate_record API by forcing futures to wait
// until the entire batch is validated by the last future in that batch
if record_id % active_work == active_work - 1
|| record_id == total_records - 1
{
send_notify.notify_waiters();
} else {
send_notify.notified().await;
}
assert_eq!(msg, r);
}
},
),
)
.collect::<Vec<_>>()
.await;
}

let config = TestWorldConfig {
gateway_config: GatewayConfig {
active: active_work.try_into().unwrap(),
read_size: read_size.try_into().unwrap(),
..Default::default()
},
..Default::default()
};

let world = TestWorld::new_with(&config);
let (h1_send_channel, h1_recv_channel) =
duplex_channel(&world, Role::H1, Role::H2, total_records, active_work);
let (h2_send_channel, h2_recv_channel) =
duplex_channel(&world, Role::H2, Role::H1, total_records, active_work);

join(
circuit(
h1_send_channel,
h1_recv_channel,
active_work,
total_records,
sample.clone(),
),
circuit(
h2_send_channel,
h2_recv_channel,
active_work,
total_records,
sample,
),
)
.await;
}
}
96 changes: 82 additions & 14 deletions ipa-core/src/helpers/gateway/send.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,35 +248,51 @@ impl<I: Debug> Stream for GatewaySendStream<I> {

impl SendChannelConfig {
fn new<M: Message>(gateway_config: GatewayConfig, total_records: TotalRecords) -> Self {
debug_assert!(M::Size::USIZE > 0, "Message size cannot be 0");
Self::new_with(gateway_config, total_records, M::Size::USIZE)
}
fn new_with(
gateway_config: GatewayConfig,
total_records: TotalRecords,
record_size: usize,
) -> Self {
debug_assert!(record_size > 0, "Message size cannot be 0");
debug_assert!(
gateway_config.active.is_power_of_two(),
"Active work {} must be a power of two",
gateway_config.active.get()
);

let record_size = M::Size::USIZE;
let total_capacity = gateway_config.active.get() * record_size;
Self {
// define read size in terms of percentage of active work, rather than bytes.
// both are powers of two, so it should always be possible. We pick the read size
// to be the closest to the configuration value in bytes.
// let read_size = closest_multiple(record_size, gateway_config.read_size.get());
let read_size = (gateway_config.read_size.get() / record_size + 1).next_power_of_two() / 2
* record_size;
let this = Self {
total_capacity: total_capacity.try_into().unwrap(),
record_size: record_size.try_into().unwrap(),
read_size: if total_records.is_indeterminate()
|| gateway_config.read_size.get() <= record_size
{
read_size: if total_records.is_indeterminate() || read_size <= record_size {
record_size
} else {
std::cmp::min(
total_capacity,
// closest multiple of record_size to read_size
gateway_config.read_size.get() / record_size * record_size,
)
std::cmp::min(total_capacity, read_size)
}
.try_into()
.unwrap(),
total_records,
}
};

debug_assert!(this.total_capacity.get() >= record_size * gateway_config.active.get());

this
}
}

#[cfg(test)]
mod test {
use std::num::NonZeroUsize;

use proptest::proptest;
use typenum::Unsigned;

use crate::{
Expand Down Expand Up @@ -379,15 +395,67 @@ mod test {
fn config_read_size_closest_multiple_to_record_size() {
assert_eq!(
6,
send_config::<BA20, 12, 7>(TotalRecords::Specified(2.try_into().unwrap()))
send_config::<BA20, 16, 7>(TotalRecords::Specified(2.try_into().unwrap()))
.read_size
.get()
);
assert_eq!(
6,
send_config::<BA20, 12, 8>(TotalRecords::Specified(2.try_into().unwrap()))
send_config::<BA20, 16, 8>(TotalRecords::Specified(2.try_into().unwrap()))
.read_size
.get()
);
}

#[test]
fn config_read_size_record_size_misalignment() {
ensure_config(Some(15), 90, 16, 3);
}

fn ensure_config(
total_records: Option<usize>,
active: usize,
read_size: usize,
record_size: usize,
) {
let gateway_config = GatewayConfig {
active: active.next_power_of_two().try_into().unwrap(),
read_size: read_size.try_into().unwrap(),
// read_size: read_size.next_power_of_two().try_into().unwrap(),
..Default::default()
};
let config = SendChannelConfig::new_with(
gateway_config,
total_records.map_or(TotalRecords::Indeterminate, |v| {
TotalRecords::specified(v).unwrap()
}),
record_size,
);

// total capacity checks
assert!(config.total_capacity.get() > 0);
assert!(config.total_capacity.get() >= config.read_size.get());
assert_eq!(0, config.total_capacity.get() % config.record_size.get());
assert_eq!(
config.total_capacity.get(),
record_size * gateway_config.active.get()
);

// read size checks
assert!(config.read_size.get() > 0);
assert!(config.read_size.get() >= config.record_size.get());
assert_eq!(0, config.total_capacity.get() % config.read_size.get());
}

proptest! {
#[test]
fn config_prop(
total_records in proptest::option::of(1_usize..1 << 32),
active in 1_usize..100_000,
read_size in 1_usize..32768,
record_size in 1_usize..4096,
) {
ensure_config(total_records, active, read_size, record_size);
}
}
}

0 comments on commit b0b7223

Please sign in to comment.