From 659561d4434c3ff7982befe652cdbc03e6c53054 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Sun, 20 Oct 2024 14:18:54 -0700 Subject: [PATCH] Streaming tag resharding and match key collection in Hybrid In #1358 it was mentioned that waiting until all AAD tags and match keys have been collected before starting resharding process adds latency that is unnecessary. We can start resharding process right when we received the first tag and do everything in parallel. This change does that by leveraging newly added `reshard_try_stream` and a few helper structs and functions to make it ergonomic to use --- ipa-core/src/query/runner/hybrid.rs | 54 +++----- ipa-core/src/query/runner/mod.rs | 1 + ipa-core/src/query/runner/reshard_tag.rs | 149 +++++++++++++++++++++++ 3 files changed, 170 insertions(+), 34 deletions(-) create mode 100644 ipa-core/src/query/runner/reshard_tag.rs diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index cc3b861c7..7154c066b 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -10,12 +10,9 @@ use crate::{ BodyStream, LengthDelimitedStream, }, hpke::PrivateKeyRegistry, - protocol::{ - context::{reshard_iter, ShardedContext}, - hybrid::step::HybridStep, - step::ProtocolStep::Hybrid, - }, - report::hybrid::{EncryptedHybridReport, HybridReport, UniqueTag, UniqueTagValidator}, + protocol::{context::ShardedContext, hybrid::step::HybridStep, step::ProtocolStep::Hybrid}, + query::runner::reshard_tag::reshard_aad, + report::hybrid::{EncryptedHybridReport, UniqueTag, UniqueTagValidator}, secret_sharing::{replicated::semi_honest::AdditiveShare as ReplicatedShare, SharedValue}, }; @@ -61,35 +58,24 @@ where )); } - let (_decrypted_reports, tags): (Vec>, Vec) = - LengthDelimitedStream::::new(input_stream) - .map_err(Into::::into) - .map_ok(|enc_reports| { - iter(enc_reports.into_iter().map({ - |enc_report| { - let dec_report = enc_report - .decrypt::(key_registry.as_ref()) - .map_err(Into::::into); - let unique_tag = UniqueTag::from_unique_bytes(&enc_report); - dec_report.map(|dec_report1| (dec_report1, unique_tag)) - } - })) - }) - .try_flatten() - .take(sz) - .try_fold( - (Vec::with_capacity(sz), Vec::with_capacity(sz)), - |mut acc, result| async move { - acc.0.push(result.0); - acc.1.push(result.1); - Ok(acc) - }, - ) - .await?; - - let resharded_tags = reshard_iter( + let stream = LengthDelimitedStream::::new(input_stream) + .map_err(Into::::into) + .map_ok(|enc_reports| { + iter(enc_reports.into_iter().map({ + |enc_report| { + let dec_report = enc_report + .decrypt::(key_registry.as_ref()) + .map_err(Into::::into); + let unique_tag = UniqueTag::from_unique_bytes(&enc_report); + dec_report.map(|dec_report1| (dec_report1, unique_tag)) + } + })) + }) + .try_flatten() + .take(sz); + let (_decrypted_reports, resharded_tags) = reshard_aad( ctx.narrow(&HybridStep::ReshardByTag), - tags, + stream, |ctx, _, tag| tag.shard_picker(ctx.shard_count()), ) .await?; diff --git a/ipa-core/src/query/runner/mod.rs b/ipa-core/src/query/runner/mod.rs index 9bd739db9..3f1b59f55 100644 --- a/ipa-core/src/query/runner/mod.rs +++ b/ipa-core/src/query/runner/mod.rs @@ -2,6 +2,7 @@ mod add_in_prime_field; mod hybrid; mod oprf_ipa; +mod reshard_tag; #[cfg(any(test, feature = "cli", feature = "test-fixture"))] mod test_multiply; diff --git a/ipa-core/src/query/runner/reshard_tag.rs b/ipa-core/src/query/runner/reshard_tag.rs new file mode 100644 index 000000000..5ef7b6311 --- /dev/null +++ b/ipa-core/src/query/runner/reshard_tag.rs @@ -0,0 +1,149 @@ +use std::{ + pin::{pin, Pin}, + task::{Context, Poll}, +}; + +use futures::{ready, Stream}; +use pin_project::pin_project; + +use crate::{ + error::Error, + helpers::Message, + protocol::{ + context::{reshard_try_stream, ShardedContext}, + RecordId, + }, + sharding::ShardIndex, +}; + +type DataWithTag = Result<(D, A), Error>; + +/// Helper function to work with inputs to hybrid queries. Each encryption needs +/// to be checked for uniqueness and we use AAD tag for that. While match keys are +/// being collected, AAD tags need to be resharded. This function does both at the same +/// time which should reduce the perceived latency of queries. +/// +/// The output contains two separate collections: one for data and another one +/// for AAD tags that are "owned" by this shard. The tags can later be checked for +/// uniqueness. +/// +/// ## Errors +/// This will return an error, if input stream contains at least one `Err` element. +#[allow(dead_code)] +pub async fn reshard_aad( + ctx: C, + input: L, + shard_picker: S, +) -> Result<(Vec, Vec), crate::error::Error> +where + L: Stream>, + S: Fn(C, RecordId, &A) -> ShardIndex + Send, + A: Message + Clone, + C: ShardedContext, +{ + let mut k_buf = Vec::with_capacity(input.size_hint().1.unwrap_or(0)); + let splitter = StreamSplitter { + inner: input, + buf: &mut k_buf, + }; + let a_buf = reshard_try_stream(ctx, splitter, shard_picker).await?; + + Ok((k_buf, a_buf)) +} + +/// Takes a fallible input stream that yields a tuple `(K, A)` and produces a new stream +/// over `A` while collecting `K` elements into the provided buffer. +/// Any error encountered from the input stream is propagated. +#[pin_project] +struct StreamSplitter<'a, S: Stream>, K, A> { + #[pin] + inner: S, + buf: &'a mut Vec, +} + +impl>, K, A> Stream for StreamSplitter<'_, S, K, A> { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + match ready!(this.inner.poll_next(cx)) { + Some(Ok((k, a))) => { + this.buf.push(k); + Poll::Ready(Some(Ok(a))) + } + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +#[cfg(test)] +mod tests { + use futures::{stream, StreamExt}; + + use crate::{ + error::Error, + ff::{boolean_array::BA8, U128Conversions}, + query::runner::reshard_tag::reshard_aad, + secret_sharing::SharedValue, + sharding::{ShardConfiguration, ShardIndex}, + test_executor::run, + test_fixture::{Runner, TestWorld, TestWorldConfig, WithShards}, + }; + + #[test] + fn reshard_basic() { + run(|| async { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + world + .semi_honest( + vec![BA8::truncate_from(1u128), BA8::truncate_from(2u128)].into_iter(), + |ctx, input| async move { + let shard_id = ctx.shard_id(); + let sz = input.len(); + let (values, tags) = reshard_aad( + ctx, + stream::iter(input).map(|v| Ok((v, BA8::ZERO))), + |_, _, _| ShardIndex::FIRST, + ) + .await + .unwrap(); + assert_eq!(sz, values.len()); + match shard_id { + ShardIndex::FIRST => assert_eq!(2, tags.len()), + _ => assert_eq!(0, tags.len()), + } + }, + ) + .await; + }); + } + + #[test] + #[should_panic(expected = "InconsistentShares")] + fn reshard_err() { + run(|| async { + let world: TestWorld> = + TestWorld::with_shards(TestWorldConfig::default()); + world + .semi_honest( + vec![BA8::truncate_from(1u128), BA8::truncate_from(2u128)].into_iter(), + |ctx, input| async move { + reshard_aad( + ctx, + stream::iter(input) + .map(|_| Err::<(BA8, BA8), _>(Error::InconsistentShares)), + |_, _, _| ShardIndex::FIRST, + ) + .await + .unwrap(); + }, + ) + .await; + }); + } +}