diff --git a/.clippy.toml b/.clippy.toml index 238592639..d92b301bb 100644 --- a/.clippy.toml +++ b/.clippy.toml @@ -3,3 +3,5 @@ disallowed-methods = [ { path = "futures::future::join_all", reason = "We don't have a replacement for this method yet. Consider extending `SeqJoin` trait." }, { path = "futures::future::try_join_all", reason = "Use Context.try_join instead." }, ] + +allow-private-module-inception = true diff --git a/Cargo.toml b/Cargo.toml index 1767e8e24..70e40b5f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -99,6 +99,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } typenum = "1.16" # hpke is pinned to it x25519-dalek = "2.0.0-rc.3" +delegate = "0.10.0" [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = "0.5.0" diff --git a/src/helpers/buffers/mod.rs b/src/helpers/buffers/mod.rs index 83a92c24e..943c884dd 100644 --- a/src/helpers/buffers/mod.rs +++ b/src/helpers/buffers/mod.rs @@ -5,37 +5,3 @@ mod unordered_receiver; pub use ordering_mpsc::{ordering_mpsc, OrderingMpscReceiver, OrderingMpscSender}; pub use ordering_sender::{OrderedStream, OrderingSender}; pub use unordered_receiver::UnorderedReceiver; - -#[cfg(debug_assertions)] -#[allow(unused)] // todo(alex): make test world print the state again -mod waiting { - use std::collections::HashMap; - - use crate::helpers::ChannelId; - - pub(in crate::helpers) struct WaitingTasks<'a> { - tasks: HashMap<&'a ChannelId, Vec>, - } - - impl<'a> WaitingTasks<'a> { - pub fn new(tasks: HashMap<&'a ChannelId, Vec>) -> Self { - Self { tasks } - } - - pub fn is_empty(&self) -> bool { - self.tasks.is_empty() - } - } - - impl std::fmt::Debug for WaitingTasks<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "[")?; - for (channel, records) in &self.tasks { - write!(f, "\n {channel:?}: {records:?}")?; - } - write!(f, "\n]")?; - - Ok(()) - } - } -} diff --git a/src/helpers/buffers/ordering_sender.rs b/src/helpers/buffers/ordering_sender.rs index ea1459d62..866e36005 100644 --- a/src/helpers/buffers/ordering_sender.rs +++ b/src/helpers/buffers/ordering_sender.rs @@ -188,6 +188,10 @@ impl WaitingShard { self.wakers.pop_front().unwrap().w.wake(); } } + + pub fn waiting(&self) -> impl Iterator + '_ { + self.wakers.iter().map(|waker| waker.i) + } } /// A collection of wakers that are indexed by the send index (`i`). @@ -224,6 +228,12 @@ impl Waiting { fn wake(&self, i: usize) { self.shard(i).wake(i); } + + fn waiting(&self, indices: &mut Vec) { + self.shards + .iter() + .for_each(|shard| indices.extend(shard.lock().unwrap().waiting())); + } } /// An `OrderingSender` accepts messages for sending in any order, but @@ -375,6 +385,14 @@ impl OrderingSender { ) -> OrderedStream> { OrderedStream { sender: self } } + + pub fn waiting(&self) -> Vec { + let mut buf = Vec::new(); + self.waiting.waiting(&mut buf); + buf.sort_unstable(); + + buf + } } /// A future for writing item `i` into an `OrderingSender`. diff --git a/src/helpers/buffers/unordered_receiver.rs b/src/helpers/buffers/unordered_receiver.rs index d578e6456..c637e3625 100644 --- a/src/helpers/buffers/unordered_receiver.rs +++ b/src/helpers/buffers/unordered_receiver.rs @@ -143,7 +143,7 @@ where /// Note: in protocols we try to send before receiving, so we can rely on /// that easing load on this mechanism. There might also need to be some /// end-to-end back pressure for tasks that do not involve sending at all. - overflow_wakers: Vec, + overflow_wakers: Vec<(Waker, usize)>, _marker: PhantomData, } @@ -172,7 +172,7 @@ where ); // We don't save a waker at `self.next`, so `>` and not `>=`. if i > self.next + self.wakers.len() { - self.overflow_wakers.push(waker); + self.overflow_wakers.push((waker, i)); } else { let index = i % self.wakers.len(); if let Some(old) = self.wakers[index].as_ref() { @@ -195,7 +195,8 @@ where } if self.next % (self.wakers.len() / 2) == 0 { // Wake all the overflowed wakers. See comments on `overflow_wakers`. - for w in take(&mut self.overflow_wakers) { + // todo: we may want to wake specific wakers now + for (w, _) in take(&mut self.overflow_wakers) { w.wake(); } } @@ -228,6 +229,22 @@ where } } } + + fn waiting(&self) -> impl Iterator + '_ { + let start = self.next % self.wakers.len(); + self.wakers + .iter() + .enumerate() + .filter_map(|(i, waker)| waker.as_ref().map(|_| i)) + .map(move |i| { + if i < start { + self.next + (self.wakers.len() - start + i) + } else { + self.next + (i - start) + } + }) + .chain(self.overflow_wakers.iter().map(|v| v.1)) + } } /// Take an ordered stream of bytes and make messages from that stream @@ -284,6 +301,13 @@ where _marker: PhantomData, } } + + pub fn waiting(&self) -> Vec { + let mut r = self.inner.lock().unwrap().waiting().collect::>(); + r.sort_unstable(); + + r + } } impl Clone for UnorderedReceiver diff --git a/src/helpers/gateway/gateway.rs b/src/helpers/gateway/gateway.rs new file mode 100644 index 000000000..fdae37ff0 --- /dev/null +++ b/src/helpers/gateway/gateway.rs @@ -0,0 +1,233 @@ +use std::{ + fmt::{Debug, Formatter}, + num::NonZeroUsize, + time::Duration, +}; + +use delegate::delegate; +#[cfg(all(feature = "shuttle", test))] +use shuttle::future as tokio; + +use crate::{ + helpers::{ + gateway::{ + observable::{ObserveState, Observed}, + receive, + receive::GatewayReceivers, + send, + send::GatewaySenders, + transport::RoleResolvingTransport, + }, + ChannelId, Message, ReceivingEnd, Role, RoleAssignment, SendingEnd, TotalRecords, + TransportImpl, + }, + protocol::QueryId, + sync::{atomic::AtomicUsize, Arc, Weak}, +}; + +/// Gateway into IPA Network infrastructure. It allows helpers send and receive messages. +pub struct Gateway { + config: GatewayConfig, + transport: RoleResolvingTransport, + // todo: use different state when feature is off + inner: Arc, +} + +#[derive(Default)] +pub struct State { + senders: GatewaySenders, + receivers: GatewayReceivers, +} +impl State { + pub fn downgrade(self: &Arc) -> Weak { + Arc::downgrade(self) + } +} + +#[derive(Clone, Copy, Debug)] +pub struct GatewayConfig { + /// The number of items that can be active at the one time. + /// This is used to determine the size of sending and receiving buffers. + active: NonZeroUsize, + + /// Time to wait before checking gateway progress. If no progress has been made between + /// checks, the gateway is considered to be stalled and will create a report with outstanding + /// send/receive requests + pub progress_check_interval: Duration, +} + +impl Default for GatewayConfig { + fn default() -> Self { + Self::new(1024) + } +} + +impl GatewayConfig { + /// Generate a new configuration with the given active limit. + /// + /// ## Panics + /// If `active` is 0. + #[must_use] + pub fn new(active: usize) -> Self { + // In-memory tests move data fast, so progress check intervals can be lower. + // Real world scenarios currently over-report stalls because of inefficiencies inside + // infrastructure and actual networking issues. This checks is only valuable to report + // bugs, so keeping it large enough to avoid false positives. + Self { + active: NonZeroUsize::new(active).unwrap(), + progress_check_interval: Duration::from_secs(if cfg!(test) { 5 } else { 60 }), + } + } + + /// The configured amount of active work. + #[must_use] + pub fn active_work(&self) -> NonZeroUsize { + self.active + } +} + +impl Observed { + delegate! { + to self.inner() { + #[inline] + pub fn role(&self) -> Role; + + #[inline] + pub fn config(&self) -> &GatewayConfig; + } + } + + pub fn new( + query_id: QueryId, + config: GatewayConfig, + roles: RoleAssignment, + transport: TransportImpl, + ) -> Self { + let version = Arc::new(AtomicUsize::default()); + // todo: this sucks, we shouldn't do an extra clone + Self::wrap(&version, Gateway::new(query_id, config, roles, transport)) + } + + #[must_use] + pub fn get_sender( + &self, + channel_id: &ChannelId, + total_records: TotalRecords, + ) -> SendingEnd { + Observed::wrap( + self.get_version(), + self.inner().get_sender(channel_id, total_records), + ) + } + + #[must_use] + pub fn get_receiver(&self, channel_id: &ChannelId) -> ReceivingEnd { + Observed::wrap(self.get_version(), self.inner().get_receiver(channel_id)) + } + + pub fn to_observed(&self) -> Observed> { + // todo: inner.inner + Observed::wrap(self.get_version(), self.inner().inner.downgrade()) + } +} + +pub struct GatewayWaitingTasks { + senders_state: Option, + receivers_state: Option, +} + +impl Debug for GatewayWaitingTasks { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if let Some(senders_state) = &self.senders_state { + write!(f, "\n{{{senders_state:?}\n}}")?; + } + if let Some(receivers_state) = &self.receivers_state { + write!(f, "\n{{{receivers_state:?}\n}}")?; + } + + Ok(()) + } +} + +impl ObserveState for Weak { + type State = GatewayWaitingTasks; + + fn get_state(&self) -> Option { + self.upgrade().map(|state| Self::State { + senders_state: state.senders.get_state(), + receivers_state: state.receivers.get_state(), + }) + } +} + +impl Gateway { + #[must_use] + pub fn new( + query_id: QueryId, + config: GatewayConfig, + roles: RoleAssignment, + transport: TransportImpl, + ) -> Self { + Self { + config, + transport: RoleResolvingTransport { + query_id, + roles, + inner: transport, + config, + }, + inner: State::default().into(), + } + } + + #[must_use] + pub fn role(&self) -> Role { + self.transport.role() + } + + #[must_use] + pub fn config(&self) -> &GatewayConfig { + &self.config + } + + /// + /// ## Panics + /// If there is a failure connecting via HTTP + #[must_use] + pub fn get_sender( + &self, + channel_id: &ChannelId, + total_records: TotalRecords, + ) -> send::SendingEnd { + let (tx, maybe_stream) = self.inner.senders.get_or_create::( + channel_id, + self.config.active_work(), + total_records, + ); + if let Some(stream) = maybe_stream { + tokio::spawn({ + let channel_id = channel_id.clone(); + let transport = self.transport.clone(); + async move { + // TODO(651): In the HTTP case we probably need more robust error handling here. + transport + .send(&channel_id, stream) + .await + .expect("{channel_id:?} receiving end should be accepted by transport"); + } + }); + } + + send::SendingEnd::new(tx, self.role(), channel_id) + } + + #[must_use] + pub fn get_receiver(&self, channel_id: &ChannelId) -> receive::ReceivingEnd { + receive::ReceivingEnd::new( + channel_id.clone(), + self.inner + .receivers + .get_or_create(channel_id, || self.transport.receive(channel_id)), + ) + } +} diff --git a/src/helpers/gateway/mod.rs b/src/helpers/gateway/mod.rs index c2a37916c..e93e6893e 100644 --- a/src/helpers/gateway/mod.rs +++ b/src/helpers/gateway/mod.rs @@ -1,24 +1,20 @@ +#[allow(clippy::module_inception)] // private +mod gateway; +mod observable; +#[cfg(not(feature = "shuttle"))] // todo its own feature +pub mod observer; mod receive; mod send; mod transport; -use std::{fmt::Debug, num::NonZeroUsize}; +use std::ops::RangeInclusive; -pub use send::SendingEnd; -#[cfg(all(feature = "shuttle", test))] -use shuttle::future as tokio; +// TODO: feature flag +pub type SendingEnd = Observed>; +pub type ReceivingEnd = Observed>; +pub type Gateway = Observed; -use crate::{ - helpers::{ - gateway::{ - receive::{GatewayReceivers, ReceivingEnd as ReceivingEndBase}, - send::GatewaySenders, - transport::RoleResolvingTransport, - }, - ChannelId, Message, Role, RoleAssignment, TotalRecords, Transport, - }, - protocol::QueryId, -}; +use crate::helpers::{gateway::observable::Observed, Transport}; /// Alias for the currently configured transport. /// @@ -31,120 +27,21 @@ pub type TransportImpl = super::transport::InMemoryTransport; pub type TransportImpl = crate::sync::Arc; pub type TransportError = ::Error; -pub type ReceivingEnd = ReceivingEndBase; -/// Gateway into IPA Infrastructure systems. This object allows sending and receiving messages. -/// As it is generic over network/transport layer implementation, type alias [`Gateway`] should be -/// used to avoid carrying `T` over. -/// -/// [`Gateway`]: crate::helpers::Gateway -pub struct Gateway { - config: GatewayConfig, - transport: RoleResolvingTransport, - senders: GatewaySenders, - receivers: GatewayReceivers, -} - -#[derive(Clone, Copy, Debug)] -pub struct GatewayConfig { - /// The number of items that can be active at the one time. - /// This is used to determine the size of sending and receiving buffers. - active: NonZeroUsize, -} - -impl Gateway { - #[must_use] - pub fn new( - query_id: QueryId, - config: GatewayConfig, - roles: RoleAssignment, - transport: T, - ) -> Self { - Self { - config, - transport: RoleResolvingTransport { - query_id, - roles, - inner: transport, - config, - }, - senders: GatewaySenders::default(), - receivers: GatewayReceivers::default(), - } - } - - #[must_use] - pub fn role(&self) -> Role { - self.transport.role() - } - - #[must_use] - pub fn config(&self) -> &GatewayConfig { - &self.config - } - - /// - /// ## Panics - /// If there is a failure connecting via HTTP - #[must_use] - pub fn get_sender( - &self, - channel_id: &ChannelId, - total_records: TotalRecords, - ) -> SendingEnd { - let (tx, maybe_stream) = - self.senders - .get_or_create::(channel_id, self.config.active_work(), total_records); - if let Some(stream) = maybe_stream { - tokio::spawn({ - let channel_id = channel_id.clone(); - let transport = self.transport.clone(); - async move { - // TODO(651): In the HTTP case we probably need more robust error handling here. - transport - .send(&channel_id, stream) - .await - .expect("{channel_id:?} receiving end should be accepted by transport"); - } - }); - } - - SendingEnd::new(tx, self.role(), channel_id) - } - - #[must_use] - pub fn get_receiver(&self, channel_id: &ChannelId) -> ReceivingEndBase { - ReceivingEndBase::new( - channel_id.clone(), - self.receivers - .get_or_create(channel_id, || self.transport.receive(channel_id)), - ) - } -} - -impl Default for GatewayConfig { - fn default() -> Self { - Self::new(1024) - } -} - -impl GatewayConfig { - /// Generate a new configuration with the given active limit. - /// - /// ## Panics - /// If `active` is 0. - #[must_use] - pub fn new(active: usize) -> Self { - Self { - active: NonZeroUsize::new(active).unwrap(), - } - } - - /// The configured amount of active work. - #[must_use] - pub fn active_work(&self) -> NonZeroUsize { - self.active - } +pub use gateway::GatewayConfig; + +fn to_ranges(nums: Vec) -> Vec> { + nums.into_iter() + .fold(Vec::>::new(), |mut ranges, num| { + if let Some(last_range) = ranges.last_mut().filter(|r| *r.end() == num - 1) { + *last_range = *last_range.start()..=num; + } else { + ranges.push(num..=num); + } + ranges + }) + .into_iter() + .collect() } #[cfg(all(test, unit_test))] @@ -215,6 +112,8 @@ mod tests { // sent (same batch or different does not matter here) let spawned = tokio::spawn(async move { let channel = sender_ctx.send_channel(Role::H2); + // channel.send(RecordId::from(1), Fp31::truncate_from(1_u128)).await.unwrap(); + // channel.send(RecordId::from(0), Fp31::truncate_from(0_u128)).await.unwrap(); try_join( channel.send(RecordId::from(1), Fp31::truncate_from(1_u128)), channel.send(RecordId::from(0), Fp31::truncate_from(0_u128)), diff --git a/src/helpers/gateway/observable.rs b/src/helpers/gateway/observable.rs new file mode 100644 index 000000000..cbd30ab23 --- /dev/null +++ b/src/helpers/gateway/observable.rs @@ -0,0 +1,78 @@ +use std::{ + cmp::Ordering::{Equal, Greater, Less}, + fmt::{Debug, Display}, + ops::{RangeInclusive, Sub}, +}; + +use crate::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +pub struct Observed { + version: Arc, + inner: T, +} + +impl Observed { + pub fn get_version(&self) -> &Arc { + &self.version + } + + pub fn wrap(version: &Arc, inner: T) -> Self { + Self { + version: Arc::clone(version), + inner, + } + } + + pub fn inner(&self) -> &T { + &self.inner + } + + pub fn inc_sn(&self) { + self.version.fetch_add(1, Ordering::Relaxed); + } + + pub fn get_sn(&self) -> usize { + self.version.load(Ordering::Relaxed) + } + + pub fn map R, R>(&self, f: F) -> Observed { + Observed { + version: Arc::clone(&self.version), + inner: f(self), + } + } +} + +impl Observed { + pub fn get_state(&self) -> Option { + self.inner().get_state() + } +} + +pub trait ObserveState { + type State: Debug; + fn get_state(&self) -> Option; +} + +impl ObserveState for Vec> +where + U: Copy + Display + Eq + PartialOrd + Ord + Sub + From, +{ + type State = Vec; + fn get_state(&self) -> Option { + Some( + self.iter() + .map( + |range| match (*range.end() - *range.start()).cmp(&U::from(1)) { + Less => format!("{}", range.start()), + Equal => format!("[{}, {}] ", range.start(), range.end()), + Greater => format!("[{},...,{}] ", range.start(), range.end()), + }, + ) + .collect(), + ) + } +} diff --git a/src/helpers/gateway/observer.rs b/src/helpers/gateway/observer.rs new file mode 100644 index 000000000..2041477a3 --- /dev/null +++ b/src/helpers/gateway/observer.rs @@ -0,0 +1,32 @@ +use std::time::Duration; + +use tracing::{Instrument, Span}; + +use crate::{ + helpers::gateway::observable::{ObserveState, Observed}, + task::JoinHandle, +}; + +#[cfg(not(feature = "shuttle"))] +pub fn spawn( + within: Span, + check_interval: Duration, + observed: Observed, +) -> JoinHandle<()> { + tokio::spawn(async move { + let mut last_observed = 0; + loop { + ::tokio::time::sleep(check_interval).await; + let now = observed.get_sn(); + if now == last_observed { + if let Some(state) = observed.get_state() { + tracing::warn!(sn = now, state = ?state, "Helper is stalled after {check_interval:?}"); + } else { + break + } + } else { + last_observed = now; + } + } + }.instrument(within)) +} diff --git a/src/helpers/gateway/receive.rs b/src/helpers/gateway/receive.rs index 282ff68e2..042559603 100644 --- a/src/helpers/gateway/receive.rs +++ b/src/helpers/gateway/receive.rs @@ -1,32 +1,70 @@ -use std::marker::PhantomData; +use std::{ + collections::HashMap, + fmt::{Debug, Formatter}, + marker::PhantomData, +}; use dashmap::{mapref::entry::Entry, DashMap}; +use delegate::delegate; use futures::Stream; use crate::{ - helpers::{buffers::UnorderedReceiver, ChannelId, Error, Message, Transport}, + helpers::{ + buffers::UnorderedReceiver, + gateway::{ + observable::{ObserveState, Observed}, + to_ranges, + }, + ChannelId, Error, Message, Transport, TransportImpl, + }, protocol::RecordId, }; /// Receiving end end of the gateway channel. -pub struct ReceivingEnd { +pub struct ReceivingEnd { channel_id: ChannelId, - unordered_rx: UR, + unordered_rx: UR, _phantom: PhantomData, } +impl Observed> { + delegate! { + to { self.inc_sn(); self.inner() } { + #[inline] + pub async fn receive(&self, record_id: RecordId) -> Result; + } + } +} + +pub struct WaitingTasks(HashMap>); + +impl Debug for WaitingTasks { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + for (channel, records) in &self.0 { + write!( + f, + "\n\"{:?}\", from={:?}. Waiting to receive records {:?}.", + channel.gate, channel.role, records + )?; + } + + Ok(()) + } +} + /// Receiving channels, indexed by (role, step). -pub(super) struct GatewayReceivers { - inner: DashMap>, +#[derive(Default, Clone)] +pub(super) struct GatewayReceivers { + inner: DashMap, } -pub(super) type UR = UnorderedReceiver< - ::RecordsStream, - <::RecordsStream as Stream>::Item, +pub(super) type UR = UnorderedReceiver< + ::RecordsStream, + <::RecordsStream as Stream>::Item, >; -impl ReceivingEnd { - pub(super) fn new(channel_id: ChannelId, rx: UR) -> Self { +impl ReceivingEnd { + pub(super) fn new(channel_id: ChannelId, rx: UR) -> Self { Self { channel_id, unordered_rx: rx, @@ -55,16 +93,8 @@ impl ReceivingEnd { } } -impl Default for GatewayReceivers { - fn default() -> Self { - Self { - inner: DashMap::default(), - } - } -} - -impl GatewayReceivers { - pub fn get_or_create UR>(&self, channel_id: &ChannelId, ctr: F) -> UR { +impl GatewayReceivers { + pub fn get_or_create UR>(&self, channel_id: &ChannelId, ctr: F) -> UR { // TODO: raw entry API if it becomes available to avoid cloning the key match self.inner.entry(channel_id.clone()) { Entry::Occupied(entry) => entry.get().clone(), @@ -77,3 +107,19 @@ impl GatewayReceivers { } } } + +impl ObserveState for GatewayReceivers { + type State = WaitingTasks; + + fn get_state(&self) -> Option { + let mut map = HashMap::default(); + for entry in &self.inner { + let channel = entry.key(); + if let Some(waiting) = to_ranges(entry.value().waiting()).get_state() { + map.insert(channel.clone(), waiting); + } + } + + Some(WaitingTasks(map)) + } +} diff --git a/src/helpers/gateway/send.rs b/src/helpers/gateway/send.rs index 4eb876af0..6d4356986 100644 --- a/src/helpers/gateway/send.rs +++ b/src/helpers/gateway/send.rs @@ -1,4 +1,6 @@ use std::{ + collections::HashMap, + fmt::{Debug, Formatter}, marker::PhantomData, num::NonZeroUsize, pin::Pin, @@ -6,11 +8,19 @@ use std::{ }; use dashmap::{mapref::entry::Entry, DashMap}; +use delegate::delegate; use futures::Stream; use typenum::Unsigned; use crate::{ - helpers::{buffers::OrderingSender, ChannelId, Error, Message, Role, TotalRecords}, + helpers::{ + buffers::OrderingSender, + gateway::{ + observable::{ObserveState, Observed}, + to_ranges, + }, + ChannelId, Error, Message, Role, TotalRecords, + }, protocol::RecordId, sync::Arc, telemetry::{ @@ -27,12 +37,54 @@ pub struct SendingEnd { _phantom: PhantomData, } +impl Observed> { + delegate! { + to { self.inc_sn(); self.inner() } { + #[inline] + pub async fn send(&self, record_id: RecordId, msg: M) -> Result<(), Error>; + } + } +} + /// Sending channels, indexed by (role, step). -#[derive(Default)] +#[derive(Default, Clone)] pub(super) struct GatewaySenders { inner: DashMap>, } +pub struct WaitingTasks(HashMap>); + +impl Debug for WaitingTasks { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + for (channel, records) in &self.0 { + write!( + f, + "\n\"{:?}\", to={:?}. Waiting to send records {:?}.", + channel.gate, channel.role, records + )?; + } + + Ok(()) + } +} + +impl ObserveState for GatewaySenders { + type State = WaitingTasks; + + fn get_state(&self) -> Option { + let mut state = HashMap::new(); + for entry in &self.inner { + let channel = entry.key(); + let sender = entry.value(); + if let Some(sender_state) = sender.get_state() { + state.insert(channel.clone(), sender_state); + } + } + + Some(WaitingTasks(state)) + } +} + pub(super) struct GatewaySender { channel_id: ChannelId, ordering_tx: OrderingSender, @@ -43,6 +95,15 @@ pub(super) struct GatewaySendStream { inner: Arc, } +impl ObserveState for GatewaySender { + type State = Vec; + + fn get_state(&self) -> Option { + let waiting_indices = self.ordering_tx.waiting(); + to_ranges(waiting_indices).get_state() + } +} + impl GatewaySender { fn new(channel_id: ChannelId, tx: OrderingSender, total_records: TotalRecords) -> Self { Self { diff --git a/src/helpers/gateway/transport.rs b/src/helpers/gateway/transport.rs index 94563b3c0..8c90a29ee 100644 --- a/src/helpers/gateway/transport.rs +++ b/src/helpers/gateway/transport.rs @@ -2,7 +2,7 @@ use crate::{ helpers::{ buffers::UnorderedReceiver, gateway::{receive::UR, send::GatewaySendStream}, - ChannelId, GatewayConfig, Role, RoleAssignment, RouteId, Transport, + ChannelId, GatewayConfig, Role, RoleAssignment, RouteId, Transport, TransportImpl, }, protocol::QueryId, }; @@ -12,19 +12,19 @@ use crate::{ /// /// [`HelperIdentity`]: crate::helpers::HelperIdentity #[derive(Clone)] -pub(super) struct RoleResolvingTransport { +pub(super) struct RoleResolvingTransport { pub query_id: QueryId, pub roles: RoleAssignment, pub config: GatewayConfig, - pub inner: T, + pub inner: TransportImpl, } -impl RoleResolvingTransport { +impl RoleResolvingTransport { pub(crate) async fn send( &self, channel_id: &ChannelId, data: GatewaySendStream, - ) -> Result<(), T::Error> { + ) -> Result<(), ::Error> { let dest_identity = self.roles.identity(channel_id.role); assert_ne!( dest_identity, @@ -41,7 +41,7 @@ impl RoleResolvingTransport { .await } - pub(crate) fn receive(&self, channel_id: &ChannelId) -> UR { + pub(crate) fn receive(&self, channel_id: &ChannelId) -> UR { let peer = self.roles.identity(channel_id.role); assert_ne!( peer, diff --git a/src/helpers/mod.rs b/src/helpers/mod.rs index 373070736..bd3aae00a 100644 --- a/src/helpers/mod.rs +++ b/src/helpers/mod.rs @@ -3,6 +3,8 @@ use std::{ num::NonZeroUsize, }; +use generic_array::GenericArray; + mod buffers; mod error; mod gateway; @@ -15,11 +17,12 @@ use std::ops::{Index, IndexMut}; #[cfg(test)] pub use buffers::OrderingSender; pub use error::{Error, Result}; +#[cfg(not(feature = "shuttle"))] +pub use gateway::observer; // TODO: this type should only be available within infra. Right now several infra modules // are exposed at the root level. That makes it impossible to have a proper hierarchy here. pub use gateway::{Gateway, TransportError, TransportImpl}; pub use gateway::{GatewayConfig, ReceivingEnd, SendingEnd}; -use generic_array::GenericArray; pub use prss_protocol::negotiate as negotiate_prss; #[cfg(feature = "web-app")] pub use transport::WrappedAxumBodyStream; diff --git a/src/helpers/prss_protocol.rs b/src/helpers/prss_protocol.rs index e44fec8ed..4dddd21fb 100644 --- a/src/helpers/prss_protocol.rs +++ b/src/helpers/prss_protocol.rs @@ -3,7 +3,7 @@ use rand_core::{CryptoRng, RngCore}; use x25519_dalek::PublicKey; use crate::{ - helpers::{ChannelId, Direction, Error, Gateway, TotalRecords, Transport}, + helpers::{ChannelId, Direction, Error, Gateway, TotalRecords}, protocol::{ prss, step::{Gate, Step, StepNarrow}, @@ -24,8 +24,8 @@ impl Step for PrssExchangeStep {} /// establish the prss endpoint by exchanging public keys with the other helpers /// # Errors /// if communication with other helpers fails -pub async fn negotiate( - gateway: &Gateway, +pub async fn negotiate( + gateway: &Gateway, gate: &Gate, rng: &mut R, ) -> Result { diff --git a/src/test_fixture/world.rs b/src/test_fixture/world.rs index e3502056d..520f39893 100644 --- a/src/test_fixture/world.rs +++ b/src/test_fixture/world.rs @@ -120,6 +120,18 @@ impl TestWorld { } let gateways = gateways.map(Option::unwrap); + #[cfg(not(feature = "shuttle"))] + gateways + .iter() + .map(|g| { + crate::helpers::observer::spawn( + tracing::info_span!("Observer", role=?g.role()), + config.gateway_config.progress_check_interval, + g.to_observed(), + ) + }) + .for_each(drop); + TestWorld { gateways, participants,