Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for named dims (torchdim) #3347

Open
wants to merge 14 commits into
base: dev
Choose a base branch
from
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ scrub: FORCE
doctest: FORCE
# We skip testing pyro.distributions.torch wrapper classes because
# they include torch docstrings which are tested upstream.
python -m pytest -p tests.doctest_fixtures --doctest-modules -o filterwarnings=ignore pyro --ignore=pyro/distributions/torch.py
python -m pytest -p tests.doctest_fixtures --doctest-modules -o filterwarnings=ignore pyro --ignore=pyro/distributions/torch.py \
--ignore=pyro/contrib/named

perf-test: FORCE
bash scripts/perf_test.sh ${ref}
Expand Down
Empty file added pyro/contrib/named/__init__.py
Empty file.
6 changes: 6 additions & 0 deletions pyro/contrib/named/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from pyro.contrib.named.infer.elbo import Trace_ELBO

__all__ = ["Trace_ELBO"]
133 changes: 133 additions & 0 deletions pyro/contrib/named/infer/elbo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Tuple

import torch
from functorch.dim import Dim
from typing_extensions import ParamSpec

import pyro
from pyro import poutine
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.infer import ELBO as _OrigELBO
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message

_P = ParamSpec("_P")


class ELBO(_OrigELBO):
def _get_trace(self, *args, **kwargs):
raise RuntimeError("shouldn't be here!")

def differentiable_loss(self, model, guide, *args, **kwargs):
raise NotImplementedError("Must implement differentiable_loss")

def loss(self, model, guide, *args, **kwargs):
return self.differentiable_loss(model, guide, *args, **kwargs).detach().item()

def loss_and_grads(self, model, guide, *args, **kwargs):
loss = self.differentiable_loss(model, guide, *args, **kwargs)
loss.backward()
return loss.item()


def track_provenance(x: torch.Tensor, provenance: Dim) -> torch.Tensor:
return x.unsqueeze(0)[provenance]


class track_nonreparam(Messenger):
def _pyro_post_sample(self, msg: Message) -> None:
if (
msg["type"] == "sample"
and isinstance(msg["fn"], TorchDistributionMixin)
and not msg["is_observed"]
and not msg["fn"].has_rsample
):
provenance = Dim(msg["name"])
msg["value"] = track_provenance(msg["value"], provenance)


def get_importance_trace(
model: Callable[_P, Any],
guide: Callable[_P, Any],
*args: _P.args,
**kwargs: _P.kwargs
) -> Tuple[poutine.Trace, poutine.Trace]:
"""
Returns traces from the guide and the model that is run against it.
The returned traces also store the log probability at each site.
"""
with track_nonreparam():
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
replay_model = poutine.replay(model, trace=guide_trace)
model_trace = poutine.trace(replay_model).get_trace(*args, **kwargs)

for is_guide, trace in zip((True, False), (guide_trace, model_trace)):
for site in list(trace.nodes.values()):
if site["type"] == "sample" and isinstance(
site["fn"], TorchDistributionMixin
):
log_prob = site["fn"].log_prob(site["value"])
site["log_prob"] = log_prob

if is_guide and not site["fn"].has_rsample:
# importance sampling weights
site["log_measure"] = log_prob - log_prob.detach()
else:
trace.remove_node(site["name"])
return model_trace, guide_trace


class Trace_ELBO(ELBO):
def differentiable_loss(
self,
model: Callable[_P, Any],
guide: Callable[_P, Any],
*args: _P.args,
**kwargs: _P.kwargs
) -> torch.Tensor:
if self.num_particles > 1:
vectorize = pyro.plate(
"num_particles", self.num_particles, dim=Dim("num_particles")
)
model = vectorize(model)
guide = vectorize(guide)

model_trace, guide_trace = get_importance_trace(model, guide, *args, **kwargs)

cost_terms = []
# logp terms
for site in model_trace.nodes.values():
cost = site["log_prob"]
scale = site["scale"]
batch_dims = tuple(f.dim for f in site["cond_indep_stack"])
deps = tuple(set(getattr(cost, "dims", ())) - set(batch_dims))
cost_terms.append((cost, scale, batch_dims, deps))
# -logq terms
for site in guide_trace.nodes.values():
cost = -site["log_prob"]
scale = site["scale"]
batch_dims = tuple(f.dim for f in site["cond_indep_stack"])
deps = tuple(set(getattr(cost, "dims", ())) - set(batch_dims))
cost_terms.append((cost, scale, batch_dims, deps))

elbo = 0.0
for cost, scale, batch_dims, deps in cost_terms:
if deps:
dice_factor = 0.0
for key in deps:
dice_factor += guide_trace.nodes[str(key)]["log_measure"]
dice_factor_dims = getattr(dice_factor, "dims", ())
cost_dims = getattr(cost, "dims", ())
sum_dims = tuple(set(dice_factor_dims) - set(cost_dims))
if sum_dims:
dice_factor = dice_factor.sum(sum_dims)
cost = torch.exp(dice_factor) * cost
cost = cost.mean(deps)
if scale is not None:
cost = cost * scale
elbo += cost.sum(batch_dims) / self.num_particles

return -elbo
49 changes: 47 additions & 2 deletions pyro/distributions/torch_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import warnings
from collections import OrderedDict
from typing import Callable
from typing import TYPE_CHECKING, Callable, Tuple

import torch
from torch.distributions.kl import kl_divergence, register_kl
from typing_extensions import Self

import pyro.distributions.torch

Expand All @@ -15,6 +16,9 @@
from .score_parts import ScoreParts
from .util import broadcast_shape, scale_and_mask

if TYPE_CHECKING:
from functorch.dim import Dim


class TorchDistributionMixin(Distribution, Callable):
"""
Expand Down Expand Up @@ -45,11 +49,52 @@ def __call__(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
batched). The shape of the result should be `self.shape()`.
:rtype: torch.Tensor
"""
return (
sample_shape = self.named_sample_shape + sample_shape
result = (
self.rsample(sample_shape)
if self.has_rsample
else self.sample(sample_shape)
)
bind_named_dims = self.named_shape[
len(self.named_shape) - len(self.named_sample_shape) :
]
if bind_named_dims:
result = result[bind_named_dims]
return result
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit tests for distribution shapes for log_prob, mean, sample, rsample, entropy (fail when named and positional dims are mixed in the batch/event/sample shape; conflicting named dims)

Generalize named dim binding implementation.

Test transforms and support.

Shape inference.


@property
def named_shape(self) -> Tuple["Dim"]:
if getattr(self, "_named_shape", None) is None:
result = []
for param in self.arg_constraints:
value = getattr(self, param)
for dim in getattr(value, "dims", ()):
# Can't use `dim in result` when `result` is a list or a tuple
# RuntimeError: vmap: It looks like you're attempting to use
# a Tensor in some data-dependent control flow. We don't support
# that yet, please shout over at
# https://github.com/pytorch/functorch/issues/257
if dim not in set(result):
result.append(dim)
self._named_shape = tuple(result)
return self._named_shape

def expand_named_shape(self, named_shape: Tuple["Dim"]) -> Self:
for dim in named_shape:
if dim not in set(self.named_shape):
self._named_shape += (dim,)
self.named_sample_shape = self.named_sample_shape + (dim.size,)
return self

@property
def named_sample_shape(self) -> torch.Size:
if getattr(self, "_named_sample_shape", None) is None:
self._named_sample_shape = torch.Size()
return self._named_sample_shape

@named_sample_shape.setter
def named_sample_shape(self, value: torch.Size) -> None:
self._named_sample_shape = value

@property
def batch_shape(self) -> torch.Size:
Expand Down
4 changes: 4 additions & 0 deletions pyro/ops/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,7 @@ def __init__(self, tensor):

def __getitem__(self, args):
return vindex(self._tensor, args)


def index_select(input, dim, index):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add type annotation. Move to contrib/named in the follow up PR.

return input.order(dim)[index]
12 changes: 11 additions & 1 deletion pyro/poutine/broadcast_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from pyro.util import ignore_jit_warnings

if TYPE_CHECKING:
from functorch.dim import Dim

from pyro.poutine.runtime import Message


Expand Down Expand Up @@ -59,7 +61,11 @@ def _pyro_sample(msg: "Message") -> None:
target_batch_shape = [
None if size == 1 else size for size in actual_batch_shape
]
named_shape: List["Dim"] = []
for f in msg["cond_indep_stack"]:
if hasattr(f.dim, "is_bound"):
named_shape.append(f.dim)
continue
if f.dim is None or f.size == -1:
continue
assert f.dim < 0
Expand Down Expand Up @@ -88,6 +94,10 @@ def _pyro_sample(msg: "Message") -> None:
target_batch_shape[i] = (
actual_batch_shape[i] if len(actual_batch_shape) >= -i else 1
)
msg["fn"] = dist.expand(target_batch_shape)
if named_shape:
assert len(target_batch_shape) == 0
msg["fn"] = dist.expand_named_shape(tuple(named_shape))
else:
msg["fn"] = dist.expand(target_batch_shape)
if msg["fn"].has_rsample != dist.has_rsample:
msg["fn"].has_rsample = dist.has_rsample # copy custom attribute
19 changes: 12 additions & 7 deletions pyro/poutine/indep_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import numbers
from typing import Iterator, NamedTuple, Optional, Tuple
from typing import TYPE_CHECKING, Iterator, NamedTuple, Optional, Tuple, Union

import torch
from typing_extensions import Self
Expand All @@ -11,10 +11,13 @@
from pyro.poutine.runtime import _DIM_ALLOCATOR, Message
from pyro.util import ignore_jit_warnings

if TYPE_CHECKING:
from functorch.dim import Dim


class CondIndepStackFrame(NamedTuple):
name: str
dim: Optional[int]
dim: Optional[Union[int, "Dim"]]
size: int
counter: int
full_size: Optional[int] = None
Expand All @@ -23,7 +26,7 @@ class CondIndepStackFrame(NamedTuple):
def vectorized(self) -> bool:
return self.dim is not None

def _key(self) -> Tuple[str, Optional[int], int, int]:
def _key(self) -> Tuple[str, Optional[Union[int, "Dim"]], int, int]:
size = self.size
with ignore_jit_warnings(["Converting a tensor to a Python number"]):
if isinstance(size, torch.Tensor): # type: ignore[unreachable]
Expand Down Expand Up @@ -69,7 +72,7 @@ def __init__(
self,
name: str,
size: int,
dim: Optional[int] = None,
dim: Optional[Union[int, "Dim"]] = None,
device: Optional[str] = None,
) -> None:
if not torch._C._get_tracing_state() and size == 0:
Expand Down Expand Up @@ -97,13 +100,13 @@ def __enter__(self) -> Self:
if self._vectorized is not False:
self._vectorized = True

if self._vectorized is True:
if self._vectorized is True and not hasattr(self.dim, "is_bound"):
self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim)

return super().__enter__()

def __exit__(self, *args) -> None:
if self._vectorized is True:
if self._vectorized is True and not hasattr(self.dim, "is_bound"):
assert self.dim is not None
_DIM_ALLOCATOR.free(self.name, self.dim)
return super().__exit__(*args)
Expand All @@ -124,7 +127,7 @@ def __iter__(self) -> Iterator[int]:
yield i if isinstance(i, numbers.Number) else i.item()

def _reset(self) -> None:
if self._vectorized:
if self._vectorized and not hasattr(self.dim, "is_bound"):
assert self.dim is not None
_DIM_ALLOCATOR.free(self.name, self.dim)
self._vectorized = None
Expand All @@ -134,6 +137,8 @@ def _reset(self) -> None:
def indices(self) -> torch.Tensor:
if self._indices is None:
self._indices = torch.arange(self.size, dtype=torch.long).to(self.device)
if hasattr(self.dim, "is_bound"):
return self._indices[self.dim]
return self._indices

def _process_message(self, msg: Message) -> None:
Expand Down
Loading
Loading