Skip to content

Commit

Permalink
added cannot_sample_rv
Browse files Browse the repository at this point in the history
  • Loading branch information
Rishab87 committed Feb 24, 2025
1 parent 7621508 commit c3a1fe5
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 6 deletions.
6 changes: 0 additions & 6 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,12 +619,6 @@ def dist(cls, n, p, *args, **kwargs):
return super().dist([n, p], *args, **kwargs)

def support_point(rv, size, n, p):
observed = getattr(rv.tag, "observed", None)
if observed is None:
raise ValueError(
"Latent Multinomial variables are not supported for sampling. "
"Use a Categorical variable instead."
)
n = pt.shape_padright(n)
mean = n * p
mode = pt.round(mean)
Expand Down
9 changes: 9 additions & 0 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains
from pymc.backends.zarr import ZarrChain, ZarrTrace
from pymc.blocking import DictToArrayBijection
from pymc.distributions.multivariate import Multinomial
from pymc.exceptions import SamplingError
from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain
from pymc.model import Model, modelcontext
Expand All @@ -63,6 +64,7 @@
)
from pymc.step_methods import NUTS, CompoundStep
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
from pymc.step_methods.cannot_sample import CannotSampleRV
from pymc.step_methods.hmc import quadpotential
from pymc.util import (
ProgressBarManager,
Expand Down Expand Up @@ -144,6 +146,13 @@ def instantiate_steppers(
if initial_point is None:
initial_point = model.initial_point()

for rv in model.free_RVs:
if isinstance(rv.owner.op, Multinomial) and getattr(rv.tag, "observed", None) is None:
for step_class in list(selected_steps.keys()):
if rv in selected_steps[step_class]:
selected_steps[step_class].remove(rv)
selected_steps.setdefault(CannotSampleRV, []).append(rv)

for step_class, vars in selected_steps.items():
if vars:
name = getattr(step_class, "name")
Expand Down
21 changes: 21 additions & 0 deletions pymc/step_methods/cannot_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pymc.step_methods.arraystep import ArrayStep

class CannotSampleRV(ArrayStep):
"""
A step method that raises an error when sampling a latent Multinomial variable.
"""
name = "cannot_sample_rv"
def __init__(self, vars, **kwargs):
# Remove keys that ArrayStep.__init__ does not accept.
kwargs.pop("model", None)
kwargs.pop("initial_point", None)
kwargs.pop("compile_kwargs", None)
self.vars = vars
super().__init__(vars=vars,fs=[], **kwargs)

def astep(self, q0):
# This method is required by the abstract base class.
raise ValueError(
"Latent Multinomial variables are not supported"
)

0 comments on commit c3a1fe5

Please sign in to comment.