Skip to content

Commit

Permalink
Upgrade Transformers to 4.48.x (#782)
Browse files Browse the repository at this point in the history
  • Loading branch information
TimoImhof authored Feb 26, 2025
1 parent 463222e commit e5a8689
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 755 deletions.
2 changes: 1 addition & 1 deletion hf_transformers
Submodule hf_transformers updated 1918 files
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"timeout-decorator",
"torch",
"torchvision",
"transformers~=4.47.1",
"transformers~=4.48.3",
]


Expand Down
81 changes: 80 additions & 1 deletion src/adapters/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,21 @@
import torch.utils.checkpoint
from torch import nn

from transformers.models.beit.modeling_beit import BeitLayer, BeitRelativePositionBias, BeitSelfAttention
from adapters.composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
from transformers.models.beit.modeling_beit import (
BeitLayer,
BeitRelativePositionBias,
BeitSdpaSelfAttention,
BeitSelfAttention,
)
from transformers.utils import logging

from .mixin_beit import BeitLayerAdaptersMixin, BeitSelfAttentionAdaptersMixin


logger = logging.get_logger(__name__)


class BeitSelfAttentionWithAdapters(BeitSelfAttentionAdaptersMixin, BeitSelfAttention):
def forward(
self,
Expand Down Expand Up @@ -84,6 +94,75 @@ def forward(
return outputs


class BeitSdpaSelfAttentionWithAdapters(BeitSelfAttentionAdaptersMixin, BeitSdpaSelfAttention):
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
resolution: Optional[Tuple[int]] = None,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`BeitSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not "
"support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, "
"but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
relative_position_bias=relative_position_bias,
interpolate_pos_encoding=interpolate_pos_encoding,
resolution=resolution,
)

mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)

# >>> START AH Changes <<<
query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)

key_layer, value_layer, _ = self.prefix_tuning(key_layer, value_layer, hidden_states)
(query_layer,) = adjust_tensors_for_parallel(key_layer, query_layer)
# >>> END AH Changes <<<

attn_bias = None
if self.relative_position_bias is not None:
height, width = resolution
window_size = (height // self.config.patch_size, width // self.config.patch_size)
attn_bias = self.relative_position_bias(
window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
)

# Add shared relative position bias if provided.
if relative_position_bias is not None:
if attn_bias is None:
attn_bias = relative_position_bias
else:
attn_bias += relative_position_bias

scaling = 1 / math.sqrt(self.attention_head_size)
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attn_bias,
dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=scaling,
)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer, None


class BeitLayerWithAdapters(BeitLayerAdaptersMixin, BeitLayer):
"""This corresponds to the Block class in the timm implementation."""

Expand Down
169 changes: 54 additions & 115 deletions src/adapters/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""

from typing import Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union

import torch
import torch.utils.checkpoint

from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2SdpaAttention
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, eager_attention_forward
from transformers.utils import logging

from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_
Expand All @@ -41,6 +42,7 @@ def forward(
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
Expand All @@ -49,37 +51,72 @@ def forward(
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
query_states = self.q_attn(hidden_states)
key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
shape_q = (*query_states.shape[:-1], -1, self.head_dim)
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)

query_states = query_states.view(shape_q).transpose(1, 2)
key_states = key_states.view(shape_kv).transpose(1, 2)
value_states = value_states.view(shape_kv).transpose(1, 2)

if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
key_states = torch.cat((past_key, key_states), dim=-2)
value_states = torch.cat((past_value, value_states), dim=-2)

if use_cache is True:
present = (key, value)
present = (key_states, value_states)
else:
present = None

# >>> START AH Changes <<<
key, value, attention_mask = self.prefix_tuning(key, value, hidden_states, attention_mask)
(query,) = adjust_tensors_for_parallel(key, query)
key_states, value_states, attention_mask = self.prefix_tuning(
key_states, value_states, hidden_states, attention_mask
)
(query_states,) = adjust_tensors_for_parallel(key_states, query_states)
# >>> END AH Changes <<<

if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
is_cross_attention = encoder_hidden_states is not None
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention

using_eager = self.config._attn_implementation == "eager"
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
using_eager = True
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
# Attention functions are consistent with previous equivalent attention classes, however they do not support some options
# (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
# not necessarily to eager (if mentionned options are provided).
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

if using_eager and self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(
query_states, key_states, value_states, attention_mask, head_mask
)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
head_mask=head_mask,
dropout=self.attn_dropout.p if self.training else 0.0,
is_causal=is_causal,
**kwargs,
)

attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

Expand All @@ -90,104 +127,6 @@ def forward(
return outputs # a, present, (attentions)


class GPT2SdpaAttentionWithAdapters(GPT2AttentionAdaptersMixin, GPT2SdpaAttention):
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)

bsz, q_len, _ = hidden_states.size()

# Initial attention projections
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

# Optional kv caching
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)

present = None
if use_cache is True:
present = (key, value)

# >>> START AH Changes <<<
key, value, attention_mask = self.prefix_tuning(key, value, hidden_states, attention_mask)
(query,) = adjust_tensors_for_parallel(key, query)
bsz = key.shape[0]
# >>> END AH Changes <<<

# Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=self.attn_dropout.p if self.training else 0.0,
is_causal=is_causal,
)

# Reshape outputs
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.embed_dim)

# Final projection
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

return attn_output, present, None


class GPT2BlockWithAdapters(GPT2DecoderBlockAdaptersMixin, GPT2Block):
def forward(
self,
Expand Down
Loading

0 comments on commit e5a8689

Please sign in to comment.