Skip to content

Commit

Permalink
Allow censoring Categorical distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 28, 2025
1 parent fa43eba commit a3cbfee
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 8 deletions.
45 changes: 37 additions & 8 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,29 +1155,58 @@ def support_point(rv, size, p):
mode = pt.full(size, mode)
return mode

def logp(value, p):
k = pt.shape(p)[-1]
value_clip = pt.clip(value, 0, k - 1)
@staticmethod
def _safe_index_value_p(value, p):
# Find the probabily of the given value by indexing in p,
# after handling broadcasting and invalid values.

# In the standard case p has one more dimension than value
dim_diff = p.type.ndim - value.type.ndim
if dim_diff > 1:
# p brodacasts implicitly beyond value
value_clip = pt.shape_padleft(value_clip, dim_diff - 1)
value = pt.shape_padleft(value, dim_diff - 1)
elif dim_diff < 1:
# value broadcasts implicitly beyond p
p = pt.shape_padleft(p, 1 - dim_diff)

a = pt.log(pt.take_along_axis(p, value_clip[..., None], axis=-1).squeeze(-1))
k = pt.shape(p)[-1]
value_clip = pt.clip(value, 0, k - 1).astype(int)
return value, pt.log(pt.take_along_axis(p, value_clip[..., None], axis=-1).squeeze(-1))

res = pt.switch(
def logp(value, p):
k = pt.shape(p)[-1]
value, safe_value_p = Categorical._safe_index_value_p(value, p)

value_p = pt.switch(
pt.or_(pt.lt(value, 0), pt.gt(value, k - 1)),
-np.inf,
a,
safe_value_p,
)

return check_parameters(
res,
value_p,
0 <= p,
p <= 1,
pt.isclose(pt.sum(p, axis=-1), 1),
msg="0 <= p <=1, sum(p) = 1",
)

def logcdf(value, p):
k = pt.shape(p)[-1]
value, safe_value_p = Categorical._safe_index_value_p(value, p.cumsum(-1))

value_p = pt.switch(
pt.lt(value, 0),
-np.inf,
pt.switch(
pt.gt(value, k - 1),
0,
safe_value_p,
),
)

return check_parameters(
value_p,
0 <= p,
p <= 1,
pt.isclose(pt.sum(p, axis=-1), 1),
Expand Down
16 changes: 16 additions & 0 deletions tests/distributions/test_censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pymc as pm

from pymc import logp
from pymc.distributions.shape_utils import change_dist_size


Expand Down Expand Up @@ -110,3 +111,18 @@ def test_dist_broadcasted_by_lower_upper(self):
pm.Normal.dist(size=(3, 4, 2)), lower=np.zeros((2,)), upper=np.zeros((4, 2))
)
assert tuple(x.owner.inputs[0].shape.eval()) == (3, 4, 2)

def test_censored_categorical(self):
cat = pm.Categorical.dist([0.1, 0.2, 0.2, 0.3, 0.2], shape=(5,))

np.testing.assert_allclose(
logp(cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(),
[0, 0.1, 0.2, 0.2, 0.3, 0.2, 0],
)

censored_cat = pm.Censored.dist(cat, lower=1, upper=3, shape=(5,))

np.testing.assert_allclose(
logp(censored_cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(),
[0, 0, 0.3, 0.2, 0.5, 0, 0],
)

0 comments on commit a3cbfee

Please sign in to comment.