Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with chained CustomSymbolicDists #7690

Merged
merged 1 commit into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from pytensor import tensor as pt
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import FunctionGraph, clone_replace, node_rewriter
from pytensor.graph import FunctionGraph, graph_replace, node_rewriter
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.rewriting.basic import in2out
from pytensor.graph.utils import MetaType
Expand Down Expand Up @@ -588,7 +588,9 @@ def inline_symbolic_random_variable(fgraph, node):
"""Expand a SymbolicRV when obtaining the logp graph if `inline_logprob` is True."""
op = node.op
if op.inline_logprob:
return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
return graph_replace(
op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)), strict=False
)


# Registered before pre-canonicalization which happens at position=-10
Expand Down
73 changes: 71 additions & 2 deletions tests/distributions/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from numpy import random as npr
from pytensor import scan
from pytensor import tensor as pt
from pytensor.graph import FunctionGraph
from scipy import stats as st

from pymc.distributions import (
Expand All @@ -42,11 +43,11 @@
Uniform,
)
from pymc.distributions.custom import CustomDist, CustomDistRV, CustomSymbolicDistRV
from pymc.distributions.distribution import support_point
from pymc.distributions.distribution import inline_symbolic_random_variable, support_point
from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple
from pymc.distributions.transforms import log
from pymc.exceptions import BlockModelAccessError
from pymc.logprob import logcdf, logp
from pymc.logprob import conditional_logp, logcdf, logp
from pymc.model import Deterministic, Model
from pymc.pytensorf import collect_default_updates
from pymc.sampling import draw, sample, sample_posterior_predictive
Expand Down Expand Up @@ -648,3 +649,71 @@ def dist(p, size):
assert out.owner.op.extended_signature == "[size],(),[rng]->(),[rng]"
assert out.owner.op.ndim_supp == 0
assert out.owner.op.ndims_params == [0]

def test_inline_does_not_duplicate_graph(self):
mu = Normal.dist()
x = CustomDist.dist(mu, dist=lambda mu, size: Normal.dist(mu, size=size))

fgraph = FunctionGraph(outputs=[x], clone=False)
[inner_x, inner_rng_update] = inline_symbolic_random_variable.transform(fgraph, x.owner)
assert inner_rng_update.owner.inputs[-2] is mu
assert inner_x.owner.inputs[-2] is mu

def test_chained_custom_dist_bug(self):
"""Regression test for issue reported in https://discourse.pymc.io/t/error-with-custom-distribution-after-using-scan/16255
This bug was caused by a duplication of a Scan-based CustomSymbolicDist when inlining another CustomSymbolicDist that used it as an input.
PyTensor failed to merge the two Scan graphs, causing a failure in the logp extraction.
"""

rng = np.random.default_rng(123)
steps = 4
batch = 2

def scan_dist(seq, n_steps, size):
def step(s):
innov = Normal.dist()
traffic = s + innov
return traffic, {innov.owner.inputs[0]: innov.owner.outputs[0]}

rv_seq, _ = pytensor.scan(
fn=step,
sequences=[seq],
outputs_info=[None],
n_steps=n_steps,
strict=True,
)
return rv_seq

def normal_shifted(mu, size):
return Normal.dist(mu=mu, size=size) - 1

seq = pt.matrix("seq", shape=(batch, steps))
latent_rv = CustomDist.dist(
seq.T,
steps,
dist=scan_dist,
shape=(steps, batch),
)
latent_rv.name = "latent"

observed_rv = CustomDist.dist(
latent_rv,
dist=normal_shifted,
shape=(steps, batch),
)
observed_rv.name = "observed"

latent_vv = latent_rv.type()
observed_vv = observed_rv.type()

observed_logp = conditional_logp({latent_rv: latent_vv, observed_rv: observed_vv})[
observed_vv
]
latent_vv_test = rng.standard_normal(size=(steps, batch))
observed_vv_test = rng.standard_normal(size=(steps, batch))
expected_logp = st.norm.logpdf(observed_vv_test + 1, loc=latent_vv_test)
np.testing.assert_allclose(
observed_logp.eval({latent_vv: latent_vv_test, observed_vv: observed_vv_test}),
expected_logp,
)
Loading