Skip to content

Commit

Permalink
Bring back testing utilities used in downstream packages
Browse files Browse the repository at this point in the history
Follow up to
* 534a9ae
* e1d36ca
  • Loading branch information
ricardoV94 committed Mar 8, 2023
1 parent 49aacf4 commit a41d524
Show file tree
Hide file tree
Showing 37 changed files with 207 additions and 194 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ jobs:
tests/test_func_utils.py
tests/distributions/test_shape_utils.py
tests/distributions/test_mixture.py
tests/test_testing.py
- |
tests/distributions/test_continuous.py
Expand Down
58 changes: 29 additions & 29 deletions docs/source/contributing/implementing_distribution.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,11 @@ Most tests can be accommodated by the default `BaseTestDistributionRandom` class
1. Shape variable inference is correct, via `check_rv_size`

```python
from tests.distributions.util import BaseTestDistributionRandom, seeded_scipy_distribution_builder

class TestBlah(BaseTestDistributionRandom):
from pymc.testing import BaseTestDistributionRandom, seeded_scipy_distribution_builder


class TestBlah(BaseTestDistributionRandom):
pymc_dist = pm.Blah
# Parameters with which to test the blah pymc Distribution
pymc_dist_params = {"param1": 0.25, "param2": 2.0}
Expand Down Expand Up @@ -311,38 +312,36 @@ Tests for the `logp` and `logcdf` mostly make use of the helpers `check_logp`, `
`check_selfconsistency_discrete_logcdf` implemented in `~tests.distributions.util`

```python
from tests.helpers import select_by_precision
from tests.distributions.util import check_logp, check_logcdf, Domain

from pymc.testing import Domain, check_logp, check_logcdf, select_by_precision

R = Domain([-np.inf, -2.1, -1, -0.01, 0.0, 0.01, 1, 2.1, np.inf])
Rplus = Domain([0, 0.01, 0.1, 0.9, 0.99, 1, 1.5, 2, 100, np.inf])



def test_blah():

check_logp(
pymc_dist=pm.Blah,
# Domain of the distribution values
domain=R,
# Domains of the distribution parameters
paramdomains={"mu": R, "sigma": Rplus},
# Reference scipy (or other) logp function
scipy_logp = lambda value, mu, sigma: sp.norm.logpdf(value, mu, sigma),
# Number of decimal points expected to match between the pymc and reference functions
decimal=select_by_precision(float64=6, float32=3),
# Maximum number of combinations of domain * paramdomains to test
n_samples=100,
)

check_logcdf(
pymc_dist=pm.Blah,
domain=R,
paramdomains={"mu": R, "sigma": Rplus},
scipy_logcdf=lambda value, mu, sigma: sp.norm.logcdf(value, mu, sigma),
decimal=select_by_precision(float64=6, float32=1),
n_samples=-1,
)
check_logp(
pymc_dist=pm.Blah,
# Domain of the distribution values
domain=R,
# Domains of the distribution parameters
paramdomains={"mu": R, "sigma": Rplus},
# Reference scipy (or other) logp function
scipy_logp=lambda value, mu, sigma: sp.norm.logpdf(value, mu, sigma),
# Number of decimal points expected to match between the pymc and reference functions
decimal=select_by_precision(float64=6, float32=3),
# Maximum number of combinations of domain * paramdomains to test
n_samples=100,
)

check_logcdf(
pymc_dist=pm.Blah,
domain=R,
paramdomains={"mu": R, "sigma": Rplus},
scipy_logcdf=lambda value, mu, sigma: sp.norm.logcdf(value, mu, sigma),
decimal=select_by_precision(float64=6, float32=1),
n_samples=-1,
)

```

Expand Down Expand Up @@ -382,7 +381,8 @@ which checks if:

import pytest
from pymc.distributions import Blah
from tests.distributions.util import assert_moment_is_expected
from pymc.testing import assert_moment_is_expected


@pytest.mark.parametrize(
"param1, param2, size, expected",
Expand Down
Loading

0 comments on commit a41d524

Please sign in to comment.