diff --git a/pyro/infer/__init__.py b/pyro/infer/__init__.py index c0f3a26c3f..6934bd29fe 100644 --- a/pyro/infer/__init__.py +++ b/pyro/infer/__init__.py @@ -12,7 +12,7 @@ from pyro.infer.mcmc.hmc import HMC from pyro.infer.mcmc.nuts import NUTS from pyro.infer.mcmc.rwkernel import RandomWalkKernel -from pyro.infer.predictive import Predictive +from pyro.infer.predictive import Predictive, WeighedPredictive from pyro.infer.renyi_elbo import RenyiELBO from pyro.infer.rws import ReweightedWakeSleep from pyro.infer.smcfilter import SMCFilter @@ -62,4 +62,5 @@ "TraceTailAdaptive_ELBO", "Trace_ELBO", "Trace_MMD", + "WeighedPredictive", ] diff --git a/pyro/infer/importance.py b/pyro/infer/importance.py index d7c25a843d..d25cf16680 100644 --- a/pyro/infer/importance.py +++ b/pyro/infer/importance.py @@ -12,6 +12,7 @@ from .abstract_infer import TracePosterior from .enum import get_importance_trace +from .util import plate_log_prob_sum class Importance(TracePosterior): @@ -143,22 +144,9 @@ def _fn(*args, **kwargs): log_weights = model_trace.log_prob_sum() - guide_trace.log_prob_sum() else: wd = guide_trace.plate_to_symbol["num_particles_vectorized"] - log_weights = 0.0 - for site in model_trace.nodes.values(): - if site["type"] != "sample": - continue - log_weights += torch.einsum( - site["packed"]["log_prob"]._pyro_dims + "->" + wd, - [site["packed"]["log_prob"]], - ) - - for site in guide_trace.nodes.values(): - if site["type"] != "sample": - continue - log_weights -= torch.einsum( - site["packed"]["log_prob"]._pyro_dims + "->" + wd, - [site["packed"]["log_prob"]], - ) + log_weights = plate_log_prob_sum(model_trace, wd) - plate_log_prob_sum( + guide_trace, wd + ) if normalized: log_weights = log_weights - torch.logsumexp(log_weights) diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 9d8b1c7f76..6be8b5cb5f 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -3,11 +3,14 @@ import warnings from functools import reduce +from typing import List, NamedTuple, Union import torch import pyro import pyro.poutine as poutine +from pyro.infer.util import plate_log_prob_sum +from pyro.poutine.trace_struct import Trace from pyro.poutine.util import prune_subsample_sites @@ -31,16 +34,20 @@ def _guess_max_plate_nesting(model, args, kwargs): return max_plate_nesting +class _predictiveResults(NamedTuple): + """ + Return value of call to ``_predictive`` and ``_predictive_sequential``. + """ + + samples: dict + trace: Union[Trace, List[Trace]] + + def _predictive_sequential( - model, - posterior_samples, - model_args, - model_kwargs, - num_samples, - return_site_shapes, - return_trace=False, + model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes ): - collected = [] + collected_samples = [] + collected_trace = [] samples = [ {k: v[i] for k, v in posterior_samples.items()} for i in range(num_samples) ] @@ -48,20 +55,21 @@ def _predictive_sequential( trace = poutine.trace(poutine.condition(model, samples[i])).get_trace( *model_args, **model_kwargs ) - if return_trace: - collected.append(trace) - else: - collected.append( - {site: trace.nodes[site]["value"] for site in return_site_shapes} - ) + collected_trace.append(trace) + collected_samples.append( + {site: trace.nodes[site]["value"] for site in return_site_shapes} + ) - if return_trace: - return collected - else: - return { - site: torch.stack([s[site] for s in collected]).reshape(shape) + return _predictiveResults( + trace=collected_trace, + samples={ + site: torch.stack([s[site] for s in collected_samples]).reshape(shape) for site, shape in return_site_shapes.items() - } + }, + ) + + +_predictive_vectorize_plate_name = "_num_predictive_samples" def _predictive( @@ -69,15 +77,15 @@ def _predictive( posterior_samples, num_samples, return_sites=(), - return_trace=False, parallel=False, model_args=(), model_kwargs={}, + mask=True, ): - model = torch.no_grad()(poutine.mask(model, mask=False)) + model = torch.no_grad()(poutine.mask(model, mask=False) if mask else model) max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs) vectorize = pyro.plate( - "_num_predictive_samples", num_samples, dim=-max_plate_nesting - 1 + _predictive_vectorize_plate_name, num_samples, dim=-max_plate_nesting - 1 ) model_trace = prune_subsample_sites( poutine.trace(model).get_trace(*model_args, **model_kwargs) @@ -93,12 +101,6 @@ def _predictive( ) reshaped_samples[name] = sample - if return_trace: - trace = poutine.trace( - poutine.condition(vectorize(model), reshaped_samples) - ).get_trace(*model_args, **model_kwargs) - return trace - return_site_shapes = {} for site in model_trace.stochastic_nodes + model_trace.observation_nodes: append_ndim = max_plate_nesting - len(model_trace.nodes[site]["fn"].batch_shape) @@ -131,7 +133,6 @@ def _predictive( model_kwargs, num_samples, return_site_shapes, - return_trace=False, ) trace = poutine.trace( @@ -148,7 +149,7 @@ def _predictive( else: predictions[site] = value.reshape(shape) - return predictions + return _predictiveResults(trace=trace, samples=predictions) class Predictive(torch.nn.Module): @@ -269,7 +270,7 @@ def forward(self, *args, **kwargs): parallel=self.parallel, model_args=args, model_kwargs=kwargs, - ) + ).samples return _predictive( self.model, posterior_samples, @@ -278,7 +279,7 @@ def forward(self, *args, **kwargs): parallel=self.parallel, model_args=args, model_kwargs=kwargs, - ) + ).samples def get_samples(self, *args, **kwargs): warnings.warn( @@ -304,12 +305,144 @@ def get_vectorized_trace(self, *args, **kwargs): parallel=self.parallel, model_args=args, model_kwargs=kwargs, - ) + ).samples return _predictive( self.model, posterior_samples, self.num_samples, - return_trace=True, + parallel=True, model_args=args, model_kwargs=kwargs, + ).trace + + +class WeighedPredictiveResults(NamedTuple): + """ + Return value of call to instance of :class:`WeighedPredictive`. + """ + + samples: Union[dict, tuple] + log_weights: torch.Tensor + guide_log_prob: torch.Tensor + model_log_prob: torch.Tensor + + +class WeighedPredictive(Predictive): + """ + Class used to construct a weighed predictive distribution that is based + on the same initialization interface as :class:`Predictive`. + + The methods `.forward` and `.call` can be called with an additional keyword argument + ``model_guide`` which is the model used to create and optimize the guide (if not + provided ``model_guide`` defaults to ``self.model``), and they return both samples and log_weights. + + The weights are calculated as the per sample gap between the model_guide log-probability + and the guide log-probability (a guide must always be provided). + + A typical use case would be based on a ``model`` :math:`p(x,z)=p(x|z)p(z)` and ``guide`` :math:`q(z)` + that has already been fitted to the model given observations :math:`p(X_{obs},z)`, both of which + are provided at itialization of :class:`WeighedPredictive` (same as you would do with :class:`Predictive`). + When calling an instance of :class:`WeighedPredictive` we provide the model given observations :math:`p(X_{obs},z)` + as the keyword argument ``model_guide``. + The resulting output would be the usual samples :math:`p(x|z)q(z)` returned by :class:`Predictive`, + along with per sample weights :math:`p(X_{obs},z)/q(z)`. The samples and weights can be fed into + :any:`weighed_quantile` in order to obtain the true quantiles of the resulting distribution. + + Note that the ``model`` can be more elaborate with sample sites :math:`y` that are not observed + and are not part of the guide, if the samples sites :math:`y` are sampled after the observations + and the latent variables sampled by the guide, such that :math:`p(x,y,z)=p(y|x,z)p(x|z)p(z)` where + each element in the product represents a set of ``pyro.sample`` statements. + """ + + def call(self, *args, **kwargs): + """ + Method `.call` that is backwards compatible with the same method found in :class:`Predictive` + but can be called with an additional keyword argument `model_guide` + which is the model used to create and optimize the guide. + + Returns :class:`WeighedPredictiveResults` which has attributes ``.samples`` and per sample + weights ``.log_weights``. + """ + result = self.forward(*args, **kwargs) + return WeighedPredictiveResults( + samples=tuple(v for _, v in sorted(result.items())), + log_weights=result.log_weights, + guide_log_prob=result.guide_log_prob, + model_log_prob=result.model_log_prob, + ) + + def forward(self, *args, **kwargs): + """ + Method `.forward` that is backwards compatible with the same method found in :class:`Predictive` + but can be called with an additional keyword argument `model_guide` + which is the model used to create and optimize the guide. + + Returns :class:`WeighedPredictiveResults` which has attributes ``.samples`` and per sample + weights ``.log_weights``. + """ + model_guide = kwargs.pop("model_guide", self.model) + return_sites = self.return_sites + # return all sites by default if a guide is provided. + return_sites = None if not return_sites else return_sites + guide_predictive = _predictive( + self.guide, + self.posterior_samples, + self.num_samples, + return_sites=None, + parallel=self.parallel, + model_args=args, + model_kwargs=kwargs, + mask=False, + ) + posterior_samples = guide_predictive.samples + model_predictive = _predictive( + model_guide, + posterior_samples, + self.num_samples, + return_sites=return_sites, + parallel=self.parallel, + model_args=args, + model_kwargs=kwargs, + mask=False, + ) + if not isinstance(guide_predictive.trace, list): + guide_trace = prune_subsample_sites(guide_predictive.trace) + model_trace = prune_subsample_sites(model_predictive.trace) + guide_trace.compute_score_parts() + model_trace.compute_log_prob() + guide_trace.pack_tensors() + model_trace.pack_tensors(guide_trace.plate_to_symbol) + plate_symbol = guide_trace.plate_to_symbol[_predictive_vectorize_plate_name] + guide_log_prob = plate_log_prob_sum(guide_trace, plate_symbol) + model_log_prob = plate_log_prob_sum(model_trace, plate_symbol) + else: + guide_log_prob = torch.Tensor( + [ + trace_element.log_prob_sum() + for trace_element in guide_predictive.trace + ] + ) + model_log_prob = torch.Tensor( + [ + trace_element.log_prob_sum() + for trace_element in model_predictive.trace + ] + ) + return WeighedPredictiveResults( + samples=( + _predictive( + self.model, + posterior_samples, + self.num_samples, + return_sites=return_sites, + parallel=self.parallel, + model_args=args, + model_kwargs=kwargs, + ).samples + if model_guide is not self.model + else model_predictive.samples + ), + log_weights=model_log_prob - guide_log_prob, + guide_log_prob=guide_log_prob, + model_log_prob=model_log_prob, ) diff --git a/pyro/infer/util.py b/pyro/infer/util.py index 7ea460c1ec..13e1d9e12f 100644 --- a/pyro/infer/util.py +++ b/pyro/infer/util.py @@ -14,6 +14,7 @@ from pyro.ops import packed from pyro.ops.einsum.adjoint import require_backward from pyro.ops.rings import MarginalRing +from pyro.poutine.trace_struct import Trace from pyro.poutine.util import site_is_subsample from .. import settings @@ -342,3 +343,18 @@ def check_fully_reparametrized(guide_site): raise NotImplementedError( "All distributions in the guide must be fully reparameterized." ) + + +def plate_log_prob_sum(trace: Trace, plate_symbol: str) -> torch.Tensor: + """ + Get log probability sum from trace while keeping indexing over the specified plate. + """ + log_prob_sum = 0.0 + for site in trace.nodes.values(): + if site["type"] != "sample": + continue + log_prob_sum += torch.einsum( + site["packed"]["log_prob"]._pyro_dims + "->" + plate_symbol, + [site["packed"]["log_prob"]], + ) + return log_prob_sum diff --git a/pyro/ops/stats.py b/pyro/ops/stats.py index 2ec57d4784..8e0bd2631f 100644 --- a/pyro/ops/stats.py +++ b/pyro/ops/stats.py @@ -277,14 +277,17 @@ def weighed_quantile( :param int dim: dimension to take quantiles from ``input``. :returns torch.Tensor: quantiles of ``input`` at ``probs``. - Example: - >>> from pyro.ops.stats import weighed_quantile - >>> import torch - >>> input = torch.Tensor([[10, 50, 40], [20, 30, 0]]) - >>> probs = torch.Tensor([0.2, 0.8]) - >>> log_weights = torch.Tensor([0.4, 0.5, 0.1]).log() - >>> result = weighed_quantile(input, probs, log_weights, -1) - >>> torch.testing.assert_close(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]])) + **Example:** + + .. doctest:: + + >>> from pyro.ops.stats import weighed_quantile + >>> import torch + >>> input = torch.Tensor([[10, 50, 40], [20, 30, 0]]) + >>> probs = torch.Tensor([0.2, 0.8]) + >>> log_weights = torch.Tensor([0.4, 0.5, 0.1]).log() + >>> result = weighed_quantile(input, probs, log_weights, -1) + >>> torch.testing.assert_close(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]])) """ dim = dim if dim >= 0 else (len(input.shape) + dim) if isinstance(probs, (list, tuple)): diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index fc6f63fa37..1f28e1f05c 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -8,7 +8,7 @@ import pyro.distributions as dist import pyro.optim as optim import pyro.poutine as poutine -from pyro.infer import SVI, Predictive, Trace_ELBO +from pyro.infer import SVI, Predictive, Trace_ELBO, WeighedPredictive from pyro.infer.autoguide import AutoDelta, AutoDiagonalNormal from tests.common import assert_close @@ -39,29 +39,44 @@ def beta_guide(num_trials): pyro.sample("phi", phi_posterior) +@pytest.mark.parametrize("predictive", [Predictive, WeighedPredictive]) @pytest.mark.parametrize("parallel", [False, True]) -def test_posterior_predictive_svi_manual_guide(parallel): +def test_posterior_predictive_svi_manual_guide(parallel, predictive): true_probs = torch.ones(5) * 0.7 - num_trials = torch.ones(5) * 1000 + num_trials = ( + torch.ones(5) * 400 + ) # Reduced to 400 from 1000 in order for guide optimization to converge num_success = dist.Binomial(num_trials, true_probs).sample() conditioned_model = poutine.condition(model, data={"obs": num_success}) elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) - svi = SVI(conditioned_model, beta_guide, optim.Adam(dict(lr=1.0)), elbo) - for i in range(1000): + svi = SVI(conditioned_model, beta_guide, optim.Adam(dict(lr=3.0)), elbo) + for i in range( + 5000 + ): # Increased to 5000 from 1000 in order for guide optimization to converge svi.step(num_trials) - posterior_predictive = Predictive( + posterior_predictive = predictive( model, guide=beta_guide, num_samples=10000, parallel=parallel, return_sites=["_RETURN"], ) - marginal_return_vals = posterior_predictive(num_trials)["_RETURN"] - assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05) + if predictive is Predictive: + marginal_return_vals = posterior_predictive(num_trials)["_RETURN"] + else: + weighed_samples = posterior_predictive( + num_trials, model_guide=conditioned_model + ) + marginal_return_vals = weighed_samples.samples["_RETURN"] + assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape + # Weights should be uniform as the guide has the same distribution as the model + assert weighed_samples.log_weights.std() < 0.6 + assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 280, rtol=0.1) +@pytest.mark.parametrize("predictive", [Predictive, WeighedPredictive]) @pytest.mark.parametrize("parallel", [False, True]) -def test_posterior_predictive_svi_auto_delta_guide(parallel): +def test_posterior_predictive_svi_auto_delta_guide(parallel, predictive): true_probs = torch.ones(5) * 0.7 num_trials = torch.ones(5) * 1000 num_success = dist.Binomial(num_trials, true_probs).sample() @@ -70,15 +85,23 @@ def test_posterior_predictive_svi_auto_delta_guide(parallel): svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=1.0)), Trace_ELBO()) for i in range(1000): svi.step(num_trials) - posterior_predictive = Predictive( + posterior_predictive = predictive( model, guide=guide, num_samples=10000, parallel=parallel ) - marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"] + if predictive is Predictive: + marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"] + else: + weighed_samples = posterior_predictive.get_samples( + num_trials, model_guide=conditioned_model + ) + marginal_return_vals = weighed_samples.samples["obs"] + assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05) +@pytest.mark.parametrize("predictive", [Predictive, WeighedPredictive]) @pytest.mark.parametrize("return_trace", [False, True]) -def test_posterior_predictive_svi_auto_diag_normal_guide(return_trace): +def test_posterior_predictive_svi_auto_diag_normal_guide(return_trace, predictive): true_probs = torch.ones(5) * 0.7 num_trials = torch.ones(5) * 1000 num_success = dist.Binomial(num_trials, true_probs).sample() @@ -87,7 +110,7 @@ def test_posterior_predictive_svi_auto_diag_normal_guide(return_trace): svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=0.1)), Trace_ELBO()) for i in range(1000): svi.step(num_trials) - posterior_predictive = Predictive( + posterior_predictive = predictive( model, guide=guide, num_samples=10000, parallel=True ) if return_trace: @@ -95,7 +118,14 @@ def test_posterior_predictive_svi_auto_diag_normal_guide(return_trace): num_trials ).nodes["obs"]["value"] else: - marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"] + if predictive is Predictive: + marginal_return_vals = posterior_predictive.get_samples(num_trials)["obs"] + else: + weighed_samples = posterior_predictive.get_samples( + num_trials, model_guide=conditioned_model + ) + marginal_return_vals = weighed_samples.samples["obs"] + assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 700, rtol=0.05) @@ -113,8 +143,9 @@ def test_posterior_predictive_svi_one_hot(): assert_close(marginal_return_vals.mean(dim=0), true_probs.unsqueeze(0), rtol=0.1) +@pytest.mark.parametrize("predictive", [Predictive, WeighedPredictive]) @pytest.mark.parametrize("parallel", [False, True]) -def test_shapes(parallel): +def test_shapes(parallel, predictive): num_samples = 10 def model(): @@ -132,22 +163,26 @@ def model(): expected = poutine.replay(vectorize(model), trace)() # Use Predictive. - predictive = Predictive( + actual = predictive( model, guide=guide, return_sites=["x", "y"], num_samples=num_samples, parallel=parallel, - ) - actual = predictive() + )() + if predictive is WeighedPredictive: + assert actual.samples["x"].shape[:1] == actual.log_weights.shape + assert actual.samples["y"].shape[:1] == actual.log_weights.shape + actual = actual.samples assert set(actual) == set(expected) assert actual["x"].shape == expected["x"].shape assert actual["y"].shape == expected["y"].shape +@pytest.mark.parametrize("predictive", [Predictive, WeighedPredictive]) @pytest.mark.parametrize("with_plate", [True, False]) @pytest.mark.parametrize("event_shape", [(), (2,)]) -def test_deterministic(with_plate, event_shape): +def test_deterministic(with_plate, event_shape, predictive): def model(y=None): with pyro.util.optional(pyro.plate("plate", 3), with_plate): x = pyro.sample("x", dist.Normal(0, 1).expand(event_shape).to_event()) @@ -162,9 +197,13 @@ def model(y=None): for i in range(100): svi.step(y) - actual = Predictive( + actual = predictive( model, guide=guide, return_sites=["x2", "x3"], num_samples=1000 )() + if predictive is WeighedPredictive: + assert actual.samples["x2"].shape[:1] == actual.log_weights.shape + assert actual.samples["x3"].shape[:1] == actual.log_weights.shape + actual = actual.samples x2_batch_shape = (3,) if with_plate else () assert actual["x2"].shape == (1000,) + x2_batch_shape + event_shape # x3 shape is prepended 1 to match Pyro shape semantics