Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow custom groups without warnings #2401

Merged
merged 2 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@

## Unreleased

### New features

### Maintenance and fixes
- Make `arviz.data.generate_dims_coords` handle `dims` and `default_dims` consistently ([2395](https://github.com/arviz-devs/arviz/pull/2395))
- Only emit a warning for custom groups in `InferenceData` when explicitly requested ([2401](https://github.com/arviz-devs/arviz/pull/2401))

### Documentation

## v0.20.0 (2024 Sep 28)

Expand Down
27 changes: 20 additions & 7 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
def __init__(
self,
attrs: Union[None, Mapping[Any, Any]] = None,
warn_on_custom_groups: bool = False,
**kwargs: Union[xr.Dataset, List[xr.Dataset], Tuple[xr.Dataset, xr.Dataset]],
) -> None:
"""Initialize InferenceData object from keyword xarray datasets.
Expand All @@ -110,6 +111,9 @@ def __init__(
----------
attrs : dict
sets global attribute for InferenceData object.
warn_on_custom_groups : bool, default False
Emit a warning when custom groups are present in the InferenceData.
"custom group" means any group whose name isn't defined in :ref:`schema`
kwargs :
Keyword arguments of xarray datasets

Expand Down Expand Up @@ -154,9 +158,10 @@ def __init__(
for key in kwargs:
if key not in SUPPORTED_GROUPS_ALL:
key_list.append(key)
warnings.warn(
f"{key} group is not defined in the InferenceData scheme", UserWarning
)
if warn_on_custom_groups:
warnings.warn(
f"{key} group is not defined in the InferenceData scheme", UserWarning
)
for key in key_list:
dataset = kwargs[key]
dataset_warmup = None
Expand Down Expand Up @@ -1467,7 +1472,9 @@ def rename_dims(self, name_dict=None, groups=None, filter_groups=None, inplace=F
else:
return out

def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs):
def add_groups(
self, group_dict=None, coords=None, dims=None, warn_on_custom_groups=False, **kwargs
):
"""Add new groups to InferenceData object.

Parameters
Expand All @@ -1479,6 +1486,9 @@ def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs):
dims : dict of {str : list of str}, optional
Dimensions of each variable. The keys are variable names, values are lists of
coordinates.
warn_on_custom_groups : bool, default False
Emit a warning when custom groups are present in the InferenceData.
"custom group" means any group whose name isn't defined in :ref:`schema`
kwargs : dict, optional
The keyword arguments form of group_dict. One of group_dict or kwargs must be provided.

Expand Down Expand Up @@ -1542,7 +1552,7 @@ def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs):
if repeated_groups:
raise ValueError(f"{repeated_groups} group(s) already exists.")
for group, dataset in group_dict.items():
if group not in SUPPORTED_GROUPS_ALL:
if warn_on_custom_groups and group not in SUPPORTED_GROUPS_ALL:
warnings.warn(
f"The group {group} is not defined in the InferenceData scheme",
UserWarning,
Expand Down Expand Up @@ -1597,7 +1607,7 @@ def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs):
else:
self._groups.append(group)

def extend(self, other, join="left"):
def extend(self, other, join="left", warn_on_custom_groups=False):
"""Extend InferenceData with groups from another InferenceData.

Parameters
Expand All @@ -1608,6 +1618,9 @@ def extend(self, other, join="left"):
Defines how the two decide which group to keep when the same group is
present in both objects. 'left' will discard the group in ``other`` whereas 'right'
will keep the group in ``other`` and discard the one in ``self``.
warn_on_custom_groups : bool, default False
Emit a warning when custom groups are present in the InferenceData.
"custom group" means any group whose name isn't defined in :ref:`schema`

Examples
--------
Expand Down Expand Up @@ -1651,7 +1664,7 @@ def extend(self, other, join="left"):
for group in other._groups_all: # pylint: disable=protected-access
if hasattr(self, group) and join == "left":
continue
if group not in SUPPORTED_GROUPS_ALL:
if warn_on_custom_groups and group not in SUPPORTED_GROUPS_ALL:
warnings.warn(
f"{group} group is not defined in the InferenceData scheme", UserWarning
)
Expand Down
19 changes: 13 additions & 6 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import importlib
import os
import warnings
from collections import namedtuple
from copy import deepcopy
from html import escape
Expand Down Expand Up @@ -938,7 +939,7 @@ def test_add_groups_warning(self, data_random):
data = np.random.normal(size=(4, 500, 8))
idata = data_random
with pytest.warns(UserWarning, match="The group.+not defined in the InferenceData scheme"):
idata.add_groups({"new_group": idata.posterior})
idata.add_groups({"new_group": idata.posterior}, warn_on_custom_groups=True)
with pytest.warns(UserWarning, match="the default dims.+will be added automatically"):
idata.add_groups(constant_data={"a": data[..., 0], "b": data})
assert idata.new_group.equals(idata.posterior)
Expand Down Expand Up @@ -979,8 +980,8 @@ def test_extend_errors_warnings(self, data_random):
with pytest.raises(ValueError, match="join must be either"):
idata.extend(idata2, join="outer")
idata2.add_groups(new_group=idata2.prior)
with pytest.warns(UserWarning):
idata.extend(idata2)
with pytest.warns(UserWarning, match="new_group"):
idata.extend(idata2, warn_on_custom_groups=True)


class TestNumpyToDataArray:
Expand Down Expand Up @@ -1173,11 +1174,17 @@ def test_bad_inference_data():
InferenceData(posterior=[1, 2, 3])


def test_inference_data_other_groups():
@pytest.mark.parametrize("warn", [True, False])
def test_inference_data_other_groups(warn):
datadict = {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)}
dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={"b": ["c"]})
with pytest.warns(UserWarning, match="not.+in.+InferenceData scheme"):
idata = InferenceData(other_group=dataset)
if warn:
with pytest.warns(UserWarning, match="not.+in.+InferenceData scheme"):
idata = InferenceData(other_group=dataset, warn_on_custom_groups=True)
else:
with warnings.catch_warnings():
warnings.simplefilter("error")
idata = InferenceData(other_group=dataset, warn_on_custom_groups=False)
fails = check_multiple_attrs({"other_group": ["a", "b"]}, idata)
assert not fails

Expand Down