From f599909d8cac4aa67993c220ffe22cfe79774a43 Mon Sep 17 00:00:00 2001 From: ntfrgl Date: Mon, 10 Oct 2022 20:42:50 -0700 Subject: [PATCH 1/3] Finish moving `ppl.testlib`: src/ -> tests/ --- .pyre_configuration | 2 -- README.md | 2 +- src/beanmachine/ppl/testlib/__init__.py | 0 .../inference/compositional_infer_conjugate_test_nightly.py | 2 +- tests/ppl/inference/hypothesis_testing_nightly.py | 2 +- .../single_site_ancestral_mh_conjugate_test_nightly.py | 2 +- ...le_site_hamiltonian_monte_carlo_conjugate_test_nightly.py | 2 +- ...ngle_site_newtonian_monte_carlo_conjugate_test_nightly.py | 2 +- .../single_site_no_u_turn_conjugate_test_nightly.py | 2 +- ...ingle_site_random_walk_adaptive_conjugate_test_nightly.py | 2 +- .../single_site_random_walk_conjugate_test_nightly.py | 2 +- .../single_site_uniform_mh_conjugate_test_nightly.py | 2 +- {src/beanmachine => tests}/ppl/testlib/abstract_conjugate.py | 5 +++-- {src/beanmachine => tests}/ppl/testlib/hypothesis_testing.py | 0 tests/ppl/testlib/hypothesis_testing_test.py | 5 +++-- 15 files changed, 16 insertions(+), 16 deletions(-) delete mode 100644 src/beanmachine/ppl/testlib/__init__.py rename {src/beanmachine => tests}/ppl/testlib/abstract_conjugate.py (99%) rename {src/beanmachine => tests}/ppl/testlib/hypothesis_testing.py (100%) diff --git a/.pyre_configuration b/.pyre_configuration index 279623e875..2685e535d6 100644 --- a/.pyre_configuration +++ b/.pyre_configuration @@ -5,8 +5,6 @@ "src/beanmachine/ppl/compiler/runtime.py", "src/beanmachine/ppl/inference/base_inference.py", "src/beanmachine/ppl/inference/bmg_inference.py", - "src/beanmachine/ppl/testlib/abstract_conjugate.py", - "src/beanmachine/ppl/testlib/hypothesis_testing.py", "src/beanmachine/ppl/experimental/torch_jit_backend.py", "src/beanmachine/ppl/diagnostics/tools/utils/diagnostic_tool_base.py" ], diff --git a/README.md b/README.md index af8c70aded..6cacd0e2e9 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ If you would like to run the builtin unit tests: ```bash python -m pip install "beanmachine[test]" -pytest . +pytest ``` ## License diff --git a/src/beanmachine/ppl/testlib/__init__.py b/src/beanmachine/ppl/testlib/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/ppl/inference/compositional_infer_conjugate_test_nightly.py b/tests/ppl/inference/compositional_infer_conjugate_test_nightly.py index bfe828f320..942ca3ee68 100644 --- a/tests/ppl/inference/compositional_infer_conjugate_test_nightly.py +++ b/tests/ppl/inference/compositional_infer_conjugate_test_nightly.py @@ -6,7 +6,7 @@ import unittest from beanmachine.ppl.inference.compositional_infer import CompositionalInference -from beanmachine.ppl.testlib.abstract_conjugate import AbstractConjugateTests +from ..testlib.abstract_conjugate import AbstractConjugateTests class CompositionalInferenceConjugateTest(unittest.TestCase, AbstractConjugateTests): diff --git a/tests/ppl/inference/hypothesis_testing_nightly.py b/tests/ppl/inference/hypothesis_testing_nightly.py index 84e57afed5..bed590f084 100644 --- a/tests/ppl/inference/hypothesis_testing_nightly.py +++ b/tests/ppl/inference/hypothesis_testing_nightly.py @@ -7,7 +7,7 @@ from sys import float_info import torch.distributions as dist -from beanmachine.ppl.testlib.hypothesis_testing import ( +from ..testlib.hypothesis_testing import ( inverse_normal_cdf, mean_equality_hypothesis_confidence_interval, mean_equality_hypothesis_test, diff --git a/tests/ppl/inference/single_site_ancestral_mh_conjugate_test_nightly.py b/tests/ppl/inference/single_site_ancestral_mh_conjugate_test_nightly.py index 85457da847..2c2c16ec24 100644 --- a/tests/ppl/inference/single_site_ancestral_mh_conjugate_test_nightly.py +++ b/tests/ppl/inference/single_site_ancestral_mh_conjugate_test_nightly.py @@ -6,7 +6,7 @@ import unittest import beanmachine.ppl as bm -from beanmachine.ppl.testlib.abstract_conjugate import AbstractConjugateTests +from ..testlib.abstract_conjugate import AbstractConjugateTests class SingleSiteAncestralMetropolisHastingsConjugateTest( diff --git a/tests/ppl/inference/single_site_hamiltonian_monte_carlo_conjugate_test_nightly.py b/tests/ppl/inference/single_site_hamiltonian_monte_carlo_conjugate_test_nightly.py index 01479b55e1..6593879bc3 100644 --- a/tests/ppl/inference/single_site_hamiltonian_monte_carlo_conjugate_test_nightly.py +++ b/tests/ppl/inference/single_site_hamiltonian_monte_carlo_conjugate_test_nightly.py @@ -6,7 +6,7 @@ import unittest import beanmachine.ppl as bm -from beanmachine.ppl.testlib.abstract_conjugate import AbstractConjugateTests +from ..testlib.abstract_conjugate import AbstractConjugateTests class SingleSiteHamiltonianMonteCarloConjugateTest( diff --git a/tests/ppl/inference/single_site_newtonian_monte_carlo_conjugate_test_nightly.py b/tests/ppl/inference/single_site_newtonian_monte_carlo_conjugate_test_nightly.py index a91963f404..abd8c9ddab 100644 --- a/tests/ppl/inference/single_site_newtonian_monte_carlo_conjugate_test_nightly.py +++ b/tests/ppl/inference/single_site_newtonian_monte_carlo_conjugate_test_nightly.py @@ -6,7 +6,7 @@ import unittest import beanmachine.ppl as bm -from beanmachine.ppl.testlib.abstract_conjugate import AbstractConjugateTests +from ..testlib.abstract_conjugate import AbstractConjugateTests class SingleSiteNewtonianMonteCarloConjugateTest( diff --git a/tests/ppl/inference/single_site_no_u_turn_conjugate_test_nightly.py b/tests/ppl/inference/single_site_no_u_turn_conjugate_test_nightly.py index 8c98cd0c6e..172e2739fd 100644 --- a/tests/ppl/inference/single_site_no_u_turn_conjugate_test_nightly.py +++ b/tests/ppl/inference/single_site_no_u_turn_conjugate_test_nightly.py @@ -6,7 +6,7 @@ import unittest import beanmachine.ppl as bm -from beanmachine.ppl.testlib.abstract_conjugate import AbstractConjugateTests +from ..testlib.abstract_conjugate import AbstractConjugateTests class SingleSiteNoUTurnConjugateTest(unittest.TestCase, AbstractConjugateTests): diff --git a/tests/ppl/inference/single_site_random_walk_adaptive_conjugate_test_nightly.py b/tests/ppl/inference/single_site_random_walk_adaptive_conjugate_test_nightly.py index 771a30a343..c8453e746c 100644 --- a/tests/ppl/inference/single_site_random_walk_adaptive_conjugate_test_nightly.py +++ b/tests/ppl/inference/single_site_random_walk_adaptive_conjugate_test_nightly.py @@ -6,7 +6,7 @@ import unittest import beanmachine.ppl as bm -from beanmachine.ppl.testlib.abstract_conjugate import AbstractConjugateTests +from ..testlib.abstract_conjugate import AbstractConjugateTests class SingleSiteAdaptiveRandomWalkConjugateTest( diff --git a/tests/ppl/inference/single_site_random_walk_conjugate_test_nightly.py b/tests/ppl/inference/single_site_random_walk_conjugate_test_nightly.py index 94f6402db5..7a9ad76fed 100644 --- a/tests/ppl/inference/single_site_random_walk_conjugate_test_nightly.py +++ b/tests/ppl/inference/single_site_random_walk_conjugate_test_nightly.py @@ -6,7 +6,7 @@ import unittest import beanmachine.ppl as bm -from beanmachine.ppl.testlib.abstract_conjugate import AbstractConjugateTests +from ..testlib.abstract_conjugate import AbstractConjugateTests class SingleSiteRandomWalkConjugateTest(unittest.TestCase, AbstractConjugateTests): diff --git a/tests/ppl/inference/single_site_uniform_mh_conjugate_test_nightly.py b/tests/ppl/inference/single_site_uniform_mh_conjugate_test_nightly.py index 7881e71d1a..a00836f4fb 100644 --- a/tests/ppl/inference/single_site_uniform_mh_conjugate_test_nightly.py +++ b/tests/ppl/inference/single_site_uniform_mh_conjugate_test_nightly.py @@ -6,7 +6,7 @@ import unittest import beanmachine.ppl as bm -from beanmachine.ppl.testlib.abstract_conjugate import AbstractConjugateTests +from ..testlib.abstract_conjugate import AbstractConjugateTests class SingleSiteUniformMetropolisHastingsConjugateTest( diff --git a/src/beanmachine/ppl/testlib/abstract_conjugate.py b/tests/ppl/testlib/abstract_conjugate.py similarity index 99% rename from src/beanmachine/ppl/testlib/abstract_conjugate.py rename to tests/ppl/testlib/abstract_conjugate.py index 94d41c467f..15c25699c8 100644 --- a/src/beanmachine/ppl/testlib/abstract_conjugate.py +++ b/tests/ppl/testlib/abstract_conjugate.py @@ -20,11 +20,12 @@ from beanmachine.ppl.inference import utils from beanmachine.ppl.inference.base_inference import BaseInference from beanmachine.ppl.model.rv_identifier import RVIdentifier -from beanmachine.ppl.testlib.hypothesis_testing import ( +from torch import Tensor, tensor + +from .hypothesis_testing import ( mean_equality_hypothesis_confidence_interval, variance_equality_hypothesis_confidence_interval, ) -from torch import Tensor, tensor class AbstractConjugateTests(metaclass=ABCMeta): diff --git a/src/beanmachine/ppl/testlib/hypothesis_testing.py b/tests/ppl/testlib/hypothesis_testing.py similarity index 100% rename from src/beanmachine/ppl/testlib/hypothesis_testing.py rename to tests/ppl/testlib/hypothesis_testing.py diff --git a/tests/ppl/testlib/hypothesis_testing_test.py b/tests/ppl/testlib/hypothesis_testing_test.py index 00ceb25753..f0e65f5df9 100644 --- a/tests/ppl/testlib/hypothesis_testing_test.py +++ b/tests/ppl/testlib/hypothesis_testing_test.py @@ -6,7 +6,9 @@ """Tests for hypothesis_testing.py""" import unittest -from beanmachine.ppl.testlib.hypothesis_testing import ( +from torch import tensor + +from .hypothesis_testing import ( inverse_chi2_cdf, inverse_normal_cdf, mean_equality_hypothesis_confidence_interval, @@ -14,7 +16,6 @@ variance_equality_hypothesis_confidence_interval, variance_equality_hypothesis_test, ) -from torch import tensor class HypothesisTestingTest(unittest.TestCase): From a746ca7c6eff42988476fc0755a8580f23bda329 Mon Sep 17 00:00:00 2001 From: ntfrgl Date: Sun, 16 Oct 2022 17:41:27 -0700 Subject: [PATCH 2/3] Gather `ppl.examples.conjugate_models` in 1 module --- .../ppl/examples/conjugate_models.py | 146 ++++++++++++++++++ .../ppl/examples/conjugate_models/__init__.py | 21 --- .../conjugate_models/beta_bernoulli.py | 22 --- .../conjugate_models/beta_binomial.py | 34 ---- .../conjugate_models/categorical_dirichlet.py | 21 --- .../examples/conjugate_models/gamma_gamma.py | 23 --- .../examples/conjugate_models/gamma_normal.py | 25 --- .../conjugate_models/normal_normal.py | 23 --- .../fix_beta_bernoulli_alpha_rv_test.py | 10 +- .../compiler/fix_beta_binomial_basic_test.py | 2 +- .../compiler/fix_normal_normal_basic_test.py | 18 +-- tests/ppl/testlib/abstract_conjugate.py | 40 ++--- 12 files changed, 177 insertions(+), 208 deletions(-) create mode 100644 src/beanmachine/ppl/examples/conjugate_models.py delete mode 100644 src/beanmachine/ppl/examples/conjugate_models/__init__.py delete mode 100644 src/beanmachine/ppl/examples/conjugate_models/beta_bernoulli.py delete mode 100644 src/beanmachine/ppl/examples/conjugate_models/beta_binomial.py delete mode 100644 src/beanmachine/ppl/examples/conjugate_models/categorical_dirichlet.py delete mode 100644 src/beanmachine/ppl/examples/conjugate_models/gamma_gamma.py delete mode 100644 src/beanmachine/ppl/examples/conjugate_models/gamma_normal.py delete mode 100644 src/beanmachine/ppl/examples/conjugate_models/normal_normal.py diff --git a/src/beanmachine/ppl/examples/conjugate_models.py b/src/beanmachine/ppl/examples/conjugate_models.py new file mode 100644 index 0000000000..2f7aecf0bc --- /dev/null +++ b/src/beanmachine/ppl/examples/conjugate_models.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod + +import beanmachine.ppl as bm +import torch +import torch.distributions as dist +from torch import Tensor + + +class ConjugateModel(ABC): + """ + The Bean Machine models in this module are examples of conjugacy. Conjugacy + means the posterior will also be in the same family as the prior. The random + variable names theta and x follow the typical presentation of the conjugate + prior relation in the form of p(theta|x) = p(x|theta) * p(theta)/p(x). + + See: + https://en.wikipedia.org/wiki/Conjugate_prior + """ + + x_dim = 0 + """ + Number of indices in likelihood. + """ + + @abstractmethod + def theta(self) -> dist.Distribution: + """ + Prior of a conjugate model. + """ + pass + + @abstractmethod + def x(self, *args) -> dist.Distribution: + """ + Likelihood of a conjugate model. + """ + pass + + +class BetaBernoulliModel(ConjugateModel): + x_dim = 1 + + def __init__(self, alpha: Tensor, beta: Tensor) -> None: + self.alpha_ = alpha + self.beta_ = beta + + @bm.random_variable + def theta(self) -> dist.Distribution: + return dist.Beta(self.alpha_, self.beta_) + + @bm.random_variable + def x(self, i: int) -> dist.Distribution: + return dist.Bernoulli(self.theta()) + + +class BetaBinomialModel(ConjugateModel): + def __init__(self, alpha: Tensor, beta: Tensor, n: Tensor) -> None: + self.alpha_ = alpha + self.beta_ = beta + self.n_ = n + + @bm.random_variable + def theta(self) -> dist.Distribution: + return dist.Beta(self.alpha_, self.beta_) + + @bm.random_variable + def x(self) -> dist.Distribution: + return dist.Binomial(self.n_, self.theta()) + + +class CategoricalDirichletModel(ConjugateModel): + def __init__(self, alpha: Tensor) -> None: + self.alpha_ = alpha + + @bm.random_variable + def theta(self) -> dist.Distribution: + return dist.Dirichlet(self.alpha_) + + @bm.random_variable + def x(self) -> dist.Distribution: + return dist.Categorical(self.theta()) + + +class GammaGammaModel(ConjugateModel): + def __init__(self, shape: Tensor, rate: Tensor, alpha: Tensor) -> None: + self.shape_ = shape + self.rate_ = rate + self.alpha_ = alpha + + @bm.random_variable + def theta(self) -> dist.Distribution: + return dist.Gamma(self.shape_, self.rate_) + + @bm.random_variable + def x(self) -> dist.Distribution: + return dist.Gamma(self.alpha_, self.theta()) + + +class GammaNormalModel(ConjugateModel): + def __init__(self, shape: Tensor, rate: Tensor, mu: Tensor) -> None: + self.shape_ = shape + self.rate_ = rate + self.mu_ = mu + + @bm.random_variable + def theta(self) -> dist.Distribution: + return dist.Gamma(self.shape_, self.rate_) + + @bm.random_variable + def x(self) -> dist.Distribution: + return dist.Normal(self.mu_, torch.tensor(1) / torch.sqrt(self.theta())) + + +class NormalNormalModel(ConjugateModel): + def __init__(self, mu: Tensor, sigma: Tensor, std: Tensor) -> None: + self.mu = mu + self.sigma = sigma + self.std = std + + @bm.random_variable + def theta(self) -> dist.Distribution: + return dist.Normal(self.mu, self.sigma) + + @bm.random_variable + def x(self) -> dist.Distribution: + return dist.Normal(self.theta(), self.std) + + +class MV_NormalNormalModel(ConjugateModel): + def __init__(self, mu, sigma, std) -> None: + self.mu = mu + self.sigma = sigma + self.std = std + + @bm.random_variable + def theta(self): + return dist.MultivariateNormal(self.mu, self.sigma) + + @bm.random_variable + def x(self): + return dist.MultivariateNormal(self.theta(), self.std) diff --git a/src/beanmachine/ppl/examples/conjugate_models/__init__.py b/src/beanmachine/ppl/examples/conjugate_models/__init__.py deleted file mode 100644 index 84a4c22ec2..0000000000 --- a/src/beanmachine/ppl/examples/conjugate_models/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from beanmachine.ppl.examples.conjugate_models.beta_binomial import BetaBinomialModel -from beanmachine.ppl.examples.conjugate_models.categorical_dirichlet import ( - CategoricalDirichletModel, -) -from beanmachine.ppl.examples.conjugate_models.gamma_gamma import GammaGammaModel -from beanmachine.ppl.examples.conjugate_models.gamma_normal import GammaNormalModel -from beanmachine.ppl.examples.conjugate_models.normal_normal import NormalNormalModel - - -__all__ = [ - "BetaBinomialModel", - "CategoricalDirichletModel", - "GammaGammaModel", - "GammaNormalModel", - "NormalNormalModel", -] diff --git a/src/beanmachine/ppl/examples/conjugate_models/beta_bernoulli.py b/src/beanmachine/ppl/examples/conjugate_models/beta_bernoulli.py deleted file mode 100644 index ce2893463e..0000000000 --- a/src/beanmachine/ppl/examples/conjugate_models/beta_bernoulli.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import beanmachine.ppl as bm -import torch.distributions as dist -from torch import Tensor - - -class BetaBernoulliModel: - def __init__(self, alpha: Tensor, beta: Tensor) -> None: - self.alpha_ = alpha - self.beta_ = beta - - @bm.random_variable - def theta(self) -> dist.Distribution: - return dist.Beta(self.alpha_, self.beta_) - - @bm.random_variable - def y(self, i: int) -> dist.Distribution: - return dist.Bernoulli(self.theta()) diff --git a/src/beanmachine/ppl/examples/conjugate_models/beta_binomial.py b/src/beanmachine/ppl/examples/conjugate_models/beta_binomial.py deleted file mode 100644 index e4abece3d2..0000000000 --- a/src/beanmachine/ppl/examples/conjugate_models/beta_binomial.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import beanmachine.ppl as bm -import torch.distributions as dist -from torch import Tensor - - -class BetaBinomialModel: - """This Bean Machine model is an example of conjugacy, where - the prior and the likelihood are the Beta and the Binomial - distributions respectively. Conjugacy means the posterior - will also be in the same family as the prior, Beta. - The random variable names theta and x follow the - typical presentation of the conjugate prior relation in the - form of p(theta|x) = p(x|theta) * p(theta)/p(x). - Note: Variable names here follow those used on: - https://en.wikipedia.org/wiki/Conjugate_prior - """ - - def __init__(self, alpha: Tensor, beta: Tensor, n: Tensor) -> None: - self.alpha_ = alpha - self.beta_ = beta - self.n_ = n - - @bm.random_variable - def theta(self) -> dist.Distribution: - return dist.Beta(self.alpha_, self.beta_) - - @bm.random_variable - def x(self) -> dist.Distribution: - return dist.Binomial(self.n_, self.theta()) diff --git a/src/beanmachine/ppl/examples/conjugate_models/categorical_dirichlet.py b/src/beanmachine/ppl/examples/conjugate_models/categorical_dirichlet.py deleted file mode 100644 index 0c838c4c1d..0000000000 --- a/src/beanmachine/ppl/examples/conjugate_models/categorical_dirichlet.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import beanmachine.ppl as bm -import torch.distributions as dist -from torch import Tensor - - -class CategoricalDirichletModel: - def __init__(self, alpha: Tensor) -> None: - self.alpha_ = alpha - - @bm.random_variable - def dirichlet(self) -> dist.Distribution: - return dist.Dirichlet(self.alpha_) - - @bm.random_variable - def categorical(self) -> dist.Distribution: - return dist.Categorical(self.dirichlet()) diff --git a/src/beanmachine/ppl/examples/conjugate_models/gamma_gamma.py b/src/beanmachine/ppl/examples/conjugate_models/gamma_gamma.py deleted file mode 100644 index 54005a8b9d..0000000000 --- a/src/beanmachine/ppl/examples/conjugate_models/gamma_gamma.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import beanmachine.ppl as bm -import torch.distributions as dist -from torch import Tensor - - -class GammaGammaModel: - def __init__(self, shape: Tensor, rate: Tensor, alpha: Tensor) -> None: - self.shape_ = shape - self.rate_ = rate - self.alpha_ = alpha - - @bm.random_variable - def gamma_p(self) -> dist.Distribution: - return dist.Gamma(self.shape_, self.rate_) - - @bm.random_variable - def gamma(self) -> dist.Distribution: - return dist.Gamma(self.alpha_, self.gamma_p()) diff --git a/src/beanmachine/ppl/examples/conjugate_models/gamma_normal.py b/src/beanmachine/ppl/examples/conjugate_models/gamma_normal.py deleted file mode 100644 index 7e4216aa07..0000000000 --- a/src/beanmachine/ppl/examples/conjugate_models/gamma_normal.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import beanmachine.ppl as bm -import torch -import torch.distributions as dist -from torch import Tensor - - -class GammaNormalModel: - def __init__(self, shape: Tensor, rate: Tensor, mu: Tensor) -> None: - self.shape_ = shape - self.rate_ = rate - self.mu_ = mu - - @bm.random_variable - def gamma(self) -> dist.Distribution: - return dist.Gamma(self.shape_, self.rate_) - - @bm.random_variable - def normal(self) -> dist.Distribution: - # pyre-fixme[58]: `/` is not supported for operand types `int` and `Tensor`. - return dist.Normal(self.mu_, 1 / torch.sqrt(self.gamma())) diff --git a/src/beanmachine/ppl/examples/conjugate_models/normal_normal.py b/src/beanmachine/ppl/examples/conjugate_models/normal_normal.py deleted file mode 100644 index 317e4297c4..0000000000 --- a/src/beanmachine/ppl/examples/conjugate_models/normal_normal.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import beanmachine.ppl as bm -import torch.distributions as dist -from torch import Tensor - - -class NormalNormalModel: - def __init__(self, mu: Tensor, std: Tensor, sigma: Tensor) -> None: - self.mu_ = mu - self.std_ = std - self.sigma_ = sigma - - @bm.random_variable - def normal_p(self) -> dist.Distribution: - return dist.Normal(self.mu_, self.std_) - - @bm.random_variable - def normal(self) -> dist.Distribution: - return dist.Normal(self.normal_p(), self.sigma_) diff --git a/tests/ppl/compiler/fix_beta_bernoulli_alpha_rv_test.py b/tests/ppl/compiler/fix_beta_bernoulli_alpha_rv_test.py index bef88fbbc7..70a6b456df 100644 --- a/tests/ppl/compiler/fix_beta_bernoulli_alpha_rv_test.py +++ b/tests/ppl/compiler/fix_beta_bernoulli_alpha_rv_test.py @@ -9,7 +9,7 @@ import unittest import beanmachine.ppl as bm -from beanmachine.ppl.examples.conjugate_models.beta_bernoulli import BetaBernoulliModel +from beanmachine.ppl.examples.conjugate_models import BetaBernoulliModel from beanmachine.ppl.inference.bmg_inference import BMGInference from torch import tensor from torch.distributions import Beta @@ -39,10 +39,10 @@ def test_conjugate_graph(self) -> None: model = BetaBernoulliAlphaRVModel() queries = [model.theta()] observations = { - model.y(0): tensor(0.0), - model.y(1): tensor(0.0), - model.y(2): tensor(1.0), - model.y(3): tensor(0.0), + model.x(0): tensor(0.0), + model.x(1): tensor(0.0), + model.x(2): tensor(1.0), + model.x(3): tensor(0.0), } num_samples = 1000 bmg = BMGInference() diff --git a/tests/ppl/compiler/fix_beta_binomial_basic_test.py b/tests/ppl/compiler/fix_beta_binomial_basic_test.py index fd606edd1c..1636445c6c 100644 --- a/tests/ppl/compiler/fix_beta_binomial_basic_test.py +++ b/tests/ppl/compiler/fix_beta_binomial_basic_test.py @@ -12,7 +12,7 @@ import beanmachine.ppl as bm import scipy import torch -from beanmachine.ppl.examples.conjugate_models.beta_binomial import BetaBinomialModel +from beanmachine.ppl.examples.conjugate_models import BetaBinomialModel from beanmachine.ppl.inference.bmg_inference import BMGInference from torch import tensor from torch.distributions import Beta diff --git a/tests/ppl/compiler/fix_normal_normal_basic_test.py b/tests/ppl/compiler/fix_normal_normal_basic_test.py index d636f3be36..220f2f8f0e 100644 --- a/tests/ppl/compiler/fix_normal_normal_basic_test.py +++ b/tests/ppl/compiler/fix_normal_normal_basic_test.py @@ -10,7 +10,7 @@ import scipy import torch -from beanmachine.ppl.examples.conjugate_models.normal_normal import NormalNormalModel +from beanmachine.ppl.examples.conjugate_models import NormalNormalModel from beanmachine.ppl.inference.bmg_inference import BMGInference from torch import tensor from torch.distributions import Normal @@ -21,8 +21,8 @@ def test_conjugate_graph(self) -> None: bmg = BMGInference() model = NormalNormalModel(10.0, 2.0, 5.0) - queries = [model.normal_p()] - observations = {model.normal(): tensor(15.9)} + queries = [model.theta()] + observations = {model.x(): tensor(15.9)} observed_bmg = bmg.to_dot(queries, observations, skip_optimizations=set()) expected_bmg = """ digraph "graph" { @@ -49,26 +49,24 @@ def test_normal_normal_conjugate(self) -> None: torch.manual_seed(seed) random.seed(seed) true_mu = 0.5 - true_y = Normal(true_mu, 10.0) + true_x = Normal(true_mu, 10.0) num_samples = 1000 bmg = BMGInference() model = NormalNormalModel(10.0, 2.0, 5.0) - queries = [model.normal_p()] - observations = { - model.normal(): true_y.sample(), - } + queries = [model.theta()] + observations = {model.x(): true_x.sample()} skip_optimizations = {"normal_normal_conjugate_fixer"} original_posterior = bmg.infer( queries, observations, num_samples, 1, skip_optimizations=skip_optimizations ) - original_samples = original_posterior[model.normal_p()][0] + original_samples = original_posterior[model.theta()][0] transformed_posterior = bmg.infer( queries, observations, num_samples, 1, skip_optimizations=set() ) - transformed_samples = transformed_posterior[model.normal_p()][0] + transformed_samples = transformed_posterior[model.theta()][0] self.assertEqual( type(original_samples), diff --git a/tests/ppl/testlib/abstract_conjugate.py b/tests/ppl/testlib/abstract_conjugate.py index 15c25699c8..29fddaf518 100644 --- a/tests/ppl/testlib/abstract_conjugate.py +++ b/tests/ppl/testlib/abstract_conjugate.py @@ -10,13 +10,13 @@ import scipy.stats import torch from beanmachine.ppl.diagnostics.common_statistics import effective_sample_size -from beanmachine.ppl.examples.conjugate_models.beta_binomial import BetaBinomialModel -from beanmachine.ppl.examples.conjugate_models.categorical_dirichlet import ( +from beanmachine.ppl.examples.conjugate_models import ( + BetaBinomialModel, CategoricalDirichletModel, + GammaGammaModel, + GammaNormalModel, + NormalNormalModel, ) -from beanmachine.ppl.examples.conjugate_models.gamma_gamma import GammaGammaModel -from beanmachine.ppl.examples.conjugate_models.gamma_normal import GammaNormalModel -from beanmachine.ppl.examples.conjugate_models.normal_normal import NormalNormalModel from beanmachine.ppl.inference import utils from beanmachine.ppl.inference.base_inference import BaseInference from beanmachine.ppl.model.rv_identifier import RVIdentifier @@ -72,7 +72,6 @@ def compute_beta_binomial_moments( (alpha_prime * beta_prime) / ((alpha_prime + beta_prime).pow(2.0) * (alpha_prime + beta_prime + 1.0)) ).pow(0.5) - return (mean_prime, std_prime, queries, observations) def compute_gamma_gamma_moments( @@ -89,8 +88,8 @@ def compute_gamma_gamma_moments( alpha = tensor([1.5, 1.5]) obs = tensor([2.0, 4.0]) model = GammaGammaModel(shape, rate, alpha) - queries = [model.gamma_p()] - observations = {model.gamma(): obs} + queries = [model.theta()] + observations = {model.x(): obs} shape = shape + alpha rate = rate + obs expected_mean = shape / rate @@ -111,8 +110,8 @@ def compute_gamma_normal_moments( mu = tensor([1.0, 2.0]) obs = tensor([1.5, 2.5]) model = GammaNormalModel(shape, rate, mu) - queries = [model.gamma()] - observations = {model.normal(): obs} + queries = [model.theta()] + observations = {model.x(): obs} shape = shape + tensor([0.5, 0.5]) deviations = (obs - mu).pow(2.0) rate = rate + (deviations * (0.5)) @@ -130,19 +129,14 @@ def compute_normal_normal_moments( queries and observations """ mu = tensor([1.0, 1.0]) - std = tensor([1.0, 1.0]) sigma = tensor([1.0, 1.0]) + std = tensor([1.0, 1.0]) obs = tensor([1.5, 2.5]) - model = NormalNormalModel(mu, std, sigma) - queries = [model.normal_p()] - observations = {model.normal(): obs} - expected_mean = (mu / std.pow(2.0) + obs / sigma.pow(2.0)) / ( - # pyre-fixme[58]: `/` is not supported for operand types `float` and - # `Tensor`. - 1.0 / sigma.pow(2.0) - # pyre-fixme[58]: `/` is not supported for operand types `float` and - # `Tensor`. - + 1.0 / std.pow(2.0) + model = NormalNormalModel(mu, sigma, std) + queries = [model.theta()] + observations = {model.x(): obs} + expected_mean = (mu / sigma.pow(2.0) + obs / std.pow(2.0)) / ( + tensor(1.0) / std.pow(2.0) + tensor(1.0) / sigma.pow(2.0) ) expected_std = (std.pow(-2.0) + sigma.pow(-2.0)).pow(-0.5) return (expected_mean, expected_std, queries, observations) @@ -158,8 +152,8 @@ def compute_dirichlet_categorical_moments(self): alpha = tensor([0.5, 0.5]) model = CategoricalDirichletModel(alpha) obs = tensor([1.0]) - queries = [model.dirichlet()] - observations = {model.categorical(): obs} + queries = [model.theta()] + observations = {model.x(): obs} alpha = alpha + tensor([0.0, 1.0]) expected_mean = alpha / alpha.sum() expected_std = (expected_mean * (1 - expected_mean) / (alpha.sum() + 1)).pow( From ef2e63dc50efaecd5b5747b51717d4dde9f14973 Mon Sep 17 00:00:00 2001 From: ntfrgl Date: Mon, 10 Oct 2022 19:04:21 -0700 Subject: [PATCH 3/3] Pytest: Split models from inference tests (1/2) - port: unittest -> pytest - factor out as fixtures: models, samplers, proposers - standardise RV names in models with similar structures - reduce intra-module code redundancy - apply linting --- .../ppl/examples/hierarchical_models.py | 75 +++ .../ppl/examples/primitive_models.py | 54 ++ tests/__init__.py | 0 tests/models/__init__.py | 0 tests/ppl/conftest.py | 6 +- ...positional_infer_conjugate_test_nightly.py | 1 + .../inference/hypothesis_testing_nightly.py | 5 +- .../ppl/inference/monte_carlo_samples_test.py | 390 ++++++------- tests/ppl/inference/nnc_test.py | 31 +- tests/ppl/inference/predictive_test.py | 297 ++++------ .../inference/proposer/hmc_proposer_test.py | 33 +- .../ppl/inference/proposer/hmc_utils_test.py | 38 +- ...ace_newtonian_monte_carlo_proposer_test.py | 77 +-- ...ace_newtonian_monte_carlo_proposer_test.py | 540 ++++++++---------- ...lex_newtonian_monte_carlo_proposer_test.py | 84 ++- tests/ppl/inference/sampler_test.py | 58 +- ...ite_ancestral_mh_conjugate_test_nightly.py | 1 + ...nian_monte_carlo_conjugate_test_nightly.py | 1 + ...nian_monte_carlo_conjugate_test_nightly.py | 1 + ...e_site_no_u_turn_conjugate_test_nightly.py | 1 + ...om_walk_adaptive_conjugate_test_nightly.py | 1 + ...site_random_walk_conjugate_test_nightly.py | 1 + .../inference/single_site_random_walk_test.py | 20 +- ..._site_uniform_mh_conjugate_test_nightly.py | 1 + .../inference/single_site_uniform_mh_test.py | 74 +-- tests/ppl/inference/utils_test.py | 36 +- tests/ppl/utils/fixtures.py | 145 +++++ 27 files changed, 1006 insertions(+), 965 deletions(-) create mode 100644 src/beanmachine/ppl/examples/hierarchical_models.py create mode 100644 src/beanmachine/ppl/examples/primitive_models.py create mode 100644 tests/__init__.py create mode 100644 tests/models/__init__.py create mode 100644 tests/ppl/utils/fixtures.py diff --git a/src/beanmachine/ppl/examples/hierarchical_models.py b/src/beanmachine/ppl/examples/hierarchical_models.py new file mode 100644 index 0000000000..8fb6f397c9 --- /dev/null +++ b/src/beanmachine/ppl/examples/hierarchical_models.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import beanmachine.ppl as bm +import torch +import torch.distributions as dist +from torch import Tensor, tensor + + +class UniformNormalModel: + def __init__(self, lo: Tensor, hi: Tensor, std: Tensor) -> None: + self.lo = lo + self.hi = hi + self.std = std + + @bm.random_variable + def mean(self): + return dist.Uniform(self.lo, self.hi) + + @bm.random_variable + def obs(self): + return dist.Normal(self.mean(), self.std) + + +class UniformBernoulliModel: + def __init__(self, lo: Tensor, hi: Tensor) -> None: + self.lo = lo + self.hi = hi + + @bm.random_variable + def prior(self): + return dist.Uniform(self.lo, self.hi) + + @bm.random_variable + def likelihood(self): + return dist.Bernoulli(self.prior()) + + @bm.random_variable + def likelihood_i(self, i): + return dist.Bernoulli(self.prior()) + + @bm.random_variable + def likelihood_dynamic(self, i): + assert self.lo.ndim == self.hi.ndim == 0 + if self.likelihood_i(i).item() > 0: + return dist.Normal(torch.zeros(1), torch.ones(1)) + else: + return dist.Normal(5.0 * torch.ones(1), torch.ones(1)) + + @bm.random_variable + def likelihood_reg(self, x): + assert self.lo.ndim == self.hi.ndim == 0 + return dist.Normal(self.prior() * x, torch.tensor(1.0)) + + +class LogisticRegressionModel: + @bm.random_variable + def theta_0(self): + return dist.Normal(tensor(0.0), tensor(1.0)) + + @bm.random_variable + def theta_1(self): + return dist.Normal(tensor(0.0), tensor(1.0)) + + @bm.random_variable + def x(self, i): + return dist.Normal(tensor(0.0), tensor(1.0)) + + @bm.random_variable + def y(self, i): + y = self.theta_1() * self.x(i) + self.theta_0() + probs = 1 / (1 + (y * -1).exp()) + return dist.Bernoulli(probs) diff --git a/src/beanmachine/ppl/examples/primitive_models.py b/src/beanmachine/ppl/examples/primitive_models.py new file mode 100644 index 0000000000..bd89e12b86 --- /dev/null +++ b/src/beanmachine/ppl/examples/primitive_models.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod + +import beanmachine.ppl as bm +import torch.distributions as dist +from torch import Tensor + + +class PrimitiveModel(ABC): + @abstractmethod + def x(self) -> dist.Distribution: + pass + + +class NormalModel(PrimitiveModel): + def __init__(self, mu: Tensor, sigma: Tensor) -> None: + self.mu = mu + self.sigma = sigma + + @bm.random_variable + def x(self) -> dist.Distribution: + return dist.Normal(self.mu, self.sigma) + + +class GammaModel(PrimitiveModel): + def __init__(self, alpha: Tensor, beta: Tensor) -> None: + self.alpha = alpha + self.beta = beta + + @bm.random_variable + def x(self) -> dist.Distribution: + return dist.Gamma(self.alpha, self.beta) + + +class PoissonModel(PrimitiveModel): + def __init__(self, rate: Tensor) -> None: + self.rate = rate + + @bm.random_variable + def x(self) -> dist.Distribution: + return dist.Poisson(self.rate) + + +class DirichletModel(PrimitiveModel): + def __init__(self, alpha: Tensor) -> None: + self.alpha = alpha + + @bm.random_variable + def x(self) -> dist.Distribution: + return dist.Dirichlet(self.alpha) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/ppl/conftest.py b/tests/ppl/conftest.py index 6421f21cc6..bc50022514 100644 --- a/tests/ppl/conftest.py +++ b/tests/ppl/conftest.py @@ -9,12 +9,12 @@ @pytest.fixture(autouse=True) -def fix_random_seed(): +def fix_random_seed() -> None: """Fix the random state for every test in the test suite.""" bm.seed(0) @pytest.fixture(autouse=True) -def disable_torch_distribution_validation(): - """Disables validation of Torch distribution arguments.""" +def disable_torch_distribution_validation() -> None: + """Disable validation of Torch distribution arguments.""" dist.Distribution.set_default_validate_args(False) diff --git a/tests/ppl/inference/compositional_infer_conjugate_test_nightly.py b/tests/ppl/inference/compositional_infer_conjugate_test_nightly.py index 942ca3ee68..ba351aa3e7 100644 --- a/tests/ppl/inference/compositional_infer_conjugate_test_nightly.py +++ b/tests/ppl/inference/compositional_infer_conjugate_test_nightly.py @@ -6,6 +6,7 @@ import unittest from beanmachine.ppl.inference.compositional_infer import CompositionalInference + from ..testlib.abstract_conjugate import AbstractConjugateTests diff --git a/tests/ppl/inference/hypothesis_testing_nightly.py b/tests/ppl/inference/hypothesis_testing_nightly.py index bed590f084..39107ce533 100644 --- a/tests/ppl/inference/hypothesis_testing_nightly.py +++ b/tests/ppl/inference/hypothesis_testing_nightly.py @@ -7,13 +7,14 @@ from sys import float_info import torch.distributions as dist +from numpy import sqrt +from torch import manual_seed, mean, tensor + from ..testlib.hypothesis_testing import ( inverse_normal_cdf, mean_equality_hypothesis_confidence_interval, mean_equality_hypothesis_test, ) -from numpy import sqrt -from torch import manual_seed, mean, tensor class HypothesisTestingTest(unittest.TestCase): diff --git a/tests/ppl/inference/monte_carlo_samples_test.py b/tests/ppl/inference/monte_carlo_samples_test.py index a7be2499d1..9631e99ff4 100644 --- a/tests/ppl/inference/monte_carlo_samples_test.py +++ b/tests/ppl/inference/monte_carlo_samples_test.py @@ -4,252 +4,222 @@ # LICENSE file in the root directory of this source tree. import pickle -import unittest import beanmachine.ppl as bm import numpy as np +import pytest import torch -import torch.distributions as dist import xarray as xr +from beanmachine.ppl.examples.conjugate_models import NormalNormalModel from beanmachine.ppl.inference.monte_carlo_samples import merge_dicts, MonteCarloSamples - - -class MonteCarloSamplesTest(unittest.TestCase): - class SampleModel(object): - @bm.random_variable - def foo(self): - return dist.Normal(torch.tensor(0.0), torch.tensor(1.0)) - - @bm.random_variable - def bar(self): - return dist.Normal(self.foo(), torch.tensor(1.0)) - - def test_default_four_chains(self): - model = self.SampleModel() - mh = bm.SingleSiteAncestralMetropolisHastings() - foo_key = model.foo() - mcs = mh.infer([foo_key], {}, 10) - - self.assertEqual(mcs[foo_key].shape, torch.zeros(4, 10).shape) - self.assertEqual(mcs.get_variable(foo_key).shape, torch.zeros(4, 10).shape) - self.assertEqual(mcs.get_chain(3)[foo_key].shape, torch.zeros(10).shape) - self.assertEqual(mcs.num_chains, 4) - self.assertCountEqual(mcs.keys(), [foo_key]) - - mcs = mh.infer([foo_key], {}, 7, num_adaptive_samples=3) - - self.assertEqual(mcs.num_adaptive_samples, 3) - self.assertEqual(mcs[foo_key].shape, torch.zeros(4, 7).shape) - self.assertEqual(mcs.get_variable(foo_key).shape, torch.zeros(4, 7).shape) - self.assertEqual( - mcs.get_variable(foo_key, True).shape, torch.zeros(4, 10).shape - ) - self.assertEqual(mcs.get_chain(3)[foo_key].shape, torch.zeros(7).shape) - self.assertEqual(mcs.num_chains, 4) - self.assertCountEqual(mcs.keys(), [foo_key]) - - def test_one_chain(self): - model = self.SampleModel() - mh = bm.SingleSiteAncestralMetropolisHastings() - foo_key = model.foo() - bar_key = model.bar() - mcs = mh.infer([foo_key, bar_key], {}, 10, 1) - - self.assertEqual(mcs[foo_key].shape, torch.zeros(1, 10).shape) - self.assertEqual(mcs.get_variable(foo_key).shape, torch.zeros(1, 10).shape) - self.assertEqual(mcs.get_chain()[foo_key].shape, torch.zeros(10).shape) - self.assertEqual(mcs.num_chains, 1) - self.assertCountEqual(mcs.keys(), [foo_key, bar_key]) - - mcs = mh.infer([foo_key, bar_key], {}, 7, 1, num_adaptive_samples=3) - - self.assertEqual(mcs.num_adaptive_samples, 3) - self.assertEqual(mcs[foo_key].shape, torch.zeros(1, 7).shape) - self.assertEqual(mcs.get_variable(foo_key).shape, torch.zeros(1, 7).shape) - self.assertEqual( - mcs.get_variable(foo_key, True).shape, torch.zeros(1, 10).shape - ) - self.assertEqual(mcs.get_chain()[foo_key].shape, torch.zeros(7).shape) - self.assertEqual(mcs.num_chains, 1) - self.assertCountEqual(mcs.keys(), [foo_key, bar_key]) - - def test_chain_exceptions(self): - model = self.SampleModel() - mh = bm.SingleSiteAncestralMetropolisHastings() - foo_key = model.foo() - mcs = mh.infer([foo_key], {}, 10) - - with self.assertRaisesRegex(IndexError, r"Please specify a valid chain"): +from torch import tensor + +from ..utils.fixtures import parametrize_inference, parametrize_model + + +pytestmark = parametrize_model( + [NormalNormalModel(tensor(0.0), tensor(1.0), tensor(1.0))] +) + + +def test_merge_dicts(model): + chain_lists = [{model.theta(): torch.rand(3)}, {model.theta(): torch.rand(3)}] + rv_dict = merge_dicts(chain_lists) + assert model.theta() in rv_dict + assert rv_dict.get(model.theta()).shape == (2, 3) + chain_lists.append({model.x(): torch.rand(3)}) + with pytest.raises(ValueError): + merge_dicts(chain_lists) + + +def test_type_conversion(model): + samples = MonteCarloSamples( + [{model.theta(): torch.rand(5), model.x(): torch.rand(5)}], + num_adaptive_samples=3, + ) + + xr_dataset = samples.to_xarray() + assert isinstance(xr_dataset, xr.Dataset) + assert model.theta() in xr_dataset + assert np.allclose(samples[model.x()].numpy(), xr_dataset[model.x()]) + xr_dataset = samples.to_xarray(include_adapt_steps=True) + assert xr_dataset[model.theta()].shape == (1, 5) + + inference_data = samples.to_inference_data() + assert model.theta() in inference_data.posterior + + +def test_get_variable(model): + samples = MonteCarloSamples( + [{model.x(): torch.arange(10)}], num_adaptive_samples=3 + ).get_chain(0) + assert torch.all(samples.get_variable(model.x()) == torch.arange(3, 10)) + assert torch.all(samples.get_variable(model.x(), True) == torch.arange(10)) + + +@parametrize_inference([bm.SingleSiteAncestralMetropolisHastings()]) +class TestInferenceResults: + @staticmethod + def test_default_four_chains(model, inference): + p_key = model.theta() + mcs = inference.infer([p_key], {}, 10) + + assert mcs[p_key].shape == torch.zeros(4, 10).shape + assert mcs.get_variable(p_key).shape == torch.zeros(4, 10).shape + assert mcs.get_chain(3)[p_key].shape == torch.zeros(10).shape + assert mcs.num_chains == 4 + assert set(mcs.keys()) == set([p_key]) + + mcs = inference.infer([p_key], {}, 7, num_adaptive_samples=3) + + assert mcs.num_adaptive_samples == 3 + assert mcs[p_key].shape == torch.zeros(4, 7).shape + assert mcs.get_variable(p_key).shape == torch.zeros(4, 7).shape + assert mcs.get_variable(p_key, True).shape == torch.zeros(4, 10).shape + assert mcs.get_chain(3)[p_key].shape == torch.zeros(7).shape + assert mcs.num_chains == 4 + assert set(mcs.keys()) == set([p_key]) + + @staticmethod + def test_one_chain(model, inference): + p_key = model.theta() + l_key = model.x() + mcs = inference.infer([p_key, l_key], {}, 10, 1) + + assert mcs[p_key].shape == torch.zeros(1, 10).shape + assert mcs.get_variable(p_key).shape == torch.zeros(1, 10).shape + assert mcs.get_chain()[p_key].shape == torch.zeros(10).shape + assert mcs.num_chains == 1 + assert set(mcs.keys()) == set([p_key, l_key]) + + mcs = inference.infer([p_key, l_key], {}, 7, 1, num_adaptive_samples=3) + + assert mcs.num_adaptive_samples == 3 + assert mcs[p_key].shape == torch.zeros(1, 7).shape + assert mcs.get_variable(p_key).shape == torch.zeros(1, 7).shape + assert mcs.get_variable(p_key, True).shape == torch.zeros(1, 10).shape + assert mcs.get_chain()[p_key].shape == torch.zeros(7).shape + assert mcs.num_chains == 1 + assert set(mcs.keys()) == set([p_key, l_key]) + + @staticmethod + def test_chain_exceptions(model, inference): + p_key = model.theta() + mcs = inference.infer([p_key], {}, 10) + + with pytest.raises(IndexError, match="Please specify a valid chain"): mcs.get_chain(-1) - - with self.assertRaisesRegex(IndexError, r"Please specify a valid chain"): + with pytest.raises(IndexError, match="Please specify a valid chain"): mcs.get_chain(4) - - with self.assertRaisesRegex( + with pytest.raises( ValueError, - r"The current MonteCarloSamples object has already" - r" been restricted to a single chain", + match=( + r"The current MonteCarloSamples object has already" + r" been restricted to a single chain" + ), ): one_chain = mcs.get_chain() one_chain.get_chain() - def test_num_adaptive_samples(self): - model = self.SampleModel() - mh = bm.SingleSiteAncestralMetropolisHastings() - foo_key = model.foo() - mcs = mh.infer([foo_key], {}, 10, num_adaptive_samples=3) - - self.assertEqual(mcs[foo_key].shape, torch.zeros(4, 10).shape) - self.assertEqual(mcs.get_variable(foo_key).shape, torch.zeros(4, 10).shape) - self.assertEqual( - mcs.get_variable(foo_key, include_adapt_steps=True).shape, - torch.zeros(4, 13).shape, + @staticmethod + def test_num_adaptive_samples(model, inference): + p_key = model.theta() + mcs = inference.infer([p_key], {}, 10, num_adaptive_samples=3) + + assert mcs[p_key].shape == torch.zeros(4, 10).shape + assert mcs.get_variable(p_key).shape == torch.zeros(4, 10).shape + assert ( + mcs.get_variable(p_key, include_adapt_steps=True).shape + == torch.zeros(4, 13).shape ) - self.assertEqual(mcs.get_num_samples(), 10) - self.assertEqual(mcs.get_num_samples(include_adapt_steps=True), 13) + assert mcs.get_num_samples() == 10 + assert mcs.get_num_samples(include_adapt_steps=True) == 13 - def test_dump_and_restore_samples(self): - model = self.SampleModel() - mh = bm.SingleSiteAncestralMetropolisHastings() - foo_key = model.foo() - samples = mh.infer([foo_key], {}, num_samples=10, num_chains=2) - self.assertEqual(samples[foo_key].shape, (2, 10)) + @staticmethod + def test_dump_and_restore_samples(model, inference): + p_key = model.theta() + samples = inference.infer([p_key], {}, num_samples=10, num_chains=2) + assert samples[p_key].shape == (2, 10) dumped = pickle.dumps((model, samples)) # delete local variables and pretend that we are starting from a new session del model - del mh - del foo_key + del inference + del p_key del samples # reload from dumped bytes reloaded_model, reloaded_samples = pickle.loads(dumped) # check the values still exist and have the correct shape - self.assertEqual(reloaded_samples[reloaded_model.foo()].shape, (2, 10)) - - def test_get_rv_with_default(self): - model = self.SampleModel() - mh = bm.SingleSiteAncestralMetropolisHastings() - foo_key = model.foo() - samples = mh.infer([foo_key], {}, num_samples=10, num_chains=2) - - self.assertIn(model.foo(), samples) - self.assertIsInstance(samples.get(model.foo()), torch.Tensor) - self.assertIsNone(samples.get(model.bar())) - self.assertEqual(samples.get(model.foo(), chain=0).shape, (10,)) - - def test_merge_dicts(self): - model = self.SampleModel() - chain_lists = [{model.foo(): torch.rand(3)}, {model.foo(): torch.rand(3)}] - rv_dict = merge_dicts(chain_lists) - self.assertIn(model.foo(), rv_dict) - self.assertEqual(rv_dict.get(model.foo()).shape, (2, 3)) - chain_lists.append({model.bar(): torch.rand(3)}) - with self.assertRaises(ValueError): - merge_dicts(chain_lists) - - def test_type_conversion(self): - model = self.SampleModel() - samples = MonteCarloSamples( - [{model.foo(): torch.rand(5), model.bar(): torch.rand(5)}], - num_adaptive_samples=3, - ) - - xr_dataset = samples.to_xarray() - self.assertIsInstance(xr_dataset, xr.Dataset) - self.assertIn(model.foo(), xr_dataset) - assert np.allclose(samples[model.bar()].numpy(), xr_dataset[model.bar()]) - xr_dataset = samples.to_xarray(include_adapt_steps=True) - self.assertEqual(xr_dataset[model.foo()].shape, (1, 5)) - - inference_data = samples.to_inference_data() - self.assertIn(model.foo(), inference_data.posterior) - - def test_get_variable(self): - model = self.SampleModel() - samples = MonteCarloSamples( - [{model.foo(): torch.arange(10)}], num_adaptive_samples=3 - ).get_chain(0) - self.assertTrue( - torch.all(samples.get_variable(model.foo()) == torch.arange(3, 10)) - ) - self.assertTrue( - torch.all(samples.get_variable(model.foo(), True) == torch.arange(10)) - ) - - def test_get_log_likehoods(self): - model = self.SampleModel() - mh = bm.SingleSiteAncestralMetropolisHastings() - foo_key = model.foo() - bar_key = model.bar() - mcs = mh.infer( - [foo_key], - {bar_key: torch.tensor(4.0)}, + assert reloaded_samples[reloaded_model.theta()].shape == (2, 10) + + @staticmethod + def test_get_rv_with_default(model, inference): + p_key = model.theta() + samples = inference.infer([p_key], {}, num_samples=10, num_chains=2) + + assert model.theta() in samples + assert isinstance(samples.get(model.theta()), torch.Tensor) + assert samples.get(model.x()) is None + assert samples.get(model.theta(), chain=0).shape == (10,) + + @staticmethod + def test_get_log_likehoods(model, inference): + p_key = model.theta() + l_key = model.x() + mcs = inference.infer( + [p_key], + {l_key: torch.tensor(4.0)}, num_samples=5, num_chains=2, ) - self.assertTrue(hasattr(mcs, "log_likelihoods")) - self.assertIn(bar_key, mcs.log_likelihoods) - self.assertTrue(hasattr(mcs, "adaptive_log_likelihoods")) - self.assertIn(bar_key, mcs.adaptive_log_likelihoods) - self.assertEqual( - mcs.get_log_likelihoods(bar_key).shape, torch.zeros(2, 5).shape - ) + assert hasattr(mcs, "log_likelihoods") + assert l_key in mcs.log_likelihoods + assert hasattr(mcs, "adaptive_log_likelihoods") + assert l_key in mcs.adaptive_log_likelihoods + assert mcs.get_log_likelihoods(l_key).shape == torch.zeros(2, 5).shape mcs = mcs.get_chain(0) - self.assertEqual(mcs.get_log_likelihoods(bar_key).shape, torch.zeros(5).shape) + assert mcs.get_log_likelihoods(l_key).shape == torch.zeros(5).shape - mcs = mh.infer( - [foo_key], - {bar_key: torch.tensor(4.0)}, + mcs = inference.infer( + [p_key], + {l_key: torch.tensor(4.0)}, num_samples=5, num_chains=2, num_adaptive_samples=3, ) - - self.assertEqual( - mcs.get_log_likelihoods(bar_key).shape, torch.zeros(2, 5).shape - ) - self.assertEqual( - mcs.adaptive_log_likelihoods[bar_key].shape, torch.zeros(2, 3).shape - ) - self.assertEqual( - mcs.get_chain(0).get_log_likelihoods(bar_key).shape, torch.zeros(5).shape - ) - self.assertEqual( - mcs.get_log_likelihoods(bar_key, True).shape, torch.zeros(2, 8).shape - ) - self.assertEqual( - mcs.get_chain(0).adaptive_log_likelihoods[bar_key].shape, - torch.zeros(1, 3).shape, + assert mcs.get_log_likelihoods(l_key).shape == torch.zeros(2, 5).shape + assert mcs.adaptive_log_likelihoods[l_key].shape == torch.zeros(2, 3).shape + assert mcs.get_chain(0).get_log_likelihoods(l_key).shape == torch.zeros(5).shape + assert mcs.get_log_likelihoods(l_key, True).shape == torch.zeros(2, 8).shape + assert ( + mcs.get_chain(0).adaptive_log_likelihoods[l_key].shape + == torch.zeros(1, 3).shape ) - def test_thinning(self): - model = self.SampleModel() - mh = bm.SingleSiteAncestralMetropolisHastings() - samples = mh.infer([model.foo()], {}, num_samples=20, num_chains=1) - - self.assertEqual(samples.get(model.foo(), chain=0).shape, (20,)) - self.assertEqual(samples.get(model.foo(), chain=0, thinning=4).shape, (5,)) - - def test_add_group(self): - model = self.SampleModel() - mh = bm.SingleSiteAncestralMetropolisHastings() - samples = mh.infer([model.foo()], {}, num_samples=20, num_chains=1) - bar_samples = MonteCarloSamples(samples.samples, default_namespace="bar") - bar_samples.add_groups(samples) - self.assertEqual(samples.observations, bar_samples.observations) - self.assertEqual(samples.log_likelihoods, bar_samples.log_likelihoods) - self.assertIn("posterior", bar_samples.namespaces) - - def test_to_inference_data(self): - model = self.SampleModel() - mh = bm.SingleSiteAncestralMetropolisHastings() - samples = mh.infer([model.foo()], {}, num_samples=10, num_chains=1) + @staticmethod + def test_thinning(model, inference): + samples = inference.infer([model.theta()], {}, num_samples=20, num_chains=1) + assert samples.get(model.theta(), chain=0).shape == (20,) + assert samples.get(model.theta(), chain=0, thinning=4).shape == (5,) + + @staticmethod + def test_add_group(model, inference): + samples = inference.infer([model.theta()], {}, num_samples=20, num_chains=1) + new_samples = MonteCarloSamples(samples.samples, default_namespace="new") + new_samples.add_groups(samples) + assert samples.observations == new_samples.observations + assert samples.log_likelihoods == new_samples.log_likelihoods + assert "posterior" in new_samples.namespaces + + @staticmethod + def test_to_inference_data(model, inference): + samples = inference.infer([model.theta()], {}, num_samples=10, num_chains=1) az_xarray = samples.to_inference_data() - self.assertNotIn("warmup_posterior", az_xarray) + assert "warmup_posterior" not in az_xarray - samples = mh.infer( - [model.foo()], {}, num_samples=10, num_adaptive_samples=2, num_chains=1 + samples = inference.infer( + [model.theta()], {}, num_samples=10, num_adaptive_samples=2, num_chains=1 ) az_xarray = samples.to_inference_data(include_adapt_steps=True) - self.assertIn("warmup_posterior", az_xarray) + assert "warmup_posterior" in az_xarray diff --git a/tests/ppl/inference/nnc_test.py b/tests/ppl/inference/nnc_test.py index e379dbe985..3dab9d205a 100644 --- a/tests/ppl/inference/nnc_test.py +++ b/tests/ppl/inference/nnc_test.py @@ -6,38 +6,33 @@ import warnings import beanmachine.ppl as bm -import pytest import torch -import torch.distributions as dist +from beanmachine.ppl.examples.conjugate_models import NormalNormalModel +from torch import tensor +from ..utils.fixtures import parametrize_inference, parametrize_model -class SampleModel: - @bm.random_variable - def foo(self): - return dist.Normal(0.0, 1.0) - @bm.random_variable - def bar(self): - return dist.Normal(self.foo(), 1.0) +pytestmark = parametrize_model( + [NormalNormalModel(tensor(0.0), tensor(1.0), tensor(1.0))] +) -@pytest.mark.parametrize( - "algorithm", +@parametrize_inference( [ bm.GlobalNoUTurnSampler(nnc_compile=True), bm.GlobalHamiltonianMonteCarlo(trajectory_length=1.0, nnc_compile=True), - ], + ] ) -def test_nnc_compile(algorithm): - model = SampleModel() - queries = [model.foo()] - observations = {model.bar(): torch.tensor(0.5)} +def test_nnc_compile(model, inference): + queries = [model.theta()] + observations = {model.x(): torch.tensor(0.5)} num_samples = 30 num_chains = 2 with warnings.catch_warnings(): warnings.simplefilter("ignore") # verify that NNC can run through - samples = algorithm.infer( + samples = inference.infer( queries, observations, num_samples, @@ -45,4 +40,4 @@ def test_nnc_compile(algorithm): num_chains=num_chains, ) # sanity check: make sure that the samples are valid - assert not torch.isnan(samples[model.foo()]).any() + assert not torch.isnan(samples[model.theta()]).any() diff --git a/tests/ppl/inference/predictive_test.py b/tests/ppl/inference/predictive_test.py index 5293204ec6..005b0241b1 100644 --- a/tests/ppl/inference/predictive_test.py +++ b/tests/ppl/inference/predictive_test.py @@ -3,217 +3,170 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import unittest - import beanmachine.ppl as bm +import pytest import torch -import torch.distributions as dist - - -class PredictiveTest(unittest.TestCase): - @bm.random_variable - def prior(self): - return dist.Uniform(torch.tensor(0.0), torch.tensor(1.0)) - - @bm.random_variable - def likelihood(self): - return dist.Bernoulli(self.prior()) - - @bm.random_variable - def likelihood_i(self, i): - return dist.Bernoulli(self.prior()) - - @bm.random_variable - def prior_1(self): - return dist.Uniform(torch.tensor([0.0]), torch.tensor([1.0])) - - @bm.random_variable - def likelihood_1(self): - return dist.Bernoulli(self.prior_1()) - - @bm.random_variable - def likelihood_dynamic(self, i): - if self.likelihood_i(i).item() > 0: - return dist.Normal(torch.zeros(1), torch.ones(1)) - else: - return dist.Normal(5.0 * torch.ones(1), torch.ones(1)) - - @bm.random_variable - def prior_2(self): - return dist.Uniform(torch.zeros(1, 2), torch.ones(1, 2)) - - @bm.random_variable - def likelihood_2(self, i): - return dist.Bernoulli(self.prior_2()) - - @bm.random_variable - def likelihood_2_vec(self, i): - return dist.Bernoulli(self.prior_2()) - - @bm.random_variable - def likelihood_reg(self, x): - return dist.Normal(self.prior() * x, torch.tensor(1.0)) - - def test_prior_predictive(self): - queries = [self.prior(), self.likelihood()] +from beanmachine.ppl.examples.hierarchical_models import UniformBernoulliModel +from torch import tensor + +from ..utils.fixtures import ( + approx_all, + parametrize_inference, + parametrize_model, + parametrize_model_value, + parametrize_value, +) + + +@parametrize_model([UniformBernoulliModel(tensor(0.0), tensor(1.0))]) +class TestPredictive: + @staticmethod + def test_prior_predictive(model): + queries = [model.prior(), model.likelihood()] predictives = bm.simulate(queries, num_samples=10) - assert predictives[self.prior()].shape == (1, 10) - assert predictives[self.likelihood()].shape == (1, 10) - - def test_posterior_predictive(self): - obs = { - self.likelihood_i(0): torch.tensor(1.0), - self.likelihood_i(1): torch.tensor(0.0), - } - post_samples = bm.SingleSiteAncestralMetropolisHastings().infer( - [self.prior()], obs, num_samples=10, num_chains=2 + assert predictives[model.prior()].shape == (1, 10) + assert predictives[model.likelihood()].shape == (1, 10) + + @staticmethod + @parametrize_value([tensor([1.0, 0.0])]) + @pytest.mark.parametrize("num_chains", [2]) + @parametrize_inference([bm.SingleSiteAncestralMetropolisHastings()]) + @pytest.mark.parametrize("vectorized", [True, False]) + def test_posterior_predictive(model, value, inference, num_chains, vectorized): + num_samples = 10 + shape_samples = (num_chains, num_samples) + model.lo.shape + + obs = {model.likelihood_i(i): value[i] for i in range(len(value))} + post_samples = inference.infer( + [model.prior()], obs, num_samples=num_samples, num_chains=num_chains ) - assert post_samples[self.prior()].shape == (2, 10) - predictives = bm.simulate(list(obs.keys()), post_samples, vectorized=True) - assert predictives[self.likelihood_i(0)].shape == (2, 10) - assert predictives[self.likelihood_i(1)].shape == (2, 10) + assert post_samples[model.prior()].shape == shape_samples - def test_posterior_predictive_seq(self): - obs = { - self.likelihood_i(0): torch.tensor(1.0), - self.likelihood_i(1): torch.tensor(0.0), - } - post_samples = bm.SingleSiteAncestralMetropolisHastings().infer( - [self.prior()], obs, num_samples=10, num_chains=2 - ) - assert post_samples[self.prior()].shape == (2, 10) - predictives = bm.simulate(list(obs.keys()), post_samples, vectorized=False) - assert predictives[self.likelihood_i(0)].shape == (2, 10) - assert predictives[self.likelihood_i(1)].shape == (2, 10) + predictives = bm.simulate(list(obs.keys()), post_samples, vectorized=vectorized) + assert all(predictives[rv].shape == shape_samples for rv in obs.keys()) - def test_predictive_dynamic(self): + @staticmethod + def test_predictive_dynamic(model): obs = { - self.likelihood_dynamic(0): torch.tensor([0.9]), - self.likelihood_dynamic(1): torch.tensor([4.9]), + model.likelihood_dynamic(0): torch.tensor([0.9]), + model.likelihood_dynamic(1): torch.tensor([4.9]), } # only query one of the variables post_samples = bm.SingleSiteAncestralMetropolisHastings().infer( - [self.prior()], obs, num_samples=10, num_chains=2 + [model.prior()], obs, num_samples=10, num_chains=2 ) - assert post_samples[self.prior()].shape == (2, 10) + assert post_samples[model.prior()].shape == (2, 10) predictives = bm.simulate(list(obs.keys()), post_samples, vectorized=False) - assert predictives[self.likelihood_dynamic(0)].shape == (2, 10) - assert predictives[self.likelihood_dynamic(1)].shape == (2, 10) + assert predictives[model.likelihood_dynamic(0)].shape == (2, 10) + assert predictives[model.likelihood_dynamic(1)].shape == (2, 10) - def test_predictive_data(self): + @staticmethod + def test_predictive_data(model): x = torch.randn(4) y = torch.randn(4) + 2.0 - obs = {self.likelihood_reg(x): y} + obs = {model.likelihood_reg(x): y} post_samples = bm.SingleSiteAncestralMetropolisHastings().infer( - [self.prior()], obs, num_samples=10, num_chains=2 + [model.prior()], obs, num_samples=10, num_chains=2 ) - assert post_samples[self.prior()].shape == (2, 10) + assert post_samples[model.prior()].shape == (2, 10) test_x = torch.randn(4, 1, 1) - test_query = self.likelihood_reg(test_x) + test_query = model.likelihood_reg(test_x) predictives = bm.simulate([test_query], post_samples, vectorized=True) assert predictives[test_query].shape == (4, 2, 10) - def test_posterior_predictive_1d(self): - obs = {self.likelihood_1(): torch.tensor([1.0])} - post_samples = bm.SingleSiteAncestralMetropolisHastings().infer( - [self.prior_1()], obs, num_samples=10, num_chains=1 - ) - assert post_samples[self.prior_1()].shape == (1, 10, 1) - predictives = bm.simulate(list(obs.keys()), post_samples, vectorized=True) - y = predictives[self.likelihood_1()].shape - assert y == (1, 10, 1) - - def test_multi_chain_infer_predictive_2d(self): - torch.manual_seed(10) + @staticmethod + def test_empirical(model): obs = { - self.likelihood_2(0): torch.tensor([[1.0, 1.0]]), - self.likelihood_2(1): torch.tensor([[0.0, 1.0]]), + model.likelihood_i(0): torch.tensor(1.0), + model.likelihood_i(1): torch.tensor(0.0), + model.likelihood_i(2): torch.tensor(0.0), } post_samples = bm.SingleSiteAncestralMetropolisHastings().infer( - [self.prior_2()], obs, num_samples=10, num_chains=2 + [model.prior()], obs, num_samples=10, num_chains=4 ) - - assert post_samples[self.prior_2()].shape == (2, 10, 1, 2) - predictives = bm.simulate(list(obs.keys()), post_samples, vectorized=True) - predictive_0 = predictives[self.likelihood_2(0)] - predictive_1 = predictives[self.likelihood_2(1)] - assert predictive_0.shape == (2, 10, 1, 2) - assert predictive_1.shape == (2, 10, 1, 2) - assert (predictive_1 - predictive_0).sum().item() != 0 - - def test_empirical(self): - obs = { - self.likelihood_i(0): torch.tensor(1.0), - self.likelihood_i(1): torch.tensor(0.0), - self.likelihood_i(2): torch.tensor(0.0), - } - post_samples = bm.SingleSiteAncestralMetropolisHastings().infer( - [self.prior()], obs, num_samples=10, num_chains=4 - ) - empirical = bm.empirical([self.prior()], post_samples, num_samples=26) - assert empirical[self.prior()].shape == (1, 26) + empirical = bm.empirical([model.prior()], post_samples, num_samples=26) + assert empirical[model.prior()].shape == (1, 26) predictives = bm.simulate(list(obs.keys()), post_samples, vectorized=True) empirical = bm.empirical(list(obs.keys()), predictives, num_samples=27) assert len(empirical) == 3 - assert empirical[self.likelihood_i(0)].shape == (1, 27) - assert empirical[self.likelihood_i(1)].shape == (1, 27) - - def test_return_inference_data(self): - torch.manual_seed(10) - obs = { - self.likelihood_2(0): torch.tensor([[1.0, 1.0]]), - self.likelihood_2(1): torch.tensor([[0.0, 1.0]]), - } - post_samples = bm.SingleSiteAncestralMetropolisHastings().infer( - [self.prior_2()], obs, num_samples=10, num_chains=2 - ) + assert empirical[model.likelihood_i(0)].shape == (1, 27) + assert empirical[model.likelihood_i(1)].shape == (1, 27) - assert post_samples[self.prior_2()].shape == (2, 10, 1, 2) - predictives = bm.simulate( - list(obs.keys()), - post_samples, - vectorized=True, - ).to_inference_data() - assert "posterior" in predictives - assert "observed_data" in predictives - assert "log_likelihood" in predictives - assert "posterior_predictive" in predictives - assert predictives.posterior_predictive[self.likelihood_2(0)].shape == ( - 2, - 10, - 1, - 2, - ) - assert predictives.posterior_predictive[self.likelihood_2(1)].shape == ( - 2, - 10, - 1, - 2, - ) - - def test_posterior_dict(self): + @staticmethod + def test_posterior_dict(model): obs = { - self.likelihood_i(0): torch.tensor(1.0), - self.likelihood_i(1): torch.tensor(0.0), + model.likelihood_i(0): torch.tensor(1.0), + model.likelihood_i(1): torch.tensor(0.0), } - posterior = {self.prior(): torch.tensor([0.5, 0.5])} + posterior = {model.prior(): torch.tensor([0.5, 0.5])} predictives_dict = bm.simulate(list(obs.keys()), posterior) - assert predictives_dict[self.likelihood_i(0)].shape == (1, 2) - assert predictives_dict[self.likelihood_i(1)].shape == (1, 2) + assert predictives_dict[model.likelihood_i(0)].shape == (1, 2) + assert predictives_dict[model.likelihood_i(1)].shape == (1, 2) - def test_posterior_dict_predictive(self): + @staticmethod + def test_posterior_dict_predictive(model): obs = { - self.likelihood_i(0): torch.tensor(1.0), - self.likelihood_i(1): torch.tensor(0.0), + model.likelihood_i(0): torch.tensor(1.0), + model.likelihood_i(1): torch.tensor(0.0), } post_samples = bm.SingleSiteAncestralMetropolisHastings().infer( - [self.prior()], obs, num_samples=10, num_chains=1 + [model.prior()], obs, num_samples=10, num_chains=1 ) - assert post_samples[self.prior()].shape == (1, 10) + assert post_samples[model.prior()].shape == (1, 10) post_samples_dict = dict(post_samples) predictives_dict = bm.simulate(list(obs.keys()), post_samples_dict) - assert predictives_dict[self.likelihood_i(0)].shape == (1, 10) - assert predictives_dict[self.likelihood_i(1)].shape == (1, 10) + assert predictives_dict[model.likelihood_i(0)].shape == (1, 10) + assert predictives_dict[model.likelihood_i(1)].shape == (1, 10) + + +@parametrize_model_value( + [ + (UniformBernoulliModel(tensor([0.0]), tensor([1.0])), tensor([1.0])), + ( + UniformBernoulliModel(torch.zeros(1, 2), torch.ones(1, 2)), + tensor([[[1.0, 1.0]], [[0.0, 1.0]]]), + ), + ] +) +@pytest.mark.parametrize("num_chains", [1, 3]) +@parametrize_inference([bm.SingleSiteAncestralMetropolisHastings()]) +class TestPredictiveMV: + @staticmethod + def test_posterior_predictive(model, value, inference, num_chains): + torch.manual_seed(10) + num_samples = 10 + shape_samples = (num_chains, num_samples) + model.lo.shape + + # define observations + if value.ndim == model.lo.ndim: + obs = {model.likelihood(): value} + else: + obs = {model.likelihood_i(i): value[i] for i in range(len(value))} + + # run inference + post_samples = inference.infer( + [model.prior()], obs, num_samples=num_samples, num_chains=num_chains + ) + assert post_samples[model.prior()].shape == shape_samples + + # simulate predictives + predictives = bm.simulate(list(obs.keys()), post_samples, vectorized=True) + for rv in obs.keys(): + assert predictives[rv].shape == shape_samples + if value.ndim == 1 + model.lo.ndim: + rvs = list(obs.keys())[:2] + assert not approx_all(predictives[rvs[0]], predictives[rvs[1]], 0.5) + inf_data = predictives.to_inference_data() + result_keys = [ + "posterior", + "observed_data", + "log_likelihood", + "posterior_predictive", + ] + for k in result_keys: + assert k in inf_data + for rv in obs.keys(): + assert inf_data.posterior_predictive[rv].shape == shape_samples diff --git a/tests/ppl/inference/proposer/hmc_proposer_test.py b/tests/ppl/inference/proposer/hmc_proposer_test.py index 16d80354f7..5b503f1dc2 100644 --- a/tests/ppl/inference/proposer/hmc_proposer_test.py +++ b/tests/ppl/inference/proposer/hmc_proposer_test.py @@ -6,25 +6,23 @@ import beanmachine.ppl as bm import pytest import torch -import torch.distributions as dist +from beanmachine.ppl.examples.hierarchical_models import UniformNormalModel from beanmachine.ppl.inference.proposer.hmc_proposer import HMCProposer from beanmachine.ppl.world import World +from torch import tensor +from ...utils.fixtures import approx_all, parametrize_inference, parametrize_model -@bm.random_variable -def foo(): - return dist.Uniform(0.0, 1.0) - -@bm.random_variable -def bar(): - return dist.Normal(foo(), 1.0) +pytestmark = parametrize_model( + [UniformNormalModel(tensor(0.0), tensor(1.0), tensor(1.0))] +) @pytest.fixture -def world(): +def world(model): w = World() - w.call(bar()) + w.call(model.obs()) return w @@ -58,24 +56,23 @@ def test_leapfrog_step(hmc): new_positions, new_momentums, pe, pe_grad = hmc._leapfrog_step( hmc._positions, momentums, step_size, hmc._mass_inv ) - assert torch.allclose(momentums, new_momentums) - assert torch.allclose(hmc._positions, new_positions) + assert approx_all(momentums, new_momentums) + assert approx_all(hmc._positions, new_positions) -@pytest.mark.parametrize( +@parametrize_inference( # forcing the step_size to be 0 for HMC/ NUTS - "algorithm", [ bm.GlobalNoUTurnSampler(initial_step_size=0.0), bm.GlobalHamiltonianMonteCarlo(trajectory_length=1.0, initial_step_size=0.0), ], ) -def test_step_size_exception(algorithm): - queries = [foo()] - observations = {bar(): torch.tensor(0.5)} +def test_step_size_exception(model, inference): + queries = [model.mean()] + observations = {model.obs(): torch.tensor(0.5)} with pytest.raises(ValueError): - algorithm.infer( + inference.infer( queries, observations, num_samples=20, diff --git a/tests/ppl/inference/proposer/hmc_utils_test.py b/tests/ppl/inference/proposer/hmc_utils_test.py index 9f7fd5261f..efc6f6f221 100644 --- a/tests/ppl/inference/proposer/hmc_utils_test.py +++ b/tests/ppl/inference/proposer/hmc_utils_test.py @@ -5,11 +5,12 @@ import warnings -import beanmachine.ppl as bm import numpy as np import pytest import torch import torch.distributions as dist +from beanmachine.ppl.examples.hierarchical_models import UniformNormalModel +from beanmachine.ppl.examples.primitive_models import PoissonModel from beanmachine.ppl.inference.proposer.hmc_utils import ( DualAverageAdapter, MassMatrixAdapter, @@ -19,22 +20,9 @@ ) from beanmachine.ppl.inference.proposer.utils import DictToVecConverter from beanmachine.ppl.world import World +from torch import tensor - -class SampleModel: - @bm.random_variable - def foo(self): - return dist.Uniform(0.0, 1.0) - - @bm.random_variable - def bar(self): - return dist.Normal(self.foo(), 1.0) - - -class DiscreteModel: - @bm.random_variable - def baz(self): - return dist.Poisson(5.0) +from ...utils.fixtures import approx_all, parametrize_model def test_dual_average_adapter(): @@ -97,10 +85,10 @@ def test_large_window_scheme(num_adaptive_samples): @pytest.mark.parametrize("full_mass_matrix", [True, False]) -def test_mass_matrix_adapter(full_mass_matrix): - model = SampleModel() +@parametrize_model([UniformNormalModel(tensor(0.0), tensor(1.0), tensor(1.0))]) +def test_mass_matrix_adapter(model, full_mass_matrix): world = World() - world.call(model.bar()) + world.call(model.obs()) positions_dict = RealSpaceTransform(world, world.latent_nodes)(dict(world)) dict2vec = DictToVecConverter(positions_dict) positions = dict2vec.to_vec(positions_dict) @@ -116,7 +104,7 @@ def test_mass_matrix_adapter(full_mass_matrix): mass_matrix_adapter.finalize() # mass matrix adapter has seen less than 2 samples, so mass_inv is not updated - assert torch.allclose(mass_inv_old, mass_matrix_adapter.mass_inv) + assert approx_all(mass_inv_old, mass_matrix_adapter.mass_inv) # check the size of the matrix matrix_width = len(positions) @@ -135,7 +123,7 @@ def test_diagonal_welford_covariance(): welford.step(sample) sample_var = torch.var(samples, dim=0) estimated_var = welford.finalize(regularize=False) - assert torch.allclose(estimated_var, sample_var) + assert approx_all(estimated_var, sample_var) regularized_var = welford.finalize(regularize=True) assert (torch.argsort(regularized_var) == torch.argsort(estimated_var)).all() @@ -149,7 +137,7 @@ def test_dense_welford_covariance(): welford.step(sample) sample_cov = torch.from_numpy(np.cov(samples.T.numpy())).to(samples.dtype) estimated_cov = welford.finalize(regularize=False) - assert torch.allclose(estimated_cov, sample_cov) + assert approx_all(estimated_cov, sample_cov) regularized_cov = welford.finalize(regularize=True) assert (torch.argsort(regularized_cov) == torch.argsort(estimated_cov)).all() @@ -161,9 +149,9 @@ def test_welford_exception(): welford.finalize() -def test_discrete_rv_exception(): - model = DiscreteModel() +@parametrize_model([PoissonModel(tensor(5.0))]) +def test_discrete_rv_exception(model): world = World() - world.call(model.baz()) + world.call(model.x()) with pytest.raises(TypeError): RealSpaceTransform(world, world.latent_nodes)(dict(world)) diff --git a/tests/ppl/inference/proposer/nmc/single_site_half_space_newtonian_monte_carlo_proposer_test.py b/tests/ppl/inference/proposer/nmc/single_site_half_space_newtonian_monte_carlo_proposer_test.py index 6f571fe840..2a1e0783b5 100644 --- a/tests/ppl/inference/proposer/nmc/single_site_half_space_newtonian_monte_carlo_proposer_test.py +++ b/tests/ppl/inference/proposer/nmc/single_site_half_space_newtonian_monte_carlo_proposer_test.py @@ -3,73 +3,22 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import unittest - -import beanmachine.ppl as bm -import torch -import torch.distributions as dist +from beanmachine.ppl.examples.primitive_models import GammaModel from beanmachine.ppl.inference.proposer.nmc import SingleSiteHalfSpaceNMCProposer from beanmachine.ppl.world import World from torch import tensor +from ....utils.fixtures import approx, parametrize_model, parametrize_proposer -class SingleSiteHalfSpaceNewtonianMonteCarloProposerTest(unittest.TestCase): - class SampleNormalModel: - @bm.random_variable - def foo(self): - return dist.Normal(tensor(2.0), tensor(2.0)) - - @bm.random_variable - def bar(self): - return dist.Normal(self.foo(), torch.tensor(1.0)) - - class SampleLogisticRegressionModel: - @bm.random_variable - def theta_0(self): - return dist.Normal(tensor(0.0), tensor(1.0)) - - @bm.random_variable - def theta_1(self): - return dist.Normal(tensor(0.0), tensor(1.0)) - - @bm.random_variable - def x(self, i): - return dist.Normal(tensor(0.0), tensor(1.0)) - - @bm.random_variable - def y(self, i): - y = self.theta_1() * self.x(i) + self.theta_0() - probs = 1 / (1 + (y * -1).exp()) - return dist.Bernoulli(probs) - - class SampleFallbackModel: - @bm.random_variable - def foo(self): - return dist.Gamma(tensor(2.0), tensor(2.0)) - - @bm.random_variable - def bar(self): - return dist.Normal(self.foo(), torch.tensor(1.0)) - - def test_alpha_and_beta_for_gamma(self): - alpha = tensor([2.0, 2.0, 2.0]) - beta = tensor([2.0, 2.0, 2.0]) - - @bm.random_variable - def gamma(): - return dist.Gamma(alpha, beta) - world = World() - with world: - gamma() - nw_proposer = SingleSiteHalfSpaceNMCProposer(gamma()) - is_valid, predicted_alpha, predicted_beta = nw_proposer.compute_alpha_beta( - world - ) - self.assertEqual(is_valid, True) - self.assertAlmostEqual( - alpha.sum().item(), (predicted_alpha).sum().item(), delta=0.0001 - ) - self.assertAlmostEqual( - beta.sum().item(), (predicted_beta).sum().item(), delta=0.0001 - ) +@parametrize_model([GammaModel(tensor([2.0, 2.0, 2.0]), tensor([2.0, 2.0, 2.0]))]) +@parametrize_proposer([SingleSiteHalfSpaceNMCProposer]) +def test_alpha_and_beta_for_gamma(model, proposer): + world = World() + with world: + model.x() + prop = proposer(model.x()) + is_valid, predicted_alpha, predicted_beta = prop.compute_alpha_beta(world) + assert is_valid + assert approx(model.alpha.sum(), predicted_alpha.sum(), 1e-4) + assert approx(model.beta.sum(), predicted_beta.sum(), 1e-4) diff --git a/tests/ppl/inference/proposer/nmc/single_site_real_space_newtonian_monte_carlo_proposer_test.py b/tests/ppl/inference/proposer/nmc/single_site_real_space_newtonian_monte_carlo_proposer_test.py index b60309fd8d..5c454a438f 100644 --- a/tests/ppl/inference/proposer/nmc/single_site_real_space_newtonian_monte_carlo_proposer_test.py +++ b/tests/ppl/inference/proposer/nmc/single_site_real_space_newtonian_monte_carlo_proposer_test.py @@ -3,12 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import unittest - -import beanmachine.ppl as bm import torch import torch.autograd import torch.distributions as dist +from beanmachine.ppl.examples.conjugate_models import MV_NormalNormalModel +from beanmachine.ppl.examples.hierarchical_models import LogisticRegressionModel from beanmachine.ppl.inference.proposer.nmc.single_site_real_space_nmc_proposer import ( SingleSiteRealSpaceNMCProposer as SingleSiteRealSpaceNewtonianMonteCarloProposer, ) @@ -16,312 +15,243 @@ from beanmachine.ppl.world.variable import Variable from torch import tensor +from ....utils.fixtures import ( + approx, + approx_all, + parametrize_model, + parametrize_model_value_expected, + parametrize_proposer, + parametrize_value_expected, +) -class SingleSiteRealSpaceNewtonianMonteCarloProposerTest(unittest.TestCase): - class SampleNormalModel: - @bm.random_variable - def foo(self): - return dist.MultivariateNormal(torch.zeros(2), torch.eye(2)) - - @bm.random_variable - def bar(self): - return dist.MultivariateNormal(self.foo(), torch.eye(2)) - - class SampleLogisticRegressionModel: - @bm.random_variable - def theta_0(self): - return dist.Normal(tensor(0.0), tensor(1.0)) - - @bm.random_variable - def theta_1(self): - return dist.Normal(tensor(0.0), tensor(1.0)) - @bm.random_variable - def x(self, i): - return dist.Normal(tensor(0.0), tensor(1.0)) +pytestmark = parametrize_proposer([SingleSiteRealSpaceNewtonianMonteCarloProposer]) - @bm.random_variable - def y(self, i): - y = self.theta_1() * self.x(i) + self.theta_0() - probs = 1 / (1 + (y * -1).exp()) - return dist.Bernoulli(probs) - def test_mean_scale_tril_for_node_with_child(self): - foo_key = bm.random_variable( - lambda: dist.MultivariateNormal( - tensor([1.0, 1.0]), tensor([[1.0, 0.8], [0.8, 1]]) - ) - ) - bar_key = bm.random_variable( - lambda: dist.MultivariateNormal( - foo_key(), +@parametrize_model_value_expected( + [ + ( + MV_NormalNormalModel( + tensor([1.0, 1.0]), + tensor([[1.0, 0.8], [0.8, 1]]), tensor([[1.0, 0.8], [0.8, 1.0]]), - ) - ) - nw_proposer = SingleSiteRealSpaceNewtonianMonteCarloProposer(foo_key()) - val = tensor([2.0, 2.0]) - queries = [foo_key(), bar_key()] - observed_val = tensor([2.0, 2.0]) - observations = {bar_key(): observed_val} - world = World.initialize_world(queries, observations) - world_vars = world._variables - world_vars[foo_key] = val - - nw_proposer.learning_rate_ = 1.0 - prop_dist = nw_proposer.get_proposal_distribution(world).base_dist - mean, scale_tril = prop_dist.mean, prop_dist.scale_tril - expected_mean = tensor([1.5, 1.5]) - expected_scale_tril = torch.linalg.cholesky( - tensor([[0.5000, 0.4000], [0.4000, 0.5000]]) - ) - self.assertTrue(torch.isclose(mean, expected_mean).all()) - self.assertTrue(torch.isclose(scale_tril, expected_scale_tril).all()) - - def test_mean_scale_tril(self): - model = self.SampleNormalModel() - foo_key = model.foo() - nw_proposer = SingleSiteRealSpaceNewtonianMonteCarloProposer(foo_key) - val = tensor([2.0, 2.0]) - val.requires_grad_(True) - distribution = dist.MultivariateNormal( - tensor([1.0, 1.0]), tensor([[1.0, 0.8], [0.8, 1]]) - ) - queries = [foo_key] + ), + tensor([2.0, 2.0]), + ([1.5, 1.5], torch.linalg.cholesky(tensor([[0.5, 0.4], [0.4, 0.5]]))), + ), + ( + MV_NormalNormalModel(torch.zeros(2), torch.eye(2), torch.eye(2)), + ( + dist.MultivariateNormal( + tensor([1.0, 1.0]), tensor([[1.0, 0.8], [0.8, 1.0]]) + ), + tensor([2.0, 2.0]), + ), + ([1.0, 1.0], torch.linalg.cholesky(tensor([[1.0, 0.8], [0.8, 1]]))), + ), + ( + MV_NormalNormalModel(torch.zeros(2), torch.eye(2), torch.eye(2)), + ( + dist.Normal( + tensor([[1.0, 1.0], [1.0, 1.0]]), tensor([[1.0, 1.0], [1.0, 1.0]]) + ), + tensor([[2.0, 2.0], [2.0, 2.0]]), + ), + ([1.0, 1.0, 1.0, 1.0], torch.eye(4)), + ), + ] +) +def test_mean_scale_tril(model, proposer, value, expected): + # set latents + if torch.is_tensor(value): + queries = [model.theta(), model.x()] + observations = {model.x(): value} + else: + queries = [model.theta()] observations = {} - world = World.initialize_world(queries, observations) - world_vars = world._variables - world_vars[foo_key] = Variable( - value=val, - distribution=distribution, - ) - - nw_proposer.learning_rate_ = 1.0 - prop_dist = nw_proposer.get_proposal_distribution(world).base_dist - mean, scale_tril = prop_dist.mean, prop_dist.scale_tril - - expected_mean = tensor([1.0, 1.0]) - expected_scale_tril = torch.linalg.cholesky(tensor([[1.0, 0.8], [0.8, 1]])) - self.assertTrue(torch.isclose(mean, expected_mean).all()) - self.assertTrue(torch.isclose(scale_tril, expected_scale_tril).all()) - - def test_mean_scale_tril_for_iids(self): - model = self.SampleNormalModel() - foo_key = model.foo() - nw_proposer = SingleSiteRealSpaceNewtonianMonteCarloProposer(foo_key) - val = tensor([[2.0, 2.0], [2.0, 2.0]]) + world = World.initialize_world(queries, observations) + if torch.is_tensor(value): + world._variables[model.theta] = value + else: + dist, val = value val.requires_grad_(True) - distribution = dist.Normal( - tensor([[1.0, 1.0], [1.0, 1.0]]), tensor([[1.0, 1.0], [1.0, 1.0]]) - ) - queries = [foo_key] - observations = {} - world = World.initialize_world(queries, observations) - world_vars = world._variables - world_vars[foo_key] = Variable( - value=val, - distribution=distribution, - ) - - nw_proposer.learning_rate_ = 1.0 - prop_dist = nw_proposer.get_proposal_distribution(world).base_dist - mean, scale_tril = prop_dist.mean, prop_dist.scale_tril - - expected_mean = tensor([1.0, 1.0, 1.0, 1.0]) - expected_scale_tril = torch.eye(4) - self.assertTrue(torch.isclose(mean, expected_mean).all()) - self.assertTrue(torch.isclose(scale_tril, expected_scale_tril).all()) - - def test_multi_mean_scale_tril_computation_in_inference(self): - model = self.SampleLogisticRegressionModel() - theta_0_key = model.theta_0() - theta_1_key = model.theta_1() - nw_proposer = SingleSiteRealSpaceNewtonianMonteCarloProposer(theta_0_key) - - x_0_key = model.x(0) - x_1_key = model.x(1) - y_0_key = model.y(0) - y_1_key = model.y(1) - - theta_0_value = tensor(1.5708) - theta_0_value.requires_grad_(True) - x_0_value = tensor(0.7654) - x_1_value = tensor(-6.6737) - theta_1_value = tensor(-0.4459) - - theta_0_distribution = dist.Normal(torch.tensor(0.0), torch.tensor(1.0)) - queries = [theta_0_key, theta_1_key] - observations = {} - world = World.initialize_world(queries, observations) - world_vars = world._variables - world_vars[theta_0_key] = Variable( - value=theta_0_value, - distribution=theta_0_distribution, - children=set({y_0_key, y_1_key}), - ) - - world_vars[theta_1_key] = Variable( - value=theta_1_value, - distribution=theta_0_distribution, - children=set({y_0_key, y_1_key}), - ) - - x_distribution = dist.Normal(torch.tensor(0.0), torch.tensor(5.0)) - world_vars[x_0_key] = Variable( - value=x_0_value, - distribution=x_distribution, - children=set({y_0_key, y_1_key}), - ) - - world_vars[x_1_key] = Variable( - value=x_1_value, - distribution=x_distribution, - children=set({y_0_key, y_1_key}), - ) - - y = theta_0_value + theta_1_value * x_0_value - probs_0 = 1 / (1 + (y * -1).exp()) - y_0_distribution = dist.Bernoulli(probs_0) - - world_vars[y_0_key] = Variable( - value=tensor(1.0), - distribution=y_0_distribution, - parents=set({theta_0_key, theta_1_key, x_0_key}), - ) - - y = theta_0_value + theta_1_value * x_1_value - probs_1 = 1 / (1 + (y * -1).exp()) - y_1_distribution = dist.Bernoulli(probs_1) - - world_vars[y_1_key] = Variable( - value=tensor(1.0), - distribution=y_1_distribution, - parents=set({theta_0_key, theta_1_key, x_1_key}), - ) - - nw_proposer.learning_rate_ = 1.0 - prop_dist = nw_proposer.get_proposal_distribution(world).base_dist - mean, scale_tril = prop_dist.mean, prop_dist.scale_tril - - score = theta_0_distribution.log_prob(theta_0_value) - score += ( - 1 / (1 + (-1 * (theta_0_value + theta_1_value * x_0_value)).exp()) - ).log() - score += ( - 1 / (1 + (-1 * (theta_0_value + theta_1_value * x_1_value)).exp()) - ).log() - - expected_first_gradient = torch.autograd.grad( - score, theta_0_value, create_graph=True - )[0] - expected_second_gradient = torch.autograd.grad( - expected_first_gradient, theta_0_value - )[0] - - expected_covar = expected_second_gradient.reshape(1, 1).inverse() * -1 - expected_scale_tril = torch.linalg.cholesky(expected_covar) - self.assertAlmostEqual( - expected_scale_tril.item(), scale_tril.item(), delta=0.001 - ) - expected_first_gradient = expected_first_gradient.unsqueeze(0) - expected_mean = ( - theta_0_value.unsqueeze(0) - + expected_first_gradient.unsqueeze(0).mm(expected_covar) - ).squeeze(0) - self.assertAlmostEqual(mean.item(), expected_mean.item(), delta=0.001) - - proposal_value = ( - dist.MultivariateNormal(mean, scale_tril=scale_tril) - .sample() - .reshape(theta_0_value.shape) - ) - proposal_value.requires_grad_(True) - world_vars[theta_0_key].value = proposal_value - - y = proposal_value + theta_1_value * x_0_value - probs_0 = 1 / (1 + (y * -1).exp()) - y_0_distribution = dist.Bernoulli(probs_0) - world_vars[y_0_key].distribution = y_0_distribution - world_vars[y_0_key].log_prob = y_0_distribution.log_prob(tensor(1.0)) - y = proposal_value + theta_1_value * x_1_value - probs_1 = 1 / (1 + (y * -1).exp()) - y_1_distribution = dist.Bernoulli(probs_1) - world_vars[y_1_key].distribution = y_1_distribution - - nw_proposer.learning_rate_ = 1.0 - prop_dist = nw_proposer.get_proposal_distribution(world).base_dist - mean, scale_tril = prop_dist.mean, prop_dist.scale_tril - - score = tensor(0.0) - - score = theta_0_distribution.log_prob(proposal_value) - score += ( - 1 / (1 + (-1 * (proposal_value + theta_1_value * x_0_value)).exp()) - ).log() - score += ( - 1 / (1 + (-1 * (proposal_value + theta_1_value * x_1_value)).exp()) - ).log() - - expected_first_gradient = torch.autograd.grad( - score, proposal_value, create_graph=True - )[0] - expected_second_gradient = torch.autograd.grad( - expected_first_gradient, proposal_value - )[0] - expected_covar = expected_second_gradient.reshape(1, 1).inverse() * -1 - expected_scale_tril = torch.linalg.cholesky(expected_covar) - self.assertAlmostEqual( - expected_scale_tril.item(), scale_tril.item(), delta=0.001 - ) - expected_first_gradient = expected_first_gradient.unsqueeze(0) - - expected_mean = ( - proposal_value.unsqueeze(0) - + expected_first_gradient.unsqueeze(0).mm(expected_covar) - ).squeeze(0) - self.assertAlmostEqual(mean.item(), expected_mean.item(), delta=0.001) - - self.assertAlmostEqual( - scale_tril.item(), expected_scale_tril.item(), delta=0.001 - ) - - def test_adaptive_alpha_beta_computation(self): - model = self.SampleLogisticRegressionModel() - theta_0_key = model.theta_0() - nw_proposer = SingleSiteRealSpaceNewtonianMonteCarloProposer(theta_0_key) - nw_proposer.learning_rate_ = tensor(0.0416, dtype=torch.float64) - nw_proposer.running_mean_, nw_proposer.running_var_ = ( - tensor(0.079658), - tensor(0.0039118), - ) - nw_proposer.accepted_samples_ = 37 - alpha, beta = nw_proposer.compute_beta_priors_from_accepted_lr() - self.assertAlmostEqual(nw_proposer.running_mean_.item(), 0.0786, delta=0.0001) - self.assertAlmostEqual(nw_proposer.running_var_.item(), 0.00384, delta=0.00001) - self.assertAlmostEqual(alpha.item(), 1.4032, delta=0.001) - self.assertAlmostEqual(beta.item(), 16.4427, delta=0.001) - - def test_adaptive_vectorized_alpha_beta_computation(self): - model = self.SampleLogisticRegressionModel() - theta_0_key = model.theta_0() - nw_proposer = SingleSiteRealSpaceNewtonianMonteCarloProposer(theta_0_key) - nw_proposer.learning_rate_ = tensor([0.0416, 0.0583], dtype=torch.float64) - nw_proposer.running_mean_, nw_proposer.running_var_ = ( - tensor([0.079658, 0.089861]), - tensor([0.0039118, 0.0041231]), - ) - nw_proposer.accepted_samples_ = 37 - alpha, beta = nw_proposer.compute_beta_priors_from_accepted_lr() - self.assertListEqual( - [round(x.item(), 4) for x in list(nw_proposer.running_mean_)], - [0.0786, 0.089], - ) - self.assertListEqual( - [round(x.item(), 4) for x in list(nw_proposer.running_var_)], - [0.0038, 0.004], - ) - self.assertListEqual( - [round(x.item(), 4) for x in list(alpha)], [1.4032, 1.6984] - ) - self.assertListEqual( - [round(x.item(), 4) for x in list(beta)], [16.4427, 17.3829] - ) + world._variables[model.theta()] = Variable(value=val, distribution=dist) + + # evaluate proposer + prop = proposer(model.theta()) + prop.learning_rate_ = 1.0 + prop_dist = prop.get_proposal_distribution(world).base_dist + mean, scale_tril = prop_dist.mean, prop_dist.scale_tril + expected_mean, expected_scale_tril = expected + assert approx_all(mean, expected_mean) + assert approx_all(scale_tril, expected_scale_tril) + + +@parametrize_model([LogisticRegressionModel()]) +def test_multi_mean_scale_tril_computation_in_inference(model, proposer): + theta_0_key = model.theta_0() + theta_1_key = model.theta_1() + + x_0_key = model.x(0) + x_1_key = model.x(1) + y_0_key = model.y(0) + y_1_key = model.y(1) + + theta_0_value = tensor(1.5708) + theta_0_value.requires_grad_(True) + x_0_value = tensor(0.7654) + x_1_value = tensor(-6.6737) + theta_1_value = tensor(-0.4459) + + theta_0_distribution = dist.Normal(torch.tensor(0.0), torch.tensor(1.0)) + queries = [theta_0_key, theta_1_key] + observations = {} + world = World.initialize_world(queries, observations) + world_vars = world._variables + world_vars[theta_0_key] = Variable( + value=theta_0_value, + distribution=theta_0_distribution, + children=set({y_0_key, y_1_key}), + ) + world_vars[theta_1_key] = Variable( + value=theta_1_value, + distribution=theta_0_distribution, + children=set({y_0_key, y_1_key}), + ) + + x_distribution = dist.Normal(torch.tensor(0.0), torch.tensor(5.0)) + world_vars[x_0_key] = Variable( + value=x_0_value, + distribution=x_distribution, + children=set({y_0_key, y_1_key}), + ) + world_vars[x_1_key] = Variable( + value=x_1_value, + distribution=x_distribution, + children=set({y_0_key, y_1_key}), + ) + + y = theta_0_value + theta_1_value * x_0_value + probs_0 = 1 / (1 + (y * -1).exp()) + y_0_distribution = dist.Bernoulli(probs_0) + + world_vars[y_0_key] = Variable( + value=tensor(1.0), + distribution=y_0_distribution, + parents=set({theta_0_key, theta_1_key, x_0_key}), + ) + y = theta_0_value + theta_1_value * x_1_value + probs_1 = 1 / (1 + (y * -1).exp()) + y_1_distribution = dist.Bernoulli(probs_1) + world_vars[y_1_key] = Variable( + value=tensor(1.0), + distribution=y_1_distribution, + parents=set({theta_0_key, theta_1_key, x_1_key}), + ) + + prop = proposer(theta_0_key) + prop.learning_rate_ = 1.0 + prop_dist = prop.get_proposal_distribution(world).base_dist + mean, scale_tril = prop_dist.mean, prop_dist.scale_tril + + score = theta_0_distribution.log_prob(theta_0_value) + score += (1 / (1 + (-1 * (theta_0_value + theta_1_value * x_0_value)).exp())).log() + score += (1 / (1 + (-1 * (theta_0_value + theta_1_value * x_1_value)).exp())).log() + + expected_first_gradient = torch.autograd.grad( + score, theta_0_value, create_graph=True + )[0] + expected_second_gradient = torch.autograd.grad( + expected_first_gradient, theta_0_value + )[0] + + expected_covar = expected_second_gradient.reshape(1, 1).inverse() * -1 + expected_scale_tril = torch.linalg.cholesky(expected_covar) + assert approx(scale_tril, expected_scale_tril, 1e-3) + + expected_first_gradient = expected_first_gradient.unsqueeze(0) + expected_mean = ( + theta_0_value.unsqueeze(0) + + expected_first_gradient.unsqueeze(0).mm(expected_covar) + ).squeeze(0) + assert approx(mean, expected_mean, 1e-3) + + proposal_value = ( + dist.MultivariateNormal(mean, scale_tril=scale_tril) + .sample() + .reshape(theta_0_value.shape) + ) + proposal_value.requires_grad_(True) + world_vars[theta_0_key].value = proposal_value + + y = proposal_value + theta_1_value * x_0_value + probs_0 = 1 / (1 + (y * -1).exp()) + y_0_distribution = dist.Bernoulli(probs_0) + world_vars[y_0_key].distribution = y_0_distribution + world_vars[y_0_key].log_prob = y_0_distribution.log_prob(tensor(1.0)) + y = proposal_value + theta_1_value * x_1_value + probs_1 = 1 / (1 + (y * -1).exp()) + y_1_distribution = dist.Bernoulli(probs_1) + world_vars[y_1_key].distribution = y_1_distribution + + prop.learning_rate_ = 1.0 + prop_dist = prop.get_proposal_distribution(world).base_dist + mean, scale_tril = prop_dist.mean, prop_dist.scale_tril + + score = tensor(0.0) + + score = theta_0_distribution.log_prob(proposal_value) + score += (1 / (1 + (-1 * (proposal_value + theta_1_value * x_0_value)).exp())).log() + score += (1 / (1 + (-1 * (proposal_value + theta_1_value * x_1_value)).exp())).log() + + expected_first_gradient = torch.autograd.grad( + score, proposal_value, create_graph=True + )[0] + expected_second_gradient = torch.autograd.grad( + expected_first_gradient, proposal_value + )[0] + expected_covar = expected_second_gradient.reshape(1, 1).inverse() * -1 + expected_scale_tril = torch.linalg.cholesky(expected_covar) + assert approx(scale_tril, expected_scale_tril, 1e-3) + + expected_first_gradient = expected_first_gradient.unsqueeze(0) + expected_mean = ( + proposal_value.unsqueeze(0) + + expected_first_gradient.unsqueeze(0).mm(expected_covar) + ).squeeze(0) + assert approx(mean, expected_mean, 1e-3) + assert approx(scale_tril, expected_scale_tril, 1e-3) + + +@parametrize_model([LogisticRegressionModel()]) +@parametrize_value_expected( + [ + ( + tensor([0.0416, 0.079658, 0.0039118]).to(torch.float64), + [0.07863, 0.00384, 1.40321, 16.44271], + ), + ( + tensor([[0.0416, 0.0583], [0.079658, 0.089861], [0.0039118, 0.0041231]]).to( + torch.float64 + ), + [ + [0.07863, 0.08901], + [0.00384, 0.00403], + [1.40321, 1.69839], + [16.44271, 17.38294], + ], + ), + ] +) +def test_adaptive_alpha_beta_computation(model, proposer, value, expected): + theta_0_key = model.theta_0() + prop = proposer(theta_0_key) + prop.learning_rate_, prop.running_mean_, prop.running_var_ = value + prop.accepted_samples_ = 37 + alpha, beta = prop.compute_beta_priors_from_accepted_lr() + results = [prop.running_mean_, prop.running_var_, alpha, beta] + assert approx_all( + (torch.hstack if results[0].ndim == 0 else torch.vstack)(results), + expected, + 1e-5, + ) diff --git a/tests/ppl/inference/proposer/nmc/single_site_simplex_newtonian_monte_carlo_proposer_test.py b/tests/ppl/inference/proposer/nmc/single_site_simplex_newtonian_monte_carlo_proposer_test.py index cb37859a4c..0ae93d64f4 100644 --- a/tests/ppl/inference/proposer/nmc/single_site_simplex_newtonian_monte_carlo_proposer_test.py +++ b/tests/ppl/inference/proposer/nmc/single_site_simplex_newtonian_monte_carlo_proposer_test.py @@ -3,56 +3,46 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import unittest - import torch -import torch.distributions as dist -from beanmachine import ppl as bm +from beanmachine.ppl.examples.conjugate_models import BetaBernoulliModel +from beanmachine.ppl.examples.primitive_models import DirichletModel from beanmachine.ppl.inference.proposer.nmc import SingleSiteSimplexSpaceNMCProposer from beanmachine.ppl.inference.single_site_nmc import SingleSiteNewtonianMonteCarlo from beanmachine.ppl.world import World from torch import tensor - -class SingleSiteSimplexNewtonianMonteCarloProposerTest(unittest.TestCase): - def test_alpha_for_dirichlet(self): - alpha = tensor([[0.5, 0.5], [0.5, 0.5]]) - - @bm.random_variable - def a(): - return dist.Dirichlet(alpha) - - world_ = World() - with world_: - a() - nw_proposer = SingleSiteSimplexSpaceNMCProposer(a()) - is_valid, predicted_alpha = nw_proposer.compute_alpha(world_) - self.assertEqual(is_valid, True) - self.assertAlmostEqual( - alpha.sum().item(), (predicted_alpha).sum().item(), delta=0.0001 - ) - - def test_coin_flip(self): - prior_heads, prior_tails = 2.0, 2.0 - p = bm.random_variable(lambda: dist.Beta(2.0, 2.0)) - x = bm.random_variable(lambda: dist.Bernoulli(p())) - - heads_observed = 5 - samples = ( - SingleSiteNewtonianMonteCarlo() - .infer( - queries=[p()], - observations={x(): torch.ones(heads_observed)}, - num_samples=100, - num_chains=1, - ) - .get_chain(0) - ) - - # assert we are close to the conjugate poserior mean - self.assertAlmostEqual( - samples[p()].mean(), - (prior_heads + heads_observed) - / (prior_heads + prior_tails + heads_observed), - delta=0.05, - ) +from ....utils.fixtures import ( + approx, + parametrize_inference, + parametrize_model, + parametrize_proposer, +) + + +@parametrize_model([DirichletModel(tensor([[0.5, 0.5], [0.5, 0.5]]))]) +@parametrize_proposer([SingleSiteSimplexSpaceNMCProposer]) +def test_alpha_for_dirichlet(model, proposer): + world_ = World() + with world_: + model.x() + prop = proposer(model.x()) + is_valid, predicted_alpha = prop.compute_alpha(world_) + assert is_valid + assert approx(model.alpha.sum(), predicted_alpha.sum(), 1e-4) + + +@parametrize_model([BetaBernoulliModel(tensor(2.0), tensor(2.0))]) +@parametrize_inference([SingleSiteNewtonianMonteCarlo()]) +def test_coin_flip(model, inference): + prior_heads, prior_tails = model.alpha_, model.beta_ + heads_observed = 5 + conjugate_posterior_mean = (prior_heads + heads_observed) / ( + prior_heads + prior_tails + heads_observed + ) + samples = inference.infer( + queries=[model.theta()], + observations={model.x(0): torch.ones(heads_observed)}, + num_samples=100, + num_chains=1, + ).get_chain(0) + assert approx(samples[model.theta()].mean(), conjugate_posterior_mean, 5e-2) diff --git a/tests/ppl/inference/sampler_test.py b/tests/ppl/inference/sampler_test.py index f2281c61de..671db2bcd3 100644 --- a/tests/ppl/inference/sampler_test.py +++ b/tests/ppl/inference/sampler_test.py @@ -5,46 +5,50 @@ import beanmachine.ppl as bm import torch -import torch.distributions as dist +from beanmachine.ppl.examples.conjugate_models import NormalNormalModel +from torch import tensor +from ..utils.fixtures import ( + parametrize_inference, + parametrize_inference_comparison, + parametrize_model, +) -class SampleModel: - @bm.random_variable - def foo(self): - return dist.Normal(0.0, 1.0) - @bm.random_variable - def bar(self): - return dist.Normal(self.foo(), 1.0) +pytestmark = parametrize_model( + [NormalNormalModel(tensor(0.0), tensor(1.0), tensor(1.0))] +) -def test_sampler(): - model = SampleModel() - nuts = bm.GlobalNoUTurnSampler() - queries = [model.foo()] - observations = {model.bar(): torch.tensor(0.5)} +@parametrize_inference([bm.GlobalNoUTurnSampler(), bm.GlobalHamiltonianMonteCarlo(1.0)]) +def test_sampler(model, inference): + queries = [model.theta()] + observations = {model.x(): torch.tensor(0.5)} num_samples = 10 - sampler = nuts.sampler(queries, observations, num_samples, num_adaptive_samples=0) + sampler = inference.sampler( + queries, observations, num_samples, num_adaptive_samples=0 + ) worlds = list(sampler) assert len(worlds) == num_samples for world in worlds: - assert model.foo() in world + assert model.theta() in world with world: - assert isinstance(model.foo(), torch.Tensor) + assert isinstance(model.theta(), torch.Tensor) -def test_two_samplers(): - model = SampleModel() - queries = [model.foo()] - observations = {model.bar(): torch.tensor(0.5)} - nuts_sampler = bm.GlobalNoUTurnSampler().sampler(queries, observations) - hmc_sampler = bm.GlobalHamiltonianMonteCarlo(1.0).sampler(queries, observations) - world = next(nuts_sampler) +@parametrize_inference_comparison( + [bm.GlobalNoUTurnSampler(), bm.GlobalHamiltonianMonteCarlo(1.0)] +) +def test_two_samplers(model, inferences): + queries = [model.theta()] + observations = {model.x(): torch.tensor(0.5)} + samplers = [alg.sampler(queries, observations) for alg in inferences] + world = next(samplers[0]) # it's possible to use multiple sampler interchangably to update the worlds (or # in general, pass a new world to sampler and continue inference with existing # hyperparameters) for _ in range(3): - world = hmc_sampler.send(world) - world = nuts_sampler.send(world) - assert model.foo() in world - assert model.bar() in world + world = samplers[1].send(world) + world = samplers[0].send(world) + assert model.theta() in world + assert model.x() in world diff --git a/tests/ppl/inference/single_site_ancestral_mh_conjugate_test_nightly.py b/tests/ppl/inference/single_site_ancestral_mh_conjugate_test_nightly.py index 2c2c16ec24..e1e73b3ad4 100644 --- a/tests/ppl/inference/single_site_ancestral_mh_conjugate_test_nightly.py +++ b/tests/ppl/inference/single_site_ancestral_mh_conjugate_test_nightly.py @@ -6,6 +6,7 @@ import unittest import beanmachine.ppl as bm + from ..testlib.abstract_conjugate import AbstractConjugateTests diff --git a/tests/ppl/inference/single_site_hamiltonian_monte_carlo_conjugate_test_nightly.py b/tests/ppl/inference/single_site_hamiltonian_monte_carlo_conjugate_test_nightly.py index 6593879bc3..dcf649e79e 100644 --- a/tests/ppl/inference/single_site_hamiltonian_monte_carlo_conjugate_test_nightly.py +++ b/tests/ppl/inference/single_site_hamiltonian_monte_carlo_conjugate_test_nightly.py @@ -6,6 +6,7 @@ import unittest import beanmachine.ppl as bm + from ..testlib.abstract_conjugate import AbstractConjugateTests diff --git a/tests/ppl/inference/single_site_newtonian_monte_carlo_conjugate_test_nightly.py b/tests/ppl/inference/single_site_newtonian_monte_carlo_conjugate_test_nightly.py index abd8c9ddab..0dd198cf67 100644 --- a/tests/ppl/inference/single_site_newtonian_monte_carlo_conjugate_test_nightly.py +++ b/tests/ppl/inference/single_site_newtonian_monte_carlo_conjugate_test_nightly.py @@ -6,6 +6,7 @@ import unittest import beanmachine.ppl as bm + from ..testlib.abstract_conjugate import AbstractConjugateTests diff --git a/tests/ppl/inference/single_site_no_u_turn_conjugate_test_nightly.py b/tests/ppl/inference/single_site_no_u_turn_conjugate_test_nightly.py index 172e2739fd..1f18cd45be 100644 --- a/tests/ppl/inference/single_site_no_u_turn_conjugate_test_nightly.py +++ b/tests/ppl/inference/single_site_no_u_turn_conjugate_test_nightly.py @@ -6,6 +6,7 @@ import unittest import beanmachine.ppl as bm + from ..testlib.abstract_conjugate import AbstractConjugateTests diff --git a/tests/ppl/inference/single_site_random_walk_adaptive_conjugate_test_nightly.py b/tests/ppl/inference/single_site_random_walk_adaptive_conjugate_test_nightly.py index c8453e746c..08819e1854 100644 --- a/tests/ppl/inference/single_site_random_walk_adaptive_conjugate_test_nightly.py +++ b/tests/ppl/inference/single_site_random_walk_adaptive_conjugate_test_nightly.py @@ -6,6 +6,7 @@ import unittest import beanmachine.ppl as bm + from ..testlib.abstract_conjugate import AbstractConjugateTests diff --git a/tests/ppl/inference/single_site_random_walk_conjugate_test_nightly.py b/tests/ppl/inference/single_site_random_walk_conjugate_test_nightly.py index 7a9ad76fed..117ce05d00 100644 --- a/tests/ppl/inference/single_site_random_walk_conjugate_test_nightly.py +++ b/tests/ppl/inference/single_site_random_walk_conjugate_test_nightly.py @@ -6,6 +6,7 @@ import unittest import beanmachine.ppl as bm + from ..testlib.abstract_conjugate import AbstractConjugateTests diff --git a/tests/ppl/inference/single_site_random_walk_test.py b/tests/ppl/inference/single_site_random_walk_test.py index 857437b991..dc0dabf2ed 100644 --- a/tests/ppl/inference/single_site_random_walk_test.py +++ b/tests/ppl/inference/single_site_random_walk_test.py @@ -207,9 +207,9 @@ def test_single_site_adaptive_random_walk(self): mu=torch.tensor(0.0), std=torch.tensor(1.0), sigma=torch.ones(1) ) mh = bm.SingleSiteRandomWalk(step_size=4) - p_key = model.normal_p() + p_key = model.theta() queries = [p_key] - observations = {model.normal(): torch.tensor(100.0)} + observations = {model.x(): torch.tensor(100.0)} predictions = mh.infer(queries, observations, 100, num_adaptive_samples=30) predictions = predictions.get_chain()[p_key] self.assertIn(True, [45 < pred < 55 for pred in predictions]) @@ -223,9 +223,9 @@ def test_single_site_random_walk_rate(self): mu=torch.zeros(1), std=torch.ones(1), sigma=torch.ones(1) ) mh = bm.SingleSiteRandomWalk(step_size=10) - p_key = model.normal_p() + p_key = model.theta() queries = [p_key] - observations = {model.normal(): torch.tensor(100.0)} + observations = {model.x(): torch.tensor(100.0)} predictions = mh.infer(queries, observations, 100) predictions = predictions.get_chain()[p_key] self.assertIn(True, [45 < pred < 55 for pred in predictions]) @@ -235,9 +235,9 @@ def test_single_site_random_walk_rate_vector(self): mu=torch.zeros(2), std=torch.ones(2), sigma=torch.ones(2) ) mh = bm.SingleSiteRandomWalk(step_size=10) - p_key = model.normal_p() + p_key = model.theta() queries = [p_key] - observations = {model.normal(): torch.tensor([100.0, -100.0])} + observations = {model.x(): torch.tensor([100.0, -100.0])} predictions = mh.infer(queries, observations, 100) predictions = predictions.get_chain()[p_key] self.assertIn(True, [45 < pred[0] < 55 for pred in predictions]) @@ -248,9 +248,9 @@ def test_single_site_random_walk_half_support_rate(self): shape=torch.ones(1), rate=torch.ones(1), mu=torch.ones(1) ) mh = bm.SingleSiteRandomWalk(step_size=4.0) - p_key = model.gamma() + p_key = model.theta() queries = [p_key] - observations = {model.normal(): torch.tensor([100.0])} + observations = {model.x(): torch.tensor([100.0])} predictions = mh.infer(queries, observations, 100) predictions = predictions.get_chain()[p_key] """ @@ -285,9 +285,9 @@ def test_single_site_random_walk_interval_support_rate(self): def test_single_site_random_walk_simplex_support_rate(self): model = CategoricalDirichletModel(alpha=torch.tensor([1.0, 10.0])) mh = bm.SingleSiteRandomWalk(step_size=1.0) - p_key = model.dirichlet() + p_key = model.theta() queries = [p_key] - observations = {model.categorical(): torch.tensor([1.0, 1.0, 1.0])} + observations = {model.x(): torch.tensor([1.0, 1.0, 1.0])} predictions = mh.infer(queries, observations, 50) predictions = predictions.get_chain()[p_key] """ diff --git a/tests/ppl/inference/single_site_uniform_mh_conjugate_test_nightly.py b/tests/ppl/inference/single_site_uniform_mh_conjugate_test_nightly.py index a00836f4fb..642e12a411 100644 --- a/tests/ppl/inference/single_site_uniform_mh_conjugate_test_nightly.py +++ b/tests/ppl/inference/single_site_uniform_mh_conjugate_test_nightly.py @@ -6,6 +6,7 @@ import unittest import beanmachine.ppl as bm + from ..testlib.abstract_conjugate import AbstractConjugateTests diff --git a/tests/ppl/inference/single_site_uniform_mh_test.py b/tests/ppl/inference/single_site_uniform_mh_test.py index 6bdcb1059e..dd12068ccc 100644 --- a/tests/ppl/inference/single_site_uniform_mh_test.py +++ b/tests/ppl/inference/single_site_uniform_mh_test.py @@ -3,52 +3,32 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import unittest - import beanmachine.ppl as bm import torch -import torch.distributions as dist - - -class SingleSiteUniformMetropolisHastingsTest(unittest.TestCase): - class SampleBernoulliModel(object): - @bm.random_variable - def foo(self): - return dist.Beta(torch.tensor(2.0), torch.tensor(2.0)) - - @bm.random_variable - def bar(self): - return dist.Bernoulli(self.foo()) - - class SampleCategoricalModel(object): - @bm.random_variable - def foo(self): - return dist.Dirichlet(torch.tensor([0.5, 0.5])) - - @bm.random_variable - def bar(self): - return dist.Categorical(self.foo()) - - def test_single_site_uniform_mh_with_bernoulli(self): - model = self.SampleBernoulliModel() - mh = bm.SingleSiteUniformMetropolisHastings() - foo_key = model.foo() - bar_key = model.bar() - sampler = mh.sampler([foo_key], {bar_key: torch.tensor(0.0)}, num_samples=5) - for world in sampler: - self.assertTrue(foo_key in world) - self.assertTrue(bar_key in world) - self.assertTrue(foo_key in world.get_variable(bar_key).parents) - self.assertTrue(bar_key in world.get_variable(foo_key).children) - - def test_single_site_uniform_mh_with_categorical(self): - model = self.SampleCategoricalModel() - mh = bm.SingleSiteUniformMetropolisHastings() - foo_key = model.foo() - bar_key = model.bar() - sampler = mh.sampler([foo_key], {bar_key: torch.tensor(0.0)}, num_samples=5) - for world in sampler: - self.assertTrue(foo_key in world) - self.assertTrue(bar_key in world) - self.assertTrue(foo_key in world.get_variable(bar_key).parents) - self.assertTrue(bar_key in world.get_variable(foo_key).children) +from beanmachine.ppl.examples.conjugate_models import ( + BetaBernoulliModel, + CategoricalDirichletModel, +) +from torch import tensor + +from ..utils.fixtures import parametrize_inference, parametrize_model + + +pytestmark = parametrize_inference([bm.SingleSiteUniformMetropolisHastings()]) + + +@parametrize_model( + [ + BetaBernoulliModel(tensor(2.0), tensor(2.0)), + CategoricalDirichletModel(tensor([0.5, 0.5])), + ] +) +def test_single_site_uniform_mh(model, inference): + p_key = model.theta() + l_key = model.x(0) if model.x_dim == 1 else model.x() + sampler = inference.sampler([p_key], {l_key: torch.tensor(0.0)}, num_samples=5) + for world in sampler: + assert p_key in world + assert l_key in world + assert p_key in world.get_variable(l_key).parents + assert l_key in world.get_variable(p_key).children diff --git a/tests/ppl/inference/utils_test.py b/tests/ppl/inference/utils_test.py index f2bf4d0a71..3f19ee579d 100644 --- a/tests/ppl/inference/utils_test.py +++ b/tests/ppl/inference/utils_test.py @@ -4,37 +4,39 @@ # LICENSE file in the root directory of this source tree. import beanmachine.ppl as bm -import torch -import torch.distributions as dist +import pytest +from beanmachine.ppl.examples.primitive_models import NormalModel +from torch import tensor +from ..utils.fixtures import approx_all, parametrize_inference, parametrize_model -@bm.random_variable -def foo(): - return dist.Normal(0.0, 1.0) +pytestmark = [ + parametrize_model([NormalModel(tensor(0.0), tensor(1.0))]), + parametrize_inference([bm.SingleSiteAncestralMetropolisHastings()]), +] -def test_set_random_seed(): + +@pytest.mark.parametrize("seed", [123, 47]) +def test_set_random_seed(model, inference, seed): def sample_with_seed(seed): bm.seed(seed) - return bm.SingleSiteAncestralMetropolisHastings().infer( - [foo()], {}, num_samples=20, num_chains=1 - ) + return inference.infer([model.x()], {}, num_samples=20, num_chains=1) - samples1 = sample_with_seed(123) - samples2 = sample_with_seed(123) - assert torch.allclose(samples1[foo()], samples2[foo()]) + samples = (sample_with_seed(seed) for _ in range(2)) + assert approx_all(*(s[model.x()] for s in samples)) -def test_detach_samples(): +def test_detach_samples(model, inference): """Test to ensure samples are detached from torch computation graphs.""" - queries = [foo()] - samples = bm.SingleSiteAncestralMetropolisHastings().infer( + queries = [model.x()] + samples = inference.infer( queries=queries, observations={}, num_samples=20, num_chains=1, ) - rv_data = samples[foo()] + rv_data = samples[model.x()] idata = samples.to_inference_data() assert hasattr(rv_data, "detach") - assert not hasattr(idata["posterior"][foo()], "detach") + assert not hasattr(idata["posterior"][model.x()], "detach") diff --git a/tests/ppl/utils/fixtures.py b/tests/ppl/utils/fixtures.py new file mode 100644 index 0000000000..bd01cbf880 --- /dev/null +++ b/tests/ppl/utils/fixtures.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from operator import attrgetter +from typing import Callable, List, Tuple, Type + +import pytest +import torch.distributions as dist +from beanmachine.ppl.inference.base_inference import BaseInference +from beanmachine.ppl.inference.proposer.base_proposer import BaseProposer +from torch import allclose, is_tensor, isclose, Tensor, tensor + + +# numerical predicates ========================================================= + + +def _approx(comp_fn: Callable, result: Tensor, expected, atol=1e-8): + return comp_fn( + result, + expected if is_tensor(expected) else tensor(expected).to(result), + atol=atol, + ) + + +approx = partial(_approx, isclose) +approx_all = partial(_approx, allclose) + + +# fixture printers ============================================================= + + +def _is_model(arg): + return arg.__class__.__name__.endswith("Model") + + +_is_value_tensor = is_tensor + + +def _is_value_variable(arg): + return ( + isinstance(arg, tuple) + and isinstance(arg[0], dist.Distribution) + and is_tensor(arg[1]) + ) + + +def _is_value(arg): + return _is_value_tensor(arg) or _is_value_variable(arg) + + +def _is_inference(arg): + return isinstance(arg, BaseInference) + + +def _is_inferences(args): + return all(map(_is_inference, args)) + + +def _is_proposer(arg): + return issubclass(arg, BaseProposer) + + +_id_empty = "" +_id_model = attrgetter("__class__") + + +def _id_value_tensor(arg): + return f"Tensor{tuple(arg.shape)}" + + +def _id_value_variable(arg): + return f"Variable{tuple(arg[1].shape)}" + + +def _id_value(arg): + return _id_value_tensor(arg) if _is_value_tensor(arg) else _id_value_variable(arg) + + +_id_inference = attrgetter("__class__") + + +def _id_inferences(args): + return f"({','.join([a.__class__.__name__ for a in args])})" + + +_id_proposer = None # default printer + + +def _id_model_value(arg): + return _id_model(arg) if _is_model(arg) else _id_value(arg) + + +def _id_value_expected(arg): + return _id_value(arg) if _is_value(arg) else _id_empty + + +def _id_model_value_expected(arg): + return _id_model(arg) if _is_model(arg) else _id_value_expected(arg) + + +# fixtures ===================================================================== + + +def parametrize_model(models: List): + assert all(map(_is_model, models)) + return pytest.mark.parametrize("model", models, ids=_id_model) + + +def parametrize_value(args: List[Tuple]): + assert all(map(_is_value, args)) + return pytest.mark.parametrize("value", args, ids=_id_value) + + +def parametrize_model_value(args: List[Tuple]): + assert all(isinstance(a, tuple) for a in args) + return pytest.mark.parametrize("model, value", args, ids=_id_model_value) + + +def parametrize_value_expected(args: List[Tuple]): + assert all(isinstance(a, tuple) for a in args) + return pytest.mark.parametrize("value, expected", args, ids=_id_value_expected) + + +def parametrize_model_value_expected(args: List[Tuple]): + return pytest.mark.parametrize( + "model, value, expected", args, ids=_id_model_value_expected + ) + + +def parametrize_inference(methods: List[BaseInference]): + assert _is_inferences(methods) + return pytest.mark.parametrize("inference", methods, ids=_id_inference) + + +def parametrize_inference_comparison(methods: List[BaseInference]): + assert _is_inferences(methods) + return pytest.mark.parametrize("inferences", [methods], ids=_id_inferences) + + +def parametrize_proposer(methods: List[Type[BaseProposer]]): + assert all(map(_is_proposer, methods)) + return pytest.mark.parametrize("proposer", methods, ids=_id_proposer)