From a86fe7cf8d60df7e143aa9b11ef2521666c1a7b1 Mon Sep 17 00:00:00 2001 From: David Horsley <3401668+dehorsley@users.noreply.github.com> Date: Thu, 27 Apr 2023 03:24:39 +1000 Subject: [PATCH] Fix WhiteNoise subclassing from Covariance (#6674) * fix WhiteNoise subclassing from Covariance (#6673) Since #6458, Covariance is now the base class for kernels/covariance functions with input_dim and active_dims, which does not include WhiteNoise and Constant kernels. * add regression test for #6673 * fix WhiteNoise input to marginal GP --- pymc/gp/cov.py | 2 +- pymc/gp/gp.py | 6 +++--- tests/gp/test_cov.py | 13 +++++++++++++ 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/pymc/gp/cov.py b/pymc/gp/cov.py index 134c3f207ea..c910ce93fd0 100644 --- a/pymc/gp/cov.py +++ b/pymc/gp/cov.py @@ -386,7 +386,7 @@ def full(self, X, Xs=None): return pt.alloc(self.c, X.shape[0], Xs.shape[0]) -class WhiteNoise(Covariance): +class WhiteNoise(BaseCovariance): r""" White noise covariance function. diff --git a/pymc/gp/gp.py b/pymc/gp/gp.py index e8a695787a3..2518b2daa55 100644 --- a/pymc/gp/gp.py +++ b/pymc/gp/gp.py @@ -21,7 +21,7 @@ import pymc as pm -from pymc.gp.cov import Constant, Covariance +from pymc.gp.cov import BaseCovariance, Constant from pymc.gp.mean import Zero from pymc.gp.util import ( JITTER_DEFAULT, @@ -483,7 +483,7 @@ def marginal_likelihood( """ sigma = _handle_sigma_noise_parameters(sigma=sigma, noise=noise) - noise_func = sigma if isinstance(sigma, Covariance) else pm.gp.cov.WhiteNoise(sigma) + noise_func = sigma if isinstance(sigma, BaseCovariance) else pm.gp.cov.WhiteNoise(sigma) mu, cov = self._build_marginal_likelihood(X=X, noise_func=noise_func, jitter=jitter) self.X = X self.y = y @@ -515,7 +515,7 @@ def _get_given_vals(self, given): if all(val in given for val in ["X", "y", "sigma"]): X, y, sigma = given["X"], given["y"], given["sigma"] - noise_func = sigma if isinstance(sigma, Covariance) else pm.gp.cov.WhiteNoise(sigma) + noise_func = sigma if isinstance(sigma, BaseCovariance) else pm.gp.cov.WhiteNoise(sigma) else: X, y, noise_func = self.X, self.y, self.sigma return X, y, noise_func, cov_total, mean_total diff --git a/tests/gp/test_cov.py b/tests/gp/test_cov.py index e671ba8fc31..ba8d41962cc 100644 --- a/tests/gp/test_cov.py +++ b/tests/gp/test_cov.py @@ -95,6 +95,19 @@ def test_inv_rightadd(self): with pytest.raises(ValueError, match=r"cannot combine"): cov = M + pm.gp.cov.ExpQuad(1, 1.0) + def test_rightadd_whitenoise(self): + X = np.linspace(0, 1, 10)[:, None] + with pm.Model() as model: + cov1 = pm.gp.cov.ExpQuad(1, 0.1) + cov2 = pm.gp.cov.WhiteNoise(sigma=1) + cov = cov1 + cov2 + K = cov(X).eval() + npt.assert_allclose(K[0, 1], 0.53940, atol=1e-3) + npt.assert_allclose(K[0, 0], 2, atol=1e-3) + # check diagonal + Kd = cov(X, diag=True).eval() + npt.assert_allclose(np.diag(K), Kd, atol=1e-5) + class TestCovProd: def test_symprod_cov(self):