Skip to content

Commit

Permalink
don't warn on custom groups unless requested
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Dec 9, 2024
1 parent 529d795 commit 620b391
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
27 changes: 20 additions & 7 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
def __init__(
self,
attrs: Union[None, Mapping[Any, Any]] = None,
warn_on_custom_groups: bool = False,
**kwargs: Union[xr.Dataset, List[xr.Dataset], Tuple[xr.Dataset, xr.Dataset]],
) -> None:
"""Initialize InferenceData object from keyword xarray datasets.
Expand All @@ -110,6 +111,9 @@ def __init__(
----------
attrs : dict
sets global attribute for InferenceData object.
warn_on_custom_groups : bool, default False
Emit a warning when custom groups are present in the InferenceData.
"custom group" means any group whose name isn't defined in :ref:`schema`
kwargs :
Keyword arguments of xarray datasets
Expand Down Expand Up @@ -154,9 +158,10 @@ def __init__(
for key in kwargs:
if key not in SUPPORTED_GROUPS_ALL:
key_list.append(key)
warnings.warn(
f"{key} group is not defined in the InferenceData scheme", UserWarning
)
if warn_on_custom_groups:
warnings.warn(
f"{key} group is not defined in the InferenceData scheme", UserWarning
)
for key in key_list:
dataset = kwargs[key]
dataset_warmup = None
Expand Down Expand Up @@ -1467,7 +1472,9 @@ def rename_dims(self, name_dict=None, groups=None, filter_groups=None, inplace=F
else:
return out

def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs):
def add_groups(
self, group_dict=None, coords=None, dims=None, warn_on_custom_groups=False, **kwargs
):
"""Add new groups to InferenceData object.
Parameters
Expand All @@ -1479,6 +1486,9 @@ def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs):
dims : dict of {str : list of str}, optional
Dimensions of each variable. The keys are variable names, values are lists of
coordinates.
warn_on_custom_groups : bool, default False
Emit a warning when custom groups are present in the InferenceData.
"custom group" means any group whose name isn't defined in :ref:`schema`
kwargs : dict, optional
The keyword arguments form of group_dict. One of group_dict or kwargs must be provided.
Expand Down Expand Up @@ -1542,7 +1552,7 @@ def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs):
if repeated_groups:
raise ValueError(f"{repeated_groups} group(s) already exists.")
for group, dataset in group_dict.items():
if group not in SUPPORTED_GROUPS_ALL:
if warn_on_custom_groups and group not in SUPPORTED_GROUPS_ALL:
warnings.warn(
f"The group {group} is not defined in the InferenceData scheme",
UserWarning,
Expand Down Expand Up @@ -1597,7 +1607,7 @@ def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs):
else:
self._groups.append(group)

def extend(self, other, join="left"):
def extend(self, other, join="left", warn_on_custom_groups=False):
"""Extend InferenceData with groups from another InferenceData.
Parameters
Expand All @@ -1608,6 +1618,9 @@ def extend(self, other, join="left"):
Defines how the two decide which group to keep when the same group is
present in both objects. 'left' will discard the group in ``other`` whereas 'right'
will keep the group in ``other`` and discard the one in ``self``.
warn_on_custom_groups : bool, default False
Emit a warning when custom groups are present in the InferenceData.
"custom group" means any group whose name isn't defined in :ref:`schema`
Examples
--------
Expand Down Expand Up @@ -1651,7 +1664,7 @@ def extend(self, other, join="left"):
for group in other._groups_all: # pylint: disable=protected-access
if hasattr(self, group) and join == "left":
continue
if group not in SUPPORTED_GROUPS_ALL:
if warn_on_custom_groups and group not in SUPPORTED_GROUPS_ALL:
warnings.warn(
f"{group} group is not defined in the InferenceData scheme", UserWarning
)
Expand Down
2 changes: 1 addition & 1 deletion arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ def test_add_groups_warning(self, data_random):
data = np.random.normal(size=(4, 500, 8))
idata = data_random
with pytest.warns(UserWarning, match="The group.+not defined in the InferenceData scheme"):
idata.add_groups({"new_group": idata.posterior})
idata.add_groups({"new_group": idata.posterior}, warn_on_custom_groups=True)
with pytest.warns(UserWarning, match="the default dims.+will be added automatically"):
idata.add_groups(constant_data={"a": data[..., 0], "b": data})
assert idata.new_group.equals(idata.posterior)
Expand Down

0 comments on commit 620b391

Please sign in to comment.