diff --git a/pytorch_translate/ensemble_export.py b/pytorch_translate/ensemble_export.py index 386e3509..96959a08 100644 --- a/pytorch_translate/ensemble_export.py +++ b/pytorch_translate/ensemble_export.py @@ -256,25 +256,12 @@ def forward(self, src_tokens, src_lengths): # evaluation mode model.eval() - # TODO(jamesreed): transformer encodder returns a None output, and - # the fork/join API doesn't handle that well. We should figure out - # a way to annotate outputs as Optional and record that in fork/join - # traces. - if isinstance(model.encoder, TransformerEncoder): - futures.append(model.encoder(src_tokens_seq_first, src_lengths)) - else: - futures.append( - torch.jit._fork(model.encoder, src_tokens_seq_first, src_lengths) - ) - - # evaluation mode - model.eval() + futures.append( + torch.jit._fork(model.encoder, src_tokens_seq_first, src_lengths) + ) for i, (model, future) in enumerate(zip(self.models, futures)): - if isinstance(model.encoder, TransformerEncoder): - encoder_out = future - else: - encoder_out = torch.jit._wait(future) + encoder_out = torch.jit._wait(future) # "primary" encoder output (vector representations per source token) encoder_outputs = encoder_out[0] outputs.append(encoder_outputs) @@ -1361,9 +1348,6 @@ def forward(self, src_tokens, src_lengths, char_inds, word_lengths): src_tokens_seq_first, src_lengths, char_inds, word_lengths ) - # evaluation mode - model.eval() - # "primary" encoder output (vector representations per source token) encoder_outputs = encoder_out[0] outputs.append(encoder_outputs) diff --git a/pytorch_translate/hybrid_transformer_rnn.py b/pytorch_translate/hybrid_transformer_rnn.py index 71fee2e8..f6fd6502 100644 --- a/pytorch_translate/hybrid_transformer_rnn.py +++ b/pytorch_translate/hybrid_transformer_rnn.py @@ -247,6 +247,9 @@ def forward( ): (encoder_x, src_tokens, encoder_padding_mask) = encoder_out + if encoder_padding_mask is not None and encoder_padding_mask.numel() == 0: + encoder_padding_mask = None + bsz, seqlen = prev_output_tokens.size() if incremental_state is not None: prev_output_tokens = prev_output_tokens[:, -1:] diff --git a/pytorch_translate/transformer.py b/pytorch_translate/transformer.py index f71998c0..dac29081 100644 --- a/pytorch_translate/transformer.py +++ b/pytorch_translate/transformer.py @@ -277,6 +277,10 @@ def forward(self, src_tokens, src_lengths): x=x, positions=positions, encoder_padding_mask=encoder_padding_mask ) + if encoder_padding_mask is None: + # using an empty tensor instead of None for PyTorch native export + encoder_padding_mask = torch.Tensor().type_as(src_tokens) + return x, src_tokens, encoder_padding_mask def reorder_encoder_out(self, encoder_out, new_order): @@ -285,7 +289,7 @@ def reorder_encoder_out(self, encoder_out, new_order): x = x.index_select(1, new_order) if src_tokens is not None: src_tokens = src_tokens.index_select(0, new_order) - if encoder_padding_mask is not None: + if encoder_padding_mask is not None and encoder_padding_mask.numel() != 0: encoder_padding_mask = encoder_padding_mask.index_select(0, new_order) return (x, src_tokens, encoder_padding_mask) @@ -382,6 +386,9 @@ def forward( ): (encoder_x, src_tokens, encoder_padding_mask) = encoder_out + if encoder_padding_mask is not None and encoder_padding_mask.numel() == 0: + encoder_padding_mask = None + # embed positions positions = self.embed_positions( prev_output_tokens, incremental_state=incremental_state, timestep=timestep