diff --git a/pyro/nn/auto_reg_nn.py b/pyro/nn/auto_reg_nn.py index 3ae06bd055..e6a59f5a9e 100644 --- a/pyro/nn/auto_reg_nn.py +++ b/pyro/nn/auto_reg_nn.py @@ -2,13 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 import warnings +from typing import List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn from torch.nn import functional as F -def sample_mask_indices(input_dim, hidden_dim, simple=True): +def sample_mask_indices( + input_dim: int, hidden_dim: int, simple: bool = True +) -> torch.Tensor: """ Samples the indices assigned to hidden units during the construction of MADE masks @@ -33,8 +36,12 @@ def sample_mask_indices(input_dim, hidden_dim, simple=True): def create_mask( - input_dim, context_dim, hidden_dims, permutation, output_dim_multiplier -): + input_dim: int, + context_dim: int, + hidden_dims: List[int], + permutation: torch.LongTensor, + output_dim_multiplier: int, +) -> Tuple[List[torch.Tensor], torch.Tensor]: """ Creates MADE masks for a conditional distribution @@ -109,11 +116,13 @@ class MaskedLinear(nn.Linear): :type bias: bool """ - def __init__(self, in_features, out_features, mask, bias=True): + def __init__( + self, in_features: int, out_features: int, mask: torch.Tensor, bias: bool = True + ) -> None: super().__init__(in_features, out_features, bias) self.register_buffer("mask", mask.data) - def forward(self, _input): + def forward(self, _input: torch.Tensor) -> torch.Tensor: masked_weight = self.weight * self.mask return F.linear(_input, masked_weight, self.bias) @@ -166,14 +175,14 @@ class ConditionalAutoRegressiveNN(nn.Module): def __init__( self, - input_dim, - context_dim, - hidden_dims, - param_dims=[1, 1], - permutation=None, - skip_connections=False, - nonlinearity=nn.ReLU(), - ): + input_dim: int, + context_dim: int, + hidden_dims: List[int], + param_dims: List[int] = [1, 1], + permutation: Optional[torch.LongTensor] = None, + skip_connections: bool = False, + nonlinearity: torch.nn.Module = nn.ReLU(), + ) -> None: super().__init__() if input_dim == 1: warnings.warn( @@ -206,6 +215,7 @@ def __init__( else: # The permutation is chosen by the user P = permutation.type(dtype=torch.int64) + self.permutation: torch.LongTensor self.register_buffer("permutation", P) # Create masks @@ -230,6 +240,7 @@ def __init__( ) self.layers = nn.ModuleList(layers) + self.skip_layer: Optional[MaskedLinear] if skip_connections: self.skip_layer = MaskedLinear( input_dim + context_dim, @@ -243,13 +254,15 @@ def __init__( # Save the nonlinearity self.f = nonlinearity - def get_permutation(self): + def get_permutation(self) -> torch.LongTensor: """ Get the permutation applied to the inputs (by default this is chosen at random) """ return self.permutation - def forward(self, x, context=None): + def forward( + self, x: torch.Tensor, context: Optional[torch.Tensor] = None + ) -> Union[Sequence[torch.Tensor], torch.Tensor]: # We must be able to broadcast the size of the context over the input if context is None: context = self.context @@ -258,7 +271,7 @@ def forward(self, x, context=None): x = torch.cat([context, x], dim=-1) return self._forward(x) - def _forward(self, x): + def _forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]: h = x for layer in self.layers[:-1]: h = self.f(layer(h)) @@ -328,13 +341,13 @@ class AutoRegressiveNN(ConditionalAutoRegressiveNN): def __init__( self, - input_dim, - hidden_dims, - param_dims=[1, 1], - permutation=None, - skip_connections=False, - nonlinearity=nn.ReLU(), - ): + input_dim: int, + hidden_dims: List[int], + param_dims: List[int] = [1, 1], + permutation: Optional[torch.LongTensor] = None, + skip_connections: bool = False, + nonlinearity: torch.nn.Module = nn.ReLU(), + ) -> None: super(AutoRegressiveNN, self).__init__( input_dim, 0, @@ -345,5 +358,5 @@ def __init__( nonlinearity=nonlinearity, ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]: # type: ignore[override] return self._forward(x) diff --git a/pyro/nn/dense_nn.py b/pyro/nn/dense_nn.py index a7a9a7e645..a3cf93af8d 100644 --- a/pyro/nn/dense_nn.py +++ b/pyro/nn/dense_nn.py @@ -1,6 +1,8 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from typing import List, Sequence, Union + import torch @@ -35,12 +37,12 @@ class ConditionalDenseNN(torch.nn.Module): def __init__( self, - input_dim, - context_dim, - hidden_dims, - param_dims=[1, 1], - nonlinearity=torch.nn.ReLU(), - ): + input_dim: int, + context_dim: int, + hidden_dims: List[int], + param_dims: List[int] = [1, 1], + nonlinearity: torch.nn.Module = torch.nn.ReLU(), + ) -> None: super().__init__() self.input_dim = input_dim @@ -65,14 +67,16 @@ def __init__( # Save the nonlinearity self.f = nonlinearity - def forward(self, x, context): + def forward( + self, x: torch.Tensor, context: torch.Tensor + ) -> Union[Sequence[torch.Tensor], torch.Tensor]: # We must be able to broadcast the size of the context over the input context = context.expand(x.size()[:-1] + (context.size(-1),)) x = torch.cat([context, x], dim=-1) return self._forward(x) - def _forward(self, x): + def _forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]: """ The forward method """ @@ -122,11 +126,15 @@ class DenseNN(ConditionalDenseNN): """ def __init__( - self, input_dim, hidden_dims, param_dims=[1, 1], nonlinearity=torch.nn.ReLU() - ): + self, + input_dim: int, + hidden_dims: List[int], + param_dims: List[int] = [1, 1], + nonlinearity: torch.nn.Module = torch.nn.ReLU(), + ) -> None: super(DenseNN, self).__init__( input_dim, 0, hidden_dims, param_dims=param_dims, nonlinearity=nonlinearity ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]: # type: ignore[override] return self._forward(x)