From bf684ad85201978475d6dd521427e7f1798ee30c Mon Sep 17 00:00:00 2001 From: divyanshuaggarwal Date: Sun, 24 Nov 2024 02:45:26 +0530 Subject: [PATCH] fix flash attention for mistral. (#758) ## fix flash attention for mistral. This pull request fixes flash attention forward method for mistral. --- .../models/mistral/modeling_mistral.py | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/src/adapters/models/mistral/modeling_mistral.py b/src/adapters/models/mistral/modeling_mistral.py index 0ed865559..f9c262f06 100644 --- a/src/adapters/models/mistral/modeling_mistral.py +++ b/src/adapters/models/mistral/modeling_mistral.py @@ -45,7 +45,7 @@ if is_flash_attn_2_available(): - from transformers.models.mistral.modeling_mistral import _flash_supports_window_size + from transformers.modeling_flash_attention_utils import _flash_attention_forward logger = logging.get_logger(__name__) @@ -173,18 +173,6 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - ) - - if not _flash_supports_window_size: - logger.warning_once( - "The current flash attention version does not support sliding window attention, for a more memory" - " efficient implementation make sure to upgrade flash-attn library." - ) - if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 @@ -257,14 +245,17 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = self._flash_attention_forward( + attn_output = _flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, + position_ids=position_ids, dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, + sliding_window=getattr(self.config, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()