Skip to content

Commit

Permalink
task: add JoinSet::try_join_next (#6280)
Browse files Browse the repository at this point in the history
  • Loading branch information
maminrayej authored Jan 14, 2024
1 parent e4f9bcb commit 12ce924
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 1 deletion.
55 changes: 54 additions & 1 deletion tokio/src/task/join_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::task::{Context, Poll};
use crate::runtime::Handle;
#[cfg(tokio_unstable)]
use crate::task::Id;
use crate::task::{AbortHandle, JoinError, JoinHandle, LocalSet};
use crate::task::{unconstrained, AbortHandle, JoinError, JoinHandle, LocalSet};
use crate::util::IdleNotifiedSet;

/// A collection of tasks spawned on a Tokio runtime.
Expand Down Expand Up @@ -306,6 +306,59 @@ impl<T: 'static> JoinSet<T> {
crate::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await
}

/// Tries to join one of the tasks in the set that has completed and return its output.
///
/// Returns `None` if the set is empty.
pub fn try_join_next(&mut self) -> Option<Result<T, JoinError>> {
// Loop over all notified `JoinHandle`s to find one that's ready, or until none are left.
loop {
let mut entry = self.inner.try_pop_notified()?;

let res = entry.with_value_and_context(|jh, ctx| {
// Since this function is not async and cannot be forced to yield, we should
// disable budgeting when we want to check for the `JoinHandle` readiness.
Pin::new(&mut unconstrained(jh)).poll(ctx)
});

if let Poll::Ready(res) = res {
let _entry = entry.remove();

return Some(res);
}
}
}

/// Tries to join one of the tasks in the set that has completed and return its output,
/// along with the [task ID] of the completed task.
///
/// Returns `None` if the set is empty.
///
/// When this method returns an error, then the id of the task that failed can be accessed
/// using the [`JoinError::id`] method.
///
/// [task ID]: crate::task::Id
/// [`JoinError::id`]: fn@crate::task::JoinError::id
#[cfg(tokio_unstable)]
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
pub fn try_join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
// Loop over all notified `JoinHandle`s to find one that's ready, or until none are left.
loop {
let mut entry = self.inner.try_pop_notified()?;

let res = entry.with_value_and_context(|jh, ctx| {
// Since this function is not async and cannot be forced to yield, we should
// disable budgeting when we want to check for the `JoinHandle` readiness.
Pin::new(&mut unconstrained(jh)).poll(ctx)
});

if let Poll::Ready(res) = res {
let entry = entry.remove();

return Some(res.map(|output| (entry.id(), output)));
}
}
}

/// Aborts all tasks and waits for them to finish shutting down.
///
/// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
Expand Down
28 changes: 28 additions & 0 deletions tokio/src/util/idle_notified_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,34 @@ impl<T> IdleNotifiedSet<T> {
Some(EntryInOneOfTheLists { entry, set: self })
}

/// Tries to pop an entry from the notified list to poll it. The entry is moved to
/// the idle list atomically.
pub(crate) fn try_pop_notified(&mut self) -> Option<EntryInOneOfTheLists<'_, T>> {
// We don't decrement the length because this call moves the entry to
// the idle list rather than removing it.
if self.length == 0 {
// Fast path.
return None;
}

let mut lock = self.lists.lock();

// Pop the entry, returning None if empty.
let entry = lock.notified.pop_back()?;

lock.idle.push_front(entry.clone());

// Safety: We are holding the lock.
entry.my_list.with_mut(|ptr| unsafe {
*ptr = List::Idle;
});

drop(lock);

// Safety: We just put the entry in the idle list, so it is in one of the lists.
Some(EntryInOneOfTheLists { entry, set: self })
}

/// Call a function on every element in this list.
pub(crate) fn for_each<F: FnMut(&mut T)>(&mut self, mut func: F) {
fn get_ptrs<T>(list: &mut LinkedList<T>, ptrs: &mut Vec<*mut T>) {
Expand Down
77 changes: 77 additions & 0 deletions tokio/tests/task_join_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,80 @@ async fn join_set_coop() {
assert!(coop_count >= 1);
assert_eq!(count, TASK_NUM);
}

#[tokio::test(flavor = "current_thread")]
async fn try_join_next() {
const TASK_NUM: u32 = 1000;

let (send, recv) = tokio::sync::watch::channel(());

let mut set = JoinSet::new();

for _ in 0..TASK_NUM {
let mut recv = recv.clone();
set.spawn(async move { recv.changed().await.unwrap() });
}
drop(recv);

assert!(set.try_join_next().is_none());

send.send_replace(());
send.closed().await;

let mut count = 0;
loop {
match set.try_join_next() {
Some(Ok(())) => {
count += 1;
}
Some(Err(err)) => panic!("failed: {}", err),
None => {
break;
}
}
}

assert_eq!(count, TASK_NUM);
}

#[cfg(tokio_unstable)]
#[tokio::test(flavor = "current_thread")]
async fn try_join_next_with_id() {
const TASK_NUM: u32 = 1000;

let (send, recv) = tokio::sync::watch::channel(());

let mut set = JoinSet::new();
let mut spawned = std::collections::HashSet::with_capacity(TASK_NUM as usize);

for _ in 0..TASK_NUM {
let mut recv = recv.clone();
let handle = set.spawn(async move { recv.changed().await.unwrap() });

spawned.insert(handle.id());
}
drop(recv);

assert!(set.try_join_next_with_id().is_none());

send.send_replace(());
send.closed().await;

let mut count = 0;
let mut joined = std::collections::HashSet::with_capacity(TASK_NUM as usize);
loop {
match set.try_join_next_with_id() {
Some(Ok((id, ()))) => {
count += 1;
joined.insert(id);
}
Some(Err(err)) => panic!("failed: {}", err),
None => {
break;
}
}
}

assert_eq!(count, TASK_NUM);
assert_eq!(joined, spawned);
}

0 comments on commit 12ce924

Please sign in to comment.