diff --git a/pytorch/skipthoughts/skipthoughts.py b/pytorch/skipthoughts/skipthoughts.py index 5166193..66569dc 100644 --- a/pytorch/skipthoughts/skipthoughts.py +++ b/pytorch/skipthoughts/skipthoughts.py @@ -132,7 +132,7 @@ def _select_last_old(self, input, lengths): def _process_lengths(self, input): max_length = input.size(1) - lengths = list(max_length - input.data.eq(0).sum(1).squeeze()) + lengths = list(max_length - input.data.eq(0).sum(1, keepdim=False)) return lengths def _load_rnn(self):