Skip to content

Commit

Permalink
fix gpt2
Browse files Browse the repository at this point in the history
  • Loading branch information
TimoImhof committed Jan 13, 2025
1 parent 23b5a5f commit 6794a38
Showing 1 changed file with 54 additions and 115 deletions.
169 changes: 54 additions & 115 deletions src/adapters/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""

from typing import Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union

import torch
import torch.utils.checkpoint

from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2SdpaAttention
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, eager_attention_forward
from transformers.utils import logging

from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_
Expand All @@ -41,6 +42,7 @@ def forward(
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
Expand All @@ -49,37 +51,72 @@ def forward(
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
query_states = self.q_attn(hidden_states)
key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
shape_q = (*query_states.shape[:-1], -1, self.head_dim)
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)

query_states = query_states.view(shape_q).transpose(1, 2)
key_states = key_states.view(shape_kv).transpose(1, 2)
value_states = value_states.view(shape_kv).transpose(1, 2)

if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
key_states = torch.cat((past_key, key_states), dim=-2)
value_states = torch.cat((past_value, value_states), dim=-2)

if use_cache is True:
present = (key, value)
present = (key_states, value_states)
else:
present = None

# >>> START AH Changes <<<
key, value, attention_mask = self.prefix_tuning(key, value, hidden_states, attention_mask)
(query,) = adjust_tensors_for_parallel(key, query)
key_states, value_states, attention_mask = self.prefix_tuning(
key_states, value_states, hidden_states, attention_mask
)
(query_states,) = adjust_tensors_for_parallel(key_states, query_states)
# >>> END AH Changes <<<

if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
is_cross_attention = encoder_hidden_states is not None
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention

using_eager = self.config._attn_implementation == "eager"
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
using_eager = True
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
# Attention functions are consistent with previous equivalent attention classes, however they do not support some options
# (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
# not necessarily to eager (if mentionned options are provided).
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

if using_eager and self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(
query_states, key_states, value_states, attention_mask, head_mask
)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
head_mask=head_mask,
dropout=self.attn_dropout.p if self.training else 0.0,
is_causal=is_causal,
**kwargs,
)

attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

Expand All @@ -90,104 +127,6 @@ def forward(
return outputs # a, present, (attentions)


class GPT2SdpaAttentionWithAdapters(GPT2AttentionAdaptersMixin, GPT2SdpaAttention):
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)

bsz, q_len, _ = hidden_states.size()

# Initial attention projections
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

# Optional kv caching
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)

present = None
if use_cache is True:
present = (key, value)

# >>> START AH Changes <<<
key, value, attention_mask = self.prefix_tuning(key, value, hidden_states, attention_mask)
(query,) = adjust_tensors_for_parallel(key, query)
bsz = key.shape[0]
# >>> END AH Changes <<<

# Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=self.attn_dropout.p if self.training else 0.0,
is_causal=is_causal,
)

# Reshape outputs
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.embed_dim)

# Final projection
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

return attn_output, present, None


class GPT2BlockWithAdapters(GPT2DecoderBlockAdaptersMixin, GPT2Block):
def forward(
self,
Expand Down

0 comments on commit 6794a38

Please sign in to comment.