diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 39cae775..ad794fd5 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -172,13 +172,17 @@ def train(config: Config): input_ids = batch["input_ids"].to("cuda") labels = batch["labels"].to("cuda") + logger.debug(f"input_ids: {input_ids[0][0:10]}") + logger.debug(f"labels: {labels[0][0:10]}") + with model.no_sync() if is_accumulating else nullcontext(): logits = model(tokens=input_ids).contiguous() flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab") flatten_labels = rearrange(labels, "b seq -> (b seq)") loss = ( - F.cross_entropy(flatten_logits, flatten_labels, ignore_index=-100) / gradient_accumulation_steps + F.cross_entropy(flatten_logits, flatten_labels, ignore_index=tokenizer.pad_token_id) + / gradient_accumulation_steps ) loss.backward() loss_batch += loss.detach() @@ -192,7 +196,7 @@ def train(config: Config): real_step = outer_step * num_inner_steps + inner_step + 1 # add + 1 because inner_step start at 0 inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0] - dist.all_reduce(loss_batch, op=dist.ReduceOp.AVG) + dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG) # syncing loss across all data parallel rank # todo(sami): when using diloco make sure that the loss is computed only on local world