From e452ecda739966232f0158dc7649d2de4630f1df Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Thu, 2 Nov 2023 16:33:48 -0700 Subject: [PATCH] Final touches --- Cargo.toml | 3 +- src/helpers/gateway/mod.rs | 2 +- src/helpers/gateway/send.rs | 1 - src/helpers/gateway/stall_detection.rs | 345 +++++++++++++------------ src/helpers/mod.rs | 6 +- src/protocol/step/compact.rs | 2 +- src/protocol/step/descriptive.rs | 2 +- 7 files changed, 182 insertions(+), 179 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 15bc169f9..e35bfb09a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,8 @@ disable-metrics = [] web-app = ["axum", "axum-server", "base64", "clap", "comfy-table", "enable-serde", "hyper", "hyper-rustls", "rcgen", "rustls", "rustls-pemfile", "time", "tokio-rustls", "toml", "tower", "tower-http"] test-fixture = ["enable-serde", "weak-field"] # Include observability instruments that detect lack of progress inside MPC. If there is a bug that leads to helper -# miscommunication, this feature helps to detect it. Turning it on hurts performance a bit. +# miscommunication, this feature helps to detect it. Turning it on has some cost. +# If "shuttle" feature is enabled, turning this on has no effect. stall-detection = [] shuttle = ["shuttle-crate", "test-fixture"] debug-trace = ["tracing/max_level_trace", "tracing/release_max_level_debug"] diff --git a/src/helpers/gateway/mod.rs b/src/helpers/gateway/mod.rs index 54ec1e8c1..5365da0f2 100644 --- a/src/helpers/gateway/mod.rs +++ b/src/helpers/gateway/mod.rs @@ -11,7 +11,7 @@ pub(super) use send::SendingEnd; #[cfg(all(test, feature = "shuttle"))] use shuttle::future as tokio; #[cfg(feature = "stall-detection")] -pub(super) use {stall_detection::InstrumentedGateway}; +pub(super) use stall_detection::InstrumentedGateway; use crate::{ helpers::{ diff --git a/src/helpers/gateway/send.rs b/src/helpers/gateway/send.rs index 1deed8349..1bad4b28f 100644 --- a/src/helpers/gateway/send.rs +++ b/src/helpers/gateway/send.rs @@ -27,7 +27,6 @@ pub struct SendingEnd { _phantom: PhantomData, } - /// Sending channels, indexed by (role, step). #[derive(Default, Clone)] pub(super) struct GatewaySenders { diff --git a/src/helpers/gateway/stall_detection.rs b/src/helpers/gateway/stall_detection.rs index dd4af2309..34ef59a91 100644 --- a/src/helpers/gateway/stall_detection.rs +++ b/src/helpers/gateway/stall_detection.rs @@ -1,26 +1,14 @@ use std::{ - fmt::{Debug, Formatter}, - ops::RangeInclusive, + fmt::{Debug, Display, Formatter}, + ops::{RangeInclusive, Sub}, }; -use std::fmt::Display; -use std::ops::Sub; -use std::sync::atomic::Ordering; -use delegate::delegate; +pub use gateway::InstrumentedGateway; -use super::{ - Gateway, GatewayConfig, State, +use crate::sync::{ + atomic::{AtomicUsize, Ordering}, + Weak, }; -use crate::{ - helpers::{ - ChannelId, Message, ReceivingEnd, Role, RoleAssignment, SendingEnd, TotalRecords, - TransportImpl, - }, - protocol::QueryId, - sync::{atomic::AtomicUsize, Arc, Weak}, - task::JoinHandle, -}; - /// Trait for structs that can report their current state. pub trait ObserveState { @@ -41,27 +29,28 @@ pub struct Observed { /// object is dropped. /// /// External observers watching this object will declare it stalled if it's sequence number - /// hasn't been incremented for long enough time. If `T` implements `ObserveState`, then the - /// state of `T` is also reported. + /// hasn't been incremented for long enough time. It can happen for two reasons: either there is + /// no work to do for this object, or its state is not drained/consumed by the clients. In the + /// former case, the bottleneck is somewhere else, otherwise if `T` implements `ObserveState`, + /// the current state of `T` is also reported. sn: Weak, inner: T, } -impl Observed { +impl Observed { fn wrap(sn: Weak, inner: T) -> Self { - Self { - sn, - inner, - } + Self { sn, inner } } - fn current_sn(&self) -> &Weak { + fn get_sn(&self) -> &Weak { &self.sn } + /// Advances the sequence number ahead. + /// /// ## Panics /// This will panic if the sequence number is dropped. - fn inc_sn(&self) { + fn advance(&self) { let sn = self.sn.upgrade().unwrap(); sn.fetch_add(1, Ordering::Relaxed); } @@ -69,7 +58,6 @@ impl Observed { fn inner(&self) -> &T { &self.inner } - } impl Observed { @@ -78,177 +66,174 @@ impl Observed { } } -pub fn spawn( - within: tracing::Span, - check_interval: std::time::Duration, - observed: Observed, -) -> JoinHandle<()> { - use tracing::Instrument; - tokio::spawn(async move { - let mut last_observed = 0; - loop { - ::tokio::time::sleep(check_interval).await; - let now = observed.current_sn().upgrade().map(|v| v.load(Ordering::Relaxed)); - if let Some(now) = now { - if now == last_observed { - if let Some(state) = observed.get_state() { - tracing::warn!(sn = now, state = ?state, "Helper is stalled after {check_interval:?}"); - } - } - last_observed = now; - } else { - break; - } - } - }.instrument(within)) -} +mod gateway { + use delegate::delegate; -pub struct InstrumentedGateway { - gateway: Gateway, - // Gateway owns the sequence number associated with it. When it goes out of scope, sn is destroyed - // and external observers can see that they no longer need to watch it. - _sn: Arc, -} - -impl Observed { - delegate! { - to self.inner().gateway { - #[inline] - pub fn role(&self) -> Role; + use super::*; + use crate::{ + helpers::{ + gateway::{Gateway, State}, + ChannelId, GatewayConfig, Message, ReceivingEnd, Role, RoleAssignment, SendingEnd, + TotalRecords, TransportImpl, + }, + protocol::QueryId, + sync::Arc, + }; - #[inline] - pub fn config(&self) -> &GatewayConfig; - } + pub struct InstrumentedGateway { + gateway: Gateway, + // Gateway owns the sequence number associated with it. When it goes out of scope, sn is destroyed + // and external observers can see that they no longer need to watch it. + _sn: Arc, } - pub fn new( - query_id: QueryId, - config: GatewayConfig, - roles: RoleAssignment, - transport: TransportImpl, - ) -> Self { - let version = Arc::new(AtomicUsize::default()); - let r = Self::wrap( - Arc::downgrade(&version), - InstrumentedGateway { - gateway: Gateway::new(query_id, config, roles, transport), - _sn: version, - }, - ); - - // spawn observer - spawn( - tracing::info_span!("observer", role=?r.role()), - config.progress_check_interval, - r.to_observed(), - ); - - r - } + impl Observed { + delegate! { + to self.inner().gateway { - #[must_use] - pub fn get_sender( - &self, - channel_id: &ChannelId, - total_records: TotalRecords, - ) -> SendingEnd { - Observed::wrap( - Weak::clone(self.current_sn()), - self.inner().gateway.get_sender(channel_id, total_records), - ) - } + #[inline] + pub fn role(&self) -> Role; - #[must_use] - pub fn get_receiver(&self, channel_id: &ChannelId) -> ReceivingEnd { - Observed::wrap( - Weak::clone(self.current_sn()), - self.inner().gateway.get_receiver(channel_id), - ) - } + #[inline] + pub fn config(&self) -> &GatewayConfig; + } + } - pub fn to_observed(&self) -> Observed> { - // todo: inner.inner - Observed::wrap( - Weak::clone(self.current_sn()), - Arc::downgrade(&self.inner().gateway.inner), - ) - } -} + pub fn new( + query_id: QueryId, + config: GatewayConfig, + roles: RoleAssignment, + transport: TransportImpl, + ) -> Self { + let version = Arc::new(AtomicUsize::default()); + let r = Self::wrap( + Arc::downgrade(&version), + InstrumentedGateway { + gateway: Gateway::new(query_id, config, roles, transport), + _sn: version, + }, + ); + + // spawn the watcher + #[cfg(not(feature = "shuttle"))] + { + use tracing::Instrument; + + tokio::spawn({ + let gateway = r.to_observed(); + async move { + let mut last_observed = 0; + loop { + ::tokio::time::sleep(config.progress_check_interval).await; + let now = gateway.get_sn().upgrade().map(|v| v.load(Ordering::Relaxed)); + if let Some(now) = now { + if now == last_observed { + if let Some(state) = gateway.get_state() { + tracing::warn!(sn = now, state = ?state, "Helper is stalled"); + } + } + last_observed = now; + } else { + break; + } + } + }.instrument(tracing::info_span!("stall_detector", role = ?r.role())) + }); + } -pub struct GatewayWaitingTasks { - senders_state: Option, - receivers_state: Option, -} + r + } -impl Debug for GatewayWaitingTasks { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - if let Some(senders_state) = &self.senders_state { - write!(f, "\n{{{senders_state:?}\n}}")?; + #[must_use] + pub fn get_sender( + &self, + channel_id: &ChannelId, + total_records: TotalRecords, + ) -> SendingEnd { + Observed::wrap( + Weak::clone(self.get_sn()), + self.inner().gateway.get_sender(channel_id, total_records), + ) } - if let Some(receivers_state) = &self.receivers_state { - write!(f, "\n{{{receivers_state:?}\n}}")?; + + #[must_use] + pub fn get_receiver(&self, channel_id: &ChannelId) -> ReceivingEnd { + Observed::wrap( + Weak::clone(self.get_sn()), + self.inner().gateway.get_receiver(channel_id), + ) } - Ok(()) + pub fn to_observed(&self) -> Observed> { + // todo: inner.inner + Observed::wrap( + Weak::clone(self.get_sn()), + Arc::downgrade(&self.inner().gateway.inner), + ) + } } -} -impl ObserveState for Weak { - type State = GatewayWaitingTasks; + pub struct GatewayWaitingTasks { + senders_state: Option, + receivers_state: Option, + } - fn get_state(&self) -> Option { - self.upgrade().and_then(|state| { - match (state.senders.get_state(), state.receivers.get_state()) { - (None, None) => None, - (senders_state, receivers_state) => Some(Self::State { - senders_state, - receivers_state, - }), + impl Debug for GatewayWaitingTasks { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if let Some(senders_state) = &self.senders_state { + write!(f, "\n{{{senders_state:?}\n}}")?; } - }) + if let Some(receivers_state) = &self.receivers_state { + write!(f, "\n{{{receivers_state:?}\n}}")?; + } + + Ok(()) + } } -} -fn to_ranges(nums: Vec) -> Vec> { - nums.into_iter() - .fold(Vec::>::new(), |mut ranges, num| { - if let Some(last_range) = ranges.last_mut().filter(|r| *r.end() == num - 1) { - *last_range = *last_range.start()..=num; - } else { - ranges.push(num..=num); - } - ranges - }) + impl ObserveState for Weak { + type State = GatewayWaitingTasks; + + fn get_state(&self) -> Option { + self.upgrade().and_then(|state| { + match (state.senders.get_state(), state.receivers.get_state()) { + (None, None) => None, + (senders_state, receivers_state) => Some(Self::State { + senders_state, + receivers_state, + }), + } + }) + } + } } mod receive { use std::{ - collections::HashMap, fmt::{Debug, Formatter}, }; + use std::collections::BTreeMap; + use super::*; use crate::{ helpers::{ error::Error, - gateway::{ - receive::GatewayReceivers, ReceivingEnd, - }, + gateway::{receive::GatewayReceivers, ReceivingEnd}, ChannelId, Message, }, protocol::RecordId, }; - use super::*; impl Observed> { delegate::delegate! { - to { self.inc_sn(); self.inner() } { + to { self.advance(); self.inner() } { #[inline] pub async fn receive(&self, record_id: RecordId) -> Result; } } } - pub struct WaitingTasks(HashMap>); + pub struct WaitingTasks(BTreeMap>); impl Debug for WaitingTasks { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { @@ -268,7 +253,7 @@ mod receive { type State = WaitingTasks; fn get_state(&self) -> Option { - let mut map = HashMap::default(); + let mut map = BTreeMap::default(); for entry in &self.inner { let channel = entry.key(); if let Some(waiting) = super::to_ranges(entry.value().waiting()).get_state() { @@ -283,27 +268,30 @@ mod receive { mod send { use std::{ - collections::HashMap, fmt::{Debug, Formatter}, }; + use std::collections::BTreeMap; - use crate::helpers::{gateway::{ - send::{GatewaySender, GatewaySenders}, - }, ChannelId, Message, TotalRecords}; - use crate::protocol::RecordId; - use crate::helpers::error::Error; use super::*; + use crate::{ + helpers::{ + error::Error, + gateway::send::{GatewaySender, GatewaySenders}, + ChannelId, Message, TotalRecords, + }, + protocol::RecordId, + }; impl Observed> { delegate::delegate! { - to { self.inc_sn(); self.inner() } { + to { self.advance(); self.inner() } { #[inline] pub async fn send(&self, record_id: RecordId, msg: M) -> Result<(), Error>; } } } - pub struct WaitingTasks(HashMap)>); + pub struct WaitingTasks(BTreeMap)>); impl Debug for WaitingTasks { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { @@ -323,7 +311,7 @@ mod send { type State = WaitingTasks; fn get_state(&self) -> Option { - let mut state = HashMap::new(); + let mut state = BTreeMap::new(); for entry in &self.inner { let channel = entry.key(); let sender = entry.value(); @@ -346,9 +334,24 @@ mod send { } } +/// Converts a vector of numbers into a vector of ranges. +/// For example, [1, 2, 3, 4, 5, 7, 9, 10, 11] produces [(1..=5), (7..=7), (9..=11)]. +fn to_ranges(nums: Vec) -> Vec> { + nums.into_iter() + .fold(Vec::>::new(), |mut ranges, num| { + if let Some(last_range) = ranges.last_mut().filter(|r| *r.end() == num - 1) { + *last_range = *last_range.start()..=num; + } else { + ranges.push(num..=num); + } + ranges + }) +} + +/// Range formatter that prints one-element wide ranges as single numbers. impl ObserveState for Vec> - where - U: Copy + Display + Eq + PartialOrd + Ord + Sub + From, +where + U: Copy + Display + Eq + PartialOrd + Ord + Sub + From, { type State = Vec; fn get_state(&self) -> Option { @@ -358,7 +361,7 @@ impl ObserveState for Vec> |range| match (*range.end() - *range.start()).cmp(&U::from(1)) { std::cmp::Ordering::Less => format!("{}", range.start()), std::cmp::Ordering::Equal => format!("[{}, {}]", range.start(), range.end()), - std::cmp::Ordering::Greater => format!("[{},...,{}]", range.start(), range.end()), + std::cmp::Ordering::Greater => format!("[{}..{}]", range.start(), range.end()), }, ) .collect::>(); diff --git a/src/helpers/mod.rs b/src/helpers/mod.rs index a976437f5..188ebb154 100644 --- a/src/helpers/mod.rs +++ b/src/helpers/mod.rs @@ -22,7 +22,7 @@ pub use error::{Error, Result}; mod gateway_stuff { use crate::helpers::{ gateway, - gateway::{InstrumentedGateway, stall_detection::Observed}, + gateway::{stall_detection::Observed, InstrumentedGateway}, }; pub type Gateway = Observed; @@ -219,7 +219,7 @@ impl IndexMut for Vec { /// may be `H2` or `H3`. /// Each helper instance must be able to take any role, but once the role is assigned, it cannot /// be changed for the remainder of the query. -#[derive(Copy, Clone, Debug, PartialEq, Hash, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] #[cfg_attr(feature = "cli", derive(clap::ValueEnum))] #[cfg_attr( feature = "enable-serde", @@ -408,7 +408,7 @@ impl TryFrom<[Role; 3]> for RoleAssignment { /// Combination of helper role and step that uniquely identifies a single channel of communication /// between two helpers. -#[derive(Clone, Eq, PartialEq, Hash)] +#[derive(Clone, Eq, PartialEq, Hash, Ord, PartialOrd)] pub struct ChannelId { pub role: Role, // TODO: step could be either reference or owned value. references are convenient to use inside diff --git a/src/protocol/step/compact.rs b/src/protocol/step/compact.rs index 585ce2042..e65abf737 100644 --- a/src/protocol/step/compact.rs +++ b/src/protocol/step/compact.rs @@ -5,7 +5,7 @@ use ipa_macros::Gate; use super::StepNarrow; use crate::helpers::{prss_protocol::PrssExchangeStep, query::QueryType}; -#[derive(Gate, Clone, Hash, PartialEq, Eq, Default)] +#[derive(Gate, Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Default)] #[cfg_attr( feature = "enable-serde", derive(serde::Deserialize), diff --git a/src/protocol/step/descriptive.rs b/src/protocol/step/descriptive.rs index 4f41a4881..dc13e40a1 100644 --- a/src/protocol/step/descriptive.rs +++ b/src/protocol/step/descriptive.rs @@ -22,7 +22,7 @@ use crate::telemetry::{labels::STEP, metrics::STEP_NARROWED}; /// Step "a" would be executed with a context identifier of "protocol/a", which it /// would `narrow()` into "protocol/a/x" and "protocol/a/y" to produce a final set /// of identifiers: ".../a/x", ".../a/y", ".../b", and ".../c". -#[derive(Clone, Hash, PartialEq, Eq)] +#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] #[cfg_attr( feature = "enable-serde", derive(serde::Deserialize),