diff --git a/crates/rate_limit/Cargo.toml b/crates/rate_limit/Cargo.toml index 8b4a7db..1171b76 100644 --- a/crates/rate_limit/Cargo.toml +++ b/crates/rate_limit/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] chrono = { version = "0.4", features = ["serde"] } deadpool-redis = "0.12" +moka = { version = "0.12", features = ["future"] } redis = { version = "0.23", default-features = false, features = ["script"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/crates/rate_limit/src/lib.rs b/crates/rate_limit/src/lib.rs index 288e7e1..b0dfed9 100644 --- a/crates/rate_limit/src/lib.rs +++ b/crates/rate_limit/src/lib.rs @@ -2,6 +2,7 @@ use { chrono::{DateTime, Duration, Utc}, core::fmt, deadpool_redis::{Pool, PoolError}, + moka::future::Cache, redis::{RedisError, Script}, std::{collections::HashMap, sync::Arc}, }; @@ -35,13 +36,25 @@ pub enum RateLimitError { Internal(InternalRateLimitError), } +/// Rate limit check using a token bucket algorithm for one key and in-memory +/// cache for rate-limited keys. `mem_cache` TTL must be set to the same value +/// as the refill interval. pub async fn token_bucket( + mem_cache: &Cache, redis_write_pool: &Arc, key: String, max_tokens: u32, interval: Duration, refill_rate: u32, ) -> Result<(), RateLimitError> { + // Check if the key is in the memory cache of rate limited keys + // to omit the redis RTT in case of flood + if let Some(reset) = mem_cache.get(&key).await { + return Err(RateLimitError::RateLimitExceeded(RateLimitExceeded { + reset, + })); + } + let result = token_bucket_many( redis_write_pool, vec![key.clone()], @@ -54,14 +67,21 @@ pub async fn token_bucket( let (remaining, reset) = result.get(&key).expect("Should contain the key"); if remaining.is_negative() { + let reset_interval = reset / 1000; + + // Insert the rate-limited key into the memory cache to avoid the redis RTT in + // case of flood + mem_cache.insert(key, reset_interval).await; + Err(RateLimitError::RateLimitExceeded(RateLimitExceeded { - reset: reset / 1000, + reset: reset_interval, })) } else { Ok(()) } } +/// Rate limit check using a token bucket algorithm for many keys. pub async fn token_bucket_many( redis_write_pool: &Arc, keys: Vec, @@ -95,6 +115,8 @@ pub async fn token_bucket_many( #[cfg(test)] mod tests { const REDIS_URI: &str = "redis://localhost:6379"; + const REFILL_INTERVAL_MILLIS: i64 = 100; + use { super::*, deadpool_redis::{Config, Runtime}, @@ -111,16 +133,16 @@ mod tests { } #[tokio::test] - async fn test_token_bucket() { + async fn test_token_bucket_many() { let cfg = Config::from_url(REDIS_URI); let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); - let key = "test_token_bucket".to_string(); + let key = "token_bucket_many_test_key".to_string(); // Before running the test, ensure the test keys are cleared redis_clear_keys(REDIS_URI, &[key.clone()]).await; let max_tokens = 10; - let refill_interval = chrono::Duration::try_milliseconds(100).unwrap(); + let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap(); let refill_rate = 1; let rate_limit = || async { token_bucket_many( @@ -136,23 +158,78 @@ mod tests { .unwrap() .to_owned() }; + let call_rate_limit_loop = || async { + for i in 0..=max_tokens { + let curr_iter = max_tokens as i64 - i as i64 - 1; + let result = rate_limit().await; + assert_eq!(result.0, curr_iter); + } + }; - // Iterate over the max tokens - for i in 0..=max_tokens { - let curr_iter = max_tokens as i64 - i as i64 - 1; - let result = rate_limit().await; - assert_eq!(result.0, curr_iter); - } + // Call rate limit until max tokens limit is reached + call_rate_limit_loop().await; // Sleep for refill and try again - // Tokens numbers should be the same as the previous iteration + // Tokens numbers should be the same as the previous iteration because + // they were refilled sleep((refill_interval * max_tokens as i32).to_std().unwrap()).await; + call_rate_limit_loop().await; - for i in 0..=max_tokens { - let curr_iter = max_tokens as i64 - i as i64 - 1; - let result = rate_limit().await; - assert_eq!(result.0, curr_iter); - } + // Clear keys after the test + redis_clear_keys(REDIS_URI, &[key.clone()]).await; + } + + #[tokio::test] + async fn test_token_bucket() { + // Create Moka cache with a TTL of the refill interval + let cache: Cache = Cache::builder() + .time_to_live(std::time::Duration::from_millis( + REFILL_INTERVAL_MILLIS as u64, + )) + .build(); + + let cfg = Config::from_url(REDIS_URI); + let pool = Arc::new(cfg.create_pool(Some(Runtime::Tokio1)).unwrap()); + let key = "token_bucket_test_key".to_string(); + + // Before running the test, ensure the test keys are cleared + redis_clear_keys(REDIS_URI, &[key.clone()]).await; + + let max_tokens = 10; + let refill_interval = chrono::Duration::try_milliseconds(REFILL_INTERVAL_MILLIS).unwrap(); + let refill_rate = 1; + let rate_limit = || async { + token_bucket( + &cache, + &pool, + key.clone(), + max_tokens, + refill_interval, + refill_rate, + ) + .await + }; + let call_rate_limit_loop = || async { + for i in 0..=max_tokens { + let result = rate_limit().await; + if i == max_tokens { + assert!(result + .err() + .unwrap() + .to_string() + .contains("Rate limit exceeded")); + } else { + assert!(result.is_ok()); + } + } + }; + + // Call rate limit until max tokens limit is reached + call_rate_limit_loop().await; + + // Sleep for refill and try again + sleep((refill_interval * max_tokens as i32).to_std().unwrap()).await; + call_rate_limit_loop().await; // Clear keys after the test redis_clear_keys(REDIS_URI, &[key.clone()]).await;