Skip to content

Commit

Permalink
log all reduce time diloco (#25)
Browse files Browse the repository at this point in the history
* log all reduce time dilocoi

* add mfu which take into consideration the all reduce

---------

Co-authored-by: Sami jaghouar <[email protected]>
  • Loading branch information
samsja and Sami jaghouar authored Sep 30, 2024
1 parent 6475ea4 commit ac78db1
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/zeroband/diloco.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from pydantic_config import BaseConfig
import torch
from torch import nn
Expand Down Expand Up @@ -106,7 +107,10 @@ def step(self, model: nn.Module):
"""
Step the optimizer
"""
time_start = time.perf_counter()
self.sync_pseudo_gradient(model)
self._logger.info(f"all reduce pseudo gradient in: {time.perf_counter() - time_start} seconds")

if self.outer_optimizer is not None:
self.outer_optimizer.step()
self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this
Expand Down
12 changes: 12 additions & 0 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from contextlib import nullcontext
from typing import Literal
import time

import torch
from pydantic_config import parse_argv, BaseConfig
Expand Down Expand Up @@ -201,6 +202,7 @@ def train(config: Config):
# if we don't use diloco we don't print the outer step logs
logger.info(f"outer_step step: {training_progress.outer_step}")

time_start_outer = time.perf_counter()
for _inner_step in range(num_inner_steps):
loss_batch = 0

Expand Down Expand Up @@ -290,6 +292,16 @@ def train(config: Config):
# we only allow to checkpoint after a outer step. For non diloco training outer step = 1 anyway
ckpt_manager.save(config.ckpt.path, config.ckpt.remote_path)

if config.diloco:
tokens_per_second = (
config.optim.batch_size
* config.diloco.inner_steps
* config.data.seq_length
/ (time.perf_counter() - time_start_outer)
)
mfu = 100 * num_flop_per_token * tokens_per_second / gpu_peak_flops / world_info.local_world_size
logger.info(f"effective mfu: {mfu}")

if training_progress.step >= config.optim.total_steps:
# we only allow to break outisde of the inner loop.
# This avoid ending the training in the middle of a the inner loop
Expand Down
2 changes: 1 addition & 1 deletion src/zeroband/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def count_tokens(self, tokens: int):
def get_tokens_per_second(self) -> float | None:
if len(self.tokens) < 2:
return None
return sum(self.tokens) / (self.times[-1] - self.times[0])
return sum(self.tokens[1:]) / (self.times[-1] - self.times[0])


TENSOR_SIG_SAMPLE_SIZE = 100
Expand Down

0 comments on commit ac78db1

Please sign in to comment.