Skip to content

Commit

Permalink
Add to residual methods & create test
Browse files Browse the repository at this point in the history
  • Loading branch information
TimoImhof committed Jan 27, 2025
1 parent 92bb9d4 commit b621b8e
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 1 deletion.
13 changes: 13 additions & 0 deletions src/adapters/configuration/adapter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,9 @@ class PrefixTuningConfig(AdapterConfig):
shared_gating (:
obj:`bool`, optional): Whether to use a shared gate for the prefixes of all attention matrices. Only
applicable if `use_gating=True`. Defaults to True.
init_weights_seed (:obj:`int`, optional): The seed to use for the initialization of the adapter weights per layer.
Important: set, the seed will be reset for all adapter modules, meaning that all adapter modules will have the same
initialization. If not set, the seed will be set once and each adapter module has random weights initialization. Defaults to None.
"""

architecture: Optional[str] = "prefix_tuning"
Expand All @@ -436,6 +439,7 @@ class PrefixTuningConfig(AdapterConfig):
dropout: float = 0.0
use_gating: bool = False
shared_gating: bool = True
init_weights_seed: Optional[int] = None


@dataclass(eq=False)
Expand All @@ -454,6 +458,9 @@ class PromptTuningConfig(AdapterConfig):
combine (str):
The method used to combine the prompt with the input. Can be either "prefix" or "prefix_after_bos".
Defaults to "prefix".
init_weights_seed (:obj:`int`, optional): The seed to use for the initialization of the adapter weights per layer.
Important: set, the seed will be reset for all adapter modules, meaning that all adapter modules will have the same
initialization. If not set, the seed will be set once and each adapter module has random weights initialization. Defaults to None.
"""

architecture: str = "prompt_tuning"
Expand All @@ -463,6 +470,7 @@ class PromptTuningConfig(AdapterConfig):
prompt_init_text: Optional[str] = None
random_uniform_scale = 0.5
combine: str = "prefix"
init_weights_seed: Optional[int] = None


@dataclass(eq=False)
Expand Down Expand Up @@ -561,6 +569,10 @@ class ReftConfig(AdapterConfig):
dropout (float): The dropout rate used in the intervention layer.
non_linearity (str): The activation function used in the intervention layer.
dtype (str, optional): torch dtype for intervention tensors. Defaults to None.
init_weights_seed (:obj:`int`, optional): The seed to use for the initialization of the adapter weights per layer.
Important: set, the seed will be reset for all adapter modules, meaning that all adapter modules will have the same
initialization. If not set, the seed will be set once and each adapter module has random weights initialization. Defaults to None.
"""

layers: Union[Literal["all"], List[int]]
Expand All @@ -577,6 +589,7 @@ class ReftConfig(AdapterConfig):
architecture: str = "reft"

output_reft: bool = True
init_weights_seed: Optional[int] = None


@dataclass(eq=False)
Expand Down
4 changes: 4 additions & 0 deletions src/adapters/methods/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,10 @@ def _init_W(self, W_left=None, W_right=None, W=None):
return init_W(self.config, W_left, W_right, W)

def reset_parameters(self):

if self.config.init_weights_seed:
fix_seed(self.config.init_weights_seed)

if not self.shared_W_phm:
self._init_W()

Expand Down
6 changes: 6 additions & 0 deletions src/adapters/methods/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..context import AdapterSetup, ForwardContext
from .adapter_layer_base import ComposableAdapterLayerBase
from .modeling import Activation_Function_Class
from .utils import fix_seed


class PrefixTuning(nn.Module, ModuleUtilsMixin):
Expand All @@ -30,6 +31,8 @@ def __init__(
self.n_embd_per_head = n_embd_per_head or self.input_size // self.n_heads
self.config = config

if self.config.init_weights_seed:
fix_seed(self.config.init_weights_seed)
self.wte = nn.Embedding(self.config.prefix_length, self.input_size)
self.control_trans = nn.Sequential(
nn.Linear(self.input_size, self.config.bottleneck_size),
Expand Down Expand Up @@ -80,6 +83,9 @@ def __init__(
self.n_embd_per_head = n_embd_per_head or self.input_size // self.n_heads
self.config = config

if self.config.init_weights_seed:
fix_seed(self.config.init_weights_seed)

self.control_trans = nn.Parameter(
torch.randn(self.config.prefix_length * self.n_layers * 2 * self.n_heads * self.n_embd_per_head)
)
Expand Down
5 changes: 5 additions & 0 deletions src/adapters/methods/prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..configuration import ModelAdaptersConfig, PromptTuningConfig
from ..context import ForwardContext
from .adapter_layer_base import AdapterLayerBase
from .utils import fix_seed


class PromptTuning(nn.Module):
Expand Down Expand Up @@ -65,6 +66,10 @@ def __init__(
)

def _init_prompt_embedding(self, base_model_embeddings: nn.Module) -> None:

if self.prompt_tuning_config.init_weights_seed:
fix_seed(self.prompt_tuning_config.init_weights_seed)

if self.prompt_tuning_config.prompt_init == "random_uniform":
nn.init.uniform_(
self.prompt_embedding.weight,
Expand Down
8 changes: 7 additions & 1 deletion src/adapters/methods/reft.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..context import ForwardContext
from .adapter_layer_base import AdapterLayerBase
from .modeling import Activation_Function_Class
from .utils import fix_seed


logger = logging.getLogger(__name__)
Expand All @@ -22,13 +23,17 @@ def __init__(
subtract_projection: bool = True,
non_linearity: str = None,
dropout: float = 0.0,
init_weights_seed: int = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.orthogonal = orthogonal
self.learned_source = nn.Linear(in_dim, r_dim, bias=True, dtype=dtype)

if init_weights_seed:
fix_seed(init_weights_seed)
self.learned_source = nn.Linear(in_dim, r_dim, bias=True, dtype=dtype)
projection = nn.Linear(in_dim, r_dim, bias=False, dtype=dtype)

if orthogonal:
# orthogonal is not implemented for half precision
if dtype in [torch.float16, torch.bfloat16]:
Expand Down Expand Up @@ -72,6 +77,7 @@ def __init__(self, in_features: int, config: ReftConfig):
config.subtract_projection,
config.non_linearity,
config.dropout,
config.init_weights_seed,
dtype,
)
for _ in range(n_units)
Expand Down
31 changes: 31 additions & 0 deletions tests/test_methods/method_test_impl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,34 @@ def run_generate_test(self, adapter_config, max_new_tokens=32):
generate_input = self.build_generate_input(self.input_shape).to(torch_device)
generated = model.generate(generate_input, max_new_tokens=max_new_tokens)
self.assertLessEqual(generated.shape, (self.input_shape[0], self.input_shape[1] + max_new_tokens))

def run_same_weights_test(self, adapter_config, filter_keys):

# Check one model with multiple adapters with same config
model = self.get_model()
num_adapters = 2
per_model_filter_keys = {}

for i in range(num_adapters):
model.add_adapter(f"adapter{i}", config=adapter_config)
for i in range(num_adapters):
name = f"adapter{i}"
per_model_filter_keys[name] = [k.format(name=name) for k in filter_keys]

for (k1, v1), (k2, v2) in zip(
self._filter_parameters(model, per_model_filter_keys["adapter0"]).items(),
self._filter_parameters(model, per_model_filter_keys["adapter1"]).items(),
):
self.assertTrue(torch.equal(v1, v2), msg=f"{k1} has different weights than {k2}")

# Check multiple models with one adapter with same config
model1, model2 = create_twin_models(self.model_class, self.config)
model1.add_adapter("adapter", config=adapter_config)
model2.add_adapter("adapter", config=adapter_config)
per_model_filter_keys = {"adapter": [k.format(name="adapter") for k in filter_keys]}

for (k1, v1), (k2, v2) in zip(
self._filter_parameters(model1, per_model_filter_keys["adapter"]).items(),
self._filter_parameters(model2, per_model_filter_keys["adapter"]).items(),
):
self.assertTrue(torch.equal(v1, v2), msg=f"{k1} has different weights than {k2}")
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,7 @@ def test_load_adapter_setup(self):
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.allclose(output1[0][0], output2[0][0], atol=1e-4))
self.assertTrue(torch.allclose(output1[1][0], output2[1][0], atol=1e-4))

def test_same_weights_after_adding_adapter(self):
# setting init_weights_seed should leed to every adapter layer having the same weights after initialization
self.run_same_weights_test(SeqBnConfig(init_weights_seed=42), ["adapters.{name}."])
6 changes: 6 additions & 0 deletions tests/test_methods/method_test_impl/peft/test_compacter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,9 @@ def test_train_shared_phm_compacter(self):

def test_compacter_generate(self):
self.run_generate_test(CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8))

def test_same_weights_after_adding_adapter(self):
# setting init_weights_seed should leed to every adapter layer having the same weights after initialization
self.run_same_weights_test(
CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8, init_weights_seed=42), ["adapters.{name}."]
)
4 changes: 4 additions & 0 deletions tests/test_methods/method_test_impl/peft/test_ia3.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,7 @@ def test_reset_ia3(self):

def test_ia3_gradient_checkpointing_single_adapter(self):
self.run_gradient_checkpointing_single_adapter_test(IA3Config())

def test_same_weights_after_adding_adapter(self):
# setting init_weights_seed should leed to every adapter layer having the same weights after initialization
self.run_same_weights_test(IA3Config(init_weights_seed=42), ["loras.{name}."])
4 changes: 4 additions & 0 deletions tests/test_methods/method_test_impl/peft/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,7 @@ def test_reset_lora(self):

def test_lora_gradient_checkpointing_single_adapter(self):
self.run_gradient_checkpointing_single_adapter_test(LoRAConfig())

def test_same_weights_after_adding_adapter(self):
# setting init_weights_seed should leed to every adapter layer having the same weights after initialization
self.run_same_weights_test(LoRAConfig(init_weights_seed=42), ["loras.{name}."])
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,7 @@ def test_prefix_tuning_generate(self):

def test_prefix_tuning_gradient_checkpointing_single_adapter(self):
self.run_gradient_checkpointing_single_adapter_test(PrefixTuningConfig())

def test_same_weights_after_adding_adapter(self):
# setting init_weights_seed should leed to every adapter layer having the same weights after initialization
self.run_same_weights_test(PrefixTuningConfig(init_weights_seed=42), ["prefix_tunings.{name}."])
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,9 @@ def test_train_prompt_tuning(self):

def test_prompt_tuning_gradient_checkpointing_single_adapter(self):
self.run_gradient_checkpointing_single_adapter_test(PromptTuningConfig(prompt_length=10))

def test_same_weights_after_adding_adapter(self):
# setting init_weights_seed should leed to every adapter layer having the same weights after initialization
self.run_same_weights_test(
PromptTuningConfig(init_weights_seed=42, prompt_length=10), ["prompt_tunings.{name}."]
)
4 changes: 4 additions & 0 deletions tests/test_methods/method_test_impl/peft/test_reft.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,7 @@ def test_reft_generate(self):

def test_reft_gradient_checkpointing_single_adapter(self):
self.run_gradient_checkpointing_single_adapter_test(LoReftConfig())

def test_same_weights_after_adding_adapter(self):
# setting init_weights_seed should leed to every adapter layer having the same weights after initialization
self.run_same_weights_test(LoReftConfig(init_weights_seed=42), ["refts.{name}."])

0 comments on commit b621b8e

Please sign in to comment.