-
-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1389a8b
commit 3feaf2e
Showing
2 changed files
with
201 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,92 +1,168 @@ | ||
import numpy as np | ||
import pymc as pm | ||
import pytensor | ||
import pytensor.tensor as pt | ||
|
||
from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs | ||
from pytensor.graph.op import Apply, Op | ||
from numpy.core.numeric import normalize_axis_tuple | ||
from pymc.distributions.distribution import Continuous | ||
from pytensor.compile.builders import OpFromGraph | ||
from pytensor.tensor.einsum import _delta | ||
|
||
# from pymc.logprob.abstract import MeasurableOp | ||
|
||
class Cov(Op): | ||
__props__ = ("fn",) | ||
|
||
def __init__(self, fn): | ||
self.fn = fn | ||
class GPCovariance(OpFromGraph): | ||
"""OFG representing a GP covariance""" | ||
|
||
def make_node(self, ls): | ||
ls = pt.as_tensor(ls) | ||
out = pt.matrix(shape=(None, None)) | ||
|
||
return Apply(self, [ls], [out]) | ||
|
||
def __call__(self, ls=1.0): | ||
return super().__call__(ls) | ||
|
||
def perform(self, node, inputs, output_storage): | ||
raise NotImplementedError("You should convert Cov into a TensorVariable expression!") | ||
|
||
def do_constant_folding(self, fgraph, node): | ||
return False | ||
@staticmethod | ||
def square_dist(X, ls): | ||
X = X / ls | ||
X2 = pt.sum(pt.square(X), axis=-1) | ||
sqd = -2.0 * X @ X.mT + (X2[..., :, None] + X2[..., None, :]) | ||
|
||
return sqd | ||
|
||
class GP(Op): | ||
__props__ = ("approx",) | ||
|
||
def __init__(self, approx): | ||
self.approx = approx | ||
class ExpQuadCov(GPCovariance): | ||
""" | ||
ExpQuad covariance function | ||
""" | ||
|
||
def make_node(self, mean, cov): | ||
mean = pt.as_tensor(mean) | ||
cov = pt.as_tensor(cov) | ||
|
||
if not (cov.owner and isinstance(cov.owner.op, Cov)): | ||
raise ValueError("Second argument should be a Cov output.") | ||
|
||
out = pt.vector(shape=(None,)) | ||
@classmethod | ||
def exp_quad_full(cls, X, ls): | ||
return pt.exp(-0.5 * cls.square_dist(X, ls)) | ||
|
||
return Apply(self, [mean, cov], [out]) | ||
@classmethod | ||
def build_covariance(cls, X, ls): | ||
X = pt.as_tensor(X) | ||
ls = pt.as_tensor(ls) | ||
|
||
def perform(self, node, inputs, output_storage): | ||
raise NotImplementedError("You cannot evaluate a GP, not enough RAM in the Universe.") | ||
ofg = cls(inputs=[X, ls], outputs=[cls.exp_quad_full(X, ls)]) | ||
return ofg(X, ls) | ||
|
||
def do_constant_folding(self, fgraph, node): | ||
return False | ||
|
||
def ExpQuad(X, ls): | ||
return ExpQuadCov.build_covariance(X, ls) | ||
|
||
class PriorFromGP(Op): | ||
"""This Op will be replaced by the right MvNormal.""" | ||
|
||
def make_node(self, gp, x, rng): | ||
gp = pt.as_tensor(gp) | ||
if not (gp.owner and isinstance(gp.owner.op, GP)): | ||
raise ValueError("First argument should be a GP output.") | ||
class WhiteNoiseCov(GPCovariance): | ||
@classmethod | ||
def white_noise_full(cls, X, sigma): | ||
X_shape = tuple(X.shape) | ||
shape = X_shape[:-1] + (X_shape[-2],) | ||
|
||
# TODO: Assert RNG has the right type | ||
x = pt.as_tensor(x) | ||
out = x.type() | ||
return _delta(shape, normalize_axis_tuple((-1, -2), X.ndim)) * sigma**2 | ||
|
||
return Apply(self, [gp, x, rng], [out]) | ||
@classmethod | ||
def build_covariance(cls, X, sigma): | ||
X = pt.as_tensor(X) | ||
sigma = pt.as_tensor(sigma) | ||
|
||
def __call__(self, gp, x, rng=None): | ||
if rng is None: | ||
rng = pytensor.shared(np.random.default_rng()) | ||
return super().__call__(gp, x, rng) | ||
ofg = cls(inputs=[X, sigma], outputs=[cls.white_noise_full(X, sigma)]) | ||
return ofg(X, sigma) | ||
|
||
def perform(self, node, inputs, output_storage): | ||
raise NotImplementedError("You should convert PriorFromGP into a MvNormal!") | ||
|
||
def do_constant_folding(self, fgraph, node): | ||
return False | ||
def WhiteNoise(X, sigma): | ||
return WhiteNoiseCov.build_covariance(X, sigma) | ||
|
||
|
||
cov_op = Cov(fn=pm.gp.cov.ExpQuad) | ||
gp_op = GP("vanilla") | ||
# SymbolicRandomVariable.register(type(gp_op)) | ||
prior_from_gp = PriorFromGP() | ||
class GP_RV(pm.MvNormal.rv_type): | ||
name = "gaussian_process" | ||
signature = "(n),(n,n)->(n)" | ||
dtype = "floatX" | ||
_print_name = ("GP", "\\operatorname{GP}") | ||
|
||
MeasurableVariable.register(type(prior_from_gp)) | ||
|
||
class GP(Continuous): | ||
rv_type = GP_RV | ||
rv_op = GP_RV() | ||
|
||
@_get_measurable_outputs.register(type(prior_from_gp)) | ||
def gp_measurable_outputs(op, node): | ||
return node.outputs | ||
@classmethod | ||
def dist(cls, cov, **kwargs): | ||
cov = pt.as_tensor(cov) | ||
mu = pt.zeros(cov.shape[-1]) | ||
return super().dist([mu, cov], **kwargs) | ||
|
||
|
||
# @register_canonicalize | ||
# @node_rewriter(tracks=[pm.MvNormal.rv_type]) | ||
# def GP_normal_mvnormal_conjugacy(fgraph: FunctionGraph, node): | ||
# # TODO: Should this alert users that it can't be applied when the GP is in a deterministic? | ||
# gp_rng, gp_size, mu, cov = node.inputs | ||
# next_gp_rng, gp_rv = node.outputs | ||
# | ||
# if not isinstance(cov.owner.op, GPCovariance): | ||
# return | ||
# | ||
# for client, input_index in fgraph.clients[gp_rv]: | ||
# # input_index is 2 because it goes (rng, size, mu, sigma), and we want the mu | ||
# # to be the GP we're looking | ||
# if isinstance(client.op, pm.Normal.rv_type) and (input_index == 2): | ||
# next_normal_rng, normal_rv = client.outputs | ||
# normal_rng, normal_size, mu, sigma = client.inputs | ||
# | ||
# if normal_rv.ndim != gp_rv.ndim: | ||
# return | ||
# | ||
# X = cov.owner.inputs[0] | ||
# | ||
# white_noise = WhiteNoiseCov.build_covariance(X, sigma) | ||
# white_noise.name = 'WhiteNoiseCov' | ||
# cov = cov + white_noise | ||
# | ||
# if not rv_size_is_none(normal_size): | ||
# normal_size = tuple(normal_size) | ||
# new_gp_size = normal_size[:-1] | ||
# core_shape = normal_size[-1] | ||
# | ||
# cov_shape = (*(None,) * (cov.ndim - 2), core_shape, core_shape) | ||
# cov = pt.specify_shape(cov, cov_shape) | ||
# | ||
# else: | ||
# new_gp_size = None | ||
# | ||
# next_new_gp_rng, new_gp_mvn = pm.MvNormal.dist(cov=cov, rng=gp_rng, size=new_gp_size).owner.outputs | ||
# new_gp_mvn.name = 'NewGPMvn' | ||
# | ||
# # Check that the new shape is at least as specific as the shape we are replacing | ||
# for new_shape, old_shape in zip(new_gp_mvn.type.shape, normal_rv.type.shape, strict=True): | ||
# if new_shape is None: | ||
# assert old_shape is None | ||
# | ||
# return { | ||
# next_normal_rng: next_new_gp_rng, | ||
# normal_rv: new_gp_mvn, | ||
# next_gp_rng: next_new_gp_rng | ||
# } | ||
# | ||
# else: | ||
# return None | ||
# | ||
# #TODO: Why do I need to register this twice? | ||
# specialization_ir_rewrites_db.register( | ||
# GP_normal_mvnormal_conjugacy.__name__, | ||
# GP_normal_mvnormal_conjugacy, | ||
# "basic", | ||
# ) | ||
|
||
# @node_rewriter(tracks=[pm.MvNormal.rv_type]) | ||
# def GP_normal_marginal_logp(fgraph: FunctionGraph, node): | ||
# """ | ||
# Replace Normal(GP(cov), sigma) -> MvNormal(0, cov + diag(sigma)). | ||
# """ | ||
# rng, size, mu, cov = node.inputs | ||
# if cov.owner and cov.owner.op == matrix_inverse: | ||
# tau = cov.owner.inputs[0] | ||
# return PrecisionMvNormalRV.rv_op(mu, tau, size=size, rng=rng).owner.outputs | ||
# return None | ||
# | ||
|
||
# cov_op = GPCovariance() | ||
# gp_op = GP("vanilla") | ||
# # SymbolicRandomVariable.register(type(gp_op)) | ||
# prior_from_gp = PriorFromGP() | ||
# | ||
# MeasurableVariable.register(type(prior_from_gp)) | ||
# | ||
# | ||
# @_get_measurable_outputs.register(type(prior_from_gp)) | ||
# def gp_measurable_outputs(op, node): | ||
# return node.outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import numpy as np | ||
import pymc as pm | ||
import pytensor.tensor as pt | ||
import pytest | ||
|
||
from pymc_experimental.gp.pytensor_gp import GP, ExpQuad | ||
|
||
|
||
def test_exp_quad(): | ||
x = pt.arange(3)[:, None] | ||
ls = pt.ones(()) | ||
cov = ExpQuad.build_covariance(x, ls).eval() | ||
expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]]) | ||
|
||
np.testing.assert_allclose(cov, np.exp(-0.5 * expected_distance)) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def marginal_model(): | ||
with pm.Model() as m: | ||
X = pm.Data("X", np.arange(3)[:, None]) | ||
y = np.full(3, np.pi) | ||
ls = 1.0 | ||
cov = ExpQuad(X, ls) | ||
gp = GP("gp", cov=cov) | ||
|
||
sigma = 1.0 | ||
obs = pm.Normal("obs", mu=gp, sigma=sigma, observed=y) | ||
|
||
return m | ||
|
||
|
||
def test_marginal_sigma_rewrites_to_white_noise_cov(marginal_model): | ||
obs = marginal_model["obs"] | ||
|
||
# TODO: Bring these checks back after we implement marginalization of the GP RV | ||
# | ||
# assert sum(isinstance(var.owner.op, pm.Normal.rv_type) | ||
# for var in ancestors([obs]) | ||
# if var.owner is not None) == 1 | ||
# | ||
f = pm.compile_pymc([], obs) | ||
# | ||
# assert not any(isinstance(node.op, pm.Normal.rv_type) for node in f.maker.fgraph.apply_nodes) | ||
|
||
draws = np.stack([f() for _ in range(10_000)]) | ||
empirical_cov = np.cov(draws.T) | ||
|
||
expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]]) | ||
|
||
np.testing.assert_allclose( | ||
empirical_cov, np.exp(-0.5 * expected_distance) + np.eye(3), atol=0.1, rtol=0.1 | ||
) | ||
|
||
|
||
def test_marginal_gp_logp(marginal_model): | ||
expected_logps = {"obs": -8.8778} | ||
point_logps = marginal_model.point_logps(round_vals=4) | ||
for v1, v2 in zip(point_logps.values(), expected_logps.values()): | ||
np.testing.assert_allclose(v1, v2, atol=1e-6) |