-
Notifications
You must be signed in to change notification settings - Fork 357
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pluggable Model Integration Interface (#738)
This PR drafts a new model integration interface which makes it easier to support new and custom model architectures for selected adapter methods without full model implementation. This is done with the new `AdapterModelInterface` class that translates from generic model access points to model-specific attribute names. ### Example usage: Basic interface for Qwen model: ```python model_interface = AdapterModelInterface( adapter_types=["lora", "reft"], model_embeddings="embed_tokens", model_layers="layers", layer_self_attn="self_attn", layer_cross_attn=None, attn_k_proj="k_proj", attn_q_proj="q_proj", attn_v_proj="v_proj", attn_o_proj="o_proj", layer_intermediate_proj="mlp.up_proj", layer_output_proj="mlp.down_proj", ) model_name = "Qwen/Qwen2-0.5B" model = AutoModelForCausalLM.from_pretrained(model_name) adapters.init(model, interface=model_interface) config = LoRAConfig() # config = LoReftConfig() model.add_adapter("my_adapter", config=config) print(model.adapter_summary()) ``` #### Extended interface Additionally, the interface provides optional attributes that enable (almost) full bottleneck adapter support. Without the extended interface, bottleneck adapter support is very limited. Example for Gemma2: ```python adapter_interface = AdapterModelInterface( adapter_types=["bottleneck", "lora", "reft"], model_embeddings="embed_tokens", model_layers="layers", layer_self_attn="self_attn", layer_cross_attn=None, attn_k_proj="k_proj", attn_q_proj="q_proj", attn_v_proj="v_proj", attn_o_proj="o_proj", layer_intermediate_proj="mlp.up_proj", layer_output_proj="mlp.down_proj", layer_pre_self_attn="input_layernorm", layer_pre_cross_attn=None, layer_pre_ffn="pre_feedforward_layernorm", layer_ln_1="post_attention_layernorm", layer_ln_2="post_feedforward_layernorm", ) ``` ### Additional novelties - Adds `AdapterMethod` as an enum of all supported adapter method types (e.g. `AdapterMethod.bottleneck`, `AdapterMethod.lora`, ...) - Adds a `supports_adapter()` method for easy checking whether a model instance supports a certain adapter method. This method can receive an `AdapterMethod` string or a config object: ```python model.supports_adapter(AdapterMethod.prompt_tuning) # or model.supports_adapter(PromptTuningConfig()) ``` (This method is supported by both models implemented via "classic" mixins and via pluggable interface.) ### State of implementation Supported adapter types: - [x] LoRA - [x] ReFT - [x] Bottleneck/ Compacter: **partial**, currently does **not** support: - [x] `is_parallel`, via extended interface - [x] `original_ln_before=True`, via extended interface - [ ] `original_ln_after=False` (e.g. used for `AdapterPlusConfig`) - [x] Invertible adapters - [ ] Prefix Tuning - [x] Prompt Tuning: **partial**: attention mask modification only supports very specific model implementations Supported features: - [x] Embedding training - [x] Fusion composition **Not** to be supported: - Parallel composition - AdapterModel classes
- Loading branch information
Showing
28 changed files
with
1,109 additions
and
103 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
Adapter Model Interface | ||
======================= | ||
|
||
.. autoclass:: adapters.AdapterModelInterface | ||
:members: | ||
|
||
.. autoclass:: adapters.AdapterMethod | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Custom Models | ||
|
||
The _Adapters_ library provides a simple mechanism for integrating adapter methods into any available _Transformers_ model - including custom architectures. | ||
This can be accomplished by defining a plugin interface instance of [`AdapterModelInterface`](adapters.AdapterModelInterface). | ||
The following example shows how this looks like for Gemma 2: | ||
|
||
```python | ||
import adapters | ||
from adapters import AdapterModelInterface | ||
from transformers import AutoModelForCausalLM | ||
|
||
plugin_interface = AdapterModelInterface( | ||
adapter_methods=["lora", "reft"], | ||
model_embeddings="embed_tokens", | ||
model_layers="layers", | ||
layer_self_attn="self_attn", | ||
layer_cross_attn=None, | ||
attn_k_proj="k_proj", | ||
attn_q_proj="q_proj", | ||
attn_v_proj="v_proj", | ||
attn_o_proj="o_proj", | ||
layer_intermediate_proj="mlp.up_proj", | ||
layer_output_proj="mlp.down_proj", | ||
) | ||
|
||
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", token="<YOUR_TOKEN>") | ||
adapters.init(model, interface=plugin_interface) | ||
|
||
model.add_adapter("my_adapter", config="lora") | ||
|
||
print(model.adapter_summary()) | ||
``` | ||
|
||
## Walkthrough | ||
|
||
Let's go through what happens in the example above step by step: | ||
|
||
**1. Define adapter methods to plug into a model:** | ||
The `adapter_methods` argument is the central parameter to configure which adapters will be supported in the model. | ||
Here, we enable all LoRA and ReFT based adapters. | ||
See [`AdapterMethod`](adapters.AdapterMethod) for valid options to specify here. | ||
Check out [Adapter Methods](methods.md) for detailed explanation of the methods. | ||
|
||
**2. Define layer and module names:** | ||
While all Transformers layers share similar basic components, their implementation can differ in terms of subtleties such as module names. | ||
Therefore, the [`AdapterModelInterface`](adapters.AdapterModelInterface) needs to translate the model-specific module structure into a common set of access points for adapter implementations to hook in. | ||
The remaining attributes in the definition above serve this purpose. | ||
Their attribute names follow a common syntax that specify their location and purpose: | ||
- The initial part before the first "_" defines the base module relative to which the name should be specified. | ||
- The remaining part after the first "_" defines the functional component. | ||
|
||
E.g., `model_embeddings` identifies the embeddings layer (functional component) relative to the base model (location). | ||
`layer_output_proj` identifies the FFN output projection relative to one Transformer layer. | ||
Each attribute value may specify a direct submodule of the reference module (`"embed_token"`) or a multi-level path starting at the reference module (`"mlp.down_proj"`). | ||
|
||
**3. (optional) Extended interface attributes:** | ||
There are a couple of attributes in the [`AdapterModelInterface`](adapters.AdapterModelInterface) that are only required for some adapter methods. | ||
We don't need those in the above example for LoRA and ReFT, but when supporting bottleneck adapters as well, the full interface would look as follows: | ||
```python | ||
adapter_interface = AdapterModelInterface( | ||
adapter_types=["bottleneck", "lora", "reft"], | ||
model_embeddings="embed_tokens", | ||
model_layers="layers", | ||
layer_self_attn="self_attn", | ||
layer_cross_attn=None, | ||
attn_k_proj="k_proj", | ||
attn_q_proj="q_proj", | ||
attn_v_proj="v_proj", | ||
attn_o_proj="o_proj", | ||
layer_intermediate_proj="mlp.up_proj", | ||
layer_output_proj="mlp.down_proj", | ||
layer_pre_self_attn="input_layernorm", | ||
layer_pre_cross_attn=None, | ||
layer_pre_ffn="pre_feedforward_layernorm", | ||
layer_ln_1="post_attention_layernorm", | ||
layer_ln_2="post_feedforward_layernorm", | ||
) | ||
``` | ||
|
||
**4. Initialize adapter methods in the model:** | ||
Finally, we just need to apply the defined adapter integration in the target model. | ||
This can be achieved using the usual `adapters.init()` method: | ||
```python | ||
adapters.init(model, interface=adapter_interface) | ||
``` | ||
Now, you can use (almost) all functionality of the _Adapters_ library on the adapted model instance! | ||
|
||
## Limitations | ||
|
||
The following features of the _Adapters_ library are not supported via the plugin interface approach: | ||
- Prefix Tuning adapters | ||
- Parallel composition blocks | ||
- XAdapterModel classes | ||
- Setting `original_ln_after=False` in bottleneck adapter configurations (this affects `AdapterPlusConfig`) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import json | ||
import os | ||
from dataclasses import asdict, dataclass | ||
from typing import List, Optional | ||
|
||
from transformers.utils import cached_file | ||
|
||
from . import __version__ | ||
from .utils import INTERFACE_CONFIG_NAME | ||
|
||
|
||
class AdapterMethod: | ||
""" | ||
Enum of all supported adapter method types. | ||
Attributes: | ||
bottleneck: Adapter methods using bottleneck layers. | ||
prefix_tuning: Adapters methods based on Prefix Tuning. Note that this is currently unsupported via AdapterModelInterface. | ||
lora: Adapter methods based on low-rank adaptation. | ||
prompt_tuning: Adapter methods based on Prompt Tuning. | ||
reft: Adapters methods based on Representation Fine-Tuning. | ||
invertible: Adapter methods using invertible modules. | ||
""" | ||
|
||
bottleneck = "bottleneck" | ||
prefix_tuning = "prefix_tuning" | ||
lora = "lora" | ||
prompt_tuning = "prompt_tuning" | ||
reft = "reft" | ||
invertible = "invertible" | ||
|
||
@staticmethod | ||
def get_from_config(config) -> List[str]: | ||
""" | ||
Get the adapter type from a given adapter config. | ||
Args: | ||
config: The adapter config. | ||
Returns: | ||
List[str]: The adapter type. | ||
""" | ||
methods = [] | ||
if getattr(config, "inv_adapter", False): | ||
methods.append(AdapterMethod.invertible) | ||
if config.architecture is None: | ||
methods.append(AdapterMethod.bottleneck) | ||
elif config.architecture == "union": | ||
for sub_config in config.configs: | ||
methods.extend(AdapterMethod.get_from_config(sub_config)) | ||
else: | ||
methods.append(config.architecture) | ||
return methods | ||
|
||
|
||
@dataclass | ||
class AdapterModelInterface: | ||
""" | ||
Defines the main interface for integrating adapter methods into a model class. | ||
This interface translates generic accessor names to model-specific attribute names. | ||
Args: | ||
adapter_methods (List[str]): List of adapter types that are supported by the model. | ||
model_embeddings (str): Name of the model's embedding layer. | ||
model_layers (str): Name of the model's layer list. | ||
layer_self_attn (str): Name of the self-attention layer in a transformer layer. | ||
layer_cross_attn (str): Name of the cross-attention layer in a transformer layer. | ||
attn_k_proj (str): Name of the key projection layer in an attention layer. | ||
attn_q_proj (str): Name of the query projection layer in an attention layer. | ||
attn_v_proj (str): Name of the value projection layer in an attention layer. | ||
attn_o_proj (str): Name of the output projection layer in an attention layer. | ||
layer_intermediate_proj (str): Name of the intermediate projection layer in a transformer layer. | ||
layer_output_proj (str): Name of the output projection layer in a transformer layer. | ||
layer_pre_self_attn (Optional[str]): Hook point directly before the self attention layer. Used for extended bottleneck adapter support. | ||
layer_pre_cross_attn (Optional[str]): Hook point directly before the cross attention layer. Used for extended bottleneck adapter support. | ||
layer_pre_ffn (Optional[str]): Hook point directly before the feed forward layer. Used for extended bottleneck adapter support. | ||
layer_ln_1 (Optional[str]): Layer norm *after* the self-attention layer. Used for extended bottleneck adapter support. | ||
layer_ln_2 (Optional[str]): Layer norm *after* the feed forward layer. Used for extended bottleneck adapter support. | ||
""" | ||
|
||
adapter_methods: List[str] | ||
|
||
model_embeddings: str | ||
model_layers: str | ||
|
||
layer_self_attn: str | ||
layer_cross_attn: str | ||
attn_k_proj: str | ||
attn_q_proj: str | ||
attn_v_proj: str | ||
attn_o_proj: str | ||
|
||
layer_intermediate_proj: str | ||
layer_output_proj: str | ||
|
||
# Optional attributes for extended bottleneck adapter support | ||
layer_pre_self_attn: Optional[str] = None | ||
layer_pre_cross_attn: Optional[str] = None | ||
layer_pre_ffn: Optional[str] = None | ||
layer_ln_1: Optional[str] = None | ||
layer_ln_2: Optional[str] = None | ||
|
||
def to_dict(self): | ||
return asdict(self) | ||
|
||
def _save(self, save_directory, model_config): | ||
config_dict = { | ||
"model_type": model_config.model_type, | ||
"interface": self.to_dict(), | ||
"version": "adapters." + __version__, | ||
} | ||
save_path = os.path.join(save_directory, INTERFACE_CONFIG_NAME) | ||
with open(save_path, "w") as f: | ||
json.dump(config_dict, f, indent=2, sort_keys=True) | ||
|
||
@classmethod | ||
def _load(cls, path_or_repo_id: str, **kwargs): | ||
resolved_file = cached_file(path_or_repo_id, INTERFACE_CONFIG_NAME, **kwargs) | ||
with open(resolved_file, "r") as f: | ||
config_dict = json.load(f) | ||
return AdapterModelInterface(**config_dict["interface"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from .bottleneck import init_bottleneck | ||
from .invertible import init_invertible_adapters | ||
from .lora import init_lora | ||
from .prompt_tuning import init_prompt_tuning | ||
from .reft import init_reft | ||
|
||
|
||
METHOD_INIT_MAPPING = { | ||
"bottleneck": init_bottleneck, | ||
"lora": init_lora, | ||
"prompt_tuning": init_prompt_tuning, | ||
"reft": init_reft, | ||
"invertible": init_invertible_adapters, | ||
} |
Oops, something went wrong.