diff --git a/scripts/simulate_multi_node.sh b/scripts/simulate_multi_node.sh index a5344542..518ea21e 100755 --- a/scripts/simulate_multi_node.sh +++ b/scripts/simulate_multi_node.sh @@ -57,7 +57,7 @@ mkdir -p logs for i in $(seq 0 $(($N - 1 ))) do > logs/log$i - CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 & + CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 & child_pids+=($!) done diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 2bed767c..1af0d411 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -13,36 +13,14 @@ class DilocoConfig(BaseConfig): inner_steps: int -def get_offloaded_param(model: nn.Module) -> list[torch.Tensor]: - """ - Offload the model parameters to cpu - """ - offloaded_params = [] - for param in model.parameters(): - if param.requires_grad: - offloaded_param = param.data.detach().clone().to("cpu") - offloaded_param.requires_grad = True - offloaded_params.append(offloaded_param) - - return offloaded_params - - -class Diloco: - def __init__(self, config: DilocoConfig, model: nn.Module, fsdp_sharding_strategy: ShardingStrategy): - self.config = config - self.fsdp_sharding_strategy = fsdp_sharding_strategy - - if self.fsdp_sharding_strategy != ShardingStrategy.FULL_SHARD: - raise NotImplementedError("Only FULL_SHARD is supported for now") +class ElasticDeviceMesh: + """Init two process group through device mesh, one local on gpu and one global on cpu""" + def __init__(self): self._logger = get_logger() - self.world_info = get_world_info() - self._init_setup_device_mesh() - self._init_offloaded_optimizer(model=model) + self.world_info = get_world_info() - def _init_setup_device_mesh(self): - """Init two process group through device mesh, one local on gpu and one global on cpu""" # right now device mesh does not support two backend so we just create two identicaly mesh expect the backend self.device_mesh = init_device_mesh( "cuda", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local") @@ -56,30 +34,97 @@ def _init_setup_device_mesh(self): self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}") - def _init_offloaded_optimizer(self, model): - self.cpu_model = get_offloaded_param(model) - # todo: in case of sharded grap op we need to offload the cpu model only once per nodes +class Diloco: + def __init__( + self, + config: DilocoConfig, + model: nn.Module, + fsdp_sharding_strategy: ShardingStrategy, + elastic_device_mesh: ElasticDeviceMesh, + ): + self.config = config + self.fsdp_sharding_strategy = fsdp_sharding_strategy + self.elastic_device_mesh = elastic_device_mesh + + self._logger = get_logger() + self.world_info = get_world_info() + + self.need_to_offload = ( + self.fsdp_sharding_strategy == ShardingStrategy.FULL_SHARD or self.world_info.local_rank == 0 + ) + # if we are not in fully sharded mode only the local rank 0 will have the model on cpu + if self.need_to_offload: + self._init_offloaded_optimizer(model=model) + else: + self.outer_optimizer = None + self.cpu_model = None + + def _init_offloaded_optimizer(self, model): + self.cpu_model = self.get_offloaded_param(model) self.outer_optimizer = torch.optim.SGD(self.cpu_model, lr=self.config.outer_lr, momentum=0.9, nesterov=True) + self._logger.debug("offload model to cpu") def sync_pseudo_gradient(self, model: nn.Module): """ Sync the pseudo gradient from the local process group to the global process group """ + if self.need_to_offload: + self._logger.debug("sync pseudo gradient") + for param_offloaded, param in zip(self.cpu_model, model.parameters()): + # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices + param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) + + # gloo does not support AVG + param_offloaded.grad = param_offloaded.grad / self.elastic_device_mesh.global_pg.size() + dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg) - ### the whole sectione below is just a PoC. We need to benchmark and optimizer what is the most efficient: - ## do the all reduce on cpu or on gpu - ## do the outer optimizer step on cpu or on gpu + def sync_inner_model(self, model: nn.Module): + """ + Sync the inner model from the global process group to the local process group + """ - ## right now we do all reduce on cpu + if self.fsdp_sharding_strategy == ShardingStrategy.FULL_SHARD: + # here each rank has a shard of the model in memory so all rank do the sync + self._logger.debug("sync inner model") + for param_offloaded, param in zip(self.cpu_model, model.parameters()): + param.data = param_offloaded.data.to("cuda") + + elif self.fsdp_sharding_strategy in [ShardingStrategy.SHARD_GRAD_OP, ShardingStrategy.NO_SHARD]: + self._logger.debug("sync inner model") + # in shard_grad_op mode, only the local rank 0 has the model in cpu + # we first copy the model to the gpu 0 and then broadcast it to the other gpu as + # gpu to gpu is faster than cpu to gpu with nvlink + + for i, (param_offloaded, param) in enumerate(zip(self.cpu_model, model.parameters())): + # todo: we can probably overlap both comm here + if self.world_info.local_rank == 0: + self._logger.debug( + f"i: {i} shape param {param.data.shape} shape offloaded {param_offloaded.data.shape}" + ) + param.data = param_offloaded.data.to("cuda") + + dist.broadcast(tensor=param.data, src=0, group=self.elastic_device_mesh.local_pg) + + def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: + """ + Offload the model parameters to cpu + """ + offloaded_params = [] + for param in model.parameters(): + if param.requires_grad: + offloaded_param = param.data.detach().clone().to("cpu") + offloaded_param.requires_grad = True + offloaded_params.append(offloaded_param) - for param_offloaded, param in zip(self.cpu_model, model.parameters()): - # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices - param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) + return offloaded_params - if param_offloaded.grad.device == torch.device("cpu"): - # gloo does not support AVG - param_offloaded.grad = param_offloaded.grad / self.global_pg.size() - dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.global_pg) - else: - dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=self.global_pg) + def step(self, model: nn.Module): + """ + Step the optimizer + """ + self.sync_pseudo_gradient(model) + if self.outer_optimizer is not None: + self.outer_optimizer.step() + self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this + self.sync_inner_model(model) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 00b6c5e8..332b008f 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -19,7 +19,7 @@ ) import torch.distributed as dist from zeroband import utils -from zeroband.diloco import Diloco, DilocoConfig +from zeroband.diloco import Diloco, DilocoConfig, ElasticDeviceMesh from zeroband.utils import get_sharding_strategy from zeroband.utils.monitor import WandbMonitor, DummyMonitor @@ -119,11 +119,14 @@ def train(config: Config): config.data.seq_length, ) + elastic_device_mesh = ElasticDeviceMesh() + model = FSDP( model, sharding_strategy=sharding_strategy, mixed_precision=MixedPrecision(param_dtype=torch.bfloat16), use_orig_params=True, + process_group=elastic_device_mesh.local_pg if config.diloco is not None else None, ) if config.train.torch_compile: @@ -131,7 +134,7 @@ def train(config: Config): logger.debug("model compiled and fsdped") if config.diloco is not None: - diloco = Diloco(config.diloco, model, sharding_strategy) + diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh) # Setup optimizers inner_optimizer = torch.optim.AdamW( @@ -221,9 +224,7 @@ def train(config: Config): ) if config.diloco is not None: - diloco.sync_pseudo_gradient(model) - diloco.outer_optimizer.step() - diloco.outer_optimizer.zero_grad() # todo(sami): check if we can remove this + diloco.step(model) outer_step += 1