Skip to content

Commit

Permalink
Update modeling_mistral.py
Browse files Browse the repository at this point in the history
fix flash attention
  • Loading branch information
divyanshuaggarwal authored Nov 5, 2024
1 parent 0e18a53 commit c4b165d
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions src/adapters/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@


if is_flash_attn_2_available():
from transformers.models.mistral.modeling_mistral import _flash_supports_window_size
# from transformers.models.mistral.modeling_mistral import _flash_supports_window_size
from transformers.modeling_flash_attention_utils import _flash_attention_forward


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -173,17 +174,17 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

use_sliding_windows = (
_flash_supports_window_size
and getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
)
# use_sliding_windows = (
# _flash_supports_window_size
# and getattr(self.config, "sliding_window", None) is not None
# and kv_seq_len > self.config.sliding_window
# )

if not _flash_supports_window_size:
logger.warning_once(
"The current flash attention version does not support sliding window attention, for a more memory"
" efficient implementation make sure to upgrade flash-attn library."
)
# if not _flash_supports_window_size:
# logger.warning_once(
# "The current flash attention version does not support sliding window attention, for a more memory"
# " efficient implementation make sure to upgrade flash-attn library."
# )

if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
Expand Down Expand Up @@ -257,16 +258,20 @@ def forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

attn_output = self._flash_attention_forward(
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
use_sliding_windows=use_sliding_windows,
sliding_window=getattr(self.config, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)


attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)

Expand Down

0 comments on commit c4b165d

Please sign in to comment.