Skip to content

Commit

Permalink
Rework SubjectLimiter::lock to return a RAII
Browse files Browse the repository at this point in the history
  • Loading branch information
jennydaman committed Sep 8, 2024
1 parent da25c3a commit 4d2f238
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 146 deletions.
229 changes: 113 additions & 116 deletions src/limiter.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
use std::collections::HashMap;
use std::future::Future;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::sync::Semaphore;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};

/// Something that can be used as a subject key.
pub(crate) trait Subject: Eq + Hash + Clone + Debug {}
impl<T: Eq + Hash + Clone + Debug> Subject for T {}

/// A synchronization and rate-limiting mechanism.
pub(crate) struct SubjectLimiter<S: Eq + Hash + Clone + tracing::Value>(PureSubjectLimiter<S>);
pub(crate) struct SubjectLimiter<S: Subject>(KindaPureSubjectLimiter<S>);

impl<S: Eq + Hash + Clone + tracing::Value> SubjectLimiter<S> {
impl<S: Subject> SubjectLimiter<S> {
/// Create a new [SubjectLimiter] which rate-limits functions per subject
/// to be called no more than once per given `interval`.
pub fn new(interval: Duration) -> Self {
Self(PureSubjectLimiter::new(Instant::now() - interval, interval))
Self(KindaPureSubjectLimiter::new(
Instant::now() - interval,
interval,
))
}

/// Wraps the given async function `f`, calling it if it isn't currently
/// running not has been called recently (within the duration specified
/// to [`SubjectLimiter::new`]). Otherwise, does nothing (i.e. `f` is not called).
pub async fn lock<Fut, R>(&self, subject: S, f: Fut) -> Option<R>
where
Fut: Future<Output = R>,
{
self.0
.lock(Instant::now(), || Instant::now(), subject, f)
.await
pub fn lock(&self, subject: S) -> Option<Permit<S>> {
self.0.lock(Instant::now(), subject)
}

/// Forget a subject. Blocks until the function running for the subject is done.
Expand All @@ -50,49 +52,59 @@ impl SubjectState {
}
}

/// Pure implementation of [SubjectLimiter].
struct PureSubjectLimiter<S: Eq + Hash + Clone + tracing::Value> {
subjects: std::sync::Mutex<HashMap<S, SubjectState>>,
/// (Not actually) pure implementation of [SubjectLimiter].
///
/// In the past, I thought it would be easier to test [SubjectLimiter] if it were
/// implemented purely, but I changed my mind about that.
struct KindaPureSubjectLimiter<S: Subject> {
subjects: Arc<Mutex<HashMap<S, SubjectState>>>,
start: Instant,
interval: Duration,
}

impl<S: Eq + Hash + Clone + tracing::Value> PureSubjectLimiter<S> {
/// A [RAII](https://github.com/rust-unofficial/patterns/blob/main/src/patterns/behavioural/RAII.md)
/// for synchronization by calling [`SubjectLimiter::lock`].
pub(crate) struct Permit<S: Subject> {
_permit: OwnedSemaphorePermit,
subject: S,
subjects: Arc<Mutex<HashMap<S, SubjectState>>>,
}

impl<S: Subject> Drop for Permit<S> {
fn drop(&mut self) {
let mut subjects = self.subjects.lock().unwrap();
if let Some(state) = subjects.get_mut(&self.subject) {
state.last_sent = Instant::now(); // impure
}
}
}

impl<S: Subject> KindaPureSubjectLimiter<S> {
fn new(start: Instant, interval: Duration) -> Self {
Self {
subjects: Default::default(),
subjects: Arc::new(Default::default()),
start,
interval,
}
}

async fn lock<L, Fut, R>(&self, now: Instant, later: L, subject: S, f: Fut) -> Option<R>
where
L: FnOnce() -> Instant,
Fut: Future<Output = R>,
{
let (try_acquire, last_sent) = {
// note: don't want to keep self.subjects locked while running `f`
let mut subjects = self.subjects.lock().unwrap();
let state = subjects
.entry(subject.clone())
.or_insert_with(|| SubjectState::new(self.start));
let permit = Arc::clone(&state.semaphore).try_acquire_owned();
(permit, state.last_sent)
};
if now - last_sent < self.interval {
fn lock(&self, now: Instant, subject: S) -> Option<Permit<S>> {
let mut subjects = self.subjects.lock().unwrap();
let state = subjects
.entry(subject.clone())
.or_insert_with(|| SubjectState::new(self.start));
if now - state.last_sent < self.interval {
return None;
}
if let Ok(_permit_raii) = try_acquire {
let ret = f.await;
let mut subjects = self.subjects.lock().unwrap();
if let Some(state) = subjects.get_mut(&subject) {
state.last_sent = later();
}
Some(ret)
} else {
None
}
Arc::clone(&state.semaphore)
.try_acquire_owned()
.ok()
.map(|permit| permit)
.map(|permit| Permit {
_permit: permit,
subject,
subjects: Arc::clone(&self.subjects),
})
}

async fn forget(&self, subject: &S) {
Expand All @@ -113,12 +125,15 @@ impl<S: Eq + Hash + Clone + tracing::Value> PureSubjectLimiter<S> {
}
}
Err(_) => {
tracing::warn!(subject = subject, "SubjectLimiter::forget called twice");
tracing::warn!(
subject = format!("{subject:?}"),
"SubjectLimiter::forget called twice"
);
}
}
} else {
tracing::warn!(
subject = subject,
subject = format!("{subject:?}"),
"SubjectLimiter::forget called on unknown subject"
);
}
Expand All @@ -128,79 +143,72 @@ impl<S: Eq + Hash + Clone + tracing::Value> PureSubjectLimiter<S> {
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tokio::task::JoinHandle;

#[tokio::test]
async fn test_lock() {
let interval = Duration::from_millis(100);
let limiter = Arc::new(SubjectLimiter::new(interval));
let task_a = create_task(&limiter, "subject1");
let task_b = create_task(&limiter, "subject1");
let task_c = create_task(&limiter, "subject2");

let (ret_a, ret_b, ret_c) = tokio::try_join!(task_a, task_b, task_c).unwrap();
assert!(
ret_a.is_some(),
"task_a was not called, but it should have been called because it was the first function."
let limiter = SubjectLimiter::new(interval);
let task_a = create_task(&limiter, "subject1").expect(
"task_a was not called, but it should have been called \
because it was the first function.",
);
let task_b = create_task(&limiter, "subject1");
assert!(
ret_b.is_none(),
task_b.is_none(),
"task_b was called, but it should not have been called because task a recently ran."
);
assert!(
ret_c.is_some(),
"task_c was not called, but it should have been called because it is a different subject than task_a."
let task_c = create_task(&limiter, "subject2").expect(
"task_c was not called, but it should have been called \
because it is a different subject than task_a.",
);
tokio::time::sleep(interval).await;
let task_d = create_task(&limiter, "subject1");
let ret_d = task_d.await.unwrap();
assert!(
ret_d.is_some(),
"task_d was not called, but it should have been called because \"subject1\" has not been busy for a while."

tokio::time::sleep(interval * 2).await;
let task_d = create_task(&limiter, "subject1").expect(
"task_d was not called, but it should have been called \
because \"subject1\" has not been busy for a while.",
);
tokio::try_join!(task_a, task_c, task_d).unwrap();
}

fn create_task<S: Eq + Hash + Send + Clone + tracing::Value + 'static>(
limiter: &Arc<SubjectLimiter<S>>,
fn create_task<S: Subject + Send + 'static>(
limiter: &SubjectLimiter<S>,
subject: S,
) -> JoinHandle<Option<()>> {
let limiter = Arc::clone(&limiter);
tokio::spawn(async move {
limiter
.lock(subject, async {
tokio::time::sleep(Duration::from_millis(10)).await;
})
.await
})
) -> Option<JoinHandle<()>> {
if let Some(raii) = limiter.lock(subject) {
let task = tokio::spawn(async move {
let _raii_binding = raii;
tokio::time::sleep(Duration::from_millis(10)).await;
});
Some(task)
} else {
None
}
}

#[tokio::test]
async fn test_lock_during_busy_forget() {
let interval = Duration::from_millis(200);
let limiter = Arc::new(SubjectLimiter::new(interval));
let finished = Arc::new(std::sync::Mutex::new(false));
let limiter = SubjectLimiter::new(interval);
let finished = Arc::new(Mutex::new(false));
let (tx, rx) = tokio::sync::oneshot::channel();
let start = Instant::now();
let task_a = {
let limiter = Arc::clone(&limiter);
let finished = Arc::clone(&finished);
let start = start.clone();
let raii = limiter.lock("subject1").unwrap();
tokio::spawn(async move {
limiter
.lock("subject1", async {
tx.send(()).unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
*finished.lock().unwrap() = true;
start.elapsed()
})
.await
let _raii_binding = raii;
tx.send(()).unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
*finished.lock().unwrap() = true;
start.elapsed()
})
};
rx.await.unwrap();
limiter.forget(&"subject1").await;
let outer_elapsed = start.elapsed();
let task_a_elapsed = task_a.await.unwrap().unwrap();
let task_a_elapsed = task_a.await.unwrap();
assert!(
outer_elapsed >= task_a_elapsed,
"SubjectLimiter::forget should have taken as long as task_a slept for, \
Expand All @@ -213,57 +221,46 @@ mod tests {
#[tokio::test]
async fn test_forget_called_twice_shouldnt_blow_up() {
let interval = Duration::from_millis(200);
let limiter = Arc::new(SubjectLimiter::new(interval));
let limiter = SubjectLimiter::new(interval);
let (tx, rx) = tokio::sync::oneshot::channel();
let task_a = {
let limiter = Arc::clone(&limiter);
let raii = limiter.lock("subject1");
tokio::spawn(async move {
limiter
.lock("subject1", async {
tx.send(()).unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
})
.await
let _raii_binding = raii;
tx.send(()).unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
})
};
rx.await.unwrap();
tokio::join!(limiter.forget(&"subject1"), limiter.forget(&"subject1"));
task_a.await.unwrap().unwrap();
task_a.await.unwrap();
}

#[tokio::test]
async fn test_different_subjects_not_locked() {
let interval = Duration::from_millis(100);
let limiter = Arc::new(SubjectLimiter::new(interval));
let limiter = SubjectLimiter::new(interval);
let (tx, mut rx) = tokio::sync::mpsc::channel(8);
let task_a = {
let limiter = Arc::clone(&limiter);
let tx = tx.clone();
let raii = limiter.lock("subject1").unwrap();
tokio::spawn(async move {
limiter
.lock("subject1", async {
tx.send(1).await.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
tx.send(3).await.unwrap();
})
.await
let _raii_binding = raii;
tx.send(1).await.unwrap();
tokio::time::sleep(Duration::from_millis(200)).await;
tx.send(3).await.unwrap();
})
};
let task_b = {
let limiter = Arc::clone(&limiter);
let tx = tx.clone();
let raii = limiter.lock("subject2").unwrap();
tokio::spawn(async move {
limiter
.lock("subject2", async {
tokio::time::sleep(Duration::from_millis(100)).await;
tx.send(2).await.unwrap();
})
.await
let _raii_binding = raii;
tokio::time::sleep(Duration::from_millis(100)).await;
tx.send(2).await.unwrap();
})
};
let (a, b) = tokio::try_join!(task_a, task_b).unwrap();
assert!(a.is_some());
assert!(b.is_some());
tokio::try_join!(task_a, task_b).unwrap();
let actual = [rx.recv().await, rx.recv().await, rx.recv().await];
let expected = [Some(1), Some(2), Some(3)];
assert_eq!(actual, expected);
Expand Down
Loading

0 comments on commit 4d2f238

Please sign in to comment.