From bc692442c78aa946222d906f61ae90029653a4de Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Sat, 14 Dec 2024 01:15:50 -0800 Subject: [PATCH] Fix the bug when sender was never dropped --- ipa-core/src/query/runner/hybrid.rs | 42 ++++++++++++++++++++++++----- 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index 8e9fa85c6..0f57b2ef5 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -2,13 +2,15 @@ use std::{ convert::{Infallible, Into}, marker::PhantomData, ops::Add, + pin::Pin, sync::Arc, + task::{Context, Poll}, }; -use futures::{stream::iter, StreamExt, TryStreamExt}; -use futures_util::TryFutureExt; +use futures::{stream::iter, Stream, StreamExt, TryStreamExt}; +use futures_util::{stream, TryFutureExt}; use generic_array::ArrayLength; -use tokio_stream::wrappers::ReceiverStream; +use tokio::sync::mpsc::Receiver; use super::QueryResult; use crate::{ @@ -73,6 +75,23 @@ impl Query { } } +struct KnownSizeReceiverStream { + rx: Receiver, + sz: usize, +} + +impl Stream for KnownSizeReceiverStream { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.rx.poll_recv(cx) + } + + fn size_hint(&self) -> (usize, Option) { + (0, Some(self.sz)) + } +} + impl Query where C: UpgradableContext @@ -153,11 +172,18 @@ where let (tx, rx) = tokio::sync::mpsc::channel(ctx.active_work().get()); let (_, (decrypted_reports, resharded_tags)) = assert_send(futures::future::try_join( - seq_join(ctx.active_work(), stream) - .try_for_each(|(report, tag)| tx.send((report, tag)).map_err(|_| Error::Internal)), + { + let f = seq_join(ctx.active_work(), stream) + .zip(stream::repeat(tx)) + .map(|(r, tx)| r.map(|v| (v, tx))) + .try_for_each(|((report, tag), tx)| async move { + tx.send((report, tag)).map_err(|_| Error::Internal).await + }); + f + }, reshard_aad( ctx.narrow(&HybridStep::ReshardByTag), - ReceiverStream::new(rx).map(Ok), + KnownSizeReceiverStream { rx, sz }.map(Ok), |ctx, _, tag| tag.shard_picker(ctx.shard_count()), ), )) @@ -229,6 +255,7 @@ mod tests { use rand_core::SeedableRng; use crate::{ + executor::IpaRuntime, ff::{ boolean_array::{BA3, BA32, BA8}, U128Conversions, @@ -335,6 +362,7 @@ mod tests { HybridQuery::<_, BA32, KeyRegistry>::new( query_params, Arc::clone(&key_registry), + IpaRuntime::current(), ) .execute(ctx, query_size, input) }) @@ -418,6 +446,7 @@ mod tests { HybridQuery::<_, BA32, KeyRegistry>::new( query_params, Arc::clone(&key_registry), + IpaRuntime::current(), ) .execute(ctx, query_size, input) }) @@ -464,6 +493,7 @@ mod tests { HybridQuery::<_, BA32, KeyRegistry>::new( query_params, Arc::clone(&key_registry), + IpaRuntime::current(), ) .execute(ctx, query_size, input) })