diff --git a/megatron/training.py b/megatron/training.py index 733904dec..d9932483a 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -316,7 +316,7 @@ def get_batch(neox_args, data_iterator): # Items and their type. if neox_args.train_impl == "normal": - keys = ["text", "label"] if neox_args.label_data_paths else ["text"] + keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] elif neox_args.train_impl == "dpo": keys = ( [["pos", "pos_label"], ["neg", "neg_label"]] @@ -1016,7 +1016,7 @@ def train_step( reduced_loss = train_step_pipe( neox_args=neox_args, timers=timers, model=model, data_iterator=data_iterator ) - reduce_metrics = {"lm_loss": reduced_loss} + reduce_metrics = reduced_loss if ( neox_args.memory_profiling and neox_args.iteration >= neox_args.profile_step_start