From 794940355fd60b5fb908d8e352836d4667d3cee5 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 12 Dec 2024 02:14:06 +0000 Subject: [PATCH 01/17] add global_ddp --- src/zeroband/global_ddp.py | 110 ++++++++++++++++++++++++++++++ src/zeroband/train.py | 13 +++- tests/test_torchrun/test_train.py | 13 ++-- 3 files changed, 131 insertions(+), 5 deletions(-) create mode 100644 src/zeroband/global_ddp.py diff --git a/src/zeroband/global_ddp.py b/src/zeroband/global_ddp.py new file mode 100644 index 00000000..7ad364e1 --- /dev/null +++ b/src/zeroband/global_ddp.py @@ -0,0 +1,110 @@ +import time +from pydantic import model_validator +from pydantic_config import BaseConfig +import torch.nn as nn +from zeroband.comms import ElasticDeviceMesh +import torch.distributed as dist +from zeroband.collectives import Compression, all_reduce +from torch.distributed._tensor.api import DTensor + +from zeroband.utils.logging import get_logger +from zeroband.utils.world_info import get_world_info + + +class GlobalDDPConfig(BaseConfig): + retry_all_reduce: int = 3 + compression: Compression = Compression.NO + dpu: bool = False + enable: bool = True + + @model_validator(mode="after") + def validate_dpu(self): + if self.dpu: + raise NotImplementedError("DPU is not implemented yet") + + return self + + +class GlobalDDP: + """ + This class implements DDP over internet. It + + :Args: + model: The model to be trained + elastic_device_mesh: The elastic device mesh to be used + dpu: Whether to use delayed parameter updates + + Example usage: + + ``` + global_ddp = GlobalDDP(model, elastic_device_mesh) + + for step in range(num_steps): + for micro_bs in range(num_micro_bs): + optimizer.zero_grad() + loss = model(batch) + loss.backward() + + global_ddp.all_reduce() + optimizer.step() + diloco.step(model) + ``` + + """ + + flag: str = "global_ddp" + + def __init__( + self, + config: GlobalDDPConfig, + elastic_device_mesh: ElasticDeviceMesh, + ): + self.elastic_device_mesh = elastic_device_mesh + self.config = config + + self._logger = get_logger() + self.world_info = get_world_info() + + def all_reduce(self, model: nn.Module): + _start_time = time.perf_counter() + + self.elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=False) + world_size = self.elastic_device_mesh.global_pg.size() + + self._logger.debug("sync pseudo gradient with world size %d", world_size) + + global_pg = self.elastic_device_mesh.global_pg + + for i in range(self.config.retry_all_reduce): + try: + _collective_start_time = time.perf_counter() + self._logger.debug("Waiting on barrier") + self.elastic_device_mesh.monitored_barrier(self.flag) + + self._logger.debug("Beginning all reduce") + + total_param = len(list(model.parameters())) + for j, param in enumerate(model.parameters()): + t0 = time.perf_counter() + if isinstance(param.grad, DTensor): + grad = param.grad.to_local() + else: + grad = param + + grad.div_(world_size) + + all_reduce(self.config.compression, grad, dist.ReduceOp.SUM, global_pg) + self._logger.debug( + f"{j}/{total_param} all reduce bucket done in {time.perf_counter() - t0:.6f} seconds, numel: {grad.numel()}" + ) + break + except Exception as e: + self._logger.error(f"Error syncing pseudo gradient: {e}, retry {i+1}/{self.config.retry_all_reduce}") + global_pg = self.elastic_device_mesh.get_global_pg(maybe_reinit=True) + else: + self._logger.error( + "Failed to sync pseudo gradient after %d retries. Resorting to calculating pseudo-gradient without reduce", + self.config.retry_all_reduce, + ) + + self._logger.info(f"Global gradient all reduce done in {time.perf_counter() - _start_time:.6f} seconds") diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 7ab7cb8d..5ebaa340 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -19,6 +19,7 @@ from zeroband import utils from zeroband.diloco import Diloco, DilocoConfig from zeroband.comms import ElasticDeviceMesh +from zeroband.global_ddp import GlobalDDP, GlobalDDPConfig from zeroband.loss import cross_entropy_max_z_loss from zeroband.models.llama.model import create_block_mask_from_seqlens @@ -105,6 +106,9 @@ class Config(BaseConfig): # sub config diloco: DilocoConfig | None = None + + global_ddp: GlobalDDPConfig | None = None + data: DataConfig = DataConfig() optim: OptimConfig = OptimConfig() train: TrainConfig @@ -185,7 +189,7 @@ def train(config: Config): apply_ac_ckpt(model, num) elastic_device_mesh = ElasticDeviceMesh( - enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src + enable=config.diloco is not None or config.global_ddp, live_recovery_rank_src=config.ckpt.live_recovery_rank_src ) mp_policy = MixedPrecisionPolicy( @@ -222,6 +226,9 @@ def train(config: Config): if config.diloco is not None: diloco = Diloco(config.diloco, model, elastic_device_mesh) + if config.global_ddp: + global_ddp = GlobalDDP(config.global_ddp, elastic_device_mesh) + scheduler = get_scheduler( sched_type=config.optim.sched_type, optimizer=inner_optimizer, @@ -397,6 +404,10 @@ def train(config: Config): dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + if config.global_ddp: + global_ddp.all_reduce(model) + inner_optimizer.step() scheduler.step() inner_optimizer.zero_grad() diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index e5703fe3..d2813150 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -34,7 +34,7 @@ def gpus_to_use(num_nodes, num_gpu, rank): return ",".join(map(str, range(rank * num_gpu, (rank + 1) * num_gpu))) -def _test_multi_gpu(num_gpus, config, extra_args=[], diloco=False): +def _test_multi_gpu(num_gpus, config, extra_args=[], multi_nodes=False): num_nodes, num_gpu = num_gpus[0], num_gpus[1] processes = [] @@ -53,7 +53,7 @@ def _test_multi_gpu(num_gpus, config, extra_args=[], diloco=False): env = copy.deepcopy(os.environ) - if diloco: + if multi_nodes: new_env = { "GLOBAL_RANK": str(i), "GLOBAL_UNIQUE_ID": str(i), @@ -83,7 +83,7 @@ def test_multi_gpu(num_gpus): @pytest.mark.parametrize("num_gpus", [[2, 1], [2, 2]] if num_gpu >= 4 else [[2, 1]]) def test_multi_gpu_diloco(num_gpus): - _test_multi_gpu(num_gpus, "debug/diloco.toml", diloco=True) + _test_multi_gpu(num_gpus, "debug/diloco.toml", multi_nodes=True) def test_act_ckpt(): @@ -99,7 +99,7 @@ def test_act_ckpt_num(): @pytest.mark.parametrize("backend", [Compression.NO, Compression.UINT8]) def test_all_reduce_diloco(backend: Compression): num_gpus = [2, 1] - _test_multi_gpu(num_gpus, "debug/diloco.toml", extra_args=["--diloco.compression", backend.value], diloco=True) + _test_multi_gpu(num_gpus, "debug/diloco.toml", extra_args=["--diloco.compression", backend.value], multi_nodes=True) def test_z_loss(): @@ -112,3 +112,8 @@ def test_packing(packing: bool): num_gpus = [2, 1] packing_arg = "--train.sequence_packing" if packing else "--no-train.sequence_packing" _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[packing_arg]) + + +def test_global_ddp(): + num_gpus = [2, 1] + _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=["--global_ddp.enable"], multi_nodes=True) From 668263db56da253e89825ec91e653dc276e61e08 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 12 Dec 2024 02:50:49 +0000 Subject: [PATCH 02/17] use _all_reduce --- src/zeroband/global_ddp.py | 13 ++++++++++--- src/zeroband/train.py | 4 ++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/zeroband/global_ddp.py b/src/zeroband/global_ddp.py index 7ad364e1..fbe1db96 100644 --- a/src/zeroband/global_ddp.py +++ b/src/zeroband/global_ddp.py @@ -1,6 +1,7 @@ import time from pydantic import model_validator from pydantic_config import BaseConfig +import torch import torch.nn as nn from zeroband.comms import ElasticDeviceMesh import torch.distributed as dist @@ -56,6 +57,7 @@ class GlobalDDP: def __init__( self, + model: nn.Module, config: GlobalDDPConfig, elastic_device_mesh: ElasticDeviceMesh, ): @@ -65,7 +67,12 @@ def __init__( self._logger = get_logger() self.world_info = get_world_info() - def all_reduce(self, model: nn.Module): + self.model = model + + def all_reduce(self): + self._all_reduce(list(self.model.parameters())) + + def _all_reduce(self, tensor: list[torch.Tensor]): _start_time = time.perf_counter() self.elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=False) @@ -83,8 +90,8 @@ def all_reduce(self, model: nn.Module): self._logger.debug("Beginning all reduce") - total_param = len(list(model.parameters())) - for j, param in enumerate(model.parameters()): + total_param = len(tensor) + for j, param in enumerate(tensor): t0 = time.perf_counter() if isinstance(param.grad, DTensor): grad = param.grad.to_local() diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 5ebaa340..0bd46b84 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -227,7 +227,7 @@ def train(config: Config): diloco = Diloco(config.diloco, model, elastic_device_mesh) if config.global_ddp: - global_ddp = GlobalDDP(config.global_ddp, elastic_device_mesh) + global_ddp = GlobalDDP(model=model, config=config.global_ddp, elastic_device_mesh=elastic_device_mesh) scheduler = get_scheduler( sched_type=config.optim.sched_type, @@ -406,7 +406,7 @@ def train(config: Config): torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) if config.global_ddp: - global_ddp.all_reduce(model) + global_ddp.all_reduce() inner_optimizer.step() scheduler.step() From 3d3bbbc71ad43d208a23be7f86b746984d851fb7 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 12 Dec 2024 18:36:31 +0000 Subject: [PATCH 03/17] add staling gradient --- src/zeroband/collectives.py | 5 +- src/zeroband/global_ddp.py | 138 +++++++++++++++++++----------- src/zeroband/train.py | 12 ++- tests/test_torchrun/test_train.py | 6 +- 4 files changed, 104 insertions(+), 57 deletions(-) diff --git a/src/zeroband/collectives.py b/src/zeroband/collectives.py index efdb3ea1..ea87575c 100644 --- a/src/zeroband/collectives.py +++ b/src/zeroband/collectives.py @@ -12,7 +12,8 @@ def gloo_all_reduce( tensor: torch.Tensor, op: dist.ReduceOp = dist.ReduceOp.SUM, group: Optional[dist.ProcessGroup] = None, -) -> None: + async_op: bool = False, +) -> None | dist.Work: """Wrap gloo all reduce""" if group is None: group = dist.distributed_c10d._get_default_group() @@ -24,7 +25,7 @@ def gloo_all_reduce( # todo check numerical stability of doing post or pre div tensor.div_(group.size()) - dist.all_reduce(tensor, op, group=group) + return dist.all_reduce(tensor, op, group=group, async_op=async_op) class Compression(Enum): diff --git a/src/zeroband/global_ddp.py b/src/zeroband/global_ddp.py index fbe1db96..1d94ea8e 100644 --- a/src/zeroband/global_ddp.py +++ b/src/zeroband/global_ddp.py @@ -1,54 +1,84 @@ import time +from typing import Generator, NamedTuple from pydantic import model_validator from pydantic_config import BaseConfig import torch import torch.nn as nn from zeroband.comms import ElasticDeviceMesh import torch.distributed as dist -from zeroband.collectives import Compression, all_reduce +from zeroband.collectives import Compression, gloo_all_reduce from torch.distributed._tensor.api import DTensor - from zeroband.utils.logging import get_logger from zeroband.utils.world_info import get_world_info +from torch.distributed import Work + class GlobalDDPConfig(BaseConfig): - retry_all_reduce: int = 3 + # retry_all_reduce: int = 3 compression: Compression = Compression.NO dpu: bool = False enable: bool = True @model_validator(mode="after") - def validate_dpu(self): - if self.dpu: - raise NotImplementedError("DPU is not implemented yet") - + def validate_compression(self): + if self.compression != Compression.NO: + raise NotImplementedError("Compression is not implemented yet") return self +def offload_grad_generator(model: nn.Module) -> Generator: + for param in model.parameters(): + if param.grad is not None: + if isinstance(param.grad, DTensor): + yield param.grad.to_local().to("cpu") + else: + yield param.grad.to("cpu") + + +def apply_staling_grad(model: nn.Module, tensors: list[torch.Tensor]): + for param, tensor in zip(model.parameters(), tensors): + if isinstance(param.grad, DTensor): + param.grad.to_local().copy_(tensor) + else: + param.grad.copy_(tensor) + + +def maybe_unwrap_dtensor(tensor: torch.Tensor | DTensor): + if isinstance(tensor, DTensor): + return tensor.to_local() + else: + return tensor + + +class AllReduceGradWork(NamedTuple): + grad: torch.Tensor + work: Work + + class GlobalDDP: """ This class implements DDP over internet. It :Args: model: The model to be trained + config: The configuration for the global DDP elastic_device_mesh: The elastic device mesh to be used - dpu: Whether to use delayed parameter updates Example usage: ``` - global_ddp = GlobalDDP(model, elastic_device_mesh) + config = GlobalDDPConfig(dpu=False) + global_ddp = GlobalDDP(model, config, elastic_device_mesh) for step in range(num_steps): for micro_bs in range(num_micro_bs): - optimizer.zero_grad() loss = model(batch) loss.backward() global_ddp.all_reduce() optimizer.step() - diloco.step(model) + optimizer.zero_grad() ``` """ @@ -64,54 +94,62 @@ def __init__( self.elastic_device_mesh = elastic_device_mesh self.config = config - self._logger = get_logger() self.world_info = get_world_info() + self._logger = get_logger() self.model = model - def all_reduce(self): - self._all_reduce(list(self.model.parameters())) + self._stalling_grad_work: list[AllReduceGradWork] | None = None - def _all_reduce(self, tensor: list[torch.Tensor]): - _start_time = time.perf_counter() + def all_reduce(self): + if not self.config.dpu: + self._blocking_all_reduce(self.model) + else: + new_staling_grad_work = self._async_all_reduce(self.model) + + if self._stalling_grad_work is None: + # if it is the first step we just store the work for the next call to this function and return + self._stalling_grad_work = new_staling_grad_work + else: + # otherwise we wait for the current staling grad work to finish + start_time = time.time() + [all_reduce_grad_work.work.wait() for all_reduce_grad_work in self._stalling_grad_work] + self._logger.debug(f"Time to wait for staling grads: {time.time() - start_time}") + # and apply the staling grads to the model + apply_staling_grad( + self.model, [all_reduce_grad_work.grad for all_reduce_grad_work in self._stalling_grad_work] + ) + # and store the new staling grad work for the next call to this function + self._stalling_grad_work = new_staling_grad_work + + def _async_all_reduce(self, model: nn.Module) -> list[AllReduceGradWork]: + """ + Triggered all reduce operation on a list of tensors in a async manner. + Return a list of async jobs that can be waited on. + """ self.elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=False) world_size = self.elastic_device_mesh.global_pg.size() - self._logger.debug("sync pseudo gradient with world size %d", world_size) - global_pg = self.elastic_device_mesh.global_pg + self.elastic_device_mesh.monitored_barrier(self.flag) + self._logger.debug("Beginning all reduce") - for i in range(self.config.retry_all_reduce): - try: - _collective_start_time = time.perf_counter() - self._logger.debug("Waiting on barrier") - self.elastic_device_mesh.monitored_barrier(self.flag) - - self._logger.debug("Beginning all reduce") - - total_param = len(tensor) - for j, param in enumerate(tensor): - t0 = time.perf_counter() - if isinstance(param.grad, DTensor): - grad = param.grad.to_local() - else: - grad = param - - grad.div_(world_size) - - all_reduce(self.config.compression, grad, dist.ReduceOp.SUM, global_pg) - self._logger.debug( - f"{j}/{total_param} all reduce bucket done in {time.perf_counter() - t0:.6f} seconds, numel: {grad.numel()}" - ) - break - except Exception as e: - self._logger.error(f"Error syncing pseudo gradient: {e}, retry {i+1}/{self.config.retry_all_reduce}") - global_pg = self.elastic_device_mesh.get_global_pg(maybe_reinit=True) - else: - self._logger.error( - "Failed to sync pseudo gradient after %d retries. Resorting to calculating pseudo-gradient without reduce", - self.config.retry_all_reduce, - ) + async_job = [] + + for param in offload_grad_generator(model): # TODO: do we need to offload when doing blocking all reduce ? + grad = maybe_unwrap_dtensor(param) + + grad.div_(world_size) + + # all_reduce(self.config.compression, grad, dist.ReduceOp.SUM, global_pg) # doing gloo all reduce direclty because of async op + + async_job.append(AllReduceGradWork(grad, gloo_all_reduce(grad, dist.ReduceOp.SUM, global_pg, True))) + + return async_job - self._logger.info(f"Global gradient all reduce done in {time.perf_counter() - _start_time:.6f} seconds") + def _blocking_all_reduce(self, tensor: list[torch.Tensor]): + """ + Triggered all reduce operation on a list of tensors in a blocking manner. + """ + [all_reduce_grad_work.work.wait() for all_reduce_grad_work in self._async_all_reduce(tensor)] diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 0bd46b84..7ce836ff 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -296,6 +296,7 @@ def train(config: Config): logger.info("starting training") need_live_recovery = config.ckpt.live_recovery_rank_src is not None + first_step = True while True: if num_inner_steps > 1: # if we don't use diloco we don't print the outer step logs @@ -408,9 +409,14 @@ def train(config: Config): if config.global_ddp: global_ddp.all_reduce() - inner_optimizer.step() - scheduler.step() - inner_optimizer.zero_grad() + if config.global_ddp is not None and config.global_ddp.dpu and first_step: + inner_optimizer.zero_grad() + first_step = False + ## if we are at the beginning of the dpu bubble we need to skip the first step + else: + inner_optimizer.step() + scheduler.step() + inner_optimizer.zero_grad() # logging training_progress.step += 1 diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index d2813150..2022e161 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -114,6 +114,8 @@ def test_packing(packing: bool): _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[packing_arg]) -def test_global_ddp(): +@pytest.mark.parametrize("dpu", [True, False]) +def test_global_ddp(dpu: bool): num_gpus = [2, 1] - _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=["--global_ddp.enable"], multi_nodes=True) + dpu_arg = "--global_ddp.dpu" if dpu else "--no-global_ddp.dpu" + _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[dpu_arg], multi_nodes=True) From d593c0795e81e629a0234a92c76fdba34bb59988 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 19 Dec 2024 17:01:36 +0000 Subject: [PATCH 04/17] make asyc all reduce a fn --- src/zeroband/global_ddp.py | 62 +++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/src/zeroband/global_ddp.py b/src/zeroband/global_ddp.py index 1d94ea8e..4bf4a049 100644 --- a/src/zeroband/global_ddp.py +++ b/src/zeroband/global_ddp.py @@ -13,6 +13,8 @@ from torch.distributed import Work +logger = get_logger(__name__) + class GlobalDDPConfig(BaseConfig): # retry_all_reduce: int = 3 @@ -56,6 +58,33 @@ class AllReduceGradWork(NamedTuple): work: Work +def async_all_reduce(model: nn.Module, elastic_device_mesh: ElasticDeviceMesh, flag: str) -> list[AllReduceGradWork]: + """ + Triggered all reduce operation on a list of tensors in a async manner. + Return a list of async jobs that can be waited on. + """ + + elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=False) + world_size = elastic_device_mesh.global_pg.size() + + global_pg = elastic_device_mesh.global_pg + elastic_device_mesh.monitored_barrier(flag) + logger.debug("Beginning all reduce") + + async_job = [] + + for param in offload_grad_generator(model): # TODO: do we need to offload when doing blocking all reduce ? + grad = maybe_unwrap_dtensor(param) + + grad.div_(world_size) + + # all_reduce(self.config.compression, grad, dist.ReduceOp.SUM, global_pg) # doing gloo all reduce direclty because of async op + + async_job.append(AllReduceGradWork(grad, gloo_all_reduce(grad, dist.ReduceOp.SUM, global_pg, True))) + + return async_job + + class GlobalDDP: """ This class implements DDP over internet. It @@ -105,7 +134,7 @@ def all_reduce(self): if not self.config.dpu: self._blocking_all_reduce(self.model) else: - new_staling_grad_work = self._async_all_reduce(self.model) + new_staling_grad_work = async_all_reduce(self.model, self.elastic_device_mesh, self.flag) if self._stalling_grad_work is None: # if it is the first step we just store the work for the next call to this function and return @@ -122,34 +151,11 @@ def all_reduce(self): # and store the new staling grad work for the next call to this function self._stalling_grad_work = new_staling_grad_work - def _async_all_reduce(self, model: nn.Module) -> list[AllReduceGradWork]: - """ - Triggered all reduce operation on a list of tensors in a async manner. - Return a list of async jobs that can be waited on. - """ - - self.elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=False) - world_size = self.elastic_device_mesh.global_pg.size() - - global_pg = self.elastic_device_mesh.global_pg - self.elastic_device_mesh.monitored_barrier(self.flag) - self._logger.debug("Beginning all reduce") - - async_job = [] - - for param in offload_grad_generator(model): # TODO: do we need to offload when doing blocking all reduce ? - grad = maybe_unwrap_dtensor(param) - - grad.div_(world_size) - - # all_reduce(self.config.compression, grad, dist.ReduceOp.SUM, global_pg) # doing gloo all reduce direclty because of async op - - async_job.append(AllReduceGradWork(grad, gloo_all_reduce(grad, dist.ReduceOp.SUM, global_pg, True))) - - return async_job - def _blocking_all_reduce(self, tensor: list[torch.Tensor]): """ Triggered all reduce operation on a list of tensors in a blocking manner. """ - [all_reduce_grad_work.work.wait() for all_reduce_grad_work in self._async_all_reduce(tensor)] + [ + all_reduce_grad_work.work.wait() + for all_reduce_grad_work in async_all_reduce(tensor, self.elastic_device_mesh, self.flag) + ] From 19a7c7b5edd9bcc832382270e4ed6487ca0ad35e Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sat, 11 Jan 2025 06:19:05 +0000 Subject: [PATCH 05/17] fix: add global ddp to config --- src/zeroband/config.py | 3 ++- src/zeroband/train.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/zeroband/config.py b/src/zeroband/config.py index 07d4b7e0..61ac6e56 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -6,6 +6,7 @@ from zeroband.checkpoint import CkptConfig from zeroband.data import DataConfig from zeroband.diloco import DilocoConfig +from zeroband.global_ddp import GlobalDDPConfig from zeroband.models.llama.model import AttnFnType from zeroband.optimizers import OptimizersConfig, AdamConfig @@ -68,6 +69,7 @@ class Config(BaseConfig): # sub config diloco: DilocoConfig | None = None + global_ddp: GlobalDDPConfig | None = None data: DataConfig = DataConfig() optim: OptimConfig = OptimConfig() train: TrainConfig @@ -88,4 +90,3 @@ def validate_live_recovery_rank_src(self): if self.ckpt is not None and self.ckpt.live_recovery_rank_src is not None and self.diloco is None: raise ValueError("live_recovery_rank_src is only supported with diloco") return self - diff --git a/src/zeroband/train.py b/src/zeroband/train.py index d283d934..27ad7825 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -15,11 +15,11 @@ from zeroband import utils from zeroband.diloco import Diloco from zeroband.comms import ElasticDeviceMesh -from zeroband.global_ddp import GlobalDDP, GlobalDDPConfig +from zeroband.global_ddp import GlobalDDP from zeroband.loss import cross_entropy_max_z_loss from zeroband.models.llama.model import create_block_mask_from_seqlens -from zeroband.config import Config #, MemoryProfilerConfig +from zeroband.config import Config # , MemoryProfilerConfig from zeroband.optimizers import get_optimizer from zeroband.utils import ( @@ -40,6 +40,7 @@ from zeroband.checkpoint import CkptManager, TrainingProgress from zeroband.lr_scheduler import get_scheduler + def log_hash_training_state( config: Config, model: torch.nn.Module, From 4a6b96c99b4efcd846f871c36af67d0cfd54ab1b Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sat, 11 Jan 2025 08:41:19 +0000 Subject: [PATCH 06/17] TEMP: isolated implementation --- llama-debug/config.json | 35 ++++++++++++ train_dpu.py | 115 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+) create mode 100644 llama-debug/config.json create mode 100644 train_dpu.py diff --git a/llama-debug/config.json b/llama-debug/config.json new file mode 100644 index 00000000..c3f7712c --- /dev/null +++ b/llama-debug/config.json @@ -0,0 +1,35 @@ +{ +"architectures": [ + "LlamaForCausalLM" +], +"attention_bias": false, +"attention_dropout": 0.0, +"bos_token_id": 128000, +"eos_token_id": 128001, +"head_dim": 64, +"hidden_act": "silu", +"hidden_size": 1024, +"initializer_range": 0.02, +"intermediate_size": 4096, +"max_position_embeddings": 1024, +"mlp_bias": false, +"model_type": "llama", +"num_attention_heads": 16, +"num_hidden_layers": 5, +"num_key_value_heads": 8, +"pretraining_tp": 1, +"rms_norm_eps": 1e-05, +"rope_scaling": { + "factor": 32.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" +}, +"rope_theta": 500000.0, +"tie_word_embeddings": true, +"torch_dtype": "bfloat16", +"transformers_version": "4.45.0.dev0", +"use_cache": true, +"vocab_size": 128256 +} \ No newline at end of file diff --git a/train_dpu.py b/train_dpu.py new file mode 100644 index 00000000..13fd5c29 --- /dev/null +++ b/train_dpu.py @@ -0,0 +1,115 @@ +from typing import List +import torch +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +from datasets import load_dataset +from torch.optim import AdamW, Optimizer + + +# Loss function +def compute_loss(model: torch.nn.Module, inputs: List[str], tokenizer) -> torch.Tensor: + """ + Compute the loss for a batch of input text using a causal language modeling objective. + + Args: + model (torch.nn.Module): The pre-trained model (e.g., Llama). + inputs (List[str]): A batch of input text strings. + tokenizer: The tokenizer associated with the model. + + Returns: + torch.Tensor: The computed loss value. + """ + # Tokenize input text and prepare for model input + input_ids = tokenizer( + inputs, return_tensors="pt", padding=True, truncation=True, max_length=model.config.max_position_embeddings + ).input_ids + input_ids = input_ids.to(model.device) + labels = input_ids.clone() + + # Compute the loss + outputs = model(input_ids, labels=labels) + return outputs.loss + + +# Optimizer update step +def optimizer_step(optimizer, model_params, gradients): + for param, grad in zip(model_params, gradients): + if param.grad is not None: + param.grad = grad # Set gradients + optimizer.step() # Perform optimizer step + optimizer.zero_grad() # Reset gradients + + +def acco_algorithm( + model: torch.nn.Module, tokenizer, data_loader: DataLoader, optimizer: Optimizer, num_steps: int +) -> None: + """ + ACCO algorithm implementation without memory leaks. + """ + model_params = [p for p in model.parameters() if p.requires_grad] + + for step, batch in enumerate(data_loader): + if step >= num_steps: + break + + # Split the batch into two halves + batch_text = batch["text"] + mid_point = len(batch_text) // 2 + first_half, second_half = batch_text[:mid_point], batch_text[mid_point:] + + # Stage 1: Compute gradients g_t using the second half of the batch + loss_t = compute_loss(model, second_half, tokenizer) + loss_t.backward() # Compute gradients for g_t + g_t = [p.grad.clone() for p in model_params] # Copy gradients + optimizer.zero_grad() # Clear gradients to avoid accumulation + + # Stage 2: Estimate next parameters (tilde_theta_t+1) + with torch.no_grad(): + tilde_theta_t1 = [param - optimizer.defaults["lr"] * grad for param, grad in zip(model_params, g_t)] + + # Temporarily update model parameters to tilde_theta_t+1 + for param, tilde_param in zip(model_params, tilde_theta_t1): + param.data.copy_(tilde_param) + + # Compute estimated gradients tilde_g_t+1 using the first half of the batch + loss_tilde = compute_loss(model, first_half, tokenizer) + tilde_g_t1 = torch.autograd.grad(loss_tilde, model_params) # No need for retain_graph + + # Restore original parameters + for param, original_param in zip(model_params, tilde_theta_t1): + param.data.copy_(original_param) + + # Update parameters theta_t+1 using combined gradients + combined_gradients = [g_t_i + tilde_g_t1_i for g_t_i, tilde_g_t1_i in zip(g_t, tilde_g_t1)] + optimizer_step(optimizer, model_params, combined_gradients) + + print(f"Step {step + 1}/{num_steps}: Loss = {loss_t.item()}") + + +# Main function +def main(): + # Load dataset + dataset = load_dataset( + "/root/prime/prime/datasets/fineweb-edu", split="train", streaming=True + ) # Small subset for example + data_loader = DataLoader(dataset, batch_size=8, shuffle=False) + + # Load model and tokenizer + model_name = "llama-debug" # Replace with actual Llama model if available + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") + tokenizer.pad_token = tokenizer.eos_token + config = AutoConfig.from_pretrained(model_name) + model = AutoModelForCausalLM.from_config(config) + model = model.to("cuda" if torch.cuda.is_available() else "cpu") + + # Define optimizer + optimizer = AdamW(model.parameters(), lr=1e-4) + + # Run ACCO algorithm + num_steps = 100 + acco_algorithm(model, tokenizer, data_loader, optimizer, num_steps) + + +# Entry point +if __name__ == "__main__": + main() From 449948463dedbd031ef4db57ad12703b479ea27b Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sat, 11 Jan 2025 21:05:13 +0000 Subject: [PATCH 07/17] fix: faithful reproduction of acco eq --- train_dpu.py | 53 ++++++++++++++++++++++------------------------------ 1 file changed, 22 insertions(+), 31 deletions(-) diff --git a/train_dpu.py b/train_dpu.py index 13fd5c29..97d1d674 100644 --- a/train_dpu.py +++ b/train_dpu.py @@ -31,15 +31,6 @@ def compute_loss(model: torch.nn.Module, inputs: List[str], tokenizer) -> torch. return outputs.loss -# Optimizer update step -def optimizer_step(optimizer, model_params, gradients): - for param, grad in zip(model_params, gradients): - if param.grad is not None: - param.grad = grad # Set gradients - optimizer.step() # Perform optimizer step - optimizer.zero_grad() # Reset gradients - - def acco_algorithm( model: torch.nn.Module, tokenizer, data_loader: DataLoader, optimizer: Optimizer, num_steps: int ) -> None: @@ -48,6 +39,7 @@ def acco_algorithm( """ model_params = [p for p in model.parameters() if p.requires_grad] + first_step = True for step, batch in enumerate(data_loader): if step >= num_steps: break @@ -57,31 +49,32 @@ def acco_algorithm( mid_point = len(batch_text) // 2 first_half, second_half = batch_text[:mid_point], batch_text[mid_point:] - # Stage 1: Compute gradients g_t using the second half of the batch - loss_t = compute_loss(model, second_half, tokenizer) + # Stage 1: Compute gradients g_t and tilde_theta_t+1 + loss_t = compute_loss(model, first_half, tokenizer) loss_t.backward() # Compute gradients for g_t - g_t = [p.grad.clone() for p in model_params] # Copy gradients - optimizer.zero_grad() # Clear gradients to avoid accumulation + g_t = [p.grad.cpu() for p in model_params] + theta_t = [p.cpu() for p in model_params] + # TODO: Gather gradients from other workers - # Stage 2: Estimate next parameters (tilde_theta_t+1) - with torch.no_grad(): - tilde_theta_t1 = [param - optimizer.defaults["lr"] * grad for param, grad in zip(model_params, g_t)] + if not first_step: + optimizer.step() + optimizer.zero_grad() + first_step = False - # Temporarily update model parameters to tilde_theta_t+1 - for param, tilde_param in zip(model_params, tilde_theta_t1): - param.data.copy_(tilde_param) - - # Compute estimated gradients tilde_g_t+1 using the first half of the batch - loss_tilde = compute_loss(model, first_half, tokenizer) - tilde_g_t1 = torch.autograd.grad(loss_tilde, model_params) # No need for retain_graph + # Stage 2: Compute g_tilde_t+1 and theta_t+1 + loss_tilde = compute_loss(model, second_half, tokenizer) + loss_tilde.backward() # Restore original parameters - for param, original_param in zip(model_params, tilde_theta_t1): + for param, original_param in zip(model_params, theta_t): param.data.copy_(original_param) - - # Update parameters theta_t+1 using combined gradients - combined_gradients = [g_t_i + tilde_g_t1_i for g_t_i, tilde_g_t1_i in zip(g_t, tilde_g_t1)] - optimizer_step(optimizer, model_params, combined_gradients) + + # Incorporate other workers grads + for param, grad in zip(model_params, g_t): + param.grad += grad.cuda() # TODO: Offload optimizer + + optimizer.step() + optimizer.zero_grad() print(f"Step {step + 1}/{num_steps}: Loss = {loss_t.item()}") @@ -89,9 +82,7 @@ def acco_algorithm( # Main function def main(): # Load dataset - dataset = load_dataset( - "/root/prime/prime/datasets/fineweb-edu", split="train", streaming=True - ) # Small subset for example + dataset = load_dataset("/root/prime/datasets/fineweb-edu", split="train", streaming=True) data_loader = DataLoader(dataset, batch_size=8, shuffle=False) # Load model and tokenizer From e7abcd37d5170e3571bacde2ddd53c828a968e79 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sat, 11 Jan 2025 21:15:22 +0000 Subject: [PATCH 08/17] move to main --- train_dpu.py | 46 +++++++++++++++++++--------------------------- 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/train_dpu.py b/train_dpu.py index 97d1d674..18acf924 100644 --- a/train_dpu.py +++ b/train_dpu.py @@ -30,13 +30,26 @@ def compute_loss(model: torch.nn.Module, inputs: List[str], tokenizer) -> torch. outputs = model(input_ids, labels=labels) return outputs.loss +# Main function +def main(): + # Load dataset + dataset = load_dataset("/root/prime/datasets/fineweb-edu", split="train", streaming=True) + data_loader = DataLoader(dataset, batch_size=8, shuffle=False) + + # Load model and tokenizer + model_name = "llama-debug" # Replace with actual Llama model if available + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") + tokenizer.pad_token = tokenizer.eos_token + config = AutoConfig.from_pretrained(model_name) + model = AutoModelForCausalLM.from_config(config) + model = model.to("cuda" if torch.cuda.is_available() else "cpu") + + # Define optimizer + optimizer = AdamW(model.parameters(), lr=1e-4) + + # Run ACCO algorithm + num_steps = 100 -def acco_algorithm( - model: torch.nn.Module, tokenizer, data_loader: DataLoader, optimizer: Optimizer, num_steps: int -) -> None: - """ - ACCO algorithm implementation without memory leaks. - """ model_params = [p for p in model.parameters() if p.requires_grad] first_step = True @@ -79,27 +92,6 @@ def acco_algorithm( print(f"Step {step + 1}/{num_steps}: Loss = {loss_t.item()}") -# Main function -def main(): - # Load dataset - dataset = load_dataset("/root/prime/datasets/fineweb-edu", split="train", streaming=True) - data_loader = DataLoader(dataset, batch_size=8, shuffle=False) - - # Load model and tokenizer - model_name = "llama-debug" # Replace with actual Llama model if available - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") - tokenizer.pad_token = tokenizer.eos_token - config = AutoConfig.from_pretrained(model_name) - model = AutoModelForCausalLM.from_config(config) - model = model.to("cuda" if torch.cuda.is_available() else "cpu") - - # Define optimizer - optimizer = AdamW(model.parameters(), lr=1e-4) - - # Run ACCO algorithm - num_steps = 100 - acco_algorithm(model, tokenizer, data_loader, optimizer, num_steps) - # Entry point if __name__ == "__main__": From 9f551ed77b49dc63ea94a7f3b637d495237d71b5 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sat, 11 Jan 2025 22:45:36 +0000 Subject: [PATCH 09/17] doesnt explode --- train_dpu.py | 70 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 43 insertions(+), 27 deletions(-) diff --git a/train_dpu.py b/train_dpu.py index 18acf924..5d574adf 100644 --- a/train_dpu.py +++ b/train_dpu.py @@ -1,3 +1,4 @@ +import torch.distributed as dist from typing import List import torch from torch.utils.data import DataLoader @@ -32,6 +33,7 @@ def compute_loss(model: torch.nn.Module, inputs: List[str], tokenizer) -> torch. # Main function def main(): + batch_size = 8 # Load dataset dataset = load_dataset("/root/prime/datasets/fineweb-edu", split="train", streaming=True) data_loader = DataLoader(dataset, batch_size=8, shuffle=False) @@ -42,16 +44,17 @@ def main(): tokenizer.pad_token = tokenizer.eos_token config = AutoConfig.from_pretrained(model_name) model = AutoModelForCausalLM.from_config(config) - model = model.to("cuda" if torch.cuda.is_available() else "cpu") + theta_t = [p.detach().clone() for p in model.parameters() if p.requires_grad] + optimizer_copy = [p.detach().clone() for p in model.parameters() if p.requires_grad] + reduce_work = [] + model.to("cuda") # Define optimizer - optimizer = AdamW(model.parameters(), lr=1e-4) + optimizer = AdamW(optimizer_copy, lr=1e-4) # Run ACCO algorithm num_steps = 100 - model_params = [p for p in model.parameters() if p.requires_grad] - first_step = True for step, batch in enumerate(data_loader): if step >= num_steps: @@ -62,37 +65,50 @@ def main(): mid_point = len(batch_text) // 2 first_half, second_half = batch_text[:mid_point], batch_text[mid_point:] - # Stage 1: Compute gradients g_t and tilde_theta_t+1 - loss_t = compute_loss(model, first_half, tokenizer) - loss_t.backward() # Compute gradients for g_t - g_t = [p.grad.cpu() for p in model_params] - theta_t = [p.cpu() for p in model_params] - # TODO: Gather gradients from other workers + # Stage 1: Compute gradients g_tilde and theta + for p in model.parameters(): + p.grad = None + loss = compute_loss(model, first_half, tokenizer) + loss.backward() # Compute gradients for g_t + for work in reduce_work: + work.wait() if not first_step: + for opt_param, cpu_param, _g_t, _g_tilde in zip(optimizer_copy, theta_t, g_t, g_tilde): + opt_param.data = cpu_param.data + opt_param.grad = (_g_t + _g_tilde) / (batch_size * dist.get_world_size()) optimizer.step() - optimizer.zero_grad() + for param, cpu_param, opt_param in zip(model.parameters(), theta_t, optimizer_copy): + param.data.copy_(opt_param.data, non_blocking=True) + cpu_param.data.copy_(opt_param.data, non_blocking=True) first_step = False - # Stage 2: Compute g_tilde_t+1 and theta_t+1 - loss_tilde = compute_loss(model, second_half, tokenizer) - loss_tilde.backward() - - # Restore original parameters - for param, original_param in zip(model_params, theta_t): - param.data.copy_(original_param) - - # Incorporate other workers grads - for param, grad in zip(model_params, g_t): - param.grad += grad.cuda() # TODO: Offload optimizer - + g_tilde = [p.grad.cpu() for p in model.parameters() if p.requires_grad] + reduce_work = [dist.all_reduce(_g_tilde, op=dist.ReduceOp.SUM, async_op=True) for _g_tilde in g_tilde] + + # Stage 2: Compute g_t and theta_tilde + for p in model.parameters(): + p.grad = None + loss = compute_loss(model, second_half, tokenizer) + loss.backward() + g_t = [p.grad.cpu() for p in model.parameters() if p.requires_grad] + for work in reduce_work: + work.wait() + reduce_work = [dist.all_reduce(_g_t, op=dist.ReduceOp.SUM, async_op=True) for _g_t in g_t] + + # theta_tilde + for param, _g_tilde in zip(optimizer_copy, g_tilde): + ## TODO: Weight by seen by batches + param.grad = _g_tilde / (batch_size // 2 * dist.get_world_size()) optimizer.step() - optimizer.zero_grad() - - print(f"Step {step + 1}/{num_steps}: Loss = {loss_t.item()}") - + for param, param_tilde in zip(model.parameters(), optimizer_copy): + param.data.copy_(param_tilde, non_blocking=True) + print(f"Step {step + 1}/{num_steps}: Loss = {loss.item()}") # Entry point if __name__ == "__main__": + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + torch.cuda.set_device(dist.get_rank()) main() + dist.destroy_process_group() From 0fb66373f8b7ebea107976d91982569aff4e30ba Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Mon, 13 Jan 2025 01:25:27 +0000 Subject: [PATCH 10/17] with cpu offload --- train_dpu.py | 42 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/train_dpu.py b/train_dpu.py index 5d574adf..ef6e2953 100644 --- a/train_dpu.py +++ b/train_dpu.py @@ -5,6 +5,9 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from datasets import load_dataset from torch.optim import AdamW, Optimizer +import wandb +import psutil +from tqdm import tqdm # Loss function @@ -31,11 +34,37 @@ def compute_loss(model: torch.nn.Module, inputs: List[str], tokenizer) -> torch. outputs = model(input_ids, labels=labels) return outputs.loss +def print_memory_usage(): + # Get CPU memory usage + memory_info = psutil.virtual_memory() + cpu_memory_used = memory_info.used / (1024 ** 2) + cpu_memory_total = memory_info.total / (1024 ** 2) + + print(f"CPU Memory Usage:") + print(f"Used: {cpu_memory_used:.2f} MB") + print(f"Total: {cpu_memory_total:.2f} MB") + print(f"Percentage: {memory_info.percent}%\n") + + # Check if CUDA is available + if torch.cuda.is_available(): + # Get current device + device = torch.device('cuda') + gpu_memory_used = torch.cuda.memory_allocated(device=device) + gpu_memory_reserved = torch.cuda.memory_reserved(device=device) + gpu_memory_total = torch.cuda.get_device_properties(device).total_memory + + print(f"GPU Memory Usage (Device: {torch.cuda.get_device_name(device)}):") + print(f"Allocated: {gpu_memory_used / (1024 ** 2):.2f} MB") + print(f"Reserved: {gpu_memory_reserved / (1024 ** 2):.2f} MB") + print(f"Total: {gpu_memory_total / (1024 ** 2):.2f} MB\n") + else: + print("CUDA is not available.") + # Main function def main(): batch_size = 8 # Load dataset - dataset = load_dataset("/root/prime/datasets/fineweb-edu", split="train", streaming=True) + dataset = load_dataset("/root/prime/prime/datasets/fineweb-edu", split="train", streaming=True) data_loader = DataLoader(dataset, batch_size=8, shuffle=False) # Load model and tokenizer @@ -44,6 +73,8 @@ def main(): tokenizer.pad_token = tokenizer.eos_token config = AutoConfig.from_pretrained(model_name) model = AutoModelForCausalLM.from_config(config) + print(f"Model params: {sum(p.numel() for p in model.parameters()):,}, Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}") + print_memory_usage() theta_t = [p.detach().clone() for p in model.parameters() if p.requires_grad] optimizer_copy = [p.detach().clone() for p in model.parameters() if p.requires_grad] reduce_work = [] @@ -56,7 +87,9 @@ def main(): num_steps = 100 first_step = True - for step, batch in enumerate(data_loader): + print("Post Init") + print_memory_usage() + for step, batch in tqdm(enumerate(data_loader), total=num_steps): if step >= num_steps: break @@ -105,10 +138,15 @@ def main(): param.data.copy_(param_tilde, non_blocking=True) print(f"Step {step + 1}/{num_steps}: Loss = {loss.item()}") + wandb.log({"loss": loss.item()}) + print(f"End of step {step}") + print_memory_usage() # Entry point if __name__ == "__main__": dist.init_process_group(backend="cpu:gloo,cuda:nccl") torch.cuda.set_device(dist.get_rank()) + wandb.init() main() + wandb.finish() dist.destroy_process_group() From 029673e98cfa97f212ffebd56a79ae6f036d2065 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Mon, 13 Jan 2025 02:02:36 +0000 Subject: [PATCH 11/17] revert to main train.py --- src/zeroband/train.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 27ad7825..87f4635d 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -15,11 +15,10 @@ from zeroband import utils from zeroband.diloco import Diloco from zeroband.comms import ElasticDeviceMesh -from zeroband.global_ddp import GlobalDDP from zeroband.loss import cross_entropy_max_z_loss from zeroband.models.llama.model import create_block_mask_from_seqlens -from zeroband.config import Config # , MemoryProfilerConfig +from zeroband.config import Config #, MemoryProfilerConfig from zeroband.optimizers import get_optimizer from zeroband.utils import ( @@ -40,7 +39,6 @@ from zeroband.checkpoint import CkptManager, TrainingProgress from zeroband.lr_scheduler import get_scheduler - def log_hash_training_state( config: Config, model: torch.nn.Module, @@ -139,7 +137,7 @@ def train(config: Config): apply_ac_ckpt(model, num) elastic_device_mesh = ElasticDeviceMesh( - enable=config.diloco is not None or config.global_ddp, live_recovery_rank_src=config.ckpt.live_recovery_rank_src + enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src ) mp_policy = MixedPrecisionPolicy( @@ -170,9 +168,6 @@ def train(config: Config): diloco = Diloco(config.diloco, model, elastic_device_mesh) if config.diloco is not None else None - if config.global_ddp: - global_ddp = GlobalDDP(model=model, config=config.global_ddp, elastic_device_mesh=elastic_device_mesh) - scheduler = get_scheduler( sched_type=config.optim.sched_type, optimizer=inner_optimizer, @@ -236,7 +231,6 @@ def train(config: Config): logger.info("starting training") need_live_recovery = config.ckpt.live_recovery_rank_src is not None - first_step = True while True: if num_inner_steps > 1: # if we don't use diloco we don't print the outer step logs @@ -338,18 +332,9 @@ def train(config: Config): dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) - - if config.global_ddp: - global_ddp.all_reduce() - - if config.global_ddp is not None and config.global_ddp.dpu and first_step: - inner_optimizer.zero_grad() - first_step = False - ## if we are at the beginning of the dpu bubble we need to skip the first step - else: - inner_optimizer.step() - scheduler.step() - inner_optimizer.zero_grad() + inner_optimizer.step() + scheduler.step() + inner_optimizer.zero_grad() # logging training_progress.step += 1 From 7ff29a140e6b53d8a6572991a2ef640e23104811 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Mon, 13 Jan 2025 18:17:33 +0000 Subject: [PATCH 12/17] really ugly implementation that doesnt explode --- src/zeroband/config.py | 4 +- src/zeroband/dpu.py | 5 +++ src/zeroband/train.py | 97 +++++++++++++++++++++++++++++++++++++++--- 3 files changed, 99 insertions(+), 7 deletions(-) create mode 100644 src/zeroband/dpu.py diff --git a/src/zeroband/config.py b/src/zeroband/config.py index 61ac6e56..a82a043a 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -6,9 +6,9 @@ from zeroband.checkpoint import CkptConfig from zeroband.data import DataConfig from zeroband.diloco import DilocoConfig -from zeroband.global_ddp import GlobalDDPConfig from zeroband.models.llama.model import AttnFnType from zeroband.optimizers import OptimizersConfig, AdamConfig +from zeroband.dpu import ACCOConfig class OptimConfig(BaseConfig): @@ -69,7 +69,7 @@ class Config(BaseConfig): # sub config diloco: DilocoConfig | None = None - global_ddp: GlobalDDPConfig | None = None + acco: ACCOConfig | None = None data: DataConfig = DataConfig() optim: OptimConfig = OptimConfig() train: TrainConfig diff --git a/src/zeroband/dpu.py b/src/zeroband/dpu.py new file mode 100644 index 00000000..d6c5161a --- /dev/null +++ b/src/zeroband/dpu.py @@ -0,0 +1,5 @@ +from typing import Optional +from pydantic_config import BaseConfig + +class ACCOConfig(BaseConfig): + theta_t_device: Optional[str] = None diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 87f4635d..2e535e61 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -83,6 +83,9 @@ def train(config: Config): assert batch_size % config.train.micro_bs == 0 gradient_accumulation_steps = batch_size // config.train.micro_bs + if config.acco: + assert gradient_accumulation_steps % 2 == 0, "ACCO requires gradient accumulation steps to be even" + gradient_accumulation_steps //= 2 if config.ckpt is not None and config.ckpt.interval is not None and config.diloco is not None: assert ( @@ -137,7 +140,7 @@ def train(config: Config): apply_ac_ckpt(model, num) elastic_device_mesh = ElasticDeviceMesh( - enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src + enable=config.diloco is not None or config.acco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src ) mp_policy = MixedPrecisionPolicy( @@ -165,6 +168,12 @@ def train(config: Config): # Setup optimizers inner_optimizer = get_optimizer(model.parameters(), config.optim.optim) + if config.acco is not None: + first_step = True + reduce_work = [] + theta_t = [p.detach().clone() for p in model.parameters() if p.requires_grad] + if config.acco.theta_t_device is not None: + theta_t = [p.to(config.acco.theta_t_device) for p in theta_t] diloco = Diloco(config.diloco, model, elastic_device_mesh) if config.diloco is not None else None @@ -331,10 +340,88 @@ def train(config: Config): if config.optim.z_loss: dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg) - torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) - inner_optimizer.step() - scheduler.step() - inner_optimizer.zero_grad() + if config.acco is not None: + # TODO: This is wrong, we overwrite g_tilde before we use it in the update + g_tilde = [p.grad.detach().cpu() for p in model.parameters() if p.requires_grad] + for work in reduce_work: + work.wait() + #reduce_work = [elastic_device_mesh.global_pg.allreduce([_g_tilde], op=dist.ReduceOp.SUM) for _g_tilde in g_tilde] + #reduce_work = [dist.all_reduce(_g_tilde, dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) for _g_tilde in g_tilde] + #a = torch.randn(10, device="cpu") + #work = dist.all_reduce(a, op=dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) + #work.wait() + + if not first_step: + # Copy in theta_t and consume g_t + for opt_param, cpu_param, _g_t, _g_tilde in zip(model.parameters(), theta_t, g_t, g_tilde): + opt_param.data.copy_(cpu_param.data, non_blocking=True) + opt_param.grad.copy_(_g_t + _g_tilde, non_blocking=True) + opt_param.grad /= batch_size * elastic_device_mesh.global_pg.size() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + inner_optimizer.step() + scheduler.step() + inner_optimizer.zero_grad() + # Update theta_t + for param, cpu_param in zip(model.parameters(), theta_t): + cpu_param.data.copy_(param.data, non_blocking=True) + first_step = False + + for _g_tilde in g_tilde: + work = dist.all_reduce(_g_tilde.to_local(), dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) + work.wait() + + # Stage 2: Compute g_t and theta_tilde + for grad_acc_step in range(gradient_accumulation_steps): + is_accumulating = grad_acc_step < gradient_accumulation_steps - 1 + # no sync if we are accumulating gradients + model.set_requires_gradient_sync(not is_accumulating) + + batch = next(train_dataloader_iterator) + input_ids = batch["input_ids"].to("cuda") + labels = batch["labels"].to("cuda") + if config.train.sequence_packing: + seqlens = [seqlen.to("cuda") for seqlen in batch["seqlens"]] + block_mask = create_block_mask_from_seqlens(seqlens) if seqlens is not None else None + else: + block_mask = None + + logits = model(tokens=input_ids, block_mask=block_mask).contiguous() + flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab") + flatten_labels = rearrange(labels, "b seq -> (b seq)") + + if config.optim.z_loss: + ce_loss, z_loss = cross_entropy_max_z_loss( + flatten_logits, flatten_labels, config.optim.z_loss_weight + ) + ce_loss /= gradient_accumulation_steps + z_loss /= gradient_accumulation_steps + + del logits + loss = ce_loss + z_loss + loss.backward() + + else: + loss = F.cross_entropy(flatten_logits, flatten_labels) / gradient_accumulation_steps + del logits + loss.backward() + + if config.optim.z_loss: + loss_batch += ce_loss.clone().detach() + z_loss_batch += z_loss.clone().detach() + else: + loss_batch += loss.clone().detach() + + g_t = [p.grad.detach().cpu() for p in model.parameters() if p.requires_grad] + for work in reduce_work: + work.wait() + #reduce_work = [elastic_device_mesh.global_pg.allreduce([_g_t], op=dist.ReduceOp.SUM) for _g_t in g_t] + reduce_work = [dist.all_reduce(_g_t.to_local(), dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) for _g_t in g_t] + + for opt_param, _g_tilde in zip(model.parameters(), g_tilde): + opt_param.grad.copy_(_g_tilde, non_blocking=True) + opt_param.grad /= (batch_size // 2 * elastic_device_mesh.global_pg.size()) + inner_optimizer.step() + inner_optimizer.zero_grad() # logging training_progress.step += 1 From e4323b48d8803e28350739626a08fbc5811f4145 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Mon, 13 Jan 2025 18:35:02 +0000 Subject: [PATCH 13/17] save g_tilde before update --- src/zeroband/train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 2e535e61..058f2966 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -341,8 +341,7 @@ def train(config: Config): dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg) if config.acco is not None: - # TODO: This is wrong, we overwrite g_tilde before we use it in the update - g_tilde = [p.grad.detach().cpu() for p in model.parameters() if p.requires_grad] + new_g_tilde = [p.grad.detach().cpu() for p in model.parameters() if p.requires_grad] for work in reduce_work: work.wait() #reduce_work = [elastic_device_mesh.global_pg.allreduce([_g_tilde], op=dist.ReduceOp.SUM) for _g_tilde in g_tilde] @@ -366,9 +365,9 @@ def train(config: Config): cpu_param.data.copy_(param.data, non_blocking=True) first_step = False + g_tilde = new_g_tilde for _g_tilde in g_tilde: work = dist.all_reduce(_g_tilde.to_local(), dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) - work.wait() # Stage 2: Compute g_t and theta_tilde for grad_acc_step in range(gradient_accumulation_steps): From 8b6874ab44acba190d4ba72696d42eabb7f55cb6 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Mon, 13 Jan 2025 18:51:03 +0000 Subject: [PATCH 14/17] refactor out loss computation --- src/zeroband/train.py | 193 ++++++++++++++++++++---------------------- 1 file changed, 94 insertions(+), 99 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 058f2966..6796f16d 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -6,6 +6,7 @@ from pydantic_config import parse_argv from einops import rearrange from torch.nn import functional as F +from torch import nn from transformers import AutoTokenizer @@ -18,7 +19,7 @@ from zeroband.loss import cross_entropy_max_z_loss from zeroband.models.llama.model import create_block_mask_from_seqlens -from zeroband.config import Config #, MemoryProfilerConfig +from zeroband.config import Config # , MemoryProfilerConfig from zeroband.optimizers import get_optimizer from zeroband.utils import ( @@ -39,6 +40,7 @@ from zeroband.checkpoint import CkptManager, TrainingProgress from zeroband.lr_scheduler import get_scheduler + def log_hash_training_state( config: Config, model: torch.nn.Module, @@ -76,6 +78,61 @@ def log_hash_training_state( metric_logger.log(metrics) +def compute_loss( + model: nn.Module, + gradient_accumulation_steps: int, + train_dataloader_iterator: iter, + local_pg: dist.ProcessGroup, + enable_z_loss: bool = False, +): + loss_batch = 0 + z_loss_batch = 0 + + for grad_acc_step in range(gradient_accumulation_steps): + is_accumulating = grad_acc_step < gradient_accumulation_steps - 1 + # no sync if we are accumulating gradients + model.set_requires_gradient_sync(not is_accumulating) + + batch = next(train_dataloader_iterator) + input_ids = batch["input_ids"].to("cuda") + labels = batch["labels"].to("cuda") + if config.train.sequence_packing: + seqlens = [seqlen.to("cuda") for seqlen in batch["seqlens"]] + block_mask = create_block_mask_from_seqlens(seqlens) if seqlens is not None else None + else: + block_mask = None + + logits = model(tokens=input_ids, block_mask=block_mask).contiguous() + flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab") + flatten_labels = rearrange(labels, "b seq -> (b seq)") + + if enable_z_loss: + ce_loss, z_loss = cross_entropy_max_z_loss(flatten_logits, flatten_labels, config.optim.z_loss_weight) + ce_loss /= gradient_accumulation_steps + z_loss /= gradient_accumulation_steps + + del logits + loss = ce_loss + z_loss + loss.backward() + + else: + loss = F.cross_entropy(flatten_logits, flatten_labels) / gradient_accumulation_steps + del logits + loss.backward() + + if config.optim.z_loss: + loss_batch += ce_loss.clone().detach() + z_loss_batch += z_loss.clone().detach() + else: + loss_batch += loss.clone().detach() + + dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=local_pg) + if config.optim.z_loss: + dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=local_pg) + + return loss_batch, z_loss_batch + + def train(config: Config): # batch_size is the total batch size for all GPUs assert config.optim.batch_size % world_info.local_world_size == 0 @@ -140,7 +197,8 @@ def train(config: Config): apply_ac_ckpt(model, num) elastic_device_mesh = ElasticDeviceMesh( - enable=config.diloco is not None or config.acco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src + enable=config.diloco is not None or config.acco is not None, + live_recovery_rank_src=config.ckpt.live_recovery_rank_src, ) mp_policy = MixedPrecisionPolicy( @@ -293,66 +351,27 @@ def train(config: Config): monitor.set_stage("inner_loop") for inner_step in range(num_inner_steps): - loss_batch = 0 - z_loss_batch = 0 - - for grad_acc_step in range(gradient_accumulation_steps): - is_accumulating = grad_acc_step < gradient_accumulation_steps - 1 - # no sync if we are accumulating gradients - model.set_requires_gradient_sync(not is_accumulating) - - batch = next(train_dataloader_iterator) - input_ids = batch["input_ids"].to("cuda") - labels = batch["labels"].to("cuda") - if config.train.sequence_packing: - seqlens = [seqlen.to("cuda") for seqlen in batch["seqlens"]] - block_mask = create_block_mask_from_seqlens(seqlens) if seqlens is not None else None - else: - block_mask = None - - logits = model(tokens=input_ids, block_mask=block_mask).contiguous() - flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab") - flatten_labels = rearrange(labels, "b seq -> (b seq)") - - if config.optim.z_loss: - ce_loss, z_loss = cross_entropy_max_z_loss( - flatten_logits, flatten_labels, config.optim.z_loss_weight - ) - ce_loss /= gradient_accumulation_steps - z_loss /= gradient_accumulation_steps - - del logits - loss = ce_loss + z_loss - loss.backward() - - else: - loss = F.cross_entropy(flatten_logits, flatten_labels) / gradient_accumulation_steps - del logits - loss.backward() - - if config.optim.z_loss: - loss_batch += ce_loss.clone().detach() - z_loss_batch += z_loss.clone().detach() - else: - loss_batch += loss.clone().detach() - - dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg) - if config.optim.z_loss: - dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg) + loss_batch, z_loss_batch = compute_loss( + model, + gradient_accumulation_steps, + train_dataloader_iterator, + elastic_device_mesh.local_pg, + config.optim.z_loss, + ) if config.acco is not None: new_g_tilde = [p.grad.detach().cpu() for p in model.parameters() if p.requires_grad] for work in reduce_work: work.wait() - #reduce_work = [elastic_device_mesh.global_pg.allreduce([_g_tilde], op=dist.ReduceOp.SUM) for _g_tilde in g_tilde] - #reduce_work = [dist.all_reduce(_g_tilde, dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) for _g_tilde in g_tilde] - #a = torch.randn(10, device="cpu") - #work = dist.all_reduce(a, op=dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) - #work.wait() + # reduce_work = [elastic_device_mesh.global_pg.allreduce([_g_tilde], op=dist.ReduceOp.SUM) for _g_tilde in g_tilde] + # reduce_work = [dist.all_reduce(_g_tilde, dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) for _g_tilde in g_tilde] + # a = torch.randn(10, device="cpu") + # work = dist.all_reduce(a, op=dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) + # work.wait() if not first_step: # Copy in theta_t and consume g_t - for opt_param, cpu_param, _g_t, _g_tilde in zip(model.parameters(), theta_t, g_t, g_tilde): + for opt_param, cpu_param, _g_t, _g_tilde in zip(model.parameters(), theta_t, g_t, g_tilde): # noqa opt_param.data.copy_(cpu_param.data, non_blocking=True) opt_param.grad.copy_(_g_t + _g_tilde, non_blocking=True) opt_param.grad /= batch_size * elastic_device_mesh.global_pg.size() @@ -367,58 +386,34 @@ def train(config: Config): g_tilde = new_g_tilde for _g_tilde in g_tilde: - work = dist.all_reduce(_g_tilde.to_local(), dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) - - # Stage 2: Compute g_t and theta_tilde - for grad_acc_step in range(gradient_accumulation_steps): - is_accumulating = grad_acc_step < gradient_accumulation_steps - 1 - # no sync if we are accumulating gradients - model.set_requires_gradient_sync(not is_accumulating) - - batch = next(train_dataloader_iterator) - input_ids = batch["input_ids"].to("cuda") - labels = batch["labels"].to("cuda") - if config.train.sequence_packing: - seqlens = [seqlen.to("cuda") for seqlen in batch["seqlens"]] - block_mask = create_block_mask_from_seqlens(seqlens) if seqlens is not None else None - else: - block_mask = None - - logits = model(tokens=input_ids, block_mask=block_mask).contiguous() - flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab") - flatten_labels = rearrange(labels, "b seq -> (b seq)") - - if config.optim.z_loss: - ce_loss, z_loss = cross_entropy_max_z_loss( - flatten_logits, flatten_labels, config.optim.z_loss_weight - ) - ce_loss /= gradient_accumulation_steps - z_loss /= gradient_accumulation_steps - - del logits - loss = ce_loss + z_loss - loss.backward() - - else: - loss = F.cross_entropy(flatten_logits, flatten_labels) / gradient_accumulation_steps - del logits - loss.backward() - - if config.optim.z_loss: - loss_batch += ce_loss.clone().detach() - z_loss_batch += z_loss.clone().detach() - else: - loss_batch += loss.clone().detach() + work = dist.all_reduce( + _g_tilde.to_local(), dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True + ) + + loss_batch_1, z_loss_batch_1 = compute_loss( + model, + gradient_accumulation_steps, + train_dataloader_iterator, + elastic_device_mesh.local_pg, + config.optim.z_loss, + ) + loss_batch += loss_batch_1 + z_loss_batch += z_loss_batch_1 g_t = [p.grad.detach().cpu() for p in model.parameters() if p.requires_grad] for work in reduce_work: work.wait() - #reduce_work = [elastic_device_mesh.global_pg.allreduce([_g_t], op=dist.ReduceOp.SUM) for _g_t in g_t] - reduce_work = [dist.all_reduce(_g_t.to_local(), dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) for _g_t in g_t] + # reduce_work = [elastic_device_mesh.global_pg.allreduce([_g_t], op=dist.ReduceOp.SUM) for _g_t in g_t] + reduce_work = [ + dist.all_reduce( + _g_t.to_local(), dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True + ) + for _g_t in g_t + ] for opt_param, _g_tilde in zip(model.parameters(), g_tilde): opt_param.grad.copy_(_g_tilde, non_blocking=True) - opt_param.grad /= (batch_size // 2 * elastic_device_mesh.global_pg.size()) + opt_param.grad /= batch_size // 2 * elastic_device_mesh.global_pg.size() inner_optimizer.step() inner_optimizer.zero_grad() From fd10802c0e2ea2e2897032ea1f9fe70a4818f17c Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Mon, 13 Jan 2025 19:30:35 +0000 Subject: [PATCH 15/17] fix: ddp optimizer step --- src/zeroband/train.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 6796f16d..2c91a9dc 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -83,6 +83,7 @@ def compute_loss( gradient_accumulation_steps: int, train_dataloader_iterator: iter, local_pg: dist.ProcessGroup, + loss_scaling: float = 1.0, enable_z_loss: bool = False, ): loss_batch = 0 @@ -108,15 +109,15 @@ def compute_loss( if enable_z_loss: ce_loss, z_loss = cross_entropy_max_z_loss(flatten_logits, flatten_labels, config.optim.z_loss_weight) - ce_loss /= gradient_accumulation_steps - z_loss /= gradient_accumulation_steps + ce_loss /= gradient_accumulation_steps * loss_scaling + z_loss /= gradient_accumulation_steps * loss_scaling del logits loss = ce_loss + z_loss loss.backward() else: - loss = F.cross_entropy(flatten_logits, flatten_labels) / gradient_accumulation_steps + loss = F.cross_entropy(flatten_logits, flatten_labels) / gradient_accumulation_steps * loss_scaling del logits loss.backward() @@ -356,18 +357,15 @@ def train(config: Config): gradient_accumulation_steps, train_dataloader_iterator, elastic_device_mesh.local_pg, + 0.5 if config.acco is not None else 1.0, config.optim.z_loss, ) + print(loss_batch, z_loss_batch) if config.acco is not None: new_g_tilde = [p.grad.detach().cpu() for p in model.parameters() if p.requires_grad] for work in reduce_work: work.wait() - # reduce_work = [elastic_device_mesh.global_pg.allreduce([_g_tilde], op=dist.ReduceOp.SUM) for _g_tilde in g_tilde] - # reduce_work = [dist.all_reduce(_g_tilde, dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) for _g_tilde in g_tilde] - # a = torch.randn(10, device="cpu") - # work = dist.all_reduce(a, op=dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) - # work.wait() if not first_step: # Copy in theta_t and consume g_t @@ -385,21 +383,25 @@ def train(config: Config): first_step = False g_tilde = new_g_tilde - for _g_tilde in g_tilde: - work = dist.all_reduce( + reduce_work = [ + dist.all_reduce( _g_tilde.to_local(), dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True ) + for _g_tilde in g_tilde + ] loss_batch_1, z_loss_batch_1 = compute_loss( model, gradient_accumulation_steps, train_dataloader_iterator, elastic_device_mesh.local_pg, + 0.5 if config.acco is not None else 1.0, config.optim.z_loss, ) loss_batch += loss_batch_1 z_loss_batch += z_loss_batch_1 + print(loss_batch, z_loss_batch) g_t = [p.grad.detach().cpu() for p in model.parameters() if p.requires_grad] for work in reduce_work: work.wait() @@ -417,6 +419,11 @@ def train(config: Config): inner_optimizer.step() inner_optimizer.zero_grad() + else: + inner_optimizer.step() + scheduler.step() + inner_optimizer.zero_grad() + # logging training_progress.step += 1 inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0] From 3968d567ebaf82e6e9c82e73f64fb44f151340a0 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Tue, 14 Jan 2025 01:22:39 +0000 Subject: [PATCH 16/17] testing config --- configs/1B/H100.toml | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/configs/1B/H100.toml b/configs/1B/H100.toml index de9cef75..41a3db65 100644 --- a/configs/1B/H100.toml +++ b/configs/1B/H100.toml @@ -1,15 +1,22 @@ name_model = "1B" -project = "debug_1B_zero_band" +project = "adam_sweep" type_model = "llama2" [train] -micro_bs = 32 +micro_bs = 2 reshard_after_forward = true [optim] -batch_size = 1024 +batch_size = 128 warmup_steps = 1000 total_steps = 8192 -[optim.optim] -lr = 7e-4 +optim.lr = 4e-4 + +[data] +seq_length = 8192 +num_workers = 2 +dataset_name_or_paths = "/home/ubuntu/prime/datasets/fineweb-edu" +split_by_data_rank = true + +[acco] From 429ac7683c692ee556040986f1d6391490209177 Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Tue, 14 Jan 2025 09:05:21 +0000 Subject: [PATCH 17/17] lalala --- configs/150M/H100.toml | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/configs/150M/H100.toml b/configs/150M/H100.toml index a6339181..94ba83cf 100644 --- a/configs/150M/H100.toml +++ b/configs/150M/H100.toml @@ -1,16 +1,21 @@ name_model = "150M" -project = "debug_150m_zero_band" +project = "adam_sweep" type_model = "llama2" [train] -micro_bs = 64 # change this base on the gpu +micro_bs = 4 # change this base on the gpu reshard_after_forward = true [optim] -batch_size = 512 +batch_size = 128 warmup_steps = 1000 -total_steps = 88_000 +total_steps = 8192 +optim.lr = 4e-4 -[optim.optim] -lr = 4e-4 +[data] +seq_length = 8192 +num_workers = 2 +dataset_name_or_paths = "/home/ubuntu/prime/datasets/fineweb-edu" +split_by_data_rank = true +[acco]