From 17c1f47376eff46e4117f4cd9e65cefddc198de4 Mon Sep 17 00:00:00 2001 From: James Cross Date: Tue, 5 Feb 2019 19:33:02 -0800 Subject: [PATCH 1/2] model export: remove extraneous model.eval() Differential Revision: D13969501 fbshipit-source-id: 525464bd46f7d5c925e2392ba93f6ef6533dfb63 --- pytorch_translate/ensemble_export.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pytorch_translate/ensemble_export.py b/pytorch_translate/ensemble_export.py index 386e3509..be268579 100644 --- a/pytorch_translate/ensemble_export.py +++ b/pytorch_translate/ensemble_export.py @@ -267,9 +267,6 @@ def forward(self, src_tokens, src_lengths): torch.jit._fork(model.encoder, src_tokens_seq_first, src_lengths) ) - # evaluation mode - model.eval() - for i, (model, future) in enumerate(zip(self.models, futures)): if isinstance(model.encoder, TransformerEncoder): encoder_out = future @@ -1361,9 +1358,6 @@ def forward(self, src_tokens, src_lengths, char_inds, word_lengths): src_tokens_seq_first, src_lengths, char_inds, word_lengths ) - # evaluation mode - model.eval() - # "primary" encoder output (vector representations per source token) encoder_outputs = encoder_out[0] outputs.append(encoder_outputs) From c2fbd68cf35f9d91dc7df8552a820ba33a5103fa Mon Sep 17 00:00:00 2001 From: James Cross Date: Tue, 5 Feb 2019 19:33:16 -0800 Subject: [PATCH 2/2] return empty tensor instead of None (#332) Summary: Pull Request resolved: https://github.com/pytorch/translate/pull/332 To allow efficient use of fork/join annotation, we return an empty tensor instead of `None` for `encoder_padding_mask` from transformer encoder in the unmasked/inference case. Note that this slight hack is preferable to more far-reaching changes in, e.g., Fairseq multihead_attention. Differential Revision: D13969691 fbshipit-source-id: 862ed44019012449554527f236cb344046c75184 --- pytorch_translate/ensemble_export.py | 18 ++++-------------- pytorch_translate/hybrid_transformer_rnn.py | 3 +++ pytorch_translate/transformer.py | 9 ++++++++- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pytorch_translate/ensemble_export.py b/pytorch_translate/ensemble_export.py index be268579..96959a08 100644 --- a/pytorch_translate/ensemble_export.py +++ b/pytorch_translate/ensemble_export.py @@ -256,22 +256,12 @@ def forward(self, src_tokens, src_lengths): # evaluation mode model.eval() - # TODO(jamesreed): transformer encodder returns a None output, and - # the fork/join API doesn't handle that well. We should figure out - # a way to annotate outputs as Optional and record that in fork/join - # traces. - if isinstance(model.encoder, TransformerEncoder): - futures.append(model.encoder(src_tokens_seq_first, src_lengths)) - else: - futures.append( - torch.jit._fork(model.encoder, src_tokens_seq_first, src_lengths) - ) + futures.append( + torch.jit._fork(model.encoder, src_tokens_seq_first, src_lengths) + ) for i, (model, future) in enumerate(zip(self.models, futures)): - if isinstance(model.encoder, TransformerEncoder): - encoder_out = future - else: - encoder_out = torch.jit._wait(future) + encoder_out = torch.jit._wait(future) # "primary" encoder output (vector representations per source token) encoder_outputs = encoder_out[0] outputs.append(encoder_outputs) diff --git a/pytorch_translate/hybrid_transformer_rnn.py b/pytorch_translate/hybrid_transformer_rnn.py index 71fee2e8..f6fd6502 100644 --- a/pytorch_translate/hybrid_transformer_rnn.py +++ b/pytorch_translate/hybrid_transformer_rnn.py @@ -247,6 +247,9 @@ def forward( ): (encoder_x, src_tokens, encoder_padding_mask) = encoder_out + if encoder_padding_mask is not None and encoder_padding_mask.numel() == 0: + encoder_padding_mask = None + bsz, seqlen = prev_output_tokens.size() if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] diff --git a/pytorch_translate/transformer.py b/pytorch_translate/transformer.py index f71998c0..dac29081 100644 --- a/pytorch_translate/transformer.py +++ b/pytorch_translate/transformer.py @@ -277,6 +277,10 @@ def forward(self, src_tokens, src_lengths): x=x, positions=positions, encoder_padding_mask=encoder_padding_mask ) + if encoder_padding_mask is None: + # using an empty tensor instead of None for PyTorch native export + encoder_padding_mask = torch.Tensor().type_as(src_tokens) + return x, src_tokens, encoder_padding_mask def reorder_encoder_out(self, encoder_out, new_order): @@ -285,7 +289,7 @@ def reorder_encoder_out(self, encoder_out, new_order): x = x.index_select(1, new_order) if src_tokens is not None: src_tokens = src_tokens.index_select(0, new_order) - if encoder_padding_mask is not None: + if encoder_padding_mask is not None and encoder_padding_mask.numel() != 0: encoder_padding_mask = encoder_padding_mask.index_select(0, new_order) return (x, src_tokens, encoder_padding_mask) @@ -382,6 +386,9 @@ def forward( ): (encoder_x, src_tokens, encoder_padding_mask) = encoder_out + if encoder_padding_mask is not None and encoder_padding_mask.numel() == 0: + encoder_padding_mask = None + # embed positions positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state, timestep=timestep