From dcb81d8bb15338c78dd3ecfa9b041a3806f46c2e Mon Sep 17 00:00:00 2001 From: doscortados Date: Thu, 6 Feb 2025 08:38:24 +0000 Subject: [PATCH] review: address final comments --- core/src/errors.rs | 3 +- core/src/rpc/aggregator.rs | 92 ++++++++++++++++++++++---------------- 2 files changed, 56 insertions(+), 39 deletions(-) diff --git a/core/src/errors.rs b/core/src/errors.rs index 1149bfd8..6bd01eda 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -111,7 +111,8 @@ pub enum BridgeError { RPCStreamEndedUnexpectedly(String), #[error("Invalid response from an RPC endpoint: {0}")] RPCInvalidResponse(String), - + #[error("RPC broadcast receiver failed: {0}")] + RPCBroadcastRecvError(#[from] tokio::sync::broadcast::error::RecvError), /// ConfigError is returned when the configuration is invalid #[error("ConfigError: {0}")] ConfigError(String), diff --git a/core/src/rpc/aggregator.rs b/core/src/rpc/aggregator.rs index 1ad21cb3..cbf3146a 100644 --- a/core/src/rpc/aggregator.rs +++ b/core/src/rpc/aggregator.rs @@ -262,6 +262,36 @@ async fn create_nonce_streams( Ok((first_responses, transformed_streams)) } +/// Use items collected from the broadcast receiver for an async function call. +/// +/// Handles the boilerplate of managing a receiver of a broadcast channel. +/// If receiver is lagged at any time (data is lost) an error is returned. +async fn collect_and_call( + rx: &mut tokio::sync::broadcast::Receiver>, + f: F, +) -> Result +where + R: Default, + T: Clone, + F: Fn(Vec) -> Fut, + Fut: Future>, +{ + loop { + match rx.recv().await { + Ok(params) => { + f(params).await?; + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + break Err(BridgeError::RPCStreamEndedUnexpectedly(format!( + "lost {n} items due to lagging receiver" + )) + .into()); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => break Ok(R::default()), + } + } +} + impl Aggregator { // Extracts pub_nonce from given stream. fn extract_pub_nonce( @@ -404,13 +434,12 @@ impl ClementineAggregator for Aggregator { let (operator_params_tx, operator_params_rx) = tokio::sync::broadcast::channel(CHANNEL_CAPACITY); - let tx = operator_params_tx.clone(); let operators = self.operator_clients.clone(); let get_operator_params_chunked_handle = tokio::spawn(async move { tracing::info!(clients = operators.len(), "Collecting operator details..."); try_join_all(operators.iter().map(|operator| { let mut operator = operator.clone(); - let tx = tx.clone(); + let tx = operator_params_tx.clone(); async move { let stream = operator .get_params(Request::new(Empty {})) @@ -427,28 +456,22 @@ impl ClementineAggregator for Aggregator { Ok::<_, Status>(()) }); - drop(operator_params_tx); let verifiers = self.verifier_clients.clone(); let set_operator_params_handle = tokio::spawn(async move { tracing::info!("Informing verifiers for existing operators..."); try_join_all(verifiers.iter().map(|verifier| { - let mut verifier = verifier.clone(); - let operator_params_rx = &operator_params_rx; + let verifier = verifier.clone(); + let rx = &operator_params_rx; async move { - let mut operator_params_rx = operator_params_rx.resubscribe(); - loop { - match operator_params_rx.recv().await { - Ok(params) => { - verifier.set_operator(futures::stream::iter(params)).await?; - } - Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { - break Err(BridgeError::RPCStreamEndedUnexpectedly(format!( - "lost {n} operator params batches due to lagging receiver" - ))); - } - Err(tokio::sync::broadcast::error::RecvError::Closed) => break Ok(()), + let mut rx = rx.resubscribe(); + collect_and_call(&mut rx, |params| { + let mut verifier = verifier.clone(); + async move { + verifier.set_operator(futures::stream::iter(params)).await?; + Ok::<_, Status>(()) } - }?; + }) + .await?; Ok::<_, Status>(()) } })) @@ -461,13 +484,12 @@ impl ClementineAggregator for Aggregator { let (watchtower_params_tx, watchtower_params_rx) = tokio::sync::broadcast::channel(CHANNEL_CAPACITY); - let tx = watchtower_params_tx.clone(); let watchtowers = self.watchtower_clients.clone(); let get_watchtower_params_chunked_handle = tokio::spawn(async move { tracing::info!("Collecting Winternitz public keys from watchtowers..."); try_join_all(watchtowers.iter().map(|watchtower| { let mut watchtower = watchtower.clone(); - let tx = tx.clone(); + let tx = watchtower_params_tx.clone(); async move { let stream = watchtower .get_params(Request::new(Empty {})) @@ -484,30 +506,24 @@ impl ClementineAggregator for Aggregator { Ok::<_, Status>(()) }); - drop(watchtower_params_tx); let verifiers = self.verifier_clients.clone(); let set_watchtower_params_handle = tokio::spawn(async move { tracing::info!("Sending Winternitz public keys to verifiers..."); try_join_all(verifiers.iter().map(|verifier| { - let mut verifier = verifier.clone(); - let watchtower_params_rx = &watchtower_params_rx; + let verifier = verifier.clone(); + let rx = &watchtower_params_rx; async move { - let mut watchtower_params_rx = watchtower_params_rx.resubscribe(); - loop { - match watchtower_params_rx.recv().await { - Ok(params) => { - verifier - .set_watchtower(futures::stream::iter(params)) - .await?; - } - Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { - break Err(BridgeError::RPCStreamEndedUnexpectedly(format!( - "lost {n} watchtower params batches due to lagging receiver" - ))); - } - Err(tokio::sync::broadcast::error::RecvError::Closed) => break Ok(()), + let mut rx = rx.resubscribe(); + collect_and_call(&mut rx, |params| { + let mut verifier = verifier.clone(); + async move { + verifier + .set_watchtower(futures::stream::iter(params)) + .await?; + Ok::<_, Status>(()) } - }?; + }) + .await?; Ok::<_, Status>(()) } }))