From e5a8689bcc8fec21363fb05c952bf905517728ee Mon Sep 17 00:00:00 2001 From: TimoImhof <62378375+TimoImhof@users.noreply.github.com> Date: Wed, 26 Feb 2025 12:38:21 +0100 Subject: [PATCH] Upgrade Transformers to 4.48.x (#782) --- hf_transformers | 2 +- setup.py | 2 +- src/adapters/models/beit/modeling_beit.py | 81 +++- src/adapters/models/gpt2/modeling_gpt2.py | 169 +++----- src/adapters/models/llama/modeling_llama.py | 364 ++---------------- .../models/mistral/modeling_mistral.py | 348 +++-------------- 6 files changed, 211 insertions(+), 755 deletions(-) diff --git a/hf_transformers b/hf_transformers index 241c04d368..093bebcdd9 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit 241c04d36867259cdf11dbb4e9d9a60f9cb65ebc +Subproject commit 093bebcdd9efea9626a0edf78f1cd8d256b1b652 diff --git a/setup.py b/setup.py index 33a2e77e0b..aa4803baf1 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ "timeout-decorator", "torch", "torchvision", - "transformers~=4.47.1", + "transformers~=4.48.3", ] diff --git a/src/adapters/models/beit/modeling_beit.py b/src/adapters/models/beit/modeling_beit.py index 865fcdeae5..6c88965799 100644 --- a/src/adapters/models/beit/modeling_beit.py +++ b/src/adapters/models/beit/modeling_beit.py @@ -22,11 +22,21 @@ import torch.utils.checkpoint from torch import nn -from transformers.models.beit.modeling_beit import BeitLayer, BeitRelativePositionBias, BeitSelfAttention +from adapters.composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel +from transformers.models.beit.modeling_beit import ( + BeitLayer, + BeitRelativePositionBias, + BeitSdpaSelfAttention, + BeitSelfAttention, +) +from transformers.utils import logging from .mixin_beit import BeitLayerAdaptersMixin, BeitSelfAttentionAdaptersMixin +logger = logging.get_logger(__name__) + + class BeitSelfAttentionWithAdapters(BeitSelfAttentionAdaptersMixin, BeitSelfAttention): def forward( self, @@ -84,6 +94,75 @@ def forward( return outputs +class BeitSdpaSelfAttentionWithAdapters(BeitSelfAttentionAdaptersMixin, BeitSdpaSelfAttention): + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional["BeitRelativePositionBias"] = None, + interpolate_pos_encoding: bool = False, + resolution: Optional[Tuple[int]] = None, + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + if output_attentions or head_mask is not None: + logger.warning_once( + "`BeitSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not " + "support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, " + "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + head_mask=head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, + resolution=resolution, + ) + + mixed_query_layer = self.query(hidden_states) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # >>> START AH Changes <<< + query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) + + key_layer, value_layer, _ = self.prefix_tuning(key_layer, value_layer, hidden_states) + (query_layer,) = adjust_tensors_for_parallel(key_layer, query_layer) + # >>> END AH Changes <<< + + attn_bias = None + if self.relative_position_bias is not None: + height, width = resolution + window_size = (height // self.config.patch_size, width // self.config.patch_size) + attn_bias = self.relative_position_bias( + window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1] + ) + + # Add shared relative position bias if provided. + if relative_position_bias is not None: + if attn_bias is None: + attn_bias = relative_position_bias + else: + attn_bias += relative_position_bias + + scaling = 1 / math.sqrt(self.attention_head_size) + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attn_bias, + dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=scaling, + ) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer, None + + class BeitLayerWithAdapters(BeitLayerAdaptersMixin, BeitLayer): """This corresponds to the Block class in the timm implementation.""" diff --git a/src/adapters/models/gpt2/modeling_gpt2.py b/src/adapters/models/gpt2/modeling_gpt2.py index bb6410f838..6555250c0b 100644 --- a/src/adapters/models/gpt2/modeling_gpt2.py +++ b/src/adapters/models/gpt2/modeling_gpt2.py @@ -15,12 +15,13 @@ # limitations under the License. """PyTorch OpenAI GPT-2 model.""" -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.utils.checkpoint -from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2SdpaAttention +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, eager_attention_forward from transformers.utils import logging from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_ @@ -41,6 +42,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + **kwargs, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): @@ -49,37 +51,72 @@ def forward( "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." ) - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + query_states = self.q_attn(hidden_states) + key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) attention_mask = encoder_attention_mask else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) + shape_q = (*query_states.shape[:-1], -1, self.head_dim) + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) + + query_states = query_states.view(shape_q).transpose(1, 2) + key_states = key_states.view(shape_kv).transpose(1, 2) + value_states = value_states.view(shape_kv).transpose(1, 2) if layer_past is not None: past_key, past_value = layer_past - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) + key_states = torch.cat((past_key, key_states), dim=-2) + value_states = torch.cat((past_value, value_states), dim=-2) if use_cache is True: - present = (key, value) + present = (key_states, value_states) else: present = None # >>> START AH Changes <<< - key, value, attention_mask = self.prefix_tuning(key, value, hidden_states, attention_mask) - (query,) = adjust_tensors_for_parallel(key, query) + key_states, value_states, attention_mask = self.prefix_tuning( + key_states, value_states, hidden_states, attention_mask + ) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) # >>> END AH Changes <<< - if self.reorder_and_upcast_attn: - attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + is_cross_attention = encoder_hidden_states is not None + is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention + + using_eager = self.config._attn_implementation == "eager" + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None): + using_eager = True + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + # Attention functions are consistent with previous equivalent attention classes, however they do not support some options + # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but + # not necessarily to eager (if mentionned options are provided). + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if using_eager and self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn( + query_states, key_states, value_states, attention_mask, head_mask + ) else: - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + head_mask=head_mask, + dropout=self.attn_dropout.p if self.training else 0.0, + is_causal=is_causal, + **kwargs, + ) - attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) @@ -90,104 +127,6 @@ def forward( return outputs # a, present, (attentions) -class GPT2SdpaAttentionWithAdapters(GPT2AttentionAdaptersMixin, GPT2SdpaAttention): - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - if output_attentions or head_mask is not None: - logger.warning_once( - "`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " - "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " - 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - bsz, q_len, _ = hidden_states.size() - - # Initial attention projections - is_cross_attention = encoder_hidden_states is not None - if is_cross_attention: - if not hasattr(self, "q_attn"): - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`." - ) - - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) - attention_mask = encoder_attention_mask - else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) - - # Optional kv caching - if layer_past is not None: - past_key = layer_past[0] - past_value = layer_past[1] - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - - present = None - if use_cache is True: - present = (key, value) - - # >>> START AH Changes <<< - key, value, attention_mask = self.prefix_tuning(key, value, hidden_states, attention_mask) - (query,) = adjust_tensors_for_parallel(key, query) - bsz = key.shape[0] - # >>> END AH Changes <<< - - # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA - if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=self.attn_dropout.p if self.training else 0.0, - is_causal=is_causal, - ) - - # Reshape outputs - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.embed_dim) - - # Final projection - attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) - - return attn_output, present, None - - class GPT2BlockWithAdapters(GPT2DecoderBlockAdaptersMixin, GPT2Block): def forward( self, diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index 27a0e8c8ea..9ff83f5c2e 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -18,26 +18,22 @@ # See the License for the specific language governing permissions and # limitations under the License. """PyTorch LLaMA model.""" -import math -import warnings -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple import torch -import torch.nn.functional as F import torch.utils.checkpoint -from torch import nn from adapters.composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel -from transformers.cache_utils import Cache, StaticCache -from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, - LlamaFlashAttention2, - LlamaSdpaAttention, apply_rotary_pos_emb, - repeat_kv, + eager_attention_forward, ) +from transformers.processing_utils import Unpack from transformers.utils import logging from .mixin_llama import LlamaAttentionMixin, LlamaDecoderLayerMixin @@ -52,166 +48,31 @@ class LlamaAttentionWithAdapters(LlamaAttentionMixin, LlamaAttention): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Loosen constraint on batch_size to allow parallel adapter composition - query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - # >>> START AH Changes <<< - query_states, key_states, value_states = match_attn_matrices_for_parallel( - query_states, key_states, value_states - ) - (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) - # >>> END AH Changes <<< - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - # >>> START AH Changes <<< - key_states, value_states, attention_mask = self.prefix_tuning( - key_states, value_states, hidden_states, attention_mask - ) - (query_states,) = adjust_tensors_for_parallel(key_states, query_states) - bsz = key_states.shape[0] - # >>> END AH Changes <<< - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) - - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaFlashAttention2WithAdapters(LlamaAttentionMixin, LlamaFlashAttention2): - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` make" - " sure to use `sdpa` in the mean time, and open an issue at" - " https://github.com/huggingface/transformers" - ) - - output_attentions = False - bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - # Loosen constraint on batch_size to allow parallel adapter composition - query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(-1, q_len, self.config.num_attention_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(-1, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(-1, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2) - # >>> START AH Changes <<< query_states, key_states, value_states = match_attn_matrices_for_parallel( query_states, key_states, value_states ) (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) # >>> END AH Changes <<< - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -224,176 +85,35 @@ def forward( key_states, value_states, hidden_states, attention_mask ) (query_states,) = adjust_tensors_for_parallel(key_states, query_states) - # Make adjustments since (parallel) prefix tuning changes the attention mask bsz = key_states.shape[0] # >>> END AH Changes <<< - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - "The input hidden states seems to be silently casted in float32, this might be related to the fact" - " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) - - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaSdpaAttentionWithAdapters(LlamaAttentionMixin, LlamaSdpaAttention): - - # Adapted from LlamaAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does" - " not support `output_attentions=True`. Falling back to the manual attention implementation, but" - " specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This" - ' warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Loosen constraint on batch_size to allow parallel adapter composition - query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - # >>> START AH Changes <<< - query_states, key_states, value_states = match_attn_matrices_for_parallel( - query_states, key_states, value_states - ) - (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) - # >>> END AH Changes <<< - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - # >>> START AH Changes <<< - key_states, value_states, attention_mask = self.prefix_tuning( - key_states, value_states, hidden_states, attention_mask - ) - (query_states,) = adjust_tensors_for_parallel(key_states, query_states) - # >>> END AH Changes <<< - - bsz = key_states.shape[0] - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value + return attn_output, attn_weights class LlamaDecoderLayerWithAdapters(LlamaDecoderLayerMixin, LlamaDecoderLayer): @@ -406,35 +126,16 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use" - " `attention_mask` instead.`" - ) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -458,7 +159,4 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/src/adapters/models/mistral/modeling_mistral.py b/src/adapters/models/mistral/modeling_mistral.py index f9c262f06f..f8925ba821 100644 --- a/src/adapters/models/mistral/modeling_mistral.py +++ b/src/adapters/models/mistral/modeling_mistral.py @@ -18,36 +18,36 @@ # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Mistral model.""" -import math -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple + + +try: + from typing import Unpack +except ImportError: + from typing_extensions import Unpack import torch import torch.utils.checkpoint -from torch import nn from adapters.composition import ( adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel, ) -from transformers.cache_utils import Cache, StaticCache +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.models.mistral.modeling_mistral import ( MistralAttention, MistralDecoderLayer, - MistralFlashAttention2, - MistralSdpaAttention, apply_rotary_pos_emb, - repeat_kv, + eager_attention_forward, ) -from transformers.utils import is_flash_attn_2_available, logging +from transformers.utils import logging from .mixin_mistral import MistralAttentionMixin, MistralDecoderLayerMixin -if is_flash_attn_2_available(): - from transformers.modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) @@ -55,32 +55,31 @@ class MistralAttentionWithAdapters(MistralAttentionMixin, MistralAttention): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # >>> START AH Changes <<< bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # >>> START AH Changes <<< # Loosen constraint on batch_size to allow parallel adapter composition - query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(-1, q_len, self.config.num_attention_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(-1, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(-1, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2) query_states, key_states, value_states = match_attn_matrices_for_parallel( query_states, key_states, value_states ) - (attention_mask, position_ids) = adjust_tensors_for_parallel(query_states, attention_mask, position_ids) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) # >>> END AH Changes <<< - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -88,9 +87,6 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - # >>> START AH Changes <<< key_states, value_states, attention_mask = self.prefix_tuning( key_states, value_states, hidden_states, attention_mask @@ -100,271 +96,35 @@ def forward( bsz = key_states.shape[0] # >>> END AH Changes <<< - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class MistralFlashAttention2WithAdapters(MistralAttentionMixin, MistralFlashAttention2): - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ): - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` make" - " sure to use `sdpa` in the mean time, and open an issue at" - " https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # >>> START AH Changes <<< - # Loosen constraint on batch_size to allow parallel adapter composition - query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - query_states, key_states, value_states = match_attn_matrices_for_parallel( - query_states, key_states, value_states - ) - (attention_mask, position_ids) = adjust_tensors_for_parallel(query_states, attention_mask, position_ids) - # >>> END AH Changes <<< - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += cache_position[0] - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - "past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1," - f" head_dim`), got {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # >>> START AH Changes <<< - key_states, value_states, attention_mask = self.prefix_tuning( - key_states, value_states, hidden_states, attention_mask - ) - (query_states,) = adjust_tensors_for_parallel(key_states, query_states) - # Make adjustments since (parallel) prefix tuning changes the attention mask - kv_seq_len = key_states.shape[-2] - bsz = key_states.shape[0] - # >>> END AH Changes <<< - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - "The input hidden states seems to be silently casted in float32, this might be related to the fact" - " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self.config, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class MistralSdpaAttentionWithAdapters(MistralAttentionMixin, MistralSdpaAttention): - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention`" - " does not support `output_attentions=True`. Falling back to the manual attention implementation, but" - " specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This" - ' warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - # >>> START AH Changes <<< - # Loosen constraint on batch_size to allow parallel adapter composition - query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - query_states, key_states, value_states = match_attn_matrices_for_parallel( - query_states, key_states, value_states - ) - (attention_mask, position_ids) = adjust_tensors_for_parallel(query_states, attention_mask, position_ids) - # >>> END AH Changes <<< - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - # >>> START AH Changes <<< - key_states, value_states, attention_mask = self.prefix_tuning( - key_states, value_states, hidden_states, attention_mask - ) - (query_states,) = adjust_tensors_for_parallel(key_states, query_states) - # Make adjustments since (parallel) prefix tuning changes the attention mask - bsz = key_states.shape[0] - # >>> END AH Changes <<< - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, q_len, -1) + # >>> END AH Changes <<< attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value + return attn_output, attn_weights class MistralDecoderLayerWithAdapters(MistralDecoderLayerMixin, MistralDecoderLayer): @@ -373,31 +133,13 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ adjust_tensors_for_parallel_(hidden_states, attention_mask, position_ids) residual = hidden_states @@ -405,7 +147,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -413,6 +155,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, ) hidden_states = self.attention_adapters(hidden_states, residual, None) @@ -423,11 +167,7 @@ def forward( hidden_states = self.output_adapters(hidden_states, residual, None) outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs