From a6701457ea57098621e198415d7788e17dc67a6a Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 2 Oct 2024 12:50:55 -0700 Subject: [PATCH] Add a type that enforces power of two constraint While working on changing the gateway and parameters, I ran into several issues where the power of two constraint was not enforced and breakages were hard to find. A better model for me is to gate the active work at the config level, prohibiting invalid constructions at the caller side. --- ipa-core/src/app.rs | 7 +- ipa-core/src/helpers/gateway/mod.rs | 29 +++-- ipa-core/src/helpers/gateway/send.rs | 49 +++++--- .../src/helpers/gateway/stall_detection.rs | 4 +- ipa-core/src/helpers/prss_protocol.rs | 4 +- .../src/protocol/context/dzkp_malicious.rs | 7 +- ipa-core/src/protocol/context/malicious.rs | 7 +- ipa-core/src/protocol/context/mod.rs | 9 +- ipa-core/src/query/processor.rs | 6 +- ipa-core/src/test_fixture/circuit.rs | 10 +- ipa-core/src/utils/mod.rs | 5 + ipa-core/src/utils/power_of_two.rs | 110 ++++++++++++++++++ 12 files changed, 198 insertions(+), 49 deletions(-) create mode 100644 ipa-core/src/utils/power_of_two.rs diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index da56e67e3..f84aed06e 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -1,4 +1,4 @@ -use std::{num::NonZeroUsize, sync::Weak}; +use std::sync::Weak; use async_trait::async_trait; @@ -13,17 +13,18 @@ use crate::{ protocol::QueryId, query::{NewQueryError, QueryProcessor, QueryStatus}, sync::Arc, + utils::NonZeroU32PowerOfTwo, }; #[derive(Default)] pub struct AppConfig { - active_work: Option, + active_work: Option, key_registry: Option>, } impl AppConfig { #[must_use] - pub fn with_active_work(mut self, active_work: Option) -> Self { + pub fn with_active_work(mut self, active_work: Option) -> Self { self.active_work = active_work; self } diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index 55d1b1ffc..15d2580d2 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -30,6 +30,7 @@ use crate::{ protocol::QueryId, sharding::ShardIndex, sync::{Arc, Mutex}, + utils::NonZeroU32PowerOfTwo, }; /// Alias for the currently configured transport. @@ -73,8 +74,7 @@ pub struct State { pub struct GatewayConfig { /// The number of items that can be active at the one time. /// This is used to determine the size of sending and receiving buffers. - /// Any value that is not a power of two will be rejected - pub active: NonZeroUsize, + pub active: NonZeroU32PowerOfTwo, /// Number of bytes packed and sent together in one batch down to the network layer. This /// shouldn't be too small to keep the network throughput, but setting it large enough may @@ -155,7 +155,7 @@ impl Gateway { &self, channel_id: &HelperChannelId, total_records: TotalRecords, - active_work: NonZeroUsize, + active_work: NonZeroU32PowerOfTwo, ) -> send::SendingEnd { let transport = &self.transports.mpc; let channel = self.inner.mpc_senders.get::( @@ -265,6 +265,11 @@ impl GatewayConfig { /// The configured amount of active work. #[must_use] pub fn active_work(&self) -> NonZeroUsize { + self.active.to_non_zero_usize() + } + + #[must_use] + pub fn active_work_as_power_of_two(&self) -> NonZeroU32PowerOfTwo { self.active } @@ -287,12 +292,12 @@ impl GatewayConfig { ) .next_power_of_two(); // we set active to be at least 2, so unwrap is fine. - self.active = NonZeroUsize::new(active).unwrap(); + self.active = NonZeroU32PowerOfTwo::try_from(active).unwrap(); } /// Creates a new configuration by overriding the value of active work. #[must_use] - pub fn set_active_work(&self, active_work: NonZeroUsize) -> Self { + pub fn set_active_work(&self, active_work: NonZeroU32PowerOfTwo) -> Self { Self { active: active_work, ..*self @@ -304,7 +309,6 @@ impl GatewayConfig { mod tests { use std::{ iter::{repeat, zip}, - num::NonZeroUsize, sync::Arc, }; @@ -337,6 +341,7 @@ mod tests { sharding::ShardConfiguration, test_executor::run, test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig, WithShards}, + utils::NonZeroU32PowerOfTwo, }; /// Verifies that [`Gateway`] send buffer capacity is adjusted to the message size. @@ -556,13 +561,19 @@ mod tests { run(|| async move { let world = TestWorld::new_with(TestWorldConfig { gateway_config: GatewayConfig { - active: 5.try_into().unwrap(), + active: 8.try_into().unwrap(), ..Default::default() }, ..Default::default() }); - let new_active_work = NonZeroUsize::new(3).unwrap(); - assert!(new_active_work < world.gateway(Role::H1).config().active_work()); + let new_active_work = NonZeroU32PowerOfTwo::try_from(4).unwrap(); + assert!( + new_active_work + < world + .gateway(Role::H1) + .config() + .active_work_as_power_of_two() + ); let sender = world.gateway(Role::H1).get_mpc_sender::( &ChannelId::new(Role::H2, Gate::default()), TotalRecords::specified(15).unwrap(), diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index e75cac7b2..70cae3707 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -255,34 +255,35 @@ impl SendChannelConfig { total_records: TotalRecords, record_size: usize, ) -> Self { - debug_assert!(record_size > 0, "Message size cannot be 0"); - debug_assert!( - gateway_config.active.is_power_of_two(), - "Active work {} must be a power of two", - gateway_config.active.get() - ); + assert!(record_size > 0, "Message size cannot be 0"); let total_capacity = gateway_config.active.get() * record_size; - // define read size in terms of percentage of active work, rather than bytes. - // both are powers of two, so it should always be possible. We pick the read size - // to be the closest to the configuration value in bytes. - // let read_size = closest_multiple(record_size, gateway_config.read_size.get()); - let read_size = (gateway_config.read_size.get() / record_size + 1).next_power_of_two() / 2 - * record_size; + // define read size as a multiplier of record size. The multiplier must be + // a power of two to align perfectly with total capacity. + let read_size_multiplier = { + // next_power_of_two returns a value that is greater than or equal two + // in order to compute previous, we need strictly greater value, thus + // we add 1 + let prev_power_of_two = + (gateway_config.read_size.get() / record_size + 1).next_power_of_two() / 2; + std::cmp::max(1, prev_power_of_two) + }; + let this = Self { total_capacity: total_capacity.try_into().unwrap(), record_size: record_size.try_into().unwrap(), - read_size: if total_records.is_indeterminate() || read_size <= record_size { + read_size: if total_records.is_indeterminate() { record_size } else { - std::cmp::min(total_capacity, read_size) + std::cmp::min(total_capacity, read_size_multiplier * record_size) } .try_into() .unwrap(), total_records, }; - debug_assert!(this.total_capacity.get() >= record_size * gateway_config.active.get()); + assert!(this.total_capacity.get() >= record_size * gateway_config.active.get()); + assert_eq!(0, this.total_capacity.get() % this.read_size.get()); this } @@ -297,7 +298,7 @@ mod test { use crate::{ ff::{ - boolean_array::{BA16, BA20, BA256, BA3, BA7}, + boolean_array::{BA16, BA20, BA256, BA3, BA32, BA7}, Serializable, }, helpers::{gateway::send::SendChannelConfig, GatewayConfig, TotalRecords}, @@ -412,6 +413,21 @@ mod test { ensure_config(Some(15), 90, 16, 3); } + #[test] + fn config_read_size_multiple_of_record_size() { + // 4 bytes * 8 = 32 bytes total capacity. + // desired read size is 15 bytes, and the closest multiple of BA32 + // to it that is a power of two is 2 (4 gets us over 15 byte target) + assert_eq!(8, send_config::(50.into()).read_size.get()); + + // here, read size is already a power of two + assert_eq!(16, send_config::(50.into()).read_size.get()); + + // read size can be ridiculously small, config adjusts it to fit + // at least one record + assert_eq!(3, send_config::(50.into()).read_size.get()); + } + fn ensure_config( total_records: Option, active: usize, @@ -421,7 +437,6 @@ mod test { let gateway_config = GatewayConfig { active: active.next_power_of_two().try_into().unwrap(), read_size: read_size.try_into().unwrap(), - // read_size: read_size.next_power_of_two().try_into().unwrap(), ..Default::default() }; let config = SendChannelConfig::new_with( diff --git a/ipa-core/src/helpers/gateway/stall_detection.rs b/ipa-core/src/helpers/gateway/stall_detection.rs index 43706f450..4a844386f 100644 --- a/ipa-core/src/helpers/gateway/stall_detection.rs +++ b/ipa-core/src/helpers/gateway/stall_detection.rs @@ -67,7 +67,6 @@ impl Observed { } mod gateway { - use std::num::NonZeroUsize; use delegate::delegate; @@ -81,6 +80,7 @@ mod gateway { protocol::QueryId, sharding::ShardIndex, sync::Arc, + utils::NonZeroU32PowerOfTwo, }; pub struct InstrumentedGateway { @@ -154,7 +154,7 @@ mod gateway { &self, channel_id: &HelperChannelId, total_records: TotalRecords, - active_work: NonZeroUsize, + active_work: NonZeroU32PowerOfTwo, ) -> SendingEnd { Observed::wrap( Weak::clone(self.get_sn()), diff --git a/ipa-core/src/helpers/prss_protocol.rs b/ipa-core/src/helpers/prss_protocol.rs index f9284f9eb..850d6c733 100644 --- a/ipa-core/src/helpers/prss_protocol.rs +++ b/ipa-core/src/helpers/prss_protocol.rs @@ -24,12 +24,12 @@ pub async fn negotiate( let left_sender = gateway.get_mpc_sender::( &left_channel, TotalRecords::ONE, - gateway.config().active_work(), + gateway.config().active_work_as_power_of_two(), ); let right_sender = gateway.get_mpc_sender::( &right_channel, TotalRecords::ONE, - gateway.config().active_work(), + gateway.config().active_work_as_power_of_two(), ); let left_receiver = gateway.get_mpc_receiver::(&left_channel); let right_receiver = gateway.get_mpc_receiver::(&right_channel); diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index 9f28239ba..80762fb52 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -61,8 +61,11 @@ impl<'a> DZKPUpgraded<'a> { // This overrides the active work for this context and all children // created from it by using narrow, clone, etc. // This allows all steps participating in malicious validation - // to use the same active work window and prevent deadlocks - base_ctx: base_ctx.set_active_work(active_work), + // to use the same active work window and prevent deadlocks. + // + // This also checks that active work is a power of two and + // panics if it is not. + base_ctx: base_ctx.set_active_work(active_work.get().try_into().unwrap()), } } diff --git a/ipa-core/src/protocol/context/malicious.rs b/ipa-core/src/protocol/context/malicious.rs index 8c287b1f2..b11f6f5a8 100644 --- a/ipa-core/src/protocol/context/malicious.rs +++ b/ipa-core/src/protocol/context/malicious.rs @@ -80,7 +80,7 @@ impl<'a> Context<'a> { } #[must_use] - pub fn set_active_work(self, new_active_work: NonZeroUsize) -> Self { + pub fn set_active_work(self, new_active_work: NonZeroU32PowerOfTwo) -> Self { Self { inner: self.inner.set_active_work(new_active_work), } @@ -171,7 +171,10 @@ impl Debug for Context<'_> { } } -use crate::sync::{Mutex, Weak}; +use crate::{ + sync::{Mutex, Weak}, + utils::NonZeroU32PowerOfTwo, +}; pub(super) type MacBatcher<'a, F> = Mutex>>; diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index eead81a16..abf6f8476 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -44,6 +44,7 @@ use crate::{ secret_sharing::replicated::malicious::ExtendableField, seq_join::SeqJoin, sharding::{NotSharded, ShardBinding, ShardConfiguration, ShardIndex, Sharded}, + utils::NonZeroU32PowerOfTwo, }; /// Context used by each helper to perform secure computation. Provides access to shared randomness @@ -162,7 +163,7 @@ pub struct Base<'a, B: ShardBinding = NotSharded> { inner: Inner<'a>, gate: Gate, total_records: TotalRecords, - active_work: NonZeroUsize, + active_work: NonZeroU32PowerOfTwo, /// This indicates whether the system uses sharding or no. It's not ideal that we keep it here /// because it gets cloned often, a potential solution to that, if this shows up on flame graph, /// would be to move it to [`Inner`] struct. @@ -181,13 +182,13 @@ impl<'a, B: ShardBinding> Base<'a, B> { inner: Inner::new(participant, gateway), gate, total_records, - active_work: gateway.config().active_work(), + active_work: gateway.config().active_work_as_power_of_two(), sharding, } } #[must_use] - pub fn set_active_work(self, new_active_work: NonZeroUsize) -> Self { + pub fn set_active_work(self, new_active_work: NonZeroU32PowerOfTwo) -> Self { Self { active_work: new_active_work, ..self.clone() @@ -336,7 +337,7 @@ impl ShardConfiguration for Base<'_, Sharded> { impl<'a, B: ShardBinding> SeqJoin for Base<'a, B> { fn active_work(&self) -> NonZeroUsize { - self.active_work + self.active_work.to_non_zero_usize() } } diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index a8694012e..679b740fd 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -1,7 +1,6 @@ use std::{ collections::hash_map::Entry, fmt::{Debug, Formatter}, - num::NonZeroUsize, }; use futures::{future::try_join, stream}; @@ -22,6 +21,7 @@ use crate::{ CompletionHandle, ProtocolResult, }, sync::Arc, + utils::NonZeroU32PowerOfTwo, }; /// `Processor` accepts and tracks requests to initiate new queries on this helper party @@ -44,7 +44,7 @@ use crate::{ pub struct Processor { queries: RunningQueries, key_registry: Arc>, - active_work: Option, + active_work: Option, } impl Default for Processor { @@ -118,7 +118,7 @@ impl Processor { #[must_use] pub fn new( key_registry: KeyRegistry, - active_work: Option, + active_work: Option, ) -> Self { Self { queries: RunningQueries::default(), diff --git a/ipa-core/src/test_fixture/circuit.rs b/ipa-core/src/test_fixture/circuit.rs index 5a1ecd67e..17920591f 100644 --- a/ipa-core/src/test_fixture/circuit.rs +++ b/ipa-core/src/test_fixture/circuit.rs @@ -1,4 +1,4 @@ -use std::{array, num::NonZeroUsize}; +use std::array; use futures::{future::join3, stream, StreamExt}; use ipa_step::StepNarrow; @@ -17,7 +17,7 @@ use crate::{ secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, FieldSimd, IntoShares}, seq_join::seq_join, test_fixture::{ReconstructArr, TestWorld, TestWorldConfig}, - utils::array::zip3, + utils::{array::zip3, NonZeroU32PowerOfTwo}, }; pub struct Inputs, const N: usize> { @@ -76,7 +76,7 @@ pub async fn arithmetic( [F; N]: IntoShares>, Standard: Distribution, { - let active = NonZeroUsize::new(active_work).unwrap(); + let active = NonZeroU32PowerOfTwo::try_from(active_work.next_power_of_two()).unwrap(); let config = TestWorldConfig { gateway_config: GatewayConfig { active, @@ -85,7 +85,7 @@ pub async fn arithmetic( initial_gate: Some(Gate::default().narrow(&ProtocolStep::Test)), ..Default::default() }; - let world = TestWorld::new_with(config); + let world = TestWorld::new_with(&config); // Re-use contexts for the entire execution because record identifiers are contiguous. let contexts = world.contexts(); @@ -96,7 +96,7 @@ pub async fn arithmetic( // accumulated. This gives the best performance for vectorized operation. let ctx = ctx.set_total_records(TotalRecords::Indeterminate); seq_join( - active, + config.gateway_config.active_work(), stream::iter((0..(width / u32::try_from(N).unwrap())).zip(col_data)).map( move |(record, Inputs { a, b })| { circuit(ctx.clone(), RecordId::from(record), depth, a, b) diff --git a/ipa-core/src/utils/mod.rs b/ipa-core/src/utils/mod.rs index a3600e899..e8dfd95ae 100644 --- a/ipa-core/src/utils/mod.rs +++ b/ipa-core/src/utils/mod.rs @@ -1,2 +1,7 @@ pub mod array; pub mod arraychunks; +#[cfg(target_pointer_width = "64")] +mod power_of_two; + +#[cfg(target_pointer_width = "64")] +pub use power_of_two::NonZeroU32PowerOfTwo; diff --git a/ipa-core/src/utils/power_of_two.rs b/ipa-core/src/utils/power_of_two.rs new file mode 100644 index 000000000..abce8055e --- /dev/null +++ b/ipa-core/src/utils/power_of_two.rs @@ -0,0 +1,110 @@ +use std::{fmt::Display, num::NonZeroUsize, str::FromStr}; + +#[derive(Debug, thiserror::Error)] +#[error("{0} is not a power of two or not within the 1..u32::MAX range")] +pub struct ConvertError(I); + +impl PartialEq for ConvertError { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +/// This construction guarantees the value to be a power of two and +/// within the range 0..2^32-1 +#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq)] +pub struct NonZeroU32PowerOfTwo(u32); + +impl Display for NonZeroU32PowerOfTwo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", u32::from(*self)) + } +} + +impl TryFrom for NonZeroU32PowerOfTwo { + type Error = ConvertError; + + fn try_from(value: usize) -> Result { + if value > 0 && value < usize::try_from(u32::MAX).unwrap() && value.is_power_of_two() { + Ok(NonZeroU32PowerOfTwo(u32::try_from(value).unwrap())) + } else { + Err(ConvertError(value)) + } + } +} + +impl From for usize { + fn from(value: NonZeroU32PowerOfTwo) -> Self { + // we are using 64 bit registers + usize::try_from(value.0).unwrap() + } +} + +impl From for u32 { + fn from(value: NonZeroU32PowerOfTwo) -> Self { + value.0 + } +} + +impl FromStr for NonZeroU32PowerOfTwo { + type Err = ConvertError; + + fn from_str(s: &str) -> Result { + let v = s.parse::().map_err(|_| ConvertError(s.to_owned()))?; + NonZeroU32PowerOfTwo::try_from(v).map_err(|_| ConvertError(s.to_owned())) + } +} + +impl NonZeroU32PowerOfTwo { + #[must_use] + pub fn to_non_zero_usize(self) -> NonZeroUsize { + let v = usize::from(self); + NonZeroUsize::new(v).unwrap_or_else(|| unreachable!()) + } + + #[must_use] + pub fn get(self) -> usize { + usize::from(self) + } +} + +#[cfg(test)] +mod tests { + use super::{ConvertError, NonZeroU32PowerOfTwo}; + + #[test] + fn rejects_invalid_values() { + assert!(matches!( + NonZeroU32PowerOfTwo::try_from(0), + Err(ConvertError(0)) + )); + assert!(matches!( + NonZeroU32PowerOfTwo::try_from(3), + Err(ConvertError(3)) + )); + + assert!(matches!( + NonZeroU32PowerOfTwo::try_from(1_usize << 33), + Err(ConvertError(_)) + )); + } + + #[test] + fn accepts_valid() { + assert_eq!(4, u32::from(NonZeroU32PowerOfTwo::try_from(4).unwrap())); + assert_eq!(16, u32::from(NonZeroU32PowerOfTwo::try_from(16).unwrap())); + } + + #[test] + fn parse_from_str() { + assert_eq!(NonZeroU32PowerOfTwo(4), "4".parse().unwrap()); + assert_eq!( + ConvertError("0".to_owned()), + "0".parse::().unwrap_err() + ); + assert_eq!( + ConvertError("3".to_owned()), + "3".parse::().unwrap_err() + ); + } +}