Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move out new_ds to arviz-stats #102

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 5 additions & 39 deletions src/arviz_plots/plots/psensedistplot.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""PsenseDist plot code."""
from importlib import import_module

from arviz_base import extract, rcParams
from arviz_base import rcParams
from arviz_base.labels import BaseLabeller
from arviz_stats.psense import _get_power_scale_weights
from arviz_stats.psense import power_scale_dataset
from xarray import concat

from arviz_plots.plot_collection import PlotCollection, process_facet_dims
Expand Down Expand Up @@ -136,16 +136,16 @@ def plot_psense_dist(
# Here we are generating new datasets for the prior and likelihood
# by resampling the original dataset with the power scale weights
# Instead we could have weighted KDEs/ecdfs/etc
ds_prior = new_ds(dt, "prior", alphas, sample_dims=sample_dims)
ds_likelihood = new_ds(dt, "likelihood", alphas, sample_dims=sample_dims)
ds_prior = power_scale_dataset(dt, "prior", alphas, sample_dims=sample_dims)
ds_likelihood = power_scale_dataset(dt, "likelihood", alphas, sample_dims=sample_dims)
distribution = concat([ds_prior, ds_likelihood], dim="component_group").assign_coords(
{"component_group": ["prior", "likelihood"]}
)
distribution = process_group_variables_coords(
distribution, group=None, var_names=var_names, filter_vars=filter_vars, coords=coords
)
if len(sample_dims) > 1:
# sample dims will have been stacked and renamed by `new_ds`
# sample dims will have been stacked and renamed by `power_scale_dataset`
sample_dims = ["sample"]

if backend is None:
Expand Down Expand Up @@ -250,37 +250,3 @@ def plot_psense_dist(
)

return plot_collection


def new_ds(dt, group, alphas, sample_dims):
"""Resample the dataset with the power scale weights."""
lower_w, upper_w = _get_power_scale_weights(dt, alphas, group=group, sample_dims=sample_dims)
lower_w = lower_w.values.flatten()
upper_w = upper_w.values.flatten()
s_size = len(lower_w)

idxs_to_drop = sample_dims if len(sample_dims) == 1 else ["sample"] + sample_dims
idxs_to_drop = set(idxs_to_drop).union(
[
idx
for idx in dt["posterior"].xindexes
if any(dim in dt["posterior"][idx].dims for dim in sample_dims)
]
)
resampled = [
extract(
dt,
group="posterior",
sample_dims=sample_dims,
num_samples=s_size,
weights=weights,
random_seed=42,
resampling_method="stratified",
).drop_indexes(idxs_to_drop)
for weights in (lower_w, upper_w)
]
resampled.insert(
1, extract(dt, group="posterior", sample_dims=sample_dims).drop_indexes(idxs_to_drop)
)

return concat(resampled, dim="alpha").assign_coords(alpha=[alphas[0], 1, alphas[1]])