Skip to content

Commit

Permalink
Add bayes_factor function (#52)
Browse files Browse the repository at this point in the history
* bayes_factor computation functionality

* Update tests/test_bayes_factor.py

* fixing imports

* documentation update

* arrange alphabetically

---------

Co-authored-by: Osvaldo A Martin <[email protected]>
  • Loading branch information
PiyushPanwarFST and aloctavodia authored Jan 31, 2025
1 parent dcf0a61 commit 51b2aa9
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
.. autosummary::
:toctree: generated/
arviz_stats.bayes_factor
arviz_stats.ess
arviz_stats.mcse
arviz_stats.psense
Expand Down
84 changes: 84 additions & 0 deletions src/arviz_stats/bayes_factor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Bayes Factor using Savage-Dickey density ratio."""

import warnings

import numpy as np
from arviz_base import extract

from arviz_stats.base.density import _DensityBase


def bayes_factor(idata, var_name, ref_val=0, return_ref_vals=False, prior=None):
"""
Approximated Bayes Factor for comparing hypothesis of two nested models.
Parameters
----------
idata : InferenceData
Object containing posterior and prior data.
var_name : str
Name of the variable to test.
ref_val : int or float, default 0
Reference (point-null) value for Bayes factor estimation.
return_ref_vals : bool, default False
If True, also return the values of prior and posterior densities at the reference value.
Returns
-------
dict
A dictionary with Bayes Factor values: BF10 (H1/H0 ratio) and BF01 (H0/H1 ratio).
References
----------
.. [1] Heck, D., 2019. A caveat on the Savage-Dickey density ratio:
The case of computing Bayes factors for regression parameters.
Examples
--------
Moderate evidence indicating that the parameter "a" is different from zero.
.. ipython::
In [1]: import numpy as np
...: from arviz_base import from_dict
...: import arviz_stats as azs
...: idata = from_dict({"posterior":{"a":np.random.normal(1, 0.5, (2, 1000))},
...: {"prior":{"a":np.random.normal(0, 1, (2, 1000))}}})
...: azs.plot_bf(idata, var_name="a", ref_val=0)
"""
posterior = extract(idata, var_names=var_name).values
prior = extract(idata, var_names=var_name, group="prior").values

if not isinstance(ref_val, int | float):
raise ValueError("The reference value (ref_val) must be a numerical value (int or float).")

if ref_val > posterior.max() or ref_val < posterior.min():
warnings.warn(
"The reference value is outside the posterior range. "
"This results in infinite support for H1, which may overstate evidence."
)

prior_at_ref_val = 0
posterior_at_ref_val = 0

if posterior.dtype.kind == "f":
# pylint: disable=W0212
density_instance = _DensityBase()
posterior_grid, posterior_pdf, _ = density_instance._kde(
x=posterior, grid_len=512, circular=False
)
prior_grid, prior_pdf, _ = density_instance._kde(x=prior, grid_len=512, circular=False)

posterior_at_ref_val = np.interp(ref_val, posterior_grid, posterior_pdf)
prior_at_ref_val = np.interp(ref_val, prior_grid, prior_pdf)

elif posterior.dtype.kind == "i":
posterior_at_ref_val = (posterior == ref_val).mean()
prior_at_ref_val = (prior == ref_val).mean()

bf_10 = prior_at_ref_val / posterior_at_ref_val
bf = {"BF10": bf_10, "BF01": 1 / bf_10}

if return_ref_vals:
return (bf, {"prior": prior_at_ref_val, "posterior": posterior_at_ref_val})
return bf
73 changes: 73 additions & 0 deletions tests/test_bayes_factor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import numpy as np
import pytest
from arviz_base import from_dict

from arviz_stats.bayes_factor import bayes_factor


def test_bayes_factor_comparison():
idata = from_dict(
{
"posterior": {"a": np.random.normal(1, 0.5, (2, 1000))},
"prior": {"a": np.random.normal(0, 1, (2, 1000))},
}
)
bf_dict0 = bayes_factor(idata=idata, var_name="a", ref_val=0)
custom_prior = np.random.normal(1, 2, 5000)
bf_dict1 = bayes_factor(idata=idata, var_name="a", prior={"a": custom_prior}, ref_val=1)
assert "BF10" in bf_dict0
assert "BF01" in bf_dict0
assert bf_dict0["BF10"] > bf_dict0["BF01"]
assert "BF10" in bf_dict1
assert "BF01" in bf_dict1
assert bf_dict1["BF10"] < bf_dict1["BF01"]


def test_bayes_factor_invalid_ref_val():
idata = from_dict(
{
"posterior": {"a": np.random.normal(1, 0.5, (2, 1000))},
"prior": {"a": np.random.normal(0, 1, (2, 1000))},
}
)
with pytest.raises(ValueError, match="The reference value.*must be a numerical value.*"):
bayes_factor(idata=idata, var_name="a", ref_val="invalid")


def test_bayes_factor_custom_prior():
posterior_data = np.random.normal(1, 0.5, (2, 1000))
prior_data = np.random.normal(0, 1, (2, 1000))
custom_prior = np.random.normal(0, 10, (2, 1000))
idata = from_dict({"posterior": {"a": posterior_data}, "prior": {"a": prior_data}})
result = bayes_factor(idata=idata, var_name="a", prior={"a": custom_prior}, ref_val=0)
assert "BF10" in result
assert "BF01" in result
assert result["BF10"] > 0
assert result["BF01"] > 0


def test_bayes_factor_different_ref_vals():
idata = from_dict(
{
"posterior": {"a": np.random.normal(1, 0.5, (2, 1000))},
"prior": {"a": np.random.normal(0, 1, (2, 1000))},
}
)
ref_vals = [-1, 0, 1]
for ref_val in ref_vals:
result = bayes_factor(idata=idata, var_name="a", ref_val=ref_val)
assert "BF10" in result
assert "BF01" in result
assert result["BF10"] > 0
assert result["BF01"] > 0


def test_bayes_factor_large_data():
posterior_data = np.random.normal(1, 0.5, (2, 1000))
prior_data = np.random.normal(0, 1, (2, 1000))
idata = from_dict({"posterior": {"a": posterior_data}, "prior": {"a": prior_data}})
result = bayes_factor(idata=idata, var_name="a", ref_val=0)
assert "BF10" in result
assert "BF01" in result
assert result["BF10"] > 0
assert result["BF01"] > 0

0 comments on commit 51b2aa9

Please sign in to comment.