Skip to content

Commit

Permalink
Add test for compare method standard error sorting consistency (#2350)
Browse files Browse the repository at this point in the history
  • Loading branch information
uzairgheewala committed Jan 12, 2025
1 parent 0fc1117 commit 69cbc02
Showing 1 changed file with 41 additions and 1 deletion.
42 changes: 41 additions & 1 deletion arviz/tests/base_tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from xarray import DataArray, Dataset
from xarray_einstats.stats import XrContinuousRV

from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data
from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data, InferenceData
from ...rcparams import rcParams
from ...stats import (
apply_test_function,
Expand Down Expand Up @@ -882,3 +882,43 @@ def test_bayes_factor():
bf_dict1 = bayes_factor(idata, prior=np.random.normal(0, 10, 5000), var_name="a", ref_val=0)
assert bf_dict0["BF10"] > bf_dict0["BF01"]
assert bf_dict1["BF10"] < bf_dict1["BF01"]

def test_compare_sorting_consistency():
chains, draws = 4, 1000

# Model 1 - good fit
log_lik1 = np.random.normal(-2, 1, size=(chains, draws))
posterior1 = Dataset(
{"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
coords={"chain": range(chains), "draw": range(draws)},
)
log_like1 = Dataset(
{"y": (("chain", "draw"), log_lik1)},
coords={"chain": range(chains), "draw": range(draws)},
)
data1 = InferenceData(posterior=posterior1, log_likelihood=log_like1)

# Model 2 - poor fit (higher variance)
log_lik2 = np.random.normal(-5, 2, size=(chains, draws))
posterior2 = Dataset(
{"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
coords={"chain": range(chains), "draw": range(draws)},
)
log_like2 = Dataset(
{"y": (("chain", "draw"), log_lik2)},
coords={"chain": range(chains), "draw": range(draws)},
)
data2 = InferenceData(posterior=posterior2, log_likelihood=log_like2)

# Compare models in different orders
comp_dict1 = {"M1": data1, "M2": data2}
comp_dict2 = {"M2": data2, "M1": data1}

comparison1 = compare(comp_dict1, method="bb-pseudo-bma")
comparison2 = compare(comp_dict2, method="bb-pseudo-bma")

assert comparison1.index.tolist() == comparison2.index.tolist()

se1 = comparison1["se"].values
se2 = comparison2["se"].values
np.testing.assert_array_almost_equal(se1, se2)

0 comments on commit 69cbc02

Please sign in to comment.