Skip to content

Commit

Permalink
Language Model Batching fix
Browse files Browse the repository at this point in the history
  • Loading branch information
JadenFiotto-Kaufman committed Feb 4, 2025
1 parent 90a1af6 commit f2b73e3
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions src/nnsight/modeling/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _batch(

if batched_inputs is None:
return ((input,), {"labels": labels})

batched_labels = batched_inputs[1]["labels"]
batched_inputs = batched_inputs[0][0]

Expand All @@ -275,8 +275,18 @@ def _batch(
if labels is not None:

batched_labels = torch.cat((batched_labels, labels))

batched_inputs["attention_mask"][:-1, : attention_mask.shape[1]] = attention_mask

if self.tokenizer.padding_side == "left":

batched_inputs["attention_mask"][
: attention_mask.shape[0], -attention_mask.shape[1] :
] = attention_mask

else:

batched_inputs["attention_mask"][
: attention_mask.shape[0], : attention_mask.shape[1]
] = attention_mask

return ((batched_inputs,), {"labels": batched_labels})

Expand Down

0 comments on commit f2b73e3

Please sign in to comment.