Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 22, 2024
1 parent 31e8695 commit 6c65322
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 48 deletions.
2 changes: 1 addition & 1 deletion scripts/simulate_multi_node.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ mkdir -p logs
for i in $(seq 0 $(($N - 1 )))
do
> logs/log$i
CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 &
CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 &
child_pids+=($!)
done

Expand Down
129 changes: 87 additions & 42 deletions src/zeroband/diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,14 @@ class DilocoConfig(BaseConfig):
inner_steps: int


def get_offloaded_param(model: nn.Module) -> list[torch.Tensor]:
"""
Offload the model parameters to cpu
"""
offloaded_params = []
for param in model.parameters():
if param.requires_grad:
offloaded_param = param.data.detach().clone().to("cpu")
offloaded_param.requires_grad = True
offloaded_params.append(offloaded_param)

return offloaded_params


class Diloco:
def __init__(self, config: DilocoConfig, model: nn.Module, fsdp_sharding_strategy: ShardingStrategy):
self.config = config
self.fsdp_sharding_strategy = fsdp_sharding_strategy

if self.fsdp_sharding_strategy != ShardingStrategy.FULL_SHARD:
raise NotImplementedError("Only FULL_SHARD is supported for now")
class ElasticDeviceMesh:
"""Init two process group through device mesh, one local on gpu and one global on cpu"""

def __init__(self):
self._logger = get_logger()
self.world_info = get_world_info()

self._init_setup_device_mesh()
self._init_offloaded_optimizer(model=model)
self.world_info = get_world_info()

def _init_setup_device_mesh(self):
"""Init two process group through device mesh, one local on gpu and one global on cpu"""
# right now device mesh does not support two backend so we just create two identicaly mesh expect the backend
self.device_mesh = init_device_mesh(
"cuda", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local")
Expand All @@ -56,30 +34,97 @@ def _init_setup_device_mesh(self):

self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}")

def _init_offloaded_optimizer(self, model):
self.cpu_model = get_offloaded_param(model)
# todo: in case of sharded grap op we need to offload the cpu model only once per nodes

class Diloco:
def __init__(
self,
config: DilocoConfig,
model: nn.Module,
fsdp_sharding_strategy: ShardingStrategy,
elastic_device_mesh: ElasticDeviceMesh,
):
self.config = config
self.fsdp_sharding_strategy = fsdp_sharding_strategy
self.elastic_device_mesh = elastic_device_mesh

self._logger = get_logger()
self.world_info = get_world_info()

self.need_to_offload = (
self.fsdp_sharding_strategy == ShardingStrategy.FULL_SHARD or self.world_info.local_rank == 0
)
# if we are not in fully sharded mode only the local rank 0 will have the model on cpu
if self.need_to_offload:
self._init_offloaded_optimizer(model=model)
else:
self.outer_optimizer = None
self.cpu_model = None

def _init_offloaded_optimizer(self, model):
self.cpu_model = self.get_offloaded_param(model)
self.outer_optimizer = torch.optim.SGD(self.cpu_model, lr=self.config.outer_lr, momentum=0.9, nesterov=True)
self._logger.debug("offload model to cpu")

def sync_pseudo_gradient(self, model: nn.Module):
"""
Sync the pseudo gradient from the local process group to the global process group
"""
if self.need_to_offload:
self._logger.debug("sync pseudo gradient")
for param_offloaded, param in zip(self.cpu_model, model.parameters()):
# todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices
param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device)

# gloo does not support AVG
param_offloaded.grad = param_offloaded.grad / self.elastic_device_mesh.global_pg.size()
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg)

### the whole sectione below is just a PoC. We need to benchmark and optimizer what is the most efficient:
## do the all reduce on cpu or on gpu
## do the outer optimizer step on cpu or on gpu
def sync_inner_model(self, model: nn.Module):
"""
Sync the inner model from the global process group to the local process group
"""

## right now we do all reduce on cpu
if self.fsdp_sharding_strategy == ShardingStrategy.FULL_SHARD:
# here each rank has a shard of the model in memory so all rank do the sync
self._logger.debug("sync inner model")
for param_offloaded, param in zip(self.cpu_model, model.parameters()):
param.data = param_offloaded.data.to("cuda")

elif self.fsdp_sharding_strategy in [ShardingStrategy.SHARD_GRAD_OP, ShardingStrategy.NO_SHARD]:
self._logger.debug("sync inner model")
# in shard_grad_op mode, only the local rank 0 has the model in cpu
# we first copy the model to the gpu 0 and then broadcast it to the other gpu as
# gpu to gpu is faster than cpu to gpu with nvlink

for i, (param_offloaded, param) in enumerate(zip(self.cpu_model, model.parameters())):
# todo: we can probably overlap both comm here
if self.world_info.local_rank == 0:
self._logger.debug(
f"i: {i} shape param {param.data.shape} shape offloaded {param_offloaded.data.shape}"
)
param.data = param_offloaded.data.to("cuda")

dist.broadcast(tensor=param.data, src=0, group=self.elastic_device_mesh.local_pg)

def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]:
"""
Offload the model parameters to cpu
"""
offloaded_params = []
for param in model.parameters():
if param.requires_grad:
offloaded_param = param.data.detach().clone().to("cpu")
offloaded_param.requires_grad = True
offloaded_params.append(offloaded_param)

for param_offloaded, param in zip(self.cpu_model, model.parameters()):
# todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices
param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device)
return offloaded_params

if param_offloaded.grad.device == torch.device("cpu"):
# gloo does not support AVG
param_offloaded.grad = param_offloaded.grad / self.global_pg.size()
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.global_pg)
else:
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=self.global_pg)
def step(self, model: nn.Module):
"""
Step the optimizer
"""
self.sync_pseudo_gradient(model)
if self.outer_optimizer is not None:
self.outer_optimizer.step()
self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this
self.sync_inner_model(model)
11 changes: 6 additions & 5 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
import torch.distributed as dist
from zeroband import utils
from zeroband.diloco import Diloco, DilocoConfig
from zeroband.diloco import Diloco, DilocoConfig, ElasticDeviceMesh

from zeroband.utils import get_sharding_strategy
from zeroband.utils.monitor import WandbMonitor, DummyMonitor
Expand Down Expand Up @@ -119,19 +119,22 @@ def train(config: Config):
config.data.seq_length,
)

elastic_device_mesh = ElasticDeviceMesh()

model = FSDP(
model,
sharding_strategy=sharding_strategy,
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16),
use_orig_params=True,
process_group=elastic_device_mesh.local_pg if config.diloco is not None else None,
)

if config.train.torch_compile:
model = torch.compile(model)
logger.debug("model compiled and fsdped")

if config.diloco is not None:
diloco = Diloco(config.diloco, model, sharding_strategy)
diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh)

# Setup optimizers
inner_optimizer = torch.optim.AdamW(
Expand Down Expand Up @@ -221,9 +224,7 @@ def train(config: Config):
)

if config.diloco is not None:
diloco.sync_pseudo_gradient(model)
diloco.outer_optimizer.step()
diloco.outer_optimizer.zero_grad() # todo(sami): check if we can remove this
diloco.step(model)

outer_step += 1

Expand Down

0 comments on commit 6c65322

Please sign in to comment.