Skip to content

Commit

Permalink
Introducing PyroModuleList, because torch.nn.ModueList reinitializies…
Browse files Browse the repository at this point in the history
… modules when slice-indexing (#3339)
  • Loading branch information
MartinBubel authored Mar 17, 2024
1 parent 01e340e commit 8869834
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pyro/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
MaskedLinear,
)
from pyro.nn.dense_nn import ConditionalDenseNN, DenseNN
from pyro.nn.module import PyroModule, PyroParam, PyroSample, pyro_method
from pyro.nn.module import (
PyroModule,
PyroModuleList,
PyroParam,
PyroSample,
pyro_method,
)

__all__ = [
"AutoRegressiveNN",
Expand All @@ -21,4 +27,5 @@
"PyroParam",
"PyroSample",
"pyro_method",
"PyroModuleList",
]
43 changes: 43 additions & 0 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,21 @@
"""
import functools
import inspect
import warnings
import weakref

try:
from torch._jit_internal import _copy_to_script_wrapper
except ImportError:
warnings.warn(
"Cannot find torch._jit_internal._copy_to_script_wrapper", ImportWarning
)

# Fall back to trivial decorator.
def _copy_to_script_wrapper(fn):
return fn


from collections import OrderedDict
from dataclasses import dataclass
from types import TracebackType
Expand Down Expand Up @@ -902,3 +916,32 @@ def __set__(self, obj: object, value: Any) -> None:


PyroModule[torch.nn.RNNBase]._flat_weights = _FlatWeightsDescriptor() # type: ignore[attr-defined]


# pyro module list
# using pyro.nn.PyroModule[torch.nn.ModuleList] can cause issues when
# slice-indexing nested PyroModuleLists, so we define a separate PyroModuleList
# class that overwrites the __getitem__ method to return a torch.nn.ModuleList
# to not use self.__class__ in __getitem__, as that would call the
# PyroModule.__init__ without the parent module context, leading to a loss
# of the parent module's _pyro_name, and eventually, errors during sampling
# as parameter names may not be unique anymore
# The scenario is rare but happend.
# The fix could not be applied in torch directly, which is why we have to deal
# with it here, see https://github.com/pytorch/pytorch/issues/121008
class PyroModuleList(torch.nn.ModuleList, PyroModule):
def __init__(self, modules):
super().__init__(modules)

@_copy_to_script_wrapper
def __getitem__(
self, idx: Union[int, slice]
) -> Union[torch.nn.Module, "PyroModuleList"]:
if isinstance(idx, slice):
# return self.__class__(list(self._modules.values())[idx])
return torch.nn.ModuleList(list(self._modules.values())[idx])
else:
return self._modules[self._get_abs_string_index(idx)]


_PyroModuleMeta._pyro_mixin_cache[torch.nn.ModuleList] = PyroModuleList
201 changes: 201 additions & 0 deletions tests/nn/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

import io
import math
import warnings
from typing import Callable, Iterable

import pytest
import torch
Expand All @@ -13,6 +15,7 @@
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide.guides import AutoDiagonalNormal
from pyro.nn.module import PyroModule, PyroParam, PyroSample, clear, to_pyro_module_
from pyro.optim import Adam
from tests.common import assert_equal, xfail_param
Expand Down Expand Up @@ -844,3 +847,201 @@ def forward(self, x, y):
grad_params_func[k], torch.zeros_like(grad_params_func[k])
), k
assert torch.allclose(grad_params_autograd[k], grad_params_func[k]), k


class BNN(PyroModule):
# this is a vanilla Bayesian neural network implementation, nothing new or exiting here
def __init__(
self,
input_size: int,
hidden_layer_sizes: Iterable[int],
output_size: int,
use_new_module_list_type: bool,
) -> None:
super().__init__()

layer_sizes = (
[(input_size, hidden_layer_sizes[0])]
+ list(zip(hidden_layer_sizes[:-1], hidden_layer_sizes[1:]))
+ [(hidden_layer_sizes[-1], output_size)]
)

layers = [
pyro.nn.module.PyroModule[torch.nn.Linear](in_size, out_size)
for in_size, out_size in layer_sizes
]
if use_new_module_list_type:
self.layers = pyro.nn.module.PyroModuleList(layers)
else:
self.layers = pyro.nn.module.PyroModule[torch.nn.ModuleList](layers)

# make the layers Bayesian
for layer_idx, layer in enumerate(self.layers):
layer.weight = pyro.nn.module.PyroSample(
dist.Normal(0.0, 5.0 * math.sqrt(2 / layer_sizes[layer_idx][0]))
.expand(
[
layer_sizes[layer_idx][1],
layer_sizes[layer_idx][0],
]
)
.to_event(2)
)
layer.bias = pyro.nn.module.PyroSample(
dist.Normal(0.0, 5.0).expand([layer_sizes[layer_idx][1]]).to_event(1)
)

self.activation = torch.nn.Tanh()
self.output_size = output_size

def forward(self, x: torch.Tensor, obs=None) -> torch.Tensor:
mean = self.layers[-1](x)

if obs is not None:
with pyro.plate("data", x.shape[0]):
pyro.sample(
"obs", dist.Normal(mean, 0.1).to_event(self.output_size), obs=obs
)

return mean


class SliceIndexingModuleListBNN(BNN):
# I claim that it makes a difference whether slice-indexing is used or whether position-indexing is used
# when sub-pyromodule are wrapped in a PyroModule[torch.nn.ModuleList]
def __init__(
self,
input_size: int,
hidden_layer_sizes: Iterable[int],
output_size: int,
use_new_module_list_type: bool,
) -> None:
super().__init__(
input_size, hidden_layer_sizes, output_size, use_new_module_list_type
)

def forward(self, x: torch.Tensor, obs=None) -> torch.Tensor:
for layer in self.layers[:-1]:
x = layer(x)
x = self.activation(x)

return super().forward(x, obs=obs)


class PositionIndexingModuleListBNN(BNN):
# I claim that it makes a difference whether slice-indexing is used or whether position-indexing is used
# when sub-pyromodule are wrapped in a PyroModule[torch.nn.ModuleList]
def __init__(
self,
input_size: int,
hidden_layer_sizes: Iterable[int],
output_size: int,
use_new_module_list_type: bool,
) -> None:
super().__init__(
input_size, hidden_layer_sizes, output_size, use_new_module_list_type
)

def forward(self, x: torch.Tensor, obs=None) -> torch.Tensor:
for i in range(len(self.layers) - 1):
x = self.layers[i](x)
x = self.activation(x)

return super().forward(x, obs=obs)


class NestedBNN(pyro.nn.module.PyroModule):
# finally, the issue I want to describe occurs after the second "layer of nesting",
# i.e. when a PyroModule[ModuleList] is wrapped in a PyroModule[ModuleList]
def __init__(self, bnns: Iterable[BNN], use_new_module_list_type: bool) -> None:
super().__init__()
if use_new_module_list_type:
self.bnns = pyro.nn.module.PyroModuleList(bnns)
else:
self.bnns = pyro.nn.module.PyroModule[torch.nn.ModuleList](bnns)

def forward(self, x: torch.Tensor, obs=None) -> torch.Tensor:
mean = sum([bnn(x) for bnn in self.bnns]) / len(self.bnns)

with pyro.plate("data", x.shape[0]):
pyro.sample("obs", dist.Normal(mean, 0.1).to_event(1), obs=obs)

return mean


def train_bnn(model: BNN, input_size: int) -> None:
pyro.clear_param_store()

# small numbers for demo purposes
num_points = 20
num_svi_iterations = 100

x = torch.linspace(0, 1, num_points).reshape((-1, input_size))
y = torch.sin(2 * math.pi * x) + torch.randn(x.size()) * 0.1

guide = AutoDiagonalNormal(model)
adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO())

for _ in range(num_svi_iterations):
svi.step(x, y)


class ModuleListTester:
def setup(self, use_new_module_list_type: bool) -> None:
self.input_size = 1
self.output_size = 1
self.hidden_size = 3
self.num_hidden_layers = 3
self.use_new_module_list_type = use_new_module_list_type

def get_position_indexing_modulelist_bnn(self) -> PositionIndexingModuleListBNN:
return PositionIndexingModuleListBNN(
self.input_size,
[self.hidden_size] * self.num_hidden_layers,
self.output_size,
self.use_new_module_list_type,
)

def get_slice_indexing_modulelist_bnn(self) -> SliceIndexingModuleListBNN:
return SliceIndexingModuleListBNN(
self.input_size,
[self.hidden_size] * self.num_hidden_layers,
self.output_size,
self.use_new_module_list_type,
)

def train_nested_bnn(self, module_getter: Callable[[], BNN]) -> None:
train_bnn(
NestedBNN(
[module_getter() for _ in range(2)],
use_new_module_list_type=self.use_new_module_list_type,
),
self.input_size,
)


class TestTorchModuleList(ModuleListTester):
def test_with_position_indexing(self) -> None:
self.setup(False)
self.train_nested_bnn(self.get_position_indexing_modulelist_bnn)

def test_with_slice_indexing(self) -> None:
self.setup(False)
# with pytest.raises(RuntimeError):
# error no longer gets raised
self.train_nested_bnn(self.get_slice_indexing_modulelist_bnn)


class TestPyroModuleList(ModuleListTester):
def test_with_position_indexing(self) -> None:
self.setup(True)
self.train_nested_bnn(self.get_position_indexing_modulelist_bnn)

def test_with_slice_indexing(self) -> None:
self.setup(True)
self.train_nested_bnn(self.get_slice_indexing_modulelist_bnn)


def test_module_list() -> None:
assert PyroModule[torch.nn.ModuleList] is pyro.nn.PyroModuleList

0 comments on commit 8869834

Please sign in to comment.