Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aggregator setup pipelining #487

Draft
wants to merge 3 commits into
base: dev
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 65 additions & 70 deletions core/src/rpc/aggregator.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::future::Future;

use super::clementine::{
clementine_aggregator_server::ClementineAggregator, verifier_deposit_finalize_params,
DepositParams, Empty, RawSignedMoveTx, VerifierDepositFinalizeParams,
Expand All @@ -17,6 +19,7 @@ use bitcoin::{Amount, TapSighash};
use futures::{future::try_join_all, stream::BoxStream, FutureExt, Stream, StreamExt};
use secp256k1::musig::{MusigAggNonce, MusigPartialSignature, MusigPubNonce};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio_stream::wrappers::ReceiverStream;
use tonic::{async_trait, Request, Response, Status, Streaming};

struct AggNonceQueueItem {
Expand Down Expand Up @@ -284,117 +287,109 @@ impl Aggregator {
}
}

async fn collect<C, T, E, Fut, F>(clients: &[C], mut f: F) -> Result<Vec<T>, E>
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, these will reduce boilerplate by a lot.

I think the naming needs to change though, this seems like a utility function that calls all clients with a given func and returns the inner response. Something like: map_clients_collect?

Also maybe #[inline] and/or #[track_caller] are appropriate here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this definitely needs refining, thank you for pointing it out. I didn't find an elegant way to generalize it over streaming (returning Streaming<T>) and regular/non-streaming (returning Response<T>) requests 🤷 .

where
C: Clone,
Fut: Future<Output = Result<Response<T>, E>>,
F: FnMut(C) -> Fut,
{
try_join_all(
clients.iter().map(|client| {
f(client.clone()).map(|result| result.map(|response| response.into_inner()))
}),
)
.await
}

async fn pull<T>(mut stream: Streaming<T>) -> Result<Vec<T>, Status> {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need docs on the behavior. This currently collects until end of stream unless an error happens, in which case we prematurely return the error. I think collect fits better here. This is also the behavior of collect when doing a Iter<Item = Result<T, E>> -> Result<Vec<T>, E>

Also can we move these two into another module?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, this was just a first attempt to generalize stream collection, probably just .collect() on a stream would have been enough and this function is unnecessary.

let mut ret = Vec::new();
while let Some(next) = stream.message().await? {
ret.push(next);
}
Ok(ret)
}

#[async_trait]
impl ClementineAggregator for Aggregator {
#[tracing::instrument(skip_all, err(level = tracing::Level::ERROR), ret(level = tracing::Level::TRACE))]
async fn setup(&self, _request: Request<Empty>) -> Result<Response<Empty>, Status> {
tracing::info!("Collecting verifier details...");
let verifier_params = try_join_all(self.verifier_clients.iter().map(|client| {
let mut client = client.clone();
async move {
let response = client.get_params(Request::new(Empty {})).await?;
Ok::<_, Status>(response.into_inner())
}
}))
let verifier_params = collect(&self.verifier_clients, |mut client| async move {
client.get_params(Request::new(Empty {})).await
})
.await?;
let verifier_public_keys: Vec<Vec<u8>> =
verifier_params.into_iter().map(|p| p.public_key).collect();
tracing::debug!("Verifier public keys: {:?}", verifier_public_keys);

tracing::info!("Setting up verifiers...");
try_join_all(self.verifier_clients.iter().map(|client| {
let mut client = client.clone();
{
let verifier_public_keys = clementine::VerifierPublicKeys {
verifier_public_keys: verifier_public_keys.clone(),
};
async move {
let response = client
.set_verifiers(Request::new(verifier_public_keys))
.await?;
Ok::<_, Status>(response.into_inner())
}
collect(&self.verifier_clients, |mut client| {
let verifier_public_keys = clementine::VerifierPublicKeys {
verifier_public_keys: verifier_public_keys.clone(),
};
async move {
client
.set_verifiers(Request::new(verifier_public_keys))
.await
}
}))
})
.await?;

tracing::info!("Collecting operator details...");
let operator_params = try_join_all(self.operator_clients.iter().map(|client| {
let mut client = client.clone();
async move {
let mut responses = Vec::new();
let mut params_stream = client
.get_params(Request::new(Empty {}))
.await?
.into_inner();
while let Some(response) = params_stream.message().await? {
responses.push(response);
}
let operator_params = try_join_all(
collect(&self.operator_clients, |mut client| async move {
client.get_params(Request::new(Empty {})).await
})
.await?
.into_iter()
.map(|stream| async move { pull(stream).await }),
)
.await?;

Ok::<_, Status>(responses)
}
}))
tracing::info!("Collecting Winternitz public keys from watchtowers...");
let watchtower_params = try_join_all(
collect(&self.watchtower_clients, |mut client| async move {
client.get_params(Request::new(Empty {})).await
})
.await?
.into_iter()
.map(|stream| async move { pull(stream).await }),
)
.await?;

tracing::info!("Informing verifiers for existing operators...");
try_join_all(self.verifier_clients.iter().map(|client| {
let mut client = client.clone();
let operator_params = operator_params.clone();

async move {
for params in operator_params {
let (tx, rx) = tokio::sync::mpsc::channel(1280);
let future =
client.set_operator(tokio_stream::wrappers::ReceiverStream::new(rx));

let (tx, rx) = channel(params.len());
let future = client.set_operator(ReceiverStream::new(rx));
for param in params {
tx.send(param).await.unwrap();
tx.send(param).await
.map_err(|e| Status::aborted(e.to_string()))?
}

future.await?; // TODO: This is dangerous: If channel size becomes not sufficient, this will block forever.
future.await?;
}

Ok::<_, tonic::Status>(())
}
}))
.await?;

tracing::info!("Collecting Winternitz public keys from watchtowers...");
let watchtower_params = try_join_all(self.watchtower_clients.iter().map(|client| {
let mut client = client.clone();
async move {
let mut responses = Vec::new();
let mut params_stream = client
.get_params(Request::new(Empty {}))
.await?
.into_inner();
while let Some(response) = params_stream.message().await? {
responses.push(response);
}

Ok::<_, Status>(responses)
}
}))
.await?;

tracing::info!("Sending Winternitz public keys to verifiers...");
try_join_all(self.verifier_clients.iter().map(|client| {
let mut client = client.clone();
let watchtower_params = watchtower_params.clone();

async move {
for params in watchtower_params {
let (tx, rx) = tokio::sync::mpsc::channel(1280);

let future =
client.set_watchtower(tokio_stream::wrappers::ReceiverStream::new(rx));
let (tx, rx) = channel(params.len());
let future = client.set_watchtower(ReceiverStream::new(rx));
for param in params {
tx.send(param).await.unwrap();
tx.send(param).await
.map_err(|e| Status::aborted(e.to_string()))?
}

future.await?; // TODO: This is dangerous: If channel size becomes not sufficient, this will block forever.
future.await?;
}

Ok::<_, tonic::Status>(())
}
}))
Expand Down
Loading