diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index d24abd8..703221c 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -54,3 +54,4 @@ url = "2.5.0" [dev-dependencies] axum-test-helper = "0.3.0" +tokio = { version = "1.36.0", features = ["full", "test-util"] } diff --git a/crates/sdk/src/default_main.rs b/crates/sdk/src/default_main.rs index 9fef731..fcc7232 100644 --- a/crates/sdk/src/default_main.rs +++ b/crates/sdk/src/default_main.rs @@ -1,5 +1,6 @@ use std::net; use std::path::{Path, PathBuf}; +use std::time::Duration; use axum::{ body::Body, @@ -24,6 +25,7 @@ use crate::connector::{Connector, ConnectorSetup, ErrorResponse, Result}; use crate::fetch_metrics::fetch_metrics; use crate::json_rejection::JsonRejection; use crate::json_response::JsonResponse; +use crate::throttle; use crate::tracing::{init_tracing, make_span, on_response}; #[derive(Parser)] @@ -139,6 +141,7 @@ pub struct ServerState { configuration: C::Configuration, state: C::State, metrics: Registry, + health: throttle::Throttle>, } impl Clone for ServerState @@ -151,20 +154,59 @@ where configuration: self.configuration.clone(), state: self.state.clone(), metrics: self.metrics.clone(), + health: self.health.clone(), } } } -impl ServerState { +impl ServerState +where + C::Configuration: Clone, + C::State: Clone, +{ pub fn new(configuration: C::Configuration, state: C::State, metrics: Registry) -> Self { Self { - configuration, - state, + configuration: configuration.clone(), + state: state.clone(), metrics, + health: throttle::Throttle::new( + ConnectorHealth { + configuration, + state, + }, + Duration::from_secs(1), + ), } } } +#[derive(Debug)] +struct ConnectorHealth { + configuration: C::Configuration, + state: C::State, +} + +impl Clone for ConnectorHealth +where + C::Configuration: Clone, + C::State: Clone, +{ + fn clone(&self) -> Self { + Self { + configuration: self.configuration.clone(), + state: self.state.clone(), + } + } +} + +impl throttle::Operation for ConnectorHealth { + type Output = Result<()>; + + async fn run(&self) -> Self::Output { + C::health_check(&self.configuration, &self.state).await + } +} + /// A default main function for a connector. /// /// The intent is that this function can replace your `main` function @@ -286,7 +328,11 @@ where pub async fn init_server_state( setup: Setup, config_directory: impl AsRef + Send, -) -> Result> { +) -> Result> +where + ::Configuration: Clone, + ::State: Clone, +{ let mut metrics = Registry::new(); let configuration = setup.parse_configuration(config_directory).await?; let state = setup.try_init_state(&configuration, &mut metrics).await?; @@ -383,7 +429,7 @@ async fn get_capabilities() -> JsonResponse } async fn get_health(State(state): State>) -> Result<()> { - C::health_check(&state.configuration, &state.state).await + state.health.next().await } async fn get_schema( diff --git a/crates/sdk/src/lib.rs b/crates/sdk/src/lib.rs index a4741cd..53d74e4 100644 --- a/crates/sdk/src/lib.rs +++ b/crates/sdk/src/lib.rs @@ -4,6 +4,7 @@ pub mod default_main; pub mod fetch_metrics; pub mod json_rejection; pub mod json_response; +pub mod throttle; pub mod tracing; pub use ndc_models as models; diff --git a/crates/sdk/src/throttle.rs b/crates/sdk/src/throttle.rs new file mode 100644 index 0000000..9e6f876 --- /dev/null +++ b/crates/sdk/src/throttle.rs @@ -0,0 +1,316 @@ +//! Provides a wrapper around behavior that throttles it when it is called concurrently or in quick +//! succession. +//! +//! Once an operation is completed, the next is delayed by the interval provided from the _start_ +//! time of the previous operation. +//! +//! If multiple operations are delayed, they are pooled so that only one is run, and the rest clone +//! the resulting value. + +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::{oneshot, Mutex}; +use tokio::time::{self, Instant}; + +/// A type alias for a boxed future, to simplify places where it's used. +type BoxedFuture<'a, T> = std::pin::Pin + Send + 'a>>; + +/// An operation that can be throttled. +/// +/// We use this instead of `Fn() -> BoxedFuture` because we run into strange and confusing +/// lifetime errors with that pattern. + +pub trait Operation { + type Output: Clone + Send + Sync; + + fn run(&self) -> impl std::future::Future + Send; +} + +/// The state of the throttle at any given moment. +/// +/// This is always wrapped in an `Arc>`. +enum ThrottleState { + /// On first run, the throttle is in this state. We never return to it. + NeverRun, + /// If the throttle was run before, but is not currently running, it is in this state. + Idle { last_run: Instant }, + /// If the throttle is currently running, it is in this state. See [RunningState] for more + /// details. + Running { state: Arc>> }, +} + +/// The state of a running operation. +/// +/// This is independently wrapped in an `Arc>` because we would otherwise need to store +/// the result (in the `Finished` state) in the outer state, which means it would be possible (by +/// construction) for callers to acquire previous results instead of waiting for a new one. By +/// splitting this out, we ensure that once an operation has finished, a new one _must_ be started +/// in order to get a result. +enum RunningState { + Running(Vec>), + Finished(T), +} + +/// The throttle delays operations if they are called too quickly. It is constructed with some +/// behavior (implementing [Operation]) and an interval. +/// +/// See the module-level documentation for details. +/// +/// Cloning this creates a new value which shares its state with the original. +#[derive(Clone)] +pub struct Throttle { + behavior: Arc, + interval: Duration, + state: Arc>>, +} + +impl std::fmt::Debug for Throttle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Throttle({:?})", self.behavior) + } +} + +impl Throttle { + /// Constructs a new throttle with the given behavior and interval. + pub fn new(behavior: Behavior, interval: Duration) -> Self { + Self { + behavior: Arc::new(behavior), + interval, + state: Arc::new(Mutex::new(ThrottleState::NeverRun)), + } + } + + /// Gets the next value, either by running the operation provided or by waiting for an + /// already-running operation to complete. + /// + /// It may be delayed by up to the interval. + pub async fn next(&self) -> Behavior::Output { + self.decide().await.await + } + + /// Decide what to do when asked to perform an operation. + /// + /// This acquires locks, but does not hold them in the returned future. + async fn decide(&self) -> BoxedFuture { + let mut state = self.state.lock().await; + + // First, we check if we need to delay. If so, we store a future which will wait the + // appropriate amount of time. + // + // We do not delay while holding onto the state lock. + let delay: BoxedFuture<()> = match &*state { + ThrottleState::NeverRun | ThrottleState::Running { .. } => Box::pin(async {}), + ThrottleState::Idle { last_run } => { + let delayed_start = *last_run + self.interval; + Box::pin(time::sleep_until(delayed_start)) + } + }; + + match &*state { + // If nothing is running, we start a new operation. + ThrottleState::NeverRun | ThrottleState::Idle { last_run: _ } => { + let start = Instant::now(); + let running_state = Arc::new(Mutex::new(RunningState::Running(Vec::new()))); + let running_state_for_later = running_state.clone(); + *state = ThrottleState::Running { + state: running_state, + }; + drop(state); + + Box::pin(async move { + // First, we wait for the appropriate delay. + delay.await; + + // Next, we run the operation to get the value. + let value = self.behavior.run().await; + + // If any other operations are waiting, we let them know of the result. + let mut running_state = running_state_for_later.lock().await; + match &mut *running_state { + RunningState::Running(waiters) => { + for waiter in waiters.drain(..) { + let _ = waiter.send(value.clone()); + } + *running_state = RunningState::Finished(value.clone()); + } + RunningState::Finished(_) => { + unreachable!("throttle completed twice"); + } + } + + // Finally, we mark the throttle state as idle. + let mut state = self.state.lock().await; + *state = ThrottleState::Idle { last_run: start }; + + value + }) + } + // If something is running, we wait for it by pushing a one-shot channel into a + // queue, and then waiting for the result. + ThrottleState::Running { + state: running_state, + } => { + let mut running_state = running_state.lock().await; + match &mut *running_state { + RunningState::Running(running) => { + let (sender, receiver) = oneshot::channel(); + running.push(sender); + drop(running_state); + drop(state); + + Box::pin(async { receiver.await.unwrap() }) + } + // If it already finished since we acquired the outer lock, simply return + // the result. + RunningState::Finished(value) => { + let value = value.clone(); + drop(running_state); + drop(state); + + Box::pin(async move { value }) + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicI32, Ordering}; + + use tokio::task; + + use super::*; + + #[tokio::test] + async fn makes_requests() { + time::pause(); + + let counter = Counter::new(); + let throttled_counter = ThrottledCounter::new(counter.clone()); + let throttle = Arc::new(Throttle::new(throttled_counter, Duration::from_secs(1))); + + let handle = task::spawn({ + let t = throttle.clone(); + async move { t.next().await } + }); + + assert_eq!(counter.value(), 0); + + task::yield_now().await; + assert_eq!(counter.value(), 1); + + assert_eq!(handle.await.unwrap(), 0); + } + + #[tokio::test] + async fn throttles_requests() { + time::pause(); + + let counter = Counter::new(); + let throttled_counter = ThrottledCounter::new(counter.clone()); + let throttle = Arc::new(Throttle::new(throttled_counter, Duration::from_secs(1))); + + let first = task::spawn({ + let t = throttle.clone(); + async move { t.next().await } + }); + let second = task::spawn({ + let t = throttle.clone(); + async move { t.next().await } + }); + + assert_eq!(counter.value(), 0); + + task::yield_now().await; + assert_eq!(counter.value(), 1); + + time::advance(Duration::from_millis(500)).await; + task::yield_now().await; + assert_eq!(counter.value(), 1); + + time::advance(Duration::from_millis(501)).await; + task::yield_now().await; + assert_eq!(counter.value(), 2); + + assert_eq!(first.await.unwrap(), 0); + assert_eq!(second.await.unwrap(), 1); + } + + #[tokio::test] + async fn requests_are_pooled_during_throttling() { + time::pause(); + + let counter = Counter::new(); + let throttled_counter = ThrottledCounter::new(counter.clone()); + let throttle = Arc::new(Throttle::new(throttled_counter, Duration::from_secs(1))); + + let requests = [0, 1, 1, 1, 1].map(|expected_value| { + ( + expected_value, + task::spawn({ + let t = throttle.clone(); + async move { t.next().await } + }), + ) + }); + + assert_eq!(counter.value(), 0); + + task::yield_now().await; + assert_eq!(counter.value(), 1); + + time::advance(Duration::from_millis(500)).await; + task::yield_now().await; + assert_eq!(counter.value(), 1); + + time::advance(Duration::from_millis(501)).await; + task::yield_now().await; + assert_eq!(counter.value(), 2); + + time::advance(Duration::from_secs(60)).await; + task::yield_now().await; + assert_eq!(counter.value(), 2); + + for (value, request) in requests { + assert_eq!(request.await.unwrap(), value); + } + } + + #[derive(Clone)] + struct Counter(Arc); + + impl Counter { + fn new() -> Self { + Counter(Arc::new(AtomicI32::new(0))) + } + + fn value(&self) -> i32 { + self.0.load(Ordering::Acquire) + } + + fn fetch_and_inc(&self) -> i32 { + self.0.fetch_add(1, Ordering::AcqRel) + } + } + + struct ThrottledCounter { + counter: Counter, + } + + impl ThrottledCounter { + fn new(counter: Counter) -> Self { + Self { counter } + } + } + + impl Operation for ThrottledCounter { + type Output = i32; + + async fn run(&self) -> Self::Output { + self.counter.fetch_and_inc() + } + } +}