Skip to content

Commit

Permalink
fix(pipeline): avoid allreduce for dense model (InternLM#570)
Browse files Browse the repository at this point in the history
* avoid allreduce for dense model in pp

* avoid allreduce when num_expert=1
  • Loading branch information
blankde authored Jan 9, 2024
1 parent 5539f9d commit 91480c5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion internlm/core/scheduler/no_pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _train_one_batch(
self._call_hooks("after_criterion", loss)
moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
if hasattr(gpc.config.model, "num_experts")
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
)
moe_loss /= scale_loss
Expand Down
15 changes: 10 additions & 5 deletions internlm/core/scheduler/pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def _forward_step(

moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
if hasattr(gpc.config.model, "num_experts")
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
)
moe_loss /= self.num_microbatches
Expand Down Expand Up @@ -445,7 +445,9 @@ def _forward_only_step(self, engine, return_loss=True, return_output_label=True)
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)

output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))

if accum_loss is not None:
accum_loss += accum_moe_loss
Expand Down Expand Up @@ -647,7 +649,9 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)

output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))

if accum_loss is not None:
accum_loss += accum_moe_loss
Expand Down Expand Up @@ -855,7 +859,7 @@ def _forward_step(self, engine, chunk_id):

moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
if hasattr(gpc.config.model, "num_experts")
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
)
moe_loss /= self.num_microbatches
Expand Down Expand Up @@ -1387,7 +1391,8 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo
else:
output, label = (None, None)

dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
accum_moe_loss = self._accum_moe_loss

accum_loss = self._accum_loss
Expand Down

0 comments on commit 91480c5

Please sign in to comment.