Skip to content

Commit

Permalink
Streaming tag resharding and match key collection in Hybrid
Browse files Browse the repository at this point in the history
In private-attribution#1358 it was mentioned that waiting until all AAD tags and match keys have been collected before starting resharding process adds latency that is unnecessary. We can start resharding process right when we received the first tag and do everything in parallel.

This change does that by leveraging newly added `reshard_try_stream` and a few helper structs and functions to make it ergonomic to use
  • Loading branch information
akoshelev committed Oct 20, 2024
1 parent 99dbe22 commit 659561d
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 34 deletions.
54 changes: 20 additions & 34 deletions ipa-core/src/query/runner/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,9 @@ use crate::{
BodyStream, LengthDelimitedStream,
},
hpke::PrivateKeyRegistry,
protocol::{
context::{reshard_iter, ShardedContext},
hybrid::step::HybridStep,
step::ProtocolStep::Hybrid,
},
report::hybrid::{EncryptedHybridReport, HybridReport, UniqueTag, UniqueTagValidator},
protocol::{context::ShardedContext, hybrid::step::HybridStep, step::ProtocolStep::Hybrid},
query::runner::reshard_tag::reshard_aad,
report::hybrid::{EncryptedHybridReport, UniqueTag, UniqueTagValidator},
secret_sharing::{replicated::semi_honest::AdditiveShare as ReplicatedShare, SharedValue},
};

Expand Down Expand Up @@ -61,35 +58,24 @@ where
));
}

let (_decrypted_reports, tags): (Vec<HybridReport<BA8, BA3>>, Vec<UniqueTag>) =
LengthDelimitedStream::<EncryptedHybridReport, _>::new(input_stream)
.map_err(Into::<Error>::into)
.map_ok(|enc_reports| {
iter(enc_reports.into_iter().map({
|enc_report| {
let dec_report = enc_report
.decrypt::<R, BA8, BA3, BA20>(key_registry.as_ref())
.map_err(Into::<Error>::into);
let unique_tag = UniqueTag::from_unique_bytes(&enc_report);
dec_report.map(|dec_report1| (dec_report1, unique_tag))
}
}))
})
.try_flatten()
.take(sz)
.try_fold(
(Vec::with_capacity(sz), Vec::with_capacity(sz)),
|mut acc, result| async move {
acc.0.push(result.0);
acc.1.push(result.1);
Ok(acc)
},
)
.await?;

let resharded_tags = reshard_iter(
let stream = LengthDelimitedStream::<EncryptedHybridReport, _>::new(input_stream)
.map_err(Into::<Error>::into)
.map_ok(|enc_reports| {
iter(enc_reports.into_iter().map({
|enc_report| {
let dec_report = enc_report
.decrypt::<R, BA8, BA3, BA20>(key_registry.as_ref())
.map_err(Into::<Error>::into);
let unique_tag = UniqueTag::from_unique_bytes(&enc_report);
dec_report.map(|dec_report1| (dec_report1, unique_tag))
}
}))
})
.try_flatten()
.take(sz);
let (_decrypted_reports, resharded_tags) = reshard_aad(
ctx.narrow(&HybridStep::ReshardByTag),
tags,
stream,
|ctx, _, tag| tag.shard_picker(ctx.shard_count()),
)
.await?;
Expand Down
1 change: 1 addition & 0 deletions ipa-core/src/query/runner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod add_in_prime_field;
mod hybrid;
mod oprf_ipa;
mod reshard_tag;
#[cfg(any(test, feature = "cli", feature = "test-fixture"))]
mod test_multiply;

Expand Down
149 changes: 149 additions & 0 deletions ipa-core/src/query/runner/reshard_tag.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
use std::{
pin::{pin, Pin},
task::{Context, Poll},
};

use futures::{ready, Stream};
use pin_project::pin_project;

use crate::{
error::Error,
helpers::Message,
protocol::{
context::{reshard_try_stream, ShardedContext},
RecordId,
},
sharding::ShardIndex,
};

type DataWithTag<D, A> = Result<(D, A), Error>;

/// Helper function to work with inputs to hybrid queries. Each encryption needs
/// to be checked for uniqueness and we use AAD tag for that. While match keys are
/// being collected, AAD tags need to be resharded. This function does both at the same
/// time which should reduce the perceived latency of queries.
///
/// The output contains two separate collections: one for data and another one
/// for AAD tags that are "owned" by this shard. The tags can later be checked for
/// uniqueness.
///
/// ## Errors
/// This will return an error, if input stream contains at least one `Err` element.
#[allow(dead_code)]
pub async fn reshard_aad<L, K, A, C, S>(
ctx: C,
input: L,
shard_picker: S,
) -> Result<(Vec<K>, Vec<A>), crate::error::Error>
where
L: Stream<Item = DataWithTag<K, A>>,
S: Fn(C, RecordId, &A) -> ShardIndex + Send,
A: Message + Clone,
C: ShardedContext,
{
let mut k_buf = Vec::with_capacity(input.size_hint().1.unwrap_or(0));
let splitter = StreamSplitter {
inner: input,
buf: &mut k_buf,
};
let a_buf = reshard_try_stream(ctx, splitter, shard_picker).await?;

Ok((k_buf, a_buf))
}

/// Takes a fallible input stream that yields a tuple `(K, A)` and produces a new stream
/// over `A` while collecting `K` elements into the provided buffer.
/// Any error encountered from the input stream is propagated.
#[pin_project]
struct StreamSplitter<'a, S: Stream<Item = DataWithTag<K, A>>, K, A> {
#[pin]
inner: S,
buf: &'a mut Vec<K>,
}

impl<S: Stream<Item = Result<(K, A), Error>>, K, A> Stream for StreamSplitter<'_, S, K, A> {
type Item = Result<A, crate::error::Error>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
match ready!(this.inner.poll_next(cx)) {
Some(Ok((k, a))) => {
this.buf.push(k);
Poll::Ready(Some(Ok(a)))
}
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}

#[cfg(test)]
mod tests {
use futures::{stream, StreamExt};

use crate::{
error::Error,
ff::{boolean_array::BA8, U128Conversions},
query::runner::reshard_tag::reshard_aad,
secret_sharing::SharedValue,
sharding::{ShardConfiguration, ShardIndex},
test_executor::run,
test_fixture::{Runner, TestWorld, TestWorldConfig, WithShards},
};

#[test]
fn reshard_basic() {
run(|| async {
let world: TestWorld<WithShards<2>> =
TestWorld::with_shards(TestWorldConfig::default());
world
.semi_honest(
vec![BA8::truncate_from(1u128), BA8::truncate_from(2u128)].into_iter(),
|ctx, input| async move {
let shard_id = ctx.shard_id();
let sz = input.len();
let (values, tags) = reshard_aad(
ctx,
stream::iter(input).map(|v| Ok((v, BA8::ZERO))),
|_, _, _| ShardIndex::FIRST,
)
.await
.unwrap();
assert_eq!(sz, values.len());
match shard_id {
ShardIndex::FIRST => assert_eq!(2, tags.len()),
_ => assert_eq!(0, tags.len()),
}
},
)
.await;
});
}

#[test]
#[should_panic(expected = "InconsistentShares")]
fn reshard_err() {
run(|| async {
let world: TestWorld<WithShards<2>> =
TestWorld::with_shards(TestWorldConfig::default());
world
.semi_honest(
vec![BA8::truncate_from(1u128), BA8::truncate_from(2u128)].into_iter(),
|ctx, input| async move {
reshard_aad(
ctx,
stream::iter(input)
.map(|_| Err::<(BA8, BA8), _>(Error::InconsistentShares)),
|_, _, _| ShardIndex::FIRST,
)
.await
.unwrap();
},
)
.await;
});
}
}

0 comments on commit 659561d

Please sign in to comment.