From cacfd8b69db17abc1944a2f9f20e079a6b75fb10 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 24 Feb 2025 12:38:12 +0100 Subject: [PATCH] Fix bug with chained CustomSymbolicDists --- pymc/distributions/distribution.py | 6 ++- tests/distributions/test_custom.py | 73 +++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 4 deletions(-) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index b2ec6fb79b..577d9245a6 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -27,7 +27,7 @@ from pytensor import tensor as pt from pytensor.compile.builders import OpFromGraph -from pytensor.graph import FunctionGraph, clone_replace, node_rewriter +from pytensor.graph import FunctionGraph, graph_replace, node_rewriter from pytensor.graph.basic import Apply, Variable from pytensor.graph.rewriting.basic import in2out from pytensor.graph.utils import MetaType @@ -588,7 +588,9 @@ def inline_symbolic_random_variable(fgraph, node): """Expand a SymbolicRV when obtaining the logp graph if `inline_logprob` is True.""" op = node.op if op.inline_logprob: - return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs))) + return graph_replace( + op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)), strict=False + ) # Registered before pre-canonicalization which happens at position=-10 diff --git a/tests/distributions/test_custom.py b/tests/distributions/test_custom.py index d3de7cf4f7..a076eef7b6 100644 --- a/tests/distributions/test_custom.py +++ b/tests/distributions/test_custom.py @@ -21,6 +21,7 @@ from numpy import random as npr from pytensor import scan from pytensor import tensor as pt +from pytensor.graph import FunctionGraph from scipy import stats as st from pymc.distributions import ( @@ -42,11 +43,11 @@ Uniform, ) from pymc.distributions.custom import CustomDist, CustomDistRV, CustomSymbolicDistRV -from pymc.distributions.distribution import support_point +from pymc.distributions.distribution import inline_symbolic_random_variable, support_point from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple from pymc.distributions.transforms import log from pymc.exceptions import BlockModelAccessError -from pymc.logprob import logcdf, logp +from pymc.logprob import conditional_logp, logcdf, logp from pymc.model import Deterministic, Model from pymc.pytensorf import collect_default_updates from pymc.sampling import draw, sample, sample_posterior_predictive @@ -648,3 +649,71 @@ def dist(p, size): assert out.owner.op.extended_signature == "[size],(),[rng]->(),[rng]" assert out.owner.op.ndim_supp == 0 assert out.owner.op.ndims_params == [0] + + def test_inline_does_not_duplicate_graph(self): + mu = Normal.dist() + x = CustomDist.dist(mu, dist=lambda mu, size: Normal.dist(mu, size=size)) + + fgraph = FunctionGraph(outputs=[x], clone=False) + [inner_x, inner_rng_update] = inline_symbolic_random_variable.transform(fgraph, x.owner) + assert inner_rng_update.owner.inputs[-2] is mu + assert inner_x.owner.inputs[-2] is mu + + def test_chained_custom_dist_bug(self): + """Regression test for issue reported in https://discourse.pymc.io/t/error-with-custom-distribution-after-using-scan/16255 + + This bug was caused by a duplication of a Scan-based CustomSymbolicDist when inlining another CustomSymbolicDist that used it as an input. + PyTensor failed to merge the two Scan graphs, causing a failure in the logp extraction. + """ + + rng = np.random.default_rng(123) + steps = 4 + batch = 2 + + def scan_dist(seq, n_steps, size): + def step(s): + innov = Normal.dist() + traffic = s + innov + return traffic, {innov.owner.inputs[0]: innov.owner.outputs[0]} + + rv_seq, _ = pytensor.scan( + fn=step, + sequences=[seq], + outputs_info=[None], + n_steps=n_steps, + strict=True, + ) + return rv_seq + + def normal_shifted(mu, size): + return Normal.dist(mu=mu, size=size) - 1 + + seq = pt.matrix("seq", shape=(batch, steps)) + latent_rv = CustomDist.dist( + seq.T, + steps, + dist=scan_dist, + shape=(steps, batch), + ) + latent_rv.name = "latent" + + observed_rv = CustomDist.dist( + latent_rv, + dist=normal_shifted, + shape=(steps, batch), + ) + observed_rv.name = "observed" + + latent_vv = latent_rv.type() + observed_vv = observed_rv.type() + + observed_logp = conditional_logp({latent_rv: latent_vv, observed_rv: observed_vv})[ + observed_vv + ] + latent_vv_test = rng.standard_normal(size=(steps, batch)) + observed_vv_test = rng.standard_normal(size=(steps, batch)) + expected_logp = st.norm.logpdf(observed_vv_test + 1, loc=latent_vv_test) + np.testing.assert_allclose( + observed_logp.eval({latent_vv: latent_vv_test, observed_vv: observed_vv_test}), + expected_logp, + )