From 26594128c198681d890354166e0cb57bd6fbfd23 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 18 Mar 2024 01:16:39 +0000 Subject: [PATCH 1/2] type annotate pyro.nn --- pyro/nn/auto_reg_nn.py | 62 ++++++++++++++++++++++++++---------------- pyro/nn/dense_nn.py | 31 +++++++++++++-------- 2 files changed, 58 insertions(+), 35 deletions(-) diff --git a/pyro/nn/auto_reg_nn.py b/pyro/nn/auto_reg_nn.py index 3ae06bd055..8ef97ae051 100644 --- a/pyro/nn/auto_reg_nn.py +++ b/pyro/nn/auto_reg_nn.py @@ -2,13 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 import warnings +from collections.abc import Sequence +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from torch.nn import functional as F -def sample_mask_indices(input_dim, hidden_dim, simple=True): +def sample_mask_indices( + input_dim: int, hidden_dim: int, simple: bool = True +) -> torch.Tensor: """ Samples the indices assigned to hidden units during the construction of MADE masks @@ -33,8 +37,12 @@ def sample_mask_indices(input_dim, hidden_dim, simple=True): def create_mask( - input_dim, context_dim, hidden_dims, permutation, output_dim_multiplier -): + input_dim: int, + context_dim: int, + hidden_dims: List[int], + permutation: torch.LongTensor, + output_dim_multiplier: int, +) -> Tuple[List[torch.Tensor], torch.Tensor]: """ Creates MADE masks for a conditional distribution @@ -109,11 +117,13 @@ class MaskedLinear(nn.Linear): :type bias: bool """ - def __init__(self, in_features, out_features, mask, bias=True): + def __init__( + self, in_features: int, out_features: int, mask: torch.Tensor, bias: bool = True + ) -> None: super().__init__(in_features, out_features, bias) self.register_buffer("mask", mask.data) - def forward(self, _input): + def forward(self, _input: torch.Tensor) -> torch.Tensor: masked_weight = self.weight * self.mask return F.linear(_input, masked_weight, self.bias) @@ -166,14 +176,14 @@ class ConditionalAutoRegressiveNN(nn.Module): def __init__( self, - input_dim, - context_dim, - hidden_dims, - param_dims=[1, 1], - permutation=None, - skip_connections=False, - nonlinearity=nn.ReLU(), - ): + input_dim: int, + context_dim: int, + hidden_dims: List[int], + param_dims: List[int] = [1, 1], + permutation: Optional[torch.LongTensor] = None, + skip_connections: bool = False, + nonlinearity: torch.nn.Module = nn.ReLU(), + ) -> None: super().__init__() if input_dim == 1: warnings.warn( @@ -206,6 +216,7 @@ def __init__( else: # The permutation is chosen by the user P = permutation.type(dtype=torch.int64) + self.permutation: torch.LongTensor self.register_buffer("permutation", P) # Create masks @@ -230,6 +241,7 @@ def __init__( ) self.layers = nn.ModuleList(layers) + self.skip_layer: Optional[MaskedLinear] if skip_connections: self.skip_layer = MaskedLinear( input_dim + context_dim, @@ -243,13 +255,15 @@ def __init__( # Save the nonlinearity self.f = nonlinearity - def get_permutation(self): + def get_permutation(self) -> torch.LongTensor: """ Get the permutation applied to the inputs (by default this is chosen at random) """ return self.permutation - def forward(self, x, context=None): + def forward( + self, x: torch.Tensor, context: Optional[torch.Tensor] = None + ) -> Union[Sequence[torch.Tensor], torch.Tensor]: # We must be able to broadcast the size of the context over the input if context is None: context = self.context @@ -258,7 +272,7 @@ def forward(self, x, context=None): x = torch.cat([context, x], dim=-1) return self._forward(x) - def _forward(self, x): + def _forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]: h = x for layer in self.layers[:-1]: h = self.f(layer(h)) @@ -328,13 +342,13 @@ class AutoRegressiveNN(ConditionalAutoRegressiveNN): def __init__( self, - input_dim, - hidden_dims, - param_dims=[1, 1], - permutation=None, - skip_connections=False, - nonlinearity=nn.ReLU(), - ): + input_dim: int, + hidden_dims: List[int], + param_dims: List[int] = [1, 1], + permutation: Optional[torch.LongTensor] = None, + skip_connections: bool = False, + nonlinearity: torch.nn.Module = nn.ReLU(), + ) -> None: super(AutoRegressiveNN, self).__init__( input_dim, 0, @@ -345,5 +359,5 @@ def __init__( nonlinearity=nonlinearity, ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]: # type: ignore[override] return self._forward(x) diff --git a/pyro/nn/dense_nn.py b/pyro/nn/dense_nn.py index a7a9a7e645..ae4b2e29b8 100644 --- a/pyro/nn/dense_nn.py +++ b/pyro/nn/dense_nn.py @@ -1,6 +1,9 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence +from typing import List, Union + import torch @@ -35,12 +38,12 @@ class ConditionalDenseNN(torch.nn.Module): def __init__( self, - input_dim, - context_dim, - hidden_dims, - param_dims=[1, 1], - nonlinearity=torch.nn.ReLU(), - ): + input_dim: int, + context_dim: int, + hidden_dims: List[int], + param_dims: List[int] = [1, 1], + nonlinearity: torch.nn.Module = torch.nn.ReLU(), + ) -> None: super().__init__() self.input_dim = input_dim @@ -65,14 +68,16 @@ def __init__( # Save the nonlinearity self.f = nonlinearity - def forward(self, x, context): + def forward( + self, x: torch.Tensor, context: torch.Tensor + ) -> Union[Sequence[torch.Tensor], torch.Tensor]: # We must be able to broadcast the size of the context over the input context = context.expand(x.size()[:-1] + (context.size(-1),)) x = torch.cat([context, x], dim=-1) return self._forward(x) - def _forward(self, x): + def _forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]: """ The forward method """ @@ -122,11 +127,15 @@ class DenseNN(ConditionalDenseNN): """ def __init__( - self, input_dim, hidden_dims, param_dims=[1, 1], nonlinearity=torch.nn.ReLU() - ): + self, + input_dim: int, + hidden_dims: List[int], + param_dims: List[int] = [1, 1], + nonlinearity: torch.nn.Module = torch.nn.ReLU(), + ) -> None: super(DenseNN, self).__init__( input_dim, 0, hidden_dims, param_dims=param_dims, nonlinearity=nonlinearity ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]: # type: ignore[override] return self._forward(x) From 8d7e227836fe2dcd3c288c7abce4435b87db89e5 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 18 Mar 2024 01:28:17 +0000 Subject: [PATCH 2/2] typing.Sequence --- pyro/nn/auto_reg_nn.py | 3 +-- pyro/nn/dense_nn.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pyro/nn/auto_reg_nn.py b/pyro/nn/auto_reg_nn.py index 8ef97ae051..e6a59f5a9e 100644 --- a/pyro/nn/auto_reg_nn.py +++ b/pyro/nn/auto_reg_nn.py @@ -2,8 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import warnings -from collections.abc import Sequence -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn diff --git a/pyro/nn/dense_nn.py b/pyro/nn/dense_nn.py index ae4b2e29b8..a3cf93af8d 100644 --- a/pyro/nn/dense_nn.py +++ b/pyro/nn/dense_nn.py @@ -1,8 +1,7 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from collections.abc import Sequence -from typing import List, Union +from typing import List, Sequence, Union import torch