From 362dd561e95a9812c00f80bf1c36c8ba5ad44f80 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 8 Nov 2023 14:58:09 -0800 Subject: [PATCH] Support multithreading in `seq_join`/`parallel_join` Support is currently behind a feature flag that is not enabled by default We use userspace concurrency to drive many futures in parallel by spawning tasks inside the executor. This model is not ideal for performance because memory loads will happen across thread boundaries and NUMA cores, but already gives 50% more throughput for OPRF version and 200% to old IPA. --- ipa-core/Cargo.toml | 3 + ipa-core/src/protocol/basics/reshare.rs | 2 +- .../modulus_conversion/convert_shares.rs | 32 +- .../protocol/sort/generate_permutation_opt.rs | 4 +- .../replicated/malicious/additive_share.rs | 2 +- ipa-core/src/secret_sharing/scheme.rs | 3 +- ipa-core/src/seq_join.rs | 592 +++++++++++++----- 7 files changed, 453 insertions(+), 185 deletions(-) diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 68867c196..7ee527ec0 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -61,6 +61,8 @@ step-trace = ["descriptive-gate"] # of unit tests use it. Compact uses memory-efficient gates and is suitable for production. descriptive-gate = [] compact-gate = ["ipa-macros/compact-gate"] +# Enable using more than one thread for protocol execution. Most of the parallelism occurs at parallel/seq_join operations +multi-threading = ["async-scoped"] # Standalone aggregation protocol. We use IPA infra for communication # but it has nothing to do with IPA. @@ -73,6 +75,7 @@ ipa-macros = { version = "*", path = "../ipa-macros" } aes = "0.8.3" async-trait = "0.1.68" +async-scoped = { version = "0.7.1", features = ["use-tokio"], optional = true } axum = { version = "0.5.17", optional = true, features = ["http2"] } axum-server = { version = "0.5.1", optional = true, features = [ "rustls", diff --git a/ipa-core/src/protocol/basics/reshare.rs b/ipa-core/src/protocol/basics/reshare.rs index 2e9a868e6..70b65c1c3 100644 --- a/ipa-core/src/protocol/basics/reshare.rs +++ b/ipa-core/src/protocol/basics/reshare.rs @@ -42,7 +42,7 @@ use crate::{ /// `to_helper` = (`rand_left`, `rand_right`) = (r0, r1) /// `to_helper.right` = (`rand_right`, part1 + part2) = (r0, part1 + part2) #[async_trait] -pub trait Reshare: Sized { +pub trait Reshare: Sized + 'static { async fn reshare<'fut>( &self, ctx: C, diff --git a/ipa-core/src/protocol/modulus_conversion/convert_shares.rs b/ipa-core/src/protocol/modulus_conversion/convert_shares.rs index 87df09abf..dae9ae8c1 100644 --- a/ipa-core/src/protocol/modulus_conversion/convert_shares.rs +++ b/ipa-core/src/protocol/modulus_conversion/convert_shares.rs @@ -292,17 +292,17 @@ where /// # Panics /// If the total record count on the context is unspecified. #[tracing::instrument(name = "modulus_conversion", skip_all, fields(bits = ?bit_range, gate = %ctx.gate().as_ref()))] -pub fn convert_bits( +pub fn convert_bits<'a, F, V, C, S, VS>( ctx: C, binary_shares: VS, bit_range: Range, -) -> impl Stream, Error>> +) -> impl Stream, Error>> + 'a where F: PrimeField, - V: ToBitConversionTriples, - C: UpgradedContext, + V: ToBitConversionTriples + 'a, + C: UpgradedContext + 'a, S: LinearSecretSharing + SecureMul, - VS: Stream + Unpin + Send, + VS: Stream + Unpin + Send + 'a, for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, BitConversionTriple>, BitConversionTriple>, { @@ -313,35 +313,37 @@ where /// Note that unconverted fields are not upgraded, so they might need to be upgraded either before or /// after invoking this function. #[tracing::instrument(name = "modulus_conversion", skip_all, fields(bits = ?bit_range, gate = %ctx.gate().as_ref()))] -pub fn convert_selected_bits( +pub fn convert_selected_bits<'inp, F, V, C, S, VS, R>( ctx: C, binary_shares: VS, bit_range: Range, -) -> impl Stream, R), Error>> +) -> impl Stream, R), Error>> + 'inp where + R: Send + 'static, F: PrimeField, - V: ToBitConversionTriples, - C: UpgradedContext, + V: ToBitConversionTriples + 'inp, + C: UpgradedContext + 'inp, S: LinearSecretSharing + SecureMul, - VS: Stream + Unpin + Send, + VS: Stream + Unpin + Send + 'inp, for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, BitConversionTriple>, BitConversionTriple>, { convert_some_bits(ctx, binary_shares, RecordId::FIRST, bit_range) } -pub(crate) fn convert_some_bits( +pub(crate) fn convert_some_bits<'a, F, V, C, S, VS, R>( ctx: C, binary_shares: VS, first_record: RecordId, bit_range: Range, -) -> impl Stream, R), Error>> +) -> impl Stream, R), Error>> + 'a where + R: Send + 'static, F: PrimeField, - V: ToBitConversionTriples, - C: UpgradedContext, + V: ToBitConversionTriples + 'a, + C: UpgradedContext + 'a, S: LinearSecretSharing + SecureMul, - VS: Stream + Unpin + Send, + VS: Stream + Unpin + Send + 'a, for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, BitConversionTriple>, BitConversionTriple>, { diff --git a/ipa-core/src/protocol/sort/generate_permutation_opt.rs b/ipa-core/src/protocol/sort/generate_permutation_opt.rs index b82e05a80..22d2eed1b 100644 --- a/ipa-core/src/protocol/sort/generate_permutation_opt.rs +++ b/ipa-core/src/protocol/sort/generate_permutation_opt.rs @@ -298,7 +298,9 @@ mod tests { } /// Passing 32 records for Fp31 doesn't work. - #[tokio::test] + /// + /// Requires one extra thread to cancel futures running in parallel with the one that panics. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[should_panic = "prime field ipa_core::ff::prime_field::fp31::Fp31 is too small to sort 32 records"] async fn fp31_overflow() { const COUNT: usize = 32; diff --git a/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs b/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs index c2fe440c4..d92a0d000 100644 --- a/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs +++ b/ipa-core/src/secret_sharing/replicated/malicious/additive_share.rs @@ -77,7 +77,7 @@ impl LinearSecretSharing for AdditiveShare< /// when the protocol is done. This should not be used directly. #[async_trait] pub trait Downgrade: Send { - type Target: Send; + type Target: Send + 'static; async fn downgrade(self) -> UnauthorizedDowngradeWrapper; } diff --git a/ipa-core/src/secret_sharing/scheme.rs b/ipa-core/src/secret_sharing/scheme.rs index 0d2131eeb..b20348539 100644 --- a/ipa-core/src/secret_sharing/scheme.rs +++ b/ipa-core/src/secret_sharing/scheme.rs @@ -7,7 +7,7 @@ use super::{SharedValue, WeakSharedValue}; use crate::ff::{AddSub, AddSubAssign, GaloisField}; /// Secret sharing scheme i.e. Replicated secret sharing -pub trait SecretSharing: Clone + Debug + Sized + Send + Sync { +pub trait SecretSharing: Clone + Debug + Sized + Send + Sync + 'static { const ZERO: Self; } @@ -21,6 +21,7 @@ pub trait Linear: + Mul + for<'r> Mul<&'r V, Output = Self> + Neg + + 'static { } diff --git a/ipa-core/src/seq_join.rs b/ipa-core/src/seq_join.rs index e3cb8c5b3..12cab4d67 100644 --- a/ipa-core/src/seq_join.rs +++ b/ipa-core/src/seq_join.rs @@ -1,5 +1,4 @@ use std::{ - collections::VecDeque, future::IntoFuture, num::NonZeroUsize, pin::Pin, @@ -39,15 +38,13 @@ pub fn assert_send<'a, O>( /// [`try_join_all`]: futures::future::try_join_all /// [`Stream`]: futures::stream::Stream /// [`StreamExt::buffered`]: futures::stream::StreamExt::buffered -pub fn seq_join(active: NonZeroUsize, source: S) -> SequentialFutures +pub fn seq_join<'st, S, F, O>(active: NonZeroUsize, source: S) -> SequentialFutures<'st, S, F> where - S: Stream + Send, - F: Future, + S: Stream + Send + 'st, + F: Future + Send, + O: Send + 'static, { - SequentialFutures { - source: source.fuse(), - active: VecDeque::with_capacity(active.get()), - } + SequentialFutures::new(active, source) } /// The `SeqJoin` trait wraps `seq_try_join_all`, providing the `active` parameter @@ -72,16 +69,37 @@ pub trait SeqJoin { /// [`active_work`]: Self::active_work /// [`parallel_join`]: Self::parallel_join /// [`join3`]: futures::future::join3 - fn try_join(&self, iterable: I) -> TryCollect, Vec> + fn try_join<'fut, I, F, O, E>( + &self, + iterable: I, + ) -> TryCollect, Vec> where I: IntoIterator + Send, - I::IntoIter: Send, - F: Future>, + I::IntoIter: Send + 'fut, + F: Future> + Send + 'fut, + O: Send + 'static, + E: Send + 'static, { seq_try_join_all(self.active_work(), iterable) } /// Join multiple tasks in parallel. Only do this if you can't use a sequential join. + #[cfg(feature = "multi-threading")] + fn parallel_join<'a, I, F, O, E>( + &self, + iterable: I, + ) -> Pin, E>> + Send + 'a>> + where + I: IntoIterator + Send, + F: Future> + Send + 'a, + O: Send + 'static, + E: Send + 'static, + { + multi_thread::parallel_join(iterable) + } + + /// Join multiple tasks in parallel. Only do this if you can't use a sequential join. + #[cfg(not(feature = "multi-threading"))] fn parallel_join(&self, iterable: I) -> futures::future::TryJoinAll where I: IntoIterator, @@ -95,209 +113,348 @@ pub trait SeqJoin { fn active_work(&self) -> NonZeroUsize; } -type SeqTryJoinAll = SequentialFutures::IntoIter>, F>; +type SeqTryJoinAll<'st, I, F> = + SequentialFutures<'st, StreamIter<::IntoIter>, F>; /// A substitute for [`futures::future::try_join_all`] that uses [`seq_join`]. /// This awaits all the provided futures in order, /// aborting early if any future returns `Result::Err`. -pub fn seq_try_join_all( +pub fn seq_try_join_all<'iter, I, F, O, E>( active: NonZeroUsize, source: I, -) -> TryCollect, Vec> +) -> TryCollect, Vec> where I: IntoIterator + Send, - I::IntoIter: Send, - F: Future>, + I::IntoIter: Send + 'iter, + F: Future> + Send + 'iter, + O: Send + 'static, + E: Send + 'static, { seq_join(active, iter(source)).try_collect() } -enum ActiveItem { - Pending(Pin>), - Resolved(F::Output), +impl<'fut, S, F> ExactSizeStream for SequentialFutures<'fut, S, F> +where + S: Stream + Send + ExactSizeStream, + F: IntoFuture, + ::IntoFuture: Send + 'fut, + <::IntoFuture as Future>::Output: Send + 'static, +{ } -impl ActiveItem { - /// Drives this item to resolved state when value is ready to be taken out. Has no effect - /// if the value is ready. - /// - /// ## Panics - /// Panics if this item is completed - fn check_ready(&mut self, cx: &mut Context<'_>) -> bool { - let ActiveItem::Pending(f) = self else { - return true; - }; - if let Poll::Ready(v) = Future::poll(Pin::as_mut(f), cx) { - *self = ActiveItem::Resolved(v); - true - } else { - false - } - } +#[cfg(feature = "multi-threading")] +pub type SequentialFutures<'fut, S, F> = multi_thread::SequentialFutures<'fut, S, F>; - /// Takes the resolved value out - /// - /// ## Panics - /// If the value is not ready yet. - #[must_use] - fn take(self) -> F::Output { - let ActiveItem::Resolved(v) = self else { - panic!("No value to take out"); - }; +#[cfg(not(feature = "multi-threading"))] +pub type SequentialFutures<'unused, S, F> = local::SequentialFutures<'unused, S, F>; + +/// Parallel and sequential join that use at most one thread. Good for unit testing and debugging, +/// to get results in predictable order with fewer things happening at the same time. +#[cfg(not(feature = "multi-threading"))] +mod local { + use std::{collections::VecDeque, marker::PhantomData}; - v + use super::*; + + enum ActiveItem { + Pending(Pin>), + Resolved(F::Output), } -} -#[pin_project] -pub struct SequentialFutures -where - S: Stream + Send, - F: IntoFuture, -{ - #[pin] - source: futures::stream::Fuse, - active: VecDeque>, -} + impl ActiveItem { + /// Drives this item to resolved state when value is ready to be taken out. Has no effect + /// if the value is ready. + /// + /// ## Panics + /// Panics if this item is completed + fn check_ready(&mut self, cx: &mut Context<'_>) -> bool { + let ActiveItem::Pending(f) = self else { + return true; + }; + if let Poll::Ready(v) = Future::poll(Pin::as_mut(f), cx) { + *self = ActiveItem::Resolved(v); + true + } else { + false + } + } -impl Stream for SequentialFutures -where - S: Stream + Send, - F: IntoFuture, -{ - type Item = F::Output; + /// Takes the resolved value out + /// + /// ## Panics + /// If the value is not ready yet. + #[must_use] + fn take(self) -> F::Output { + let ActiveItem::Resolved(v) = self else { + panic!("No value to take out"); + }; + + v + } + } - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); + #[pin_project] + pub struct SequentialFutures<'unused, S, F> + where + S: Stream + Send, + F: IntoFuture, + { + #[pin] + source: futures::stream::Fuse, + active: VecDeque>, + _marker: PhantomData &'unused ()>, + } - // Draw more values from the input, up to the capacity. - while this.active.len() < this.active.capacity() { - if let Poll::Ready(Some(f)) = this.source.as_mut().poll_next(cx) { - this.active - .push_back(ActiveItem::Pending(Box::pin(f.into_future()))); - } else { - break; + impl SequentialFutures<'_, S, F> + where + S: Stream + Send, + F: IntoFuture, + { + pub fn new(active: NonZeroUsize, source: S) -> Self { + Self { + source: source.fuse(), + active: VecDeque::with_capacity(active.get()), + _marker: PhantomData, } } + } - if let Some(item) = this.active.front_mut() { - if item.check_ready(cx) { - let v = this.active.pop_front().map(ActiveItem::take); - Poll::Ready(v) - } else { - for f in this.active.iter_mut().skip(1) { - f.check_ready(cx); + impl Stream for SequentialFutures<'_, S, F> + where + S: Stream + Send, + F: IntoFuture, + { + type Item = F::Output; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + // Draw more values from the input, up to the capacity. + while this.active.len() < this.active.capacity() { + if let Poll::Ready(Some(f)) = this.source.as_mut().poll_next(cx) { + this.active + .push_back(ActiveItem::Pending(Box::pin(f.into_future()))); + } else { + break; + } + } + + if let Some(item) = this.active.front_mut() { + if item.check_ready(cx) { + let v = this.active.pop_front().map(ActiveItem::take); + Poll::Ready(v) + } else { + for f in this.active.iter_mut().skip(1) { + f.check_ready(cx); + } + Poll::Pending } + } else if this.source.is_done() { + Poll::Ready(None) + } else { Poll::Pending } - } else if this.source.is_done() { - Poll::Ready(None) - } else { - Poll::Pending } - } - fn size_hint(&self) -> (usize, Option) { - let in_progress = self.active.len(); - let (lower, upper) = self.source.size_hint(); - ( - lower.saturating_add(in_progress), - upper.and_then(|u| u.checked_add(in_progress)), - ) + fn size_hint(&self) -> (usize, Option) { + let in_progress = self.active.len(); + let (lower, upper) = self.source.size_hint(); + ( + lower.saturating_add(in_progress), + upper.and_then(|u| u.checked_add(in_progress)), + ) + } } } -impl ExactSizeStream for SequentialFutures -where - S: Stream + Send + ExactSizeStream, - F: IntoFuture, -{ -} +/// Both joins use executor tasks to drive futures to completion. Much faster than single-threaded +/// version, so this is what we want to use in release/prod mode. +#[cfg(feature = "multi-threading")] +mod multi_thread { + use futures::future::BoxFuture; + use tracing::{Instrument, Span}; -#[cfg(all(test, unit_test))] -mod test { - use std::{ - convert::Infallible, - iter::once, - num::NonZeroUsize, - ptr::null, - sync::{Arc, Mutex}, - task::{Context, Poll, Waker}, - }; + use super::*; - use futures::{ - future::{lazy, BoxFuture}, - stream::{iter, poll_fn, poll_immediate, repeat_with}, - Future, StreamExt, - }; + #[cfg(feature = "shuttle")] + mod shuttle_spawner { + use shuttle_crate::{ + future, + future::{JoinError, JoinHandle}, + }; - use crate::seq_join::{seq_join, seq_try_join_all}; + use super::*; - async fn immediate(count: u32) { - let capacity = NonZeroUsize::new(3).unwrap(); - let values = seq_join(capacity, iter((0..count).map(|i| async move { i }))) - .collect::>() - .await; - assert_eq!((0..count).collect::>(), values); - } + /// Spawner implementation for Shuttle framework to run tests in parallel + pub(super) struct ShuttleSpawner; - #[tokio::test] - async fn within_capacity() { - immediate(2).await; - immediate(1).await; + unsafe impl async_scoped::spawner::Spawner for ShuttleSpawner + where + T: Send + 'static, + { + type FutureOutput = Result; + type SpawnHandle = JoinHandle; + + fn spawn + Send + 'static>(&self, f: F) -> Self::SpawnHandle { + future::spawn(f) + } + } + + unsafe impl async_scoped::spawner::Blocker for ShuttleSpawner { + fn block_on>(&self, f: F) -> T { + future::block_on(f) + } + } } - #[tokio::test] - async fn over_capacity() { - immediate(10).await; + #[cfg(feature = "shuttle")] + type Spawner<'fut, T> = async_scoped::Scope<'fut, T, shuttle_spawner::ShuttleSpawner>; + #[cfg(not(feature = "shuttle"))] + type Spawner<'fut, T> = TokioScope<'fut, T>; + + unsafe fn create_spawner<'fut, T: Send + 'static>() -> Spawner<'fut, T> { + #[cfg(feature = "shuttle")] + return async_scoped::Scope::create(shuttle_spawner::ShuttleSpawner); + #[cfg(not(feature = "shuttle"))] + return TokioScope::create(Tokio); } - #[tokio::test] - async fn out_of_order() { - let capacity = NonZeroUsize::new(3).unwrap(); - let barrier = tokio::sync::Barrier::new(2); - let unresolved: BoxFuture<'_, u32> = Box::pin(async { - barrier.wait().await; - 0 - }); - let it = once(unresolved) - .chain((1..4_u32).map(|i| -> BoxFuture<'_, u32> { Box::pin(async move { i }) })); - let mut seq_futures = seq_join(capacity, iter(it)); + #[pin_project] + #[must_use = "Futures do nothing, unless polled"] + pub struct SequentialFutures<'fut, S, F> + where + S: Stream + Send + 'fut, + F: IntoFuture, + <::IntoFuture as Future>::Output: Send + 'static, + { + #[pin] + spawner: Spawner<'fut, F::Output>, + #[pin] + source: futures::stream::Fuse, + capacity: usize, + } - assert_eq!( - Some(Poll::Pending), - poll_immediate(&mut seq_futures).next().await - ); - barrier.wait().await; - assert_eq!(vec![0, 1, 2, 3], seq_futures.collect::>().await); + impl SequentialFutures<'_, S, F> + where + S: Stream + Send, + F: IntoFuture, + <::IntoFuture as Future>::Output: Send + 'static, + { + pub fn new(active: NonZeroUsize, source: S) -> Self { + SequentialFutures { + spawner: unsafe { create_spawner() }, + source: source.fuse(), + capacity: active.get(), + } + } } - #[tokio::test] - async fn join_success() { - fn f(v: T) -> impl Future> { - lazy(move |_| Ok(v)) + impl<'fut, S, F> Stream for SequentialFutures<'fut, S, F> + where + S: Stream + Send, + F: IntoFuture, + ::IntoFuture: Send + 'fut, + <::IntoFuture as Future>::Output: Send + 'static, + { + type Item = F::Output; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + // Draw more values from the input, up to the capacity. + while this.spawner.remaining() < *this.capacity { + if let Poll::Ready(Some(f)) = this.source.as_mut().poll_next(cx) { + // Making futures cancellable is critical to avoid hangs. + // if one of them panics, unwinding causes spawner to drop and, in turn, + // it blocks the thread to await all pending futures completion. If there is + // a dependency between futures, pending one will never complete. + // Cancellable futures will be cancelled when spawner is dropped which is + // the behavior we want. + this.spawner + .spawn_cancellable(f.into_future().instrument(Span::current()), || { + panic!("cancelled") + }); + } else { + break; + } + } + + // Poll spawner if it has work to do. If both source and spawner are empty, we're done + if this.spawner.remaining() > 0 { + this.spawner.as_mut().poll_next(cx).map(|v| match v { + Some(Ok(v)) => Some(v), + Some(Err(_)) => panic!("task is cancelled"), + None => None, + }) + } else if this.source.is_done() { + Poll::Ready(None) + } else { + Poll::Pending + } } - let active = NonZeroUsize::new(10).unwrap(); - let res = seq_try_join_all(active, (1..5).map(f)).await.unwrap(); - assert_eq!((1..5).collect::>(), res); + fn size_hint(&self) -> (usize, Option) { + let in_progress = self.spawner.remaining(); + let (lower, upper) = self.source.size_hint(); + ( + lower.saturating_add(in_progress), + upper.and_then(|u| u.checked_add(in_progress)), + ) + } } - #[tokio::test] - async fn try_join_early_abort() { - const ERROR: &str = "error message"; - fn f(i: u32) -> impl Future> { - lazy(move |_| match i { - 1 => Ok(1), - 2 => Err(ERROR), - _ => panic!("should have aborted earlier"), - }) - } + /// TODO: change it to impl Future once https://github.com/rust-lang/rust/pull/115822 is + /// available in stable Rust. + pub(super) fn parallel_join<'fut, I, F, O, E>(iterable: I) -> BoxFuture<'fut, Result, E>> + where + I: IntoIterator + Send, + F: Future> + Send + 'fut, + O: Send + 'static, + E: Send + 'static, + { + // TODO: implement spawner for shuttle + let mut scope = { + let iter = iterable.into_iter(); + // SAFETY: scope object does not escape this function. All futures are driven to + // completion inside it or cancelled if a panic occurs. + let mut scope = unsafe { create_spawner() }; + for element in iter { + // it is important to make those cancellable. + // TODO: elaborate why + scope.spawn_cancellable(element.instrument(Span::current()), || { + panic!("Future is cancelled.") + }); + } + scope + }; - let active = NonZeroUsize::new(10).unwrap(); - let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); - assert_eq!(err, ERROR); + Box::pin(async move { + let mut result = Vec::with_capacity(scope.len()); + while let Some(item) = scope.next().await { + // join error is nothing we can do about + result.push(item.unwrap()?) + } + Ok(result) + }) } +} + +#[cfg(all(test, unit_test, not(feature = "multi-threading")))] +mod local_test { + use std::{ + num::NonZeroUsize, + ptr::null, + sync::{Arc, Mutex}, + task::{Context, Poll, Waker}, + }; + + use futures::{ + future::lazy, + stream::{poll_fn, repeat_with}, + StreamExt, + }; + + use super::*; fn fake_waker() -> Waker { use std::task::{RawWaker, RawWakerVTable}; @@ -365,8 +522,8 @@ mod test { } /// A fully synchronous test with a synthetic stream, all the way to the end. - #[test] - fn complete_stream() { + #[tokio::test] + async fn complete_stream() { const VALUE: u32 = 20; const COUNT: usize = 7; let capacity = NonZeroUsize::new(3).unwrap(); @@ -408,3 +565,106 @@ mod test { assert!(matches!(res, Poll::Ready(None))); } } + +#[cfg(all(test, unit_test))] +mod test { + use std::{convert::Infallible, iter::once}; + + use futures::{ + future::{lazy, BoxFuture}, + stream::{iter, poll_immediate}, + Future, StreamExt, + }; + + use super::*; + + async fn immediate(count: u32) { + let capacity = NonZeroUsize::new(3).unwrap(); + let values = seq_join(capacity, iter((0..count).map(|i| async move { i }))) + .collect::>() + .await; + assert_eq!((0..count).collect::>(), values); + } + + #[tokio::test] + async fn within_capacity() { + immediate(2).await; + immediate(1).await; + } + + #[tokio::test] + async fn over_capacity() { + immediate(10).await; + } + + #[tokio::test] + async fn out_of_order() { + let capacity = NonZeroUsize::new(3).unwrap(); + let barrier = tokio::sync::Barrier::new(2); + let unresolved: BoxFuture<'_, u32> = Box::pin(async { + barrier.wait().await; + 0 + }); + let it = once(unresolved) + .chain((1..4_u32).map(|i| -> BoxFuture<'_, u32> { Box::pin(async move { i }) })); + let mut seq_futures = seq_join(capacity, iter(it)); + + assert_eq!( + Some(Poll::Pending), + poll_immediate(&mut seq_futures).next().await + ); + barrier.wait().await; + assert_eq!(vec![0, 1, 2, 3], seq_futures.collect::>().await); + } + + #[tokio::test] + async fn join_success() { + fn f(v: T) -> impl Future> { + lazy(move |_| Ok(v)) + } + + let active = NonZeroUsize::new(10).unwrap(); + let res = seq_try_join_all(active, (1..5).map(f)).await.unwrap(); + assert_eq!((1..5).collect::>(), res); + } + + /// This test has to use multi-threaded runtime because early return causes `TryCollect` to be + /// dropped and the remaining futures to be cancelled which can only happen if there is more + /// than one thread available. + /// + /// This behavior is only applicable when `seq_try_join_all` uses more than one thread, for + /// maintenance reasons, we use it even parallelism is turned off. + #[tokio::test(flavor = "multi_thread")] + async fn try_join_early_abort() { + const ERROR: &str = "error message"; + fn f(i: u32) -> impl Future> { + lazy(move |_| match i { + 1 => Ok(1), + 2 => Err(ERROR), + _ => panic!("should have aborted earlier"), + }) + } + + let active = NonZeroUsize::new(10).unwrap(); + let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); + assert_eq!(err, ERROR); + } + + #[tokio::test(flavor = "multi_thread")] + async fn does_not_block_on_error() { + const ERROR: &str = "returning early is safe"; + use std::pin::Pin; + + fn f(i: u32) -> Pin> + Send>> { + match i { + 1 => Box::pin(lazy(move |_| Ok(1))), + 2 => Box::pin(lazy(move |_| Err(ERROR))), + _ => Box::pin(futures::future::pending()), + } + } + + let active = NonZeroUsize::new(10).unwrap(); + let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err(); + assert_eq!(err, ERROR); + } +}