diff --git a/.github/workflows/adapter_docs_build.yml b/.github/workflows/adapter_docs_build.yml index 187f57d82c..35fab0de49 100644 --- a/.github/workflows/adapter_docs_build.yml +++ b/.github/workflows/adapter_docs_build.yml @@ -18,7 +18,7 @@ jobs: fetch-depth: 0 - uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: "3.10" - name: Install run: | pip install setuptools==57.4.0 diff --git a/.github/workflows/tests_torch.yml b/.github/workflows/tests_torch.yml index fd5930ebb6..cb8c61be1b 100644 --- a/.github/workflows/tests_torch.yml +++ b/.github/workflows/tests_torch.yml @@ -32,8 +32,8 @@ jobs: submodules: true - uses: actions/setup-python@v2 with: - python-version: 3.8 - - uses: actions/cache@v2 + python-version: "3.10" + - uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} @@ -53,8 +53,8 @@ jobs: submodules: true - uses: actions/setup-python@v2 with: - python-version: 3.8 - - uses: actions/cache@v2 + python-version: "3.10" + - uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} @@ -76,8 +76,8 @@ jobs: submodules: true - uses: actions/setup-python@v2 with: - python-version: 3.8 - - uses: actions/cache@v2 + python-version: "3.10" + - uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} @@ -99,8 +99,8 @@ jobs: submodules: true - uses: actions/setup-python@v2 with: - python-version: 3.8 - - uses: actions/cache@v2 + python-version: "3.10" + - uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} diff --git a/docs/methods.md b/docs/methods.md index 302b1973d3..95226e3578 100644 --- a/docs/methods.md +++ b/docs/methods.md @@ -59,6 +59,11 @@ _Papers:_ * [Adapters Strike Back](https://arxiv.org/pdf/2406.06820) (Steitz and Roth., 2024) * [AdapterHub: A Framework for Adapting Transformers](https://arxiv.org/pdf/2007.07779.pdf) (Pfeiffer et al., 2020) +```{eval-rst} +.. note:: + The two parameters ``original_ln_before`` and ``original_ln_after`` inside bottleneck adapters control both the addition of the residual input and the application of the pretrained layer norm. If the original model does not apply a layer norm function at a specific position of the forward function (e.g after the FFN layer), the two bottleneck parameters of the adapter set at that same position will only control the application of the residual input. +``` + ## Language Adapters - Invertible Adapters _Configuration class_: [`SeqBnInvConfig`](adapters.SeqBnInvConfig), [`DoubleSeqBnInvConfig`](adapters.DoubleSeqBnInvConfig) diff --git a/hf_transformers b/hf_transformers index 052e652d6d..241c04d368 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit 052e652d6d53c2b26ffde87e039b723949a53493 +Subproject commit 241c04d36867259cdf11dbb4e9d9a60f9cb65ebc diff --git a/notebooks/ViT_AdapterPlus_FineTuning.ipynb b/notebooks/ViT_AdapterPlus_FineTuning.ipynb index 1cf549ea75..6833a6b0e1 100644 --- a/notebooks/ViT_AdapterPlus_FineTuning.ipynb +++ b/notebooks/ViT_AdapterPlus_FineTuning.ipynb @@ -205,7 +205,18 @@ "source": [ "### Loading the `ViT` model and the `AdapterPlusConfig`\n", "\n", - "Here we load the `vit-base-patch16-224-in21k` model similar to the one used in the `AdapterConfig` paper. We will load the model using the `adapters` `AutoAdapterModel` and add the corresponding `AdapterPlusConfig`. To read more about the config, you can check out the docs page [here](https://docs.adapterhub.ml/methods#bottleneck-adapters) under `AdapterPlusConfig`" + "Here we load the `vit-base-patch16-224-in21k` model similar to the one used in the `AdapterConfig` paper. We will load the model using the `adapters` `AutoAdapterModel` and add the corresponding `AdapterPlusConfig`. To read more about the config, you can check out the docs page [here](https://docs.adapterhub.ml/methods#bottleneck-adapters) under `AdapterPlusConfig`.\n", + "\n", + "#### Important Note\n", + "\n", + "Please note that some configurations of the adapters parameters `original_ln_after`, `original_ln_before`, and \n", + "`residual_before_ln` may result in performance issues when training. \n", + "\n", + "In the general case:\n", + "\n", + "1) At least one of `original_ln_before` or `original_ln_after` should be set to `True` in order to ensure that the original residual\n", + " connection from pre-training is preserved. \n", + "2) If `original_ln_after` is set to `False`, `residual_before_ln` must also be set to `False` to ensure convergence during training." ] }, { @@ -218,7 +229,7 @@ "from adapters import AdapterPlusConfig\n", "\n", "model = ViTAdapterModel.from_pretrained(model_name_or_path)\n", - "config = AdapterPlusConfig(original_ln_after=True)\n", + "config = AdapterPlusConfig()\n", "\n", "model.add_adapter(\"adapterplus_config\", config)\n", "model.add_image_classification_head(\"adapterplus_config\", num_labels=num_classes)\n", diff --git a/setup.py b/setup.py index e7389c8be6..d7a15ef921 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ "isort>=5.5.4", "Jinja2==2.11.3", "nltk", + "packaging", "parameterized", "pillow", "protobuf", @@ -60,7 +61,7 @@ "timeout-decorator", "torch", "torchvision", - "transformers~=4.46.3", + "transformers~=4.47.1", ] @@ -136,11 +137,12 @@ def deps_list(*pkgs): # when modifying the following list, make sure to update src/transformers/dependency_versions_check.py install_requires = [ deps["transformers"], + deps["packaging"], ] setup( name="adapters", - version="1.0.1", + version="1.1.0.dev0", author="The AdapterHub team and community contributors", author_email="calpt@mail.de", description="A Unified Library for Parameter-Efficient and Modular Transfer Learning", diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index a917828e72..88549c6969 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.0.1" +__version__ = "1.1.0.dev0" from typing import TYPE_CHECKING diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index b5249cb9f5..9e1cf052ac 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -374,10 +374,19 @@ class ParBnConfig(BnConfig): class AdapterPlusConfig(BnConfig): """ The AdapterPlus config architecture proposed by Jan-Martin O, Steitz and Stefan Roth. See https://arxiv.org/pdf/2406.06820 + + Please note that some configurations of the adapters parameters `original_ln_after`, `original_ln_before`, and + `residual_before_ln` may result in performance issues when training. + + In the general case: + 1) At least one of `original_ln_before` or `original_ln_after` should be set to True in order to ensure that the original residual + connection from pre-training is preserved. + 2) If `original_ln_after` is set to `False`, `residual_before_ln` must also be set to `False` to ensure convergence during training. """ original_ln_after: bool = False - residual_before_ln: bool = True + original_ln_before: bool = True + residual_before_ln: bool = False stochastic_depth: float = 0.1 init_weights: str = "houlsby" scaling: Union[float, str] = "channel" diff --git a/src/adapters/loading.py b/src/adapters/loading.py index b1918b0a0f..69747e04cb 100644 --- a/src/adapters/loading.py +++ b/src/adapters/loading.py @@ -6,6 +6,7 @@ from typing import Callable, Mapping, Optional, Sequence, Tuple import torch +from packaging.version import Version try: @@ -368,6 +369,23 @@ def _rename_legacy_weights(self, k): k = k.replace(old, new) return k + def _fix_backward_compat(self, config): + # Fix error in previous versions for LoRA/ (IA)^3 + ADAPTER_PREFIX = "adapters." + MIN_VERSION = Version("1.1.0") + + version = config.get("version", "") + if version.startswith(ADAPTER_PREFIX) and Version(version[len(ADAPTER_PREFIX) :]) < MIN_VERSION: + if ( + config["config"].get("architecture", None) == "lora" + and config["config"]["r"] != config["config"]["alpha"] + ): + logger.warning( + "Loading a LoRA trained using a faulty scaling implementation of a previous library version. Editing the configuration to make sure the adapter works as trained." + "See https://github.com/adapter-hub/adapters/pull/770 for more." + ) + config["config"]["alpha"] = config["config"]["r"] + # This method is used to remove unnecessary invertible adapters from task adapters using the old format. # In the old format, task adapters e.g. using seq_bn config specify inv. adapters but don't use them. # As inv. adapters would be incorrectly used in the new implementation, @@ -560,6 +578,8 @@ def load( # The conversion to a set and then back to a list removes all duplicates leave_out = list(set(leave_out + config["config"]["leave_out"])) config["config"]["leave_out"] = leave_out + # Fix issues + self._fix_backward_compat(config) adapter_name = load_as or config["name"] # If the adapter is not part of the model, add it diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index d56a11a91d..8f3bc29401 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -100,6 +100,7 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens hidden_states = hidden_states * gate else: gate = None + hidden_states = hidden_states * self.scaling return hidden_states, gate @@ -171,6 +172,7 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens hidden_states = hidden_states * gate else: gate = None + hidden_states = hidden_states * self.scaling return hidden_states, gate diff --git a/src/adapters/models/deberta/modeling_deberta.py b/src/adapters/models/deberta/modeling_deberta.py index 4380b5e038..77c6117b19 100644 --- a/src/adapters/models/deberta/modeling_deberta.py +++ b/src/adapters/models/deberta/modeling_deberta.py @@ -16,12 +16,13 @@ import torch import torch.utils.checkpoint +from torch import nn from transformers.models.deberta.modeling_deberta import ( DebertaOutput, DebertaSelfOutput, DisentangledSelfAttention, - XSoftmax, + scaled_size_sqrt, ) from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel @@ -95,71 +96,83 @@ def forward( """ + # >>> START AH Changes <<< attention_mask = prefix_attention_mask(attention_mask, dim=3, prefix_value=1) # type: ignore attention_mask = prefix_attention_mask(attention_mask, dim=2, prefix_value=1) # type: ignore + # >>> END AH Changes <<< if query_states is None: qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1) query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1) else: - - def linear(w, b, x): - if b is not None: - return torch.matmul(x, w.t()) + b.t() - else: - return torch.matmul(x, w.t()) # + b.t() - ws = self.in_proj.weight.chunk(self.num_attention_heads * 3, dim=0) qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)] - qkvb = [None] * 3 - - q = linear(qkvw[0], qkvb[0], query_states.to(dtype=qkvw[0].dtype)) - k, v = [linear(qkvw[i], qkvb[i], hidden_states.to(dtype=qkvw[i].dtype)) for i in range(1, 3)] + q = torch.matmul(qkvw[0], query_states.t().to(dtype=qkvw[0].dtype)) + k = torch.matmul(qkvw[1], hidden_states.t().to(dtype=qkvw[1].dtype)) + v = torch.matmul(qkvw[2], hidden_states.t().to(dtype=qkvw[2].dtype)) query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]] + # >>> START AH Changes <<< query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask) + # >>> END AH Changes <<< query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :]) + # >>> START AH Changes <<< orig_key_layer = key_layer # save this for relative attention key_layer, value_layer, attention_mask = self.prefix_tuning( key_layer, value_layer, hidden_states, attention_mask, False ) (query_layer, orig_key_layer) = adjust_tensors_for_parallel(key_layer, query_layer, orig_key_layer) + # >>> END AH Changes <<< - rel_att = None + rel_att: int = 0 # Take the dot product between "query" and "key" to get the raw attention scores. scale_factor = 1 + len(self.pos_att_type) - scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) + scale = scaled_size_sqrt(query_layer, scale_factor) query_layer = query_layer / scale.to(dtype=query_layer.dtype) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - if self.relative_attention: + + if self.relative_attention and rel_embeddings is not None and relative_pos is not None: rel_embeddings = self.pos_dropout(rel_embeddings) + # >>> START AH Changes <<< rel_att = self.disentangled_att_bias( query_layer, orig_key_layer, relative_pos, rel_embeddings, scale_factor ) + # >>> END AH Changes <<< if rel_att is not None: - rel_att_padded = torch.zeros_like(attention_scores) - rel_att_padded[:, :, :, -rel_att.size(-1) :] = rel_att - attention_scores = attention_scores + rel_att_padded + # >>> START AH Changes <<< + # rel_att is set to 0 by default, i.e. rel_att is always not None (don't know why HuggingFace does this). + # Hence, we must check whether rel_att is a tensor and if so, pad it with zeros to be able to add it to attention_scores. + if isinstance(rel_att, torch.Tensor): + rel_att_padded = torch.zeros_like(attention_scores) + rel_att_padded[:, :, :, -rel_att.size(-1) :] = rel_att + attention_scores = attention_scores + rel_att_padded + else: + attention_scores = attention_scores + rel_att + # >>> END AH Changes <<< # bxhxlxd - if self.talking_head: + if self.head_logits_proj is not None: attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_mask = attention_mask.bool() + attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min) + # bsz x height x length x dimension + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + attention_probs.masked_fill(attention_mask, 0) + attention_probs = self.dropout(attention_probs) - if self.talking_head: + if self.head_weights_proj is not None: attention_probs = self.head_weights_proj(attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.view(new_context_layer_shape) - if output_attentions: - return (context_layer, attention_probs) - else: - return context_layer + if not output_attentions: + return (context_layer, None) + return (context_layer, attention_probs) diff --git a/src/adapters/models/deberta_v2/modeling_deberta_v2.py b/src/adapters/models/deberta_v2/modeling_deberta_v2.py index bc41ae82af..2b673c491f 100644 --- a/src/adapters/models/deberta_v2/modeling_deberta_v2.py +++ b/src/adapters/models/deberta_v2/modeling_deberta_v2.py @@ -16,12 +16,13 @@ import torch import torch.utils.checkpoint +from torch import nn from transformers.models.deberta_v2.modeling_deberta_v2 import ( DebertaV2Output, DebertaV2SelfOutput, DisentangledSelfAttention, - XSoftmax, + scaled_size_sqrt, ) from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel @@ -90,11 +91,15 @@ def forward( The embedding of relative distances. It's a tensor of shape [\\(2 \\times \\text{max_relative_positions}\\), *hidden_size*]. """ + # >>> START AH Changes <<< attention_mask = prefix_attention_mask(attention_mask, dim=3, prefix_value=1) # type: ignore attention_mask = prefix_attention_mask(attention_mask, dim=2, prefix_value=1) # type: ignore + # >>> END AH Changes <<< if query_states is None: query_states = hidden_states + + # >>> START AH Changes <<< query_layer = self.transpose_for_scores_extended(self.query_proj(query_states), self.num_attention_heads) key_layer = self.transpose_for_scores_extended(self.key_proj(hidden_states), self.num_attention_heads) value_layer = self.transpose_for_scores_extended(self.value_proj(hidden_states), self.num_attention_heads) @@ -112,6 +117,7 @@ def forward( key_layer = key_layer.contiguous().view(-1, key_layer.size(2), key_layer.size(-1)) value_layer = value_layer.contiguous().view(-1, value_layer.size(2), value_layer.size(-1)) orig_key_layer = orig_key_layer.contiguous().view(-1, orig_key_layer.size(2), orig_key_layer.size(-1)) + # >>> END AH Changes <<< rel_att = None # Take the dot product between "query" and "key" to get the raw attention scores. @@ -120,25 +126,39 @@ def forward( scale_factor += 1 if "p2c" in self.pos_att_type: scale_factor += 1 - scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) - attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale.to(dtype=query_layer.dtype) + scale = scaled_size_sqrt(query_layer, scale_factor) + attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype)) if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) + # >>> START AH Changes <<< rel_att = self.disentangled_attention_bias( query_layer, orig_key_layer, relative_pos, rel_embeddings, scale_factor ) + # >>> END AH Changes <<< if rel_att is not None: - rel_att_padded = torch.zeros_like(attention_scores) - rel_att_padded[:, :, -rel_att.size(2) :] = rel_att - attention_scores = attention_scores + rel_att_padded + # >>> START AH Changes <<< + # rel_att is set to 0 by default, i.e. rel_att is always not None (don't know why HuggingFace does this). + # Hence, we must check whether rel_att is a tensor and if so, pad it with zeros to be able to add it to attention_scores. + if isinstance(rel_att, torch.Tensor): + rel_att_padded = torch.zeros_like(attention_scores) + rel_att_padded[:, :, -rel_att.size(2) :] = rel_att + attention_scores = attention_scores + rel_att_padded + else: + attention_scores = attention_scores + rel_att + # >>> END AH Changes <<< + attention_scores = attention_scores attention_scores = attention_scores.view( -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) ) + attention_mask = attention_mask.bool() + attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min) # bsz x height x length x dimension - attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + attention_probs.masked_fill(attention_mask, 0) + attention_probs = self.dropout(attention_probs) context_layer = torch.bmm( attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer @@ -150,7 +170,6 @@ def forward( ) new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.view(new_context_layer_shape) - if output_attentions: - return (context_layer, attention_probs) - else: - return context_layer + if not output_attentions: + return (context_layer, None) + return (context_layer, attention_probs)