diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index d2f35c8007..649ca246c0 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -1155,29 +1155,58 @@ def support_point(rv, size, p): mode = pt.full(size, mode) return mode - def logp(value, p): - k = pt.shape(p)[-1] - value_clip = pt.clip(value, 0, k - 1) + @staticmethod + def _safe_index_value_p(value, p): + # Find the probabily of the given value by indexing in p, + # after handling broadcasting and invalid values. # In the standard case p has one more dimension than value dim_diff = p.type.ndim - value.type.ndim if dim_diff > 1: # p brodacasts implicitly beyond value - value_clip = pt.shape_padleft(value_clip, dim_diff - 1) + value = pt.shape_padleft(value, dim_diff - 1) elif dim_diff < 1: # value broadcasts implicitly beyond p p = pt.shape_padleft(p, 1 - dim_diff) - a = pt.log(pt.take_along_axis(p, value_clip[..., None], axis=-1).squeeze(-1)) + k = pt.shape(p)[-1] + value_clip = pt.clip(value, 0, k - 1).astype(int) + return value, pt.log(pt.take_along_axis(p, value_clip[..., None], axis=-1).squeeze(-1)) - res = pt.switch( + def logp(value, p): + k = pt.shape(p)[-1] + value, safe_value_p = Categorical._safe_index_value_p(value, p) + + value_p = pt.switch( pt.or_(pt.lt(value, 0), pt.gt(value, k - 1)), -np.inf, - a, + safe_value_p, ) return check_parameters( - res, + value_p, + 0 <= p, + p <= 1, + pt.isclose(pt.sum(p, axis=-1), 1), + msg="0 <= p <=1, sum(p) = 1", + ) + + def logcdf(value, p): + k = pt.shape(p)[-1] + value, safe_value_p = Categorical._safe_index_value_p(value, p.cumsum(-1)) + + value_p = pt.switch( + pt.lt(value, 0), + -np.inf, + pt.switch( + pt.gt(value, k - 1), + 0, + safe_value_p, + ), + ) + + return check_parameters( + value_p, 0 <= p, p <= 1, pt.isclose(pt.sum(p, axis=-1), 1), diff --git a/tests/distributions/test_censored.py b/tests/distributions/test_censored.py index 6e8b0f9dcd..21dce537b0 100644 --- a/tests/distributions/test_censored.py +++ b/tests/distributions/test_censored.py @@ -17,6 +17,7 @@ import pymc as pm +from pymc import logp from pymc.distributions.shape_utils import change_dist_size @@ -110,3 +111,18 @@ def test_dist_broadcasted_by_lower_upper(self): pm.Normal.dist(size=(3, 4, 2)), lower=np.zeros((2,)), upper=np.zeros((4, 2)) ) assert tuple(x.owner.inputs[0].shape.eval()) == (3, 4, 2) + + def test_censored_categorical(self): + cat = pm.Categorical.dist([0.1, 0.2, 0.2, 0.3, 0.2], shape=(5,)) + + np.testing.assert_allclose( + logp(cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(), + [0, 0.1, 0.2, 0.2, 0.3, 0.2, 0], + ) + + censored_cat = pm.Censored.dist(cat, lower=1, upper=3, shape=(5,)) + + np.testing.assert_allclose( + logp(censored_cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(), + [0, 0, 0.3, 0.2, 0.5, 0, 0], + )