From 12ce924fb9c1ffe0340b979fefa00d13ebf631c3 Mon Sep 17 00:00:00 2001 From: "M.Amin Rayej" Date: Sun, 14 Jan 2024 14:04:30 +0330 Subject: [PATCH] task: add `JoinSet::try_join_next` (#6280) --- tokio/src/task/join_set.rs | 55 ++++++++++++++++++++- tokio/src/util/idle_notified_set.rs | 28 +++++++++++ tokio/tests/task_join_set.rs | 77 +++++++++++++++++++++++++++++ 3 files changed, 159 insertions(+), 1 deletion(-) diff --git a/tokio/src/task/join_set.rs b/tokio/src/task/join_set.rs index 4eb15a24d5f..7aace14d850 100644 --- a/tokio/src/task/join_set.rs +++ b/tokio/src/task/join_set.rs @@ -12,7 +12,7 @@ use std::task::{Context, Poll}; use crate::runtime::Handle; #[cfg(tokio_unstable)] use crate::task::Id; -use crate::task::{AbortHandle, JoinError, JoinHandle, LocalSet}; +use crate::task::{unconstrained, AbortHandle, JoinError, JoinHandle, LocalSet}; use crate::util::IdleNotifiedSet; /// A collection of tasks spawned on a Tokio runtime. @@ -306,6 +306,59 @@ impl JoinSet { crate::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await } + /// Tries to join one of the tasks in the set that has completed and return its output. + /// + /// Returns `None` if the set is empty. + pub fn try_join_next(&mut self) -> Option> { + // Loop over all notified `JoinHandle`s to find one that's ready, or until none are left. + loop { + let mut entry = self.inner.try_pop_notified()?; + + let res = entry.with_value_and_context(|jh, ctx| { + // Since this function is not async and cannot be forced to yield, we should + // disable budgeting when we want to check for the `JoinHandle` readiness. + Pin::new(&mut unconstrained(jh)).poll(ctx) + }); + + if let Poll::Ready(res) = res { + let _entry = entry.remove(); + + return Some(res); + } + } + } + + /// Tries to join one of the tasks in the set that has completed and return its output, + /// along with the [task ID] of the completed task. + /// + /// Returns `None` if the set is empty. + /// + /// When this method returns an error, then the id of the task that failed can be accessed + /// using the [`JoinError::id`] method. + /// + /// [task ID]: crate::task::Id + /// [`JoinError::id`]: fn@crate::task::JoinError::id + #[cfg(tokio_unstable)] + #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] + pub fn try_join_next_with_id(&mut self) -> Option> { + // Loop over all notified `JoinHandle`s to find one that's ready, or until none are left. + loop { + let mut entry = self.inner.try_pop_notified()?; + + let res = entry.with_value_and_context(|jh, ctx| { + // Since this function is not async and cannot be forced to yield, we should + // disable budgeting when we want to check for the `JoinHandle` readiness. + Pin::new(&mut unconstrained(jh)).poll(ctx) + }); + + if let Poll::Ready(res) = res { + let entry = entry.remove(); + + return Some(res.map(|output| (entry.id(), output))); + } + } + } + /// Aborts all tasks and waits for them to finish shutting down. /// /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in diff --git a/tokio/src/util/idle_notified_set.rs b/tokio/src/util/idle_notified_set.rs index 430f2e7568b..bd9c2ef1bbc 100644 --- a/tokio/src/util/idle_notified_set.rs +++ b/tokio/src/util/idle_notified_set.rs @@ -203,6 +203,34 @@ impl IdleNotifiedSet { Some(EntryInOneOfTheLists { entry, set: self }) } + /// Tries to pop an entry from the notified list to poll it. The entry is moved to + /// the idle list atomically. + pub(crate) fn try_pop_notified(&mut self) -> Option> { + // We don't decrement the length because this call moves the entry to + // the idle list rather than removing it. + if self.length == 0 { + // Fast path. + return None; + } + + let mut lock = self.lists.lock(); + + // Pop the entry, returning None if empty. + let entry = lock.notified.pop_back()?; + + lock.idle.push_front(entry.clone()); + + // Safety: We are holding the lock. + entry.my_list.with_mut(|ptr| unsafe { + *ptr = List::Idle; + }); + + drop(lock); + + // Safety: We just put the entry in the idle list, so it is in one of the lists. + Some(EntryInOneOfTheLists { entry, set: self }) + } + /// Call a function on every element in this list. pub(crate) fn for_each(&mut self, mut func: F) { fn get_ptrs(list: &mut LinkedList, ptrs: &mut Vec<*mut T>) { diff --git a/tokio/tests/task_join_set.rs b/tokio/tests/task_join_set.rs index bed9b7dad82..8a42be17b49 100644 --- a/tokio/tests/task_join_set.rs +++ b/tokio/tests/task_join_set.rs @@ -227,3 +227,80 @@ async fn join_set_coop() { assert!(coop_count >= 1); assert_eq!(count, TASK_NUM); } + +#[tokio::test(flavor = "current_thread")] +async fn try_join_next() { + const TASK_NUM: u32 = 1000; + + let (send, recv) = tokio::sync::watch::channel(()); + + let mut set = JoinSet::new(); + + for _ in 0..TASK_NUM { + let mut recv = recv.clone(); + set.spawn(async move { recv.changed().await.unwrap() }); + } + drop(recv); + + assert!(set.try_join_next().is_none()); + + send.send_replace(()); + send.closed().await; + + let mut count = 0; + loop { + match set.try_join_next() { + Some(Ok(())) => { + count += 1; + } + Some(Err(err)) => panic!("failed: {}", err), + None => { + break; + } + } + } + + assert_eq!(count, TASK_NUM); +} + +#[cfg(tokio_unstable)] +#[tokio::test(flavor = "current_thread")] +async fn try_join_next_with_id() { + const TASK_NUM: u32 = 1000; + + let (send, recv) = tokio::sync::watch::channel(()); + + let mut set = JoinSet::new(); + let mut spawned = std::collections::HashSet::with_capacity(TASK_NUM as usize); + + for _ in 0..TASK_NUM { + let mut recv = recv.clone(); + let handle = set.spawn(async move { recv.changed().await.unwrap() }); + + spawned.insert(handle.id()); + } + drop(recv); + + assert!(set.try_join_next_with_id().is_none()); + + send.send_replace(()); + send.closed().await; + + let mut count = 0; + let mut joined = std::collections::HashSet::with_capacity(TASK_NUM as usize); + loop { + match set.try_join_next_with_id() { + Some(Ok((id, ()))) => { + count += 1; + joined.insert(id); + } + Some(Err(err)) => panic!("failed: {}", err), + None => { + break; + } + } + } + + assert_eq!(count, TASK_NUM); + assert_eq!(joined, spawned); +}