Skip to content

Commit

Permalink
add shard grap op
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 22, 2024
1 parent 6c65322 commit 3c61fc7
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
5 changes: 1 addition & 4 deletions scripts/simulate_multi_node.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,5 @@ do
CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 &
child_pids+=($!)
done

tail -f logs/log0 &
child_pids+=($!)


wait
14 changes: 9 additions & 5 deletions src/zeroband/diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,12 @@ def step(self, model: nn.Module):
"""
Step the optimizer
"""
self.sync_pseudo_gradient(model)
if self.outer_optimizer is not None:
self.outer_optimizer.step()
self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this
self.sync_inner_model(model)
# self.sync_pseudo_gradient(model)
# if self.outer_optimizer is not None:
# self.outer_optimizer.step()
# self.outer_optimizer.zero_grad() # todo(sami): check if we can remove this

for param in model.parameters():
param.data = torch.zeros_like(param.data).to(param.data.device)

# self.sync_inner_model(model)
3 changes: 3 additions & 0 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def train(config: Config):
logger.debug("model compiled and fsdped")

if config.diloco is not None:
if world_info.local_world_size == 1:
raise ValueError("Diloco is not supported for local_world_size == 1 because of a pytorch bug")

diloco = Diloco(config.diloco, model, sharding_strategy, elastic_device_mesh)

# Setup optimizers
Expand Down

0 comments on commit 3c61fc7

Please sign in to comment.