diff --git a/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs b/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs index 955d24d9e9..7de66f59cf 100644 --- a/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs +++ b/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs @@ -2,7 +2,7 @@ use crate::cluster_async::ConnectionFuture; use crate::cluster_routing::{Route, ShardAddrs, SlotAddr}; use crate::cluster_slotmap::{ReadFromReplicaStrategy, SlotMap, SlotMapValue}; use crate::cluster_topology::TopologyHash; -use dashmap::DashMap; +use dashmap::{DashMap, DashSet}; use futures::FutureExt; use rand::seq::IteratorRandom; use std::net::IpAddr; @@ -10,6 +10,8 @@ use std::sync::atomic::Ordering; use std::sync::Arc; use telemetrylib::Telemetry; +use tokio::task::JoinHandle; + /// Count the number of connections in a connections_map object macro_rules! count_connections { ($conn_map:expr) => {{ @@ -121,6 +123,12 @@ pub(crate) enum ConnectionType { pub(crate) struct ConnectionsMap(pub(crate) DashMap>); +pub(crate) struct RefreshState { + pub handle: JoinHandle<()>, // The currect running refresh task + pub node_conn: Option> // The refreshed connection after the task is done +} + + impl std::fmt::Display for ConnectionsMap { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { for item in self.0.iter() { @@ -139,6 +147,14 @@ pub(crate) struct ConnectionsContainer { pub(crate) slot_map: SlotMap, read_from_replica_strategy: ReadFromReplicaStrategy, topology_hash: TopologyHash, + + + // Holds all the failed addresses that started a refresh task. + pub(crate) refresh_addresses_started: DashSet, + // Follow the refresh ops on the connections + pub(crate) refresh_operations: DashMap>, + // Holds all the refreshed addresses that are ready to be inserted into the connection_map + pub(crate) refresh_addresses_done: DashSet, } impl Drop for ConnectionsContainer { @@ -155,6 +171,9 @@ impl Default for ConnectionsContainer { slot_map: Default::default(), read_from_replica_strategy: ReadFromReplicaStrategy::AlwaysFromPrimary, topology_hash: 0, + refresh_addresses_started: DashSet::new(), + refresh_operations: DashMap::new(), + refresh_addresses_done: DashSet::new(), } } } @@ -182,6 +201,9 @@ where slot_map, read_from_replica_strategy, topology_hash, + refresh_addresses_started: DashSet::new(), + refresh_operations: DashMap::new(), + refresh_addresses_done: DashSet::new(), } } @@ -572,6 +594,9 @@ mod tests { connection_map, read_from_replica_strategy: ReadFromReplicaStrategy::AZAffinity("use-1a".to_string()), topology_hash: 0, + refresh_addresses_started: DashSet::new(), + refresh_operations: DashMap::new(), + refresh_addresses_done: DashSet::new(), } } @@ -628,6 +653,9 @@ mod tests { connection_map, read_from_replica_strategy: strategy, topology_hash: 0, + refresh_addresses_started: DashSet::new(), + refresh_operations: DashMap::new(), + refresh_addresses_done: DashSet::new(), } } diff --git a/glide-core/redis-rs/redis/src/cluster_async/mod.rs b/glide-core/redis-rs/redis/src/cluster_async/mod.rs index 534fdd429e..679098ef02 100644 --- a/glide-core/redis-rs/redis/src/cluster_async/mod.rs +++ b/glide-core/redis-rs/redis/src/cluster_async/mod.rs @@ -40,18 +40,13 @@ use crate::{ commands::cluster_scan::{cluster_scan, ClusterScanArgs, ScanStateRC}, FromRedisValue, InfoDict, }; +use connections_container::RefreshState; use dashmap::DashMap; use std::{ - collections::{HashMap, HashSet}, - fmt, io, mem, - net::{IpAddr, SocketAddr}, - pin::Pin, - sync::{ + collections::{HashMap, HashSet}, fmt, io, iter::once, mem, net::{IpAddr, SocketAddr}, pin::Pin, sync::{ atomic::{self, AtomicUsize, Ordering}, Arc, Mutex, - }, - task::{self, Poll}, - time::SystemTime, + }, task::{self, Poll}, time::SystemTime }; use strum_macros::Display; #[cfg(feature = "tokio-comp")] @@ -1341,13 +1336,13 @@ where } // identify nodes with closed connection - let mut addrs_to_refresh = Vec::new(); + let mut addrs_to_refresh = HashSet::new(); for (addr, con_fut) in &all_valid_conns { let con = con_fut.clone().await; // connection object might be present despite the transport being closed if con.is_closed() { // transport is closed, need to refresh - addrs_to_refresh.push(addr.clone()); + addrs_to_refresh.insert(addr.clone()); } } @@ -1365,68 +1360,95 @@ where inner.clone(), addrs_to_refresh, RefreshConnectionType::AllConnections, - false, - ) - .await; + ).await; } } async fn refresh_connections( inner: Arc>, - addresses: Vec, + addresses: HashSet, conn_type: RefreshConnectionType, - check_existing_conn: bool, ) { info!("Started refreshing connections to {:?}", addresses); - let mut tasks = FuturesUnordered::new(); - let inner = inner.clone(); - for address in addresses.into_iter() { - let inner = inner.clone(); + let connections_container = inner.conn_lock.read().expect(MUTEX_READ_ERR); + let refresh_ops_map = &connections_container.refresh_operations; - tasks.push(async move { - let node_option = if check_existing_conn { - let connections_container = inner.conn_lock.read().expect(MUTEX_READ_ERR); - connections_container.remove_node(&address) - } else { - None - }; + for address in addresses { + if refresh_ops_map.contains_key(&address) { + info!("Skipping refresh for {}: already in progress", address); + continue; + } + + let inner_clone = inner.clone(); + let address_clone = address.clone(); + let address_clone_for_task = address.clone(); + + let handle = tokio::spawn(async move { + info!("Refreshing connection task to {:?} started", address_clone_for_task); + let _ = async { + // Add this address to be removed in poll_flush so all requests see a consistent connection map. + // See next comment for elaborated explanation. + inner_clone.conn_lock.read().expect(MUTEX_READ_ERR).refresh_addresses_done.insert(address_clone_for_task.clone()); + + let mut cluster_params = inner_clone.cluster_params.read().expect(MUTEX_READ_ERR).clone(); + let subs_guard = inner_clone.subscriptions_by_address.read().await; + cluster_params.pubsub_subscriptions = subs_guard.get(&address_clone_for_task).cloned(); + drop(subs_guard); - // Override subscriptions for this connection - let mut cluster_params = inner.cluster_params.read().expect(MUTEX_READ_ERR).clone(); - let subs_guard = inner.subscriptions_by_address.read().await; - cluster_params.pubsub_subscriptions = subs_guard.get(&address).cloned(); - drop(subs_guard); - - let node = get_or_create_conn( - &address, - node_option, - &cluster_params, - conn_type, - inner.glide_connection_options.clone(), - ) - .await; + let node_result = get_or_create_conn( + &address_clone_for_task, + None, + &cluster_params, + conn_type, + inner_clone.glide_connection_options.clone(), + ) + .await; - (address, node) + match node_result { + Ok(node) => { + // Maintain the newly refreshed connection separately from the main connection map. + // This refreshed connection will be incorporated into the main connection map at the start of the poll_flush operation. + // This approach ensures that all requests within the current batch interact with a consistent connection map, + // preventing potential reordering issues. + // + // By delaying the integration of the refreshed connection: + // + // 1. We maintain consistency throughout the processing of a batch of requests. + // 2. We avoid mid-batch changes to the connection map that could lead to inconsistent routing or ordering of operations. + // 3. We ensure that all requests in a batch see the same cluster topology, reducing the risk of race conditions or unexpected behavior. + // + // This strategy effectively creates a synchronization point at the beginning of poll_flush, where the connection map is + // updated atomically for the next batch of operations. This approach balances the need for up-to-date connection information + // with the requirement for consistent request handling within each processing cycle. + let connection_container = inner_clone.conn_lock.read().expect(MUTEX_READ_ERR); + if let Some(mut refresh_state) = connection_container.refresh_operations.get_mut(&address_clone_for_task) { + refresh_state.node_conn = Some(node); + } + connection_container.refresh_addresses_done.insert(address_clone_for_task); + Ok(()) + } + Err(err) => { + warn!( + "Failed to refresh connection for node {}. Error: `{:?}`", + address_clone_for_task, err + ); + Err(err) + } + } + }.await; + + info!("Refreshing connection task to {:?} is done", address_clone); }); - } - // Poll connection tasks as soon as each one finishes - while let Some(result) = tasks.next().await { - match result { - (address, Ok(node)) => { - let connections_container = inner.conn_lock.read().expect(MUTEX_READ_ERR); - connections_container.replace_or_add_connection_for_address(address, node); - } - (address, Err(err)) => { - warn!( - "Failed to refresh connection for node {}. Error: `{:?}`", - address, err - ); - } - } + // Keep the task handle into the RefreshState of this address + info!("Inserting tokio task to refresh_ops map of address {:?}", address.clone()); + refresh_ops_map.insert(address, RefreshState { + handle, + node_conn: None, + }); } - debug!("refresh connections completed"); + debug!("refresh connection tasts initiated"); } async fn aggregate_results( @@ -1762,11 +1784,9 @@ where // immediately trigger connection reestablishment Self::refresh_connections( inner.clone(), - addrs_to_refresh.into_iter().collect(), + addrs_to_refresh, RefreshConnectionType::AllConnections, - false, - ) - .await; + ).await; } } @@ -1798,9 +1818,8 @@ where if !failed_connections.is_empty() { Self::refresh_connections( inner, - failed_connections, + failed_connections.into_iter().collect::>(), RefreshConnectionType::OnlyManagementConnection, - true, ) .await; } @@ -2271,32 +2290,12 @@ where let (address, mut conn) = match conn_check { ConnectionCheck::Found((address, connection)) => (address, connection.await), ConnectionCheck::OnlyAddress(addr) => { - let mut this_conn_params = core.get_cluster_param(|params| params.clone())?; - let subs_guard = core.subscriptions_by_address.read().await; - this_conn_params.pubsub_subscriptions = subs_guard.get(addr.as_str()).cloned(); - drop(subs_guard); - match connect_and_check::( - &addr, - this_conn_params, - None, - RefreshConnectionType::AllConnections, - None, - core.glide_connection_options.clone(), - ) - .await - .get_node() - { - Ok(node) => { - let connection_clone = node.user_connection.conn.clone().await; - let connections = core.conn_lock.read().expect(MUTEX_READ_ERR); - let address = connections.replace_or_add_connection_for_address(addr, node); - drop(connections); - (address, connection_clone) - } - Err(err) => { - return Err(err); - } - } + // No connection in for this address in the conn_map + Self::refresh_connections(core, HashSet::from_iter(once(addr)),RefreshConnectionType::AllConnections).await; + return Err(RedisError::from(( + ErrorKind::AllConnectionsUnavailable, + "No connection for the address, started a refresh task", + ))); } ConnectionCheck::RandomConnection => { let random_conn = core @@ -2393,6 +2392,47 @@ where Self::try_request(info, core).await } + fn update_refreshed_connection(&mut self) { + loop { + let connections_container = self.inner.conn_lock.read().expect(MUTEX_WRITE_ERR); + + // Process refresh_addresses_started + let addresses_to_remove: Vec = connections_container.refresh_addresses_started.iter().map(|r| r.key().clone()).collect(); + for address in addresses_to_remove { + connections_container.refresh_addresses_started.remove(&address); + connections_container.remove_node(&address); + } + + // Process refresh_addresses_done + let addresses_done: Vec = connections_container.refresh_addresses_done.iter().map(|r| r.key().clone()).collect(); + for address in addresses_done { + connections_container.refresh_addresses_done.remove(&address); + + if let Some(mut refresh_state) = connections_container.refresh_operations.get_mut(&address) { + info!("update_refreshed_connection: Update conn for addr: {}", address); + + // Take the node_conn out of RefreshState, replacing it with None + if let Some(node_conn) = mem::take(&mut refresh_state.node_conn) { + info!("update_refreshed_connection: replacing/adding the conn"); + // Move the node_conn to the function + connections_container.replace_or_add_connection_for_address(address.clone(), node_conn); + } + } + // Remove this entry from refresh_ops_map + connections_container.refresh_operations.remove(&address); + + } + + // Check if both sets are empty + if connections_container.refresh_addresses_started.is_empty() && connections_container.refresh_addresses_done.is_empty() { + break; + } + + // Release the lock before the next iteration + drop(connections_container); + } + } + fn poll_complete(&mut self, cx: &mut task::Context<'_>) -> Poll { let retry_params = self .inner @@ -2498,7 +2538,7 @@ where } Next::Reconnect { request, target } => { poll_flush_action = - poll_flush_action.change_state(PollFlushAction::Reconnect(vec![target])); + poll_flush_action.change_state(PollFlushAction::Reconnect(HashSet::from_iter([target]))); if let Some(request) = request { self.inner.pending_requests.lock().unwrap().push(request); } @@ -2543,7 +2583,7 @@ where enum PollFlushAction { None, RebuildSlots, - Reconnect(Vec), + Reconnect(HashSet), ReconnectFromInitialConnections, } @@ -2581,18 +2621,18 @@ where fn start_send(self: Pin<&mut Self>, msg: Message) -> Result<(), Self::Error> { let Message { cmd, sender } = msg; - + let info = RequestInfo { cmd }; - + self.inner .pending_requests .lock() .unwrap() .push(PendingRequest { - retry: 0, - sender, - info, - }); + retry: 0, + sender, + info, + }); Ok(()) } @@ -2617,6 +2657,12 @@ where return Poll::Pending; } + // Updating the connection_map with all the refreshed_connections + // In case of active poll_recovery, the work should + // take care of the refreshed_connection, add them if still relevant, and kill the refresh_tasks of + // non-relevant addresses. + self.update_refreshed_connection(); + match ready!(self.poll_complete(cx)) { PollFlushAction::None => return Poll::Ready(Ok(())), PollFlushAction::RebuildSlots => { @@ -2632,8 +2678,7 @@ where ClusterConnInner::refresh_connections( self.inner.clone(), addresses, - RefreshConnectionType::OnlyUserConnection, - true, + RefreshConnectionType::AllConnections, ), ))); }