From 4d2f23867ea5c5ef4d607f5c92af235f6aea2cd6 Mon Sep 17 00:00:00 2001 From: Jennings Zhang Date: Sun, 8 Sep 2024 13:38:04 -0400 Subject: [PATCH] Rework SubjectLimiter::lock to return a RAII --- src/limiter.rs | 229 +++++++++++++++++++------------------- src/notifier.rs | 52 ++++----- tests/integration_test.rs | 8 +- 3 files changed, 143 insertions(+), 146 deletions(-) diff --git a/src/limiter.rs b/src/limiter.rs index 7495d73..a08b229 100644 --- a/src/limiter.rs +++ b/src/limiter.rs @@ -1,30 +1,32 @@ use std::collections::HashMap; -use std::future::Future; +use std::fmt::Debug; use std::hash::Hash; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; -use tokio::sync::Semaphore; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; + +/// Something that can be used as a subject key. +pub(crate) trait Subject: Eq + Hash + Clone + Debug {} +impl Subject for T {} /// A synchronization and rate-limiting mechanism. -pub(crate) struct SubjectLimiter(PureSubjectLimiter); +pub(crate) struct SubjectLimiter(KindaPureSubjectLimiter); -impl SubjectLimiter { +impl SubjectLimiter { /// Create a new [SubjectLimiter] which rate-limits functions per subject /// to be called no more than once per given `interval`. pub fn new(interval: Duration) -> Self { - Self(PureSubjectLimiter::new(Instant::now() - interval, interval)) + Self(KindaPureSubjectLimiter::new( + Instant::now() - interval, + interval, + )) } /// Wraps the given async function `f`, calling it if it isn't currently /// running not has been called recently (within the duration specified /// to [`SubjectLimiter::new`]). Otherwise, does nothing (i.e. `f` is not called). - pub async fn lock(&self, subject: S, f: Fut) -> Option - where - Fut: Future, - { - self.0 - .lock(Instant::now(), || Instant::now(), subject, f) - .await + pub fn lock(&self, subject: S) -> Option> { + self.0.lock(Instant::now(), subject) } /// Forget a subject. Blocks until the function running for the subject is done. @@ -50,49 +52,59 @@ impl SubjectState { } } -/// Pure implementation of [SubjectLimiter]. -struct PureSubjectLimiter { - subjects: std::sync::Mutex>, +/// (Not actually) pure implementation of [SubjectLimiter]. +/// +/// In the past, I thought it would be easier to test [SubjectLimiter] if it were +/// implemented purely, but I changed my mind about that. +struct KindaPureSubjectLimiter { + subjects: Arc>>, start: Instant, interval: Duration, } -impl PureSubjectLimiter { +/// A [RAII](https://github.com/rust-unofficial/patterns/blob/main/src/patterns/behavioural/RAII.md) +/// for synchronization by calling [`SubjectLimiter::lock`]. +pub(crate) struct Permit { + _permit: OwnedSemaphorePermit, + subject: S, + subjects: Arc>>, +} + +impl Drop for Permit { + fn drop(&mut self) { + let mut subjects = self.subjects.lock().unwrap(); + if let Some(state) = subjects.get_mut(&self.subject) { + state.last_sent = Instant::now(); // impure + } + } +} + +impl KindaPureSubjectLimiter { fn new(start: Instant, interval: Duration) -> Self { Self { - subjects: Default::default(), + subjects: Arc::new(Default::default()), start, interval, } } - async fn lock(&self, now: Instant, later: L, subject: S, f: Fut) -> Option - where - L: FnOnce() -> Instant, - Fut: Future, - { - let (try_acquire, last_sent) = { - // note: don't want to keep self.subjects locked while running `f` - let mut subjects = self.subjects.lock().unwrap(); - let state = subjects - .entry(subject.clone()) - .or_insert_with(|| SubjectState::new(self.start)); - let permit = Arc::clone(&state.semaphore).try_acquire_owned(); - (permit, state.last_sent) - }; - if now - last_sent < self.interval { + fn lock(&self, now: Instant, subject: S) -> Option> { + let mut subjects = self.subjects.lock().unwrap(); + let state = subjects + .entry(subject.clone()) + .or_insert_with(|| SubjectState::new(self.start)); + if now - state.last_sent < self.interval { return None; } - if let Ok(_permit_raii) = try_acquire { - let ret = f.await; - let mut subjects = self.subjects.lock().unwrap(); - if let Some(state) = subjects.get_mut(&subject) { - state.last_sent = later(); - } - Some(ret) - } else { - None - } + Arc::clone(&state.semaphore) + .try_acquire_owned() + .ok() + .map(|permit| permit) + .map(|permit| Permit { + _permit: permit, + subject, + subjects: Arc::clone(&self.subjects), + }) } async fn forget(&self, subject: &S) { @@ -113,12 +125,15 @@ impl PureSubjectLimiter { } } Err(_) => { - tracing::warn!(subject = subject, "SubjectLimiter::forget called twice"); + tracing::warn!( + subject = format!("{subject:?}"), + "SubjectLimiter::forget called twice" + ); } } } else { tracing::warn!( - subject = subject, + subject = format!("{subject:?}"), "SubjectLimiter::forget called on unknown subject" ); } @@ -128,79 +143,72 @@ impl PureSubjectLimiter { #[cfg(test)] mod tests { use super::*; - use std::sync::Arc; use tokio::task::JoinHandle; #[tokio::test] async fn test_lock() { let interval = Duration::from_millis(100); - let limiter = Arc::new(SubjectLimiter::new(interval)); - let task_a = create_task(&limiter, "subject1"); - let task_b = create_task(&limiter, "subject1"); - let task_c = create_task(&limiter, "subject2"); - - let (ret_a, ret_b, ret_c) = tokio::try_join!(task_a, task_b, task_c).unwrap(); - assert!( - ret_a.is_some(), - "task_a was not called, but it should have been called because it was the first function." + let limiter = SubjectLimiter::new(interval); + let task_a = create_task(&limiter, "subject1").expect( + "task_a was not called, but it should have been called \ + because it was the first function.", ); + let task_b = create_task(&limiter, "subject1"); assert!( - ret_b.is_none(), + task_b.is_none(), "task_b was called, but it should not have been called because task a recently ran." ); - assert!( - ret_c.is_some(), - "task_c was not called, but it should have been called because it is a different subject than task_a." + let task_c = create_task(&limiter, "subject2").expect( + "task_c was not called, but it should have been called \ + because it is a different subject than task_a.", ); - tokio::time::sleep(interval).await; - let task_d = create_task(&limiter, "subject1"); - let ret_d = task_d.await.unwrap(); - assert!( - ret_d.is_some(), - "task_d was not called, but it should have been called because \"subject1\" has not been busy for a while." + + tokio::time::sleep(interval * 2).await; + let task_d = create_task(&limiter, "subject1").expect( + "task_d was not called, but it should have been called \ + because \"subject1\" has not been busy for a while.", ); + tokio::try_join!(task_a, task_c, task_d).unwrap(); } - fn create_task( - limiter: &Arc>, + fn create_task( + limiter: &SubjectLimiter, subject: S, - ) -> JoinHandle> { - let limiter = Arc::clone(&limiter); - tokio::spawn(async move { - limiter - .lock(subject, async { - tokio::time::sleep(Duration::from_millis(10)).await; - }) - .await - }) + ) -> Option> { + if let Some(raii) = limiter.lock(subject) { + let task = tokio::spawn(async move { + let _raii_binding = raii; + tokio::time::sleep(Duration::from_millis(10)).await; + }); + Some(task) + } else { + None + } } #[tokio::test] async fn test_lock_during_busy_forget() { let interval = Duration::from_millis(200); - let limiter = Arc::new(SubjectLimiter::new(interval)); - let finished = Arc::new(std::sync::Mutex::new(false)); + let limiter = SubjectLimiter::new(interval); + let finished = Arc::new(Mutex::new(false)); let (tx, rx) = tokio::sync::oneshot::channel(); let start = Instant::now(); let task_a = { - let limiter = Arc::clone(&limiter); let finished = Arc::clone(&finished); let start = start.clone(); + let raii = limiter.lock("subject1").unwrap(); tokio::spawn(async move { - limiter - .lock("subject1", async { - tx.send(()).unwrap(); - tokio::time::sleep(Duration::from_millis(100)).await; - *finished.lock().unwrap() = true; - start.elapsed() - }) - .await + let _raii_binding = raii; + tx.send(()).unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + *finished.lock().unwrap() = true; + start.elapsed() }) }; rx.await.unwrap(); limiter.forget(&"subject1").await; let outer_elapsed = start.elapsed(); - let task_a_elapsed = task_a.await.unwrap().unwrap(); + let task_a_elapsed = task_a.await.unwrap(); assert!( outer_elapsed >= task_a_elapsed, "SubjectLimiter::forget should have taken as long as task_a slept for, \ @@ -213,57 +221,46 @@ mod tests { #[tokio::test] async fn test_forget_called_twice_shouldnt_blow_up() { let interval = Duration::from_millis(200); - let limiter = Arc::new(SubjectLimiter::new(interval)); + let limiter = SubjectLimiter::new(interval); let (tx, rx) = tokio::sync::oneshot::channel(); let task_a = { - let limiter = Arc::clone(&limiter); + let raii = limiter.lock("subject1"); tokio::spawn(async move { - limiter - .lock("subject1", async { - tx.send(()).unwrap(); - tokio::time::sleep(Duration::from_millis(100)).await; - }) - .await + let _raii_binding = raii; + tx.send(()).unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; }) }; rx.await.unwrap(); tokio::join!(limiter.forget(&"subject1"), limiter.forget(&"subject1")); - task_a.await.unwrap().unwrap(); + task_a.await.unwrap(); } #[tokio::test] async fn test_different_subjects_not_locked() { let interval = Duration::from_millis(100); - let limiter = Arc::new(SubjectLimiter::new(interval)); + let limiter = SubjectLimiter::new(interval); let (tx, mut rx) = tokio::sync::mpsc::channel(8); let task_a = { - let limiter = Arc::clone(&limiter); let tx = tx.clone(); + let raii = limiter.lock("subject1").unwrap(); tokio::spawn(async move { - limiter - .lock("subject1", async { - tx.send(1).await.unwrap(); - tokio::time::sleep(Duration::from_millis(200)).await; - tx.send(3).await.unwrap(); - }) - .await + let _raii_binding = raii; + tx.send(1).await.unwrap(); + tokio::time::sleep(Duration::from_millis(200)).await; + tx.send(3).await.unwrap(); }) }; let task_b = { - let limiter = Arc::clone(&limiter); let tx = tx.clone(); + let raii = limiter.lock("subject2").unwrap(); tokio::spawn(async move { - limiter - .lock("subject2", async { - tokio::time::sleep(Duration::from_millis(100)).await; - tx.send(2).await.unwrap(); - }) - .await + let _raii_binding = raii; + tokio::time::sleep(Duration::from_millis(100)).await; + tx.send(2).await.unwrap(); }) }; - let (a, b) = tokio::try_join!(task_a, task_b).unwrap(); - assert!(a.is_some()); - assert!(b.is_some()); + tokio::try_join!(task_a, task_b).unwrap(); let actual = [rx.recv().await, rx.recv().await, rx.recv().await]; let expected = [Some(1), Some(2), Some(3)]; assert_eq!(actual, expected); diff --git a/src/notifier.rs b/src/notifier.rs index 3a76b9b..8d017cc 100644 --- a/src/notifier.rs +++ b/src/notifier.rs @@ -30,15 +30,17 @@ pub async fn cube_pacsfile_notifier( let mut counts: SeriesCounts = Default::default(); let limiter = Arc::new(SubjectLimiter::new(progress_interval)); while let Some((series, event)) = receiver.recv().await { - tx.send(handle_event( + let task = handle_event( &mut counts, series, event, &celery, nats_client.clone(), - Arc::clone(&limiter), - )) - .unwrap(); + &limiter, + ); + if let Some(task) = task { + tx.send(task).unwrap(); + } } drop(tx); }; @@ -69,25 +71,31 @@ fn handle_event( event: SeriesEvent, DicomInfo>, celery_client: &Arc, nats_client: Option, - limiter: Arc>, -) -> RegistrationTask { + limiter: &Arc>, +) -> Option { match event { SeriesEvent::Instance(result) => { let payload = count_series(series_key.clone(), counts, result); - tokio::spawn(async move { - maybe_send_lonk(nats_client, limiter, &series_key, payload).await + limiter.lock(series_key.clone()).map(|raii| { + tokio::spawn(async move { + let _raii_binding = raii; + maybe_send_lonk(nats_client, &series_key, payload).await + }) }) } SeriesEvent::Finish(series_info) => { let celery_client = Arc::clone(celery_client); + let limiter = Arc::clone(limiter); let ndicom = counts.remove(&series_key).unwrap_or(0); - tokio::spawn(async move { + let task = tokio::spawn(async move { + limiter.forget(&series_key).await; let (a, b) = tokio::join!( - maybe_send_final_progress_messages(nats_client, limiter, &series_key, ndicom), + maybe_send_final_progress_messages(nats_client, &series_key, ndicom), send_registration_task_to_celery(series_info, ndicom, &celery_client) ); a.and(b) - }) + }); + Some(task) } } } @@ -111,17 +119,14 @@ fn count_series( async fn maybe_send_lonk( client: Option, - limiter: Arc>, series: &SeriesKey, payload: Bytes, ) -> Result<(), ()> { if let Some(client) = client { - send_lonk(client, limiter, series, payload) - .await - .map_err(|e| { - tracing::error!(error = e.to_string()); - () - }) + send_lonk(client, series, payload).await.map_err(|e| { + tracing::error!(error = e.to_string()); + () + }) } else { Ok(()) } @@ -129,26 +134,17 @@ async fn maybe_send_lonk( async fn send_lonk( client: async_nats::Client, - limiter: Arc>, series: &SeriesKey, payload: Bytes, ) -> Result<(), async_nats::PublishError> { - let subject = subject_of(series); - limiter - .lock(subject.clone(), client.publish(subject, payload)) - .await - .unwrap_or(Ok(())) + client.publish(subject_of(series), payload).await } async fn maybe_send_final_progress_messages( client: Option, - limiter: Arc>, series: &SeriesKey, ndicom: u32, ) -> Result<(), ()> { - // ensures prior progress messages are done being sent - limiter.forget(&subject_of(series)).await; - if let Some(client) = client { send_final_progress_messages(client, series, ndicom) .await diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 7a1be38..38658d2 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -67,8 +67,12 @@ async fn test_run_everything_from_env() { // wait for server to shut down server_handle.await.unwrap().unwrap(); - // shutdown the NATS subscriber - tokio::time::sleep(core::time::Duration::from_secs(1)).await; + // Shutdown the NATS subscriber after waiting a little bit. + // Note: instead of shutting itself down after receiving the correct number + // of "DONE" messages, we prefer the naive approach of waiting 500ms instead, + // so that here in the test we can assert that the "DONE" messages do indeed + // come last and no out-of-order/race condition errors are happening. + tokio::time::sleep(Duration::from_millis(500)).await; nats_shutdown_tx.send(true).await.unwrap(); let lonk_messages = nats_subscriber_loop.await.unwrap();