Skip to content

Commit

Permalink
Allow censoring Categorical distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 30, 2025
1 parent fa43eba commit 60c763f
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 28 deletions.
45 changes: 37 additions & 8 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,29 +1155,58 @@ def support_point(rv, size, p):
mode = pt.full(size, mode)
return mode

def logp(value, p):
k = pt.shape(p)[-1]
value_clip = pt.clip(value, 0, k - 1)
@staticmethod
def _safe_index_value_p(value, p):
# Find the probabily of the given value by indexing in p,
# after handling broadcasting and invalid values.

# In the standard case p has one more dimension than value
dim_diff = p.type.ndim - value.type.ndim
if dim_diff > 1:
# p brodacasts implicitly beyond value
value_clip = pt.shape_padleft(value_clip, dim_diff - 1)
value = pt.shape_padleft(value, dim_diff - 1)
elif dim_diff < 1:
# value broadcasts implicitly beyond p
p = pt.shape_padleft(p, 1 - dim_diff)

a = pt.log(pt.take_along_axis(p, value_clip[..., None], axis=-1).squeeze(-1))
k = pt.shape(p)[-1]
value_clip = pt.clip(value, 0, k - 1).astype(int)
return value, pt.log(pt.take_along_axis(p, value_clip[..., None], axis=-1).squeeze(-1))

res = pt.switch(
def logp(value, p):
k = pt.shape(p)[-1]
value, safe_value_p = Categorical._safe_index_value_p(value, p)

value_p = pt.switch(
pt.or_(pt.lt(value, 0), pt.gt(value, k - 1)),
-np.inf,
a,
safe_value_p,
)

return check_parameters(
res,
value_p,
0 <= p,
p <= 1,
pt.isclose(pt.sum(p, axis=-1), 1),
msg="0 <= p <=1, sum(p) = 1",
)

def logcdf(value, p):
k = pt.shape(p)[-1]
value, safe_value_p = Categorical._safe_index_value_p(value, p.cumsum(-1))

value_p = pt.switch(
pt.lt(value, 0),
-np.inf,
pt.switch(
pt.gt(value, k - 1),
0,
safe_value_p,
),
)

return check_parameters(
value_p,
0 <= p,
p <= 1,
pt.isclose(pt.sum(p, axis=-1), 1),
Expand Down
16 changes: 16 additions & 0 deletions tests/distributions/test_censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pymc as pm

from pymc import logp
from pymc.distributions.shape_utils import change_dist_size


Expand Down Expand Up @@ -110,3 +111,18 @@ def test_dist_broadcasted_by_lower_upper(self):
pm.Normal.dist(size=(3, 4, 2)), lower=np.zeros((2,)), upper=np.zeros((4, 2))
)
assert tuple(x.owner.inputs[0].shape.eval()) == (3, 4, 2)

def test_censored_categorical(self):
cat = pm.Categorical.dist([0.1, 0.2, 0.2, 0.3, 0.2], shape=(5,))

np.testing.assert_allclose(
logp(cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(),
[0, 0.1, 0.2, 0.2, 0.3, 0.2, 0],
)

censored_cat = pm.Censored.dist(cat, lower=1, upper=3, shape=(5,))

np.testing.assert_allclose(
logp(censored_cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(),
[0, 0, 0.3, 0.2, 0.5, 0, 0],
)
55 changes: 35 additions & 20 deletions tests/distributions/test_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,43 +367,58 @@ def test_poisson(self):

@pytest.mark.parametrize("n", [2, 3, 4])
def test_categorical(self, n):
domain = Domain(range(n), dtype="int64", edges=(0, n))
paramdomains = {"p": Simplex(n)}

check_logp(
pm.Categorical,
Domain(range(n), dtype="int64", edges=(0, n)),
{"p": Simplex(n)},
domain,
paramdomains,
lambda value, p: categorical_logpdf(value, p),
)

def test_categorical_logp_batch_dims(self):
check_selfconsistency_discrete_logcdf(
pm.Categorical,
domain,
paramdomains,
)

@pytest.mark.parametrize("method", (logp, logcdf), ids=lambda x: x.__name__)
def test_categorical_logp_batch_dims(self, method):
# Core case
p = np.array([0.2, 0.3, 0.5])
value = np.array(2.0)
logp_expr = logp(pm.Categorical.dist(p=p, shape=value.shape), value)
assert logp_expr.type.ndim == 0
np.testing.assert_allclose(logp_expr.eval(), np.log(0.5))
expr = method(pm.Categorical.dist(p=p, shape=value.shape), value)
assert expr.type.ndim == 0
expected_p = 0.5 if method is logp else 1.0
np.testing.assert_allclose(expr.exp().eval(), expected_p)

# Explicit batched value broadcasts p
bcast_p = p[None] # shape (1, 3)
batch_value = np.array([0, 1]) # shape(3,)
logp_expr = logp(pm.Categorical.dist(p=bcast_p, shape=batch_value.shape), batch_value)
assert logp_expr.type.ndim == 1
np.testing.assert_allclose(logp_expr.eval(), np.log([0.2, 0.3]))
expr = method(pm.Categorical.dist(p=bcast_p, shape=batch_value.shape), batch_value)
assert expr.type.ndim == 1
expected_p = [0.2, 0.3] if method is logp else [0.2, 0.5]
np.testing.assert_allclose(expr.exp().eval(), expected_p)

# Implicit batch value broadcasts p
expr = method(pm.Categorical.dist(p=p, shape=()), batch_value)
assert expr.type.ndim == 1
expected_p = [0.2, 0.3] if method is logp else [0.2, 0.5]
np.testing.assert_allclose(expr.exp().eval(), expected_p)

# Explicit batched value and batched p
batch_p = np.array([p[::-1], p])
logp_expr = logp(pm.Categorical.dist(p=batch_p, shape=batch_value.shape), batch_value)
assert logp_expr.type.ndim == 1
np.testing.assert_allclose(logp_expr.eval(), np.log([0.5, 0.3]))

# Implicit batch value broadcasts p
logp_expr = logp(pm.Categorical.dist(p=p, shape=()), batch_value)
assert logp_expr.type.ndim == 1
np.testing.assert_allclose(logp_expr.eval(), np.log([0.2, 0.3]))
expr = method(pm.Categorical.dist(p=batch_p, shape=batch_value.shape), batch_value)
assert expr.type.ndim == 1
expected_p = [0.5, 0.3] if method is logp else [0.5, 0.5]
np.testing.assert_allclose(expr.exp().eval(), expected_p)

# Implicit batch p broadcasts value
logp_expr = logp(pm.Categorical.dist(p=batch_p, shape=None), value)
assert logp_expr.type.ndim == 1
np.testing.assert_allclose(logp_expr.eval(), np.log([0.2, 0.5]))
expr = method(pm.Categorical.dist(p=batch_p, shape=None), value)
assert expr.type.ndim == 1
expected_p = [0.2, 0.5] if method is logp else [1.0, 1.0]
np.testing.assert_allclose(expr.exp().eval(), expected_p)

@pytensor.config.change_flags(compute_test_value="raise")
def test_categorical_bounds(self):
Expand Down

0 comments on commit 60c763f

Please sign in to comment.