From 2e3fae389408978f4be63c34278a76f597ee3d91 Mon Sep 17 00:00:00 2001 From: Alex Whitney Date: Mon, 6 Jan 2025 16:57:06 +0000 Subject: [PATCH 1/3] cancelling a timer also aborts the cancelled task --- Cargo.lock | 5 +- crux_time/Cargo.toml | 3 +- crux_time/src/lib.rs | 90 +++++++++++++++++++++++++++++++----- crux_time/tests/time_test.rs | 44 +++++++++++------- 4 files changed, 110 insertions(+), 32 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f505c138f..a401ffc0a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -441,6 +441,7 @@ version = "0.6.0" dependencies = [ "chrono", "crux_core", + "pin-project-lite", "serde", "serde_json", "thiserror", @@ -1142,9 +1143,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" [[package]] name = "pin-utils" diff --git a/crux_time/Cargo.toml b/crux_time/Cargo.toml index fa1b0b7e8..57f24460b 100644 --- a/crux_time/Cargo.toml +++ b/crux_time/Cargo.toml @@ -14,9 +14,10 @@ rust-version.workspace = true typegen = ["crux_core/typegen"] [dependencies] +chrono = { version = "0.4.38", features = ["serde"], optional = true } crux_core = { version = "0.10.0", path = "../crux_core" } +pin-project-lite = "0.2.16" serde = { workspace = true, features = ["derive"] } -chrono = { version = "0.4.38", features = ["serde"], optional = true } thiserror = "1.0.65" [dev-dependencies] diff --git a/crux_time/src/lib.rs b/crux_time/src/lib.rs index 1907ea589..4dacb1c6c 100644 --- a/crux_time/src/lib.rs +++ b/crux_time/src/lib.rs @@ -15,7 +15,15 @@ pub use instant::Instant; use serde::{Deserialize, Serialize}; use crux_core::capability::{CapabilityContext, Operation}; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::{ + collections::HashSet, + future::Future, + sync::{ + atomic::{AtomicUsize, Ordering}, + LazyLock, Mutex, + }, + task::Poll, +}; #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -26,7 +34,7 @@ pub enum TimeRequest { Clear { id: TimerId }, } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct TimerId(pub usize); fn get_timer_id() -> TimerId { @@ -127,7 +135,8 @@ where let this = self.clone(); async move { - context.update_app(callback(this.notify_at_async(tid, instant).await)); + let response = this.notify_at_async(tid, instant).await; + context.update_app(callback(response)); } }); @@ -136,10 +145,15 @@ where /// Ask to receive a notification when the specified [`Instant`] has arrived. /// This is an async call to use with [`crux_core::compose::Compose`]. - pub async fn notify_at_async(&self, id: TimerId, instant: Instant) -> TimeResponse { - self.context - .request_from_shell(TimeRequest::NotifyAt { id, instant }) - .await + pub fn notify_at_async( + &self, + id: TimerId, + instant: Instant, + ) -> TimerFuture> { + let future = self + .context + .request_from_shell(TimeRequest::NotifyAt { id, instant }); + TimerFuture::new(id, future) } /// Ask to receive a notification when the specified duration has elapsed. @@ -162,16 +176,25 @@ where /// Ask to receive a notification when the specified duration has elapsed. /// This is an async call to use with [`crux_core::compose::Compose`]. - pub async fn notify_after_async(&self, id: TimerId, duration: Duration) -> TimeResponse { - self.context - .request_from_shell(TimeRequest::NotifyAfter { id, duration }) - .await + pub fn notify_after_async( + &self, + id: TimerId, + duration: Duration, + ) -> TimerFuture> { + let future = self + .context + .request_from_shell(TimeRequest::NotifyAfter { id, duration }); + TimerFuture::new(id, future) } pub fn clear(&self, id: TimerId) { self.context.spawn({ - let context = self.context.clone(); + { + let mut lock = CLEARED_TIMER_IDS.lock().unwrap(); + lock.insert(id); + } + let context = self.context.clone(); async move { context.notify_shell(TimeRequest::Clear { id }).await; } @@ -179,6 +202,49 @@ where } } +pin_project_lite::pin_project! { + pub struct TimerFuture + where + F: Future, + { + timer_id: TimerId, + #[pin] + future: F, + } +} + +impl Future for TimerFuture +where + F: Future, +{ + type Output = TimeResponse; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let is_cleared = { + let lock = CLEARED_TIMER_IDS.lock().unwrap(); + lock.contains(&self.timer_id) + }; + if is_cleared { + Poll::Ready(TimeResponse::Cleared { id: self.timer_id }) + } else { + let this = self.project(); + this.future.poll(cx) + } + } +} + +impl> TimerFuture { + fn new(timer_id: TimerId, future: F) -> Self { + Self { timer_id, future } + } +} + +static CLEARED_TIMER_IDS: LazyLock>> = + LazyLock::new(|| Mutex::new(HashSet::new())); + #[cfg(test)] mod test { use super::*; diff --git a/crux_time/tests/time_test.rs b/crux_time/tests/time_test.rs index 6d1dd577a..923784617 100644 --- a/crux_time/tests/time_test.rs +++ b/crux_time/tests/time_test.rs @@ -17,6 +17,7 @@ mod shared { StartDebounce, DurationElapsed(usize, TimeResponse), + Cancel(TimerId), } #[derive(Default)] @@ -96,6 +97,9 @@ mod shared { Event::DurationElapsed(_, _) => { panic!("Unexpected debounce event") } + Event::Cancel(timer_id) => { + caps.time.clear(timer_id); + } } } @@ -249,27 +253,33 @@ mod tests { } #[test] - pub fn test_cancel_timer() { + pub fn test_start_debounce_then_clear() { let app = AppTester::::default(); let mut model = Model::default(); - - let request1 = &mut app + let mut debounce = app .update(Event::StartDebounce, &mut model) .expect_one_effect() .expect_time(); - - assert!(model.debounce_time_id.is_some()); - - app.resolve_to_event_then_update( - request1, - TimeResponse::Cleared { - id: model.debounce_time_id.unwrap(), - }, - &mut model, - ) - .assert_empty(); - - assert!(!model.debounce_complete); - assert!(model.debounce_time_id.is_none()); + let timer_id = model.debounce_time_id.unwrap(); + let _cancel = app + .update(Event::Cancel(timer_id), &mut model) + .expect_one_effect() + .expect_time(); + // this is a little strange-looking. We have cleared the timer, + // so the in-flight debounce should have resolved. But to force that + // to happen, we have to run the app, and the easiest way to do that + // is to resolve the original debounce effect with a fake outcome - + // which will be ignored in favour of TimeResponse::Cleared + let ev = app + .resolve( + &mut debounce, + TimeResponse::DurationElapsed { id: timer_id }, + ) + .unwrap() + .expect_one_event(); + let Event::DurationElapsed(_, TimeResponse::Cleared { id }) = ev else { + panic!() + }; + assert_eq!(id, timer_id); } } From b0e94566cfe19dcc0ff36ffc2669912482de9cf5 Mon Sep 17 00:00:00 2001 From: Alex Whitney Date: Mon, 6 Jan 2025 17:17:51 +0000 Subject: [PATCH 2/3] add some more comments --- crux_time/src/lib.rs | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/crux_time/src/lib.rs b/crux_time/src/lib.rs index 4dacb1c6c..fbbcb06d5 100644 --- a/crux_time/src/lib.rs +++ b/crux_time/src/lib.rs @@ -208,6 +208,7 @@ pin_project_lite::pin_project! { F: Future, { timer_id: TimerId, + is_cleared: bool, #[pin] future: F, } @@ -223,14 +224,23 @@ where self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll { - let is_cleared = { - let lock = CLEARED_TIMER_IDS.lock().unwrap(); - lock.contains(&self.timer_id) + if self.is_cleared { + // short-circuit return + return Poll::Ready(TimeResponse::Cleared { id: self.timer_id }); }; - if is_cleared { - Poll::Ready(TimeResponse::Cleared { id: self.timer_id }) + // see if the timer has been cleared + let timer_is_cleared = { + let mut lock = CLEARED_TIMER_IDS.lock().unwrap(); + lock.remove(&self.timer_id) + }; + let this = self.project(); + *this.is_cleared = timer_is_cleared; + if timer_is_cleared { + // if the timer has been cleared, immediately return 'Ready' without + // waiting for the timer to elapse + Poll::Ready(TimeResponse::Cleared { id: *this.timer_id }) } else { - let this = self.project(); + // otherwise, defer to the inner future this.future.poll(cx) } } @@ -238,10 +248,18 @@ where impl> TimerFuture { fn new(timer_id: TimerId, future: F) -> Self { - Self { timer_id, future } + Self { + timer_id, + future, + is_cleared: false, + } } } +// Global HashSet containing the ids of timers which have been _cleared_ +// but the whose futures have _not since been polled_. When the future is next +// polled, the timer id is evicted from this set and the timer is 'poisoned' +// so as to return immediately without waiting on the shell. static CLEARED_TIMER_IDS: LazyLock>> = LazyLock::new(|| Mutex::new(HashSet::new())); From 85fa5fd14276db0a3e500ae258f2a61978d72a4c Mon Sep 17 00:00:00 2001 From: Alex Whitney Date: Tue, 7 Jan 2025 09:22:55 +0000 Subject: [PATCH 3/3] remove pin_project --- Cargo.lock | 1 - crux_time/Cargo.toml | 1 - crux_time/src/lib.rs | 35 ++++++++++++++++++----------------- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a401ffc0a..5a88bad1d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -441,7 +441,6 @@ version = "0.6.0" dependencies = [ "chrono", "crux_core", - "pin-project-lite", "serde", "serde_json", "thiserror", diff --git a/crux_time/Cargo.toml b/crux_time/Cargo.toml index 57f24460b..7718e78e5 100644 --- a/crux_time/Cargo.toml +++ b/crux_time/Cargo.toml @@ -16,7 +16,6 @@ typegen = ["crux_core/typegen"] [dependencies] chrono = { version = "0.4.38", features = ["serde"], optional = true } crux_core = { version = "0.10.0", path = "../crux_core" } -pin-project-lite = "0.2.16" serde = { workspace = true, features = ["derive"] } thiserror = "1.0.65" diff --git a/crux_time/src/lib.rs b/crux_time/src/lib.rs index fbbcb06d5..a07770c89 100644 --- a/crux_time/src/lib.rs +++ b/crux_time/src/lib.rs @@ -18,6 +18,7 @@ use crux_core::capability::{CapabilityContext, Operation}; use std::{ collections::HashSet, future::Future, + pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, LazyLock, Mutex, @@ -202,26 +203,23 @@ where } } -pin_project_lite::pin_project! { - pub struct TimerFuture - where - F: Future, - { - timer_id: TimerId, - is_cleared: bool, - #[pin] - future: F, - } +pub struct TimerFuture +where + F: Future + Unpin, +{ + timer_id: TimerId, + is_cleared: bool, + future: F, } impl Future for TimerFuture where - F: Future, + F: Future + Unpin, { type Output = TimeResponse; fn poll( - self: std::pin::Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll { if self.is_cleared { @@ -233,20 +231,23 @@ where let mut lock = CLEARED_TIMER_IDS.lock().unwrap(); lock.remove(&self.timer_id) }; - let this = self.project(); - *this.is_cleared = timer_is_cleared; + let this = self.get_mut(); + this.is_cleared = timer_is_cleared; if timer_is_cleared { // if the timer has been cleared, immediately return 'Ready' without // waiting for the timer to elapse - Poll::Ready(TimeResponse::Cleared { id: *this.timer_id }) + Poll::Ready(TimeResponse::Cleared { id: this.timer_id }) } else { // otherwise, defer to the inner future - this.future.poll(cx) + Pin::new(&mut this.future).poll(cx) } } } -impl> TimerFuture { +impl TimerFuture +where + F: Future + Unpin, +{ fn new(timer_id: TimerId, future: F) -> Self { Self { timer_id,