diff --git a/tokio/tests/rt_poll_callbacks.rs b/tokio/tests/rt_poll_callbacks.rs index 14f2a249571..aecf079e513 100644 --- a/tokio/tests/rt_poll_callbacks.rs +++ b/tokio/tests/rt_poll_callbacks.rs @@ -2,7 +2,7 @@ #[cfg(tokio_unstable)] mod unstable { - use std::sync::{atomic::AtomicUsize, Arc}; + use std::sync::{atomic::AtomicUsize, Arc, Mutex}; use tokio::task::yield_now; @@ -12,12 +12,22 @@ mod unstable { let poll_stop_counter = Arc::new(AtomicUsize::new(0)); let poll_start = poll_start_counter.clone(); let poll_stop = poll_stop_counter.clone(); + + let before_task_poll_callback_task_id: Arc>> = + Arc::new(Mutex::new(None)); + let after_task_poll_callback_task_id: Arc>> = + Arc::new(Mutex::new(None)); + + let before_task_poll_callback_task_id_ref = Arc::clone(&before_task_poll_callback_task_id); + let after_task_poll_callback_task_id_ref = Arc::clone(&after_task_poll_callback_task_id); let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() - .on_before_task_poll(move |_meta| { + .on_before_task_poll(move |task_meta| { + before_task_poll_callback_task_id_ref.lock().unwrap().replace(task_meta.id()); poll_start_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); }) - .on_after_task_poll(move |_meta| { + .on_after_task_poll(move |task_meta| { + after_task_poll_callback_task_id_ref.lock().unwrap().replace(task_meta.id()); poll_stop_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); }) .build() @@ -27,7 +37,19 @@ mod unstable { yield_now().await; yield_now().await; }); + + let spawned_task_id = task.id(); + let _ = rt.block_on(task); + + assert_eq!( + before_task_poll_callback_task_id.lock().unwrap().unwrap(), + spawned_task_id + ); + assert_eq!( + after_task_poll_callback_task_id.lock().unwrap().unwrap(), + spawned_task_id + ); assert_eq!(poll_start.load(std::sync::atomic::Ordering::Relaxed), 4); assert_eq!(poll_stop.load(std::sync::atomic::Ordering::Relaxed), 4); } @@ -39,12 +61,27 @@ mod unstable { let poll_start = poll_start_counter.clone(); let poll_stop = poll_stop_counter.clone(); + let before_task_poll_callback_task_id: Arc>> = + Arc::new(Mutex::new(None)); + let after_task_poll_callback_task_id: Arc>> = + Arc::new(Mutex::new(None)); + + let before_task_poll_callback_task_id_ref = Arc::clone(&before_task_poll_callback_task_id); + let after_task_poll_callback_task_id_ref = Arc::clone(&after_task_poll_callback_task_id); let rt = tokio::runtime::Builder::new_current_thread() .enable_all() - .on_before_task_poll(move |_meta| { + .on_before_task_poll(move |task_meta| { + before_task_poll_callback_task_id_ref + .lock() + .unwrap() + .replace(task_meta.id()); poll_start_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); }) - .on_after_task_poll(move |_meta| { + .on_after_task_poll(move |task_meta| { + after_task_poll_callback_task_id_ref + .lock() + .unwrap() + .replace(task_meta.id()); poll_stop_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); }) .build() @@ -55,8 +92,19 @@ mod unstable { yield_now().await; yield_now().await; }); + + let spawned_task_id = task.id(); + let _ = rt.block_on(task); + assert_eq!( + before_task_poll_callback_task_id.lock().unwrap().unwrap(), + spawned_task_id + ); + assert_eq!( + after_task_poll_callback_task_id.lock().unwrap().unwrap(), + spawned_task_id + ); assert_eq!(poll_start.load(std::sync::atomic::Ordering::Relaxed), 4); assert_eq!(poll_stop.load(std::sync::atomic::Ordering::Relaxed), 4); }