Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix flash attention for mistral. #758

Merged
merged 3 commits into from
Nov 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions src/adapters/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@


if is_flash_attn_2_available():
from transformers.models.mistral.modeling_mistral import _flash_supports_window_size
from transformers.modeling_flash_attention_utils import _flash_attention_forward


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -173,18 +173,6 @@ def forward(
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

use_sliding_windows = (
_flash_supports_window_size
and getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
)

if not _flash_supports_window_size:
logger.warning_once(
"The current flash attention version does not support sliding window attention, for a more memory"
" efficient implementation make sure to upgrade flash-attn library."
)

if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
Expand Down Expand Up @@ -257,14 +245,17 @@ def forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

attn_output = self._flash_attention_forward(
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
position_ids=position_ids,
dropout=dropout_rate,
use_sliding_windows=use_sliding_windows,
sliding_window=getattr(self.config, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
Expand Down
Loading