Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add diloco #3

Merged
merged 15 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.vscode/*
logs/*

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,15 @@ uv run ...
To check that everything is working you can do

```bash
ZERO_BAND_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug.toml
ZERO_BAND_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug/debug.toml
```

## run diloco

To run diloco locally you can use the helper script `scripts/simulatsimulate_multi_nodee_mutl.sh`

```bash
ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node.sh 2 2 src/zeroband/train.py @configs/debug/diloco.toml
```

## run test
Expand Down
13 changes: 0 additions & 13 deletions configs/debug.toml

This file was deleted.

16 changes: 16 additions & 0 deletions configs/debug/debug.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name_model = "debugmodel"
project = "/tmp/debug"
metric_logger_type = "dummy"

[train]
micro_bs = 8
sharding_strategy = "SHARD_GRAD_OP"

[optim]
batch_size = 16
warmup_steps = 10
total_steps = 10

[data]
fake = true

19 changes: 19 additions & 0 deletions configs/debug/diloco.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name_model = "debugmodel"
project = "/tmp/debug"
metric_logger_type = "dummy"

[train]
micro_bs = 8
sharding_strategy = "FULL_SHARD"

[optim]
batch_size = 16
warmup_steps = 10
total_steps = 10

[data]
fake = true

[diloco]
inner_steps = 5

67 changes: 67 additions & 0 deletions scripts/simulate_multi_node.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/bin/bash

#
# simulate multi nodes on one gpu. start N torchrun on X gpu locally.
# example how to run ./scripts/simulate_multi_node.sh 2 1 src/zeroband/train.py @configs/debug/debug.toml

# Function to get CUDA devices based on the number of GPUs and index
function get_cuda_devices() {
local num_gpu=$1
local index=$2
local start_gpu=$((num_gpu * index))
local end_gpu=$((start_gpu + num_gpu - 1))

if [ "$num_gpu" -eq 1 ]; then
echo $start_gpu
else
echo $(seq -s ',' $start_gpu $end_gpu)
fi
}

# Array to store PIDs of child processes
child_pids=()

# Function to kill all child processes
cleanup() {
echo "Cleaning up child processes..."
local killed=0
for pid in "${child_pids[@]}"; do
if kill -TERM "$pid" 2>/dev/null; then
((killed++))
fi
done
wait
echo "All child processes terminated. Killed $killed processes."
exit
}

# Check if at least three arguments were passed
if [ "$#" -lt 3 ]; then
echo "Usage: $0 <N> <initial_peer> <num_gpu> [additional_python_args]"
exit 1
fi


N=$1 # Set N from the first argument
NUM_GPU=$2
shift 2 # Remove the first three arguments so $@ contains only additional Python arguments

# Register the cleanup function to be called on SIGINT (Ctrl+C)
trap cleanup SIGINT


mkdir -p logs



for i in $(seq 0 $(($N - 1 )))
do
> logs/log$i
CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 &
child_pids+=($!)
done

tail -f logs/log0 &
child_pids+=($!)

wait
156 changes: 156 additions & 0 deletions src/zeroband/diloco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from pydantic_config import BaseConfig
import torch
from torch.distributed.device_mesh import init_device_mesh
from torch import nn
from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger
from torch.distributed.fsdp import ShardingStrategy
import torch.distributed as dist


class DilocoConfig(BaseConfig):
outer_lr: float = 0.7
inner_steps: int


class ElasticDeviceMesh:
"""Init two process group through device mesh, one local on gpu and one global on cpu"""

def __init__(self):
self._logger = get_logger()

self.world_info = get_world_info()

# right now device mesh does not support two backend so we just create two identicaly mesh expect the backend
self.device_mesh = init_device_mesh(
"cuda", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local")
)
self.device_mesh_cpu = init_device_mesh(
"gloo", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local")
)

self.global_pg = self.device_mesh_cpu.get_group("global")
self.local_pg = self.device_mesh.get_group("local")

self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}")


class Diloco:
"""
This class implements the diloco algorithm from https://arxiv.org/abs/2311.08105 and https://arxiv.org/abs/2407.07852.

It handles the outer loop as well as the inter node communication.

There is no VRAM overhead with this implementation as the model is outer optimizer is offloaded to cpu.
All reduce communication are also done on cpu using GLOO.

Example usage:

# Example usage in a training loop:

diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh)

for outer_step in range(num_outer_steps):
for inner_step in range(config.diloco.inner_steps):
# Regular inner training loop
optimizer.zero_grad()
loss = model(batch)
loss.backward()
optimizer.step()

diloco.step(model)
"""

def __init__(
self,
config: DilocoConfig,
model: nn.Module,
fsdp_sharding_strategy: ShardingStrategy,
elastic_device_mesh: ElasticDeviceMesh,
):
self.config = config
self.fsdp_sharding_strategy = fsdp_sharding_strategy
self.elastic_device_mesh = elastic_device_mesh

self._logger = get_logger()
self.world_info = get_world_info()

self.need_to_offload = (
self.fsdp_sharding_strategy == ShardingStrategy.FULL_SHARD or self.world_info.local_rank == 0
)
# if we are not in fully sharded mode only the local rank 0 will have the model on cpu
if self.need_to_offload:
self._init_offloaded_optimizer(model=model)
else:
self.outer_optimizer = None
self.cpu_model = None

def _init_offloaded_optimizer(self, model):
self.cpu_model = self.get_offloaded_param(model)
self.outer_optimizer = torch.optim.SGD(self.cpu_model, lr=self.config.outer_lr, momentum=0.9, nesterov=True)
self._logger.debug("offload model to cpu")

def sync_pseudo_gradient(self, model: nn.Module):
"""
Sync the pseudo gradient from the local process group to the global process group
"""
if self.need_to_offload:
self._logger.debug("sync pseudo gradient")
for param_offloaded, param in zip(self.cpu_model, model.parameters()):
# todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices
param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device)

# gloo does not support AVG
param_offloaded.grad = param_offloaded.grad / self.elastic_device_mesh.global_pg.size()
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg)

def sync_inner_model(self, model: nn.Module):
"""
Sync the inner model from the global process group to the local process group
"""

if self.fsdp_sharding_strategy == ShardingStrategy.FULL_SHARD:
# here each rank has a shard of the model in memory so all rank do the sync
self._logger.debug("sync inner model")
for param_offloaded, param in zip(self.cpu_model, model.parameters()):
param.data = param_offloaded.data.to("cuda")

elif self.fsdp_sharding_strategy in [ShardingStrategy.SHARD_GRAD_OP, ShardingStrategy.NO_SHARD]:
self._logger.debug("sync inner model")
# in shard_grad_op mode, only the local rank 0 has the model in cpu
# we first copy the model to the gpu 0 and then broadcast it to the other gpu as
# gpu to gpu is faster than cpu to gpu with nvlink

for i, (param_offloaded, param) in enumerate(zip(self.cpu_model, model.parameters())):
# todo: we can probably overlap both comm here
if self.world_info.local_rank == 0:
self._logger.debug(
f"i: {i} shape param {param.data.shape} shape offloaded {param_offloaded.data.shape}"
)
param.data = param_offloaded.data.to("cuda")

dist.broadcast(tensor=param.data, src=0, group=self.elastic_device_mesh.local_pg)

def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]:
"""
Offload the model parameters to cpu
"""
offloaded_params = []
for param in model.parameters():
if param.requires_grad:
offloaded_param = param.data.detach().clone().to("cpu")
offloaded_param.requires_grad = True
offloaded_params.append(offloaded_param)

return offloaded_params

def step(self, model: nn.Module):
"""
Step the optimizer
"""
self.sync_pseudo_gradient(model)
if self.outer_optimizer is not None:
self.outer_optimizer.step()
self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this

self.sync_inner_model(model)
Loading