Skip to content

Commit

Permalink
[Backend support] Allow num_logits_to_keep as Tensor + add flag (#3…
Browse files Browse the repository at this point in the history
…5757)

* support

* Update modeling_utils.py

* style

* most models

* Other models

* fix-copies

* tests + generation utils
  • Loading branch information
Cyrilvallez authored Jan 23, 2025
1 parent 8736e91 commit d3af76d
Show file tree
Hide file tree
Showing 62 changed files with 603 additions and 315 deletions.
6 changes: 3 additions & 3 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ def __init__(
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
)

# Remove potential default "num_logits_to_keep" key
if "num_logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_num_logits_to_keep():
del assistant_kwargs["num_logits_to_keep"]
# Remove potential default "logits_to_keep" key
if "logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_logits_to_keep():
del assistant_kwargs["logits_to_keep"]

if "assistant_encoder_outputs" in model_kwargs:
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
Expand Down
24 changes: 12 additions & 12 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,12 +1780,12 @@ def _prepare_cache_for_generation(
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)

def _supports_num_logits_to_keep(self) -> bool:
def _supports_logits_to_keep(self) -> bool:
"""
Return True if the current model supports the keyword argument `num_logits_to_keep` in forward()
Return True if the current model supports the keyword argument `logits_to_keep` in forward()
to save memory. Checking it in this way allows to avoid using a new model attribute.
"""
return "num_logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())
return "logits_to_keep" in set(inspect.signature(self.forward).parameters.keys())

def _prepare_special_tokens(
self,
Expand Down Expand Up @@ -2066,11 +2066,11 @@ def generate(
input_ids_length=input_ids_length,
)

# If the model supports `num_logits_to_keep` in forward(), set it to 1 to avoid computing the whole
# If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
# logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
# dynamically overrides this value as it can need more than the last token logits
if self._supports_num_logits_to_keep() and "num_logits_to_keep" not in model_kwargs:
model_kwargs["num_logits_to_keep"] = 1
if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
model_kwargs["logits_to_keep"] = 1

self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

Expand Down Expand Up @@ -4236,8 +4236,8 @@ def _assisted_decoding(
)

model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
if "num_logits_to_keep" in model_inputs:
model_inputs["num_logits_to_keep"] = candidate_length + 1
if "logits_to_keep" in model_inputs:
model_inputs["logits_to_keep"] = candidate_length + 1

# 2.2. Run a forward pass on the candidate sequence
# prepare variable output controls (note: some models won't accept all output controls)
Expand Down Expand Up @@ -4575,7 +4575,7 @@ def _split_model_inputs(
# ModelOutput object.
# bool should not be split but replicated for each split
bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"]
keys_to_ignore = ["cache_position", "encoder_outputs", "num_logits_to_keep"]
keys_to_ignore = ["cache_position", "encoder_outputs", "logits_to_keep"]
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore]

num_hidden_layers = config.get_text_config().num_hidden_layers
Expand All @@ -4595,10 +4595,10 @@ def _split_model_inputs(
data_split_list = [
{**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list)
]
# num_logits_to_keep should be replicated for each split, similar to bool values
if "num_logits_to_keep" in model_input:
# logits_to_keep should be replicated for each split, similar to bool values
if "logits_to_keep" in model_input:
data_split_list = [
{**data_split, "num_logits_to_keep": model_input["num_logits_to_keep"]} for data_split in data_split_list
{**data_split, "logits_to_keep": model_input["logits_to_keep"]} for data_split in data_split_list
]

# Convert each dictionary in the list to an object of the inferred class
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,6 +1292,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# `config.base_model_tp_plan` during `post_init`.
_tp_plan = None

# This flag signal that the model can be used as an efficient backend in TGI and vLLM
# In practice, it means that they support attention interface functions, fully pass the kwargs
# through all modules up to the Attention layer, and can slice logits with Tensor
_supports_attention_backend = False

@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
Expand Down Expand Up @@ -5188,6 +5193,10 @@ def get_compiled_call(self, compile_config: CompileConfig):
self._compiled_call = torch.compile(self.__call__, **compile_config.to_dict())
return self._compiled_call

@classmethod
def is_backend_compatible(cls):
return cls._supports_attention_backend


PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
if PreTrainedModel.push_to_hub.__doc__ is not None:
Expand Down
28 changes: 18 additions & 10 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
logging,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from ...utils.import_utils import is_torch_available
from ..auto import AutoModel, AutoModelForCausalLM
from .configuration_aria import AriaConfig, AriaTextConfig
Expand Down Expand Up @@ -708,6 +709,7 @@ class AriaPreTrainedModel(PreTrainedModel):
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = False

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down Expand Up @@ -1168,6 +1170,7 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.model

@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
Expand All @@ -1183,7 +1186,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Expand All @@ -1193,10 +1196,12 @@ def forward(
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Expand Down Expand Up @@ -1239,7 +1244,8 @@ def forward(

hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])

loss = None
if labels is not None:
Expand Down Expand Up @@ -1324,8 +1330,9 @@ class AriaCausalLMOutputWithPast(ModelOutput):
Whether to output hidden states.
return_dict (`bool`, *optional*):
Whether to return a `ModelOutput` object.
num_logits_to_keep (`int`, *optional*, defaults to 0):
Calculate logits for the last `num_logits_to_keep` tokens, or all `input_ids` if `0`.
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
If an `int`, calculate logits for the last `logits_to_keep` tokens, or all `input_ids` if `0`.
Otherwise, slice according to the 1D tensor in the sequence length dimension
cache_position (`torch.LongTensor`, *optional*):
Cache positions.
**loss_kwargs:
Expand Down Expand Up @@ -1426,6 +1433,7 @@ def get_image_features(
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
return image_features

@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
def forward(
Expand All @@ -1442,7 +1450,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
num_logits_to_keep: int = 0,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**loss_kwargs,
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
Expand Down Expand Up @@ -1552,7 +1560,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
num_logits_to_keep=num_logits_to_keep,
logits_to_keep=logits_to_keep,
)

logits = outputs[0]
Expand Down Expand Up @@ -1584,7 +1592,7 @@ def prepare_inputs_for_generation(
pixel_mask=None,
attention_mask=None,
cache_position=None,
num_logits_to_keep=None,
logits_to_keep=None,
**kwargs,
):
model_inputs = self.language_model.prepare_inputs_for_generation(
Expand All @@ -1593,7 +1601,7 @@ def prepare_inputs_for_generation(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
logits_to_keep=logits_to_keep,
**kwargs,
)

Expand Down
17 changes: 11 additions & 6 deletions src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
logging,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from ...utils.import_utils import is_torch_available
from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
from ..llama.configuration_llama import LlamaConfig
Expand Down Expand Up @@ -1222,6 +1223,8 @@ def _init_weights(self, module):


class AriaPreTrainedModel(LlamaPreTrainedModel):
_supports_attention_backend = False

def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
Expand Down Expand Up @@ -1301,8 +1304,9 @@ class AriaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
Whether to output hidden states.
return_dict (`bool`, *optional*):
Whether to return a `ModelOutput` object.
num_logits_to_keep (`int`, *optional*, defaults to 0):
Calculate logits for the last `num_logits_to_keep` tokens, or all `input_ids` if `0`.
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
If an `int`, calculate logits for the last `logits_to_keep` tokens, or all `input_ids` if `0`.
Otherwise, slice according to the 1D tensor in the sequence length dimension
cache_position (`torch.LongTensor`, *optional*):
Cache positions.
**loss_kwargs:
Expand Down Expand Up @@ -1403,6 +1407,7 @@ def get_image_features(
image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
return image_features

@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
def forward(
Expand All @@ -1419,7 +1424,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
num_logits_to_keep: int = 0,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**loss_kwargs,
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
Expand Down Expand Up @@ -1529,7 +1534,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
num_logits_to_keep=num_logits_to_keep,
logits_to_keep=logits_to_keep,
)

logits = outputs[0]
Expand Down Expand Up @@ -1561,7 +1566,7 @@ def prepare_inputs_for_generation(
pixel_mask=None,
attention_mask=None,
cache_position=None,
num_logits_to_keep=None,
logits_to_keep=None,
**kwargs,
):
model_inputs = self.language_model.prepare_inputs_for_generation(
Expand All @@ -1570,7 +1575,7 @@ def prepare_inputs_for_generation(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
logits_to_keep=logits_to_keep,
**kwargs,
)

Expand Down
19 changes: 12 additions & 7 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils.deprecation import deprecate_kwarg
from ...utils.import_utils import (
is_causal_conv1d_available,
is_mamba_2_ssm_available,
Expand Down Expand Up @@ -1466,6 +1467,7 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.model

@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(BAMBA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
Expand All @@ -1481,7 +1483,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Expand All @@ -1491,10 +1493,12 @@ def forward(
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int` or `None`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all
`input_ids`. Only last token logits are needed for generation, and calculating them only for that token
can save memory, which becomes pretty significant for long sequences.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Expand Down Expand Up @@ -1537,7 +1541,8 @@ def forward(

hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])

loss = None
if labels is not None:
Expand Down Expand Up @@ -1602,7 +1607,7 @@ def prepare_inputs_for_generation(
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": self.config.num_logits_to_keep,
"logits_to_keep": self.config.num_logits_to_keep,
"cache_position": cache_position,
}
)
Expand Down
Loading

0 comments on commit d3af76d

Please sign in to comment.