Skip to content

Commit

Permalink
Add a type that enforces power of two constraint
Browse files Browse the repository at this point in the history
While working on changing the gateway and parameters, I ran into
several issues where the power of two constraint was not enforced
and breakages were hard to find.

A better model for me is to gate the active work at the config level,
prohibiting invalid constructions at the caller side.
  • Loading branch information
akoshelev committed Oct 2, 2024
1 parent b0b7223 commit a670145
Show file tree
Hide file tree
Showing 12 changed files with 198 additions and 49 deletions.
7 changes: 4 additions & 3 deletions ipa-core/src/app.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{num::NonZeroUsize, sync::Weak};
use std::sync::Weak;

use async_trait::async_trait;

Expand All @@ -13,17 +13,18 @@ use crate::{
protocol::QueryId,
query::{NewQueryError, QueryProcessor, QueryStatus},
sync::Arc,
utils::NonZeroU32PowerOfTwo,
};

#[derive(Default)]
pub struct AppConfig {
active_work: Option<NonZeroUsize>,
active_work: Option<NonZeroU32PowerOfTwo>,
key_registry: Option<KeyRegistry<PrivateKeyOnly>>,
}

impl AppConfig {
#[must_use]
pub fn with_active_work(mut self, active_work: Option<NonZeroUsize>) -> Self {
pub fn with_active_work(mut self, active_work: Option<NonZeroU32PowerOfTwo>) -> Self {
self.active_work = active_work;
self
}
Expand Down
29 changes: 20 additions & 9 deletions ipa-core/src/helpers/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use crate::{
protocol::QueryId,
sharding::ShardIndex,
sync::{Arc, Mutex},
utils::NonZeroU32PowerOfTwo,
};

/// Alias for the currently configured transport.
Expand Down Expand Up @@ -73,8 +74,7 @@ pub struct State {
pub struct GatewayConfig {
/// The number of items that can be active at the one time.
/// This is used to determine the size of sending and receiving buffers.
/// Any value that is not a power of two will be rejected
pub active: NonZeroUsize,
pub active: NonZeroU32PowerOfTwo,

/// Number of bytes packed and sent together in one batch down to the network layer. This
/// shouldn't be too small to keep the network throughput, but setting it large enough may
Expand Down Expand Up @@ -155,7 +155,7 @@ impl Gateway {
&self,
channel_id: &HelperChannelId,
total_records: TotalRecords,
active_work: NonZeroUsize,
active_work: NonZeroU32PowerOfTwo,
) -> send::SendingEnd<Role, M> {
let transport = &self.transports.mpc;
let channel = self.inner.mpc_senders.get::<M, _>(
Expand Down Expand Up @@ -265,6 +265,11 @@ impl GatewayConfig {
/// The configured amount of active work.
#[must_use]
pub fn active_work(&self) -> NonZeroUsize {
self.active.to_non_zero_usize()
}

#[must_use]
pub fn active_work_as_power_of_two(&self) -> NonZeroU32PowerOfTwo {
self.active
}

Expand All @@ -287,12 +292,12 @@ impl GatewayConfig {
)
.next_power_of_two();
// we set active to be at least 2, so unwrap is fine.
self.active = NonZeroUsize::new(active).unwrap();
self.active = NonZeroU32PowerOfTwo::try_from(active).unwrap();
}

/// Creates a new configuration by overriding the value of active work.
#[must_use]
pub fn set_active_work(&self, active_work: NonZeroUsize) -> Self {
pub fn set_active_work(&self, active_work: NonZeroU32PowerOfTwo) -> Self {
Self {
active: active_work,
..*self
Expand All @@ -304,7 +309,6 @@ impl GatewayConfig {
mod tests {
use std::{
iter::{repeat, zip},
num::NonZeroUsize,
sync::Arc,
};

Expand Down Expand Up @@ -337,6 +341,7 @@ mod tests {
sharding::ShardConfiguration,
test_executor::run,
test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards},
utils::NonZeroU32PowerOfTwo,
};

/// Verifies that [`Gateway`] send buffer capacity is adjusted to the message size.
Expand Down Expand Up @@ -556,13 +561,19 @@ mod tests {
run(|| async move {
let world = TestWorld::new_with(TestWorldConfig {
gateway_config: GatewayConfig {
active: 5.try_into().unwrap(),
active: 8.try_into().unwrap(),
..Default::default()
},
..Default::default()
});
let new_active_work = NonZeroUsize::new(3).unwrap();
assert!(new_active_work < world.gateway(Role::H1).config().active_work());
let new_active_work = NonZeroU32PowerOfTwo::try_from(4).unwrap();
assert!(
new_active_work
< world
.gateway(Role::H1)
.config()
.active_work_as_power_of_two()
);
let sender = world.gateway(Role::H1).get_mpc_sender::<BA3>(
&ChannelId::new(Role::H2, Gate::default()),
TotalRecords::specified(15).unwrap(),
Expand Down
49 changes: 32 additions & 17 deletions ipa-core/src/helpers/gateway/send.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,34 +255,35 @@ impl SendChannelConfig {
total_records: TotalRecords,
record_size: usize,
) -> Self {
debug_assert!(record_size > 0, "Message size cannot be 0");
debug_assert!(
gateway_config.active.is_power_of_two(),
"Active work {} must be a power of two",
gateway_config.active.get()
);
assert!(record_size > 0, "Message size cannot be 0");

let total_capacity = gateway_config.active.get() * record_size;
// define read size in terms of percentage of active work, rather than bytes.
// both are powers of two, so it should always be possible. We pick the read size
// to be the closest to the configuration value in bytes.
// let read_size = closest_multiple(record_size, gateway_config.read_size.get());
let read_size = (gateway_config.read_size.get() / record_size + 1).next_power_of_two() / 2
* record_size;
// define read size as a multiplier of record size. The multiplier must be
// a power of two to align perfectly with total capacity.
let read_size_multiplier = {
// next_power_of_two returns a value that is greater than or equal two
// in order to compute previous, we need strictly greater value, thus
// we add 1
let prev_power_of_two =
(gateway_config.read_size.get() / record_size + 1).next_power_of_two() / 2;
std::cmp::max(1, prev_power_of_two)
};

let this = Self {
total_capacity: total_capacity.try_into().unwrap(),
record_size: record_size.try_into().unwrap(),
read_size: if total_records.is_indeterminate() || read_size <= record_size {
read_size: if total_records.is_indeterminate() {
record_size
} else {
std::cmp::min(total_capacity, read_size)
std::cmp::min(total_capacity, read_size_multiplier * record_size)
}
.try_into()
.unwrap(),
total_records,
};

debug_assert!(this.total_capacity.get() >= record_size * gateway_config.active.get());
assert!(this.total_capacity.get() >= record_size * gateway_config.active.get());
assert_eq!(0, this.total_capacity.get() % this.read_size.get());

this
}
Expand All @@ -297,7 +298,7 @@ mod test {

use crate::{
ff::{
boolean_array::{BA16, BA20, BA256, BA3, BA7},
boolean_array::{BA16, BA20, BA256, BA3, BA32, BA7},
Serializable,
},
helpers::{gateway::send::SendChannelConfig, GatewayConfig, TotalRecords},
Expand Down Expand Up @@ -412,6 +413,21 @@ mod test {
ensure_config(Some(15), 90, 16, 3);
}

#[test]
fn config_read_size_multiple_of_record_size() {
// 4 bytes * 8 = 32 bytes total capacity.
// desired read size is 15 bytes, and the closest multiple of BA32
// to it that is a power of two is 2 (4 gets us over 15 byte target)
assert_eq!(8, send_config::<BA32, 8, 15>(50.into()).read_size.get());

// here, read size is already a power of two
assert_eq!(16, send_config::<BA32, 8, 16>(50.into()).read_size.get());

// read size can be ridiculously small, config adjusts it to fit
// at least one record
assert_eq!(3, send_config::<BA20, 8, 1>(50.into()).read_size.get());
}

fn ensure_config(
total_records: Option<usize>,
active: usize,
Expand All @@ -421,7 +437,6 @@ mod test {
let gateway_config = GatewayConfig {
active: active.next_power_of_two().try_into().unwrap(),
read_size: read_size.try_into().unwrap(),
// read_size: read_size.next_power_of_two().try_into().unwrap(),
..Default::default()
};
let config = SendChannelConfig::new_with(
Expand Down
4 changes: 2 additions & 2 deletions ipa-core/src/helpers/gateway/stall_detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ impl<T: ObserveState> Observed<T> {
}

mod gateway {
use std::num::NonZeroUsize;

use delegate::delegate;

Expand All @@ -81,6 +80,7 @@ mod gateway {
protocol::QueryId,
sharding::ShardIndex,
sync::Arc,
utils::NonZeroU32PowerOfTwo,
};

pub struct InstrumentedGateway {
Expand Down Expand Up @@ -154,7 +154,7 @@ mod gateway {
&self,
channel_id: &HelperChannelId,
total_records: TotalRecords,
active_work: NonZeroUsize,
active_work: NonZeroU32PowerOfTwo,
) -> SendingEnd<Role, M> {
Observed::wrap(
Weak::clone(self.get_sn()),
Expand Down
4 changes: 2 additions & 2 deletions ipa-core/src/helpers/prss_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ pub async fn negotiate<R: RngCore + CryptoRng>(
let left_sender = gateway.get_mpc_sender::<PublicKey>(
&left_channel,
TotalRecords::ONE,
gateway.config().active_work(),
gateway.config().active_work_as_power_of_two(),
);
let right_sender = gateway.get_mpc_sender::<PublicKey>(
&right_channel,
TotalRecords::ONE,
gateway.config().active_work(),
gateway.config().active_work_as_power_of_two(),
);
let left_receiver = gateway.get_mpc_receiver::<PublicKey>(&left_channel);
let right_receiver = gateway.get_mpc_receiver::<PublicKey>(&right_channel);
Expand Down
7 changes: 5 additions & 2 deletions ipa-core/src/protocol/context/dzkp_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,11 @@ impl<'a> DZKPUpgraded<'a> {
// This overrides the active work for this context and all children
// created from it by using narrow, clone, etc.
// This allows all steps participating in malicious validation
// to use the same active work window and prevent deadlocks
base_ctx: base_ctx.set_active_work(active_work),
// to use the same active work window and prevent deadlocks.
//
// This also checks that active work is a power of two and
// panics if it is not.
base_ctx: base_ctx.set_active_work(active_work.get().try_into().unwrap()),
}
}

Expand Down
7 changes: 5 additions & 2 deletions ipa-core/src/protocol/context/malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl<'a> Context<'a> {
}

#[must_use]
pub fn set_active_work(self, new_active_work: NonZeroUsize) -> Self {
pub fn set_active_work(self, new_active_work: NonZeroU32PowerOfTwo) -> Self {
Self {
inner: self.inner.set_active_work(new_active_work),
}
Expand Down Expand Up @@ -171,7 +171,10 @@ impl Debug for Context<'_> {
}
}

use crate::sync::{Mutex, Weak};
use crate::{
sync::{Mutex, Weak},
utils::NonZeroU32PowerOfTwo,
};

pub(super) type MacBatcher<'a, F> = Mutex<Batcher<'a, validator::Malicious<'a, F>>>;

Expand Down
9 changes: 5 additions & 4 deletions ipa-core/src/protocol/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use crate::{
secret_sharing::replicated::malicious::ExtendableField,
seq_join::SeqJoin,
sharding::{NotSharded, ShardBinding, ShardConfiguration, ShardIndex, Sharded},
utils::NonZeroU32PowerOfTwo,
};

/// Context used by each helper to perform secure computation. Provides access to shared randomness
Expand Down Expand Up @@ -162,7 +163,7 @@ pub struct Base<'a, B: ShardBinding = NotSharded> {
inner: Inner<'a>,
gate: Gate,
total_records: TotalRecords,
active_work: NonZeroUsize,
active_work: NonZeroU32PowerOfTwo,
/// This indicates whether the system uses sharding or no. It's not ideal that we keep it here
/// because it gets cloned often, a potential solution to that, if this shows up on flame graph,
/// would be to move it to [`Inner`] struct.
Expand All @@ -181,13 +182,13 @@ impl<'a, B: ShardBinding> Base<'a, B> {
inner: Inner::new(participant, gateway),
gate,
total_records,
active_work: gateway.config().active_work(),
active_work: gateway.config().active_work_as_power_of_two(),
sharding,
}
}

#[must_use]
pub fn set_active_work(self, new_active_work: NonZeroUsize) -> Self {
pub fn set_active_work(self, new_active_work: NonZeroU32PowerOfTwo) -> Self {
Self {
active_work: new_active_work,
..self.clone()
Expand Down Expand Up @@ -336,7 +337,7 @@ impl ShardConfiguration for Base<'_, Sharded> {

impl<'a, B: ShardBinding> SeqJoin for Base<'a, B> {
fn active_work(&self) -> NonZeroUsize {
self.active_work
self.active_work.to_non_zero_usize()
}
}

Expand Down
6 changes: 3 additions & 3 deletions ipa-core/src/query/processor.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::{
collections::hash_map::Entry,
fmt::{Debug, Formatter},
num::NonZeroUsize,
};

use futures::{future::try_join, stream};
Expand All @@ -22,6 +21,7 @@ use crate::{
CompletionHandle, ProtocolResult,
},
sync::Arc,
utils::NonZeroU32PowerOfTwo,
};

/// `Processor` accepts and tracks requests to initiate new queries on this helper party
Expand All @@ -44,7 +44,7 @@ use crate::{
pub struct Processor {
queries: RunningQueries,
key_registry: Arc<KeyRegistry<PrivateKeyOnly>>,
active_work: Option<NonZeroUsize>,
active_work: Option<NonZeroU32PowerOfTwo>,
}

impl Default for Processor {
Expand Down Expand Up @@ -118,7 +118,7 @@ impl Processor {
#[must_use]
pub fn new(
key_registry: KeyRegistry<PrivateKeyOnly>,
active_work: Option<NonZeroUsize>,
active_work: Option<NonZeroU32PowerOfTwo>,
) -> Self {
Self {
queries: RunningQueries::default(),
Expand Down
Loading

0 comments on commit a670145

Please sign in to comment.