From 91480c5b63a3d6eaa119ce44e6a4a1998dc77d9b Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Tue, 9 Jan 2024 10:34:22 +0800 Subject: [PATCH] fix(pipeline): avoid allreduce for dense model (#570) * avoid allreduce for dense model in pp * avoid allreduce when num_expert=1 --- internlm/core/scheduler/no_pipeline_scheduler.py | 2 +- internlm/core/scheduler/pipeline_scheduler.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 79a6f625..f6cabbcf 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -122,7 +122,7 @@ def _train_one_batch( self._call_hooks("after_criterion", loss) moe_loss = ( sum(moe_losses) * gpc.config.loss.moe_loss_coeff - if hasattr(gpc.config.model, "num_experts") + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1 else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) ) moe_loss /= scale_loss diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 5b864ff4..2c7e0df8 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -308,7 +308,7 @@ def _forward_step( moe_loss = ( sum(moe_losses) * gpc.config.loss.moe_loss_coeff - if hasattr(gpc.config.model, "num_experts") + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1 else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) ) moe_loss /= self.num_microbatches @@ -445,7 +445,9 @@ def _forward_only_step(self, engine, return_loss=True, return_output_label=True) comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) - dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) + + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) if accum_loss is not None: accum_loss += accum_moe_loss @@ -647,7 +649,9 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None) - dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) + + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) if accum_loss is not None: accum_loss += accum_moe_loss @@ -855,7 +859,7 @@ def _forward_step(self, engine, chunk_id): moe_loss = ( sum(moe_losses) * gpc.config.loss.moe_loss_coeff - if hasattr(gpc.config.model, "num_experts") + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1 else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) ) moe_loss /= self.num_microbatches @@ -1387,7 +1391,8 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo else: output, label = (None, None) - dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) + if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1: + dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE)) accum_moe_loss = self._accum_moe_loss accum_loss = self._accum_loss