diff --git a/docs/classes/adapter_model_interface.rst b/docs/classes/adapter_model_interface.rst new file mode 100644 index 0000000000..e17f1a38ac --- /dev/null +++ b/docs/classes/adapter_model_interface.rst @@ -0,0 +1,8 @@ +Adapter Model Interface +======================= + +.. autoclass:: adapters.AdapterModelInterface + :members: + +.. autoclass:: adapters.AdapterMethod + :members: diff --git a/docs/contributing/adding_adapters_to_a_model.md b/docs/contributing/adding_adapters_to_a_model.md index 97c9fe21d7..f99297510a 100644 --- a/docs/contributing/adding_adapters_to_a_model.md +++ b/docs/contributing/adding_adapters_to_a_model.md @@ -1,4 +1,11 @@ # Adding Adapters to a Model + +```{eval-rst} +.. important:: + For most use cases, it can be much easier support a new model architecture via the new adapter plugin interface. + Check out `Custom Models <../plugin_interface.html>`_ for more. +``` + This document gives an overview of how new model architectures of Hugging Face Transformers can be supported by `adapters`. Before delving into implementation details, you should familiarize yourself with the main design philosophies of `adapters`: diff --git a/docs/index.rst b/docs/index.rst index 0c10de8a18..9cef2b044f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -52,7 +52,6 @@ Currently, we support the PyTorch versions of all models as listed on the `Model merging_adapters prediction_heads embeddings - extending .. toctree:: :maxdepth: 2 @@ -66,6 +65,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model :caption: Supported Models model_overview + plugin_interface classes/models/albert classes/models/auto classes/models/bart @@ -99,6 +99,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model classes/adapter_config classes/model_adapters_config classes/adapter_layer + classes/adapter_model_interface classes/model_mixins classes/adapter_training classes/adapter_utils @@ -110,6 +111,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model contributing contributing/adding_adapter_methods contributing/adding_adapters_to_a_model + extending Citation ======== diff --git a/docs/model_overview.md b/docs/model_overview.md index b364ab6ebb..d46cff99d1 100644 --- a/docs/model_overview.md +++ b/docs/model_overview.md @@ -13,6 +13,7 @@ The table below further shows which model architectures support which adaptation | Model | (Bottleneck)
Adapters | Prefix
Tuning | LoRA | Compacter | Adapter
Fusion | Invertible
Adapters | Parallel
block | Prompt
Tuning | ReFT | | --------------------------------------- | -| - | - | - | - | - | - |- | - | +| [Custom models](plugin_interface.html) | ✅(°) | | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ | | [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | | [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | ✅ | @@ -38,9 +39,11 @@ The table below further shows which model architectures support which adaptation | [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +(°) `original_ln_after=False` is unsupported for bottleneck configs. (*) If the used encoder and decoder model class are supported. -**Missing a model architecture you'd like to use?** -adapters can be easily extended to new model architectures as described in [Adding Adapters to a Model](https://docs.adapterhub.ml/contributing/adding_adapters_to_a_model.html). +**Missing a model architecture you'd like to use?** +The new model plugin interface makes it easy to support new transformer models with just a few lines of code [Learn more](plugin_interface.md). +Also, _Adapters_ can be extended to new model architectures as described in [Adding Adapters to a Model](https://docs.adapterhub.ml/contributing/adding_adapters_to_a_model.html). Feel free to [open an issue](https://github.com/Adapter-Hub/adapters/issues) requesting support for a new architecture. _We very much welcome pull requests adding new model implementations!_ diff --git a/docs/plugin_interface.md b/docs/plugin_interface.md new file mode 100644 index 0000000000..10ff2da74f --- /dev/null +++ b/docs/plugin_interface.md @@ -0,0 +1,94 @@ +# Custom Models + +The _Adapters_ library provides a simple mechanism for integrating adapter methods into any available _Transformers_ model - including custom architectures. +This can be accomplished by defining a plugin interface instance of [`AdapterModelInterface`](adapters.AdapterModelInterface). +The following example shows how this looks like for Gemma 2: + +```python +import adapters +from adapters import AdapterModelInterface +from transformers import AutoModelForCausalLM + +plugin_interface = AdapterModelInterface( + adapter_methods=["lora", "reft"], + model_embeddings="embed_tokens", + model_layers="layers", + layer_self_attn="self_attn", + layer_cross_attn=None, + attn_k_proj="k_proj", + attn_q_proj="q_proj", + attn_v_proj="v_proj", + attn_o_proj="o_proj", + layer_intermediate_proj="mlp.up_proj", + layer_output_proj="mlp.down_proj", +) + +model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", token="") +adapters.init(model, interface=plugin_interface) + +model.add_adapter("my_adapter", config="lora") + +print(model.adapter_summary()) +``` + +## Walkthrough + +Let's go through what happens in the example above step by step: + +**1. Define adapter methods to plug into a model:** +The `adapter_methods` argument is the central parameter to configure which adapters will be supported in the model. +Here, we enable all LoRA and ReFT based adapters. +See [`AdapterMethod`](adapters.AdapterMethod) for valid options to specify here. +Check out [Adapter Methods](methods.md) for detailed explanation of the methods. + +**2. Define layer and module names:** +While all Transformers layers share similar basic components, their implementation can differ in terms of subtleties such as module names. +Therefore, the [`AdapterModelInterface`](adapters.AdapterModelInterface) needs to translate the model-specific module structure into a common set of access points for adapter implementations to hook in. +The remaining attributes in the definition above serve this purpose. +Their attribute names follow a common syntax that specify their location and purpose: +- The initial part before the first "_" defines the base module relative to which the name should be specified. +- The remaining part after the first "_" defines the functional component. + +E.g., `model_embeddings` identifies the embeddings layer (functional component) relative to the base model (location). +`layer_output_proj` identifies the FFN output projection relative to one Transformer layer. +Each attribute value may specify a direct submodule of the reference module (`"embed_token"`) or a multi-level path starting at the reference module (`"mlp.down_proj"`). + +**3. (optional) Extended interface attributes:** +There are a couple of attributes in the [`AdapterModelInterface`](adapters.AdapterModelInterface) that are only required for some adapter methods. +We don't need those in the above example for LoRA and ReFT, but when supporting bottleneck adapters as well, the full interface would look as follows: +```python +adapter_interface = AdapterModelInterface( + adapter_types=["bottleneck", "lora", "reft"], + model_embeddings="embed_tokens", + model_layers="layers", + layer_self_attn="self_attn", + layer_cross_attn=None, + attn_k_proj="k_proj", + attn_q_proj="q_proj", + attn_v_proj="v_proj", + attn_o_proj="o_proj", + layer_intermediate_proj="mlp.up_proj", + layer_output_proj="mlp.down_proj", + layer_pre_self_attn="input_layernorm", + layer_pre_cross_attn=None, + layer_pre_ffn="pre_feedforward_layernorm", + layer_ln_1="post_attention_layernorm", + layer_ln_2="post_feedforward_layernorm", +) +``` + +**4. Initialize adapter methods in the model:** +Finally, we just need to apply the defined adapter integration in the target model. +This can be achieved using the usual `adapters.init()` method: +```python +adapters.init(model, interface=adapter_interface) +``` +Now, you can use (almost) all functionality of the _Adapters_ library on the adapted model instance! + +## Limitations + +The following features of the _Adapters_ library are not supported via the plugin interface approach: +- Prefix Tuning adapters +- Parallel composition blocks +- XAdapterModel classes +- Setting `original_ln_after=False` in bottleneck adapter configurations (this affects `AdapterPlusConfig`) diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index 905706b509..b813c8644c 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -84,6 +84,7 @@ "Seq2SeqLMHead", "TaggingHead", ], + "interface": ["AdapterMethod", "AdapterModelInterface"], "methods.adapter_layer_base": ["AdapterLayerBase", "ComposableAdapterLayerBase"], "model_mixin": [ "EmbeddingAdaptersMixin", @@ -198,6 +199,7 @@ Seq2SeqLMHead, TaggingHead, ) + from .interface import AdapterMethod, AdapterModelInterface from .methods.adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase from .model_mixin import ( EmbeddingAdaptersMixin, diff --git a/src/adapters/interface.py b/src/adapters/interface.py new file mode 100644 index 0000000000..2f0026296a --- /dev/null +++ b/src/adapters/interface.py @@ -0,0 +1,121 @@ +import json +import os +from dataclasses import asdict, dataclass +from typing import List, Optional + +from transformers.utils import cached_file + +from . import __version__ +from .utils import INTERFACE_CONFIG_NAME + + +class AdapterMethod: + """ + Enum of all supported adapter method types. + + Attributes: + bottleneck: Adapter methods using bottleneck layers. + prefix_tuning: Adapters methods based on Prefix Tuning. Note that this is currently unsupported via AdapterModelInterface. + lora: Adapter methods based on low-rank adaptation. + prompt_tuning: Adapter methods based on Prompt Tuning. + reft: Adapters methods based on Representation Fine-Tuning. + invertible: Adapter methods using invertible modules. + """ + + bottleneck = "bottleneck" + prefix_tuning = "prefix_tuning" + lora = "lora" + prompt_tuning = "prompt_tuning" + reft = "reft" + invertible = "invertible" + + @staticmethod + def get_from_config(config) -> List[str]: + """ + Get the adapter type from a given adapter config. + + Args: + config: The adapter config. + + Returns: + List[str]: The adapter type. + """ + methods = [] + if getattr(config, "inv_adapter", False): + methods.append(AdapterMethod.invertible) + if config.architecture is None: + methods.append(AdapterMethod.bottleneck) + elif config.architecture == "union": + for sub_config in config.configs: + methods.extend(AdapterMethod.get_from_config(sub_config)) + else: + methods.append(config.architecture) + return methods + + +@dataclass +class AdapterModelInterface: + """ + Defines the main interface for integrating adapter methods into a model class. + This interface translates generic accessor names to model-specific attribute names. + + Args: + adapter_methods (List[str]): List of adapter types that are supported by the model. + model_embeddings (str): Name of the model's embedding layer. + model_layers (str): Name of the model's layer list. + layer_self_attn (str): Name of the self-attention layer in a transformer layer. + layer_cross_attn (str): Name of the cross-attention layer in a transformer layer. + attn_k_proj (str): Name of the key projection layer in an attention layer. + attn_q_proj (str): Name of the query projection layer in an attention layer. + attn_v_proj (str): Name of the value projection layer in an attention layer. + attn_o_proj (str): Name of the output projection layer in an attention layer. + layer_intermediate_proj (str): Name of the intermediate projection layer in a transformer layer. + layer_output_proj (str): Name of the output projection layer in a transformer layer. + layer_pre_self_attn (Optional[str]): Hook point directly before the self attention layer. Used for extended bottleneck adapter support. + layer_pre_cross_attn (Optional[str]): Hook point directly before the cross attention layer. Used for extended bottleneck adapter support. + layer_pre_ffn (Optional[str]): Hook point directly before the feed forward layer. Used for extended bottleneck adapter support. + layer_ln_1 (Optional[str]): Layer norm *after* the self-attention layer. Used for extended bottleneck adapter support. + layer_ln_2 (Optional[str]): Layer norm *after* the feed forward layer. Used for extended bottleneck adapter support. + """ + + adapter_methods: List[str] + + model_embeddings: str + model_layers: str + + layer_self_attn: str + layer_cross_attn: str + attn_k_proj: str + attn_q_proj: str + attn_v_proj: str + attn_o_proj: str + + layer_intermediate_proj: str + layer_output_proj: str + + # Optional attributes for extended bottleneck adapter support + layer_pre_self_attn: Optional[str] = None + layer_pre_cross_attn: Optional[str] = None + layer_pre_ffn: Optional[str] = None + layer_ln_1: Optional[str] = None + layer_ln_2: Optional[str] = None + + def to_dict(self): + return asdict(self) + + def _save(self, save_directory, model_config): + config_dict = { + "model_type": model_config.model_type, + "interface": self.to_dict(), + "version": "adapters." + __version__, + } + save_path = os.path.join(save_directory, INTERFACE_CONFIG_NAME) + with open(save_path, "w") as f: + json.dump(config_dict, f, indent=2, sort_keys=True) + + @classmethod + def _load(cls, path_or_repo_id: str, **kwargs): + resolved_file = cached_file(path_or_repo_id, INTERFACE_CONFIG_NAME, **kwargs) + with open(resolved_file, "r") as f: + config_dict = json.load(f) + return AdapterModelInterface(**config_dict["interface"]) diff --git a/src/adapters/methods/__init__.py b/src/adapters/methods/__init__.py index e69de29bb2..5082ff0b6b 100644 --- a/src/adapters/methods/__init__.py +++ b/src/adapters/methods/__init__.py @@ -0,0 +1,14 @@ +from .bottleneck import init_bottleneck +from .invertible import init_invertible_adapters +from .lora import init_lora +from .prompt_tuning import init_prompt_tuning +from .reft import init_reft + + +METHOD_INIT_MAPPING = { + "bottleneck": init_bottleneck, + "lora": init_lora, + "prompt_tuning": init_prompt_tuning, + "reft": init_reft, + "invertible": init_invertible_adapters, +} diff --git a/src/adapters/methods/bottleneck.py b/src/adapters/methods/bottleneck.py index 889941d2b9..74461635ee 100644 --- a/src/adapters/methods/bottleneck.py +++ b/src/adapters/methods/bottleneck.py @@ -1,3 +1,4 @@ +from functools import partial from typing import List, Mapping, NamedTuple, Optional, Union import torch @@ -15,10 +16,16 @@ ) from ..configuration import BnConfig from ..context import ForwardContext +from ..utils import multigetattr from .adapter_layer_base import ComposableAdapterLayerBase from .modeling import Adapter, BertFusion, ParallelAdapter +LAYER_HOOK_UNSUPPORTED = [ + ("original_ln_after", False), +] + + class BottleneckState(NamedTuple): """ Models the input and output states of a bottleneck adapter layer. @@ -45,9 +52,10 @@ class BottleneckLayer(ComposableAdapterLayerBase, nn.Module): adapter_modules_name = "adapters" supported_compositions = [Stack, Fuse, Split, Parallel, BatchSplit, Average] - def __init__(self, location_key: str): + def __init__(self, location_key: str, is_layer_hooked: bool = False): super().__init__() self.location_key = location_key + self.is_layer_hooked = is_layer_hooked def init_adapters(self, model_config, adapters_config): self._init_mapping() @@ -55,6 +63,8 @@ def init_adapters(self, model_config, adapters_config): self.adapters_config = adapters_config self.adapters = nn.ModuleDict(dict()) self.adapter_fusion_layer = nn.ModuleDict(dict()) + if not hasattr(self, "is_layer_hooked"): + self.is_layer_hooked = False def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: self.layer_idx = layer_idx @@ -78,6 +88,15 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: '{"1": 16, "default": 16}' ) + # check unsupported configurations for layer hooking mode + if self.is_layer_hooked: + for key, value in LAYER_HOOK_UNSUPPORTED: + if adapter_config.get(key, None) == value: + raise ValueError( + f"Unsupported configuration for bottleneck layer hooking mode: {key}={value}. " + "Please set this configuration to a supported value." + ) + if adapter_config.is_parallel: adapter_class = ParallelAdapter else: @@ -88,6 +107,10 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: down_sample=int(self.model_config.hidden_size // reduction_factor), config=adapter_config, ) + # for adapters hooked via interface: + # residual & LN are applied by model, so don't apply in adapters + if self.is_layer_hooked: + adapter.original_ln_after = False adapter.train(self.training) # make sure training mode is consistent self.adapters[adapter_name] = adapter return True @@ -321,9 +344,10 @@ def bottleneck_layer_forward(self, hidden_states, residual_input, layer_norm): torch.Tensor: Output hidden states of the adapter layer. """ # Batch sizes might be different due to prefix tuning w. Parallel block - (residual_input,) = adjust_tensors_for_parallel(hidden_states, residual_input) - # Replicate in both directions as residual might be larger (e.g. GPT-J) - (hidden_states,) = adjust_tensors_for_parallel(residual_input, hidden_states) + if residual_input is not None: + (residual_input,) = adjust_tensors_for_parallel(hidden_states, residual_input) + # Replicate in both directions as residual might be larger (e.g. GPT-J) + (hidden_states,) = adjust_tensors_for_parallel(residual_input, hidden_states) adapter_setup = self.get_active_setup() if adapter_setup is not None: input_hidden_states = hidden_states @@ -335,9 +359,9 @@ def bottleneck_layer_forward(self, hidden_states, residual_input, layer_norm): last_adapter = self.adapters[last] hidden_states = last_adapter.post_forward(hidden_states, input_hidden_states, residual_input, layer_norm) - elif layer_norm: + elif layer_norm is not None and not self.is_layer_hooked: hidden_states = layer_norm(hidden_states + residual_input) - else: + elif residual_input is not None and not self.is_layer_hooked: hidden_states = hidden_states + residual_input return hidden_states @@ -354,3 +378,55 @@ def forward(self, hidden_states, residual_input, layer_norm): torch.Tensor: Output hidden states of the adapter layer. """ return self.bottleneck_layer_forward(hidden_states, residual_input, layer_norm) + + +def hook_fn(adapter_layer, ln_get_fn, module, args, output): + # Retrieve residual from previous hook, if existing + context = ForwardContext.get_context() + residual_input = getattr(context, f"{adapter_layer.location_key}_residual_input", None) + # Retrieve layer norm from getter fn + if ln_get_fn is not None: + layer_norm = ln_get_fn() + else: + layer_norm = None + # Call adapter layer + if isinstance(output, torch.Tensor): + return adapter_layer(output, residual_input, layer_norm) + else: + return (adapter_layer(output[0], residual_input, layer_norm),) + output[1:] + + +def _residual_hook_fn(location_key, module, args): + context = ForwardContext.get_context() + if context is not None: + setattr(context, f"{location_key}_residual_input", args[0]) + + +def init_bottleneck(model): + model = model.base_model + for _, layer in model.iter_layers(): + if self_attn := multigetattr(layer, model.adapter_interface.layer_self_attn, None): + if o_proj := multigetattr(self_attn, model.adapter_interface.attn_o_proj, None): + if not hasattr(layer, "attention_adapters"): + layer.attention_adapters = BottleneckLayer("mh_adapter", is_layer_hooked=True) + ln_1_get_fn = lambda: multigetattr(layer, model.adapter_interface.layer_ln_1, None) + o_proj.register_forward_hook(partial(hook_fn, layer.attention_adapters, ln_1_get_fn)) + if layer_output_proj := multigetattr(layer, model.adapter_interface.layer_output_proj, None): + if not hasattr(layer, "output_adapters"): + layer.output_adapters = BottleneckLayer("output_adapter", is_layer_hooked=True) + ln_2_get_fn = lambda: multigetattr(layer, model.adapter_interface.layer_ln_2, None) + layer_output_proj.register_forward_hook(partial(hook_fn, layer.output_adapters, ln_2_get_fn)) + if cross_attn := multigetattr(layer, model.adapter_interface.layer_cross_attn, None): + if not hasattr(cross_attn, "cross_attention_adapters"): + layer.attention_adapters = BottleneckLayer("cross_adapter", is_layer_hooked=True) + cross_attn.register_forward_hook(partial(hook_fn, layer.attention_adapters, None)) + + if model.adapter_interface.layer_pre_self_attn is not None: + if pre_self_attn := multigetattr(layer, model.adapter_interface.layer_pre_self_attn, None): + pre_self_attn.register_forward_pre_hook(partial(_residual_hook_fn, "mh_adapter")) + if model.adapter_interface.layer_pre_cross_attn is not None: + if pre_cross_attn := multigetattr(layer, model.adapter_interface.layer_pre_cross_attn, None): + pre_cross_attn.register_forward_pre_hook(partial(_residual_hook_fn, "cross_adapter")) + if model.adapter_interface.layer_pre_ffn is not None: + if pre_ffn := multigetattr(layer, model.adapter_interface.layer_pre_ffn, None): + pre_ffn.register_forward_pre_hook(partial(_residual_hook_fn, "output_adapter")) diff --git a/src/adapters/methods/invertible.py b/src/adapters/methods/invertible.py new file mode 100644 index 0000000000..4a8158599c --- /dev/null +++ b/src/adapters/methods/invertible.py @@ -0,0 +1,104 @@ +import types +from functools import partial + +import torch +import torch.nn as nn + +from ..configuration.adapter_config import BnConfig +from ..utils import multigetattr +from .adapter_layer_base import AdapterLayerBase +from .modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock + + +class InvertibleAdapterLayer(AdapterLayerBase, nn.ModuleDict): + adapter_modules_name = "_modules" + + def __init__(self, model_config, adapters_config): + super().__init__() + self.location_key = "inv_adapter" + self.model_config = model_config + self.adapters_config = adapters_config + + def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: + self.layer_idx = layer_idx + embedding_size = getattr(self.model_config, "embedding_size", self.model_config.hidden_size) + adapter_config = self.adapters_config.match( + adapter_name, + config_type=BnConfig, + location_key="inv_adapter", + ) + if adapter_config is not None and adapter_config["inv_adapter"]: + if adapter_config["inv_adapter"] == "nice": + inv_adap = NICECouplingBlock( + [[embedding_size]], + non_linearity=adapter_config["non_linearity"], + reduction_factor=adapter_config["inv_adapter_reduction_factor"], + ) + elif adapter_config["inv_adapter"] == "glow": + inv_adap = GLOWCouplingBlock( + [[embedding_size]], + non_linearity=adapter_config["non_linearity"], + reduction_factor=adapter_config["inv_adapter_reduction_factor"], + ) + else: + raise ValueError(f"Invalid invertible adapter type '{adapter_config['inv_adapter']}'.") + self[adapter_name] = inv_adap + self[adapter_name].apply(Adapter.init_bert_weights) + return True + + return False + + def get_invertible_adapter(self): + # HACK: returns the first adapter of the currently active setup. for backwards compatibility + adapter_setup = self.get_active_setup() + if adapter_setup is not None and len(adapter_setup) > 0: + first_adapter = adapter_setup.first() + if first_adapter in self: + return self[first_adapter] + return None + + def forward(self, hidden_states: torch.Tensor, rev=False): + adapter_setup = self.get_active_setup() + if adapter_setup is not None and len(adapter_setup) > 0: + first_adapter = adapter_setup.first() + if first_adapter in self: + hidden_states = self[first_adapter](hidden_states, rev=rev) + return hidden_states + + +def hook_fn(model, module, args, embedding_output): + embedding_output = model.invertible_adapters(embedding_output) + return embedding_output + + +def inv_hook_fn(model, module, args): + inv_output = model.invertible_adapters(args[0], rev=True) + return (inv_output,) + args[1:] + + +def init_invertible_adapters(model): + base_model = model.base_model + if not hasattr(base_model, "invertible_adapters"): + base_model.invertible_adapters = InvertibleAdapterLayer(base_model.config, base_model.adapters_config) + + embed_layer = multigetattr(base_model, base_model.adapter_interface.model_embeddings) + embed_layer.register_forward_hook(partial(hook_fn, base_model)) + + # Add methods from original invertible adapter mixin. + # This is primarily for backwards compatibility and internal use. + base_model.add_invertible_adapter = types.MethodType( + lambda self, *args, **kwargs: self.invertible_adapters.add_adapter(*args, **kwargs), base_model + ) + base_model.delete_invertible_adapter = types.MethodType( + lambda self, *args, **kwargs: self.invertible_adapters.delete_adapter(*args, **kwargs), base_model + ) + base_model.get_invertible_adapter = types.MethodType( + lambda self: self.invertible_adapters.get_invertible_adapter(), base_model + ) + base_model.invertible_adapters_forward = types.MethodType( + lambda self, *args, **kwargs: self.invertible_adapters(*args, **kwargs), base_model + ) + + # Register reverse forward pass + if output_embedding := model.get_output_embeddings(): + output_embedding.register_forward_pre_hook(partial(inv_hook_fn, base_model)) diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py index 1dc8ef9f23..ab4cb7aaa1 100644 --- a/src/adapters/methods/lora.py +++ b/src/adapters/methods/lora.py @@ -16,6 +16,7 @@ from ..composition import Average, BatchSplit, Parallel, Stack from ..configuration import LoRAConfig, ModelAdaptersConfig +from ..utils import multigetattr, multisetattr from .adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase from .utils import dequantize_bnb_weight, fix_seed @@ -336,6 +337,14 @@ def _average_adapter_lora_delta_w_svd(self, input_adapters: Dict[str, float], av avg_state_dict["lora_A"] = V avg_state_dict["lora_B"] = U @ torch.diag(S_diag) + def _copy_hooks_from(self, module: nn.Module): + for ( + k, + v, + ) in module.__dict__.items(): + if "_hooks" in k: + setattr(self, k, v) + class LoRAState(NamedTuple): """Models the input and output states of a LoRA layer. @@ -432,6 +441,7 @@ def wrap( **kwargs, ) new_module.copy_from(module) + new_module._copy_hooks_from(module) return new_module @@ -676,6 +686,7 @@ def wrap( new_module.weight = module.weight if module.bias is not None: new_module.bias = module.bias + new_module._copy_hooks_from(module) return new_module @@ -814,3 +825,25 @@ def T(w): raise ValueError(f"Invalid adapter setup. Cannot use {adapter_setup} with LoRA.") return F.linear(x, T(self.weight), bias=self.bias) + + +def init_lora(model): + model = model.base_model + for _, _, attention in model.iter_attentions(): + if q_proj := multigetattr(attention, model.adapter_interface.attn_q_proj, None): + lora_proj = LoRALinear.wrap(q_proj, "selfattn", model.config, model.adapters_config, attn_key="q") + multisetattr(attention, model.adapter_interface.attn_q_proj, lora_proj) + if k_proj := multigetattr(attention, model.adapter_interface.attn_k_proj, None): + lora_proj = LoRALinear.wrap(k_proj, "selfattn", model.config, model.adapters_config, attn_key="k") + multisetattr(attention, model.adapter_interface.attn_k_proj, lora_proj) + if v_proj := multigetattr(attention, model.adapter_interface.attn_v_proj, None): + lora_proj = LoRALinear.wrap(v_proj, "selfattn", model.config, model.adapters_config, attn_key="v") + multisetattr(attention, model.adapter_interface.attn_v_proj, lora_proj) + + for _, layer in model.iter_layers(): + if intermediate_proj := multigetattr(layer, model.adapter_interface.layer_intermediate_proj): + lora_proj = LoRALinear.wrap(intermediate_proj, "intermediate", model.config, model.adapters_config) + multisetattr(layer, model.adapter_interface.layer_intermediate_proj, lora_proj) + if output_proj := multigetattr(layer, model.adapter_interface.layer_output_proj): + lora_proj = LoRALinear.wrap(output_proj, "output", model.config, model.adapters_config) + multisetattr(layer, model.adapter_interface.layer_output_proj, lora_proj) diff --git a/src/adapters/methods/prompt_tuning.py b/src/adapters/methods/prompt_tuning.py index b9504ac40b..b91b7e76b9 100644 --- a/src/adapters/methods/prompt_tuning.py +++ b/src/adapters/methods/prompt_tuning.py @@ -1,6 +1,7 @@ # https://github.com/google-research/prompt-tuning/blob/main/prompt_tuning/train/prompts.py import math +from functools import partial from typing import Callable import numpy as np @@ -12,6 +13,7 @@ from ..configuration import ModelAdaptersConfig, PromptTuningConfig from ..context import ForwardContext +from ..utils import multigetattr, prefix_attention_mask from .adapter_layer_base import AdapterLayerBase from .utils import fix_seed @@ -179,3 +181,27 @@ def forward(self, hidden_states: torch.Tensor): context.prompt_tokens_length = prefix_attention_mask_length return hidden_states + + +def hook_fn(model, module, args, embedding_output): + embedding_output = model.prompt_tuning.forward(embedding_output) + return embedding_output + + +# TODO: this will only work for a limited set of models +def _attn_mask_hook_fn(module, args): + attn_mask = args[1] + attn_mask = prefix_attention_mask(attn_mask) + return (args[0], attn_mask) + args[2:] + + +def init_prompt_tuning(model): + model = model.base_model + if not hasattr(model, "prompt_tuning"): + model.support_prompt_tuning = True + model.prompt_tuning = PromptTuningLayer(model.config, model.adapters_config, model.get_input_embeddings()) + embed_layer = multigetattr(model, model.adapter_interface.model_embeddings) + embed_layer.register_forward_hook(partial(hook_fn, model)) + + for _, layer in model.iter_layers(): + layer.register_forward_pre_hook(_attn_mask_hook_fn) diff --git a/src/adapters/methods/reft.py b/src/adapters/methods/reft.py index 1884bf5e01..d4904d7fd2 100644 --- a/src/adapters/methods/reft.py +++ b/src/adapters/methods/reft.py @@ -229,6 +229,7 @@ def hook_fn(module, args, output): def init_reft(model): + model = model.base_model for _, layer in model.iter_layers(): if not hasattr(layer, "reft_layer"): layer.reft_layer = ReftLayer("output", model.config, model.adapters_config) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 8303191a52..63c251874f 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -9,7 +9,7 @@ from copy import deepcopy from functools import partial from os.path import join -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union import torch from torch import nn @@ -25,7 +25,9 @@ from .configuration import ADAPTER_CONFIG_MAP, AdapterConfig, AdapterFusionConfig, BnConfig from .context import AdapterSetup, ForwardContext from .hub_mixin import PushAdapterToHubMixin +from .interface import AdapterMethod, AdapterModelInterface from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader +from .methods import METHOD_INIT_MAPPING from .methods.adapter_layer_base import AdapterLayerBase from .methods.bottleneck import BottleneckLayer from .methods.lora import LoRALayer @@ -39,6 +41,8 @@ TOKENIZER_PATH, get_adapter_config_hash, inherit_doc, + multigetattr, + multihasattr, patch_forward, resolve_adapter_path, ) @@ -432,10 +436,16 @@ def _init_adapters_submodules(self, model_config, adapters_config): if hasattr(module, "init_adapters"): module.init_adapters(model_config, adapters_config) - # Initialize reft modules - init_reft(self) + def _default_init_adapter_methods(self, model_config, adapters_config): + init_reft(self.base_model) + # Add prefix tuning + self.base_model.prefix_tuning = PrefixTuningPool(model_config, adapters_config) + # Add Prompt Tuning + if self.add_base_adapters: + if self.support_prompt_tuning: + self.prompt_tuning = PromptTuningLayer(model_config, adapters_config, self.get_input_embeddings()) - def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=True): + def init_adapters(self, model_config, adapters_config): """ This method initializes adapter modules and fusion modules from the model config. """ @@ -443,19 +453,22 @@ def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=Tr # Initialize adapters config init_adapters_config(self, model_config, adapters_config) + + # Initialize adapter types defined in interface + if getattr(self.base_model, "adapter_interface", None) is not None: + for adapter_type in self.base_model.adapter_interface.adapter_methods: + init_func = METHOD_INIT_MAPPING[adapter_type] + init_func(self) + else: + self._default_init_adapter_methods(self.config, self.adapters_config) + # Initialize adapters in all submodules self._init_adapters_submodules(self.config, self.adapters_config) # Link all prefix tunings - if add_prefix_tuning_pool: - self.base_model.prefix_tuning = PrefixTuningPool(self.config, self.adapters_config) + if hasattr(self.base_model, "prefix_tuning"): self.apply_to_adapter_layers(lambda i, layer: self._link_prefix_to_pool(layer)) - # Add Prompt Tuning - if self.add_base_adapters: - if self.support_prompt_tuning: - self.prompt_tuning = PromptTuningLayer(model_config, self.adapters_config, self.get_input_embeddings()) - # Initialize adapters from config for adapter_name in self.adapters_config: self._add_adapter_weights(adapter_name) @@ -468,6 +481,35 @@ def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=Tr self._add_tied_weights_keys() + def supports_adapter(self, type_or_config: Union[str, AdapterConfig]) -> bool: + """ + Checks if the model supports a given adapter type. + + Args: + adapter_type (str): The adapter type to check. + + Returns: + bool: True if the adapter type is supported, False otherwise. + """ + if isinstance(type_or_config, AdapterConfig): + types = AdapterMethod.get_from_config(type_or_config) + else: + types = [type_or_config] + + supported = [] + for _type in types: + if getattr(self.base_model, "adapter_interface", None) is not None: + supported.append(_type in self.base_model.adapter_interface.adapter_methods) + elif _type == AdapterMethod.prompt_tuning: + supported.append(self.base_model.support_prompt_tuning) + elif _type == AdapterMethod.invertible: + supported.append( + isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin) + ) + else: + supported.append(True) + return all(supported) + # These methods have to be implemented by every deriving class: @abstractmethod @@ -593,6 +635,10 @@ def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False the adapter is added but not activated. """ config = AdapterConfig.load(config) # ensure config is ok and up-to-date + # check if config is valid for this model + config_or_type = config or AdapterMethod.bottleneck + if not self.supports_adapter(config_or_type): + raise ValueError(f"Adapter config or type '{config_or_type}' is not supported by this model.") # In case adapter already exists and we allow overwriting, explicitly delete the existing one first if overwrite_ok and adapter_name in self.adapters_config: self.delete_adapter(adapter_name) @@ -1200,12 +1246,10 @@ def get_adapter(self, name) -> dict: # global weights are saved at index -1 if name in self.base_model.shared_parameters: destination[-1]["shared"] = self.base_model.shared_parameters[name] - if ( - isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin) - ) and name in self.invertible_adapters: + if self.supports_adapter("invertible") and name in self.invertible_adapters: destination[-1]["invertible"] = self.invertible_adapters[name] - if self.support_prompt_tuning: + if self.supports_adapter("prompt_tuning"): prompt_tuning = self.prompt_tuning.get_adapter(name) if prompt_tuning is not None: destination[-1]["prompt"] = prompt_tuning @@ -1614,6 +1658,8 @@ def save_pretrained( lambda i, layer: layer.set_pool(None) if isinstance(layer, PrefixTuningLayer) else None ) + if interface := getattr(self.base_model, "adapter_interface", None): + interface._save(save_directory, self.config) super().save_pretrained(save_directory, **kwargs) # Remove adapters config del self.config.adapters @@ -1686,13 +1732,37 @@ def gradient_checkpointing_function(function, *args, **kwargs): @inherit_doc class ModelBaseAdaptersMixin(ModelAdaptersMixin): + adapter_interface: AdapterModelInterface = None add_base_adapters = True - def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=True): - super().init_adapters(model_config, adapters_config, add_prefix_tuning_pool) + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) patch_forward(self) + # Adapter Interface Methods + + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + for i, layer in enumerate(multigetattr(self, self.adapter_interface.model_layers)): + yield i, layer + + def get_layer(self, idx: int) -> nn.Module: + return multigetattr(self, self.adapter_interface.model_layers)[idx] + + def iter_attentions(self) -> Iterable[Tuple[int, Literal["self", "cross"], nn.Module]]: + for i, layer in self.iter_layers(): + if multihasattr(layer, self.adapter_interface.layer_self_attn or ""): + yield i, "self", multigetattr(layer, self.adapter_interface.layer_self_attn) + if multihasattr(layer, self.adapter_interface.layer_cross_attn or ""): + yield i, "cross", multigetattr(layer, self.adapter_interface.layer_cross_attn) + + def iter_layer_ffns(self) -> Iterable[Tuple[int, Literal["intermediate", "output"], nn.Module]]: + for i, layer in self.iter_layers(): + if intermediate_proj := multigetattr(layer, self.adapter_interface.layer_intermediate_proj): + yield i, "intermediate", intermediate_proj + if output_proj := multigetattr(layer, self.adapter_interface.layer_output_proj): + yield i, "output", output_proj + def post_embedding_forward(self, module, args, embedding_output): if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin): embedding_output = self.invertible_adapters_forward(embedding_output) @@ -1724,6 +1794,12 @@ def _init_adapters_submodules(self, model_config, adapters_config): """ pass + def _default_init_adapter_methods(self, model_config, adapters_config): + """ + Init default adapter methods in base model. This is done in sub-models, so don't do anything here. + """ + pass + @inherit_doc class ModelWithHeadsAdaptersMixin(ModelAdaptersMixin): @@ -1734,8 +1810,8 @@ class ModelWithHeadsAdaptersMixin(ModelAdaptersMixin): def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) - def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=True): - super().init_adapters(model_config, adapters_config, add_prefix_tuning_pool=add_prefix_tuning_pool) + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) self._convert_to_flex_head = False def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: diff --git a/src/adapters/models/clip/mixin_clip.py b/src/adapters/models/clip/mixin_clip.py index 36f16b240a..55e823328e 100644 --- a/src/adapters/models/clip/mixin_clip.py +++ b/src/adapters/models/clip/mixin_clip.py @@ -5,7 +5,7 @@ from ...composition import adjust_tensors_for_parallel_ from ...methods.bottleneck import BottleneckLayer from ...methods.lora import LoRALinear -from ...methods.prefix_tuning import PrefixTuningLayer +from ...methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool from ...methods.reft import ReftLayer, hook_fn from ...model_mixin import ( EmbeddingAdaptersMixin, @@ -119,6 +119,7 @@ def _init_adapters_submodules(self, model_config, adapters_config): if hasattr(module, "init_adapters"): module.init_adapters(model_config.vision_config, adapters_config) + def _default_init_adapter_methods(self, model_config, adapters_config): # Patch for ReFT initialization for layer in self.text_model.encoder.layers: if not hasattr(layer, "reft_layer"): @@ -128,3 +129,6 @@ def _init_adapters_submodules(self, model_config, adapters_config): if not hasattr(layer, "reft_layer"): layer.reft_layer = ReftLayer("output", model_config.vision_config, adapters_config) layer.register_forward_hook(hook_fn) + + # Add prefix tuning + self.base_model.prefix_tuning = PrefixTuningPool(model_config, adapters_config) diff --git a/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py b/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py index 50257d1536..1103eae310 100644 --- a/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py +++ b/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py @@ -26,7 +26,7 @@ def init_adapters(self, model_config, adapters_config): # Before initializing adapters, forward adding invertible adapters to the encoder self.add_invertible_adapter = self.encoder.base_model.add_invertible_adapter - super().init_adapters(model_config, adapters_config, add_prefix_tuning_pool=False) + super().init_adapters(model_config, adapters_config) # ensure that encoder and decoder use the same shared parameters if hasattr(self.encoder, "set_shared_parameters"): diff --git a/src/adapters/utils.py b/src/adapters/utils.py index 7c0540850a..f55a042673 100644 --- a/src/adapters/utils.py +++ b/src/adapters/utils.py @@ -54,6 +54,7 @@ EMBEDDING_FILE = "embedding.pt" TOKENIZER_PATH = "tokenizer" SETUP_CONFIG_NAME = "adapter_setup.json" +INTERFACE_CONFIG_NAME = "adapter_interface.json" ADAPTER_HUB_URL = "https://raw.githubusercontent.com/Adapter-Hub/Hub/master/dist/v2/" ADAPTER_HUB_INDEX_FILE = ADAPTER_HUB_URL + "index/{}.json" @@ -173,6 +174,39 @@ def inherit_doc(cls): return cls +def multigetattr(o: object, name: str, default=None) -> Optional[object]: + if not name: + return default + for n in name.split("."): + if hasattr(o, n): + o = getattr(o, n) + else: + return default + return o + + +def multihasattr(o: object, name: str) -> bool: + if not name: + return False + parts = name.split(".") + for n in parts: + if hasattr(o, n): + o = getattr(o, n) + else: + return False + return True + + +def multisetattr(o: object, name: str, value: object): + parts = name.split(".") + for n in parts[:-1]: + if hasattr(o, n): + o = getattr(o, n) + else: + return + setattr(o, parts[-1], value) + + def urljoin(*args): return "/".join([s.strip("/") for s in args]) diff --git a/src/adapters/wrappers/configuration.py b/src/adapters/wrappers/configuration.py index 40dc421787..709bb54009 100644 --- a/src/adapters/wrappers/configuration.py +++ b/src/adapters/wrappers/configuration.py @@ -94,6 +94,8 @@ def init_adapters_config( model.adapters_config = ModelAdaptersConfig() elif model_config.adapters is not None and not isinstance(model_config.adapters, ModelAdaptersConfig): model.adapters_config = ModelAdaptersConfig(**model_config.adapters) + if hasattr(model, "base_model") and model.base_model is not model: + model.base_model.adapters_config = model.adapters_config # Convert AdapterFusions from old format for backwards compatibility fusion_models = getattr(model_config, "adapter_fusion_models", []) diff --git a/src/adapters/wrappers/model.py b/src/adapters/wrappers/model.py index 1f54e29ca1..70c4f285a9 100644 --- a/src/adapters/wrappers/model.py +++ b/src/adapters/wrappers/model.py @@ -10,9 +10,12 @@ from transformers.models.auto.configuration_auto import model_type_to_module_name from ..configuration import ModelAdaptersConfig +from ..interface import AdapterModelInterface from ..model_mixin import ( + EmbeddingAdaptersMixin, EmbeddingAdaptersWrapperMixin, ModelAdaptersMixin, + ModelBaseAdaptersMixin, ModelUsingSubmodelsAdaptersMixin, ModelWithHeadsAdaptersMixin, ) @@ -40,7 +43,9 @@ def replace_with_adapter_class(module: nn.Module, modules_with_adapters) -> None module.__class__.__name__, (MODEL_MIXIN_MAPPING[module.__class__.__name__], module.__class__), {} ) module.__class__ = model_class - elif module.__class__.__module__.startswith("transformers.models"): + elif module.__class__.__module__.startswith("transformers.models") or module.__class__.__module__.startswith( + "adapters.wrappers.model" + ): try: module_class = getattribute_from_module(modules_with_adapters, module.__class__.__name__ + "WithAdapters") module.__class__ = module_class @@ -49,30 +54,51 @@ def replace_with_adapter_class(module: nn.Module, modules_with_adapters) -> None pass -def init(model: PreTrainedModel, adapters_config: Optional[ModelAdaptersConfig] = None) -> None: +def init( + model: PreTrainedModel, + adapters_config: Optional[ModelAdaptersConfig] = None, + interface: Optional[AdapterModelInterface] = None, +) -> None: if isinstance(model, ModelAdaptersMixin): return model - # First, replace original module classes with their adapters counterparts - model_name = get_module_name(model.config.model_type) - modules_with_adapters = importlib.import_module(f".{model_name}.modeling_{model_name}", "adapters.models") - submodules = list(model.modules()) - - # Replace the base model class - replace_with_adapter_class(submodules.pop(0), modules_with_adapters) - - # Check if the base model class derives from ModelUsingSubmodelsAdaptersMixin - if isinstance(model, ModelUsingSubmodelsAdaptersMixin): - # Before initializing the submodels, make sure that adapters_config is set for the whole model. - # Otherwise, it would not be shared between the submodels. - init_adapters_config(model, model.config, adapters_config) - adapters_config = model.adapters_config - model.init_submodels() - submodules = [] - - # Change the class of all child modules to their adapters class - for module in submodules: - replace_with_adapter_class(module, modules_with_adapters) + if interface is not None: + base_model = model.base_model + model_class_name = base_model.__class__.__name__ + model_class = type( + model_class_name, + (EmbeddingAdaptersMixin, ModelBaseAdaptersMixin, base_model.__class__), + {}, + ) + base_model.__class__ = model_class + base_model.adapter_interface = interface + base_model.support_prompt_tuning = False # HACK: will be set to true if init_prompt_tuning() is called + else: + # First, replace original module classes with their adapters counterparts + model_name = get_module_name(model.config.model_type) + try: + modules_with_adapters = importlib.import_module(f".{model_name}.modeling_{model_name}", "adapters.models") + except ImportError: + raise ValueError( + f"Model {model_name} not pre-supported by adapters. Please specify and pass `interface` explicitly." + ) + submodules = list(model.modules()) + + # Replace the base model class + replace_with_adapter_class(submodules.pop(0), modules_with_adapters) + + # Check if the base model class derives from ModelUsingSubmodelsAdaptersMixin + if isinstance(model, ModelUsingSubmodelsAdaptersMixin): + # Before initializing the submodels, make sure that adapters_config is set for the whole model. + # Otherwise, it would not be shared between the submodels. + init_adapters_config(model, model.config, adapters_config) + adapters_config = model.adapters_config + model.init_submodels() + submodules = [] + + # Change the class of all child modules to their adapters class + for module in submodules: + replace_with_adapter_class(module, modules_with_adapters) # Next, check if model class itself is not replaced and has an adapter-supporting base class if not isinstance(model, ModelAdaptersMixin): @@ -98,6 +124,7 @@ def init(model: PreTrainedModel, adapters_config: Optional[ModelAdaptersConfig] def load_model( model_name_or_path: Optional[Union[str, os.PathLike]], model_class: Type[PreTrainedModel], + interface: Optional[AdapterModelInterface] = None, *model_args: Any, **kwargs: Any, ) -> PreTrainedModel: @@ -109,6 +136,9 @@ def load_model( Parameter identical to PreTrainedModel.from_pretrained model_class (`PreTrainedModel` or `AutoModel`): The model class to load (e.g. EncoderDecoderModel and EncoderDecoderAdapterModel both work) + interface (`AdapterModelInterface`, *optional*): + The custom adapter interface to use for the model, to be passed to the init() method. + If not provided, init() will try to use one of the built-in model integrations. model_args (sequence of positional arguments, *optional*): All remaining positional arguments will be passed to the underlying model's `__init__` method. kwargs (remaining dictionary of keyword arguments, *optional*): @@ -120,15 +150,20 @@ def load_model( old_init = model_class.__init__ + # try if we can find a interface file + if interface is None: + try: + interface = AdapterModelInterface._load(model_name_or_path, **kwargs) + except EnvironmentError: + pass + def new_init(self, config, *args, **kwargs): old_init(self, config, *args, **kwargs) - init(self) + init(self, interface=interface) # wrap model after it is initialized but before the weights are loaded - model_class.__init__ = new_init - model = model_class.from_pretrained(model_name_or_path, *model_args, **kwargs) - - # restore original __init__ function for when other models of the same type are created - model_class.__init__ = old_init + new_model_class = type(model_class.__name__, (model_class,), {}) + new_model_class.__init__ = new_init + model = new_model_class.from_pretrained(model_name_or_path, *model_args, **kwargs) return model diff --git a/tests/test_methods/base.py b/tests/test_methods/base.py index f5e53fedd6..ef44167143 100644 --- a/tests/test_methods/base.py +++ b/tests/test_methods/base.py @@ -4,7 +4,7 @@ import torch import adapters -from adapters import AutoAdapterModel +from adapters import ADAPTER_MODEL_MAPPING, AutoAdapterModel from transformers import AutoFeatureExtractor, AutoTokenizer, GlueDataset, GlueDataTrainingArguments from transformers.testing_utils import torch_device @@ -79,6 +79,18 @@ def build_generate_input(self, shape): """The generate() functions for inference require different inputs depeding on the model type. E.g. the text models require input_ids, where as the audio models require input_features""" return self.build_rand_ids_tensor(self.input_shape if not shape else shape).to(torch_device) + def _init_model_for_train_run(self, trained_adapter_name, frozen_adapter_name, adapter_config=None): + if self.config_class not in ADAPTER_MODEL_MAPPING: + self.skipTest("Does not support flex heads.") + model = AutoAdapterModel.from_config(self.config()) + + # add two adapters: one will be trained and the other should be frozen + model.add_adapter(trained_adapter_name, config=adapter_config) + model.add_adapter(frozen_adapter_name, config=adapter_config) + self.add_head(model, trained_adapter_name) + + return model + class TextAdapterTestBase(AbstractAdapterTestBase): """Base class for adapter tests for text models. Text models test classes should inherit from this class and override the attributes and functions as needed.""" diff --git a/tests/test_methods/method_test_impl/base.py b/tests/test_methods/method_test_impl/base.py index 7f7e0ba83b..b970f2a3ee 100644 --- a/tests/test_methods/method_test_impl/base.py +++ b/tests/test_methods/method_test_impl/base.py @@ -6,7 +6,7 @@ import torch import adapters -from adapters import ADAPTER_MODEL_MAPPING, AdapterSetup, AdapterTrainer, AutoAdapterModel +from adapters import ADAPTER_MODEL_MAPPING, AdapterSetup, AdapterTrainer from adapters.heads import CausalLMHead from adapters.utils import WEIGHTS_NAME from adapters.wrappers import load_model @@ -188,8 +188,11 @@ def run_forward_test(self, model, adapter_config, dtype=torch.float32): model.set_active_adapters(None) model.delete_adapter(name) + def create_twin_models(self): + return create_twin_models(self.model_class, self.config) + def run_load_test(self, adapter_config): - model1, model2 = create_twin_models(self.model_class, self.config) + model1, model2 = self.create_twin_models() name = "dummy_adapter" model1.add_adapter(name, config=adapter_config) @@ -233,8 +236,8 @@ def run_full_model_load_test(self, adapter_config): model2, loading_info = load_model(temp_dir, self.model_class, output_loading_info=True) # check if all weights were loaded - self.assertEqual(0, len(loading_info["missing_keys"])) - self.assertEqual(0, len(loading_info["unexpected_keys"])) + self.assertEqual(0, len(loading_info["missing_keys"]), loading_info["missing_keys"]) + self.assertEqual(0, len(loading_info["unexpected_keys"]), loading_info["unexpected_keys"]) # check if adapter was correctly loaded self.assertTrue(name in model2.adapters_config) @@ -275,14 +278,7 @@ def trainings_run(self, model, lr=1.0, steps=8, batch_size=2, gradient_accumulat def run_train_test(self, adapter_config, filter_keys): if not self.do_run_train_tests: self.skipTest("Skipping training tests. Set `do_run_train_tests=True` to run them.") - if self.config_class not in ADAPTER_MODEL_MAPPING: - self.skipTest("Does not support flex heads.") - model = AutoAdapterModel.from_config(self.config()) - - # add two adapters: one will be trained and the other should be frozen - model.add_adapter("mrpc", config=adapter_config) - model.add_adapter("dummy", config=adapter_config) - self.add_head(model, "mrpc") + model = self._init_model_for_train_run("mrpc", "dummy", adapter_config) self._assert_adapter_available(model, "mrpc") self._assert_adapter_available(model, "dummy") @@ -314,7 +310,8 @@ def run_train_test(self, adapter_config, filter_keys): def has_tied_embeddings(k): tied_embeddings = hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings is_tied_layer = ( - isinstance(model.heads["mrpc"], CausalLMHead) + hasattr(model, "heads") + and isinstance(model.heads["mrpc"], CausalLMHead) and "heads.{}.{}.weight".format("mrpc", len(model.heads["mrpc"]._modules) - 1) in k ) return tied_embeddings and is_tied_layer @@ -322,7 +319,7 @@ def has_tied_embeddings(k): for (k1, v1), (k2, v2) in zip(state_dict_pre.items(), model.state_dict().items()): # move both to the same device to avoid device mismatch errors v1, v2 = v1.to(v2.device), v2 - if "mrpc" in k1 and not has_tied_embeddings(k1): + if "mrpc" in k1 and not has_tied_embeddings(k1) or not k1.startswith(model.base_model_prefix): adapters_with_change |= not torch.equal(v1, v2) else: base_with_change |= not torch.equal(v1, v2) @@ -463,7 +460,10 @@ def run_same_weights_test(self, adapter_config, filter_keys): self.assertTrue(torch.equal(v1, v2), msg=f"{k1} has different weights than {k2}") # Check multiple models with one adapter with same config - model1, model2 = create_twin_models(self.model_class, self.config) + if hasattr(self, "adapter_interface") and self.adapter_interface: + model1, model2 = create_twin_models(self.model_class, self.config, self.adapter_interface) + else: + model1, model2 = create_twin_models(self.model_class, self.config) model1.add_adapter("adapter", config=adapter_config) model2.add_adapter("adapter", config=adapter_config) per_model_filter_keys = {"adapter": [k.format(name="adapter") for k in filter_keys]} diff --git a/tests/test_methods/method_test_impl/core/test_adapter_backward_compability.py b/tests/test_methods/method_test_impl/core/test_adapter_backward_compability.py index 196380524f..d74fc48619 100644 --- a/tests/test_methods/method_test_impl/core/test_adapter_backward_compability.py +++ b/tests/test_methods/method_test_impl/core/test_adapter_backward_compability.py @@ -9,8 +9,11 @@ @require_torch class CompabilityTestMixin: + def create_twin_models(self): + return create_twin_models(self.model_class, self.config) + def test_load_old_non_linearity(self): - model1, model2 = create_twin_models(self.model_class, self.config) + model1, model2 = self.create_twin_models() config = SeqBnConfig(non_linearity="gelu") name = "dummy" model1.add_adapter(name, config=config) diff --git a/tests/test_methods/method_test_impl/embeddings/test_adapter_embeddings.py b/tests/test_methods/method_test_impl/embeddings/test_adapter_embeddings.py index a41b862004..93832106ac 100644 --- a/tests/test_methods/method_test_impl/embeddings/test_adapter_embeddings.py +++ b/tests/test_methods/method_test_impl/embeddings/test_adapter_embeddings.py @@ -3,7 +3,6 @@ import torch -from adapters import AutoAdapterModel from transformers import AutoTokenizer, Trainer, TrainingArguments from transformers.testing_utils import require_torch, torch_device @@ -85,11 +84,10 @@ def test_training_embedding(self): tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - model = AutoAdapterModel.from_config(self.config()) + model = self._init_model_for_train_run("test", "dummy") + model.add_embeddings("test", tokenizer) self.assertEqual(model.active_embeddings, "test") - model.add_adapter("test") - self.add_head(model, "test") model.train_adapter("test", train_embeddings=True) for k, v in filter_parameters(model, "adapters.test.").items(): @@ -105,7 +103,7 @@ def test_training_embedding(self): training_args = TrainingArguments( output_dir="./examples", do_train=True, - learning_rate=0.4, + learning_rate=1.0, max_steps=15, use_cpu=True, per_device_train_batch_size=2, @@ -140,17 +138,19 @@ def test_training_embedding(self): and "embed_tokens" not in k1 and "shared" not in k1 and "wte" not in k1 + and "score" not in k1 ) ) def test_reference_embedding(self): - model = AutoAdapterModel.from_config(self.config()) # self.get_model() + model = self.get_model() tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token new_tokenizer = AutoTokenizer.from_pretrained("tests/fixtures/SiBERT") model.add_embeddings("test", new_tokenizer, "default", tokenizer) + model.to(torch_device) default_embedding = model.base_model.loaded_embeddings["default"] test_embedding = model.base_model.loaded_embeddings["test"] @@ -165,8 +165,8 @@ def test_reference_embedding(self): if len(input_test) >= 5: break - input_default = torch.tensor([input_default]) - input_test = torch.tensor([input_test]) + input_default = torch.tensor([input_default]).to(torch_device) + input_test = torch.tensor([input_test]).to(torch_device) default = default_embedding(input_default) test = test_embedding(input_test) diff --git a/tests/test_methods/method_test_impl/peft/test_adapter_common.py b/tests/test_methods/method_test_impl/peft/test_adapter_common.py index 6116804c3c..a69ab0d406 100644 --- a/tests/test_methods/method_test_impl/peft/test_adapter_common.py +++ b/tests/test_methods/method_test_impl/peft/test_adapter_common.py @@ -13,8 +13,6 @@ DoubleSeqBnConfig, DoubleSeqBnInvConfig, Fuse, - InvertibleAdaptersMixin, - InvertibleAdaptersWrapperMixin, MAMConfig, SeqBnConfig, SeqBnInvConfig, @@ -75,7 +73,7 @@ def test_delete_adapter(self): def test_add_adapter_with_invertible(self): model = self.get_model().base_model model.eval() - if not isinstance(model, InvertibleAdaptersMixin) and not isinstance(model, InvertibleAdaptersWrapperMixin): + if not model.supports_adapter("invertible"): self.skipTest("Model does not support invertible adapters.") for adapter_config in [SeqBnInvConfig(), DoubleSeqBnInvConfig()]: @@ -123,7 +121,7 @@ def test_delete_adapter_with_invertible(self): """Tests if the invertible adapters are deleted correctly.""" model = self.get_model().base_model model.eval() - if not isinstance(model, InvertibleAdaptersMixin) and not isinstance(model, InvertibleAdaptersWrapperMixin): + if not model.supports_adapter("invertible"): self.skipTest("Model does not support invertible adapters.") # iterate through all adapter invertible adapter configs @@ -227,6 +225,8 @@ def test_forward_bottleneck(self): def test_invertible_adapter_forward(self): model = self.get_model() model.eval() + if not model.supports_adapter("invertible"): + self.skipTest("Model does not support invertible adapters.") for adapter_config, _ in self.inv_adapter_configs_to_test: with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__): @@ -248,6 +248,8 @@ def test_model_config_serialization(self): """ for k, v in ADAPTER_CONFIG_MAP.items(): model = self.get_model() + if not model.supports_adapter(v): + continue # HACK: reduce the reduction factor such that # the small test model can have a phm_dim of 4 if hasattr(v, "phm_layer") and v.phm_layer: @@ -260,15 +262,19 @@ def test_model_adapter_summary(self): # count model parameters before model = self.get_model() model_no_params = sum(p.numel() for p in model.parameters()) + added = [] for k, v in ADAPTER_CONFIG_MAP.items(): + if not model.supports_adapter(v): + continue # HACK: reduce the reduction factor such that # the small test model can have a phm_dim of 4 if hasattr(v, "phm_layer") and v.phm_layer: v = v.__class__(reduction_factor=4) model.add_adapter(k, config=v) + added.append(k) summary = model.adapter_summary(as_dict=True) - self.assertEqual(len(ADAPTER_CONFIG_MAP) + 1, len(summary)) - for name in ADAPTER_CONFIG_MAP.keys(): + self.assertEqual(len(added) + 1, len(summary)) + for name in added: self.assertTrue(any([row["name"] == name for row in summary])) self.assertEqual(model_no_params, summary[-1]["#param"]) diff --git a/tests/test_methods/method_test_impl/peft/test_lora.py b/tests/test_methods/method_test_impl/peft/test_lora.py index 70d8e0a447..1014f9a36a 100644 --- a/tests/test_methods/method_test_impl/peft/test_lora.py +++ b/tests/test_methods/method_test_impl/peft/test_lora.py @@ -2,7 +2,7 @@ import torch -from adapters import LoRAConfig +from adapters import AdapterConfig, LoRAConfig from adapters.methods.lora import LoRALayer from tests.test_methods.method_test_impl.base import AdapterMethodBaseTestMixin from transformers.testing_utils import require_torch @@ -23,21 +23,30 @@ def test_merging_with_other_adapters(self): model.add_adapter("lora", config="lora") # Add different adapters - model.add_adapter("bottleneck", config="seq_bn") - model.add_adapter("prompt", config="prompt_tuning") - model.add_adapter("prefix", config="prefix_tuning") - model.add_adapter("ia3", config="ia3") - model.add_adapter("unipelt", config="unipelt") - model.add_adapter("mam", config="mam") - model.add_adapter("compacter", config="compacter[phm_dim=2, reduction_factor=8]") + adapter_methods = [ + "seq_bn", + "prompt_tuning", + "prefix_tuning", + "ia3", + "unipelt", + "mam", + "compacter[phm_dim=2, reduction_factor=8]", + ] + + for adapter_method in adapter_methods: + config = AdapterConfig.load(adapter_method) + if model.supports_adapter(config): + model.add_adapter(adapter_method, config=config) # Merging adapters with different architectures with LoRA should raise a ValueError - for adapter_architecture in ["bottleneck", "prompt", "prefix", "ia3", "unipelt", "mam", "compacter"]: - with self.subTest(adapter_architecture=adapter_architecture): + for adapter_method in adapter_methods: + with self.subTest(adapter_architecture=adapter_method): + if adapter_method not in model.adapters_config: + continue with self.assertRaises(ValueError): model.average_adapter( - adapter_name=f"average_lora_{adapter_architecture}", - adapter_list=[adapter_architecture, "lora"], + adapter_name=f"average_lora_{adapter_method}", + adapter_list=[adapter_method, "lora"], weights=[0.5, 0.5], combine_strategy="linear", ) diff --git a/tests/test_methods/method_test_impl/utils.py b/tests/test_methods/method_test_impl/utils.py index 473c422e60..445e9d2e63 100644 --- a/tests/test_methods/method_test_impl/utils.py +++ b/tests/test_methods/method_test_impl/utils.py @@ -10,7 +10,7 @@ global_rng = random.Random() -def create_twin_models(model_class, config_creator=None): +def create_twin_models(model_class, config_creator=None, interface=None): if config_creator and model_class.__name__.startswith("Auto"): model_config = config_creator() model1 = model_class.from_config(model_config) @@ -20,7 +20,7 @@ def create_twin_models(model_class, config_creator=None): else: model_config = model_class.config_class() model1 = model_class(model_config) - init(model1) + init(model1, interface=interface) model1.eval() # create a twin initialized with the same random weights model2 = copy.deepcopy(model1) diff --git a/tests/test_methods/test_on_custom_interface.py b/tests/test_methods/test_on_custom_interface.py new file mode 100644 index 0000000000..2952fef3cc --- /dev/null +++ b/tests/test_methods/test_on_custom_interface.py @@ -0,0 +1,125 @@ +import unittest + +import pytest + +import adapters +from adapters import AdapterModelInterface, ConfigUnion, DoubleSeqBnConfig, LoRAConfig, ParBnConfig +from transformers import Gemma2ForCausalLM, Gemma2ForSequenceClassification +from transformers.models.gemma2.configuration_gemma2 import Gemma2Config +from transformers.testing_utils import torch_device + +from .base import TextAdapterTestBase +from .generator import generate_method_tests, require_torch +from .method_test_impl.peft.test_adapter_common import BottleneckAdapterTestMixin +from .method_test_impl.utils import create_twin_models, make_config + + +class CustomInterfaceModelTestBase(TextAdapterTestBase): + model_class = Gemma2ForCausalLM + config_class = Gemma2Config + config = make_config( + Gemma2Config, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=4, + intermediate_size=16, + pad_token_id=0, + ) + tokenizer_name = "yujiepan/gemma-2-tiny-random" + adapter_interface = AdapterModelInterface( + adapter_methods=["bottleneck", "lora", "reft", "invertible"], + model_embeddings="embed_tokens", + model_layers="layers", + layer_self_attn="self_attn", + layer_cross_attn=None, + attn_k_proj="k_proj", + attn_q_proj="q_proj", + attn_v_proj="v_proj", + attn_o_proj="o_proj", + layer_intermediate_proj="mlp.up_proj", + layer_output_proj="mlp.down_proj", + layer_pre_self_attn="input_layernorm", + layer_pre_cross_attn=None, + layer_pre_ffn="pre_feedforward_layernorm", + layer_ln_1="post_attention_layernorm", + layer_ln_2="post_feedforward_layernorm", + ) + + def get_model(self): + model = Gemma2ForCausalLM(self.config()) + adapters.init(model, interface=self.adapter_interface) + model.to(torch_device) + return model + + def _init_model_for_train_run(self, trained_adapter_name, frozen_adapter_name, adapter_config=None): + model = Gemma2ForSequenceClassification(self.config()) + adapters.init(model, interface=self.adapter_interface) + + model.add_adapter(trained_adapter_name, config=adapter_config or LoRAConfig(init_weights="bert")) + model.add_adapter(frozen_adapter_name, config=adapter_config or LoRAConfig(init_weights="bert")) + + return model + + adapter_configs_to_test = [ + (DoubleSeqBnConfig(), ["adapters.{name}."]), + (ParBnConfig(init_weights="bert"), ["adapters.{name}."]), + ] + + def create_twin_models(self): + return create_twin_models(self.model_class, self.config, self.adapter_interface) + + def test_load_mam_adapter(self): + self.skipTest("Does not support prefix tuning.") + + def test_train_mam_adapter(self): + self.skipTest("Does not support prefix tuning.") + + def test_merging_with_other_adapters(self): + self.skipTest("Does not support all required methods yet.") + + def test_supports_adapter(self): + model = self.get_model() + model.eval() + + config = "unipelt" + with self.assertRaises(ValueError): + model.add_adapter("my_adapter", config=config) + + +method_tests = generate_method_tests( + CustomInterfaceModelTestBase, + not_supported=[ + "ConfigUnion", + "ClassConversion", + "Heads", + "PrefixTuning", + "PromptTuning", + "UniPELT", + "Composition", + "Bottleneck", + ], +) + +for test_class_name, test_class in method_tests.items(): + globals()[test_class_name] = test_class + + +@require_torch +@pytest.mark.bottleneck +class Bottleneck( + CustomInterfaceModelTestBase, + BottleneckAdapterTestMixin, + unittest.TestCase, +): + def test_get_adapter(self): + model = self.get_model() + model.eval() + n_layers = len(list(model.iter_layers())) + + for adapter_config, n_expected in [ + (DoubleSeqBnConfig(), n_layers * 2), + (ConfigUnion(LoRAConfig(), ParBnConfig()), n_layers * 2), + ]: + with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__): + self.run_get_test(model, adapter_config, n_expected) diff --git a/tests/test_misc/test_custom_interface_compat.py b/tests/test_misc/test_custom_interface_compat.py new file mode 100644 index 0000000000..1e19aade28 --- /dev/null +++ b/tests/test_misc/test_custom_interface_compat.py @@ -0,0 +1,209 @@ +import os +import tempfile +import unittest + +import torch + +import adapters +from adapters import AdapterModelInterface, AutoAdapterModel +from adapters.utils import WEIGHTS_NAME +from parameterized import parameterized +from tests.test_methods.method_test_impl.utils import ids_tensor +from transformers import AutoModel, AutoModelForCausalLM, BertConfig, LlamaConfig +from transformers.testing_utils import require_torch, torch_device + + +@require_torch +class CustomInterfaceCompatTest(unittest.TestCase): + # This test is to check if the custom interface produces the same results as the AdapterModel implementation. + + llama_config = LlamaConfig( + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + pad_token_id=0, + ) + bert_config = BertConfig( + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + pad_token_id=0, + ) + llama_interface = AdapterModelInterface( + adapter_methods=["bottleneck", "lora", "reft", "invertible"], + model_embeddings="embed_tokens", + model_layers="layers", + layer_self_attn="self_attn", + layer_cross_attn=None, + attn_k_proj="k_proj", + attn_q_proj="q_proj", + attn_v_proj="v_proj", + attn_o_proj="o_proj", + layer_intermediate_proj="mlp.up_proj", + layer_output_proj="mlp.down_proj", + layer_pre_self_attn="input_layernorm", + layer_pre_cross_attn=None, + layer_pre_ffn="post_attention_layernorm", + layer_ln_1=None, + layer_ln_2=None, + ) + bert_interface = AdapterModelInterface( + adapter_methods=["bottleneck", "lora", "reft", "prompt_tuning", "invertible"], + model_embeddings="embeddings", + model_layers="encoder.layer", + layer_self_attn="attention", + layer_cross_attn=None, + attn_k_proj="self.key", + attn_q_proj="self.query", + attn_v_proj="self.value", + attn_o_proj="output.dense", + layer_intermediate_proj="intermediate.dense", + layer_output_proj="output.dense", + layer_pre_self_attn="attention.self", + layer_pre_cross_attn=None, + layer_pre_ffn="intermediate", + layer_ln_1="attention.output.LayerNorm", + layer_ln_2="output.LayerNorm", + ) + bert_bn_rewrites = [(".attention_adapters.", ".attention.output."), (".output_adapters.", ".output.")] + + def create_twin_models(self, config, adapter_interface, hf_auto_model_class): + model1 = hf_auto_model_class.from_config(config) + adapters.init(model1, interface=adapter_interface) + model1.eval() + # create a twin initialized with the same random weights + model2 = AutoAdapterModel.from_pretrained(None, config=config, state_dict=model1.state_dict()) + model2.eval() + return model1, model2 + + @parameterized.expand( + [ + ("LoRA_Llama", adapters.LoRAConfig(), llama_config, llama_interface, AutoModelForCausalLM), + ("LoRA_BERT", adapters.LoRAConfig(), bert_config, bert_interface, AutoModel), + ("LoReft_Llama", adapters.LoReftConfig(), llama_config, llama_interface, AutoModelForCausalLM), + ("LoReft_BERT", adapters.LoReftConfig(), bert_config, bert_interface, AutoModel), + ( + "BnSeq_Llama", + adapters.SeqBnConfig(original_ln_before=False), + llama_config, + llama_interface, + AutoModelForCausalLM, + ), + ( + "BnSeqInv_Llama", + adapters.SeqBnInvConfig(), + llama_config, + llama_interface, + AutoModelForCausalLM, + ), + ( + "BnSeqPreLN_Llama", + adapters.SeqBnConfig(original_ln_before=True), + llama_config, + llama_interface, + AutoModelForCausalLM, + ), + ("BnPar_Llama", adapters.ParBnConfig(), llama_config, llama_interface, AutoModelForCausalLM), + ( + "Bn2Seq_Llama", + adapters.DoubleSeqBnConfig(original_ln_before=True), + llama_config, + llama_interface, + AutoModelForCausalLM, + ), + ( + "Bn2Par_Llama", + adapters.ParBnConfig(mh_adapter=True, output_adapter=True), + llama_config, + llama_interface, + AutoModelForCausalLM, + ), + ( + "BnSeq_BERT", + adapters.SeqBnConfig(original_ln_before=False), + bert_config, + bert_interface, + AutoModel, + bert_bn_rewrites, + ), + ( + "BnSeqInv_BERT", + adapters.SeqBnInvConfig(), + bert_config, + bert_interface, + AutoModel, + bert_bn_rewrites, + ), + ( + "BnSeqPreLN_BERT", + adapters.SeqBnConfig(original_ln_before=True), + bert_config, + bert_interface, + AutoModel, + bert_bn_rewrites, + ), + ("BnPar_BERT", adapters.ParBnConfig(), bert_config, bert_interface, AutoModel, bert_bn_rewrites), + ( + "Bn2Seq_BERT", + adapters.DoubleSeqBnConfig(original_ln_before=True), + bert_config, + bert_interface, + AutoModel, + bert_bn_rewrites, + ), + ( + "Bn2Par_BERT", + adapters.ParBnConfig(mh_adapter=True, output_adapter=True), + bert_config, + bert_interface, + AutoModel, + bert_bn_rewrites, + ), + ("Prompt_BERT", adapters.PromptTuningConfig(), bert_config, bert_interface, AutoModel), + ] + ) + def test_load_adapter(self, name, adapter_config, config, adapter_interface, hf_auto_model_class, rewrites=None): + custom_model, auto_model = self.create_twin_models(config, adapter_interface, hf_auto_model_class) + + custom_model.add_adapter(name, config=adapter_config) + custom_model.set_active_adapters(name) + with tempfile.TemporaryDirectory() as temp_dir: + custom_model.save_adapter(temp_dir, name) + + # Check that there are actually weights saved + weights = torch.load(os.path.join(temp_dir, WEIGHTS_NAME), map_location="cpu") + self.assertTrue(len(weights) > 0) + # The weight names of the custom interface adapter and built-in adapter might be different. + # Apply weight renaming here if necessary. + if rewrites is not None: + for old, new in rewrites: + for key in list(weights.keys()): + if old in key: + new_key = key.replace(old, new) + weights[new_key] = weights.pop(key) + + torch.save(weights, os.path.join(temp_dir, WEIGHTS_NAME)) + + # also tests that set_active works + loading_info = {} + auto_model.load_adapter(temp_dir, set_active=True, loading_info=loading_info) + + # check if all weights were loaded + self.assertEqual(0, len(loading_info["missing_keys"]), loading_info["missing_keys"]) + self.assertEqual(0, len(loading_info["unexpected_keys"]), loading_info["unexpected_keys"]) + + # check if adapter was correctly loaded + self.assertTrue(name in auto_model.adapters_config) + + # check equal output + input_data = {"input_ids": ids_tensor((2, 128), 1000)} + custom_model.to(torch_device) + auto_model.to(torch_device) + output1 = custom_model(**input_data) + output2 = auto_model(**input_data) + self.assertEqual(len(output1), len(output2)) + self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-5))