From 7cc928f0ec921b1772d7e0801063572f5c0d514f Mon Sep 17 00:00:00 2001 From: onur-ozkan Date: Mon, 7 Oct 2024 11:42:22 +0300 Subject: [PATCH] sync upstream expirable map impl Signed-off-by: onur-ozkan --- src/expirable_map.rs | 146 +++++++++++++++++++++++++----------------- src/proxy/http/mod.rs | 24 +++++-- 2 files changed, 107 insertions(+), 63 deletions(-) diff --git a/src/expirable_map.rs b/src/expirable_map.rs index 33db871..35ef3da 100644 --- a/src/expirable_map.rs +++ b/src/expirable_map.rs @@ -1,10 +1,13 @@ -//! Provides a map that associates values with keys and supports expiring entries. +//! This module provides a cross-compatible map that associates values with keys and supports expiring entries. //! //! Designed for performance-oriented use-cases utilizing `FxHashMap` under the hood, //! and is not suitable for cryptographic purposes. +#![allow(dead_code)] + use rustc_hash::FxHashMap; use std::{ + collections::BTreeMap, hash::Hash, time::{Duration, Instant}, }; @@ -16,44 +19,85 @@ pub struct ExpirableEntry { } impl ExpirableEntry { - #[allow(dead_code)] + #[inline(always)] + pub fn new(v: V, exp: Duration) -> Self { + Self { + expires_at: Instant::now() + exp, + value: v, + } + } + + #[inline(always)] pub fn get_element(&self) -> &V { &self.value } - #[allow(dead_code)] + #[inline(always)] + pub fn update_value(&mut self, v: V) { + self.value = v + } + + #[inline(always)] pub fn update_expiration(&mut self, expires_at: Instant) { self.expires_at = expires_at } + + /// Checks whether entry has longer ttl than the given one. + #[inline(always)] + pub fn has_longer_life_than(&self, min_ttl: Duration) -> bool { + self.expires_at > Instant::now() + min_ttl + } } -impl Default for ExpirableMap { +impl Default for ExpirableMap { fn default() -> Self { Self::new() } } /// A map that allows associating values with keys and expiring entries. -/// It is important to note that this implementation does not automatically -/// remove any entries; it is the caller's responsibility to invoke `clear_expired_entries` -/// at specified intervals. +/// It is important to note that this implementation does not have a background worker to +/// automatically clear expired entries. Outdated entries are only removed when the control flow +/// is handed back to the map mutably (i.e. some mutable method of the map is invoked). /// /// WARNING: This is designed for performance-oriented use-cases utilizing `FxHashMap` /// under the hood and is not suitable for cryptographic purposes. #[derive(Clone, Debug)] -pub struct ExpirableMap(FxHashMap>); +pub struct ExpirableMap { + map: FxHashMap>, + /// A sorted inverse map from expiration times to keys to speed up expired entries clearing. + expiries: BTreeMap, +} -impl ExpirableMap { +impl ExpirableMap { /// Creates a new empty `ExpirableMap` #[inline] pub fn new() -> Self { - Self(FxHashMap::default()) + Self { + map: FxHashMap::default(), + expiries: BTreeMap::new(), + } + } + + /// Returns the associated value if present and not expired. + #[inline] + pub fn get(&self, k: &K) -> Option<&V> { + self.map + .get(k) + .filter(|v| v.expires_at > Instant::now()) + .map(|v| &v.value) } - /// Returns the associated value if present. + /// Removes a key-value pair from the map and returns the associated value if present and not expired. #[inline] - pub fn get(&mut self, k: &K) -> Option<&V> { - self.0.get(k).map(|v| &v.value) + pub fn remove(&mut self, k: &K) -> Option { + self.map + .remove(k) + .filter(|v| v.expires_at > Instant::now()) + .map(|v| { + self.expiries.remove(&v.expires_at); + v.value + }) } /// Inserts a key-value pair with an expiration duration. @@ -61,24 +105,31 @@ impl ExpirableMap { /// If a value already exists for the given key, it will be updated and then /// the old one will be returned. pub fn insert(&mut self, k: K, v: V, exp: Duration) -> Option { - let entry = ExpirableEntry { - expires_at: Instant::now() + exp, - value: v, - }; - - self.0.insert(k, entry).map(|v| v.value) + self.clear_expired_entries(); + let entry = ExpirableEntry::new(v, exp); + self.expiries.insert(entry.expires_at, k); + self.map.insert(k, entry).map(|v| v.value) } /// Removes expired entries from the map. - pub fn clear_expired_entries(&mut self) { - self.0.retain(|_k, v| Instant::now() < v.expires_at); - } - - /// Removes a key-value pair from the map and returns the associated value if present. - #[inline] - #[allow(dead_code)] - pub fn remove(&mut self, k: &K) -> Option { - self.0.remove(k).map(|v| v.value) + /// + /// Iterates through the `expiries` in order, removing entries that have expired. + /// Stops at the first non-expired entry, leveraging the sorted nature of `BTreeMap`. + fn clear_expired_entries(&mut self) { + let now = Instant::now(); + + // `pop_first()` is used here as it efficiently removes expired entries. + // `first_key_value()` was considered as it wouldn't need re-insertion for + // non-expired entries, but it would require an extra remove operation for + // each expired entry. `pop_first()` needs only one re-insertion per call, + // which is an acceptable trade-off compared to multiple remove operations. + while let Some((exp, key)) = self.expiries.pop_first() { + if exp > now { + self.expiries.insert(exp, key); + break; + } + self.map.remove(&key); + } } } @@ -86,15 +137,14 @@ impl ExpirableMap { mod tests { use super::*; - #[tokio::test] async fn test_clear_expired_entries() { let mut expirable_map = ExpirableMap::new(); let value = "test_value"; let exp = Duration::from_secs(1); // Insert 2 entries with 1 sec expiration time - expirable_map.insert("key1".to_string(), value.to_string(), exp); - expirable_map.insert("key2".to_string(), value.to_string(), exp); + expirable_map.insert("key1", value, exp); + expirable_map.insert("key2", value, exp); // Wait for entries to expire tokio::time::sleep(Duration::from_secs(2)).await; @@ -103,34 +153,14 @@ mod tests { expirable_map.clear_expired_entries(); // We waited for 2 seconds, so we shouldn't have any entry accessible - assert_eq!(expirable_map.0.len(), 0); + assert_eq!(expirable_map.map.len(), 0); // Insert 5 entries - expirable_map.insert( - "key1".to_string(), - value.to_string(), - Duration::from_secs(5), - ); - expirable_map.insert( - "key2".to_string(), - value.to_string(), - Duration::from_secs(4), - ); - expirable_map.insert( - "key3".to_string(), - value.to_string(), - Duration::from_secs(7), - ); - expirable_map.insert( - "key4".to_string(), - value.to_string(), - Duration::from_secs(2), - ); - expirable_map.insert( - "key5".to_string(), - value.to_string(), - Duration::from_millis(3750), - ); + expirable_map.insert("key1", value, Duration::from_secs(5)); + expirable_map.insert("key2", value, Duration::from_secs(4)); + expirable_map.insert("key3", value, Duration::from_secs(7)); + expirable_map.insert("key4", value, Duration::from_secs(2)); + expirable_map.insert("key5", value, Duration::from_millis(3750)); // Wait 2 seconds to expire some entries tokio::time::sleep(Duration::from_secs(2)).await; @@ -139,6 +169,6 @@ mod tests { expirable_map.clear_expired_entries(); // We waited for 2 seconds, only one entry should expire - assert_eq!(expirable_map.0.len(), 4); + assert_eq!(expirable_map.map.len(), 4); } } diff --git a/src/proxy/http/mod.rs b/src/proxy/http/mod.rs index 2527933..dd63a59 100644 --- a/src/proxy/http/mod.rs +++ b/src/proxy/http/mod.rs @@ -1,6 +1,7 @@ use hyper::{StatusCode, Uri}; +use libp2p::PeerId; use proxy_signature::ProxySign; -use std::{net::SocketAddr, sync::LazyLock, time::Duration}; +use std::{net::SocketAddr, str::FromStr, sync::LazyLock, time::Duration}; use tokio::sync::Mutex; use crate::{ @@ -108,19 +109,32 @@ async fn peer_connection_healthcheck( // for 10 seconds without asking again. let know_peer_expiration = Duration::from_secs(cfg.peer_healthcheck_caching_secs); - static KNOWN_PEERS: LazyLock>> = + static KNOWN_PEERS: LazyLock>> = LazyLock::new(|| Mutex::new(ExpirableMap::new())); let mut know_peers = KNOWN_PEERS.lock().await; - know_peers.clear_expired_entries(); - let is_known = know_peers.get(&signed_message.address).is_some(); + let Ok(peer_id) = PeerId::from_str(&signed_message.address) else { + tracked_log( + log::Level::Warn, + remote_addr.ip(), + &signed_message.address, + req_uri, + format!( + "Peer id '{}' isn't valid, returning 401", + signed_message.address + ), + ); + return Err(StatusCode::UNAUTHORIZED); + }; + + let is_known = know_peers.get(&peer_id).is_some(); if !is_known { match peer_connection_healthcheck_rpc(cfg, &signed_message.address).await { Ok(response) => { if response["result"] == serde_json::json!(true) { - know_peers.insert(signed_message.address.clone(), (), know_peer_expiration); + know_peers.insert(peer_id, (), know_peer_expiration); } else { tracked_log( log::Level::Warn,