From 7c2357f8d49b6dedab9ab83143b6cbbff5d92301 Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 8 Jan 2025 11:20:04 +0100 Subject: [PATCH 1/9] Upgrade Transformers to v4.47.x (#776) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Leon Engländer --- .github/workflows/adapter_docs_build.yml | 2 +- .github/workflows/tests_torch.yml | 16 ++--- hf_transformers | 2 +- setup.py | 2 +- .../models/deberta/modeling_deberta.py | 63 +++++++++++-------- .../models/deberta_v2/modeling_deberta_v2.py | 41 ++++++++---- 6 files changed, 79 insertions(+), 47 deletions(-) diff --git a/.github/workflows/adapter_docs_build.yml b/.github/workflows/adapter_docs_build.yml index 187f57d82c..35fab0de49 100644 --- a/.github/workflows/adapter_docs_build.yml +++ b/.github/workflows/adapter_docs_build.yml @@ -18,7 +18,7 @@ jobs: fetch-depth: 0 - uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: "3.10" - name: Install run: | pip install setuptools==57.4.0 diff --git a/.github/workflows/tests_torch.yml b/.github/workflows/tests_torch.yml index fd5930ebb6..cb8c61be1b 100644 --- a/.github/workflows/tests_torch.yml +++ b/.github/workflows/tests_torch.yml @@ -32,8 +32,8 @@ jobs: submodules: true - uses: actions/setup-python@v2 with: - python-version: 3.8 - - uses: actions/cache@v2 + python-version: "3.10" + - uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} @@ -53,8 +53,8 @@ jobs: submodules: true - uses: actions/setup-python@v2 with: - python-version: 3.8 - - uses: actions/cache@v2 + python-version: "3.10" + - uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} @@ -76,8 +76,8 @@ jobs: submodules: true - uses: actions/setup-python@v2 with: - python-version: 3.8 - - uses: actions/cache@v2 + python-version: "3.10" + - uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} @@ -99,8 +99,8 @@ jobs: submodules: true - uses: actions/setup-python@v2 with: - python-version: 3.8 - - uses: actions/cache@v2 + python-version: "3.10" + - uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }} diff --git a/hf_transformers b/hf_transformers index 052e652d6d..241c04d368 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit 052e652d6d53c2b26ffde87e039b723949a53493 +Subproject commit 241c04d36867259cdf11dbb4e9d9a60f9cb65ebc diff --git a/setup.py b/setup.py index 1666ae3d0a..d7a15ef921 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ "timeout-decorator", "torch", "torchvision", - "transformers~=4.46.3", + "transformers~=4.47.1", ] diff --git a/src/adapters/models/deberta/modeling_deberta.py b/src/adapters/models/deberta/modeling_deberta.py index 4380b5e038..77c6117b19 100644 --- a/src/adapters/models/deberta/modeling_deberta.py +++ b/src/adapters/models/deberta/modeling_deberta.py @@ -16,12 +16,13 @@ import torch import torch.utils.checkpoint +from torch import nn from transformers.models.deberta.modeling_deberta import ( DebertaOutput, DebertaSelfOutput, DisentangledSelfAttention, - XSoftmax, + scaled_size_sqrt, ) from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel @@ -95,71 +96,83 @@ def forward( """ + # >>> START AH Changes <<< attention_mask = prefix_attention_mask(attention_mask, dim=3, prefix_value=1) # type: ignore attention_mask = prefix_attention_mask(attention_mask, dim=2, prefix_value=1) # type: ignore + # >>> END AH Changes <<< if query_states is None: qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1) query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1) else: - - def linear(w, b, x): - if b is not None: - return torch.matmul(x, w.t()) + b.t() - else: - return torch.matmul(x, w.t()) # + b.t() - ws = self.in_proj.weight.chunk(self.num_attention_heads * 3, dim=0) qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)] - qkvb = [None] * 3 - - q = linear(qkvw[0], qkvb[0], query_states.to(dtype=qkvw[0].dtype)) - k, v = [linear(qkvw[i], qkvb[i], hidden_states.to(dtype=qkvw[i].dtype)) for i in range(1, 3)] + q = torch.matmul(qkvw[0], query_states.t().to(dtype=qkvw[0].dtype)) + k = torch.matmul(qkvw[1], hidden_states.t().to(dtype=qkvw[1].dtype)) + v = torch.matmul(qkvw[2], hidden_states.t().to(dtype=qkvw[2].dtype)) query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]] + # >>> START AH Changes <<< query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask) + # >>> END AH Changes <<< query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :]) + # >>> START AH Changes <<< orig_key_layer = key_layer # save this for relative attention key_layer, value_layer, attention_mask = self.prefix_tuning( key_layer, value_layer, hidden_states, attention_mask, False ) (query_layer, orig_key_layer) = adjust_tensors_for_parallel(key_layer, query_layer, orig_key_layer) + # >>> END AH Changes <<< - rel_att = None + rel_att: int = 0 # Take the dot product between "query" and "key" to get the raw attention scores. scale_factor = 1 + len(self.pos_att_type) - scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) + scale = scaled_size_sqrt(query_layer, scale_factor) query_layer = query_layer / scale.to(dtype=query_layer.dtype) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - if self.relative_attention: + + if self.relative_attention and rel_embeddings is not None and relative_pos is not None: rel_embeddings = self.pos_dropout(rel_embeddings) + # >>> START AH Changes <<< rel_att = self.disentangled_att_bias( query_layer, orig_key_layer, relative_pos, rel_embeddings, scale_factor ) + # >>> END AH Changes <<< if rel_att is not None: - rel_att_padded = torch.zeros_like(attention_scores) - rel_att_padded[:, :, :, -rel_att.size(-1) :] = rel_att - attention_scores = attention_scores + rel_att_padded + # >>> START AH Changes <<< + # rel_att is set to 0 by default, i.e. rel_att is always not None (don't know why HuggingFace does this). + # Hence, we must check whether rel_att is a tensor and if so, pad it with zeros to be able to add it to attention_scores. + if isinstance(rel_att, torch.Tensor): + rel_att_padded = torch.zeros_like(attention_scores) + rel_att_padded[:, :, :, -rel_att.size(-1) :] = rel_att + attention_scores = attention_scores + rel_att_padded + else: + attention_scores = attention_scores + rel_att + # >>> END AH Changes <<< # bxhxlxd - if self.talking_head: + if self.head_logits_proj is not None: attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_mask = attention_mask.bool() + attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min) + # bsz x height x length x dimension + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + attention_probs.masked_fill(attention_mask, 0) + attention_probs = self.dropout(attention_probs) - if self.talking_head: + if self.head_weights_proj is not None: attention_probs = self.head_weights_proj(attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.view(new_context_layer_shape) - if output_attentions: - return (context_layer, attention_probs) - else: - return context_layer + if not output_attentions: + return (context_layer, None) + return (context_layer, attention_probs) diff --git a/src/adapters/models/deberta_v2/modeling_deberta_v2.py b/src/adapters/models/deberta_v2/modeling_deberta_v2.py index bc41ae82af..2b673c491f 100644 --- a/src/adapters/models/deberta_v2/modeling_deberta_v2.py +++ b/src/adapters/models/deberta_v2/modeling_deberta_v2.py @@ -16,12 +16,13 @@ import torch import torch.utils.checkpoint +from torch import nn from transformers.models.deberta_v2.modeling_deberta_v2 import ( DebertaV2Output, DebertaV2SelfOutput, DisentangledSelfAttention, - XSoftmax, + scaled_size_sqrt, ) from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel @@ -90,11 +91,15 @@ def forward( The embedding of relative distances. It's a tensor of shape [\\(2 \\times \\text{max_relative_positions}\\), *hidden_size*]. """ + # >>> START AH Changes <<< attention_mask = prefix_attention_mask(attention_mask, dim=3, prefix_value=1) # type: ignore attention_mask = prefix_attention_mask(attention_mask, dim=2, prefix_value=1) # type: ignore + # >>> END AH Changes <<< if query_states is None: query_states = hidden_states + + # >>> START AH Changes <<< query_layer = self.transpose_for_scores_extended(self.query_proj(query_states), self.num_attention_heads) key_layer = self.transpose_for_scores_extended(self.key_proj(hidden_states), self.num_attention_heads) value_layer = self.transpose_for_scores_extended(self.value_proj(hidden_states), self.num_attention_heads) @@ -112,6 +117,7 @@ def forward( key_layer = key_layer.contiguous().view(-1, key_layer.size(2), key_layer.size(-1)) value_layer = value_layer.contiguous().view(-1, value_layer.size(2), value_layer.size(-1)) orig_key_layer = orig_key_layer.contiguous().view(-1, orig_key_layer.size(2), orig_key_layer.size(-1)) + # >>> END AH Changes <<< rel_att = None # Take the dot product between "query" and "key" to get the raw attention scores. @@ -120,25 +126,39 @@ def forward( scale_factor += 1 if "p2c" in self.pos_att_type: scale_factor += 1 - scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) - attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale.to(dtype=query_layer.dtype) + scale = scaled_size_sqrt(query_layer, scale_factor) + attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype)) if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) + # >>> START AH Changes <<< rel_att = self.disentangled_attention_bias( query_layer, orig_key_layer, relative_pos, rel_embeddings, scale_factor ) + # >>> END AH Changes <<< if rel_att is not None: - rel_att_padded = torch.zeros_like(attention_scores) - rel_att_padded[:, :, -rel_att.size(2) :] = rel_att - attention_scores = attention_scores + rel_att_padded + # >>> START AH Changes <<< + # rel_att is set to 0 by default, i.e. rel_att is always not None (don't know why HuggingFace does this). + # Hence, we must check whether rel_att is a tensor and if so, pad it with zeros to be able to add it to attention_scores. + if isinstance(rel_att, torch.Tensor): + rel_att_padded = torch.zeros_like(attention_scores) + rel_att_padded[:, :, -rel_att.size(2) :] = rel_att + attention_scores = attention_scores + rel_att_padded + else: + attention_scores = attention_scores + rel_att + # >>> END AH Changes <<< + attention_scores = attention_scores attention_scores = attention_scores.view( -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) ) + attention_mask = attention_mask.bool() + attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min) # bsz x height x length x dimension - attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + attention_probs.masked_fill(attention_mask, 0) + attention_probs = self.dropout(attention_probs) context_layer = torch.bmm( attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer @@ -150,7 +170,6 @@ def forward( ) new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.view(new_context_layer_shape) - if output_attentions: - return (context_layer, attention_probs) - else: - return context_layer + if not output_attentions: + return (context_layer, None) + return (context_layer, attention_probs) From 9edc20d37e7a14e5513266b7da8ab2c1c8a58069 Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 8 Jan 2025 18:06:06 +0100 Subject: [PATCH 2/9] Allow saving, loading and pushing adapter compositions together (#771) Closes #441; closes #747. This PR introduces a set of new methods for saving, loading and pushing entire adapter compositions with one command: - `save_adapter_setup()` - `load_adapter_setup()` - `push_adapter_setup_to_hub()` They require two main params: - `adapter_setup`: the adapter composition to be saved. Identical to what can be specified for `active_adapters` - `head_setup`: for models with heads, the head setup to save along with the adapters. Identical to what can be specified for `active_head` Docs [here](https://github.com/adapter-hub/adapters/blob/04e69957a2bfc8093e2593186f7ebb2e71f88ec9/docs/loading.md#saving-and-loading-adapter-compositions) ### Example ```python model = AutoAdapterModel.from_pretrained("roberta-base") # create a complex setup model.add_adapter("a", config=SeqBnConfig()) model.add_adapter("b", config=SeqBnConfig()) model.add_adapter("c", config=SeqBnConfig()) model.add_adapter_fusion(["a", "b"]) model.add_classification_head("head_a") model.add_classification_head("head_b") adapter_setup = Stack(Fuse("a", "b"), "c") head_setup = BatchSplit("head_a", "head_b", batch_sizes=[1, 1]) model.set_active_adapters(adapter_setup) model.active_head = head_setup # save model.save_adapter_setup("checkpoint", adapter_setup, head_setup=head_setup) # push model.push_adapter_setup_to_hub("calpt/random_adapter_setup_test", adapter_setup, head_setup=head_setup) # re-load # model2 = AutoAdapterModel.from_pretrained("roberta-base") # model2.load_adapter_setup("checkpoint", set_active=True) ``` --------- Co-authored-by: Timo Imhof --- docs/adapter_composition.md | 2 + docs/loading.md | 36 ++++ docs/quickstart.md | 2 +- src/adapters/composition.py | 35 ++++ src/adapters/hub_mixin.py | 95 ++++++++- src/adapters/model_mixin.py | 278 ++++++++++++++++++++++++++- src/adapters/utils.py | 11 +- tests/methods/test_adapter_common.py | 43 +++++ tests/test_clip.py | 3 + 9 files changed, 497 insertions(+), 8 deletions(-) diff --git a/docs/adapter_composition.md b/docs/adapter_composition.md index 0e35f7b21b..b55dccef1c 100644 --- a/docs/adapter_composition.md +++ b/docs/adapter_composition.md @@ -125,6 +125,8 @@ model.active_adapters = ac.Fuse("d", "e", "f") To learn how training an _AdapterFusion_ layer works, check out [this Colab notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/03_Adapter_Fusion.ipynb) from the `adapters` repo. +To save and upload the full composition setup with adapters and fusion layer in one line of code, check out the docs on [saving and loading adapter compositions](loading.md#saving-and-loading-adapter-compositions). + ### Retrieving AdapterFusion attentions Finally, it is possible to retrieve the attention scores computed by each fusion layer in a forward pass of the model. diff --git a/docs/loading.md b/docs/loading.md index 8af81820d9..a1a37ed6d8 100644 --- a/docs/loading.md +++ b/docs/loading.md @@ -94,3 +94,39 @@ We will go through the different arguments and their meaning one by one: To load the adapter using a custom name, we can use the `load_as` parameter. - Finally, `set_active` will directly activate the loaded adapter for usage in each model forward pass. Otherwise, you have to manually activate the adapter via `set_active_adapters()`. + +## Saving and loading adapter compositions + +In addition to saving and loading individual adapters, you can also save, load and share entire [compositions of adapters](adapter_composition.md) with a single line of code. +_Adapters_ provides three methods for this purpose that work very similar to those for single adapters: + +- [`save_adapter_setup()`](adapters.ModelWithHeadsAdaptersMixin.save_adapter_setup) to save an adapter composition along with prediction heads to the local file system. +- [`load_adapter_setup()`](adapters.ModelWithHeadsAdaptersMixin.load_adapter_setup) to load a saved adapter composition from the local file system or the Model Hub. +- [`push_adapter_setup_to_hub()`](adapters.hub_mixin.PushAdapterToHubMixin.push_adapter_setup_to_hub) to upload an adapter setup along with prediction heads to the Model Hub. See our [Hugging Face Model Hub guide](huggingface_hub.md) for more. + +As an example, this is how you would save and load an AdapterFusion setup of three adapters with a prediction head: + +```python +# Create an AdapterFusion +model = AutoAdapterModel.from_pretrained("bert-base-uncased") +model.load_adapter("sentiment/sst-2@ukp", config=SeqBnConfig(), with_head=False) +model.load_adapter("nli/multinli@ukp", config=SeqBnConfig(), with_head=False) +model.load_adapter("sts/qqp@ukp", config=SeqBnConfig(), with_head=False) +model.add_adapter_fusion(["sst-2", "mnli", "qqp"]) +model.add_classification_head("clf_head") +adapter_setup = Fuse("sst-2", "mnli", "qqp") +head_setup = "clf_head" +model.set_active_adapters(adapter_setup) +model.active_head = head_setup + +# Train AdapterFusion ... + +# Save +model.save_adapter_setup("checkpoint", adapter_setup, head_setup=head_setup) + +# Push to Hub +model.push_adapter_setup_to_hub("/fusion_setup", adapter_setup, head_setup=head_setup) + +# Re-load +# model.load_adapter_setup("checkpoint", set_active=True) +``` diff --git a/docs/quickstart.md b/docs/quickstart.md index 9cefe33cc1..6e8b7fd49f 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -105,7 +105,7 @@ model = AutoAdapterModel.from_pretrained(example_path) model.load_adapter(example_path) ``` -Similar to how the weights of the full model are saved, the `save_adapter()` will create a file for saving the adapter weights and a file for saving the adapter configuration in the specified directory. +Similar to how the weights of the full model are saved, [`save_adapter()`](adapters.ModelWithHeadsAdaptersMixin.save_adapter) will create a file for saving the adapter weights and a file for saving the adapter configuration in the specified directory. Finally, if we have finished working with adapters, we can restore the base Transformer to its original form by deactivating and deleting the adapter: diff --git a/src/adapters/composition.py b/src/adapters/composition.py index 48a6bc8acf..a44b9c5aac 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -1,4 +1,5 @@ import itertools +import sys import warnings from collections.abc import Sequence from typing import List, Optional, Set, Tuple, Union @@ -45,6 +46,31 @@ def parallel_channels(self): def flatten(self) -> Set[str]: return set(itertools.chain(*[[b] if isinstance(b, str) else b.flatten() for b in self.children])) + def _get_save_kwargs(self): + return None + + def to_dict(self): + save_dict = { + "type": self.__class__.__name__, + "children": [ + c.to_dict() if isinstance(c, AdapterCompositionBlock) else {"type": "single", "children": [c]} + for c in self.children + ], + } + if kwargs := self._get_save_kwargs(): + save_dict["kwargs"] = kwargs + return save_dict + + @classmethod + def from_dict(cls, data): + children = [] + for child in data["children"]: + if child["type"] == "single": + children.append(child["children"][0]) + else: + children.append(cls.from_dict(child)) + return getattr(sys.modules[__name__], data["type"])(*children, **data.get("kwargs", {})) + class Parallel(AdapterCompositionBlock): def __init__(self, *parallel_adapters: List[str]): @@ -80,12 +106,18 @@ def __init__(self, *split_adapters: List[Union[AdapterCompositionBlock, str]], s super().__init__(*split_adapters) self.splits = splits if isinstance(splits, list) else [splits] * len(split_adapters) + def _get_save_kwargs(self): + return {"splits": self.splits} + class BatchSplit(AdapterCompositionBlock): def __init__(self, *split_adapters: List[Union[AdapterCompositionBlock, str]], batch_sizes: Union[List[int], int]): super().__init__(*split_adapters) self.batch_sizes = batch_sizes if isinstance(batch_sizes, list) else [batch_sizes] * len(split_adapters) + def _get_save_kwargs(self): + return {"batch_sizes": self.batch_sizes} + class Average(AdapterCompositionBlock): def __init__( @@ -105,6 +137,9 @@ def __init__( else: self.weights = [1 / len(average_adapters)] * len(average_adapters) + def _get_save_kwargs(self): + return {"weights": self.weights} + # Mapping each composition block type to the allowed nested types ALLOWED_NESTINGS = { diff --git a/src/adapters/hub_mixin.py b/src/adapters/hub_mixin.py index c23c92eb7e..61942426d9 100644 --- a/src/adapters/hub_mixin.py +++ b/src/adapters/hub_mixin.py @@ -4,6 +4,8 @@ from transformers.utils.generic import working_or_temp_dir +from .composition import AdapterCompositionBlock + logger = logging.getLogger(__name__) @@ -35,7 +37,7 @@ from adapters import AutoAdapterModel model = AutoAdapterModel.from_pretrained("{model_name}") -adapter_name = model.load_adapter("{adapter_repo_name}", set_active=True) +adapter_name = model.{load_fn}("{adapter_repo_name}", set_active=True) ``` ## Architecture & Training @@ -66,6 +68,7 @@ def _save_adapter_card( language: Optional[str] = None, license: Optional[str] = None, metrics: Optional[List[str]] = None, + load_fn: str = "load_adapter", **kwargs, ): # Key remains "adapter-transformers", see: https://github.com/huggingface/huggingface.js/pull/459 @@ -103,6 +106,7 @@ def _save_adapter_card( model_name=self.model_name, dataset_name=dataset_name, head_info=head_info, + load_fn=load_fn, adapter_repo_name=adapter_repo_name, architecture_training=kwargs.pop("architecture_training", DEFAULT_TEXT), results=kwargs.pop("results", DEFAULT_TEXT), @@ -133,8 +137,6 @@ def push_adapter_to_hub( Args: repo_id (str): The name of the repository on the model hub to upload to. adapter_name (str): The name of the adapter to be uploaded. - organization (str, optional): Organization in which to push the adapter - (you must be a member of this organization). Defaults to None. datasets_tag (str, optional): Dataset identifier from https://huggingface.co/datasets. Defaults to None. local_path (str, optional): Local path used as clone directory of the adapter repository. @@ -156,6 +158,8 @@ def push_adapter_to_hub( Branch to push the uploaded files to. commit_description (`str`, *optional*): The description of the commit that will be created + adapter_card_kwargs (Optional[dict], optional): Additional arguments to pass to the adapter card text generation. + Currently includes: tags, language, license, metrics, architecture_training, results, citation. Returns: str: The url of the adapter repository on the model hub. @@ -190,3 +194,88 @@ def push_adapter_to_hub( revision=revision, commit_description=commit_description, ) + + def push_adapter_setup_to_hub( + self, + repo_id: str, + adapter_setup: Union[str, list, AdapterCompositionBlock], + head_setup: Optional[Union[bool, str, list, AdapterCompositionBlock]] = None, + datasets_tag: Optional[str] = None, + local_path: Optional[str] = None, + commit_message: Optional[str] = None, + private: Optional[bool] = None, + token: Optional[Union[bool, str]] = None, + overwrite_adapter_card: bool = False, + create_pr: bool = False, + revision: str = None, + commit_description: str = None, + adapter_card_kwargs: Optional[dict] = None, + ): + """Upload an adapter setup to HuggingFace's Model Hub. + + Args: + repo_id (str): The name of the repository on the model hub to upload to. + adapter_setup (Union[str, list, AdapterCompositionBlock]): The adapter setup to be uploaded. Usually an adapter composition block. + head_setup (Optional[Union[bool, str, list, AdapterCompositionBlock]], optional): The head setup to be uploaded. + datasets_tag (str, optional): Dataset identifier from https://huggingface.co/datasets. Defaults to + None. + local_path (str, optional): Local path used as clone directory of the adapter repository. + If not specified, will create a temporary directory. Defaults to None. + commit_message (:obj:`str`, `optional`): + Message to commit while pushing. Will default to :obj:`"add config"`, :obj:`"add tokenizer"` or + :obj:`"add model"` depending on the type of the class. + private (:obj:`bool`, `optional`): + Whether or not the repository created should be private (requires a paying subscription). + token (`bool` or `str`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url` + is not specified. + overwrite_adapter_card (bool, optional): Overwrite an existing adapter card with a newly generated one. + If set to `False`, will only generate an adapter card, if none exists. Defaults to False. + create_pr (bool, optional): + Whether or not to create a PR with the uploaded files or directly commit. + revision (`str`, *optional*): + Branch to push the uploaded files to. + commit_description (`str`, *optional*): + The description of the commit that will be created + adapter_card_kwargs (Optional[dict], optional): Additional arguments to pass to the adapter card text generation. + Currently includes: tags, language, license, metrics, architecture_training, results, citation. + + Returns: + str: The url of the adapter repository on the model hub. + """ + use_temp_dir = not os.path.isdir(local_path) if local_path else True + + # Create repo or get retrieve an existing repo + repo_id = self._create_repo(repo_id, private=private, token=token) + + # Commit and push + logger.info('Pushing adapter setup "%s" to model hub at %s ...', adapter_setup, repo_id) + with working_or_temp_dir(working_dir=local_path, use_temp_dir=use_temp_dir) as work_dir: + files_timestamps = self._get_files_timestamps(work_dir) + # Save adapter and optionally create model card + if head_setup is not None: + save_kwargs = {"head_setup": head_setup} + else: + save_kwargs = {} + self.save_adapter_setup(work_dir, adapter_setup, **save_kwargs) + if overwrite_adapter_card or not os.path.exists(os.path.join(work_dir, "README.md")): + adapter_card_kwargs = adapter_card_kwargs or {} + self._save_adapter_card( + work_dir, + str(adapter_setup), + repo_id, + datasets_tag=datasets_tag, + load_fn="load_adapter_setup", + **adapter_card_kwargs, + ) + return self._upload_modified_files( + work_dir, + repo_id, + files_timestamps, + commit_message=commit_message, + token=token, + create_pr=create_pr, + revision=revision, + commit_description=commit_description, + ) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 659a6cfcff..3154af5ac8 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1,4 +1,5 @@ import inspect +import json import logging import os from abc import ABC, abstractmethod @@ -15,6 +16,7 @@ from transformers.modeling_outputs import ModelOutput from transformers.utils import is_accelerate_available +from . import __version__ from .composition import AdapterCompositionBlock, Fuse, Stack, parse_composition from .configuration import ADAPTER_CONFIG_MAP, AdapterConfig, AdapterFusionConfig, BnConfig from .context import AdapterSetup, ForwardContext @@ -27,7 +29,15 @@ from .methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool from .methods.prompt_tuning import PromptTuningLayer from .methods.reft import init_reft -from .utils import EMBEDDING_FILE, TOKENIZER_PATH, get_adapter_config_hash, inherit_doc, patch_forward +from .utils import ( + EMBEDDING_FILE, + SETUP_CONFIG_NAME, + TOKENIZER_PATH, + get_adapter_config_hash, + inherit_doc, + patch_forward, + resolve_adapter_path, +) from .wrappers.configuration import SUBMODEL_NAMES, init_adapters_config @@ -802,7 +812,7 @@ def load_adapter( adapter_name_or_path (str): can be either: - the identifier of a pre-trained task adapter to be loaded from Adapter Hub - - a path to a directory containing adapter weights saved using `model.saved_adapter()` + - a path to a directory containing adapter weights saved using `model.save_adapter()` - a URL pointing to a zip folder containing a saved adapter module config (dict or str, optional): Deprecated. version (str, optional): The version of the adapter to be loaded. @@ -881,6 +891,161 @@ def load_adapter_fusion( ) return load_name + def _save_adapter_setup_config( + self, + save_directory: str, + adapter_setup: AdapterCompositionBlock, + head_setup: Optional[Union[bool, str, list, AdapterCompositionBlock]] = None, + ): + setup_config = { + "adapter_setup": adapter_setup.to_dict(), + "head_setup": head_setup.to_dict() if isinstance(head_setup, AdapterCompositionBlock) else head_setup, + "version": "adapters." + __version__, + } + with open(join(save_directory, SETUP_CONFIG_NAME), "w") as f: + json.dump(setup_config, f, indent=2) + + def _load_adapter_setup_config( + self, load_directory: str + ) -> Tuple[AdapterCompositionBlock, Optional[AdapterCompositionBlock]]: + with open(join(load_directory, SETUP_CONFIG_NAME), "r") as f: + setup_config = json.load(f) + adapter_setup = AdapterCompositionBlock.from_dict(setup_config["adapter_setup"]) + head_setup = setup_config["head_setup"] + if isinstance(head_setup, dict): + head_setup = AdapterCompositionBlock.from_dict(head_setup) + return adapter_setup, head_setup + + def _save_adapter_setup_weights( + self, + save_directory: str, + adapter_setup: AdapterCompositionBlock, + meta_dict: dict = None, + custom_weights_loaders: Optional[List[WeightsLoader]] = None, + use_safetensors: bool = False, + ): + # Save single adapters + for adapter_name in adapter_setup.flatten(): + save_path = join(save_directory, adapter_name) + self.save_adapter(save_path, adapter_name, meta_dict=meta_dict, use_safetensors=use_safetensors) + # Save adapter fusions + fusions = [] + if isinstance(adapter_setup, Fuse): + fusions.append(adapter_setup) + for child_setup in adapter_setup.children: + if isinstance(child_setup, Fuse): + fusions.append(child_setup) + for fusion in fusions: + save_path = join(save_directory, fusion.name) + self.save_adapter_fusion(save_path, fusion, meta_dict=meta_dict, use_safetensors=use_safetensors) + # Save additional custom weights + if custom_weights_loaders: + for weights_loader in custom_weights_loaders: + weights_loader.save(save_directory, adapter_name) + + def _load_adapter_setup_weights( + self, + load_directory: str, + adapter_setup: AdapterCompositionBlock, + custom_weights_loaders: Optional[List[WeightsLoader]] = None, + set_active: bool = False, + use_safetensors: bool = False, + ): + # Load single adapters + for adapter_name in adapter_setup.flatten(): + save_path = join(load_directory, adapter_name) + self.load_adapter(save_path, use_safetensors=use_safetensors) + # Load adapter fusions + fusions = [] + if isinstance(adapter_setup, Fuse): + fusions.append(adapter_setup) + for child_setup in adapter_setup.children: + if isinstance(child_setup, Fuse): + fusions.append(child_setup) + for fusion in fusions: + save_path = join(load_directory, fusion.name) + self.load_adapter_fusion(save_path, use_safetensors=use_safetensors) + # Load additional custom weights + if custom_weights_loaders: + for weights_loader in custom_weights_loaders: + weights_loader.load(load_directory) + + if set_active: + self.set_active_adapters(adapter_setup) + + def save_adapter_setup( + self, + save_directory: str, + adapter_setup: Union[str, list, AdapterCompositionBlock], + meta_dict: dict = None, + custom_weights_loaders: Optional[List[WeightsLoader]] = None, + use_safetensors: bool = False, + ): + """Saves an adapter setup to a directory so that it can be shared or reloaded using `load_adapter_setup()`. + + Args: + save_directory (str): Path to a directory where the adapter setup should be saved. + adapter_setup (Union[str, list, AdapterCompositionBlock]): The adapter setup to be saved. Usually an adapter composition block. + use_safetensors (bool, optional): If True, weights are saved via `safetensors`. Otherwise, the regular torch save method is used. + """ + os.makedirs(save_directory, exist_ok=True) + adapter_setup = parse_composition(adapter_setup, model_type=self.config.model_type) + + self._save_adapter_setup_config(save_directory, adapter_setup) + self._save_adapter_setup_weights( + save_directory, + adapter_setup, + meta_dict=meta_dict, + custom_weights_loaders=custom_weights_loaders, + use_safetensors=use_safetensors, + ) + + def load_adapter_setup( + self, + adapter_setup_name_or_path: str, + version: str = None, + custom_weights_loaders: Optional[List[WeightsLoader]] = None, + set_active: bool = False, + use_safetensors: bool = False, + **kwargs, + ) -> Tuple[AdapterCompositionBlock, Any]: + """Loads an adapter setup from the local file system or a remote location. + + Args: + adapter_setup_name_or_path (str): can be either: + + - the identifier of a repository on the HuggingFace Model Hub. + - a path to a directory containing adapter weights saved using `model.save_adapter_setup()` + - a URL pointing to a zip folder containing a saved adapter module + version (str, optional): The version of the adapter to be loaded. + set_active (bool, optional): + Set the loaded adapter setup to be the active one. By default (False), the adapter setup is loaded but not + activated. + use_safetensors (bool, optional): If True, weights are loaded via `safetensors` if safetensors checkpoint is available. Otherwise, the regular torch save method is used. + + Returns: + Tuple[AdapterCompositionBlock, Any]: The loaded adapter setup and the head setup if available. + """ + resolved_folder = resolve_adapter_path( + adapter_setup_name_or_path, + version=version, + do_exists_check=False, + **kwargs, + ) + adapter_setup, head_setup = self._load_adapter_setup_config(resolved_folder) + self._load_adapter_setup_weights( + resolved_folder, + adapter_setup, + custom_weights_loaders=custom_weights_loaders, + set_active=set_active, + use_safetensors=use_safetensors, + ) + + if head_setup: + logger.warning("Loaded adapter setup contains a head setup that is not supported by the current model.") + + return adapter_setup, head_setup + def save_all_adapters( self, save_directory: str, @@ -1857,6 +2022,115 @@ def load_adapter_fusion( **kwargs, ) + def save_adapter_setup( + self, + save_directory: str, + adapter_setup: Union[str, list, AdapterCompositionBlock], + head_setup: Optional[Union[bool, str, list, AdapterCompositionBlock]] = None, + meta_dict: dict = None, + custom_weights_loaders: Optional[List[WeightsLoader]] = None, + use_safetensors: bool = False, + ): + """Saves an adapter setup to a directory so that it can be shared or reloaded using `load_adapter_setup()`. + + Args: + save_directory (str): Path to a directory where the adapter setup should be saved. + adapter_setup (Union[str, list, AdapterCompositionBlock]): The adapter setup to be saved. Usually an adapter composition block. + head_setup (Optional[Union[bool, str, list, AdapterCompositionBlock]], optional): The head setup to be saved. Can be either: + + - True: save the default head for models without flex heads. + - str: save a single head with the given name. + - list: save a list of heads. + - AdapterCompositionBlock: save a custom head setup. + - None (default): do not save any heads. + use_safetensors (bool, optional): If True, weights are saved via `safetensors`. Otherwise, the regular torch save method is used. + """ + os.makedirs(save_directory, exist_ok=True) + adapter_setup = parse_composition(adapter_setup, model_type=self.config.model_type) + + self._save_adapter_setup_config(save_directory, adapter_setup, head_setup) + self._save_adapter_setup_weights( + save_directory, + adapter_setup, + meta_dict=meta_dict, + custom_weights_loaders=custom_weights_loaders, + use_safetensors=use_safetensors, + ) + + if head_setup is True: + self.save_head(save_directory, use_safetensors=use_safetensors) + elif head_setup: + heads_to_save = [] + if isinstance(head_setup, AdapterCompositionBlock): + heads_to_save = head_setup.flatten() + elif isinstance(head_setup, list): + heads_to_save = head_setup + elif isinstance(head_setup, str): + heads_to_save = [head_setup] + for head_name in heads_to_save: + save_path = join(save_directory, head_name) + self.save_head(save_path, head_name, use_safetensors=use_safetensors) + + def load_adapter_setup( + self, + adapter_setup_name_or_path: str, + version: str = None, + custom_weights_loaders: Optional[List[WeightsLoader]] = None, + set_active: bool = False, + use_safetensors: bool = False, + **kwargs, + ) -> str: + """Loads an adapter setup from the local file system or a remote location. + + Args: + adapter_setup_name_or_path (str): can be either: + + - the identifier of a repository on the HuggingFace Model Hub. + - a path to a directory containing adapter weights saved using `model.save_adapter_setup()` + - a URL pointing to a zip folder containing a saved adapter module + version (str, optional): The version of the adapter to be loaded. + set_active (bool, optional): + Set the loaded adapter setup to be the active one. By default (False), the adapter setup is loaded but not + activated. + use_safetensors (bool, optional): If True, weights are loaded via `safetensors` if safetensors checkpoint is available. Otherwise, the regular torch save method is used. + + Returns: + Tuple[AdapterCompositionBlock, Any]: The loaded adapter setup and the head setup if available. + """ + resolved_folder = resolve_adapter_path( + adapter_setup_name_or_path, + version=version, + do_exists_check=False, + **kwargs, + ) + adapter_setup, head_setup = self._load_adapter_setup_config(resolved_folder) + self._load_adapter_setup_weights( + resolved_folder, + adapter_setup, + custom_weights_loaders=custom_weights_loaders, + set_active=set_active, + use_safetensors=use_safetensors, + ) + + if head_setup is True: + self.load_head(resolved_folder, use_safetensors=use_safetensors) + elif head_setup: + heads_to_load = [] + if isinstance(head_setup, AdapterCompositionBlock): + heads_to_load = head_setup.flatten() + elif isinstance(head_setup, list): + heads_to_load = head_setup + elif isinstance(head_setup, str): + heads_to_load = [head_setup] + for head_name in heads_to_load: + save_path = join(resolved_folder, head_name) + self.load_head(save_path, head_name, use_safetensors=use_safetensors) + + if set_active: + self.active_head = head_setup + + return adapter_setup, head_setup + def save_all_heads(self, save_directory: str, use_safetensors: bool = False): """Saves all prediction heads of this model to subfolders of the given location. diff --git a/src/adapters/utils.py b/src/adapters/utils.py index 1103d9fffb..7c0540850a 100644 --- a/src/adapters/utils.py +++ b/src/adapters/utils.py @@ -53,6 +53,7 @@ SAFE_ADAPTERFUSION_WEIGHTS_NAME = "model_adapter_fusion.safetensors" EMBEDDING_FILE = "embedding.pt" TOKENIZER_PATH = "tokenizer" +SETUP_CONFIG_NAME = "adapter_setup.json" ADAPTER_HUB_URL = "https://raw.githubusercontent.com/Adapter-Hub/Hub/master/dist/v2/" ADAPTER_HUB_INDEX_FILE = ADAPTER_HUB_URL + "index/{}.json" @@ -671,6 +672,7 @@ def resolve_adapter_path( model_name: str = None, adapter_config: Union[dict, str] = None, version: str = None, + do_exists_check: bool = True, **kwargs, ) -> str: """ @@ -701,8 +703,13 @@ def resolve_adapter_path( # path to a local folder saved using save() elif isdir(adapter_name_or_path): if ( - isfile(join(adapter_name_or_path, WEIGHTS_NAME)) or isfile(join(adapter_name_or_path, SAFE_WEIGHTS_NAME)) - ) and isfile(join(adapter_name_or_path, CONFIG_NAME)): + not do_exists_check + or ( + isfile(join(adapter_name_or_path, WEIGHTS_NAME)) + or isfile(join(adapter_name_or_path, SAFE_WEIGHTS_NAME)) + ) + and isfile(join(adapter_name_or_path, CONFIG_NAME)) + ): return adapter_name_or_path else: raise EnvironmentError( diff --git a/tests/methods/test_adapter_common.py b/tests/methods/test_adapter_common.py index 1ea6cd6f37..717d3af98e 100644 --- a/tests/methods/test_adapter_common.py +++ b/tests/methods/test_adapter_common.py @@ -1,4 +1,5 @@ import copy +import os import tempfile import torch @@ -17,8 +18,10 @@ MAMConfig, SeqBnConfig, SeqBnInvConfig, + Stack, ) from adapters.heads.language_modeling import CausalLMHead +from adapters.utils import SETUP_CONFIG_NAME from transformers import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, CLIPConfig from transformers.testing_utils import require_torch, torch_device @@ -475,3 +478,43 @@ def test_batch_split_training(self): base_with_change |= not torch.equal(v1, v2) self.assertTrue(adapters_with_change) self.assertFalse(base_with_change) + + def test_load_adapter_setup(self): + if self.config_class not in ADAPTER_MODEL_MAPPING: + self.skipTest("Does not support flex heads.") + model1, model2 = create_twin_models(self.model_class, self.config) + + # Create a complex setup + model1.add_adapter("a", config=SeqBnConfig()) + model1.add_adapter("b", config=SeqBnConfig()) + model1.add_adapter("c", config=SeqBnConfig()) + model1.add_adapter_fusion(["a", "b"]) + self.add_head(model1, "head_a") + self.add_head(model1, "head_b") + adapter_setup = Stack(Fuse("a", "b"), "c") + head_setup = BatchSplit("head_a", "head_b", batch_sizes=[2, 1]) + model1.set_active_adapters(adapter_setup) + model1.active_head = head_setup + + with tempfile.TemporaryDirectory() as temp_dir: + model1.save_adapter_setup(temp_dir, adapter_setup, head_setup=head_setup) + + self.assertTrue(os.path.exists(os.path.join(temp_dir, SETUP_CONFIG_NAME))) + + # also tests that set_active works + model2.load_adapter_setup(temp_dir, set_active=True) + + # check if adapter was correctly loaded + for name in ["a", "b", "c"]: + self.assertTrue(name in model2.adapters_config) + self.assertEqual(adapter_setup, model2.active_adapters) + + # check equal output + input_data = self.get_input_samples(config=model1.config) + model1.to(torch_device) + model2.to(torch_device) + output1 = model1(**input_data) + output2 = model2(**input_data) + self.assertEqual(len(output1), len(output2)) + self.assertTrue(torch.allclose(output1[0][0], output2[0][0], atol=1e-4)) + self.assertTrue(torch.allclose(output1[1][0], output2[1][0], atol=1e-4)) diff --git a/tests/test_clip.py b/tests/test_clip.py index 30be353f74..ead9c7d561 100644 --- a/tests/test_clip.py +++ b/tests/test_clip.py @@ -225,3 +225,6 @@ class CLIPAdapterTest( def test_adapter_fusion_save_with_head(self): # This test is not applicable to CLIP self.skipTest("Not applicable to CLIP.") + + def test_load_adapter_setup(self): + self.skipTest("Not applicable to CLIP.") From 303c34bdd91f37e656fca5f60624d86b991def3c Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 8 Jan 2025 18:29:24 +0100 Subject: [PATCH 3/9] Custom names for AdapterFusion layers (#774) Resolves #316. This PR implements the option to specify a custom name for an added AdapterFusion layer. The name can be specified when adding a fusion layer like this: ```python model.add_adapter_fusion(["adapter1", "adapter2"], name="custom_name_fusion") ``` Afterwards, to address the custom-name fusion, specify the name in the `Fuse` block. E.g. for activation: ```python model.set_active_adapters(Fuse("adapter1", "adapter2", name="custom_name_fusion")) ``` Some fusion-specific methods can either take the named `Fuse` block or directly the fusion name: ```python # saving model.save_adapter_fusion("./checkpoint_dir", Fuse("adapter1", "adapter2", name="custom_name_fusion")) # or: # model.save_adapter_fusion("./checkpoint_dir", "custom_name_fusion") # deleting model.delete_adapter_fusion(Fuse("adapter1", "adapter2", name="custom_name_fusion")) # or: # model.delete_adapter_fusion("custom_name_fusion") ``` --------- Co-authored-by: Timo Imhof --- src/adapters/composition.py | 8 +- .../configuration/model_adapters_config.py | 30 +++++-- src/adapters/loading.py | 12 ++- src/adapters/methods/bottleneck.py | 8 +- src/adapters/model_mixin.py | 27 +++--- tests/test_adapter_fusion_common.py | 83 +++++++++++++++++++ 6 files changed, 140 insertions(+), 28 deletions(-) diff --git a/src/adapters/composition.py b/src/adapters/composition.py index a44b9c5aac..6c17fb8ebd 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -92,13 +92,17 @@ def __init__(self, *stack_layers: List[Union[AdapterCompositionBlock, str]]): class Fuse(AdapterCompositionBlock): - def __init__(self, *fuse_stacks: List[Union[AdapterCompositionBlock, str]]): + def __init__(self, *fuse_stacks: List[Union[AdapterCompositionBlock, str]], name: Optional[str] = None): super().__init__(*fuse_stacks) + self._name = name # TODO-V2 pull this up to all block classes? @property def name(self): - return ",".join([c if isinstance(c, str) else c.last() for c in self.children]) + if self._name: + return self._name + else: + return ",".join([c if isinstance(c, str) else c.last() for c in self.children]) class Split(AdapterCompositionBlock): diff --git a/src/adapters/configuration/model_adapters_config.py b/src/adapters/configuration/model_adapters_config.py index 3ae7dcf56c..f742028b67 100644 --- a/src/adapters/configuration/model_adapters_config.py +++ b/src/adapters/configuration/model_adapters_config.py @@ -1,7 +1,7 @@ import copy import logging from collections.abc import Collection, Mapping -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union from .. import __version__ from ..composition import AdapterCompositionBlock @@ -27,6 +27,7 @@ def __init__(self, **kwargs): self.fusions: Mapping[str, str] = kwargs.pop("fusions", {}) self.fusion_config_map = kwargs.pop("fusion_config_map", {}) + self.fusion_name_map = kwargs.pop("fusion_name_map", {}) # TODO-V2 Save this with config? self.active_setup: Optional[AdapterCompositionBlock] = None @@ -131,7 +132,7 @@ def add(self, adapter_name: str, config: Optional[Union[str, dict]] = None): self.adapters[adapter_name] = config_name logger.info(f"Adding adapter '{adapter_name}'.") - def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]: + def get_fusion(self, fusion_name: Union[str, List[str]]) -> Tuple[Optional[dict], Optional[list]]: """ Gets the config dictionary for a given AdapterFusion. @@ -140,6 +141,7 @@ def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]: Returns: Optional[dict]: The AdapterFusion configuration. + Optional[list]: The names of the adapters to fuse. """ if isinstance(fusion_name, list): fusion_name = ",".join(fusion_name) @@ -149,20 +151,31 @@ def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]: config = self.fusion_config_map.get(config_name, None) else: config = ADAPTERFUSION_CONFIG_MAP.get(config_name, None) + + if fusion_name in self.fusion_name_map: + adapter_names = self.fusion_name_map[fusion_name] + else: + adapter_names = fusion_name.split(",") + + return config, adapter_names else: - config = None - return config + return None, None - def add_fusion(self, fusion_name: Union[str, List[str]], config: Optional[Union[str, dict]] = None): + def add_fusion( + self, adapter_names: List[str], config: Optional[Union[str, dict]] = None, fusion_name: Optional[str] = None + ): """ Adds a new AdapterFusion. Args: - fusion_name (Union[str, List[str]]): The name of the AdapterFusion or the adapters to fuse. + adapter_names (List[str]): The names of the adapters to fuse. config (Optional[Union[str, dict]], optional): AdapterFusion config. Defaults to None. + fusion_name (Optional[str], optional): The name of the AdapterFusion. If not specified, will default to comma-separated adapter names. """ - if isinstance(fusion_name, list): - fusion_name = ",".join(fusion_name) + if fusion_name is None: + fusion_name = ",".join(adapter_names) + else: + self.fusion_name_map[fusion_name] = adapter_names if fusion_name in self.fusions: raise ValueError(f"An AdapterFusion with the name '{fusion_name}' has already been added.") if config is None: @@ -218,6 +231,7 @@ def to_dict(self): output_dict["fusion_config_map"][k] = v.to_dict() else: output_dict["fusion_config_map"][k] = copy.deepcopy(v) + output_dict["fusion_name_map"] = copy.deepcopy(self.fusion_name_map) return output_dict def __eq__(self, other): diff --git a/src/adapters/loading.py b/src/adapters/loading.py index 69747e04cb..55ba1db45b 100644 --- a/src/adapters/loading.py +++ b/src/adapters/loading.py @@ -639,7 +639,7 @@ def save_to_state_dict(self, name: str): if name not in self.model.adapters_config.fusions: raise ValueError(f"No AdapterFusion with name '{name}' available.") - adapter_fusion_config = self.model.adapters_config.get_fusion(name) + adapter_fusion_config, _ = self.model.adapters_config.get_fusion(name) config_dict = build_full_config( adapter_fusion_config, @@ -676,13 +676,14 @@ def save(self, save_directory: str, name: str, meta_dict=None): else: assert isdir(save_directory), "Saving path should be a directory where the head can be saved." - adapter_fusion_config = self.model.adapters_config.get_fusion(name) + adapter_fusion_config, adapter_names = self.model.adapters_config.get_fusion(name) # Save the adapter fusion configuration config_dict = build_full_config( adapter_fusion_config, self.model.config, name=name, + adapter_names=adapter_names, model_name=self.model.model_name, model_class=self.model.__class__.__name__, ) @@ -746,9 +747,14 @@ def load(self, save_directory, load_as=None, loading_info=None, **kwargs): config = self.weights_helper.load_weights_config(save_directory) adapter_fusion_name = load_as or config["name"] + adapter_names = config.get("adapter_names", adapter_fusion_name) if adapter_fusion_name not in self.model.adapters_config.fusions: self.model.add_adapter_fusion( - adapter_fusion_name, config["config"], overwrite_ok=True, set_active=kwargs.pop("set_active", True) + adapter_names, + config["config"], + name=adapter_fusion_name, + overwrite_ok=True, + set_active=kwargs.pop("set_active", True), ) else: logger.warning("Overwriting existing adapter fusion module '{}'".format(adapter_fusion_name)) diff --git a/src/adapters/methods/bottleneck.py b/src/adapters/methods/bottleneck.py index ff12a91cd7..889941d2b9 100644 --- a/src/adapters/methods/bottleneck.py +++ b/src/adapters/methods/bottleneck.py @@ -96,9 +96,9 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: def add_fusion_layer(self, adapter_names: Union[List, str]): """See BertModel.add_fusion_layer""" - adapter_names = adapter_names if isinstance(adapter_names, list) else adapter_names.split(",") + fusion_name = ",".join(adapter_names) if isinstance(adapter_names, list) else adapter_names + fusion_config, adapter_names = self.adapters_config.get_fusion(fusion_name) if self.adapters_config.common_config_value(adapter_names, self.location_key): - fusion_config = self.adapters_config.get_fusion(adapter_names) dropout_prob = fusion_config.dropout_prob or getattr(self.model_config, "attention_probs_dropout_prob", 0) fusion = BertFusion( fusion_config, @@ -106,7 +106,7 @@ def add_fusion_layer(self, adapter_names: Union[List, str]): dropout_prob, ) fusion.train(self.training) # make sure training mode is consistent - self.adapter_fusion_layer[",".join(adapter_names)] = fusion + self.adapter_fusion_layer[fusion_name] = fusion def delete_fusion_layer(self, adapter_names: Union[List, str]): adapter_names = adapter_names if isinstance(adapter_names, str) else ",".join(adapter_names) @@ -223,7 +223,7 @@ def compose_fuse(self, adapter_setup: Fuse, state: BottleneckState, lvl: int = 0 context = ForwardContext.get_context() # config of _last_ fused adapter is significant - fusion_config = self.adapters_config.get_fusion(adapter_setup.name) + fusion_config, _ = self.adapters_config.get_fusion(adapter_setup.name) last = adapter_setup.last() last_adapter = self.adapters[last] hidden_states, query, residual = last_adapter.pre_forward( diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 3154af5ac8..62de6178ac 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -638,6 +638,7 @@ def add_adapter_fusion( self, adapter_names: Union[Fuse, list, str], config=None, + name: str = None, overwrite_ok: bool = False, set_active: bool = False, ): @@ -655,6 +656,8 @@ def add_adapter_fusion( - a string identifying a pre-defined adapter fusion configuration - a dictionary representing the adapter fusion configuration - the path to a file containing the adapter fusion configuration + name (str, optional): + Name of the AdapterFusion layer. If not specified, the name is generated automatically from the fused adapter names. overwrite_ok (bool, optional): Overwrite an AdapterFusion layer with the same name if it exists. By default (False), an exception is thrown. @@ -662,22 +665,24 @@ def add_adapter_fusion( Activate the added AdapterFusion. By default (False), the AdapterFusion is added but not activated. """ if isinstance(adapter_names, Fuse): + if name is None: + name = adapter_names.name adapter_names = adapter_names.children elif isinstance(adapter_names, str): adapter_names = adapter_names.split(",") + if name is None: + name = ",".join(adapter_names) if isinstance(config, dict): config = AdapterFusionConfig.from_dict(config) # ensure config is ok and up-to-date # In case adapter already exists and we allow overwriting, explicitly delete the existing one first - if overwrite_ok and self.adapters_config.get_fusion(adapter_names) is not None: - self.delete_adapter_fusion(adapter_names) - self.adapters_config.add_fusion(adapter_names, config=config) - self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(adapter_names)) - self.apply_to_basemodel_childs(lambda i, child: child.add_fusion_layer(adapter_names)) + if overwrite_ok and self.adapters_config.get_fusion(name)[0] is not None: + self.delete_adapter_fusion(name) + self.adapters_config.add_fusion(adapter_names, config=config, fusion_name=name) + self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(name)) + self.apply_to_basemodel_childs(lambda i, child: child.add_fusion_layer(name)) if set_active: - if not isinstance(adapter_names, list): - adapter_names = adapter_names.split(",") - self.set_active_adapters(Fuse(*adapter_names)) + self.set_active_adapters(Fuse(*adapter_names, name=name)) def delete_adapter(self, adapter_name: str): """ @@ -710,7 +715,7 @@ def delete_adapter_fusion(self, adapter_names: Union[Fuse, list, str]): adapter_names (Union[Fuse, list, str]): AdapterFusion layer to delete. """ if isinstance(adapter_names, Fuse): - adapter_fusion_name = ",".join(adapter_names.children) + adapter_fusion_name = adapter_names.name elif isinstance(adapter_names, list): adapter_fusion_name = ",".join(adapter_names) elif isinstance(adapter_names, str): @@ -776,7 +781,7 @@ def save_adapter_fusion( ValueError: If the given AdapterFusion name is invalid. """ if isinstance(adapter_names, Fuse): - adapter_fusion_name = ",".join(adapter_names.children) + adapter_fusion_name = adapter_names.name elif isinstance(adapter_names, list): adapter_fusion_name = ",".join(adapter_names) elif isinstance(adapter_names, str): @@ -1094,7 +1099,7 @@ def save_all_adapter_fusions( """ os.makedirs(save_directory, exist_ok=True) for name in self.adapters_config.fusions: - adapter_fusion_config = self.adapters_config.get_fusion(name) + adapter_fusion_config, _ = self.adapters_config.get_fusion(name) h = get_adapter_config_hash(adapter_fusion_config) save_path = join(save_directory, name) if meta_dict: diff --git a/tests/test_adapter_fusion_common.py b/tests/test_adapter_fusion_common.py index ccc860f667..695808eb24 100644 --- a/tests/test_adapter_fusion_common.py +++ b/tests/test_adapter_fusion_common.py @@ -214,3 +214,86 @@ def test_output_adapter_fusion_attentions(self): self.assertEqual(len(per_layer_scores), 1) for k, v in per_layer_scores.items(): self.assertEqual(self.default_input_samples_shape[0], v.shape[0], k) + + def test_add_adapter_fusion_custom_name(self): + config_name = "seq_bn" + model = self.get_model() + model.eval() + + name1 = f"{config_name}-1" + name2 = f"{config_name}-2" + model.add_adapter(name1, config=config_name) + model.add_adapter(name2, config=config_name) + + # adapter is correctly added to config + self.assertTrue(name1 in model.adapters_config) + self.assertTrue(name2 in model.adapters_config) + + # add fusion with default name + model.add_adapter_fusion([name1, name2]) + model.to(torch_device) + + # check forward pass + input_data = self.get_input_samples(config=model.config) + model.set_active_adapters(Fuse(name1, name2)) + fusion_default_ref_output = model(**input_data) + + # add fusion with custom name + model.add_adapter_fusion([name1, name2], name="custom_name_fusion") + model.to(torch_device) + + self.assertIn(f"{name1},{name2}", model.adapters_config.fusions) + self.assertIn("custom_name_fusion", model.adapters_config.fusions) + self.assertIn("custom_name_fusion", model.adapters_config.fusion_name_map) + + # check forward pass + model.set_active_adapters(Fuse(name1, name2, name="custom_name_fusion")) + fusion_custom_output = model(**input_data) + model.set_active_adapters(Fuse(name1, name2)) + fusion_default_output = model(**input_data) + model.set_active_adapters(None) + base_output = model(**input_data) + + self.assertFalse(torch.equal(fusion_default_ref_output[0], base_output[0])) + self.assertTrue(torch.equal(fusion_default_ref_output[0], fusion_default_output[0])) + self.assertFalse(torch.equal(fusion_custom_output[0], fusion_default_output[0])) + self.assertFalse(torch.equal(fusion_custom_output[0], base_output[0])) + + # delete only the custom fusion + model.delete_adapter_fusion(Fuse(name1, name2, name="custom_name_fusion")) + # model.delete_adapter_fusion("custom_name_fusion") + + self.assertIn(f"{name1},{name2}", model.adapters_config.fusions) + self.assertNotIn("custom_name_fusion", model.adapters_config.fusions) + + def test_load_adapter_fusion_custom_name(self): + model1 = self.get_model() + model1.eval() + + name1 = "name1" + name2 = "name2" + model1.add_adapter(name1) + model1.add_adapter(name2) + + model2 = copy.deepcopy(model1) + model2.eval() + + model1.add_adapter_fusion([name1, name2], name="custom_name_fusion") + model1.set_active_adapters(Fuse(name1, name2, name="custom_name_fusion")) + + with tempfile.TemporaryDirectory() as temp_dir: + model1.save_adapter_fusion(temp_dir, "custom_name_fusion") + # also tests that set_active works + model2.load_adapter_fusion(temp_dir, set_active=True) + + # check if adapter was correctly loaded + self.assertEqual(model1.adapters_config.fusions.keys(), model2.adapters_config.fusions.keys()) + + # check equal output + in_data = self.get_input_samples(config=model1.config) + model1.to(torch_device) + model2.to(torch_device) + output1 = model1(**in_data) + output2 = model2(**in_data) + self.assertEqual(len(output1), len(output2)) + self.assertTrue(torch.equal(output1[0], output2[0])) From 66ff58305ca5fb2d42e6235f769427830e6734d1 Mon Sep 17 00:00:00 2001 From: calpt Date: Sat, 18 Jan 2025 20:13:21 +0100 Subject: [PATCH 4/9] ReFT generate & orthogonal fixes (#778) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolves #772. Provides the following fixes for the ReFT implementation: - support cases where seq_len < prefix/ suffix position (fixes issues with seq generation) - ensure orthogonal projection is always initialized with float32 (as half precision is not supported) --------- Co-authored-by: Leon Engländer --- src/adapters/methods/reft.py | 24 ++++++++++++++++------- tests/methods/base.py | 27 ++++++++++++++++++++++++++ tests/methods/test_compacter.py | 30 +++-------------------------- tests/methods/test_prefix_tuning.py | 28 ++------------------------- tests/methods/test_reft.py | 3 +++ 5 files changed, 52 insertions(+), 60 deletions(-) diff --git a/src/adapters/methods/reft.py b/src/adapters/methods/reft.py index 9c6647e399..a847b2b60b 100644 --- a/src/adapters/methods/reft.py +++ b/src/adapters/methods/reft.py @@ -1,3 +1,4 @@ +import logging from typing import List, Optional import torch @@ -9,6 +10,9 @@ from .modeling import Activation_Function_Class +logger = logging.getLogger(__name__) + + class ReftUnit(nn.Module): def __init__( self, @@ -26,6 +30,13 @@ def __init__( projection = nn.Linear(in_dim, r_dim, bias=False, dtype=dtype) if orthogonal: + # orthogonal is not implemented for half precision + if dtype in [torch.float16, torch.bfloat16]: + logger.warning( + "Orthogonal parametrization is not supported for half precision dtypes. Converting REFT projection layer to float32.", + UserWarning, + ) + projection = projection.to(dtype=torch.float32) self.projection = nn.utils.parametrizations.orthogonal(projection) else: self.projection = projection @@ -93,19 +104,18 @@ def _gather_adapted_states(self, hidden_states: torch.Tensor): ) # create indexing matrices for prefixes & suffixes if self.prefix_positions > 0: + real_pref_len = min(self.prefix_positions, hidden_states.size(1)) pref_idx = first_non_padding.view(-1, 1, 1) + ( - torch.arange(self.prefix_positions) - .unsqueeze(-1) - .expand(bsz, self.prefix_positions, ddim) - .to(hidden_states.device) + torch.arange(real_pref_len).unsqueeze(-1).expand(bsz, real_pref_len, ddim).to(hidden_states.device) ) # Cache for next layer context.pref_idx = pref_idx if self.suffix_positions > 0: + real_suff_len = min(self.suffix_positions, hidden_states.size(1)) suff_idx = last_non_padding.view(-1, 1, 1) + ( - torch.arange(-self.suffix_positions, 0) + torch.arange(-real_suff_len, 0) .unsqueeze(-1) - .expand(bsz, self.suffix_positions, ddim) + .expand(bsz, real_suff_len, ddim) .to(hidden_states.device) ) context.suff_idx = suff_idx @@ -131,7 +141,7 @@ def _scatter_adapted_states(self, hidden_states: torch.Tensor, adapted_states: L context = ForwardContext.get_context() # merge prefix, suffix and adapted states - adapted_output = torch.cat(adapted_states, dim=1) + adapted_output = torch.cat(adapted_states, dim=1).to(hidden_states.dtype) if self.prefix_positions > 0: hidden_states = torch.scatter( diff --git a/tests/methods/base.py b/tests/methods/base.py index 0d20f32fef..338cf970aa 100644 --- a/tests/methods/base.py +++ b/tests/methods/base.py @@ -370,3 +370,30 @@ def run_reset_test(self, adapter_config): # check forward pass self.assertEqual(len(output_1), len(output_2)) self.assertTrue(torch.allclose(output_1[0], output_2[0], atol=1e-3)) + + def run_generate_test(self, adapter_config): + if self.config_class not in ADAPTER_MODEL_MAPPING or ( + "seq2seq_lm" not in ADAPTER_MODEL_MAPPING[self.config_class].head_types + and "causal_lm" not in ADAPTER_MODEL_MAPPING[self.config_class].head_types + ): + self.skipTest("No seq2seq or causal language model head") + + model1 = AutoAdapterModel.from_config(self.config()) + model1.add_adapter("dummy", config=adapter_config) + if "seq2seq_lm" in ADAPTER_MODEL_MAPPING[self.config_class].head_types: + model1.add_seq2seq_lm_head("dummy") + else: + model1.add_causal_lm_head("dummy") + model1.set_active_adapters("dummy") + model1.to(torch_device) + + seq_output_length = 32 + + # Finally, also check if generation works properly + if self.is_speech_model: + input_ids = self.get_input_samples((1, 80, 3000), config=model1.config)["input_features"] + else: + input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"] + input_ids = input_ids.to(torch_device) + generated = model1.generate(input_ids, max_length=seq_output_length) + self.assertLessEqual(generated.shape, (1, seq_output_length)) diff --git a/tests/methods/test_compacter.py b/tests/methods/test_compacter.py index 292fab1efb..06b3a346e1 100644 --- a/tests/methods/test_compacter.py +++ b/tests/methods/test_compacter.py @@ -1,5 +1,5 @@ -from adapters import ADAPTER_MODEL_MAPPING, AutoAdapterModel, CompacterPlusPlusConfig -from transformers.testing_utils import require_torch, torch_device +from adapters import CompacterPlusPlusConfig +from transformers.testing_utils import require_torch from .base import AdapterMethodBaseTestMixin @@ -53,28 +53,4 @@ def test_train_shared_phm_compacter(self): self.run_train_test(adapter_config, ["adapters.{name}."]) def test_compacter_generate(self): - if self.config_class not in ADAPTER_MODEL_MAPPING or ( - "seq2seq_lm" not in ADAPTER_MODEL_MAPPING[self.config_class].head_types - and "causal_lm" not in ADAPTER_MODEL_MAPPING[self.config_class].head_types - ): - self.skipTest("No seq2seq or causal language model head") - - model1 = AutoAdapterModel.from_config(self.config()) - model1.add_adapter("dummy", config=CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8)) - if "seq2seq_lm" in ADAPTER_MODEL_MAPPING[self.config_class].head_types: - model1.add_seq2seq_lm_head("dummy") - else: - model1.add_causal_lm_head("dummy") - model1.set_active_adapters("dummy") - model1.to(torch_device) - - seq_output_length = 32 - - # Finally, also check if generation works properly - if self.is_speech_model: - input_ids = self.get_input_samples((1, 80, 3000), config=model1.config)["input_features"] - else: - input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"] - input_ids = input_ids.to(torch_device) - generated = model1.generate(input_ids, max_length=seq_output_length) - self.assertLessEqual(generated.shape, (1, seq_output_length)) + self.run_generate_test(CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8)) diff --git a/tests/methods/test_prefix_tuning.py b/tests/methods/test_prefix_tuning.py index dd443c0d0b..d5765771ff 100644 --- a/tests/methods/test_prefix_tuning.py +++ b/tests/methods/test_prefix_tuning.py @@ -1,6 +1,6 @@ import torch -from adapters import ADAPTER_MODEL_MAPPING, AutoAdapterModel, PrefixTuningConfig +from adapters import PrefixTuningConfig from transformers import CLIPConfig from transformers.testing_utils import require_torch, torch_device @@ -76,28 +76,4 @@ def test_eject_prefix(self): self.assertTrue(torch.allclose(output_1[0], output_2[0], atol=1e-4)) def test_prefix_tuning_generate(self): - if self.config_class not in ADAPTER_MODEL_MAPPING or ( - "seq2seq_lm" not in ADAPTER_MODEL_MAPPING[self.config_class].head_types - and "causal_lm" not in ADAPTER_MODEL_MAPPING[self.config_class].head_types - ): - self.skipTest("No seq2seq or causal language model head") - - model1 = AutoAdapterModel.from_config(self.config()) - model1.add_adapter("dummy", config="prefix_tuning") - if "seq2seq_lm" in ADAPTER_MODEL_MAPPING[self.config_class].head_types: - model1.add_seq2seq_lm_head("dummy") - else: - model1.add_causal_lm_head("dummy") - model1.set_active_adapters("dummy") - model1.to(torch_device) - - seq_output_length = 32 - - # Finally, also check if generation works properly - if self.is_speech_model: - input_ids = self.get_input_samples((1, 80, 3000), config=model1.config)["input_features"] - else: - input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"] - input_ids = input_ids.to(torch_device) - generated = model1.generate(input_ids, max_length=seq_output_length) - self.assertLessEqual(generated.shape, (1, seq_output_length)) + self.run_generate_test(PrefixTuningConfig()) diff --git a/tests/methods/test_reft.py b/tests/methods/test_reft.py index 8849221808..f89fe18bea 100644 --- a/tests/methods/test_reft.py +++ b/tests/methods/test_reft.py @@ -77,3 +77,6 @@ def test_load_full_model_reft(self): def test_train_loreft(self): self.run_train_test(LoReftConfig(), ["refts.{name}."]) + + def test_reft_generate(self): + self.run_generate_test(LoReftConfig()) From 127f51be8a74939fe1e8de6e48c903df388b7e52 Mon Sep 17 00:00:00 2001 From: calpt Date: Sat, 18 Jan 2025 20:13:41 +0100 Subject: [PATCH 5/9] Readme & test fixes (#780) Changes in this PR: - set minimum supported Python version to 3.9 (following recent Transformers upgrade) - reduce number of warnings in tests - add `ignore_cleanup_errors=True` in full model loading tests to allow running on Windows - minor fixes in Readme --- README.md | 6 ++--- docs/installation.md | 2 +- setup.py | 2 +- src/adapters/loading.py | 4 ++-- src/adapters/model_mixin.py | 2 +- src/adapters/trainer.py | 28 +++++++++++++++------- tests/composition/test_parallel.py | 2 +- tests/extended/test_adapter_trainer_ext.py | 2 +- tests/methods/base.py | 8 +++---- tests/test_adapter_conversion.py | 8 +++---- tests/test_adapter_embeddings.py | 2 +- tests/test_adapter_fusion_common.py | 2 +- tests/test_adapter_heads.py | 2 +- tests/test_adapter_hub.py | 2 +- tests/test_adapter_trainer.py | 8 +++---- 15 files changed, 46 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 9375cfc220..4a8c34d93b 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ A Unified Library for Parameter-Efficient and Modular Transfer Learning Paper -![Tests](https://github.com/Adapter-Hub/adapters/workflows/Tests/badge.svg?branch=adapters) +![Tests](https://github.com/Adapter-Hub/adapters/workflows/Tests/badge.svg) [![GitHub](https://img.shields.io/github/license/adapter-hub/adapters.svg?color=blue)](https://github.com/adapter-hub/adapters/blob/main/LICENSE) [![PyPI](https://img.shields.io/pypi/v/adapters)](https://pypi.org/project/adapters/) @@ -45,7 +45,7 @@ _Adapters_ provides a unified interface for efficient fine-tuning and modular tr ## Installation -`adapters` currently supports **Python 3.8+** and **PyTorch 1.10+**. +`adapters` currently supports **Python 3.9+** and **PyTorch 2.0+**. After [installing PyTorch](https://pytorch.org/get-started/locally/), you can install `adapters` from PyPI ... ``` @@ -147,7 +147,7 @@ Currently, adapters integrates all architectures and methods listed below: | Method | Paper(s) | Quick Links | | --- | --- | --- | -| Bottleneck adapters | [Houlsby et al. (2019)](https://arxiv.org/pdf/1902.00751.pdf)
[Bapna and Firat (2019)](https://arxiv.org/pdf/1909.08478.pdf) | [Quickstart](https://docs.adapterhub.ml/quickstart.html), [Notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/01_Adapter_Training.ipynb) | +| Bottleneck adapters | [Houlsby et al. (2019)](https://arxiv.org/pdf/1902.00751.pdf)
[Bapna and Firat (2019)](https://arxiv.org/pdf/1909.08478.pdf)
[Steitz and Roth (2024)](https://openaccess.thecvf.com/content/CVPR2024/papers/Steitz_Adapters_Strike_Back_CVPR_2024_paper.pdf) | [Quickstart](https://docs.adapterhub.ml/quickstart.html), [Notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/01_Adapter_Training.ipynb) | | AdapterFusion | [Pfeiffer et al. (2021)](https://aclanthology.org/2021.eacl-main.39.pdf) | [Docs: Training](https://docs.adapterhub.ml/training.html#train-adapterfusion), [Notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/03_Adapter_Fusion.ipynb) | | MAD-X,
Invertible adapters | [Pfeiffer et al. (2020)](https://aclanthology.org/2020.emnlp-main.617/) | [Notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/04_Cross_Lingual_Transfer.ipynb) | | AdapterDrop | [Rücklé et al. (2021)](https://arxiv.org/pdf/2010.11918.pdf) | [Notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/05_Adapter_Drop_Training.ipynb) | diff --git a/docs/installation.md b/docs/installation.md index c3b8468eb8..51a5eaa3b0 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -1,7 +1,7 @@ # Installation The `adapters` package is designed as an add-on for Hugging Face's Transformers library. -It currently supports Python 3.8+ and PyTorch 1.10+. You will have to [install PyTorch](https://pytorch.org/get-started/locally/) first. +It currently supports Python 3.9+ and PyTorch 2.0+. You will have to [install PyTorch](https://pytorch.org/get-started/locally/) first. ```{eval-rst} .. important:: diff --git a/setup.py b/setup.py index d7a15ef921..e3af210570 100644 --- a/setup.py +++ b/setup.py @@ -155,7 +155,7 @@ def deps_list(*pkgs): packages=find_packages("src"), zip_safe=False, extras_require=extras, - python_requires=">=3.8.0", + python_requires=">=3.9.0", install_requires=install_requires, classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/src/adapters/loading.py b/src/adapters/loading.py index 55ba1db45b..154951c01a 100644 --- a/src/adapters/loading.py +++ b/src/adapters/loading.py @@ -160,10 +160,10 @@ def load_weights( else: logger.info(f"No safetensors file found in {save_directory}. Falling back to torch.load...") weights_file = join(save_directory, self.weights_name) - state_dict = torch.load(weights_file, map_location="cpu") + state_dict = torch.load(weights_file, map_location="cpu", weights_only=True) else: weights_file = join(save_directory, self.weights_name) - state_dict = torch.load(weights_file, map_location="cpu") + state_dict = torch.load(weights_file, map_location="cpu", weights_only=True) except Exception: raise OSError("Unable to load weights from pytorch checkpoint file. ") logger.info("Loading module weights from {}".format(weights_file)) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 62de6178ac..ca4db8092c 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -257,7 +257,7 @@ def load_embeddings(self, path: str, name: str): embedding_path = os.path.join(path, EMBEDDING_FILE) if not os.path.isfile(embedding_path): raise FileNotFoundError("No embeddings found at {}".format(embedding_path)) - weights = torch.load(embedding_path) + weights = torch.load(embedding_path, weights_only=True) self.loaded_embeddings[name] = nn.Embedding.from_pretrained(weights) self.set_active_embeddings(name) diff --git a/src/adapters/trainer.py b/src/adapters/trainer.py index 2896585bcf..ca7662d449 100644 --- a/src/adapters/trainer.py +++ b/src/adapters/trainer.py @@ -4,21 +4,28 @@ import torch from torch import nn -from torch.utils.data.dataset import Dataset +from torch.utils.data.dataset import Dataset, IterableDataset from transformers import PreTrainedModel, Seq2SeqTrainer, Trainer, __version__ from transformers.configuration_utils import PretrainedConfig from transformers.data.data_collator import DataCollator +from transformers.feature_extraction_utils import FeatureExtractionMixin +from transformers.image_processing_utils import BaseImageProcessor from transformers.modeling_utils import unwrap_model +from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState from transformers.trainer_utils import EvalPrediction from transformers.training_args import TrainingArguments -from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, is_sagemaker_mp_enabled, logging +from transformers.utils import CONFIG_NAME, WEIGHTS_NAME, is_datasets_available, is_sagemaker_mp_enabled, logging from .composition import AdapterCompositionBlock, Fuse +if is_datasets_available(): + import datasets + + if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp @@ -32,15 +39,19 @@ def __init__( model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Dataset] = None, - eval_dataset: Optional[Dataset] = None, + train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, - model_init: Callable[[], PreTrainedModel] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, adapter_names: Optional[List[List[str]]] = None, - optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, + **kwargs, ): if model is not None: model_quantized = getattr(model, "is_quantized", False) @@ -51,12 +62,13 @@ def __init__( data_collator, train_dataset, eval_dataset, - tokenizer=tokenizer, + processing_class=processing_class or tokenizer, model_init=model_init, compute_metrics=compute_metrics, callbacks=[AdapterTrainerCallback(self)] + callbacks if callbacks else [AdapterTrainerCallback(self)], optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics, + **kwargs, ) if model is not None: model.is_quantized = model_quantized diff --git a/tests/composition/test_parallel.py b/tests/composition/test_parallel.py index 80e1ae8616..8a15a9f1c5 100644 --- a/tests/composition/test_parallel.py +++ b/tests/composition/test_parallel.py @@ -214,7 +214,7 @@ def run_parallel_training_test(self, adapter_config, filter_key): do_train=True, learning_rate=1.0, max_steps=20, - no_cuda=True, + use_cpu=True, remove_unused_columns=False, ) diff --git a/tests/extended/test_adapter_trainer_ext.py b/tests/extended/test_adapter_trainer_ext.py index 6e14944654..8da0ea07c8 100644 --- a/tests/extended/test_adapter_trainer_ext.py +++ b/tests/extended/test_adapter_trainer_ext.py @@ -300,7 +300,7 @@ def run_trainer( --per_device_eval_batch_size 4 --max_eval_samples 8 --val_max_target_length {max_len} - --evaluation_strategy steps + --eval_strategy steps --eval_steps {str(eval_steps)} --train_adapter """.split() diff --git a/tests/methods/base.py b/tests/methods/base.py index 338cf970aa..2d5771ce61 100644 --- a/tests/methods/base.py +++ b/tests/methods/base.py @@ -192,11 +192,11 @@ def run_load_test(self, adapter_config): name = "dummy_adapter" model1.add_adapter(name, config=adapter_config) model1.set_active_adapters(name) - with tempfile.TemporaryDirectory() as temp_dir: + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as temp_dir: model1.save_adapter(temp_dir, name) # Check that there are actually weights saved - weights = torch.load(os.path.join(temp_dir, WEIGHTS_NAME), map_location="cpu") + weights = torch.load(os.path.join(temp_dir, WEIGHTS_NAME), map_location="cpu", weights_only=True) self.assertTrue(len(weights) > 0) # also tests that set_active works @@ -225,7 +225,7 @@ def run_full_model_load_test(self, adapter_config): name = "dummy" model1.add_adapter(name, config=adapter_config) - with tempfile.TemporaryDirectory() as temp_dir: + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as temp_dir: model1.save_pretrained(temp_dir) model2, loading_info = load_model(temp_dir, self.model_class, output_loading_info=True) @@ -256,7 +256,7 @@ def trainings_run(self, model, lr=1.0, steps=8): do_train=True, learning_rate=lr, max_steps=steps, - no_cuda=True, + use_cpu=True, per_device_train_batch_size=2, remove_unused_columns=False, ) diff --git a/tests/test_adapter_conversion.py b/tests/test_adapter_conversion.py index 9653b3f340..067b1b9665 100644 --- a/tests/test_adapter_conversion.py +++ b/tests/test_adapter_conversion.py @@ -37,7 +37,7 @@ def run_test(self, static_model, input_shape=None, label_dict=None): ): self.skipTest("Skipping as base model classes are different.") - with tempfile.TemporaryDirectory() as temp_dir: + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as temp_dir: static_model.save_head(temp_dir) loading_info = {} @@ -193,7 +193,7 @@ def test_equivalent_language_generation(self): static_model.eval() flex_model.eval() - with tempfile.TemporaryDirectory() as temp_dir: + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as temp_dir: static_model.save_adapter(temp_dir, "dummy") loading_info = {} @@ -209,7 +209,7 @@ def test_equivalent_language_generation(self): model_gen = static_model.generate(**input_samples) flex_model_gen = flex_model.generate(**input_samples) - self.assertEquals(model_gen.shape, flex_model_gen.shape) + self.assertEqual(model_gen.shape, flex_model_gen.shape) self.assertTrue(torch.equal(model_gen, flex_model_gen)) def test_full_model_conversion(self): @@ -220,7 +220,7 @@ def test_full_model_conversion(self): adapters.init(static_head_model) static_head_model.eval() - with tempfile.TemporaryDirectory() as temp_dir: + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as temp_dir: static_head_model.save_pretrained(temp_dir) flex_head_model, loading_info = AutoAdapterModel.from_pretrained(temp_dir, output_loading_info=True) diff --git a/tests/test_adapter_embeddings.py b/tests/test_adapter_embeddings.py index 160828c776..0284b7c384 100644 --- a/tests/test_adapter_embeddings.py +++ b/tests/test_adapter_embeddings.py @@ -105,7 +105,7 @@ def test_training_embedding(self): do_train=True, learning_rate=0.4, max_steps=15, - no_cuda=True, + use_cpu=True, per_device_train_batch_size=2, label_names=["labels"], ) diff --git a/tests/test_adapter_fusion_common.py b/tests/test_adapter_fusion_common.py index 695808eb24..b8472483ee 100644 --- a/tests/test_adapter_fusion_common.py +++ b/tests/test_adapter_fusion_common.py @@ -126,7 +126,7 @@ def test_load_full_model_fusion(self): model1.add_adapter(name2) model1.add_adapter_fusion([name1, name2]) # save & reload model - with tempfile.TemporaryDirectory() as temp_dir: + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as temp_dir: model1.save_pretrained(temp_dir) model2 = load_model(temp_dir, self.model_class) diff --git a/tests/test_adapter_heads.py b/tests/test_adapter_heads.py index cb7ea7078c..df7a0ac7f8 100644 --- a/tests/test_adapter_heads.py +++ b/tests/test_adapter_heads.py @@ -315,7 +315,7 @@ def test_load_full_model(self): self.add_head(model, "dummy", layers=1) true_config = model.get_prediction_heads_config() - with tempfile.TemporaryDirectory() as temp_dir: + with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as temp_dir: # save model.save_pretrained(temp_dir) # reload diff --git a/tests/test_adapter_hub.py b/tests/test_adapter_hub.py index fa29d13b19..0dee5eb0a6 100644 --- a/tests/test_adapter_hub.py +++ b/tests/test_adapter_hub.py @@ -76,7 +76,7 @@ def test_load_task_adapter_from_hub(self): overwrite_cache=True, ) eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") - training_args = TrainingArguments(output_dir="./examples", no_cuda=True) + training_args = TrainingArguments(output_dir="./examples", use_cpu=True) # evaluate trainer = Trainer( diff --git a/tests/test_adapter_trainer.py b/tests/test_adapter_trainer.py index fd1647865b..8630a31479 100644 --- a/tests/test_adapter_trainer.py +++ b/tests/test_adapter_trainer.py @@ -237,7 +237,7 @@ def test_training_load_best_model_at_end_full_model(self): save_steps=1, remove_unused_columns=False, load_best_model_at_end=True, - evaluation_strategy="epoch", + eval_strategy="epoch", save_strategy="epoch", num_train_epochs=2, ) @@ -273,7 +273,7 @@ def test_training_load_best_model_at_end_adapter(self): save_steps=1, remove_unused_columns=False, load_best_model_at_end=True, - evaluation_strategy="epoch", + eval_strategy="epoch", save_strategy="epoch", num_train_epochs=2, ) @@ -309,7 +309,7 @@ def test_training_load_best_model_at_end_fusion(self): save_steps=1, remove_unused_columns=False, load_best_model_at_end=True, - evaluation_strategy="epoch", + eval_strategy="epoch", save_strategy="epoch", num_train_epochs=2, ) @@ -600,7 +600,7 @@ def forward(self, x): output_dir=tempdir, per_device_train_batch_size=1, per_device_eval_batch_size=1, - evaluation_strategy="steps", + eval_strategy="steps", logging_steps=10, max_steps=5, lr_scheduler_type="constant", From adef6dc139e60624deb1e41f5ab0cacbea0b1d62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Sun, 26 Jan 2025 21:23:06 +0100 Subject: [PATCH 6/9] Add Support for Gradient Checkpointing (#759) # Add Support for Gradient Checkpointing This PR adds support for gradient checkpointing. Gradient checkpointing is a technique that trades computation for memory by recomputing intermediate activations during the backward pass instead of storing them. This is particularly useful when training large models. Because we recompute values during the backpropagation, we need to preserve the original ForwardContext in this phase. I solved this by overwriting the `gradient_checkpointing_enable` function so that the checkpoint function receives the current ForwardContext as the backward pass context manager. --------- Co-authored-by: calpt --- docs/training.md | 4 + notebooks/Gradient_Checkpointing_Llama.ipynb | 339 ++++++++++++++++++ notebooks/README.md | 7 +- src/adapters/composition.py | 4 +- src/adapters/context.py | 5 +- src/adapters/methods/lora.py | 4 +- src/adapters/model_mixin.py | 69 ++++ src/adapters/models/beit/adapter_model.py | 14 + .../models/deberta/modeling_deberta.py | 53 +++ .../models/deberta_v2/modeling_deberta_v2.py | 55 +++ src/adapters/models/mt5/modeling_mt5.py | 16 +- src/adapters/models/t5/modeling_t5.py | 16 +- tests/methods/base.py | 60 +++- tests/methods/test_ia3.py | 3 + tests/methods/test_lora.py | 3 + tests/methods/test_prefix_tuning.py | 3 + tests/methods/test_prompt_tuning.py | 3 + tests/methods/test_reft.py | 3 + tests/methods/test_unipelt.py | 3 + tests/models/test_clip.py | 5 + 20 files changed, 659 insertions(+), 10 deletions(-) create mode 100644 notebooks/Gradient_Checkpointing_Llama.ipynb diff --git a/docs/training.md b/docs/training.md index 78fcd9e757..d4de614392 100644 --- a/docs/training.md +++ b/docs/training.md @@ -223,3 +223,7 @@ trainer = AdapterTrainer( _Adapters_ supports fine-tuning of quantized language models similar to [QLoRA (Dettmers et al., 2023)](https://arxiv.org/pdf/2305.14314.pdf) via the `bitsandbytes` library integrated into Transformers. Quantized training is supported for LoRA-based adapters as well as bottleneck adapters and prefix tuning. Please refer to [this notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) for a hands-on guide. + +## Gradient Checkpointing +Gradient checkpointing is supported for all models (e.g. Llama 1/2/3) except for the models that are not supported by Hugging Face Transformers (like ALBERT). Please refer to [this notebook](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Gradient_Checkpointing_Llama.ipynb) for a hands-on guide. + diff --git a/notebooks/Gradient_Checkpointing_Llama.ipynb b/notebooks/Gradient_Checkpointing_Llama.ipynb new file mode 100644 index 0000000000..b48390d846 --- /dev/null +++ b/notebooks/Gradient_Checkpointing_Llama.ipynb @@ -0,0 +1,339 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "introduction", + "metadata": {}, + "source": [ + "# Efficient Llama Training with Gradient Checkpointing and _Adapters_\n", + "\n", + "In this notebook, we show how to efficiently fine-tune a **Llama 3** model using **gradient checkpointing** and adapter methods.\n", + "\n", + "**Gradient checkpointing** is a technique to reduce peak memory usage significantly and thus enables training larger models with larger batch sizes. Gradient checkpointing achieves this by trading compute for memory: During the forward pass, gradient checkpointing only stores a subset of activations (thus saving memory). During backpropagation, gradient checkpointing recomputes the activations that were not stored. This can significantly reduce memory requirements at the cost of slightly increased computation time.\n", + "\n", + "In this notebook, we finetune Llama-3 8B on supervised instruction tuning data collected by the [Open Assistant project](https://github.com/LAION-AI/Open-Assistant) for training chatbots.\n", + "\n", + "Another way to reduce memore usage is to use quantization. Have a look a the [QLora notebook](QLoRA_Llama_Finetuning.ipynb) for an example. This gradient checkpointing notebook is based on the QLoRA notebook. While we use a normal LoRA setup in this notebook, you can easily replace LoRA with QLoRA to reduce memory usage even further." + ] + }, + { + "cell_type": "markdown", + "id": "installation", + "metadata": {}, + "source": [ + "## Installation\n", + "\n", + "We need `adapters`, `datasets` and `pytorch` for training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "install", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -qq -U adapters datasets torch" + ] + }, + { + "cell_type": "markdown", + "id": "dataset", + "metadata": {}, + "source": [ + "## Load Open Assistant dataset\n", + "\n", + "We use the [`timdettmers/openassistant-guanaco`](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) dataset, which contains a small subset of conversations from the full Open Assistant database." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "load_dataset", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['text'],\n", + " num_rows: 9846\n", + " })\n", + " test: Dataset({\n", + " features: ['text'],\n", + " num_rows: 518\n", + " })\n", + "})" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset = load_dataset(\"timdettmers/openassistant-guanaco\")\n", + "dataset" + ] + }, + { + "cell_type": "markdown", + "id": "model_setup", + "metadata": {}, + "source": [ + "## Load and prepare model\n", + "\n", + "We download the official Llama-2 7B/ Llama-3 8B checkpoint from the HuggingFace Hub. Note that you must request access to this model on the HuggingFace website and use an API token to download it.\n", + "\n", + "The key difference in this notebook is that we'll enable gradient checkpointing to reduce memory usage during training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "load_model", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "83e60dee3c434bb3a2bc656bd7f4b667", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00\"\n", + "\n", + "modelpath=\"meta-llama/Meta-Llama-3-8B\"\n", + "\n", + "# Load model with gradient checkpointing enabled\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " modelpath, \n", + " device_map=\"auto\",\n", + " torch_dtype=torch.bfloat16,\n", + " token=HUGGINGFACE_ACCESS_TOKEN,\n", + ")\n", + "model.config.use_cache = False\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(modelpath, token=HUGGINGFACE_ACCESS_TOKEN)\n", + "tokenizer.pad_token = tokenizer.eos_token" + ] + }, + { + "cell_type": "markdown", + "id": "5cd73b7d", + "metadata": {}, + "source": [ + "If you get a message similar to `WARNING:accelerate.big_modeling:Some parameters are on the meta device because they were offloaded to the cpu and disk.`, then the model itself is too big for your GPU. If you don't have a bigger / additional GPU at hand, you can use a quantization method like we show in the [QLoRA notebook](QLoRA_Llama_Finetuning.ipynb). Adding the quantization_config when loading the model and choosing a quantized `LoRAConfig` in the next step will enable quantized training." + ] + }, + { + "cell_type": "markdown", + "id": "adapter_setup", + "metadata": {}, + "source": [ + "## Initialize adapter\n", + "\n", + "We initialize the adapter functionality and add a LoRA adapter. When using gradient checkpointing with adapters, we need to enable input gradients explicitly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "init_adapter", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "================================================================================\n", + "Name Architecture #Param %Param Active Train\n", + "--------------------------------------------------------------------------------\n", + "lora_adapter lora 3,407,872 0.085 1 1\n", + "--------------------------------------------------------------------------------\n", + "Full model 4,015,263,744 100.000 0\n", + "================================================================================\n" + ] + } + ], + "source": [ + "import adapters\n", + "from adapters import LoRAConfig\n", + "\n", + "adapters.init(model)\n", + "\n", + "config = LoRAConfig()\n", + "model.add_adapter(\"lora_adapter\", config=config)\n", + "model.train_adapter(\"lora_adapter\")\n", + "\n", + "# Activate gradient checkpointing\n", + "model.gradient_checkpointing_enable()\n", + "\n", + "print(model.adapter_summary())" + ] + }, + { + "cell_type": "markdown", + "id": "data_prep", + "metadata": {}, + "source": [ + "## Prepare data for training\n", + "\n", + "The dataset is tokenized and truncated." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "tokenize", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "def tokenize(element):\n", + " return tokenizer(\n", + " element[\"text\"],\n", + " truncation=True,\n", + " max_length=512,\n", + " add_special_tokens=False,\n", + " )\n", + "\n", + "dataset_tokenized = dataset.map(\n", + " tokenize, \n", + " batched=True, \n", + " num_proc=os.cpu_count(),\n", + " remove_columns=[\"text\"]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "training", + "metadata": {}, + "source": [ + "## Training\n", + "\n", + "We specify training hyperparameters and train the model using the `AdapterTrainer` class. With gradient checkpointing enabled, we can use larger batch sizes than would otherwise be possible." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "training_args", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import TrainingArguments\n", + "\n", + "args = TrainingArguments(\n", + " output_dir=\"output/llama_gradient_checkpointing\",\n", + " per_device_train_batch_size=1,\n", + " per_device_eval_batch_size=1,\n", + " evaluation_strategy=\"steps\",\n", + " logging_steps=10,\n", + " save_steps=500,\n", + " eval_steps=187,\n", + " save_total_limit=3,\n", + " gradient_accumulation_steps=16,\n", + " max_steps=1875,\n", + " learning_rate=0.0002,\n", + " bf16=True,\n", + " warmup_ratio=0.03,\n", + " group_by_length=True,\n", + " lr_scheduler_type=\"constant\",\n", + " optim=\"adamw_torch\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "train", + "metadata": {}, + "outputs": [], + "source": [ + "from adapters import AdapterTrainer\n", + "from transformers import DataCollatorForLanguageModeling\n", + "\n", + "trainer = AdapterTrainer(\n", + " model=model,\n", + " tokenizer=tokenizer,\n", + " data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),\n", + " train_dataset=dataset_tokenized[\"train\"],\n", + " eval_dataset=dataset_tokenized[\"test\"],\n", + " args=args,\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "id": "inference", + "metadata": {}, + "source": [ + "## Inference\n", + "\n", + "For inference, we can disable gradient checkpointing since we don't need gradients:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "inference_setup", + "metadata": {}, + "outputs": [], + "source": [ + "# Disable gradient checkpointing for inference\n", + "model.gradient_checkpointing_disable()\n", + "model.config.use_cache = True\n", + "\n", + "def prompt_model(model, text: str):\n", + " batch = tokenizer(f\"### Human: {text}\\n### Assistant:\", return_tensors=\"pt\")\n", + " batch = batch.to(model.device)\n", + " \n", + " model.eval()\n", + " with torch.inference_mode():\n", + " output_tokens = model.generate(**batch, max_new_tokens=50)\n", + "\n", + " return tokenizer.decode(output_tokens[0], skip_special_tokens=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "test_inference", + "metadata": {}, + "outputs": [], + "source": [ + "print(prompt_model(model, \"Explain gradient checkpointing in simple terms\"))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/README.md b/notebooks/README.md index 052cdafe4d..4766baca7e 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -28,7 +28,6 @@ As adapters is fully compatible with HuggingFace's Transformers, you can also us | Notebook | Description | | |:----------------|:---------------------|--:| | [Text Generation](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Text_Generation_Training.ipynb) | How to train an adapter for language generation. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Text_Generation_Training.ipynb) | -| [QLoRA LLama Finetuning](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) | How to finetune a quantized Llama model for using QLoRA. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) | | [Training a NER Adapter](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Adapter_train_NER_with_id2label.ipynb) | How to train an adapter on a named entity recoginition task. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Adapter_train_NER_with_id2label.ipynb) | | [Adapter Drop Training](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Adapter_Drop_Training.ipynb) | How to train an adapter using AdapterDrop | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Adapter_Drop_Training.ipynb) | | [Inference example for id2label](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Adapter_train_NER_with_id2label.ipynb) | How to use the id2label dictionary for inference | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Adapter_id2label_inference.ipynb) | @@ -36,3 +35,9 @@ As adapters is fully compatible with HuggingFace's Transformers, you can also us | [Finetuning Whisper with Adapters](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Adapter_Whisper_Audio_FineTuning.ipynb) | Fine Tuning Whisper using LoRA | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/Adapter_Whisper_Audio_FineTuning.ipynb) | | [Adapter Training with ReFT](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/ReFT_Adapters_Finetuning.ipynb) | Fine Tuning using ReFT Adapters | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/ReFT_Adapters_Finetuning.ipynb) | | [ViT Fine-Tuning with AdapterPlus](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/ViT_AdapterPlus_FineTuning.ipynb) | ViT Fine-Tuning with AdapterPlus | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/ViT_AdapterPlus_FineTuning.ipynb) | + +### Memory Efficient Training +| Notebook | Description | | +|:----------------|:---------------------|--:| +| [QLoRA LLama Finetuning](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) | How to finetune a quantized Llama model for using QLoRA. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) | +| [Gradient Checkpointing](https://github.com/Adapter-Hub/adapters/blob/main/notebooks/Gradient_Checkpointing_Llama.ipynb) | How to finetune a quantized Llama model for using QLoRA. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapters/blob/main/notebooks/QLoRA_Llama_Finetuning.ipynb) | diff --git a/src/adapters/composition.py b/src/adapters/composition.py index 6c17fb8ebd..f9c3ee8cc8 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -273,7 +273,7 @@ def adjust_tensors_for_parallel(hidden_states, *tensors): """ outputs = [] for tensor in tensors: - if tensor is not None and hidden_states.shape[0] >= tensor.shape[0]: + if tensor is not None and hidden_states.shape[0] > tensor.shape[0]: repeats = [1] * len(tensor.shape) repeats[0] = hidden_states.shape[0] // tensor.shape[0] new_tensor = tensor.repeat(*repeats) @@ -288,7 +288,7 @@ def adjust_tensors_for_parallel_(hidden_states, *tensors): In-place version of adjust_tensors_for_parallel(). """ for tensor in tensors: - if tensor is not None and hidden_states.shape[0] >= tensor.shape[0]: + if tensor is not None and hidden_states.shape[0] > tensor.shape[0]: repeats = [1] * len(tensor.shape) repeats[0] = hidden_states.shape[0] // tensor.shape[0] new_tensor = tensor.repeat(*repeats) diff --git a/src/adapters/context.py b/src/adapters/context.py index 70e685d037..db09b8918f 100644 --- a/src/adapters/context.py +++ b/src/adapters/context.py @@ -1,10 +1,11 @@ import functools import threading +from typing import ContextManager from .composition import parse_composition, parse_heads_from_composition -class AdapterSetup: +class AdapterSetup(ContextManager): """ Represents an adapter setup of a model including active adapters and active heads. This class is intended to be used as a context manager using the ``with`` statement. The setup defined by the ``AdapterSetup`` context will @@ -67,7 +68,7 @@ def get_context_head_setup(cls): return None -class ForwardContext: +class ForwardContext(ContextManager): """ Holds context information during a forward pass through a model. This class should be used via the ``ForwardContext.wrap()`` method. diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index 8f3bc29401..3245afdd99 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -718,7 +718,9 @@ def pad(self, x, lora, fill_value=None): fill_value = 1 result = x.new_full((*x.shape[:-1], self.out_features), fill_value) result = result.view(-1, self.out_features) - result[:, lora.lora_ind] = x.reshape(-1, self.out_features // 3 * self.get_n_heads(lora)) + # Move lora_ind to the same device as x + lora_ind = lora.lora_ind.to(x.device) + result[:, lora_ind] = x.reshape(-1, self.out_features // 3 * self.get_n_heads(lora)) return result.view((*x.shape[:-1], self.out_features)) def reset_adapter(self): diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index ca4db8092c..1895671f8d 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1,3 +1,5 @@ +import contextlib +import functools import inspect import json import logging @@ -5,11 +7,13 @@ from abc import ABC, abstractmethod from collections import defaultdict from copy import deepcopy +from functools import partial from os.path import join from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn +from torch.utils.checkpoint import checkpoint from adapters.configuration.adapter_config import ConfigUnion, LoRAConfig from transformers import GenerationConfig @@ -1617,6 +1621,71 @@ def save_pretrained( # Remove adapters config del self.config.adapters + # Override PreTrainedModel.gradient_checkpointing_enable(...) method from transformers/modeling_utils.py to support gradient checkpointing for adapter training. + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of + the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + + Args: + gradient_checkpointing_kwargs (dict, *optional*): + Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. + """ + if not self.supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": False} + + # >>> START AH Changes <<< + if "use_reentrant" not in gradient_checkpointing_kwargs: + # use_reentrant must be set. + gradient_checkpointing_kwargs["use_reentrant"] = False + else: + if gradient_checkpointing_kwargs["use_reentrant"]: + raise ValueError( + "Gradient checkpointing with use_reentrant=True is not supported. For gradient checkpointing, we need to set context_fn, which is only supported by PyTorch when use_reentrant is set to False." + ) + + def gradient_checkpointing_function(function, *args, **kwargs): + context = ForwardContext.get_context() + context_fn = lambda: (contextlib.nullcontext(), context) + return checkpoint(function, *args, context_fn=context_fn, **kwargs) + + gradient_checkpointing_func = functools.partial( + gradient_checkpointing_function, **gradient_checkpointing_kwargs + ) + # >>> END AH Changes <<< + + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` method + _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters + + if not _is_using_old_format: + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) + else: + self.apply(partial(self._set_gradient_checkpointing, value=True)) + logger.warning( + "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." + "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." + ) + + # >>> START AH Changes <<< + # For adapter training, we set requires_grad=True for the input embeddings. Just like Hugging Face does for training with PEFT. + try: + self.enable_input_require_grads() + except NotImplementedError: + # Some models (CLIP) don't have input embeddings, so Hugging Face's implementation raises a NotImplementedError. We provide the user with some more information. + raise NotImplementedError( + "Model has no enable_input_require_grads method implementation by Hugging Face. Parameter efficient fine-tuning however needs gradients for embeddings. This model therefore doesn't support gradient checkpointing with Adapters nor Hugging Face's PEFT library." + ) + # >>> END AH Changes <<< + @inherit_doc class ModelBaseAdaptersMixin(ModelAdaptersMixin): diff --git a/src/adapters/models/beit/adapter_model.py b/src/adapters/models/beit/adapter_model.py index 5667fa098d..578142ea11 100644 --- a/src/adapters/models/beit/adapter_model.py +++ b/src/adapters/models/beit/adapter_model.py @@ -36,6 +36,20 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + # Overwrites the function from: transformers.modeling_utils.PreTrainedModel + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings specifically for BEiT's tuple output format. + """ + + def make_inputs_require_grads(module, input, output): + # >>> START AH Changes <<< + # Handle BEiT's specific tuple output format. Hugging Face's implementation is buggy and doesn't work for BEiT. + output[0].requires_grad_(True) + # >>> END AH Changes <<< + + self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) def forward( self, diff --git a/src/adapters/models/deberta/modeling_deberta.py b/src/adapters/models/deberta/modeling_deberta.py index 77c6117b19..8b4c87b2c5 100644 --- a/src/adapters/models/deberta/modeling_deberta.py +++ b/src/adapters/models/deberta/modeling_deberta.py @@ -19,6 +19,7 @@ from torch import nn from transformers.models.deberta.modeling_deberta import ( + DebertaEmbeddings, DebertaOutput, DebertaSelfOutput, DisentangledSelfAttention, @@ -47,6 +48,58 @@ def forward(self, hidden_states, input_tensor): return hidden_states +class DebertaEmbeddingsWithAdapters(DebertaEmbeddings): + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.position_embeddings is not None: + position_embeddings = self.position_embeddings(position_ids.long()) + else: + position_embeddings = torch.zeros_like(inputs_embeds) + + embeddings = inputs_embeds + if self.position_biased_input: + # >>> START AH Changes <<< + # HuggingFace uses += instead of + which leads to a bug when using model.enable_input_require_grads. Once this is fixed, we can remove + embeddings = embeddings + position_embeddings + # >>> END AH Changes <<< + if self.token_type_embeddings is not None: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + # >>> START AH Changes <<< + embeddings = embeddings + token_type_embeddings + # >>> END AH Changes <<< + if self.embed_proj is not None: + embeddings = self.embed_proj(embeddings) + + embeddings = self.LayerNorm(embeddings) + + if mask is not None: + if mask.dim() != embeddings.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) + + embeddings = embeddings * mask + + embeddings = self.dropout(embeddings) + return embeddings + + class DisentangledSelfAttentionWithAdapters(DebertaSelfAttentionAdaptersMixin, DisentangledSelfAttention): """ Disentangled self-attention module diff --git a/src/adapters/models/deberta_v2/modeling_deberta_v2.py b/src/adapters/models/deberta_v2/modeling_deberta_v2.py index 2b673c491f..2e7d86ae8a 100644 --- a/src/adapters/models/deberta_v2/modeling_deberta_v2.py +++ b/src/adapters/models/deberta_v2/modeling_deberta_v2.py @@ -19,6 +19,7 @@ from torch import nn from transformers.models.deberta_v2.modeling_deberta_v2 import ( + DebertaV2Embeddings, DebertaV2Output, DebertaV2SelfOutput, DisentangledSelfAttention, @@ -49,6 +50,60 @@ def forward(self, hidden_states, input_tensor): return hidden_states +# Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm,Deberta->DebertaV2 +class DebertaV2EmbeddingsWithAdapters(DebertaV2Embeddings): + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.position_embeddings is not None: + position_embeddings = self.position_embeddings(position_ids.long()) + else: + position_embeddings = torch.zeros_like(inputs_embeds) + + embeddings = inputs_embeds + if self.position_biased_input: + # >>> START AH Changes <<< + # HuggingFace uses += instead of + which leads to a bug when using model.enable_input_require_grads. Once this is fixed, we can remove DebertaV2EmbeddingsWithAdapters. + embeddings = embeddings + position_embeddings + # >>> END AH Changes <<< + if self.token_type_embeddings is not None: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + # >>> START AH Changes <<< + embeddings = embeddings + token_type_embeddings + # >>> END AH Changes <<< + + if self.embed_proj is not None: + embeddings = self.embed_proj(embeddings) + + embeddings = self.LayerNorm(embeddings) + + if mask is not None: + if mask.dim() != embeddings.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) + + embeddings = embeddings * mask + + embeddings = self.dropout(embeddings) + return embeddings + + class DisentangledSelfAttentionWithAdapters(DebertaV2SelfAttentionAdaptersMixin, DisentangledSelfAttention): def transpose_for_scores_extended(self, x, attention_heads): new_x_shape = x.size()[:-1] + (attention_heads, -1) diff --git a/src/adapters/models/mt5/modeling_mt5.py b/src/adapters/models/mt5/modeling_mt5.py index 05141a08cf..b317823335 100644 --- a/src/adapters/models/mt5/modeling_mt5.py +++ b/src/adapters/models/mt5/modeling_mt5.py @@ -419,8 +419,22 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: + # >>> START AH Changes <<< + # Without this change, T5 training with gradient checkpointing will fail for reft. + def create_custom_forward(module): + def custom_forward(*inputs): + # Ensure all inputs are on the same device + inputs = tuple(x.to(inputs[0].device) if isinstance(x, torch.Tensor) else x for x in inputs) + return module(*inputs) + + return custom_forward + + # >>> END AH Changes <<< + layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, + # >>> START AH Changes <<< + create_custom_forward(layer_module), + # >>> END AH Changes <<< hidden_states, causal_mask, position_bias, diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py index 09b969bb1b..e401a2b840 100644 --- a/src/adapters/models/t5/modeling_t5.py +++ b/src/adapters/models/t5/modeling_t5.py @@ -419,8 +419,22 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: + # >>> START AH Changes <<< + # Without this change, T5 training with gradient checkpointing will fail for reft. + def create_custom_forward(module): + def custom_forward(*inputs): + # Ensure all inputs are on the same device + inputs = tuple(x.to(inputs[0].device) if isinstance(x, torch.Tensor) else x for x in inputs) + return module(*inputs) + + return custom_forward + + # >>> END AH Changes <<< + layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, + # >>> START AH Changes <<< + create_custom_forward(layer_module), + # >>> END AH Changes <<< hidden_states, causal_mask, position_bias, diff --git a/tests/methods/base.py b/tests/methods/base.py index 2d5771ce61..55389b7052 100644 --- a/tests/methods/base.py +++ b/tests/methods/base.py @@ -1,6 +1,7 @@ import copy import os import tempfile +from typing import Callable import torch @@ -247,7 +248,7 @@ def run_full_model_load_test(self, adapter_config): self.assertEqual(len(output1), len(output2)) self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-4)) - def trainings_run(self, model, lr=1.0, steps=8): + def trainings_run(self, model, lr=1.0, steps=8, batch_size=2, gradient_accumulation_steps=1): # setup dataset train_dataset = self.dataset() @@ -257,7 +258,8 @@ def trainings_run(self, model, lr=1.0, steps=8): learning_rate=lr, max_steps=steps, use_cpu=True, - per_device_train_batch_size=2, + per_device_train_batch_size=batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, remove_unused_columns=False, ) @@ -371,6 +373,60 @@ def run_reset_test(self, adapter_config): self.assertEqual(len(output_1), len(output_2)) self.assertTrue(torch.allclose(output_1[0], output_2[0], atol=1e-3)) + def _run_gradient_checkpointing_test_helper(self, adapter_setup_fn: Callable[[adapters.ModelAdaptersMixin], None]): + """ + Test that gradient checkpointing produces the same results as normal training + Args: + adapter_setup_fn: Function that takes a model and sets up the adapter training. Must also add a head (usually via self.add_head(...)). We have this in a separate function to allow complex setups (like training a normal adapter or training parallel setups) + """ + + if not self.do_run_train_tests: + self.skipTest("Skipping training tests. Set `do_run_train_tests=True` to run them.") + if self.config_class not in ADAPTER_MODEL_MAPPING: + self.skipTest("Does not support flex heads.") + + config = self.config() + state_dict_after_training = {} + + # Run training twice (with & without gradient checkpointing) to verify both produce identical results (i.e. the same state dict) + for train_with_checkpointing in [True, False]: + # Set random seed + torch.manual_seed(42) + + # Initialize model + model = adapters.AutoAdapterModel.from_config(config) + + # if model doesn't support gradient checkpointing, skip the test + if not model.supports_gradient_checkpointing: + self.skipTest("Model does not support gradient checkpointing") + + model.to(torch_device) + adapter_setup_fn(model) + + # Enable gradient checkpointing + if train_with_checkpointing: + model.gradient_checkpointing_enable() + + # Train & store state dict + self.trainings_run(model, batch_size=1, gradient_accumulation_steps=2) + state_dict_after_training[train_with_checkpointing] = copy.deepcopy(model.state_dict()) + + # Check that the state dicts are the same (we know that normal training works as expected, so we only need to check that gradient checkpointing produces the same results.) + for (k1, v1), (k2, v2) in zip( + state_dict_after_training[True].items(), state_dict_after_training[False].items() + ): + v1 = v1.to(v2.device) + self.assertTrue(torch.equal(v1, v2), msg=f"Key {k1} is not equal:\nv1: {v1}\nv2: {v2}") + + def run_gradient_checkpointing_single_adapter_test(self, adapter_config): + def adapter_setup_fn(model): + model.add_adapter("adapter1", config=adapter_config) + self.add_head(model, "adapter1") + model.train_adapter("adapter1") + model.adapter_to("adapter1", torch_device) + + self._run_gradient_checkpointing_test_helper(adapter_setup_fn) + def run_generate_test(self, adapter_config): if self.config_class not in ADAPTER_MODEL_MAPPING or ( "seq2seq_lm" not in ADAPTER_MODEL_MAPPING[self.config_class].head_types diff --git a/tests/methods/test_ia3.py b/tests/methods/test_ia3.py index 3a30e2448d..b96dbcd02a 100644 --- a/tests/methods/test_ia3.py +++ b/tests/methods/test_ia3.py @@ -45,3 +45,6 @@ def test_merge_ia3(self): def test_reset_ia3(self): self.run_reset_test(IA3Config(init_weights="bert")) + + def test_ia3_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(IA3Config()) diff --git a/tests/methods/test_lora.py b/tests/methods/test_lora.py index 067f78c8b8..e1ced5188a 100644 --- a/tests/methods/test_lora.py +++ b/tests/methods/test_lora.py @@ -313,3 +313,6 @@ def test_merge_lora(self): def test_reset_lora(self): self.run_reset_test(LoRAConfig(init_weights="bert")) + + def test_lora_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(LoRAConfig()) diff --git a/tests/methods/test_prefix_tuning.py b/tests/methods/test_prefix_tuning.py index d5765771ff..35906edb19 100644 --- a/tests/methods/test_prefix_tuning.py +++ b/tests/methods/test_prefix_tuning.py @@ -77,3 +77,6 @@ def test_eject_prefix(self): def test_prefix_tuning_generate(self): self.run_generate_test(PrefixTuningConfig()) + + def test_prefix_tuning_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(PrefixTuningConfig()) diff --git a/tests/methods/test_prompt_tuning.py b/tests/methods/test_prompt_tuning.py index 97015d1319..f2fd1b0345 100644 --- a/tests/methods/test_prompt_tuning.py +++ b/tests/methods/test_prompt_tuning.py @@ -36,3 +36,6 @@ def test_load_full_model_prompt_tuning(self): def test_train_prompt_tuning(self): self.run_train_test(PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."]) + + def test_prompt_tuning_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(PromptTuningConfig(prompt_length=10)) diff --git a/tests/methods/test_reft.py b/tests/methods/test_reft.py index f89fe18bea..76d0980d57 100644 --- a/tests/methods/test_reft.py +++ b/tests/methods/test_reft.py @@ -80,3 +80,6 @@ def test_train_loreft(self): def test_reft_generate(self): self.run_generate_test(LoReftConfig()) + + def test_reft_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(LoReftConfig()) diff --git a/tests/methods/test_unipelt.py b/tests/methods/test_unipelt.py index d29fa5f18d..b855670ab4 100644 --- a/tests/methods/test_unipelt.py +++ b/tests/methods/test_unipelt.py @@ -64,3 +64,6 @@ def test_output_adapter_gating_scores_unipelt(self): self.assertGreaterEqual(len(per_layer_scores), 3) for k, v in per_layer_scores.items(): self.assertEqual(self.default_input_samples_shape[0], v.shape[0], k) + + def test_unipelt_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(UniPELTConfig()) diff --git a/tests/models/test_clip.py b/tests/models/test_clip.py index 921e0668f5..cf1297b693 100644 --- a/tests/models/test_clip.py +++ b/tests/models/test_clip.py @@ -37,3 +37,8 @@ def test_initialization(self): [0.0, 1.0], msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + + def test_gradient_checkpointing_enable_disable(self): + # CLIPAdapterModel does not support gradient checkpointing (because enable_input_require_grads is not implemented by Hugging Face, + # which is required for gradient checkpointing with parameter efficient fine-tuning methods). + self.skipTest("CLIPAdapterModel does not support gradient checkpointing") From 1dcac5c086705574ef4fe1bb069247feea52a41a Mon Sep 17 00:00:00 2001 From: TimoImhof <62378375+TimoImhof@users.noreply.github.com> Date: Mon, 27 Jan 2025 22:40:19 +0100 Subject: [PATCH 7/9] Tests Refactoring (#740) This PR aims to refactor the test suite into a more generic version. The following points are to be covered (subject to change; will become more granular): - [x] Add more granularity to test viewer for better test overview/readability and more convenient testing of subgroups of methods (e.g. only exectue LoRA tests for one model) - [x] Add documentation about the structure of the test directory and entry points - [x] Add pytest markers to allow for more granular testing of subgroups (e.g. exectue only LoRA tests for all models at once, useful when implementing a new adapter method) - [x] Refactor test methods to extract similar/duplicate code into a set of utils methods - [x] Relocate edgecases due to model peculiarities out of the tests and into the respective model test classes to keeps the tests generic - [x] Fix config union tests, closes #785 This should make it easier to add new models or methods to the existing test suite in the future and make the development/testing process more convenient in the future --- Makefile | 19 +- conftest.py | 5 + examples/pytorch/language-modeling/run_clm.py | 2 +- pyproject.toml | 15 +- tests/README.md | 150 +++++++++++ tests/{models => fixtures}/__init__.py | 0 tests/methods/__init__.py | 42 ---- tests/test_adapter.py | 210 ---------------- tests/test_albert.py | 55 ----- tests/test_bart.py | 68 ----- tests/test_beit.py | 45 ---- tests/test_bert.py | 51 ---- tests/test_clip.py | 230 ----------------- tests/test_deberta.py | 55 ----- tests/test_debertaV2.py | 54 ---- tests/test_distilbert.py | 51 ---- tests/test_electra.py | 52 ---- tests/test_gpt2.py | 67 ----- tests/test_gptj.py | 68 ----- tests/test_llama.py | 66 ----- tests/test_mbart.py | 62 ----- tests/test_methods/__init__.py | 0 tests/test_methods/base.py | 232 ++++++++++++++++++ tests/test_methods/generator.py | 227 +++++++++++++++++ .../test_methods/method_test_impl/__init__.py | 0 .../method_test_impl}/base.py | 97 ++++---- .../method_test_impl/composition/__init__.py | 0 .../composition/test_parallel.py | 52 +--- .../method_test_impl/core/__init__.py | 0 .../test_adapter_backward_compability.py | 2 +- .../core}/test_adapter_conversion.py | 38 ++- .../core}/test_adapter_fusion_common.py | 4 +- .../method_test_impl/embeddings/__init__.py | 0 .../embeddings}/test_adapter_embeddings.py | 21 +- .../method_test_impl/heads/__init__.py | 0 .../heads}/test_adapter_heads.py | 59 ++--- .../method_test_impl/peft/__init__.py | 0 .../peft}/test_adapter_common.py | 22 +- .../method_test_impl/peft}/test_compacter.py | 25 +- .../peft}/test_config_union.py | 28 ++- .../method_test_impl/peft}/test_ia3.py | 3 +- .../method_test_impl/peft}/test_lora.py | 15 +- .../peft}/test_prefix_tuning.py | 3 +- .../peft}/test_prompt_tuning.py | 3 +- .../method_test_impl/peft}/test_reft.py | 5 +- .../method_test_impl/peft}/test_unipelt.py | 7 +- tests/test_methods/method_test_impl/utils.py | 48 ++++ tests/test_methods/test_on_albert.py | 48 ++++ tests/test_methods/test_on_bart.py | 28 +++ tests/test_methods/test_on_beit.py | 24 ++ tests/test_methods/test_on_bert.py | 23 ++ .../test_on_bert_generation.py} | 60 ++--- tests/test_methods/test_on_clip/test_model.py | 83 +++++++ .../test_on_clip/test_textmodel.py | 27 ++ .../test_textwithprojectionmodel.py | 27 ++ .../test_on_clip/test_visionmodel.py | 28 +++ .../test_visionwithprojectionmodel.py | 28 +++ tests/test_methods/test_on_deberta.py | 29 +++ tests/test_methods/test_on_debertaV2.py | 26 ++ tests/test_methods/test_on_distilbert.py | 23 ++ tests/test_methods/test_on_electra.py | 24 ++ .../test_on_encoder_decoder.py} | 47 ++-- tests/test_methods/test_on_gpt2.py | 27 ++ tests/test_methods/test_on_llama.py | 39 +++ tests/test_methods/test_on_mbart.py | 31 +++ tests/test_methods/test_on_mistral.py | 26 ++ tests/test_methods/test_on_mt5.py | 29 +++ tests/test_methods/test_on_plbart.py | 27 ++ tests/test_methods/test_on_roberta.py | 24 ++ tests/test_methods/test_on_t5.py | 28 +++ tests/test_methods/test_on_vit.py | 23 ++ tests/test_methods/test_on_whisper.py | 31 +++ tests/test_methods/test_on_xlm_roberta.py | 23 ++ tests/test_methods/test_on_xmod.py | 25 ++ .../test_adapter_composition.py | 2 +- tests/{ => test_misc}/test_adapter_config.py | 0 .../test_adapter_custom_head.py | 3 +- .../test_adapter_fusion_config.py | 0 tests/{ => test_misc}/test_adapter_hub.py | 8 +- .../test_adapter_safetensors.py | 0 .../test_adapter_save_id2label.py | 0 .../test_adapter_trainer/__init__.py | 0 .../test_adapter_trainer.py | 0 .../test_adapter_trainer_ext.py | 4 +- tests/test_mistral.py | 66 ----- tests/test_models/__init__.py | 0 tests/{models => test_models}/base.py | 0 .../test_albert_model.py} | 0 .../test_bart_model.py} | 0 .../test_beit_model.py} | 0 .../test_bert_generation_model.py} | 0 .../test_bert_model.py} | 0 .../test_clip_model.py} | 0 .../test_debertaV2_model.py} | 0 .../test_deberta_model.py} | 0 .../test_distilbert_model.py} | 0 .../test_electra_model.py} | 0 .../test_encoder_decoder_model.py} | 0 .../test_gpt2_model.py} | 0 .../test_gptj_model.py} | 0 .../test_llama_model.py} | 0 .../test_mbart_model.py} | 0 .../test_mistral_model.py} | 0 .../test_mt5_model.py} | 0 .../test_plbart_model.py} | 0 .../test_roberta_model.py} | 0 .../test_t5_model.py} | 0 .../test_vit_model.py} | 0 .../test_whisper_model.py} | 0 .../test_xlm_roberta_model.py} | 0 .../test_xmod_model.py} | 0 tests/test_mt5.py | 68 ----- tests/test_plbart.py | 67 ----- tests/test_roberta.py | 49 ---- tests/test_t5.py | 68 ----- tests/test_vit.py | 48 ---- tests/test_whisper.py | 72 ------ tests/test_xlm_roberta.py | 41 ---- tests/test_xmod.py | 49 ---- 119 files changed, 1655 insertions(+), 2006 deletions(-) create mode 100644 tests/README.md rename tests/{models => fixtures}/__init__.py (100%) delete mode 100644 tests/methods/__init__.py delete mode 100644 tests/test_adapter.py delete mode 100644 tests/test_albert.py delete mode 100644 tests/test_bart.py delete mode 100644 tests/test_beit.py delete mode 100644 tests/test_bert.py delete mode 100644 tests/test_clip.py delete mode 100644 tests/test_deberta.py delete mode 100644 tests/test_debertaV2.py delete mode 100644 tests/test_distilbert.py delete mode 100644 tests/test_electra.py delete mode 100644 tests/test_gpt2.py delete mode 100644 tests/test_gptj.py delete mode 100644 tests/test_llama.py delete mode 100644 tests/test_mbart.py create mode 100644 tests/test_methods/__init__.py create mode 100644 tests/test_methods/base.py create mode 100644 tests/test_methods/generator.py create mode 100644 tests/test_methods/method_test_impl/__init__.py rename tests/{methods => test_methods/method_test_impl}/base.py (85%) create mode 100644 tests/test_methods/method_test_impl/composition/__init__.py rename tests/{ => test_methods/method_test_impl}/composition/test_parallel.py (83%) create mode 100644 tests/test_methods/method_test_impl/core/__init__.py rename tests/{ => test_methods/method_test_impl/core}/test_adapter_backward_compability.py (96%) rename tests/{ => test_methods/method_test_impl/core}/test_adapter_conversion.py (90%) rename tests/{ => test_methods/method_test_impl/core}/test_adapter_fusion_common.py (98%) create mode 100644 tests/test_methods/method_test_impl/embeddings/__init__.py rename tests/{ => test_methods/method_test_impl/embeddings}/test_adapter_embeddings.py (88%) create mode 100644 tests/test_methods/method_test_impl/heads/__init__.py rename tests/{ => test_methods/method_test_impl/heads}/test_adapter_heads.py (93%) create mode 100644 tests/test_methods/method_test_impl/peft/__init__.py rename tests/{methods => test_methods/method_test_impl/peft}/test_adapter_common.py (96%) rename tests/{methods => test_methods/method_test_impl/peft}/test_compacter.py (57%) rename tests/{methods => test_methods/method_test_impl/peft}/test_config_union.py (59%) rename tests/{methods => test_methods/method_test_impl/peft}/test_ia3.py (95%) rename tests/{methods => test_methods/method_test_impl/peft}/test_lora.py (96%) rename tests/{methods => test_methods/method_test_impl/peft}/test_prefix_tuning.py (97%) rename tests/{methods => test_methods/method_test_impl/peft}/test_prompt_tuning.py (95%) rename tests/{methods => test_methods/method_test_impl/peft}/test_reft.py (96%) rename tests/{methods => test_methods/method_test_impl/peft}/test_unipelt.py (92%) create mode 100644 tests/test_methods/method_test_impl/utils.py create mode 100644 tests/test_methods/test_on_albert.py create mode 100644 tests/test_methods/test_on_bart.py create mode 100644 tests/test_methods/test_on_beit.py create mode 100644 tests/test_methods/test_on_bert.py rename tests/{test_bert_generation.py => test_methods/test_on_bert_generation.py} (62%) create mode 100644 tests/test_methods/test_on_clip/test_model.py create mode 100644 tests/test_methods/test_on_clip/test_textmodel.py create mode 100644 tests/test_methods/test_on_clip/test_textwithprojectionmodel.py create mode 100644 tests/test_methods/test_on_clip/test_visionmodel.py create mode 100644 tests/test_methods/test_on_clip/test_visionwithprojectionmodel.py create mode 100644 tests/test_methods/test_on_deberta.py create mode 100644 tests/test_methods/test_on_debertaV2.py create mode 100644 tests/test_methods/test_on_distilbert.py create mode 100644 tests/test_methods/test_on_electra.py rename tests/{test_encoder_decoder.py => test_methods/test_on_encoder_decoder.py} (73%) create mode 100644 tests/test_methods/test_on_gpt2.py create mode 100644 tests/test_methods/test_on_llama.py create mode 100644 tests/test_methods/test_on_mbart.py create mode 100644 tests/test_methods/test_on_mistral.py create mode 100644 tests/test_methods/test_on_mt5.py create mode 100644 tests/test_methods/test_on_plbart.py create mode 100644 tests/test_methods/test_on_roberta.py create mode 100644 tests/test_methods/test_on_t5.py create mode 100644 tests/test_methods/test_on_vit.py create mode 100644 tests/test_methods/test_on_whisper.py create mode 100644 tests/test_methods/test_on_xlm_roberta.py create mode 100644 tests/test_methods/test_on_xmod.py rename tests/{composition => test_misc}/test_adapter_composition.py (99%) rename tests/{ => test_misc}/test_adapter_config.py (100%) rename tests/{ => test_misc}/test_adapter_custom_head.py (98%) rename tests/{ => test_misc}/test_adapter_fusion_config.py (100%) rename tests/{ => test_misc}/test_adapter_hub.py (96%) rename tests/{ => test_misc}/test_adapter_safetensors.py (100%) rename tests/{ => test_misc}/test_adapter_save_id2label.py (100%) create mode 100644 tests/test_misc/test_adapter_trainer/__init__.py rename tests/{ => test_misc/test_adapter_trainer}/test_adapter_trainer.py (100%) rename tests/{extended => test_misc/test_adapter_trainer}/test_adapter_trainer_ext.py (98%) delete mode 100644 tests/test_mistral.py create mode 100644 tests/test_models/__init__.py rename tests/{models => test_models}/base.py (100%) rename tests/{models/test_albert.py => test_models/test_albert_model.py} (100%) rename tests/{models/test_bart.py => test_models/test_bart_model.py} (100%) rename tests/{models/test_beit.py => test_models/test_beit_model.py} (100%) rename tests/{models/test_bert_generation.py => test_models/test_bert_generation_model.py} (100%) rename tests/{models/test_bert.py => test_models/test_bert_model.py} (100%) rename tests/{models/test_clip.py => test_models/test_clip_model.py} (100%) rename tests/{models/test_debertaV2.py => test_models/test_debertaV2_model.py} (100%) rename tests/{models/test_deberta.py => test_models/test_deberta_model.py} (100%) rename tests/{models/test_distilbert.py => test_models/test_distilbert_model.py} (100%) rename tests/{models/test_electra.py => test_models/test_electra_model.py} (100%) rename tests/{models/test_encoder_decoder.py => test_models/test_encoder_decoder_model.py} (100%) rename tests/{models/test_gpt2.py => test_models/test_gpt2_model.py} (100%) rename tests/{models/test_gptj.py => test_models/test_gptj_model.py} (100%) rename tests/{models/test_llama.py => test_models/test_llama_model.py} (100%) rename tests/{models/test_mbart.py => test_models/test_mbart_model.py} (100%) rename tests/{models/test_mistral.py => test_models/test_mistral_model.py} (100%) rename tests/{models/test_mt5.py => test_models/test_mt5_model.py} (100%) rename tests/{models/test_plbart.py => test_models/test_plbart_model.py} (100%) rename tests/{models/test_roberta.py => test_models/test_roberta_model.py} (100%) rename tests/{models/test_t5.py => test_models/test_t5_model.py} (100%) rename tests/{models/test_vit.py => test_models/test_vit_model.py} (100%) rename tests/{models/test_whisper.py => test_models/test_whisper_model.py} (100%) rename tests/{models/test_xlm_roberta.py => test_models/test_xlm_roberta_model.py} (100%) rename tests/{models/test_xmod.py => test_models/test_xmod_model.py} (100%) delete mode 100644 tests/test_mt5.py delete mode 100644 tests/test_plbart.py delete mode 100644 tests/test_roberta.py delete mode 100644 tests/test_t5.py delete mode 100644 tests/test_vit.py delete mode 100644 tests/test_whisper.py delete mode 100644 tests/test_xlm_roberta.py delete mode 100644 tests/test_xmod.py diff --git a/Makefile b/Makefile index 04b9374133..ed978553ef 100644 --- a/Makefile +++ b/Makefile @@ -28,18 +28,29 @@ style: isort $(check_dirs) ${MAKE} extra_style_checks -# Run tests for the library +# Library Tests +# run all tests in the library test: python -m pytest -n auto --dist=loadfile -s -v ./tests/ + python -c "import transformers; print(transformers.__version__)" +# run all tests for the adapter methods for all adapter models test-adapter-methods: - python -m pytest --ignore ./tests/models -n auto --dist=loadfile -s -v ./tests/ + python -m pytest -n auto --dist=loadfile -s -v ./tests/test_methods/ +# run a subset of the adapter method tests for all adapter models +# list of all subsets: [core, heads, embeddings, composition, prefix_tuning, prompt_tuning, reft, unipelt, compacter, bottleneck, ia3, lora, config_union] +subset ?= +test-adapter-method-subset: + @echo "Running subset $(subset)" + python -m pytest -n auto --dist=loadfile -s -v ./tests/test_methods/ -m $(subset) + + +# run the hugginface test suite for all adapter models test-adapter-models: - python -m pytest -n auto --dist=loadfile -s -v ./tests/models + python -m pytest -n auto --dist=loadfile -s -v ./tests/test_models/ # Run tests for examples - test-examples: python -m pytest -n auto --dist=loadfile -s -v ./examples/pytorch/ diff --git a/conftest.py b/conftest.py index 6e6ad8e7b8..425836fcb2 100644 --- a/conftest.py +++ b/conftest.py @@ -87,3 +87,8 @@ def check_output(self, want, got, optionflags): doctest.OutputChecker = CustomOutputChecker + + +def pytest_collection_modifyitems(items): + # Exclude the 'test_class' group from the test collection since it's not a real test class and byproduct of the generic test class generation. + items[:] = [item for item in items if 'test_class' not in item.nodeid] \ No newline at end of file diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index 20d7fbba0d..f2a7b14fef 100644 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -442,7 +442,7 @@ def main(): else: model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code) n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values()) - logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") + logger.info(f"Training new model from scratch - Total size={n_params / 2**20:.2f}M params") # Convert the model into an adapter model adapters.init(model) diff --git a/pyproject.toml b/pyproject.toml index 3dca5b20da..219af9b189 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,21 @@ [tool.black] line-length = 119 target-version = ['py38', 'py39', 'py310'] - -# copied from HF for testing [tool.pytest.ini_options] markers = [ + "core: marks tests as core adapter test", + "composition: marks tests as composition adapter test", + "heads: marks tests as heads adapter test", + "embeddings: marks tests as embeddings adapter test", + "class_conversion: marks tests as class conversion adapter test", + "prefix_tuning: marks tests as prefix tuning adapter test", + "prompt_tuning: marks tests as prompt tuning adapter test", + "reft: marks tests as reft adapter test", + "unipelt: marks tests as unipelt adapter test", + "compacter: marks tests as compacter adapter test", + "bottleneck: marks tests as bottleneck adapter test", + "ia3: marks tests as ia3 adapter test", + "lora: marks tests as lora adapter test", "flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')", "bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests", "generate: marks tests that use the GenerationTesterMixin" diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000000..71b5f34fc4 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,150 @@ +# Testing the Adapters Library + +This README provides a comprehensive overview of the test directory organization and explains how to execute different types of tests within the adapters library. + +## Test Directory Structure Overview + +``` +tests/ +├── __init__.py +├── fixtures/ # Datasets, test samples, ... +| └── ... +├── test_methods/ # Dynamic adapter method tests (all models) +│ ├── __init__.py +│ ├── method_test_impl/ # Implementation of tests +│ │ ├── __init__.py +│ │ ├── core/ +│ │ ├── composition/ +│ │ └── ... +│ ├── base.py # Base from which model test bases inherit +│ ├── generator.py # Testcase generation and registration +│ ├── test_on_albert.py # Example model test base for testing adapter methods on albert adapter model +│ ├── test_on_beit.py +│ └── ... +├── test_misc/ # Miscellaneous adapter method tests (single model) +│ ├── test_adapter_config.py +│ └── ... +├── test_models/ # Adapter model tests with Hugging Face test suite +│ └── __init__.py +│ │ ├── base.py +│ │ ├── test_albert_model.py +│ │ └── ... +``` + +## Test Categories + +The testing framework encompasses three distinct categories of tests: + +1. Dynamic Adapter Method Tests: These tests cover core functionalities of the adapters library, including individual adapter methods (such as LoRA and prompt tuning) and head functionalities. These tests are executed across all supported models. + +2. Miscellaneous Adapter Method Tests: These supplementary tests cover scenarios not included in the dynamic tests. To optimize resources, they are executed on a single model, as repeated execution across multiple models would not provide additional value. + +3. Adapter Model Tests: These tests verify the implementation of the adapter models themselves using the Hugging Face model test suite. + +## Test Generator and Pytest Markers + +The test_methods directory contains the central component `generator.py`, which generates appropriate sets of adapter method tests. Each model test base registers these tests using the following pattern: + +```python +method_tests = generate_method_tests(AlbertAdapterTestBase) + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class +``` + +Each generated test class is decorated with a specific marker type. For example: + +```python +@require_torch +@pytest.mark.lora +class LoRA( + AlbertAdapterTestBase, + LoRATestMixin, + unittest.TestCase, +): + pass +``` + +These markers enable the execution of specific test types across all models. You can run these tests using either of these methods: + +1. Using the make command: +```bash +make test-adapter-method-subset subset=lora +``` + +2. Directly executing from the test directory: +```bash +cd tests/test_methods +pytest -m lora +``` + +Both approaches will execute all LoRA tests across every model in the adapters library. + +## Adding a New Adapter Method to the Test Suite + +The modular design of the test base simplifies the process of adding tests for new adapter methods. To add tests for a new adapter method "X", follow these steps: + +1. Create the Test Implementation: + Create a new file `tests/test_methods/method_test_impl/peft/test_X.py` and implement the test mixin class: + + ```python + @require_torch + class XTestMixin(AdapterMethodBaseTestMixin): + + default_config = XConfig() + + def test_add_X(self): + model = self.get_model() + self.run_add_test(model, self.default_config, ["adapters.{name}."]) + + def ... + ``` + +2. Register the Test Mixin: + Add the new test mixin class to `tests/test_methods/generator.py`: + + ```python + from tests.test_methods.method_test_impl.peft.test_X import XTestMixin + + def generate_method_tests(model_test_base, ...): + """ Generate method tests for the given model test base """ + test_classes = {} + + @require_torch + @pytest.mark.core + class Core( + model_test_base, + CompabilityTestMixin, + AdapterFusionModelTestMixin, + unittest.TestCase, + ): + pass + + if "Core" not in excluded_tests: + test_classes["Core"] = Core + + @require_torch + @pytest.mark.X + class X( + model_test_base, + XTestMixin, + unittest.TestCase, + ): + pass + + if "X" not in excluded_tests: + test_classes["X"] = X + ``` + + The pytest marker enables execution of the new method's tests across all adapter models using: + ```bash + make test-adapter-method-subset subset=X + ``` + + If the new method is incompatible with specific adapter models, you can exclude the tests in the respective `test_on_xyz.py` file: + + ```python + method_tests = generate_method_tests(BartAdapterTestBase, excluded_tests=["PromptTuning", "X"]) + ``` + + Note: It is recommended to design new methods to work with the complete library whenever possible. Only exclude tests when there are unavoidable compatibility issues and make them clear in the documenation. \ No newline at end of file diff --git a/tests/models/__init__.py b/tests/fixtures/__init__.py similarity index 100% rename from tests/models/__init__.py rename to tests/fixtures/__init__.py diff --git a/tests/methods/__init__.py b/tests/methods/__init__.py deleted file mode 100644 index ea65b2997b..0000000000 --- a/tests/methods/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -# flake8: noqa -# There's no way to ignore "F401 '...' imported but unused" warnings in this -# module, but to preserve other warnings. So, don't check this module at all. - -# Copyright 2020 The Adapter-Hub Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .base import create_twin_models -from .test_adapter_common import BottleneckAdapterTestMixin -from .test_compacter import CompacterTestMixin -from .test_ia3 import IA3TestMixin -from .test_lora import LoRATestMixin -from .test_prefix_tuning import PrefixTuningTestMixin -from .test_prompt_tuning import PromptTuningTestMixin -from .test_reft import ReftTestMixin -from .test_unipelt import UniPELTTestMixin - - -class AllMethodsTestMixin( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - PromptTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, -): - """Shorthand mixin for models which support all adapter methods.""" - - pass diff --git a/tests/test_adapter.py b/tests/test_adapter.py deleted file mode 100644 index bafa7e65a9..0000000000 --- a/tests/test_adapter.py +++ /dev/null @@ -1,210 +0,0 @@ -import random - -import datasets -import torch - -import adapters -from adapters import AutoAdapterModel -from transformers import AutoFeatureExtractor, AutoTokenizer, GlueDataset, GlueDataTrainingArguments -from transformers.testing_utils import torch_device - - -global_rng = random.Random() - - -def make_config(config_class, **kwargs): - return staticmethod(lambda: config_class(**kwargs)) - - -def ids_tensor(shape, vocab_size, rng=None, name=None): - # Creates a random int32 tensor of the shape within the vocab size - if rng is None: - rng = global_rng - - total_dims = 1 - for dim in shape: - total_dims *= dim - - values = [] - for _ in range(total_dims): - values.append(rng.randint(0, vocab_size - 1)) - - return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() - - -class AdapterTestBase: - # If not overriden by subclass, AutoModel should be used. - model_class = AutoAdapterModel - # Default shape of inputs to use - default_input_samples_shape = (3, 64) - leave_out_layers = [0, 1] - do_run_train_tests = True - # default arguments for test_adapter_heads - batch_size = 1 - seq_length = 128 - is_speech_model = ( - False # Flag for tests to determine if the model is a speech model due to input format difference - ) - - def get_model(self): - if self.model_class == AutoAdapterModel: - model = AutoAdapterModel.from_config(self.config()) - else: - model = self.model_class(self.config()) - adapters.init(model) - model.to(torch_device) - return model - - def get_input_samples(self, shape=None, vocab_size=5000, config=None, **kwargs): - shape = shape or self.default_input_samples_shape - total_dims = 1 - for dim in shape: - total_dims *= dim - - values = [] - for _ in range(total_dims): - values.append(random.randint(0, vocab_size - 1)) - input_ids = torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() - # this is needed e.g. for BART - if config and config.eos_token_id is not None and config.eos_token_id < vocab_size: - input_ids[input_ids == config.eos_token_id] = random.randint(0, config.eos_token_id - 1) - input_ids[:, -1] = config.eos_token_id - in_data = {"input_ids": input_ids} - - if config and config.is_encoder_decoder: - in_data["decoder_input_ids"] = input_ids.clone() - return in_data - - def add_head(self, model, name, **kwargs): - model.add_classification_head(name, **kwargs) - return model.heads[name].config["num_labels"] - - def dataset(self, tokenizer=None): - # setup tokenizer - if tokenizer is None: - tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - data_args = GlueDataTrainingArguments( - task_name="mrpc", data_dir="./hf_transformers/tests/fixtures/tests_samples/MRPC", overwrite_cache=True - ) - return GlueDataset(data_args, tokenizer=tokenizer, mode="train") - - def assert_adapter_available(self, model, adapter_name): - self.assertTrue(adapter_name in model.adapters_config) - self.assertGreater(len(model.get_adapter(adapter_name)), 0) - - def assert_adapter_unavailable(self, model, adapter_name): - self.assertFalse(adapter_name in model.adapters_config) - self.assertEqual(len(model.get_adapter(adapter_name)), 0) - - -class VisionAdapterTestBase(AdapterTestBase): - default_input_samples_shape = (3, 3, 224, 224) - - def get_input_samples(self, shape=None, config=None, dtype=torch.float, **kwargs): - shape = shape or self.default_input_samples_shape - total_dims = 1 - for dim in shape: - total_dims *= dim - values = [] - for _ in range(total_dims): - values.append(random.random()) - pixel_values = torch.tensor(data=values, dtype=dtype, device=torch_device).view(shape).contiguous() - in_data = {"pixel_values": pixel_values} - - return in_data - - def add_head(self, model, name, **kwargs): - if "num_labels" not in kwargs: - kwargs["num_labels"] = 10 - model.add_image_classification_head(name, **kwargs) - return model.heads[name].config["num_labels"] - - def dataset(self, feature_extractor=None): - if feature_extractor is None: - feature_extractor = AutoFeatureExtractor.from_pretrained(self.feature_extractor_name) - - def transform(example_batch): - inputs = feature_extractor([x for x in example_batch["img"]], return_tensors="pt") - - inputs["labels"] = example_batch["label"] - return inputs - - dataset = datasets.load_dataset( - "./tests/fixtures/samples/cifar10", - data_dir="./tests/fixtures/samples/cifar10", - split="train", - trust_remote_code=True, - ) - dataset = dataset.with_transform(transform) - - return dataset - - -class SpeechAdapterTestBase(AdapterTestBase): - """Base class for speech adapter tests.""" - - default_input_samples_shape = (3, 80, 3000) # (batch_size, n_mels, enc_seq_len) - is_speech_model = True # Flag for tests to determine if the model is a speech model due to input format difference - time_window = 3000 # Time window for audio samples - seq_length = 80 - - def add_head(self, model, name, head_type="seq2seq_lm", **kwargs): - """Adds a head to the model.""" - if head_type == "audio_classification": - model.add_audio_classification_head(name, **kwargs) - return model.heads[name].config["num_labels"] - elif head_type == "seq2seq_lm": - kwargs.pop("num_labels", 1) # Remove num_labels from kwargs if present in the tests - model.add_seq2seq_lm_head(name, **kwargs) - return self.default_input_samples_shape[1] # Return the number of mel features - else: - raise ValueError(f"Head type {head_type} not supported.") - - def get_input_samples(self, shape=None, config=None, **kwargs): - """Creates a dummy batch of samples in the format required for speech models.""" - shape = shape or self.default_input_samples_shape - - # Input features - total_dims = 1 - for dim in shape: - total_dims *= dim - values = [] - for _ in range(total_dims): - values.append(random.random()) - input_features = torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous() - in_data = {"input_features": input_features} - - # Decoder input ids - if config and config.is_encoder_decoder: - in_data["decoder_input_ids"] = ids_tensor((shape[:-1]), config.vocab_size) - return in_data - - _TASK_DATASET_MAPPING = { - "seq2seq_lm": "./tests/fixtures/audio_datasets/common_voice_encoded", - "audio_classification": "./tests/fixtures/audio_datasets/speech_commands_encoded", - } - - def dataset(self, feature_extractor=None, processor=None, tokenizer=None, task_type: str = "seq2seq_lm", **kwargs): - """Returns a dataset to test speech model training. Standard dataset is for seq2seq_lm.""" - if task_type == "seq2seq_lm": - return self._prep_seq2seq_lm_dataset(task_type, **kwargs) - elif task_type == "audio_classification": - return self._prep_audio_classification_dataset(task_type, **kwargs) - - def _prep_seq2seq_lm_dataset(self, task_type, **kwargs): - """Prepares a dataset for conditional generation.""" - # The dataset is already processed and saved to disk, to save time during testing - # Preparation script can be found in tests/fixtures/audio_datasets/prepare_audio_datasets.py - dataset_path = self._TASK_DATASET_MAPPING[task_type] - dataset = datasets.load_from_disk(dataset_path) - return dataset["train"] - - def _prep_audio_classification_dataset(self, task_type, **kwargs): - """Prepares a dataset for audio classification.""" - # The dataset is already processed and saved to disk, to save time during testing - # Preparation script can be found in tests/fixtures/audio_datasets/prepare_audio_datasets.py - dataset_path = self._TASK_DATASET_MAPPING[task_type] - dataset = datasets.load_from_disk(dataset_path) - return dataset["train"] diff --git a/tests/test_albert.py b/tests/test_albert.py deleted file mode 100644 index 64dd62bc37..0000000000 --- a/tests/test_albert.py +++ /dev/null @@ -1,55 +0,0 @@ -import unittest -from math import ceil - -from transformers import AlbertConfig -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import AllMethodsTestMixin -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_embeddings import EmbeddingTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class AlbertAdapterTestBase(AdapterTestBase): - config_class = AlbertConfig - config = make_config( - AlbertConfig, - embedding_size=16, - hidden_size=64, - num_hidden_layers=5, - num_attention_heads=4, - intermediate_size=37, - num_hidden_groups=2, - ) - tokenizer_name = "albert-base-v2" - leave_out_layers = [0] - - -@require_torch -class AlbertAdapterTest( - AllMethodsTestMixin, - EmbeddingTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - ParallelTrainingMixin, - AlbertAdapterTestBase, - unittest.TestCase, -): - def test_context_simple(self): - expected_number_of_adapter_calls = ceil(self.config().num_hidden_layers / self.config().num_hidden_groups) - super().test_context_simple(expected_number_of_adapter_calls=expected_number_of_adapter_calls) - - -@require_torch -class AlbertClassConversionTest( - ModelClassConversionTestMixin, - AlbertAdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_bart.py b/tests/test_bart.py deleted file mode 100644 index 8c11dc7033..0000000000 --- a/tests/test_bart.py +++ /dev/null @@ -1,68 +0,0 @@ -import unittest - -from tests.methods.test_config_union import ConfigUnionAdapterTest -from transformers import BartConfig -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import ( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, -) -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_embeddings import EmbeddingTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class BartAdapterTestBase(AdapterTestBase): - config_class = BartConfig - config = make_config( - BartConfig, - d_model=16, - encoder_layers=2, - decoder_layers=2, - encoder_attention_heads=4, - decoder_attention_heads=4, - encoder_ffn_dim=4, - decoder_ffn_dim=4, - ) - tokenizer_name = "facebook/bart-base" - - -@require_torch -class BartAdapterTest( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - EmbeddingTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - ParallelTrainingMixin, - ConfigUnionAdapterTest, - BartAdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class BartClassConversionTest( - ModelClassConversionTestMixin, - BartAdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_beit.py b/tests/test_beit.py deleted file mode 100644 index a943b2e7fd..0000000000 --- a/tests/test_beit.py +++ /dev/null @@ -1,45 +0,0 @@ -import unittest - -from transformers import BeitConfig -from transformers.testing_utils import require_torch - -from .methods import AllMethodsTestMixin -from .test_adapter import VisionAdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class BeitAdapterTestBase(VisionAdapterTestBase): - config_class = BeitConfig - config = make_config( - BeitConfig, - image_size=224, - hidden_size=32, - num_hidden_layers=4, - num_attention_heads=4, - intermediate_size=37, - ) - feature_extractor_name = "microsoft/beit-base-patch16-224-pt22k" - - -@require_torch -class BeitAdapterTest( - AllMethodsTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - PredictionHeadModelTestMixin, - BeitAdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class BeitClassConversionTest( - ModelClassConversionTestMixin, - BeitAdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_bert.py b/tests/test_bert.py deleted file mode 100644 index 7bde9b557a..0000000000 --- a/tests/test_bert.py +++ /dev/null @@ -1,51 +0,0 @@ -import unittest - -from tests.methods.test_config_union import ConfigUnionAdapterTest -from transformers import BertConfig -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import AllMethodsTestMixin -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_embeddings import EmbeddingTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class BertAdapterTestBase(AdapterTestBase): - config_class = BertConfig - config = make_config( - BertConfig, - hidden_size=32, - num_hidden_layers=4, - num_attention_heads=4, - intermediate_size=37, - ) - tokenizer_name = "bert-base-uncased" - - -@require_torch -class BertAdapterTest( - AllMethodsTestMixin, - EmbeddingTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - ParallelTrainingMixin, - ConfigUnionAdapterTest, - BertAdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class BertClassConversionTest( - ModelClassConversionTestMixin, - BertAdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_clip.py b/tests/test_clip.py deleted file mode 100644 index ead9c7d561..0000000000 --- a/tests/test_clip.py +++ /dev/null @@ -1,230 +0,0 @@ -import random -import unittest - -import torch - -from transformers import ( - CLIPConfig, - CLIPTextConfig, - CLIPTextModel, - CLIPTextModelWithProjection, - CLIPVisionConfig, - CLIPVisionModel, - CLIPVisionModelWithProjection, -) -from transformers.testing_utils import require_torch, torch_device - -from .methods import ( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, -) -from .test_adapter import AdapterTestBase, VisionAdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin - - -class CLIPVisionAdapterTestBase(VisionAdapterTestBase): - model_class = CLIPVisionModel - config_class = CLIPVisionConfig - config = make_config( - CLIPVisionConfig, - image_size=224, - hidden_size=32, - num_hidden_layers=4, - num_attention_heads=4, - intermediate_size=37, - ) - feature_extractor_name = "openai/clip-vit-base-patch32" - - -@require_torch -class CLIPVisionAdapterTest( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - CLIPVisionAdapterTestBase, - unittest.TestCase, -): - pass - - -class CLIPVisionWithProjectionAdapterTestBase(VisionAdapterTestBase): - model_class = CLIPVisionModelWithProjection - config_class = CLIPVisionConfig - config = make_config( - CLIPVisionConfig, - image_size=224, - hidden_size=32, - num_hidden_layers=4, - num_attention_heads=4, - intermediate_size=37, - ) - feature_extractor_name = "openai/clip-vit-base-patch32" - - -@require_torch -class CLIPVisionWithProjectionAdapterTest( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - CLIPVisionWithProjectionAdapterTestBase, - unittest.TestCase, -): - pass - - -class CLIPTextAdapterTestBase(AdapterTestBase): - model_class = CLIPTextModel - config_class = CLIPTextConfig - config = make_config( - CLIPTextConfig, - hidden_size=32, - num_hidden_layers=4, - num_attention_heads=4, - intermediate_size=37, - ) - tokenizer_name = "openai/clip-vit-base-patch32" - - -@require_torch -class CLIPTextAdapterTest( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - CLIPTextAdapterTestBase, - unittest.TestCase, -): - pass - - -class CLIPTextWithProjectionAdapterTestBase(AdapterTestBase): - model_class = CLIPTextModelWithProjection - config_class = CLIPTextConfig - config = make_config( - CLIPTextConfig, - hidden_size=32, - num_hidden_layers=4, - num_attention_heads=4, - intermediate_size=37, - ) - tokenizer_name = "openai/clip-vit-base-patch32" - - -@require_torch -class CLIPTextWithProjectionAdapterTest( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - CLIPTextWithProjectionAdapterTestBase, - unittest.TestCase, -): - pass - - -class CLIPAdapterTestBase(AdapterTestBase): - config_class = CLIPConfig - config = staticmethod( - lambda: CLIPConfig.from_text_vision_configs( - CLIPTextConfig( - hidden_size=32, - num_hidden_layers=4, - num_attention_heads=4, - intermediate_size=37, - ), - CLIPVisionConfig( - image_size=224, - hidden_size=32, - num_hidden_layers=4, - num_attention_heads=4, - intermediate_size=37, - ), - ) - ) - tokenizer_name = "openai/clip-vit-base-patch32" - # Default shape of inputs to use - default_text_input_samples_shape = (3, 64) - default_vision_input_samples_shape = (3, 3, 224, 224) - do_run_train_tests = False - - def get_input_samples(self, vocab_size=5000, config=None, dtype=torch.float, **kwargs): - # text inputs - shape = self.default_text_input_samples_shape - total_dims = 1 - for dim in shape: - total_dims *= dim - values = [] - for _ in range(total_dims): - values.append(random.randint(0, vocab_size - 1)) - input_ids = torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() - # this is needed e.g. for BART - if config and config.eos_token_id is not None and config.eos_token_id < vocab_size: - input_ids[input_ids == config.eos_token_id] = random.randint(0, config.eos_token_id - 1) - input_ids[:, -1] = config.eos_token_id - in_data = {"input_ids": input_ids} - - # vision inputs - shape = self.default_vision_input_samples_shape - total_dims = 1 - for dim in shape: - total_dims *= dim - values = [] - for _ in range(total_dims): - values.append(random.random()) - pixel_values = torch.tensor(data=values, dtype=dtype, device=torch_device).view(shape).contiguous() - in_data["pixel_values"] = pixel_values - - return in_data - - def add_head(self, *args, **kwargs): - pass - - -@require_torch -class CLIPAdapterTest( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - CLIPAdapterTestBase, - unittest.TestCase, -): - def test_adapter_fusion_save_with_head(self): - # This test is not applicable to CLIP - self.skipTest("Not applicable to CLIP.") - - def test_load_adapter_setup(self): - self.skipTest("Not applicable to CLIP.") diff --git a/tests/test_deberta.py b/tests/test_deberta.py deleted file mode 100644 index 61b4c96957..0000000000 --- a/tests/test_deberta.py +++ /dev/null @@ -1,55 +0,0 @@ -import unittest - -from tests.methods.test_config_union import ConfigUnionAdapterTest -from transformers import DebertaConfig -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import AllMethodsTestMixin -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_embeddings import EmbeddingTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class DebertaAdapterTestBase(AdapterTestBase): - config_class = DebertaConfig - config = make_config( - DebertaConfig, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - relative_attention=True, - pos_att_type="p2c|c2p", - ) - tokenizer_name = "microsoft/deberta-base" - - -@require_torch -class DebertaAdapterTest( - AdapterFusionModelTestMixin, - CompabilityTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - AllMethodsTestMixin, - EmbeddingTestMixin, - ParallelTrainingMixin, - ConfigUnionAdapterTest, - DebertaAdapterTestBase, - unittest.TestCase, -): - def test_parallel_training_lora(self): - self.skipTest("Not supported for DeBERTa") - - -@require_torch -class DebertaClassConversionTest( - ModelClassConversionTestMixin, - DebertaAdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_debertaV2.py b/tests/test_debertaV2.py deleted file mode 100644 index 6494e1f865..0000000000 --- a/tests/test_debertaV2.py +++ /dev/null @@ -1,54 +0,0 @@ -import unittest - -from tests.methods.test_config_union import ConfigUnionAdapterTest -from transformers import DebertaV2Config -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import AllMethodsTestMixin -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_embeddings import EmbeddingTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class DebertaV2AdapterTestBase(AdapterTestBase): - config_class = DebertaV2Config - config = make_config( - DebertaV2Config, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - relative_attention=True, - pos_att_type="p2c|c2p", - ) - tokenizer_name = "microsoft/deberta-v3-base" - - -@require_torch -class DebertaV2AdapterTest( - AdapterFusionModelTestMixin, - CompabilityTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - AllMethodsTestMixin, - EmbeddingTestMixin, - ParallelTrainingMixin, - ConfigUnionAdapterTest, - DebertaV2AdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class DebertaV2ClassConversionTest( - ModelClassConversionTestMixin, - DebertaV2AdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_distilbert.py b/tests/test_distilbert.py deleted file mode 100644 index c90c39875c..0000000000 --- a/tests/test_distilbert.py +++ /dev/null @@ -1,51 +0,0 @@ -import unittest - -from tests.methods.test_config_union import ConfigUnionAdapterTest -from transformers import DistilBertConfig -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import AllMethodsTestMixin -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_embeddings import EmbeddingTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class DistilBertAdapterTestBase(AdapterTestBase): - config_class = DistilBertConfig - config = make_config( - DistilBertConfig, - dim=32, - n_layers=4, - n_heads=4, - hidden_dim=37, - ) - tokenizer_name = "distilbert-base-uncased" - - -@require_torch -class DistilBertAdapterTest( - AllMethodsTestMixin, - EmbeddingTestMixin, - CompabilityTestMixin, - AdapterFusionModelTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - ParallelTrainingMixin, - ConfigUnionAdapterTest, - DistilBertAdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class DistilBertClassConversionTest( - ModelClassConversionTestMixin, - DistilBertAdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_electra.py b/tests/test_electra.py deleted file mode 100644 index d3272a23d5..0000000000 --- a/tests/test_electra.py +++ /dev/null @@ -1,52 +0,0 @@ -import unittest - -from tests.methods.test_config_union import ConfigUnionAdapterTest -from transformers import ElectraConfig -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import AllMethodsTestMixin -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_embeddings import EmbeddingTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class ElectraAdapterTestBase(AdapterTestBase): - config_class = ElectraConfig - config = make_config( - ElectraConfig, - # vocab_size=99, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=4, - intermediate_size=37, - ) - tokenizer_name = "google/electra-base-generator" - - -@require_torch -class ElectraAdapterTest( - AllMethodsTestMixin, - EmbeddingTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - ParallelTrainingMixin, - ConfigUnionAdapterTest, - ElectraAdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class ElectraClassConversionTest( - ModelClassConversionTestMixin, - ElectraAdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_gpt2.py b/tests/test_gpt2.py deleted file mode 100644 index c6ac6d188f..0000000000 --- a/tests/test_gpt2.py +++ /dev/null @@ -1,67 +0,0 @@ -import unittest - -from tests.methods.test_config_union import ConfigUnionAdapterTest -from transformers import GPT2Config -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import ( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, -) -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_embeddings import EmbeddingTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class GPT2AdapterTestBase(AdapterTestBase): - config_class = GPT2Config - config = make_config( - GPT2Config, - n_embd=32, - n_layer=4, - n_head=4, - # set pad token to eos token - pad_token_id=50256, - ) - tokenizer_name = "gpt2" - - -@require_torch -class GPT2AdapterTest( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, - EmbeddingTestMixin, - CompabilityTestMixin, - AdapterFusionModelTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - ParallelTrainingMixin, - ConfigUnionAdapterTest, - GPT2AdapterTestBase, - unittest.TestCase, -): - def test_parallel_training_lora(self): - self.skipTest("Not supported for GPT2") - - -@require_torch -class GPT2ClassConversionTest( - ModelClassConversionTestMixin, - GPT2AdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_gptj.py b/tests/test_gptj.py deleted file mode 100644 index 934abf2904..0000000000 --- a/tests/test_gptj.py +++ /dev/null @@ -1,68 +0,0 @@ -import unittest - -from tests.methods.test_config_union import ConfigUnionAdapterTest -from transformers import GPTJConfig -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import ( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, -) -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_embeddings import EmbeddingTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class GPTJAdapterTestBase(AdapterTestBase): - config_class = GPTJConfig - config = make_config( - GPTJConfig, - n_embd=32, - n_layer=4, - n_head=4, - rotary_dim=4, - # set pad token to eos token - pad_token_id=50256, - resid_pdrop=0.1, - ) - tokenizer_name = "EleutherAI/gpt-j-6B" - - -@require_torch -class GPTJAdapterTest( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - ReftTestMixin, - UniPELTTestMixin, - PrefixTuningTestMixin, - EmbeddingTestMixin, - CompabilityTestMixin, - AdapterFusionModelTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - ParallelTrainingMixin, - ConfigUnionAdapterTest, - GPTJAdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class GPTJClassConversionTest( - ModelClassConversionTestMixin, - GPTJAdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_llama.py b/tests/test_llama.py deleted file mode 100644 index d3c78e23f3..0000000000 --- a/tests/test_llama.py +++ /dev/null @@ -1,66 +0,0 @@ -import unittest - -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import ( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, -) -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_embeddings import EmbeddingTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class LlamaAdapterTestBase(AdapterTestBase): - config_class = LlamaConfig - config = make_config( - LlamaConfig, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - pad_token_id=0, - ) - tokenizer_name = "openlm-research/open_llama_13b" - - -@require_torch -class LlamaAdapterTest( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, - EmbeddingTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - ParallelTrainingMixin, - LlamaAdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class LlamaClassConversionTest( - ModelClassConversionTestMixin, - LlamaAdapterTestBase, - unittest.TestCase, -): - def test_conversion_question_answering_model(self): - raise self.skipTest("We don't support the Llama QA model.") diff --git a/tests/test_mbart.py b/tests/test_mbart.py deleted file mode 100644 index 56fa406daf..0000000000 --- a/tests/test_mbart.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest - -from transformers import MBartConfig -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin -from .methods import ( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, -) -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class MBartAdapterTestBase(AdapterTestBase): - config_class = MBartConfig - config = make_config( - MBartConfig, - d_model=16, - encoder_layers=2, - decoder_layers=2, - encoder_attention_heads=4, - decoder_attention_heads=4, - encoder_ffn_dim=4, - decoder_ffn_dim=4, - vocab_size=250027, - ) - tokenizer_name = "facebook/mbart-large-cc25" - - -@require_torch -class MBartAdapterTest( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, - AdapterFusionModelTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - MBartAdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class MBartClassConversionTest( - ModelClassConversionTestMixin, - MBartAdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_methods/__init__.py b/tests/test_methods/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_methods/base.py b/tests/test_methods/base.py new file mode 100644 index 0000000000..f5e53fedd6 --- /dev/null +++ b/tests/test_methods/base.py @@ -0,0 +1,232 @@ +import random + +import datasets +import torch + +import adapters +from adapters import AutoAdapterModel +from transformers import AutoFeatureExtractor, AutoTokenizer, GlueDataset, GlueDataTrainingArguments +from transformers.testing_utils import torch_device + + +class AbstractAdapterTestBase: + """Base class for adapter tests. Defines basic functions and attributes with default values which are used in the tests. + Model test classes should inherit from this class or subclass and override the attributes and functions as needed. + """ + + model_class = AutoAdapterModel + tokenizer_name = "tests/fixtures/SiBERT" # path to default tokenizer config available in the test repo + config = None # specified in the actual model test classes + input_shape = () # (batch_size, seq_length) + leave_out_layers = [] + do_run_train_tests = True + num_labels = 2 + + def get_input_samples(self, shape=None, vocab_size=5000, config=None, **kwargs): + """Creates a dummy batch of samples in the format required for the model.""" + raise NotImplementedError("get_input_samples() must be implemented in the subclass.") + + def add_head(self, model, name, **kwargs): + """Adds a dummy head to the model.""" + raise NotImplementedError("add_head() must be implemented in the subclass.") + + def get_dataset(self, **kwargs): + """Loads a dummy dataset for the model.""" + raise NotImplementedError("get_dataset() must be implemented in the subclass.") + + def get_dataset_non_batched(self): + """Builds a non-batched dummy dataset for the model.""" + raise NotImplementedError("build_dummy_dataset() must be implemented in the subclass.") + + def attach_labels(self, inputs): + """Attaches labels to the input samples.""" + raise NotImplementedError("attach_labels() with respective label shape must be implemented in the subclass.") + + def get_model(self): + """Builds a model instance for testing based on the provied model configuration.""" + if self.model_class == AutoAdapterModel: + model = AutoAdapterModel.from_config(self.config()) + else: + model = self.model_class(self.config()) + adapters.init(model) + model.to(torch_device) + return model + + def build_rand_tensor(self, shape, dtype=torch.float): + """Creates a random tensor of the given shape.""" + total_dims = self._calc_total_dim(shape) + values = [random.random() for _ in range(total_dims)] + + return torch.tensor(data=values, dtype=dtype, device=torch_device).view(shape).contiguous() + + def build_rand_ids_tensor(self, shape, vocab_size=5000): + """Creates a random tensor of type torch.long with the given shape with random values in range 0 - (vocab_size-1).""" + total_dims = self._calc_total_dim(shape) + values = [random.randint(0, vocab_size - 1) for _ in range(total_dims)] + return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() + + def _calc_total_dim(self, shape): + total_dims = 1 + for dim in shape: + total_dims *= dim + return total_dims + + def extract_input_ids(self, inputs): + # TODO: Check if this is needed in all tests and if it differs between text, vision and speech models + return inputs["input_ids"] + + def build_generate_input(self, shape): + """The generate() functions for inference require different inputs depeding on the model type. E.g. the text models require input_ids, where as the audio models require input_features""" + return self.build_rand_ids_tensor(self.input_shape if not shape else shape).to(torch_device) + + +class TextAdapterTestBase(AbstractAdapterTestBase): + """Base class for adapter tests for text models. Text models test classes should inherit from this class and override the attributes and functions as needed.""" + + input_shape = (3, 64) + leave_out_layers = [0, 1] + batch_size, seq_length = ( + input_shape # TODO: Check in which tests this is needed and if we can simplify by using input_shape + ) + + def get_input_samples(self, shape=None, vocab_size=5000, config=None, **kwargs): + shape = shape or self.input_shape + input_ids = self.build_rand_ids_tensor(shape, vocab_size=vocab_size) + + # Ensures that only tha last token in each sample is the eos token (needed e.g. for BART) + if config and config.eos_token_id is not None and config.eos_token_id < vocab_size: + input_ids[input_ids == config.eos_token_id] = random.randint(0, config.eos_token_id - 1) + input_ids[:, -1] = config.eos_token_id + in_data = {"input_ids": input_ids} + + # Add decoder input ids for models with a decoder + if config and config.is_encoder_decoder: + in_data["decoder_input_ids"] = input_ids.clone() + + if "num_labels" in kwargs: + in_data["labels"] = self.build_rand_ids_tensor(shape[:-1], vocab_size=kwargs["num_labels"]) + return in_data + + def add_head(self, model, name, **kwargs): + # TODO: Check if this should be more modular + model.add_classification_head(name, **kwargs) + return model.heads[name].config["num_labels"] + + def get_dataset(self, tokenizer=None): + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + data_args = GlueDataTrainingArguments( + task_name="mrpc", data_dir="./hf_transformers/tests/fixtures/tests_samples/MRPC", overwrite_cache=True + ) + return GlueDataset(data_args, tokenizer=tokenizer, mode="train") + + def get_dataset_non_batched(self, config): + dataset = [] + for i in range(3): + input_data = self.get_input_samples(config=config) + input_data["labels"] = self.build_rand_ids_tensor((3, 1), self.num_labels) + dataset.append(input_data) + return dataset + + def attach_labels(self, inputs): + inputs["labels"] = torch.randint(0, 2, (self.batch_size, 1), device=torch_device) + return inputs + + +class VisionAdapterTestBase(AbstractAdapterTestBase): + """Base class for adapter tests for vision models. Vision models test classes should inherit from this class and override the attributes and functions as needed.""" + + input_shape = (3, 3, 224, 224) + batch_size = 3 + + def get_input_samples(self, shape=None, config=None, dtype=torch.float, **kwargs): + shape = shape or self.input_shape + pixel_values = self.build_rand_tensor(shape, dtype=dtype) + return {"pixel_values": pixel_values} + + def add_head(self, model, name, **kwargs): + kwargs["num_labels"] = 10 if "num_labels" not in kwargs else kwargs["num_labels"] + model.add_image_classification_head(name, **kwargs) + return model.heads[name].config["num_labels"] + + def get_dataset(self, feature_extractor=None): + dataset = datasets.load_dataset( + "./tests/fixtures/samples/cifar10", + data_dir="./tests/fixtures/samples/cifar10", + split="train", + trust_remote_code=True, + ) + if feature_extractor is None: + feature_extractor = AutoFeatureExtractor.from_pretrained(self.feature_extractor_name) + + def transform(example_batch): + inputs = feature_extractor([x for x in example_batch["img"]], return_tensors="pt") + inputs["labels"] = example_batch["label"] + return inputs + + dataset = dataset.with_transform(transform) + return dataset + + +class AudioAdapterTestBase(AbstractAdapterTestBase): + """Base class for adapter tests for audio models. Audio models test classes should inherit from this class and override the attributes and functions as needed.""" + + input_shape = (3, 80, 3000) # (batch_size, n_mels, enc_seq_len) + time_window = 3000 # Time window for audio samples + seq_length = 80 + batch_size = 3 + + _TASK_DATASET_MAPPING = { + # TODO: build global mapping for all tasks and datasets + "seq2seq_lm": "./tests/fixtures/audio_datasets/common_voice_encoded", + "audio_classification": "./tests/fixtures/audio_datasets/speech_commands_encoded", + } + + def add_head(self, model, name, head_type="seq2seq_lm", **kwargs): + # TODO: simpify Audio tests by using the same head type for all tests + if head_type == "audio_classification": + model.add_audio_classification_head(name, **kwargs) + return model.heads[name].config["num_labels"] + elif head_type == "seq2seq_lm": + kwargs.pop("num_labels", 1) # Remove num_labels from kwargs if present in the tests + model.add_seq2seq_lm_head(name, **kwargs) + return self.input_shape[1] # Return the number of mel features + else: + raise ValueError(f"Head type {head_type} not supported.") + + def get_input_samples(self, shape=None, config=None, **kwargs): + shape = shape or self.input_shape + in_data = {"input_features": self.build_rand_tensor(shape, dtype=torch.float)} + + # Add decoder input ids for models with a decoder + if config and config.is_encoder_decoder: + in_data["decoder_input_ids"] = self.build_rand_ids_tensor((shape[:-1]), vocab_size=config.vocab_size) + return in_data + + def get_dataset(self, task_type: str = "seq2seq_lm", **kwargs): + # Dataset is already processed and saved to disk, to save time during testing + # Preparation script can be found in tests/fixtures/audio_datasets/respective_prepare_script.py + dataset_path = self._TASK_DATASET_MAPPING[task_type] + dataset = datasets.load_from_disk(dataset_path) + return dataset["train"] + + def extract_input_ids(self, inputs): + return inputs["input_features"] + + def build_generate_input(self, shape): + return self.build_rand_tensor(self.input_shape if not shape else shape, dtype=torch.float) + + def attach_labels(self, inputs): + inputs["labels"] = torch.randint(0, 2, (self.batch_size, self.seq_length), device=torch_device) + return inputs + + def get_dataset_non_batched(self, config): + dataset_batched = self.get_dataset() + dataset = [{} for _ in range(len(dataset_batched))] + # For non-batched training, we need to wrap the samples by an additional dimension + for i in range(len(dataset_batched)): + for key, value in dataset_batched[i].items(): + dataset[i][key] = torch.unsqueeze(value, 0) + return dataset diff --git a/tests/test_methods/generator.py b/tests/test_methods/generator.py new file mode 100644 index 0000000000..2c33b46532 --- /dev/null +++ b/tests/test_methods/generator.py @@ -0,0 +1,227 @@ +import unittest + +import pytest + +from tests.test_methods.method_test_impl.composition.test_parallel import ( + ParallelAdapterInferenceTestMixin, + ParallelTrainingMixin, +) +from tests.test_methods.method_test_impl.core.test_adapter_backward_compability import CompabilityTestMixin +from tests.test_methods.method_test_impl.core.test_adapter_conversion import ModelClassConversionTestMixin +from tests.test_methods.method_test_impl.core.test_adapter_fusion_common import AdapterFusionModelTestMixin +from tests.test_methods.method_test_impl.embeddings.test_adapter_embeddings import EmbeddingTestMixin +from tests.test_methods.method_test_impl.heads.test_adapter_heads import PredictionHeadModelTestMixin +from tests.test_methods.method_test_impl.peft.test_adapter_common import BottleneckAdapterTestMixin +from tests.test_methods.method_test_impl.peft.test_compacter import CompacterTestMixin +from tests.test_methods.method_test_impl.peft.test_config_union import ConfigUnionAdapterTest +from tests.test_methods.method_test_impl.peft.test_ia3 import IA3TestMixin +from tests.test_methods.method_test_impl.peft.test_lora import LoRATestMixin +from tests.test_methods.method_test_impl.peft.test_prefix_tuning import PrefixTuningTestMixin +from tests.test_methods.method_test_impl.peft.test_prompt_tuning import PromptTuningTestMixin +from tests.test_methods.method_test_impl.peft.test_reft import ReftTestMixin +from tests.test_methods.method_test_impl.peft.test_unipelt import UniPELTTestMixin +from transformers.testing_utils import require_torch + + +def generate_method_tests( + model_test_base, + redundant=[], + not_supported=[], +) -> dict: + """ + Generates a set of method test classes for a given model test base. + + Args: + model_test_base (type): The base class for the model tests. + redundant (list, optional): A list of redundant tests to exclude. Defaults to []. + not_supported (list, optional): A list of tests that are not supported for the model. Defaults to []. + + Returns: + dict: A dictionary mapping test class names to the generated test classes. + """ + test_classes = {} + + if "Core" not in redundant and "Core" not in not_supported: + + @require_torch + @pytest.mark.core + class Core( + model_test_base, + CompabilityTestMixin, + AdapterFusionModelTestMixin, + unittest.TestCase, + ): + pass + + test_classes["Core"] = Core + + if "Heads" not in redundant and "Heads" not in not_supported: + + @require_torch + @pytest.mark.heads + class Heads( + model_test_base, + PredictionHeadModelTestMixin, + unittest.TestCase, + ): + pass + + test_classes["Heads"] = Heads + + if "Embeddings" not in redundant and "Embeddings" not in not_supported: + + @require_torch + @pytest.mark.embeddings + class Embeddings( + model_test_base, + EmbeddingTestMixin, + unittest.TestCase, + ): + pass + + test_classes["Embeddings"] = Embeddings + + if "Composition" not in redundant and "Composition" not in not_supported: + + @require_torch + @pytest.mark.composition + class Composition( + model_test_base, + ParallelAdapterInferenceTestMixin, + ParallelTrainingMixin, + unittest.TestCase, + ): + pass + + test_classes["Composition"] = Composition + + if "ClassConversion" not in redundant and "ClassConversion" not in not_supported: + + @require_torch + class ClassConversion( + ModelClassConversionTestMixin, + model_test_base, + unittest.TestCase, + ): + pass + + test_classes["ClassConversion"] = ClassConversion + + if "PrefixTuning" not in redundant and "PrefixTuning" not in not_supported: + + @require_torch + @pytest.mark.prefix_tuning + class PrefixTuning( + model_test_base, + PrefixTuningTestMixin, + unittest.TestCase, + ): + pass + + test_classes["PrefixTuning"] = PrefixTuning + + if "PromptTuning" not in redundant and "PromptTuning" not in not_supported: + + @require_torch + @pytest.mark.prompt_tuning + class PromptTuning( + model_test_base, + PromptTuningTestMixin, + unittest.TestCase, + ): + pass + + test_classes["PromptTuning"] = PromptTuning + + if "ReFT" not in redundant and "ReFT" not in not_supported: + + @require_torch + @pytest.mark.reft + class ReFT( + model_test_base, + ReftTestMixin, + unittest.TestCase, + ): + pass + + test_classes["ReFT"] = ReFT + + if "UniPELT" not in redundant and "UniPELT" not in not_supported: + + @require_torch + @pytest.mark.unipelt + class UniPELT( + model_test_base, + UniPELTTestMixin, + unittest.TestCase, + ): + pass + + test_classes["UniPELT"] = UniPELT + + if "Compacter" not in redundant and "Compacter" not in not_supported: + + @require_torch + @pytest.mark.compacter + class Compacter( + model_test_base, + CompacterTestMixin, + unittest.TestCase, + ): + pass + + test_classes["Compacter"] = Compacter + + if "Bottleneck" not in redundant and "Bottleneck" not in not_supported: + + @require_torch + @pytest.mark.bottleneck + class Bottleneck( + model_test_base, + BottleneckAdapterTestMixin, + unittest.TestCase, + ): + pass + + test_classes["Bottleneck"] = Bottleneck + + if "IA3" not in redundant and "IA3" not in not_supported: + + @require_torch + @pytest.mark.ia3 + class IA3( + model_test_base, + IA3TestMixin, + unittest.TestCase, + ): + pass + + test_classes["IA3"] = IA3 + + if "LoRA" not in redundant and "LoRA" not in not_supported: + + @require_torch + @pytest.mark.lora + class LoRA( + model_test_base, + LoRATestMixin, + unittest.TestCase, + ): + pass + + test_classes["LoRA"] = LoRA + + if "ConfigUnion" not in redundant and "ConfigUnion" not in not_supported: + + @require_torch + @pytest.mark.config_union + class ConfigUnion( + model_test_base, + ConfigUnionAdapterTest, + unittest.TestCase, + ): + pass + + test_classes["ConfigUnion"] = ConfigUnion + + return test_classes diff --git a/tests/test_methods/method_test_impl/__init__.py b/tests/test_methods/method_test_impl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/methods/base.py b/tests/test_methods/method_test_impl/base.py similarity index 85% rename from tests/methods/base.py rename to tests/test_methods/method_test_impl/base.py index 55389b7052..95e3411725 100644 --- a/tests/methods/base.py +++ b/tests/test_methods/method_test_impl/base.py @@ -13,33 +13,26 @@ from transformers import TrainingArguments from transformers.testing_utils import require_torch, torch_device - -def create_twin_models(model_class, config_creator=None): - if config_creator and model_class.__name__.startswith("Auto"): - model_config = config_creator() - model1 = model_class.from_config(model_config) - elif config_creator: - model_config = config_creator() - model1 = model_class(model_config) - else: - model_config = model_class.config_class() - model1 = model_class(model_config) - adapters.init(model1) - model1.eval() - # create a twin initialized with the same random weights - model2 = copy.deepcopy(model1) - model2.eval() - return model1, model2 +from .utils import add_lm_head, create_twin_models @require_torch class AdapterMethodBaseTestMixin: - """Provides base test running methods for testing an adapter method implementation.""" + """Implements base test running methods for testing adapter method implementations.""" - # Model weight dtypes to test in forward pass dtypes_to_test = [torch.float32, torch.half] if torch_device == "cuda" else [torch.float32] - def filter_parameters(self, model, filter_keys): + def _assert_adapter_available(self, model, adapter_name): + """Check wether the adapter name is present in the model's adapter config and has been created.""" + self.assertTrue(adapter_name in model.adapters_config) + self.assertGreater(len(model.get_adapter(adapter_name)), 0) + + def _assert_adapter_unavailable(self, model, adapter_name): + """Check wether the adapter name is not present in the model's adapter config and has not been created.""" + self.assertFalse(adapter_name in model.adapters_config) + self.assertEqual(len(model.get_adapter(adapter_name)), 0) + + def _filter_parameters(self, model, filter_keys): return {k: v for (k, v) in model.named_parameters() if any([filter_key in k for filter_key in filter_keys])} def run_add_test(self, model, adapter_config, filter_keys): @@ -57,11 +50,15 @@ def run_add_test(self, model, adapter_config, filter_keys): # check that weights are available and active has_weights = False filter_keys = [k.format(name=name) for k in filter_keys] - for k, v in self.filter_parameters(model, filter_keys).items(): + for k, v in self._filter_parameters(model, filter_keys).items(): has_weights = True self.assertTrue(v.requires_grad, k) self.assertTrue(has_weights) + # Remove added adapters in case of multiple subtests + model.set_active_adapters(None) + model.delete_adapter(name) + def run_leave_out_test(self, model, adapter_config, leave_out): model.eval() @@ -71,7 +68,7 @@ def run_leave_out_test(self, model, adapter_config, leave_out): model.set_active_adapters(name) # adapter is correctly added to config - self.assert_adapter_available(model, name) + self._assert_adapter_available(model, name) adapter = model.get_adapter(name) @@ -96,7 +93,7 @@ def run_linear_average_test(self, model, adapter_config, filter_keys): averaged_weights = {} for i, w in enumerate(weights): this_filter_keys = [k.format(name=name + f"_{i}") for k in filter_keys] - for k, v in self.filter_parameters(model, this_filter_keys).items(): + for k, v in self._filter_parameters(model, this_filter_keys).items(): base_k = k.replace(name + f"_{i}", name) if base_k not in averaged_weights: averaged_weights[base_k] = w * v @@ -114,7 +111,7 @@ def run_linear_average_test(self, model, adapter_config, filter_keys): # compare averaged weights to collected weights this_filter_keys = [k.format(name=name) for k in filter_keys] - for k, v in self.filter_parameters(model, this_filter_keys).items(): + for k, v in self._filter_parameters(model, this_filter_keys).items(): self.assertTrue(torch.allclose(v, averaged_weights[k]), k) def run_delete_test(self, model, adapter_config, filter_keys): @@ -126,16 +123,16 @@ def run_delete_test(self, model, adapter_config, filter_keys): model.to(torch_device) # adapter is correctly added to config - self.assert_adapter_available(model, name) + self._assert_adapter_available(model, name) # remove the adapter again model.delete_adapter(name) - self.assert_adapter_unavailable(model, name) + self._assert_adapter_unavailable(model, name) # check that weights are available and active has_weights = False filter_keys = [k.format(name=name) for k in filter_keys] - for k, v in self.filter_parameters(model, filter_keys).items(): + for k, v in self._filter_parameters(model, filter_keys).items(): has_weights = True self.assertFalse(has_weights) @@ -147,7 +144,7 @@ def run_get_test(self, model, adapter_config, num_expected_modules): # adapter is correctly added to config name = "first" - self.assert_adapter_available(model, name) + self._assert_adapter_available(model, name) adapter = model.get_adapter("first") @@ -187,6 +184,10 @@ def run_forward_test(self, model, adapter_config, dtype=torch.float32): self.assertGreaterEqual(len(output_1), len(base_output)) self.assertFalse(torch.equal(output_1[0], base_output[0])) + # Remove added adapters in case of multiple subtests + model.set_active_adapters(None) + model.delete_adapter(name) + def run_load_test(self, adapter_config): model1, model2 = create_twin_models(self.model_class, self.config) @@ -250,7 +251,7 @@ def run_full_model_load_test(self, adapter_config): def trainings_run(self, model, lr=1.0, steps=8, batch_size=2, gradient_accumulation_steps=1): # setup dataset - train_dataset = self.dataset() + train_dataset = self.get_dataset() training_args = TrainingArguments( output_dir="./examples", @@ -283,8 +284,8 @@ def run_train_test(self, adapter_config, filter_keys): model.add_adapter("dummy", config=adapter_config) self.add_head(model, "mrpc") - self.assert_adapter_available(model, "mrpc") - self.assert_adapter_available(model, "dummy") + self._assert_adapter_available(model, "mrpc") + self._assert_adapter_available(model, "dummy") # train the mrpc adapter -> should be activated & unfreezed model.train_adapter("mrpc") @@ -293,13 +294,13 @@ def run_train_test(self, adapter_config, filter_keys): # all weights of the adapter should be activated has_weights = False filter_keys_trained = [k.format(name="mrpc") for k in filter_keys] - for k, v in self.filter_parameters(model, filter_keys_trained).items(): + for k, v in self._filter_parameters(model, filter_keys_trained).items(): has_weights = True self.assertTrue(v.requires_grad, k) self.assertTrue(has_weights) # all weights of the adapter not used for training should be frozen filter_keys_untrained = [k.format(name="dummy") for k in filter_keys] - for k, v in self.filter_parameters(model, filter_keys_untrained).items(): + for k, v in self._filter_parameters(model, filter_keys_untrained).items(): self.assertFalse(v.requires_grad, k) state_dict_pre = copy.deepcopy(model.state_dict()) @@ -427,29 +428,17 @@ def adapter_setup_fn(model): self._run_gradient_checkpointing_test_helper(adapter_setup_fn) - def run_generate_test(self, adapter_config): + def run_generate_test(self, adapter_config, max_new_tokens=32): if self.config_class not in ADAPTER_MODEL_MAPPING or ( "seq2seq_lm" not in ADAPTER_MODEL_MAPPING[self.config_class].head_types and "causal_lm" not in ADAPTER_MODEL_MAPPING[self.config_class].head_types ): self.skipTest("No seq2seq or causal language model head") - - model1 = AutoAdapterModel.from_config(self.config()) - model1.add_adapter("dummy", config=adapter_config) - if "seq2seq_lm" in ADAPTER_MODEL_MAPPING[self.config_class].head_types: - model1.add_seq2seq_lm_head("dummy") - else: - model1.add_causal_lm_head("dummy") - model1.set_active_adapters("dummy") - model1.to(torch_device) - - seq_output_length = 32 - - # Finally, also check if generation works properly - if self.is_speech_model: - input_ids = self.get_input_samples((1, 80, 3000), config=model1.config)["input_features"] - else: - input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"] - input_ids = input_ids.to(torch_device) - generated = model1.generate(input_ids, max_length=seq_output_length) - self.assertLessEqual(generated.shape, (1, seq_output_length)) + model = self.get_model() + model.add_adapter("generate", config=adapter_config) + add_lm_head(self.config_class, model, "generate") + model.set_active_adapters("generate") + model.to(torch_device) + generate_input = self.build_generate_input(self.input_shape).to(torch_device) + generated = model.generate(generate_input, max_new_tokens=max_new_tokens) + self.assertLessEqual(generated.shape, (self.input_shape[0], self.input_shape[1] + max_new_tokens)) diff --git a/tests/test_methods/method_test_impl/composition/__init__.py b/tests/test_methods/method_test_impl/composition/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/composition/test_parallel.py b/tests/test_methods/method_test_impl/composition/test_parallel.py similarity index 83% rename from tests/composition/test_parallel.py rename to tests/test_methods/method_test_impl/composition/test_parallel.py index 8a15a9f1c5..8d3ebfccec 100644 --- a/tests/composition/test_parallel.py +++ b/tests/test_methods/method_test_impl/composition/test_parallel.py @@ -12,7 +12,7 @@ T5AdapterModel, ) from adapters.composition import BatchSplit, Parallel -from adapters.models.bert_generation.adapter_model import BertGenerationAdapterModel +from tests.test_methods.method_test_impl.utils import add_lm_head from transformers import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, Trainer, TrainingArguments from transformers.testing_utils import require_torch, torch_device @@ -116,7 +116,7 @@ def test_batch_split_with_heads(self): ) ) - def test_parallel_generate(self): + def test_parallel_generate(self, max_new_tokens=32): if self.config_class not in ADAPTER_MODEL_MAPPING or ( "seq2seq_lm" not in ADAPTER_MODEL_MAPPING[self.config_class].head_types and "causal_lm" not in ADAPTER_MODEL_MAPPING[self.config_class].head_types @@ -126,25 +126,13 @@ def test_parallel_generate(self): model1 = AutoAdapterModel.from_config(self.config()) model1.add_adapter("adapter1") model1.add_adapter("adapter2") - if "seq2seq_lm" in ADAPTER_MODEL_MAPPING[self.config_class].head_types: - model1.add_seq2seq_lm_head("adapter1") - model1.add_seq2seq_lm_head("adapter2") - else: - model1.add_causal_lm_head("adapter1") - model1.add_causal_lm_head("adapter2") + add_lm_head(self.config_class, model1, "adapter1") + add_lm_head(self.config_class, model1, "adapter2") model1.set_active_adapters(Parallel("adapter1", "adapter2")) model1.to(torch_device) - - seq_output_length = 32 - - # Finally, also check if generation works properly - if self.is_speech_model: - input_ids = self.get_input_samples((1, 80, 3000), config=model1.config)["input_features"] - else: - input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"] - input_ids = input_ids.to(torch_device) - generated = model1.generate(input_ids, max_length=seq_output_length) - self.assertLessEqual(generated.shape, (2, seq_output_length)) + generate_input = self.build_generate_input(self.input_shape).to(torch_device) + generated = model1.generate(generate_input, max_new_tokens=max_new_tokens) + self.assertLessEqual(generated.shape, (self.input_shape[0] * 2, self.input_shape[1] + max_new_tokens)) class ParallelTrainingMixin: @@ -208,7 +196,7 @@ def run_parallel_training_test(self, adapter_config, filter_key): state_dict_pre = copy.deepcopy(model.state_dict()) - train_dataset = self.dataset() + train_dataset = self.get_dataset() training_args = TrainingArguments( output_dir="./examples", do_train=True, @@ -241,22 +229,7 @@ def run_parallel_training_equivalent_to_single(self, adapter_config): a1, a2 = self.create_twin_adapters(model, "a", adapter_config) b1, b2 = self.create_twin_adapters(model, "b", adapter_config) - dataset = [] - if self.is_speech_model: - dataset_batched = self.dataset() - dataset = [{} for _ in range(len(dataset_batched))] - # As this test uses a non-batched training, we need to wrap the samples by an additional dimension - for i in range(len(dataset_batched)): - for key, value in dataset_batched[i].items(): - dataset[i][key] = torch.unsqueeze(value, 0) - else: - for i in range(3): - input_data = self.get_input_samples(config=model.config) - if isinstance(model, BertGenerationAdapterModel): - input_data["labels"] = torch.randint(0, 2, (3, 64)) - else: - input_data["labels"] = torch.randint(0, 2, (3, 1)) - dataset.append(input_data) + dataset = self.get_dataset_non_batched(model.config) for adapter in [a1, b1]: model.active_head = adapter @@ -314,12 +287,7 @@ def test_parallel_training_single_forward_pass(self): input_data = self.get_input_samples( config=model.config, ) - if isinstance(model, BertGenerationAdapterModel): - input_data["labels"] = torch.randint(0, 2, (3, 64), device=torch_device) - elif self.is_speech_model: - input_data["labels"] = input_data["decoder_input_ids"] - else: - input_data["labels"] = torch.randint(0, 2, (3, 1), device=torch_device) + input_data = self.attach_labels(input_data) outputs = [] for adapter in [a1, b1]: diff --git a/tests/test_methods/method_test_impl/core/__init__.py b/tests/test_methods/method_test_impl/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_adapter_backward_compability.py b/tests/test_methods/method_test_impl/core/test_adapter_backward_compability.py similarity index 96% rename from tests/test_adapter_backward_compability.py rename to tests/test_methods/method_test_impl/core/test_adapter_backward_compability.py index 6ec2ef2143..196380524f 100644 --- a/tests/test_adapter_backward_compability.py +++ b/tests/test_methods/method_test_impl/core/test_adapter_backward_compability.py @@ -3,7 +3,7 @@ import tempfile from adapters import SeqBnConfig, __version__ -from tests.methods import create_twin_models +from tests.test_methods.method_test_impl.utils import create_twin_models from transformers.testing_utils import require_torch diff --git a/tests/test_adapter_conversion.py b/tests/test_methods/method_test_impl/core/test_adapter_conversion.py similarity index 90% rename from tests/test_adapter_conversion.py rename to tests/test_methods/method_test_impl/core/test_adapter_conversion.py index 067b1b9665..cb285a33f4 100644 --- a/tests/test_adapter_conversion.py +++ b/tests/test_methods/method_test_impl/core/test_adapter_conversion.py @@ -103,30 +103,26 @@ def test_conversion_masked_lm_model(self): label_dict["decoder_input_ids"] = label_dict["labels"].clone() self.run_test(model, label_dict=label_dict) - def test_conversion_seq2seq_lm_model(self): - if ( - self.config_class not in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING - and self.config_class not in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING - ): + def test_conversion_audio_seq2seq_lm_model(self): + if self.config_class not in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING: self.skipTest("No seq2seq language modeling class.") + label_dict = {} + model = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING[self.config_class](self.config()) + label_dict["input_features"] = torch.randn((self.input_shape), dtype=torch.float32, device=torch_device) + label_dict["decoder_input_ids"] = torch.randint( + 0, model.config.vocab_size, size=self.input_shape[:-1], device=torch_device + ) + label_dict["labels"] = label_dict["decoder_input_ids"] + adapters.init(model) + self.run_test(model, label_dict=label_dict) + def test_conversion_text_seq2seq_lm_model(self): + if self.config_class not in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: + self.skipTest("No seq2seq language modeling class.") label_dict = {} - if self.is_speech_model: - # speech models require input_features - model = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING[self.config_class](self.config()) - label_dict["input_features"] = torch.randn( - (self.default_input_samples_shape), dtype=torch.float32, device=torch_device - ) - label_dict["decoder_input_ids"] = torch.randint( - 0, model.config.vocab_size, size=self.default_input_samples_shape[:-1], device=torch_device - ) - label_dict["labels"] = label_dict["decoder_input_ids"] - else: - model = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[self.config_class](self.config()) - label_dict["labels"] = torch.zeros( - (self.batch_size, self.seq_length), dtype=torch.long, device=torch_device - ) - label_dict["decoder_input_ids"] = label_dict["labels"].clone() + model = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[self.config_class](self.config()) + label_dict["labels"] = torch.zeros((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device) + label_dict["decoder_input_ids"] = label_dict["labels"].clone() adapters.init(model) self.run_test(model, label_dict=label_dict) diff --git a/tests/test_adapter_fusion_common.py b/tests/test_methods/method_test_impl/core/test_adapter_fusion_common.py similarity index 98% rename from tests/test_adapter_fusion_common.py rename to tests/test_methods/method_test_impl/core/test_adapter_fusion_common.py index b8472483ee..754111d4a3 100644 --- a/tests/test_adapter_fusion_common.py +++ b/tests/test_methods/method_test_impl/core/test_adapter_fusion_common.py @@ -206,14 +206,14 @@ def test_output_adapter_fusion_attentions(self): model.set_active_adapters(Fuse("a", "b")) output_1 = model(**input_data, output_adapter_fusion_attentions=True) - self.assertEqual(len(output_1[0]), self.default_input_samples_shape[0]) + self.assertEqual(len(output_1[0]), self.input_shape[0]) self.assertTrue(hasattr(output_1, "adapter_fusion_attentions")) attention_scores = output_1.adapter_fusion_attentions["a,b"] self.assertEqual(len(list(model.iter_layers())), len(attention_scores)) for k, per_layer_scores in attention_scores.items(): self.assertEqual(len(per_layer_scores), 1) for k, v in per_layer_scores.items(): - self.assertEqual(self.default_input_samples_shape[0], v.shape[0], k) + self.assertEqual(self.input_shape[0], v.shape[0], k) def test_add_adapter_fusion_custom_name(self): config_name = "seq_bn" diff --git a/tests/test_methods/method_test_impl/embeddings/__init__.py b/tests/test_methods/method_test_impl/embeddings/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_adapter_embeddings.py b/tests/test_methods/method_test_impl/embeddings/test_adapter_embeddings.py similarity index 88% rename from tests/test_adapter_embeddings.py rename to tests/test_methods/method_test_impl/embeddings/test_adapter_embeddings.py index 0284b7c384..a41b862004 100644 --- a/tests/test_adapter_embeddings.py +++ b/tests/test_methods/method_test_impl/embeddings/test_adapter_embeddings.py @@ -47,7 +47,8 @@ def test_delete_embeddings(self): def test_save_load_embedding(self): model = self.get_model() - tokenizer, input_data = self._instantiate_tokenizer(model) + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) + input_data = self.get_input_samples(config=self.config()) model.add_embeddings("test", tokenizer) model.eval() model.to(torch_device) @@ -65,12 +66,13 @@ def test_save_load_embedding(self): torch.equal(model.loaded_embeddings["test"].weight, model.loaded_embeddings["test_reloaded"].weight) ) self.assertTrue(torch.equal(output1[0], output2[0])) - self.assertEqual(tokenizer.vocab, tokenizer_ref.vocab) + self.assertEqual(tokenizer.get_vocab(), tokenizer_ref.get_vocab()) def test_back_to_default(self): model = self.get_model() model.eval() - tokenizer, input_data = self._instantiate_tokenizer(model) + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) + input_data = self.get_input_samples(config=self.config()) output1 = model(**input_data) model.add_embeddings("test", tokenizer) self.assertEqual(model.active_embeddings, "test") @@ -99,7 +101,7 @@ def test_training_embedding(self): state_dict_pre = copy.deepcopy(model.state_dict()) initial_embedding = model.get_input_embeddings().weight.clone() - train_dataset = self.dataset() + train_dataset = self.get_dataset() training_args = TrainingArguments( output_dir="./examples", do_train=True, @@ -174,14 +176,3 @@ def test_reference_embedding(self): # activate for training model.add_adapter("test") model.train_adapter("test", train_embeddings=True) - - def _instantiate_tokenizer(self, model): - """Depending on the model type, instantiate a tokenizer and input data. - Speech models require a different tokenizer and sample size.""" - if self.is_speech_model: - tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name) - input_data = self.get_input_samples(config=self.config()) - else: - tokenizer = AutoTokenizer.from_pretrained("tests/fixtures/SiBERT") - input_data = self.get_input_samples((1, 128), vocab_size=tokenizer.vocab_size, config=model.config) - return tokenizer, input_data diff --git a/tests/test_methods/method_test_impl/heads/__init__.py b/tests/test_methods/method_test_impl/heads/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_adapter_heads.py b/tests/test_methods/method_test_impl/heads/test_adapter_heads.py similarity index 93% rename from tests/test_adapter_heads.py rename to tests/test_methods/method_test_impl/heads/test_adapter_heads.py index df7a0ac7f8..4e54eaed33 100644 --- a/tests/test_adapter_heads.py +++ b/tests/test_methods/method_test_impl/heads/test_adapter_heads.py @@ -7,11 +7,10 @@ from adapters import ADAPTER_MODEL_MAPPING, AdapterSetup, AutoAdapterModel from adapters.composition import BatchSplit, Stack from adapters.heads import PredictionHead +from tests.test_methods.method_test_impl.utils import create_twin_models from transformers import AutoModelForSequenceClassification from transformers.testing_utils import require_torch, torch_device -from .methods import create_twin_models - @require_torch class PredictionHeadModelTestMixin: @@ -21,10 +20,8 @@ def run_prediction_head_test( compare_model, head_name, input_shape=None, - output_shape=(1, 2), + output_shape=None, label_dict=None, - num_labels=None, - with_labels=False, ): # first, check if the head is actually correctly registered as part of the pt module self.assertTrue(f"heads.{head_name}" in dict(model.named_modules())) @@ -43,10 +40,8 @@ def run_prediction_head_test( # make a forward pass model.active_head = head_name - input_shape = input_shape if input_shape is not None else self._get_input_shape() - in_data = self.get_input_samples( - input_shape, config=model.config, num_labels=num_labels, with_labels=with_labels - ) + input_shape = input_shape if input_shape else self.input_shape + in_data = self.get_input_samples(shape=input_shape, config=model.config) if label_dict: for k, v in label_dict.items(): in_data[k] = v @@ -70,7 +65,9 @@ def test_classification_head(self): model1.add_classification_head("dummy") label_dict = {} label_dict["labels"] = torch.zeros(self.batch_size, dtype=torch.long, device=torch_device) - self.run_prediction_head_test(model1, model2, "dummy", label_dict=label_dict) + self.run_prediction_head_test( + model1, model2, "dummy", label_dict=label_dict, output_shape=(self.batch_size, 2) + ) def test_image_classification_head(self): if "image_classification" not in ADAPTER_MODEL_MAPPING[self.config_class].head_types: @@ -81,7 +78,14 @@ def test_image_classification_head(self): model1.add_image_classification_head("dummy") label_dict = {} label_dict["labels"] = torch.zeros(self.batch_size, dtype=torch.long, device=torch_device) - self.run_prediction_head_test(model1, model2, "dummy", input_shape=(1, 3, 224, 224), label_dict=label_dict) + self.run_prediction_head_test( + model1, + model2, + "dummy", + input_shape=self.input_shape, + label_dict=label_dict, + output_shape=(self.batch_size, 2), + ) def test_multiple_choice_head(self): if "multiple_choice" not in ADAPTER_MODEL_MAPPING[self.config_class].head_types: @@ -93,7 +97,12 @@ def test_multiple_choice_head(self): label_dict = {} label_dict["labels"] = torch.ones(self.batch_size, dtype=torch.long, device=torch_device) self.run_prediction_head_test( - model1, model2, "dummy", input_shape=(self.batch_size, 2, self.seq_length), label_dict=label_dict + model1, + model2, + "dummy", + input_shape=(self.batch_size, 2, self.seq_length), + label_dict=label_dict, + output_shape=(self.batch_size, 2), ) def test_tagging_head(self): @@ -174,17 +183,14 @@ def test_seq2seq_lm_head(self): ) # Finally, also check if generation works properly - input_shape = self._get_input_shape() - if self.is_speech_model: - input_ids = self.get_input_samples(input_shape, config=model1.config)["input_features"] - else: - input_ids = self.get_input_samples(input_shape, config=model1.config)["input_ids"] + input_ids = self.extract_input_ids(self.get_input_samples(self.input_shape, config=model1.config)) + input_ids = input_ids.to(torch_device) # Use a different length for the seq2seq output seq_output_length = self.seq_length + 30 generated = model1.generate(input_ids, max_length=seq_output_length) self.assertTrue(generated.shape[1] <= seq_output_length) - self.assertEqual(generated.shape[0], 1) + self.assertEqual(generated.shape[0], self.input_shape[0]) def test_masked_lm_head(self): if "masked_lm" not in ADAPTER_MODEL_MAPPING[self.config_class].head_types: @@ -240,7 +246,11 @@ def test_dependency_parsing_head(self): (self.batch_size, self.seq_length), dtype=torch.long, device=torch_device ) self.run_prediction_head_test( - model1, model2, "dummy", output_shape=(1, self.seq_length, self.seq_length + 1, 2), label_dict=label_dict + model1, + model2, + "dummy", + output_shape=(self.batch_size, self.seq_length, self.seq_length + 1, 2), + label_dict=label_dict, ) def test_delete_head(self): @@ -424,8 +434,7 @@ def forward_pre_hook(module, input): self.assertIsNotNone(inv_adapter) inv_adapter.register_forward_pre_hook(forward_pre_hook) - input_shape = self._get_input_shape() - in_data = self.get_input_samples(input_shape, config=model.config) + in_data = self.get_input_samples(self.input_shape, config=model.config) model.to(torch_device) out = model(**in_data) @@ -474,14 +483,6 @@ def test_save_all_adapters_with_head(self): model.save_all_adapters(tmp_dir, with_head=False) self.assertFalse(os.path.isfile(os.path.join(tmp_dir, "test", "head_config.json"))) - def _get_input_shape(self): - # speech models require a different input dimensions compared to text models - if self.is_speech_model: - input_shape = (self.batch_size, self.seq_length, self.time_window) - else: - input_shape = (self.batch_size, self.seq_length) - return input_shape - def test_average_head(self): # Test the average_head method model = AutoAdapterModel.from_config(self.config()) diff --git a/tests/test_methods/method_test_impl/peft/__init__.py b/tests/test_methods/method_test_impl/peft/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/methods/test_adapter_common.py b/tests/test_methods/method_test_impl/peft/test_adapter_common.py similarity index 96% rename from tests/methods/test_adapter_common.py rename to tests/test_methods/method_test_impl/peft/test_adapter_common.py index 717d3af98e..53c9a1c7be 100644 --- a/tests/methods/test_adapter_common.py +++ b/tests/test_methods/method_test_impl/peft/test_adapter_common.py @@ -22,11 +22,11 @@ ) from adapters.heads.language_modeling import CausalLMHead from adapters.utils import SETUP_CONFIG_NAME +from tests.test_methods.method_test_impl.base import AdapterMethodBaseTestMixin +from tests.test_methods.method_test_impl.utils import create_twin_models from transformers import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, CLIPConfig from transformers.testing_utils import require_torch, torch_device -from .base import AdapterMethodBaseTestMixin, create_twin_models - @require_torch class BottleneckAdapterTestMixin(AdapterMethodBaseTestMixin): @@ -134,12 +134,12 @@ def test_delete_adapter_with_invertible(self): model.set_active_adapters(name) # check if adapter is correctly added to config - self.assert_adapter_available(model, name) + self._assert_adapter_available(model, name) # remove the adapter again model.delete_adapter(name) # check if adapter is correctly removed from the model - self.assert_adapter_unavailable(model, name) + self._assert_adapter_unavailable(model, name) # check additionally if invertible adapter is removed correctly from the model self.assertFalse(name in model.invertible_adapters) @@ -148,7 +148,7 @@ def test_delete_adapter_with_invertible(self): # check that weights are available and active has_weights = False filter_keys = [k.format(name=name) for k in filter_keys] - for k, v in self.filter_parameters(model, filter_keys).items(): + for k, v in self._filter_parameters(model, filter_keys).items(): has_weights = True self.assertFalse(has_weights) @@ -390,13 +390,13 @@ def test_train_adapter_fusion(self): self.assertEqual(adapter_setup, model.active_adapters) # all weights of the adapters should be frozen (test for one) - for k, v in self.filter_parameters(model, ["adapters.a."]).items(): + for k, v in self._filter_parameters(model, ["adapters.a."]).items(): self.assertFalse(v.requires_grad, k) # all weights of the fusion layer should be activated - for k, v in self.filter_parameters(model, ["adapter_fusion_layer"]).items(): + for k, v in self._filter_parameters(model, ["adapter_fusion_layer"]).items(): self.assertTrue(v.requires_grad, k) # weights of the model should be frozen (check on some examples) - for k, v in self.filter_parameters(model, ["encoder.layer.0.attention"]).items(): + for k, v in self._filter_parameters(model, ["encoder.layer.0.attention"]).items(): self.assertFalse(v.requires_grad, k) state_dict_pre = copy.deepcopy(model.state_dict()) @@ -456,13 +456,13 @@ def test_batch_split_training(self): model.train_adapter(adapter_setup) # all weights of the adapter should be activated - for k, v in self.filter_parameters(model, ["adapters.mrpc1."]).items(): + for k, v in self._filter_parameters(model, ["adapters.mrpc1."]).items(): self.assertTrue(v.requires_grad, k) # all weights of the adapter not used for training should be frozen - for k, v in self.filter_parameters(model, ["adapters.mrpc2."]).items(): + for k, v in self._filter_parameters(model, ["adapters.mrpc2."]).items(): self.assertTrue(v.requires_grad, k) # weights of the model should be frozen (check on some examples) - for k, v in self.filter_parameters(model, ["encoder.layer.0.attention"]).items(): + for k, v in self._filter_parameters(model, ["encoder.layer.0.attention"]).items(): self.assertFalse(v.requires_grad, k) state_dict_pre = copy.deepcopy(model.state_dict()) diff --git a/tests/methods/test_compacter.py b/tests/test_methods/method_test_impl/peft/test_compacter.py similarity index 57% rename from tests/methods/test_compacter.py rename to tests/test_methods/method_test_impl/peft/test_compacter.py index 06b3a346e1..d2e97270ee 100644 --- a/tests/methods/test_compacter.py +++ b/tests/test_methods/method_test_impl/peft/test_compacter.py @@ -1,38 +1,36 @@ from adapters import CompacterPlusPlusConfig +from tests.test_methods.method_test_impl.base import AdapterMethodBaseTestMixin from transformers.testing_utils import require_torch -from .base import AdapterMethodBaseTestMixin - @require_torch class CompacterTestMixin(AdapterMethodBaseTestMixin): + default_config = CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8) + def test_add_compacter(self): model = self.get_model() - self.run_add_test(model, CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8), ["adapters.{name}."]) + self.run_add_test(model, self.default_config, ["adapters.{name}."]) def test_leave_out_compacter(self): model = self.get_model() - self.run_leave_out_test(model, CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8), self.leave_out_layers) + self.run_leave_out_test(model, self.default_config, self.leave_out_layers) def test_linear_average_compacter(self): model = self.get_model() - self.run_linear_average_test( - model, CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8), ["adapters.{name}."] - ) + self.run_linear_average_test(model, self.default_config, ["adapters.{name}."]) def test_delete_compacter(self): model = self.get_model() - self.run_delete_test(model, CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8), ["adapters.{name}."]) + self.run_delete_test(model, self.default_config, ["adapters.{name}."]) def test_get_compacter(self): model = self.get_model() n_layers = len(list(model.iter_layers())) - self.run_get_test(model, CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8), n_layers + 1) + self.run_get_test(model, self.default_config, n_layers + 1) def test_forward_compacter(self): model = self.get_model() - adapter_config = CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8) - self.run_forward_test(model, adapter_config) + self.run_forward_test(model, self.default_config) def test_forward_shared_phm_compacter(self): model = self.get_model() @@ -40,7 +38,7 @@ def test_forward_shared_phm_compacter(self): self.run_forward_test(model, adapter_config) def test_load_compacter(self): - self.run_load_test(CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8)) + self.run_load_test(self.default_config) def test_train_shared_w_compacter(self): adapter_config = CompacterPlusPlusConfig( @@ -49,8 +47,7 @@ def test_train_shared_w_compacter(self): self.run_train_test(adapter_config, ["adapters.{name}."]) def test_train_shared_phm_compacter(self): - adapter_config = CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8) - self.run_train_test(adapter_config, ["adapters.{name}."]) + self.run_train_test(self.default_config, ["adapters.{name}."]) def test_compacter_generate(self): self.run_generate_test(CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8)) diff --git a/tests/methods/test_config_union.py b/tests/test_methods/method_test_impl/peft/test_config_union.py similarity index 59% rename from tests/methods/test_config_union.py rename to tests/test_methods/method_test_impl/peft/test_config_union.py index 12d82a5def..78d06158f7 100644 --- a/tests/methods/test_config_union.py +++ b/tests/test_methods/method_test_impl/peft/test_config_union.py @@ -6,7 +6,7 @@ PrefixTuningConfig, SeqBnConfig, ) -from tests.methods.base import AdapterMethodBaseTestMixin +from tests.test_methods.method_test_impl.base import AdapterMethodBaseTestMixin from transformers.testing_utils import require_torch @@ -22,15 +22,17 @@ class ConfigUnionAdapterTest(AdapterMethodBaseTestMixin): ), ( ConfigUnion( - CompacterConfig(), - LoRAConfig(), + CompacterConfig( + reduction_factor=8 + ), # set to smaller value than default due to smaller hidden size of test models + LoRAConfig(init_weights="bert"), # set to bert to avoid zero initialization ), ["adapters.{name}.", "loras.{name}."], ), ( ConfigUnion( - SeqBnConfig(), - LoRAConfig(), + SeqBnConfig(phm_dim=1), + LoRAConfig(init_weights="bert"), # set to bert to avoid zero initialization ), ["adapters.{name}.", "loras.{name}."], ), @@ -39,15 +41,23 @@ class ConfigUnionAdapterTest(AdapterMethodBaseTestMixin): def test_add_union_adapter(self): model = self.get_model() model.eval() - for adapter_config, filter_keys in self.adapter_configs_to_test: - with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__): + config = ( + "ConfigUnion: " + + adapter_config.configs[0].__class__.__name__ + + adapter_config.configs[1].__class__.__name__ + ) + with self.subTest(model_class=model.__class__.__name__, config=config): self.run_add_test(model, adapter_config, filter_keys) def test_union_adapter_forward(self): model = self.get_model() model.eval() - for adapter_config, _ in self.adapter_configs_to_test: - with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__): + config = ( + "ConfigUnion: " + + adapter_config.configs[0].__class__.__name__ + + adapter_config.configs[1].__class__.__name__ + ) + with self.subTest(model_class=model.__class__.__name__, config=config): self.run_forward_test(model, adapter_config) diff --git a/tests/methods/test_ia3.py b/tests/test_methods/method_test_impl/peft/test_ia3.py similarity index 95% rename from tests/methods/test_ia3.py rename to tests/test_methods/method_test_impl/peft/test_ia3.py index b96dbcd02a..6c273f1b7d 100644 --- a/tests/methods/test_ia3.py +++ b/tests/test_methods/method_test_impl/peft/test_ia3.py @@ -1,8 +1,7 @@ from adapters import IA3Config +from tests.test_methods.method_test_impl.base import AdapterMethodBaseTestMixin from transformers.testing_utils import require_torch -from .base import AdapterMethodBaseTestMixin - @require_torch class IA3TestMixin(AdapterMethodBaseTestMixin): diff --git a/tests/methods/test_lora.py b/tests/test_methods/method_test_impl/peft/test_lora.py similarity index 96% rename from tests/methods/test_lora.py rename to tests/test_methods/method_test_impl/peft/test_lora.py index e1ced5188a..8d47ae6242 100644 --- a/tests/methods/test_lora.py +++ b/tests/test_methods/method_test_impl/peft/test_lora.py @@ -4,10 +4,9 @@ from adapters import LoRAConfig from adapters.methods.lora import LoRALayer +from tests.test_methods.method_test_impl.base import AdapterMethodBaseTestMixin from transformers.testing_utils import require_torch -from .base import AdapterMethodBaseTestMixin - @require_torch class LoRATestMixin(AdapterMethodBaseTestMixin): @@ -69,7 +68,7 @@ def test_linear_average_only_negate_b_lora(self): averaged_weights = {} for i, w in enumerate(weights): this_filter_keys = [k.format(name=f"{name}_{i}") for k in ["loras.{name}."]] - for k, v in self.filter_parameters(model, this_filter_keys).items(): + for k, v in self._filter_parameters(model, this_filter_keys).items(): base_k = k.replace(f"{name}_{i}", name) # Only negate the lora_B weights and use the absolute value of the weight for lora_A weights. weight = abs(w) if "lora_A" in k else w @@ -93,7 +92,7 @@ def test_linear_average_only_negate_b_lora(self): # compare averaged weights to collected weights this_filter_keys = [k.format(name=name) for k in ["loras.{name}."]] - for k, v in self.filter_parameters(model, this_filter_keys).items(): + for k, v in self._filter_parameters(model, this_filter_keys).items(): self.assertTrue(torch.allclose(v, averaged_weights[k]), k) def _check_svd_weights(self, delta_w, merged_lora, svd_rank, atol=1e-5): @@ -195,7 +194,7 @@ def test_edge_case_average_adapters_single_adapter(self): # collect weights of the first adapter so we can compare them to the newly created adapters in the subsequent tests filter_keys_adapter_0 = [k.format(name=f"{name}_0") for k in ["loras.{name}."]] - adapter_0 = self.filter_parameters(model, filter_keys_adapter_0) + adapter_0 = self._filter_parameters(model, filter_keys_adapter_0) # Run tests for every combine strategy for combine_strategy in ["linear", "lora_linear_only_negate_b", "lora_delta_w_svd"]: @@ -215,7 +214,7 @@ def test_edge_case_average_adapters_single_adapter(self): filter_keys = [k.format(name=f"{combine_strategy}_merged") for k in ["loras.{name}."]] if combine_strategy != "lora_delta_w_svd": - for k, v in self.filter_parameters(model, filter_keys).items(): + for k, v in self._filter_parameters(model, filter_keys).items(): adapter_0_key = k.replace(f"{combine_strategy}_merged", f"{name}_0") self.assertTrue(torch.allclose(v, adapter_0[adapter_0_key])) else: @@ -247,7 +246,7 @@ def test_edge_case_average_adapters_multiple_adapters(self): # collect weights of the first adapter so we can compare them to the newly created adapters in the subsequent tests filter_keys_adapter_0 = [k.format(name=f"{name}_0") for k in ["loras.{name}."]] - adapter_0 = self.filter_parameters(model, filter_keys_adapter_0) + adapter_0 = self._filter_parameters(model, filter_keys_adapter_0) # Run tests for every combine strategy for combine_strategy in ["linear", "lora_linear_only_negate_b", "lora_delta_w_svd"]: @@ -269,7 +268,7 @@ def test_edge_case_average_adapters_multiple_adapters(self): filter_keys = [k.format(name=f"{combine_strategy}_merged") for k in ["loras.{name}."]] if combine_strategy != "lora_delta_w_svd": - for k, v in self.filter_parameters(model, filter_keys).items(): + for k, v in self._filter_parameters(model, filter_keys).items(): adapter_1_key = k.replace(f"{combine_strategy}_merged", f"{name}_0") self.assertTrue(torch.allclose(v, adapter_0[adapter_1_key])) else: diff --git a/tests/methods/test_prefix_tuning.py b/tests/test_methods/method_test_impl/peft/test_prefix_tuning.py similarity index 97% rename from tests/methods/test_prefix_tuning.py rename to tests/test_methods/method_test_impl/peft/test_prefix_tuning.py index 35906edb19..1a883a817c 100644 --- a/tests/methods/test_prefix_tuning.py +++ b/tests/test_methods/method_test_impl/peft/test_prefix_tuning.py @@ -1,11 +1,10 @@ import torch from adapters import PrefixTuningConfig +from tests.test_methods.method_test_impl.base import AdapterMethodBaseTestMixin from transformers import CLIPConfig from transformers.testing_utils import require_torch, torch_device -from .base import AdapterMethodBaseTestMixin - @require_torch class PrefixTuningTestMixin(AdapterMethodBaseTestMixin): diff --git a/tests/methods/test_prompt_tuning.py b/tests/test_methods/method_test_impl/peft/test_prompt_tuning.py similarity index 95% rename from tests/methods/test_prompt_tuning.py rename to tests/test_methods/method_test_impl/peft/test_prompt_tuning.py index f2fd1b0345..2856ca7d12 100644 --- a/tests/methods/test_prompt_tuning.py +++ b/tests/test_methods/method_test_impl/peft/test_prompt_tuning.py @@ -1,8 +1,7 @@ from adapters import PromptTuningConfig +from tests.test_methods.method_test_impl.base import AdapterMethodBaseTestMixin from transformers.testing_utils import require_torch -from .base import AdapterMethodBaseTestMixin - @require_torch class PromptTuningTestMixin(AdapterMethodBaseTestMixin): diff --git a/tests/methods/test_reft.py b/tests/test_methods/method_test_impl/peft/test_reft.py similarity index 96% rename from tests/methods/test_reft.py rename to tests/test_methods/method_test_impl/peft/test_reft.py index 76d0980d57..a458aaa2b9 100644 --- a/tests/methods/test_reft.py +++ b/tests/test_methods/method_test_impl/peft/test_reft.py @@ -1,8 +1,7 @@ from adapters import DiReftConfig, LoReftConfig, NoReftConfig +from tests.test_methods.method_test_impl.base import AdapterMethodBaseTestMixin from transformers.testing_utils import require_torch -from .base import AdapterMethodBaseTestMixin - @require_torch class ReftTestMixin(AdapterMethodBaseTestMixin): @@ -32,7 +31,7 @@ def test_layers_reft(self): model.set_active_adapters(name) # adapter is correctly added to config - self.assert_adapter_available(model, name) + self._assert_adapter_available(model, name) adapter = model.get_adapter(name) diff --git a/tests/methods/test_unipelt.py b/tests/test_methods/method_test_impl/peft/test_unipelt.py similarity index 92% rename from tests/methods/test_unipelt.py rename to tests/test_methods/method_test_impl/peft/test_unipelt.py index b855670ab4..6617568490 100644 --- a/tests/methods/test_unipelt.py +++ b/tests/test_methods/method_test_impl/peft/test_unipelt.py @@ -1,8 +1,7 @@ from adapters import UniPELTConfig +from tests.test_methods.method_test_impl.base import AdapterMethodBaseTestMixin from transformers.testing_utils import require_torch, torch_device -from .base import AdapterMethodBaseTestMixin - @require_torch class UniPELTTestMixin(AdapterMethodBaseTestMixin): @@ -56,14 +55,14 @@ def test_output_adapter_gating_scores_unipelt(self): model.set_active_adapters(name) output_1 = model(**input_data, output_adapter_gating_scores=True) - self.assertEqual(len(output_1[0]), self.default_input_samples_shape[0]) + self.assertEqual(len(output_1[0]), self.input_shape[0]) self.assertTrue(hasattr(output_1, "adapter_gating_scores")) gating_scores = output_1.adapter_gating_scores[name] self.assertEqual(len(list(model.iter_layers())), len(gating_scores)) for k, per_layer_scores in gating_scores.items(): self.assertGreaterEqual(len(per_layer_scores), 3) for k, v in per_layer_scores.items(): - self.assertEqual(self.default_input_samples_shape[0], v.shape[0], k) + self.assertEqual(self.input_shape[0], v.shape[0], k) def test_unipelt_gradient_checkpointing_single_adapter(self): self.run_gradient_checkpointing_single_adapter_test(UniPELTConfig()) diff --git a/tests/test_methods/method_test_impl/utils.py b/tests/test_methods/method_test_impl/utils.py new file mode 100644 index 0000000000..473c422e60 --- /dev/null +++ b/tests/test_methods/method_test_impl/utils.py @@ -0,0 +1,48 @@ +import copy +import random + +import torch + +from adapters import ADAPTER_MODEL_MAPPING, init +from transformers.testing_utils import torch_device + + +global_rng = random.Random() + + +def create_twin_models(model_class, config_creator=None): + if config_creator and model_class.__name__.startswith("Auto"): + model_config = config_creator() + model1 = model_class.from_config(model_config) + elif config_creator: + model_config = config_creator() + model1 = model_class(model_config) + else: + model_config = model_class.config_class() + model1 = model_class(model_config) + init(model1) + model1.eval() + # create a twin initialized with the same random weights + model2 = copy.deepcopy(model1) + model2.eval() + return model1, model2 + + +def add_lm_head(config_class, model, adapter_name): + """Add appropriate language model head based on model type""" + if "seq2seq_lm" in ADAPTER_MODEL_MAPPING[config_class].head_types: + model.add_seq2seq_lm_head(adapter_name) + else: + model.add_causal_lm_head(adapter_name) + + +def make_config(config_class, **kwargs): + return staticmethod(lambda: config_class(**kwargs)) + + +def ids_tensor(shape, vocab_size=5000, dtype=torch.long): + total_dims = 1 + for dim in shape: + total_dims *= dim + values = [global_rng.randint(0, vocab_size - 1) for _ in range(total_dims)] + return torch.tensor(data=values, dtype=dtype, device=torch_device).view(shape).contiguous() diff --git a/tests/test_methods/test_on_albert.py b/tests/test_methods/test_on_albert.py new file mode 100644 index 0000000000..0b271e32b8 --- /dev/null +++ b/tests/test_methods/test_on_albert.py @@ -0,0 +1,48 @@ +import unittest +from math import ceil + +import pytest + +from transformers import AlbertConfig +from transformers.testing_utils import require_torch + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.heads.test_adapter_heads import PredictionHeadModelTestMixin +from .method_test_impl.utils import make_config + + +class AlbertAdapterTestBase(TextAdapterTestBase): + """Model configuration for testing methods on Albert.""" + + config_class = AlbertConfig + config = make_config( + AlbertConfig, + embedding_size=16, + hidden_size=64, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + num_hidden_groups=2, + ) + tokenizer_name = "albert-base-v2" + leave_out_layers = [0] + + +method_tests = generate_method_tests(AlbertAdapterTestBase, not_supported=["Heads"]) + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class + + +@require_torch +@pytest.mark.heads +class Heads( + AlbertAdapterTestBase, + PredictionHeadModelTestMixin, + unittest.TestCase, +): + + def test_context_simple(self): + expected_number_of_adapter_calls = ceil(self.config().num_hidden_layers / self.config().num_hidden_groups) + super().test_context_simple(expected_number_of_adapter_calls=expected_number_of_adapter_calls) diff --git a/tests/test_methods/test_on_bart.py b/tests/test_methods/test_on_bart.py new file mode 100644 index 0000000000..558c3527c8 --- /dev/null +++ b/tests/test_methods/test_on_bart.py @@ -0,0 +1,28 @@ +from transformers import BartConfig + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +class BartAdapterTestBase(TextAdapterTestBase): + config_class = BartConfig + config = make_config( + BartConfig, + d_model=16, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=4, + decoder_attention_heads=4, + encoder_ffn_dim=4, + decoder_ffn_dim=4, + ) + tokenizer_name = "facebook/bart-base" + + +method_tests = generate_method_tests( + BartAdapterTestBase, not_supported=["PromptTuning"], redundant=["ConfigUnion", "Embeddings"] +) + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_beit.py b/tests/test_methods/test_on_beit.py new file mode 100644 index 0000000000..a19e7ce73d --- /dev/null +++ b/tests/test_methods/test_on_beit.py @@ -0,0 +1,24 @@ +from transformers import BeitConfig + +from .base import VisionAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +class BeitAdapterTestBase(VisionAdapterTestBase): + config_class = BeitConfig + config = make_config( + BeitConfig, + image_size=224, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=37, + ) + feature_extractor_name = "microsoft/beit-base-patch16-224-pt22k" + + +method_tests = generate_method_tests(BeitAdapterTestBase, not_supported=["Composition", "Embeddings"]) + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_bert.py b/tests/test_methods/test_on_bert.py new file mode 100644 index 0000000000..814c6545a4 --- /dev/null +++ b/tests/test_methods/test_on_bert.py @@ -0,0 +1,23 @@ +from transformers import BertConfig + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +class BertAdapterTestBase(TextAdapterTestBase): + config_class = BertConfig + config = make_config( + BertConfig, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=37, + ) + tokenizer_name = "bert-base-uncased" + + +method_tests = generate_method_tests(BertAdapterTestBase) + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_bert_generation.py b/tests/test_methods/test_on_bert_generation.py similarity index 62% rename from tests/test_bert_generation.py rename to tests/test_methods/test_on_bert_generation.py index 48fe3e7b40..a4e862963b 100644 --- a/tests/test_bert_generation.py +++ b/tests/test_methods/test_on_bert_generation.py @@ -1,21 +1,15 @@ -import unittest - +import torch from datasets import load_dataset from transformers import AutoTokenizer, BertGenerationConfig -from transformers.testing_utils import require_torch +from transformers.testing_utils import torch_device -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import AllMethodsTestMixin -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_embeddings import EmbeddingTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config -class BertGenerationAdapterTestBase(AdapterTestBase): +class BertGenerationAdapterTestBase(TextAdapterTestBase): config_class = BertGenerationConfig config = make_config( BertGenerationConfig, @@ -28,9 +22,9 @@ class BertGenerationAdapterTestBase(AdapterTestBase): def add_head(self, model, name, **kwargs): model.add_masked_lm_head(name) - return self.default_input_samples_shape[-1] + return self.input_shape[-1] - def dataset(self, tokenizer=None): + def get_dataset(self, tokenizer=None): # setup tokenizer if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False) @@ -69,26 +63,20 @@ def preprocess_function(examples): ) return train_dataset + def get_dataset_non_batched(self, config): + dataset = [] + for i in range(3): + input_data = self.get_input_samples(config=config) + input_data = self.attach_labels(input_data) + dataset.append(input_data) + return dataset + + def attach_labels(self, inputs): + inputs["labels"] = torch.randint(0, 2, (self.batch_size, 64), device=torch_device) + return inputs + + +method_tests = generate_method_tests(BertGenerationAdapterTestBase) -@require_torch -class BertGenerationAdapterTest( - AllMethodsTestMixin, - EmbeddingTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - ParallelTrainingMixin, - BertGenerationAdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class BertGenerationClassConversionTest( - ModelClassConversionTestMixin, - BertGenerationAdapterTestBase, - unittest.TestCase, -): - pass +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_clip/test_model.py b/tests/test_methods/test_on_clip/test_model.py new file mode 100644 index 0000000000..6bff937bc2 --- /dev/null +++ b/tests/test_methods/test_on_clip/test_model.py @@ -0,0 +1,83 @@ +import random + +import torch + +from tests.test_methods.base import TextAdapterTestBase +from tests.test_methods.generator import generate_method_tests +from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig +from transformers.testing_utils import torch_device + + +class CLIPAdapterTestBase(TextAdapterTestBase): + config_class = CLIPConfig + config = staticmethod( + lambda: CLIPConfig.from_text_vision_configs( + CLIPTextConfig( + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=37, + ), + CLIPVisionConfig( + image_size=224, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=37, + ), + ) + ) + tokenizer_name = "openai/clip-vit-base-patch32" + # Default shape of inputs to use + default_text_input_samples_shape = (3, 64) + default_vision_input_samples_shape = (3, 3, 224, 224) + do_run_train_tests = False + + def get_input_samples(self, vocab_size=5000, config=None, dtype=torch.float, **kwargs): + # text inputs + shape = self.default_text_input_samples_shape + total_dims = 1 + for dim in shape: + total_dims *= dim + values = [] + for _ in range(total_dims): + values.append(random.randint(0, vocab_size - 1)) + input_ids = torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() + # this is needed e.g. for BART + if config and config.eos_token_id is not None and config.eos_token_id < vocab_size: + input_ids[input_ids == config.eos_token_id] = random.randint(0, config.eos_token_id - 1) + input_ids[:, -1] = config.eos_token_id + in_data = {"input_ids": input_ids} + + # vision inputs + shape = self.default_vision_input_samples_shape + total_dims = 1 + for dim in shape: + total_dims *= dim + values = [] + for _ in range(total_dims): + values.append(random.random()) + pixel_values = torch.tensor(data=values, dtype=dtype, device=torch_device).view(shape).contiguous() + in_data["pixel_values"] = pixel_values + + return in_data + + def add_head(self, *args, **kwargs): + pass + + def test_adapter_fusion_save_with_head(self): + # This test is not applicable to CLIP + self.skipTest("Not applicable to CLIP.") + + def test_load_adapter_setup(self): + self.skipTest("Not applicable to CLIP.") + + +method_tests = generate_method_tests( + model_test_base=CLIPAdapterTestBase, + not_supported=["Embeddings", "Heads", "Composition", "ClassConversion", "PromptTuning", "ConfigUnion"], +) + + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_clip/test_textmodel.py b/tests/test_methods/test_on_clip/test_textmodel.py new file mode 100644 index 0000000000..3b9f505389 --- /dev/null +++ b/tests/test_methods/test_on_clip/test_textmodel.py @@ -0,0 +1,27 @@ +from tests.test_methods.base import TextAdapterTestBase +from tests.test_methods.generator import generate_method_tests +from tests.test_methods.method_test_impl.utils import make_config +from transformers import CLIPTextConfig, CLIPTextModel + + +class CLIPTextAdapterTestBase(TextAdapterTestBase): + model_class = CLIPTextModel + config_class = CLIPTextConfig + config = make_config( + CLIPTextConfig, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=37, + ) + tokenizer_name = "openai/clip-vit-base-patch32" + + +method_tests = generate_method_tests( + model_test_base=CLIPTextAdapterTestBase, + not_supported=["Embeddings", "Heads", "Composition", "ClassConversion", "PromptTuning", "ConfigUnion"], +) + + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_clip/test_textwithprojectionmodel.py b/tests/test_methods/test_on_clip/test_textwithprojectionmodel.py new file mode 100644 index 0000000000..438f61c52a --- /dev/null +++ b/tests/test_methods/test_on_clip/test_textwithprojectionmodel.py @@ -0,0 +1,27 @@ +from tests.test_methods.base import TextAdapterTestBase +from tests.test_methods.generator import generate_method_tests +from tests.test_methods.method_test_impl.utils import make_config +from transformers import CLIPTextConfig, CLIPTextModelWithProjection + + +class CLIPTextWithProjectionAdapterTestBase(TextAdapterTestBase): + model_class = CLIPTextModelWithProjection + config_class = CLIPTextConfig + config = make_config( + CLIPTextConfig, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=37, + ) + tokenizer_name = "openai/clip-vit-base-patch32" + + +method_tests = generate_method_tests( + model_test_base=CLIPTextWithProjectionAdapterTestBase, + not_supported=["Embeddings", "Heads", "Composition", "ClassConversion", "PromptTuning", "ConfigUnion"], +) + + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_clip/test_visionmodel.py b/tests/test_methods/test_on_clip/test_visionmodel.py new file mode 100644 index 0000000000..83c25a41f9 --- /dev/null +++ b/tests/test_methods/test_on_clip/test_visionmodel.py @@ -0,0 +1,28 @@ +from tests.test_methods.base import VisionAdapterTestBase +from tests.test_methods.generator import generate_method_tests +from tests.test_methods.method_test_impl.utils import make_config +from transformers import CLIPVisionConfig, CLIPVisionModel + + +class CLIPVisionAdapterTestBase(VisionAdapterTestBase): + model_class = CLIPVisionModel + config_class = CLIPVisionConfig + config = make_config( + CLIPVisionConfig, + image_size=224, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=37, + ) + feature_extractor_name = "openai/clip-vit-base-patch32" + + +method_tests = generate_method_tests( + model_test_base=CLIPVisionAdapterTestBase, + not_supported=["Embeddings", "Heads", "Composition", "ClassConversion", "PromptTuning", "ConfigUnion"], +) + + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_clip/test_visionwithprojectionmodel.py b/tests/test_methods/test_on_clip/test_visionwithprojectionmodel.py new file mode 100644 index 0000000000..17ffbbe305 --- /dev/null +++ b/tests/test_methods/test_on_clip/test_visionwithprojectionmodel.py @@ -0,0 +1,28 @@ +from tests.test_methods.base import VisionAdapterTestBase +from tests.test_methods.generator import generate_method_tests +from tests.test_methods.method_test_impl.utils import make_config +from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection + + +class CLIPVisionWithProjectionAdapterTestBase(VisionAdapterTestBase): + model_class = CLIPVisionModelWithProjection + config_class = CLIPVisionConfig + config = make_config( + CLIPVisionConfig, + image_size=224, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=37, + ) + feature_extractor_name = "openai/clip-vit-base-patch32" + + +method_tests = generate_method_tests( + model_test_base=CLIPVisionWithProjectionAdapterTestBase, + not_supported=["Embeddings", "Heads", "Composition", "ClassConversion", "PromptTuning", "ConfigUnion"], +) + + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_deberta.py b/tests/test_methods/test_on_deberta.py new file mode 100644 index 0000000000..feb430f535 --- /dev/null +++ b/tests/test_methods/test_on_deberta.py @@ -0,0 +1,29 @@ +from transformers import DebertaConfig + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +class DebertaAdapterTestBase(TextAdapterTestBase): + config_class = DebertaConfig + config = make_config( + DebertaConfig, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + relative_attention=True, + pos_att_type="p2c|c2p", + ) + tokenizer_name = "microsoft/deberta-base" + + def test_parallel_training_lora(self): + self.skipTest("Not supported for DeBERTa") + + +method_tests = generate_method_tests(DebertaAdapterTestBase) + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_debertaV2.py b/tests/test_methods/test_on_debertaV2.py new file mode 100644 index 0000000000..91ff684036 --- /dev/null +++ b/tests/test_methods/test_on_debertaV2.py @@ -0,0 +1,26 @@ +from transformers import DebertaV2Config + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +class DebertaV2AdapterTestBase(TextAdapterTestBase): + config_class = DebertaV2Config + config = make_config( + DebertaV2Config, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + relative_attention=True, + pos_att_type="p2c|c2p", + ) + tokenizer_name = "microsoft/deberta-v3-base" + + +method_tests = generate_method_tests(DebertaV2AdapterTestBase) + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_distilbert.py b/tests/test_methods/test_on_distilbert.py new file mode 100644 index 0000000000..961b0fbda8 --- /dev/null +++ b/tests/test_methods/test_on_distilbert.py @@ -0,0 +1,23 @@ +from transformers import DistilBertConfig + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +class DistilBertAdapterTestBase(TextAdapterTestBase): + config_class = DistilBertConfig + config = make_config( + DistilBertConfig, + dim=32, + n_layers=4, + n_heads=4, + hidden_dim=37, + ) + tokenizer_name = "distilbert-base-uncased" + + +method_tests = generate_method_tests(DistilBertAdapterTestBase) + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_electra.py b/tests/test_methods/test_on_electra.py new file mode 100644 index 0000000000..05fa5d47d6 --- /dev/null +++ b/tests/test_methods/test_on_electra.py @@ -0,0 +1,24 @@ +from transformers import ElectraConfig + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +class ElectraAdapterTestBase(TextAdapterTestBase): + config_class = ElectraConfig + config = make_config( + ElectraConfig, + # vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + ) + tokenizer_name = "google/electra-base-generator" + + +method_tests = generate_method_tests(ElectraAdapterTestBase) + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_encoder_decoder.py b/tests/test_methods/test_on_encoder_decoder.py similarity index 73% rename from tests/test_encoder_decoder.py rename to tests/test_methods/test_on_encoder_decoder.py index 708a6bfbb2..c20c8f5d35 100644 --- a/tests/test_encoder_decoder.py +++ b/tests/test_methods/test_on_encoder_decoder.py @@ -1,24 +1,11 @@ -# flake8: noqa: F403,F405 -import unittest +from adapters import init +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BertConfig, EncoderDecoderConfig, EncoderDecoderModel -import adapters -from hf_transformers.tests.models.encoder_decoder.test_modeling_encoder_decoder import * # Imported to execute model tests -from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BertConfig -from transformers.testing_utils import require_torch, torch_device - -from .methods import ( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - UniPELTTestMixin, -) -from .test_adapter import AdapterTestBase -from .test_adapter_fusion_common import AdapterFusionModelTestMixin +from .base import TextAdapterTestBase +from .generator import generate_method_tests -class EncoderDecoderAdapterTestBase(AdapterTestBase): +class EncoderDecoderAdapterTestBase(TextAdapterTestBase): model_class = EncoderDecoderModel config_class = EncoderDecoderConfig config = staticmethod( @@ -42,22 +29,9 @@ class EncoderDecoderAdapterTestBase(AdapterTestBase): tokenizer_name = "bert-base-uncased" do_run_train_tests = False - -@require_torch -class EncoderDecoderAdapterTest( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - UniPELTTestMixin, - AdapterFusionModelTestMixin, - EncoderDecoderAdapterTestBase, - unittest.TestCase, -): def test_generation(self): model = AutoModelForSeq2SeqLM.from_config(self.config()) - adapters.init(model) + init(model) model.add_adapter("test", config="pfeiffer") model.set_active_adapters("test") tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False) @@ -88,3 +62,12 @@ def test_output_adapter_gating_scores_unipelt(self): def test_output_adapter_fusion_attentions(self): # TODO currently not supported self.skipTest("Not implemented.") + + +test_methods = generate_method_tests( + EncoderDecoderAdapterTestBase, + not_supported=["Heads", "ConfigUnion", "Embeddings", "Composition", "PromptTuning", "ClassConversion"], +) + +for test_class_name, test_class in test_methods.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_gpt2.py b/tests/test_methods/test_on_gpt2.py new file mode 100644 index 0000000000..afce4453c5 --- /dev/null +++ b/tests/test_methods/test_on_gpt2.py @@ -0,0 +1,27 @@ +from transformers import GPT2Config + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +class GPT2AdapterTestBase(TextAdapterTestBase): + config_class = GPT2Config + config = make_config( + GPT2Config, + n_embd=32, + n_layer=4, + n_head=4, + # set pad token to eos token + pad_token_id=50256, + ) + tokenizer_name = "gpt2" + + def test_parallel_training_lora(self): + self.skipTest("Not supported for GPT2") + + +method_tests = generate_method_tests(GPT2AdapterTestBase, not_supported=["PromptTuning"]) + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_llama.py b/tests/test_methods/test_on_llama.py new file mode 100644 index 0000000000..318f5f220e --- /dev/null +++ b/tests/test_methods/test_on_llama.py @@ -0,0 +1,39 @@ +import unittest + +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.testing_utils import require_torch + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.core.test_adapter_conversion import ModelClassConversionTestMixin +from .method_test_impl.utils import make_config + + +class LlamaAdapterTestBase(TextAdapterTestBase): + config_class = LlamaConfig + config = make_config( + LlamaConfig, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + pad_token_id=0, + ) + tokenizer_name = "openlm-research/open_llama_13b" + + +method_tests = generate_method_tests(LlamaAdapterTestBase, not_supported=["PromptTuning"]) + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class + + +@require_torch +class ClassConversion( + ModelClassConversionTestMixin, + LlamaAdapterTestBase, + unittest.TestCase, +): + def test_conversion_question_answering_model(self): + raise self.skipTest("We don't support the Llama QA model.") diff --git a/tests/test_methods/test_on_mbart.py b/tests/test_methods/test_on_mbart.py new file mode 100644 index 0000000000..d6dc2e8340 --- /dev/null +++ b/tests/test_methods/test_on_mbart.py @@ -0,0 +1,31 @@ +from transformers import MBartConfig + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +class MBartAdapterTestBase(TextAdapterTestBase): + config_class = MBartConfig + config = make_config( + MBartConfig, + d_model=16, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=4, + decoder_attention_heads=4, + encoder_ffn_dim=4, + decoder_ffn_dim=4, + vocab_size=250027, + ) + tokenizer_name = "facebook/mbart-large-cc25" + + def test_parallel_training_lora(self): + self.skipTest("Not supported for MBart") + + +method_tests = generate_method_tests( + MBartAdapterTestBase, redundant=["ConfigUnion", "Embeddings"], not_supported=["PromptTuning"] +) +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_mistral.py b/tests/test_methods/test_on_mistral.py new file mode 100644 index 0000000000..94ef7721e7 --- /dev/null +++ b/tests/test_methods/test_on_mistral.py @@ -0,0 +1,26 @@ +from transformers.models.mistral.configuration_mistral import MistralConfig + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +class MistralAdapterTestBase(TextAdapterTestBase): + config_class = MistralConfig + config = make_config( + MistralConfig, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=8, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + pad_token_id=0, + ) + tokenizer_name = "HuggingFaceH4/zephyr-7b-beta" + + +test_methods = generate_method_tests(MistralAdapterTestBase, not_supported=["PromptTuning", "ConfigUnion"]) + +for test_class_name, test_class in test_methods.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_mt5.py b/tests/test_methods/test_on_mt5.py new file mode 100644 index 0000000000..8f256b0799 --- /dev/null +++ b/tests/test_methods/test_on_mt5.py @@ -0,0 +1,29 @@ +from transformers import MT5Config +from transformers.testing_utils import require_torch + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +@require_torch +class MT5AdapterTestBase(TextAdapterTestBase): + config_class = MT5Config + config = make_config( + MT5Config, + d_model=16, + num_layers=2, + num_decoder_layers=2, + num_heads=4, + d_ff=4, + d_kv=16 // 4, + tie_word_embeddings=False, + decoder_start_token_id=0, + ) + tokenizer_name = "google/mt5-base" + + +method_tests = generate_method_tests(MT5AdapterTestBase, not_supported=["PromptTuning", "ConfigUnion"]) + +for test_name, test_class in method_tests.items(): + globals()[test_name] = test_class diff --git a/tests/test_methods/test_on_plbart.py b/tests/test_methods/test_on_plbart.py new file mode 100644 index 0000000000..67f7091e0b --- /dev/null +++ b/tests/test_methods/test_on_plbart.py @@ -0,0 +1,27 @@ +from transformers import PLBartConfig + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +class PLBartAdapterTestBase(TextAdapterTestBase): + config_class = PLBartConfig + config = make_config( + PLBartConfig, + d_model=32, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=4, + decoder_attention_heads=4, + encoder_ffn_dim=4, + decoder_ffn_dim=4, + scale_embedding=False, # Required for embedding tests + ) + tokenizer_name = "uclanlp/plbart-base" + + +method_tests = generate_method_tests(PLBartAdapterTestBase, not_supported=["PromptTuning"]) + +for test_name, test_class in method_tests.items(): + globals()[test_name] = test_class diff --git a/tests/test_methods/test_on_roberta.py b/tests/test_methods/test_on_roberta.py new file mode 100644 index 0000000000..1917e8a2d0 --- /dev/null +++ b/tests/test_methods/test_on_roberta.py @@ -0,0 +1,24 @@ +from transformers import RobertaConfig + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +class RobertaAdapterTestBase(TextAdapterTestBase): + config_class = RobertaConfig + config = make_config( + RobertaConfig, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=37, + vocab_size=50265, + ) + tokenizer_name = "roberta-base" + + +method_tests = generate_method_tests(RobertaAdapterTestBase) + +for test_name, test_class in method_tests.items(): + globals()[test_name] = test_class diff --git a/tests/test_methods/test_on_t5.py b/tests/test_methods/test_on_t5.py new file mode 100644 index 0000000000..5a737328e7 --- /dev/null +++ b/tests/test_methods/test_on_t5.py @@ -0,0 +1,28 @@ +from transformers import T5Config +from transformers.testing_utils import require_torch + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +@require_torch +class T5AdapterTestBase(TextAdapterTestBase): + config_class = T5Config + config = make_config( + T5Config, + d_model=16, + num_layers=2, + num_decoder_layers=2, + num_heads=4, + d_ff=4, + d_kv=16 // 4, + tie_word_embeddings=False, + decoder_start_token_id=0, + ) + tokenizer_name = "t5-base" + + +method_tests = generate_method_tests(T5AdapterTestBase, not_supported=["ConfigUnion", "PromptTuning"]) +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_vit.py b/tests/test_methods/test_on_vit.py new file mode 100644 index 0000000000..31df85bda6 --- /dev/null +++ b/tests/test_methods/test_on_vit.py @@ -0,0 +1,23 @@ +from transformers import ViTConfig + +from .base import VisionAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +class ViTAdapterTestBase(VisionAdapterTestBase): + config_class = ViTConfig + config = make_config( + ViTConfig, + image_size=224, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=37, + ) + feature_extractor_name = "google/vit-base-patch16-224-in21k" + + +method_tests = generate_method_tests(ViTAdapterTestBase, not_supported=["ConfigUnion", "Embeddings", "Composition"]) +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_whisper.py b/tests/test_methods/test_on_whisper.py new file mode 100644 index 0000000000..b172a6fdd2 --- /dev/null +++ b/tests/test_methods/test_on_whisper.py @@ -0,0 +1,31 @@ +from transformers import WhisperConfig + +from .base import AudioAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +class WhisperAdapterTestBase(AudioAdapterTestBase): + config_class = WhisperConfig + config = make_config( + WhisperConfig, + d_model=32, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=4, + decoder_attention_heads=4, + encoder_ffn_dim=4, + decoder_ffn_dim=4, + vocab_size=51865, + ) + tokenizer_name = "openai/whisper-small" + sampling_rate = 16000 + decoder_start_token_id = 50257 + + def test_parallel_training_lora(self): + self.skipTest("Not supported for Whisper") + + +method_tests = generate_method_tests(WhisperAdapterTestBase, not_supported=["PromptTuning"]) +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_xlm_roberta.py b/tests/test_methods/test_on_xlm_roberta.py new file mode 100644 index 0000000000..80605d4fa1 --- /dev/null +++ b/tests/test_methods/test_on_xlm_roberta.py @@ -0,0 +1,23 @@ +from transformers import XLMRobertaConfig + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +class XLMRobertaAdapterTestBase(TextAdapterTestBase): + config_class = XLMRobertaConfig + config = make_config( + XLMRobertaConfig, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=37, + vocab_size=250002, + ) + tokenizer_name = "xlm-roberta-base" + + +method_tests = generate_method_tests(XLMRobertaAdapterTestBase, redundant=["ConfigUnion", "Embeddings"]) +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/test_methods/test_on_xmod.py b/tests/test_methods/test_on_xmod.py new file mode 100644 index 0000000000..1dab7d079e --- /dev/null +++ b/tests/test_methods/test_on_xmod.py @@ -0,0 +1,25 @@ +from transformers import XmodConfig + +from .base import TextAdapterTestBase +from .generator import generate_method_tests +from .method_test_impl.utils import make_config + + +class XmodAdapterTestBase(TextAdapterTestBase): + config_class = XmodConfig + config = make_config( + XmodConfig, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=37, + vocab_size=250002, + max_position_embeddings=512, + default_language="en_XX", + ) + tokenizer_name = "xlm-roberta-base" + + +method_tests = generate_method_tests(XmodAdapterTestBase, not_supported=["ConfigUnion", "Embeddings"]) +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class diff --git a/tests/composition/test_adapter_composition.py b/tests/test_misc/test_adapter_composition.py similarity index 99% rename from tests/composition/test_adapter_composition.py rename to tests/test_misc/test_adapter_composition.py index 3d0d47412d..29dade33a8 100644 --- a/tests/composition/test_adapter_composition.py +++ b/tests/test_misc/test_adapter_composition.py @@ -5,7 +5,7 @@ import adapters from adapters import IA3Config, LoRAConfig, PrefixTuningConfig, SeqBnConfig from adapters.composition import Average, BatchSplit, Fuse, Parallel, Split, Stack, parse_composition -from tests.test_adapter import ids_tensor +from tests.test_methods.method_test_impl.utils import ids_tensor from transformers import BertConfig, BertForSequenceClassification from transformers.testing_utils import require_torch, torch_device diff --git a/tests/test_adapter_config.py b/tests/test_misc/test_adapter_config.py similarity index 100% rename from tests/test_adapter_config.py rename to tests/test_misc/test_adapter_config.py diff --git a/tests/test_adapter_custom_head.py b/tests/test_misc/test_adapter_custom_head.py similarity index 98% rename from tests/test_adapter_custom_head.py rename to tests/test_misc/test_adapter_custom_head.py index 8e29636d05..68a6dd1946 100644 --- a/tests/test_adapter_custom_head.py +++ b/tests/test_misc/test_adapter_custom_head.py @@ -5,11 +5,10 @@ from adapters import AutoAdapterModel from adapters.heads import ClassificationHead, PredictionHead +from tests.test_methods.method_test_impl.utils import ids_tensor from transformers import AutoConfig from transformers.testing_utils import require_torch, torch_device -from .test_adapter import ids_tensor - class CustomHead(PredictionHead): def __init__( diff --git a/tests/test_adapter_fusion_config.py b/tests/test_misc/test_adapter_fusion_config.py similarity index 100% rename from tests/test_adapter_fusion_config.py rename to tests/test_misc/test_adapter_fusion_config.py diff --git a/tests/test_adapter_hub.py b/tests/test_misc/test_adapter_hub.py similarity index 96% rename from tests/test_adapter_hub.py rename to tests/test_misc/test_adapter_hub.py index 0dee5eb0a6..f7cb4fc96f 100644 --- a/tests/test_adapter_hub.py +++ b/tests/test_misc/test_adapter_hub.py @@ -1,5 +1,6 @@ import os import unittest +from pathlib import Path import numpy as np @@ -7,6 +8,7 @@ from adapters import ADAPTER_CONFIG_MAP, AdapterConfig, BertAdapterModel, get_adapter_config_hash from adapters.trainer import AdapterTrainer as Trainer from adapters.utils import find_in_index +from tests.test_methods.method_test_impl.utils import ids_tensor from transformers import ( AutoModel, AutoTokenizer, @@ -17,10 +19,10 @@ ) from transformers.testing_utils import require_torch, torch_device -from .test_adapter import ids_tensor - -SAMPLE_INDEX = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/hub-index.sample.json") +current_file_path = os.path.abspath(__file__) +fixtures_dir = Path(current_file_path).parent.parent.parent / "fixtures" +SAMPLE_INDEX = str(fixtures_dir / "hub-index.sample.json") @require_torch diff --git a/tests/test_adapter_safetensors.py b/tests/test_misc/test_adapter_safetensors.py similarity index 100% rename from tests/test_adapter_safetensors.py rename to tests/test_misc/test_adapter_safetensors.py diff --git a/tests/test_adapter_save_id2label.py b/tests/test_misc/test_adapter_save_id2label.py similarity index 100% rename from tests/test_adapter_save_id2label.py rename to tests/test_misc/test_adapter_save_id2label.py diff --git a/tests/test_misc/test_adapter_trainer/__init__.py b/tests/test_misc/test_adapter_trainer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_adapter_trainer.py b/tests/test_misc/test_adapter_trainer/test_adapter_trainer.py similarity index 100% rename from tests/test_adapter_trainer.py rename to tests/test_misc/test_adapter_trainer/test_adapter_trainer.py diff --git a/tests/extended/test_adapter_trainer_ext.py b/tests/test_misc/test_adapter_trainer/test_adapter_trainer_ext.py similarity index 98% rename from tests/extended/test_adapter_trainer_ext.py rename to tests/test_misc/test_adapter_trainer/test_adapter_trainer_ext.py index 8da0ea07c8..917a9c9193 100644 --- a/tests/extended/test_adapter_trainer_ext.py +++ b/tests/test_misc/test_adapter_trainer/test_adapter_trainer_ext.py @@ -42,7 +42,7 @@ bindir = os.path.abspath(os.path.dirname(__file__)) -with ExtendSysPath(f"{bindir}/../../examples/pytorch/translation"): +with ExtendSysPath(f"{bindir}/../../../examples/pytorch/translation"): from run_translation import main # noqa @@ -268,7 +268,7 @@ def run_trainer( do_predict: bool = True, n_gpus_to_use: int = None, ): - data_dir = self.test_file_dir / "../../hf_transformers/tests/fixtures/tests_samples/wmt_en_ro" + data_dir = self.test_file_dir / "../../../hf_transformers/tests/fixtures/tests_samples/wmt_en_ro" output_dir = self.get_auto_remove_tmp_dir() args_train = f""" --model_name_or_path {model_name} diff --git a/tests/test_mistral.py b/tests/test_mistral.py deleted file mode 100644 index b10065a702..0000000000 --- a/tests/test_mistral.py +++ /dev/null @@ -1,66 +0,0 @@ -import unittest - -from transformers.models.mistral.configuration_mistral import MistralConfig -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import ( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, -) -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_embeddings import EmbeddingTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class MistralAdapterTestBase(AdapterTestBase): - config_class = MistralConfig - config = make_config( - MistralConfig, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=8, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - pad_token_id=0, - ) - tokenizer_name = "HuggingFaceH4/zephyr-7b-beta" - - -@require_torch -class MistralAdapterTest( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, - EmbeddingTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - ParallelTrainingMixin, - MistralAdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class MistralClassConversionTest( - ModelClassConversionTestMixin, - MistralAdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_models/__init__.py b/tests/test_models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/base.py b/tests/test_models/base.py similarity index 100% rename from tests/models/base.py rename to tests/test_models/base.py diff --git a/tests/models/test_albert.py b/tests/test_models/test_albert_model.py similarity index 100% rename from tests/models/test_albert.py rename to tests/test_models/test_albert_model.py diff --git a/tests/models/test_bart.py b/tests/test_models/test_bart_model.py similarity index 100% rename from tests/models/test_bart.py rename to tests/test_models/test_bart_model.py diff --git a/tests/models/test_beit.py b/tests/test_models/test_beit_model.py similarity index 100% rename from tests/models/test_beit.py rename to tests/test_models/test_beit_model.py diff --git a/tests/models/test_bert_generation.py b/tests/test_models/test_bert_generation_model.py similarity index 100% rename from tests/models/test_bert_generation.py rename to tests/test_models/test_bert_generation_model.py diff --git a/tests/models/test_bert.py b/tests/test_models/test_bert_model.py similarity index 100% rename from tests/models/test_bert.py rename to tests/test_models/test_bert_model.py diff --git a/tests/models/test_clip.py b/tests/test_models/test_clip_model.py similarity index 100% rename from tests/models/test_clip.py rename to tests/test_models/test_clip_model.py diff --git a/tests/models/test_debertaV2.py b/tests/test_models/test_debertaV2_model.py similarity index 100% rename from tests/models/test_debertaV2.py rename to tests/test_models/test_debertaV2_model.py diff --git a/tests/models/test_deberta.py b/tests/test_models/test_deberta_model.py similarity index 100% rename from tests/models/test_deberta.py rename to tests/test_models/test_deberta_model.py diff --git a/tests/models/test_distilbert.py b/tests/test_models/test_distilbert_model.py similarity index 100% rename from tests/models/test_distilbert.py rename to tests/test_models/test_distilbert_model.py diff --git a/tests/models/test_electra.py b/tests/test_models/test_electra_model.py similarity index 100% rename from tests/models/test_electra.py rename to tests/test_models/test_electra_model.py diff --git a/tests/models/test_encoder_decoder.py b/tests/test_models/test_encoder_decoder_model.py similarity index 100% rename from tests/models/test_encoder_decoder.py rename to tests/test_models/test_encoder_decoder_model.py diff --git a/tests/models/test_gpt2.py b/tests/test_models/test_gpt2_model.py similarity index 100% rename from tests/models/test_gpt2.py rename to tests/test_models/test_gpt2_model.py diff --git a/tests/models/test_gptj.py b/tests/test_models/test_gptj_model.py similarity index 100% rename from tests/models/test_gptj.py rename to tests/test_models/test_gptj_model.py diff --git a/tests/models/test_llama.py b/tests/test_models/test_llama_model.py similarity index 100% rename from tests/models/test_llama.py rename to tests/test_models/test_llama_model.py diff --git a/tests/models/test_mbart.py b/tests/test_models/test_mbart_model.py similarity index 100% rename from tests/models/test_mbart.py rename to tests/test_models/test_mbart_model.py diff --git a/tests/models/test_mistral.py b/tests/test_models/test_mistral_model.py similarity index 100% rename from tests/models/test_mistral.py rename to tests/test_models/test_mistral_model.py diff --git a/tests/models/test_mt5.py b/tests/test_models/test_mt5_model.py similarity index 100% rename from tests/models/test_mt5.py rename to tests/test_models/test_mt5_model.py diff --git a/tests/models/test_plbart.py b/tests/test_models/test_plbart_model.py similarity index 100% rename from tests/models/test_plbart.py rename to tests/test_models/test_plbart_model.py diff --git a/tests/models/test_roberta.py b/tests/test_models/test_roberta_model.py similarity index 100% rename from tests/models/test_roberta.py rename to tests/test_models/test_roberta_model.py diff --git a/tests/models/test_t5.py b/tests/test_models/test_t5_model.py similarity index 100% rename from tests/models/test_t5.py rename to tests/test_models/test_t5_model.py diff --git a/tests/models/test_vit.py b/tests/test_models/test_vit_model.py similarity index 100% rename from tests/models/test_vit.py rename to tests/test_models/test_vit_model.py diff --git a/tests/models/test_whisper.py b/tests/test_models/test_whisper_model.py similarity index 100% rename from tests/models/test_whisper.py rename to tests/test_models/test_whisper_model.py diff --git a/tests/models/test_xlm_roberta.py b/tests/test_models/test_xlm_roberta_model.py similarity index 100% rename from tests/models/test_xlm_roberta.py rename to tests/test_models/test_xlm_roberta_model.py diff --git a/tests/models/test_xmod.py b/tests/test_models/test_xmod_model.py similarity index 100% rename from tests/models/test_xmod.py rename to tests/test_models/test_xmod_model.py diff --git a/tests/test_mt5.py b/tests/test_mt5.py deleted file mode 100644 index a7d7c3a0fe..0000000000 --- a/tests/test_mt5.py +++ /dev/null @@ -1,68 +0,0 @@ -import unittest - -from transformers import MT5Config -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import ( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, -) -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_embeddings import EmbeddingTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -@require_torch -class MT5AdapterTestBase(AdapterTestBase): - config_class = MT5Config - config = make_config( - MT5Config, - d_model=16, - num_layers=2, - num_decoder_layers=2, - num_heads=4, - d_ff=4, - d_kv=16 // 4, - tie_word_embeddings=False, - decoder_start_token_id=0, - ) - tokenizer_name = "google/mt5-base" - - -@require_torch -class MT5AdapterTest( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, - EmbeddingTestMixin, - CompabilityTestMixin, - ParallelAdapterInferenceTestMixin, - ParallelTrainingMixin, - AdapterFusionModelTestMixin, - PredictionHeadModelTestMixin, - MT5AdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class MT5ClassConversionTest( - ModelClassConversionTestMixin, - MT5AdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_plbart.py b/tests/test_plbart.py deleted file mode 100644 index aa84457919..0000000000 --- a/tests/test_plbart.py +++ /dev/null @@ -1,67 +0,0 @@ -import unittest - -from tests.methods.test_config_union import ConfigUnionAdapterTest -from transformers import PLBartConfig -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import ( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - UniPELTTestMixin, -) -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_embeddings import EmbeddingTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class PLBartAdapterTestBase(AdapterTestBase): - config_class = PLBartConfig - config = make_config( - PLBartConfig, - d_model=16, - encoder_layers=2, - decoder_layers=2, - encoder_attention_heads=4, - decoder_attention_heads=4, - encoder_ffn_dim=4, - decoder_ffn_dim=4, - scale_embedding=False, # Required for embedding tests - ) - tokenizer_name = "uclanlp/plbart-base" - - -@require_torch -class PLBartAdapterTest( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - UniPELTTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - EmbeddingTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - ParallelTrainingMixin, - ConfigUnionAdapterTest, - PLBartAdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class PLBartClassConversionTest( - ModelClassConversionTestMixin, - PLBartAdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_roberta.py b/tests/test_roberta.py deleted file mode 100644 index 142a69e7a8..0000000000 --- a/tests/test_roberta.py +++ /dev/null @@ -1,49 +0,0 @@ -import unittest - -from tests.methods.test_config_union import ConfigUnionAdapterTest -from transformers import RobertaConfig -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin -from .methods import AllMethodsTestMixin -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class RobertaAdapterTestBase(AdapterTestBase): - config_class = RobertaConfig - config = make_config( - RobertaConfig, - hidden_size=32, - num_hidden_layers=4, - num_attention_heads=4, - intermediate_size=37, - vocab_size=50265, - ) - tokenizer_name = "roberta-base" - - -@require_torch -class RobertaAdapterTest( - AllMethodsTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - ConfigUnionAdapterTest, - RobertaAdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class RobertaClassConversionTest( - ModelClassConversionTestMixin, - RobertaAdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_t5.py b/tests/test_t5.py deleted file mode 100644 index 1c2480c6bb..0000000000 --- a/tests/test_t5.py +++ /dev/null @@ -1,68 +0,0 @@ -import unittest - -from transformers import T5Config -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import ( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, -) -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_embeddings import EmbeddingTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -@require_torch -class T5AdapterTestBase(AdapterTestBase): - config_class = T5Config - config = make_config( - T5Config, - d_model=16, - num_layers=2, - num_decoder_layers=2, - num_heads=4, - d_ff=4, - d_kv=16 // 4, - tie_word_embeddings=False, - decoder_start_token_id=0, - ) - tokenizer_name = "t5-base" - - -@require_torch -class T5AdapterTest( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, - EmbeddingTestMixin, - CompabilityTestMixin, - ParallelAdapterInferenceTestMixin, - ParallelTrainingMixin, - AdapterFusionModelTestMixin, - PredictionHeadModelTestMixin, - T5AdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class T5ClassConversionTest( - ModelClassConversionTestMixin, - T5AdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_vit.py b/tests/test_vit.py deleted file mode 100644 index a4bdd0afb9..0000000000 --- a/tests/test_vit.py +++ /dev/null @@ -1,48 +0,0 @@ -import unittest - -from transformers import ViTConfig -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import AllMethodsTestMixin -from .test_adapter import VisionAdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class ViTAdapterTestBase(VisionAdapterTestBase): - config_class = ViTConfig - config = make_config( - ViTConfig, - image_size=224, - hidden_size=32, - num_hidden_layers=4, - num_attention_heads=4, - intermediate_size=37, - ) - feature_extractor_name = "google/vit-base-patch16-224-in21k" - - -@require_torch -class ViTAdapterTest( - AllMethodsTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - ParallelTrainingMixin, - ViTAdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class ViTClassConversionTest( - ModelClassConversionTestMixin, - ViTAdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_whisper.py b/tests/test_whisper.py deleted file mode 100644 index 95675fc425..0000000000 --- a/tests/test_whisper.py +++ /dev/null @@ -1,72 +0,0 @@ -import unittest - -from tests.methods.test_config_union import ConfigUnionAdapterTest -from transformers import WhisperConfig -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin -from .methods import ( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, -) -from .test_adapter import SpeechAdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_embeddings import EmbeddingTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class WhisperAdapterTestBase(SpeechAdapterTestBase): - config_class = WhisperConfig - config = make_config( - WhisperConfig, - d_model=16, - encoder_layers=2, - decoder_layers=2, - encoder_attention_heads=4, - decoder_attention_heads=4, - encoder_ffn_dim=4, - decoder_ffn_dim=4, - vocab_size=51865, - ) - tokenizer_name = "openai/whisper-small" - sampling_rate = 16000 - decoder_start_token_id = 50257 - - -@require_torch -class WhisperAdapterTest( - BottleneckAdapterTestMixin, - CompacterTestMixin, - IA3TestMixin, - LoRATestMixin, - PrefixTuningTestMixin, - ReftTestMixin, - UniPELTTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - EmbeddingTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - ParallelTrainingMixin, - ConfigUnionAdapterTest, - WhisperAdapterTestBase, - unittest.TestCase, -): - def test_parallel_training_lora(self): - self.skipTest("Not supported for Whisper") - - -@require_torch -class WhisperClassConversionTest( - ModelClassConversionTestMixin, - WhisperAdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_xlm_roberta.py b/tests/test_xlm_roberta.py deleted file mode 100644 index 9125b3fbeb..0000000000 --- a/tests/test_xlm_roberta.py +++ /dev/null @@ -1,41 +0,0 @@ -import unittest - -from transformers import XLMRobertaConfig -from transformers.testing_utils import require_torch - -from .methods import AllMethodsTestMixin -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin - - -class XLMRobertaAdapterTestBase(AdapterTestBase): - config_class = XLMRobertaConfig - config = make_config( - XLMRobertaConfig, - hidden_size=32, - num_hidden_layers=4, - num_attention_heads=4, - intermediate_size=37, - vocab_size=250002, - ) - tokenizer_name = "xlm-roberta-base" - - -@require_torch -class XLMRobertaAdapterTest( - AllMethodsTestMixin, - AdapterFusionModelTestMixin, - XLMRobertaAdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class XLMRobertaClassConversionTest( - ModelClassConversionTestMixin, - XLMRobertaAdapterTestBase, - unittest.TestCase, -): - pass diff --git a/tests/test_xmod.py b/tests/test_xmod.py deleted file mode 100644 index 9ca2aaa70a..0000000000 --- a/tests/test_xmod.py +++ /dev/null @@ -1,49 +0,0 @@ -import unittest - -from transformers import XmodConfig -from transformers.testing_utils import require_torch - -from .composition.test_parallel import ParallelAdapterInferenceTestMixin -from .methods import AllMethodsTestMixin -from .test_adapter import AdapterTestBase, make_config -from .test_adapter_backward_compability import CompabilityTestMixin -from .test_adapter_conversion import ModelClassConversionTestMixin -from .test_adapter_fusion_common import AdapterFusionModelTestMixin -from .test_adapter_heads import PredictionHeadModelTestMixin - - -class XmodAdapterTestBase(AdapterTestBase): - config_class = XmodConfig - config = make_config( - XmodConfig, - hidden_size=32, - num_hidden_layers=4, - num_attention_heads=4, - intermediate_size=37, - vocab_size=250002, - max_position_embeddings=512, - default_language="en_XX", - ) - tokenizer_name = "xlm-roberta-base" - - -@require_torch -class XmodAdapterTest( - AllMethodsTestMixin, - AdapterFusionModelTestMixin, - CompabilityTestMixin, - PredictionHeadModelTestMixin, - ParallelAdapterInferenceTestMixin, - XmodAdapterTestBase, - unittest.TestCase, -): - pass - - -@require_torch -class XmodClassConversionTest( - ModelClassConversionTestMixin, - XmodAdapterTestBase, - unittest.TestCase, -): - pass From ea2b6395dc8313a7064d7765465f2f32af6b1397 Mon Sep 17 00:00:00 2001 From: calpt Date: Mon, 27 Jan 2025 23:04:57 +0100 Subject: [PATCH 8/9] Fix GenerationMixin warning for AdapterModel classes (#787) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes the following warning for all AdapterModels: ``` LlamaAdapterModel has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions. ``` --- src/adapters/models/bart/adapter_model.py | 5 ++++- src/adapters/models/bert/adapter_model.py | 5 ++++- src/adapters/models/bert_generation/adapter_model.py | 3 ++- src/adapters/models/distilbert/adapter_model.py | 3 ++- src/adapters/models/electra/adapter_model.py | 5 ++++- src/adapters/models/gpt2/adapter_model.py | 5 ++++- src/adapters/models/gptj/adapter_model.py | 5 ++++- src/adapters/models/llama/adapter_model.py | 5 ++++- src/adapters/models/mbart/adapter_model.py | 5 ++++- src/adapters/models/mistral/adapter_model.py | 5 ++++- src/adapters/models/mt5/adapter_model.py | 5 ++++- src/adapters/models/plbart/adapter_model.py | 5 ++++- src/adapters/models/roberta/adapter_model.py | 5 ++++- src/adapters/models/t5/adapter_model.py | 5 ++++- src/adapters/models/whisper/adapter_model.py | 5 ++++- src/adapters/models/xlm_roberta/adapter_model.py | 3 ++- src/adapters/models/xmod/adapter_model.py | 5 ++++- 17 files changed, 62 insertions(+), 17 deletions(-) diff --git a/src/adapters/models/bart/adapter_model.py b/src/adapters/models/bart/adapter_model.py index 4e07fc5f10..34a5615644 100644 --- a/src/adapters/models/bart/adapter_model.py +++ b/src/adapters/models/bart/adapter_model.py @@ -1,5 +1,6 @@ import torch +from transformers.generation import GenerationMixin from transformers.models.bart.modeling_bart import ( BART_INPUTS_DOCSTRING, BART_START_DOCSTRING, @@ -18,7 +19,9 @@ @add_start_docstrings( "BART Model with the option to add multiple flexible prediction heads on top.", BART_START_DOCSTRING ) -class BartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BartPreTrainedModel): +class BartAdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BartPreTrainedModel, GenerationMixin +): _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", diff --git a/src/adapters/models/bert/adapter_model.py b/src/adapters/models/bert/adapter_model.py index a15f3e4327..3be78bd5bd 100644 --- a/src/adapters/models/bert/adapter_model.py +++ b/src/adapters/models/bert/adapter_model.py @@ -1,3 +1,4 @@ +from transformers.generation import GenerationMixin from transformers.models.bert.modeling_bert import ( BERT_INPUTS_DOCSTRING, BERT_START_DOCSTRING, @@ -16,7 +17,9 @@ """Bert Model transformer with the option to add multiple flexible heads on top.""", BERT_START_DOCSTRING, ) -class BertAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BertPreTrainedModel): +class BertAdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BertPreTrainedModel, GenerationMixin +): head_types = [ "classification", diff --git a/src/adapters/models/bert_generation/adapter_model.py b/src/adapters/models/bert_generation/adapter_model.py index d3822e24a7..0bbe5ad51f 100644 --- a/src/adapters/models/bert_generation/adapter_model.py +++ b/src/adapters/models/bert_generation/adapter_model.py @@ -1,3 +1,4 @@ +from transformers.generation import GenerationMixin from transformers.models.bert_generation.modeling_bert_generation import ( BERT_GENERATION_INPUTS_DOCSTRING, BERT_GENERATION_START_DOCSTRING, @@ -17,7 +18,7 @@ BERT_GENERATION_START_DOCSTRING, ) class BertGenerationAdapterModel( - EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BertGenerationPreTrainedModel + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, BertGenerationPreTrainedModel, GenerationMixin ): _keys_to_ignore_on_load_unexpected = [r"lm_head.bias"] diff --git a/src/adapters/models/distilbert/adapter_model.py b/src/adapters/models/distilbert/adapter_model.py index 3f38c893ca..d7b09dfe1e 100644 --- a/src/adapters/models/distilbert/adapter_model.py +++ b/src/adapters/models/distilbert/adapter_model.py @@ -1,5 +1,6 @@ import torch.nn as nn +from transformers.generation import GenerationMixin from transformers.models.distilbert.modeling_distilbert import ( DISTILBERT_INPUTS_DOCSTRING, DISTILBERT_START_DOCSTRING, @@ -18,7 +19,7 @@ DISTILBERT_START_DOCSTRING, ) class DistilBertAdapterModel( - EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, DistilBertPreTrainedModel + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, DistilBertPreTrainedModel, GenerationMixin ): head_types = [ "classification", diff --git a/src/adapters/models/electra/adapter_model.py b/src/adapters/models/electra/adapter_model.py index 57e20fadbe..83bc8f9184 100644 --- a/src/adapters/models/electra/adapter_model.py +++ b/src/adapters/models/electra/adapter_model.py @@ -1,3 +1,4 @@ +from transformers.generation import GenerationMixin from transformers.models.electra.modeling_electra import ( ELECTRA_INPUTS_DOCSTRING, ELECTRA_START_DOCSTRING, @@ -16,7 +17,9 @@ """Electra Model transformer with the option to add multiple flexible heads on top.""", ELECTRA_START_DOCSTRING, ) -class ElectraAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, ElectraPreTrainedModel): +class ElectraAdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, ElectraPreTrainedModel, GenerationMixin +): head_types = [ "classification", diff --git a/src/adapters/models/gpt2/adapter_model.py b/src/adapters/models/gpt2/adapter_model.py index 2cfbdc8821..c6b96d1204 100644 --- a/src/adapters/models/gpt2/adapter_model.py +++ b/src/adapters/models/gpt2/adapter_model.py @@ -2,6 +2,7 @@ import torch +from transformers.generation import GenerationMixin from transformers.models.gpt2.modeling_gpt2 import GPT2_START_DOCSTRING, GPT2Model, GPT2PreTrainedModel from transformers.utils import add_start_docstrings @@ -25,7 +26,9 @@ """, GPT2_START_DOCSTRING, ) -class GPT2AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, GPT2PreTrainedModel): +class GPT2AdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, GPT2PreTrainedModel, GenerationMixin +): head_types = [ "classification", "multilabel_classification", diff --git a/src/adapters/models/gptj/adapter_model.py b/src/adapters/models/gptj/adapter_model.py index f029f840d6..c075aeac1a 100644 --- a/src/adapters/models/gptj/adapter_model.py +++ b/src/adapters/models/gptj/adapter_model.py @@ -2,6 +2,7 @@ import torch +from transformers.generation import GenerationMixin from transformers.models.gptj.modeling_gptj import GPTJ_START_DOCSTRING, GPTJModel, GPTJPreTrainedModel from transformers.utils import add_start_docstrings @@ -25,7 +26,9 @@ """, GPTJ_START_DOCSTRING, ) -class GPTJAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, GPTJPreTrainedModel): +class GPTJAdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, GPTJPreTrainedModel, GenerationMixin +): head_types = [ "classification", "multilabel_classification", diff --git a/src/adapters/models/llama/adapter_model.py b/src/adapters/models/llama/adapter_model.py index c3116fbe14..39d93ad9b5 100644 --- a/src/adapters/models/llama/adapter_model.py +++ b/src/adapters/models/llama/adapter_model.py @@ -3,6 +3,7 @@ import torch +from transformers.generation import GenerationMixin from transformers.models.llama.modeling_llama import LLAMA_START_DOCSTRING, LlamaModel, LlamaPreTrainedModel from transformers.utils import add_start_docstrings @@ -26,7 +27,9 @@ """, LLAMA_START_DOCSTRING, ) -class LlamaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, LlamaPreTrainedModel): +class LlamaAdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, LlamaPreTrainedModel, GenerationMixin +): head_types = [ "classification", "multilabel_classification", diff --git a/src/adapters/models/mbart/adapter_model.py b/src/adapters/models/mbart/adapter_model.py index ebbfb45efa..06e31650fa 100644 --- a/src/adapters/models/mbart/adapter_model.py +++ b/src/adapters/models/mbart/adapter_model.py @@ -1,5 +1,6 @@ import torch +from transformers.generation import GenerationMixin from transformers.models.mbart.modeling_mbart import ( MBART_INPUTS_DOCSTRING, MBART_START_DOCSTRING, @@ -19,7 +20,9 @@ @add_start_docstrings( "MBART Model with the option to add multiple flexible prediction heads on top.", MBART_START_DOCSTRING ) -class MBartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MBartPreTrainedModel): +class MBartAdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MBartPreTrainedModel, GenerationMixin +): _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", diff --git a/src/adapters/models/mistral/adapter_model.py b/src/adapters/models/mistral/adapter_model.py index 1909fccdec..595cace188 100644 --- a/src/adapters/models/mistral/adapter_model.py +++ b/src/adapters/models/mistral/adapter_model.py @@ -2,6 +2,7 @@ import torch +from transformers.generation import GenerationMixin from transformers.models.mistral.modeling_mistral import MISTRAL_START_DOCSTRING, MistralModel, MistralPreTrainedModel from transformers.utils import add_start_docstrings @@ -25,7 +26,9 @@ """, MISTRAL_START_DOCSTRING, ) -class MistralAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MistralPreTrainedModel): +class MistralAdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MistralPreTrainedModel, GenerationMixin +): head_types = [ "classification", "multilabel_classification", diff --git a/src/adapters/models/mt5/adapter_model.py b/src/adapters/models/mt5/adapter_model.py index 418b47b13f..705d0852ef 100644 --- a/src/adapters/models/mt5/adapter_model.py +++ b/src/adapters/models/mt5/adapter_model.py @@ -2,6 +2,7 @@ import torch +from transformers.generation import GenerationMixin from transformers.models.mt5.modeling_mt5 import ( MT5_INPUTS_DOCSTRING, MT5_START_DOCSTRING, @@ -22,7 +23,9 @@ @add_start_docstrings( "MT5 Model with the option to add multiple flexible prediction heads on top.", MT5_START_DOCSTRING ) -class MT5AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MT5PreTrainedModel): +class MT5AdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MT5PreTrainedModel, GenerationMixin +): _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", diff --git a/src/adapters/models/plbart/adapter_model.py b/src/adapters/models/plbart/adapter_model.py index 2aaaf0b9fa..0475fd077d 100644 --- a/src/adapters/models/plbart/adapter_model.py +++ b/src/adapters/models/plbart/adapter_model.py @@ -1,5 +1,6 @@ import torch +from transformers.generation import GenerationMixin from transformers.models.plbart.modeling_plbart import ( PLBART_INPUTS_DOCSTRING, PLBART_START_DOCSTRING, @@ -18,7 +19,9 @@ @add_start_docstrings( "PLBART Model with the option to add multiple flexible prediction heads on top.", PLBART_START_DOCSTRING ) -class PLBartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, PLBartPreTrainedModel): +class PLBartAdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, PLBartPreTrainedModel, GenerationMixin +): _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", diff --git a/src/adapters/models/roberta/adapter_model.py b/src/adapters/models/roberta/adapter_model.py index ab9411ef7d..5a9af959d8 100644 --- a/src/adapters/models/roberta/adapter_model.py +++ b/src/adapters/models/roberta/adapter_model.py @@ -1,3 +1,4 @@ +from transformers.generation import GenerationMixin from transformers.models.roberta.modeling_roberta import ( ROBERTA_INPUTS_DOCSTRING, ROBERTA_START_DOCSTRING, @@ -16,7 +17,9 @@ """Roberta Model transformer with the option to add multiple flexible heads on top.""", ROBERTA_START_DOCSTRING, ) -class RobertaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, RobertaPreTrainedModel): +class RobertaAdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, RobertaPreTrainedModel, GenerationMixin +): head_types = [ "classification", "multilabel_classification", diff --git a/src/adapters/models/t5/adapter_model.py b/src/adapters/models/t5/adapter_model.py index 5aa7aff4fd..5f2b324380 100644 --- a/src/adapters/models/t5/adapter_model.py +++ b/src/adapters/models/t5/adapter_model.py @@ -2,6 +2,7 @@ import torch +from transformers.generation import GenerationMixin from transformers.models.t5.modeling_t5 import T5_INPUTS_DOCSTRING, T5_START_DOCSTRING, T5Model, T5PreTrainedModel from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward @@ -15,7 +16,9 @@ @add_start_docstrings("T5 Model with the option to add multiple flexible prediction heads on top.", T5_START_DOCSTRING) -class T5AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, T5PreTrainedModel): +class T5AdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, T5PreTrainedModel, GenerationMixin +): _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", diff --git a/src/adapters/models/whisper/adapter_model.py b/src/adapters/models/whisper/adapter_model.py index d76ae610c5..4bcc026927 100644 --- a/src/adapters/models/whisper/adapter_model.py +++ b/src/adapters/models/whisper/adapter_model.py @@ -1,6 +1,7 @@ import torch from transformers import EncoderDecoderCache, StaticCache +from transformers.generation import GenerationMixin from transformers.models.whisper.modeling_whisper import ( WHISPER_INPUTS_DOCSTRING, WHISPER_START_DOCSTRING, @@ -19,7 +20,9 @@ @add_start_docstrings( "WHISPER Model with the option to add multiple flexible prediction heads on top.", WHISPER_START_DOCSTRING ) -class WhisperAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, WhisperPreTrainedModel): +class WhisperAdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, WhisperPreTrainedModel, GenerationMixin +): _tied_weights_keys = [] head_types = ["seq2seq_lm"] diff --git a/src/adapters/models/xlm_roberta/adapter_model.py b/src/adapters/models/xlm_roberta/adapter_model.py index 1cab4aaac5..559202d52d 100644 --- a/src/adapters/models/xlm_roberta/adapter_model.py +++ b/src/adapters/models/xlm_roberta/adapter_model.py @@ -1,3 +1,4 @@ +from transformers.generation import GenerationMixin from transformers.models.xlm_roberta.modeling_xlm_roberta import ( XLM_ROBERTA_INPUTS_DOCSTRING, XLM_ROBERTA_START_DOCSTRING, @@ -17,7 +18,7 @@ XLM_ROBERTA_START_DOCSTRING, ) class XLMRobertaAdapterModel( - EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, XLMRobertaPreTrainedModel + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, XLMRobertaPreTrainedModel, GenerationMixin ): head_types = [ diff --git a/src/adapters/models/xmod/adapter_model.py b/src/adapters/models/xmod/adapter_model.py index a179fc6be8..e81f49dee0 100644 --- a/src/adapters/models/xmod/adapter_model.py +++ b/src/adapters/models/xmod/adapter_model.py @@ -2,6 +2,7 @@ import torch +from transformers.generation import GenerationMixin from transformers.models.xmod.modeling_xmod import ( XMOD_INPUTS_DOCSTRING, XMOD_START_DOCSTRING, @@ -20,7 +21,9 @@ """X-MOD Model transformer with the option to add multiple flexible heads on top.""", XMOD_START_DOCSTRING, ) -class XmodAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, XmodPreTrainedModel): +class XmodAdapterModel( + EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, XmodPreTrainedModel, GenerationMixin +): head_types = [ "classification", From 326d071c4dc41ab05f2a0f520813e9f4f5032979 Mon Sep 17 00:00:00 2001 From: calpt Date: Tue, 28 Jan 2025 21:12:54 +0100 Subject: [PATCH 9/9] Release v1.1.0 --- setup.py | 5 +++-- src/adapters/__init__.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index e3af210570..33a2e77e0b 100644 --- a/setup.py +++ b/setup.py @@ -142,7 +142,7 @@ def deps_list(*pkgs): setup( name="adapters", - version="1.1.0.dev0", + version="1.1.0", author="The AdapterHub team and community contributors", author_email="calpt@mail.de", description="A Unified Library for Parameter-Efficient and Modular Transfer Learning", @@ -165,9 +165,10 @@ def deps_list(*pkgs): "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], ) diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index 88549c6969..905706b509 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.1.0.dev0" +__version__ = "1.1.0" from typing import TYPE_CHECKING