From 5ea14a6ed51b2151c7648cd1dd4b9df6cf622a94 Mon Sep 17 00:00:00 2001 From: Paul Michel Date: Tue, 25 Sep 2018 14:49:46 -0400 Subject: [PATCH] Remove pad tokens from attention during decoding Incidentally this fixes #225 --- pytorch_translate/beam_decode.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_translate/beam_decode.py b/pytorch_translate/beam_decode.py index fe407caf..65668949 100644 --- a/pytorch_translate/beam_decode.py +++ b/pytorch_translate/beam_decode.py @@ -171,6 +171,8 @@ def _generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=No attn = scores.new(bsz * beam_size, src_encoding_len, maxlen + 2) attn_buf = attn.clone() + + nonpad_idxs = src_tokens.ne(self.pad) # list of completed sentences finalized = [[] for i in range(bsz)] @@ -258,11 +260,13 @@ def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None): sents_seen.add(sent) def get_hypo(): - _, alignment = attn_clone[i].max(dim=0) + # remove padding tokens from attn scores + hypo_attn = attn_clone[i][nonpad_idxs[sent]] + _, alignment = hypo_attn.max(dim=0) return { "tokens": tokens_clone[i], "score": score, - "attention": attn_clone[i], # src_len x tgt_len + "attention": hypo_attn, # src_len x tgt_len "alignment": alignment, "positional_scores": pos_scores[i], }