Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modified transform parameter in plot_forest to accept dictionary values #2403

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions arviz/plots/forestplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@
interpret `var_names` as substrings of the real variables names. If "regex",
interpret `var_names` as regular expressions on the real variables names. See
:ref:`this section <common_filter_vars>` for usage examples.
transform : callable, optional
Function to transform data (defaults to None i.e.the identity function).
transform : callable or dict, optional
Function to transform the data. Defaults to None, i.e., the identity function.
coords : dict, optional
Coordinates of ``var_names`` to be plotted. Passed to :meth:`xarray.Dataset.sel`.
See :ref:`this section <common_coords>` for usage examples.
Expand Down Expand Up @@ -228,7 +228,19 @@

datasets = [convert_to_dataset(datum) for datum in reversed(data)]
if transform is not None:
datasets = [transform(dataset) for dataset in datasets]
if callable(transform):
datasets = [transform(dataset) for dataset in datasets]
elif isinstance(transform, dict):
transformed_datasets = []
for dataset in datasets:
new_dataset = dataset.copy()
for var_name, func in transform.items():
if var_name in new_dataset:
new_dataset[var_name] = func(new_dataset[var_name])
transformed_datasets.append(new_dataset)
datasets = transformed_datasets
else:
raise ValueError("transform must be either a callable or a dict {var_name: callable}")

Check warning on line 243 in arviz/plots/forestplot.py

View check run for this annotation

Codecov / codecov/patch

arviz/plots/forestplot.py#L243

Added line #L243 was not covered by tests
datasets = get_coords(
datasets, list(reversed(coords)) if isinstance(coords, (list, tuple)) else coords
)
Expand Down
18 changes: 18 additions & 0 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2086,3 +2086,21 @@ def test_plot_bf():
bf_dict1, _ = plot_bf(idata, prior=np.random.normal(0, 10, 5000), var_name="a", ref_val=0)
assert bf_dict0["BF10"] > bf_dict0["BF01"]
assert bf_dict1["BF10"] < bf_dict1["BF01"]


def test_plot_forest_with_transform():
"""Test if plot_forest runs successfully with a transform dictionary."""
data = xr.Dataset(
{
"var1": (["chain", "draw"], np.array([[1, 2, 3], [4, 5, 6]])),
"var2": (["chain", "draw"], np.array([[7, 8, 9], [10, 11, 12]])),
},
coords={"chain": [0, 1], "draw": [0, 1, 2]},
)
transform_dict = {
"var1": lambda x: x + 1,
"var2": lambda x: x * 2,
}

axes = plot_forest(data, transform=transform_dict, show=False)
assert axes is not None
38 changes: 38 additions & 0 deletions examples/matplotlib/mpl_plot_forest_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Forest Plot with transforms
==============================
_gallery_category: Distributions
"""

import arviz as az
import matplotlib.pyplot as plt
import numpy as np

non_centered_data = az.load_arviz_data("non_centered_eight")


def log_transform(data):
"""Apply log transformation, avoiding log(0)."""
return np.log(np.maximum(data, 1e-8))


def exp_transform(data):
"""Apply exponential transformation."""
return np.exp(data)


def center_data(data):
"""Center the data by subtracting the mean."""
return data - np.mean(data)


axes = az.plot_forest(
non_centered_data,
kind="forestplot",
var_names=["theta", "mu", "tau"],
filter_vars=None,
combined=True,
figsize=(9, 7),
transform={"theta": center_data, "mu": exp_transform, "tau": log_transform},
)
axes[0].set_title("Estimated theta for 8 schools model")
plt.show()