Skip to content

Commit

Permalink
redis-rs: Reconnect in Tokio Task
Browse files Browse the repository at this point in the history
  • Loading branch information
GilboaAWS committed Dec 16, 2024
1 parent daa5af2 commit b1e24b8
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@ use crate::cluster_async::ConnectionFuture;
use crate::cluster_routing::{Route, ShardAddrs, SlotAddr};
use crate::cluster_slotmap::{ReadFromReplicaStrategy, SlotMap, SlotMapValue};
use crate::cluster_topology::TopologyHash;
use crate::RedisResult;
use dashmap::DashMap;
use futures::FutureExt;
use rand::seq::IteratorRandom;
use tokio::task::JoinHandle;
use std::net::IpAddr;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use telemetrylib::Telemetry;

use tokio::sync::oneshot;
use futures::future::Shared;

/// Count the number of connections in a connections_map object
macro_rules! count_connections {
($conn_map:expr) => {{
Expand Down Expand Up @@ -121,6 +126,11 @@ pub(crate) enum ConnectionType {

pub(crate) struct ConnectionsMap<Connection>(pub(crate) DashMap<String, ClusterNode<Connection>>);

pub(crate) struct RefreshState {
pub handle: JoinHandle<()>,
pub rx: Shared<oneshot::Receiver<Arc<RedisResult<()>>>>,
}

impl<Connection> std::fmt::Display for ConnectionsMap<Connection> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for item in self.0.iter() {
Expand All @@ -139,6 +149,9 @@ pub(crate) struct ConnectionsContainer<Connection> {
pub(crate) slot_map: SlotMap,
read_from_replica_strategy: ReadFromReplicaStrategy,
topology_hash: TopologyHash,

// Follow the refresh ops on the connections
pub(crate) refresh_operations: DashMap<String, RefreshState>,
}

impl<Connection> Drop for ConnectionsContainer<Connection> {
Expand All @@ -155,6 +168,7 @@ impl<Connection> Default for ConnectionsContainer<Connection> {
slot_map: Default::default(),
read_from_replica_strategy: ReadFromReplicaStrategy::AlwaysFromPrimary,
topology_hash: 0,
refresh_operations: DashMap::new(),
}
}
}
Expand Down Expand Up @@ -182,6 +196,7 @@ where
slot_map,
read_from_replica_strategy,
topology_hash,
refresh_operations: DashMap::new(),
}
}

Expand Down
198 changes: 132 additions & 66 deletions glide-core/redis-rs/redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ use crate::{
commands::cluster_scan::{cluster_scan, ClusterScanArgs, ObjectType, ScanStateRC},
FromRedisValue, InfoDict, ToRedisArgs,
};
use connections_container::RefreshState;
use dashmap::DashMap;
use std::{
collections::{HashMap, HashSet},
Expand Down Expand Up @@ -96,11 +97,11 @@ use dispose::{Disposable, Dispose};
use futures::{future::BoxFuture, prelude::*, ready};
use pin_project_lite::pin_project;
use std::sync::RwLock as StdRwLock;
use tokio::sync::{
use tokio::{sync::{
mpsc,
oneshot::{self, Receiver},
RwLock as TokioRwLock,
};
}, time::sleep};
use tracing::{debug, info, trace, warn};

use self::{
Expand Down Expand Up @@ -1432,68 +1433,102 @@ where
inner.clone(),
addrs_to_refresh,
RefreshConnectionType::AllConnections,
false,
)
.await;
true,
);
}
}

async fn refresh_connections(
inner: Arc<InnerCore<C>>,
pub(crate) fn refresh_connections(
inner_arg: Arc<InnerCore<C>>,
addresses: Vec<String>,
conn_type: RefreshConnectionType,
check_existing_conn: bool,
) {
info!("Started refreshing connections to {:?}", addresses);
let mut tasks = FuturesUnordered::new();
let inner = inner.clone();
let addresses_set = addresses.into_iter().collect::<HashSet<String>>();

for address in addresses.into_iter() {
let inner = inner.clone();
let connections_container = inner_arg.conn_lock.read().expect(MUTEX_READ_ERR);
let refresh_ops_map = &connections_container.refresh_operations;

tasks.push(async move {
let node_option = if check_existing_conn {
let connections_container = inner.conn_lock.read().expect(MUTEX_READ_ERR);
connections_container.remove_node(&address)
} else {
None
};
let address_list: Vec<String> = refresh_ops_map.iter()
.map(|entry| entry.key().clone())
.collect();
info!("Current addresses being refreshed: {:?}", address_list);

// Override subscriptions for this connection
let mut cluster_params = inner.cluster_params.read().expect(MUTEX_READ_ERR).clone();
let subs_guard = inner.subscriptions_by_address.read().await;
cluster_params.pubsub_subscriptions = subs_guard.get(&address).cloned();
drop(subs_guard);
info!("Started refreshing task to the set of connections to {:?}", addresses_set);
for address in addresses_set {
if refresh_ops_map.contains_key(&address) {
info!("Skipping refresh for {}: already in progress", address);
continue;
}

let node = get_or_create_conn(
&address,
node_option,
&cluster_params,
conn_type,
inner.glide_connection_options.clone(),
)
.await;
info!("Starting refresh task for {}", address);

let inner = inner_arg.clone();
let address_clone = address.clone();
let (tx, rx) = oneshot::channel();

let handle = tokio::spawn(async move {
info!("TOKIO refreshing connections to {:?}", address);
let result = async {
let node_option = if check_existing_conn {
let connections_container = inner.conn_lock.read().expect(MUTEX_READ_ERR);
connections_container.remove_node(&address)
} else {
None
};

println!("Timer started");
sleep(Duration::from_secs(5)).await;
println!("5 seconds have passed");

let mut cluster_params = inner.cluster_params.read().expect(MUTEX_READ_ERR).clone();
let subs_guard = inner.subscriptions_by_address.read().await;
cluster_params.pubsub_subscriptions = subs_guard.get(&address).cloned();
drop(subs_guard);

(address, node)
let node_result = get_or_create_conn(
&address,
node_option,
&cluster_params,
conn_type,
inner.glide_connection_options.clone(),
)
.await;

match node_result {
Ok(node) => {
let connections_container = inner.conn_lock.write().expect(MUTEX_WRITE_ERR);
connections_container.replace_or_add_connection_for_address(address.clone(), node);
Ok(())
}
Err(err) => {
warn!(
"Failed to refresh connection for node {}. Error: `{:?}`",
address, err
);
Err(err)
}
}
}.await;

// Send the result through the channel
let _ = tx.send(Arc::new(result));

// Remove the refresh operation from the map
let connections_container = inner.conn_lock.read().expect(MUTEX_READ_ERR);
connections_container.refresh_operations.remove(&address);
info!("TOKIO refreshing connections to {:?} is done", address);
});
}

// Poll connection tasks as soon as each one finishes
while let Some(result) = tasks.next().await {
match result {
(address, Ok(node)) => {
let connections_container = inner.conn_lock.read().expect(MUTEX_READ_ERR);
connections_container.replace_or_add_connection_for_address(address, node);
}
(address, Err(err)) => {
warn!(
"Failed to refresh connection for node {}. Error: `{:?}`",
address, err
);
}
}
// Add both the handle and the receiver to the refresh_operations in ConnectionContainer
info!("Inserting tokio task to refresh_ops map of address {:?}", address_clone);
refresh_ops_map.insert(address_clone, RefreshState {
handle,
rx: rx.shared()
});
}
debug!("refresh connections completed");
debug!("refresh connections initiated");
}

async fn aggregate_results(
Expand Down Expand Up @@ -1831,9 +1866,8 @@ where
inner.clone(),
addrs_to_refresh.into_iter().collect(),
RefreshConnectionType::AllConnections,
false,
)
.await;
true,
);
}
}

Expand Down Expand Up @@ -1868,8 +1902,7 @@ where
failed_connections,
RefreshConnectionType::OnlyManagementConnection,
true,
)
.await;
);
}

false
Expand Down Expand Up @@ -2332,24 +2365,56 @@ where
.connection_for_address(&address);
if let Some((address, conn)) = conn_option {
return Ok((address, conn.await));
}

// Check if this address is currently refreshing its connection
let connections_container = core.conn_lock.read().expect(MUTEX_READ_ERR);
if connections_container.refresh_operations.contains_key(&address) {
ConnectionCheck::OnlyAddress(address)
} else {
return Err((
ErrorKind::ConnectionNotFoundForRoute,
"Requested connection not found",
address,
)
.into());
).into());
}
}
};

let (address, mut conn) = match conn_check {
ConnectionCheck::Found((address, connection)) => (address, connection.await),
ConnectionCheck::OnlyAddress(addr) => {
let refresh_rx = {
let connections_container = core.conn_lock.read().expect(MUTEX_READ_ERR);
connections_container.refresh_operations
.get(&addr)
.map(|refresh_state| refresh_state.rx.clone())
}; // connections_container is dropped here

if let Some(rx) = refresh_rx {
// There's an ongoing refresh, wait for it to complete

info!("awaiting on a RX channel for address {:?}", addr);

let _ = rx.await;

let conn_option = core
.conn_lock
.read()
.expect(MUTEX_READ_ERR)
.connection_for_address(&addr);
if let Some((address, conn)) = conn_option {
return Ok((address, conn.await));
// connections_container is dropped here
}
}

// If we reach here, either there was no refresh or refresh didn't result in a connection
let mut this_conn_params = core.get_cluster_param(|params| params.clone())?;
let subs_guard = core.subscriptions_by_address.read().await;
this_conn_params.pubsub_subscriptions = subs_guard.get(addr.as_str()).cloned();
drop(subs_guard);

match connect_and_check::<C>(
&addr,
this_conn_params,
Expand All @@ -2363,9 +2428,10 @@ where
{
Ok(node) => {
let connection_clone = node.user_connection.conn.clone().await;
let connections = core.conn_lock.read().expect(MUTEX_READ_ERR);
let address = connections.replace_or_add_connection_for_address(addr, node);
drop(connections);
let address = {
let connections = core.conn_lock.read().expect(MUTEX_READ_ERR);
connections.replace_or_add_connection_for_address(addr.clone(), node)
}; // connections is dropped here
(address, connection_clone)
}
Err(err) => {
Expand Down Expand Up @@ -2675,6 +2741,7 @@ where
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
info!("barak");
trace!("poll_flush: {:?}", self.state);
loop {
self.send_refresh_error();
Expand Down Expand Up @@ -2703,14 +2770,13 @@ where
)));
}
PollFlushAction::Reconnect(addresses) => {
self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin(
ClusterConnInner::refresh_connections(
self.inner.clone(),
addresses,
RefreshConnectionType::OnlyUserConnection,
true,
),
)));
info!("got reconnect state from the poll_complete, calling refresh_connections");
ClusterConnInner::refresh_connections(
self.inner.clone(),
addresses,
RefreshConnectionType::OnlyUserConnection,
true,
);
}
PollFlushAction::ReconnectFromInitialConnections => {
self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin(
Expand Down

0 comments on commit b1e24b8

Please sign in to comment.