-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathconv_lstm.py
31 lines (22 loc) · 1.11 KB
/
conv_lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from torch import nn
class Conv2dLSTMCell(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(Conv2dLSTMCell, self).__init__()
kwargs = dict(kernel_size=kernel_size, stride=stride, padding=padding)
in_channels += out_channels
self.forget = nn.Conv2d(in_channels, out_channels, **kwargs)
self.input = nn.Conv2d(in_channels, out_channels, **kwargs)
self.output = nn.Conv2d(in_channels, out_channels, **kwargs)
self.state = nn.Conv2d(in_channels, out_channels, **kwargs)
def forward(self, input, states):
(cell, hidden) = states
input = torch.cat((hidden, input), dim=1)
forget_gate = torch.sigmoid(self.forget(input))
input_gate = torch.sigmoid(self.input(input))
output_gate = torch.sigmoid(self.output(input))
state_gate = torch.tanh(self.state(input))
# Update internal cell state
cell = forget_gate * cell + input_gate * state_gate
hidden = output_gate * torch.tanh(cell)
return cell, hidden