From 5a52adee4c5929792833b8ddfb4d9b09383979b9 Mon Sep 17 00:00:00 2001 From: PiyushPanwarFST Date: Tue, 28 Jan 2025 17:40:41 +0530 Subject: [PATCH] bayes_factor computation functionality --- src/arviz_stats/bayes_factor.py | 84 +++++++++++++++++++++++++++++++++ tests/test_bayes_factor.py | 73 ++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+) create mode 100644 src/arviz_stats/bayes_factor.py create mode 100644 tests/test_bayes_factor.py diff --git a/src/arviz_stats/bayes_factor.py b/src/arviz_stats/bayes_factor.py new file mode 100644 index 0000000..d105c30 --- /dev/null +++ b/src/arviz_stats/bayes_factor.py @@ -0,0 +1,84 @@ +"""Bayes Factor using Savage-Dickey density ratio.""" + +import warnings + +import numpy as np +from arviz_base import extract + +from arviz_stats.base.density import _DensityBase + + +def bayes_factor(idata, var_name, ref_val=0, return_ref_vals=False, prior=None): + """ + Approximated Bayes Factor for comparing hypothesis of two nested models. + + Parameters + ---------- + idata : InferenceData + Object containing posterior and prior data. + var_name : str + Name of the variable to test. + ref_val : int or float, default 0 + Reference (point-null) value for Bayes factor estimation. + return_ref_vals : bool, default False + If True, also return the values of prior and posterior densities at the reference value. + + Returns + ------- + dict + A dictionary with Bayes Factor values: BF10 (H1/H0 ratio) and BF01 (H0/H1 ratio). + + References + ---------- + .. [1] Heck, D., 2019. A caveat on the Savage-Dickey density ratio: + The case of computing Bayes factors for regression parameters. + + Examples + -------- + Moderate evidence indicating that the parameter "a" is different from zero. + + .. ipython:: + + In [1]: import numpy as np + ...: from arviz_base import from_dict + ...: import arviz_stats as azs + ...: idata = from_dict({"posterior":{"a":np.random.normal(1, 0.5, (2, 1000))}, + ...: {"prior":{"a":np.random.normal(0, 1, (2, 1000))}}}) + ...: azs.plot_bf(idata, var_name="a", ref_val=0) + """ + posterior = extract(idata, var_names=var_name).values + prior = extract(idata, var_names=var_name, group="prior").values + + if not isinstance(ref_val, int | float): + raise ValueError("The reference value (ref_val) must be a numerical value (int or float).") + + if ref_val > posterior.max() or ref_val < posterior.min(): + warnings.warn( + "The reference value is outside the posterior range. " + "This results in infinite support for H1, which may overstate evidence." + ) + + prior_at_ref_val = 0 + posterior_at_ref_val = 0 + + if posterior.dtype.kind == "f": + # pylint: disable=W0212 + density_instance = _DensityBase() + posterior_grid, posterior_pdf, _ = density_instance._kde( + x=posterior, grid_len=512, circular=False + ) + prior_grid, prior_pdf, _ = density_instance._kde(x=prior, grid_len=512, circular=False) + + posterior_at_ref_val = np.interp(ref_val, posterior_grid, posterior_pdf) + prior_at_ref_val = np.interp(ref_val, prior_grid, prior_pdf) + + elif posterior.dtype.kind == "i": + posterior_at_ref_val = (posterior == ref_val).mean() + prior_at_ref_val = (prior == ref_val).mean() + + bf_10 = prior_at_ref_val / posterior_at_ref_val + bf = {"BF10": bf_10, "BF01": 1 / bf_10} + + if return_ref_vals: + return (bf, {"prior": prior_at_ref_val, "posterior": posterior_at_ref_val}) + return bf diff --git a/tests/test_bayes_factor.py b/tests/test_bayes_factor.py new file mode 100644 index 0000000..4f540ea --- /dev/null +++ b/tests/test_bayes_factor.py @@ -0,0 +1,73 @@ +import numpy as np +import pytest +from arviz_base import from_dict + +from src.arviz_stats.bayes_factor import bayes_factor + + +def test_bayes_factor_comparison(): + idata = from_dict( + { + "posterior": {"a": np.random.normal(1, 0.5, (2, 1000))}, + "prior": {"a": np.random.normal(0, 1, (2, 1000))}, + } + ) + bf_dict0 = bayes_factor(idata=idata, var_name="a", ref_val=0) + custom_prior = np.random.normal(1, 2, 5000) + bf_dict1 = bayes_factor(idata=idata, var_name="a", prior={"a": custom_prior}, ref_val=1) + assert "BF10" in bf_dict0 + assert "BF01" in bf_dict0 + assert bf_dict0["BF10"] > bf_dict0["BF01"] + assert "BF10" in bf_dict1 + assert "BF01" in bf_dict1 + assert bf_dict1["BF10"] < bf_dict1["BF01"] + + +def test_bayes_factor_invalid_ref_val(): + idata = from_dict( + { + "posterior": {"a": np.random.normal(1, 0.5, (2, 1000))}, + "prior": {"a": np.random.normal(0, 1, (2, 1000))}, + } + ) + with pytest.raises(ValueError, match="The reference value.*must be a numerical value.*"): + bayes_factor(idata=idata, var_name="a", ref_val="invalid") + + +def test_bayes_factor_custom_prior(): + posterior_data = np.random.normal(1, 0.5, (2, 1000)) + prior_data = np.random.normal(0, 1, (2, 1000)) + custom_prior = np.random.normal(0, 10, (2, 1000)) + idata = from_dict({"posterior": {"a": posterior_data}, "prior": {"a": prior_data}}) + result = bayes_factor(idata=idata, var_name="a", prior={"a": custom_prior}, ref_val=0) + assert "BF10" in result + assert "BF01" in result + assert result["BF10"] > 0 + assert result["BF01"] > 0 + + +def test_bayes_factor_different_ref_vals(): + idata = from_dict( + { + "posterior": {"a": np.random.normal(1, 0.5, (2, 1000))}, + "prior": {"a": np.random.normal(0, 1, (2, 1000))}, + } + ) + ref_vals = [-1, 0, 1] + for ref_val in ref_vals: + result = bayes_factor(idata=idata, var_name="a", ref_val=ref_val) + assert "BF10" in result + assert "BF01" in result + assert result["BF10"] > 0 + assert result["BF01"] > 0 + + +def test_bayes_factor_large_data(): + posterior_data = np.random.normal(1, 0.5, (2, 1000)) + prior_data = np.random.normal(0, 1, (2, 1000)) + idata = from_dict({"posterior": {"a": posterior_data}, "prior": {"a": prior_data}}) + result = bayes_factor(idata=idata, var_name="a", ref_val=0) + assert "BF10" in result + assert "BF01" in result + assert result["BF10"] > 0 + assert result["BF01"] > 0