Skip to content

Commit

Permalink
Raise NotImplementedError for multivariate CustomDists
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 31, 2023
1 parent 2303bf9 commit 1a06d50
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,11 @@ def dist(
class_name: str = "CustomDist",
**kwargs,
):
if ndim_supp > 0:
raise NotImplementedError(
"CustomDist with ndim_supp > 0 and without a `dist` function are not supported."
)

dist_params = [as_tensor_variable(param) for param in dist_params]

# Assume scalar ndims_params
Expand Down
16 changes: 16 additions & 0 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ def test_custom_dist_without_random(self):
with pytest.raises(NotImplementedError):
pm.sample_posterior_predictive(idata, model=model)

@pytest.mark.xfail(
NotImplementedError,
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
)
@pytest.mark.parametrize("size", [(), (3,), (3, 2)], ids=str)
def test_custom_dist_with_random_multivariate(self, size):
supp_shape = 5
Expand Down Expand Up @@ -264,6 +268,10 @@ def test_custom_dist_old_api_error(self):
):
CustomDist("a", lambda x: x)

@pytest.mark.xfail(
NotImplementedError,
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
)
@pytest.mark.parametrize("size", [None, (), (2,)], ids=str)
def test_custom_dist_multivariate_logp(self, size):
supp_shape = 5
Expand Down Expand Up @@ -314,6 +322,10 @@ def density_moment(rv, size, mu):
assert evaled_moment.shape == to_tuple(size)
assert np.all(evaled_moment == mu_val)

@pytest.mark.xfail(
NotImplementedError,
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
)
@pytest.mark.parametrize("size", [(), (2,), (3, 2)], ids=str)
def test_custom_dist_custom_moment_multivariate(self, size):
def density_moment(rv, size, mu):
Expand All @@ -328,6 +340,10 @@ def density_moment(rv, size, mu):
assert evaled_moment.shape == to_tuple(size) + (5,)
assert np.all(evaled_moment == mu_val)

@pytest.mark.xfail(
NotImplementedError,
reason="Support shape of multivariate CustomDist cannot be inferred. See https://github.com/pymc-devs/pytensor/pull/388",
)
@pytest.mark.parametrize(
"with_random, size",
[
Expand Down

0 comments on commit 1a06d50

Please sign in to comment.