diff --git a/src/adapters/models/mistral/modeling_mistral.py b/src/adapters/models/mistral/modeling_mistral.py index f9c262f06..14cc47615 100644 --- a/src/adapters/models/mistral/modeling_mistral.py +++ b/src/adapters/models/mistral/modeling_mistral.py @@ -18,36 +18,30 @@ # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Mistral model.""" -import math -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple, Unpack import torch import torch.utils.checkpoint -from torch import nn from adapters.composition import ( adjust_tensors_for_parallel, adjust_tensors_for_parallel_, match_attn_matrices_for_parallel, ) -from transformers.cache_utils import Cache, StaticCache +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.models.mistral.modeling_mistral import ( MistralAttention, MistralDecoderLayer, - MistralFlashAttention2, - MistralSdpaAttention, apply_rotary_pos_emb, - repeat_kv, + eager_attention_forward, ) -from transformers.utils import is_flash_attn_2_available, logging +from transformers.utils import logging from .mixin_mistral import MistralAttentionMixin, MistralDecoderLayerMixin -if is_flash_attn_2_available(): - from transformers.modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) @@ -55,32 +49,31 @@ class MistralAttentionWithAdapters(MistralAttentionMixin, MistralAttention): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # >>> START AH Changes <<< bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # >>> START AH Changes <<< # Loosen constraint on batch_size to allow parallel adapter composition - query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(-1, q_len, self.config.num_attention_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(-1, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(-1, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2) query_states, key_states, value_states = match_attn_matrices_for_parallel( query_states, key_states, value_states ) - (attention_mask, position_ids) = adjust_tensors_for_parallel(query_states, attention_mask, position_ids) + (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) # >>> END AH Changes <<< - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -88,283 +81,44 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - # >>> START AH Changes <<< - key_states, value_states, attention_mask = self.prefix_tuning( - key_states, value_states, hidden_states, attention_mask - ) - (query_states,) = adjust_tensors_for_parallel(key_states, query_states) - # Make adjustments since (parallel) prefix tuning changes the attention mask - bsz = key_states.shape[0] - # >>> END AH Changes <<< - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class MistralFlashAttention2WithAdapters(MistralAttentionMixin, MistralFlashAttention2): - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ): - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` make" - " sure to use `sdpa` in the mean time, and open an issue at" - " https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # >>> START AH Changes <<< - # Loosen constraint on batch_size to allow parallel adapter composition - query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - query_states, key_states, value_states = match_attn_matrices_for_parallel( - query_states, key_states, value_states - ) - (attention_mask, position_ids) = adjust_tensors_for_parallel(query_states, attention_mask, position_ids) - # >>> END AH Changes <<< - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += cache_position[0] - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - "past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1," - f" head_dim`), got {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - # >>> START AH Changes <<< key_states, value_states, attention_mask = self.prefix_tuning( key_states, value_states, hidden_states, attention_mask ) (query_states,) = adjust_tensors_for_parallel(key_states, query_states) # Make adjustments since (parallel) prefix tuning changes the attention mask - kv_seq_len = key_states.shape[-2] bsz = key_states.shape[0] # >>> END AH Changes <<< - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) else: - target_dtype = self.q_proj.weight.dtype + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - "The input hidden states seems to be silently casted in float32, this might be related to the fact" - " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self.config, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class MistralSdpaAttentionWithAdapters(MistralAttentionMixin, MistralSdpaAttention): - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention`" - " does not support `output_attentions=True`. Falling back to the manual attention implementation, but" - " specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This" - ' warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # >>> START AH Changes <<< - # Loosen constraint on batch_size to allow parallel adapter composition - query_states = query_states.view(-1, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(-1, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - query_states, key_states, value_states = match_attn_matrices_for_parallel( - query_states, key_states, value_states + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama + **kwargs, ) - (attention_mask, position_ids) = adjust_tensors_for_parallel(query_states, attention_mask, position_ids) - # >>> END AH Changes <<< - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) # >>> START AH Changes <<< - key_states, value_states, attention_mask = self.prefix_tuning( - key_states, value_states, hidden_states, attention_mask - ) - (query_states,) = adjust_tensors_for_parallel(key_states, query_states) - # Make adjustments since (parallel) prefix tuning changes the attention mask - bsz = key_states.shape[0] - # >>> END AH Changes <<< - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, q_len, -1) + # >>> END AH Changes <<< attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value + return attn_output, attn_weights class MistralDecoderLayerWithAdapters(MistralDecoderLayerMixin, MistralDecoderLayer): @@ -373,31 +127,13 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ adjust_tensors_for_parallel_(hidden_states, attention_mask, position_ids) residual = hidden_states @@ -405,7 +141,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -413,6 +149,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, ) hidden_states = self.attention_adapters(hidden_states, residual, None) @@ -423,11 +161,7 @@ def forward( hidden_states = self.output_adapters(hidden_states, residual, None) outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs