Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cancelling a timer also aborts the cancelled task #292

Merged
merged 3 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crux_time/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ rust-version.workspace = true
typegen = ["crux_core/typegen"]

[dependencies]
chrono = { version = "0.4.38", features = ["serde"], optional = true }
crux_core = { version = "0.10.0", path = "../crux_core" }
pin-project-lite = "0.2.16"
serde = { workspace = true, features = ["derive"] }
chrono = { version = "0.4.38", features = ["serde"], optional = true }
thiserror = "1.0.65"

[dev-dependencies]
Expand Down
108 changes: 96 additions & 12 deletions crux_time/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ pub use instant::Instant;
use serde::{Deserialize, Serialize};

use crux_core::capability::{CapabilityContext, Operation};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::{
collections::HashSet,
future::Future,
sync::{
atomic::{AtomicUsize, Ordering},
LazyLock, Mutex,
},
task::Poll,
};

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
Expand All @@ -26,7 +34,7 @@ pub enum TimeRequest {
Clear { id: TimerId },
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TimerId(pub usize);

fn get_timer_id() -> TimerId {
Expand Down Expand Up @@ -127,7 +135,8 @@ where
let this = self.clone();

async move {
context.update_app(callback(this.notify_at_async(tid, instant).await));
let response = this.notify_at_async(tid, instant).await;
context.update_app(callback(response));
}
});

Expand All @@ -136,10 +145,15 @@ where

/// Ask to receive a notification when the specified [`Instant`] has arrived.
/// This is an async call to use with [`crux_core::compose::Compose`].
pub async fn notify_at_async(&self, id: TimerId, instant: Instant) -> TimeResponse {
self.context
.request_from_shell(TimeRequest::NotifyAt { id, instant })
.await
pub fn notify_at_async(
&self,
id: TimerId,
instant: Instant,
) -> TimerFuture<impl Future<Output = TimeResponse>> {
let future = self
.context
.request_from_shell(TimeRequest::NotifyAt { id, instant });
TimerFuture::new(id, future)
}

/// Ask to receive a notification when the specified duration has elapsed.
Expand All @@ -162,23 +176,93 @@ where

/// Ask to receive a notification when the specified duration has elapsed.
/// This is an async call to use with [`crux_core::compose::Compose`].
pub async fn notify_after_async(&self, id: TimerId, duration: Duration) -> TimeResponse {
self.context
.request_from_shell(TimeRequest::NotifyAfter { id, duration })
.await
pub fn notify_after_async(
&self,
id: TimerId,
duration: Duration,
) -> TimerFuture<impl Future<Output = TimeResponse>> {
let future = self
.context
.request_from_shell(TimeRequest::NotifyAfter { id, duration });
TimerFuture::new(id, future)
}

pub fn clear(&self, id: TimerId) {
self.context.spawn({
let context = self.context.clone();
{
let mut lock = CLEARED_TIMER_IDS.lock().unwrap();
lock.insert(id);
}

let context = self.context.clone();
async move {
context.notify_shell(TimeRequest::Clear { id }).await;
}
});
}
}

pin_project_lite::pin_project! {
pub struct TimerFuture<F>
where
F: Future<Output = TimeResponse>,
{
timer_id: TimerId,
is_cleared: bool,
#[pin]
future: F,
}
}

impl<F> Future for TimerFuture<F>
where
F: Future<Output = TimeResponse>,
{
type Output = TimeResponse;

fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
if self.is_cleared {
// short-circuit return
return Poll::Ready(TimeResponse::Cleared { id: self.timer_id });
};
// see if the timer has been cleared
let timer_is_cleared = {
let mut lock = CLEARED_TIMER_IDS.lock().unwrap();
lock.remove(&self.timer_id)
};
let this = self.project();
*this.is_cleared = timer_is_cleared;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd have to try to have specific advice, but I think you might be able to avoid the pin_project by self.deref_mut() and an Unpin trait bound for whatever is complaining about not being guaranteed unpin.

Was that problematic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have removed pin_project, you're right, it wasn't necessary in this instance.

if timer_is_cleared {
// if the timer has been cleared, immediately return 'Ready' without
// waiting for the timer to elapse
Poll::Ready(TimeResponse::Cleared { id: *this.timer_id })
} else {
// otherwise, defer to the inner future
this.future.poll(cx)
}
}
}

impl<F: Future<Output = TimeResponse>> TimerFuture<F> {
fn new(timer_id: TimerId, future: F) -> Self {
Self {
timer_id,
future,
is_cleared: false,
}
}
}

// Global HashSet containing the ids of timers which have been _cleared_
// but the whose futures have _not since been polled_. When the future is next
// polled, the timer id is evicted from this set and the timer is 'poisoned'
// so as to return immediately without waiting on the shell.
static CLEARED_TIMER_IDS: LazyLock<Mutex<HashSet<TimerId>>> =
LazyLock::new(|| Mutex::new(HashSet::new()));

#[cfg(test)]
mod test {
use super::*;
Expand Down
44 changes: 27 additions & 17 deletions crux_time/tests/time_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod shared {

StartDebounce,
DurationElapsed(usize, TimeResponse),
Cancel(TimerId),
}

#[derive(Default)]
Expand Down Expand Up @@ -96,6 +97,9 @@ mod shared {
Event::DurationElapsed(_, _) => {
panic!("Unexpected debounce event")
}
Event::Cancel(timer_id) => {
caps.time.clear(timer_id);
}
}
}

Expand Down Expand Up @@ -249,27 +253,33 @@ mod tests {
}

#[test]
pub fn test_cancel_timer() {
pub fn test_start_debounce_then_clear() {
let app = AppTester::<App, _>::default();
let mut model = Model::default();

let request1 = &mut app
let mut debounce = app
.update(Event::StartDebounce, &mut model)
.expect_one_effect()
.expect_time();

assert!(model.debounce_time_id.is_some());

app.resolve_to_event_then_update(
request1,
TimeResponse::Cleared {
id: model.debounce_time_id.unwrap(),
},
&mut model,
)
.assert_empty();

assert!(!model.debounce_complete);
assert!(model.debounce_time_id.is_none());
let timer_id = model.debounce_time_id.unwrap();
let _cancel = app
.update(Event::Cancel(timer_id), &mut model)
.expect_one_effect()
.expect_time();
// this is a little strange-looking. We have cleared the timer,
// so the in-flight debounce should have resolved. But to force that
// to happen, we have to run the app, and the easiest way to do that
// is to resolve the original debounce effect with a fake outcome -
// which will be ignored in favour of TimeResponse::Cleared
let ev = app
.resolve(
&mut debounce,
TimeResponse::DurationElapsed { id: timer_id },
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A good test that this doesn't cause issues if the shell still fires the timer.

In a way, we could simply not tell the shell about the cancellation and just stop the task like you are doing which will ignore the timer when it fires.

.unwrap()
.expect_one_event();
let Event::DurationElapsed(_, TimeResponse::Cleared { id }) = ev else {
panic!()
};
assert_eq!(id, timer_id);
}
}
Loading