-
Notifications
You must be signed in to change notification settings - Fork 495
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* test * testing tensor cache x) * fix logger * condition cache class usage * update opset for beit and data2vec vision and skip flattened/fused pkv (e.g. gpt bigcode) * style * fix args patcher * fix modernbert testing * adaot to new whisper returned generation length * fix is_causal in transformers * fix modernbert failures * style * traceable cache * use pkv index * add version gard and clean up other model patcher version gards * patch sdpa attention in optimum for now * remove modernbert condition * style * fix MistralModelPatcher * correctly patch gpt2 in vision encoder decoder * patch sdpa attention forward everywhere * fix gpt2 cross attention in seq2seq as well * moved traceable cache to a file for simplicity of model patcher * Apply suggestions from code review * style * fix
- Loading branch information
1 parent
50531a4
commit d1bcdf7
Showing
7 changed files
with
272 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import logging | ||
from typing import Any, Dict, Optional, Tuple | ||
|
||
import torch | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# Simply removing the nn.Module, same as in https://github.com/huggingface/transformers/pull/35873 | ||
class TraceableCache: | ||
""" | ||
Base, abstract class for all caches. The actual data structure is specific to each subclass. | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
def update( | ||
self, | ||
key_states: torch.Tensor, | ||
value_states: torch.Tensor, | ||
layer_idx: int, | ||
cache_kwargs: Optional[Dict[str, Any]] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | ||
Parameters: | ||
key_states (`torch.Tensor`): | ||
The new key states to cache. | ||
value_states (`torch.Tensor`): | ||
The new value states to cache. | ||
layer_idx (`int`): | ||
The index of the layer to cache the states for. | ||
cache_kwargs (`Dict[str, Any]`, `optional`): | ||
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of | ||
cache to be created. | ||
Return: | ||
A tuple containing the updated key and value states. | ||
""" | ||
raise NotImplementedError("Make sure to implement `update` in a subclass.") | ||
|
||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" | ||
# TODO: deprecate this function in favor of `cache_position` | ||
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") | ||
|
||
# Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length" | ||
# Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles | ||
# infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so | ||
# we change naming to be more explicit | ||
def get_max_length(self) -> Optional[int]: | ||
logger.warning_once( | ||
"`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. " | ||
"Calling `get_max_cache()` will raise error from v4.48" | ||
) | ||
return self.get_max_cache_shape() | ||
|
||
def get_max_cache_shape(self) -> Optional[int]: | ||
"""Returns the maximum sequence length (i.e. max capacity) of the cache object""" | ||
raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") | ||
|
||
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: | ||
"""Given the sequence length of the new inputs, returns the usable length of the cache.""" | ||
# Cache without size limit -> all cache is usable | ||
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache | ||
# length, we will need to evict part of the cache (and thus not all cache is usable) | ||
max_length = self.get_max_cache_shape() | ||
previous_seq_length = self.get_seq_length(layer_idx) | ||
if max_length is not None and previous_seq_length + new_seq_length > max_length: | ||
return max_length - new_seq_length | ||
return previous_seq_length | ||
|
||
def reorder_cache(self, beam_idx: torch.LongTensor): | ||
"""Reorders the cache for beam search, given the selected beam indices.""" | ||
for layer_idx in range(len(self.key_cache)): | ||
if self.key_cache[layer_idx] != []: | ||
device = self.key_cache[layer_idx].device | ||
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) | ||
if self.value_cache[layer_idx] != []: | ||
device = self.value_cache[layer_idx].device | ||
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) | ||
|
||
@property | ||
def seen_tokens(self): | ||
logger.warning_once( | ||
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " | ||
"model input instead." | ||
) | ||
if hasattr(self, "_seen_tokens"): | ||
return self._seen_tokens | ||
else: | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.