Skip to content

Commit

Permalink
Rough implementation of new API
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Dec 4, 2024
1 parent 1389a8b commit 3feaf2e
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 65 deletions.
206 changes: 141 additions & 65 deletions pymc_experimental/gp/pytensor_gp.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,168 @@
import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt

from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs
from pytensor.graph.op import Apply, Op
from numpy.core.numeric import normalize_axis_tuple
from pymc.distributions.distribution import Continuous
from pytensor.compile.builders import OpFromGraph
from pytensor.tensor.einsum import _delta

# from pymc.logprob.abstract import MeasurableOp

class Cov(Op):
__props__ = ("fn",)

def __init__(self, fn):
self.fn = fn
class GPCovariance(OpFromGraph):
"""OFG representing a GP covariance"""

def make_node(self, ls):
ls = pt.as_tensor(ls)
out = pt.matrix(shape=(None, None))

return Apply(self, [ls], [out])

def __call__(self, ls=1.0):
return super().__call__(ls)

def perform(self, node, inputs, output_storage):
raise NotImplementedError("You should convert Cov into a TensorVariable expression!")

def do_constant_folding(self, fgraph, node):
return False
@staticmethod
def square_dist(X, ls):
X = X / ls
X2 = pt.sum(pt.square(X), axis=-1)
sqd = -2.0 * X @ X.mT + (X2[..., :, None] + X2[..., None, :])

return sqd

class GP(Op):
__props__ = ("approx",)

def __init__(self, approx):
self.approx = approx
class ExpQuadCov(GPCovariance):
"""
ExpQuad covariance function
"""

def make_node(self, mean, cov):
mean = pt.as_tensor(mean)
cov = pt.as_tensor(cov)

if not (cov.owner and isinstance(cov.owner.op, Cov)):
raise ValueError("Second argument should be a Cov output.")

out = pt.vector(shape=(None,))
@classmethod
def exp_quad_full(cls, X, ls):
return pt.exp(-0.5 * cls.square_dist(X, ls))

return Apply(self, [mean, cov], [out])
@classmethod
def build_covariance(cls, X, ls):
X = pt.as_tensor(X)
ls = pt.as_tensor(ls)

def perform(self, node, inputs, output_storage):
raise NotImplementedError("You cannot evaluate a GP, not enough RAM in the Universe.")
ofg = cls(inputs=[X, ls], outputs=[cls.exp_quad_full(X, ls)])
return ofg(X, ls)

def do_constant_folding(self, fgraph, node):
return False

def ExpQuad(X, ls):
return ExpQuadCov.build_covariance(X, ls)

class PriorFromGP(Op):
"""This Op will be replaced by the right MvNormal."""

def make_node(self, gp, x, rng):
gp = pt.as_tensor(gp)
if not (gp.owner and isinstance(gp.owner.op, GP)):
raise ValueError("First argument should be a GP output.")
class WhiteNoiseCov(GPCovariance):
@classmethod
def white_noise_full(cls, X, sigma):
X_shape = tuple(X.shape)
shape = X_shape[:-1] + (X_shape[-2],)

# TODO: Assert RNG has the right type
x = pt.as_tensor(x)
out = x.type()
return _delta(shape, normalize_axis_tuple((-1, -2), X.ndim)) * sigma**2

return Apply(self, [gp, x, rng], [out])
@classmethod
def build_covariance(cls, X, sigma):
X = pt.as_tensor(X)
sigma = pt.as_tensor(sigma)

def __call__(self, gp, x, rng=None):
if rng is None:
rng = pytensor.shared(np.random.default_rng())
return super().__call__(gp, x, rng)
ofg = cls(inputs=[X, sigma], outputs=[cls.white_noise_full(X, sigma)])
return ofg(X, sigma)

def perform(self, node, inputs, output_storage):
raise NotImplementedError("You should convert PriorFromGP into a MvNormal!")

def do_constant_folding(self, fgraph, node):
return False
def WhiteNoise(X, sigma):
return WhiteNoiseCov.build_covariance(X, sigma)


cov_op = Cov(fn=pm.gp.cov.ExpQuad)
gp_op = GP("vanilla")
# SymbolicRandomVariable.register(type(gp_op))
prior_from_gp = PriorFromGP()
class GP_RV(pm.MvNormal.rv_type):
name = "gaussian_process"
signature = "(n),(n,n)->(n)"
dtype = "floatX"
_print_name = ("GP", "\\operatorname{GP}")

MeasurableVariable.register(type(prior_from_gp))

class GP(Continuous):
rv_type = GP_RV
rv_op = GP_RV()

@_get_measurable_outputs.register(type(prior_from_gp))
def gp_measurable_outputs(op, node):
return node.outputs
@classmethod
def dist(cls, cov, **kwargs):
cov = pt.as_tensor(cov)
mu = pt.zeros(cov.shape[-1])
return super().dist([mu, cov], **kwargs)


# @register_canonicalize
# @node_rewriter(tracks=[pm.MvNormal.rv_type])
# def GP_normal_mvnormal_conjugacy(fgraph: FunctionGraph, node):
# # TODO: Should this alert users that it can't be applied when the GP is in a deterministic?
# gp_rng, gp_size, mu, cov = node.inputs
# next_gp_rng, gp_rv = node.outputs
#
# if not isinstance(cov.owner.op, GPCovariance):
# return
#
# for client, input_index in fgraph.clients[gp_rv]:
# # input_index is 2 because it goes (rng, size, mu, sigma), and we want the mu
# # to be the GP we're looking
# if isinstance(client.op, pm.Normal.rv_type) and (input_index == 2):
# next_normal_rng, normal_rv = client.outputs
# normal_rng, normal_size, mu, sigma = client.inputs
#
# if normal_rv.ndim != gp_rv.ndim:
# return
#
# X = cov.owner.inputs[0]
#
# white_noise = WhiteNoiseCov.build_covariance(X, sigma)
# white_noise.name = 'WhiteNoiseCov'
# cov = cov + white_noise
#
# if not rv_size_is_none(normal_size):
# normal_size = tuple(normal_size)
# new_gp_size = normal_size[:-1]
# core_shape = normal_size[-1]
#
# cov_shape = (*(None,) * (cov.ndim - 2), core_shape, core_shape)
# cov = pt.specify_shape(cov, cov_shape)
#
# else:
# new_gp_size = None
#
# next_new_gp_rng, new_gp_mvn = pm.MvNormal.dist(cov=cov, rng=gp_rng, size=new_gp_size).owner.outputs
# new_gp_mvn.name = 'NewGPMvn'
#
# # Check that the new shape is at least as specific as the shape we are replacing
# for new_shape, old_shape in zip(new_gp_mvn.type.shape, normal_rv.type.shape, strict=True):
# if new_shape is None:
# assert old_shape is None
#
# return {
# next_normal_rng: next_new_gp_rng,
# normal_rv: new_gp_mvn,
# next_gp_rng: next_new_gp_rng
# }
#
# else:
# return None
#
# #TODO: Why do I need to register this twice?
# specialization_ir_rewrites_db.register(
# GP_normal_mvnormal_conjugacy.__name__,
# GP_normal_mvnormal_conjugacy,
# "basic",
# )

# @node_rewriter(tracks=[pm.MvNormal.rv_type])
# def GP_normal_marginal_logp(fgraph: FunctionGraph, node):
# """
# Replace Normal(GP(cov), sigma) -> MvNormal(0, cov + diag(sigma)).
# """
# rng, size, mu, cov = node.inputs
# if cov.owner and cov.owner.op == matrix_inverse:
# tau = cov.owner.inputs[0]
# return PrecisionMvNormalRV.rv_op(mu, tau, size=size, rng=rng).owner.outputs
# return None
#

# cov_op = GPCovariance()
# gp_op = GP("vanilla")
# # SymbolicRandomVariable.register(type(gp_op))
# prior_from_gp = PriorFromGP()
#
# MeasurableVariable.register(type(prior_from_gp))
#
#
# @_get_measurable_outputs.register(type(prior_from_gp))
# def gp_measurable_outputs(op, node):
# return node.outputs
60 changes: 60 additions & 0 deletions tests/test_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pytest

from pymc_experimental.gp.pytensor_gp import GP, ExpQuad


def test_exp_quad():
x = pt.arange(3)[:, None]
ls = pt.ones(())
cov = ExpQuad.build_covariance(x, ls).eval()
expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]])

np.testing.assert_allclose(cov, np.exp(-0.5 * expected_distance))


@pytest.fixture(scope="session")
def marginal_model():
with pm.Model() as m:
X = pm.Data("X", np.arange(3)[:, None])
y = np.full(3, np.pi)
ls = 1.0
cov = ExpQuad(X, ls)
gp = GP("gp", cov=cov)

sigma = 1.0
obs = pm.Normal("obs", mu=gp, sigma=sigma, observed=y)

return m


def test_marginal_sigma_rewrites_to_white_noise_cov(marginal_model):
obs = marginal_model["obs"]

# TODO: Bring these checks back after we implement marginalization of the GP RV
#
# assert sum(isinstance(var.owner.op, pm.Normal.rv_type)
# for var in ancestors([obs])
# if var.owner is not None) == 1
#
f = pm.compile_pymc([], obs)
#
# assert not any(isinstance(node.op, pm.Normal.rv_type) for node in f.maker.fgraph.apply_nodes)

draws = np.stack([f() for _ in range(10_000)])
empirical_cov = np.cov(draws.T)

expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]])

np.testing.assert_allclose(
empirical_cov, np.exp(-0.5 * expected_distance) + np.eye(3), atol=0.1, rtol=0.1
)


def test_marginal_gp_logp(marginal_model):
expected_logps = {"obs": -8.8778}
point_logps = marginal_model.point_logps(round_vals=4)
for v1, v2 in zip(point_logps.values(), expected_logps.values()):
np.testing.assert_allclose(v1, v2, atol=1e-6)

0 comments on commit 3feaf2e

Please sign in to comment.