diff --git a/src/zeroband/collectives.py b/src/zeroband/collectives.py index efdb3ea1..ea87575c 100644 --- a/src/zeroband/collectives.py +++ b/src/zeroband/collectives.py @@ -12,7 +12,8 @@ def gloo_all_reduce( tensor: torch.Tensor, op: dist.ReduceOp = dist.ReduceOp.SUM, group: Optional[dist.ProcessGroup] = None, -) -> None: + async_op: bool = False, +) -> None | dist.Work: """Wrap gloo all reduce""" if group is None: group = dist.distributed_c10d._get_default_group() @@ -24,7 +25,7 @@ def gloo_all_reduce( # todo check numerical stability of doing post or pre div tensor.div_(group.size()) - dist.all_reduce(tensor, op, group=group) + return dist.all_reduce(tensor, op, group=group, async_op=async_op) class Compression(Enum): diff --git a/src/zeroband/global_ddp.py b/src/zeroband/global_ddp.py index fbe1db96..84c0bfb6 100644 --- a/src/zeroband/global_ddp.py +++ b/src/zeroband/global_ddp.py @@ -1,31 +1,61 @@ import time +from typing import Generator, NamedTuple from pydantic import model_validator from pydantic_config import BaseConfig import torch import torch.nn as nn from zeroband.comms import ElasticDeviceMesh import torch.distributed as dist -from zeroband.collectives import Compression, all_reduce +from zeroband.collectives import Compression, gloo_all_reduce from torch.distributed._tensor.api import DTensor - from zeroband.utils.logging import get_logger from zeroband.utils.world_info import get_world_info +from torch.distributed import Work + class GlobalDDPConfig(BaseConfig): - retry_all_reduce: int = 3 + # retry_all_reduce: int = 3 compression: Compression = Compression.NO dpu: bool = False enable: bool = True @model_validator(mode="after") - def validate_dpu(self): - if self.dpu: - raise NotImplementedError("DPU is not implemented yet") - + def validate_compression(self): + if self.compression != Compression.NO: + raise NotImplementedError("Compression is not implemented yet") return self +def offload_grad_generator(model: nn.Module) -> Generator: + for param in model.parameters(): + if param.grad is not None: + if isinstance(param.grad, DTensor): + yield param.grad.to_local().to("cpu") + else: + yield param.grad.to("cpu") + + +def apply_staling_grad(model: nn.Module, tensors: list[torch.Tensor]): + for param, tensor in zip(model.parameters(), tensors): + if isinstance(param.grad, DTensor): + param.grad.to_local().copy_(tensor) + else: + param.grad.copy_(tensor) + + +def maybe_unwrap_dtensor(tensor: torch.Tensor | DTensor): + if isinstance(tensor, DTensor): + return tensor.to_local() + else: + return tensor + + +class AllReduceGradWork(NamedTuple): + grad: torch.Tensor + work: Work + + class GlobalDDP: """ This class implements DDP over internet. It @@ -64,54 +94,62 @@ def __init__( self.elastic_device_mesh = elastic_device_mesh self.config = config - self._logger = get_logger() self.world_info = get_world_info() + self._logger = get_logger() self.model = model - def all_reduce(self): - self._all_reduce(list(self.model.parameters())) + self._stalling_grad_work: list[AllReduceGradWork] | None = None - def _all_reduce(self, tensor: list[torch.Tensor]): - _start_time = time.perf_counter() + def all_reduce(self): + if not self.config.dpu: + self._blocking_all_reduce(self.model) + else: + new_staling_grad_work = self._async_all_reduce(self.model) + + if self._stalling_grad_work is None: + # if it is the first step we just store the work for the next call to this function and return + self._stalling_grad_work = new_staling_grad_work + else: + # otherwise we wait for the current staling grad work to finish + start_time = time.time() + [all_reduce_grad_work.work.wait() for all_reduce_grad_work in self._stalling_grad_work] + self._logger.debug(f"Time to wait for staling grads: {time.time() - start_time}") + # and apply the staling grads to the model + apply_staling_grad( + self.model, [all_reduce_grad_work.grad for all_reduce_grad_work in self._stalling_grad_work] + ) + # and store the new staling grad work for the next call to this function + self._stalling_grad_work = new_staling_grad_work + + def _async_all_reduce(self, model: nn.Module) -> list[AllReduceGradWork]: + """ + Triggered all reduce operation on a list of tensors in a async manner. + Return a list of async jobs that can be waited on. + """ self.elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=False) world_size = self.elastic_device_mesh.global_pg.size() - self._logger.debug("sync pseudo gradient with world size %d", world_size) - global_pg = self.elastic_device_mesh.global_pg + self.elastic_device_mesh.monitored_barrier(self.flag) + self._logger.debug("Beginning all reduce") - for i in range(self.config.retry_all_reduce): - try: - _collective_start_time = time.perf_counter() - self._logger.debug("Waiting on barrier") - self.elastic_device_mesh.monitored_barrier(self.flag) - - self._logger.debug("Beginning all reduce") - - total_param = len(tensor) - for j, param in enumerate(tensor): - t0 = time.perf_counter() - if isinstance(param.grad, DTensor): - grad = param.grad.to_local() - else: - grad = param - - grad.div_(world_size) - - all_reduce(self.config.compression, grad, dist.ReduceOp.SUM, global_pg) - self._logger.debug( - f"{j}/{total_param} all reduce bucket done in {time.perf_counter() - t0:.6f} seconds, numel: {grad.numel()}" - ) - break - except Exception as e: - self._logger.error(f"Error syncing pseudo gradient: {e}, retry {i+1}/{self.config.retry_all_reduce}") - global_pg = self.elastic_device_mesh.get_global_pg(maybe_reinit=True) - else: - self._logger.error( - "Failed to sync pseudo gradient after %d retries. Resorting to calculating pseudo-gradient without reduce", - self.config.retry_all_reduce, - ) + async_job = [] + + for param in offload_grad_generator(model): # TODO: do we need to offload when doing blocking all reduce ? + grad = maybe_unwrap_dtensor(param) + + grad.div_(world_size) + + # all_reduce(self.config.compression, grad, dist.ReduceOp.SUM, global_pg) # doing gloo all reduce direclty because of async op + + async_job.append(AllReduceGradWork(grad, gloo_all_reduce(grad, dist.ReduceOp.SUM, global_pg, True))) + + return async_job - self._logger.info(f"Global gradient all reduce done in {time.perf_counter() - _start_time:.6f} seconds") + def _blocking_all_reduce(self, tensor: list[torch.Tensor]): + """ + Triggered all reduce operation on a list of tensors in a blocking manner. + """ + [all_reduce_grad_work.work.wait() for all_reduce_grad_work in self._async_all_reduce(tensor)] diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 0bd46b84..7ce836ff 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -296,6 +296,7 @@ def train(config: Config): logger.info("starting training") need_live_recovery = config.ckpt.live_recovery_rank_src is not None + first_step = True while True: if num_inner_steps > 1: # if we don't use diloco we don't print the outer step logs @@ -408,9 +409,14 @@ def train(config: Config): if config.global_ddp: global_ddp.all_reduce() - inner_optimizer.step() - scheduler.step() - inner_optimizer.zero_grad() + if config.global_ddp is not None and config.global_ddp.dpu and first_step: + inner_optimizer.zero_grad() + first_step = False + ## if we are at the beginning of the dpu bubble we need to skip the first step + else: + inner_optimizer.step() + scheduler.step() + inner_optimizer.zero_grad() # logging training_progress.step += 1 diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index d2813150..2022e161 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -114,6 +114,8 @@ def test_packing(packing: bool): _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[packing_arg]) -def test_global_ddp(): +@pytest.mark.parametrize("dpu", [True, False]) +def test_global_ddp(dpu: bool): num_gpus = [2, 1] - _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=["--global_ddp.enable"], multi_nodes=True) + dpu_arg = "--global_ddp.dpu" if dpu else "--no-global_ddp.dpu" + _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[dpu_arg], multi_nodes=True)