From 6f0c479a51c02fb391c51b0afcfff007103d39d0 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 21 Sep 2024 23:11:55 +0000 Subject: [PATCH 01/15] add diloco version --- README.md | 2 +- configs/{ => debug}/debug.toml | 1 + configs/debug/diloco.toml | 17 +++ src/zeroband/diloco.py | 85 ++++++++++++ src/zeroband/train.py | 14 +- src/zeroband/utils/world_info.py | 1 + tests/test_torchrun/test_train | 206 ------------------------------ tests/test_torchrun/test_train.py | 25 +++- 8 files changed, 133 insertions(+), 218 deletions(-) rename configs/{ => debug}/debug.toml (80%) create mode 100644 configs/debug/diloco.toml create mode 100644 src/zeroband/diloco.py delete mode 100644 tests/test_torchrun/test_train diff --git a/README.md b/README.md index d6796638..37aa2fbe 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ uv run ... To check that everything is working you can do ```bash -ZERO_BAND_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug.toml +ZERO_BAND_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug/debug.toml ``` ## run test diff --git a/configs/debug.toml b/configs/debug/debug.toml similarity index 80% rename from configs/debug.toml rename to configs/debug/debug.toml index e7d6e30d..eedfea20 100644 --- a/configs/debug.toml +++ b/configs/debug/debug.toml @@ -3,6 +3,7 @@ project = "debug" [train] micro_bs = 8 +sharding_strategy = "SHARD_GRAD_OP" [optim] batch_size = 16 diff --git a/configs/debug/diloco.toml b/configs/debug/diloco.toml new file mode 100644 index 00000000..24a8602c --- /dev/null +++ b/configs/debug/diloco.toml @@ -0,0 +1,17 @@ +name_model = "debugmodel" +project = "debug" + +[train] +micro_bs = 8 +sharding_strategy = "FULL_SHARD" + +[optim] +batch_size = 16 +warmup_steps = 10 +total_steps = 5000 + +[data] +fake_data = true + +[diloco] +inner_steps = 10 \ No newline at end of file diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py new file mode 100644 index 00000000..2bed767c --- /dev/null +++ b/src/zeroband/diloco.py @@ -0,0 +1,85 @@ +from pydantic_config import BaseConfig +import torch +from torch.distributed.device_mesh import init_device_mesh +from torch import nn +from zeroband.utils.world_info import get_world_info +from zeroband.utils.logging import get_logger +from torch.distributed.fsdp import ShardingStrategy +import torch.distributed as dist + + +class DilocoConfig(BaseConfig): + outer_lr: float = 0.7 + inner_steps: int + + +def get_offloaded_param(model: nn.Module) -> list[torch.Tensor]: + """ + Offload the model parameters to cpu + """ + offloaded_params = [] + for param in model.parameters(): + if param.requires_grad: + offloaded_param = param.data.detach().clone().to("cpu") + offloaded_param.requires_grad = True + offloaded_params.append(offloaded_param) + + return offloaded_params + + +class Diloco: + def __init__(self, config: DilocoConfig, model: nn.Module, fsdp_sharding_strategy: ShardingStrategy): + self.config = config + self.fsdp_sharding_strategy = fsdp_sharding_strategy + + if self.fsdp_sharding_strategy != ShardingStrategy.FULL_SHARD: + raise NotImplementedError("Only FULL_SHARD is supported for now") + + self._logger = get_logger() + self.world_info = get_world_info() + + self._init_setup_device_mesh() + self._init_offloaded_optimizer(model=model) + + def _init_setup_device_mesh(self): + """Init two process group through device mesh, one local on gpu and one global on cpu""" + # right now device mesh does not support two backend so we just create two identicaly mesh expect the backend + self.device_mesh = init_device_mesh( + "cuda", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local") + ) + self.device_mesh_cpu = init_device_mesh( + "gloo", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local") + ) + + self.global_pg = self.device_mesh_cpu.get_group("global") + self.local_pg = self.device_mesh.get_group("local") + + self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}") + + def _init_offloaded_optimizer(self, model): + self.cpu_model = get_offloaded_param(model) + # todo: in case of sharded grap op we need to offload the cpu model only once per nodes + + self.outer_optimizer = torch.optim.SGD(self.cpu_model, lr=self.config.outer_lr, momentum=0.9, nesterov=True) + + def sync_pseudo_gradient(self, model: nn.Module): + """ + Sync the pseudo gradient from the local process group to the global process group + """ + + ### the whole sectione below is just a PoC. We need to benchmark and optimizer what is the most efficient: + ## do the all reduce on cpu or on gpu + ## do the outer optimizer step on cpu or on gpu + + ## right now we do all reduce on cpu + + for param_offloaded, param in zip(self.cpu_model, model.parameters()): + # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices + param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) + + if param_offloaded.grad.device == torch.device("cpu"): + # gloo does not support AVG + param_offloaded.grad = param_offloaded.grad / self.global_pg.size() + dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.global_pg) + else: + dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=self.global_pg) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 9988e767..00b6c5e8 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -19,6 +19,7 @@ ) import torch.distributed as dist from zeroband import utils +from zeroband.diloco import Diloco, DilocoConfig from zeroband.utils import get_sharding_strategy from zeroband.utils.monitor import WandbMonitor, DummyMonitor @@ -36,11 +37,6 @@ def ddp_setup(): torch.cuda.set_device(world_info.local_rank) -class DilocoConfig(BaseConfig): - outer_lr: float = 0.7 - inner_steps: int = 10 - - class DataConfig(BaseConfig): seq_length: int = 1024 fake_data: bool = False @@ -134,6 +130,9 @@ def train(config: Config): model = torch.compile(model) logger.debug("model compiled and fsdped") + if config.diloco is not None: + diloco = Diloco(config.diloco, model, sharding_strategy) + # Setup optimizers inner_optimizer = torch.optim.AdamW( model.parameters(), @@ -221,6 +220,11 @@ def train(config: Config): f"step: {real_step}, loss: {loss_batch.item():.4f}, tokens_per_second: {metrics['tokens_per_second']:.2f}, mfu: {mfu:.2f}" ) + if config.diloco is not None: + diloco.sync_pseudo_gradient(model) + diloco.outer_optimizer.step() + diloco.outer_optimizer.zero_grad() # todo(sami): check if we can remove this + outer_step += 1 if real_step >= config.optim.total_steps: diff --git a/src/zeroband/utils/world_info.py b/src/zeroband/utils/world_info.py index 6ab3780f..efe30a1a 100644 --- a/src/zeroband/utils/world_info.py +++ b/src/zeroband/utils/world_info.py @@ -16,6 +16,7 @@ def __init__(self): self.rank = int(os.environ["RANK"]) self.local_rank = int(os.environ["LOCAL_RANK"]) self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + self.nnodes = self.world_size // self.local_world_size def get_world_info() -> WorldInfo: diff --git a/tests/test_torchrun/test_train b/tests/test_torchrun/test_train deleted file mode 100644 index 3295d5f5..00000000 --- a/tests/test_torchrun/test_train +++ /dev/null @@ -1,206 +0,0 @@ -import pickle -import subprocess -import numpy as np -import pytest -import socket -from hivemind.dht.dht import DHT -from open_diloco.ckpt_utils import CKPT_PREFIX - - -def get_random_available_port(): - # https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - - -@pytest.fixture(scope="session") -def random_available_port(): - return get_random_available_port() - - -@pytest.fixture -def config() -> list[str]: - return [ - "--path_model", - "tests/models/llama-2m-fresh", - "--fake_data", - "--no-torch_compile", - "--lr", - "1e-2", - "--per_device_train_batch_size", - "8", - "--total_batch_size", - "16", - "--max_steps", - "50", - "--metric_logger_type", - "dummy", - ] - - -@pytest.mark.parametrize("num_gpu", [2]) -def test_multi_gpu_ckpt(config, random_available_port, num_gpu, tmp_path): - ckpt_path = f"{tmp_path}/ckpt" - log_file_1 = f"{tmp_path}/log1.json" - log_file_2 = f"{tmp_path}/log2.json" - - run_1 = ["--ckpt.path", ckpt_path, "--ckpt.interval", "10", "--project", log_file_1] - - cmd = [ - "torchrun", - f"--nproc_per_node={num_gpu}", - "--rdzv-endpoint", - f"localhost:{random_available_port}", - "open_diloco/train_fsdp.py", - *config, - ] - - result = subprocess.run(cmd + run_1) - - if result.returncode != 0: - pytest.fail(f"Process {result} failed {result.stderr}") - - run_2 = ["--ckpt.path", ckpt_path, "--ckpt.resume", f"{ckpt_path}/{CKPT_PREFIX}_20", "--project", log_file_2] - - results_resume = subprocess.run(cmd + run_2) - - if results_resume.returncode != 0: - pytest.fail(f"Process {result} failed {result.stderr}") - - with open(log_file_1, "rb") as f: - log1 = pickle.load(f) - with open(log_file_2, "rb") as f: - log2 = pickle.load(f) - - log1 = {data["step"]: [data["Loss"], data["lr"]] for data in log1} - log2 = {data["step"]: [data["Loss"], data["lr"]] for data in log2} - - common_step = set(log1.keys()) & set(log2.keys()) - - for step in common_step: - assert np.allclose(log1[step][0], log2[step][0], atol=1e-3), f"Loss at step {step} is different" - assert log1[step][1] == log2[step][1], f"Lr at step {step} is different" - - -@pytest.fixture -def config_hv() -> list[str]: - config = [ - "--path_model", - "tests/models/llama-2m-fresh", - "--fake_data", - "--no-torch_compile", - "--lr", - "1e-2", - "--per_device_train_batch_size", - "8", - "--total_batch_size", - "16", - "--max_steps", - "100", - "--metric_logger_type", - "dummy", - ] - - return config + [ - "--hv.local_steps", - "25", - "--hv.skip_load_from_peers", - "--hv.fail_rank_drop", - "--hv.matchmaking_time", - "5", - ] - - -@pytest.mark.parametrize("num_diloco", [2]) -def test_multi_gpu_hivemind(config_hv, num_diloco, tmp_path): - dht = DHT( - start=True, - host_maddrs=[f"/ip4/0.0.0.0/tcp/{get_random_available_port()}"], - ) - - initial_peers = str(dht.get_visible_maddrs()[0]) - - results = [] - - ckpt_path = f"{tmp_path}/ckpt" - - def get_base_cmd(i, initial_peers): - return [ - "torchrun", - f"--nproc_per_node={1}", - "--rdzv-endpoint", - f"localhost:{port}", - "open_diloco/train_fsdp.py", - *config_hv, - "--hv.initial_peers", - initial_peers, - "--hv.world_rank", - str(i), - "--hv.galaxy_size", - str(num_diloco), - ] - - for i in range(num_diloco): - port = get_random_available_port() - - cmd = get_base_cmd(i, initial_peers) + [ - "--ckpt.path", - ckpt_path, - "--ckpt.interval", - "25", - "--project", - f"{tmp_path}/log{i}_part1.json", - ] - - result = subprocess.Popen(cmd) - results.append(result) - - for result in results: - result.wait() - if result.returncode != 0: - pytest.fail(f"Process {result} failed {result.stderr}") - - # resume from ckpt - - dht.shutdown() - - del dht - dht = DHT( - start=True, - host_maddrs=[f"/ip4/0.0.0.0/tcp/{get_random_available_port()}"], - ) - initial_peers = str(dht.get_visible_maddrs()[0]) - - for i in range(num_diloco): - port = get_random_available_port() - - cmd = get_base_cmd(i, initial_peers) + [ - "--ckpt.resume", - f"{ckpt_path}/{CKPT_PREFIX}_50", - "--project", - f"{tmp_path}/log{i}_part2.json", - ] - - result = subprocess.Popen(cmd) - results.append(result) - - for result in results: - result.wait() - if result.returncode != 0: - pytest.fail(f"Process {result} failed {result.stderr}") - - for i in range(num_diloco): - with open(f"{tmp_path}/log{i}_part1.json", "rb") as f: - log1 = pickle.load(f) - with open(f"{tmp_path}/log{i}_part2.json", "rb") as f: - log2 = pickle.load(f) - - log1 = {data["step"]: [data["Loss"], data["lr"]] for data in log1} - log2 = {data["step"]: [data["Loss"], data["lr"]] for data in log2} - - common_step = set(log1.keys()) & set(log2.keys()) - - for step in common_step: - assert np.allclose(log1[step][0], log2[step][0], atol=1e-2), f"Loss at step {step} is different" - assert log1[step][1] == log2[step][1], f"Lr at step {step} is different" diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index 61c7ec74..b669dd58 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -15,21 +15,34 @@ def random_available_port(): return get_random_available_port() -@pytest.fixture() -def config_path() -> str: - # need to be executed in the root dir - return "configs/debug.toml" +@pytest.mark.parametrize("num_gpu", [1, 2]) +def test_multi_gpu(random_available_port, num_gpu): + cmd = [ + "torchrun", + f"--nproc_per_node={num_gpu}", + "--rdzv-endpoint", + f"localhost:{random_available_port}", + "src/zeroband/train.py", + "@configs/debug/debug.toml", + "--optim.total_steps", + "10", + ] + + result = subprocess.run(cmd) + + if result.returncode != 0: + pytest.fail(f"Process {result} failed {result.stderr}") @pytest.mark.parametrize("num_gpu", [1, 2]) -def test_multi_gpu_ckpt(config_path, random_available_port, num_gpu): +def test_multi_gpu_diloco(random_available_port, num_gpu): cmd = [ "torchrun", f"--nproc_per_node={num_gpu}", "--rdzv-endpoint", f"localhost:{random_available_port}", "src/zeroband/train.py", - f"@{config_path}", + "@configs/debug/diloco.toml", "--optim.total_steps", "10", ] From f34f9c9881b9d9625f8ec2562161b991d0aaa7f6 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 22 Sep 2024 01:15:02 +0000 Subject: [PATCH 02/15] add better tests --- configs/debug/debug.toml | 4 +- configs/debug/diloco.toml | 5 ++- tests/test_torchrun/test_train.py | 68 ++++++++++++++++++++----------- 3 files changed, 50 insertions(+), 27 deletions(-) diff --git a/configs/debug/debug.toml b/configs/debug/debug.toml index eedfea20..ae283e9e 100644 --- a/configs/debug/debug.toml +++ b/configs/debug/debug.toml @@ -8,7 +8,7 @@ sharding_strategy = "SHARD_GRAD_OP" [optim] batch_size = 16 warmup_steps = 10 -total_steps = 5000 +total_steps = 10 [data] -fake_data = true \ No newline at end of file +fake_data = true diff --git a/configs/debug/diloco.toml b/configs/debug/diloco.toml index 24a8602c..9283c964 100644 --- a/configs/debug/diloco.toml +++ b/configs/debug/diloco.toml @@ -8,10 +8,11 @@ sharding_strategy = "FULL_SHARD" [optim] batch_size = 16 warmup_steps = 10 -total_steps = 5000 +total_steps = 10 [data] fake_data = true [diloco] -inner_steps = 10 \ No newline at end of file +inner_steps = 5 + diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index b669dd58..8e98716a 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -1,37 +1,59 @@ +import copy +import os import subprocess import pytest import socket -def get_random_available_port(): +def get_random_available_port_list(num_port): # https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] + ports = [] + while len(ports) < num_port: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + new_port = s.getsockname()[1] -@pytest.fixture() -def random_available_port(): - return get_random_available_port() + if new_port not in ports: + ports.append(new_port) + return ports -@pytest.mark.parametrize("num_gpu", [1, 2]) -def test_multi_gpu(random_available_port, num_gpu): - cmd = [ - "torchrun", - f"--nproc_per_node={num_gpu}", - "--rdzv-endpoint", - f"localhost:{random_available_port}", - "src/zeroband/train.py", - "@configs/debug/debug.toml", - "--optim.total_steps", - "10", - ] - result = subprocess.run(cmd) +def get_random_available_port(num_port): + return get_random_available_port_list(num_port)[0] - if result.returncode != 0: - pytest.fail(f"Process {result} failed {result.stderr}") + +def gpus_to_use(num_nodes, num_gpu, rank): + return ",".join(map(str, range(rank * num_gpu, (rank + 1) * num_gpu))) + + +@pytest.mark.parametrize("num_gpus", [[1, 1], [2, 1], [1, 2]]) +@pytest.mark.parametrize("config", ["debug/debug.toml", "debug/diloco.toml"]) +def test_multi_gpu(num_gpus, config): + num_nodes, num_gpu = num_gpus[0], num_gpus[1] + + processes = [] + ports = get_random_available_port_list(num_nodes) + for i in range(num_nodes): + cmd = [ + "torchrun", + f"--nproc_per_node={num_gpu}", + "--rdzv-endpoint", + f"localhost:{ports[i]}", + "src/zeroband/train.py", + f"@configs/{config}", + ] + + env = copy.deepcopy(os.environ) + env["CUDA_VISIBLE_DEVICES"] = gpus_to_use(num_nodes, num_gpu, i) + process1 = subprocess.Popen(cmd, env=env) + processes.append(process1) + + for process in processes: + result = process.wait() + if result != 0: + pytest.fail(f"Process {result} failed {result}") @pytest.mark.parametrize("num_gpu", [1, 2]) @@ -44,7 +66,7 @@ def test_multi_gpu_diloco(random_available_port, num_gpu): "src/zeroband/train.py", "@configs/debug/diloco.toml", "--optim.total_steps", - "10", + "50", ] result = subprocess.run(cmd) From 60c910ead9db3255290c514fef4955659d73e114 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 22 Sep 2024 03:30:47 +0000 Subject: [PATCH 03/15] add mutli run script --- .gitignore | 1 + scripts/simulate_multi_node.sh | 39 ++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100755 scripts/simulate_multi_node.sh diff --git a/.gitignore b/.gitignore index d5671f03..3af43cdb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .vscode/* +logs/* # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/scripts/simulate_multi_node.sh b/scripts/simulate_multi_node.sh new file mode 100755 index 00000000..c7185138 --- /dev/null +++ b/scripts/simulate_multi_node.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +# +# simulate multi nodes on one gpu. start N torchrun on X gpu locally. +# example how to run ./scripts/simulate_multi_node.sh 2 1 src/zeroband/train.py @configs/debug/debug.toml + +# Function to get CUDA devices based on the number of GPUs and index +function get_cuda_devices() { + local num_gpu=$1 + local index=$2 + local start_gpu=$((num_gpu * index)) + local end_gpu=$((start_gpu + num_gpu - 1)) + + if [ "$num_gpu" -eq 1 ]; then + echo $start_gpu + else + echo $(seq -s ',' $start_gpu $end_gpu) + fi +} + +# Check if at least three arguments were passed +if [ "$#" -lt 3 ]; then + echo "Usage: $0 [additional_python_args]" + exit 1 +fi + +N=$1 # Set N from the first argument +NUM_GPU=$2 +shift 2 # Remove the first three arguments so $@ contains only additional Python arguments + +mkdir -p logs + + +for i in $(seq 0 $(($N - 1 ))) +do + CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 & +done + +tail -f logs/log0 From 31e86953c0ab271f5fc96f0db0bcfa60923db564 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 22 Sep 2024 14:35:47 +0000 Subject: [PATCH 04/15] add better scropt --- scripts/simulate_multi_node.sh | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/scripts/simulate_multi_node.sh b/scripts/simulate_multi_node.sh index c7185138..a5344542 100755 --- a/scripts/simulate_multi_node.sh +++ b/scripts/simulate_multi_node.sh @@ -18,22 +18,50 @@ function get_cuda_devices() { fi } +# Array to store PIDs of child processes +child_pids=() + +# Function to kill all child processes +cleanup() { + echo "Cleaning up child processes..." + local killed=0 + for pid in "${child_pids[@]}"; do + if kill -TERM "$pid" 2>/dev/null; then + ((killed++)) + fi + done + wait + echo "All child processes terminated. Killed $killed processes." + exit +} + # Check if at least three arguments were passed if [ "$#" -lt 3 ]; then echo "Usage: $0 [additional_python_args]" exit 1 fi + N=$1 # Set N from the first argument NUM_GPU=$2 shift 2 # Remove the first three arguments so $@ contains only additional Python arguments +# Register the cleanup function to be called on SIGINT (Ctrl+C) +trap cleanup SIGINT + + mkdir -p logs + for i in $(seq 0 $(($N - 1 ))) do + > logs/log$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 & + child_pids+=($!) done -tail -f logs/log0 +tail -f logs/log0 & +child_pids+=($!) + +wait From 6c65322fcbc7c02847b02d86d99c1e053dbe0569 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 22 Sep 2024 20:46:17 +0000 Subject: [PATCH 05/15] wip --- scripts/simulate_multi_node.sh | 2 +- src/zeroband/diloco.py | 129 ++++++++++++++++++++++----------- src/zeroband/train.py | 11 +-- 3 files changed, 94 insertions(+), 48 deletions(-) diff --git a/scripts/simulate_multi_node.sh b/scripts/simulate_multi_node.sh index a5344542..518ea21e 100755 --- a/scripts/simulate_multi_node.sh +++ b/scripts/simulate_multi_node.sh @@ -57,7 +57,7 @@ mkdir -p logs for i in $(seq 0 $(($N - 1 ))) do > logs/log$i - CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 & + CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 & child_pids+=($!) done diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 2bed767c..1af0d411 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -13,36 +13,14 @@ class DilocoConfig(BaseConfig): inner_steps: int -def get_offloaded_param(model: nn.Module) -> list[torch.Tensor]: - """ - Offload the model parameters to cpu - """ - offloaded_params = [] - for param in model.parameters(): - if param.requires_grad: - offloaded_param = param.data.detach().clone().to("cpu") - offloaded_param.requires_grad = True - offloaded_params.append(offloaded_param) - - return offloaded_params - - -class Diloco: - def __init__(self, config: DilocoConfig, model: nn.Module, fsdp_sharding_strategy: ShardingStrategy): - self.config = config - self.fsdp_sharding_strategy = fsdp_sharding_strategy - - if self.fsdp_sharding_strategy != ShardingStrategy.FULL_SHARD: - raise NotImplementedError("Only FULL_SHARD is supported for now") +class ElasticDeviceMesh: + """Init two process group through device mesh, one local on gpu and one global on cpu""" + def __init__(self): self._logger = get_logger() - self.world_info = get_world_info() - self._init_setup_device_mesh() - self._init_offloaded_optimizer(model=model) + self.world_info = get_world_info() - def _init_setup_device_mesh(self): - """Init two process group through device mesh, one local on gpu and one global on cpu""" # right now device mesh does not support two backend so we just create two identicaly mesh expect the backend self.device_mesh = init_device_mesh( "cuda", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("global", "local") @@ -56,30 +34,97 @@ def _init_setup_device_mesh(self): self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}") - def _init_offloaded_optimizer(self, model): - self.cpu_model = get_offloaded_param(model) - # todo: in case of sharded grap op we need to offload the cpu model only once per nodes +class Diloco: + def __init__( + self, + config: DilocoConfig, + model: nn.Module, + fsdp_sharding_strategy: ShardingStrategy, + elastic_device_mesh: ElasticDeviceMesh, + ): + self.config = config + self.fsdp_sharding_strategy = fsdp_sharding_strategy + self.elastic_device_mesh = elastic_device_mesh + + self._logger = get_logger() + self.world_info = get_world_info() + + self.need_to_offload = ( + self.fsdp_sharding_strategy == ShardingStrategy.FULL_SHARD or self.world_info.local_rank == 0 + ) + # if we are not in fully sharded mode only the local rank 0 will have the model on cpu + if self.need_to_offload: + self._init_offloaded_optimizer(model=model) + else: + self.outer_optimizer = None + self.cpu_model = None + + def _init_offloaded_optimizer(self, model): + self.cpu_model = self.get_offloaded_param(model) self.outer_optimizer = torch.optim.SGD(self.cpu_model, lr=self.config.outer_lr, momentum=0.9, nesterov=True) + self._logger.debug("offload model to cpu") def sync_pseudo_gradient(self, model: nn.Module): """ Sync the pseudo gradient from the local process group to the global process group """ + if self.need_to_offload: + self._logger.debug("sync pseudo gradient") + for param_offloaded, param in zip(self.cpu_model, model.parameters()): + # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices + param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) + + # gloo does not support AVG + param_offloaded.grad = param_offloaded.grad / self.elastic_device_mesh.global_pg.size() + dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg) - ### the whole sectione below is just a PoC. We need to benchmark and optimizer what is the most efficient: - ## do the all reduce on cpu or on gpu - ## do the outer optimizer step on cpu or on gpu + def sync_inner_model(self, model: nn.Module): + """ + Sync the inner model from the global process group to the local process group + """ - ## right now we do all reduce on cpu + if self.fsdp_sharding_strategy == ShardingStrategy.FULL_SHARD: + # here each rank has a shard of the model in memory so all rank do the sync + self._logger.debug("sync inner model") + for param_offloaded, param in zip(self.cpu_model, model.parameters()): + param.data = param_offloaded.data.to("cuda") + + elif self.fsdp_sharding_strategy in [ShardingStrategy.SHARD_GRAD_OP, ShardingStrategy.NO_SHARD]: + self._logger.debug("sync inner model") + # in shard_grad_op mode, only the local rank 0 has the model in cpu + # we first copy the model to the gpu 0 and then broadcast it to the other gpu as + # gpu to gpu is faster than cpu to gpu with nvlink + + for i, (param_offloaded, param) in enumerate(zip(self.cpu_model, model.parameters())): + # todo: we can probably overlap both comm here + if self.world_info.local_rank == 0: + self._logger.debug( + f"i: {i} shape param {param.data.shape} shape offloaded {param_offloaded.data.shape}" + ) + param.data = param_offloaded.data.to("cuda") + + dist.broadcast(tensor=param.data, src=0, group=self.elastic_device_mesh.local_pg) + + def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: + """ + Offload the model parameters to cpu + """ + offloaded_params = [] + for param in model.parameters(): + if param.requires_grad: + offloaded_param = param.data.detach().clone().to("cpu") + offloaded_param.requires_grad = True + offloaded_params.append(offloaded_param) - for param_offloaded, param in zip(self.cpu_model, model.parameters()): - # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices - param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) + return offloaded_params - if param_offloaded.grad.device == torch.device("cpu"): - # gloo does not support AVG - param_offloaded.grad = param_offloaded.grad / self.global_pg.size() - dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.global_pg) - else: - dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=self.global_pg) + def step(self, model: nn.Module): + """ + Step the optimizer + """ + self.sync_pseudo_gradient(model) + if self.outer_optimizer is not None: + self.outer_optimizer.step() + self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this + self.sync_inner_model(model) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 00b6c5e8..332b008f 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -19,7 +19,7 @@ ) import torch.distributed as dist from zeroband import utils -from zeroband.diloco import Diloco, DilocoConfig +from zeroband.diloco import Diloco, DilocoConfig, ElasticDeviceMesh from zeroband.utils import get_sharding_strategy from zeroband.utils.monitor import WandbMonitor, DummyMonitor @@ -119,11 +119,14 @@ def train(config: Config): config.data.seq_length, ) + elastic_device_mesh = ElasticDeviceMesh() + model = FSDP( model, sharding_strategy=sharding_strategy, mixed_precision=MixedPrecision(param_dtype=torch.bfloat16), use_orig_params=True, + process_group=elastic_device_mesh.local_pg if config.diloco is not None else None, ) if config.train.torch_compile: @@ -131,7 +134,7 @@ def train(config: Config): logger.debug("model compiled and fsdped") if config.diloco is not None: - diloco = Diloco(config.diloco, model, sharding_strategy) + diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh) # Setup optimizers inner_optimizer = torch.optim.AdamW( @@ -221,9 +224,7 @@ def train(config: Config): ) if config.diloco is not None: - diloco.sync_pseudo_gradient(model) - diloco.outer_optimizer.step() - diloco.outer_optimizer.zero_grad() # todo(sami): check if we can remove this + diloco.step(model) outer_step += 1 From 3c61fc785921bf47e22454d69ec3f5c03bae1148 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 22 Sep 2024 21:07:45 +0000 Subject: [PATCH 06/15] add shard grap op --- scripts/simulate_multi_node.sh | 5 +---- src/zeroband/diloco.py | 14 +++++++++----- src/zeroband/train.py | 3 +++ 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/scripts/simulate_multi_node.sh b/scripts/simulate_multi_node.sh index 518ea21e..d6a9199f 100755 --- a/scripts/simulate_multi_node.sh +++ b/scripts/simulate_multi_node.sh @@ -60,8 +60,5 @@ do CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 & child_pids+=($!) done - -tail -f logs/log0 & -child_pids+=($!) - + wait diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 1af0d411..f694d48f 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -123,8 +123,12 @@ def step(self, model: nn.Module): """ Step the optimizer """ - self.sync_pseudo_gradient(model) - if self.outer_optimizer is not None: - self.outer_optimizer.step() - self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this - self.sync_inner_model(model) + # self.sync_pseudo_gradient(model) + # if self.outer_optimizer is not None: + # self.outer_optimizer.step() + # self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this + + for param in model.parameters(): + param.data = torch.zeros_like(param.data).to(param.data.device) + + # self.sync_inner_model(model) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 332b008f..552549cf 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -134,6 +134,9 @@ def train(config: Config): logger.debug("model compiled and fsdped") if config.diloco is not None: + if world_info.local_world_size == 1: + raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug") + diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh) # Setup optimizers From b4d5fd78f67049289c7c312a28b2c4a990ce6e87 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 22 Sep 2024 22:02:29 +0000 Subject: [PATCH 07/15] add instruction to readme --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 37aa2fbe..976264e0 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,14 @@ To check that everything is working you can do ZERO_BAND_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug/debug.toml ``` +## run diloco + +To run diloco locally you can use the helper script `scripts/simulatsimulate_multi_nodee_mutl.sh` + +```bash +ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node.sh 2 2 src/zeroband/train.py @configs/debug/diloco.toml +``` + ## run test You need a machine with a least two gpus to run the full test suite. From a1a72e367c9673fa31fd0f56d12a40942484adde Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 22 Sep 2024 22:14:03 +0000 Subject: [PATCH 08/15] add tests --- tests/test_torchrun/test_train.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index 8e98716a..7bb545dd 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -28,9 +28,7 @@ def gpus_to_use(num_nodes, num_gpu, rank): return ",".join(map(str, range(rank * num_gpu, (rank + 1) * num_gpu))) -@pytest.mark.parametrize("num_gpus", [[1, 1], [2, 1], [1, 2]]) -@pytest.mark.parametrize("config", ["debug/debug.toml", "debug/diloco.toml"]) -def test_multi_gpu(num_gpus, config): +def _test_multi_gpu(num_gpus, config, diloco: bool): num_nodes, num_gpu = num_gpus[0], num_gpus[1] processes = [] @@ -56,20 +54,12 @@ def test_multi_gpu(num_gpus, config): pytest.fail(f"Process {result} failed {result}") -@pytest.mark.parametrize("num_gpu", [1, 2]) -def test_multi_gpu_diloco(random_available_port, num_gpu): - cmd = [ - "torchrun", - f"--nproc_per_node={num_gpu}", - "--rdzv-endpoint", - f"localhost:{random_available_port}", - "src/zeroband/train.py", - "@configs/debug/diloco.toml", - "--optim.total_steps", - "50", - ] +@pytest.mark.parametrize("num_gpus", [[1, 1], [2, 1], [1, 2]]) +def test_multi_gpu(num_gpus): + _test_multi_gpu(num_gpus, "debug/debug.toml", diloco=False) - result = subprocess.run(cmd) - if result.returncode != 0: - pytest.fail(f"Process {result} failed {result.stderr}") +@pytest.mark.parametrize("num_gpus", [[1, 2], [2, 2]]) +def test_multi_gpu_diloco(num_gpus): + # we don't test 1,1 and 2,1 because 1 solo gpu failed with fsdp + _test_multi_gpu(num_gpus, "debug/diloco.toml", diloco=True) From 01bf6c9f1a23a9a4ba7d055b7c7aee1f935d713b Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 22 Sep 2024 22:19:52 +0000 Subject: [PATCH 09/15] use dummy monitor for debug --- configs/debug/debug.toml | 4 +++- configs/debug/diloco.toml | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/configs/debug/debug.toml b/configs/debug/debug.toml index ae283e9e..a9bcea26 100644 --- a/configs/debug/debug.toml +++ b/configs/debug/debug.toml @@ -1,5 +1,6 @@ name_model = "debugmodel" -project = "debug" +project = "/tmp/debug" +metric_logger_type = "dummy" [train] micro_bs = 8 @@ -12,3 +13,4 @@ total_steps = 10 [data] fake_data = true + diff --git a/configs/debug/diloco.toml b/configs/debug/diloco.toml index 9283c964..16805f14 100644 --- a/configs/debug/diloco.toml +++ b/configs/debug/diloco.toml @@ -1,5 +1,6 @@ name_model = "debugmodel" -project = "debug" +project = "/tmp/debug" +metric_logger_type = "dummy" [train] micro_bs = 8 From 7db3ceace6e420e57654118c395f0e6e016676df Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 22 Sep 2024 22:31:45 +0000 Subject: [PATCH 10/15] rename fake data argf --- configs/debug/debug.toml | 2 +- configs/debug/diloco.toml | 2 +- src/zeroband/train.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/configs/debug/debug.toml b/configs/debug/debug.toml index a9bcea26..69757b76 100644 --- a/configs/debug/debug.toml +++ b/configs/debug/debug.toml @@ -12,5 +12,5 @@ warmup_steps = 10 total_steps = 10 [data] -fake_data = true +fake = true diff --git a/configs/debug/diloco.toml b/configs/debug/diloco.toml index 16805f14..fe2b459d 100644 --- a/configs/debug/diloco.toml +++ b/configs/debug/diloco.toml @@ -12,7 +12,7 @@ warmup_steps = 10 total_steps = 10 [data] -fake_data = true +fake = true [diloco] inner_steps = 5 diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 552549cf..bad4f6b7 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -39,7 +39,7 @@ def ddp_setup(): class DataConfig(BaseConfig): seq_length: int = 1024 - fake_data: bool = False + fake: bool = False num_workers: int = 4 @@ -97,7 +97,7 @@ def train(config: Config): seq_length=config.data.seq_length, batch_size=config.train.micro_bs, num_workers=config.data.num_workers, - fake_data=config.data.fake_data, + fake_data=config.data.fake, ) model, model_config = get_model( From f021af4f813676d4db09b0a5e50401876c66c32d Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 22 Sep 2024 22:41:50 +0000 Subject: [PATCH 11/15] add back tail -f --- scripts/simulate_multi_node.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scripts/simulate_multi_node.sh b/scripts/simulate_multi_node.sh index d6a9199f..c5def4a9 100755 --- a/scripts/simulate_multi_node.sh +++ b/scripts/simulate_multi_node.sh @@ -60,5 +60,8 @@ do CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 & child_pids+=($!) done + +tail -f logs/log0 & +child_pids+=($!) wait From 293dfbda183af95816db58ab05f38e9e9c9a1360 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 22 Sep 2024 22:49:33 +0000 Subject: [PATCH 12/15] add diloco peers logs --- src/zeroband/train.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index bad4f6b7..ab863aff 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -219,12 +219,17 @@ def train(config: Config): "mfu": mfu, } + if config.diloco is not None: + metrics["num_peers"] = elastic_device_mesh.global_pg.size() + if world_info.rank == 0: metric_logger.log(metrics) - logger.info( - f"step: {real_step}, loss: {loss_batch.item():.4f}, tokens_per_second: {metrics['tokens_per_second']:.2f}, mfu: {mfu:.2f}" - ) + log = f"step: {real_step}, loss: {loss_batch.item():.4f}, tokens_per_second: {metrics['tokens_per_second']:.2f}, mfu: {mfu:.2f}" + if config.diloco is not None: + log += f", diloco_peers: {metrics['num_peers']}" + + logger.info(log) if config.diloco is not None: diloco.step(model) From fc603ef85cd1a0dffa4beced3ae2ee49361a1672 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 22 Sep 2024 22:54:02 +0000 Subject: [PATCH 13/15] fix mfu --- src/zeroband/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index ab863aff..5bb4243d 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -207,7 +207,7 @@ def train(config: Config): time_taken = time.time() - beginning_step_time tokens_per_second = config.data.seq_length * config.optim.batch_size / time_taken - mfu = 100 * num_flop_per_token * tokens_per_second / gpu_peak_flops + mfu = 100 * num_flop_per_token * tokens_per_second / gpu_peak_flops / world_info.local_world_size metrics = { "Loss": loss_batch.item(), From 146ecffd5970f0a8bf330ee1eb9cede754ac7dca Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 22 Sep 2024 23:41:49 +0000 Subject: [PATCH 14/15] Add diloco docstring --- src/zeroband/diloco.py | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index f694d48f..d1d5b81c 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -36,6 +36,31 @@ def __init__(self): class Diloco: + """ + This class implements the diloco algorithm from https://arxiv.org/abs/2311.08105 and https://arxiv.org/abs/2407.07852. + + It handles the outer loop as well as the inter node communication. + + There is no VRAM overhead with this implementation as the model is outer optimizer is offloaded to cpu. + All reduce communication are also done on cpu using GLOO. + + Example usage: + + # Example usage in a training loop: + + diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh) + + for outer_step in range(num_outer_steps): + for inner_step in range(config.diloco.inner_steps): + # Regular inner training loop + optimizer.zero_grad() + loss = model(batch) + loss.backward() + optimizer.step() + + diloco.step(model) + """ + def __init__( self, config: DilocoConfig, @@ -123,12 +148,9 @@ def step(self, model: nn.Module): """ Step the optimizer """ - # self.sync_pseudo_gradient(model) - # if self.outer_optimizer is not None: - # self.outer_optimizer.step() - # self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this - - for param in model.parameters(): - param.data = torch.zeros_like(param.data).to(param.data.device) + self.sync_pseudo_gradient(model) + if self.outer_optimizer is not None: + self.outer_optimizer.step() + self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this - # self.sync_inner_model(model) + self.sync_inner_model(model) From 78af87ce582fe2c8aa771b9f4fe5826f9aba3660 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 22 Sep 2024 23:42:44 +0000 Subject: [PATCH 15/15] refactor: remode ddp func --- src/zeroband/train.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 5bb4243d..39cae775 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -29,14 +29,6 @@ from zeroband.utils.logging import get_logger -def ddp_setup(): - """ - Initialize the distributed process group. - """ - init_process_group() - torch.cuda.set_device(world_info.local_rank) - - class DataConfig(BaseConfig): seq_length: int = 1024 fake: bool = False @@ -255,7 +247,8 @@ def train(config: Config): world_info = get_world_info() logger = get_logger() - ddp_setup() + init_process_group() + torch.cuda.set_device(world_info.local_rank) config = Config(**parse_argv()) logger.debug(f"config: {config.model_dump()}")