Skip to content

Commit

Permalink
feat: adding Moka caching for a single key calls
Browse files Browse the repository at this point in the history
  • Loading branch information
geekbrother committed Mar 25, 2024
1 parent b0889c1 commit c9bde58
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 16 deletions.
1 change: 1 addition & 0 deletions crates/rate_limit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ edition = "2021"
[dependencies]
chrono = { version = "0.4", features = ["serde"] }
deadpool-redis = "0.12"
moka = { version = "0.12", features = ["future"] }
redis = { version = "0.23", default-features = false, features = ["script"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
Expand Down
109 changes: 93 additions & 16 deletions crates/rate_limit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use {
chrono::{DateTime, Duration, Utc},
core::fmt,
deadpool_redis::{Pool, PoolError},
moka::future::Cache,
redis::{RedisError, Script},
std::{collections::HashMap, sync::Arc},
};
Expand Down Expand Up @@ -35,13 +36,25 @@ pub enum RateLimitError {
Internal(InternalRateLimitError),
}

/// Rate limit check using a token bucket algorithm for one key and in-memory
/// cache for rate-limited keys. `mem_cache` TTL must be set to the same value
/// as the refill interval.
pub async fn token_bucket(
mem_cache: &Cache<String, u64>,
redis_write_pool: &Arc<Pool>,
key: String,
max_tokens: u32,
interval: Duration,
refill_rate: u32,
) -> Result<(), RateLimitError> {
// Check if the key is in the memory cache of rate limited keys
// to omit the redis RTT in case of flood
if let Some(reset) = mem_cache.get(&key).await {
return Err(RateLimitError::RateLimitExceeded(RateLimitExceeded {
reset,
}));
}

let result = token_bucket_many(
redis_write_pool,
vec![key.clone()],
Expand All @@ -54,14 +67,21 @@ pub async fn token_bucket(

let (remaining, reset) = result.get(&key).expect("Should contain the key");
if remaining.is_negative() {
let reset_interval = reset / 1000;

// Insert the rate-limited key into the memory cache to avoid the redis RTT in
// case of flood
mem_cache.insert(key, reset_interval).await;

Err(RateLimitError::RateLimitExceeded(RateLimitExceeded {
reset: reset / 1000,
reset: reset_interval,
}))
} else {
Ok(())
}
}

/// Rate limit check using a token bucket algorithm for many keys.
pub async fn token_bucket_many(
redis_write_pool: &Arc<Pool>,
keys: Vec<String>,
Expand Down Expand Up @@ -95,6 +115,8 @@ pub async fn token_bucket_many(
#[cfg(test)]
mod tests {
const REDIS_URI: &str = "redis://localhost:6379";
const REFILL_INTERVAL_MILLIS: i64 = 100;

use {
super::*,
deadpool_redis::{Config, Runtime},
Expand All @@ -111,16 +133,16 @@ mod tests {
}

#[tokio::test]
async fn test_token_bucket() {
async fn test_token_bucket_many() {
let cfg = Config::from_url(REDIS_URI);
let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap());
let key = "test_token_bucket".to_string();
let key = "token_bucket_many_test_key".to_string();

// Before running the test, ensure the test keys are cleared
redis_clear_keys(REDIS_URI, &[key.clone()]).await;

let max_tokens = 10;
let refill_interval = chrono::Duration::try_milliseconds(100).unwrap();
let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap();
let refill_rate = 1;
let rate_limit = || async {
token_bucket_many(
Expand All @@ -136,23 +158,78 @@ mod tests {
.unwrap()
.to_owned()
};
let call_rate_limit_loop = || async {
for i in 0..=max_tokens {
let curr_iter = max_tokens as i64 - i as i64 - 1;
let result = rate_limit().await;
assert_eq!(result.0, curr_iter);
}
};

// Iterate over the max tokens
for i in 0..=max_tokens {
let curr_iter = max_tokens as i64 - i as i64 - 1;
let result = rate_limit().await;
assert_eq!(result.0, curr_iter);
}
// Call rate limit until max tokens limit is reached
call_rate_limit_loop().await;

// Sleep for refill and try again
// Tokens numbers should be the same as the previous iteration
// Tokens numbers should be the same as the previous iteration because
// they were refilled
sleep((refill_interval * max_tokens as i32).to_std().unwrap()).await;
call_rate_limit_loop().await;

for i in 0..=max_tokens {
let curr_iter = max_tokens as i64 - i as i64 - 1;
let result = rate_limit().await;
assert_eq!(result.0, curr_iter);
}
// Clear keys after the test
redis_clear_keys(REDIS_URI, &[key.clone()]).await;
}

#[tokio::test]
async fn test_token_bucket() {
// Create Moka cache with a TTL of the refill interval
let cache: Cache<String, u64> = Cache::builder()
.time_to_live(std::time::Duration::from_millis(
REFILL_INTERVAL_MILLIS as u64,
))
.build();

let cfg = Config::from_url(REDIS_URI);
let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap());
let key = "token_bucket_test_key".to_string();

// Before running the test, ensure the test keys are cleared
redis_clear_keys(REDIS_URI, &[key.clone()]).await;

let max_tokens = 10;
let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap();
let refill_rate = 1;
let rate_limit = || async {
token_bucket(
&cache,
&pool,
key.clone(),
max_tokens,
refill_interval,
refill_rate,
)
.await
};
let call_rate_limit_loop = || async {
for i in 0..=max_tokens {
let result = rate_limit().await;
if i == max_tokens {
assert!(result
.err()
.unwrap()
.to_string()
.contains("Rate limit exceeded"));
} else {
assert!(result.is_ok());
}
}
};

// Call rate limit until max tokens limit is reached
call_rate_limit_loop().await;

// Sleep for refill and try again
sleep((refill_interval * max_tokens as i32).to_std().unwrap()).await;
call_rate_limit_loop().await;

// Clear keys after the test
redis_clear_keys(REDIS_URI, &[key.clone()]).await;
Expand Down

0 comments on commit c9bde58

Please sign in to comment.