Skip to content

Commit

Permalink
refactor: use only global pg in diloco + first failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 26, 2024
1 parent df1616c commit 294c36d
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 14 deletions.
11 changes: 5 additions & 6 deletions src/zeroband/diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ def __init__(
config: DilocoConfig,
model: nn.Module,
fsdp_sharding_strategy: ShardingStrategy,
elastic_device_mesh: ElasticDeviceMesh,
global_pg: dist.ProcessGroup,
):
self.config = config
self.fsdp_sharding_strategy = fsdp_sharding_strategy
self.elastic_device_mesh = elastic_device_mesh
self.global_pg = global_pg

self._logger = get_logger()
self.world_info = get_world_info()
Expand All @@ -93,13 +93,12 @@ def sync_pseudo_gradient(self, model: nn.Module):
"""
self._logger.debug("sync pseudo gradient")
for param_offloaded, param in zip(self.param_list_cpu, model.parameters()):
# todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices
param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device)

# gloo does not support AVG
param_offloaded.grad = param_offloaded.grad / self.elastic_device_mesh.global_pg.size()
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.elastic_device_mesh.global_pg)
# todo maybe do async here
param_offloaded.grad = param_offloaded.grad / self.global_pg.size()
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=self.global_pg)
# todo async here

def sync_inner_model(self, model: nn.Module):
"""
Expand Down
2 changes: 1 addition & 1 deletion src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def train(config: Config):
if world_info.local_world_size == 1:
raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug")

diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh)
diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh.global_pg)

# Setup optimizers
inner_optimizer = torch.optim.AdamW(
Expand Down
4 changes: 3 additions & 1 deletion tests/test_dist/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,17 @@ def random_available_port():
@pytest.fixture()
def dist_environment() -> callable:
@contextmanager
def dist_environment(random_available_port, local_rank=0, world_size=1):
def dist_environment(random_available_port, local_rank=0, world_size=1, local_world_size=1):
with mock.patch.dict(
os.environ,
{
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"LOCAL_WORLD_SIZE": str(local_world_size),
"RANK": str(local_rank),
"MASTER_ADDR": "localhost",
"MASTER_PORT": str(random_available_port),
"ZERO_BAND_LOG_LEVEL": "DEBUG",
},
):
try:
Expand Down
5 changes: 0 additions & 5 deletions tests/test_dist/test_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,15 @@
import torch
import pytest


import os
import multiprocessing


@pytest.mark.parametrize("world_size", [2])
def test_all_reduce(world_size, random_available_port, dist_environment):
def all_reduce(rank: int, world_size: int):
with dist_environment(random_available_port, local_rank=rank, world_size=world_size):
print(f"os.environ['LOCAL_RANK'] {os.environ['WORLD_SIZE']}")
data = (rank + 1) * torch.ones(10, 10).to("cuda")
print(data.mean())
dist.all_reduce(data, op=dist.ReduceOp.SUM)
print(data.mean())
assert data.mean() == sum([i + 1 for i in range(world_size)])

processes = [multiprocessing.Process(target=all_reduce, args=(rank, world_size)) for rank in range(world_size)]
Expand Down
59 changes: 58 additions & 1 deletion tests/test_dist/test_diloco.py
Original file line number Diff line number Diff line change
@@ -1 +1,58 @@
"""test Diloco. Need 4 gpus to run this tests"""
"""test Diloco."""

import multiprocessing
import pytest

import torch
import torch.distributed as dist
from torch.distributed.fsdp import ShardingStrategy

from zeroband.diloco import Diloco, DilocoConfig


@pytest.mark.parametrize("world_size", [2]) # [1, 2])
def test_diloco_all_reduce(world_size, random_available_port, dist_environment):
"""
In this test we manually create a inner model and a outer model where we control the weight:
inner has weight: (rank + 1) / 2
outer has weight: (rank + 1)
since we know the world_size we can predict the results of the all reduce of the pseudo gradient and therefore test
if it is done correclty.
"""

def all_reduce(rank: int, world_size: int):
with dist_environment(random_available_port, local_rank=rank, world_size=world_size):
diloco_config = DilocoConfig(inner_steps=10)

model = torch.nn.Linear(10, 10)

# init param to rank + 1
for param in model.parameters():
param.data = (rank + 1) * torch.ones_like(param.data).to("cuda")

global_pg = dist.new_group(backend="gloo")
diloco = Diloco(diloco_config, model, ShardingStrategy.FULL_SHARD, global_pg)

# simulate inner model updates
for param in model.parameters():
param.data = (rank + 1) / 2 * torch.ones_like(param.data).to("cuda")

diloco.sync_pseudo_gradient(model)

for param in diloco.param_list_cpu:
print(f"param.grad.mean() {param.grad.mean()}")
target = (
torch.ones_like(param.grad)
* sum([(rank + 1) - (rank + 1) / 2 for rank in range(world_size)])
/ world_size
)
assert param.grad.mean() == target.mean()

processes = [multiprocessing.Process(target=all_reduce, args=(rank, world_size)) for rank in range(world_size)]
for p in processes:
p.start()
for p in processes:
p.join()
if p.exitcode != 0:
pytest.fail(f"Process {p.pid} failed with exit code {p.exitcode}")

0 comments on commit 294c36d

Please sign in to comment.