Skip to content

Commit

Permalink
Mypy warn_unreachable=True (#3312)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy authored Jan 17, 2024
1 parent 4f17274 commit c0b36f1
Show file tree
Hide file tree
Showing 24 changed files with 291 additions and 162 deletions.
6 changes: 3 additions & 3 deletions pyro/ops/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import math
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
from torch.distributions.utils import lazy_property
Expand Down Expand Up @@ -111,7 +111,7 @@ def event_permute(self, perm) -> "Gaussian":
precision = self.precision[..., perm][..., perm, :]
return Gaussian(self.log_normalizer, info_vec, precision)

def __add__(self, other: "Gaussian") -> "Gaussian":
def __add__(self, other: Union["Gaussian", int, float, torch.Tensor]) -> "Gaussian":
"""
Adds two Gaussians in log-density space.
"""
Expand All @@ -126,7 +126,7 @@ def __add__(self, other: "Gaussian") -> "Gaussian":
return Gaussian(self.log_normalizer + other, self.info_vec, self.precision)
raise ValueError("Unsupported type: {}".format(type(other)))

def __sub__(self, other: "Gaussian") -> "Gaussian":
def __sub__(self, other: Union["Gaussian", int, float, torch.Tensor]) -> "Gaussian":
if isinstance(other, (int, float, torch.Tensor)):
return Gaussian(self.log_normalizer - other, self.info_vec, self.precision)
raise ValueError("Unsupported type: {}".format(type(other)))
Expand Down
24 changes: 14 additions & 10 deletions pyro/poutine/block_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# SPDX-License-Identifier: Apache-2.0

from functools import partial
from typing import Callable, List, Optional
from typing import TYPE_CHECKING, Callable, List, Optional

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message

if TYPE_CHECKING:
from pyro.poutine.runtime import Message


def _block_fn(
Expand All @@ -14,7 +16,7 @@ def _block_fn(
hide: List[str],
hide_types: List[str],
hide_all: bool,
msg: Message,
msg: "Message",
) -> bool:
# handle observes
if msg["type"] == "sample" and msg["is_observed"]:
Expand Down Expand Up @@ -43,7 +45,7 @@ def _make_default_hide_fn(
expose: Optional[List[str]],
hide_types: Optional[List[str]],
expose_types: Optional[List[str]],
) -> Callable[[Message], bool]:
) -> Callable[["Message"], bool]:
# first, some sanity checks:
# hide_all and expose_all intersect?
assert (hide_all is False and expose_all is False) or (
Expand Down Expand Up @@ -81,9 +83,11 @@ def _make_default_hide_fn(
return partial(_block_fn, expose, expose_types, hide, hide_types, hide_all)


def _negate_fn(fn: Callable[[Message], Optional[bool]]) -> Callable[[Message], bool]:
def _negate_fn(
fn: Callable[["Message"], Optional[bool]]
) -> Callable[["Message"], bool]:
# typed version of lambda msg: not fn(msg)
def negated_fn(msg: Message) -> bool:
def negated_fn(msg: "Message") -> bool:
return not fn(msg)

return negated_fn
Expand Down Expand Up @@ -140,15 +144,15 @@ class BlockMessenger(Messenger):

def __init__(
self,
hide_fn: Optional[Callable[[Message], Optional[bool]]] = None,
expose_fn: Optional[Callable[[Message], Optional[bool]]] = None,
hide_fn: Optional[Callable[["Message"], Optional[bool]]] = None,
expose_fn: Optional[Callable[["Message"], Optional[bool]]] = None,
hide_all: bool = True,
expose_all: bool = False,
hide: Optional[List[str]] = None,
expose: Optional[List[str]] = None,
hide_types: Optional[List[str]] = None,
expose_types: Optional[List[str]] = None,
):
) -> None:
super().__init__()
if not (hide_fn is None or expose_fn is None):
raise ValueError("Only specify one of hide_fn or expose_fn")
Expand All @@ -161,5 +165,5 @@ def __init__(
hide_all, expose_all, hide, expose, hide_types, expose_types
)

def _process_message(self, msg: Message) -> None:
def _process_message(self, msg: "Message") -> None:
msg["stop"] = bool(self.hide_fn(msg))
8 changes: 5 additions & 3 deletions pyro/poutine/broadcast_messenger.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional

from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message
from pyro.util import ignore_jit_warnings

if TYPE_CHECKING:
from pyro.poutine.runtime import Message


class BroadcastMessenger(Messenger):
"""
Expand Down Expand Up @@ -41,7 +43,7 @@ class BroadcastMessenger(Messenger):

@staticmethod
@ignore_jit_warnings(["Converting a tensor to a Python boolean"])
def _pyro_sample(msg: Message) -> None:
def _pyro_sample(msg: "Message") -> None:
"""
:param msg: current message at a trace site.
"""
Expand Down
53 changes: 24 additions & 29 deletions pyro/poutine/collapse_messenger.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0


from functools import reduce, singledispatch
from typing import TYPE_CHECKING, FrozenSet, Tuple

from typing_extensions import Self

import pyro
from pyro.distributions.distribution import COERCIONS
from pyro.ops.linalg import ignore_torch_deprecation_warnings
from pyro.poutine.runtime import _PYRO_STACK
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.util import site_is_subsample

from .runtime import _PYRO_STACK
from .trace_messenger import TraceMessenger

# TODO Remove import guard once funsor is a required dependency.
try:
import funsor
Expand All @@ -24,20 +27,10 @@
Funsor = type("Funsor", (), {})
Variable = type("Variable", (), {})

if TYPE_CHECKING:
from funsor.distribution import Distribution

@singledispatch
def _get_free_vars(x):
return x


@_get_free_vars.register(Variable)
def _(x):
return frozenset((x.name,))


@_get_free_vars.register(tuple)
def _(x, subs):
return frozenset().union(*map(_get_free_vars, x))
from pyro.poutine.runtime import Message


@singledispatch
Expand Down Expand Up @@ -92,7 +85,7 @@ class CollapseMessenger(TraceMessenger):

_coerce = None

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
if CollapseMessenger._coerce is None:
import funsor
from funsor.distribution import CoerceDistributionToFunsor
Expand All @@ -102,18 +95,20 @@ def __init__(self, *args, **kwargs):
self._block = False
super().__init__(*args, **kwargs)

def _process_message(self, msg):
def _process_message(self, msg: "Message") -> None:
if self._block:
return
if site_is_subsample(msg):
return
super()._process_message(msg)

def _pyro_sample(self, msg):
def _pyro_sample(self, msg: "Message") -> None:
# Eagerly convert fn and value to Funsor.
dim_to_name = {f.dim: f.name for f in msg["cond_indep_stack"]}
dim_to_name.update(self.preserved_plates)
msg["fn"] = funsor.to_funsor(msg["fn"], funsor.Real, dim_to_name)
if TYPE_CHECKING:
assert isinstance(msg["fn"], Distribution)
domain = msg["fn"].inputs["value"]
if msg["value"] is None:
msg["value"] = funsor.Variable(msg["name"], domain)
Expand All @@ -123,14 +118,14 @@ def _pyro_sample(self, msg):
msg["done"] = True
msg["stop"] = True

def _pyro_post_sample(self, msg):
def _pyro_post_sample(self, msg: "Message") -> None:
if self._block:
return
if site_is_subsample(msg):
return
super()._pyro_post_sample(msg)

def _pyro_barrier(self, msg):
def _pyro_barrier(self, msg: "Message") -> None:
# Get log_prob and record factor.
name, log_prob, log_joint, sampled_vars = self._get_log_prob()
self._block = True
Expand All @@ -151,14 +146,14 @@ def _pyro_barrier(self, msg):
value = _substitute(value, samples)
msg["value"] = value

def __enter__(self):
def __enter__(self) -> Self:
self.preserved_plates = {
h.dim: h.name for h in _PYRO_STACK if isinstance(h, pyro.plate)
}
COERCIONS.append(self._coerce)
return super().__enter__()

def __exit__(self, *args):
def __exit__(self, *args) -> None:
_coerce = COERCIONS.pop()
assert _coerce is self._coerce
super().__exit__(*args)
Expand All @@ -168,20 +163,20 @@ def __exit__(self, *args):
pyro.factor(name, log_prob.data)

@ignore_torch_deprecation_warnings()
def _get_log_prob(self):
def _get_log_prob(self) -> Tuple[str, Funsor, Funsor, FrozenSet[str]]:
# Convert delayed statements to pyro.factor()
reduced_vars = []
reduced_vars_list = []
log_prob_terms = []
plates = frozenset()
plates: FrozenSet[str] = frozenset()
for name, site in self.trace.nodes.items():
if not site["is_observed"]:
reduced_vars.append(name)
reduced_vars_list.append(name)
log_prob_terms.append(site["fn"](value=site["value"]))
plates |= frozenset(
f.name for f in site["cond_indep_stack"] if f.vectorized
)
name = reduced_vars[0]
reduced_vars = frozenset(reduced_vars)
name = reduced_vars_list[0]
reduced_vars = frozenset(reduced_vars_list)
assert log_prob_terms, "nothing to collapse"
self.trace.nodes.clear()
reduced_plates = plates - frozenset(self.preserved_plates.values())
Expand Down
8 changes: 5 additions & 3 deletions pyro/poutine/condition_messenger.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, Union
from typing import TYPE_CHECKING, Dict, Union

import torch

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message
from pyro.poutine.trace_struct import Trace

if TYPE_CHECKING:
from pyro.poutine.runtime import Message


class ConditionMessenger(Messenger):
"""
Expand Down Expand Up @@ -46,7 +48,7 @@ def __init__(self, data: Union[Dict[str, torch.Tensor], Trace]) -> None:
super().__init__()
self.data = data

def _pyro_sample(self, msg: Message) -> None:
def _pyro_sample(self, msg: "Message") -> None:
"""
:param msg: current message at a trace site.
:returns: a sample from the stochastic function at the site.
Expand Down
2 changes: 1 addition & 1 deletion pyro/poutine/do_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class DoMessenger(Messenger):
:returns: stochastic function decorated with a :class:`~pyro.poutine.do_messenger.DoMessenger`
"""

def __init__(self, data: Dict[str, Union[torch.Tensor, numbers.Number]]):
def __init__(self, data: Dict[str, Union[torch.Tensor, numbers.Number]]) -> None:
super().__init__()
self.data = data
self._intervener_id = str(id(self))
Expand Down
27 changes: 16 additions & 11 deletions pyro/poutine/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union

import torch

import pyro.distributions as dist
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.poutine.runtime import Message
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import prune_subsample_sites, site_is_subsample

if TYPE_CHECKING:
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.poutine.runtime import Message


class GuideMessenger(TraceMessenger, ABC):
"""
Expand Down Expand Up @@ -60,12 +62,13 @@ def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[
samples[name] = site["value"]
return samples

def _pyro_sample(self, msg: Message) -> None:
def _pyro_sample(self, msg: "Message") -> None:
if msg["is_observed"] or site_is_subsample(msg):
return
assert isinstance(msg["name"], str)
assert isinstance(msg["fn"], TorchDistributionMixin)
assert msg["infer"] is not None
if TYPE_CHECKING:
assert isinstance(msg["name"], str)
assert isinstance(msg["fn"], TorchDistributionMixin)
assert msg["infer"] is not None
prior = msg["fn"]
msg["infer"]["prior"] = prior
posterior = self.get_posterior(msg["name"], prior)
Expand All @@ -75,20 +78,21 @@ def _pyro_sample(self, msg: Message) -> None:
posterior = posterior.expand(prior.batch_shape)
msg["fn"] = posterior

def _pyro_post_sample(self, msg: Message) -> None:
def _pyro_post_sample(self, msg: "Message") -> None:
# Manually apply outer plates.
assert msg["infer"] is not None
prior = msg["infer"].get("prior")
if prior is not None:
assert isinstance(msg["fn"], TorchDistributionMixin)
if TYPE_CHECKING:
assert isinstance(msg["fn"], TorchDistributionMixin)
if prior.batch_shape != msg["fn"].batch_shape:
msg["infer"]["prior"] = prior.expand(msg["fn"].batch_shape)
return super()._pyro_post_sample(msg)

@abstractmethod
def get_posterior(
self, name: str, prior: TorchDistributionMixin
) -> Union[TorchDistributionMixin, torch.Tensor]:
self, name: str, prior: "TorchDistributionMixin"
) -> Union["TorchDistributionMixin", torch.Tensor]:
"""
Abstract method to compute a posterior distribution or sample a
posterior value given a prior distribution conditioned on upstream
Expand Down Expand Up @@ -148,6 +152,7 @@ def get_traces(self) -> Tuple[Trace, Trace]:
del guide_trace.nodes[name]
continue
model_site = model_trace.nodes[name].copy()
assert guide_site["infer"] is not None
model_site["fn"] = guide_site["infer"]["prior"]
model_trace.nodes[name] = model_site
return model_trace, guide_trace
Loading

0 comments on commit c0b36f1

Please sign in to comment.