Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(modeling): avoid calling item() in fwd/bwd #6

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions internlm/core/scheduler/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,19 @@ def _call_engine_criterion(engine: Engine, outputs: Any, labels: Any):
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
)

def cal_max_seqlen(self, data: dict):
if isinstance(data, dict) and "cu_seqlens" in data:
# Without BC modeling interface, we try to calculate 'max_seqlen' in advance
# to avoid overlap being interrupted by .item() operations.
if isinstance(data["cu_seqlens"], list):
cu_seqlens = data["cu_seqlens"][0]
else:
cu_seqlens = data["cu_seqlens"]

cu_seqlens = cu_seqlens.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
data.update({"max_seqlen": max_seqlen})


class SchedulerHook(ABC):
"""
Expand Down
2 changes: 2 additions & 0 deletions internlm/core/scheduler/no_pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def _load_accum_batch(self, data: Any, label: Any):
_data.pop("cu_seqlens")
_data.pop("indexes")

self.cal_max_seqlen(_data)

return _data, _label

def _train_one_batch(
Expand Down
2 changes: 2 additions & 0 deletions internlm/core/scheduler/pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def load_micro_batch(self):
micro_batch_data.pop("cu_seqlens")
micro_batch_data.pop("indexes")

self.cal_max_seqlen(micro_batch_data)

micro_batch_data["label"] = micro_batch_label
self.microbatch_offset += self.bsz_stride

Expand Down
31 changes: 16 additions & 15 deletions internlm/model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,8 @@ def __init__(self, dim: int, base=10000, scale_base=0, device=None):
self._cos_k_cached = None
self._sin_k_cached = None

def _update_cos_sin_cache(self, x, indexes):
def _update_cos_sin_cache(self, x, seqlen):
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)"""
if not isinstance(indexes, int):
seqlen = indexes.max().item() + 1
else:
seqlen = indexes + 1 # eval_forward
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
Expand All @@ -183,14 +179,23 @@ def _update_cos_sin_cache(self, x, indexes):

def forward(self, qkv: torch.Tensor, **kwargs):
if kwargs.get("indexes", None) is not None:
return self._forward(qkv, kwargs.pop("indexes"))
return self._forward(qkv, kwargs.pop("indexes"), kwargs.get("max_seqlen", None))
if kwargs.get("inference_params", None) is not None:
return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset)
else:
return self._eval_forward(qkv)

def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
self._update_cos_sin_cache(qkv, indexes)
def _cal_max_seqlen(self, indexes, max_seqlen=None):
if not isinstance(indexes, int):
if max_seqlen is None: # We try to avoid call .item() function in fwd/bwd.
max_seqlen = indexes.max().item() + 1
else:
max_seqlen = indexes + 1 # eval_forward
return max_seqlen

def _forward(self, qkv: torch.Tensor, indexes=0, max_seqlen=None) -> Tuple[torch.Tensor, torch.Tensor]:
max_seqlen = self._cal_max_seqlen(indexes, max_seqlen)
self._update_cos_sin_cache(qkv, max_seqlen)
if self.scale is None:
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
else:
Expand Down Expand Up @@ -221,9 +226,9 @@ def _eval_forward(self, qkv, seqlen_offset=0):
self._sin_k_cached[seqlen_offset:],
)

def _single_forward(self, x, indexes=0):
def _single_forward(self, x, indexes=0, **kwargs):
assert self.scale is None
self._update_cos_sin_cache(x, indexes)
self._update_cos_sin_cache(x, self._cal_max_seqlen(indexes, kwargs.get("max_seqlen", None)))
x = x[None, ...]
ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0)
return ret
Expand Down Expand Up @@ -275,12 +280,8 @@ def _update(self, seqlen, x):
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)

def _update_cos_sin_cache(self, x, indexes):
def _update_cos_sin_cache(self, x, seqlen):
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)"""
if not isinstance(indexes, int):
seqlen = indexes.max().item() + 1
else:
seqlen = indexes + 1 # eval_forward
if seqlen <= self.max_position_embeddings:
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
Expand Down
13 changes: 11 additions & 2 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ def __init__(

self.parallel_output = parallel_output

def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
def forward(
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs
):
# attention_mask: compute attention on the places where the value is 1
if hasattr(self, "embedding"):
hidden_states = self.embedding(input_ids)
Expand All @@ -401,7 +403,14 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N
assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None

if cu_seqlens is not None:
if "max_seqlen" not in kwargs:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
else:
max_seqlen = kwargs.pop("max_seqlen")
else:
max_seqlen = None

for _, block in enumerate(self.blocks):
hidden_states = block(
Expand Down
13 changes: 11 additions & 2 deletions internlm/model/modeling_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,9 @@ def __init__(
setattr(param, IS_TENSOR_PARALLEL, True)
self.parallel_output = parallel_output

def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
def forward(
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs
):
# attention_mask: compute attention on the places where the value is 1
# old condition may fail when use shared embedding
if gpc.is_pipeline_first_stage():
Expand All @@ -470,7 +472,14 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N
assert len(indexes) == 1
# The indexes are used to indicate the actual position IDs of each token in the packed input.
indexes = indexes[0]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None

if cu_seqlens is not None:
if "max_seqlen" not in kwargs:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
else:
max_seqlen = kwargs.pop("max_seqlen")
else:
max_seqlen = None

moe_losses = []
for _, block in enumerate(self.blocks):
Expand Down
4 changes: 2 additions & 2 deletions internlm/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ def get_batch_size(data):
return data.size(0)
elif isinstance(data, (list, tuple)):
if isinstance(data[0], dict):
return data[0][list(data[0].keys())[0]].size(0)
return data[0]["input_ids"].size(0)
return data[0].size(0)
elif isinstance(data, dict):
return data[list(data.keys())[0]].size(0)
return data["input_ids"].size(0)


def check_data_is_packed(data):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, start, end, model_type=None, embedding=False):
self.embedding = embedding

def forward(
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None, **kwargs
): # pylint: disable=W0613
if self.model_type != "torch" and self.part[0] != 0:
input_ids = hidden_states
Expand Down