diff --git a/models/transformer.py b/models/transformer.py index dcd536750..0e0c383a2 100644 --- a/models/transformer.py +++ b/models/transformer.py @@ -112,12 +112,6 @@ def forward(self, tgt, memory, if self.return_intermediate: intermediate.append(self.norm(output)) - if self.norm is not None: - output = self.norm(output) - if self.return_intermediate: - intermediate.pop() - intermediate.append(output) - if self.return_intermediate: return torch.stack(intermediate)