diff --git a/arviz/plots/forestplot.py b/arviz/plots/forestplot.py index 9cc9bc8280..28cb5d43bf 100644 --- a/arviz/plots/forestplot.py +++ b/arviz/plots/forestplot.py @@ -75,8 +75,8 @@ def plot_forest( interpret `var_names` as substrings of the real variables names. If "regex", interpret `var_names` as regular expressions on the real variables names. See :ref:`this section ` for usage examples. - transform : callable, optional - Function to transform data (defaults to None i.e.the identity function). + transform : callable or dict, optional + Function to transform the data. Defaults to None, i.e., the identity function. coords : dict, optional Coordinates of ``var_names`` to be plotted. Passed to :meth:`xarray.Dataset.sel`. See :ref:`this section ` for usage examples. @@ -228,7 +228,19 @@ def plot_forest( datasets = [convert_to_dataset(datum) for datum in reversed(data)] if transform is not None: - datasets = [transform(dataset) for dataset in datasets] + if callable(transform): + datasets = [transform(dataset) for dataset in datasets] + elif isinstance(transform, dict): + transformed_datasets = [] + for dataset in datasets: + new_dataset = dataset.copy() + for var_name, func in transform.items(): + if var_name in new_dataset: + new_dataset[var_name] = func(new_dataset[var_name]) + transformed_datasets.append(new_dataset) + datasets = transformed_datasets + else: + raise ValueError("transform must be either a callable or a dict {var_name: callable}") datasets = get_coords( datasets, list(reversed(coords)) if isinstance(coords, (list, tuple)) else coords ) diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index 5afcbe1669..e0047d9e8b 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -2086,3 +2086,21 @@ def test_plot_bf(): bf_dict1, _ = plot_bf(idata, prior=np.random.normal(0, 10, 5000), var_name="a", ref_val=0) assert bf_dict0["BF10"] > bf_dict0["BF01"] assert bf_dict1["BF10"] < bf_dict1["BF01"] + + +def test_plot_forest_with_transform(): + """Test if plot_forest runs successfully with a transform dictionary.""" + data = xr.Dataset( + { + "var1": (["chain", "draw"], np.array([[1, 2, 3], [4, 5, 6]])), + "var2": (["chain", "draw"], np.array([[7, 8, 9], [10, 11, 12]])), + }, + coords={"chain": [0, 1], "draw": [0, 1, 2]}, + ) + transform_dict = { + "var1": lambda x: x + 1, + "var2": lambda x: x * 2, + } + + axes = plot_forest(data, transform=transform_dict, show=False) + assert axes is not None diff --git a/examples/matplotlib/mpl_plot_forest_transform.py b/examples/matplotlib/mpl_plot_forest_transform.py new file mode 100644 index 0000000000..d23b59ba41 --- /dev/null +++ b/examples/matplotlib/mpl_plot_forest_transform.py @@ -0,0 +1,38 @@ +"""Forest Plot with transforms +============================== +_gallery_category: Distributions +""" + +import arviz as az +import matplotlib.pyplot as plt +import numpy as np + +non_centered_data = az.load_arviz_data("non_centered_eight") + + +def log_transform(data): + """Apply log transformation, avoiding log(0).""" + return np.log(np.maximum(data, 1e-8)) + + +def exp_transform(data): + """Apply exponential transformation.""" + return np.exp(data) + + +def center_data(data): + """Center the data by subtracting the mean.""" + return data - np.mean(data) + + +axes = az.plot_forest( + non_centered_data, + kind="forestplot", + var_names=["theta", "mu", "tau"], + filter_vars=None, + combined=True, + figsize=(9, 7), + transform={"theta": center_data, "mu": exp_transform, "tau": log_transform}, +) +axes[0].set_title("Estimated theta for 8 schools model") +plt.show()