From 9c7a6fb417a18b44cafeafbd866c3c7b3f1c7ebb Mon Sep 17 00:00:00 2001 From: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> Date: Mon, 13 Jan 2025 01:05:31 +0800 Subject: [PATCH] Check model coords for unknown shapes when building predictive models (#413) --- pymc_extras/statespace/core/statespace.py | 23 +++++++++- tests/statespace/test_statespace.py | 54 +++++++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 92ca4e9e..2590dd53 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -983,10 +983,31 @@ def _build_dummy_graph(self) -> None: list[pm.Flat] A list of pm.Flat variables representing all parameters estimated by the model. """ + + def infer_variable_shape(name): + shape = self._name_to_variable[name].type.shape + if not any(dim is None for dim in shape): + return shape + + dim_names = self._fit_dims.get(name, None) + if dim_names is None: + raise ValueError( + f"Could not infer shape for {name}, because it was not given coords during model" + f"fitting" + ) + + shape_from_coords = tuple([len(self._fit_coords[dim]) for dim in dim_names]) + return tuple( + [ + shape[i] if shape[i] is not None else shape_from_coords[i] + for i in range(len(shape)) + ] + ) + for name in self.param_names: pm.Flat( name, - shape=self._name_to_variable[name].type.shape, + shape=infer_variable_shape(name), dims=self._fit_dims.get(name, None), ) diff --git a/tests/statespace/test_statespace.py b/tests/statespace/test_statespace.py index b9b78dff..2b0a1140 100644 --- a/tests/statespace/test_statespace.py +++ b/tests/statespace/test_statespace.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from functools import partial import numpy as np @@ -349,6 +350,59 @@ def test_sampling_methods(group, kind, ss_mod, idata, rng): assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values)) +@pytest.mark.filterwarnings("ignore:Provided data contains missing values") +def test_sample_conditional_with_time_varying(): + class TVCovariance(PyMCStateSpace): + def __init__(self): + super().__init__(k_states=1, k_endog=1, k_posdef=1) + + def make_symbolic_graph(self) -> None: + self.ssm["transition", 0, 0] = 1.0 + + self.ssm["design", 0, 0] = 1.0 + + sigma_cov = self.make_and_register_variable("sigma_cov", (None,)) + self.ssm["state_cov"] = sigma_cov[:, None, None] ** 2 + + @property + def param_names(self) -> list[str]: + return ["sigma_cov"] + + @property + def coords(self) -> dict[str, Sequence[str]]: + return make_default_coords(self) + + @property + def state_names(self) -> list[str]: + return ["level"] + + @property + def observed_states(self) -> list[str]: + return ["level"] + + @property + def shock_names(self) -> list[str]: + return ["level"] + + ss_mod = TVCovariance() + empty_data = pd.DataFrame( + np.nan, index=pd.date_range("2020-01-01", periods=100, freq="D"), columns=["data"] + ) + + coords = ss_mod.coords + coords["time"] = empty_data.index + with pm.Model(coords=coords) as mod: + log_sigma_cov = pm.Normal("log_sigma_cov", mu=0, sigma=0.1, dims=["time"]) + pm.Deterministic("sigma_cov", pm.math.exp(log_sigma_cov.cumsum()), dims=["time"]) + + ss_mod.build_statespace_graph(data=empty_data) + + prior = pm.sample_prior_predictive(10) + + ss_mod.sample_unconditional_prior(prior) + ss_mod.sample_conditional_prior(prior) + + def _make_time_idx(mod, use_datetime_index=True): if use_datetime_index: mod._fit_coords["time"] = nile.index