From d3e46b1e538629d6a872d371d67b9f288127f317 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Mon, 23 Sep 2024 14:08:40 +0300 Subject: [PATCH 1/2] Fix EqualizeMessenger type annotations. --- pyro/poutine/handlers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyro/poutine/handlers.py b/pyro/poutine/handlers.py index 343b1a1f4b..bc1ba91de8 100644 --- a/pyro/poutine/handlers.py +++ b/pyro/poutine/handlers.py @@ -306,7 +306,8 @@ def escape( # type: ignore[empty-body] def equalize( sites: Union[str, List[str]], type: Optional[str], -) -> ConditionMessenger: ... + keep_dist: Optional[bool], +) -> EqualizeMessenger: ... @overload @@ -314,6 +315,7 @@ def equalize( fn: Callable[_P, _T], sites: Union[str, List[str]], type: Optional[str], + keep_dist: Optional[bool], ) -> Callable[_P, _T]: ... @@ -322,6 +324,7 @@ def equalize( # type: ignore[empty-body] fn: Callable[_P, _T], sites: Union[str, List[str]], type: Optional[str], + keep_dist: Optional[bool], ) -> Union[EqualizeMessenger, Callable[_P, _T]]: ... From d27c2f1bc1d386761032da51e7273e5dfcc543d0 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Tue, 24 Sep 2024 08:50:35 +0300 Subject: [PATCH 2/2] Change keep_dist typing to Optional[bool]. --- pyro/poutine/equalize_messenger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/poutine/equalize_messenger.py b/pyro/poutine/equalize_messenger.py index 1bc79a5521..31e524a651 100644 --- a/pyro/poutine/equalize_messenger.py +++ b/pyro/poutine/equalize_messenger.py @@ -68,7 +68,7 @@ def __init__( self, sites: Union[str, List[str]], type: Optional[str] = "sample", - keep_dist: bool = False, + keep_dist: Optional[bool] = False, ) -> None: super().__init__() self.sites = [sites] if isinstance(sites, str) else sites