From 60c763f7a98fbabff3d0d36cce880639294944c3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 28 Jan 2025 11:00:36 +0100 Subject: [PATCH] Allow censoring Categorical distributions --- pymc/distributions/discrete.py | 45 +++++++++++++++++++---- tests/distributions/test_censored.py | 16 ++++++++ tests/distributions/test_discrete.py | 55 ++++++++++++++++++---------- 3 files changed, 88 insertions(+), 28 deletions(-) diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index d2f35c8007..649ca246c0 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -1155,29 +1155,58 @@ def support_point(rv, size, p): mode = pt.full(size, mode) return mode - def logp(value, p): - k = pt.shape(p)[-1] - value_clip = pt.clip(value, 0, k - 1) + @staticmethod + def _safe_index_value_p(value, p): + # Find the probabily of the given value by indexing in p, + # after handling broadcasting and invalid values. # In the standard case p has one more dimension than value dim_diff = p.type.ndim - value.type.ndim if dim_diff > 1: # p brodacasts implicitly beyond value - value_clip = pt.shape_padleft(value_clip, dim_diff - 1) + value = pt.shape_padleft(value, dim_diff - 1) elif dim_diff < 1: # value broadcasts implicitly beyond p p = pt.shape_padleft(p, 1 - dim_diff) - a = pt.log(pt.take_along_axis(p, value_clip[..., None], axis=-1).squeeze(-1)) + k = pt.shape(p)[-1] + value_clip = pt.clip(value, 0, k - 1).astype(int) + return value, pt.log(pt.take_along_axis(p, value_clip[..., None], axis=-1).squeeze(-1)) - res = pt.switch( + def logp(value, p): + k = pt.shape(p)[-1] + value, safe_value_p = Categorical._safe_index_value_p(value, p) + + value_p = pt.switch( pt.or_(pt.lt(value, 0), pt.gt(value, k - 1)), -np.inf, - a, + safe_value_p, ) return check_parameters( - res, + value_p, + 0 <= p, + p <= 1, + pt.isclose(pt.sum(p, axis=-1), 1), + msg="0 <= p <=1, sum(p) = 1", + ) + + def logcdf(value, p): + k = pt.shape(p)[-1] + value, safe_value_p = Categorical._safe_index_value_p(value, p.cumsum(-1)) + + value_p = pt.switch( + pt.lt(value, 0), + -np.inf, + pt.switch( + pt.gt(value, k - 1), + 0, + safe_value_p, + ), + ) + + return check_parameters( + value_p, 0 <= p, p <= 1, pt.isclose(pt.sum(p, axis=-1), 1), diff --git a/tests/distributions/test_censored.py b/tests/distributions/test_censored.py index 6e8b0f9dcd..21dce537b0 100644 --- a/tests/distributions/test_censored.py +++ b/tests/distributions/test_censored.py @@ -17,6 +17,7 @@ import pymc as pm +from pymc import logp from pymc.distributions.shape_utils import change_dist_size @@ -110,3 +111,18 @@ def test_dist_broadcasted_by_lower_upper(self): pm.Normal.dist(size=(3, 4, 2)), lower=np.zeros((2,)), upper=np.zeros((4, 2)) ) assert tuple(x.owner.inputs[0].shape.eval()) == (3, 4, 2) + + def test_censored_categorical(self): + cat = pm.Categorical.dist([0.1, 0.2, 0.2, 0.3, 0.2], shape=(5,)) + + np.testing.assert_allclose( + logp(cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(), + [0, 0.1, 0.2, 0.2, 0.3, 0.2, 0], + ) + + censored_cat = pm.Censored.dist(cat, lower=1, upper=3, shape=(5,)) + + np.testing.assert_allclose( + logp(censored_cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(), + [0, 0, 0.3, 0.2, 0.5, 0, 0], + ) diff --git a/tests/distributions/test_discrete.py b/tests/distributions/test_discrete.py index 24eeb504c9..55e8c23128 100644 --- a/tests/distributions/test_discrete.py +++ b/tests/distributions/test_discrete.py @@ -367,43 +367,58 @@ def test_poisson(self): @pytest.mark.parametrize("n", [2, 3, 4]) def test_categorical(self, n): + domain = Domain(range(n), dtype="int64", edges=(0, n)) + paramdomains = {"p": Simplex(n)} + check_logp( pm.Categorical, - Domain(range(n), dtype="int64", edges=(0, n)), - {"p": Simplex(n)}, + domain, + paramdomains, lambda value, p: categorical_logpdf(value, p), ) - def test_categorical_logp_batch_dims(self): + check_selfconsistency_discrete_logcdf( + pm.Categorical, + domain, + paramdomains, + ) + + @pytest.mark.parametrize("method", (logp, logcdf), ids=lambda x: x.__name__) + def test_categorical_logp_batch_dims(self, method): # Core case p = np.array([0.2, 0.3, 0.5]) value = np.array(2.0) - logp_expr = logp(pm.Categorical.dist(p=p, shape=value.shape), value) - assert logp_expr.type.ndim == 0 - np.testing.assert_allclose(logp_expr.eval(), np.log(0.5)) + expr = method(pm.Categorical.dist(p=p, shape=value.shape), value) + assert expr.type.ndim == 0 + expected_p = 0.5 if method is logp else 1.0 + np.testing.assert_allclose(expr.exp().eval(), expected_p) # Explicit batched value broadcasts p bcast_p = p[None] # shape (1, 3) batch_value = np.array([0, 1]) # shape(3,) - logp_expr = logp(pm.Categorical.dist(p=bcast_p, shape=batch_value.shape), batch_value) - assert logp_expr.type.ndim == 1 - np.testing.assert_allclose(logp_expr.eval(), np.log([0.2, 0.3])) + expr = method(pm.Categorical.dist(p=bcast_p, shape=batch_value.shape), batch_value) + assert expr.type.ndim == 1 + expected_p = [0.2, 0.3] if method is logp else [0.2, 0.5] + np.testing.assert_allclose(expr.exp().eval(), expected_p) + + # Implicit batch value broadcasts p + expr = method(pm.Categorical.dist(p=p, shape=()), batch_value) + assert expr.type.ndim == 1 + expected_p = [0.2, 0.3] if method is logp else [0.2, 0.5] + np.testing.assert_allclose(expr.exp().eval(), expected_p) # Explicit batched value and batched p batch_p = np.array([p[::-1], p]) - logp_expr = logp(pm.Categorical.dist(p=batch_p, shape=batch_value.shape), batch_value) - assert logp_expr.type.ndim == 1 - np.testing.assert_allclose(logp_expr.eval(), np.log([0.5, 0.3])) - - # Implicit batch value broadcasts p - logp_expr = logp(pm.Categorical.dist(p=p, shape=()), batch_value) - assert logp_expr.type.ndim == 1 - np.testing.assert_allclose(logp_expr.eval(), np.log([0.2, 0.3])) + expr = method(pm.Categorical.dist(p=batch_p, shape=batch_value.shape), batch_value) + assert expr.type.ndim == 1 + expected_p = [0.5, 0.3] if method is logp else [0.5, 0.5] + np.testing.assert_allclose(expr.exp().eval(), expected_p) # Implicit batch p broadcasts value - logp_expr = logp(pm.Categorical.dist(p=batch_p, shape=None), value) - assert logp_expr.type.ndim == 1 - np.testing.assert_allclose(logp_expr.eval(), np.log([0.2, 0.5])) + expr = method(pm.Categorical.dist(p=batch_p, shape=None), value) + assert expr.type.ndim == 1 + expected_p = [0.2, 0.5] if method is logp else [1.0, 1.0] + np.testing.assert_allclose(expr.exp().eval(), expected_p) @pytensor.config.change_flags(compute_test_value="raise") def test_categorical_bounds(self):