Skip to content

Commit

Permalink
Merge pull request private-attribution#1430 from akoshelev/sharded-sh…
Browse files Browse the repository at this point in the history
…uffle-test

Support sharded shuffle in executor
  • Loading branch information
akoshelev authored Nov 21, 2024
2 parents 59f15c8 + 36d18a2 commit a21f229
Show file tree
Hide file tree
Showing 13 changed files with 259 additions and 10 deletions.
2 changes: 1 addition & 1 deletion ipa-core/src/helpers/cross_shard_prss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
/// ## Errors
/// If shard communication channels fail
#[allow(dead_code)] // until this is used in real sharded protocol
async fn gen_and_distribute<R: SharedRandomness, C: ShardConfiguration>(
pub async fn gen_and_distribute<R: SharedRandomness, C: ShardConfiguration>(
gateway: &Gateway,
gate: &Gate,
prss: R,
Expand Down
23 changes: 22 additions & 1 deletion ipa-core/src/helpers/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::{
ShardChannelId, TotalRecords, Transport,
},
protocol::QueryId,
sharding::ShardIndex,
sharding::{ShardConfiguration, ShardIndex},
sync::{Arc, Mutex},
utils::NonZeroU32PowerOfTwo,
};
Expand Down Expand Up @@ -106,6 +106,27 @@ pub struct GatewayConfig {
pub progress_check_interval: std::time::Duration,
}

impl ShardConfiguration for Gateway {
fn shard_id(&self) -> ShardIndex {
ShardConfiguration::shard_id(&self)
}

fn shard_count(&self) -> ShardIndex {
ShardConfiguration::shard_count(&self)
}
}

impl ShardConfiguration for &Gateway {
fn shard_id(&self) -> ShardIndex {
self.transports.shard.identity()
}

fn shard_count(&self) -> ShardIndex {
// total number of shards include this instance and all its peers, so we add 1.
ShardIndex::from(self.transports.shard.peer_count() + 1)
}
}

impl Gateway {
#[must_use]
pub fn new(
Expand Down
12 changes: 11 additions & 1 deletion ipa-core/src/helpers/gateway/stall_detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ mod gateway {
Role, RoleAssignment, SendingEnd, ShardChannelId, ShardReceivingEnd, TotalRecords,
},
protocol::QueryId,
sharding::ShardIndex,
sharding::{ShardConfiguration, ShardIndex},
sync::Arc,
utils::NonZeroU32PowerOfTwo,
};
Expand Down Expand Up @@ -207,6 +207,16 @@ mod gateway {
}
}

impl ShardConfiguration for &Observed<InstrumentedGateway> {
fn shard_id(&self) -> ShardIndex {
self.inner().gateway.shard_id()
}

fn shard_count(&self) -> ShardIndex {
self.inner().gateway.shard_count()
}
}

pub struct GatewayWaitingTasks<MS, MR, SS, SR> {
mpc_send: Option<MS>,
mpc_recv: Option<MR>,
Expand Down
4 changes: 4 additions & 0 deletions ipa-core/src/helpers/gateway/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ impl Transport for RoleResolvingTransport {
Role::all().iter().filter(move |&v| v != &this).copied()
}

fn peer_count(&self) -> u32 {
self.inner.peer_count()
}

async fn send<
D: Stream<Item = Vec<u8>> + Send + 'static,
Q: QueryIdBinding,
Expand Down
1 change: 1 addition & 0 deletions ipa-core/src/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ mod gateway_exports {
pub type ShardReceivingEnd<M> = gateway::ShardReceivingEnd<M>;
}

pub use cross_shard_prss::gen_and_distribute as setup_cross_shard_prss;
pub use gateway::GatewayConfig;
// TODO: this type should only be available within infra. Right now several infra modules
// are exposed at the root level. That makes it impossible to have a proper hierarchy here.
Expand Down
26 changes: 24 additions & 2 deletions ipa-core/src/helpers/transport/in_memory/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,11 @@ mod tests {
},
routing::RouteId,
},
HandlerBox, HelperIdentity, HelperResponse, OrderingSender, Role, RoleAssignment,
Transport, TransportIdentity,
HandlerBox, HelperIdentity, HelperResponse, InMemoryShardNetwork, OrderingSender, Role,
RoleAssignment, Transport, TransportIdentity,
},
protocol::{Gate, QueryId},
sharding::ShardIndex,
sync::Arc,
};

Expand Down Expand Up @@ -625,6 +626,27 @@ mod tests {
// must be received by now
assert_eq!(vec![vec![0, 1]], recv.collect::<Vec<_>>().await);
}

#[tokio::test]
async fn peer_count() {
let mpc_network = InMemoryMpcNetwork::default();
assert_eq!(2, mpc_network.transport(HelperIdentity::ONE).peer_count());
assert_eq!(2, mpc_network.transport(HelperIdentity::TWO).peer_count());

let shard_network = InMemoryShardNetwork::with_shards(5);
assert_eq!(
4,
shard_network
.transport(HelperIdentity::ONE, ShardIndex::FIRST)
.peer_count()
);
assert_eq!(
4,
shard_network
.transport(HelperIdentity::TWO, ShardIndex::from(4))
.peer_count()
);
}
}

pub struct TransportConfig {
Expand Down
7 changes: 7 additions & 0 deletions ipa-core/src/helpers/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,13 @@ pub trait Transport: Clone + Send + Sync + 'static {
/// Returns all the other identities, besides me, in this network.
fn peers(&self) -> impl Iterator<Item = Self::Identity>;

/// The number of peers on the network. Default implementation may not be efficient,
/// because it uses [`Self::peers`] to count, so implementations are encouraged to
/// override it
fn peer_count(&self) -> u32 {
u32::try_from(self.peers().count()).expect("Number of peers is less than 4B")
}

/// Sends a new request to the given destination helper party.
/// Depending on the specific request, it may or may not require acknowledgment by the remote
/// party
Expand Down
37 changes: 37 additions & 0 deletions ipa-core/src/net/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,10 @@ impl Transport for MpcHttpTransport {
.filter(move |&id| id != this)
}

fn peer_count(&self) -> u32 {
2
}

async fn send<
D: Stream<Item = Vec<u8>> + Send + 'static,
Q: QueryIdBinding,
Expand Down Expand Up @@ -340,6 +344,10 @@ impl Transport for ShardHttpTransport {
self.shard_count.iter().filter(move |&v| v != this)
}

fn peer_count(&self) -> u32 {
u32::from(self.shard_count).saturating_sub(1)
}

async fn send<D, Q, S, R>(
&self,
dest: Self::Identity,
Expand Down Expand Up @@ -602,4 +610,33 @@ mod tests {
.build();
test_make_helpers(conf).await;
}

#[tokio::test]
async fn peer_count() {
fn new_transport<F: ConnectionFlavor>(identity: F::Identity) -> Arc<HttpTransport<F>> {
Arc::new(HttpTransport {
http_runtime: IpaRuntime::current(),
identity,
clients: Vec::new(),
handler: None,
record_streams: StreamCollection::default(),
})
}

assert_eq!(
2,
MpcHttpTransport {
inner_transport: new_transport(HelperIdentity::ONE)
}
.peer_count()
);
assert_eq!(
9,
ShardHttpTransport {
inner_transport: new_transport(ShardIndex::FIRST),
shard_count: 10.into()
}
.peer_count()
);
}
}
20 changes: 20 additions & 0 deletions ipa-core/src/protocol/prss/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,26 @@ impl SharedRandomness for IndexedSharedRandomness {
}
}

impl SharedRandomness for Arc<IndexedSharedRandomness> {
type ChunkIter<'a, Z: ArrayLength> =
<IndexedSharedRandomness as SharedRandomness>::ChunkIter<'a, Z>;

fn generate_chunks_one_side<I: Into<PrssIndex>, Z: ArrayLength>(
&self,
index: I,
direction: Direction,
) -> Self::ChunkIter<'_, Z> {
IndexedSharedRandomness::generate_chunks_one_side(self, index, direction)
}

fn generate_chunks_iter<I: Into<PrssIndex>, Z: ArrayLength>(
&self,
index: I,
) -> impl Iterator<Item = (GenericArray<u128, Z>, GenericArray<u128, Z>)> {
IndexedSharedRandomness::generate_chunks_iter(self, index)
}
}

/// Specialized implementation for chunks that are generated using both left and right
/// randomness. The functionality is the same as [`std::iter::zip`], but it does not use
/// `Iterator` trait to call `left` and `right` next. It uses inlined method calls to
Expand Down
6 changes: 3 additions & 3 deletions ipa-core/src/protocol/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,20 @@ use ipa_step_derive::{CompactGate, CompactStep};
#[derive(CompactStep, CompactGate)]
pub enum ProtocolStep {
Prss,
CrossShardPrss,
#[step(child = crate::protocol::ipa_prf::step::IpaPrfStep)]
IpaPrf,
#[step(child = crate::protocol::hybrid::step::HybridStep)]
Hybrid,
Multiply,
PrimeFieldAddition,
#[step(child = crate::protocol::ipa_prf::shuffle::step::ShardedShuffleStep)]
ShardedShuffle,
/// Steps used in unit tests are grouped under this one. Ideally it should be
/// gated behind test configuration, but it does not work with build.rs that
/// does not enable any features when creating protocol gate file
#[step(child = TestExecutionStep)]
Test,

/// This step includes all the steps that are currently not linked into a top-level protocol.
///
/// This allows those steps to be compiled. However, any use of them will fail at run time.
Expand All @@ -39,8 +41,6 @@ pub enum DeadCodeStep {
FeatureLabelDotProduct,
#[step(child = crate::protocol::ipa_prf::boolean_ops::step::MultiplicationStep)]
Multiplication,
#[step(child = crate::protocol::ipa_prf::shuffle::step::ShardedShuffleStep)]
ShardedShuffle,
}

/// Provides a unique per-iteration context in tests.
Expand Down
5 changes: 3 additions & 2 deletions ipa-core/src/query/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ use crate::{
};
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
use crate::{
ff::Fp32BitPrime, query::runner::execute_test_multiply, query::runner::test_add_in_prime_field,
ff::Fp32BitPrime, query::runner::execute_sharded_shuffle, query::runner::execute_test_multiply,
query::runner::test_add_in_prime_field,
};

pub trait Result: Send + Debug {
Expand Down Expand Up @@ -108,7 +109,7 @@ pub fn execute<R: PrivateKeyRegistry>(
config,
gateway,
input,
|_prss, _gateway, _config, _input| unimplemented!(),
|prss, gateway, _config, input| Box::pin(execute_sharded_shuffle(prss, gateway, input)),
),
#[cfg(any(test, feature = "weak-field"))]
(QueryType::TestAddInPrimeField, FieldType::Fp31) => do_query(
Expand Down
4 changes: 4 additions & 0 deletions ipa-core/src/query/runner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@ mod hybrid;
mod oprf_ipa;
mod reshard_tag;
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
mod sharded_shuffle;
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
mod test_multiply;

#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
pub(super) use add_in_prime_field::execute as test_add_in_prime_field;
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
pub(super) use sharded_shuffle::execute_sharded_shuffle;
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
pub(super) use test_multiply::execute_test_multiply;

pub use self::oprf_ipa::OprfIpaQuery;
Expand Down
Loading

0 comments on commit a21f229

Please sign in to comment.