Skip to content
This repository has been archived by the owner on Aug 1, 2023. It is now read-only.

return empty tensor instead of None #332

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 4 additions & 20 deletions pytorch_translate/ensemble_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,25 +256,12 @@ def forward(self, src_tokens, src_lengths):
# evaluation mode
model.eval()

# TODO(jamesreed): transformer encodder returns a None output, and
# the fork/join API doesn't handle that well. We should figure out
# a way to annotate outputs as Optional and record that in fork/join
# traces.
if isinstance(model.encoder, TransformerEncoder):
futures.append(model.encoder(src_tokens_seq_first, src_lengths))
else:
futures.append(
torch.jit._fork(model.encoder, src_tokens_seq_first, src_lengths)
)

# evaluation mode
model.eval()
futures.append(
torch.jit._fork(model.encoder, src_tokens_seq_first, src_lengths)
)

for i, (model, future) in enumerate(zip(self.models, futures)):
if isinstance(model.encoder, TransformerEncoder):
encoder_out = future
else:
encoder_out = torch.jit._wait(future)
encoder_out = torch.jit._wait(future)
# "primary" encoder output (vector representations per source token)
encoder_outputs = encoder_out[0]
outputs.append(encoder_outputs)
Expand Down Expand Up @@ -1361,9 +1348,6 @@ def forward(self, src_tokens, src_lengths, char_inds, word_lengths):
src_tokens_seq_first, src_lengths, char_inds, word_lengths
)

# evaluation mode
model.eval()

# "primary" encoder output (vector representations per source token)
encoder_outputs = encoder_out[0]
outputs.append(encoder_outputs)
Expand Down
3 changes: 3 additions & 0 deletions pytorch_translate/hybrid_transformer_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ def forward(
):
(encoder_x, src_tokens, encoder_padding_mask) = encoder_out

if encoder_padding_mask is not None and encoder_padding_mask.numel() == 0:
encoder_padding_mask = None

bsz, seqlen = prev_output_tokens.size()
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
Expand Down
9 changes: 8 additions & 1 deletion pytorch_translate/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ def forward(self, src_tokens, src_lengths):
x=x, positions=positions, encoder_padding_mask=encoder_padding_mask
)

if encoder_padding_mask is None:
# using an empty tensor instead of None for PyTorch native export
encoder_padding_mask = torch.Tensor().type_as(src_tokens)

return x, src_tokens, encoder_padding_mask

def reorder_encoder_out(self, encoder_out, new_order):
Expand All @@ -285,7 +289,7 @@ def reorder_encoder_out(self, encoder_out, new_order):
x = x.index_select(1, new_order)
if src_tokens is not None:
src_tokens = src_tokens.index_select(0, new_order)
if encoder_padding_mask is not None:
if encoder_padding_mask is not None and encoder_padding_mask.numel() != 0:
encoder_padding_mask = encoder_padding_mask.index_select(0, new_order)
return (x, src_tokens, encoder_padding_mask)

Expand Down Expand Up @@ -382,6 +386,9 @@ def forward(
):
(encoder_x, src_tokens, encoder_padding_mask) = encoder_out

if encoder_padding_mask is not None and encoder_padding_mask.numel() == 0:
encoder_padding_mask = None

# embed positions
positions = self.embed_positions(
prev_output_tokens, incremental_state=incremental_state, timestep=timestep
Expand Down