Skip to content

Commit

Permalink
Make uncertainty total variation distance
Browse files Browse the repository at this point in the history
  • Loading branch information
Baxter Eaves committed Jan 13, 2024
1 parent 526ca5c commit ce69065
Show file tree
Hide file tree
Showing 12 changed files with 359 additions and 370 deletions.
15 changes: 5 additions & 10 deletions lace/benches/oracle_fns.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use criterion::Criterion;
use criterion::{black_box, criterion_group, criterion_main};
use lace::examples::Example;
use lace::{
Given, ImputeUncertaintyType, Oracle, OracleT, PredictUncertaintyType,
};
use lace::{Given, Oracle, OracleT};
use lace_data::Datum;
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
Expand Down Expand Up @@ -232,9 +230,8 @@ fn bench_simulate(c: &mut Criterion) {
fn bench_impute(c: &mut Criterion) {
c.bench_function("oracle impute", |b| {
let oracle = get_oracle();
let unc_type = ImputeUncertaintyType::JsDivergence;
b.iter(|| {
let _res = black_box(oracle.impute(13, 12, Some(unc_type)));
let _res = black_box(oracle.impute(13, 12, true));
})
});
}
Expand All @@ -246,23 +243,21 @@ fn bench_predict(c: &mut Criterion) {
(2, Datum::Categorical(lace::Category::U8(0))),
]);
let oracle = get_oracle();
let unc_type = PredictUncertaintyType::JsDivergence;
b.iter(|| {
let _res =
black_box(oracle.predict(13, &given, Some(unc_type), None));
let _res = black_box(oracle.predict(13, &given, true, None));
})
});
}

fn bench_predict_continous(c: &mut Criterion) {
fn bench_predict_continuous(c: &mut Criterion) {
c.bench_function("oracle predict continuous", |b| {
let given = Given::Conditions(vec![(
4,
Datum::Categorical(lace::Category::U8(3)),
)]);
let oracle = get_satellites_oracle();
b.iter(|| {
let _res = black_box(oracle.predict(8, &given, None, None));
let _res = black_box(oracle.predict(8, &given, false, None));
})
});
}
Expand Down
1 change: 1 addition & 0 deletions lace/lace_stats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mod perm;
pub mod prior;
pub mod seq;
mod simplex;
pub mod uncertainty;

mod sample_error;

Expand Down
239 changes: 239 additions & 0 deletions lace/lace_stats/src/uncertainty.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
use crate::rv::dist::{Bernoulli, Categorical, Gaussian, Mixture, Poisson};
use crate::rv::traits::{Mean, QuadBounds, Rv};

/// Compute the normed mean Total Variation Distance of a set of mixture
/// distributions with the mean of distributions.
///
/// # Notes
/// - The output will be in [0, 1.0).
/// - Normalization is used to account for the fact that the maximum TVD is
/// limited by the number of mixtures. For example, if there are two mixtures
/// in `mixtures` the max TVD in only 1/2; if there are three, the max TVD is
/// 2/3; if there are four the max TVD is 3/4; and so on. We divide the final
/// output by `(n - 1) / n`, where `n` is the number of mixtures, so that the
/// output can be interpreted similarly regardless of the input.
pub fn mixture_normed_tvd<Fx>(mixtures: &[Mixture<Fx>]) -> f64
where
Fx: Clone,
Mixture<Fx>: TotalVariationDistance,
{
let n = mixtures.len() as f64;
let norm = (n - 1.0) / n;

let combined = Mixture::combine(mixtures.to_owned());
let tvd = mixtures.iter().map(|mm| combined.tvd(mm)).sum::<f64>()
/ mixtures.len() as f64;

tvd / norm
}

pub trait TotalVariationDistance {
fn tvd(&self, other: &Self) -> f64;
}

fn gaussian_quad_points(
f1: &Mixture<Gaussian>,
f2: &Mixture<Gaussian>,
) -> Vec<f64> {
// Get the lower and upper bound for quadrature
let (a, b) = {
let (a_1, b_1) = f1.quad_bounds();
let (a_2, b_2) = f2.quad_bounds();
(a_1.min(a_2), b_1.max(b_2))
};

// Get a list of sorted means and their associated stddevs
let params = {
let mut params = f1
.components()
.iter()
.chain(f2.components().iter())
.map(|cpnt| (cpnt.mu(), cpnt.sigma()))
.collect::<Vec<_>>();
params.sort_unstable_by(|(a, _), (b, _)| a.total_cmp(b));
params
};

let mut last_mean = params[0].0;
let mut last_std = params[0].1;
let mut points = vec![a, last_mean];

for &(mean, std) in params.iter().skip(1) {
let dist = mean - last_mean;
let z_dist = dist / ((last_std + std) / 2.0);
if z_dist > 1.0 {
points.push(mean);
last_std = std;
last_mean = mean;
}
}

points.push(b);
points
}

impl TotalVariationDistance for Mixture<Gaussian> {
fn tvd(&self, other: &Self) -> f64 {
use crate::rv::misc::{
gauss_legendre_quadrature_cached, gauss_legendre_table,
};

let func = |x| (self.f(&x) - other.f(&x)).abs();

let quad_level = 16;
let quad_points = gaussian_quad_points(self, other);
let (weights, roots) = gauss_legendre_table(quad_level);

let mut right = quad_points[0];
quad_points
.iter()
.skip(1)
.map(|&x| {
let q = gauss_legendre_quadrature_cached(
func,
(right, x),
&weights,
&roots,
);
right = x;
q
})
.sum::<f64>()
/ 2.0
}
}

impl TotalVariationDistance for Mixture<Categorical> {
fn tvd(&self, other: &Self) -> f64 {
let k = self.components()[0].k();
assert_eq!(k, other.components()[0].k());
(0..k)
.map(|x| (self.f(&x) - other.f(&x)).abs())
.sum::<f64>()
/ 2.0
}
}

impl TotalVariationDistance for Mixture<Bernoulli> {
fn tvd(&self, other: &Self) -> f64 {
let q =
(self.f(&0) - other.f(&0)).abs() + (self.f(&1) - other.f(&1)).abs();
q / 2.0
}
}

impl TotalVariationDistance for Mixture<Poisson> {
fn tvd(&self, other: &Self) -> f64 {
let threshold = 1e-14;
let m: u32 = self.mean().unwrap().min(other.mean().unwrap()) as u32;

let mut x: u32 = 0;
let mut q: f64 = 0.0;
loop {
let f1 = self.f(&x);
let f2 = other.f(&x);

let diff = (f1 - f2).abs();

q += diff;
x += 1;

if x > m && (f1 < threshold && f2 < threshold) {
break;
}
}
q / 2.0
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn gauss_moving_means_away_increases_tvd() {
let mut last_tvd = 0.0;
(0..10).for_each(|i| {
let dist = 0.5 * (i + 1) as f64;
let g1 = Gaussian::new(-dist / 2.0, 1.0).unwrap();
let g2 = Gaussian::new(dist / 2.0, 1.0).unwrap();

let m1 = Mixture::uniform(vec![g1]).unwrap();
let m2 = Mixture::uniform(vec![g2]).unwrap();

let tvd = mixture_normed_tvd(&[m1, m2]);

eprintln!("{i} - d: {dist}, tvd: {tvd}");

assert!(last_tvd < tvd);
assert!(tvd <= 1.0);

last_tvd = tvd;
});
}

#[test]
fn count_moving_means_away_increases_tvd() {
let mut last_tvd = 0.0;
(0..10).for_each(|i| {
let p1 = Poisson::new(5.0).unwrap();
let p2 = Poisson::new(5.0 + (i + 1) as f64).unwrap();

let m1 = Mixture::uniform(vec![p1]).unwrap();
let m2 = Mixture::uniform(vec![p2]).unwrap();

let tvd = mixture_normed_tvd(&[m1, m2]);

eprintln!("{i} tvd: {tvd}");

assert!(last_tvd < tvd);
assert!(tvd <= 1.0);

last_tvd = tvd;
});
}

#[test]
fn bernoulli_moving_means_away_increases_tvd() {
let mut last_tvd = std::f64::NEG_INFINITY;
(0..10).for_each(|i| {
let p = 0.5 / (i + 1) as f64;
let b1 = Bernoulli::new(p).unwrap();
let b2 = Bernoulli::new(1.0 - p).unwrap();

let m1 = Mixture::uniform(vec![b1]).unwrap();
let m2 = Mixture::uniform(vec![b2]).unwrap();

let tvd = mixture_normed_tvd(&[m1, m2]);

eprintln!("{i} p: {p}, tvd: {tvd}");

assert!(last_tvd < tvd);
assert!(tvd <= 1.0);

last_tvd = tvd;
});
}

#[test]
fn categorical_moving_means_away_increases_tvd() {
let mut last_tvd = std::f64::NEG_INFINITY;
(0..10).for_each(|i| {
let p = 0.5 / (i + 1) as f64;
let c1 = Categorical::new(&[p, 1.0 - p]).unwrap();
let c2 = Categorical::new(&[1.0 - p, p]).unwrap();

let m1 = Mixture::uniform(vec![c1]).unwrap();
let m2 = Mixture::uniform(vec![c2]).unwrap();

let tvd = mixture_normed_tvd(&[m1, m2]);

eprintln!("{i} p: {p}, tvd: {tvd}");

assert!(last_tvd < tvd);
assert!(tvd <= 1.0);

last_tvd = tvd;
});
}
}
5 changes: 2 additions & 3 deletions lace/src/interface/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ pub use lace_metadata::latest::Metadata;
pub use oracle::utils;

pub use oracle::{
ConditionalEntropyType, DatalessOracle, ImputeUncertaintyType,
MiComponents, MiType, Oracle, OracleT, PredictUncertaintyType,
RowSimilarityVariant,
ConditionalEntropyType, DatalessOracle, MiComponents, MiType, Oracle,
OracleT, RowSimilarityVariant,
};

pub use given::Given;
Expand Down
51 changes: 5 additions & 46 deletions lace/src/interface/oracle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,38 +83,6 @@ impl MiComponents {
}
}

/// The type of uncertainty to use for `Oracle.impute`
#[derive(
Serialize,
Deserialize,
Debug,
Clone,
Copy,
Eq,
PartialEq,
Ord,
PartialOrd,
Hash,
)]
#[serde(rename_all = "snake_case")]
pub enum ImputeUncertaintyType {
/// Given a set of distributions Θ = {Θ<sub>1</sub>, ..., Θ<sub>n</sub>},
/// return the mean of KL(Θ<sub>i</sub> || Θ<sub>i</sub>)
PairwiseKl,
/// The Jensen-Shannon divergence in nats divided by ln(n), where n is the
/// number of distributions
JsDivergence,
}

/// The type of uncertainty to use for `Oracle.predict`
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
#[serde(rename_all = "snake_case")]
pub enum PredictUncertaintyType {
/// The Jensen-Shannon divergence in nats divided by ln(n), where n is the
/// number of distributions
JsDivergence,
}

/// The variant on conditional entropy to compute
#[derive(
Serialize,
Expand Down Expand Up @@ -407,29 +375,19 @@ mod tests {
}

#[test]
fn kl_impute_uncertainty_smoke() {
let oracle = get_oracle_from_yaml();
let u =
oracle._impute_uncertainty(0, 1, ImputeUncertaintyType::PairwiseKl);
assert!(u > 0.0);
}

#[test]
fn js_impute_uncertainty_smoke() {
fn impute_uncertainty_smoke() {
let oracle = get_oracle_from_yaml();
let u = oracle._impute_uncertainty(
0,
1,
ImputeUncertaintyType::JsDivergence,
);
let u = oracle._impute_uncertainty(0, 1);
assert!(u > 0.0);
assert!(u < 1.0);
}

#[test]
fn predict_uncertainty_smoke_no_given() {
let oracle = get_oracle_from_yaml();
let u = oracle._predict_uncertainty(0, &Given::Nothing, None);
assert!(u > 0.0);
assert!(u < 1.0);
}

#[test]
Expand All @@ -438,6 +396,7 @@ mod tests {
let given = Given::Conditions(vec![(1, Datum::Continuous(2.5))]);
let u = oracle._predict_uncertainty(0, &given, None);
assert!(u > 0.0);
assert!(u < 1.0);
}

#[test]
Expand Down
Loading

0 comments on commit ce69065

Please sign in to comment.