Skip to content

Commit

Permalink
Fix WhiteNoise subclassing from Covariance (#6674)
Browse files Browse the repository at this point in the history
* fix WhiteNoise subclassing from Covariance (#6673)

Since #6458, Covariance is now the base class for kernels/covariance
functions with input_dim and active_dims, which does not include
WhiteNoise and Constant kernels.

* add regression test for #6673

* fix WhiteNoise input to marginal GP
  • Loading branch information
dehorsley authored Apr 26, 2023
1 parent 55d915c commit a86fe7c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pymc/gp/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def full(self, X, Xs=None):
return pt.alloc(self.c, X.shape[0], Xs.shape[0])


class WhiteNoise(Covariance):
class WhiteNoise(BaseCovariance):
r"""
White noise covariance function.
Expand Down
6 changes: 3 additions & 3 deletions pymc/gp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pymc as pm

from pymc.gp.cov import Constant, Covariance
from pymc.gp.cov import BaseCovariance, Constant
from pymc.gp.mean import Zero
from pymc.gp.util import (
JITTER_DEFAULT,
Expand Down Expand Up @@ -483,7 +483,7 @@ def marginal_likelihood(
"""
sigma = _handle_sigma_noise_parameters(sigma=sigma, noise=noise)

noise_func = sigma if isinstance(sigma, Covariance) else pm.gp.cov.WhiteNoise(sigma)
noise_func = sigma if isinstance(sigma, BaseCovariance) else pm.gp.cov.WhiteNoise(sigma)
mu, cov = self._build_marginal_likelihood(X=X, noise_func=noise_func, jitter=jitter)
self.X = X
self.y = y
Expand Down Expand Up @@ -515,7 +515,7 @@ def _get_given_vals(self, given):

if all(val in given for val in ["X", "y", "sigma"]):
X, y, sigma = given["X"], given["y"], given["sigma"]
noise_func = sigma if isinstance(sigma, Covariance) else pm.gp.cov.WhiteNoise(sigma)
noise_func = sigma if isinstance(sigma, BaseCovariance) else pm.gp.cov.WhiteNoise(sigma)
else:
X, y, noise_func = self.X, self.y, self.sigma
return X, y, noise_func, cov_total, mean_total
Expand Down
13 changes: 13 additions & 0 deletions tests/gp/test_cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,19 @@ def test_inv_rightadd(self):
with pytest.raises(ValueError, match=r"cannot combine"):
cov = M + pm.gp.cov.ExpQuad(1, 1.0)

def test_rightadd_whitenoise(self):
X = np.linspace(0, 1, 10)[:, None]
with pm.Model() as model:
cov1 = pm.gp.cov.ExpQuad(1, 0.1)
cov2 = pm.gp.cov.WhiteNoise(sigma=1)
cov = cov1 + cov2
K = cov(X).eval()
npt.assert_allclose(K[0, 1], 0.53940, atol=1e-3)
npt.assert_allclose(K[0, 0], 2, atol=1e-3)
# check diagonal
Kd = cov(X, diag=True).eval()
npt.assert_allclose(np.diag(K), Kd, atol=1e-5)


class TestCovProd:
def test_symprod_cov(self):
Expand Down

0 comments on commit a86fe7c

Please sign in to comment.