Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Priorsense functions: allow inferencedata as input #54

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions src/arviz_stats/psense.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pandas as pd
import xarray as xr
from arviz_base import extract
from arviz_base import convert_to_datatree, extract
from arviz_base.labels import BaseLabeller
from arviz_base.sel_utils import xarray_var_iter

Expand All @@ -18,7 +18,7 @@


def psense(
dt,
data,
var_names=None,
filter_vars=None,
group="prior",
Expand All @@ -33,11 +33,8 @@ def psense(

Parameters
----------
dt : obj
Any object that can be converted to an :class:`arviz.InferenceData` object.
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
For ndarray: shape = (chain, draw).
For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``.
data : DataTree or InferenceData
Input data. It should contain the posterior and the log_likelihood and/or log_prior groups.
var_names : list of str, optional
Names of posterior variables to include in the power scaling sensitivity diagnostic
filter_vars: {None, "like", "regex"}, default None
Expand Down Expand Up @@ -82,8 +79,10 @@ def psense(
.. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with
power-scaling*, Stat Comput 34, 57 (2024), https://doi.org/10.1007/s11222-023-10366-5
"""
data = convert_to_datatree(data)

dataset = extract(
dt,
data,
var_names=var_names,
filter_vars=filter_vars,
group="posterior",
Expand All @@ -94,7 +93,7 @@ def psense(
dataset = dataset.sel(coords)

lower_w, upper_w = _get_power_scale_weights(
dt,
data,
alphas=alphas,
group=group,
sample_dims=sample_dims,
Expand Down Expand Up @@ -130,7 +129,8 @@ def psense_summary(

Parameters
----------
data : DataTree
data : DataTree or InferenceData
Input data. It should contain the posterior and the log_likelihood and/or log_prior groups.
var_names : list of str, optional
Names of posterior variables to include in the power scaling sensitivity diagnostic
filter_vars: {None, "like", "regex"}, default None
Expand Down Expand Up @@ -234,12 +234,13 @@ def _diagnose(row):
return psense_df.round(round_to)


def power_scale_dataset(dt, group, alphas, sample_dims, group_var_names, group_coords):
"""Resample the dataset with the power scale weights.
def power_scale_dataset(data, group, alphas, sample_dims, group_var_names, group_coords):
"""Resample posterior based on power-scaled weights.

Parameters
----------
dt : DataSet
data : DataTree or InferenceData
Input data. It should contain the posterior and the log_likelihood and/or log_prior groups.
group : str
Group to resample. Either "prior" or "likelihood"
alphas : tuple of float
Expand All @@ -256,6 +257,8 @@ def power_scale_dataset(dt, group, alphas, sample_dims, group_var_names, group_c
-------
DataSet with resampled data.
"""
dt = convert_to_datatree(data)

lower_w, upper_w = _get_power_scale_weights(
dt,
alphas,
Expand Down