Skip to content

Commit

Permalink
fix collate fn
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 24, 2024
1 parent ee43118 commit ccc0c50
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/zeroband/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def _collate_fn_causal_mask(
elif len(input_ids) > max_seq_length:
input_ids = input_ids[:max_seq_length]

batched["input_ids"].append(input_ids[1:])
batched["labels"].append(input_ids[:-1])
batched["input_ids"].append(input_ids[:-1])
batched["labels"].append(input_ids[1:])

return {"input_ids": torch.stack(batched["input_ids"], dim=0), "labels": torch.stack(batched["labels"], dim=0)}

Expand Down
19 changes: 19 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch
from zeroband.data import collate_causal_mask


def test_collate_fn():
tensors = [[0, 1, 2, 3, 4], [0, 0, 3, 4, 1, 7]]

batch = [{"input_ids": torch.Tensor(tensor)} for tensor in tensors]

collate_fn = collate_causal_mask(max_seq_length=4)
collated = collate_fn(batch)

assert collated is not None

assert collated["input_ids"][0].tolist() == [0, 1, 2, 3]
assert collated["labels"][0].tolist() == [1, 2, 3, 4]

assert collated["input_ids"][1].tolist() == [0, 0, 3, 4]
assert collated["labels"][1].tolist() == [0, 3, 4, 1]

0 comments on commit ccc0c50

Please sign in to comment.