diff --git a/internlm/core/scheduler/base_scheduler.py b/internlm/core/scheduler/base_scheduler.py index 6e194257..dade9076 100644 --- a/internlm/core/scheduler/base_scheduler.py +++ b/internlm/core/scheduler/base_scheduler.py @@ -122,6 +122,19 @@ def _call_engine_criterion(engine: Engine, outputs: Any, labels: Any): 'but got {type(outputs)} (model outputs) and {type(labels)} (labels)" ) + def cal_max_seqlen(self, data: dict): + if isinstance(data, dict) and "cu_seqlens" in data: + # Without BC modeling interface, we try to calculate 'max_seqlen' in advance + # to avoid overlap being interrupted by .item() operations. + if isinstance(data["cu_seqlens"], list): + cu_seqlens = data["cu_seqlens"][0] + else: + cu_seqlens = data["cu_seqlens"] + + cu_seqlens = cu_seqlens.squeeze(0) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + data.update({"max_seqlen": max_seqlen}) + class SchedulerHook(ABC): """ diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 79a6f625..cd816e06 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -81,6 +81,8 @@ def _load_accum_batch(self, data: Any, label: Any): _data.pop("cu_seqlens") _data.pop("indexes") + self.cal_max_seqlen(_data) + return _data, _label def _train_one_batch( diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 5b864ff4..7ea6a691 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -222,6 +222,8 @@ def load_micro_batch(self): micro_batch_data.pop("cu_seqlens") micro_batch_data.pop("indexes") + self.cal_max_seqlen(micro_batch_data) + micro_batch_data["label"] = micro_batch_label self.microbatch_offset += self.bsz_stride diff --git a/internlm/model/embedding.py b/internlm/model/embedding.py index d1770538..ed69a061 100644 --- a/internlm/model/embedding.py +++ b/internlm/model/embedding.py @@ -153,12 +153,8 @@ def __init__(self, dim: int, base=10000, scale_base=0, device=None): self._cos_k_cached = None self._sin_k_cached = None - def _update_cos_sin_cache(self, x, indexes): + def _update_cos_sin_cache(self, x, seqlen): """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)""" - if not isinstance(indexes, int): - seqlen = indexes.max().item() + 1 - else: - seqlen = indexes + 1 # eval_forward # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype: @@ -183,14 +179,23 @@ def _update_cos_sin_cache(self, x, indexes): def forward(self, qkv: torch.Tensor, **kwargs): if kwargs.get("indexes", None) is not None: - return self._forward(qkv, kwargs.pop("indexes")) + return self._forward(qkv, kwargs.pop("indexes"), kwargs.get("max_seqlen", None)) if kwargs.get("inference_params", None) is not None: return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset) else: return self._eval_forward(qkv) - def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]: - self._update_cos_sin_cache(qkv, indexes) + def _cal_max_seqlen(self, indexes, max_seqlen=None): + if not isinstance(indexes, int): + if max_seqlen is None: # We try to avoid call .item() function in fwd/bwd. + max_seqlen = indexes.max().item() + 1 + else: + max_seqlen = indexes + 1 # eval_forward + return max_seqlen + + def _forward(self, qkv: torch.Tensor, indexes=0, max_seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]: + max_seqlen = self._cal_max_seqlen(indexes, max_seqlen) + self._update_cos_sin_cache(qkv, max_seqlen) if self.scale is None: return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes]) else: @@ -221,9 +226,9 @@ def _eval_forward(self, qkv, seqlen_offset=0): self._sin_k_cached[seqlen_offset:], ) - def _single_forward(self, x, indexes=0): + def _single_forward(self, x, indexes=0, **kwargs): assert self.scale is None - self._update_cos_sin_cache(x, indexes) + self._update_cos_sin_cache(x, self._cal_max_seqlen(indexes, kwargs.get("max_seqlen", None))) x = x[None, ...] ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0) return ret @@ -275,12 +280,8 @@ def _update(self, seqlen, x): self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) - def _update_cos_sin_cache(self, x, indexes): + def _update_cos_sin_cache(self, x, seqlen): """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)""" - if not isinstance(indexes, int): - seqlen = indexes.max().item() + 1 - else: - seqlen = indexes + 1 # eval_forward if seqlen <= self.max_position_embeddings: # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index a47a5cdd..1029f0da 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -380,7 +380,9 @@ def __init__( self.parallel_output = parallel_output - def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + def forward( + self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs + ): # attention_mask: compute attention on the places where the value is 1 if hasattr(self, "embedding"): hidden_states = self.embedding(input_ids) @@ -401,7 +403,14 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N assert len(indexes) == 1 # The indexes are used to indicate the actual position IDs of each token in the packed input. indexes = indexes[0] - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None + + if cu_seqlens is not None: + if "max_seqlen" not in kwargs: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + else: + max_seqlen = kwargs.pop("max_seqlen") + else: + max_seqlen = None for _, block in enumerate(self.blocks): hidden_states = block( diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index df6c7a84..cac2e434 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -448,7 +448,9 @@ def __init__( setattr(param, IS_TENSOR_PARALLEL, True) self.parallel_output = parallel_output - def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + def forward( + self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs + ): # attention_mask: compute attention on the places where the value is 1 # old condition may fail when use shared embedding if gpc.is_pipeline_first_stage(): @@ -470,7 +472,14 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N assert len(indexes) == 1 # The indexes are used to indicate the actual position IDs of each token in the packed input. indexes = indexes[0] - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None + + if cu_seqlens is not None: + if "max_seqlen" not in kwargs: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + else: + max_seqlen = kwargs.pop("max_seqlen") + else: + max_seqlen = None moe_losses = [] for _, block in enumerate(self.blocks): diff --git a/internlm/utils/common.py b/internlm/utils/common.py index a20b61dd..d162cc82 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -104,10 +104,10 @@ def get_batch_size(data): return data.size(0) elif isinstance(data, (list, tuple)): if isinstance(data[0], dict): - return data[0][list(data[0].keys())[0]].size(0) + return data[0]["input_ids"].size(0) return data[0].size(0) elif isinstance(data, dict): - return data[list(data.keys())[0]].size(0) + return data["input_ids"].size(0) def check_data_is_packed(data): diff --git a/tests/test_core/utils.py b/tests/test_core/utils.py index 6f66a152..91561a18 100644 --- a/tests/test_core/utils.py +++ b/tests/test_core/utils.py @@ -33,7 +33,7 @@ def __init__(self, start, end, model_type=None, embedding=False): self.embedding = embedding def forward( - self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None + self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs ): # pylint: disable=W0613 if self.model_type != "torch" and self.part[0] != 0: input_ids = hidden_states