Skip to content

Commit

Permalink
Merge pull request #1001 from akoshelev/shard-gateway-reqs
Browse files Browse the repository at this point in the history
Use `BytesStream` inside `Transport` trait
  • Loading branch information
akoshelev authored Apr 5, 2024
2 parents 90b3d98 + 5f3967d commit b651100
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 63 deletions.
10 changes: 5 additions & 5 deletions ipa-core/src/helpers/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::{
receive::GatewayReceivers, send::GatewaySenders, transport::RoleResolvingTransport,
},
transport::routing::RouteId,
HelperChannelId, Message, Role, RoleAssignment, TotalRecords, Transport,
HelperChannelId, LogErrors, Message, Role, RoleAssignment, TotalRecords, Transport,
},
protocol::QueryId,
};
Expand Down Expand Up @@ -142,10 +142,10 @@ impl Gateway {
channel_id.clone(),
self.inner.receivers.get_or_create(channel_id, || {
UnorderedReceiver::new(
Box::pin(
self.transport
.receive(channel_id.peer, (self.query_id, channel_id.gate.clone())),
),
Box::pin(LogErrors::new(self.transport.receive(
channel_id.peer,
(self.query_id, channel_id.gate.clone()),
))),
self.config.active_work(),
)
}),
Expand Down
9 changes: 5 additions & 4 deletions ipa-core/src/helpers/gateway/receive.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::marker::PhantomData;

use bytes::Bytes;
use dashmap::{mapref::entry::Entry, DashMap};
use futures::Stream;

use crate::{
error::BoxError,
helpers::{
buffers::UnorderedReceiver, gateway::transport::RoleResolvingTransport, Error,
HelperChannelId, Message, Role, Transport,
HelperChannelId, LogErrors, Message, Role, Transport,
},
protocol::RecordId,
};
Expand All @@ -25,8 +26,8 @@ pub(super) struct GatewayReceivers {
}

pub(super) type UR = UnorderedReceiver<
<RoleResolvingTransport as Transport>::RecordsStream,
<<RoleResolvingTransport as Transport>::RecordsStream as Stream>::Item,
LogErrors<<RoleResolvingTransport as Transport>::RecordsStream, Bytes, BoxError>,
Vec<u8>,
>;

impl<M: Message> ReceivingEnd<M> {
Expand Down
1 change: 1 addition & 0 deletions ipa-core/src/helpers/transport/in_memory/sharding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ mod tests {
sum += shard_network
.transport(identity, a)
.receive(b, (QueryId, Gate::default()))
.into_bytes_stream()
.collect::<Vec<_>>()
.await
.into_iter()
Expand Down
72 changes: 34 additions & 38 deletions ipa-core/src/helpers/transport/in_memory/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use ::tokio::sync::{
oneshot,
};
use async_trait::async_trait;
use bytes::Bytes;
use futures::{Stream, StreamExt};
#[cfg(all(feature = "shuttle", test))]
use shuttle::future as tokio;
Expand All @@ -36,7 +37,7 @@ type Packet<I> = (
);
type ConnectionTx<I> = Sender<Packet<I>>;
type ConnectionRx<I> = Receiver<Packet<I>>;
type StreamItem = Vec<u8>;
type StreamItem = Result<Bytes, BoxError>;

#[derive(Debug, thiserror::Error)]
pub enum Error<I> {
Expand Down Expand Up @@ -110,12 +111,7 @@ impl<I: TransportIdentity> InMemoryTransport<I> {
handler
.as_ref()
.expect("Handler is set")
.handle(
addr,
BodyStream::from_infallible(
stream.map(Vec::into_boxed_slice),
),
)
.handle(addr, BodyStream::from_bytes_stream(stream))
.await
}
};
Expand Down Expand Up @@ -177,7 +173,11 @@ impl<I: TransportIdentity> Transport for Weak<InMemoryTransport<I>> {
let (ack_tx, ack_rx) = oneshot::channel();

channel
.send((addr, InMemoryStream::wrap(data), ack_tx))
.send((
addr,
InMemoryStream::wrap(data.map(Bytes::from).map(Ok)),
ack_tx,
))
.await
.map_err(|_e| {
io::Error::new::<String>(io::ErrorKind::ConnectionAborted, "channel closed".into())
Expand Down Expand Up @@ -216,28 +216,11 @@ pub struct InMemoryStream {
}

impl InMemoryStream {
#[cfg(all(test, unit_test))]
fn empty() -> Self {
Self::from_iter(std::iter::empty())
}

fn wrap<S: Stream<Item = StreamItem> + Send + 'static>(value: S) -> Self {
Self {
inner: Box::pin(value),
}
}

#[cfg(all(test, unit_test))]
fn from_iter<I>(input: I) -> Self
where
I: IntoIterator<Item = StreamItem>,
I::IntoIter: Send + 'static,
{
use futures_util::stream;
Self {
inner: Box::pin(stream::iter(input)),
}
}
}

impl From<Receiver<StreamItem>> for InMemoryStream {
Expand Down Expand Up @@ -324,8 +307,11 @@ mod tests {
task::Poll,
};

use bytes::Bytes;
use futures::{stream, Stream};
use futures_util::{stream::poll_immediate, FutureExt, StreamExt};
use tokio::sync::{mpsc::channel, oneshot};
use tokio_stream::wrappers::ReceiverStream;

use crate::{
ff::{FieldType, Fp31},
Expand All @@ -348,11 +334,12 @@ mod tests {

const STEP: &str = "in-memory-transport";

async fn send_and_ack<I: TransportIdentity>(
async fn send_and_ack<I: TransportIdentity, S: Stream<Item = Vec<u8>> + Send + 'static>(
sender: &ConnectionTx<I>,
addr: Addr<I>,
data: InMemoryStream,
data: S,
) {
let data = InMemoryStream::wrap(data.map(Bytes::from).map(Ok));
let (tx, rx) = oneshot::channel();
sender.send((addr, data, tx)).await.unwrap();
let _ = rx
Expand Down Expand Up @@ -398,7 +385,7 @@ mod tests {
send_and_ack(
&tx,
Addr::from_route(Some(HelperIdentity::TWO), expected),
InMemoryStream::empty(),
stream::empty(),
)
.await;

Expand All @@ -411,7 +398,9 @@ mod tests {
let transport = Arc::downgrade(&transport);
let expected = vec![vec![1], vec![2]];

let mut stream = transport.receive(HelperIdentity::TWO, (QueryId, Gate::from(STEP)));
let mut stream = transport
.receive(HelperIdentity::TWO, (QueryId, Gate::from(STEP)))
.into_bytes_stream();

// make sure it is not ready as it hasn't received the records stream yet.
assert!(matches!(
Expand All @@ -421,7 +410,7 @@ mod tests {
send_and_ack(
&tx,
Addr::records(HelperIdentity::TWO, QueryId, Gate::from(STEP)),
InMemoryStream::from_iter(expected.clone()),
stream::iter(expected.clone()),
)
.await;

Expand All @@ -436,12 +425,13 @@ mod tests {
send_and_ack(
&tx,
Addr::records(HelperIdentity::TWO, QueryId, Gate::from(STEP)),
InMemoryStream::from_iter(expected.clone()),
stream::iter(expected.clone()),
)
.await;

let stream =
Arc::downgrade(&transport).receive(HelperIdentity::TWO, (QueryId, Gate::from(STEP)));
let stream = Arc::downgrade(&transport)
.receive(HelperIdentity::TWO, (QueryId, Gate::from(STEP)))
.into_bytes_stream();

assert_eq!(expected, stream.collect::<Vec<_>>().await);
}
Expand All @@ -454,13 +444,15 @@ mod tests {
transports: &HashMap<HelperIdentity, Weak<InMemoryTransport<HelperIdentity>>>,
) {
let (stream_tx, stream_rx) = channel(1);
let stream = InMemoryStream::from(stream_rx);
let stream = ReceiverStream::new(stream_rx);

let from_transport = transports.get(&from).unwrap();
let to_transport = transports.get(&to).unwrap();
let gate = Gate::from(STEP);

let mut recv = to_transport.receive(from, (QueryId, gate.clone()));
let mut recv = to_transport
.receive(from, (QueryId, gate.clone()))
.into_bytes_stream();
assert!(matches!(
poll_immediate(&mut recv).next().await,
Some(Poll::Pending)
Expand Down Expand Up @@ -509,10 +501,12 @@ mod tests {
let (tx, owned_transport) = Setup::new(HelperIdentity::ONE).into_active_conn(None);
let gate = Gate::from(STEP);
let (stream_tx, stream_rx) = channel(1);
let stream = InMemoryStream::from(stream_rx);
let stream = ReceiverStream::from(stream_rx);
let transport = Arc::downgrade(&owned_transport);

let mut recv_stream = transport.receive(HelperIdentity::TWO, (QueryId, gate.clone()));
let mut recv_stream = transport
.receive(HelperIdentity::TWO, (QueryId, gate.clone()))
.into_bytes_stream();
send_and_ack(
&tx,
Addr::records(HelperIdentity::TWO, QueryId, gate.clone()),
Expand Down Expand Up @@ -562,7 +556,9 @@ mod tests {
)
.await
.unwrap();
let mut recv = transport2.receive(HelperIdentity::ONE, (QueryId, gate));
let mut recv = transport2
.receive(HelperIdentity::ONE, (QueryId, gate))
.into_bytes_stream();

tx.send(0, Fp31::try_from(0_u128).unwrap()).await;
// can't receive the value at index 0 because of buffering inside the sender
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/helpers/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ impl RouteParams<RouteId, QueryId, NoStep> for (RouteId, QueryId) {
#[async_trait]
pub trait Transport: Clone + Send + Sync + 'static {
type Identity: TransportIdentity;
type RecordsStream: Stream<Item = Vec<u8>> + Send + Unpin;
type RecordsStream: BytesStream;
type Error: std::fmt::Debug;

fn identity(&self) -> Self::Identity;
Expand Down
11 changes: 11 additions & 0 deletions ipa-core/src/helpers/transport/receive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ impl<I, S> ReceiveRecords<I, S> {
}
}

#[cfg(all(test, any(unit_test, web_test)))]
impl<I: TransportIdentity, S: crate::helpers::BytesStream> ReceiveRecords<I, S> {
/// Converts this into a stream that yields owned byte chunks.
///
/// ## Panics
/// If inner stream yields an [`Err`] chunk.
pub(crate) fn into_bytes_stream(self) -> impl Stream<Item = Vec<u8>> {
self.inner.map(Result::unwrap).map(Into::into)
}
}

impl<I: TransportIdentity, S: Stream> Stream for ReceiveRecords<I, S> {
type Item = S::Item;

Expand Down
6 changes: 5 additions & 1 deletion ipa-core/src/helpers/transport/stream/box_body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
use bytes::Bytes;
use futures::{stream::StreamExt, Stream};

use crate::helpers::transport::stream::BoxBytesStream;
use crate::helpers::{transport::stream::BoxBytesStream, BytesStream};

pub struct WrappedBoxBodyStream(BoxBytesStream);

Expand All @@ -22,6 +22,10 @@ impl WrappedBoxBodyStream {
Self(Box::pin(input.map(Bytes::from).map(Ok)))
}

pub fn from_bytes_stream<S: BytesStream + 'static>(input: S) -> Self {
Self(Box::pin(input))
}

#[must_use]
pub fn empty() -> Self {
WrappedBoxBodyStream(Box::pin(futures::stream::empty()))
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/helpers/transport/stream/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ impl Mode for Batch {
}
}

/// Parse a [`Stream`] of [`Bytes`] into a stream of records of some
/// Parse a [`Stream`] of bytes into a stream of records of some
/// fixed-length-[`Serializable`] type `T`.
///
/// Depending on `M`, the provided stream can yield a single record `T` or multiples of `T`. See
Expand Down
5 changes: 3 additions & 2 deletions ipa-core/src/net/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,9 @@ pub(crate) mod tests {

MpcHelperClient::resp_ok(resp).await.unwrap();

let mut stream =
Arc::clone(&transport).receive(HelperIdentity::ONE, (QueryId, expected_step.clone()));
let mut stream = Arc::clone(&transport)
.receive(HelperIdentity::ONE, (QueryId, expected_step.clone()))
.into_bytes_stream();

assert_eq!(
poll_immediate(&mut stream).next().await,
Expand Down
4 changes: 3 additions & 1 deletion ipa-core/src/net/server/handlers/query/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ mod tests {
.await
.unwrap();

let mut stream = Arc::clone(&transport).receive(HelperIdentity::TWO, (QueryId, step));
let mut stream = Arc::clone(&transport)
.receive(HelperIdentity::TWO, (QueryId, step))
.into_bytes_stream();

assert_eq!(
poll_immediate(&mut stream).next().await,
Expand Down
Loading

0 comments on commit b651100

Please sign in to comment.