Skip to content

Commit

Permalink
add mfu, hfu, tokens_per_sec, and iters_per_sec logging
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony committed Dec 18, 2024
1 parent 5ae53db commit 97e57d5
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 4 deletions.
136 changes: 134 additions & 2 deletions megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,12 @@ def human_readable_flops(num) -> str:
return "%.1f%s" % (num, "Yi")


def get_flops(neox_args, iter_time_s) -> float:
def get_actual_flops(neox_args, iter_time_s) -> float:
"""
This function finds the actual FLOPs achieved accounting for implementation and hardware details. Also used for HFU.
For more detail on flop calculations, see https://github.com/EleutherAI/cookbook/tree/main/calc and https://github.com/Zyphra/zcookbook/tree/main/calc
Use FLOPS calculation from Megatron-DeepSpeed:
https://github.com/microsoft/Megatron-DeepSpeed/blob/cc3a94c636789f74be2bc6cfc62a3d723fd5d749/megatron/utils.py#L253
They get it from https://arxiv.org/pdf/2104.04473.pdf
Expand Down Expand Up @@ -156,6 +160,83 @@ def get_flops(neox_args, iter_time_s) -> float:
return flops_per_iteration / (iter_time_s * world_size)


def get_forward_backward_flops(neox_args, iter_time_s) -> float:
"""
This function finds the estimated FLOPs required by a single forward+backward pass without accounting for implementation and hardware details. Also used for MFU.
Mostly duplicated from get_actual_flops with just a change in activation checkpointing for now, but these may diverge over time as implementation details accumulate so I think 2 separate functions are appropriate.
For more detail on flop calculations, see https://github.com/EleutherAI/cookbook/tree/main/calc and https://github.com/Zyphra/zcookbook/tree/main/calc
Use FLOPS calculation from Megatron-DeepSpeed:
https://github.com/microsoft/Megatron-DeepSpeed/blob/cc3a94c636789f74be2bc6cfc62a3d723fd5d749/megatron/utils.py#L253
They get it from https://arxiv.org/pdf/2104.04473.pdf
"""
world_size = torch.distributed.get_world_size()
vocab_size = neox_args.padded_vocab_size
batch_size = neox_args.train_batch_size
seq_len = neox_args.seq_length
hidden_size = neox_args.hidden_size
num_layers = neox_args.num_layers
fwd_bwd_factor = 3 # 1 for fwd, 2 for bwd and weight update
if "rwkv" in neox_args.attention_config:
num_heads = neox_args.num_attention_heads

flops_per_iteration = (
batch_size
* seq_len
* (
78 * hidden_size * hidden_size * num_layers
+ 84 * hidden_size * num_layers
+ 16 * hidden_size
+ 12 * hidden_size * vocab_size
+ 18 * hidden_size * hidden_size * num_layers / num_heads
)
)
elif "mamba" in neox_args.attention_config:
# from https://github.com/Zyphra/zcookbook/blob/main/calc/calc_mamba_flops.py
if neox_args.expansion_factor:
d_inner = neox_args.hidden_size * neox_args.expansion_factor
elif neox_args.intermediate_size:
d_inner = neox_args.intermediate_size
else:
d_inner = neox_args.hidden_size * 2 # default expansion factor
d_state = 16 # TODO make d_state an arg. Currently hardcoded in neox mamba definition and here
conv_dimension = 4 # TODO make conv_dimension an arg. Currently hardcoded in neox mamba definition and here
dt_rank = math.ceil(neox_args.hidden_size / 16)
ssm_flops = (
fwd_bwd_factor
* d_inner
* seq_len
* batch_size
* (11 * d_state + 4 * dt_rank + 1)
)
mamba_projectors_flops = (
fwd_bwd_factor * seq_len * batch_size * 6 * d_inner * hidden_size
)
mamba_conv_flops = (
fwd_bwd_factor * seq_len * batch_size * 2 * d_inner * conv_dimension
)
mamba_flops = ssm_flops + mamba_projectors_flops + mamba_conv_flops
embedding_flops = 6 * seq_len * batch_size * hidden_size * vocab_size
flops_per_iteration = mamba_flops * num_layers + embedding_flops
else:
flops_per_iteration = (
24
* fwd_bwd_factor
* batch_size
* seq_len
* num_layers
* (hidden_size**2)
* (
1.0
+ (seq_len / (6.0 * hidden_size))
+ (vocab_size / (16.0 * num_layers * hidden_size))
)
)
return flops_per_iteration / (iter_time_s * world_size)


def training_log(
neox_args,
timers,
Expand Down Expand Up @@ -350,6 +431,8 @@ def add_to_logging(name):
elapsed_time = timers("interval time").elapsed()
iteration_time = elapsed_time / neox_args.log_interval
samples_per_sec = neox_args.train_batch_size / iteration_time
steps_per_sec = 1 / iteration_time
tokens_per_sec = samples_per_sec * neox_args.seq_length
log_string = " samples/sec: {:.3f} |".format(samples_per_sec)
tb_wandb_log(
"runtime/samples_per_sec",
Expand All @@ -367,6 +450,22 @@ def add_to_logging(name):
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
tb_wandb_log(
"runtime/steps_per_sec",
steps_per_sec,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
tb_wandb_log(
"runtime/tokens_per_sec",
tokens_per_sec,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)
log_string += " iteration {:8d}/{:8d} |".format(
iteration, neox_args.train_iters
)
Expand All @@ -390,7 +489,7 @@ def add_to_logging(name):
)

# log tflop / gpu
flops_per_s_per_gpu = get_flops(neox_args, iteration_time)
flops_per_s_per_gpu = get_actual_flops(neox_args, iteration_time)

log_string += (
f" approx flops per GPU: {human_readable_flops(flops_per_s_per_gpu)} |"
Expand All @@ -404,6 +503,39 @@ def add_to_logging(name):
comet_experiment=neox_args.comet_experiment,
)

if neox_args.peak_theoretical_tflops:
# Convert peak theoretical TFLOPS to FLOPS for consistent units
peak_theoretical_flops = neox_args.peak_theoretical_tflops * (10**12)

# Calculate MFU and HFU as percentages
mfu = (
get_forward_backward_flops(neox_args, iteration_time)
/ peak_theoretical_flops
) * 100
hfu = (flops_per_s_per_gpu / peak_theoretical_flops) * 100

# Add to log string
log_string += f" MFU: {mfu:.2f}% | HFU: {hfu:.2f}% |"

# Log to tracking systems
tb_wandb_log(
"runtime/model_flops_utilization",
mfu,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)

tb_wandb_log(
"runtime/hardware_flops_utilization",
hfu,
iteration,
use_wandb=neox_args.use_wandb,
tensorboard_writer=neox_args.tensorboard_writer,
comet_experiment=neox_args.comet_experiment,
)

for key in total_loss_dict:
if key not in [skipped_iters_key, got_nan_key]:
v = (
Expand Down
16 changes: 14 additions & 2 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@ class NeoXArgsLogging(NeoXArgsTemplate):
Logging Arguments
"""

### BEGIN WANDB ARGS ###
use_wandb: bool = None
"""Flag indicating if wandb is to be used."""

Expand All @@ -644,6 +645,7 @@ class NeoXArgsLogging(NeoXArgsTemplate):

wandb_init_all_ranks: bool = False
"""Initialize wandb on all ranks."""
### END WANDB ARGS ###

git_hash: str = get_git_commit_hash()
"""current git hash of repository"""
Expand All @@ -653,6 +655,7 @@ class NeoXArgsLogging(NeoXArgsTemplate):
Directory to save logs to.
"""

### BEGIN TENSORBOARD ARGS ###
tensorboard_writer = None
"""
initialized tensorboard writer
Expand All @@ -662,7 +665,9 @@ class NeoXArgsLogging(NeoXArgsTemplate):
"""
Write TensorBoard logs to this directory.
"""
### END TENSORBOARD ARGS ###

### BEGIN COMET ARGS ###
use_comet: bool = None
"""Flag indicating if comet is to be used."""

Expand Down Expand Up @@ -695,6 +700,12 @@ class NeoXArgsLogging(NeoXArgsTemplate):
"""
Initialized comet experiment object used to log data
"""
### END COMET ARGS ###

peak_theoretical_tflops: float = None
"""
The peak hardware flops with which to compute MFU and HFU, in units of teraflops. Automatic detection is more trouble than it's worth, so this is left to the user. Helpful table listed at https://github.com/stas00/ml-engineering/tree/master/compute/accelerator#tflops-comparison-table
"""

log_interval: int = 100
"""
Expand All @@ -714,8 +725,7 @@ class NeoXArgsLogging(NeoXArgsTemplate):
log_grad_norm: bool = False
"""
Log the frob norm of the gradients to wandb / tensorboard (useful for debugging).
(N.B - this will only work with pp = 0 for now, as we don't have access to the gradients of the model because
deepspeed.)
(N.B - this will only work with pp = 0 for now, as we don't have access to the gradients of the model because deepspeed.)
"""

log_optimizer_states: bool = False
Expand All @@ -738,6 +748,7 @@ class NeoXArgsLogging(NeoXArgsTemplate):
Whether to offload the buffered gradients to cpu when measuring gradient noise scale.
"""

### BEGIN PROFILING ARGS
memory_profiling: bool = False
"""
Whether to take a memory snapshot of the model. Useful for debugging memory issues.
Expand Down Expand Up @@ -770,6 +781,7 @@ class NeoXArgsLogging(NeoXArgsTemplate):
"""
Step to stop profiling at.
"""
### END PROFILING ARGS ###


@dataclass
Expand Down

0 comments on commit 97e57d5

Please sign in to comment.