Skip to content

Commit

Permalink
Save & load adapter interface with full model
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Feb 8, 2025
1 parent cbf74a9 commit 788bc8d
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 34 deletions.
29 changes: 28 additions & 1 deletion src/adapters/interface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from dataclasses import dataclass
import json
import os
from dataclasses import asdict, dataclass
from typing import List, Optional

from transformers.utils import cached_file

from . import __version__
from .utils import INTERFACE_CONFIG_NAME


class AdapterMethod:
"""
Expand Down Expand Up @@ -83,3 +90,23 @@ class AdapterModelInterface:
layer_pre_ffn: Optional[str] = None
layer_ln_1: Optional[str] = None
layer_ln_2: Optional[str] = None

def to_dict(self):
return asdict(self)

def _save(self, save_directory, model_config):
config_dict = {
"model_type": model_config.model_type,
"interface": self.to_dict(),
"version": "adapters." + __version__,
}
save_path = os.path.join(save_directory, INTERFACE_CONFIG_NAME)
with open(save_path, "w") as f:
json.dump(config_dict, f, indent=2, sort_keys=True)

@classmethod
def _load(cls, path_or_repo_id: str, **kwargs):
resolved_file = cached_file(path_or_repo_id, INTERFACE_CONFIG_NAME, **kwargs)
with open(resolved_file, "r") as f:
config_dict = json.load(f)
return AdapterModelInterface(**config_dict["interface"])
2 changes: 2 additions & 0 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1661,6 +1661,8 @@ def save_pretrained(
lambda i, layer: layer.set_pool(None) if isinstance(layer, PrefixTuningLayer) else None
)

if self.base_model.adapter_interface is not None:
self.base_model.adapter_interface._save(save_directory, self.config)
super().save_pretrained(save_directory, **kwargs)
# Remove adapters config
del self.config.adapters
Expand Down
1 change: 1 addition & 0 deletions src/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
EMBEDDING_FILE = "embedding.pt"
TOKENIZER_PATH = "tokenizer"
SETUP_CONFIG_NAME = "adapter_setup.json"
INTERFACE_CONFIG_NAME = "adapter_interface.json"

ADAPTER_HUB_URL = "https://raw.githubusercontent.com/Adapter-Hub/Hub/master/dist/v2/"
ADAPTER_HUB_INDEX_FILE = ADAPTER_HUB_URL + "index/{}.json"
Expand Down
7 changes: 7 additions & 0 deletions src/adapters/wrappers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ def load_model(

old_init = model_class.__init__

# try if we can find a interface file
if interface is None:
try:
interface = AdapterModelInterface._load(model_name_or_path, **kwargs)
except EnvironmentError:
pass

def new_init(self, config, *args, **kwargs):
old_init(self, config, *args, **kwargs)
init(self, interface=interface)
Expand Down
31 changes: 0 additions & 31 deletions tests/test_methods/test_on_custom_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,37 +68,6 @@ def _init_model_for_train_run(self, trained_adapter_name, frozen_adapter_name, a
def create_twin_models(self):
return create_twin_models(self.model_class, self.config, self.adapter_interface)

# Copied from base.py to pass custom interface to load_model
def run_full_model_load_test(self, adapter_config):
model1 = self.get_model()
model1.eval()

name = "dummy"
model1.add_adapter(name, config=adapter_config)
with tempfile.TemporaryDirectory() as temp_dir:
model1.save_pretrained(temp_dir)

model2, loading_info = load_model(
temp_dir, self.model_class, output_loading_info=True, interface=self.adapter_interface
)

# check if all weights were loaded
self.assertEqual(0, len(loading_info["missing_keys"]), loading_info["missing_keys"])
self.assertEqual(0, len(loading_info["unexpected_keys"]), loading_info["unexpected_keys"])

# check if adapter was correctly loaded
self.assertTrue(name in model2.adapters_config)

# check equal output
input_data = self.get_input_samples(config=model1.config)
model1.to(torch_device)
model2.to(torch_device)
with AdapterSetup(name):
output1 = model1(**input_data)
output2 = model2(**input_data)
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-4))

def test_load_mam_adapter(self):
self.skipTest("Does not support prefix tuning.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
from adapters import AdapterModelInterface, AutoAdapterModel
from adapters.utils import WEIGHTS_NAME
from parameterized import parameterized
from tests.test_methods.method_test_impl.utils import ids_tensor
from transformers import AutoModel, AutoModelForCausalLM, BertConfig, LlamaConfig
from transformers.testing_utils import require_torch, torch_device

from .test_adapter import ids_tensor


@require_torch
class CustomInterfaceCompatTest(unittest.TestCase):
Expand Down

0 comments on commit 788bc8d

Please sign in to comment.