Skip to content

Commit

Permalink
Add supports_adapter() method
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Jan 5, 2025
1 parent b01cf6d commit 535dd9c
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 1 deletion.
18 changes: 18 additions & 0 deletions src/adapters/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,24 @@ class AdapterType:
prompt_tuning = "prompt_tuning"
reft = "reft"

@staticmethod
def get_from_config(config) -> List[str]:
"""
Get the adapter type from a given adapter config.
Args:
config: The adapter config.
Returns:
str: The adapter type.
"""
if config.architecture is None:
return [AdapterType.bottleneck]
elif config.architecture == "union":
return [AdapterType.get_from_config(sub_config) for sub_config in config.configs]
else:
return [config.architecture]


@dataclass
class AdapterModelInterface:
Expand Down
31 changes: 30 additions & 1 deletion src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .configuration import ADAPTER_CONFIG_MAP, AdapterConfig, AdapterFusionConfig, BnConfig
from .context import AdapterSetup, ForwardContext
from .hub_mixin import PushAdapterToHubMixin
from .interface import AdapterModelInterface
from .interface import AdapterModelInterface, AdapterType
from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader
from .methods import METHOD_INIT_MAPPING
from .methods.adapter_layer_base import AdapterLayerBase
Expand Down Expand Up @@ -473,6 +473,31 @@ def init_adapters(self, model_config, adapters_config):

self._add_tied_weights_keys()

def supports_adapter(self, type_or_config: Union[str, AdapterConfig]) -> bool:
"""
Checks if the model supports a given adapter type.
Args:
adapter_type (str): The adapter type to check.
Returns:
bool: True if the adapter type is supported, False otherwise.
"""
if isinstance(type_or_config, AdapterConfig):
types = AdapterType.get_from_config(type_or_config)
else:
types = [type_or_config]

supported = []
for _type in types:
if getattr(self.base_model, "adapter_interface", None) is not None:
supported.append(_type in self.base_model.adapter_interface.adapter_types)
elif _type == AdapterType.prompt_tuning:
supported.append(self.support_prompt_tuning)
else:
supported.append(True)
return all(supported)

# These methods have to be implemented by every deriving class:

@abstractmethod
Expand Down Expand Up @@ -598,6 +623,10 @@ def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False
the adapter is added but not activated.
"""
config = AdapterConfig.load(config) # ensure config is ok and up-to-date
# check if config is valid for this model
config_or_type = config or AdapterType.bottleneck
if not self.supports_adapter(config_or_type):
raise ValueError(f"Adapter config or type '{config_or_type}' is not supported by this model.")
# In case adapter already exists and we allow overwriting, explicitly delete the existing one first
if overwrite_ok and adapter_name in self.adapters_config:
self.delete_adapter(adapter_name)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_custom_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,11 @@ def test_train_mam_adapter(self):

def test_merging_with_other_adapters(self):
self.skipTest("Does not support all required methods yet.")

def test_supports_adapter(self):
model = self.get_model()
model.eval()

config = "unipelt"
with self.assertRaises(ValueError):
model.add_adapter("my_adapter", config=config)

0 comments on commit 535dd9c

Please sign in to comment.