Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cross-shard PRSS in sharded MPC circuits (in memory only for now) #1410

Merged
merged 2 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ipa-core/src/bin/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
executor::IpaRuntime,
helpers::HelperIdentity,
net::{ClientIdentity, IpaHttpClient, MpcHttpTransport, ShardHttpTransport},
sharding::Sharded,
sharding::ShardIndex,
AppConfig, AppSetup, NonZeroU32PowerOfTwo,
};
use tokio::runtime::Runtime;
Expand Down Expand Up @@ -185,7 +185,8 @@
let shard_network_config = NetworkConfig::new_shards(vec![], shard_clients_config);
let (shard_transport, _shard_server) = ShardHttpTransport::new(
IpaRuntime::from_tokio_runtime(&http_runtime),
Sharded::new(0, 1),
ShardIndex::FIRST,
ShardIndex::from(1),

Check warning on line 189 in ipa-core/src/bin/helper.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/helper.rs#L188-L189

Added lines #L188 - L189 were not covered by tests
shard_server_config,
shard_network_config,
vec![],
Expand Down
6 changes: 3 additions & 3 deletions ipa-core/src/helpers/transport/in_memory/config.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
helpers::{HelperIdentity, Role, RoleAssignment},
protocol::Gate,
sharding::{ShardIndex, Sharded},
sharding::ShardIndex,
sync::Arc,
};

Expand Down Expand Up @@ -90,7 +90,7 @@ pub enum InspectContext {
MpcMessage {
/// The shard of this instance.
/// This is `None` for non-sharded helpers.
shard: Option<Sharded>,
shard: Option<ShardIndex>,
/// Helper sending this stream.
source: HelperIdentity,
/// Helper that will receive this stream.
Expand Down Expand Up @@ -161,7 +161,7 @@ impl<F: Fn(&MaliciousHelperContext, &mut Vec<u8>) + Send + Sync> MaliciousHelper
pub struct MaliciousHelperContext {
/// The shard of this instance.
/// This is `None` for non-sharded helpers.
pub shard: Option<Sharded>,
pub shard: Option<ShardIndex>,
/// Helper that will receive this stream.
pub dest: Role,
/// Circuit gate this stream is tied to.
Expand Down
6 changes: 3 additions & 3 deletions ipa-core/src/helpers/transport/in_memory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
in_memory_config::DynStreamInterceptor, transport::in_memory::config::passthrough,
HandlerRef, HelperIdentity,
},
sharding::Sharded,
sharding::ShardIndex,
sync::{Arc, Weak},
};

Expand Down Expand Up @@ -50,13 +50,13 @@ impl InMemoryMpcNetwork {
pub fn with_stream_interceptor(
handlers: [Option<HandlerRef>; 3],
interceptor: &DynStreamInterceptor,
shard_context: Option<Sharded>,
shard: Option<ShardIndex>,
) -> Self {
let [mut first, mut second, mut third]: [_; 3] = HelperIdentity::make_three().map(|i| {
let mut config_builder = TransportConfigBuilder::for_helper(i);
config_builder.with_interceptor(interceptor);

Setup::with_config(i, config_builder.with_sharding(shard_context))
Setup::with_config(i, config_builder.with_sharding(shard))
});

first.connect(&mut second);
Expand Down
12 changes: 2 additions & 10 deletions ipa-core/src/helpers/transport/in_memory/sharding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
transport::in_memory::transport::{InMemoryTransport, Setup, TransportConfigBuilder},
HelperIdentity,
},
sharding::{ShardIndex, Sharded},
sharding::ShardIndex,
sync::{Arc, Weak},
};

Expand Down Expand Up @@ -37,15 +37,7 @@ impl InMemoryShardNetwork {

let mut shard_connections = shard_count
.iter()
.map(|i| {
Setup::with_config(
i,
config_builder.with_sharding(Some(Sharded {
shard_id: i,
shard_count,
})),
)
})
.map(|i| Setup::with_config(i, config_builder.with_sharding(Some(i))))
.collect::<Vec<_>>();
for i in 0..shard_connections.len() {
let (lhs, rhs) = shard_connections.split_at_mut(i);
Expand Down
14 changes: 7 additions & 7 deletions ipa-core/src/helpers/transport/in_memory/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::{
Transport, TransportIdentity,
},
protocol::{Gate, QueryId},
sharding::Sharded,
sharding::ShardIndex,
sync::{Arc, Weak},
};

Expand Down Expand Up @@ -192,8 +192,8 @@ impl<I: TransportIdentity> Transport for Weak<InMemoryTransport<I>> {
let gate = addr.gate.clone();

let (ack_tx, ack_rx) = oneshot::channel();
let context = gate
.map(|gate| dest.inspect_context(this.config.shard_config, this.config.identity, gate));
let context =
gate.map(|gate| dest.inspect_context(this.config.shard, this.config.identity, gate));

channel
.send((
Expand Down Expand Up @@ -628,7 +628,7 @@ mod tests {
}

pub struct TransportConfig {
pub shard_config: Option<Sharded>,
pub shard: Option<ShardIndex>,
pub identity: HelperIdentity,
pub stream_interceptor: DynStreamInterceptor,
}
Expand All @@ -652,17 +652,17 @@ impl TransportConfigBuilder {
self
}

pub fn with_sharding(&self, shard_config: Option<Sharded>) -> TransportConfig {
pub fn with_sharding(&self, shard: Option<ShardIndex>) -> TransportConfig {
TransportConfig {
shard_config,
shard,
identity: self.identity,
stream_interceptor: Arc::clone(&self.stream_interceptor),
}
}

pub fn not_sharded(&self) -> TransportConfig {
TransportConfig {
shard_config: None,
shard: None,
identity: self.identity,
stream_interceptor: Arc::clone(&self.stream_interceptor),
}
Expand Down
12 changes: 5 additions & 7 deletions ipa-core/src/helpers/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

#[cfg(feature = "in-memory-infra")]
use crate::helpers::in_memory_config::InspectContext;
#[cfg(feature = "in-memory-infra")]
use crate::sharding::Sharded;
use crate::{
helpers::{transport::routing::RouteId, HelperIdentity, Role, TransportIdentity},
protocol::{Gate, QueryId},
Expand Down Expand Up @@ -58,7 +56,7 @@
#[cfg(feature = "in-memory-infra")]
fn inspect_context(
&self,
shard: Option<Sharded>,
shard: Option<ShardIndex>,
helper: HelperIdentity,
gate: Gate,
) -> InspectContext;
Expand All @@ -84,13 +82,13 @@
#[cfg(feature = "in-memory-infra")]
fn inspect_context(
&self,
shard: Option<Sharded>,
shard: Option<ShardIndex>,
helper: HelperIdentity,
gate: Gate,
) -> InspectContext {
InspectContext::ShardMessage {
helper,
source: shard.unwrap().shard_id,
source: shard.unwrap(),
dest: *self,
gate,
}
Expand Down Expand Up @@ -125,7 +123,7 @@
#[cfg(feature = "in-memory-infra")]
fn inspect_context(
&self,
shard: Option<Sharded>,
shard: Option<ShardIndex>,
helper: HelperIdentity,
gate: Gate,
) -> InspectContext {
Expand Down Expand Up @@ -167,7 +165,7 @@
#[cfg(feature = "in-memory-infra")]
fn inspect_context(
&self,
_shard: Option<Sharded>,
_shard: Option<ShardIndex>,

Check warning on line 168 in ipa-core/src/helpers/transport/mod.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/helpers/transport/mod.rs#L168

Added line #L168 was not covered by tests
_helper: HelperIdentity,
_gate: Gate,
) -> InspectContext {
Expand Down
6 changes: 2 additions & 4 deletions ipa-core/src/net/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,8 @@ impl TestApp {
);
let (shard_transport, shard_server) = super::ShardHttpTransport::new(
IpaRuntime::current(),
crate::sharding::Sharded {
shard_id: sid.shard_index,
shard_count: self.shard_network_config.shard_count(),
},
sid.shard_index,
self.shard_network_config.shard_count(),
self.shard_server.config,
self.shard_network_config,
shard_clients,
Expand Down
17 changes: 8 additions & 9 deletions ipa-core/src/net/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
},
net::{client::IpaHttpClient, error::Error, IpaHttpServer},
protocol::{Gate, QueryId},
sharding::{ShardIndex, Sharded},
sharding::ShardIndex,
sync::Arc,
};

Expand All @@ -45,7 +45,7 @@
#[derive(Clone)]
pub struct ShardHttpTransport {
pub(super) inner_transport: Arc<HttpTransport<Shard>>,
pub(super) shard_config: Sharded,
pub(super) shard_count: ShardIndex,
}

impl RouteParams<RouteId, NoQueryId, NoStep> for QueryConfig {
Expand Down Expand Up @@ -297,7 +297,9 @@
#[must_use]
pub fn new(
http_runtime: IpaRuntime,
shard_config: Sharded,
// todo: maybe a wrapper struct for it
shard_id: ShardIndex,
shard_count: ShardIndex,
server_config: ServerConfig,
network_config: NetworkConfig<Shard>,
clients: Vec<IpaHttpClient<Shard>>,
Expand All @@ -306,12 +308,12 @@
let transport = Self {
inner_transport: Arc::new(HttpTransport {
http_runtime,
identity: shard_config.shard_id,
identity: shard_id,
clients,
handler,
record_streams: StreamCollection::default(),
}),
shard_config,
shard_count,
};

let server = IpaHttpServer::new_shards(&transport, server_config, network_config);
Expand All @@ -331,10 +333,7 @@

fn peers(&self) -> impl Iterator<Item = Self::Identity> {
let this = self.identity();
self.shard_config
.shard_count
.iter()
.filter(move |&v| v != this)
self.shard_count.iter().filter(move |&v| v != this)

Check warning on line 336 in ipa-core/src/net/transport.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/net/transport.rs#L336

Added line #L336 was not covered by tests
}

async fn send<D, Q, S, R>(
Expand Down
4 changes: 4 additions & 0 deletions ipa-core/src/protocol/context/dzkp_semi_honest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
fn shard_recv_channel<M: Message>(&self, origin: ShardIndex) -> ShardReceivingEnd<M> {
self.inner.shard_recv_channel(origin)
}

fn cross_shard_prss(&self) -> InstrumentedIndexedSharedRandomness<'_> {
self.inner.cross_shard_prss()
}

Check warning on line 58 in ipa-core/src/protocol/context/dzkp_semi_honest.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/context/dzkp_semi_honest.rs#L56-L58

Added lines #L56 - L58 were not covered by tests
}

impl<'a, B: ShardBinding> super::Context for DZKPUpgraded<'a, B> {
Expand Down
4 changes: 4 additions & 0 deletions ipa-core/src/protocol/context/malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@
fn shard_recv_channel<M: Message>(&self, origin: ShardIndex) -> ShardReceivingEnd<M> {
self.inner.shard_recv_channel(origin)
}

fn cross_shard_prss(&self) -> InstrumentedIndexedSharedRandomness<'_> {
self.inner.cross_shard_prss()
}

Check warning on line 80 in ipa-core/src/protocol/context/malicious.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/context/malicious.rs#L78-L80

Added lines #L78 - L80 were not covered by tests
}

impl<'a> Context<'a, NotSharded> {
Expand Down
15 changes: 15 additions & 0 deletions ipa-core/src/protocol/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@ impl ShardedContext for Base<'_, Sharded> {
.gateway
.get_shard_receiver(&ChannelId::new(origin, self.gate.clone()))
}

fn cross_shard_prss(&self) -> InstrumentedIndexedSharedRandomness<'_> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to implement this on non-sharded contexts, so that a protocol can use this regardless of sharding (in the non-sharded case, it just acts like normal prss)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is where Rust type system plays nicely for us - in non-sharded world you should never require cross-shard PRSS and use regular PRSS instead. All sharded circuits we write operate over ShardedContext that will make this API available. Non-sharded MPC won't compile if it tries to use it, which is what we would want imo

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. so for example with gen_prf_key we'd just have two different implementations, one with Upgradable constraint on C, and another with Upgradable + Sharded? (just to make sure I'm understanding correctly.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, we will delete non-sharded version anyway, so the only one that we will use is gen_prf_key with C: UpgradableContext + ShardedContext bound. That will provide access to cross_shard_prss

InstrumentedIndexedSharedRandomness::new(
self.sharding.cross_shard_prss().indexed(self.gate()),
self.gate(),
self.inner.gateway.role(),
)
}
}

impl<'a, B: ShardBinding> Context for Base<'a, B> {
Expand Down Expand Up @@ -325,6 +333,13 @@ pub trait ShardedContext: Context + ShardConfiguration {

ShardIndex::from(shard_index)
}

/// Get the indexed PRSS instance shared across all shards on this helper.
/// Each shard will see the same random values generated by it.
/// This is still PRSS - the corresponding shards on other helpers will share
/// the left and the right part
#[must_use]
fn cross_shard_prss(&self) -> InstrumentedIndexedSharedRandomness<'_>;
}

impl ShardConfiguration for Base<'_, Sharded> {
Expand Down
8 changes: 8 additions & 0 deletions ipa-core/src/protocol/context/semi_honest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@
fn shard_recv_channel<M: Message>(&self, origin: ShardIndex) -> ShardReceivingEnd<M> {
self.inner.shard_recv_channel(origin)
}

fn cross_shard_prss(&self) -> InstrumentedIndexedSharedRandomness<'_> {
self.inner.cross_shard_prss()
}
}

impl<'a, B: ShardBinding> super::Context for Context<'a, B> {
Expand Down Expand Up @@ -218,6 +222,10 @@
fn shard_recv_channel<M: Message>(&self, origin: ShardIndex) -> ShardReceivingEnd<M> {
self.inner.shard_recv_channel(origin)
}

fn cross_shard_prss(&self) -> InstrumentedIndexedSharedRandomness<'_> {
self.inner.cross_shard_prss()
}

Check warning on line 228 in ipa-core/src/protocol/context/semi_honest.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/context/semi_honest.rs#L226-L228

Added lines #L226 - L228 were not covered by tests
}

impl<'a, B: ShardBinding, F: ExtendableField> super::Context for Upgraded<'a, B, F> {
Expand Down
6 changes: 3 additions & 3 deletions ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ mod tests {
// changing x2
if ctx.gate.as_ref().contains("transfer_x_y")
&& ctx.dest == Role::H2
&& ctx.shard.map(|s| s.shard_id) == target_shard
&& ctx.shard == target_shard
{
data[0] ^= 1u8;
}
Expand All @@ -883,7 +883,7 @@ mod tests {
// changing y1
if ctx.gate.as_ref().contains("transfer_x_y")
&& ctx.dest == Role::H3
&& ctx.shard.map(|s| s.shard_id) == target_shard
&& ctx.shard == target_shard
{
data[0] ^= 1u8;
}
Expand All @@ -899,7 +899,7 @@ mod tests {
// changing c_hat_2
if ctx.gate.as_ref().contains("transfer_c")
&& ctx.dest == Role::H2
&& ctx.shard.map(|s| s.shard_id) == target_shard
&& ctx.shard == target_shard
{
data[0] ^= 1u8;
}
Expand Down
Loading