diff --git a/src/arviz_plots/plots/psensedistplot.py b/src/arviz_plots/plots/psensedistplot.py index ff1620d..b3680bd 100644 --- a/src/arviz_plots/plots/psensedistplot.py +++ b/src/arviz_plots/plots/psensedistplot.py @@ -1,9 +1,9 @@ """PsenseDist plot code.""" from importlib import import_module -from arviz_base import extract, rcParams +from arviz_base import rcParams from arviz_base.labels import BaseLabeller -from arviz_stats.psense import _get_power_scale_weights +from arviz_stats.psense import power_scale_dataset from xarray import concat from arviz_plots.plot_collection import PlotCollection, process_facet_dims @@ -136,8 +136,8 @@ def plot_psense_dist( # Here we are generating new datasets for the prior and likelihood # by resampling the original dataset with the power scale weights # Instead we could have weighted KDEs/ecdfs/etc - ds_prior = new_ds(dt, "prior", alphas, sample_dims=sample_dims) - ds_likelihood = new_ds(dt, "likelihood", alphas, sample_dims=sample_dims) + ds_prior = power_scale_dataset(dt, "prior", alphas, sample_dims=sample_dims) + ds_likelihood = power_scale_dataset(dt, "likelihood", alphas, sample_dims=sample_dims) distribution = concat([ds_prior, ds_likelihood], dim="component_group").assign_coords( {"component_group": ["prior", "likelihood"]} ) @@ -145,7 +145,7 @@ def plot_psense_dist( distribution, group=None, var_names=var_names, filter_vars=filter_vars, coords=coords ) if len(sample_dims) > 1: - # sample dims will have been stacked and renamed by `new_ds` + # sample dims will have been stacked and renamed by `power_scale_dataset` sample_dims = ["sample"] if backend is None: @@ -250,37 +250,3 @@ def plot_psense_dist( ) return plot_collection - - -def new_ds(dt, group, alphas, sample_dims): - """Resample the dataset with the power scale weights.""" - lower_w, upper_w = _get_power_scale_weights(dt, alphas, group=group, sample_dims=sample_dims) - lower_w = lower_w.values.flatten() - upper_w = upper_w.values.flatten() - s_size = len(lower_w) - - idxs_to_drop = sample_dims if len(sample_dims) == 1 else ["sample"] + sample_dims - idxs_to_drop = set(idxs_to_drop).union( - [ - idx - for idx in dt["posterior"].xindexes - if any(dim in dt["posterior"][idx].dims for dim in sample_dims) - ] - ) - resampled = [ - extract( - dt, - group="posterior", - sample_dims=sample_dims, - num_samples=s_size, - weights=weights, - random_seed=42, - resampling_method="stratified", - ).drop_indexes(idxs_to_drop) - for weights in (lower_w, upper_w) - ] - resampled.insert( - 1, extract(dt, group="posterior", sample_dims=sample_dims).drop_indexes(idxs_to_drop) - ) - - return concat(resampled, dim="alpha").assign_coords(alpha=[alphas[0], 1, alphas[1]])