From e3f66d987b1bef1374a14df2afb8c51e478d29ab Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 24 Jan 2025 21:11:40 +0100 Subject: [PATCH 01/13] add a test case --- tests/test_data.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_data.py b/tests/test_data.py index 5d370d02c0..5a4353d9a5 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -623,3 +623,14 @@ def test_multiple_vars(self): [draw_mA, draw_mB] = pm.draw([mA, mB]) assert draw_mA.shape == (10,) np.testing.assert_allclose(draw_mA, -draw_mB) + + +def test_scaling_data_works_in_likelihood() -> None: + data = np.array([10, 11, 12, 13, 14, 15]) + + with pm.Model() as model: + target = pm.Data("target", data) + scale = 12 + scaled_target = target / scale + mu = pm.Normal("mu", mu=0, sigma=1) + pm.Normal("x", mu=mu, sigma=1, observed=scaled_target) From 7b39d7a8aff008ae7b354ac0ac51d252ec21f0b4 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 24 Jan 2025 21:50:21 +0100 Subject: [PATCH 02/13] check for random ancestors as well --- pymc/data.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pymc/data.py b/pymc/data.py index 9373eb5775..fa99975306 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -33,6 +33,7 @@ from pytensor.scalar import Cast from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.basic import IntegersRV +from pytensor.tensor.random.var import RandomGeneratorSharedVariable from pytensor.tensor.type import TensorType from pytensor.tensor.variable import TensorConstant, TensorVariable @@ -148,6 +149,19 @@ def __str__(self): return "Minibatch" +def first_inputs(r): + if not r.owner: + return + + first_input = r.owner.inputs[0] + yield first_input + yield from first_inputs(first_input) + + +def has_random_ancestor(r): + return any(isinstance(i, RandomGeneratorSharedVariable) for i in first_inputs(r)) + + def is_valid_observed(v) -> bool: if not isinstance(v, Variable): # Non-symbolic constant @@ -165,6 +179,7 @@ def is_valid_observed(v) -> bool: and isinstance(v.owner.op.scalar_op, Cast) and is_valid_observed(v.owner.inputs[0]) ) + or not has_random_ancestor(v) # Or Minibatch or ( isinstance(v.owner.op, MinibatchOp) From b36e573fbcc00c6fce6b83e696dc08f1abd7bef7 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Fri, 24 Jan 2025 22:06:55 +0100 Subject: [PATCH 03/13] check for inputs first --- pymc/data.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pymc/data.py b/pymc/data.py index fa99975306..b3e2200f63 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -153,7 +153,12 @@ def first_inputs(r): if not r.owner: return - first_input = r.owner.inputs[0] + inputs = r.owner.inputs + + if not inputs: + return + + first_input = inputs[0] yield first_input yield from first_inputs(first_input) From 7f71397615469d6737c8ea2bb24ea6042b614d9c Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sat, 25 Jan 2025 16:30:04 +0100 Subject: [PATCH 04/13] use existing function --- pymc/data.py | 29 +++-------------------------- 1 file changed, 3 insertions(+), 26 deletions(-) diff --git a/pymc/data.py b/pymc/data.py index b3e2200f63..147e93573e 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -30,15 +30,13 @@ from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Variable from pytensor.raise_op import Assert -from pytensor.scalar import Cast -from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.basic import IntegersRV -from pytensor.tensor.random.var import RandomGeneratorSharedVariable from pytensor.tensor.type import TensorType from pytensor.tensor.variable import TensorConstant, TensorVariable import pymc as pm +from pymc.logprob.utils import rvs_in_graph from pymc.pytensorf import GeneratorOp, convert_data, smarttypeX from pymc.vartypes import isgenerator @@ -149,22 +147,8 @@ def __str__(self): return "Minibatch" -def first_inputs(r): - if not r.owner: - return - - inputs = r.owner.inputs - - if not inputs: - return - - first_input = inputs[0] - yield first_input - yield from first_inputs(first_input) - - def has_random_ancestor(r): - return any(isinstance(i, RandomGeneratorSharedVariable) for i in first_inputs(r)) + return rvs_in_graph(r) != set() def is_valid_observed(v) -> bool: @@ -177,14 +161,7 @@ def is_valid_observed(v) -> bool: return True return ( - # The only PyTensor operation we allow on observed data is type casting - # Although we could allow for any graph that does not depend on other RVs - ( - isinstance(v.owner.op, Elemwise) - and isinstance(v.owner.op.scalar_op, Cast) - and is_valid_observed(v.owner.inputs[0]) - ) - or not has_random_ancestor(v) + not has_random_ancestor(v) # Or Minibatch or ( isinstance(v.owner.op, MinibatchOp) From b182686edd12c7cd6c88cfd6554856a72ef856cb Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sat, 25 Jan 2025 16:51:21 +0100 Subject: [PATCH 05/13] eval for the observed_data group --- pymc/pytensorf.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 1f390b1771..dc0c157919 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -177,6 +177,11 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray: mask[mask_idx] = 1 return np.ma.MaskedArray(array_data, mask) + from pymc.data import has_random_ancestor + + if not has_random_ancestor(x): + return x.eval() + raise TypeError(f"Data cannot be extracted from {x}") From b2a3d1e648fc56355b176d84eb21648a068055fa Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 28 Jan 2025 17:52:00 +0100 Subject: [PATCH 06/13] specify the mode --- pymc/pytensorf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index dc0c157919..8dff15a45e 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -180,7 +180,8 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray: from pymc.data import has_random_ancestor if not has_random_ancestor(x): - return x.eval() + cheap_eval_mode = Mode(linker="py", optimizer=None) + return x.eval(mode=cheap_eval_mode) raise TypeError(f"Data cannot be extracted from {x}") From acd22f367d064ceb30ccc75cec1138bcc5bf98dd Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 28 Jan 2025 18:06:15 +0100 Subject: [PATCH 07/13] add test to extract function --- tests/test_pytensorf.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index 0ea18dabe3..f8487533a9 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -199,6 +199,14 @@ def test_minibatch_variable(self): assert isinstance(res, np.ndarray) np.testing.assert_array_equal(res, y) + def test_pytensor_operations(self): + x = np.array([1, 2, 3]) + target = 1 + 3 * pt.as_tensor_variable(x) + + res = extract_obs_data(target) + assert isinstance(res, np.ndarray) + np.testing.assert_array_equal(res, np.array([4, 7, 10])) + @pytest.mark.parametrize("input_dtype", ["int32", "int64", "float32", "float64"]) def test_convert_data(input_dtype): From 56ad68bde8984546914d418a52b5c56c3984893f Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 29 Jan 2025 08:34:31 +0100 Subject: [PATCH 08/13] simplify and remove helper function --- pymc/data.py | 6 +----- pymc/pytensorf.py | 4 ++-- tests/test_data.py | 2 +- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/pymc/data.py b/pymc/data.py index 147e93573e..1413ac640a 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -147,10 +147,6 @@ def __str__(self): return "Minibatch" -def has_random_ancestor(r): - return rvs_in_graph(r) != set() - - def is_valid_observed(v) -> bool: if not isinstance(v, Variable): # Non-symbolic constant @@ -161,7 +157,7 @@ def is_valid_observed(v) -> bool: return True return ( - not has_random_ancestor(v) + not rvs_in_graph(v) # Or Minibatch or ( isinstance(v.owner.op, MinibatchOp) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 8dff15a45e..d78d1c3252 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -177,9 +177,9 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray: mask[mask_idx] = 1 return np.ma.MaskedArray(array_data, mask) - from pymc.data import has_random_ancestor + from pymc.logprob.utils import rvs_in_graph - if not has_random_ancestor(x): + if not rvs_in_graph(x): cheap_eval_mode = Mode(linker="py", optimizer=None) return x.eval(mode=cheap_eval_mode) diff --git a/tests/test_data.py b/tests/test_data.py index 5a4353d9a5..2f268dd00c 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -628,7 +628,7 @@ def test_multiple_vars(self): def test_scaling_data_works_in_likelihood() -> None: data = np.array([10, 11, 12, 13, 14, 15]) - with pm.Model() as model: + with pm.Model(): target = pm.Data("target", data) scale = 12 scaled_target = target / scale From 6b410f30c2c5e19c73be9daffad13cb491a4f6ba Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 29 Jan 2025 08:40:16 +0100 Subject: [PATCH 09/13] check for variable having inputvars --- pymc/pytensorf.py | 2 +- tests/test_pytensorf.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index d78d1c3252..935b516fcd 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -179,7 +179,7 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray: from pymc.logprob.utils import rvs_in_graph - if not rvs_in_graph(x): + if not inputvars(x) and not rvs_in_graph(x): cheap_eval_mode = Mode(linker="py", optimizer=None) return x.eval(mode=cheap_eval_mode) diff --git a/tests/test_pytensorf.py b/tests/test_pytensorf.py index f8487533a9..e47d23ceb5 100644 --- a/tests/test_pytensorf.py +++ b/tests/test_pytensorf.py @@ -207,6 +207,13 @@ def test_pytensor_operations(self): assert isinstance(res, np.ndarray) np.testing.assert_array_equal(res, np.array([4, 7, 10])) + def test_pytensor_operations_raises(self): + x = pt.scalar("x") + target = 1 + 3 * x + + with pytest.raises(TypeError, match="Data cannot be extracted from"): + extract_obs_data(target) + @pytest.mark.parametrize("input_dtype", ["int32", "int64", "float32", "float64"]) def test_convert_data(input_dtype): From 2d753b3b71376f0d0bea9db49daf59be607d9394 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 29 Jan 2025 18:24:28 +0100 Subject: [PATCH 10/13] allowing for minibatch of pytensor operations --- tests/test_data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_data.py b/tests/test_data.py index 2f268dd00c..1146f506be 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -602,11 +602,11 @@ def test_allowed(self): mb = pm.Minibatch(pt.as_tensor(self.data).astype(int), batch_size=20) assert isinstance(mb.owner.op, MinibatchOp) - with pytest.raises(ValueError, match="not valid for Minibatch"): - pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20) + mb = pm.Minibatch(pt.as_tensor(self.data) * 2, batch_size=20) + assert isinstance(mb.owner.op, MinibatchOp) - with pytest.raises(ValueError, match="not valid for Minibatch"): - pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20) + for mb in pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20): + assert isinstance(mb.owner.op, MinibatchOp) def test_assert(self): d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20) From f1be18736c7ec8491b71a060ae3b60d6e6e86410 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sat, 8 Feb 2025 12:09:43 +0100 Subject: [PATCH 11/13] case that doesnt work now --- tests/test_data.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_data.py b/tests/test_data.py index 1146f506be..1dae3dc4ef 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -608,6 +608,12 @@ def test_allowed(self): for mb in pm.Minibatch(self.data, pt.as_tensor(self.data) * 2, batch_size=20): assert isinstance(mb.owner.op, MinibatchOp) + def test_not_allowed(self): + data = pt.random.normal(loc=self.data, scale=1) + + with pytest.raises(ValueError): + pm.Minibatch(data, batch_size=20) + def test_assert(self): d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20) with pytest.raises( From 77f73842aa3bd9477eee2416708b8c80806efe39 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sat, 8 Feb 2025 12:13:15 +0100 Subject: [PATCH 12/13] change the error message --- pymc/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/data.py b/pymc/data.py index 1413ac640a..73c4b85171 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -190,7 +190,7 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: for i, v in enumerate(tensors): if not is_valid_observed(v): raise ValueError( - f"{i}: {v} is not valid for Minibatch, only constants or constants.astype(dtype) are allowed" + f"{i}: {v} is not valid for Minibatch, only non-random variables are allowed" ) upper = tensors[0].shape[0] From 2269bd632cff94ab8d2c4e1b1d0f717f525cd883 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 12 Feb 2025 19:26:34 +0100 Subject: [PATCH 13/13] add sample to check observed_data --- tests/test_data.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_data.py b/tests/test_data.py index d7495e5fd4..afca1831a7 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -547,3 +547,10 @@ def test_scaling_data_works_in_likelihood() -> None: scaled_target = target / scale mu = pm.Normal("mu", mu=0, sigma=1) pm.Normal("x", mu=mu, sigma=1, observed=scaled_target) + + idata = pm.sample(10, chains=1, tune=100) + + np.testing.assert_allclose( + idata.observed_data["x"].values, + data / scale, + )