diff --git a/ipa-core/src/protocol/context/mod.rs b/ipa-core/src/protocol/context/mod.rs index 5b1de3505..bafc17be9 100644 --- a/ipa-core/src/protocol/context/mod.rs +++ b/ipa-core/src/protocol/context/mod.rs @@ -11,12 +11,12 @@ pub mod upgrade; mod batcher; pub mod validator; -use std::{collections::HashMap, iter, num::NonZeroUsize, pin::pin}; +use std::{collections::HashMap, num::NonZeroUsize, pin::pin}; use async_trait::async_trait; pub use dzkp_malicious::DZKPUpgraded as DZKPUpgradedMaliciousContext; pub use dzkp_semi_honest::DZKPUpgraded as DZKPUpgradedSemiHonestContext; -use futures::{stream, Stream, StreamExt}; +use futures::{stream, Stream, StreamExt, TryStreamExt}; use ipa_step::{Step, StepNarrow}; pub use malicious::MaliciousProtocolSteps; use prss::{InstrumentedIndexedSharedRandomness, InstrumentedSequentialSharedRandomness}; @@ -365,6 +365,12 @@ impl<'a> Inner<'a> { /// N per shard). Each channel stays open until the very last row is processed, then they are explicitly /// closed, even if nothing has been communicated between that pair. /// +/// ## Stream size +/// [`reshard_try_stream`] takes a regular stream, but will panic at runtime, if the stream +/// upper bound size is not known. Opting out for a runtime check is necessary for it to work +/// with query inputs, where the submitter stream is truncated to take at most `sz` elements. +/// This would mean that stream may have less than `sz` elements and resharding should work. +/// /// ## Shard picking considerations /// It is expected for `shard_picker` to select shards uniformly, by either using [`prss`] or sampling /// random values with enough entropy. Failure to do so may lead to extra memory overhead - this @@ -374,45 +380,29 @@ impl<'a> Inner<'a> { /// /// [`calculations`]: https://docs.google.com/document/d/1vej6tYgNV3GWcldD4tl7a4Z9EeZwda3F5u7roPGArlU/ /// -/// ## Stream size -/// Note that it currently works for streams where size is known in advance. Mainly because -/// we want to set up send buffer sizes and avoid sending records one-by-one to each shard. -/// Other than that, there are no technical limitation here, and it could be possible to make it -/// work with regular streams if the batching problem is somehow addressed. -/// -/// -/// ```compile_fail -/// use futures::stream::{self, StreamExt}; -/// use ipa_core::protocol::context::reshard_stream; -/// use ipa_core::ff::boolean::Boolean; -/// use ipa_core::secret_sharing::SharedValue; -/// async { -/// let a = [Boolean::ZERO]; -/// let mut s = stream::iter(a.into_iter()).cycle(); -/// // this should fail to compile: -/// // the trait bound `futures::stream::Cycle<...>: ExactSizeStream` is not satisfied -/// reshard_stream(todo!(), s, todo!()).await; -/// }; -/// ``` /// /// ## Panics -/// When `shard_picker` returns an out-of-bounds index. +/// When `shard_picker` returns an out-of-bounds index or if the input stream size +/// upper bound is not known. The latter may be the case for infinite streams. /// /// ## Errors -/// If cross-shard communication fails +/// If cross-shard communication fails or if an input stream +/// yields an `Err` element. /// -pub async fn reshard_stream( +pub async fn reshard_try_stream( ctx: C, input: L, shard_picker: S, ) -> Result, crate::error::Error> where - L: ExactSizeStream, + L: Stream>, S: Fn(C, RecordId, &K) -> ShardIndex, K: Message + Clone, C: ShardedContext, { - let input_len = input.len(); + let (_, Some(input_len)) = input.size_hint() else { + panic!("input stream must have size upper bound for resharding to work") + }; // We set channels capacity to be at least 1 to be able to open send channels to all peers. // It is prohibited to create them if total records is not set. We also over-provision here @@ -438,15 +428,17 @@ where // Request data from all shards. let rcv_stream = ctx .recv_from_shards::() - .map(|(shard_id, v)| { - ( - shard_id, - v.map(Option::Some).map_err(crate::error::Error::from), - ) + .map(|(shard_id, v)| match v { + Ok(v) => Ok((shard_id, Some(v))), + Err(e) => Err(e), }) .fuse(); let input = pin!(input); + // Annoying consequence of not having async closures stable. async blocks + // cannot capture `Copy` values and there is no way to express that + // only some things need to be moved in Rust + let mut counter = 0_u32; // This produces a stream of outcomes of send requests. // In order to make it compatible with receive stream, it also returns records that must @@ -456,38 +448,36 @@ where // whole resharding process. // If send was successful, we set the argument to Ok(None). Only records assigned to this shard // by the `shard_picker` will have the value of Ok(Some(Value)) - let send_stream = futures::stream::unfold( + let send_stream = futures::stream::try_unfold( // it is crucial that the following execution is completed sequentially, in order for record id // tracking per shard to work correctly. If tasks complete out of order, this will cause share // misplacement on the recipient side. - ( - input - .enumerate() - .zip(stream::iter(iter::repeat(ctx.clone()))), - &mut send_channels, - ), - |(mut input, send_channels)| async { - // Process more data as it comes in, or close the sending channels, if there is nothing - // left. - if let Some(((i, val), ctx)) = input.next().await { - let dest_shard = shard_picker(ctx, RecordId::from(i), &val); - if dest_shard == my_shard { - Some(((my_shard, Ok(Some(val.clone()))), (input, send_channels))) + (input, &mut send_channels, &mut counter), + |(mut input, send_channels, i)| { + let ctx = ctx.clone(); + + async { + // Process more data as it comes in, or close the sending channels, if there is nothing + // left. + if let Some(val) = input.try_next().await? { + let dest_shard = shard_picker(ctx, RecordId::from(*i), &val); + *i += 1; + if dest_shard == my_shard { + Ok(Some(((my_shard, Some(val)), (input, send_channels, i)))) + } else { + let (record_id, se) = send_channels.get_mut(&dest_shard).unwrap(); + se.send(*record_id, val) + .await + .map_err(crate::error::Error::from)?; + *record_id += 1; + Ok(Some(((my_shard, None), (input, send_channels, i)))) + } } else { - let (record_id, se) = send_channels.get_mut(&dest_shard).unwrap(); - let send_result = se - .send(*record_id, val) - .await - .map_err(crate::error::Error::from) - .map(|()| None); - *record_id += 1; - Some(((my_shard, send_result), (input, send_channels))) - } - } else { - for (last_record, send_channel) in send_channels.values() { - send_channel.close(*last_record).await; + for (last_record, send_channel) in send_channels.values() { + send_channel.close(*last_record).await; + } + Ok(None) } - None } }, ) @@ -519,8 +509,8 @@ where // This approach makes sure we do what we can - send or receive. let mut send_recv = pin!(futures::stream::select(send_stream, rcv_stream)); - while let Some((shard_id, v)) = send_recv.next().await { - if let Some(m) = v? { + while let Some((shard_id, v)) = send_recv.try_next().await? { + if let Some(m) = v { r[usize::from(shard_id)].push(m); } } @@ -528,12 +518,56 @@ where Ok(r.into_iter().flatten().collect()) } +/// Provides the same functionality as [`reshard_try_stream`] on +/// infallible streams +/// +/// ## Stream size +/// Note that it currently works for streams where size is known in advance. Mainly because +/// we want to set up send buffer sizes and avoid sending records one-by-one to each shard. +/// Other than that, there are no technical limitation here, and it could be possible to make it +/// work with regular streams or opt-out to runtime checks as [`reshard_try_stream`] does. +/// +/// +/// ```compile_fail +/// use futures::stream::{self, StreamExt}; +/// use ipa_core::protocol::context::reshard_stream; +/// use ipa_core::ff::boolean::Boolean; +/// use ipa_core::secret_sharing::SharedValue; +/// async { +/// let a = [Boolean::ZERO]; +/// let mut s = stream::iter(a.into_iter()).cycle(); +/// // this should fail to compile: +/// // the trait bound `futures::stream::Cycle<...>: ExactSizeStream` is not satisfied +/// reshard_stream(todo!(), s, todo!()).await; +/// }; +/// ``` +/// ## Panics +/// When `shard_picker` returns an out-of-bounds index. +/// +/// ## Errors +/// If cross-shard communication fails +pub async fn reshard_stream( + ctx: C, + input: L, + shard_picker: S, +) -> Result, crate::error::Error> +where + L: ExactSizeStream, + S: Fn(C, RecordId, &K) -> ShardIndex, + K: Message + Clone, + C: ShardedContext, +{ + reshard_try_stream(ctx, input.map(Ok), shard_picker).await +} + /// Same as [`reshard_stream`] but takes an iterator with the known size /// as input. /// -/// # Errors +/// ## Panics +/// When `shard_picker` returns an out-of-bounds index. /// -/// # Panics +/// ## Errors +/// If cross-shard communication fails pub async fn reshard_iter( ctx: C, input: L, @@ -567,12 +601,13 @@ pub trait DZKPContext: Context { async fn validate_record(&self, record_id: RecordId) -> Result<(), Error>; } -#[cfg(all(test, unit_test))] +#[cfg(test)] mod tests { - use std::{iter, iter::repeat}; + use std::{iter, iter::repeat, pin::Pin, task::Poll}; - use futures::{future::join_all, stream, stream::StreamExt, try_join}; + use futures::{future::join_all, ready, stream, stream::StreamExt, try_join, Stream}; use ipa_step::StepNarrow; + use pin_project::pin_project; use rand::{ distributions::{Distribution, Standard}, Rng, @@ -588,16 +623,20 @@ mod tests { protocol::{ basics::ShareKnownValue, context::{ - reshard_iter, reshard_stream, step::MaliciousProtocolStep::MaliciousProtocol, - upgrade::Upgradable, Context, ShardedContext, UpgradableContext, Validator, + reshard_iter, reshard_stream, reshard_try_stream, + step::MaliciousProtocolStep::MaliciousProtocol, upgrade::Upgradable, Context, + ShardedContext, UpgradableContext, Validator, }, prss::SharedRandomness, RecordId, }, - secret_sharing::replicated::{ - malicious::{AdditiveShare as MaliciousReplicated, ExtendableField}, - semi_honest::AdditiveShare as Replicated, - ReplicatedSecretSharing, + secret_sharing::{ + replicated::{ + malicious::{AdditiveShare as MaliciousReplicated, ExtendableField}, + semi_honest::AdditiveShare as Replicated, + ReplicatedSecretSharing, + }, + SharedValue, }, sharding::{ShardConfiguration, ShardIndex}, telemetry::metrics::{ @@ -917,6 +956,145 @@ mod tests { }); } + #[test] + fn reshard_try_stream_basic() { + run(|| async move { + const SHARDS: u32 = 5; + let input: Vec<_> = (0..SHARDS).map(BA8::truncate_from).collect(); + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let r = world + .semi_honest(input.clone().into_iter(), |ctx, shard_input| async move { + reshard_try_stream(ctx, stream::iter(shard_input).map(Ok), |_, record_id, _| { + ShardIndex::from(u32::from(record_id) % SHARDS) + }) + .await + .unwrap() + }) + .await + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + assert_eq!(input, r); + }); + } + + #[test] + fn reshard_try_stream_less_items_than_expected() { + /// This allows advertising higher upper bound limit + /// that actual number of elements in the stream. + /// reshard should be able to tolerate that + #[pin_project] + struct Wrapper { + #[pin] + inner: S, + expected_len: usize, + } + + impl Wrapper { + fn new(inner: S, expected_len: usize) -> Self { + assert!(expected_len > 0); + Self { + inner, + expected_len, + } + } + } + + impl Stream for Wrapper { + type Item = S::Item; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let this = self.project(); + let r = match ready!(this.inner.poll_next(cx)) { + Some(val) => { + *this.expected_len -= 1; + Poll::Ready(Some(val)) + } + None => Poll::Ready(None), + }; + + assert!( + *this.expected_len > 0, + "Stream should have less elements than expected" + ); + r + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.expected_len)) + } + } + + run(|| async move { + const SHARDS: u32 = 5; + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + let input: Vec<_> = (0..5 * SHARDS).map(BA8::truncate_from).collect(); + let r = world + .semi_honest(input.clone().into_iter(), |ctx, shard_input| async move { + reshard_try_stream( + ctx, + Wrapper::new(stream::iter(shard_input).map(Ok), 25), + |_, record_id, _| ShardIndex::from(u32::from(record_id) % SHARDS), + ) + .await + .unwrap() + }) + .await + .into_iter() + .flat_map(|v| v.reconstruct()) + .collect::>(); + + assert_eq!(input, r); + }); + } + + #[test] + #[should_panic(expected = "input stream must have size upper bound for resharding to work")] + fn reshard_try_stream_infinite() { + run(|| async move { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + world + .semi_honest(Vec::::new().into_iter(), |ctx, _| async move { + reshard_try_stream(ctx, stream::repeat(BA8::ZERO).map(Ok), |_, _, _| { + ShardIndex::FIRST + }) + .await + .unwrap() + }) + .await; + }); + } + + #[test] + fn reshard_try_stream_err() { + run(|| async move { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + world + .semi_honest(Vec::::new().into_iter(), |ctx, _| async move { + let err = reshard_try_stream( + ctx, + stream::iter(vec![ + Ok(BA8::ZERO), + Err(crate::error::Error::InconsistentShares), + ]), + |_, _, _| ShardIndex::FIRST, + ) + .await + .unwrap_err(); + assert!(matches!(err, crate::error::Error::InconsistentShares)); + }) + .await; + }); + } + #[test] fn prss_one_side() { run(|| async {