diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index df6c7a84..cac2e434 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -448,7 +448,9 @@ def __init__( setattr(param, IS_TENSOR_PARALLEL, True) self.parallel_output = parallel_output - def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + def forward( + self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs + ): # attention_mask: compute attention on the places where the value is 1 # old condition may fail when use shared embedding if gpc.is_pipeline_first_stage(): @@ -470,7 +472,14 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N assert len(indexes) == 1 # The indexes are used to indicate the actual position IDs of each token in the packed input. indexes = indexes[0] - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None + + if cu_seqlens is not None: + if "max_seqlen" not in kwargs: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + else: + max_seqlen = kwargs.pop("max_seqlen") + else: + max_seqlen = None moe_losses = [] for _, block in enumerate(self.blocks): diff --git a/tests/test_core/utils.py b/tests/test_core/utils.py index 6f66a152..91561a18 100644 --- a/tests/test_core/utils.py +++ b/tests/test_core/utils.py @@ -33,7 +33,7 @@ def __init__(self, start, end, model_type=None, embedding=False): self.embedding = embedding def forward( - self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None + self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs ): # pylint: disable=W0613 if self.model_type != "torch" and self.part[0] != 0: input_ids = hidden_states