Skip to content

Commit

Permalink
Support sharded shuffle in executor
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Nov 13, 2024
1 parent 91e9d4e commit 997097f
Show file tree
Hide file tree
Showing 12 changed files with 181 additions and 40 deletions.
2 changes: 1 addition & 1 deletion ipa-core/src/helpers/cross_shard_prss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
/// ## Errors
/// If shard communication channels fail
#[allow(dead_code)] // until this is used in real sharded protocol
async fn gen_and_distribute<R: SharedRandomness, C: ShardConfiguration>(
pub async fn gen_and_distribute<R: SharedRandomness, C: ShardConfiguration>(
gateway: &Gateway,
gate: &Gate,
prss: R,
Expand Down
12 changes: 11 additions & 1 deletion ipa-core/src/helpers/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::{
ShardChannelId, TotalRecords, Transport,
},
protocol::QueryId,
sharding::ShardIndex,
sharding::{ShardConfiguration, ShardIndex},
sync::{Arc, Mutex},
utils::NonZeroU32PowerOfTwo,
};
Expand Down Expand Up @@ -106,6 +106,16 @@ pub struct GatewayConfig {
pub progress_check_interval: std::time::Duration,
}

impl ShardConfiguration for Gateway {
fn shard_id(&self) -> ShardIndex {
self.transports.shard.identity()
}

fn shard_count(&self) -> ShardIndex {
ShardIndex::from(self.transports.shard.peer_count() + 1)
}
}

impl Gateway {
#[must_use]
pub fn new(
Expand Down
12 changes: 11 additions & 1 deletion ipa-core/src/helpers/gateway/stall_detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ mod gateway {
Role, RoleAssignment, SendingEnd, ShardChannelId, ShardReceivingEnd, TotalRecords,
},
protocol::QueryId,
sharding::ShardIndex,
sharding::{ShardConfiguration, ShardIndex},
sync::Arc,
utils::NonZeroU32PowerOfTwo,
};
Expand Down Expand Up @@ -207,6 +207,16 @@ mod gateway {
}
}

impl ShardConfiguration for &Observed<InstrumentedGateway> {
fn shard_id(&self) -> ShardIndex {
self.inner().gateway.shard_id()
}

fn shard_count(&self) -> ShardIndex {
self.inner().gateway.shard_count()
}
}

pub struct GatewayWaitingTasks<MS, MR, SS, SR> {
mpc_send: Option<MS>,
mpc_recv: Option<MR>,
Expand Down
4 changes: 4 additions & 0 deletions ipa-core/src/helpers/gateway/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ impl Transport for RoleResolvingTransport {
Role::all().iter().filter(move |&v| v != &this).copied()
}

fn peer_count(&self) -> u32 {
self.inner.peer_count()
}

async fn send<
D: Stream<Item = Vec<u8>> + Send + 'static,
Q: QueryIdBinding,
Expand Down
1 change: 1 addition & 0 deletions ipa-core/src/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ mod gateway_exports {
pub type ShardReceivingEnd<M> = gateway::ShardReceivingEnd<M>;
}

pub use cross_shard_prss::gen_and_distribute as setup_cross_shard_prss;
pub use gateway::GatewayConfig;
// TODO: this type should only be available within infra. Right now several infra modules
// are exposed at the root level. That makes it impossible to have a proper hierarchy here.
Expand Down
7 changes: 7 additions & 0 deletions ipa-core/src/helpers/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,13 @@ pub trait Transport: Clone + Send + Sync + 'static {
/// Returns all the other identities, besides me, in this network.
fn peers(&self) -> impl Iterator<Item = Self::Identity>;

/// The number of peers on the network. Default implementation may not be efficient,
/// because it uses [`Self::peers`] to count, so implementations are encouraged to
/// override it
fn peer_count(&self) -> u32 {
u32::try_from(self.peers().count()).expect("Number of peers is less than 4B")
}

/// Sends a new request to the given destination helper party.
/// Depending on the specific request, it may or may not require acknowledgment by the remote
/// party
Expand Down
8 changes: 8 additions & 0 deletions ipa-core/src/net/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ impl Transport for MpcHttpTransport {
.filter(move |&id| id != this)
}

fn peer_count(&self) -> u32 {
2
}

async fn send<
D: Stream<Item = Vec<u8>> + Send + 'static,
Q: QueryIdBinding,
Expand Down Expand Up @@ -336,6 +340,10 @@ impl Transport for ShardHttpTransport {
self.shard_count.iter().filter(move |&v| v != this)
}

fn peer_count(&self) -> u32 {
self.shard_count.into()
}

async fn send<D, Q, S, R>(
&self,
dest: Self::Identity,
Expand Down
20 changes: 20 additions & 0 deletions ipa-core/src/protocol/prss/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,26 @@ impl SharedRandomness for IndexedSharedRandomness {
}
}

impl SharedRandomness for Arc<IndexedSharedRandomness> {
type ChunkIter<'a, Z: ArrayLength> =
<IndexedSharedRandomness as SharedRandomness>::ChunkIter<'a, Z>;

fn generate_chunks_one_side<I: Into<PrssIndex>, Z: ArrayLength>(
&self,
index: I,
direction: Direction,
) -> Self::ChunkIter<'_, Z> {
IndexedSharedRandomness::generate_chunks_one_side(self, index, direction)
}

fn generate_chunks_iter<I: Into<PrssIndex>, Z: ArrayLength>(
&self,
index: I,
) -> impl Iterator<Item = (GenericArray<u128, Z>, GenericArray<u128, Z>)> {
IndexedSharedRandomness::generate_chunks_iter(self, index)
}
}

/// Specialized implementation for chunks that are generated using both left and right
/// randomness. The functionality is the same as [`std::iter::zip`], but it does not use
/// `Iterator` trait to call `left` and `right` next. It uses inlined method calls to
Expand Down
6 changes: 3 additions & 3 deletions ipa-core/src/protocol/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,20 @@ use ipa_step_derive::{CompactGate, CompactStep};
#[derive(CompactStep, CompactGate)]
pub enum ProtocolStep {
Prss,
CrossShardPrss,
#[step(child = crate::protocol::ipa_prf::step::IpaPrfStep)]
IpaPrf,
#[step(child = crate::protocol::hybrid::step::HybridStep)]
Hybrid,
Multiply,
PrimeFieldAddition,
#[step(child = crate::protocol::ipa_prf::shuffle::step::ShardedShuffleStep)]
ShardedShuffle,
/// Steps used in unit tests are grouped under this one. Ideally it should be
/// gated behind test configuration, but it does not work with build.rs that
/// does not enable any features when creating protocol gate file
#[step(child = TestExecutionStep)]
Test,

/// This step includes all the steps that are currently not linked into a top-level protocol.
///
/// This allows those steps to be compiled. However, any use of them will fail at run time.
Expand All @@ -39,8 +41,6 @@ pub enum DeadCodeStep {
FeatureLabelDotProduct,
#[step(child = crate::protocol::ipa_prf::boolean_ops::step::MultiplicationStep)]
Multiplication,
#[step(child = crate::protocol::ipa_prf::shuffle::step::ShardedShuffleStep)]
ShardedShuffle,
}

/// Provides a unique per-iteration context in tests.
Expand Down
7 changes: 2 additions & 5 deletions ipa-core/src/query/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use crate::{
Gate,
},
query::{
runner::{OprfIpaQuery, QueryResult},
runner::{execute_sharded_shuffle, OprfIpaQuery, QueryResult},
state::RunningQuery,
},
sync::Arc,
Expand All @@ -48,7 +48,6 @@ use crate::{
use crate::{
ff::Fp32BitPrime, query::runner::execute_test_multiply, query::runner::test_add_in_prime_field,
};
use crate::query::runner::execute_sharded_shuffle;

pub trait Result: Send + Debug {
fn to_bytes(&self) -> Vec<u8>;
Expand Down Expand Up @@ -109,9 +108,7 @@ pub fn execute<R: PrivateKeyRegistry>(
config,
gateway,
input,
|_prss, _gateway, _config, _input| {
Box::pin(execute_sharded_shuffle())
}
|prss, gateway, _config, input| Box::pin(execute_sharded_shuffle(prss, gateway, input)),
),
#[cfg(any(test, feature = "weak-field"))]
(QueryType::TestAddInPrimeField, FieldType::Fp31) => do_query(
Expand Down
8 changes: 4 additions & 4 deletions ipa-core/src/query/runner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ mod hybrid;
mod oprf_ipa;
mod reshard_tag;
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
mod test_multiply;
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
mod sharded_shuffle;
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
mod test_multiply;

#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
pub(super) use add_in_prime_field::execute as test_add_in_prime_field;
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
pub(super) use test_multiply::execute_test_multiply;
pub(super) use sharded_shuffle::execute_sharded_shuffle;
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
pub(super) use sharded_shuffle::execute as execute_sharded_shuffle;
pub(super) use test_multiply::execute_test_multiply;

pub use self::oprf_ipa::OprfIpaQuery;
use crate::{error::Error, query::ProtocolResult};
Expand Down
134 changes: 109 additions & 25 deletions ipa-core/src/query/runner/sharded_shuffle.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,119 @@
use futures_util::TryStreamExt;
use crate::error::Error;
use crate::ff::boolean_array::{BA64};
use crate::helpers::{BodyStream, Gateway, RecordsStream, SingleRecordStream};
use crate::helpers::query::QuerySize;
use crate::protocol::context::{ShardedContext, ShardedSemiHonestContext};
use crate::protocol::ipa_prf::Shuffle;
use crate::secret_sharing::replicated::semi_honest::AdditiveShare;
use crate::protocol::prss::Endpoint as PrssEndpoint;
use crate::query::runner::QueryResult;

pub async fn execute_test_multiply<'a, F>(
use ipa_step::StepNarrow;

use crate::{
error::Error,
ff::boolean_array::BA64,
helpers::{setup_cross_shard_prss, BodyStream, Gateway, SingleRecordStream},
protocol::{
context::{Context, ShardedContext, ShardedSemiHonestContext},
ipa_prf::Shuffle,
prss::Endpoint as PrssEndpoint,
step::ProtocolStep,
Gate,
},
query::runner::QueryResult,
secret_sharing::replicated::semi_honest::AdditiveShare,
sharding::{ShardConfiguration, Sharded},
sync::Arc,
};

pub async fn execute_sharded_shuffle<'a>(
prss: &'a PrssEndpoint,
gateway: &'a Gateway,
input: BodyStream,
) -> QueryResult
{
gen_and_distribute()
let ctx = ShardedSemiHonestContext::new_sharded(prss, gateway).narrow(&ProtocolStep::Multiply);
let ctx = SemiHonestContext::new(prss, gateway).narrow(&ProtocolStep::Multiply);
Ok(Box::new(
execute::<F>(ctx, input).await?,
))
) -> QueryResult {
let gate = Gate::default().narrow(&ProtocolStep::CrossShardPrss);
let cross_shard_prss =
setup_cross_shard_prss(gateway, &gate, prss.indexed(&gate), gateway).await?;
let ctx = ShardedSemiHonestContext::new_sharded(
prss,
gateway,
Sharded {
shard_id: gateway.shard_id(),
shard_count: gateway.shard_count(),
prss: Arc::new(cross_shard_prss),
},
)
.narrow(&ProtocolStep::ShardedShuffle);

Ok(Box::new(execute(ctx, input).await?))
}

#[tracing::instrument("sharded_shuffle", skip_all)]
pub async fn execute<C>(
ctx: C,
input_stream: BodyStream,
) -> Result<Vec<AdditiveShare<BA64>>, Error>
where C: ShardedContext + Shuffle
pub async fn execute<C>(ctx: C, input_stream: BodyStream) -> Result<Vec<AdditiveShare<BA64>>, Error>
where
C: ShardedContext + Shuffle,
{
let input = SingleRecordStream::<AdditiveShare<BA64>, _>::new(input_stream).try_collect::<Vec<_>>().await?;
let input = SingleRecordStream::<AdditiveShare<BA64>, _>::new(input_stream)
.try_collect::<Vec<_>>()
.await?;
ctx.shuffle(input).await
}

#[cfg(all(test, unit_test))]
mod tests {
use futures_util::future::try_join_all;
use generic_array::GenericArray;
use typenum::Unsigned;

use crate::{
ff::{boolean_array::BA64, Serializable, U128Conversions},
query::runner::sharded_shuffle::execute,
secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares},
test_executor::run,
test_fixture::{try_join3_array, Reconstruct, TestWorld, TestWorldConfig, WithShards},
utils::array::zip3,
};

#[test]
fn basic() {
run(|| async {
const SHARDS: usize = 20;
let world: TestWorld<WithShards<3>> =
TestWorld::with_shards(TestWorldConfig::default());
let contexts = world.contexts();
let input = (0..20_u128).map(BA64::truncate_from).collect::<Vec<_>>();

#[allow(clippy::redundant_closure_for_method_calls)]
let shard_shares: [Vec<Vec<AdditiveShare<BA64>>>; 3] =
input.clone().into_iter().share().map(|helper_shares| {
helper_shares
.chunks(SHARDS / 3)
.map(|v| v.to_vec())
.collect()
});

let result =
try_join3_array(zip3(contexts, shard_shares).map(|(h_contexts, h_shares)| {
try_join_all(
h_contexts
.into_iter()
.zip(h_shares)
.map(|(ctx, shard_shares)| {
let shard_stream = shard_shares
.into_iter()
.flat_map(|share| {
const SIZE: usize =
<AdditiveShare<BA64> as Serializable>::Size::USIZE;
let mut slice = [0_u8; SIZE];
share.serialize(GenericArray::from_mut_slice(&mut slice));
slice
})
.collect::<Vec<_>>()
.into();

execute(ctx, shard_stream)
}),
)
}))
.await
.unwrap()
.map(|v| v.into_iter().flatten().collect::<Vec<_>>())
.reconstruct();

// 1/20! probability of this permutation to be the same
assert_ne!(input, result);
});
}
}

0 comments on commit 997097f

Please sign in to comment.