From 4fb18b1b67988115243cf84a49e72ff3b3d3f284 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Wed, 3 Apr 2024 13:53:53 -0700 Subject: [PATCH 1/2] Use `BytesStream` inside `Transport` trait This change is to unify stream handling between in-memory and real-world implementations. Before, in memory streams yielded `Vec` and HTTP streams operated on `Result`. This discrepancy required us to build a compatibility layer, and it makes sense now to unify both implementations. The main motivator for this change is to be able to use `RecordsStream` inside shard receivers - `UnorderedReceiver` use case fits MPC model well, but is clunky to use inside shard channels. --- ipa-core/src/helpers/gateway/mod.rs | 10 +-- ipa-core/src/helpers/gateway/receive.rs | 9 +-- .../helpers/transport/in_memory/sharding.rs | 1 + .../helpers/transport/in_memory/transport.rs | 72 +++++++++---------- ipa-core/src/helpers/transport/mod.rs | 2 +- ipa-core/src/helpers/transport/receive.rs | 11 +++ .../src/helpers/transport/stream/box_body.rs | 6 +- .../src/helpers/transport/stream/input.rs | 12 ++-- ipa-core/src/helpers/transport/stream/mod.rs | 6 +- ipa-core/src/net/client/mod.rs | 5 +- .../src/net/server/handlers/query/step.rs | 4 +- ipa-core/src/net/transport.rs | 18 +++-- 12 files changed, 85 insertions(+), 71 deletions(-) diff --git a/ipa-core/src/helpers/gateway/mod.rs b/ipa-core/src/helpers/gateway/mod.rs index 151a920c6..d973d6c6f 100644 --- a/ipa-core/src/helpers/gateway/mod.rs +++ b/ipa-core/src/helpers/gateway/mod.rs @@ -20,7 +20,7 @@ use crate::{ receive::GatewayReceivers, send::GatewaySenders, transport::RoleResolvingTransport, }, transport::routing::RouteId, - HelperChannelId, Message, Role, RoleAssignment, TotalRecords, Transport, + HelperChannelId, LogErrors, Message, Role, RoleAssignment, TotalRecords, Transport, }, protocol::QueryId, }; @@ -142,10 +142,10 @@ impl Gateway { channel_id.clone(), self.inner.receivers.get_or_create(channel_id, || { UnorderedReceiver::new( - Box::pin( - self.transport - .receive(channel_id.peer, (self.query_id, channel_id.gate.clone())), - ), + Box::pin(LogErrors::new(self.transport.receive( + channel_id.peer, + (self.query_id, channel_id.gate.clone()), + ))), self.config.active_work(), ) }), diff --git a/ipa-core/src/helpers/gateway/receive.rs b/ipa-core/src/helpers/gateway/receive.rs index 6e37f05d8..a98166e9f 100644 --- a/ipa-core/src/helpers/gateway/receive.rs +++ b/ipa-core/src/helpers/gateway/receive.rs @@ -1,12 +1,13 @@ use std::marker::PhantomData; +use bytes::Bytes; use dashmap::{mapref::entry::Entry, DashMap}; -use futures::Stream; use crate::{ + error::BoxError, helpers::{ buffers::UnorderedReceiver, gateway::transport::RoleResolvingTransport, Error, - HelperChannelId, Message, Role, Transport, + HelperChannelId, LogErrors, Message, Role, Transport, }, protocol::RecordId, }; @@ -25,8 +26,8 @@ pub(super) struct GatewayReceivers { } pub(super) type UR = UnorderedReceiver< - ::RecordsStream, - <::RecordsStream as Stream>::Item, + LogErrors<::RecordsStream, Bytes, BoxError>, + Vec, >; impl ReceivingEnd { diff --git a/ipa-core/src/helpers/transport/in_memory/sharding.rs b/ipa-core/src/helpers/transport/in_memory/sharding.rs index 0700793cc..23175e375 100644 --- a/ipa-core/src/helpers/transport/in_memory/sharding.rs +++ b/ipa-core/src/helpers/transport/in_memory/sharding.rs @@ -122,6 +122,7 @@ mod tests { sum += shard_network .transport(identity, a) .receive(b, (QueryId, Gate::default())) + .into_bytes_stream() .collect::>() .await .into_iter() diff --git a/ipa-core/src/helpers/transport/in_memory/transport.rs b/ipa-core/src/helpers/transport/in_memory/transport.rs index cb456a599..4c8962e61 100644 --- a/ipa-core/src/helpers/transport/in_memory/transport.rs +++ b/ipa-core/src/helpers/transport/in_memory/transport.rs @@ -11,6 +11,7 @@ use ::tokio::sync::{ oneshot, }; use async_trait::async_trait; +use bytes::Bytes; use futures::{Stream, StreamExt}; #[cfg(all(feature = "shuttle", test))] use shuttle::future as tokio; @@ -36,7 +37,7 @@ type Packet = ( ); type ConnectionTx = Sender>; type ConnectionRx = Receiver>; -type StreamItem = Vec; +type StreamItem = Result; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -110,12 +111,7 @@ impl InMemoryTransport { handler .as_ref() .expect("Handler is set") - .handle( - addr, - BodyStream::from_infallible( - stream.map(Vec::into_boxed_slice), - ), - ) + .handle(addr, BodyStream::from_bytes_stream(stream)) .await } }; @@ -177,7 +173,11 @@ impl Transport for Weak> { let (ack_tx, ack_rx) = oneshot::channel(); channel - .send((addr, InMemoryStream::wrap(data), ack_tx)) + .send(( + addr, + InMemoryStream::wrap(data.map(Bytes::from).map(Ok)), + ack_tx, + )) .await .map_err(|_e| { io::Error::new::(io::ErrorKind::ConnectionAborted, "channel closed".into()) @@ -216,28 +216,11 @@ pub struct InMemoryStream { } impl InMemoryStream { - #[cfg(all(test, unit_test))] - fn empty() -> Self { - Self::from_iter(std::iter::empty()) - } - fn wrap + Send + 'static>(value: S) -> Self { Self { inner: Box::pin(value), } } - - #[cfg(all(test, unit_test))] - fn from_iter(input: I) -> Self - where - I: IntoIterator, - I::IntoIter: Send + 'static, - { - use futures_util::stream; - Self { - inner: Box::pin(stream::iter(input)), - } - } } impl From> for InMemoryStream { @@ -324,8 +307,11 @@ mod tests { task::Poll, }; + use bytes::Bytes; + use futures::{stream, Stream}; use futures_util::{stream::poll_immediate, FutureExt, StreamExt}; use tokio::sync::{mpsc::channel, oneshot}; + use tokio_stream::wrappers::ReceiverStream; use crate::{ ff::{FieldType, Fp31}, @@ -348,11 +334,12 @@ mod tests { const STEP: &str = "in-memory-transport"; - async fn send_and_ack( + async fn send_and_ack> + Send + 'static>( sender: &ConnectionTx, addr: Addr, - data: InMemoryStream, + data: S, ) { + let data = InMemoryStream::wrap(data.map(Bytes::from).map(Ok)); let (tx, rx) = oneshot::channel(); sender.send((addr, data, tx)).await.unwrap(); let _ = rx @@ -398,7 +385,7 @@ mod tests { send_and_ack( &tx, Addr::from_route(Some(HelperIdentity::TWO), expected), - InMemoryStream::empty(), + stream::empty(), ) .await; @@ -411,7 +398,9 @@ mod tests { let transport = Arc::downgrade(&transport); let expected = vec![vec![1], vec![2]]; - let mut stream = transport.receive(HelperIdentity::TWO, (QueryId, Gate::from(STEP))); + let mut stream = transport + .receive(HelperIdentity::TWO, (QueryId, Gate::from(STEP))) + .into_bytes_stream(); // make sure it is not ready as it hasn't received the records stream yet. assert!(matches!( @@ -421,7 +410,7 @@ mod tests { send_and_ack( &tx, Addr::records(HelperIdentity::TWO, QueryId, Gate::from(STEP)), - InMemoryStream::from_iter(expected.clone()), + stream::iter(expected.clone()), ) .await; @@ -436,12 +425,13 @@ mod tests { send_and_ack( &tx, Addr::records(HelperIdentity::TWO, QueryId, Gate::from(STEP)), - InMemoryStream::from_iter(expected.clone()), + stream::iter(expected.clone()), ) .await; - let stream = - Arc::downgrade(&transport).receive(HelperIdentity::TWO, (QueryId, Gate::from(STEP))); + let stream = Arc::downgrade(&transport) + .receive(HelperIdentity::TWO, (QueryId, Gate::from(STEP))) + .into_bytes_stream(); assert_eq!(expected, stream.collect::>().await); } @@ -454,13 +444,15 @@ mod tests { transports: &HashMap>>, ) { let (stream_tx, stream_rx) = channel(1); - let stream = InMemoryStream::from(stream_rx); + let stream = ReceiverStream::new(stream_rx); let from_transport = transports.get(&from).unwrap(); let to_transport = transports.get(&to).unwrap(); let gate = Gate::from(STEP); - let mut recv = to_transport.receive(from, (QueryId, gate.clone())); + let mut recv = to_transport + .receive(from, (QueryId, gate.clone())) + .into_bytes_stream(); assert!(matches!( poll_immediate(&mut recv).next().await, Some(Poll::Pending) @@ -509,10 +501,12 @@ mod tests { let (tx, owned_transport) = Setup::new(HelperIdentity::ONE).into_active_conn(None); let gate = Gate::from(STEP); let (stream_tx, stream_rx) = channel(1); - let stream = InMemoryStream::from(stream_rx); + let stream = ReceiverStream::from(stream_rx); let transport = Arc::downgrade(&owned_transport); - let mut recv_stream = transport.receive(HelperIdentity::TWO, (QueryId, gate.clone())); + let mut recv_stream = transport + .receive(HelperIdentity::TWO, (QueryId, gate.clone())) + .into_bytes_stream(); send_and_ack( &tx, Addr::records(HelperIdentity::TWO, QueryId, gate.clone()), @@ -562,7 +556,9 @@ mod tests { ) .await .unwrap(); - let mut recv = transport2.receive(HelperIdentity::ONE, (QueryId, gate)); + let mut recv = transport2 + .receive(HelperIdentity::ONE, (QueryId, gate)) + .into_bytes_stream(); tx.send(0, Fp31::try_from(0_u128).unwrap()).await; // can't receive the value at index 0 because of buffering inside the sender diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index 23c290388..202d9fbad 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -163,7 +163,7 @@ impl RouteParams for (RouteId, QueryId) { #[async_trait] pub trait Transport: Clone + Send + Sync + 'static { type Identity: TransportIdentity; - type RecordsStream: Stream> + Send + Unpin; + type RecordsStream: BytesStream; type Error: std::fmt::Debug; fn identity(&self) -> Self::Identity; diff --git a/ipa-core/src/helpers/transport/receive.rs b/ipa-core/src/helpers/transport/receive.rs index fec775d0b..15fa7d4d9 100644 --- a/ipa-core/src/helpers/transport/receive.rs +++ b/ipa-core/src/helpers/transport/receive.rs @@ -84,6 +84,17 @@ impl ReceiveRecords { } } +#[cfg(all(test, any(unit_test, web_test)))] +impl ReceiveRecords { + /// Converts this into a stream that yields owned byte chunks. + /// + /// ## Panics + /// If inner stream yields an [`Err`] chunk. + pub(crate) fn into_bytes_stream(self) -> impl Stream> { + self.inner.map(Result::unwrap).map(Into::into) + } +} + impl Stream for ReceiveRecords { type Item = S::Item; diff --git a/ipa-core/src/helpers/transport/stream/box_body.rs b/ipa-core/src/helpers/transport/stream/box_body.rs index aa7a25583..d59a43295 100644 --- a/ipa-core/src/helpers/transport/stream/box_body.rs +++ b/ipa-core/src/helpers/transport/stream/box_body.rs @@ -6,7 +6,7 @@ use std::{ use bytes::Bytes; use futures::{stream::StreamExt, Stream}; -use crate::helpers::transport::stream::BoxBytesStream; +use crate::helpers::{transport::stream::BoxBytesStream, BytesStream}; pub struct WrappedBoxBodyStream(BoxBytesStream); @@ -22,6 +22,10 @@ impl WrappedBoxBodyStream { Self(Box::pin(input.map(Bytes::from).map(Ok))) } + pub fn from_bytes_stream(input: S) -> Self { + Self(Box::pin(input)) + } + #[must_use] pub fn empty() -> Self { WrappedBoxBodyStream(Box::pin(futures::stream::empty())) diff --git a/ipa-core/src/helpers/transport/stream/input.rs b/ipa-core/src/helpers/transport/stream/input.rs index 2eb1ca355..4e34d1400 100644 --- a/ipa-core/src/helpers/transport/stream/input.rs +++ b/ipa-core/src/helpers/transport/stream/input.rs @@ -154,12 +154,12 @@ enum ExtendResult { Error(io::Error), } -/// Parse a [`Stream`] of [`Bytes`] into a stream of records of some +/// Parse a [`Stream`] of bytes into a stream of records of some /// fixed-length-[`Serializable`] type `T`. #[pin_project] -pub struct RecordsStream +pub struct RecordsStream = Bytes> where - S: BytesStream, + S: BytesStream, T: Serializable, { // Our implementation of `poll_next` turns a `None` from the inner stream into `Some(Err(_))` if @@ -169,12 +169,12 @@ where #[pin] stream: Fuse, buffer: BufDeque, - phantom_data: PhantomData, + phantom_data: PhantomData<(T, R)>, } -impl RecordsStream +impl> RecordsStream where - S: BytesStream, + S: BytesStream, T: Serializable, { #[must_use] diff --git a/ipa-core/src/helpers/transport/stream/mod.rs b/ipa-core/src/helpers/transport/stream/mod.rs index 053b6033c..16a6666e2 100644 --- a/ipa-core/src/helpers/transport/stream/mod.rs +++ b/ipa-core/src/helpers/transport/stream/mod.rs @@ -16,7 +16,7 @@ pub use input::{LengthDelimitedStream, RecordsStream}; use crate::error::BoxError; -pub trait BytesStream: Stream> + Send { +pub trait BytesStream = Bytes>: Stream> + Send { /// Collects the entire stream into a vec; only intended for use in tests /// # Panics /// if the stream has any failure @@ -27,11 +27,11 @@ pub trait BytesStream: Stream> + Send { { use futures::StreamExt; - Box::pin(self.map(|item| item.unwrap().to_vec()).concat()) + Box::pin(self.map(|item| item.unwrap().as_ref().to_vec()).concat()) } } -impl> + Send> BytesStream for S {} +impl, S: Stream> + Send> BytesStream for S {} pub type BoxBytesStream = Pin>; diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index 795457718..fe25c32d4 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -624,8 +624,9 @@ pub(crate) mod tests { MpcHelperClient::resp_ok(resp).await.unwrap(); - let mut stream = - Arc::clone(&transport).receive(HelperIdentity::ONE, (QueryId, expected_step.clone())); + let mut stream = Arc::clone(&transport) + .receive(HelperIdentity::ONE, (QueryId, expected_step.clone())) + .into_bytes_stream(); assert_eq!( poll_immediate(&mut stream).next().await, diff --git a/ipa-core/src/net/server/handlers/query/step.rs b/ipa-core/src/net/server/handlers/query/step.rs index 223dbbff6..00a0de9d2 100644 --- a/ipa-core/src/net/server/handlers/query/step.rs +++ b/ipa-core/src/net/server/handlers/query/step.rs @@ -70,7 +70,9 @@ mod tests { .await .unwrap(); - let mut stream = Arc::clone(&transport).receive(HelperIdentity::TWO, (QueryId, step)); + let mut stream = Arc::clone(&transport) + .receive(HelperIdentity::TWO, (QueryId, step)) + .into_bytes_stream(); assert_eq!( poll_immediate(&mut stream).next().await, diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 8e9dbca8e..79a80bea7 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -6,17 +6,15 @@ use std::{ }; use async_trait::async_trait; -use bytes::Bytes; use futures::{Stream, TryFutureExt}; use pin_project::{pin_project, pinned_drop}; use crate::{ config::{NetworkConfig, ServerConfig}, - error::BoxError, helpers::{ query::QueryConfig, routing::{Addr, RouteId}, - ApiError, BodyStream, HandlerRef, HelperIdentity, HelperResponse, LogErrors, NoQueryId, + ApiError, BodyStream, HandlerRef, HelperIdentity, HelperResponse, NoQueryId, NoResourceIdentifier, NoStep, QueryIdBinding, ReceiveRecords, RequestHandler, RouteParams, StepBinding, StreamCollection, Transport, }, @@ -25,15 +23,13 @@ use crate::{ sync::Arc, }; -type LogHttpErrors = LogErrors; - /// HTTP transport for IPA helper service. pub struct HttpTransport { identity: HelperIdentity, clients: [MpcHelperClient; 3], // TODO(615): supporting multiple queries likely require a hashmap here. It will be ok if we // only allow one query at a time. - record_streams: StreamCollection, + record_streams: StreamCollection, handler: Option, } @@ -155,14 +151,14 @@ impl HttpTransport { stream: BodyStream, ) { self.record_streams - .add_stream((query_id, from, gate), LogErrors::new(stream)); + .add_stream((query_id, from, gate), stream); } } #[async_trait] impl Transport for Arc { type Identity = HelperIdentity; - type RecordsStream = ReceiveRecords; + type RecordsStream = ReceiveRecords; type Error = Error; fn identity(&self) -> HelperIdentity { @@ -232,6 +228,7 @@ impl Transport for Arc { mod tests { use std::{iter::zip, net::TcpListener, task::Poll}; + use bytes::Bytes; use futures::stream::{poll_immediate, StreamExt}; use futures_util::future::{join_all, try_join_all}; use generic_array::GenericArray; @@ -272,8 +269,9 @@ mod tests { Arc::clone(&transport).receive_stream(QueryId, STEP.clone(), HelperIdentity::TWO, body); // Request step data reception (normally called by protocol) - let mut stream = - Arc::clone(&transport).receive(HelperIdentity::TWO, (QueryId, STEP.clone())); + let mut stream = Arc::clone(&transport) + .receive(HelperIdentity::TWO, (QueryId, STEP.clone())) + .into_bytes_stream(); // make sure it is not ready as it hasn't received any data yet. assert!(matches!( From 5f3967d77550da3f78d7fb6ec407685052158912 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Fri, 5 Apr 2024 13:24:14 -0700 Subject: [PATCH 2/2] Small fix --- ipa-core/src/helpers/transport/stream/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ipa-core/src/helpers/transport/stream/mod.rs b/ipa-core/src/helpers/transport/stream/mod.rs index a5991a665..053b6033c 100644 --- a/ipa-core/src/helpers/transport/stream/mod.rs +++ b/ipa-core/src/helpers/transport/stream/mod.rs @@ -27,7 +27,7 @@ pub trait BytesStream: Stream> + Send { { use futures::StreamExt; - Box::pin(self.map(|item| item.unwrap().as_ref().to_vec()).concat()) + Box::pin(self.map(|item| item.unwrap().to_vec()).concat()) } }