Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Jan 8, 2025
1 parent b0289e9 commit bbc555a
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions dolomite_engine/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def train_step_without_pipeline_parallel(
gradient_accumulation_steps = StepTracker.get_gradient_accumulation_steps()

batches = [get_next_batch(train_dataloader) for _ in range(gradient_accumulation_steps)]
lm_loss_multiplier = gradient_accumulation_steps / sum([(batch["labels"] != -100).sum() for batch in batches])
lm_loss_multiplier = 1 / sum([(batch["labels"] != -100).sum() for batch in batches])

with no_sync():
for batch in batches[:-1]:
Expand All @@ -82,8 +82,7 @@ def train_step_without_pipeline_parallel(

# compute gradients
with backward_context():
loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps
loss_micro_step_scaled.backward()
loss_micro_step_dict["loss"].backward()

with torch.inference_mode():
metrics_tracker = metrics_tracker + loss_micro_step_dict
Expand All @@ -97,8 +96,7 @@ def train_step_without_pipeline_parallel(

# compute gradients
with backward_context():
loss_micro_step_scaled: torch.Tensor = loss_micro_step_dict["loss"] / gradient_accumulation_steps
loss_micro_step_scaled.backward()
loss_micro_step_dict["loss"].backward()

with torch.inference_mode():
metrics_tracker = metrics_tracker + loss_micro_step_dict
Expand Down

0 comments on commit bbc555a

Please sign in to comment.