Skip to content

Commit

Permalink
tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshuawuyts committed Jun 23, 2023
1 parent 962a8a2 commit f762443
Showing 1 changed file with 116 additions and 105 deletions.
221 changes: 116 additions & 105 deletions src/future/try_join/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,21 @@ macro_rules! unsafe_poll {
.map_unchecked_mut(|t| t.deref_mut())
.poll(&mut $cx)
} {
$this.outputs.$fut_idx.write(value);
*$this.completed += 1;
$this.state[$fut_idx].set_ready();
*$this.completed += 1;
// SAFETY: the future state has been changed to "ready" which
// means we'll no longer poll the future, so it's safe to drop
unsafe { ManuallyDrop::drop($futures.$fut_name.as_mut().get_unchecked_mut()) };

// Check the value, short-circuit on error.
match value {
Ok(value) => $this.outputs.$fut_idx.write(value),
Err(err) => {
// The future should no longer be polled after we're done here
*$this.consumed = true;
return Poll::Ready(Err(err));
}
};
}
}
unsafe_poll!(@inner $iteration, $this, $futures, $cx, $($F)* | $($rest)*);
Expand Down Expand Up @@ -174,10 +183,11 @@ macro_rules! impl_try_join_tuple {
state: PollArray<{$mod_name::LEN}>,
wakers: WakerArray<{$mod_name::LEN}>,
completed: usize,
consumed: bool,
_phantom: PhantomData<Err>,
}

impl<$($F, $T)+, Err> Debug for $StructName<$($F, $T,)+ Err>
impl<$($F, $T,)+ Err> Debug for $StructName<$($F, $T,)+ Err>
where
$( $F: Future + Debug, )+
{
Expand All @@ -191,20 +201,19 @@ macro_rules! impl_try_join_tuple {
#[allow(unused_mut)]
#[allow(unused_parens)]
#[allow(unused_variables)]
impl<$($F: Future, $T,)+ Err> Future for $StructName<$($F, $T,)+ Err>
impl<$($F, $T,)+ Err> Future for $StructName<$($F, $T,)+ Err>
where $(
$F: Future<Output = Result<$T, Err>>
$F: Future<Output = Result<$T, Err>>,
)+ {
type Output = Result<($($F::Output,)+), Err>;
type Output = Result<($($T,)+), Err>;

fn poll(
self: Pin<&mut Self>, cx: &mut Context<'_>
) -> Poll<Self::Output> {
const LEN: usize = $mod_name::LEN;

let mut this = self.project();
let all_completed = !(*this.completed == LEN);
assert!(all_completed, "Futures must not be polled after completing");
assert!(!*this.consumed, "Futures must not be polled after completing");

let mut futures = this.futures.project();

Expand Down Expand Up @@ -233,13 +242,14 @@ macro_rules! impl_try_join_tuple {

if *this.completed == LEN {
let out = {
let mut out = ($(MaybeUninit::<$F::Output>::uninit(),)+);
let mut out = ($(MaybeUninit::<$T>::uninit(),)+);
core::mem::swap(&mut out, this.outputs);
let ($($F,)+) = out;
unsafe { ($($F.assume_init(),)+) }
};

this.state.set_all_completed();
*this.consumed = true;

return Poll::Ready(Ok(out));
}
Expand All @@ -251,7 +261,7 @@ macro_rules! impl_try_join_tuple {
}

#[pinned_drop]
impl<$($F, $T)+, Err> PinnedDrop for $StructName<$($F, $T,)+ Err> {
impl<$($F, $T,)+ Err> PinnedDrop for $StructName<$($F, $T,)+ Err> {
fn drop(self: Pin<&mut Self>) {
let this = self.project();

Expand All @@ -265,7 +275,7 @@ macro_rules! impl_try_join_tuple {
}

#[allow(unused_parens)]
impl<$($F, $T)+, Err> TryJoinTrait for ($($F,)+)
impl<$($F, $T,)+ Err> TryJoinTrait for ($($F,)+)
where $(
$F: IntoFuture<Output = Result<$T, Err>>,
)+ {
Expand All @@ -283,6 +293,7 @@ macro_rules! impl_try_join_tuple {
outputs: ($(MaybeUninit::<$T>::uninit(),)+),
wakers: WakerArray::new(),
completed: 0,
consumed: false,
_phantom: PhantomData,
}
}
Expand All @@ -292,97 +303,97 @@ macro_rules! impl_try_join_tuple {

impl_try_join_tuple! { try_join0 TryJoin0 }
impl_try_join_tuple! { try_join_1 TryJoin1 (A ResA) }
// impl_try_join_tuple! { try_join_2 TryJoin2 (A ResA) (B ResB) }
// impl_try_join_tuple! { try_join_3 TryJoin3 (A ResA) (B ResB) (C ResC) }
// impl_try_join_tuple! { try_join_4 TryJoin4 (A ResA) (B ResB) (C ResC) (D ResD) }
// impl_try_join_tuple! { try_join_5 TryJoin5 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) }
// impl_try_join_tuple! { try_join_6 TryJoin6 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) }
// impl_try_join_tuple! { try_join_7 TryJoin7 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) }
// impl_try_join_tuple! { try_join_8 TryJoin8 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) }
// impl_try_join_tuple! { try_join_9 TryJoin9 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) }
// impl_try_join_tuple! { try_join_10 TryJoin10 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) }
// impl_try_join_tuple! { try_join_11 TryJoin11 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) }
// impl_try_join_tuple! { try_join_12 TryJoin12 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) (L ResL) }
//
// #[cfg(test)]
// mod test {
// use super::*;

// use std::convert::Infallible;
// use std::future;
// use std::io::{self, Error, ErrorKind};

// #[test]
// fn all_ok() {
// futures_lite::future::block_on(async {
// let a = async { Ok::<_, Infallible>("aaaa") };
// let b = async { Ok::<_, Infallible>(1) };
// let c = async { Ok::<_, Infallible>('z') };

// let result = (a, b, c).try_join().await;
// assert_eq!(result, Ok(("aaaa", 1, 'z')));
// })
// }

// #[test]
// fn one_err() {
// futures_lite::future::block_on(async {
// let err = Error::new(ErrorKind::Other, "oh no");
// let res: io::Result<(_, char)> = (future::ready(Ok("hello")), future::ready(Err(err)))
// .try_join()
// .await;
// assert_eq!(res.unwrap_err().to_string(), String::from("oh no"));
// })
// }

// #[test]
// fn issue_135_resume_after_completion() {
// use futures_lite::future::yield_now;
// futures_lite::future::block_on(async {
// let ok = async { Ok::<_, ()>(()) };
// let err = async {
// yield_now().await;
// Ok::<_, ()>(())
// };

// let res = (ok, err).try_join().await;

// assert_eq!(res.unwrap(), ((), ()));
// });
// }

// #[test]
// fn does_not_leak_memory() {
// use core::cell::RefCell;
// use futures_lite::future::pending;

// thread_local! {
// static NOT_LEAKING: RefCell<bool> = RefCell::new(false);
// };

// struct FlipFlagAtDrop;
// impl Drop for FlipFlagAtDrop {
// fn drop(&mut self) {
// NOT_LEAKING.with(|v| {
// *v.borrow_mut() = true;
// });
// }
// }

// futures_lite::future::block_on(async {
// // this will trigger Miri if we don't drop the memory
// let string = future::ready("memory leak".to_owned());

// // this will not flip the thread_local flag if we don't drop the memory
// let flip = future::ready(FlipFlagAtDrop);

// let leak = (string, flip, pending::<u8>()).join();

// _ = futures_lite::future::poll_once(leak).await;
// });

// NOT_LEAKING.with(|flag| {
// assert!(*flag.borrow());
// })
// }
// }
impl_try_join_tuple! { try_join_2 TryJoin2 (A ResA) (B ResB) }
impl_try_join_tuple! { try_join_3 TryJoin3 (A ResA) (B ResB) (C ResC) }
impl_try_join_tuple! { try_join_4 TryJoin4 (A ResA) (B ResB) (C ResC) (D ResD) }
impl_try_join_tuple! { try_join_5 TryJoin5 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) }
impl_try_join_tuple! { try_join_6 TryJoin6 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) }
impl_try_join_tuple! { try_join_7 TryJoin7 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) }
impl_try_join_tuple! { try_join_8 TryJoin8 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) }
impl_try_join_tuple! { try_join_9 TryJoin9 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) }
impl_try_join_tuple! { try_join_10 TryJoin10 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) }
impl_try_join_tuple! { try_join_11 TryJoin11 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) }
impl_try_join_tuple! { try_join_12 TryJoin12 (A ResA) (B ResB) (C ResC) (D ResD) (E ResE) (F ResF) (G ResG) (H ResH) (I ResI) (J ResJ) (K ResK) (L ResL) }

#[cfg(test)]
mod test {
use super::*;

use std::convert::Infallible;
use std::future;
use std::io::{self, Error, ErrorKind};

#[test]
fn all_ok() {
futures_lite::future::block_on(async {
let a = async { Ok::<_, Infallible>("aaaa") };
let b = async { Ok::<_, Infallible>(1) };
let c = async { Ok::<_, Infallible>('z') };

let result = (a, b, c).try_join().await;
assert_eq!(result, Ok(("aaaa", 1, 'z')));
})
}

#[test]
fn one_err() {
futures_lite::future::block_on(async {
let err = Error::new(ErrorKind::Other, "oh no");
let res: io::Result<(_, char)> = (future::ready(Ok("hello")), future::ready(Err(err)))
.try_join()
.await;
assert_eq!(res.unwrap_err().to_string(), String::from("oh no"));
})
}

#[test]
fn issue_135_resume_after_completion() {
use futures_lite::future::yield_now;
futures_lite::future::block_on(async {
let ok = async { Ok::<_, ()>(()) };
let err = async {
yield_now().await;
Ok::<_, ()>(())
};

let res = (ok, err).try_join().await;

assert_eq!(res.unwrap(), ((), ()));
});
}

// #[test]
// fn does_not_leak_memory() {
// use core::cell::RefCell;
// use futures_lite::future::pending;

// thread_local! {
// static NOT_LEAKING: RefCell<bool> = RefCell::new(false);
// };

// struct FlipFlagAtDrop;
// impl Drop for FlipFlagAtDrop {
// fn drop(&mut self) {
// NOT_LEAKING.with(|v| {
// *v.borrow_mut() = true;
// });
// }
// }

// futures_lite::future::block_on(async {
// // this will trigger Miri if we don't drop the memory
// let string = future::ready("memory leak".to_owned());

// // this will not flip the thread_local flag if we don't drop the memory
// let flip = future::ready(FlipFlagAtDrop);

// let leak = (string, flip, pending::<u8>()).try_join();

// _ = futures_lite::future::poll_once(leak).await;
// });

// NOT_LEAKING.with(|flag| {
// assert!(*flag.borrow());
// })
// }
}

0 comments on commit f762443

Please sign in to comment.