From c4b165d3f423ad176b96056296e3b5847599298b Mon Sep 17 00:00:00 2001 From: divyanshuaggarwal Date: Tue, 5 Nov 2024 13:07:12 +0530 Subject: [PATCH 1/3] Update modeling_mistral.py fix flash attention --- .../models/mistral/modeling_mistral.py | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/src/adapters/models/mistral/modeling_mistral.py b/src/adapters/models/mistral/modeling_mistral.py index 0ed865559..2ab31a8e3 100644 --- a/src/adapters/models/mistral/modeling_mistral.py +++ b/src/adapters/models/mistral/modeling_mistral.py @@ -45,7 +45,8 @@ if is_flash_attn_2_available(): - from transformers.models.mistral.modeling_mistral import _flash_supports_window_size + # from transformers.models.mistral.modeling_mistral import _flash_supports_window_size + from transformers.modeling_flash_attention_utils import _flash_attention_forward logger = logging.get_logger(__name__) @@ -173,17 +174,17 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - ) + # use_sliding_windows = ( + # _flash_supports_window_size + # and getattr(self.config, "sliding_window", None) is not None + # and kv_seq_len > self.config.sliding_window + # ) - if not _flash_supports_window_size: - logger.warning_once( - "The current flash attention version does not support sliding window attention, for a more memory" - " efficient implementation make sure to upgrade flash-attn library." - ) + # if not _flash_supports_window_size: + # logger.warning_once( + # "The current flash attention version does not support sliding window attention, for a more memory" + # " efficient implementation make sure to upgrade flash-attn library." + # ) if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute @@ -257,16 +258,20 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = self._flash_attention_forward( + attn_output = _flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, + position_ids=position_ids, dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, + sliding_window=getattr(self.config, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, ) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) From cfd6ad8ab0b5cbb95be0e6b11e607c29cdf4164f Mon Sep 17 00:00:00 2001 From: divyanshuaggarwal Date: Tue, 5 Nov 2024 07:58:48 +0000 Subject: [PATCH 2/3] code quality --- src/adapters/models/mistral/modeling_mistral.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/adapters/models/mistral/modeling_mistral.py b/src/adapters/models/mistral/modeling_mistral.py index 2ab31a8e3..d2afa07ce 100644 --- a/src/adapters/models/mistral/modeling_mistral.py +++ b/src/adapters/models/mistral/modeling_mistral.py @@ -45,7 +45,6 @@ if is_flash_attn_2_available(): - # from transformers.models.mistral.modeling_mistral import _flash_supports_window_size from transformers.modeling_flash_attention_utils import _flash_attention_forward @@ -174,17 +173,6 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - # use_sliding_windows = ( - # _flash_supports_window_size - # and getattr(self.config, "sliding_window", None) is not None - # and kv_seq_len > self.config.sliding_window - # ) - - # if not _flash_supports_window_size: - # logger.warning_once( - # "The current flash attention version does not support sliding window attention, for a more memory" - # " efficient implementation make sure to upgrade flash-attn library." - # ) if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute From 6e3f527f5d3589f0297c21c53f50e9068ea5cd50 Mon Sep 17 00:00:00 2001 From: divyanshuaggarwal Date: Tue, 5 Nov 2024 08:07:06 +0000 Subject: [PATCH 3/3] code quality --- src/adapters/models/mistral/modeling_mistral.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/adapters/models/mistral/modeling_mistral.py b/src/adapters/models/mistral/modeling_mistral.py index d2afa07ce..f9c262f06 100644 --- a/src/adapters/models/mistral/modeling_mistral.py +++ b/src/adapters/models/mistral/modeling_mistral.py @@ -173,7 +173,6 @@ def forward( cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 @@ -259,7 +258,6 @@ def forward( is_causal=self.is_causal, ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output)