Skip to content

Commit

Permalink
Remove non-existent filter
Browse files Browse the repository at this point in the history
  • Loading branch information
fonnesbeck committed Dec 11, 2024
1 parent 2bdb27c commit 5dd7406
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 16 deletions.
1 change: 0 additions & 1 deletion docs/statespace/filters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,4 @@ Kalman Filters
SteadyStateFilter
KalmanSmoother
SingleTimeseriesFilter
CholeskyFilter
LinearGaussianStateSpace
2 changes: 0 additions & 2 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from pymc_extras.statespace.core.representation import PytensorRepresentation
from pymc_extras.statespace.filters import (
CholeskyFilter,
KalmanSmoother,
StandardFilter,
SteadyStateFilter,
Expand Down Expand Up @@ -52,7 +51,6 @@
FILTER_FACTORY = {
"standard": StandardFilter,
"univariate": UnivariateFilter,
"cholesky": CholeskyFilter,
"steady_state": SteadyStateFilter,
}

Expand Down
2 changes: 0 additions & 2 deletions pymc_extras/statespace/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace
from pymc_extras.statespace.filters.kalman_filter import (
CholeskyFilter,
SingleTimeseriesFilter,
StandardFilter,
SteadyStateFilter,
Expand All @@ -14,6 +13,5 @@
"SteadyStateFilter",
"KalmanSmoother",
"SingleTimeseriesFilter",
"CholeskyFilter",
"LinearGaussianStateSpace",
]
20 changes: 9 additions & 11 deletions tests/statespace/test_kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from numpy.testing import assert_allclose, assert_array_less

from pymc_extras.statespace.filters import (
CholeskyFilter,
KalmanSmoother,
SingleTimeseriesFilter,
StandardFilter,
Expand All @@ -33,18 +32,17 @@
RTOL = 1e-6 if floatX.endswith("64") else 1e-3

standard_inout = initialize_filter(StandardFilter())
cholesky_inout = initialize_filter(CholeskyFilter())
# cholesky_inout = initialize_filter(CholeskyFilter())
univariate_inout = initialize_filter(UnivariateFilter())

f_standard = pytensor.function(*standard_inout, on_unused_input="ignore")
f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore")
# f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore")
f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore")

filter_funcs = [f_standard, f_cholesky, f_univariate]
filter_funcs = [f_standard, f_univariate]

filter_names = [
"StandardFilter",
"CholeskyFilter",
"UnivariateFilter",
]

Expand Down Expand Up @@ -233,8 +231,8 @@ def test_last_smoother_is_last_filtered(filter_func, output_idx, rng):
@pytest.mark.skipif(floatX == "float32", reason="Tests are too sensitive for float32")
def test_filters_match_statsmodel_output(filter_func, filter_name, n_missing, rng):
fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing)
if filter_name == "CholeskyFilter":
P0 = np.linalg.cholesky(P0)
# if filter_name == "CholeskyFilter":
# P0 = np.linalg.cholesky(P0)
inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
outputs = filter_func(*inputs)

Expand Down Expand Up @@ -282,8 +280,8 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob
pytest.skip("Univariate filter not stable at half precision without measurement error")

fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing)
if filter_name == "CholeskyFilter":
P0 = np.linalg.cholesky(P0)
# if filter_name == "CholeskyFilter":
# P0 = np.linalg.cholesky(P0)

H *= int(obs_noise)
inputs = [data, a0, P0, c, d, T, Z, R, H, Q]
Expand All @@ -305,8 +303,8 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob

@pytest.mark.parametrize(
"filter",
[StandardFilter, CholeskyFilter],
ids=["standard", "cholesky"],
[StandardFilter],
ids=["standard"],
)
def test_kalman_filter_jax(filter):
pytest.importorskip("jax")
Expand Down

0 comments on commit 5dd7406

Please sign in to comment.