Skip to content

Commit

Permalink
Fix pipeline parallelism and incorrect neox_args name
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony committed Sep 8, 2024
1 parent e13b640 commit 4c5bdeb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def get_batch(neox_args, data_iterator):

# Items and their type.
if neox_args.train_impl == "normal":
keys = ["text", "label"] if neox_args.label_data_paths else ["text"]
keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"]
elif neox_args.train_impl == "dpo":
keys = (
[["pos", "pos_label"], ["neg", "neg_label"]]
Expand Down Expand Up @@ -1016,7 +1016,7 @@ def train_step(
reduced_loss = train_step_pipe(
neox_args=neox_args, timers=timers, model=model, data_iterator=data_iterator
)
reduce_metrics = {"lm_loss": reduced_loss}
reduce_metrics = reduced_loss
if (
neox_args.memory_profiling
and neox_args.iteration >= neox_args.profile_step_start
Expand Down

0 comments on commit 4c5bdeb

Please sign in to comment.