diff --git a/pyro/ops/gaussian.py b/pyro/ops/gaussian.py index 12f17e973c..b3da2a0915 100644 --- a/pyro/ops/gaussian.py +++ b/pyro/ops/gaussian.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import math -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from torch.distributions.utils import lazy_property @@ -111,7 +111,7 @@ def event_permute(self, perm) -> "Gaussian": precision = self.precision[..., perm][..., perm, :] return Gaussian(self.log_normalizer, info_vec, precision) - def __add__(self, other: "Gaussian") -> "Gaussian": + def __add__(self, other: Union["Gaussian", int, float, torch.Tensor]) -> "Gaussian": """ Adds two Gaussians in log-density space. """ @@ -126,7 +126,7 @@ def __add__(self, other: "Gaussian") -> "Gaussian": return Gaussian(self.log_normalizer + other, self.info_vec, self.precision) raise ValueError("Unsupported type: {}".format(type(other))) - def __sub__(self, other: "Gaussian") -> "Gaussian": + def __sub__(self, other: Union["Gaussian", int, float, torch.Tensor]) -> "Gaussian": if isinstance(other, (int, float, torch.Tensor)): return Gaussian(self.log_normalizer - other, self.info_vec, self.precision) raise ValueError("Unsupported type: {}".format(type(other))) diff --git a/pyro/poutine/block_messenger.py b/pyro/poutine/block_messenger.py index 79b594a29e..e7ebb5bbe0 100644 --- a/pyro/poutine/block_messenger.py +++ b/pyro/poutine/block_messenger.py @@ -2,10 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 from functools import partial -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message + +if TYPE_CHECKING: + from pyro.poutine.runtime import Message def _block_fn( @@ -14,7 +16,7 @@ def _block_fn( hide: List[str], hide_types: List[str], hide_all: bool, - msg: Message, + msg: "Message", ) -> bool: # handle observes if msg["type"] == "sample" and msg["is_observed"]: @@ -43,7 +45,7 @@ def _make_default_hide_fn( expose: Optional[List[str]], hide_types: Optional[List[str]], expose_types: Optional[List[str]], -) -> Callable[[Message], bool]: +) -> Callable[["Message"], bool]: # first, some sanity checks: # hide_all and expose_all intersect? assert (hide_all is False and expose_all is False) or ( @@ -81,9 +83,11 @@ def _make_default_hide_fn( return partial(_block_fn, expose, expose_types, hide, hide_types, hide_all) -def _negate_fn(fn: Callable[[Message], Optional[bool]]) -> Callable[[Message], bool]: +def _negate_fn( + fn: Callable[["Message"], Optional[bool]] +) -> Callable[["Message"], bool]: # typed version of lambda msg: not fn(msg) - def negated_fn(msg: Message) -> bool: + def negated_fn(msg: "Message") -> bool: return not fn(msg) return negated_fn @@ -140,15 +144,15 @@ class BlockMessenger(Messenger): def __init__( self, - hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, - expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, + hide_fn: Optional[Callable[["Message"], Optional[bool]]] = None, + expose_fn: Optional[Callable[["Message"], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None, - ): + ) -> None: super().__init__() if not (hide_fn is None or expose_fn is None): raise ValueError("Only specify one of hide_fn or expose_fn") @@ -161,5 +165,5 @@ def __init__( hide_all, expose_all, hide, expose, hide_types, expose_types ) - def _process_message(self, msg: Message) -> None: + def _process_message(self, msg: "Message") -> None: msg["stop"] = bool(self.hide_fn(msg)) diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index 445cd1caea..87e9dd2f7b 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -1,13 +1,15 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message from pyro.util import ignore_jit_warnings +if TYPE_CHECKING: + from pyro.poutine.runtime import Message + class BroadcastMessenger(Messenger): """ @@ -41,7 +43,7 @@ class BroadcastMessenger(Messenger): @staticmethod @ignore_jit_warnings(["Converting a tensor to a Python boolean"]) - def _pyro_sample(msg: Message) -> None: + def _pyro_sample(msg: "Message") -> None: """ :param msg: current message at a trace site. """ diff --git a/pyro/poutine/collapse_messenger.py b/pyro/poutine/collapse_messenger.py index 6206894943..4a7bd05479 100644 --- a/pyro/poutine/collapse_messenger.py +++ b/pyro/poutine/collapse_messenger.py @@ -1,16 +1,19 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 + from functools import reduce, singledispatch +from typing import TYPE_CHECKING, FrozenSet, Tuple + +from typing_extensions import Self import pyro from pyro.distributions.distribution import COERCIONS from pyro.ops.linalg import ignore_torch_deprecation_warnings +from pyro.poutine.runtime import _PYRO_STACK +from pyro.poutine.trace_messenger import TraceMessenger from pyro.poutine.util import site_is_subsample -from .runtime import _PYRO_STACK -from .trace_messenger import TraceMessenger - # TODO Remove import guard once funsor is a required dependency. try: import funsor @@ -24,20 +27,10 @@ Funsor = type("Funsor", (), {}) Variable = type("Variable", (), {}) +if TYPE_CHECKING: + from funsor.distribution import Distribution -@singledispatch -def _get_free_vars(x): - return x - - -@_get_free_vars.register(Variable) -def _(x): - return frozenset((x.name,)) - - -@_get_free_vars.register(tuple) -def _(x, subs): - return frozenset().union(*map(_get_free_vars, x)) + from pyro.poutine.runtime import Message @singledispatch @@ -92,7 +85,7 @@ class CollapseMessenger(TraceMessenger): _coerce = None - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: if CollapseMessenger._coerce is None: import funsor from funsor.distribution import CoerceDistributionToFunsor @@ -102,18 +95,20 @@ def __init__(self, *args, **kwargs): self._block = False super().__init__(*args, **kwargs) - def _process_message(self, msg): + def _process_message(self, msg: "Message") -> None: if self._block: return if site_is_subsample(msg): return super()._process_message(msg) - def _pyro_sample(self, msg): + def _pyro_sample(self, msg: "Message") -> None: # Eagerly convert fn and value to Funsor. dim_to_name = {f.dim: f.name for f in msg["cond_indep_stack"]} dim_to_name.update(self.preserved_plates) msg["fn"] = funsor.to_funsor(msg["fn"], funsor.Real, dim_to_name) + if TYPE_CHECKING: + assert isinstance(msg["fn"], Distribution) domain = msg["fn"].inputs["value"] if msg["value"] is None: msg["value"] = funsor.Variable(msg["name"], domain) @@ -123,14 +118,14 @@ def _pyro_sample(self, msg): msg["done"] = True msg["stop"] = True - def _pyro_post_sample(self, msg): + def _pyro_post_sample(self, msg: "Message") -> None: if self._block: return if site_is_subsample(msg): return super()._pyro_post_sample(msg) - def _pyro_barrier(self, msg): + def _pyro_barrier(self, msg: "Message") -> None: # Get log_prob and record factor. name, log_prob, log_joint, sampled_vars = self._get_log_prob() self._block = True @@ -151,14 +146,14 @@ def _pyro_barrier(self, msg): value = _substitute(value, samples) msg["value"] = value - def __enter__(self): + def __enter__(self) -> Self: self.preserved_plates = { h.dim: h.name for h in _PYRO_STACK if isinstance(h, pyro.plate) } COERCIONS.append(self._coerce) return super().__enter__() - def __exit__(self, *args): + def __exit__(self, *args) -> None: _coerce = COERCIONS.pop() assert _coerce is self._coerce super().__exit__(*args) @@ -168,20 +163,20 @@ def __exit__(self, *args): pyro.factor(name, log_prob.data) @ignore_torch_deprecation_warnings() - def _get_log_prob(self): + def _get_log_prob(self) -> Tuple[str, Funsor, Funsor, FrozenSet[str]]: # Convert delayed statements to pyro.factor() - reduced_vars = [] + reduced_vars_list = [] log_prob_terms = [] - plates = frozenset() + plates: FrozenSet[str] = frozenset() for name, site in self.trace.nodes.items(): if not site["is_observed"]: - reduced_vars.append(name) + reduced_vars_list.append(name) log_prob_terms.append(site["fn"](value=site["value"])) plates |= frozenset( f.name for f in site["cond_indep_stack"] if f.vectorized ) - name = reduced_vars[0] - reduced_vars = frozenset(reduced_vars) + name = reduced_vars_list[0] + reduced_vars = frozenset(reduced_vars_list) assert log_prob_terms, "nothing to collapse" self.trace.nodes.clear() reduced_plates = plates - frozenset(self.preserved_plates.values()) diff --git a/pyro/poutine/condition_messenger.py b/pyro/poutine/condition_messenger.py index 16ef9dd250..9ce259cd1f 100644 --- a/pyro/poutine/condition_messenger.py +++ b/pyro/poutine/condition_messenger.py @@ -1,14 +1,16 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Union +from typing import TYPE_CHECKING, Dict, Union import torch from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message from pyro.poutine.trace_struct import Trace +if TYPE_CHECKING: + from pyro.poutine.runtime import Message + class ConditionMessenger(Messenger): """ @@ -46,7 +48,7 @@ def __init__(self, data: Union[Dict[str, torch.Tensor], Trace]) -> None: super().__init__() self.data = data - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: """ :param msg: current message at a trace site. :returns: a sample from the stochastic function at the site. diff --git a/pyro/poutine/do_messenger.py b/pyro/poutine/do_messenger.py index bfd1bd0c73..ac686061dc 100644 --- a/pyro/poutine/do_messenger.py +++ b/pyro/poutine/do_messenger.py @@ -49,7 +49,7 @@ class DoMessenger(Messenger): :returns: stochastic function decorated with a :class:`~pyro.poutine.do_messenger.DoMessenger` """ - def __init__(self, data: Dict[str, Union[torch.Tensor, numbers.Number]]): + def __init__(self, data: Dict[str, Union[torch.Tensor, numbers.Number]]) -> None: super().__init__() self.data = data self._intervener_id = str(id(self)) diff --git a/pyro/poutine/guide.py b/pyro/poutine/guide.py index c50d51bb77..6685fc4df4 100644 --- a/pyro/poutine/guide.py +++ b/pyro/poutine/guide.py @@ -2,17 +2,19 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Callable, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union import torch import pyro.distributions as dist -from pyro.distributions.torch_distribution import TorchDistributionMixin -from pyro.poutine.runtime import Message from pyro.poutine.trace_messenger import TraceMessenger from pyro.poutine.trace_struct import Trace from pyro.poutine.util import prune_subsample_sites, site_is_subsample +if TYPE_CHECKING: + from pyro.distributions.torch_distribution import TorchDistributionMixin + from pyro.poutine.runtime import Message + class GuideMessenger(TraceMessenger, ABC): """ @@ -60,12 +62,13 @@ def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[ samples[name] = site["value"] return samples - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: if msg["is_observed"] or site_is_subsample(msg): return - assert isinstance(msg["name"], str) - assert isinstance(msg["fn"], TorchDistributionMixin) - assert msg["infer"] is not None + if TYPE_CHECKING: + assert isinstance(msg["name"], str) + assert isinstance(msg["fn"], TorchDistributionMixin) + assert msg["infer"] is not None prior = msg["fn"] msg["infer"]["prior"] = prior posterior = self.get_posterior(msg["name"], prior) @@ -75,20 +78,21 @@ def _pyro_sample(self, msg: Message) -> None: posterior = posterior.expand(prior.batch_shape) msg["fn"] = posterior - def _pyro_post_sample(self, msg: Message) -> None: + def _pyro_post_sample(self, msg: "Message") -> None: # Manually apply outer plates. assert msg["infer"] is not None prior = msg["infer"].get("prior") if prior is not None: - assert isinstance(msg["fn"], TorchDistributionMixin) + if TYPE_CHECKING: + assert isinstance(msg["fn"], TorchDistributionMixin) if prior.batch_shape != msg["fn"].batch_shape: msg["infer"]["prior"] = prior.expand(msg["fn"].batch_shape) return super()._pyro_post_sample(msg) @abstractmethod def get_posterior( - self, name: str, prior: TorchDistributionMixin - ) -> Union[TorchDistributionMixin, torch.Tensor]: + self, name: str, prior: "TorchDistributionMixin" + ) -> Union["TorchDistributionMixin", torch.Tensor]: """ Abstract method to compute a posterior distribution or sample a posterior value given a prior distribution conditioned on upstream @@ -148,6 +152,7 @@ def get_traces(self) -> Tuple[Trace, Trace]: del guide_trace.nodes[name] continue model_site = model_trace.nodes[name].copy() + assert guide_site["infer"] is not None model_site["fn"] = guide_site["infer"]["prior"] model_trace.nodes[name] = model_site return model_trace, guide_trace diff --git a/pyro/poutine/handlers.py b/pyro/poutine/handlers.py index 969c5476d6..f0116b9be2 100644 --- a/pyro/poutine/handlers.py +++ b/pyro/poutine/handlers.py @@ -51,6 +51,9 @@ import collections import functools +from typing import Callable, Iterable, Optional, TypeVar, Union, overload + +from typing_extensions import ParamSpec from pyro.poutine import util @@ -74,6 +77,9 @@ from .trace_messenger import TraceMessenger from .uncondition_messenger import UnconditionMessenger +_P = ParamSpec("_P") +_T = TypeVar("_T") + ############################################ # Begin primitive operations ############################################ @@ -276,7 +282,46 @@ def _fn(*args, **kwargs): return wrapper(fn) if fn is not None else wrapper -def markov(fn=None, history=1, keep=False, dim=None, name=None): +@overload +def markov( + fn: None = ..., + history: int = 1, + keep: bool = False, + dim: Optional[int] = None, + name: Optional[str] = None, +) -> MarkovMessenger: + ... + + +@overload +def markov( + fn: Iterable[int] = ..., + history: int = 1, + keep: bool = False, + dim: Optional[int] = None, + name: Optional[str] = None, +) -> MarkovMessenger: + ... + + +@overload +def markov( + fn: Callable[_P, _T] = ..., + history: int = 1, + keep: bool = False, + dim: Optional[int] = None, + name: Optional[str] = None, +) -> Callable[_P, _T]: + ... + + +def markov( + fn: Optional[Union[Iterable[int], Callable]] = None, + history: int = 1, + keep: bool = False, + dim: Optional[int] = None, + name: Optional[str] = None, +) -> Union[MarkovMessenger, Callable]: """ Markov dependency declaration. diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index bfb3f3d1ae..13a175403d 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -26,7 +26,7 @@ def vectorized(self) -> bool: def _key(self) -> Tuple[str, Optional[int], int, int]: with ignore_jit_warnings(["Converting a tensor to a Python number"]): size = ( - self.size.item() if isinstance(self.size, torch.Tensor) else self.size # type: ignore[attr-defined] + self.size.item() if isinstance(self.size, torch.Tensor) else self.size # type: ignore[attr-defined, unreachable] ) return self.name, self.dim, size, self.counter @@ -71,7 +71,7 @@ def __init__( size: int, dim: Optional[int] = None, device: Optional[str] = None, - ): + ) -> None: if not torch._C._get_tracing_state() and size == 0: raise ZeroDivisionError("size cannot be zero") diff --git a/pyro/poutine/infer_config_messenger.py b/pyro/poutine/infer_config_messenger.py index f70e8cbfb8..362ae2cf86 100644 --- a/pyro/poutine/infer_config_messenger.py +++ b/pyro/poutine/infer_config_messenger.py @@ -1,10 +1,12 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from typing import Callable +from typing import TYPE_CHECKING, Callable from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import InferDict, Message + +if TYPE_CHECKING: + from pyro.poutine.runtime import InferDict, Message class InferConfigMessenger(Messenger): @@ -18,7 +20,7 @@ class InferConfigMessenger(Messenger): :returns: stochastic function decorated with :class:`~pyro.poutine.infer_config_messenger.InferConfigMessenger` """ - def __init__(self, config_fn: Callable[[Message], InferDict]): + def __init__(self, config_fn: Callable[["Message"], "InferDict"]) -> None: """ :param config_fn: a callable taking a site and returning an infer dict @@ -28,7 +30,7 @@ def __init__(self, config_fn: Callable[[Message], InferDict]): super().__init__() self.config_fn = config_fn - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: """ :param msg: current message at a trace site. @@ -41,7 +43,7 @@ def _pyro_sample(self, msg: Message) -> None: assert msg["infer"] is not None msg["infer"].update(self.config_fn(msg)) - def _pyro_param(self, msg: Message) -> None: + def _pyro_param(self, msg: "Message") -> None: """ :param msg: current message at a trace site. diff --git a/pyro/poutine/lift_messenger.py b/pyro/poutine/lift_messenger.py index 09703e6825..f40de4d381 100644 --- a/pyro/poutine/lift_messenger.py +++ b/pyro/poutine/lift_messenger.py @@ -2,16 +2,18 @@ # SPDX-License-Identifier: Apache-2.0 import warnings -from typing import Callable, Dict, Set, Union +from typing import TYPE_CHECKING, Callable, Dict, Set, Union from typing_extensions import Self from pyro import params from pyro.distributions.distribution import Distribution from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message from pyro.poutine.util import is_validation_enabled +if TYPE_CHECKING: + from pyro.poutine.runtime import Message + class LiftMessenger(Messenger): """ @@ -55,7 +57,7 @@ def __init__( """ super().__init__() self.prior = prior - self._samples_cache: Dict[str, Message] = {} + self._samples_cache: Dict[str, "Message"] = {} def __enter__(self) -> Self: self._samples_cache = {} @@ -77,10 +79,10 @@ def __exit__(self, *args, **kwargs) -> None: ) return super().__exit__(*args, **kwargs) - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: return None - def _pyro_param(self, msg: Message) -> None: + def _pyro_param(self, msg: "Message") -> None: """ Overrides the `pyro.param` call with samples sampled from the distribution specified in the prior. The prior can be a @@ -118,8 +120,7 @@ def _pyro_param(self, msg: Message) -> None: msg["fn"] = self.prior msg["args"] = msg["args"][1:] else: - # otherwise leave as is - return None + raise TypeError("unreachable") msg["type"] = "sample" if name in self._samples_cache: # Multiple pyro.param statements with the same diff --git a/pyro/poutine/markov_messenger.py b/pyro/poutine/markov_messenger.py index 1d68c9e06a..6c41594fb9 100644 --- a/pyro/poutine/markov_messenger.py +++ b/pyro/poutine/markov_messenger.py @@ -2,9 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 from collections import Counter -from contextlib import ExitStack # python 3 +from contextlib import ExitStack +from typing import TYPE_CHECKING, Iterable, Iterator, List, Optional, Set -from .reentrant_messenger import ReentrantMessenger +from typing_extensions import Self + +from pyro.poutine.reentrant_messenger import ReentrantMessenger + +if TYPE_CHECKING: + from pyro.poutine.runtime import Message class MarkovMessenger(ReentrantMessenger): @@ -27,7 +33,13 @@ class MarkovMessenger(ReentrantMessenger): Interface stub, behavior not yet implemented. """ - def __init__(self, history=1, keep=False, dim=None, name=None): + def __init__( + self, + history: int = 1, + keep: bool = False, + dim: Optional[int] = None, + name: Optional[str] = None, + ) -> None: assert history >= 0 self.history = history self.keep = keep @@ -41,34 +53,35 @@ def __init__(self, history=1, keep=False, dim=None, name=None): raise NotImplementedError( "vectorized markov not yet implemented, try setting name to None" ) - self._iterable = None + self._iterable: Optional[Iterable[int]] = None self._pos = -1 - self._stack = [] + self._stack: List[Set[str]] = [] super().__init__() - def generator(self, iterable): + def generator(self, iterable: Iterable[int]) -> Self: self._iterable = iterable return self - def __iter__(self): + def __iter__(self) -> Iterator[int]: with ExitStack() as stack: + assert self._iterable is not None for value in self._iterable: stack.enter_context(self) yield value - def __enter__(self): + def __enter__(self) -> Self: self._pos += 1 if len(self._stack) <= self._pos: self._stack.append(set()) return super().__enter__() - def __exit__(self, *args, **kwargs): + def __exit__(self, *args, **kwargs) -> None: if not self.keep: self._stack.pop() self._pos -= 1 return super().__exit__(*args, **kwargs) - def _pyro_sample(self, msg): + def _pyro_sample(self, msg: "Message") -> None: if msg["done"] or type(msg["fn"]).__name__ == "_Subsample": return @@ -76,6 +89,8 @@ def _pyro_sample(self, msg): # go out of scope when any one of their markov contexts exits. # This accounting can be done by users of these fields, # e.g. EnumMessenger. + assert msg["name"] is not None + assert msg["infer"] is not None infer = msg["infer"] scope = infer.setdefault( "_markov_scope", Counter() diff --git a/pyro/poutine/mask_messenger.py b/pyro/poutine/mask_messenger.py index c3c375d8a2..132acf3b33 100644 --- a/pyro/poutine/mask_messenger.py +++ b/pyro/poutine/mask_messenger.py @@ -1,12 +1,14 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from typing import Union +from typing import TYPE_CHECKING, Union import torch from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message + +if TYPE_CHECKING: + from pyro.poutine.runtime import Message class MaskMessenger(Messenger): @@ -34,5 +36,5 @@ def __init__(self, mask: Union[bool, torch.BoolTensor]) -> None: super().__init__() self.mask = mask - def _process_message(self, msg: Message) -> None: + def _process_message(self, msg: "Message") -> None: msg["mask"] = self.mask if msg["mask"] is None else msg["mask"] & self.mask diff --git a/pyro/poutine/reparam_messenger.py b/pyro/poutine/reparam_messenger.py index acf6ff5d40..10405e0330 100644 --- a/pyro/poutine/reparam_messenger.py +++ b/pyro/poutine/reparam_messenger.py @@ -16,12 +16,13 @@ import torch from typing_extensions import ParamSpec -from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message, effectful +from pyro.poutine.runtime import effectful if TYPE_CHECKING: + from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.infer.reparam.reparam import Reparam + from pyro.poutine.runtime import Message _P = ParamSpec("_P") _T = TypeVar("_T") @@ -59,7 +60,7 @@ class ReparamMessenger(Messenger): def __init__( self, - config: Union[Dict[str, "Reparam"], Callable[[Message], Optional["Reparam"]]], + config: Union[Dict[str, "Reparam"], Callable[["Message"], Optional["Reparam"]]], ) -> None: super().__init__() assert isinstance(config, dict) or callable(config) @@ -69,11 +70,12 @@ def __init__( def __call__(self, fn: Callable[_P, _T]) -> Callable[_P, _T]: return ReparamHandler(self, fn) - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: if type(msg["fn"]).__name__ == "_Subsample": return assert msg["name"] is not None - assert isinstance(msg["fn"], TorchDistributionMixin) + if TYPE_CHECKING: + assert isinstance(msg["fn"], TorchDistributionMixin) if isinstance(self.config, dict): reparam = self.config.get(msg["name"]) else: diff --git a/pyro/poutine/replay_messenger.py b/pyro/poutine/replay_messenger.py index 548c971473..9c26490528 100644 --- a/pyro/poutine/replay_messenger.py +++ b/pyro/poutine/replay_messenger.py @@ -1,7 +1,15 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from .messenger import Messenger +from typing import TYPE_CHECKING, Dict, Optional + +from pyro.poutine.messenger import Messenger + +if TYPE_CHECKING: + import torch + + from pyro.poutine.runtime import Message + from pyro.poutine.trace_struct import Trace class ReplayMessenger(Messenger): @@ -32,7 +40,11 @@ class ReplayMessenger(Messenger): :returns: a stochastic function decorated with a :class:`~pyro.poutine.replay_messenger.ReplayMessenger` """ - def __init__(self, trace=None, params=None): + def __init__( + self, + trace: Optional["Trace"] = None, + params: Optional[Dict[str, "torch.Tensor"]] = None, + ) -> None: """ :param trace: a trace whose values should be reused @@ -45,7 +57,7 @@ def __init__(self, trace=None, params=None): self.trace = trace self.params = params - def _pyro_sample(self, msg): + def _pyro_sample(self, msg: "Message") -> None: """ :param msg: current message at a trace site. @@ -56,6 +68,7 @@ def _pyro_sample(self, msg): At a sample site that does not appear in self.trace, reverts to default Messenger._pyro_sample behavior with no additional side effects. """ + assert msg["name"] is not None name = msg["name"] if self.trace is not None and name in self.trace: guide_msg = self.trace.nodes[name] @@ -66,9 +79,8 @@ def _pyro_sample(self, msg): msg["done"] = True msg["value"] = guide_msg["value"] msg["infer"] = guide_msg["infer"] - return None - def _pyro_param(self, msg): + def _pyro_param(self, msg: "Message") -> None: name = msg["name"] if self.params is not None and name in self.params: assert hasattr( @@ -76,4 +88,3 @@ def _pyro_param(self, msg): ), "param {} must be constrained value".format(name) msg["done"] = True msg["value"] = self.params[name] - return None diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index cbb0b6aa73..1a25a3405c 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -28,6 +28,8 @@ _T = TypeVar("_T") if TYPE_CHECKING: + from collections import Counter + from pyro.distributions.score_parts import ScoreParts from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.poutine.indep_messenger import CondIndepStackFrame @@ -57,7 +59,7 @@ class InferDict(TypedDict, total=False): _dim_to_symbol: Dict[int, str] _do_not_trace: bool _enumerate_symbol: str - _markov_scope: Optional[Dict[str, int]] + _markov_scope: "Counter" _enumerate_dim: int _dim_to_id: Dict[int, int] _markov_depth: int diff --git a/pyro/poutine/scale_messenger.py b/pyro/poutine/scale_messenger.py index 121ecf6bc4..48e96c7255 100644 --- a/pyro/poutine/scale_messenger.py +++ b/pyro/poutine/scale_messenger.py @@ -1,14 +1,16 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from typing import Union +from typing import TYPE_CHECKING, Union import torch from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message from pyro.poutine.util import is_validation_enabled +if TYPE_CHECKING: + from pyro.poutine.runtime import Message + class ScaleMessenger(Messenger): """ @@ -47,5 +49,5 @@ def __init__(self, scale: Union[float, torch.Tensor]) -> None: super().__init__() self.scale = scale - def _process_message(self, msg: Message) -> None: + def _process_message(self, msg: "Message") -> None: msg["scale"] = self.scale * msg["scale"] diff --git a/pyro/poutine/subsample_messenger.py b/pyro/poutine/subsample_messenger.py index 6a13a5a528..58fba6faa8 100644 --- a/pyro/poutine/subsample_messenger.py +++ b/pyro/poutine/subsample_messenger.py @@ -165,10 +165,10 @@ def _process_message(self, msg: Message) -> None: full_size=self.size, # used for param initialization ) msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"] - if isinstance(self.size, torch.Tensor) or isinstance( - self.subsample_size, torch.Tensor + if isinstance(self.size, torch.Tensor) or isinstance( # type: ignore[unreachable] + self.subsample_size, torch.Tensor # type: ignore[unreachable] ): - if not isinstance(msg["scale"], torch.Tensor): + if not isinstance(msg["scale"], torch.Tensor): # type: ignore[unreachable] with ignore_jit_warnings(): msg["scale"] = torch.tensor(msg["scale"]) msg["scale"] = msg["scale"] * self.size / self.subsample_size diff --git a/pyro/poutine/substitute_messenger.py b/pyro/poutine/substitute_messenger.py index 2b28616381..69caa6f298 100644 --- a/pyro/poutine/substitute_messenger.py +++ b/pyro/poutine/substitute_messenger.py @@ -2,16 +2,19 @@ # SPDX-License-Identifier: Apache-2.0 import warnings -from typing import Dict, Set +from typing import TYPE_CHECKING, Dict, Set -import torch from typing_extensions import Self from pyro import params from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message from pyro.poutine.util import is_validation_enabled +if TYPE_CHECKING: + import torch + + from pyro.poutine.runtime import Message + class SubstituteMessenger(Messenger): """ @@ -32,14 +35,14 @@ class SubstituteMessenger(Messenger): :returns: ``fn`` decorated with a :class:`~pyro.poutine.substitute_messenger.SubstituteMessenger` """ - def __init__(self, data: Dict[str, torch.Tensor]) -> None: + def __init__(self, data: Dict[str, "torch.Tensor"]) -> None: """ :param data: values for the parameters. Constructor """ super().__init__() self.data = data - self._data_cache: Dict[str, Message] = {} + self._data_cache: Dict[str, "Message"] = {} def __enter__(self) -> Self: self._data_cache = {} @@ -61,10 +64,10 @@ def __exit__(self, *args, **kwargs) -> None: ) return super().__exit__(*args, **kwargs) - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: return None - def _pyro_param(self, msg: Message) -> None: + def _pyro_param(self, msg: "Message") -> None: """ Overrides the `pyro.param` with substituted values. If the param name does not match the name the keys in `data`, diff --git a/pyro/poutine/trace_messenger.py b/pyro/poutine/trace_messenger.py index 2b7609a3b5..735828c4fd 100644 --- a/pyro/poutine/trace_messenger.py +++ b/pyro/poutine/trace_messenger.py @@ -2,15 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 import sys -from typing import Any, Callable, Literal, Optional +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional from typing_extensions import Self from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message from pyro.poutine.trace_struct import Trace from pyro.poutine.util import site_is_subsample +if TYPE_CHECKING: + from pyro.poutine.runtime import Message + def identify_dense_edges(trace: Trace) -> None: """ @@ -134,7 +136,7 @@ def _reset(self) -> None: self.trace = tr super()._reset() - def _pyro_post_sample(self, msg: Message) -> None: + def _pyro_post_sample(self, msg: "Message") -> None: if self.param_only: return assert msg["name"] is not None @@ -145,7 +147,7 @@ def _pyro_post_sample(self, msg: Message) -> None: return self.trace.add_node(msg["name"], **msg.copy()) - def _pyro_post_param(self, msg: Message) -> None: + def _pyro_post_param(self, msg: "Message") -> None: assert msg["name"] is not None self.trace.add_node(msg["name"], **msg.copy()) @@ -162,7 +164,7 @@ class TraceHandler: We can also use this for visualization. """ - def __init__(self, msngr: TraceMessenger, fn: Callable): + def __init__(self, msngr: TraceMessenger, fn: Callable) -> None: self.fn = fn self.msngr = msngr diff --git a/pyro/poutine/trace_struct.py b/pyro/poutine/trace_struct.py index 4c2b58bb23..7b7e286747 100644 --- a/pyro/poutine/trace_struct.py +++ b/pyro/poutine/trace_struct.py @@ -4,6 +4,7 @@ import sys from collections import OrderedDict from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -18,18 +19,21 @@ ) import opt_einsum -import torch -from pyro.distributions.distribution import Distribution from pyro.distributions.score_parts import ScoreParts from pyro.distributions.util import scale_and_mask from pyro.ops.packed import pack -from pyro.poutine.runtime import Message from pyro.poutine.util import is_validation_enabled from pyro.util import warn_if_inf, warn_if_nan +if TYPE_CHECKING: + import torch -def allow_all_sites(name: str, site: Message) -> bool: + from pyro.distributions.distribution import Distribution + from pyro.poutine.runtime import Message + + +def allow_all_sites(name: str, site: "Message") -> bool: return True @@ -95,7 +99,7 @@ def __init__(self, graph_type: Literal["flat", "dense"] = "flat") -> None: graph_type ) self.graph_type = graph_type - self.nodes: OrderedDict[str, Message] = OrderedDict() + self.nodes: OrderedDict[str, "Message"] = OrderedDict() self._succ: OrderedDict[str, Set[str]] = OrderedDict() self._pred: OrderedDict[str, Set[str]] = OrderedDict() @@ -198,8 +202,8 @@ def topological_sort(self, reverse: bool = False) -> List[str]: def log_prob_sum( self, - site_filter: Callable[[str, Message], bool] = allow_all_sites, - ) -> Union[torch.Tensor, float]: + site_filter: Callable[[str, "Message"], bool] = allow_all_sites, + ) -> Union["torch.Tensor", float]: """ Compute the site-wise log probabilities of the trace. Each ``log_prob`` has shape equal to the corresponding ``batch_shape``. @@ -212,7 +216,8 @@ def log_prob_sum( result = 0.0 for name, site in self.nodes.items(): if site["type"] == "sample" and site_filter(name, site): - assert isinstance(site["fn"], Distribution) + if TYPE_CHECKING: + assert isinstance(site["fn"], Distribution) if "log_prob_sum" in site: log_p = site["log_prob_sum"] else: @@ -242,7 +247,7 @@ def log_prob_sum( def compute_log_prob( self, - site_filter: Callable[[str, Message], bool] = allow_all_sites, + site_filter: Callable[[str, "Message"], bool] = allow_all_sites, ) -> None: """ Compute the site-wise log probabilities of the trace. @@ -252,7 +257,8 @@ def compute_log_prob( """ for name, site in self.nodes.items(): if site["type"] == "sample" and site_filter(name, site): - assert isinstance(site["fn"], Distribution) + if TYPE_CHECKING: + assert isinstance(site["fn"], Distribution) if "log_prob" not in site: try: log_p = site["fn"].log_prob( @@ -290,7 +296,8 @@ def compute_score_parts(self) -> None: """ for name, site in self.nodes.items(): if site["type"] == "sample" and "score_parts" not in site: - assert isinstance(site["fn"], Distribution) + if TYPE_CHECKING: + assert isinstance(site["fn"], Distribution) # Note that ScoreParts overloads the multiplication operator # to correctly scale each of its three parts. try: @@ -380,7 +387,7 @@ def nonreparam_stochastic_nodes(self) -> List[str]: """ return list(set(self.stochastic_nodes) - set(self.reparameterized_nodes)) - def iter_stochastic_nodes(self) -> Iterator[Tuple[str, Message]]: + def iter_stochastic_nodes(self) -> Iterator[Tuple[str, "Message"]]: """ :return: an iterator over stochastic nodes in the trace. """ @@ -465,18 +472,22 @@ def pack_tensors(self, plate_to_symbol: Optional[Dict[str, str]] = None) -> None ) ).with_traceback(traceback) from e - def format_shapes(self, title="Trace Shapes:", last_site=None): + def format_shapes( + self, title: str = "Trace Shapes:", last_site: Optional[str] = None + ) -> str: """ Returns a string showing a table of the shapes of all sites in the trace. """ if not self.nodes: return title - rows = [[title]] + rows: List[List[Optional[str]]] = [[title]] rows.append(["Param Sites:"]) for name, site in self.nodes.items(): if site["type"] == "param": + if TYPE_CHECKING: + assert isinstance(site["value"], torch.Tensor) rows.append([name, None] + [str(size) for size in site["value"].shape]) if name == last_site: break @@ -520,7 +531,7 @@ def format_shapes(self, title="Trace Shapes:", last_site=None): return _format_table(rows) -def _format_table(rows): +def _format_table(rows: List[List[Optional[str]]]) -> str: """ Formats a right justified table using None as column separator. """ @@ -538,8 +549,9 @@ def _format_table(rows): column_widths[j] = max(column_widths[j], widths[j]) # justify columns - for i, row in enumerate(rows): - cols = [[], [], []] + justified_rows: List[List[str]] = [] + for row in rows: + cols: List[List[str]] = [[], [], []] j = 0 for cell in row: if cell is None: @@ -552,16 +564,16 @@ def _format_table(rows): else col + [""] * (width - len(col)) for width, col, direction in zip(column_widths, cols, "rrl") ] - rows[i] = sum(cols, []) + justified_rows.append(sum(cols, [])) # compute cell widths - cell_widths = [0] * len(rows[0]) - for row in rows: - for j, cell in enumerate(row): + cell_widths = [0] * len(justified_rows[0]) + for justified_row in justified_rows: + for j, cell in enumerate(justified_row): cell_widths[j] = max(cell_widths[j], len(cell)) # justify cells return "\n".join( - " ".join(cell.rjust(width) for cell, width in zip(row, cell_widths)) - for row in rows + " ".join(cell.rjust(width) for cell, width in zip(justified_row, cell_widths)) + for justified_row in justified_rows ) diff --git a/pyro/poutine/uncondition_messenger.py b/pyro/poutine/uncondition_messenger.py index 1978ba9a85..34febd0543 100644 --- a/pyro/poutine/uncondition_messenger.py +++ b/pyro/poutine/uncondition_messenger.py @@ -1,8 +1,12 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING + from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message + +if TYPE_CHECKING: + from pyro.poutine.runtime import Message class UnconditionMessenger(Messenger): @@ -11,10 +15,10 @@ class UnconditionMessenger(Messenger): distribution, ignoring observations. """ - def __init__(self): + def __init__(self) -> None: super().__init__() - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: """ :param msg: current message at a trace site. diff --git a/pyro/poutine/util.py b/pyro/poutine/util.py index 3a0ec0316b..8c682e2aef 100644 --- a/pyro/poutine/util.py +++ b/pyro/poutine/util.py @@ -1,36 +1,43 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from .. import settings +from typing import TYPE_CHECKING, List, Optional + +from pyro import settings + +if TYPE_CHECKING: + from pyro.distributions.distribution import Distribution + from pyro.poutine.runtime import Message + from pyro.poutine.trace_struct import Trace _VALIDATION_ENABLED = __debug__ settings.register("validate_poutine", __name__, "_VALIDATION_ENABLED") -def enable_validation(is_validate): +def enable_validation(is_validate: bool) -> None: global _VALIDATION_ENABLED _VALIDATION_ENABLED = is_validate -def is_validation_enabled(): +def is_validation_enabled() -> bool: return _VALIDATION_ENABLED -def site_is_subsample(site): +def site_is_subsample(site: "Message") -> bool: """ Determines whether a trace site originated from a subsample statement inside an `plate`. """ return site["type"] == "sample" and type(site["fn"]).__name__ == "_Subsample" -def site_is_factor(site): +def site_is_factor(site: "Message") -> bool: """ Determines whether a trace site originated from a factor statement. """ return site["type"] == "sample" and type(site["fn"]).__name__ == "Unit" -def prune_subsample_sites(trace): +def prune_subsample_sites(trace: "Trace") -> "Trace": """ Copies and removes all subsample sites from a trace. """ @@ -41,7 +48,9 @@ def prune_subsample_sites(trace): return trace -def enum_extend(trace, msg, num_samples=None): +def enum_extend( + trace: "Trace", msg: "Message", num_samples: Optional[int] = None +) -> List["Trace"]: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site @@ -57,18 +66,23 @@ def enum_extend(trace, msg, num_samples=None): num_samples = -1 extended_traces = [] + assert msg["name"] is not None + if TYPE_CHECKING: + assert isinstance(msg["fn"], Distribution) for i, s in enumerate(msg["fn"].enumerate_support(*msg["args"], **msg["kwargs"])): if i > num_samples and num_samples >= 0: break msg_copy = msg.copy() - msg_copy.update(value=s) + msg_copy.update(value=s) # type: ignore[call-arg] tr_cp = trace.copy() tr_cp.add_node(msg["name"], **msg_copy) extended_traces.append(tr_cp) return extended_traces -def mc_extend(trace, msg, num_samples=None): +def mc_extend( + trace: "Trace", msg: "Message", num_samples: Optional[int] = None +) -> List["Trace"]: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site @@ -88,12 +102,13 @@ def mc_extend(trace, msg, num_samples=None): msg_copy = msg.copy() msg_copy["value"] = msg_copy["fn"](*msg_copy["args"], **msg_copy["kwargs"]) tr_cp = trace.copy() + assert msg_copy["name"] is not None tr_cp.add_node(msg_copy["name"], **msg_copy) extended_traces.append(tr_cp) return extended_traces -def discrete_escape(trace, msg): +def discrete_escape(trace: "Trace", msg: "Message") -> bool: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site @@ -105,14 +120,15 @@ def discrete_escape(trace, msg): Subroutine for integrating out discrete variables for variance reduction. """ return ( - (msg["type"] == "sample") - and (not msg["is_observed"]) - and (msg["name"] not in trace) - and (getattr(msg["fn"], "has_enumerate_support", False)) + msg["type"] == "sample" + and not msg["is_observed"] + and msg["name"] is not None + and msg["name"] not in trace + and getattr(msg["fn"], "has_enumerate_support", False) ) -def all_escape(trace, msg): +def all_escape(trace: "Trace", msg: "Message") -> bool: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site @@ -124,7 +140,8 @@ def all_escape(trace, msg): Subroutine for approximately integrating out variables for variance reduction. """ return ( - (msg["type"] == "sample") - and (not msg["is_observed"]) - and (msg["name"] not in trace) + msg["type"] == "sample" + and not msg["is_observed"] + and msg["name"] is not None + and msg["name"] not in trace ) diff --git a/setup.cfg b/setup.cfg index e3ca753137..20d7fc3dd9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,6 +41,7 @@ warn_return_any = True warn_unused_configs = True warn_incomplete_stub = True ignore_missing_imports = True +warn_unreachable = True # Per-module options: