Skip to content

Commit

Permalink
add staling gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Dec 19, 2024
1 parent 668263d commit 3d3bbbc
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 57 deletions.
5 changes: 3 additions & 2 deletions src/zeroband/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def gloo_all_reduce(
tensor: torch.Tensor,
op: dist.ReduceOp = dist.ReduceOp.SUM,
group: Optional[dist.ProcessGroup] = None,
) -> None:
async_op: bool = False,
) -> None | dist.Work:
"""Wrap gloo all reduce"""
if group is None:
group = dist.distributed_c10d._get_default_group()
Expand All @@ -24,7 +25,7 @@ def gloo_all_reduce(
# todo check numerical stability of doing post or pre div
tensor.div_(group.size())

dist.all_reduce(tensor, op, group=group)
return dist.all_reduce(tensor, op, group=group, async_op=async_op)


class Compression(Enum):
Expand Down
138 changes: 88 additions & 50 deletions src/zeroband/global_ddp.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,84 @@
import time
from typing import Generator, NamedTuple
from pydantic import model_validator
from pydantic_config import BaseConfig
import torch
import torch.nn as nn
from zeroband.comms import ElasticDeviceMesh
import torch.distributed as dist
from zeroband.collectives import Compression, all_reduce
from zeroband.collectives import Compression, gloo_all_reduce
from torch.distributed._tensor.api import DTensor

from zeroband.utils.logging import get_logger
from zeroband.utils.world_info import get_world_info

from torch.distributed import Work


class GlobalDDPConfig(BaseConfig):
retry_all_reduce: int = 3
# retry_all_reduce: int = 3
compression: Compression = Compression.NO
dpu: bool = False
enable: bool = True

@model_validator(mode="after")
def validate_dpu(self):
if self.dpu:
raise NotImplementedError("DPU is not implemented yet")

def validate_compression(self):
if self.compression != Compression.NO:
raise NotImplementedError("Compression is not implemented yet")
return self


def offload_grad_generator(model: nn.Module) -> Generator:
for param in model.parameters():
if param.grad is not None:
if isinstance(param.grad, DTensor):
yield param.grad.to_local().to("cpu")
else:
yield param.grad.to("cpu")


def apply_staling_grad(model: nn.Module, tensors: list[torch.Tensor]):
for param, tensor in zip(model.parameters(), tensors):
if isinstance(param.grad, DTensor):
param.grad.to_local().copy_(tensor)
else:
param.grad.copy_(tensor)


def maybe_unwrap_dtensor(tensor: torch.Tensor | DTensor):
if isinstance(tensor, DTensor):
return tensor.to_local()
else:
return tensor


class AllReduceGradWork(NamedTuple):
grad: torch.Tensor
work: Work


class GlobalDDP:
"""
This class implements DDP over internet. It
:Args:
model: The model to be trained
config: The configuration for the global DDP
elastic_device_mesh: The elastic device mesh to be used
dpu: Whether to use delayed parameter updates
Example usage:
```
global_ddp = GlobalDDP(model, elastic_device_mesh)
config = GlobalDDPConfig(dpu=False)
global_ddp = GlobalDDP(model, config, elastic_device_mesh)
for step in range(num_steps):
for micro_bs in range(num_micro_bs):
optimizer.zero_grad()
loss = model(batch)
loss.backward()
global_ddp.all_reduce()
optimizer.step()
diloco.step(model)
optimizer.zero_grad()
```
"""
Expand All @@ -64,54 +94,62 @@ def __init__(
self.elastic_device_mesh = elastic_device_mesh
self.config = config

self._logger = get_logger()
self.world_info = get_world_info()
self._logger = get_logger()

self.model = model

def all_reduce(self):
self._all_reduce(list(self.model.parameters()))
self._stalling_grad_work: list[AllReduceGradWork] | None = None

def _all_reduce(self, tensor: list[torch.Tensor]):
_start_time = time.perf_counter()
def all_reduce(self):
if not self.config.dpu:
self._blocking_all_reduce(self.model)
else:
new_staling_grad_work = self._async_all_reduce(self.model)

if self._stalling_grad_work is None:
# if it is the first step we just store the work for the next call to this function and return
self._stalling_grad_work = new_staling_grad_work
else:
# otherwise we wait for the current staling grad work to finish
start_time = time.time()
[all_reduce_grad_work.work.wait() for all_reduce_grad_work in self._stalling_grad_work]
self._logger.debug(f"Time to wait for staling grads: {time.time() - start_time}")
# and apply the staling grads to the model
apply_staling_grad(
self.model, [all_reduce_grad_work.grad for all_reduce_grad_work in self._stalling_grad_work]
)
# and store the new staling grad work for the next call to this function
self._stalling_grad_work = new_staling_grad_work

def _async_all_reduce(self, model: nn.Module) -> list[AllReduceGradWork]:
"""
Triggered all reduce operation on a list of tensors in a async manner.
Return a list of async jobs that can be waited on.
"""

self.elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=False)
world_size = self.elastic_device_mesh.global_pg.size()

self._logger.debug("sync pseudo gradient with world size %d", world_size)

global_pg = self.elastic_device_mesh.global_pg
self.elastic_device_mesh.monitored_barrier(self.flag)
self._logger.debug("Beginning all reduce")

for i in range(self.config.retry_all_reduce):
try:
_collective_start_time = time.perf_counter()
self._logger.debug("Waiting on barrier")
self.elastic_device_mesh.monitored_barrier(self.flag)

self._logger.debug("Beginning all reduce")

total_param = len(tensor)
for j, param in enumerate(tensor):
t0 = time.perf_counter()
if isinstance(param.grad, DTensor):
grad = param.grad.to_local()
else:
grad = param

grad.div_(world_size)

all_reduce(self.config.compression, grad, dist.ReduceOp.SUM, global_pg)
self._logger.debug(
f"{j}/{total_param} all reduce bucket done in {time.perf_counter() - t0:.6f} seconds, numel: {grad.numel()}"
)
break
except Exception as e:
self._logger.error(f"Error syncing pseudo gradient: {e}, retry {i+1}/{self.config.retry_all_reduce}")
global_pg = self.elastic_device_mesh.get_global_pg(maybe_reinit=True)
else:
self._logger.error(
"Failed to sync pseudo gradient after %d retries. Resorting to calculating pseudo-gradient without reduce",
self.config.retry_all_reduce,
)
async_job = []

for param in offload_grad_generator(model): # TODO: do we need to offload when doing blocking all reduce ?
grad = maybe_unwrap_dtensor(param)

grad.div_(world_size)

# all_reduce(self.config.compression, grad, dist.ReduceOp.SUM, global_pg) # doing gloo all reduce direclty because of async op

async_job.append(AllReduceGradWork(grad, gloo_all_reduce(grad, dist.ReduceOp.SUM, global_pg, True)))

return async_job

self._logger.info(f"Global gradient all reduce done in {time.perf_counter() - _start_time:.6f} seconds")
def _blocking_all_reduce(self, tensor: list[torch.Tensor]):
"""
Triggered all reduce operation on a list of tensors in a blocking manner.
"""
[all_reduce_grad_work.work.wait() for all_reduce_grad_work in self._async_all_reduce(tensor)]
12 changes: 9 additions & 3 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def train(config: Config):
logger.info("starting training")

need_live_recovery = config.ckpt.live_recovery_rank_src is not None
first_step = True
while True:
if num_inner_steps > 1:
# if we don't use diloco we don't print the outer step logs
Expand Down Expand Up @@ -408,9 +409,14 @@ def train(config: Config):
if config.global_ddp:
global_ddp.all_reduce()

inner_optimizer.step()
scheduler.step()
inner_optimizer.zero_grad()
if config.global_ddp is not None and config.global_ddp.dpu and first_step:
inner_optimizer.zero_grad()
first_step = False
## if we are at the beginning of the dpu bubble we need to skip the first step
else:
inner_optimizer.step()
scheduler.step()
inner_optimizer.zero_grad()

# logging
training_progress.step += 1
Expand Down
6 changes: 4 additions & 2 deletions tests/test_torchrun/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def test_packing(packing: bool):
_test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[packing_arg])


def test_global_ddp():
@pytest.mark.parametrize("dpu", [True, False])
def test_global_ddp(dpu: bool):
num_gpus = [2, 1]
_test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=["--global_ddp.enable"], multi_nodes=True)
dpu_arg = "--global_ddp.dpu" if dpu else "--no-global_ddp.dpu"
_test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[dpu_arg], multi_nodes=True)

0 comments on commit 3d3bbbc

Please sign in to comment.