diff --git a/lace/benches/oracle_fns.rs b/lace/benches/oracle_fns.rs index 5d9b4b0b..554c23ee 100644 --- a/lace/benches/oracle_fns.rs +++ b/lace/benches/oracle_fns.rs @@ -1,9 +1,7 @@ use criterion::Criterion; use criterion::{black_box, criterion_group, criterion_main}; use lace::examples::Example; -use lace::{ - Given, ImputeUncertaintyType, Oracle, OracleT, PredictUncertaintyType, -}; +use lace::{Given, Oracle, OracleT}; use lace_data::Datum; use rand::SeedableRng; use rand_xoshiro::Xoshiro256Plus; @@ -232,9 +230,8 @@ fn bench_simulate(c: &mut Criterion) { fn bench_impute(c: &mut Criterion) { c.bench_function("oracle impute", |b| { let oracle = get_oracle(); - let unc_type = ImputeUncertaintyType::JsDivergence; b.iter(|| { - let _res = black_box(oracle.impute(13, 12, Some(unc_type))); + let _res = black_box(oracle.impute(13, 12, true)); }) }); } @@ -246,15 +243,13 @@ fn bench_predict(c: &mut Criterion) { (2, Datum::Categorical(lace::Category::U8(0))), ]); let oracle = get_oracle(); - let unc_type = PredictUncertaintyType::JsDivergence; b.iter(|| { - let _res = - black_box(oracle.predict(13, &given, Some(unc_type), None)); + let _res = black_box(oracle.predict(13, &given, true, None)); }) }); } -fn bench_predict_continous(c: &mut Criterion) { +fn bench_predict_continuous(c: &mut Criterion) { c.bench_function("oracle predict continuous", |b| { let given = Given::Conditions(vec![( 4, @@ -262,7 +257,7 @@ fn bench_predict_continous(c: &mut Criterion) { )]); let oracle = get_satellites_oracle(); b.iter(|| { - let _res = black_box(oracle.predict(8, &given, None, None)); + let _res = black_box(oracle.predict(8, &given, false, None)); }) }); } diff --git a/lace/lace_stats/src/lib.rs b/lace/lace_stats/src/lib.rs index 84a8b427..5d101ce9 100644 --- a/lace/lace_stats/src/lib.rs +++ b/lace/lace_stats/src/lib.rs @@ -21,6 +21,7 @@ mod perm; pub mod prior; pub mod seq; mod simplex; +pub mod uncertainty; mod sample_error; diff --git a/lace/lace_stats/src/uncertainty.rs b/lace/lace_stats/src/uncertainty.rs new file mode 100644 index 00000000..74d4dfd4 --- /dev/null +++ b/lace/lace_stats/src/uncertainty.rs @@ -0,0 +1,239 @@ +use crate::rv::dist::{Bernoulli, Categorical, Gaussian, Mixture, Poisson}; +use crate::rv::traits::{Mean, QuadBounds, Rv}; + +/// Compute the normed mean Total Variation Distance of a set of mixture +/// distributions with the mean of distributions. +/// +/// # Notes +/// - The output will be in [0, 1.0). +/// - Normalization is used to account for the fact that the maximum TVD is +/// limited by the number of mixtures. For example, if there are two mixtures +/// in `mixtures` the max TVD in only 1/2; if there are three, the max TVD is +/// 2/3; if there are four the max TVD is 3/4; and so on. We divide the final +/// output by `(n - 1) / n`, where `n` is the number of mixtures, so that the +/// output can be interpreted similarly regardless of the input. +pub fn mixture_normed_tvd(mixtures: &[Mixture]) -> f64 +where + Fx: Clone, + Mixture: TotalVariationDistance, +{ + let n = mixtures.len() as f64; + let norm = (n - 1.0) / n; + + let combined = Mixture::combine(mixtures.to_owned()); + let tvd = mixtures.iter().map(|mm| combined.tvd(mm)).sum::() + / mixtures.len() as f64; + + tvd / norm +} + +pub trait TotalVariationDistance { + fn tvd(&self, other: &Self) -> f64; +} + +fn gaussian_quad_points( + f1: &Mixture, + f2: &Mixture, +) -> Vec { + // Get the lower and upper bound for quadrature + let (a, b) = { + let (a_1, b_1) = f1.quad_bounds(); + let (a_2, b_2) = f2.quad_bounds(); + (a_1.min(a_2), b_1.max(b_2)) + }; + + // Get a list of sorted means and their associated stddevs + let params = { + let mut params = f1 + .components() + .iter() + .chain(f2.components().iter()) + .map(|cpnt| (cpnt.mu(), cpnt.sigma())) + .collect::>(); + params.sort_unstable_by(|(a, _), (b, _)| a.total_cmp(b)); + params + }; + + let mut last_mean = params[0].0; + let mut last_std = params[0].1; + let mut points = vec![a, last_mean]; + + for &(mean, std) in params.iter().skip(1) { + let dist = mean - last_mean; + let z_dist = dist / ((last_std + std) / 2.0); + if z_dist > 1.0 { + points.push(mean); + last_std = std; + last_mean = mean; + } + } + + points.push(b); + points +} + +impl TotalVariationDistance for Mixture { + fn tvd(&self, other: &Self) -> f64 { + use crate::rv::misc::{ + gauss_legendre_quadrature_cached, gauss_legendre_table, + }; + + let func = |x| (self.f(&x) - other.f(&x)).abs(); + + let quad_level = 16; + let quad_points = gaussian_quad_points(self, other); + let (weights, roots) = gauss_legendre_table(quad_level); + + let mut right = quad_points[0]; + quad_points + .iter() + .skip(1) + .map(|&x| { + let q = gauss_legendre_quadrature_cached( + func, + (right, x), + &weights, + &roots, + ); + right = x; + q + }) + .sum::() + / 2.0 + } +} + +impl TotalVariationDistance for Mixture { + fn tvd(&self, other: &Self) -> f64 { + let k = self.components()[0].k(); + assert_eq!(k, other.components()[0].k()); + (0..k) + .map(|x| (self.f(&x) - other.f(&x)).abs()) + .sum::() + / 2.0 + } +} + +impl TotalVariationDistance for Mixture { + fn tvd(&self, other: &Self) -> f64 { + let q = + (self.f(&0) - other.f(&0)).abs() + (self.f(&1) - other.f(&1)).abs(); + q / 2.0 + } +} + +impl TotalVariationDistance for Mixture { + fn tvd(&self, other: &Self) -> f64 { + let threshold = 1e-14; + let m: u32 = self.mean().unwrap().min(other.mean().unwrap()) as u32; + + let mut x: u32 = 0; + let mut q: f64 = 0.0; + loop { + let f1 = self.f(&x); + let f2 = other.f(&x); + + let diff = (f1 - f2).abs(); + + q += diff; + x += 1; + + if x > m && (f1 < threshold && f2 < threshold) { + break; + } + } + q / 2.0 + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn gauss_moving_means_away_increases_tvd() { + let mut last_tvd = 0.0; + (0..10).for_each(|i| { + let dist = 0.5 * (i + 1) as f64; + let g1 = Gaussian::new(-dist / 2.0, 1.0).unwrap(); + let g2 = Gaussian::new(dist / 2.0, 1.0).unwrap(); + + let m1 = Mixture::uniform(vec![g1]).unwrap(); + let m2 = Mixture::uniform(vec![g2]).unwrap(); + + let tvd = mixture_normed_tvd(&[m1, m2]); + + eprintln!("{i} - d: {dist}, tvd: {tvd}"); + + assert!(last_tvd < tvd); + assert!(tvd <= 1.0); + + last_tvd = tvd; + }); + } + + #[test] + fn count_moving_means_away_increases_tvd() { + let mut last_tvd = 0.0; + (0..10).for_each(|i| { + let p1 = Poisson::new(5.0).unwrap(); + let p2 = Poisson::new(5.0 + (i + 1) as f64).unwrap(); + + let m1 = Mixture::uniform(vec![p1]).unwrap(); + let m2 = Mixture::uniform(vec![p2]).unwrap(); + + let tvd = mixture_normed_tvd(&[m1, m2]); + + eprintln!("{i} tvd: {tvd}"); + + assert!(last_tvd < tvd); + assert!(tvd <= 1.0); + + last_tvd = tvd; + }); + } + + #[test] + fn bernoulli_moving_means_away_increases_tvd() { + let mut last_tvd = std::f64::NEG_INFINITY; + (0..10).for_each(|i| { + let p = 0.5 / (i + 1) as f64; + let b1 = Bernoulli::new(p).unwrap(); + let b2 = Bernoulli::new(1.0 - p).unwrap(); + + let m1 = Mixture::uniform(vec![b1]).unwrap(); + let m2 = Mixture::uniform(vec![b2]).unwrap(); + + let tvd = mixture_normed_tvd(&[m1, m2]); + + eprintln!("{i} p: {p}, tvd: {tvd}"); + + assert!(last_tvd < tvd); + assert!(tvd <= 1.0); + + last_tvd = tvd; + }); + } + + #[test] + fn categorical_moving_means_away_increases_tvd() { + let mut last_tvd = std::f64::NEG_INFINITY; + (0..10).for_each(|i| { + let p = 0.5 / (i + 1) as f64; + let c1 = Categorical::new(&[p, 1.0 - p]).unwrap(); + let c2 = Categorical::new(&[1.0 - p, p]).unwrap(); + + let m1 = Mixture::uniform(vec![c1]).unwrap(); + let m2 = Mixture::uniform(vec![c2]).unwrap(); + + let tvd = mixture_normed_tvd(&[m1, m2]); + + eprintln!("{i} p: {p}, tvd: {tvd}"); + + assert!(last_tvd < tvd); + assert!(tvd <= 1.0); + + last_tvd = tvd; + }); + } +} diff --git a/lace/src/interface/mod.rs b/lace/src/interface/mod.rs index 992c7f3f..b2303bd6 100644 --- a/lace/src/interface/mod.rs +++ b/lace/src/interface/mod.rs @@ -14,9 +14,8 @@ pub use lace_metadata::latest::Metadata; pub use oracle::utils; pub use oracle::{ - ConditionalEntropyType, DatalessOracle, ImputeUncertaintyType, - MiComponents, MiType, Oracle, OracleT, PredictUncertaintyType, - RowSimilarityVariant, + ConditionalEntropyType, DatalessOracle, MiComponents, MiType, Oracle, + OracleT, RowSimilarityVariant, }; pub use given::Given; diff --git a/lace/src/interface/oracle/mod.rs b/lace/src/interface/oracle/mod.rs index b5308b90..473bf335 100644 --- a/lace/src/interface/oracle/mod.rs +++ b/lace/src/interface/oracle/mod.rs @@ -83,38 +83,6 @@ impl MiComponents { } } -/// The type of uncertainty to use for `Oracle.impute` -#[derive( - Serialize, - Deserialize, - Debug, - Clone, - Copy, - Eq, - PartialEq, - Ord, - PartialOrd, - Hash, -)] -#[serde(rename_all = "snake_case")] -pub enum ImputeUncertaintyType { - /// Given a set of distributions Θ = {Θ1, ..., Θn}, - /// return the mean of KL(Θi || Θi) - PairwiseKl, - /// The Jensen-Shannon divergence in nats divided by ln(n), where n is the - /// number of distributions - JsDivergence, -} - -/// The type of uncertainty to use for `Oracle.predict` -#[derive(Serialize, Deserialize, Debug, Clone, Copy)] -#[serde(rename_all = "snake_case")] -pub enum PredictUncertaintyType { - /// The Jensen-Shannon divergence in nats divided by ln(n), where n is the - /// number of distributions - JsDivergence, -} - /// The variant on conditional entropy to compute #[derive( Serialize, @@ -407,22 +375,11 @@ mod tests { } #[test] - fn kl_impute_uncertainty_smoke() { - let oracle = get_oracle_from_yaml(); - let u = - oracle._impute_uncertainty(0, 1, ImputeUncertaintyType::PairwiseKl); - assert!(u > 0.0); - } - - #[test] - fn js_impute_uncertainty_smoke() { + fn impute_uncertainty_smoke() { let oracle = get_oracle_from_yaml(); - let u = oracle._impute_uncertainty( - 0, - 1, - ImputeUncertaintyType::JsDivergence, - ); + let u = oracle._impute_uncertainty(0, 1); assert!(u > 0.0); + assert!(u < 1.0); } #[test] @@ -430,6 +387,7 @@ mod tests { let oracle = get_oracle_from_yaml(); let u = oracle._predict_uncertainty(0, &Given::Nothing, None); assert!(u > 0.0); + assert!(u < 1.0); } #[test] @@ -438,6 +396,7 @@ mod tests { let given = Given::Conditions(vec![(1, Datum::Continuous(2.5))]); let u = oracle._predict_uncertainty(0, &given, None); assert!(u > 0.0); + assert!(u < 1.0); } #[test] diff --git a/lace/src/interface/oracle/traits.rs b/lace/src/interface/oracle/traits.rs index 8a992cdb..4ddab9fc 100644 --- a/lace/src/interface/oracle/traits.rs +++ b/lace/src/interface/oracle/traits.rs @@ -5,10 +5,7 @@ use crate::index::{ extract_col_pair, extract_colixs, extract_row_pair, ColumnIndex, RowIndex, }; use crate::interface::oracle::error::SurprisalError; -use crate::interface::oracle::{ - ConditionalEntropyType, ImputeUncertaintyType, MiComponents, MiType, - PredictUncertaintyType, -}; +use crate::interface::oracle::{ConditionalEntropyType, MiComponents, MiType}; use crate::interface::{CanOracle, Given}; use lace_cc::feature::{FType, Feature}; use lace_cc::state::{State, StateDiagnostics}; @@ -1730,17 +1727,19 @@ pub trait OracleT: CanOracle { /// value of the cell rather than returning the existing value. To get /// the current value of a cell, use `Oracle::data`. /// + /// Impute uncertainty is the mean [total variation distance](https://en.wikipedia.org/wiki/Total_variation_distance_of_probability_measures) + /// between each state's impute distribution and the average impute + /// distribution. + /// /// # Arguments /// /// - row_ix: the row index of the cell to impute /// - col_ix: the column index of the cell to impute - /// - with_unc: if `true` compute the uncertainty, otherwise a value of -1 - /// is returned in the uncertainty spot + /// - with_uncertainty: if `true` compute and return the uncertainty /// /// # Returns /// - /// A `(value, uncertainty_option)` tuple. If `with_unc` is `false`, - /// `uncertainty` is -1. + /// A `(value, uncertainty_option)` tuple. /// /// # Example /// @@ -1749,7 +1748,6 @@ pub trait OracleT: CanOracle { /// ``` /// # use lace::examples::Example; /// use lace::OracleT; - /// use lace::ImputeUncertaintyType; /// use lace_data::Datum; /// /// let oracle = Example::Animals.oracle().unwrap(); @@ -1757,13 +1755,13 @@ pub trait OracleT: CanOracle { /// let dolphin_swims = oracle.impute( /// "dolphin", /// "swims", - /// Some(ImputeUncertaintyType::JsDivergence) + /// true, /// ).unwrap(); /// /// let bear_swims = oracle.impute( /// "polar+bear", /// "swims", - /// Some(ImputeUncertaintyType::JsDivergence) + /// true, /// ).unwrap(); /// /// assert_eq!(dolphin_swims.0, Datum::Categorical(1_u8.into())); @@ -1782,14 +1780,13 @@ pub trait OracleT: CanOracle { /// ``` /// # use lace::examples::Example; /// # use lace::OracleT; - /// # use lace::ImputeUncertaintyType; /// # use lace_data::{Datum, Category}; /// let oracle = Example::Satellites.oracle().unwrap(); /// /// let (imp, _) = oracle.impute( /// "X-Sat", /// "Type_of_Orbit", - /// Some(ImputeUncertaintyType::JsDivergence), + /// true, /// ).unwrap(); /// /// assert_eq!(imp, Datum::Categorical("Sun-Synchronous".into())); @@ -1798,13 +1795,12 @@ pub trait OracleT: CanOracle { /// ``` /// # use lace::examples::Example; /// # use lace::OracleT; - /// # use lace::ImputeUncertaintyType; /// # use lace_data::{Datum, Category}; /// # let oracle = Example::Satellites.oracle().unwrap(); /// let (imp, _) = oracle.impute( /// "X-Sat", /// "longitude_radians_of_geo", - /// Some(ImputeUncertaintyType::JsDivergence), + /// true, /// ).unwrap(); /// /// assert!((imp.to_f64_opt().unwrap() - 0.18514237733859296).abs() < 1e-10); @@ -1813,7 +1809,7 @@ pub trait OracleT: CanOracle { &self, row_ix: RIx, col_ix: CIx, - unc_type_opt: Option, + with_uncertainty: bool, ) -> Result<(Datum, Option), IndexError> { let row_ix = row_ix.row_ix(self.codebook())?; let col_ix = col_ix.col_ix(self.codebook())?; @@ -1839,8 +1835,11 @@ pub trait OracleT: CanOracle { let val = utils::post_process_datum(val, col_ix, self.codebook()); - let unc_opt = unc_type_opt - .map(|unc_type| self._impute_uncertainty(row_ix, col_ix, unc_type)); + let unc_opt = if with_uncertainty { + Some(self._impute_uncertainty(row_ix, col_ix)) + } else { + None + }; Ok((val, unc_opt)) } @@ -1851,6 +1850,9 @@ pub trait OracleT: CanOracle { /// # Arguments /// - col_ix: the index of the column to predict /// - given: optional observations by which to constrain the prediction + /// - with_uncertainty: if true, copmute and return uncertainty + /// - state_ixs_opt: Optional vector of state indices from which to predict, + /// if None, use all states. /// /// # Returns /// A `(value, uncertainty_option)` Tuple @@ -1859,7 +1861,7 @@ pub trait OracleT: CanOracle { /// /// Predict the most likely class of orbit for given longitude of /// Geosynchronous orbit. - /// + /// /// ``` /// use lace::examples::Example; /// use lace::prelude::*; @@ -1871,7 +1873,7 @@ pub trait OracleT: CanOracle { /// &Given::Conditions(vec![ /// ("longitude_radians_of_geo", Datum::Continuous(1.0)) /// ]), - /// None, + /// false, /// None, /// ).unwrap(); /// @@ -1891,7 +1893,7 @@ pub trait OracleT: CanOracle { /// &Given::Conditions(vec![ /// ("longitude_radians_of_geo", Datum::Missing) /// ]), - /// None, + /// false, /// None, /// ).unwrap(); /// @@ -1909,7 +1911,7 @@ pub trait OracleT: CanOracle { /// &Given::Conditions(vec![( /// "Class_of_Orbit", Datum::Categorical("MEO".into())) /// ]), - /// None, + /// false, /// None, /// ).unwrap(); /// @@ -1926,7 +1928,7 @@ pub trait OracleT: CanOracle { /// let (pred_type, _) = oracle.predict( /// "longitude_radians_of_geo", /// &Given::::Nothing, - /// None, + /// false, /// None, /// ).unwrap(); /// @@ -1949,7 +1951,7 @@ pub trait OracleT: CanOracle { /// &Given::Conditions(vec![ /// ("Period_minutes", Datum::Continuous(1200.0)) /// ]), - /// Some(PredictUncertaintyType::JsDivergence), + /// true, /// None, /// ).unwrap(); /// @@ -1960,7 +1962,7 @@ pub trait OracleT: CanOracle { /// &Given::Conditions(vec![ /// ("Period_minutes", Datum::Continuous(1000.0)) /// ]), - /// Some(PredictUncertaintyType::JsDivergence), + /// true, /// None, /// ).unwrap(); /// @@ -1972,7 +1974,7 @@ pub trait OracleT: CanOracle { &self, col_ix: Ix, given: &Given, - unc_type_opt: Option, + with_uncertainty: bool, state_ixs_opt: Option<&[usize]>, ) -> Result<(Datum, Option), error::PredictError> { use super::validation::Mnar; @@ -2006,9 +2008,11 @@ pub trait OracleT: CanOracle { false }; if is_missing { - let unc_opt = unc_type_opt.map(|_| { - utils::mnar_uncertainty_jsd(states.as_slice(), col_ix, &given) - }); + let unc_opt = if with_uncertainty { + Some(utils::mnar_uncertainty(states.as_slice(), col_ix, &given)) + } else { + None + }; Ok((Datum::Missing, unc_opt)) } else { let value = match self.ftype(col_ix).unwrap() { @@ -2032,9 +2036,11 @@ pub trait OracleT: CanOracle { let value = utils::post_process_datum(value, col_ix, self.codebook()); - let unc_opt = unc_type_opt.map(|_| { - self._predict_uncertainty(col_ix, &given, state_ixs_opt) - }); + let unc_opt = if with_uncertainty { + Some(self._predict_uncertainty(col_ix, &given, state_ixs_opt)) + } else { + None + }; Ok((value, unc_opt)) } @@ -2319,21 +2325,8 @@ pub trait OracleT: CanOracle { /// # Arguments /// - row_ix: the row index /// - col_ix: the column index - /// - unc_type: The type of uncertainty to compute - fn _impute_uncertainty( - &self, - row_ix: usize, - col_ix: usize, - unc_type: ImputeUncertaintyType, - ) -> f64 { - match unc_type { - ImputeUncertaintyType::JsDivergence => { - utils::js_impute_uncertainty(self.states(), row_ix, col_ix) - } - ImputeUncertaintyType::PairwiseKl => { - utils::kl_impute_uncertainty(self.states(), row_ix, col_ix) - } - } + fn _impute_uncertainty(&self, row_ix: usize, col_ix: usize) -> f64 { + utils::impute_uncertainty(self.states(), row_ix, col_ix) } /// Computes the uncertainty associated with predicting the value of a @@ -2347,7 +2340,7 @@ pub trait OracleT: CanOracle { /// # Arguments /// - col_ix: the column index /// - given_opt: an optional list of (column index, value) tuples - /// designating other observations on which to condition the prediciton + /// designating other observations on which to condition the prediction fn _predict_uncertainty( &self, col_ix: usize, diff --git a/lace/src/interface/oracle/utils.rs b/lace/src/interface/oracle/utils.rs index 8b799256..d1ef1609 100644 --- a/lace/src/interface/oracle/utils.rs +++ b/lace/src/interface/oracle/utils.rs @@ -13,9 +13,7 @@ use crate::codebook::Codebook; use crate::stats::rv::dist::{ Bernoulli, Categorical, Gaussian, Mixture, Poisson, }; -use crate::stats::rv::traits::{ - Entropy, KlDivergence, Mode, QuadBounds, Rv, Variance, -}; +use crate::stats::rv::traits::{Entropy, Mode, QuadBounds, Rv, Variance}; use crate::stats::MixtureType; use lace_consts::rv::misc::logsumexp; use lace_data::{Category, Datum}; @@ -1503,49 +1501,6 @@ pub fn count_predict( // Predictive uncertainty helpers // ------------------------------ -// Jensen-shannon-divergence for a mixture -fn jsd(mm: Mixture) -> f64 -where - MixtureType: From>, - Fx: Entropy + Clone + std::fmt::Debug, -{ - let h_cpnts = mm - .weights() - .iter() - .zip(mm.components().iter()) - .fold(0.0, |acc, (&w, cpnt)| w.mul_add(cpnt.entropy(), acc)); - - let mt: MixtureType = mm.into(); - let h_mixture = mt.entropy(); - - h_mixture - h_cpnts -} - -fn jsd_mixture(mut components: Vec>) -> f64 -where - MixtureType: From>, -{ - // FIXME: we could do all this with the usual Rv Mixture functions if it - // wasn't for that damned Labeler type - let n_states = components.len() as f64; - let mut h_cpnts = 0_f64; - let mts: Vec = components - .drain(..) - .map(|mm| { - let mt = MixtureType::from(mm); - // h_cpnts += mt.entropy(); - h_cpnts += mt.entropy(); - mt - }) - .collect(); - - // let mt: MixtureType = mm.into(); - let mm = MixtureType::combine(mts); - let h_mixture = mm.entropy(); - - h_mixture - h_cpnts / n_states -} - macro_rules! predunc_arm { ($states: expr, $col_ix: expr, $given_opt: expr, $cpnt_type: ty) => {{ let mix_models: Vec> = $states @@ -1567,7 +1522,8 @@ macro_rules! predunc_arm { }) .collect(); - jsd_mixture(mix_models) + $crate::stats::uncertainty::mixture_normed_tvd(&mix_models) + // jsd_mixture(mix_models) }}; } @@ -1590,7 +1546,7 @@ pub fn predict_uncertainty( } } -pub(crate) fn mnar_uncertainty_jsd( +pub(crate) fn mnar_uncertainty( states: &[&State], col_ix: usize, given: &Given, @@ -1634,7 +1590,7 @@ pub(crate) fn mnar_uncertainty_jsd( // compute time. Bernoulli::new(p).unwrap() } - _ => panic!("Expected MNAR ColModel in MNAR uncertianty fn"), + _ => panic!("Expected MNAR ColModel in MNAR uncertainty fn"), }) .collect::>(); @@ -1662,174 +1618,59 @@ pub(crate) fn mnar_uncertainty_jsd( h_mix - h_cpnt / kf } -macro_rules! js_impunc_arm { - ($k: ident, $row_ix: ident, $states: ident, $ftr: ident, $variant: ident) => {{ +macro_rules! impunc_arm { + ($row_ix: ident, $col_ix: ident, $states: ident, $variant: ident) => {{ let n_states = $states.len(); - let col_ix = $ftr.id; - let mut cpnts = Vec::with_capacity(n_states); - cpnts.push($ftr.components[$k].fx.clone()); - for i in 1..n_states { - let view_ix_s = $states[i].asgn.asgn[col_ix]; - let view_s = &$states[i].views[view_ix_s]; - let k_s = view_s.asgn.asgn[$row_ix]; - match &view_s.ftrs[&col_ix] { - ColModel::$variant(ref ftr) => { - cpnts.push(ftr.components[k_s].fx.clone()); - } - ColModel::MissingNotAtRandom( - $crate::cc::feature::MissingNotAtRandom { fx, .. }, - ) => match &**fx { - ColModel::$variant(ref ftr) => { - cpnts.push(ftr.components[k_s].fx.clone()); - } - cm => { - panic!("Mismatched MNAR feature type: {}", cm.ftype()) - } - }, - cm => panic!("Mismatched feature type: {}", cm.ftype()), - } - } - jsd(Mixture::uniform(cpnts).unwrap()) - }}; -} - -pub fn js_impute_uncertainty( - states: &[State], - row_ix: usize, - col_ix: usize, -) -> f64 { - fn inner( - col_model: &ColModel, - states: &[State], - row_ix: usize, - k: usize, - ) -> f64 { - match col_model { - ColModel::Continuous(ref ftr) => { - js_impunc_arm!(k, row_ix, states, ftr, Continuous) - } - ColModel::Categorical(ref ftr) => { - js_impunc_arm!(k, row_ix, states, ftr, Categorical) - } - ColModel::Count(ref ftr) => { - js_impunc_arm!(k, row_ix, states, ftr, Count) - } - ColModel::MissingNotAtRandom(_) => { - panic!("Inner should not reach MissingNotAtRandom") - } - } - } - - let view_ix = states[0].asgn.asgn[col_ix]; - let view = &states[0].views[view_ix]; - let k = view.asgn.asgn[row_ix]; - let col_model = &view.ftrs[&col_ix]; - match col_model { - ColModel::MissingNotAtRandom(ref mnar) => { - inner(mnar.fx.as_ref(), states, row_ix, k) - } - _ => inner(col_model, states, row_ix, k), - } -} - -macro_rules! kl_impunc_arm { - ($i: ident, $ki: ident, $locators: ident, $fi: ident, $states: ident, $variant: ident) => {{ - let col_ix = $fi.id; - let mut partial_sum = 0.0; - let cpnt_i = &$fi.components[$ki].fx; - for (j, &(vj, kj)) in $locators.iter().enumerate() { - if $i != j { - let cm_j = &$states[j].views[vj].ftrs[&col_ix]; - match cm_j { - ColModel::$variant(ref fj) => { - let cpnt_j = &fj.components[kj].fx; - partial_sum += cpnt_i.kl(cpnt_j); - } + let mixtures = (0..n_states) + .map(|state_ix| { + let view_ix = $states[state_ix].asgn.asgn[$col_ix]; + let view = &$states[state_ix].views[view_ix]; + let k = view.asgn.asgn[$row_ix]; + match &view.ftrs[&$col_ix] { + ColModel::$variant(ref ftr) => ftr.components[k].fx.clone(), ColModel::MissingNotAtRandom( $crate::cc::feature::MissingNotAtRandom { fx, .. }, ) => match &**fx { - ColModel::$variant(ref fj) => { - let cpnt_j = &fj.components[kj].fx; - partial_sum += cpnt_i.kl(cpnt_j); + ColModel::$variant(ref ftr) => { + ftr.components[k].fx.clone() + } + cm => { + panic!( + "Mismatched MNAR feature type: {}", + cm.ftype() + ) } - _ => panic!( - "2nd mnar ColModel was incorrect type: {}", - fx.ftype() - ), }, - _ => panic!( - "2nd ColModel was incorrect type: {}", - cm_j.ftype() - ), + cm => panic!("Mismatched feature type: {}", cm.ftype()), } - } - } - partial_sum + }) + .map(|cpnt| Mixture::uniform(vec![cpnt]).unwrap()) + .collect::>(); + + $crate::stats::uncertainty::mixture_normed_tvd(&mixtures) }}; } -pub fn kl_impute_uncertainty( +pub fn impute_uncertainty( states: &[State], row_ix: usize, col_ix: usize, ) -> f64 { - let locators: Vec<(usize, usize)> = states - .iter() - .map(|state| { - let view_ix = state.asgn.asgn[col_ix]; - let cpnt_ix = state.views[view_ix].asgn.asgn[row_ix]; - (view_ix, cpnt_ix) - }) - .collect(); - - fn inner( - col_model: &ColModel, - i: usize, - ki: usize, - locators: &[(usize, usize)], - states: &[State], - ) -> f64 { - use crate::cc::feature::MissingNotAtRandom; - match col_model { - ColModel::Continuous(ref fi) => { - kl_impunc_arm!(i, ki, locators, fi, states, Continuous) - } - ColModel::Categorical(ref fi) => { - kl_impunc_arm!(i, ki, locators, fi, states, Categorical) - } - ColModel::Count(ref fi) => { - kl_impunc_arm!(i, ki, locators, fi, states, Count) - } - ColModel::MissingNotAtRandom(MissingNotAtRandom { fx, .. }) => { - match &**fx { - ColModel::Continuous(ref fi) => { - kl_impunc_arm!(i, ki, locators, fi, states, Continuous) - } - ColModel::Categorical(ref fi) => { - kl_impunc_arm!(i, ki, locators, fi, states, Categorical) - } - ColModel::Count(ref fi) => { - kl_impunc_arm!(i, ki, locators, fi, states, Count) - } - _ => panic!("Mnar within mnar?"), - } - } + let ftype = states[0].ftype(col_ix); + match ftype { + FType::Continuous => { + impunc_arm!(row_ix, col_ix, states, Continuous) + } + FType::Categorical => { + impunc_arm!(row_ix, col_ix, states, Categorical) + } + FType::Count => { + impunc_arm!(row_ix, col_ix, states, Count) + } + f => { + panic!("Unsupported ftype: {:?}", f) } } - - let mut kl_sum = 0.0; - for (i, &(vi, ki)) in locators.iter().enumerate() { - let col_model = &states[i].views[vi].ftrs[&col_ix]; - kl_sum += match col_model { - ColModel::MissingNotAtRandom(mnar) => { - inner(mnar.fx.as_ref(), i, ki, &locators, states) - } - _ => inner(col_model, i, ki, &locators, states), - }; - } - - let n_states = states.len() as f64; - kl_sum / n_states.mul_add(n_states, -n_states) } #[cfg(test)] diff --git a/lace/src/lib.rs b/lace/src/lib.rs index d8e817b2..1ae7335b 100644 --- a/lace/src/lib.rs +++ b/lace/src/lib.rs @@ -190,10 +190,9 @@ pub use config::EngineUpdateConfig; pub use interface::{ update_handler, utils, AppendStrategy, BuildEngineError, ConditionalEntropyType, DatalessOracle, Engine, EngineBuilder, Given, - HasData, HasStates, ImputeUncertaintyType, InsertDataActions, InsertMode, - Metadata, MiComponents, MiType, Oracle, OracleT, OverwriteMode, - PredictUncertaintyType, Row, RowSimilarityVariant, SupportExtension, Value, - WriteMode, + HasData, HasStates, InsertDataActions, InsertMode, Metadata, MiComponents, + MiType, Oracle, OracleT, OverwriteMode, Row, RowSimilarityVariant, + SupportExtension, Value, WriteMode, }; pub mod error { diff --git a/lace/src/prelude.rs b/lace/src/prelude.rs index e7389119..acc597b7 100644 --- a/lace/src/prelude.rs +++ b/lace/src/prelude.rs @@ -2,9 +2,8 @@ pub use crate::{ update_handler, AppendStrategy, Datum, Engine, EngineBuilder, - EngineUpdateConfig, Given, ImputeUncertaintyType, InsertMode, MiType, - OracleT, OverwriteMode, PredictUncertaintyType, Row, RowSimilarityVariant, - SupportExtension, Value, WriteMode, + EngineUpdateConfig, Given, InsertMode, MiType, OracleT, OverwriteMode, Row, + RowSimilarityVariant, SupportExtension, Value, WriteMode, }; pub use crate::data::DataSource; diff --git a/lace/tests/oracle.rs b/lace/tests/oracle.rs index fdf47313..1a1954f1 100644 --- a/lace/tests/oracle.rs +++ b/lace/tests/oracle.rs @@ -1579,7 +1579,7 @@ macro_rules! oracle_test { let oracle = $oracle_gen; assert_eq!( - oracle.impute(4, 1, None), + oracle.impute(4, 1, false), Err(IndexError::RowIndexOutOfBounds { n_rows: 4, row_ix: 4, @@ -1592,7 +1592,7 @@ macro_rules! oracle_test { let oracle = $oracle_gen; assert_eq!( - oracle.impute(1, 3, None), + oracle.impute(1, 3, false), Err(IndexError::ColumnIndexOutOfBounds { n_cols: 3, col_ix: 3, @@ -1613,7 +1613,7 @@ macro_rules! oracle_test { let oracle = $oracle_gen; assert_eq!( - oracle.predict(3, &Given::::Nothing, None, None), + oracle.predict(3, &Given::::Nothing, false, None), Err(PredictError::IndexError( IndexError::ColumnIndexOutOfBounds { n_cols: 3, @@ -1631,7 +1631,7 @@ macro_rules! oracle_test { oracle.predict( 1, &Given::Conditions(vec![(3, Datum::Continuous(1.2))]), - None, + false, None ), Err(PredictError::GivenError(GivenError::IndexError( @@ -1654,7 +1654,7 @@ macro_rules! oracle_test { 0, Datum::Categorical(1_u8.into()) )]), - None, + false, None, ), Err(PredictError::GivenError(GivenError::IndexError( @@ -1675,7 +1675,7 @@ macro_rules! oracle_test { oracle.predict( 0, &Given::Conditions(vec![(0, Datum::Continuous(2.1))]), - None, + false, None ), Err(PredictError::GivenError( diff --git a/pylace/lace/engine.py b/pylace/lace/engine.py index 90669259..c607a28e 100644 --- a/pylace/lace/engine.py +++ b/pylace/lace/engine.py @@ -1746,7 +1746,7 @@ def impute( self, col: Union[str, int], rows: Optional[List[Union[str, int]]] = None, - unc_type: Optional[str] = "js_divergence", + with_uncertainty: bool = True, ): r""" Impute (predict) the value of a cell(s) in the lace table. @@ -1762,20 +1762,6 @@ def impute( be returned, even if the value is most likely to be missing. Imputation forces the value of a cell to be present. - The following methods are used to compute uncertainty. - - * unc_type='js_divergence' computes the Jensen-Shannon divergence - between the state imputation distributions. - - .. math:: - JS(X_1, X_2, ..., X_S) - - * unc_type='pairwise_kl' computes the mean of the Kullback-Leibler - divergences between pairs of state imputation distributions. - - .. math:: - \frac{1}{S^2 - S} \sum_{i=1}^S \sum{j \in \{1,..,S\} \setminus i} KL(X_i | X_j) - Parameters ---------- col: column index @@ -1783,14 +1769,8 @@ def impute( rows: List[row index], optional Optional row indices to impute. If ``None`` (default), all the rows with missing values will be imputed - unc_type: str, optional - The type of uncertainty to compute. If ``None``, uncertainty will - not be computed. Acceptable values are: - - - 'js_divergence' (default): The Jensen-Shannon divergence between the - imputation distributions in each state. - - 'pairwise_kl': The mean pairwise Kullback-Leibler divergence - between pairs of state imputation distributions. + with_uncertainty: bool, default: True + If True, compute and return the impute uncertainty Returns ------- @@ -1848,7 +1828,7 @@ def impute( Uncertainty is optional - >>> engine.impute("Type_of_Orbit", unc_type=None) # doctest: +NORMALIZE_WHITESPACE + >>> engine.impute("Type_of_Orbit", with_uncertainty=False) # doctest: +NORMALIZE_WHITESPACE shape: (645, 2) ┌───────────────────────────────────┬─────────────────┐ │ index ┆ Type_of_Orbit │ @@ -1866,7 +1846,7 @@ def impute( │ Zhongxing 9 (Chinasat 9, Chinast… ┆ Sun-Synchronous │ └───────────────────────────────────┴─────────────────┘ """ - return self.engine.impute(col, rows, unc_type) + return self.engine.impute(col, rows, with_uncertainty) def depprob(self, col_pairs: list): """ diff --git a/pylace/src/lib.rs b/pylace/src/lib.rs index 1ab711c0..5dacf03e 100644 --- a/pylace/src/lib.rs +++ b/pylace/src/lib.rs @@ -12,9 +12,7 @@ use df::{DataFrameLike, PyDataFrame, PySeries}; use lace::data::DataSource; use lace::metadata::SerializedType; use lace::prelude::ColMetadataList; -use lace::{ - EngineUpdateConfig, FType, HasStates, OracleT, PredictUncertaintyType, -}; +use lace::{EngineUpdateConfig, FType, HasStates, OracleT}; use polars::prelude::{DataFrame, NamedFrom, Series}; use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; @@ -830,35 +828,22 @@ impl CoreEngine { /// rows: list(str) or list(int), optional /// Optional list of rows to impute. If None (default), all missing /// cells will be selected. - /// unc_type: str, optional - /// Can be `'js_divergence'` (default), `'pairwise_kl'` or `None`. If - /// None, uncertainty will not be computed. + /// with_uncertainty: bool, optional + /// If true (default), compute and return uncertainty /// /// Returns /// ------- /// df: polars.DataFrame /// A data frame with columns for row names, values, and optional /// uncertainty - #[pyo3(signature=(col, rows=None, unc_type="js_divergence"))] + #[pyo3(signature=(col, rows=None, with_uncertainty=true))] fn impute( &mut self, col: &PyAny, rows: Option<&PyAny>, - unc_type: Option<&str>, + with_uncertainty: bool, ) -> PyResult { use lace::cc::feature::Feature; - use lace::ImputeUncertaintyType; - - let unc_type_opt = match unc_type { - Some("js_divergence") => { - Ok(Some(ImputeUncertaintyType::JsDivergence)) - } - Some("pairwise_kl") => Ok(Some(ImputeUncertaintyType::PairwiseKl)), - Some(val) => Err(PyErr::new::(format!( - "Invalid unc_type: '{val}'" - ))), - None => Ok(None), - }?; let col_ix = utils::value_to_index(col, &self.col_indexer)?; @@ -878,7 +863,7 @@ impl CoreEngine { row_ixs.drain(..).try_for_each(|row_ix| { self.engine - .impute(row_ix, col_ix, unc_type_opt) + .impute(row_ix, col_ix, with_uncertainty) .map(|(val, unc)| { values.push(val); row_names.push(self.row_indexer.to_name[&row_ix].clone()); @@ -946,10 +931,9 @@ impl CoreEngine { let given = dict_to_given(given, &self.engine, &self.col_indexer)?; if with_uncertainty { - let unc_type_opt = Some(PredictUncertaintyType::JsDivergence); let (pred, unc) = self .engine - .predict(col_ix, &given, unc_type_opt, state_ixs.as_deref()) + .predict(col_ix, &given, true, state_ixs.as_deref()) .map_err(|err| { PyErr::new::(format!("{err}")) })?; @@ -961,7 +945,7 @@ impl CoreEngine { } else { let (pred, _) = self .engine - .predict(col_ix, &given, None, state_ixs.as_deref()) + .predict(col_ix, &given, false, state_ixs.as_deref()) .map_err(|err| { PyErr::new::(format!("{err}")) })?;