Skip to content

Commit

Permalink
Fix bug with chained CustomSymbolicDists
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 24, 2025
1 parent 0772383 commit cacfd8b
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 4 deletions.
6 changes: 4 additions & 2 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from pytensor import tensor as pt
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import FunctionGraph, clone_replace, node_rewriter
from pytensor.graph import FunctionGraph, graph_replace, node_rewriter
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.rewriting.basic import in2out
from pytensor.graph.utils import MetaType
Expand Down Expand Up @@ -588,7 +588,9 @@ def inline_symbolic_random_variable(fgraph, node):
"""Expand a SymbolicRV when obtaining the logp graph if `inline_logprob` is True."""
op = node.op
if op.inline_logprob:
return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
return graph_replace(
op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)), strict=False
)


# Registered before pre-canonicalization which happens at position=-10
Expand Down
73 changes: 71 additions & 2 deletions tests/distributions/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from numpy import random as npr
from pytensor import scan
from pytensor import tensor as pt
from pytensor.graph import FunctionGraph
from scipy import stats as st

from pymc.distributions import (
Expand All @@ -42,11 +43,11 @@
Uniform,
)
from pymc.distributions.custom import CustomDist, CustomDistRV, CustomSymbolicDistRV
from pymc.distributions.distribution import support_point
from pymc.distributions.distribution import inline_symbolic_random_variable, support_point
from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple
from pymc.distributions.transforms import log
from pymc.exceptions import BlockModelAccessError
from pymc.logprob import logcdf, logp
from pymc.logprob import conditional_logp, logcdf, logp
from pymc.model import Deterministic, Model
from pymc.pytensorf import collect_default_updates
from pymc.sampling import draw, sample, sample_posterior_predictive
Expand Down Expand Up @@ -648,3 +649,71 @@ def dist(p, size):
assert out.owner.op.extended_signature == "[size],(),[rng]->(),[rng]"
assert out.owner.op.ndim_supp == 0
assert out.owner.op.ndims_params == [0]

def test_inline_does_not_duplicate_graph(self):
mu = Normal.dist()
x = CustomDist.dist(mu, dist=lambda mu, size: Normal.dist(mu, size=size))

fgraph = FunctionGraph(outputs=[x], clone=False)
[inner_x, inner_rng_update] = inline_symbolic_random_variable.transform(fgraph, x.owner)
assert inner_rng_update.owner.inputs[-2] is mu
assert inner_x.owner.inputs[-2] is mu

def test_chained_custom_dist_bug(self):
"""Regression test for issue reported in https://discourse.pymc.io/t/error-with-custom-distribution-after-using-scan/16255
This bug was caused by a duplication of a Scan-based CustomSymbolicDist when inlining another CustomSymbolicDist that used it as an input.
PyTensor failed to merge the two Scan graphs, causing a failure in the logp extraction.
"""

rng = np.random.default_rng(123)
steps = 4
batch = 2

def scan_dist(seq, n_steps, size):
def step(s):
innov = Normal.dist()
traffic = s + innov
return traffic, {innov.owner.inputs[0]: innov.owner.outputs[0]}

rv_seq, _ = pytensor.scan(
fn=step,
sequences=[seq],
outputs_info=[None],
n_steps=n_steps,
strict=True,
)
return rv_seq

def normal_shifted(mu, size):
return Normal.dist(mu=mu, size=size) - 1

seq = pt.matrix("seq", shape=(batch, steps))
latent_rv = CustomDist.dist(
seq.T,
steps,
dist=scan_dist,
shape=(steps, batch),
)
latent_rv.name = "latent"

observed_rv = CustomDist.dist(
latent_rv,
dist=normal_shifted,
shape=(steps, batch),
)
observed_rv.name = "observed"

latent_vv = latent_rv.type()
observed_vv = observed_rv.type()

observed_logp = conditional_logp({latent_rv: latent_vv, observed_rv: observed_vv})[
observed_vv
]
latent_vv_test = rng.standard_normal(size=(steps, batch))
observed_vv_test = rng.standard_normal(size=(steps, batch))
expected_logp = st.norm.logpdf(observed_vv_test + 1, loc=latent_vv_test)
np.testing.assert_allclose(
observed_logp.eval({latent_vv: latent_vv_test, observed_vv: observed_vv_test}),
expected_logp,
)

0 comments on commit cacfd8b

Please sign in to comment.