diff --git a/src/nnsight/modeling/language.py b/src/nnsight/modeling/language.py index 6cfa4af..45228a8 100755 --- a/src/nnsight/modeling/language.py +++ b/src/nnsight/modeling/language.py @@ -258,7 +258,7 @@ def _batch( if batched_inputs is None: return ((input,), {"labels": labels}) - + batched_labels = batched_inputs[1]["labels"] batched_inputs = batched_inputs[0][0] @@ -275,8 +275,18 @@ def _batch( if labels is not None: batched_labels = torch.cat((batched_labels, labels)) - - batched_inputs["attention_mask"][:-1, : attention_mask.shape[1]] = attention_mask + + if self.tokenizer.padding_side == "left": + + batched_inputs["attention_mask"][ + : attention_mask.shape[0], -attention_mask.shape[1] : + ] = attention_mask + + else: + + batched_inputs["attention_mask"][ + : attention_mask.shape[0], : attention_mask.shape[1] + ] = attention_mask return ((batched_inputs,), {"labels": batched_labels})