Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dpu #178

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

dpu #178

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions llama-debug/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": 128001,
"head_dim": 64,
"hidden_act": "silu",
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"max_position_embeddings": 1024,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 16,
"num_hidden_layers": 5,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 32.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"rope_theta": 500000.0,
"tie_word_embeddings": true,
"torch_dtype": "bfloat16",
"transformers_version": "4.45.0.dev0",
"use_cache": true,
"vocab_size": 128256
}
5 changes: 3 additions & 2 deletions src/zeroband/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def gloo_all_reduce(
tensor: torch.Tensor,
op: dist.ReduceOp = dist.ReduceOp.SUM,
group: Optional[dist.ProcessGroup] = None,
) -> None:
async_op: bool = False,
) -> None | dist.Work:
"""Wrap gloo all reduce"""
if group is None:
group = dist.distributed_c10d._get_default_group()
Expand All @@ -24,7 +25,7 @@ def gloo_all_reduce(
# todo check numerical stability of doing post or pre div
tensor.div_(group.size())

dist.all_reduce(tensor, op, group=group)
return dist.all_reduce(tensor, op, group=group, async_op=async_op)


class Compression(Enum):
Expand Down
3 changes: 2 additions & 1 deletion src/zeroband/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from zeroband.checkpoint import CkptConfig
from zeroband.data import DataConfig
from zeroband.diloco import DilocoConfig
from zeroband.global_ddp import GlobalDDPConfig
from zeroband.models.llama.model import AttnFnType
from zeroband.optimizers import OptimizersConfig, AdamConfig

Expand Down Expand Up @@ -68,6 +69,7 @@ class Config(BaseConfig):

# sub config
diloco: DilocoConfig | None = None
global_ddp: GlobalDDPConfig | None = None
data: DataConfig = DataConfig()
optim: OptimConfig = OptimConfig()
train: TrainConfig
Expand All @@ -88,4 +90,3 @@ def validate_live_recovery_rank_src(self):
if self.ckpt is not None and self.ckpt.live_recovery_rank_src is not None and self.diloco is None:
raise ValueError("live_recovery_rank_src is only supported with diloco")
return self

161 changes: 161 additions & 0 deletions src/zeroband/global_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import time
from typing import Generator, NamedTuple
from pydantic import model_validator
from pydantic_config import BaseConfig
import torch
import torch.nn as nn
from zeroband.comms import ElasticDeviceMesh
import torch.distributed as dist
from zeroband.collectives import Compression, gloo_all_reduce
from torch.distributed._tensor.api import DTensor
from zeroband.utils.logging import get_logger
from zeroband.utils.world_info import get_world_info

from torch.distributed import Work

logger = get_logger(__name__)


class GlobalDDPConfig(BaseConfig):
# retry_all_reduce: int = 3
compression: Compression = Compression.NO
dpu: bool = False
enable: bool = True

@model_validator(mode="after")
def validate_compression(self):
if self.compression != Compression.NO:
raise NotImplementedError("Compression is not implemented yet")
return self


def offload_grad_generator(model: nn.Module) -> Generator:
for param in model.parameters():
if param.grad is not None:
if isinstance(param.grad, DTensor):
yield param.grad.to_local().to("cpu")
else:
yield param.grad.to("cpu")


def apply_staling_grad(model: nn.Module, tensors: list[torch.Tensor]):
for param, tensor in zip(model.parameters(), tensors):
if isinstance(param.grad, DTensor):
param.grad.to_local().copy_(tensor)
else:
param.grad.copy_(tensor)


def maybe_unwrap_dtensor(tensor: torch.Tensor | DTensor):
if isinstance(tensor, DTensor):
return tensor.to_local()
else:
return tensor


class AllReduceGradWork(NamedTuple):
grad: torch.Tensor
work: Work


def async_all_reduce(model: nn.Module, elastic_device_mesh: ElasticDeviceMesh, flag: str) -> list[AllReduceGradWork]:
"""
Triggered all reduce operation on a list of tensors in a async manner.
Return a list of async jobs that can be waited on.
"""

elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=False)
world_size = elastic_device_mesh.global_pg.size()

global_pg = elastic_device_mesh.global_pg
elastic_device_mesh.monitored_barrier(flag)
logger.debug("Beginning all reduce")

async_job = []

for param in offload_grad_generator(model): # TODO: do we need to offload when doing blocking all reduce ?
grad = maybe_unwrap_dtensor(param)

grad.div_(world_size)

# all_reduce(self.config.compression, grad, dist.ReduceOp.SUM, global_pg) # doing gloo all reduce direclty because of async op

async_job.append(AllReduceGradWork(grad, gloo_all_reduce(grad, dist.ReduceOp.SUM, global_pg, True)))

return async_job


class GlobalDDP:
"""
This class implements DDP over internet. It

:Args:
model: The model to be trained
config: The configuration for the global DDP
elastic_device_mesh: The elastic device mesh to be used

Example usage:

```
config = GlobalDDPConfig(dpu=False)
global_ddp = GlobalDDP(model, config, elastic_device_mesh)

for step in range(num_steps):
for micro_bs in range(num_micro_bs):
loss = model(batch)
loss.backward()

global_ddp.all_reduce()
optimizer.step()
optimizer.zero_grad()
```

"""

flag: str = "global_ddp"

def __init__(
self,
model: nn.Module,
config: GlobalDDPConfig,
elastic_device_mesh: ElasticDeviceMesh,
):
self.elastic_device_mesh = elastic_device_mesh
self.config = config

self.world_info = get_world_info()
self._logger = get_logger()

self.model = model

self._stalling_grad_work: list[AllReduceGradWork] | None = None

def all_reduce(self):
if not self.config.dpu:
self._blocking_all_reduce(self.model)
else:
new_staling_grad_work = async_all_reduce(self.model, self.elastic_device_mesh, self.flag)

if self._stalling_grad_work is None:
# if it is the first step we just store the work for the next call to this function and return
self._stalling_grad_work = new_staling_grad_work
else:
# otherwise we wait for the current staling grad work to finish
start_time = time.time()
[all_reduce_grad_work.work.wait() for all_reduce_grad_work in self._stalling_grad_work]
self._logger.debug(f"Time to wait for staling grads: {time.time() - start_time}")
# and apply the staling grads to the model
apply_staling_grad(
self.model, [all_reduce_grad_work.grad for all_reduce_grad_work in self._stalling_grad_work]
)
# and store the new staling grad work for the next call to this function
self._stalling_grad_work = new_staling_grad_work

def _blocking_all_reduce(self, tensor: list[torch.Tensor]):
"""
Triggered all reduce operation on a list of tensors in a blocking manner.
"""
[
all_reduce_grad_work.work.wait()
for all_reduce_grad_work in async_all_reduce(tensor, self.elastic_device_mesh, self.flag)
]
25 changes: 20 additions & 5 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
from zeroband import utils
from zeroband.diloco import Diloco
from zeroband.comms import ElasticDeviceMesh
from zeroband.global_ddp import GlobalDDP
from zeroband.loss import cross_entropy_max_z_loss

from zeroband.models.llama.model import create_block_mask_from_seqlens
from zeroband.config import Config #, MemoryProfilerConfig
from zeroband.config import Config # , MemoryProfilerConfig
from zeroband.optimizers import get_optimizer

from zeroband.utils import (
Expand All @@ -39,6 +40,7 @@
from zeroband.checkpoint import CkptManager, TrainingProgress
from zeroband.lr_scheduler import get_scheduler


def log_hash_training_state(
config: Config,
model: torch.nn.Module,
Expand Down Expand Up @@ -137,7 +139,7 @@ def train(config: Config):
apply_ac_ckpt(model, num)

elastic_device_mesh = ElasticDeviceMesh(
enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src
enable=config.diloco is not None or config.global_ddp, live_recovery_rank_src=config.ckpt.live_recovery_rank_src
)

mp_policy = MixedPrecisionPolicy(
Expand Down Expand Up @@ -168,6 +170,9 @@ def train(config: Config):

diloco = Diloco(config.diloco, model, elastic_device_mesh) if config.diloco is not None else None

if config.global_ddp:
global_ddp = GlobalDDP(model=model, config=config.global_ddp, elastic_device_mesh=elastic_device_mesh)

scheduler = get_scheduler(
sched_type=config.optim.sched_type,
optimizer=inner_optimizer,
Expand Down Expand Up @@ -231,6 +236,7 @@ def train(config: Config):
logger.info("starting training")

need_live_recovery = config.ckpt.live_recovery_rank_src is not None
first_step = True
while True:
if num_inner_steps > 1:
# if we don't use diloco we don't print the outer step logs
Expand Down Expand Up @@ -332,9 +338,18 @@ def train(config: Config):
dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg)

torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
inner_optimizer.step()
scheduler.step()
inner_optimizer.zero_grad()

if config.global_ddp:
global_ddp.all_reduce()

if config.global_ddp is not None and config.global_ddp.dpu and first_step:
inner_optimizer.zero_grad()
first_step = False
## if we are at the beginning of the dpu bubble we need to skip the first step
else:
inner_optimizer.step()
scheduler.step()
inner_optimizer.zero_grad()

# logging
training_progress.step += 1
Expand Down
15 changes: 11 additions & 4 deletions tests/test_torchrun/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def gpus_to_use(num_nodes, num_gpu, rank):
return ",".join(map(str, range(rank * num_gpu, (rank + 1) * num_gpu)))


def _test_multi_gpu(num_gpus, config, extra_args=[], diloco=False):
def _test_multi_gpu(num_gpus, config, extra_args=[], multi_nodes=False):
num_nodes, num_gpu = num_gpus[0], num_gpus[1]

processes = []
Expand All @@ -55,7 +55,7 @@ def _test_multi_gpu(num_gpus, config, extra_args=[], diloco=False):

env = copy.deepcopy(os.environ)

if diloco:
if multi_nodes:
new_env = {
"GLOBAL_RANK": str(i),
"GLOBAL_UNIQUE_ID": str(i),
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_multi_gpu(num_gpus):

@pytest.mark.parametrize("num_gpus", [[2, 1], [2, 2]] if num_gpu >= 4 else [[2, 1]])
def test_multi_gpu_diloco(num_gpus):
_test_multi_gpu(num_gpus, "debug/diloco.toml", diloco=True)
_test_multi_gpu(num_gpus, "debug/diloco.toml", multi_nodes=True)


def test_act_ckpt():
Expand All @@ -101,7 +101,7 @@ def test_act_ckpt_num():
@pytest.mark.parametrize("backend", [Compression.NO, Compression.UINT8])
def test_all_reduce_diloco(backend: Compression):
num_gpus = [2, 1]
_test_multi_gpu(num_gpus, "debug/diloco.toml", extra_args=["--diloco.compression", backend.value], diloco=True)
_test_multi_gpu(num_gpus, "debug/diloco.toml", extra_args=["--diloco.compression", backend.value], multi_nodes=True)


def test_z_loss():
Expand All @@ -116,6 +116,13 @@ def test_packing(packing: bool):
_test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[packing_arg])


@pytest.mark.parametrize("dpu", [True, False])
def test_global_ddp(dpu: bool):
num_gpus = [2, 1]
dpu_arg = "--global_ddp.dpu" if dpu else "--no-global_ddp.dpu"
_test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[dpu_arg], multi_nodes=True)


@pytest.mark.parametrize("diloco", [False, True])
def test_soap(diloco: bool):
num_gpus = [1, 2] if diloco else [2, 1]
Expand Down
Loading
Loading