Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotate pyro.nn.dense_nn and pyro.nn.auto_reg_nn #3342

Merged
merged 2 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 37 additions & 24 deletions pyro/nn/auto_reg_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import List, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
from torch.nn import functional as F


def sample_mask_indices(input_dim, hidden_dim, simple=True):
def sample_mask_indices(
input_dim: int, hidden_dim: int, simple: bool = True
) -> torch.Tensor:
"""
Samples the indices assigned to hidden units during the construction of MADE masks

Expand All @@ -33,8 +36,12 @@ def sample_mask_indices(input_dim, hidden_dim, simple=True):


def create_mask(
input_dim, context_dim, hidden_dims, permutation, output_dim_multiplier
):
input_dim: int,
context_dim: int,
hidden_dims: List[int],
permutation: torch.LongTensor,
output_dim_multiplier: int,
) -> Tuple[List[torch.Tensor], torch.Tensor]:
"""
Creates MADE masks for a conditional distribution

Expand Down Expand Up @@ -109,11 +116,13 @@ class MaskedLinear(nn.Linear):
:type bias: bool
"""

def __init__(self, in_features, out_features, mask, bias=True):
def __init__(
self, in_features: int, out_features: int, mask: torch.Tensor, bias: bool = True
) -> None:
super().__init__(in_features, out_features, bias)
self.register_buffer("mask", mask.data)

def forward(self, _input):
def forward(self, _input: torch.Tensor) -> torch.Tensor:
masked_weight = self.weight * self.mask
return F.linear(_input, masked_weight, self.bias)

Expand Down Expand Up @@ -166,14 +175,14 @@ class ConditionalAutoRegressiveNN(nn.Module):

def __init__(
self,
input_dim,
context_dim,
hidden_dims,
param_dims=[1, 1],
permutation=None,
skip_connections=False,
nonlinearity=nn.ReLU(),
):
input_dim: int,
context_dim: int,
hidden_dims: List[int],
param_dims: List[int] = [1, 1],
permutation: Optional[torch.LongTensor] = None,
skip_connections: bool = False,
nonlinearity: torch.nn.Module = nn.ReLU(),
) -> None:
super().__init__()
if input_dim == 1:
warnings.warn(
Expand Down Expand Up @@ -206,6 +215,7 @@ def __init__(
else:
# The permutation is chosen by the user
P = permutation.type(dtype=torch.int64)
self.permutation: torch.LongTensor
self.register_buffer("permutation", P)

# Create masks
Expand All @@ -230,6 +240,7 @@ def __init__(
)
self.layers = nn.ModuleList(layers)

self.skip_layer: Optional[MaskedLinear]
if skip_connections:
self.skip_layer = MaskedLinear(
input_dim + context_dim,
Expand All @@ -243,13 +254,15 @@ def __init__(
# Save the nonlinearity
self.f = nonlinearity

def get_permutation(self):
def get_permutation(self) -> torch.LongTensor:
"""
Get the permutation applied to the inputs (by default this is chosen at random)
"""
return self.permutation

def forward(self, x, context=None):
def forward(
self, x: torch.Tensor, context: Optional[torch.Tensor] = None
) -> Union[Sequence[torch.Tensor], torch.Tensor]:
# We must be able to broadcast the size of the context over the input
if context is None:
context = self.context
Expand All @@ -258,7 +271,7 @@ def forward(self, x, context=None):
x = torch.cat([context, x], dim=-1)
return self._forward(x)

def _forward(self, x):
def _forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]:
h = x
for layer in self.layers[:-1]:
h = self.f(layer(h))
Expand Down Expand Up @@ -328,13 +341,13 @@ class AutoRegressiveNN(ConditionalAutoRegressiveNN):

def __init__(
self,
input_dim,
hidden_dims,
param_dims=[1, 1],
permutation=None,
skip_connections=False,
nonlinearity=nn.ReLU(),
):
input_dim: int,
hidden_dims: List[int],
param_dims: List[int] = [1, 1],
permutation: Optional[torch.LongTensor] = None,
skip_connections: bool = False,
nonlinearity: torch.nn.Module = nn.ReLU(),
) -> None:
super(AutoRegressiveNN, self).__init__(
input_dim,
0,
Expand All @@ -345,5 +358,5 @@ def __init__(
nonlinearity=nonlinearity,
)

def forward(self, x):
def forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]: # type: ignore[override]
return self._forward(x)
30 changes: 19 additions & 11 deletions pyro/nn/dense_nn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import List, Sequence, Union

import torch


Expand Down Expand Up @@ -35,12 +37,12 @@ class ConditionalDenseNN(torch.nn.Module):

def __init__(
self,
input_dim,
context_dim,
hidden_dims,
param_dims=[1, 1],
nonlinearity=torch.nn.ReLU(),
):
input_dim: int,
context_dim: int,
hidden_dims: List[int],
param_dims: List[int] = [1, 1],
nonlinearity: torch.nn.Module = torch.nn.ReLU(),
) -> None:
super().__init__()

self.input_dim = input_dim
Expand All @@ -65,14 +67,16 @@ def __init__(
# Save the nonlinearity
self.f = nonlinearity

def forward(self, x, context):
def forward(
self, x: torch.Tensor, context: torch.Tensor
) -> Union[Sequence[torch.Tensor], torch.Tensor]:
# We must be able to broadcast the size of the context over the input
context = context.expand(x.size()[:-1] + (context.size(-1),))

x = torch.cat([context, x], dim=-1)
return self._forward(x)

def _forward(self, x):
def _forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]:
"""
The forward method
"""
Expand Down Expand Up @@ -122,11 +126,15 @@ class DenseNN(ConditionalDenseNN):
"""

def __init__(
self, input_dim, hidden_dims, param_dims=[1, 1], nonlinearity=torch.nn.ReLU()
):
self,
input_dim: int,
hidden_dims: List[int],
param_dims: List[int] = [1, 1],
nonlinearity: torch.nn.Module = torch.nn.ReLU(),
) -> None:
super(DenseNN, self).__init__(
input_dim, 0, hidden_dims, param_dims=param_dims, nonlinearity=nonlinearity
)

def forward(self, x):
def forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]: # type: ignore[override]
return self._forward(x)
Loading