Skip to content

Commit

Permalink
Add EqualizeMessenger tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Zickel committed Jun 16, 2024
1 parent 03fcce4 commit f267d9a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyro/poutine/equalize_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __enter__(self) -> Self:
self.value = None
return super().__enter__()

def _is_matching(self, msg: Message):
def _is_matching(self, msg: Message) -> bool:
if msg["type"] == self.type:
for site in self.sites:
if re.compile(site).fullmatch(msg["name"]) is not None: # type: ignore[arg-type]
Expand Down
44 changes: 44 additions & 0 deletions tests/poutine/test_poutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,50 @@ def test_infer_config_sample(self):
assert tr.nodes["p"]["infer"] == {}


class EqualizeHandlerTests(TestCase):
def setUp(self):
def per_category_model(category):
shift = pyro.param(f"{category}_shift", torch.randn(1))
mean = pyro.sample(f"{category}_mean", pyro.distributions.Normal(0, 1))
std = pyro.sample(f"{category}_std", pyro.distributions.LogNormal(0, 1))
with pyro.plate(f"{category}_num_samples", 5):
return pyro.sample(
f"{category}_values", pyro.distributions.Normal(mean + shift, std)
)

def model(categories=["dogs", "cats"]):
return {category: per_category_model(category) for category in categories}

self.model = model

def test_sample_site_equalization(self):
pyro.set_rng_seed(20240616)
pyro.clear_param_store()
model = poutine.equalize(self.model, ".+_std")
tr = pyro.poutine.trace(model).get_trace()
assert_equal(tr.nodes["cats_std"]["value"], tr.nodes["dogs_std"]["value"])
assert_not_equal(
tr.nodes["cats_shift"]["value"], tr.nodes["dogs_shift"]["value"]
)
guide = pyro.infer.autoguide.AutoNormal(model)
guide_sites = [*guide()]
assert guide_sites == [
"dogs_mean",
"dogs_std",
"dogs_values",
"cats_mean",
"cats_values",
]

def test_param_equalization(self):
pyro.set_rng_seed(20240616)
pyro.clear_param_store()
model = poutine.equalize(self.model, ".+_shift", "param")
tr = pyro.poutine.trace(model).get_trace()
assert_equal(tr.nodes["cats_shift"]["value"], tr.nodes["dogs_shift"]["value"])
assert_not_equal(tr.nodes["cats_std"]["value"], tr.nodes["dogs_std"]["value"])


@pytest.mark.parametrize("first_available_dim", [-1, -2, -3])
@pytest.mark.parametrize("depth", [0, 1, 2])
def test_enumerate_poutine(depth, first_available_dim):
Expand Down

0 comments on commit f267d9a

Please sign in to comment.