Skip to content

Commit

Permalink
Make seq_join tests Miri compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Jan 19, 2024
1 parent e355c39 commit e47b15a
Showing 1 changed file with 57 additions and 47 deletions.
104 changes: 57 additions & 47 deletions ipa-core/src/seq_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ mod test {
};

use super::*;
use crate::test_executor::run;

async fn immediate(count: u32) {
let capacity = NonZeroUsize::new(3).unwrap();
Expand Down Expand Up @@ -632,8 +633,8 @@ mod test {
///
/// This behavior is only applicable when `seq_try_join_all` uses more than one thread, for
/// maintenance reasons, we use it even parallelism is turned off.
#[tokio::test(flavor = "multi_thread")]
async fn try_join_early_abort() {
#[test]
fn try_join_early_abort() {
const ERROR: &str = "error message";
fn f(i: u32) -> impl Future<Output = Result<u32, &'static str>> {
lazy(move |_| match i {
Expand All @@ -643,13 +644,15 @@ mod test {
})
}

let active = NonZeroUsize::new(10).unwrap();
let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err();
assert_eq!(err, ERROR);
run(|| async {
let active = NonZeroUsize::new(10).unwrap();
let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err();
assert_eq!(err, ERROR);
});
}

#[tokio::test(flavor = "multi_thread")]
async fn does_not_block_on_error() {
#[test]
fn does_not_block_on_error() {
const ERROR: &str = "returning early is safe";
use std::pin::Pin;

Expand All @@ -661,60 +664,67 @@ mod test {
}
}

let active = NonZeroUsize::new(10).unwrap();
let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err();
assert_eq!(err, ERROR);
run(|| async {
let active = NonZeroUsize::new(10).unwrap();
let err = seq_try_join_all(active, (1..=3).map(f)).await.unwrap_err();
assert_eq!(err, ERROR);
});
}

/// This test demonstrates that forgetting the future returned by `parallel_join` is not safe and will cause
/// use-after-free safety error.
///
/// TODO: Run tests with multi-threading runtimes in CI
#[tokio::test(flavor = "multi_thread")]
#[test]
#[cfg(feature = "multi-threading")]
async fn parallel_join_forget_is_not_safe() {
fn parallel_join_forget_is_not_safe() {
use std::mem::ManuallyDrop;

use futures::future::poll_immediate;

use crate::{seq_join::multi_thread::parallel_join, sync::Arc};

const N: usize = 24;
let borrow_from_me = Arc::new(vec![1, 2, 3]);
let start = Arc::new(tokio::sync::Barrier::new(N + 1));
// counts how many tasks have accessed `borrow_from_me` after it was destroyed.
// this test expects all tasks to access `borrow_from_me` at least once.
let bad_accesses = Arc::new(tokio::sync::Barrier::new(N + 1));

let iterable = (0..N)
.map(|_| {
let borrowed = Arc::downgrade(&borrow_from_me);
let start = start.clone();
let bad_access = bad_accesses.clone();
async move {
start.wait().await;
// at this point, the parent future is forgotten and borrowed should point to nothing
for _ in 0..100 {
if borrowed.upgrade().is_none() {
bad_access.wait().await;
break;
run(|| async {
const N: usize = 24;
let borrow_from_me = Arc::new(vec![1, 2, 3]);
let start = Arc::new(tokio::sync::Barrier::new(N + 1));
// counts how many tasks have accessed `borrow_from_me` after it was destroyed.
// this test expects all tasks to access `borrow_from_me` at least once.
let bad_accesses = Arc::new(tokio::sync::Barrier::new(N + 1));

let futures = (0..N)
.map(|_| {
let borrowed = Arc::downgrade(&borrow_from_me);
let start = start.clone();
let bad_access = bad_accesses.clone();
async move {
start.wait().await;
// at this point, the parent future is forgotten and borrowed should point to nothing
for _ in 0..100 {
if borrowed.upgrade().is_none() {
bad_access.wait().await;
break;
}
tokio::task::yield_now().await;
}
tokio::task::yield_now().await;
Ok::<(), ()>(())
}
Ok::<(), ()>(())
}
})
.collect::<Vec<_>>();
})
.collect::<Vec<_>>();

let mut f = parallel_join(futures);
poll_immediate(&mut f).await;
start.wait().await;

let mut f = parallel_join(iterable);
poll_immediate(&mut f).await;
start.wait().await;
// forgetting f does not mean that futures spawned by `parallel_join` will be cancelled.
let guard = ManuallyDrop::new(f);

// forgetting f does not mean that futures spawned by `parallel_join` will be cancelled.
std::mem::forget(f);
// Async executor will still be polling futures and they will try to follow this pointer.
drop(borrow_from_me);

// Async executor will still be polling futures and they will try to follow this pointer.
drop(borrow_from_me);
// this test should terminate because all tasks should access `borrow_from_me` at least once.
bad_accesses.wait().await;

// this test should terminate because all tasks should access `borrow_from_me` at least once.
bad_accesses.wait().await;
// do not leak memory
let _ = ManuallyDrop::into_inner(guard);
})
}
}

0 comments on commit e47b15a

Please sign in to comment.