diff --git a/Cargo.lock b/Cargo.lock index f505c138f..5a88bad1d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1142,9 +1142,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" [[package]] name = "pin-utils" diff --git a/crux_time/Cargo.toml b/crux_time/Cargo.toml index fa1b0b7e8..7718e78e5 100644 --- a/crux_time/Cargo.toml +++ b/crux_time/Cargo.toml @@ -14,9 +14,9 @@ rust-version.workspace = true typegen = ["crux_core/typegen"] [dependencies] +chrono = { version = "0.4.38", features = ["serde"], optional = true } crux_core = { version = "0.10.0", path = "../crux_core" } serde = { workspace = true, features = ["derive"] } -chrono = { version = "0.4.38", features = ["serde"], optional = true } thiserror = "1.0.65" [dev-dependencies] diff --git a/crux_time/src/lib.rs b/crux_time/src/lib.rs index 1907ea589..a07770c89 100644 --- a/crux_time/src/lib.rs +++ b/crux_time/src/lib.rs @@ -15,7 +15,16 @@ pub use instant::Instant; use serde::{Deserialize, Serialize}; use crux_core::capability::{CapabilityContext, Operation}; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::{ + collections::HashSet, + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering}, + LazyLock, Mutex, + }, + task::Poll, +}; #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -26,7 +35,7 @@ pub enum TimeRequest { Clear { id: TimerId }, } -#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct TimerId(pub usize); fn get_timer_id() -> TimerId { @@ -127,7 +136,8 @@ where let this = self.clone(); async move { - context.update_app(callback(this.notify_at_async(tid, instant).await)); + let response = this.notify_at_async(tid, instant).await; + context.update_app(callback(response)); } }); @@ -136,10 +146,15 @@ where /// Ask to receive a notification when the specified [`Instant`] has arrived. /// This is an async call to use with [`crux_core::compose::Compose`]. - pub async fn notify_at_async(&self, id: TimerId, instant: Instant) -> TimeResponse { - self.context - .request_from_shell(TimeRequest::NotifyAt { id, instant }) - .await + pub fn notify_at_async( + &self, + id: TimerId, + instant: Instant, + ) -> TimerFuture> { + let future = self + .context + .request_from_shell(TimeRequest::NotifyAt { id, instant }); + TimerFuture::new(id, future) } /// Ask to receive a notification when the specified duration has elapsed. @@ -162,16 +177,25 @@ where /// Ask to receive a notification when the specified duration has elapsed. /// This is an async call to use with [`crux_core::compose::Compose`]. - pub async fn notify_after_async(&self, id: TimerId, duration: Duration) -> TimeResponse { - self.context - .request_from_shell(TimeRequest::NotifyAfter { id, duration }) - .await + pub fn notify_after_async( + &self, + id: TimerId, + duration: Duration, + ) -> TimerFuture> { + let future = self + .context + .request_from_shell(TimeRequest::NotifyAfter { id, duration }); + TimerFuture::new(id, future) } pub fn clear(&self, id: TimerId) { self.context.spawn({ - let context = self.context.clone(); + { + let mut lock = CLEARED_TIMER_IDS.lock().unwrap(); + lock.insert(id); + } + let context = self.context.clone(); async move { context.notify_shell(TimeRequest::Clear { id }).await; } @@ -179,6 +203,67 @@ where } } +pub struct TimerFuture +where + F: Future + Unpin, +{ + timer_id: TimerId, + is_cleared: bool, + future: F, +} + +impl Future for TimerFuture +where + F: Future + Unpin, +{ + type Output = TimeResponse; + + fn poll( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + if self.is_cleared { + // short-circuit return + return Poll::Ready(TimeResponse::Cleared { id: self.timer_id }); + }; + // see if the timer has been cleared + let timer_is_cleared = { + let mut lock = CLEARED_TIMER_IDS.lock().unwrap(); + lock.remove(&self.timer_id) + }; + let this = self.get_mut(); + this.is_cleared = timer_is_cleared; + if timer_is_cleared { + // if the timer has been cleared, immediately return 'Ready' without + // waiting for the timer to elapse + Poll::Ready(TimeResponse::Cleared { id: this.timer_id }) + } else { + // otherwise, defer to the inner future + Pin::new(&mut this.future).poll(cx) + } + } +} + +impl TimerFuture +where + F: Future + Unpin, +{ + fn new(timer_id: TimerId, future: F) -> Self { + Self { + timer_id, + future, + is_cleared: false, + } + } +} + +// Global HashSet containing the ids of timers which have been _cleared_ +// but the whose futures have _not since been polled_. When the future is next +// polled, the timer id is evicted from this set and the timer is 'poisoned' +// so as to return immediately without waiting on the shell. +static CLEARED_TIMER_IDS: LazyLock>> = + LazyLock::new(|| Mutex::new(HashSet::new())); + #[cfg(test)] mod test { use super::*; diff --git a/crux_time/tests/time_test.rs b/crux_time/tests/time_test.rs index 6d1dd577a..923784617 100644 --- a/crux_time/tests/time_test.rs +++ b/crux_time/tests/time_test.rs @@ -17,6 +17,7 @@ mod shared { StartDebounce, DurationElapsed(usize, TimeResponse), + Cancel(TimerId), } #[derive(Default)] @@ -96,6 +97,9 @@ mod shared { Event::DurationElapsed(_, _) => { panic!("Unexpected debounce event") } + Event::Cancel(timer_id) => { + caps.time.clear(timer_id); + } } } @@ -249,27 +253,33 @@ mod tests { } #[test] - pub fn test_cancel_timer() { + pub fn test_start_debounce_then_clear() { let app = AppTester::::default(); let mut model = Model::default(); - - let request1 = &mut app + let mut debounce = app .update(Event::StartDebounce, &mut model) .expect_one_effect() .expect_time(); - - assert!(model.debounce_time_id.is_some()); - - app.resolve_to_event_then_update( - request1, - TimeResponse::Cleared { - id: model.debounce_time_id.unwrap(), - }, - &mut model, - ) - .assert_empty(); - - assert!(!model.debounce_complete); - assert!(model.debounce_time_id.is_none()); + let timer_id = model.debounce_time_id.unwrap(); + let _cancel = app + .update(Event::Cancel(timer_id), &mut model) + .expect_one_effect() + .expect_time(); + // this is a little strange-looking. We have cleared the timer, + // so the in-flight debounce should have resolved. But to force that + // to happen, we have to run the app, and the easiest way to do that + // is to resolve the original debounce effect with a fake outcome - + // which will be ignored in favour of TimeResponse::Cleared + let ev = app + .resolve( + &mut debounce, + TimeResponse::DurationElapsed { id: timer_id }, + ) + .unwrap() + .expect_one_event(); + let Event::DurationElapsed(_, TimeResponse::Cleared { id }) = ev else { + panic!() + }; + assert_eq!(id, timer_id); } }