Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] pccl #189

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions configs/150M/3090.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ micro_bs = 16 # change this base on the gpu
reshard_after_forward = true

[optim]
batch_size = 512
batch_size = 64
warmup_steps = 1000
total_steps = 88_000
lr = 4e-4
lr = 4e-4

[data]
fake = true

[diloco]
inner_steps = 20
4 changes: 0 additions & 4 deletions configs/debug/diloco.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,3 @@ total_steps = 4

[data]
fake = true

[diloco]
inner_steps = 5

13 changes: 13 additions & 0 deletions master.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pccl import *

HOST: str = "0.0.0.0:48148"


def main():
print(f"Starting master node on {HOST}")
master: MasterNode = MasterNode(listen_address=HOST)
master.run()


if __name__ == "__main__":
main()
19 changes: 19 additions & 0 deletions meow.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
export PCCL_LOG_LEVEL=DEBUG
export WANDB_MODE=disabled
export PCCL_MASTER_ADDR=127.0.0.1:48148

export CUDA_VISIBLE_DEVICES=0,1
#export GLOBAL_RANK=0
#export GLOBAL_UNIQUE_ID=A0
#export REPLICA_GROUP_ID=A0

#export GLOO_SOCKET_IFNAME=tailscale0
export ZERO_BAND_LOG_LEVEL=DEBUG
export ZERO_BAND_LOG_ALL_RANK=true

uv run torchrun --nproc_per_node=2 \
--rdzv-endpoint localhost:10001 \
src/zeroband/train.py \
@configs/150M/3090.toml \
--no-wandb-resume
#--ckpt.live_recovery_rank_src 0
18 changes: 18 additions & 0 deletions meow1.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
export WANDB_MODE=disabled
export PCCL_MASTER_ADDR=127.0.0.1:48148

export CUDA_VISIBLE_DEVICES=2,3
#export GLOBAL_RANK=0
#export GLOBAL_UNIQUE_ID=A0
#export REPLICA_GROUP_ID=A0

#export GLOO_SOCKET_IFNAME=tailscale0
export ZERO_BAND_LOG_LEVEL=DEBUG
export ZERO_BAND_LOG_ALL_RANK=true

uv run torchrun --nproc_per_node=2 \
--rdzv-endpoint localhost:10002 \
src/zeroband/train.py \
@configs/150M/3090.toml \
--no-wandb-resume
#--ckpt.live_recovery_rank_src 0
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"pyarrow",
"toposolve",
"psutil",
"pccl @ git+ssh://[email protected]/PrimeIntellect-ai/pccl.git@main#egg=pccl&subdirectory=python/framework",
]

[project.optional-dependencies]
Expand Down
28 changes: 14 additions & 14 deletions src/zeroband/diloco.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import os
from torch.distributed.device_mesh import init_device_mesh
import re
import time
from pydantic_config import BaseConfig
import torch
from torch import nn
from zeroband.collectives import Compression, all_reduce
from zeroband.comms import ElasticDeviceMesh
from zeroband.collectives import Compression
from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger
import torch.distributed as dist
from torch.distributed._tensor.api import DTensor
from functools import lru_cache
from pccl import Communicator, ReduceOp


class DilocoConfig(BaseConfig):
Expand Down Expand Up @@ -59,19 +60,21 @@ def __init__(
self,
config: DilocoConfig,
model: nn.Module,
elastic_device_mesh: ElasticDeviceMesh,
comm: Communicator,
):
self.config = config

if config.compression == Compression.UINT8:
from zeroband.C.collectives import ring_allreduce as _ # noqa: F401
# just force compilation

self.elastic_device_mesh = elastic_device_mesh
self.comm = comm

self._logger = get_logger()
self.world_info = get_world_info()

self.cpu_local_mesh = init_device_mesh("cpu", mesh_shape=(int(os.environ["LOCAL_WORLD_SIZE"]),))

self._init_offloaded_optimizer(model=model)

@torch.no_grad()
Expand All @@ -89,14 +92,11 @@ def sync_pseudo_gradient(self, model: nn.Module, fake: bool = False, flag: str =
"""
_start_time = time.perf_counter()

self.elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=False)
world_size_post_init = self.elastic_device_mesh.global_pg.size()

world_size = world_size_post_init
self.comm.update_topology()
world_size = self.comm.get_attribute()

self._logger.debug("sync pseudo gradient %s with world size %d", " fake" if fake else "", world_size)

global_pg = self.elastic_device_mesh.global_pg
for i in range(self.config.retry_all_reduce):
for param_offloaded, param in zip(self.param_list_cpu, model.parameters()):
if fake:
Expand All @@ -114,7 +114,7 @@ def sync_pseudo_gradient(self, model: nn.Module, fake: bool = False, flag: str =
# all_reduce(self.config.compression, self.offloaded_grad_flat_tensor, dist.ReduceOp.SUM, global_pg)
for j, tensor_group in enumerate(self._offloaded_grad_grouped_tensor):
t0 = time.perf_counter()
all_reduce(self.config.compression, tensor_group, dist.ReduceOp.SUM, global_pg)
self.comm.all_reduce(tensor_group, tensor_group, ReduceOp.SUM)
self._logger.debug(
f"{j}/{len(self._offloaded_grad_grouped_tensor)} all reduce bucket done in {time.perf_counter() - t0:.6f} seconds, numel: {tensor_group.numel()}"
)
Expand All @@ -125,7 +125,7 @@ def sync_pseudo_gradient(self, model: nn.Module, fake: bool = False, flag: str =
break
except Exception as e:
self._logger.error(f"Error syncing pseudo gradient: {e}, retry {i+1}/{self.config.retry_all_reduce}")
global_pg = self.elastic_device_mesh.get_global_pg(maybe_reinit=True)
self.comm.update_topology()
else:
self._logger.error(
"Failed to sync pseudo gradient after %d retries. Resorting to calculating pseudo-gradient without reduce",
Expand Down Expand Up @@ -181,14 +181,14 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]:
offloaded_param = nn.Parameter(
DTensor.from_local(
data_tensor,
device_mesh=self.elastic_device_mesh.cpu_local_mesh,
device_mesh=self.cpu_local_mesh,
placements=param.data.placements,
)
)

offloaded_param.grad = DTensor.from_local(
grad_tensor,
device_mesh=self.elastic_device_mesh.cpu_local_mesh,
device_mesh=self.cpu_local_mesh,
placements=param.data.placements,
)
# here we pre-allocate the grad DTensor on cpu.
Expand Down
86 changes: 30 additions & 56 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
from pydantic import model_validator
from multiprocessing.process import _children
from torch.distributed.device_mesh import init_device_mesh

import torch
from pydantic_config import parse_argv, BaseConfig
Expand All @@ -16,7 +17,6 @@
import torch.distributed as dist
from zeroband import utils
from zeroband.diloco import Diloco, DilocoConfig
from zeroband.comms import ElasticDeviceMesh
from zeroband.loss import cross_entropy_max_z_loss
from zeroband.models.llama.model import AttnFnType, create_block_mask_from_seqlens

Expand All @@ -38,6 +38,8 @@
from zeroband.checkpoint import CkptConfig, CkptManager, TrainingProgress
from zeroband.lr_scheduler import get_scheduler

from pccl import Attribute, Communicator


class OptimConfig(BaseConfig):
lr: float = 4e-4
Expand Down Expand Up @@ -193,6 +195,7 @@ def train(config: Config):
seq_length=config.data.seq_length,
attn_fn=config.train.attn_fn,
)
print(model)

model = model.to(world_info.local_rank)
logger.debug("model loaded")
Expand All @@ -212,9 +215,12 @@ def train(config: Config):
num = 1 if isinstance(config.train.ac_ckpt, bool) else config.train.ac_ckpt
apply_ac_ckpt(model, num)

elastic_device_mesh = ElasticDeviceMesh(
enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src
)
dist.init_process_group(backend="cpu:gloo,cuda:nccl")
if config.diloco is not None:
comm = Communicator(os.environ["PCCL_MASTER_ADDR"], peer_group=dist.get_rank())
comm.connect()
cuda_local_mesh = init_device_mesh("cuda", mesh_shape=(int(os.environ["LOCAL_WORLD_SIZE"]),))
print(cuda_local_mesh)

mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None
Expand All @@ -228,13 +234,13 @@ def train(config: Config):
fully_shard(
transformer_block,
mp_policy=mp_policy,
mesh=elastic_device_mesh.cuda_local_mesh,
mesh=cuda_local_mesh,
reshard_after_forward=reshard_after_forward,
)
fully_shard(
model,
mp_policy=mp_policy,
mesh=elastic_device_mesh.cuda_local_mesh,
mesh=cuda_local_mesh,
reshard_after_forward=config.train.reshard_after_forward,
)
logger.debug("model fsdped")
Expand All @@ -248,7 +254,7 @@ def train(config: Config):
)

if config.diloco is not None:
diloco = Diloco(config.diloco, model, elastic_device_mesh)
diloco = Diloco(config.diloco, model, None)

scheduler = get_scheduler(
sched_type=config.optim.sched_type,
Expand Down Expand Up @@ -312,60 +318,23 @@ def train(config: Config):

logger.info("starting training")

need_live_recovery = config.ckpt.live_recovery_rank_src is not None
first_step = True
while True:
if num_inner_steps > 1:
# if we don't use diloco we don't print the outer step logs
logger.info(f"outer_step step: {training_progress.outer_step}")

time_start_outer = time.perf_counter()

if config.diloco is not None:
# this is a patch for now to allow live recovery worker to not affect the all reduce at all

if not need_live_recovery:
elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=True)

maybe_dest_rank = elastic_device_mesh.live_recovery.should_send_ckpt_to()
if maybe_dest_rank is not None:
logger.info(f"Start live recovery to rank {maybe_dest_rank}")
ckpt_manager.send_ckpt_to_peer(elastic_device_mesh.global_pg, maybe_dest_rank, blocking=True)

elastic_device_mesh.live_recovery.reset()
else:
## receiving
time_start_live_recovery = time.perf_counter()
logger.info(f"Start live recovery from rank {config.ckpt.live_recovery_rank_src}")

## we create grad buffer and opts stats mamnually, the value will be overwritten by the ckpt but we need the DTensor to be correctly init before loading it

diloco.outer_optimizer.step() # need to step to init the DTensor stats

ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg)

log_hash_training_state(
config,
model,
inner_optimizer,
diloco,
metric_logger,
step=training_progress.step,
id="live_reco_recv",
)
need_live_recovery = False

if config.ckpt.remote_data_load:
ckpt_manager.remote_data_load()

logger.info("live recovery done in %f", time.perf_counter() - time_start_live_recovery)

# at the beginning of the inner steps we allow joiner to arrive.
# We maybe reinit before the all reduce but only to allow leaving, not to join anymore
if not first_step and config.diloco is not None:
comm.update_topology()
first_step = False

if world_info.rank == 0 and config.monitor is not None:
monitor.set_stage("inner_loop")

for inner_step in range(num_inner_steps):
print("Starting inner step")
loss_batch = 0
z_loss_batch = 0

Expand All @@ -382,11 +351,14 @@ def train(config: Config):
block_mask = create_block_mask_from_seqlens(seqlens) if seqlens is not None else None
else:
block_mask = None
print("Starting inner step")

print("Model forward!")
logits = model(tokens=input_ids, block_mask=block_mask).contiguous()
flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab")
flatten_labels = rearrange(labels, "b seq -> (b seq)")

print("Mid inner step")
if config.optim.z_loss:
ce_loss, z_loss = cross_entropy_max_z_loss(
flatten_logits, flatten_labels, config.optim.z_loss_weight
Expand All @@ -402,16 +374,19 @@ def train(config: Config):
loss = F.cross_entropy(flatten_logits, flatten_labels) / gradient_accumulation_steps
del logits
loss.backward()
print("End? inner step")

if config.optim.z_loss:
loss_batch += ce_loss.clone().detach()
z_loss_batch += z_loss.clone().detach()
else:
loss_batch += loss.clone().detach()
print("Z loss")

dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg)
dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG)
if config.optim.z_loss:
dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg)
dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG)
print("Hi")

torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
inner_optimizer.step()
Expand All @@ -432,7 +407,9 @@ def train(config: Config):
else:
# we count the total tokens with respect to all diloco workers
# might need to tweak this as some worker might fail to join the all reduce later
training_progress.total_tokens += new_tokens * elastic_device_mesh.global_pg.size()
print("Get attr")
training_progress.total_tokens += new_tokens * comm.get_attribute(Attribute.CURRENT_WORLD_SIZE)
print("Post Get attr")

metrics = {
"Loss": loss_batch.item(),
Expand All @@ -458,7 +435,7 @@ def train(config: Config):
log += f", tokens_per_second: {tokens_per_second:.2f}, mfu: {metrics['mfu']:.2f}"

if config.diloco is not None:
metrics["num_peers"] = elastic_device_mesh.global_pg.size()
metrics["num_peers"] = comm.get_attribute(Attribute.CURRENT_WORLD_SIZE)
log += f", diloco_peers: {metrics['num_peers']}"

if world_info.rank == 0:
Expand Down Expand Up @@ -531,9 +508,6 @@ def train(config: Config):
monitor.finish()

ckpt_manager.wait_for_blocking_job()

del elastic_device_mesh # allow to clean up for smoother tests transition

logger.info("Training finished, exiting ...")


Expand Down
Loading