From b42d66533e0daf06414afbe3bb0a074b912c6111 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Tue, 23 Jul 2024 09:12:50 +0300 Subject: [PATCH 1/2] Fix models not rendering after application of the equalize effect handler. --- pyro/poutine/equalize_messenger.py | 2 +- tests/poutine/test_poutines.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pyro/poutine/equalize_messenger.py b/pyro/poutine/equalize_messenger.py index 035d96c06b..e17693267b 100644 --- a/pyro/poutine/equalize_messenger.py +++ b/pyro/poutine/equalize_messenger.py @@ -72,6 +72,6 @@ def _process_message(self, msg: Message) -> None: if self.value is not None and self._is_matching(msg): # type: ignore[unreachable] msg["value"] = self.value # type: ignore[unreachable] if msg["type"] == "sample": - msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim) + msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim).mask(False) msg["infer"] = {"_deterministic": True} msg["is_observed"] = True diff --git a/tests/poutine/test_poutines.py b/tests/poutine/test_poutines.py index 7e2f7cfa8e..c06a2a8778 100644 --- a/tests/poutine/test_poutines.py +++ b/tests/poutine/test_poutines.py @@ -798,6 +798,12 @@ def test_param_equalization(self): assert_equal(tr.nodes["cats_shift"]["value"], tr.nodes["dogs_shift"]["value"]) assert_not_equal(tr.nodes["cats_std"]["value"], tr.nodes["dogs_std"]["value"]) + def test_render_model(self): + pyro.set_rng_seed(20240616) + pyro.clear_param_store() + model = poutine.equalize(self.model, ".+_std") + pyro.render_model(model) + @pytest.mark.parametrize("first_available_dim", [-1, -2, -3]) @pytest.mark.parametrize("depth", [0, 1, 2]) From b19402dda94928966d8474ea478e25b43db9544e Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Tue, 23 Jul 2024 11:17:45 +0300 Subject: [PATCH 2/2] Ignore newly added mypy callability check. --- pyro/nn/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/nn/module.py b/pyro/nn/module.py index b84a17875b..05190e24d7 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -138,7 +138,7 @@ def __get__( if name not in obj.__dict__["_pyro_params"]: init_value, constraint, event_dim = self # bind method's self arg - init_value = functools.partial(init_value, obj) # type: ignore[arg-type] + init_value = functools.partial(init_value, obj) # type: ignore[arg-type,misc,operator] setattr(obj, name, PyroParam(init_value, constraint, event_dim)) value: PyroParam = obj.__getattr__(name) return value