From ba8679d9c89794faacdf800e5a80deb432b58b08 Mon Sep 17 00:00:00 2001 From: doscortados Date: Tue, 28 Jan 2025 15:38:10 +0100 Subject: [PATCH 1/3] chore: extract `collect` & `pull` utilities --- core/src/rpc/aggregator.rs | 135 +++++++++++++++++++------------------ 1 file changed, 70 insertions(+), 65 deletions(-) diff --git a/core/src/rpc/aggregator.rs b/core/src/rpc/aggregator.rs index dddb7051..8b2b34a0 100644 --- a/core/src/rpc/aggregator.rs +++ b/core/src/rpc/aggregator.rs @@ -1,3 +1,5 @@ +use std::future::Future; + use super::clementine::{ clementine_aggregator_server::ClementineAggregator, verifier_deposit_finalize_params, DepositParams, Empty, RawSignedMoveTx, VerifierDepositFinalizeParams, @@ -284,117 +286,120 @@ impl Aggregator { } } +async fn collect(clients: &[C], mut f: F) -> Result, E> +where + C: Clone, + Fut: Future, E>>, + F: FnMut(C) -> Fut, +{ + try_join_all( + clients.iter().map(|client| { + f(client.clone()).map(|result| result.map(|response| response.into_inner())) + }), + ) + .await +} + +async fn pull(mut stream: Streaming) -> Result, Status> { + let mut ret = Vec::new(); + while let Some(next) = stream.message().await? { + ret.push(next); + } + Ok(ret) +} + +/* + +vp <- verifiers + +op <- operators +wp <- watchtovers + +op -> verifiers +wp -> verifiers + +*/ + #[async_trait] impl ClementineAggregator for Aggregator { + // TODO: HERE #464 #[tracing::instrument(skip_all, err(level = tracing::Level::ERROR), ret(level = tracing::Level::TRACE))] async fn setup(&self, _request: Request) -> Result, Status> { tracing::info!("Collecting verifier details..."); - let verifier_params = try_join_all(self.verifier_clients.iter().map(|client| { - let mut client = client.clone(); - async move { - let response = client.get_params(Request::new(Empty {})).await?; - Ok::<_, Status>(response.into_inner()) - } - })) + let verifier_params = collect(&self.verifier_clients, |mut client| async move { + client.get_params(Request::new(Empty {})).await + }) .await?; let verifier_public_keys: Vec> = verifier_params.into_iter().map(|p| p.public_key).collect(); tracing::debug!("Verifier public keys: {:?}", verifier_public_keys); tracing::info!("Setting up verifiers..."); - try_join_all(self.verifier_clients.iter().map(|client| { - let mut client = client.clone(); - { - let verifier_public_keys = clementine::VerifierPublicKeys { - verifier_public_keys: verifier_public_keys.clone(), - }; - async move { - let response = client - .set_verifiers(Request::new(verifier_public_keys)) - .await?; - Ok::<_, Status>(response.into_inner()) - } + collect(&self.verifier_clients, |mut client| { + let verifier_public_keys = clementine::VerifierPublicKeys { + verifier_public_keys: verifier_public_keys.clone(), + }; + async move { + client + .set_verifiers(Request::new(verifier_public_keys)) + .await } - })) + }) .await?; tracing::info!("Collecting operator details..."); - let operator_params = try_join_all(self.operator_clients.iter().map(|client| { - let mut client = client.clone(); - async move { - let mut responses = Vec::new(); - let mut params_stream = client - .get_params(Request::new(Empty {})) - .await? - .into_inner(); - while let Some(response) = params_stream.message().await? { - responses.push(response); - } + let operator_params = try_join_all( + collect(&self.operator_clients, |mut client| async move { + client.get_params(Request::new(Empty {})).await + }) + .await? + .into_iter() + .map(|stream| async move { pull(stream).await }) + ).await?; - Ok::<_, Status>(responses) - } - })) - .await?; + tracing::info!("Collecting Winternitz public keys from watchtowers..."); + let watchtower_params = try_join_all( + collect(&self.watchtower_clients, |mut client| async move { + client.get_params(Request::new(Empty {})).await + }) + .await? + .into_iter() + .map(|stream| async move { pull(stream).await }) + ).await?; tracing::info!("Informing verifiers for existing operators..."); try_join_all(self.verifier_clients.iter().map(|client| { let mut client = client.clone(); let operator_params = operator_params.clone(); - async move { for params in operator_params { - let (tx, rx) = tokio::sync::mpsc::channel(1280); + let (tx, rx) = tokio::sync::mpsc::channel(params.len()); let future = client.set_operator(tokio_stream::wrappers::ReceiverStream::new(rx)); - for param in params { tx.send(param).await.unwrap(); } - - future.await?; // TODO: This is dangerous: If channel size becomes not sufficient, this will block forever. + future.await?; } - Ok::<_, tonic::Status>(()) } })) .await?; - tracing::info!("Collecting Winternitz public keys from watchtowers..."); - let watchtower_params = try_join_all(self.watchtower_clients.iter().map(|client| { - let mut client = client.clone(); - async move { - let mut responses = Vec::new(); - let mut params_stream = client - .get_params(Request::new(Empty {})) - .await? - .into_inner(); - while let Some(response) = params_stream.message().await? { - responses.push(response); - } - - Ok::<_, Status>(responses) - } - })) - .await?; - tracing::info!("Sending Winternitz public keys to verifiers..."); try_join_all(self.verifier_clients.iter().map(|client| { let mut client = client.clone(); let watchtower_params = watchtower_params.clone(); - async move { for params in watchtower_params { - let (tx, rx) = tokio::sync::mpsc::channel(1280); - + let (tx, rx) = tokio::sync::mpsc::channel(params.len()); let future = client.set_watchtower(tokio_stream::wrappers::ReceiverStream::new(rx)); for param in params { tx.send(param).await.unwrap(); } - - future.await?; // TODO: This is dangerous: If channel size becomes not sufficient, this will block forever. + future.await?; } - Ok::<_, tonic::Status>(()) } })) From d66338f455c8d76cb2ec9bafb596733a9aa8e3ba Mon Sep 17 00:00:00 2001 From: doscortados Date: Tue, 28 Jan 2025 15:41:28 +0100 Subject: [PATCH 2/3] style: fmt --- core/src/rpc/aggregator.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/rpc/aggregator.rs b/core/src/rpc/aggregator.rs index 8b2b34a0..84bc4011 100644 --- a/core/src/rpc/aggregator.rs +++ b/core/src/rpc/aggregator.rs @@ -354,8 +354,9 @@ impl ClementineAggregator for Aggregator { }) .await? .into_iter() - .map(|stream| async move { pull(stream).await }) - ).await?; + .map(|stream| async move { pull(stream).await }), + ) + .await?; tracing::info!("Collecting Winternitz public keys from watchtowers..."); let watchtower_params = try_join_all( @@ -364,8 +365,9 @@ impl ClementineAggregator for Aggregator { }) .await? .into_iter() - .map(|stream| async move { pull(stream).await }) - ).await?; + .map(|stream| async move { pull(stream).await }), + ) + .await?; tracing::info!("Informing verifiers for existing operators..."); try_join_all(self.verifier_clients.iter().map(|client| { From 63439bb28c3c81eebdf0018d9e75737e4d57595d Mon Sep 17 00:00:00 2001 From: doscortados Date: Wed, 29 Jan 2025 11:29:16 +0100 Subject: [PATCH 3/3] chore: avoid `.unwrap()` calls --- core/src/rpc/aggregator.rs | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/core/src/rpc/aggregator.rs b/core/src/rpc/aggregator.rs index 84bc4011..c39da27a 100644 --- a/core/src/rpc/aggregator.rs +++ b/core/src/rpc/aggregator.rs @@ -19,6 +19,7 @@ use bitcoin::{Amount, TapSighash}; use futures::{future::try_join_all, stream::BoxStream, FutureExt, Stream, StreamExt}; use secp256k1::musig::{MusigAggNonce, MusigPartialSignature, MusigPubNonce}; use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio_stream::wrappers::ReceiverStream; use tonic::{async_trait, Request, Response, Status, Streaming}; struct AggNonceQueueItem { @@ -308,21 +309,8 @@ async fn pull(mut stream: Streaming) -> Result, Status> { Ok(ret) } -/* - -vp <- verifiers - -op <- operators -wp <- watchtovers - -op -> verifiers -wp -> verifiers - -*/ - #[async_trait] impl ClementineAggregator for Aggregator { - // TODO: HERE #464 #[tracing::instrument(skip_all, err(level = tracing::Level::ERROR), ret(level = tracing::Level::TRACE))] async fn setup(&self, _request: Request) -> Result, Status> { tracing::info!("Collecting verifier details..."); @@ -375,11 +363,11 @@ impl ClementineAggregator for Aggregator { let operator_params = operator_params.clone(); async move { for params in operator_params { - let (tx, rx) = tokio::sync::mpsc::channel(params.len()); - let future = - client.set_operator(tokio_stream::wrappers::ReceiverStream::new(rx)); + let (tx, rx) = channel(params.len()); + let future = client.set_operator(ReceiverStream::new(rx)); for param in params { - tx.send(param).await.unwrap(); + tx.send(param).await + .map_err(|e| Status::aborted(e.to_string()))? } future.await?; } @@ -394,11 +382,11 @@ impl ClementineAggregator for Aggregator { let watchtower_params = watchtower_params.clone(); async move { for params in watchtower_params { - let (tx, rx) = tokio::sync::mpsc::channel(params.len()); - let future = - client.set_watchtower(tokio_stream::wrappers::ReceiverStream::new(rx)); + let (tx, rx) = channel(params.len()); + let future = client.set_watchtower(ReceiverStream::new(rx)); for param in params { - tx.send(param).await.unwrap(); + tx.send(param).await + .map_err(|e| Status::aborted(e.to_string()))? } future.await?; }