diff --git a/pymc/data.py b/pymc/data.py index c21ac3001f..4b3538a340 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -30,13 +30,12 @@ from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Variable from pytensor.raise_op import Assert -from pytensor.scalar import Cast -from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.basic import IntegersRV from pytensor.tensor.variable import TensorConstant, TensorVariable import pymc as pm +from pymc.logprob.utils import rvs_in_graph from pymc.pytensorf import convert_data from pymc.vartypes import isgenerator @@ -111,13 +110,7 @@ def is_valid_observed(v) -> bool: return True return ( - # The only PyTensor operation we allow on observed data is type casting - # Although we could allow for any graph that does not depend on other RVs - ( - isinstance(v.owner.op, Elemwise) - and isinstance(v.owner.op.scalar_op, Cast) - and is_valid_observed(v.owner.inputs[0]) - ) + not rvs_in_graph(v) # Or Minibatch or ( isinstance(v.owner.op, MinibatchOp) @@ -148,7 +141,7 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: for i, v in enumerate(tensors): if not is_valid_observed(v): raise ValueError( - f"{i}: {v} is not valid for Minibatch, only constants or constants.astype(dtype) are allowed" + f"{i}: {v} is not valid for Minibatch, only non-random variables are allowed" ) upper = tensors[0].shape[0] diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index cf5700b95e..82eca936b7 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -164,6 +164,12 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray: mask[mask_idx] = 1 return np.ma.MaskedArray(array_data, mask) + from pymc.logprob.utils import rvs_in_graph + + if not inputvars(x) and not rvs_in_graph(x): + cheap_eval_mode = Mode(linker="py", optimizer=None) + return x.eval(mode=cheap_eval_mode) + raise TypeError(f"Data cannot be extracted from {x}") diff --git a/tests/test_data.py b/tests/test_data.py index 154737b637..afca1831a7 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -509,11 +509,17 @@ def test_allowed(self): mb = pm.Minibatch(pt.as_tensor(self.data).astype(int), batch_size=20) assert isinstance(mb.owner.op, MinibatchOp) - with pytest.raises(ValueError, match="not valid for Minibatch"): - pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20) + mb = pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20) + assert isinstance(mb.owner.op, MinibatchOp) + + for mb in pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20): + assert isinstance(mb.owner.op, MinibatchOp) - with pytest.raises(ValueError, match="not valid for Minibatch"): - pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20) + def test_not_allowed(self): + data = pt.random.normal(loc=self.data, scale=1) + + with pytest.raises(ValueError): + pm.Minibatch(data, batch_size=20) def test_assert(self): d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20) @@ -530,3 +536,21 @@ def test_multiple_vars(self): [draw_mA, draw_mB] = pm.draw([mA, mB]) assert draw_mA.shape == (10,) np.testing.assert_allclose(draw_mA, -draw_mB) + + +def test_scaling_data_works_in_likelihood() -> None: + data = np.array([10, 11, 12, 13, 14, 15]) + + with pm.Model(): + target = pm.Data("target", data) + scale = 12 + scaled_target = target / scale + mu = pm.Normal("mu", mu=0, sigma=1) + pm.Normal("x", mu=mu, sigma=1, observed=scaled_target) + + idata = pm.sample(10, chains=1, tune=100) + + np.testing.assert_allclose( + idata.observed_data["x"].values, + data / scale, + ) diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index b8c82886b9..34360397a3 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -195,6 +195,21 @@ def test_minibatch_variable(self): assert isinstance(res, np.ndarray) np.testing.assert_array_equal(res, y) + def test_pytensor_operations(self): + x = np.array([1, 2, 3]) + target = 1 + 3 * pt.as_tensor_variable(x) + + res = extract_obs_data(target) + assert isinstance(res, np.ndarray) + np.testing.assert_array_equal(res, np.array([4, 7, 10])) + + def test_pytensor_operations_raises(self): + x = pt.scalar("x") + target = 1 + 3 * x + + with pytest.raises(TypeError, match="Data cannot be extracted from"): + extract_obs_data(target) + @pytest.mark.parametrize("input_dtype", ["int32", "int64", "float32", "float64"]) def test_convert_data(input_dtype):