From 624bb952170d921d202e10093ef3ae3da1c540f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Thu, 22 Feb 2024 23:03:31 +0100 Subject: [PATCH 01/12] hf transformers --- hf_transformers | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hf_transformers b/hf_transformers index a7cab3c283..a0857740c0 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit a7cab3c283312b8d4de5df3bbe719971e24f4281 +Subproject commit a0857740c0e6127485c11476650314df3accc2b6 From 1192c2dc013f86881a7aeefd3aeef2a063db4d74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Thu, 22 Feb 2024 23:18:38 +0100 Subject: [PATCH 02/12] also forgot to push setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 32c61c1334..9dc65b08fd 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ "sphinx-multiversion", "timeout-decorator", "torch>=1.10,!=1.12.0", - "transformers~=4.36.0", + "transformers~=4.38.1", ] From f96b085c643ef57abacda0d40f69c214d3b9b6ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Fri, 23 Feb 2024 10:38:02 +0100 Subject: [PATCH 03/12] llama fixes --- src/adapters/models/llama/adapter_model.py | 4 + src/adapters/models/llama/modeling_llama.py | 127 ++++++++------------ 2 files changed, 52 insertions(+), 79 deletions(-) diff --git a/src/adapters/models/llama/adapter_model.py b/src/adapters/models/llama/adapter_model.py index 97cc0c4e3f..a3ad2148f8 100644 --- a/src/adapters/models/llama/adapter_model.py +++ b/src/adapters/models/llama/adapter_model.py @@ -10,6 +10,8 @@ from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init +from typing import Optional + logger = logging.getLogger(__name__) @@ -58,6 +60,7 @@ def forward( past_key_values=None, inputs_embeds=None, use_cache=None, + cache_position: Optional[torch.LongTensor] = None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -79,6 +82,7 @@ def forward( position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, + cache_position=cache_position, output_attentions=output_attentions, return_dict=return_dict, output_hidden_states=output_hidden_states, diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index baefac75a9..9b3f8aa74b 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -29,7 +29,6 @@ from adapters.composition import ( adjust_tensors_for_parallel, - adjust_tensors_for_parallel_, match_attn_matrices_for_parallel, ) from transformers.cache_utils import Cache @@ -53,14 +52,9 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use" - " `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: @@ -94,20 +88,13 @@ def forward( ) (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - "The cache structure has changed since version v4.36. If you are using" - f" {self.__class__.__name__} for auto-regressive decoding with k/v caching, please make sure to" - " initialize the attention class with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -120,22 +107,12 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - # Make adjustments since (parallel) prefix tuning changes the attention mask - kv_seq_len = key_states.shape[-2] bsz = key_states.shape[0] - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: # no matter the length, we just slice it + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) @@ -174,18 +151,9 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # LlamaFlashAttention2 attention does not support output_attentions - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use" - " `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -206,23 +174,21 @@ def forward( ) (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) key_states, value_states, attention_mask = self.prefix_tuning( key_states, value_states, hidden_states, attention_mask ) (query_states,) = adjust_tensors_for_parallel(key_states, query_states) - # Make adjustments since (parallel) prefix tuning changes the attention mask - kv_seq_len = key_states.shape[-2] bsz = key_states.shape[0] + past_key_value = getattr(self, "past_key_value", past_key_value) + if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache @@ -241,15 +207,17 @@ def forward( input_dtype = query_states.dtype if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): + elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( - "The input hidden states seems to be silently casted in float32, this might be related to the fact" - " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) @@ -281,14 +249,13 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does" - " not support `output_attentions=True`. Falling back to the manual attention implementation, but" - " specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This" - ' warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, @@ -297,6 +264,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) bsz, q_len, _ = hidden_states.size() @@ -314,15 +282,14 @@ def forward( ) (attention_mask,) = adjust_tensors_for_parallel(query_states, attention_mask) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -333,19 +300,15 @@ def forward( ) (query_states,) = adjust_tensors_for_parallel(key_states, query_states) - # Make adjustments since (parallel) prefix tuning changes the attention mask - kv_seq_len = key_states.shape[-2] bsz = key_states.shape[0] - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) + causal_mask = attention_mask + if attention_mask is not None and cache_position is not None: + causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: + if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() @@ -354,14 +317,12 @@ def forward( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -377,12 +338,15 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -391,8 +355,11 @@ def forward( (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) - adjust_tensors_for_parallel_(hidden_states, attention_mask, position_ids) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -405,6 +372,8 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, + **kwargs, ) hidden_states = self.attention_adapters(hidden_states, residual, None) From 2130032e15a12e6a1c4140cec09d22ceadbce343 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Fri, 23 Feb 2024 18:53:04 +0100 Subject: [PATCH 04/12] fix llama qa model --- src/adapters/head_utils.py | 12 ++++++++++++ src/adapters/models/__init__.py | 3 ++- src/adapters/models/llama/adapter_model.py | 14 ++++++++++++-- src/adapters/models/llama/mixin_llama.py | 4 ++++ src/adapters/models/llama/modeling_llama.py | 18 +++++++++--------- 5 files changed, 39 insertions(+), 12 deletions(-) diff --git a/src/adapters/head_utils.py b/src/adapters/head_utils.py index 2144fbe5ee..ec78430e02 100644 --- a/src/adapters/head_utils.py +++ b/src/adapters/head_utils.py @@ -498,6 +498,7 @@ }, "layers": [None, "qa_outputs"], }, + # T5 "T5ForConditionalGeneration": { "config": { "head_type": "seq2seq_lm", @@ -526,6 +527,7 @@ "classification_head.out_proj", ], }, + # DeBERTaV2 "DebertaV2ForSequenceClassification": { "config": { "head_type": "classification", @@ -575,6 +577,7 @@ }, "layers": [None, "pooler.dense", None, None, "classifier"], }, + # DeBERTa "DebertaForSequenceClassification": { "config": { "head_type": "classification", @@ -641,6 +644,15 @@ }, "layers": ["lm_head"], }, + "LlamaForQuestionAnswering": { + "config": { + "head_type": "question_answering", + "layers": 1, + "activation_function": None, + }, + "layers": [None, "qa_outputs"], + }, + # Electra "ElectraForTokenClassification": { "config": { "head_type": "tagging", diff --git a/src/adapters/models/__init__.py b/src/adapters/models/__init__.py index 46eba733b7..0a51d37451 100644 --- a/src/adapters/models/__init__.py +++ b/src/adapters/models/__init__.py @@ -17,7 +17,7 @@ from .distilbert.mixin_distilbert import DistilBertModelAdaptersMixin, DistilBertTransformerAdaptersMixin from .gpt2.mixin_gpt2 import GPT2ModelAdapterMixin from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin -from .llama.mixin_llama import LlamaModelAdapterMixin +from .llama.mixin_llama import LlamaModelAdapterMixin, LlamaForQuestionAnswering from .t5.mixin_t5 import ( T5BlockAdaptersMixin, T5ForCondiditionalGenerationWithHeadsMixin, @@ -83,4 +83,5 @@ "BertGenerationEncoder": BertModelAdaptersMixin, "BertGenerationLayer": BertLayerAdaptersMixin, "LlamaModel": LlamaModelAdapterMixin, + "LlamaForQuestionAnswering": LlamaForQuestionAnswering, } diff --git a/src/adapters/models/llama/adapter_model.py b/src/adapters/models/llama/adapter_model.py index a3ad2148f8..a4689bbafd 100644 --- a/src/adapters/models/llama/adapter_model.py +++ b/src/adapters/models/llama/adapter_model.py @@ -1,4 +1,5 @@ import logging +from typing import Optional import torch @@ -10,8 +11,6 @@ from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init -from typing import Optional - logger = logging.getLogger(__name__) @@ -153,3 +152,14 @@ def prepare_inputs_for_generation( } ) return model_inputs + + def _load_pretrained_model(cls, model, state_dict, loaded_keys, *args, **kwargs): + # TODO: remove this once Hugging Face has fixed the naming inconsistency (https://github.com/huggingface/transformers/pull/29258) + # LlamaForQuestionAnswering has inconsistent naming of the base model: it is called "transformer" instead of "model" + # if we are loading a LlamaForQuestionAnswering model, state_dict and model contain keys like transformer.embed_tokens.weight', 'transformer.layers.0.self_attn.q_proj.weight', ... + # rename every occurence of "transformer" to "model" in the state_dict + + state_dict = [key.replace("transformer", "model") for key in state_dict] + model = {key.replace("transformer", "model"): value for key, value in model.items()} + + return super()._load_pretrained_model(cls, model, state_dict, loaded_keys, *args, **kwargs) diff --git a/src/adapters/models/llama/mixin_llama.py b/src/adapters/models/llama/mixin_llama.py index aae339433c..c8b896ec0a 100644 --- a/src/adapters/models/llama/mixin_llama.py +++ b/src/adapters/models/llama/mixin_llama.py @@ -44,3 +44,7 @@ def post_embedding_forward(self, module, args, embedding_output): embedding_output = self.invertible_adapters_forward(embedding_output) # Prompt tuning not yet supported return embedding_output + + +class LlamaForQuestionAnswering: + base_model_prefix = "transformer" # this is needed because Transformers v4.38.1 is inconsistent with the naming of the base model but didn't change the base_model_prefix diff --git a/src/adapters/models/llama/modeling_llama.py b/src/adapters/models/llama/modeling_llama.py index 9b3f8aa74b..752cfccb10 100644 --- a/src/adapters/models/llama/modeling_llama.py +++ b/src/adapters/models/llama/modeling_llama.py @@ -27,10 +27,7 @@ import torch.utils.checkpoint from torch import nn -from adapters.composition import ( - adjust_tensors_for_parallel, - match_attn_matrices_for_parallel, -) +from adapters.composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel from transformers.cache_utils import Cache from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, apply_rotary_pos_emb, repeat_kv from transformers.utils import logging @@ -216,8 +213,8 @@ def forward( target_dtype = self.q_proj.weight.dtype logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + "The input hidden states seems to be silently casted in float32, this might be related to the fact" + " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) @@ -254,8 +251,10 @@ def forward( if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does" + " not support `output_attentions=True`. Falling back to the manual attention implementation, but" + " specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This" + ' warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, @@ -357,7 +356,8 @@ def forward( """ if "padding_mask" in kwargs: warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use" + " `attention_mask` instead.`" ) residual = hidden_states From 2e9710730e18a096405cb89ef6b395f79ed177b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Fri, 23 Feb 2024 20:02:57 +0100 Subject: [PATCH 05/12] make style --- src/adapters/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/adapters/models/__init__.py b/src/adapters/models/__init__.py index 0a51d37451..11dff6ac9f 100644 --- a/src/adapters/models/__init__.py +++ b/src/adapters/models/__init__.py @@ -17,7 +17,7 @@ from .distilbert.mixin_distilbert import DistilBertModelAdaptersMixin, DistilBertTransformerAdaptersMixin from .gpt2.mixin_gpt2 import GPT2ModelAdapterMixin from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin -from .llama.mixin_llama import LlamaModelAdapterMixin, LlamaForQuestionAnswering +from .llama.mixin_llama import LlamaForQuestionAnswering, LlamaModelAdapterMixin from .t5.mixin_t5 import ( T5BlockAdaptersMixin, T5ForCondiditionalGenerationWithHeadsMixin, From b707bc2e6765efc0235fbf56b3edf67b8314d603 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Fri, 23 Feb 2024 20:13:40 +0100 Subject: [PATCH 06/12] corrected misspelled class --- src/adapters/models/__init__.py | 4 ++-- src/adapters/models/llama/mixin_llama.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/adapters/models/__init__.py b/src/adapters/models/__init__.py index 11dff6ac9f..ff19c38f3b 100644 --- a/src/adapters/models/__init__.py +++ b/src/adapters/models/__init__.py @@ -17,7 +17,7 @@ from .distilbert.mixin_distilbert import DistilBertModelAdaptersMixin, DistilBertTransformerAdaptersMixin from .gpt2.mixin_gpt2 import GPT2ModelAdapterMixin from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin -from .llama.mixin_llama import LlamaForQuestionAnswering, LlamaModelAdapterMixin +from .llama.mixin_llama import LlamaForQuestionAnsweringAdapterMixin, LlamaModelAdapterMixin from .t5.mixin_t5 import ( T5BlockAdaptersMixin, T5ForCondiditionalGenerationWithHeadsMixin, @@ -83,5 +83,5 @@ "BertGenerationEncoder": BertModelAdaptersMixin, "BertGenerationLayer": BertLayerAdaptersMixin, "LlamaModel": LlamaModelAdapterMixin, - "LlamaForQuestionAnswering": LlamaForQuestionAnswering, + "LlamaForQuestionAnswering": LlamaForQuestionAnsweringAdapterMixin, } diff --git a/src/adapters/models/llama/mixin_llama.py b/src/adapters/models/llama/mixin_llama.py index c8b896ec0a..21314235a6 100644 --- a/src/adapters/models/llama/mixin_llama.py +++ b/src/adapters/models/llama/mixin_llama.py @@ -46,5 +46,6 @@ def post_embedding_forward(self, module, args, embedding_output): return embedding_output -class LlamaForQuestionAnswering: - base_model_prefix = "transformer" # this is needed because Transformers v4.38.1 is inconsistent with the naming of the base model but didn't change the base_model_prefix +class LlamaForQuestionAnsweringAdapterMixin: + # this is needed because Transformers v4.38.1 is inconsistent with the naming of the base model but didn't change the base_model_prefix + base_model_prefix = "transformer" From b8926019d8e6ac02c102a01d6bd963455954f7d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Wed, 28 Feb 2024 22:46:06 +0100 Subject: [PATCH 07/12] LlamaAdapterModel improve renaming state_dict and model --- src/adapters/models/llama/adapter_model.py | 17 ++++++++++------- src/adapters/models/llama/mixin_llama.py | 1 + 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/adapters/models/llama/adapter_model.py b/src/adapters/models/llama/adapter_model.py index a4689bbafd..4c93319e8b 100644 --- a/src/adapters/models/llama/adapter_model.py +++ b/src/adapters/models/llama/adapter_model.py @@ -154,12 +154,15 @@ def prepare_inputs_for_generation( return model_inputs def _load_pretrained_model(cls, model, state_dict, loaded_keys, *args, **kwargs): - # TODO: remove this once Hugging Face has fixed the naming inconsistency (https://github.com/huggingface/transformers/pull/29258) - # LlamaForQuestionAnswering has inconsistent naming of the base model: it is called "transformer" instead of "model" - # if we are loading a LlamaForQuestionAnswering model, state_dict and model contain keys like transformer.embed_tokens.weight', 'transformer.layers.0.self_attn.q_proj.weight', ... - # rename every occurence of "transformer" to "model" in the state_dict - - state_dict = [key.replace("transformer", "model") for key in state_dict] - model = {key.replace("transformer", "model"): value for key, value in model.items()} + # LlamaForQuestionAnswering has inconsistent naming of the base model: it is called "transformer" instead of "model". Changing the variable name would be a breaking change, hence they keep it. + # => LlamaForQuestionAnswering's state_dict and model contain keys like 'transformer.embed_tokens.weight', 'transformer.layers.0.self_attn.q_proj.weight', ... + # => If the key begins with "transformer" then rename it to "model" + state_dict = [ + key.replace("transformer", "model", 1) if key.startswith("transformer") else key for key in state_dict + ] + model = { + key.replace("transformer", "model", 1) if key.startswith("transformer") else key: value + for key, value in model.items() + } return super()._load_pretrained_model(cls, model, state_dict, loaded_keys, *args, **kwargs) diff --git a/src/adapters/models/llama/mixin_llama.py b/src/adapters/models/llama/mixin_llama.py index 21314235a6..4f15c0d948 100644 --- a/src/adapters/models/llama/mixin_llama.py +++ b/src/adapters/models/llama/mixin_llama.py @@ -48,4 +48,5 @@ def post_embedding_forward(self, module, args, embedding_output): class LlamaForQuestionAnsweringAdapterMixin: # this is needed because Transformers v4.38.1 is inconsistent with the naming of the base model but didn't change the base_model_prefix + # TODO: remove this when the inconsistency is fixed and remove the LlamaForQuestionAnsweringAdapterMixin from `src/adapters/models/__init__.py` base_model_prefix = "transformer" From 487e07e53ae9ca7bfb904d23e7247f6676d0c78e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Mon, 18 Mar 2024 22:30:49 +0100 Subject: [PATCH 08/12] remove _load_pretrained_model from LlamaAdapterModel & update docs --- docs/classes/models/auto.rst | 6 ++++++ docs/classes/models/llama.rst | 6 ++++++ docs/prediction_heads.md | 12 +++++------- docs/quickstart.md | 3 ++- setup.py | 10 +++++----- src/adapters/models/llama/adapter_model.py | 14 -------------- 6 files changed, 24 insertions(+), 27 deletions(-) diff --git a/docs/classes/models/auto.rst b/docs/classes/models/auto.rst index f4081a77c4..969c353cd0 100644 --- a/docs/classes/models/auto.rst +++ b/docs/classes/models/auto.rst @@ -4,6 +4,12 @@ Auto Classes Similar to the ``AutoModel`` classes built-in into HuggingFace Transformers, adapters provides an ``AutoAdapterModel`` class. As with other auto classes, the correct adapter model class is automatically instantiated based on the pre-trained model passed to the ``from_pretrained()`` method. +```{eval-rst} +.. note:: + If the model loaded with the ``from_pretrained(...)`` function has a head, this head gets loaded as well. However, this only works for non-sharded models. If you want to load a sharded model with a head, you first need to load the model and then the head separately. + However, only LLMs are sharded. LLMs are typically used without a head, so this limitation of the ``from_pretrained(...)`` function should rarely occur. +``` + AutoAdapterModel ~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/classes/models/llama.rst b/docs/classes/models/llama.rst index c7fffe1834..2eb3c45921 100644 --- a/docs/classes/models/llama.rst +++ b/docs/classes/models/llama.rst @@ -1,5 +1,11 @@ LLaMA ----------------------------------------------------------------------------------------------------------------------- +```{eval-rst} +.. note:: + Loading a ``LlamaForQuestionAnswering`` via [`AutoAdapterModel`](adapters.AutoAdapterModel) or via [`LlamaAdapterModel`](adapters.LlamaAdapterModel) does not load the head, even if the model is not sharded. Please load the base model first and then subsequently the head. + Note that for sharded models the head is never automatically loaded as described here: [Auto Classes](auto.rst) +``` + The LLaMA model was proposed in `LLaMA: Open and Efficient Foundation Language Models `__ by Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, diff --git a/docs/prediction_heads.md b/docs/prediction_heads.md index 33786385d6..eba5079024 100644 --- a/docs/prediction_heads.md +++ b/docs/prediction_heads.md @@ -6,7 +6,7 @@ We will take a look at the `AdapterModel` classes (e.g. `BertAdapterModel`) intr ```{eval-rst} .. tip:: We recommend to use the `AdapterModel classes <#adaptermodel-classes>`_ whenever possible. - They have been created specifically for working with adapters and provide more flexibility. + These **flexible** models have been created specifically for working with adapters. ``` ## AdapterModel classes @@ -18,16 +18,14 @@ First, we load pre-trained model from the Hugging Face Hub via the [`AutoAdapter model = AutoAdapterModel.from_pretrained("bert-base-uncased") ``` -By default, this model doesn't have any heads yet. We add a new one in the next step: +By default, this model doesn't have any heads yet, so let's add a new binary sequence classification head on top of our model: ```python model.add_classification_head("mrpc", num_labels=2) ``` -The line above adds a binary sequence classification head on top of our model. -Because this head is named, we could add multiple other heads with different names to the same model. -This is especially useful if used together with matching adapter modules. -To learn more about the different head types and the configuration options, please refer to the class references of the respective model classes, e.g. [`BertAdapterModel`](adapters.BertAdapterModel). +All heads have a name, we called this new head `"mrpc"`. Since all heads are named, we can add multiple other heads with different names to the same model. +To see the head types of a model and how they can get configured, please refer to the class references of the respective model classes, e.g. [`BertAdapterModel`](adapters.BertAdapterModel). -Now, of course, we would like to train our classification head together with an adapter, so let's add one: +A head alone is just one layer with very few parameters. Hence, we want to train our classification head together with an adapter, so let's add one: ```python model.add_adapter("mrpc", config="seq_bn") model.set_active_adapters("mrpc") diff --git a/docs/quickstart.md b/docs/quickstart.md index c1181f1d28..4d64b51e41 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -120,4 +120,5 @@ model.delete_adapter(adapter_name) _We also have a Quickstart Colab notebook for adapter training:_ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/01_Adapter_Training.ipynb) -For more examples on training different adapter setups, refer to the section on [Adapter Training](training.md). +For more examples of training different adapter setups, refer to the section on [Adapter Training](training.md). +Further information on using adapters with prediction heads can be found in the [Prediction Heads](prediction_heads.md) section. diff --git a/setup.py b/setup.py index 9dc65b08fd..0c71f1f5c6 100644 --- a/setup.py +++ b/setup.py @@ -51,13 +51,13 @@ "sacremoses", "scikit-learn", "sentencepiece>=0.1.91,!=0.1.92", - "sphinx-copybutton", - "sphinx-markdown-tables", + "sphinx-copybutton==0.5.2", + "sphinx-markdown-tables==0.0.17", "sphinx-rtd-theme==0.4.3", # sphinx-rtd-theme==0.5.0 introduced big changes in the style. - "sphinx==3.2.1", + "sphinx==5.0.2", "sphinxext-opengraph==0.4.1", - "sphinx-intl", - "sphinx-multiversion", + "sphinx-intl==2.1.0", + "sphinx-multiversion==0.2.4", "timeout-decorator", "torch>=1.10,!=1.12.0", "transformers~=4.38.1", diff --git a/src/adapters/models/llama/adapter_model.py b/src/adapters/models/llama/adapter_model.py index 4c93319e8b..16bea405d4 100644 --- a/src/adapters/models/llama/adapter_model.py +++ b/src/adapters/models/llama/adapter_model.py @@ -152,17 +152,3 @@ def prepare_inputs_for_generation( } ) return model_inputs - - def _load_pretrained_model(cls, model, state_dict, loaded_keys, *args, **kwargs): - # LlamaForQuestionAnswering has inconsistent naming of the base model: it is called "transformer" instead of "model". Changing the variable name would be a breaking change, hence they keep it. - # => LlamaForQuestionAnswering's state_dict and model contain keys like 'transformer.embed_tokens.weight', 'transformer.layers.0.self_attn.q_proj.weight', ... - # => If the key begins with "transformer" then rename it to "model" - state_dict = [ - key.replace("transformer", "model", 1) if key.startswith("transformer") else key for key in state_dict - ] - model = { - key.replace("transformer", "model", 1) if key.startswith("transformer") else key: value - for key, value in model.items() - } - - return super()._load_pretrained_model(cls, model, state_dict, loaded_keys, *args, **kwargs) From c43ab9fc879372cc8fc6ceebdec1db0e1314217a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Mon, 18 Mar 2024 22:32:11 +0100 Subject: [PATCH 09/12] remove TODO to rename the Hugging Face library key --- src/adapters/hub_mixin.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/adapters/hub_mixin.py b/src/adapters/hub_mixin.py index ece00238a1..7a1009c5b8 100644 --- a/src/adapters/hub_mixin.py +++ b/src/adapters/hub_mixin.py @@ -70,7 +70,8 @@ def _save_adapter_card( metrics: Optional[List[str]] = None, **kwargs ): - all_tags = {"adapter-transformers"} # TODO: change this tag once changed on HF side + # Key remains "adapter-transformers", see: https://github.com/huggingface/huggingface.js/pull/459 + all_tags = {"adapter-transformers"} datasets = set() # Dataset/ Task info dataset_name = None From 8d5ecff549184fabce8a3659e4aebedd411d6d29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Mon, 18 Mar 2024 23:21:27 +0100 Subject: [PATCH 10/12] fix broken links in docs & fix test --- docs/adapter_composition.md | 2 +- docs/methods.md | 2 +- docs/training.md | 2 +- tests/test_llama.py | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/adapter_composition.md b/docs/adapter_composition.md index 0851e7970f..ccbe26fe99 100644 --- a/docs/adapter_composition.md +++ b/docs/adapter_composition.md @@ -19,7 +19,7 @@ model.active_adapters = "adapter_name" - You cannot activate an adapter before previously adding it to the model using either ``add_adapter()`` or ``load_adapter()``. - All adapters not mentioned in the ``active_adapters`` setup are ignored, although they might have been loaded into the model. Thus, after adding an adapter, make sure to activate it. ``` -Note that we also could have used the [`set_active_adapters`](adapters.) method with `model.set_active_adapters("adapter_name")` which does the same. +Note that we also could have used the `set_active_adapters` method with `model.set_active_adapters("adapter_name")` which does the same. Alternatively, the [`AdapterSetup`](adapters.AdapterSetup) context manager allows dynamic configuration of activated setups without changing the model state: diff --git a/docs/methods.md b/docs/methods.md index 06ad700e68..535b23d088 100644 --- a/docs/methods.md +++ b/docs/methods.md @@ -1,7 +1,7 @@ # Adapter Methods On this page, we present all adapter methods currently integrated into the `adapters` library. -A tabular overview of adapter methods is provided [here](overview.html#table-of-adapter-methods). +A tabular overview of adapter methods is provided [here](overview.md#table-of-adapter-methods). Additionally, options to combine multiple adapter methods in a single setup are presented [on the next page](method_combinations.md). ## Bottleneck Adapters diff --git a/docs/training.md b/docs/training.md index 649ace6e28..89c5703f25 100644 --- a/docs/training.md +++ b/docs/training.md @@ -84,7 +84,7 @@ model.set_active_adapters(task_name) ### Step D - Switch to `AdapterTrainer` class -Finally, we exchange the `Trainer` class built into Transformers for the [`AdapterTrainer`](transformers.adapters.AdapterTrainer) class that is optimized for training adapter methods. +Finally, we exchange the `Trainer` class built into Transformers for the [`AdapterTrainer`](adapters.AdapterTrainer) class that is optimized for training adapter methods. See [below for more information](#adaptertrainer). Technically, this change is not required as no changes to the training loop are required for training adapters. diff --git a/tests/test_llama.py b/tests/test_llama.py index 9cb6fcfda2..e8cf0557a0 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -61,4 +61,5 @@ class LlamaClassConversionTest( LlamaAdapterTestBase, unittest.TestCase, ): - pass + def test_conversion_question_answering_model(self): + raise self.skipTest("We don't support the Llama QA model.") From d879301eaca7ce9487af1b3b6af21077382439e0 Mon Sep 17 00:00:00 2001 From: calpt Date: Sat, 6 Apr 2024 14:00:01 +0200 Subject: [PATCH 11/12] Update docs/classes/models/auto.rst --- docs/classes/models/auto.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/classes/models/auto.rst b/docs/classes/models/auto.rst index 969c353cd0..6ce8a3050e 100644 --- a/docs/classes/models/auto.rst +++ b/docs/classes/models/auto.rst @@ -7,7 +7,6 @@ As with other auto classes, the correct adapter model class is automatically ins ```{eval-rst} .. note:: If the model loaded with the ``from_pretrained(...)`` function has a head, this head gets loaded as well. However, this only works for non-sharded models. If you want to load a sharded model with a head, you first need to load the model and then the head separately. - However, only LLMs are sharded. LLMs are typically used without a head, so this limitation of the ``from_pretrained(...)`` function should rarely occur. ``` AutoAdapterModel From 73ceaba7175aabbbfbf559c28325844f5c6f2c08 Mon Sep 17 00:00:00 2001 From: calpt Date: Sat, 6 Apr 2024 14:24:04 +0200 Subject: [PATCH 12/12] Fix docs issues --- docs/adapter_composition.md | 2 +- docs/classes/models/auto.rst | 2 -- docs/classes/models/bart.rst | 2 +- docs/classes/models/electra.rst | 2 +- docs/classes/models/llama.rst | 3 +-- docs/conf.py | 2 +- docs/index.rst | 1 + docs/training.md | 2 +- src/adapters/models/bart/adapter_model.py | 4 ++-- src/adapters/utils.py | 4 ++-- 10 files changed, 11 insertions(+), 13 deletions(-) diff --git a/docs/adapter_composition.md b/docs/adapter_composition.md index ccbe26fe99..5ff2d4284f 100644 --- a/docs/adapter_composition.md +++ b/docs/adapter_composition.md @@ -125,7 +125,7 @@ model.active_adapters = ac.Fuse("d", "e", "f") To learn how training an _AdapterFusion_ layer works, check out [this Colab notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/03_Adapter_Fusion.ipynb) from the `adapters` repo. -#### Retrieving AdapterFusion attentions +### Retrieving AdapterFusion attentions Finally, it is possible to retrieve the attention scores computed by each fusion layer in a forward pass of the model. These scores can be used for analyzing the fused adapter blocks and can serve as the basis for visualizations similar to those in the AdapterFusion paper. diff --git a/docs/classes/models/auto.rst b/docs/classes/models/auto.rst index 6ce8a3050e..a276854894 100644 --- a/docs/classes/models/auto.rst +++ b/docs/classes/models/auto.rst @@ -4,10 +4,8 @@ Auto Classes Similar to the ``AutoModel`` classes built-in into HuggingFace Transformers, adapters provides an ``AutoAdapterModel`` class. As with other auto classes, the correct adapter model class is automatically instantiated based on the pre-trained model passed to the ``from_pretrained()`` method. -```{eval-rst} .. note:: If the model loaded with the ``from_pretrained(...)`` function has a head, this head gets loaded as well. However, this only works for non-sharded models. If you want to load a sharded model with a head, you first need to load the model and then the head separately. -``` AutoAdapterModel ~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/classes/models/bart.rst b/docs/classes/models/bart.rst index 5ea8eeb11f..67a5e56572 100644 --- a/docs/classes/models/bart.rst +++ b/docs/classes/models/bart.rst @@ -22,4 +22,4 @@ BartAdapterModel .. autoclass:: adapters.BartAdapterModel :members: - :inherited-members: BartPretrainedModel + :inherited-members: BartPreTrainedModel diff --git a/docs/classes/models/electra.rst b/docs/classes/models/electra.rst index d67a96d8d5..e0dc9c5ef4 100644 --- a/docs/classes/models/electra.rst +++ b/docs/classes/models/electra.rst @@ -1,5 +1,5 @@ ELECTRA -====== +======= The ELECTRA model was proposed in the paper `ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators `__. ELECTRA is a new pretraining approach which trains two diff --git a/docs/classes/models/llama.rst b/docs/classes/models/llama.rst index 2eb3c45921..f650f93225 100644 --- a/docs/classes/models/llama.rst +++ b/docs/classes/models/llama.rst @@ -1,10 +1,9 @@ LLaMA ----------------------------------------------------------------------------------------------------------------------- -```{eval-rst} + .. note:: Loading a ``LlamaForQuestionAnswering`` via [`AutoAdapterModel`](adapters.AutoAdapterModel) or via [`LlamaAdapterModel`](adapters.LlamaAdapterModel) does not load the head, even if the model is not sharded. Please load the base model first and then subsequently the head. Note that for sharded models the head is never automatically loaded as described here: [Auto Classes](auto.rst) -``` The LLaMA model was proposed in `LLaMA: Open and Efficient Foundation Language Models `__ by diff --git a/docs/conf.py b/docs/conf.py index 417623a94b..746e77d4a9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -90,7 +90,7 @@ def skip_head_member(app, what, name, obj, skip, options): if type(obj).__name__ == "function" and "inherited-members" in options and (m := re.match(r"add\_(.*)\_head$", name)): - cls_name = options["inherited-members"].replace("PreTrainedModel", "AdapterModel").replace("PretrainedModel", "AdapterModel") + cls_name = list(options["inherited-members"])[0].replace("PreTrainedModel", "AdapterModel").replace("PretrainedModel", "AdapterModel") cls = vars(sys.modules["adapters"])[cls_name] # HACK: currently parses head type from name head_type_str = m.group(1).replace("qa", "question_answering") diff --git a/docs/index.rst b/docs/index.rst index 4d13a1942d..5f87f7ae1e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -82,6 +82,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model classes/models/gptj classes/models/llama classes/models/mbart + classes/models/mt5 classes/models/roberta classes/models/t5 classes/models/vit diff --git a/docs/training.md b/docs/training.md index 89c5703f25..95ffeb2f1e 100644 --- a/docs/training.md +++ b/docs/training.md @@ -84,7 +84,7 @@ model.set_active_adapters(task_name) ### Step D - Switch to `AdapterTrainer` class -Finally, we exchange the `Trainer` class built into Transformers for the [`AdapterTrainer`](adapters.AdapterTrainer) class that is optimized for training adapter methods. +Finally, we exchange the `Trainer` class built into Transformers for the [`AdapterTrainer`](adapters.trainer.AdapterTrainer) class that is optimized for training adapter methods. See [below for more information](#adaptertrainer). Technically, this change is not required as no changes to the training loop are required for training adapters. diff --git a/src/adapters/models/bart/adapter_model.py b/src/adapters/models/bart/adapter_model.py index dec5a838c2..384955cc11 100644 --- a/src/adapters/models/bart/adapter_model.py +++ b/src/adapters/models/bart/adapter_model.py @@ -5,7 +5,7 @@ BART_START_DOCSTRING, BartConfig, BartModel, - BartPretrainedModel, + BartPreTrainedModel, shift_tokens_right, ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward @@ -18,7 +18,7 @@ @add_start_docstrings( "BART Model with the option to add multiple flexible prediction heads on top.", BART_START_DOCSTRING ) -class BartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BartPretrainedModel): +class BartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BartPreTrainedModel): _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", diff --git a/src/adapters/utils.py b/src/adapters/utils.py index 0e3b20cabe..1dbe06cf24 100644 --- a/src/adapters/utils.py +++ b/src/adapters/utils.py @@ -837,8 +837,8 @@ def prefix_attention_mask(attention_mask, dim: int = 3, prefix_value: int = 0): The value to use for the prefix_attention_mask. Defaults to 0, however some models, e.g. DistilBert, use different values. BERT like models invert their extended_attention_mask, hence they use 0 as value for not masked tokens. This inversion is usually done in the forward method of the model in 2 different ways: - 1) by calling self.invert_attention_mask, as BERT does 2) by doing the inversion manually, e.g. ALBERT - does: `extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min` + 1) by calling self.invert_attention_mask, as BERT does 2) by doing the inversion manually, e.g. ALBERT + does: `extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min` """ forward_context = ForwardContext.get_context()