Skip to content

Commit

Permalink
Rename AdapterType -> AdapterMethod. Test fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Jan 6, 2025
1 parent 535dd9c commit 7d346db
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"Seq2SeqLMHead",
"TaggingHead",
],
"interface": ["AdapterModelInterface"],
"interface": ["AdapterMethod", "AdapterModelInterface"],
"methods.adapter_layer_base": ["AdapterLayerBase", "ComposableAdapterLayerBase"],
"model_mixin": [
"EmbeddingAdaptersMixin",
Expand Down Expand Up @@ -199,7 +199,7 @@
Seq2SeqLMHead,
TaggingHead,
)
from .interface import AdapterModelInterface
from .interface import AdapterMethod, AdapterModelInterface
from .methods.adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase
from .model_mixin import (
EmbeddingAdaptersMixin,
Expand Down
8 changes: 4 additions & 4 deletions src/adapters/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import List, Optional


class AdapterType:
class AdapterMethod:
"""
Enum for the different adapter types.
Enum of all supported adapter method types.
"""

bottleneck = "bottleneck"
Expand All @@ -25,9 +25,9 @@ def get_from_config(config) -> List[str]:
str: The adapter type.
"""
if config.architecture is None:
return [AdapterType.bottleneck]
return [AdapterMethod.bottleneck]
elif config.architecture == "union":
return [AdapterType.get_from_config(sub_config) for sub_config in config.configs]
return [AdapterMethod.get_from_config(sub_config) for sub_config in config.configs]
else:
return [config.architecture]

Expand Down
10 changes: 5 additions & 5 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .configuration import ADAPTER_CONFIG_MAP, AdapterConfig, AdapterFusionConfig, BnConfig
from .context import AdapterSetup, ForwardContext
from .hub_mixin import PushAdapterToHubMixin
from .interface import AdapterModelInterface, AdapterType
from .interface import AdapterMethod, AdapterModelInterface
from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader
from .methods import METHOD_INIT_MAPPING
from .methods.adapter_layer_base import AdapterLayerBase
Expand Down Expand Up @@ -484,16 +484,16 @@ def supports_adapter(self, type_or_config: Union[str, AdapterConfig]) -> bool:
bool: True if the adapter type is supported, False otherwise.
"""
if isinstance(type_or_config, AdapterConfig):
types = AdapterType.get_from_config(type_or_config)
types = AdapterMethod.get_from_config(type_or_config)
else:
types = [type_or_config]

supported = []
for _type in types:
if getattr(self.base_model, "adapter_interface", None) is not None:
supported.append(_type in self.base_model.adapter_interface.adapter_types)
elif _type == AdapterType.prompt_tuning:
supported.append(self.support_prompt_tuning)
elif _type == AdapterMethod.prompt_tuning:
supported.append(self.base_model.support_prompt_tuning)
else:
supported.append(True)
return all(supported)
Expand Down Expand Up @@ -624,7 +624,7 @@ def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False
"""
config = AdapterConfig.load(config) # ensure config is ok and up-to-date
# check if config is valid for this model
config_or_type = config or AdapterType.bottleneck
config_or_type = config or AdapterMethod.bottleneck
if not self.supports_adapter(config_or_type):
raise ValueError(f"Adapter config or type '{config_or_type}' is not supported by this model.")
# In case adapter already exists and we allow overwriting, explicitly delete the existing one first
Expand Down
16 changes: 10 additions & 6 deletions tests/methods/test_adapter_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
DoubleSeqBnConfig,
DoubleSeqBnInvConfig,
Fuse,
InvertibleAdaptersMixin,
InvertibleAdaptersWrapperMixin,
MAMConfig,
SeqBnConfig,
SeqBnInvConfig,
Expand Down Expand Up @@ -72,7 +70,7 @@ def test_delete_adapter(self):
def test_add_adapter_with_invertible(self):
model = self.get_model().base_model
model.eval()
if not isinstance(model, InvertibleAdaptersMixin) and not isinstance(model, InvertibleAdaptersWrapperMixin):
if not model.supports_adapter("invertible"):
self.skipTest("Model does not support invertible adapters.")

for adapter_config in [SeqBnInvConfig(), DoubleSeqBnInvConfig()]:
Expand Down Expand Up @@ -120,7 +118,7 @@ def test_delete_adapter_with_invertible(self):
"""Tests if the invertible adapters are deleted correctly."""
model = self.get_model().base_model
model.eval()
if not isinstance(model, InvertibleAdaptersMixin) and not isinstance(model, InvertibleAdaptersWrapperMixin):
if not model.supports_adapter("invertible"):
self.skipTest("Model does not support invertible adapters.")

# iterate through all adapter invertible adapter configs
Expand Down Expand Up @@ -245,6 +243,8 @@ def test_model_config_serialization(self):
"""
for k, v in ADAPTER_CONFIG_MAP.items():
model = self.get_model()
if not model.supports_adapter(v):
continue
# HACK: reduce the reduction factor such that
# the small test model can have a phm_dim of 4
if hasattr(v, "phm_layer") and v.phm_layer:
Expand All @@ -257,15 +257,19 @@ def test_model_adapter_summary(self):
# count model parameters before
model = self.get_model()
model_no_params = sum(p.numel() for p in model.parameters())
added = []
for k, v in ADAPTER_CONFIG_MAP.items():
if not model.supports_adapter(v):
continue
# HACK: reduce the reduction factor such that
# the small test model can have a phm_dim of 4
if hasattr(v, "phm_layer") and v.phm_layer:
v = v.__class__(reduction_factor=4)
model.add_adapter(k, config=v)
added.append(k)
summary = model.adapter_summary(as_dict=True)
self.assertEqual(len(ADAPTER_CONFIG_MAP) + 1, len(summary))
for name in ADAPTER_CONFIG_MAP.keys():
self.assertEqual(len(added) + 1, len(summary))
for name in added:
self.assertTrue(any([row["name"] == name for row in summary]))
self.assertEqual(model_no_params, summary[-1]["#param"])

Expand Down
33 changes: 21 additions & 12 deletions tests/methods/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from adapters import LoRAConfig
from adapters import AdapterConfig, LoRAConfig
from adapters.methods.lora import LoRALayer
from transformers.testing_utils import require_torch

Expand All @@ -24,21 +24,30 @@ def test_merging_with_other_adapters(self):
model.add_adapter("lora", config="lora")

# Add different adapters
model.add_adapter("bottleneck", config="seq_bn")
model.add_adapter("prompt", config="prompt_tuning")
model.add_adapter("prefix", config="prefix_tuning")
model.add_adapter("ia3", config="ia3")
model.add_adapter("unipelt", config="unipelt")
model.add_adapter("mam", config="mam")
model.add_adapter("compacter", config="compacter[phm_dim=2, reduction_factor=8]")
adapter_methods = [
"seq_bn",
"prompt_tuning",
"prefix_tuning",
"ia3",
"unipelt",
"mam",
"compacter[phm_dim=2, reduction_factor=8]",
]

for adapter_method in adapter_methods:
config = AdapterConfig.load(adapter_method)
if model.supports_adapter(config):
model.add_adapter(adapter_method, config=config)

# Merging adapters with different architectures with LoRA should raise a ValueError
for adapter_architecture in ["bottleneck", "prompt", "prefix", "ia3", "unipelt", "mam", "compacter"]:
with self.subTest(adapter_architecture=adapter_architecture):
for adapter_method in adapter_methods:
with self.subTest(adapter_architecture=adapter_method):
if adapter_method not in model.adapters_config:
continue
with self.assertRaises(ValueError):
model.average_adapter(
adapter_name=f"average_lora_{adapter_architecture}",
adapter_list=[adapter_architecture, "lora"],
adapter_name=f"average_lora_{adapter_method}",
adapter_list=[adapter_method, "lora"],
weights=[0.5, 0.5],
combine_strategy="linear",
)
Expand Down

0 comments on commit 7d346db

Please sign in to comment.