From 9488da816c9568ab1587dd989f53fea9befe00b2 Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Mon, 22 Apr 2024 13:21:41 +0200 Subject: [PATCH] sketched return deterministic from guide in predictive --- pyro/infer/predictive.py | 46 +++++++++++++++++++++++++++++++--- pyro/util.py | 8 +++++- tests/infer/test_predictive.py | 38 ++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 5 deletions(-) diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index e30099c85e..eb0d518013 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -60,7 +60,14 @@ def _predictive_sequential( ) collected_trace.append(trace) collected_samples.append( - {site: trace.nodes[site]["value"] for site in return_site_shapes} + { + site: ( + trace.nodes[site]["value"] + if site in trace.nodes + else samples[i][site] + ) + for site in return_site_shapes + } ) return _predictiveResults( @@ -84,6 +91,7 @@ def _predictive( model_args=(), model_kwargs={}, mask=True, + posterior_deterministic_sites=(), ): model = torch.no_grad()(poutine.mask(model, mask=False) if mask else model) max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs) @@ -122,6 +130,9 @@ def _predictive( elif site not in posterior_samples: return_site_shapes[site] = site_shape + for site in posterior_deterministic_sites: + return_site_shapes[site] = posterior_samples[site].shape + # handle _RETURN site if return_sites is not None and "_RETURN" in return_sites: value = model_trace.nodes["_RETURN"]["value"] @@ -143,7 +154,10 @@ def _predictive( ).get_trace(*model_args, **model_kwargs) predictions = {} for site, shape in return_site_shapes.items(): - value = trace.nodes[site]["value"] + if site in trace.nodes: + value = trace.nodes[site]["value"] + else: + value = reshaped_samples[site] if site == "_RETURN" and shape is None: predictions[site] = value continue @@ -179,6 +193,8 @@ class Predictive(torch.nn.Module): :param bool parallel: predict in parallel by wrapping the existing model in an outermost `plate` messenger. Note that this requires that the model has all batch dims correctly annotated via :class:`~pyro.plate`. Default is `False`. + :param return_deterministic_guide_sites: include deterministic sites from the guide + in returned samples; this does not affect the returned trace. """ def __init__( @@ -189,6 +205,7 @@ def __init__( num_samples=None, return_sites=(), parallel=False, + return_deterministic_guide_sites=False, ): super().__init__() if posterior_samples is None: @@ -231,6 +248,7 @@ def __init__( self.guide = guide self.return_sites = return_sites self.parallel = parallel + self.return_deterministic_guide_sites = return_deterministic_guide_sites def call(self, *args, **kwargs): """ @@ -262,10 +280,13 @@ def forward(self, *args, **kwargs): """ posterior_samples = self.posterior_samples return_sites = self.return_sites + + guide_deterministic_sites = () + if self.guide is not None: # return all sites by default if a guide is provided. return_sites = None if not return_sites else return_sites - posterior_samples = _predictive( + guide_pred_res = _predictive( self.guide, posterior_samples, self.num_samples, @@ -273,7 +294,23 @@ def forward(self, *args, **kwargs): parallel=self.parallel, model_args=args, model_kwargs=kwargs, - ).samples + ) + posterior_samples = guide_pred_res.samples + + if self.return_deterministic_guide_sites: + if isinstance(guide_pred_res, Trace): + guide_tr = guide_pred_res.trace + else: + guide_tr = guide_pred_res.trace[0] + + guide_deterministic_sites = tuple( + name + for name, site in guide_tr.nodes.items() + if site["type"] == "sample" + if site["infer"].get("_deterministic") + if (return_sites is None or name in return_sites) + ) + return _predictive( self.model, posterior_samples, @@ -282,6 +319,7 @@ def forward(self, *args, **kwargs): parallel=self.parallel, model_args=args, model_kwargs=kwargs, + posterior_deterministic_sites=guide_deterministic_sites, ).samples def get_samples(self, *args, **kwargs): diff --git a/pyro/util.py b/pyro/util.py index 6c89e8fa26..267276ec17 100644 --- a/pyro/util.py +++ b/pyro/util.py @@ -266,6 +266,12 @@ def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf if site["type"] == "sample" if site["infer"].get("is_auxiliary") ) + det_vars = set( + name + for name, site in guide_trace.nodes.items() + if site["type"] == "sample" + if site["infer"].get("_deterministic") + ) model_vars = set( name for name, site in model_trace.nodes.items() @@ -284,7 +290,7 @@ def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf warnings.warn( "Found auxiliary vars in the model: {}".format(aux_vars & model_vars) ) - if not (guide_vars <= model_vars | aux_vars): + if not (guide_vars <= model_vars | aux_vars | det_vars): warnings.warn( "Found non-auxiliary vars in guide but not model, " "consider marking these infer={{'is_auxiliary': True}}:\n{}".format( diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index ca155ed2fd..ed7c6c28c3 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -269,6 +269,44 @@ def model(y=None): assert_close(actual["x3"].mean(), y, rtol=0.1) +@pytest.mark.parametrize("with_plate", [True, False]) +@pytest.mark.parametrize("event_shape", [(), (2,)]) +@pytest.mark.parametrize("return_deterministic_guide_sites", [True, False]) +def test_deterministic_guide_return( + with_plate, event_shape, return_deterministic_guide_sites +): + def model(y=None): + with pyro.util.optional(pyro.plate("plate", 3), with_plate): + x = pyro.sample("x", dist.Normal(0, 1).expand(event_shape).to_event()) + x2 = pyro.deterministic("x2", x**2, event_dim=len(event_shape)) + + pyro.deterministic("x3", x2) + return pyro.sample("obs", dist.Normal(x2, 0.1).to_event(), obs=y) + + def guide(y=None): + with pyro.util.optional(pyro.plate("plate", 3), with_plate): + x = pyro.sample("x", dist.Normal(0, 1).expand(event_shape).to_event()) + + pyro.deterministic("x4", x) + + y = torch.tensor(4.0) + svi = SVI(model, guide, optim.Adam(dict(lr=0.1)), Trace_ELBO()) + for i in range(100): + svi.step(y) + + actual = Predictive( + model, + guide=guide, + num_samples=1000, + return_deterministic_guide_sites=return_deterministic_guide_sites, + )() + + if return_deterministic_guide_sites: + assert "x4" in actual + else: + assert "x4" not in actual + + def test_get_mask_optimization(): def model(): x = pyro.sample("x", dist.Normal(0, 1))