Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PyTensor deterministic operations as observations #7656

Merged
merged 14 commits into from
Feb 27, 2025
11 changes: 2 additions & 9 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,13 @@
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Variable
from pytensor.raise_op import Assert
from pytensor.scalar import Cast
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.basic import IntegersRV
from pytensor.tensor.type import TensorType
from pytensor.tensor.variable import TensorConstant, TensorVariable

import pymc as pm

from pymc.logprob.utils import rvs_in_graph
from pymc.pytensorf import GeneratorOp, convert_data, smarttypeX
from pymc.vartypes import isgenerator

Expand Down Expand Up @@ -158,13 +157,7 @@ def is_valid_observed(v) -> bool:
return True

return (
# The only PyTensor operation we allow on observed data is type casting
# Although we could allow for any graph that does not depend on other RVs
(
isinstance(v.owner.op, Elemwise)
and isinstance(v.owner.op.scalar_op, Cast)
and is_valid_observed(v.owner.inputs[0])
)
not rvs_in_graph(v)
# Or Minibatch
or (
isinstance(v.owner.op, MinibatchOp)
Expand Down
6 changes: 6 additions & 0 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray:
mask[mask_idx] = 1
return np.ma.MaskedArray(array_data, mask)

from pymc.logprob.utils import rvs_in_graph

if not inputvars(x) and not rvs_in_graph(x):
cheap_eval_mode = Mode(linker="py", optimizer=None)
return x.eval(mode=cheap_eval_mode)

raise TypeError(f"Data cannot be extracted from {x}")


Expand Down
19 changes: 15 additions & 4 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,11 +602,11 @@ def test_allowed(self):
mb = pm.Minibatch(pt.as_tensor(self.data).astype(int), batch_size=20)
assert isinstance(mb.owner.op, MinibatchOp)

with pytest.raises(ValueError, match="not valid for Minibatch"):
pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20)
mb = pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20)
assert isinstance(mb.owner.op, MinibatchOp)

with pytest.raises(ValueError, match="not valid for Minibatch"):
pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20)
for mb in pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20):
assert isinstance(mb.owner.op, MinibatchOp)

def test_assert(self):
d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20)
Expand All @@ -623,3 +623,14 @@ def test_multiple_vars(self):
[draw_mA, draw_mB] = pm.draw([mA, mB])
assert draw_mA.shape == (10,)
np.testing.assert_allclose(draw_mA, -draw_mB)


def test_scaling_data_works_in_likelihood() -> None:
data = np.array([10, 11, 12, 13, 14, 15])

with pm.Model():
target = pm.Data("target", data)
scale = 12
scaled_target = target / scale
mu = pm.Normal("mu", mu=0, sigma=1)
pm.Normal("x", mu=mu, sigma=1, observed=scaled_target)
15 changes: 15 additions & 0 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,21 @@ def test_minibatch_variable(self):
assert isinstance(res, np.ndarray)
np.testing.assert_array_equal(res, y)

def test_pytensor_operations(self):
x = np.array([1, 2, 3])
target = 1 + 3 * pt.as_tensor_variable(x)

res = extract_obs_data(target)
assert isinstance(res, np.ndarray)
np.testing.assert_array_equal(res, np.array([4, 7, 10]))

def test_pytensor_operations_raises(self):
x = pt.scalar("x")
target = 1 + 3 * x

with pytest.raises(TypeError, match="Data cannot be extracted from"):
extract_obs_data(target)


@pytest.mark.parametrize("input_dtype", ["int32", "int64", "float32", "float64"])
def test_convert_data(input_dtype):
Expand Down
Loading