diff --git a/pyro/poutine/equalize_messenger.py b/pyro/poutine/equalize_messenger.py index 95f1f574fa..035d96c06b 100644 --- a/pyro/poutine/equalize_messenger.py +++ b/pyro/poutine/equalize_messenger.py @@ -55,7 +55,7 @@ def __enter__(self) -> Self: self.value = None return super().__enter__() - def _is_matching(self, msg: Message): + def _is_matching(self, msg: Message) -> bool: if msg["type"] == self.type: for site in self.sites: if re.compile(site).fullmatch(msg["name"]) is not None: # type: ignore[arg-type] diff --git a/tests/poutine/test_poutines.py b/tests/poutine/test_poutines.py index 944d26988d..7e2f7cfa8e 100644 --- a/tests/poutine/test_poutines.py +++ b/tests/poutine/test_poutines.py @@ -755,6 +755,50 @@ def test_infer_config_sample(self): assert tr.nodes["p"]["infer"] == {} +class EqualizeHandlerTests(TestCase): + def setUp(self): + def per_category_model(category): + shift = pyro.param(f"{category}_shift", torch.randn(1)) + mean = pyro.sample(f"{category}_mean", pyro.distributions.Normal(0, 1)) + std = pyro.sample(f"{category}_std", pyro.distributions.LogNormal(0, 1)) + with pyro.plate(f"{category}_num_samples", 5): + return pyro.sample( + f"{category}_values", pyro.distributions.Normal(mean + shift, std) + ) + + def model(categories=["dogs", "cats"]): + return {category: per_category_model(category) for category in categories} + + self.model = model + + def test_sample_site_equalization(self): + pyro.set_rng_seed(20240616) + pyro.clear_param_store() + model = poutine.equalize(self.model, ".+_std") + tr = pyro.poutine.trace(model).get_trace() + assert_equal(tr.nodes["cats_std"]["value"], tr.nodes["dogs_std"]["value"]) + assert_not_equal( + tr.nodes["cats_shift"]["value"], tr.nodes["dogs_shift"]["value"] + ) + guide = pyro.infer.autoguide.AutoNormal(model) + guide_sites = [*guide()] + assert guide_sites == [ + "dogs_mean", + "dogs_std", + "dogs_values", + "cats_mean", + "cats_values", + ] + + def test_param_equalization(self): + pyro.set_rng_seed(20240616) + pyro.clear_param_store() + model = poutine.equalize(self.model, ".+_shift", "param") + tr = pyro.poutine.trace(model).get_trace() + assert_equal(tr.nodes["cats_shift"]["value"], tr.nodes["dogs_shift"]["value"]) + assert_not_equal(tr.nodes["cats_std"]["value"], tr.nodes["dogs_std"]["value"]) + + @pytest.mark.parametrize("first_available_dim", [-1, -2, -3]) @pytest.mark.parametrize("depth", [0, 1, 2]) def test_enumerate_poutine(depth, first_available_dim):