Skip to content

Commit

Permalink
Use torch.tanh.
Browse files Browse the repository at this point in the history
  • Loading branch information
tibuch committed Apr 6, 2021
1 parent a2a407d commit cd83c90
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions fit/transformers/SResTransformer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from fast_transformers.builders import TransformerEncoderBuilder, RecurrentEncoderBuilder
from fast_transformers.masking import TriangularCausalMask
from torch.nn import functional as F

from fit.transformers.PositionalEncoding2D import PositionalEncoding2D

Expand Down Expand Up @@ -53,7 +52,7 @@ def forward(self, x):
triangular_mask = TriangularCausalMask(x.shape[1], device=x.device)
y_hat = self.encoder(x, attn_mask=triangular_mask)
y_amp = self.predictor_amp(y_hat)
y_phase = F.tanh(self.predictor_phase(y_hat))
y_phase = torch.tanh(self.predictor_phase(y_hat))
return torch.cat([y_amp, y_phase], dim=-1)


Expand Down Expand Up @@ -100,5 +99,5 @@ def forward(self, x, i=0, memory=None):
x = self.pos_embedding.forward_i(x, i)
y_hat, memory = self.encoder(x, memory)
y_amp = self.predictor_amp(y_hat)
y_phase = F.tanh(self.predictor_phase(y_hat))
y_phase = torch.tanh(self.predictor_phase(y_hat))
return torch.cat([y_amp, y_phase], dim=-1), memory

0 comments on commit cd83c90

Please sign in to comment.