Skip to content

Commit

Permalink
Pass size to specialized truncated dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Sep 30, 2022
1 parent e419d53 commit 244c37d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
6 changes: 3 additions & 3 deletions pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def update(self, node: Node):


@singledispatch
def _truncated(op: Op, lower, upper, *params):
def _truncated(op: Op, lower, upper, size, *params):
"""Return the truncated equivalent of another `RandomVariable`."""
raise NotImplementedError(f"{op} does not have an equivalent truncated version implemented")

Expand Down Expand Up @@ -150,7 +150,7 @@ def rv_op(cls, dist, lower, upper, max_n_steps, size=None):

# Try to use specialized Op
try:
return _truncated(dist.owner.op, lower, upper, *dist.owner.inputs)
return _truncated(dist.owner.op, lower, upper, size, *dist.owner.inputs)
except NotImplementedError:
pass

Expand Down Expand Up @@ -339,7 +339,7 @@ def truncated_logprob(op, values, *inputs, **kwargs):


@_truncated.register(NormalRV)
def _truncated_normal(op, lower, upper, rng, size, dtype, mu, sigma):
def _truncated_normal(op, lower, upper, size, rng, old_size, dtype, mu, sigma):
return TruncatedNormal.dist(
mu=mu,
sigma=sigma,
Expand Down
23 changes: 21 additions & 2 deletions pymc/tests/distributions/test_truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,31 @@ def _icdf_not_implemented(*args, **kwargs):
raise NotImplementedError()


def test_truncation_specialized_op():
@pytest.mark.parametrize("shape_info", ("shape", "dims", "observed"))
def test_truncation_specialized_op(shape_info):
rng = aesara.shared(np.random.default_rng())
x = at.random.normal(0, 10, rng=rng, name="x")

xt = Truncated.dist(x, lower=5, upper=15, shape=(100,))
with Model(coords={"dim": range(100)}) as m:
if shape_info == "shape":
xt = Truncated("xt", dist=x, lower=5, upper=15, shape=(100,))
elif shape_info == "dims":
xt = Truncated("xt", dist=x, lower=5, upper=15, dims=("dim",))
elif shape_info == "observed":
xt = Truncated(
"xt",
dist=x,
lower=5,
upper=15,
observed=np.empty(
100,
),
)
else:
raise ValueError(f"Not a valid shape_info parametrization: {shape_info}")

assert isinstance(xt.owner.op, TruncatedNormalRV)
assert xt.shape.eval() == (100,)

# Test RNG is not reused
assert xt.owner.inputs[0] is not rng
Expand Down

0 comments on commit 244c37d

Please sign in to comment.