diff --git a/docs/classes/adapter_model_interface.rst b/docs/classes/adapter_model_interface.rst
new file mode 100644
index 0000000000..e17f1a38ac
--- /dev/null
+++ b/docs/classes/adapter_model_interface.rst
@@ -0,0 +1,8 @@
+Adapter Model Interface
+=======================
+
+.. autoclass:: adapters.AdapterModelInterface
+ :members:
+
+.. autoclass:: adapters.AdapterMethod
+ :members:
diff --git a/docs/contributing/adding_adapters_to_a_model.md b/docs/contributing/adding_adapters_to_a_model.md
index 97c9fe21d7..f99297510a 100644
--- a/docs/contributing/adding_adapters_to_a_model.md
+++ b/docs/contributing/adding_adapters_to_a_model.md
@@ -1,4 +1,11 @@
# Adding Adapters to a Model
+
+```{eval-rst}
+.. important::
+ For most use cases, it can be much easier support a new model architecture via the new adapter plugin interface.
+ Check out `Custom Models <../plugin_interface.html>`_ for more.
+```
+
This document gives an overview of how new model architectures of Hugging Face Transformers can be supported by `adapters`.
Before delving into implementation details, you should familiarize yourself with the main design philosophies of `adapters`:
diff --git a/docs/index.rst b/docs/index.rst
index 0c10de8a18..9cef2b044f 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -52,7 +52,6 @@ Currently, we support the PyTorch versions of all models as listed on the `Model
merging_adapters
prediction_heads
embeddings
- extending
.. toctree::
:maxdepth: 2
@@ -66,6 +65,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model
:caption: Supported Models
model_overview
+ plugin_interface
classes/models/albert
classes/models/auto
classes/models/bart
@@ -99,6 +99,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model
classes/adapter_config
classes/model_adapters_config
classes/adapter_layer
+ classes/adapter_model_interface
classes/model_mixins
classes/adapter_training
classes/adapter_utils
@@ -110,6 +111,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model
contributing
contributing/adding_adapter_methods
contributing/adding_adapters_to_a_model
+ extending
Citation
========
diff --git a/docs/model_overview.md b/docs/model_overview.md
index b364ab6ebb..d46cff99d1 100644
--- a/docs/model_overview.md
+++ b/docs/model_overview.md
@@ -13,6 +13,7 @@ The table below further shows which model architectures support which adaptation
| Model | (Bottleneck)
Adapters | Prefix
Tuning | LoRA | Compacter | Adapter
Fusion | Invertible
Adapters | Parallel
block | Prompt
Tuning | ReFT |
| --------------------------------------- | -| - | - | - | - | - | - |- | - |
+| [Custom models](plugin_interface.html) | ✅(°) | | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ |
| [ALBERT](classes/models/albert.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [BART](classes/models/bart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ |
| [BEIT](classes/models/beit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | ✅ |
@@ -38,9 +39,11 @@ The table below further shows which model architectures support which adaptation
| [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
+(°) `original_ln_after=False` is unsupported for bottleneck configs.
(*) If the used encoder and decoder model class are supported.
-**Missing a model architecture you'd like to use?**
-adapters can be easily extended to new model architectures as described in [Adding Adapters to a Model](https://docs.adapterhub.ml/contributing/adding_adapters_to_a_model.html).
+**Missing a model architecture you'd like to use?**
+The new model plugin interface makes it easy to support new transformer models with just a few lines of code [Learn more](plugin_interface.md).
+Also, _Adapters_ can be extended to new model architectures as described in [Adding Adapters to a Model](https://docs.adapterhub.ml/contributing/adding_adapters_to_a_model.html).
Feel free to [open an issue](https://github.com/Adapter-Hub/adapters/issues) requesting support for a new architecture.
_We very much welcome pull requests adding new model implementations!_
diff --git a/docs/plugin_interface.md b/docs/plugin_interface.md
new file mode 100644
index 0000000000..10ff2da74f
--- /dev/null
+++ b/docs/plugin_interface.md
@@ -0,0 +1,94 @@
+# Custom Models
+
+The _Adapters_ library provides a simple mechanism for integrating adapter methods into any available _Transformers_ model - including custom architectures.
+This can be accomplished by defining a plugin interface instance of [`AdapterModelInterface`](adapters.AdapterModelInterface).
+The following example shows how this looks like for Gemma 2:
+
+```python
+import adapters
+from adapters import AdapterModelInterface
+from transformers import AutoModelForCausalLM
+
+plugin_interface = AdapterModelInterface(
+ adapter_methods=["lora", "reft"],
+ model_embeddings="embed_tokens",
+ model_layers="layers",
+ layer_self_attn="self_attn",
+ layer_cross_attn=None,
+ attn_k_proj="k_proj",
+ attn_q_proj="q_proj",
+ attn_v_proj="v_proj",
+ attn_o_proj="o_proj",
+ layer_intermediate_proj="mlp.up_proj",
+ layer_output_proj="mlp.down_proj",
+)
+
+model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", token="")
+adapters.init(model, interface=plugin_interface)
+
+model.add_adapter("my_adapter", config="lora")
+
+print(model.adapter_summary())
+```
+
+## Walkthrough
+
+Let's go through what happens in the example above step by step:
+
+**1. Define adapter methods to plug into a model:**
+The `adapter_methods` argument is the central parameter to configure which adapters will be supported in the model.
+Here, we enable all LoRA and ReFT based adapters.
+See [`AdapterMethod`](adapters.AdapterMethod) for valid options to specify here.
+Check out [Adapter Methods](methods.md) for detailed explanation of the methods.
+
+**2. Define layer and module names:**
+While all Transformers layers share similar basic components, their implementation can differ in terms of subtleties such as module names.
+Therefore, the [`AdapterModelInterface`](adapters.AdapterModelInterface) needs to translate the model-specific module structure into a common set of access points for adapter implementations to hook in.
+The remaining attributes in the definition above serve this purpose.
+Their attribute names follow a common syntax that specify their location and purpose:
+- The initial part before the first "_" defines the base module relative to which the name should be specified.
+- The remaining part after the first "_" defines the functional component.
+
+E.g., `model_embeddings` identifies the embeddings layer (functional component) relative to the base model (location).
+`layer_output_proj` identifies the FFN output projection relative to one Transformer layer.
+Each attribute value may specify a direct submodule of the reference module (`"embed_token"`) or a multi-level path starting at the reference module (`"mlp.down_proj"`).
+
+**3. (optional) Extended interface attributes:**
+There are a couple of attributes in the [`AdapterModelInterface`](adapters.AdapterModelInterface) that are only required for some adapter methods.
+We don't need those in the above example for LoRA and ReFT, but when supporting bottleneck adapters as well, the full interface would look as follows:
+```python
+adapter_interface = AdapterModelInterface(
+ adapter_types=["bottleneck", "lora", "reft"],
+ model_embeddings="embed_tokens",
+ model_layers="layers",
+ layer_self_attn="self_attn",
+ layer_cross_attn=None,
+ attn_k_proj="k_proj",
+ attn_q_proj="q_proj",
+ attn_v_proj="v_proj",
+ attn_o_proj="o_proj",
+ layer_intermediate_proj="mlp.up_proj",
+ layer_output_proj="mlp.down_proj",
+ layer_pre_self_attn="input_layernorm",
+ layer_pre_cross_attn=None,
+ layer_pre_ffn="pre_feedforward_layernorm",
+ layer_ln_1="post_attention_layernorm",
+ layer_ln_2="post_feedforward_layernorm",
+)
+```
+
+**4. Initialize adapter methods in the model:**
+Finally, we just need to apply the defined adapter integration in the target model.
+This can be achieved using the usual `adapters.init()` method:
+```python
+adapters.init(model, interface=adapter_interface)
+```
+Now, you can use (almost) all functionality of the _Adapters_ library on the adapted model instance!
+
+## Limitations
+
+The following features of the _Adapters_ library are not supported via the plugin interface approach:
+- Prefix Tuning adapters
+- Parallel composition blocks
+- XAdapterModel classes
+- Setting `original_ln_after=False` in bottleneck adapter configurations (this affects `AdapterPlusConfig`)
diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py
index 905706b509..b813c8644c 100644
--- a/src/adapters/__init__.py
+++ b/src/adapters/__init__.py
@@ -84,6 +84,7 @@
"Seq2SeqLMHead",
"TaggingHead",
],
+ "interface": ["AdapterMethod", "AdapterModelInterface"],
"methods.adapter_layer_base": ["AdapterLayerBase", "ComposableAdapterLayerBase"],
"model_mixin": [
"EmbeddingAdaptersMixin",
@@ -198,6 +199,7 @@
Seq2SeqLMHead,
TaggingHead,
)
+ from .interface import AdapterMethod, AdapterModelInterface
from .methods.adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase
from .model_mixin import (
EmbeddingAdaptersMixin,
diff --git a/src/adapters/interface.py b/src/adapters/interface.py
new file mode 100644
index 0000000000..2f0026296a
--- /dev/null
+++ b/src/adapters/interface.py
@@ -0,0 +1,121 @@
+import json
+import os
+from dataclasses import asdict, dataclass
+from typing import List, Optional
+
+from transformers.utils import cached_file
+
+from . import __version__
+from .utils import INTERFACE_CONFIG_NAME
+
+
+class AdapterMethod:
+ """
+ Enum of all supported adapter method types.
+
+ Attributes:
+ bottleneck: Adapter methods using bottleneck layers.
+ prefix_tuning: Adapters methods based on Prefix Tuning. Note that this is currently unsupported via AdapterModelInterface.
+ lora: Adapter methods based on low-rank adaptation.
+ prompt_tuning: Adapter methods based on Prompt Tuning.
+ reft: Adapters methods based on Representation Fine-Tuning.
+ invertible: Adapter methods using invertible modules.
+ """
+
+ bottleneck = "bottleneck"
+ prefix_tuning = "prefix_tuning"
+ lora = "lora"
+ prompt_tuning = "prompt_tuning"
+ reft = "reft"
+ invertible = "invertible"
+
+ @staticmethod
+ def get_from_config(config) -> List[str]:
+ """
+ Get the adapter type from a given adapter config.
+
+ Args:
+ config: The adapter config.
+
+ Returns:
+ List[str]: The adapter type.
+ """
+ methods = []
+ if getattr(config, "inv_adapter", False):
+ methods.append(AdapterMethod.invertible)
+ if config.architecture is None:
+ methods.append(AdapterMethod.bottleneck)
+ elif config.architecture == "union":
+ for sub_config in config.configs:
+ methods.extend(AdapterMethod.get_from_config(sub_config))
+ else:
+ methods.append(config.architecture)
+ return methods
+
+
+@dataclass
+class AdapterModelInterface:
+ """
+ Defines the main interface for integrating adapter methods into a model class.
+ This interface translates generic accessor names to model-specific attribute names.
+
+ Args:
+ adapter_methods (List[str]): List of adapter types that are supported by the model.
+ model_embeddings (str): Name of the model's embedding layer.
+ model_layers (str): Name of the model's layer list.
+ layer_self_attn (str): Name of the self-attention layer in a transformer layer.
+ layer_cross_attn (str): Name of the cross-attention layer in a transformer layer.
+ attn_k_proj (str): Name of the key projection layer in an attention layer.
+ attn_q_proj (str): Name of the query projection layer in an attention layer.
+ attn_v_proj (str): Name of the value projection layer in an attention layer.
+ attn_o_proj (str): Name of the output projection layer in an attention layer.
+ layer_intermediate_proj (str): Name of the intermediate projection layer in a transformer layer.
+ layer_output_proj (str): Name of the output projection layer in a transformer layer.
+ layer_pre_self_attn (Optional[str]): Hook point directly before the self attention layer. Used for extended bottleneck adapter support.
+ layer_pre_cross_attn (Optional[str]): Hook point directly before the cross attention layer. Used for extended bottleneck adapter support.
+ layer_pre_ffn (Optional[str]): Hook point directly before the feed forward layer. Used for extended bottleneck adapter support.
+ layer_ln_1 (Optional[str]): Layer norm *after* the self-attention layer. Used for extended bottleneck adapter support.
+ layer_ln_2 (Optional[str]): Layer norm *after* the feed forward layer. Used for extended bottleneck adapter support.
+ """
+
+ adapter_methods: List[str]
+
+ model_embeddings: str
+ model_layers: str
+
+ layer_self_attn: str
+ layer_cross_attn: str
+ attn_k_proj: str
+ attn_q_proj: str
+ attn_v_proj: str
+ attn_o_proj: str
+
+ layer_intermediate_proj: str
+ layer_output_proj: str
+
+ # Optional attributes for extended bottleneck adapter support
+ layer_pre_self_attn: Optional[str] = None
+ layer_pre_cross_attn: Optional[str] = None
+ layer_pre_ffn: Optional[str] = None
+ layer_ln_1: Optional[str] = None
+ layer_ln_2: Optional[str] = None
+
+ def to_dict(self):
+ return asdict(self)
+
+ def _save(self, save_directory, model_config):
+ config_dict = {
+ "model_type": model_config.model_type,
+ "interface": self.to_dict(),
+ "version": "adapters." + __version__,
+ }
+ save_path = os.path.join(save_directory, INTERFACE_CONFIG_NAME)
+ with open(save_path, "w") as f:
+ json.dump(config_dict, f, indent=2, sort_keys=True)
+
+ @classmethod
+ def _load(cls, path_or_repo_id: str, **kwargs):
+ resolved_file = cached_file(path_or_repo_id, INTERFACE_CONFIG_NAME, **kwargs)
+ with open(resolved_file, "r") as f:
+ config_dict = json.load(f)
+ return AdapterModelInterface(**config_dict["interface"])
diff --git a/src/adapters/methods/__init__.py b/src/adapters/methods/__init__.py
index e69de29bb2..5082ff0b6b 100644
--- a/src/adapters/methods/__init__.py
+++ b/src/adapters/methods/__init__.py
@@ -0,0 +1,14 @@
+from .bottleneck import init_bottleneck
+from .invertible import init_invertible_adapters
+from .lora import init_lora
+from .prompt_tuning import init_prompt_tuning
+from .reft import init_reft
+
+
+METHOD_INIT_MAPPING = {
+ "bottleneck": init_bottleneck,
+ "lora": init_lora,
+ "prompt_tuning": init_prompt_tuning,
+ "reft": init_reft,
+ "invertible": init_invertible_adapters,
+}
diff --git a/src/adapters/methods/bottleneck.py b/src/adapters/methods/bottleneck.py
index 889941d2b9..74461635ee 100644
--- a/src/adapters/methods/bottleneck.py
+++ b/src/adapters/methods/bottleneck.py
@@ -1,3 +1,4 @@
+from functools import partial
from typing import List, Mapping, NamedTuple, Optional, Union
import torch
@@ -15,10 +16,16 @@
)
from ..configuration import BnConfig
from ..context import ForwardContext
+from ..utils import multigetattr
from .adapter_layer_base import ComposableAdapterLayerBase
from .modeling import Adapter, BertFusion, ParallelAdapter
+LAYER_HOOK_UNSUPPORTED = [
+ ("original_ln_after", False),
+]
+
+
class BottleneckState(NamedTuple):
"""
Models the input and output states of a bottleneck adapter layer.
@@ -45,9 +52,10 @@ class BottleneckLayer(ComposableAdapterLayerBase, nn.Module):
adapter_modules_name = "adapters"
supported_compositions = [Stack, Fuse, Split, Parallel, BatchSplit, Average]
- def __init__(self, location_key: str):
+ def __init__(self, location_key: str, is_layer_hooked: bool = False):
super().__init__()
self.location_key = location_key
+ self.is_layer_hooked = is_layer_hooked
def init_adapters(self, model_config, adapters_config):
self._init_mapping()
@@ -55,6 +63,8 @@ def init_adapters(self, model_config, adapters_config):
self.adapters_config = adapters_config
self.adapters = nn.ModuleDict(dict())
self.adapter_fusion_layer = nn.ModuleDict(dict())
+ if not hasattr(self, "is_layer_hooked"):
+ self.is_layer_hooked = False
def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
self.layer_idx = layer_idx
@@ -78,6 +88,15 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
'{"1": 16, "default": 16}'
)
+ # check unsupported configurations for layer hooking mode
+ if self.is_layer_hooked:
+ for key, value in LAYER_HOOK_UNSUPPORTED:
+ if adapter_config.get(key, None) == value:
+ raise ValueError(
+ f"Unsupported configuration for bottleneck layer hooking mode: {key}={value}. "
+ "Please set this configuration to a supported value."
+ )
+
if adapter_config.is_parallel:
adapter_class = ParallelAdapter
else:
@@ -88,6 +107,10 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
down_sample=int(self.model_config.hidden_size // reduction_factor),
config=adapter_config,
)
+ # for adapters hooked via interface:
+ # residual & LN are applied by model, so don't apply in adapters
+ if self.is_layer_hooked:
+ adapter.original_ln_after = False
adapter.train(self.training) # make sure training mode is consistent
self.adapters[adapter_name] = adapter
return True
@@ -321,9 +344,10 @@ def bottleneck_layer_forward(self, hidden_states, residual_input, layer_norm):
torch.Tensor: Output hidden states of the adapter layer.
"""
# Batch sizes might be different due to prefix tuning w. Parallel block
- (residual_input,) = adjust_tensors_for_parallel(hidden_states, residual_input)
- # Replicate in both directions as residual might be larger (e.g. GPT-J)
- (hidden_states,) = adjust_tensors_for_parallel(residual_input, hidden_states)
+ if residual_input is not None:
+ (residual_input,) = adjust_tensors_for_parallel(hidden_states, residual_input)
+ # Replicate in both directions as residual might be larger (e.g. GPT-J)
+ (hidden_states,) = adjust_tensors_for_parallel(residual_input, hidden_states)
adapter_setup = self.get_active_setup()
if adapter_setup is not None:
input_hidden_states = hidden_states
@@ -335,9 +359,9 @@ def bottleneck_layer_forward(self, hidden_states, residual_input, layer_norm):
last_adapter = self.adapters[last]
hidden_states = last_adapter.post_forward(hidden_states, input_hidden_states, residual_input, layer_norm)
- elif layer_norm:
+ elif layer_norm is not None and not self.is_layer_hooked:
hidden_states = layer_norm(hidden_states + residual_input)
- else:
+ elif residual_input is not None and not self.is_layer_hooked:
hidden_states = hidden_states + residual_input
return hidden_states
@@ -354,3 +378,55 @@ def forward(self, hidden_states, residual_input, layer_norm):
torch.Tensor: Output hidden states of the adapter layer.
"""
return self.bottleneck_layer_forward(hidden_states, residual_input, layer_norm)
+
+
+def hook_fn(adapter_layer, ln_get_fn, module, args, output):
+ # Retrieve residual from previous hook, if existing
+ context = ForwardContext.get_context()
+ residual_input = getattr(context, f"{adapter_layer.location_key}_residual_input", None)
+ # Retrieve layer norm from getter fn
+ if ln_get_fn is not None:
+ layer_norm = ln_get_fn()
+ else:
+ layer_norm = None
+ # Call adapter layer
+ if isinstance(output, torch.Tensor):
+ return adapter_layer(output, residual_input, layer_norm)
+ else:
+ return (adapter_layer(output[0], residual_input, layer_norm),) + output[1:]
+
+
+def _residual_hook_fn(location_key, module, args):
+ context = ForwardContext.get_context()
+ if context is not None:
+ setattr(context, f"{location_key}_residual_input", args[0])
+
+
+def init_bottleneck(model):
+ model = model.base_model
+ for _, layer in model.iter_layers():
+ if self_attn := multigetattr(layer, model.adapter_interface.layer_self_attn, None):
+ if o_proj := multigetattr(self_attn, model.adapter_interface.attn_o_proj, None):
+ if not hasattr(layer, "attention_adapters"):
+ layer.attention_adapters = BottleneckLayer("mh_adapter", is_layer_hooked=True)
+ ln_1_get_fn = lambda: multigetattr(layer, model.adapter_interface.layer_ln_1, None)
+ o_proj.register_forward_hook(partial(hook_fn, layer.attention_adapters, ln_1_get_fn))
+ if layer_output_proj := multigetattr(layer, model.adapter_interface.layer_output_proj, None):
+ if not hasattr(layer, "output_adapters"):
+ layer.output_adapters = BottleneckLayer("output_adapter", is_layer_hooked=True)
+ ln_2_get_fn = lambda: multigetattr(layer, model.adapter_interface.layer_ln_2, None)
+ layer_output_proj.register_forward_hook(partial(hook_fn, layer.output_adapters, ln_2_get_fn))
+ if cross_attn := multigetattr(layer, model.adapter_interface.layer_cross_attn, None):
+ if not hasattr(cross_attn, "cross_attention_adapters"):
+ layer.attention_adapters = BottleneckLayer("cross_adapter", is_layer_hooked=True)
+ cross_attn.register_forward_hook(partial(hook_fn, layer.attention_adapters, None))
+
+ if model.adapter_interface.layer_pre_self_attn is not None:
+ if pre_self_attn := multigetattr(layer, model.adapter_interface.layer_pre_self_attn, None):
+ pre_self_attn.register_forward_pre_hook(partial(_residual_hook_fn, "mh_adapter"))
+ if model.adapter_interface.layer_pre_cross_attn is not None:
+ if pre_cross_attn := multigetattr(layer, model.adapter_interface.layer_pre_cross_attn, None):
+ pre_cross_attn.register_forward_pre_hook(partial(_residual_hook_fn, "cross_adapter"))
+ if model.adapter_interface.layer_pre_ffn is not None:
+ if pre_ffn := multigetattr(layer, model.adapter_interface.layer_pre_ffn, None):
+ pre_ffn.register_forward_pre_hook(partial(_residual_hook_fn, "output_adapter"))
diff --git a/src/adapters/methods/invertible.py b/src/adapters/methods/invertible.py
new file mode 100644
index 0000000000..4a8158599c
--- /dev/null
+++ b/src/adapters/methods/invertible.py
@@ -0,0 +1,104 @@
+import types
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from ..configuration.adapter_config import BnConfig
+from ..utils import multigetattr
+from .adapter_layer_base import AdapterLayerBase
+from .modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock
+
+
+class InvertibleAdapterLayer(AdapterLayerBase, nn.ModuleDict):
+ adapter_modules_name = "_modules"
+
+ def __init__(self, model_config, adapters_config):
+ super().__init__()
+ self.location_key = "inv_adapter"
+ self.model_config = model_config
+ self.adapters_config = adapters_config
+
+ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
+ self.layer_idx = layer_idx
+ embedding_size = getattr(self.model_config, "embedding_size", self.model_config.hidden_size)
+ adapter_config = self.adapters_config.match(
+ adapter_name,
+ config_type=BnConfig,
+ location_key="inv_adapter",
+ )
+ if adapter_config is not None and adapter_config["inv_adapter"]:
+ if adapter_config["inv_adapter"] == "nice":
+ inv_adap = NICECouplingBlock(
+ [[embedding_size]],
+ non_linearity=adapter_config["non_linearity"],
+ reduction_factor=adapter_config["inv_adapter_reduction_factor"],
+ )
+ elif adapter_config["inv_adapter"] == "glow":
+ inv_adap = GLOWCouplingBlock(
+ [[embedding_size]],
+ non_linearity=adapter_config["non_linearity"],
+ reduction_factor=adapter_config["inv_adapter_reduction_factor"],
+ )
+ else:
+ raise ValueError(f"Invalid invertible adapter type '{adapter_config['inv_adapter']}'.")
+ self[adapter_name] = inv_adap
+ self[adapter_name].apply(Adapter.init_bert_weights)
+ return True
+
+ return False
+
+ def get_invertible_adapter(self):
+ # HACK: returns the first adapter of the currently active setup. for backwards compatibility
+ adapter_setup = self.get_active_setup()
+ if adapter_setup is not None and len(adapter_setup) > 0:
+ first_adapter = adapter_setup.first()
+ if first_adapter in self:
+ return self[first_adapter]
+ return None
+
+ def forward(self, hidden_states: torch.Tensor, rev=False):
+ adapter_setup = self.get_active_setup()
+ if adapter_setup is not None and len(adapter_setup) > 0:
+ first_adapter = adapter_setup.first()
+ if first_adapter in self:
+ hidden_states = self[first_adapter](hidden_states, rev=rev)
+ return hidden_states
+
+
+def hook_fn(model, module, args, embedding_output):
+ embedding_output = model.invertible_adapters(embedding_output)
+ return embedding_output
+
+
+def inv_hook_fn(model, module, args):
+ inv_output = model.invertible_adapters(args[0], rev=True)
+ return (inv_output,) + args[1:]
+
+
+def init_invertible_adapters(model):
+ base_model = model.base_model
+ if not hasattr(base_model, "invertible_adapters"):
+ base_model.invertible_adapters = InvertibleAdapterLayer(base_model.config, base_model.adapters_config)
+
+ embed_layer = multigetattr(base_model, base_model.adapter_interface.model_embeddings)
+ embed_layer.register_forward_hook(partial(hook_fn, base_model))
+
+ # Add methods from original invertible adapter mixin.
+ # This is primarily for backwards compatibility and internal use.
+ base_model.add_invertible_adapter = types.MethodType(
+ lambda self, *args, **kwargs: self.invertible_adapters.add_adapter(*args, **kwargs), base_model
+ )
+ base_model.delete_invertible_adapter = types.MethodType(
+ lambda self, *args, **kwargs: self.invertible_adapters.delete_adapter(*args, **kwargs), base_model
+ )
+ base_model.get_invertible_adapter = types.MethodType(
+ lambda self: self.invertible_adapters.get_invertible_adapter(), base_model
+ )
+ base_model.invertible_adapters_forward = types.MethodType(
+ lambda self, *args, **kwargs: self.invertible_adapters(*args, **kwargs), base_model
+ )
+
+ # Register reverse forward pass
+ if output_embedding := model.get_output_embeddings():
+ output_embedding.register_forward_pre_hook(partial(inv_hook_fn, base_model))
diff --git a/src/adapters/methods/lora.py b/src/adapters/methods/lora.py
index 1dc8ef9f23..ab4cb7aaa1 100644
--- a/src/adapters/methods/lora.py
+++ b/src/adapters/methods/lora.py
@@ -16,6 +16,7 @@
from ..composition import Average, BatchSplit, Parallel, Stack
from ..configuration import LoRAConfig, ModelAdaptersConfig
+from ..utils import multigetattr, multisetattr
from .adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase
from .utils import dequantize_bnb_weight, fix_seed
@@ -336,6 +337,14 @@ def _average_adapter_lora_delta_w_svd(self, input_adapters: Dict[str, float], av
avg_state_dict["lora_A"] = V
avg_state_dict["lora_B"] = U @ torch.diag(S_diag)
+ def _copy_hooks_from(self, module: nn.Module):
+ for (
+ k,
+ v,
+ ) in module.__dict__.items():
+ if "_hooks" in k:
+ setattr(self, k, v)
+
class LoRAState(NamedTuple):
"""Models the input and output states of a LoRA layer.
@@ -432,6 +441,7 @@ def wrap(
**kwargs,
)
new_module.copy_from(module)
+ new_module._copy_hooks_from(module)
return new_module
@@ -676,6 +686,7 @@ def wrap(
new_module.weight = module.weight
if module.bias is not None:
new_module.bias = module.bias
+ new_module._copy_hooks_from(module)
return new_module
@@ -814,3 +825,25 @@ def T(w):
raise ValueError(f"Invalid adapter setup. Cannot use {adapter_setup} with LoRA.")
return F.linear(x, T(self.weight), bias=self.bias)
+
+
+def init_lora(model):
+ model = model.base_model
+ for _, _, attention in model.iter_attentions():
+ if q_proj := multigetattr(attention, model.adapter_interface.attn_q_proj, None):
+ lora_proj = LoRALinear.wrap(q_proj, "selfattn", model.config, model.adapters_config, attn_key="q")
+ multisetattr(attention, model.adapter_interface.attn_q_proj, lora_proj)
+ if k_proj := multigetattr(attention, model.adapter_interface.attn_k_proj, None):
+ lora_proj = LoRALinear.wrap(k_proj, "selfattn", model.config, model.adapters_config, attn_key="k")
+ multisetattr(attention, model.adapter_interface.attn_k_proj, lora_proj)
+ if v_proj := multigetattr(attention, model.adapter_interface.attn_v_proj, None):
+ lora_proj = LoRALinear.wrap(v_proj, "selfattn", model.config, model.adapters_config, attn_key="v")
+ multisetattr(attention, model.adapter_interface.attn_v_proj, lora_proj)
+
+ for _, layer in model.iter_layers():
+ if intermediate_proj := multigetattr(layer, model.adapter_interface.layer_intermediate_proj):
+ lora_proj = LoRALinear.wrap(intermediate_proj, "intermediate", model.config, model.adapters_config)
+ multisetattr(layer, model.adapter_interface.layer_intermediate_proj, lora_proj)
+ if output_proj := multigetattr(layer, model.adapter_interface.layer_output_proj):
+ lora_proj = LoRALinear.wrap(output_proj, "output", model.config, model.adapters_config)
+ multisetattr(layer, model.adapter_interface.layer_output_proj, lora_proj)
diff --git a/src/adapters/methods/prompt_tuning.py b/src/adapters/methods/prompt_tuning.py
index b9504ac40b..b91b7e76b9 100644
--- a/src/adapters/methods/prompt_tuning.py
+++ b/src/adapters/methods/prompt_tuning.py
@@ -1,6 +1,7 @@
# https://github.com/google-research/prompt-tuning/blob/main/prompt_tuning/train/prompts.py
import math
+from functools import partial
from typing import Callable
import numpy as np
@@ -12,6 +13,7 @@
from ..configuration import ModelAdaptersConfig, PromptTuningConfig
from ..context import ForwardContext
+from ..utils import multigetattr, prefix_attention_mask
from .adapter_layer_base import AdapterLayerBase
from .utils import fix_seed
@@ -179,3 +181,27 @@ def forward(self, hidden_states: torch.Tensor):
context.prompt_tokens_length = prefix_attention_mask_length
return hidden_states
+
+
+def hook_fn(model, module, args, embedding_output):
+ embedding_output = model.prompt_tuning.forward(embedding_output)
+ return embedding_output
+
+
+# TODO: this will only work for a limited set of models
+def _attn_mask_hook_fn(module, args):
+ attn_mask = args[1]
+ attn_mask = prefix_attention_mask(attn_mask)
+ return (args[0], attn_mask) + args[2:]
+
+
+def init_prompt_tuning(model):
+ model = model.base_model
+ if not hasattr(model, "prompt_tuning"):
+ model.support_prompt_tuning = True
+ model.prompt_tuning = PromptTuningLayer(model.config, model.adapters_config, model.get_input_embeddings())
+ embed_layer = multigetattr(model, model.adapter_interface.model_embeddings)
+ embed_layer.register_forward_hook(partial(hook_fn, model))
+
+ for _, layer in model.iter_layers():
+ layer.register_forward_pre_hook(_attn_mask_hook_fn)
diff --git a/src/adapters/methods/reft.py b/src/adapters/methods/reft.py
index 1884bf5e01..d4904d7fd2 100644
--- a/src/adapters/methods/reft.py
+++ b/src/adapters/methods/reft.py
@@ -229,6 +229,7 @@ def hook_fn(module, args, output):
def init_reft(model):
+ model = model.base_model
for _, layer in model.iter_layers():
if not hasattr(layer, "reft_layer"):
layer.reft_layer = ReftLayer("output", model.config, model.adapters_config)
diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py
index 8303191a52..63c251874f 100644
--- a/src/adapters/model_mixin.py
+++ b/src/adapters/model_mixin.py
@@ -9,7 +9,7 @@
from copy import deepcopy
from functools import partial
from os.path import join
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
import torch
from torch import nn
@@ -25,7 +25,9 @@
from .configuration import ADAPTER_CONFIG_MAP, AdapterConfig, AdapterFusionConfig, BnConfig
from .context import AdapterSetup, ForwardContext
from .hub_mixin import PushAdapterToHubMixin
+from .interface import AdapterMethod, AdapterModelInterface
from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader
+from .methods import METHOD_INIT_MAPPING
from .methods.adapter_layer_base import AdapterLayerBase
from .methods.bottleneck import BottleneckLayer
from .methods.lora import LoRALayer
@@ -39,6 +41,8 @@
TOKENIZER_PATH,
get_adapter_config_hash,
inherit_doc,
+ multigetattr,
+ multihasattr,
patch_forward,
resolve_adapter_path,
)
@@ -432,10 +436,16 @@ def _init_adapters_submodules(self, model_config, adapters_config):
if hasattr(module, "init_adapters"):
module.init_adapters(model_config, adapters_config)
- # Initialize reft modules
- init_reft(self)
+ def _default_init_adapter_methods(self, model_config, adapters_config):
+ init_reft(self.base_model)
+ # Add prefix tuning
+ self.base_model.prefix_tuning = PrefixTuningPool(model_config, adapters_config)
+ # Add Prompt Tuning
+ if self.add_base_adapters:
+ if self.support_prompt_tuning:
+ self.prompt_tuning = PromptTuningLayer(model_config, adapters_config, self.get_input_embeddings())
- def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=True):
+ def init_adapters(self, model_config, adapters_config):
"""
This method initializes adapter modules and fusion modules from the model config.
"""
@@ -443,19 +453,22 @@ def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=Tr
# Initialize adapters config
init_adapters_config(self, model_config, adapters_config)
+
+ # Initialize adapter types defined in interface
+ if getattr(self.base_model, "adapter_interface", None) is not None:
+ for adapter_type in self.base_model.adapter_interface.adapter_methods:
+ init_func = METHOD_INIT_MAPPING[adapter_type]
+ init_func(self)
+ else:
+ self._default_init_adapter_methods(self.config, self.adapters_config)
+
# Initialize adapters in all submodules
self._init_adapters_submodules(self.config, self.adapters_config)
# Link all prefix tunings
- if add_prefix_tuning_pool:
- self.base_model.prefix_tuning = PrefixTuningPool(self.config, self.adapters_config)
+ if hasattr(self.base_model, "prefix_tuning"):
self.apply_to_adapter_layers(lambda i, layer: self._link_prefix_to_pool(layer))
- # Add Prompt Tuning
- if self.add_base_adapters:
- if self.support_prompt_tuning:
- self.prompt_tuning = PromptTuningLayer(model_config, self.adapters_config, self.get_input_embeddings())
-
# Initialize adapters from config
for adapter_name in self.adapters_config:
self._add_adapter_weights(adapter_name)
@@ -468,6 +481,35 @@ def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=Tr
self._add_tied_weights_keys()
+ def supports_adapter(self, type_or_config: Union[str, AdapterConfig]) -> bool:
+ """
+ Checks if the model supports a given adapter type.
+
+ Args:
+ adapter_type (str): The adapter type to check.
+
+ Returns:
+ bool: True if the adapter type is supported, False otherwise.
+ """
+ if isinstance(type_or_config, AdapterConfig):
+ types = AdapterMethod.get_from_config(type_or_config)
+ else:
+ types = [type_or_config]
+
+ supported = []
+ for _type in types:
+ if getattr(self.base_model, "adapter_interface", None) is not None:
+ supported.append(_type in self.base_model.adapter_interface.adapter_methods)
+ elif _type == AdapterMethod.prompt_tuning:
+ supported.append(self.base_model.support_prompt_tuning)
+ elif _type == AdapterMethod.invertible:
+ supported.append(
+ isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin)
+ )
+ else:
+ supported.append(True)
+ return all(supported)
+
# These methods have to be implemented by every deriving class:
@abstractmethod
@@ -593,6 +635,10 @@ def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False
the adapter is added but not activated.
"""
config = AdapterConfig.load(config) # ensure config is ok and up-to-date
+ # check if config is valid for this model
+ config_or_type = config or AdapterMethod.bottleneck
+ if not self.supports_adapter(config_or_type):
+ raise ValueError(f"Adapter config or type '{config_or_type}' is not supported by this model.")
# In case adapter already exists and we allow overwriting, explicitly delete the existing one first
if overwrite_ok and adapter_name in self.adapters_config:
self.delete_adapter(adapter_name)
@@ -1200,12 +1246,10 @@ def get_adapter(self, name) -> dict:
# global weights are saved at index -1
if name in self.base_model.shared_parameters:
destination[-1]["shared"] = self.base_model.shared_parameters[name]
- if (
- isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin)
- ) and name in self.invertible_adapters:
+ if self.supports_adapter("invertible") and name in self.invertible_adapters:
destination[-1]["invertible"] = self.invertible_adapters[name]
- if self.support_prompt_tuning:
+ if self.supports_adapter("prompt_tuning"):
prompt_tuning = self.prompt_tuning.get_adapter(name)
if prompt_tuning is not None:
destination[-1]["prompt"] = prompt_tuning
@@ -1614,6 +1658,8 @@ def save_pretrained(
lambda i, layer: layer.set_pool(None) if isinstance(layer, PrefixTuningLayer) else None
)
+ if interface := getattr(self.base_model, "adapter_interface", None):
+ interface._save(save_directory, self.config)
super().save_pretrained(save_directory, **kwargs)
# Remove adapters config
del self.config.adapters
@@ -1686,13 +1732,37 @@ def gradient_checkpointing_function(function, *args, **kwargs):
@inherit_doc
class ModelBaseAdaptersMixin(ModelAdaptersMixin):
+ adapter_interface: AdapterModelInterface = None
add_base_adapters = True
- def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=True):
- super().init_adapters(model_config, adapters_config, add_prefix_tuning_pool)
+ def init_adapters(self, model_config, adapters_config):
+ super().init_adapters(model_config, adapters_config)
patch_forward(self)
+ # Adapter Interface Methods
+
+ def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
+ for i, layer in enumerate(multigetattr(self, self.adapter_interface.model_layers)):
+ yield i, layer
+
+ def get_layer(self, idx: int) -> nn.Module:
+ return multigetattr(self, self.adapter_interface.model_layers)[idx]
+
+ def iter_attentions(self) -> Iterable[Tuple[int, Literal["self", "cross"], nn.Module]]:
+ for i, layer in self.iter_layers():
+ if multihasattr(layer, self.adapter_interface.layer_self_attn or ""):
+ yield i, "self", multigetattr(layer, self.adapter_interface.layer_self_attn)
+ if multihasattr(layer, self.adapter_interface.layer_cross_attn or ""):
+ yield i, "cross", multigetattr(layer, self.adapter_interface.layer_cross_attn)
+
+ def iter_layer_ffns(self) -> Iterable[Tuple[int, Literal["intermediate", "output"], nn.Module]]:
+ for i, layer in self.iter_layers():
+ if intermediate_proj := multigetattr(layer, self.adapter_interface.layer_intermediate_proj):
+ yield i, "intermediate", intermediate_proj
+ if output_proj := multigetattr(layer, self.adapter_interface.layer_output_proj):
+ yield i, "output", output_proj
+
def post_embedding_forward(self, module, args, embedding_output):
if isinstance(self, InvertibleAdaptersMixin) or isinstance(self, InvertibleAdaptersWrapperMixin):
embedding_output = self.invertible_adapters_forward(embedding_output)
@@ -1724,6 +1794,12 @@ def _init_adapters_submodules(self, model_config, adapters_config):
"""
pass
+ def _default_init_adapter_methods(self, model_config, adapters_config):
+ """
+ Init default adapter methods in base model. This is done in sub-models, so don't do anything here.
+ """
+ pass
+
@inherit_doc
class ModelWithHeadsAdaptersMixin(ModelAdaptersMixin):
@@ -1734,8 +1810,8 @@ class ModelWithHeadsAdaptersMixin(ModelAdaptersMixin):
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
- def init_adapters(self, model_config, adapters_config, add_prefix_tuning_pool=True):
- super().init_adapters(model_config, adapters_config, add_prefix_tuning_pool=add_prefix_tuning_pool)
+ def init_adapters(self, model_config, adapters_config):
+ super().init_adapters(model_config, adapters_config)
self._convert_to_flex_head = False
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
diff --git a/src/adapters/models/clip/mixin_clip.py b/src/adapters/models/clip/mixin_clip.py
index 36f16b240a..55e823328e 100644
--- a/src/adapters/models/clip/mixin_clip.py
+++ b/src/adapters/models/clip/mixin_clip.py
@@ -5,7 +5,7 @@
from ...composition import adjust_tensors_for_parallel_
from ...methods.bottleneck import BottleneckLayer
from ...methods.lora import LoRALinear
-from ...methods.prefix_tuning import PrefixTuningLayer
+from ...methods.prefix_tuning import PrefixTuningLayer, PrefixTuningPool
from ...methods.reft import ReftLayer, hook_fn
from ...model_mixin import (
EmbeddingAdaptersMixin,
@@ -119,6 +119,7 @@ def _init_adapters_submodules(self, model_config, adapters_config):
if hasattr(module, "init_adapters"):
module.init_adapters(model_config.vision_config, adapters_config)
+ def _default_init_adapter_methods(self, model_config, adapters_config):
# Patch for ReFT initialization
for layer in self.text_model.encoder.layers:
if not hasattr(layer, "reft_layer"):
@@ -128,3 +129,6 @@ def _init_adapters_submodules(self, model_config, adapters_config):
if not hasattr(layer, "reft_layer"):
layer.reft_layer = ReftLayer("output", model_config.vision_config, adapters_config)
layer.register_forward_hook(hook_fn)
+
+ # Add prefix tuning
+ self.base_model.prefix_tuning = PrefixTuningPool(model_config, adapters_config)
diff --git a/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py b/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py
index 50257d1536..1103eae310 100644
--- a/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py
+++ b/src/adapters/models/encoder_decoder/mixin_encoder_decoder.py
@@ -26,7 +26,7 @@ def init_adapters(self, model_config, adapters_config):
# Before initializing adapters, forward adding invertible adapters to the encoder
self.add_invertible_adapter = self.encoder.base_model.add_invertible_adapter
- super().init_adapters(model_config, adapters_config, add_prefix_tuning_pool=False)
+ super().init_adapters(model_config, adapters_config)
# ensure that encoder and decoder use the same shared parameters
if hasattr(self.encoder, "set_shared_parameters"):
diff --git a/src/adapters/utils.py b/src/adapters/utils.py
index 7c0540850a..f55a042673 100644
--- a/src/adapters/utils.py
+++ b/src/adapters/utils.py
@@ -54,6 +54,7 @@
EMBEDDING_FILE = "embedding.pt"
TOKENIZER_PATH = "tokenizer"
SETUP_CONFIG_NAME = "adapter_setup.json"
+INTERFACE_CONFIG_NAME = "adapter_interface.json"
ADAPTER_HUB_URL = "https://raw.githubusercontent.com/Adapter-Hub/Hub/master/dist/v2/"
ADAPTER_HUB_INDEX_FILE = ADAPTER_HUB_URL + "index/{}.json"
@@ -173,6 +174,39 @@ def inherit_doc(cls):
return cls
+def multigetattr(o: object, name: str, default=None) -> Optional[object]:
+ if not name:
+ return default
+ for n in name.split("."):
+ if hasattr(o, n):
+ o = getattr(o, n)
+ else:
+ return default
+ return o
+
+
+def multihasattr(o: object, name: str) -> bool:
+ if not name:
+ return False
+ parts = name.split(".")
+ for n in parts:
+ if hasattr(o, n):
+ o = getattr(o, n)
+ else:
+ return False
+ return True
+
+
+def multisetattr(o: object, name: str, value: object):
+ parts = name.split(".")
+ for n in parts[:-1]:
+ if hasattr(o, n):
+ o = getattr(o, n)
+ else:
+ return
+ setattr(o, parts[-1], value)
+
+
def urljoin(*args):
return "/".join([s.strip("/") for s in args])
diff --git a/src/adapters/wrappers/configuration.py b/src/adapters/wrappers/configuration.py
index 40dc421787..709bb54009 100644
--- a/src/adapters/wrappers/configuration.py
+++ b/src/adapters/wrappers/configuration.py
@@ -94,6 +94,8 @@ def init_adapters_config(
model.adapters_config = ModelAdaptersConfig()
elif model_config.adapters is not None and not isinstance(model_config.adapters, ModelAdaptersConfig):
model.adapters_config = ModelAdaptersConfig(**model_config.adapters)
+ if hasattr(model, "base_model") and model.base_model is not model:
+ model.base_model.adapters_config = model.adapters_config
# Convert AdapterFusions from old format for backwards compatibility
fusion_models = getattr(model_config, "adapter_fusion_models", [])
diff --git a/src/adapters/wrappers/model.py b/src/adapters/wrappers/model.py
index 1f54e29ca1..70c4f285a9 100644
--- a/src/adapters/wrappers/model.py
+++ b/src/adapters/wrappers/model.py
@@ -10,9 +10,12 @@
from transformers.models.auto.configuration_auto import model_type_to_module_name
from ..configuration import ModelAdaptersConfig
+from ..interface import AdapterModelInterface
from ..model_mixin import (
+ EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
ModelAdaptersMixin,
+ ModelBaseAdaptersMixin,
ModelUsingSubmodelsAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)
@@ -40,7 +43,9 @@ def replace_with_adapter_class(module: nn.Module, modules_with_adapters) -> None
module.__class__.__name__, (MODEL_MIXIN_MAPPING[module.__class__.__name__], module.__class__), {}
)
module.__class__ = model_class
- elif module.__class__.__module__.startswith("transformers.models"):
+ elif module.__class__.__module__.startswith("transformers.models") or module.__class__.__module__.startswith(
+ "adapters.wrappers.model"
+ ):
try:
module_class = getattribute_from_module(modules_with_adapters, module.__class__.__name__ + "WithAdapters")
module.__class__ = module_class
@@ -49,30 +54,51 @@ def replace_with_adapter_class(module: nn.Module, modules_with_adapters) -> None
pass
-def init(model: PreTrainedModel, adapters_config: Optional[ModelAdaptersConfig] = None) -> None:
+def init(
+ model: PreTrainedModel,
+ adapters_config: Optional[ModelAdaptersConfig] = None,
+ interface: Optional[AdapterModelInterface] = None,
+) -> None:
if isinstance(model, ModelAdaptersMixin):
return model
- # First, replace original module classes with their adapters counterparts
- model_name = get_module_name(model.config.model_type)
- modules_with_adapters = importlib.import_module(f".{model_name}.modeling_{model_name}", "adapters.models")
- submodules = list(model.modules())
-
- # Replace the base model class
- replace_with_adapter_class(submodules.pop(0), modules_with_adapters)
-
- # Check if the base model class derives from ModelUsingSubmodelsAdaptersMixin
- if isinstance(model, ModelUsingSubmodelsAdaptersMixin):
- # Before initializing the submodels, make sure that adapters_config is set for the whole model.
- # Otherwise, it would not be shared between the submodels.
- init_adapters_config(model, model.config, adapters_config)
- adapters_config = model.adapters_config
- model.init_submodels()
- submodules = []
-
- # Change the class of all child modules to their adapters class
- for module in submodules:
- replace_with_adapter_class(module, modules_with_adapters)
+ if interface is not None:
+ base_model = model.base_model
+ model_class_name = base_model.__class__.__name__
+ model_class = type(
+ model_class_name,
+ (EmbeddingAdaptersMixin, ModelBaseAdaptersMixin, base_model.__class__),
+ {},
+ )
+ base_model.__class__ = model_class
+ base_model.adapter_interface = interface
+ base_model.support_prompt_tuning = False # HACK: will be set to true if init_prompt_tuning() is called
+ else:
+ # First, replace original module classes with their adapters counterparts
+ model_name = get_module_name(model.config.model_type)
+ try:
+ modules_with_adapters = importlib.import_module(f".{model_name}.modeling_{model_name}", "adapters.models")
+ except ImportError:
+ raise ValueError(
+ f"Model {model_name} not pre-supported by adapters. Please specify and pass `interface` explicitly."
+ )
+ submodules = list(model.modules())
+
+ # Replace the base model class
+ replace_with_adapter_class(submodules.pop(0), modules_with_adapters)
+
+ # Check if the base model class derives from ModelUsingSubmodelsAdaptersMixin
+ if isinstance(model, ModelUsingSubmodelsAdaptersMixin):
+ # Before initializing the submodels, make sure that adapters_config is set for the whole model.
+ # Otherwise, it would not be shared between the submodels.
+ init_adapters_config(model, model.config, adapters_config)
+ adapters_config = model.adapters_config
+ model.init_submodels()
+ submodules = []
+
+ # Change the class of all child modules to their adapters class
+ for module in submodules:
+ replace_with_adapter_class(module, modules_with_adapters)
# Next, check if model class itself is not replaced and has an adapter-supporting base class
if not isinstance(model, ModelAdaptersMixin):
@@ -98,6 +124,7 @@ def init(model: PreTrainedModel, adapters_config: Optional[ModelAdaptersConfig]
def load_model(
model_name_or_path: Optional[Union[str, os.PathLike]],
model_class: Type[PreTrainedModel],
+ interface: Optional[AdapterModelInterface] = None,
*model_args: Any,
**kwargs: Any,
) -> PreTrainedModel:
@@ -109,6 +136,9 @@ def load_model(
Parameter identical to PreTrainedModel.from_pretrained
model_class (`PreTrainedModel` or `AutoModel`):
The model class to load (e.g. EncoderDecoderModel and EncoderDecoderAdapterModel both work)
+ interface (`AdapterModelInterface`, *optional*):
+ The custom adapter interface to use for the model, to be passed to the init() method.
+ If not provided, init() will try to use one of the built-in model integrations.
model_args (sequence of positional arguments, *optional*):
All remaining positional arguments will be passed to the underlying model's `__init__` method.
kwargs (remaining dictionary of keyword arguments, *optional*):
@@ -120,15 +150,20 @@ def load_model(
old_init = model_class.__init__
+ # try if we can find a interface file
+ if interface is None:
+ try:
+ interface = AdapterModelInterface._load(model_name_or_path, **kwargs)
+ except EnvironmentError:
+ pass
+
def new_init(self, config, *args, **kwargs):
old_init(self, config, *args, **kwargs)
- init(self)
+ init(self, interface=interface)
# wrap model after it is initialized but before the weights are loaded
- model_class.__init__ = new_init
- model = model_class.from_pretrained(model_name_or_path, *model_args, **kwargs)
-
- # restore original __init__ function for when other models of the same type are created
- model_class.__init__ = old_init
+ new_model_class = type(model_class.__name__, (model_class,), {})
+ new_model_class.__init__ = new_init
+ model = new_model_class.from_pretrained(model_name_or_path, *model_args, **kwargs)
return model
diff --git a/tests/test_methods/base.py b/tests/test_methods/base.py
index f5e53fedd6..ef44167143 100644
--- a/tests/test_methods/base.py
+++ b/tests/test_methods/base.py
@@ -4,7 +4,7 @@
import torch
import adapters
-from adapters import AutoAdapterModel
+from adapters import ADAPTER_MODEL_MAPPING, AutoAdapterModel
from transformers import AutoFeatureExtractor, AutoTokenizer, GlueDataset, GlueDataTrainingArguments
from transformers.testing_utils import torch_device
@@ -79,6 +79,18 @@ def build_generate_input(self, shape):
"""The generate() functions for inference require different inputs depeding on the model type. E.g. the text models require input_ids, where as the audio models require input_features"""
return self.build_rand_ids_tensor(self.input_shape if not shape else shape).to(torch_device)
+ def _init_model_for_train_run(self, trained_adapter_name, frozen_adapter_name, adapter_config=None):
+ if self.config_class not in ADAPTER_MODEL_MAPPING:
+ self.skipTest("Does not support flex heads.")
+ model = AutoAdapterModel.from_config(self.config())
+
+ # add two adapters: one will be trained and the other should be frozen
+ model.add_adapter(trained_adapter_name, config=adapter_config)
+ model.add_adapter(frozen_adapter_name, config=adapter_config)
+ self.add_head(model, trained_adapter_name)
+
+ return model
+
class TextAdapterTestBase(AbstractAdapterTestBase):
"""Base class for adapter tests for text models. Text models test classes should inherit from this class and override the attributes and functions as needed."""
diff --git a/tests/test_methods/method_test_impl/base.py b/tests/test_methods/method_test_impl/base.py
index 7f7e0ba83b..b970f2a3ee 100644
--- a/tests/test_methods/method_test_impl/base.py
+++ b/tests/test_methods/method_test_impl/base.py
@@ -6,7 +6,7 @@
import torch
import adapters
-from adapters import ADAPTER_MODEL_MAPPING, AdapterSetup, AdapterTrainer, AutoAdapterModel
+from adapters import ADAPTER_MODEL_MAPPING, AdapterSetup, AdapterTrainer
from adapters.heads import CausalLMHead
from adapters.utils import WEIGHTS_NAME
from adapters.wrappers import load_model
@@ -188,8 +188,11 @@ def run_forward_test(self, model, adapter_config, dtype=torch.float32):
model.set_active_adapters(None)
model.delete_adapter(name)
+ def create_twin_models(self):
+ return create_twin_models(self.model_class, self.config)
+
def run_load_test(self, adapter_config):
- model1, model2 = create_twin_models(self.model_class, self.config)
+ model1, model2 = self.create_twin_models()
name = "dummy_adapter"
model1.add_adapter(name, config=adapter_config)
@@ -233,8 +236,8 @@ def run_full_model_load_test(self, adapter_config):
model2, loading_info = load_model(temp_dir, self.model_class, output_loading_info=True)
# check if all weights were loaded
- self.assertEqual(0, len(loading_info["missing_keys"]))
- self.assertEqual(0, len(loading_info["unexpected_keys"]))
+ self.assertEqual(0, len(loading_info["missing_keys"]), loading_info["missing_keys"])
+ self.assertEqual(0, len(loading_info["unexpected_keys"]), loading_info["unexpected_keys"])
# check if adapter was correctly loaded
self.assertTrue(name in model2.adapters_config)
@@ -275,14 +278,7 @@ def trainings_run(self, model, lr=1.0, steps=8, batch_size=2, gradient_accumulat
def run_train_test(self, adapter_config, filter_keys):
if not self.do_run_train_tests:
self.skipTest("Skipping training tests. Set `do_run_train_tests=True` to run them.")
- if self.config_class not in ADAPTER_MODEL_MAPPING:
- self.skipTest("Does not support flex heads.")
- model = AutoAdapterModel.from_config(self.config())
-
- # add two adapters: one will be trained and the other should be frozen
- model.add_adapter("mrpc", config=adapter_config)
- model.add_adapter("dummy", config=adapter_config)
- self.add_head(model, "mrpc")
+ model = self._init_model_for_train_run("mrpc", "dummy", adapter_config)
self._assert_adapter_available(model, "mrpc")
self._assert_adapter_available(model, "dummy")
@@ -314,7 +310,8 @@ def run_train_test(self, adapter_config, filter_keys):
def has_tied_embeddings(k):
tied_embeddings = hasattr(model.config, "tie_word_embeddings") and model.config.tie_word_embeddings
is_tied_layer = (
- isinstance(model.heads["mrpc"], CausalLMHead)
+ hasattr(model, "heads")
+ and isinstance(model.heads["mrpc"], CausalLMHead)
and "heads.{}.{}.weight".format("mrpc", len(model.heads["mrpc"]._modules) - 1) in k
)
return tied_embeddings and is_tied_layer
@@ -322,7 +319,7 @@ def has_tied_embeddings(k):
for (k1, v1), (k2, v2) in zip(state_dict_pre.items(), model.state_dict().items()):
# move both to the same device to avoid device mismatch errors
v1, v2 = v1.to(v2.device), v2
- if "mrpc" in k1 and not has_tied_embeddings(k1):
+ if "mrpc" in k1 and not has_tied_embeddings(k1) or not k1.startswith(model.base_model_prefix):
adapters_with_change |= not torch.equal(v1, v2)
else:
base_with_change |= not torch.equal(v1, v2)
@@ -463,7 +460,10 @@ def run_same_weights_test(self, adapter_config, filter_keys):
self.assertTrue(torch.equal(v1, v2), msg=f"{k1} has different weights than {k2}")
# Check multiple models with one adapter with same config
- model1, model2 = create_twin_models(self.model_class, self.config)
+ if hasattr(self, "adapter_interface") and self.adapter_interface:
+ model1, model2 = create_twin_models(self.model_class, self.config, self.adapter_interface)
+ else:
+ model1, model2 = create_twin_models(self.model_class, self.config)
model1.add_adapter("adapter", config=adapter_config)
model2.add_adapter("adapter", config=adapter_config)
per_model_filter_keys = {"adapter": [k.format(name="adapter") for k in filter_keys]}
diff --git a/tests/test_methods/method_test_impl/core/test_adapter_backward_compability.py b/tests/test_methods/method_test_impl/core/test_adapter_backward_compability.py
index 196380524f..d74fc48619 100644
--- a/tests/test_methods/method_test_impl/core/test_adapter_backward_compability.py
+++ b/tests/test_methods/method_test_impl/core/test_adapter_backward_compability.py
@@ -9,8 +9,11 @@
@require_torch
class CompabilityTestMixin:
+ def create_twin_models(self):
+ return create_twin_models(self.model_class, self.config)
+
def test_load_old_non_linearity(self):
- model1, model2 = create_twin_models(self.model_class, self.config)
+ model1, model2 = self.create_twin_models()
config = SeqBnConfig(non_linearity="gelu")
name = "dummy"
model1.add_adapter(name, config=config)
diff --git a/tests/test_methods/method_test_impl/embeddings/test_adapter_embeddings.py b/tests/test_methods/method_test_impl/embeddings/test_adapter_embeddings.py
index a41b862004..93832106ac 100644
--- a/tests/test_methods/method_test_impl/embeddings/test_adapter_embeddings.py
+++ b/tests/test_methods/method_test_impl/embeddings/test_adapter_embeddings.py
@@ -3,7 +3,6 @@
import torch
-from adapters import AutoAdapterModel
from transformers import AutoTokenizer, Trainer, TrainingArguments
from transformers.testing_utils import require_torch, torch_device
@@ -85,11 +84,10 @@ def test_training_embedding(self):
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
- model = AutoAdapterModel.from_config(self.config())
+ model = self._init_model_for_train_run("test", "dummy")
+
model.add_embeddings("test", tokenizer)
self.assertEqual(model.active_embeddings, "test")
- model.add_adapter("test")
- self.add_head(model, "test")
model.train_adapter("test", train_embeddings=True)
for k, v in filter_parameters(model, "adapters.test.").items():
@@ -105,7 +103,7 @@ def test_training_embedding(self):
training_args = TrainingArguments(
output_dir="./examples",
do_train=True,
- learning_rate=0.4,
+ learning_rate=1.0,
max_steps=15,
use_cpu=True,
per_device_train_batch_size=2,
@@ -140,17 +138,19 @@ def test_training_embedding(self):
and "embed_tokens" not in k1
and "shared" not in k1
and "wte" not in k1
+ and "score" not in k1
)
)
def test_reference_embedding(self):
- model = AutoAdapterModel.from_config(self.config()) # self.get_model()
+ model = self.get_model()
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
new_tokenizer = AutoTokenizer.from_pretrained("tests/fixtures/SiBERT")
model.add_embeddings("test", new_tokenizer, "default", tokenizer)
+ model.to(torch_device)
default_embedding = model.base_model.loaded_embeddings["default"]
test_embedding = model.base_model.loaded_embeddings["test"]
@@ -165,8 +165,8 @@ def test_reference_embedding(self):
if len(input_test) >= 5:
break
- input_default = torch.tensor([input_default])
- input_test = torch.tensor([input_test])
+ input_default = torch.tensor([input_default]).to(torch_device)
+ input_test = torch.tensor([input_test]).to(torch_device)
default = default_embedding(input_default)
test = test_embedding(input_test)
diff --git a/tests/test_methods/method_test_impl/peft/test_adapter_common.py b/tests/test_methods/method_test_impl/peft/test_adapter_common.py
index 6116804c3c..a69ab0d406 100644
--- a/tests/test_methods/method_test_impl/peft/test_adapter_common.py
+++ b/tests/test_methods/method_test_impl/peft/test_adapter_common.py
@@ -13,8 +13,6 @@
DoubleSeqBnConfig,
DoubleSeqBnInvConfig,
Fuse,
- InvertibleAdaptersMixin,
- InvertibleAdaptersWrapperMixin,
MAMConfig,
SeqBnConfig,
SeqBnInvConfig,
@@ -75,7 +73,7 @@ def test_delete_adapter(self):
def test_add_adapter_with_invertible(self):
model = self.get_model().base_model
model.eval()
- if not isinstance(model, InvertibleAdaptersMixin) and not isinstance(model, InvertibleAdaptersWrapperMixin):
+ if not model.supports_adapter("invertible"):
self.skipTest("Model does not support invertible adapters.")
for adapter_config in [SeqBnInvConfig(), DoubleSeqBnInvConfig()]:
@@ -123,7 +121,7 @@ def test_delete_adapter_with_invertible(self):
"""Tests if the invertible adapters are deleted correctly."""
model = self.get_model().base_model
model.eval()
- if not isinstance(model, InvertibleAdaptersMixin) and not isinstance(model, InvertibleAdaptersWrapperMixin):
+ if not model.supports_adapter("invertible"):
self.skipTest("Model does not support invertible adapters.")
# iterate through all adapter invertible adapter configs
@@ -227,6 +225,8 @@ def test_forward_bottleneck(self):
def test_invertible_adapter_forward(self):
model = self.get_model()
model.eval()
+ if not model.supports_adapter("invertible"):
+ self.skipTest("Model does not support invertible adapters.")
for adapter_config, _ in self.inv_adapter_configs_to_test:
with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__):
@@ -248,6 +248,8 @@ def test_model_config_serialization(self):
"""
for k, v in ADAPTER_CONFIG_MAP.items():
model = self.get_model()
+ if not model.supports_adapter(v):
+ continue
# HACK: reduce the reduction factor such that
# the small test model can have a phm_dim of 4
if hasattr(v, "phm_layer") and v.phm_layer:
@@ -260,15 +262,19 @@ def test_model_adapter_summary(self):
# count model parameters before
model = self.get_model()
model_no_params = sum(p.numel() for p in model.parameters())
+ added = []
for k, v in ADAPTER_CONFIG_MAP.items():
+ if not model.supports_adapter(v):
+ continue
# HACK: reduce the reduction factor such that
# the small test model can have a phm_dim of 4
if hasattr(v, "phm_layer") and v.phm_layer:
v = v.__class__(reduction_factor=4)
model.add_adapter(k, config=v)
+ added.append(k)
summary = model.adapter_summary(as_dict=True)
- self.assertEqual(len(ADAPTER_CONFIG_MAP) + 1, len(summary))
- for name in ADAPTER_CONFIG_MAP.keys():
+ self.assertEqual(len(added) + 1, len(summary))
+ for name in added:
self.assertTrue(any([row["name"] == name for row in summary]))
self.assertEqual(model_no_params, summary[-1]["#param"])
diff --git a/tests/test_methods/method_test_impl/peft/test_lora.py b/tests/test_methods/method_test_impl/peft/test_lora.py
index 70d8e0a447..1014f9a36a 100644
--- a/tests/test_methods/method_test_impl/peft/test_lora.py
+++ b/tests/test_methods/method_test_impl/peft/test_lora.py
@@ -2,7 +2,7 @@
import torch
-from adapters import LoRAConfig
+from adapters import AdapterConfig, LoRAConfig
from adapters.methods.lora import LoRALayer
from tests.test_methods.method_test_impl.base import AdapterMethodBaseTestMixin
from transformers.testing_utils import require_torch
@@ -23,21 +23,30 @@ def test_merging_with_other_adapters(self):
model.add_adapter("lora", config="lora")
# Add different adapters
- model.add_adapter("bottleneck", config="seq_bn")
- model.add_adapter("prompt", config="prompt_tuning")
- model.add_adapter("prefix", config="prefix_tuning")
- model.add_adapter("ia3", config="ia3")
- model.add_adapter("unipelt", config="unipelt")
- model.add_adapter("mam", config="mam")
- model.add_adapter("compacter", config="compacter[phm_dim=2, reduction_factor=8]")
+ adapter_methods = [
+ "seq_bn",
+ "prompt_tuning",
+ "prefix_tuning",
+ "ia3",
+ "unipelt",
+ "mam",
+ "compacter[phm_dim=2, reduction_factor=8]",
+ ]
+
+ for adapter_method in adapter_methods:
+ config = AdapterConfig.load(adapter_method)
+ if model.supports_adapter(config):
+ model.add_adapter(adapter_method, config=config)
# Merging adapters with different architectures with LoRA should raise a ValueError
- for adapter_architecture in ["bottleneck", "prompt", "prefix", "ia3", "unipelt", "mam", "compacter"]:
- with self.subTest(adapter_architecture=adapter_architecture):
+ for adapter_method in adapter_methods:
+ with self.subTest(adapter_architecture=adapter_method):
+ if adapter_method not in model.adapters_config:
+ continue
with self.assertRaises(ValueError):
model.average_adapter(
- adapter_name=f"average_lora_{adapter_architecture}",
- adapter_list=[adapter_architecture, "lora"],
+ adapter_name=f"average_lora_{adapter_method}",
+ adapter_list=[adapter_method, "lora"],
weights=[0.5, 0.5],
combine_strategy="linear",
)
diff --git a/tests/test_methods/method_test_impl/utils.py b/tests/test_methods/method_test_impl/utils.py
index 473c422e60..445e9d2e63 100644
--- a/tests/test_methods/method_test_impl/utils.py
+++ b/tests/test_methods/method_test_impl/utils.py
@@ -10,7 +10,7 @@
global_rng = random.Random()
-def create_twin_models(model_class, config_creator=None):
+def create_twin_models(model_class, config_creator=None, interface=None):
if config_creator and model_class.__name__.startswith("Auto"):
model_config = config_creator()
model1 = model_class.from_config(model_config)
@@ -20,7 +20,7 @@ def create_twin_models(model_class, config_creator=None):
else:
model_config = model_class.config_class()
model1 = model_class(model_config)
- init(model1)
+ init(model1, interface=interface)
model1.eval()
# create a twin initialized with the same random weights
model2 = copy.deepcopy(model1)
diff --git a/tests/test_methods/test_on_custom_interface.py b/tests/test_methods/test_on_custom_interface.py
new file mode 100644
index 0000000000..2952fef3cc
--- /dev/null
+++ b/tests/test_methods/test_on_custom_interface.py
@@ -0,0 +1,125 @@
+import unittest
+
+import pytest
+
+import adapters
+from adapters import AdapterModelInterface, ConfigUnion, DoubleSeqBnConfig, LoRAConfig, ParBnConfig
+from transformers import Gemma2ForCausalLM, Gemma2ForSequenceClassification
+from transformers.models.gemma2.configuration_gemma2 import Gemma2Config
+from transformers.testing_utils import torch_device
+
+from .base import TextAdapterTestBase
+from .generator import generate_method_tests, require_torch
+from .method_test_impl.peft.test_adapter_common import BottleneckAdapterTestMixin
+from .method_test_impl.utils import create_twin_models, make_config
+
+
+class CustomInterfaceModelTestBase(TextAdapterTestBase):
+ model_class = Gemma2ForCausalLM
+ config_class = Gemma2Config
+ config = make_config(
+ Gemma2Config,
+ hidden_size=32,
+ num_hidden_layers=4,
+ num_attention_heads=4,
+ num_key_value_heads=4,
+ intermediate_size=16,
+ pad_token_id=0,
+ )
+ tokenizer_name = "yujiepan/gemma-2-tiny-random"
+ adapter_interface = AdapterModelInterface(
+ adapter_methods=["bottleneck", "lora", "reft", "invertible"],
+ model_embeddings="embed_tokens",
+ model_layers="layers",
+ layer_self_attn="self_attn",
+ layer_cross_attn=None,
+ attn_k_proj="k_proj",
+ attn_q_proj="q_proj",
+ attn_v_proj="v_proj",
+ attn_o_proj="o_proj",
+ layer_intermediate_proj="mlp.up_proj",
+ layer_output_proj="mlp.down_proj",
+ layer_pre_self_attn="input_layernorm",
+ layer_pre_cross_attn=None,
+ layer_pre_ffn="pre_feedforward_layernorm",
+ layer_ln_1="post_attention_layernorm",
+ layer_ln_2="post_feedforward_layernorm",
+ )
+
+ def get_model(self):
+ model = Gemma2ForCausalLM(self.config())
+ adapters.init(model, interface=self.adapter_interface)
+ model.to(torch_device)
+ return model
+
+ def _init_model_for_train_run(self, trained_adapter_name, frozen_adapter_name, adapter_config=None):
+ model = Gemma2ForSequenceClassification(self.config())
+ adapters.init(model, interface=self.adapter_interface)
+
+ model.add_adapter(trained_adapter_name, config=adapter_config or LoRAConfig(init_weights="bert"))
+ model.add_adapter(frozen_adapter_name, config=adapter_config or LoRAConfig(init_weights="bert"))
+
+ return model
+
+ adapter_configs_to_test = [
+ (DoubleSeqBnConfig(), ["adapters.{name}."]),
+ (ParBnConfig(init_weights="bert"), ["adapters.{name}."]),
+ ]
+
+ def create_twin_models(self):
+ return create_twin_models(self.model_class, self.config, self.adapter_interface)
+
+ def test_load_mam_adapter(self):
+ self.skipTest("Does not support prefix tuning.")
+
+ def test_train_mam_adapter(self):
+ self.skipTest("Does not support prefix tuning.")
+
+ def test_merging_with_other_adapters(self):
+ self.skipTest("Does not support all required methods yet.")
+
+ def test_supports_adapter(self):
+ model = self.get_model()
+ model.eval()
+
+ config = "unipelt"
+ with self.assertRaises(ValueError):
+ model.add_adapter("my_adapter", config=config)
+
+
+method_tests = generate_method_tests(
+ CustomInterfaceModelTestBase,
+ not_supported=[
+ "ConfigUnion",
+ "ClassConversion",
+ "Heads",
+ "PrefixTuning",
+ "PromptTuning",
+ "UniPELT",
+ "Composition",
+ "Bottleneck",
+ ],
+)
+
+for test_class_name, test_class in method_tests.items():
+ globals()[test_class_name] = test_class
+
+
+@require_torch
+@pytest.mark.bottleneck
+class Bottleneck(
+ CustomInterfaceModelTestBase,
+ BottleneckAdapterTestMixin,
+ unittest.TestCase,
+):
+ def test_get_adapter(self):
+ model = self.get_model()
+ model.eval()
+ n_layers = len(list(model.iter_layers()))
+
+ for adapter_config, n_expected in [
+ (DoubleSeqBnConfig(), n_layers * 2),
+ (ConfigUnion(LoRAConfig(), ParBnConfig()), n_layers * 2),
+ ]:
+ with self.subTest(model_class=model.__class__.__name__, config=adapter_config.__class__.__name__):
+ self.run_get_test(model, adapter_config, n_expected)
diff --git a/tests/test_misc/test_custom_interface_compat.py b/tests/test_misc/test_custom_interface_compat.py
new file mode 100644
index 0000000000..1e19aade28
--- /dev/null
+++ b/tests/test_misc/test_custom_interface_compat.py
@@ -0,0 +1,209 @@
+import os
+import tempfile
+import unittest
+
+import torch
+
+import adapters
+from adapters import AdapterModelInterface, AutoAdapterModel
+from adapters.utils import WEIGHTS_NAME
+from parameterized import parameterized
+from tests.test_methods.method_test_impl.utils import ids_tensor
+from transformers import AutoModel, AutoModelForCausalLM, BertConfig, LlamaConfig
+from transformers.testing_utils import require_torch, torch_device
+
+
+@require_torch
+class CustomInterfaceCompatTest(unittest.TestCase):
+ # This test is to check if the custom interface produces the same results as the AdapterModel implementation.
+
+ llama_config = LlamaConfig(
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ pad_token_id=0,
+ )
+ bert_config = BertConfig(
+ hidden_size=32,
+ num_hidden_layers=5,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ pad_token_id=0,
+ )
+ llama_interface = AdapterModelInterface(
+ adapter_methods=["bottleneck", "lora", "reft", "invertible"],
+ model_embeddings="embed_tokens",
+ model_layers="layers",
+ layer_self_attn="self_attn",
+ layer_cross_attn=None,
+ attn_k_proj="k_proj",
+ attn_q_proj="q_proj",
+ attn_v_proj="v_proj",
+ attn_o_proj="o_proj",
+ layer_intermediate_proj="mlp.up_proj",
+ layer_output_proj="mlp.down_proj",
+ layer_pre_self_attn="input_layernorm",
+ layer_pre_cross_attn=None,
+ layer_pre_ffn="post_attention_layernorm",
+ layer_ln_1=None,
+ layer_ln_2=None,
+ )
+ bert_interface = AdapterModelInterface(
+ adapter_methods=["bottleneck", "lora", "reft", "prompt_tuning", "invertible"],
+ model_embeddings="embeddings",
+ model_layers="encoder.layer",
+ layer_self_attn="attention",
+ layer_cross_attn=None,
+ attn_k_proj="self.key",
+ attn_q_proj="self.query",
+ attn_v_proj="self.value",
+ attn_o_proj="output.dense",
+ layer_intermediate_proj="intermediate.dense",
+ layer_output_proj="output.dense",
+ layer_pre_self_attn="attention.self",
+ layer_pre_cross_attn=None,
+ layer_pre_ffn="intermediate",
+ layer_ln_1="attention.output.LayerNorm",
+ layer_ln_2="output.LayerNorm",
+ )
+ bert_bn_rewrites = [(".attention_adapters.", ".attention.output."), (".output_adapters.", ".output.")]
+
+ def create_twin_models(self, config, adapter_interface, hf_auto_model_class):
+ model1 = hf_auto_model_class.from_config(config)
+ adapters.init(model1, interface=adapter_interface)
+ model1.eval()
+ # create a twin initialized with the same random weights
+ model2 = AutoAdapterModel.from_pretrained(None, config=config, state_dict=model1.state_dict())
+ model2.eval()
+ return model1, model2
+
+ @parameterized.expand(
+ [
+ ("LoRA_Llama", adapters.LoRAConfig(), llama_config, llama_interface, AutoModelForCausalLM),
+ ("LoRA_BERT", adapters.LoRAConfig(), bert_config, bert_interface, AutoModel),
+ ("LoReft_Llama", adapters.LoReftConfig(), llama_config, llama_interface, AutoModelForCausalLM),
+ ("LoReft_BERT", adapters.LoReftConfig(), bert_config, bert_interface, AutoModel),
+ (
+ "BnSeq_Llama",
+ adapters.SeqBnConfig(original_ln_before=False),
+ llama_config,
+ llama_interface,
+ AutoModelForCausalLM,
+ ),
+ (
+ "BnSeqInv_Llama",
+ adapters.SeqBnInvConfig(),
+ llama_config,
+ llama_interface,
+ AutoModelForCausalLM,
+ ),
+ (
+ "BnSeqPreLN_Llama",
+ adapters.SeqBnConfig(original_ln_before=True),
+ llama_config,
+ llama_interface,
+ AutoModelForCausalLM,
+ ),
+ ("BnPar_Llama", adapters.ParBnConfig(), llama_config, llama_interface, AutoModelForCausalLM),
+ (
+ "Bn2Seq_Llama",
+ adapters.DoubleSeqBnConfig(original_ln_before=True),
+ llama_config,
+ llama_interface,
+ AutoModelForCausalLM,
+ ),
+ (
+ "Bn2Par_Llama",
+ adapters.ParBnConfig(mh_adapter=True, output_adapter=True),
+ llama_config,
+ llama_interface,
+ AutoModelForCausalLM,
+ ),
+ (
+ "BnSeq_BERT",
+ adapters.SeqBnConfig(original_ln_before=False),
+ bert_config,
+ bert_interface,
+ AutoModel,
+ bert_bn_rewrites,
+ ),
+ (
+ "BnSeqInv_BERT",
+ adapters.SeqBnInvConfig(),
+ bert_config,
+ bert_interface,
+ AutoModel,
+ bert_bn_rewrites,
+ ),
+ (
+ "BnSeqPreLN_BERT",
+ adapters.SeqBnConfig(original_ln_before=True),
+ bert_config,
+ bert_interface,
+ AutoModel,
+ bert_bn_rewrites,
+ ),
+ ("BnPar_BERT", adapters.ParBnConfig(), bert_config, bert_interface, AutoModel, bert_bn_rewrites),
+ (
+ "Bn2Seq_BERT",
+ adapters.DoubleSeqBnConfig(original_ln_before=True),
+ bert_config,
+ bert_interface,
+ AutoModel,
+ bert_bn_rewrites,
+ ),
+ (
+ "Bn2Par_BERT",
+ adapters.ParBnConfig(mh_adapter=True, output_adapter=True),
+ bert_config,
+ bert_interface,
+ AutoModel,
+ bert_bn_rewrites,
+ ),
+ ("Prompt_BERT", adapters.PromptTuningConfig(), bert_config, bert_interface, AutoModel),
+ ]
+ )
+ def test_load_adapter(self, name, adapter_config, config, adapter_interface, hf_auto_model_class, rewrites=None):
+ custom_model, auto_model = self.create_twin_models(config, adapter_interface, hf_auto_model_class)
+
+ custom_model.add_adapter(name, config=adapter_config)
+ custom_model.set_active_adapters(name)
+ with tempfile.TemporaryDirectory() as temp_dir:
+ custom_model.save_adapter(temp_dir, name)
+
+ # Check that there are actually weights saved
+ weights = torch.load(os.path.join(temp_dir, WEIGHTS_NAME), map_location="cpu")
+ self.assertTrue(len(weights) > 0)
+ # The weight names of the custom interface adapter and built-in adapter might be different.
+ # Apply weight renaming here if necessary.
+ if rewrites is not None:
+ for old, new in rewrites:
+ for key in list(weights.keys()):
+ if old in key:
+ new_key = key.replace(old, new)
+ weights[new_key] = weights.pop(key)
+
+ torch.save(weights, os.path.join(temp_dir, WEIGHTS_NAME))
+
+ # also tests that set_active works
+ loading_info = {}
+ auto_model.load_adapter(temp_dir, set_active=True, loading_info=loading_info)
+
+ # check if all weights were loaded
+ self.assertEqual(0, len(loading_info["missing_keys"]), loading_info["missing_keys"])
+ self.assertEqual(0, len(loading_info["unexpected_keys"]), loading_info["unexpected_keys"])
+
+ # check if adapter was correctly loaded
+ self.assertTrue(name in auto_model.adapters_config)
+
+ # check equal output
+ input_data = {"input_ids": ids_tensor((2, 128), 1000)}
+ custom_model.to(torch_device)
+ auto_model.to(torch_device)
+ output1 = custom_model(**input_data)
+ output2 = auto_model(**input_data)
+ self.assertEqual(len(output1), len(output2))
+ self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-5))