Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use BytesStream inside Transport trait #1001

Merged
merged 3 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ipa-core/src/helpers/gateway/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::{
receive::GatewayReceivers, send::GatewaySenders, transport::RoleResolvingTransport,
},
transport::routing::RouteId,
HelperChannelId, Message, Role, RoleAssignment, TotalRecords, Transport,
HelperChannelId, LogErrors, Message, Role, RoleAssignment, TotalRecords, Transport,
},
protocol::QueryId,
};
Expand Down Expand Up @@ -142,10 +142,10 @@ impl Gateway {
channel_id.clone(),
self.inner.receivers.get_or_create(channel_id, || {
UnorderedReceiver::new(
Box::pin(
self.transport
.receive(channel_id.peer, (self.query_id, channel_id.gate.clone())),
),
Box::pin(LogErrors::new(self.transport.receive(
channel_id.peer,
(self.query_id, channel_id.gate.clone()),
))),
self.config.active_work(),
)
}),
Expand Down
9 changes: 5 additions & 4 deletions ipa-core/src/helpers/gateway/receive.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::marker::PhantomData;

use bytes::Bytes;
use dashmap::{mapref::entry::Entry, DashMap};
use futures::Stream;

use crate::{
error::BoxError,
helpers::{
buffers::UnorderedReceiver, gateway::transport::RoleResolvingTransport, Error,
HelperChannelId, Message, Role, Transport,
HelperChannelId, LogErrors, Message, Role, Transport,
},
protocol::RecordId,
};
Expand All @@ -25,8 +26,8 @@ pub(super) struct GatewayReceivers {
}

pub(super) type UR = UnorderedReceiver<
<RoleResolvingTransport as Transport>::RecordsStream,
<<RoleResolvingTransport as Transport>::RecordsStream as Stream>::Item,
LogErrors<<RoleResolvingTransport as Transport>::RecordsStream, Bytes, BoxError>,
Vec<u8>,
>;

impl<M: Message> ReceivingEnd<M> {
Expand Down
1 change: 1 addition & 0 deletions ipa-core/src/helpers/transport/in_memory/sharding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ mod tests {
sum += shard_network
.transport(identity, a)
.receive(b, (QueryId, Gate::default()))
.into_bytes_stream()
.collect::<Vec<_>>()
.await
.into_iter()
Expand Down
72 changes: 34 additions & 38 deletions ipa-core/src/helpers/transport/in_memory/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use ::tokio::sync::{
oneshot,
};
use async_trait::async_trait;
use bytes::Bytes;
use futures::{Stream, StreamExt};
#[cfg(all(feature = "shuttle", test))]
use shuttle::future as tokio;
Expand All @@ -36,7 +37,7 @@ type Packet<I> = (
);
type ConnectionTx<I> = Sender<Packet<I>>;
type ConnectionRx<I> = Receiver<Packet<I>>;
type StreamItem = Vec<u8>;
type StreamItem = Result<Bytes, BoxError>;

#[derive(Debug, thiserror::Error)]
pub enum Error<I> {
Expand Down Expand Up @@ -110,12 +111,7 @@ impl<I: TransportIdentity> InMemoryTransport<I> {
handler
.as_ref()
.expect("Handler is set")
.handle(
addr,
BodyStream::from_infallible(
stream.map(Vec::into_boxed_slice),
),
)
.handle(addr, BodyStream::from_bytes_stream(stream))
.await
}
};
Expand Down Expand Up @@ -177,7 +173,11 @@ impl<I: TransportIdentity> Transport for Weak<InMemoryTransport<I>> {
let (ack_tx, ack_rx) = oneshot::channel();

channel
.send((addr, InMemoryStream::wrap(data), ack_tx))
.send((
addr,
InMemoryStream::wrap(data.map(Bytes::from).map(Ok)),
ack_tx,
))
.await
.map_err(|_e| {
io::Error::new::<String>(io::ErrorKind::ConnectionAborted, "channel closed".into())
Expand Down Expand Up @@ -216,28 +216,11 @@ pub struct InMemoryStream {
}

impl InMemoryStream {
#[cfg(all(test, unit_test))]
fn empty() -> Self {
Self::from_iter(std::iter::empty())
}

fn wrap<S: Stream<Item = StreamItem> + Send + 'static>(value: S) -> Self {
Self {
inner: Box::pin(value),
}
}

#[cfg(all(test, unit_test))]
fn from_iter<I>(input: I) -> Self
where
I: IntoIterator<Item = StreamItem>,
I::IntoIter: Send + 'static,
{
use futures_util::stream;
Self {
inner: Box::pin(stream::iter(input)),
}
}
}

impl From<Receiver<StreamItem>> for InMemoryStream {
Expand Down Expand Up @@ -324,8 +307,11 @@ mod tests {
task::Poll,
};

use bytes::Bytes;
use futures::{stream, Stream};
use futures_util::{stream::poll_immediate, FutureExt, StreamExt};
use tokio::sync::{mpsc::channel, oneshot};
use tokio_stream::wrappers::ReceiverStream;

use crate::{
ff::{FieldType, Fp31},
Expand All @@ -348,11 +334,12 @@ mod tests {

const STEP: &str = "in-memory-transport";

async fn send_and_ack<I: TransportIdentity>(
async fn send_and_ack<I: TransportIdentity, S: Stream<Item = Vec<u8>> + Send + 'static>(
sender: &ConnectionTx<I>,
addr: Addr<I>,
data: InMemoryStream,
data: S,
) {
let data = InMemoryStream::wrap(data.map(Bytes::from).map(Ok));
let (tx, rx) = oneshot::channel();
sender.send((addr, data, tx)).await.unwrap();
let _ = rx
Expand Down Expand Up @@ -398,7 +385,7 @@ mod tests {
send_and_ack(
&tx,
Addr::from_route(Some(HelperIdentity::TWO), expected),
InMemoryStream::empty(),
stream::empty(),
)
.await;

Expand All @@ -411,7 +398,9 @@ mod tests {
let transport = Arc::downgrade(&transport);
let expected = vec![vec![1], vec![2]];

let mut stream = transport.receive(HelperIdentity::TWO, (QueryId, Gate::from(STEP)));
let mut stream = transport
.receive(HelperIdentity::TWO, (QueryId, Gate::from(STEP)))
.into_bytes_stream();

// make sure it is not ready as it hasn't received the records stream yet.
assert!(matches!(
Expand All @@ -421,7 +410,7 @@ mod tests {
send_and_ack(
&tx,
Addr::records(HelperIdentity::TWO, QueryId, Gate::from(STEP)),
InMemoryStream::from_iter(expected.clone()),
stream::iter(expected.clone()),
)
.await;

Expand All @@ -436,12 +425,13 @@ mod tests {
send_and_ack(
&tx,
Addr::records(HelperIdentity::TWO, QueryId, Gate::from(STEP)),
InMemoryStream::from_iter(expected.clone()),
stream::iter(expected.clone()),
)
.await;

let stream =
Arc::downgrade(&transport).receive(HelperIdentity::TWO, (QueryId, Gate::from(STEP)));
let stream = Arc::downgrade(&transport)
.receive(HelperIdentity::TWO, (QueryId, Gate::from(STEP)))
.into_bytes_stream();

assert_eq!(expected, stream.collect::<Vec<_>>().await);
}
Expand All @@ -454,13 +444,15 @@ mod tests {
transports: &HashMap<HelperIdentity, Weak<InMemoryTransport<HelperIdentity>>>,
) {
let (stream_tx, stream_rx) = channel(1);
let stream = InMemoryStream::from(stream_rx);
let stream = ReceiverStream::new(stream_rx);

let from_transport = transports.get(&from).unwrap();
let to_transport = transports.get(&to).unwrap();
let gate = Gate::from(STEP);

let mut recv = to_transport.receive(from, (QueryId, gate.clone()));
let mut recv = to_transport
.receive(from, (QueryId, gate.clone()))
.into_bytes_stream();
assert!(matches!(
poll_immediate(&mut recv).next().await,
Some(Poll::Pending)
Expand Down Expand Up @@ -509,10 +501,12 @@ mod tests {
let (tx, owned_transport) = Setup::new(HelperIdentity::ONE).into_active_conn(None);
let gate = Gate::from(STEP);
let (stream_tx, stream_rx) = channel(1);
let stream = InMemoryStream::from(stream_rx);
let stream = ReceiverStream::from(stream_rx);
let transport = Arc::downgrade(&owned_transport);

let mut recv_stream = transport.receive(HelperIdentity::TWO, (QueryId, gate.clone()));
let mut recv_stream = transport
.receive(HelperIdentity::TWO, (QueryId, gate.clone()))
.into_bytes_stream();
send_and_ack(
&tx,
Addr::records(HelperIdentity::TWO, QueryId, gate.clone()),
Expand Down Expand Up @@ -562,7 +556,9 @@ mod tests {
)
.await
.unwrap();
let mut recv = transport2.receive(HelperIdentity::ONE, (QueryId, gate));
let mut recv = transport2
.receive(HelperIdentity::ONE, (QueryId, gate))
.into_bytes_stream();

tx.send(0, Fp31::try_from(0_u128).unwrap()).await;
// can't receive the value at index 0 because of buffering inside the sender
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/helpers/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ impl RouteParams<RouteId, QueryId, NoStep> for (RouteId, QueryId) {
#[async_trait]
pub trait Transport: Clone + Send + Sync + 'static {
type Identity: TransportIdentity;
type RecordsStream: Stream<Item = Vec<u8>> + Send + Unpin;
type RecordsStream: BytesStream;
type Error: std::fmt::Debug;

fn identity(&self) -> Self::Identity;
Expand Down
11 changes: 11 additions & 0 deletions ipa-core/src/helpers/transport/receive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ impl<I, S> ReceiveRecords<I, S> {
}
}

#[cfg(all(test, any(unit_test, web_test)))]
impl<I: TransportIdentity, S: crate::helpers::BytesStream> ReceiveRecords<I, S> {
/// Converts this into a stream that yields owned byte chunks.
///
/// ## Panics
/// If inner stream yields an [`Err`] chunk.
pub(crate) fn into_bytes_stream(self) -> impl Stream<Item = Vec<u8>> {
self.inner.map(Result::unwrap).map(Into::into)
}
}

impl<I: TransportIdentity, S: Stream> Stream for ReceiveRecords<I, S> {
type Item = S::Item;

Expand Down
6 changes: 5 additions & 1 deletion ipa-core/src/helpers/transport/stream/box_body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
use bytes::Bytes;
use futures::{stream::StreamExt, Stream};

use crate::helpers::transport::stream::BoxBytesStream;
use crate::helpers::{transport::stream::BoxBytesStream, BytesStream};

pub struct WrappedBoxBodyStream(BoxBytesStream);

Expand All @@ -22,6 +22,10 @@ impl WrappedBoxBodyStream {
Self(Box::pin(input.map(Bytes::from).map(Ok)))
}

pub fn from_bytes_stream<S: BytesStream + 'static>(input: S) -> Self {
Self(Box::pin(input))
}

#[must_use]
pub fn empty() -> Self {
WrappedBoxBodyStream(Box::pin(futures::stream::empty()))
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/helpers/transport/stream/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ impl Mode for Batch {
}
}

/// Parse a [`Stream`] of [`Bytes`] into a stream of records of some
/// Parse a [`Stream`] of bytes into a stream of records of some
/// fixed-length-[`Serializable`] type `T`.
///
/// Depending on `M`, the provided stream can yield a single record `T` or multiples of `T`. See
Expand Down
5 changes: 3 additions & 2 deletions ipa-core/src/net/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,9 @@ pub(crate) mod tests {

MpcHelperClient::resp_ok(resp).await.unwrap();

let mut stream =
Arc::clone(&transport).receive(HelperIdentity::ONE, (QueryId, expected_step.clone()));
let mut stream = Arc::clone(&transport)
.receive(HelperIdentity::ONE, (QueryId, expected_step.clone()))
.into_bytes_stream();

assert_eq!(
poll_immediate(&mut stream).next().await,
Expand Down
4 changes: 3 additions & 1 deletion ipa-core/src/net/server/handlers/query/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ mod tests {
.await
.unwrap();

let mut stream = Arc::clone(&transport).receive(HelperIdentity::TWO, (QueryId, step));
let mut stream = Arc::clone(&transport)
.receive(HelperIdentity::TWO, (QueryId, step))
.into_bytes_stream();

assert_eq!(
poll_immediate(&mut stream).next().await,
Expand Down
Loading
Loading