diff --git a/metrics/src/recorder/mod.rs b/metrics/src/recorder/mod.rs index 91125dcd..2ede585c 100644 --- a/metrics/src/recorder/mod.rs +++ b/metrics/src/recorder/mod.rs @@ -1,4 +1,4 @@ -use std::{cell::Cell, ptr::NonNull}; +use std::{cell::Cell, marker::PhantomData, ptr::NonNull}; mod cell; use self::cell::RecorderOnceCell; @@ -130,30 +130,39 @@ impl_recorder!(T, std::sync::Arc); /// (thread-local storage) so that it can be accessed by the macros. This guard ensures that the /// pointer we store to the reference is cleared when the guard is dropped, so that it can't be used /// after the closure has finished, even if the closure panics and unwinds the stack. -struct LocalRecorderGuard; +/// +/// ## Note +/// +/// The guard has a lifetime parameter `'a` that is bounded using a +/// `PhantomData` type. This upholds the guard's contravariance, it must live +/// _at most as long_ as the recorder it takes a reference to. The bounded +/// lifetime prevents accidental use-after-free errors when using a guard +/// directly through [`crate::set_default_local_recorder`]. +pub struct LocalRecorderGuard<'a> { + prev_recorder: Option>, + phantom: PhantomData<&'a dyn Recorder>, +} -impl LocalRecorderGuard { +impl<'a> LocalRecorderGuard<'a> { /// Creates a new `LocalRecorderGuard` and sets the thread-local recorder. - fn new(recorder: &dyn Recorder) -> Self { + fn new(recorder: &'a dyn Recorder) -> Self { // SAFETY: While we take a lifetime-less pointer to the given reference, the reference we - // derive _from_ the pointer is never given a lifetime that exceeds the lifetime of the - // input reference. + // derive _from_ the pointer is given the same lifetime of the reference + // used to construct the guard -- captured in the guard type itself -- + // and so derived references never outlive the source reference. let recorder_ptr = unsafe { NonNull::new_unchecked(recorder as *const _ as *mut _) }; - LOCAL_RECORDER.with(|local_recorder| { - local_recorder.set(Some(recorder_ptr)); - }); + let prev_recorder = + LOCAL_RECORDER.with(|local_recorder| local_recorder.replace(Some(recorder_ptr))); - Self + Self { prev_recorder, phantom: PhantomData } } } -impl Drop for LocalRecorderGuard { +impl<'a> Drop for LocalRecorderGuard<'a> { fn drop(&mut self) { // Clear the thread-local recorder. - LOCAL_RECORDER.with(|local_recorder| { - local_recorder.set(None); - }); + LOCAL_RECORDER.with(|local_recorder| local_recorder.replace(self.prev_recorder.take())); } } @@ -175,6 +184,32 @@ where GLOBAL_RECORDER.set(recorder) } +/// Sets the recorder as the default for the current thread for the duration of +/// the lifetime of the returned [`LocalRecorderGuard`]. +/// +/// This function is suitable for capturing metrics in asynchronous code, in particular +/// when using a single-threaded runtime. Any metrics registered prior to the returned +/// guard will remain attached to the recorder that was present at the time of registration, +/// and so this cannot be used to intercept existing metrics. +/// +/// Additionally, local recorders can be used in a nested fashion. When setting a new +/// default local recorder, the previous default local recorder will be captured if one +/// was set, and will be restored when the returned guard drops. +/// the lifetime of the returned [`LocalRecorderGuard`]. +/// +/// Any metrics recorded before a guard is returned will be completely ignored. +/// Metrics implementations should provide an initialization method that +/// installs the recorder internally. +/// +/// The function is suitable for capturing metrics in asynchronous code that +/// uses a single threaded runtime. +/// +/// If a global recorder is set, it will be restored once the guard is dropped. +#[must_use] +pub fn set_default_local_recorder(recorder: &dyn Recorder) -> LocalRecorderGuard { + LocalRecorderGuard::new(recorder) +} + /// Runs the closure with the given recorder set as the global recorder for the duration. /// /// This only applies as long as the closure is running, and on the thread where `with_local_recorder` is called. This @@ -212,22 +247,129 @@ pub fn with_recorder(f: impl FnOnce(&dyn Recorder) -> T) -> T { #[cfg(test)] mod tests { - use std::sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }; + use std::sync::{atomic::Ordering, Arc}; - use crate::NoopRecorder; + use crate::{with_local_recorder, NoopRecorder}; use super::{Recorder, RecorderOnceCell}; #[test] fn boxed_recorder_dropped_on_existing_set() { // This test simply ensures that if a boxed recorder is handed to us to install, and another - // recorder has already been installed, that we drop th new boxed recorder instead of + // recorder has already been installed, that we drop the new boxed recorder instead of // leaking it. + let recorder_cell = RecorderOnceCell::new(); + + // This is the first set of the cell, so it should always succeed. + let (first_recorder, _) = test_recorders::TrackOnDropRecorder::new(); + let first_set_result = recorder_cell.set(first_recorder); + assert!(first_set_result.is_ok()); + + // Since the cell is already set, this second set should fail. We'll also then assert that + // our atomic boolean is set to `true`, indicating the drop logic ran for it. + let (second_recorder, was_dropped) = test_recorders::TrackOnDropRecorder::new(); + assert!(!was_dropped.load(Ordering::SeqCst)); + + let second_set_result = recorder_cell.set(second_recorder); + assert!(second_set_result.is_err()); + assert!(!was_dropped.load(Ordering::SeqCst)); + drop(second_set_result); + assert!(was_dropped.load(Ordering::SeqCst)); + } + + #[test] + fn blanket_implementations() { + fn is_recorder(_recorder: T) {} + + let mut local = NoopRecorder; + + is_recorder(NoopRecorder); + is_recorder(Arc::new(NoopRecorder)); + is_recorder(Box::new(NoopRecorder)); + is_recorder(&local); + is_recorder(&mut local); + } + + #[test] + fn thread_scoped_recorder_guards() { + // This test ensures that when a recorder is installed through + // `crate::set_default_local_recorder` it will only be valid in the scope of the + // thread. + // + // The goal of the test is to give confidence that no invalid memory + // access errors are present when operating with locally scoped + // recorders. + let t1_recorder = test_recorders::SimpleCounterRecorder::new(); + let t2_recorder = test_recorders::SimpleCounterRecorder::new(); + let t3_recorder = test_recorders::SimpleCounterRecorder::new(); + // Start a new thread scope to take references to each recorder in the + // closures passed to the thread. + std::thread::scope(|s| { + s.spawn(|| { + let _guard = crate::set_default_local_recorder(&t1_recorder); + crate::counter!("t1_counter").increment(1); + }); + + s.spawn(|| { + with_local_recorder(&t2_recorder, || { + crate::counter!("t2_counter").increment(2); + }) + }); + + s.spawn(|| { + let _guard = crate::set_default_local_recorder(&t3_recorder); + crate::counter!("t3_counter").increment(3); + }); + }); + + assert!(t1_recorder.get_value() == 1); + assert!(t2_recorder.get_value() == 2); + assert!(t3_recorder.get_value() == 3); + } + + #[test] + fn local_recorder_restored_when_dropped() { + // This test ensures that any previously installed local recorders are + // restored when the subsequently installed recorder's guard is dropped. + let root_recorder = test_recorders::SimpleCounterRecorder::new(); + // Install the root recorder and increment the counter once. + let _guard = crate::set_default_local_recorder(&root_recorder); + crate::counter!("test_counter").increment(1); + + // Install a second recorder and increment its counter once. + let next_recorder = test_recorders::SimpleCounterRecorder::new(); + let next_guard = crate::set_default_local_recorder(&next_recorder); + crate::counter!("test_counter").increment(1); + let final_recorder = test_recorders::SimpleCounterRecorder::new(); + crate::with_local_recorder(&final_recorder, || { + // Final recorder increments the counter by 10. At the end of the + // closure, the guard should be dropped, and `next_recorder` + // restored. + crate::counter!("test_counter").increment(10); + }); + // Since `next_recorder` is restored, we can increment it once and check + // that the value is 2 (+1 before and after the closure). + crate::counter!("test_counter").increment(1); + assert!(next_recorder.get_value() == 2); + drop(next_guard); + + // At the end, increment the counter again by an arbitrary value. Since + // `next_guard` is dropped, the root recorder is restored. + crate::counter!("test_counter").increment(20); + assert!(root_recorder.get_value() == 21); + } + + mod test_recorders { + use std::sync::{ + atomic::{AtomicBool, AtomicU64, Ordering}, + Arc, + }; + + use crate::Recorder; + #[derive(Debug)] - struct TrackOnDropRecorder(Arc); + // Tracks how many times the recorder was dropped + pub struct TrackOnDropRecorder(Arc); impl TrackOnDropRecorder { pub fn new() -> (Self, Arc) { @@ -236,6 +378,8 @@ mod tests { } } + // === impl TrackOnDropRecorder === + impl Recorder for TrackOnDropRecorder { fn describe_counter( &self, @@ -282,35 +426,78 @@ mod tests { } } - let recorder_cell = RecorderOnceCell::new(); + // A simple recorder that only implements `register_counter`. + #[derive(Debug)] + pub struct SimpleCounterRecorder { + state: Arc, + } - // This is the first set of the cell, so it should always succeed. - let (first_recorder, _) = TrackOnDropRecorder::new(); - let first_set_result = recorder_cell.set(first_recorder); - assert!(first_set_result.is_ok()); + impl SimpleCounterRecorder { + pub fn new() -> Self { + Self { state: Arc::new(AtomicU64::default()) } + } - // Since the cell is already set, this second set should fail. We'll also then assert that - // our atomic boolean is set to `true`, indicating the drop logic ran for it. - let (second_recorder, was_dropped) = TrackOnDropRecorder::new(); - assert!(!was_dropped.load(Ordering::SeqCst)); + pub fn get_value(&self) -> u64 { + self.state.load(Ordering::Acquire) + } + } - let second_set_result = recorder_cell.set(second_recorder); - assert!(second_set_result.is_err()); - assert!(!was_dropped.load(Ordering::SeqCst)); - drop(second_set_result); - assert!(was_dropped.load(Ordering::SeqCst)); - } + struct SimpleCounterHandle { + state: Arc, + } - #[test] - fn blanket_implementations() { - fn is_recorder(_recorder: T) {} + impl crate::CounterFn for SimpleCounterHandle { + fn increment(&self, value: u64) { + self.state.fetch_add(value, Ordering::Acquire); + } - let mut local = NoopRecorder; + fn absolute(&self, _value: u64) { + unimplemented!() + } + } - is_recorder(NoopRecorder); - is_recorder(Arc::new(NoopRecorder)); - is_recorder(Box::new(NoopRecorder)); - is_recorder(&local); - is_recorder(&mut local); + // === impl SimpleCounterRecorder === + + impl Recorder for SimpleCounterRecorder { + fn describe_counter( + &self, + _: crate::KeyName, + _: Option, + _: crate::SharedString, + ) { + } + fn describe_gauge( + &self, + _: crate::KeyName, + _: Option, + _: crate::SharedString, + ) { + } + fn describe_histogram( + &self, + _: crate::KeyName, + _: Option, + _: crate::SharedString, + ) { + } + + fn register_counter(&self, _: &crate::Key, _: &crate::Metadata<'_>) -> crate::Counter { + crate::Counter::from_arc(Arc::new(SimpleCounterHandle { + state: self.state.clone(), + })) + } + + fn register_gauge(&self, _: &crate::Key, _: &crate::Metadata<'_>) -> crate::Gauge { + crate::Gauge::noop() + } + + fn register_histogram( + &self, + _: &crate::Key, + _: &crate::Metadata<'_>, + ) -> crate::Histogram { + crate::Histogram::noop() + } + } } }