Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
SolenoidWGT committed Dec 29, 2023
1 parent 1d217eb commit e920872
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
13 changes: 11 additions & 2 deletions internlm/model/modeling_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,9 @@ def __init__(
setattr(param, IS_TENSOR_PARALLEL, True)
self.parallel_output = parallel_output

def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
def forward(
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs
):
# attention_mask: compute attention on the places where the value is 1
# old condition may fail when use shared embedding
if gpc.is_pipeline_first_stage():
Expand All @@ -470,7 +472,14 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N
assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None

if cu_seqlens is not None:
if "max_seqlen" not in kwargs:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
else:
max_seqlen = kwargs.pop("max_seqlen")
else:
max_seqlen = None

moe_losses = []
for _, block in enumerate(self.blocks):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, start, end, model_type=None, embedding=False):
self.embedding = embedding

def forward(
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs
): # pylint: disable=W0613
if self.model_type != "torch" and self.part[0] != 0:
input_ids = hidden_states
Expand Down

0 comments on commit e920872

Please sign in to comment.