Skip to content

Commit

Permalink
fix ignore padding token at loss level
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 24, 2024
1 parent b888717 commit 3b8c4ba
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,17 @@ def train(config: Config):
input_ids = batch["input_ids"].to("cuda")
labels = batch["labels"].to("cuda")

logger.debug(f"input_ids: {input_ids[0][0:10]}")
logger.debug(f"labels: {labels[0][0:10]}")

with model.no_sync() if is_accumulating else nullcontext():
logits = model(tokens=input_ids).contiguous()
flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab")
flatten_labels = rearrange(labels, "b seq -> (b seq)")

loss = (
F.cross_entropy(flatten_logits, flatten_labels, ignore_index=-100) / gradient_accumulation_steps
F.cross_entropy(flatten_logits, flatten_labels, ignore_index=tokenizer.pad_token_id)
/ gradient_accumulation_steps
)
loss.backward()
loss_batch += loss.detach()
Expand All @@ -192,7 +196,7 @@ def train(config: Config):
real_step = outer_step * num_inner_steps + inner_step + 1 # add + 1 because inner_step start at 0
inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0]

dist.all_reduce(loss_batch, op=dist.ReduceOp.AVG)
dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG)
# syncing loss across all data parallel rank
# todo(sami): when using diloco make sure that the loss is computed only on local world

Expand Down

0 comments on commit 3b8c4ba

Please sign in to comment.