diff --git a/pyro/nn/module.py b/pyro/nn/module.py index b84a17875b..05190e24d7 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -138,7 +138,7 @@ def __get__( if name not in obj.__dict__["_pyro_params"]: init_value, constraint, event_dim = self # bind method's self arg - init_value = functools.partial(init_value, obj) # type: ignore[arg-type] + init_value = functools.partial(init_value, obj) # type: ignore[arg-type,misc,operator] setattr(obj, name, PyroParam(init_value, constraint, event_dim)) value: PyroParam = obj.__getattr__(name) return value diff --git a/pyro/poutine/equalize_messenger.py b/pyro/poutine/equalize_messenger.py index 035d96c06b..e17693267b 100644 --- a/pyro/poutine/equalize_messenger.py +++ b/pyro/poutine/equalize_messenger.py @@ -72,6 +72,6 @@ def _process_message(self, msg: Message) -> None: if self.value is not None and self._is_matching(msg): # type: ignore[unreachable] msg["value"] = self.value # type: ignore[unreachable] if msg["type"] == "sample": - msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim) + msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim).mask(False) msg["infer"] = {"_deterministic": True} msg["is_observed"] = True diff --git a/tests/poutine/test_poutines.py b/tests/poutine/test_poutines.py index 7e2f7cfa8e..c06a2a8778 100644 --- a/tests/poutine/test_poutines.py +++ b/tests/poutine/test_poutines.py @@ -798,6 +798,12 @@ def test_param_equalization(self): assert_equal(tr.nodes["cats_shift"]["value"], tr.nodes["dogs_shift"]["value"]) assert_not_equal(tr.nodes["cats_std"]["value"], tr.nodes["dogs_std"]["value"]) + def test_render_model(self): + pyro.set_rng_seed(20240616) + pyro.clear_param_store() + model = poutine.equalize(self.model, ".+_std") + pyro.render_model(model) + @pytest.mark.parametrize("first_available_dim", [-1, -2, -3]) @pytest.mark.parametrize("depth", [0, 1, 2])