Skip to content

Commit

Permalink
Option for initializing adapters with identical weights (#786)
Browse files Browse the repository at this point in the history
Adresses #653 and introduces an additional parameter in the adapter
config that allows to fix a seed during adapter weights initialization
**for every layer**.

This means that the seed is reset for every layer to the specified
value, which leads to all adapter modules having the same weights upon
initialization.

The PR adds this option as an additonal argument `init_weights_seed` in
the adapter config and also provides an additional test that checks for
identical weight initialization between multiple adapters and multiple
models.

Closes #653
  • Loading branch information
TimoImhof authored Feb 26, 2025
1 parent acca075 commit ba2b3ba
Show file tree
Hide file tree
Showing 15 changed files with 130 additions and 2 deletions.
21 changes: 21 additions & 0 deletions src/adapters/configuration/adapter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ class BnConfig(AdapterConfig):
Defaults to False.
init_weights (:obj:`str`, optional): Initialization method for the weights of the adapter modules.
Currently, this can be either "bert" (default) or "mam_adapter" or "houlsby".
init_weights_seed (:obj:`int`, optional): The seed to use for the initialization of the adapter weights per layer.
Important: set, the seed will be reset for all adapter modules, meaning that all adapter modules will have the same
initialization. If not set, the seed will be set once and each adapter module has random weights initialization. Defaults to None.
is_parallel (:obj:`bool`, optional): If True, apply adapter transformations in parallel.
By default (False), sequential application is used.
scaling (:obj:`float` or :obj:`str`, optional):
Expand Down Expand Up @@ -233,6 +236,7 @@ class BnConfig(AdapterConfig):
ln_before: bool = False
ln_after: bool = False
init_weights: str = "bert"
init_weights_seed: Optional[int] = None
is_parallel: bool = False
scaling: Union[float, str] = 1.0
use_gating: bool = False
Expand Down Expand Up @@ -417,6 +421,9 @@ class PrefixTuningConfig(AdapterConfig):
shared_gating (:
obj:`bool`, optional): Whether to use a shared gate for the prefixes of all attention matrices. Only
applicable if `use_gating=True`. Defaults to True.
init_weights_seed (:obj:`int`, optional): The seed to use for the initialization of the adapter weights per layer.
Important: set, the seed will be reset for all adapter modules, meaning that all adapter modules will have the same
initialization. If not set, the seed will be set once and each adapter module has random weights initialization. Defaults to None.
"""

architecture: Optional[str] = "prefix_tuning"
Expand All @@ -432,6 +439,7 @@ class PrefixTuningConfig(AdapterConfig):
dropout: float = 0.0
use_gating: bool = False
shared_gating: bool = True
init_weights_seed: Optional[int] = None


@dataclass(eq=False)
Expand All @@ -450,6 +458,9 @@ class PromptTuningConfig(AdapterConfig):
combine (str):
The method used to combine the prompt with the input. Can be either "prefix" or "prefix_after_bos".
Defaults to "prefix".
init_weights_seed (:obj:`int`, optional): The seed to use for the initialization of the adapter weights per layer.
Important: set, the seed will be reset for all adapter modules, meaning that all adapter modules will have the same
initialization. If not set, the seed will be set once and each adapter module has random weights initialization. Defaults to None.
"""

architecture: str = "prompt_tuning"
Expand All @@ -459,6 +470,7 @@ class PromptTuningConfig(AdapterConfig):
prompt_init_text: Optional[str] = None
random_uniform_scale = 0.5
combine: str = "prefix"
init_weights_seed: Optional[int] = None


@dataclass(eq=False)
Expand Down Expand Up @@ -488,6 +500,9 @@ class LoRAConfig(AdapterConfig):
(IA)^3). "scale" can only be used together with r=1. Defaults to "add".
init_weights (:obj:`str`, optional): Initialization method for the weights of the LoRA modules.
Currently, this can be either "lora" (default) or "bert".
init_weights_seed (:obj:`int`, optional): The seed to use for the initialization of the adapter weights per layer.
Important: set, the seed will be reset for all adapter modules, meaning that all adapter modules will have the same
initialization. If not set, the seed will be set once and each adapter module has random weights initialization. Defaults to None.
use_gating (:obj:`bool`, optional):
Place a trainable gating module besides the added parameter module to control module activation. This is
e.g. used for UniPELT. Defaults to False. Note that modules with use_gating=True cannot be merged using
Expand All @@ -508,6 +523,7 @@ class LoRAConfig(AdapterConfig):
attn_matrices: List[str] = field(default_factory=lambda: ["q", "v"])
composition_mode: str = "add"
init_weights: str = "lora"
init_weights_seed: Optional[int] = None
use_gating: bool = False
dtype: Optional[str] = None

Expand Down Expand Up @@ -553,6 +569,10 @@ class ReftConfig(AdapterConfig):
dropout (float): The dropout rate used in the intervention layer.
non_linearity (str): The activation function used in the intervention layer.
dtype (str, optional): torch dtype for intervention tensors. Defaults to None.
init_weights_seed (:obj:`int`, optional): The seed to use for the initialization of the adapter weights per layer.
Important: set, the seed will be reset for all adapter modules, meaning that all adapter modules will have the same
initialization. If not set, the seed will be set once and each adapter module has random weights initialization. Defaults to None.
"""

layers: Union[Literal["all"], List[int]]
Expand All @@ -569,6 +589,7 @@ class ReftConfig(AdapterConfig):
architecture: str = "reft"

output_reft: bool = True
init_weights_seed: Optional[int] = None


@dataclass(eq=False)
Expand Down
8 changes: 7 additions & 1 deletion src/adapters/methods/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..composition import Average, BatchSplit, Parallel, Stack
from ..configuration import LoRAConfig, ModelAdaptersConfig
from .adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase
from .utils import dequantize_bnb_weight
from .utils import dequantize_bnb_weight, fix_seed


try:
Expand Down Expand Up @@ -57,6 +57,9 @@ def __init__(
self.lora_B = nn.Parameter(torch.zeros(lora_B_shape, dtype=dtype))
self.scaling = self.lora_alpha / self.r

# Set seed for reproducibility if specified in config
fix_seed(config.init_weights_seed)

# For compatibility with (IA)^3, allow all init_weights types here.
# Usually should be "lora".
if config.init_weights == "lora":
Expand Down Expand Up @@ -130,6 +133,9 @@ def __init__(
self.lora_B = nn.Parameter(torch.zeros(lora_B_shape))
self.scaling = self.lora_alpha

# Set seed for reproducibility if specified in config
fix_seed(config.init_weights_seed)

# For compatibility with LoRA, allow all init_weights types here.
# Usually should be "ia3".
if config.init_weights == "lora":
Expand Down
8 changes: 8 additions & 0 deletions src/adapters/methods/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ..configuration import AdapterFusionConfig, BnConfig
from ..context import ForwardContext
from .utils import fix_seed


class Activation_Function_Class(nn.Module):
Expand Down Expand Up @@ -115,6 +116,9 @@ def __init__(

self.dropout = nn.Dropout(p=config["dropout"])

# Set seed for reproducibility if specified in config
fix_seed(config.init_weights_seed)

# if we want to initialize with the bert strategy then this function is called for all the linear layers
if config["init_weights"] == "bert":
self.adapter_down.apply(self.init_bert_weights)
Expand Down Expand Up @@ -656,6 +660,10 @@ def _init_W(self, W_left=None, W_right=None, W=None):
return init_W(self.config, W_left, W_right, W)

def reset_parameters(self):

# Set seed for reproducibility if specified in config
fix_seed(self.config.init_weights_seed)

if not self.shared_W_phm:
self._init_W()

Expand Down
6 changes: 6 additions & 0 deletions src/adapters/methods/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..context import AdapterSetup, ForwardContext
from .adapter_layer_base import ComposableAdapterLayerBase
from .modeling import Activation_Function_Class
from .utils import fix_seed


class PrefixTuning(nn.Module, ModuleUtilsMixin):
Expand All @@ -30,6 +31,8 @@ def __init__(
self.n_embd_per_head = n_embd_per_head or self.input_size // self.n_heads
self.config = config

# Set seed for reproducibility if specified in config
fix_seed(self.config.init_weights_seed)
self.wte = nn.Embedding(self.config.prefix_length, self.input_size)
self.control_trans = nn.Sequential(
nn.Linear(self.input_size, self.config.bottleneck_size),
Expand Down Expand Up @@ -80,6 +83,9 @@ def __init__(
self.n_embd_per_head = n_embd_per_head or self.input_size // self.n_heads
self.config = config

# Set seed for reproducibility if specified in config
fix_seed(self.config.init_weights_seed)

self.control_trans = nn.Parameter(
torch.randn(self.config.prefix_length * self.n_layers * 2 * self.n_heads * self.n_embd_per_head)
)
Expand Down
5 changes: 5 additions & 0 deletions src/adapters/methods/prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..configuration import ModelAdaptersConfig, PromptTuningConfig
from ..context import ForwardContext
from .adapter_layer_base import AdapterLayerBase
from .utils import fix_seed


class PromptTuning(nn.Module):
Expand Down Expand Up @@ -65,6 +66,10 @@ def __init__(
)

def _init_prompt_embedding(self, base_model_embeddings: nn.Module) -> None:

# Set seed for reproducibility if specified in config
fix_seed(self.prompt_tuning_config.init_weights_seed)

if self.prompt_tuning_config.prompt_init == "random_uniform":
nn.init.uniform_(
self.prompt_embedding.weight,
Expand Down
8 changes: 7 additions & 1 deletion src/adapters/methods/reft.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..context import ForwardContext
from .adapter_layer_base import AdapterLayerBase
from .modeling import Activation_Function_Class
from .utils import fix_seed


logger = logging.getLogger(__name__)
Expand All @@ -22,13 +23,17 @@ def __init__(
subtract_projection: bool = True,
non_linearity: str = None,
dropout: float = 0.0,
init_weights_seed: int = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.orthogonal = orthogonal
self.learned_source = nn.Linear(in_dim, r_dim, bias=True, dtype=dtype)

# Set seed for reproducibility if specified in config
fix_seed(init_weights_seed)
self.learned_source = nn.Linear(in_dim, r_dim, bias=True, dtype=dtype)
projection = nn.Linear(in_dim, r_dim, bias=False, dtype=dtype)

if orthogonal:
# orthogonal is not implemented for half precision
if dtype in [torch.float16, torch.bfloat16]:
Expand Down Expand Up @@ -72,6 +77,7 @@ def __init__(self, in_features: int, config: ReftConfig):
config.subtract_projection,
config.non_linearity,
config.dropout,
config.init_weights_seed,
dtype,
)
for _ in range(n_units)
Expand Down
13 changes: 13 additions & 0 deletions src/adapters/methods/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch


def fix_seed(seed: Optional[int] = None):
"""
Helper function to fix the torch seed on cpu and gpu for initializing adapters with the same weights.
Is only executed if the config provides a respective seed.
"""
if seed:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)


# Copied from https://github.com/huggingface/peft/blob/main/src/peft/utils/integrations.py.
def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
"""
Expand Down
31 changes: 31 additions & 0 deletions tests/test_methods/method_test_impl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,34 @@ def run_generate_test(self, adapter_config, max_new_tokens=32):
generate_input = self.build_generate_input(self.input_shape).to(torch_device)
generated = model.generate(generate_input, max_new_tokens=max_new_tokens)
self.assertLessEqual(generated.shape, (self.input_shape[0], self.input_shape[1] + max_new_tokens))

def run_same_weights_test(self, adapter_config, filter_keys):

# Check one model with multiple adapters with same config
model = self.get_model()
num_adapters = 2
per_model_filter_keys = {}

for i in range(num_adapters):
model.add_adapter(f"adapter{i}", config=adapter_config)
for i in range(num_adapters):
name = f"adapter{i}"
per_model_filter_keys[name] = [k.format(name=name) for k in filter_keys]

for (k1, v1), (k2, v2) in zip(
self._filter_parameters(model, per_model_filter_keys["adapter0"]).items(),
self._filter_parameters(model, per_model_filter_keys["adapter1"]).items(),
):
self.assertTrue(torch.equal(v1, v2), msg=f"{k1} has different weights than {k2}")

# Check multiple models with one adapter with same config
model1, model2 = create_twin_models(self.model_class, self.config)
model1.add_adapter("adapter", config=adapter_config)
model2.add_adapter("adapter", config=adapter_config)
per_model_filter_keys = {"adapter": [k.format(name="adapter") for k in filter_keys]}

for (k1, v1), (k2, v2) in zip(
self._filter_parameters(model1, per_model_filter_keys["adapter"]).items(),
self._filter_parameters(model2, per_model_filter_keys["adapter"]).items(),
):
self.assertTrue(torch.equal(v1, v2), msg=f"{k1} has different weights than {k2}")
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,7 @@ def test_load_adapter_setup(self):
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.allclose(output1[0][0], output2[0][0], atol=1e-4))
self.assertTrue(torch.allclose(output1[1][0], output2[1][0], atol=1e-4))

def test_same_weights_after_adding_adapter(self):
# setting init_weights_seed should leed to every adapter layer having the same weights after initialization
self.run_same_weights_test(SeqBnConfig(init_weights_seed=42), ["adapters.{name}."])
6 changes: 6 additions & 0 deletions tests/test_methods/method_test_impl/peft/test_compacter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,9 @@ def test_train_shared_phm_compacter(self):

def test_compacter_generate(self):
self.run_generate_test(CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8))

def test_same_weights_after_adding_adapter(self):
# setting init_weights_seed should leed to every adapter layer having the same weights after initialization
self.run_same_weights_test(
CompacterPlusPlusConfig(phm_dim=2, reduction_factor=8, init_weights_seed=42), ["adapters.{name}."]
)
4 changes: 4 additions & 0 deletions tests/test_methods/method_test_impl/peft/test_ia3.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,7 @@ def test_reset_ia3(self):

def test_ia3_gradient_checkpointing_single_adapter(self):
self.run_gradient_checkpointing_single_adapter_test(IA3Config())

def test_same_weights_after_adding_adapter(self):
# setting init_weights_seed should leed to every adapter layer having the same weights after initialization
self.run_same_weights_test(IA3Config(init_weights_seed=42), ["loras.{name}."])
4 changes: 4 additions & 0 deletions tests/test_methods/method_test_impl/peft/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,7 @@ def test_reset_lora(self):

def test_lora_gradient_checkpointing_single_adapter(self):
self.run_gradient_checkpointing_single_adapter_test(LoRAConfig())

def test_same_weights_after_adding_adapter(self):
# setting init_weights_seed should leed to every adapter layer having the same weights after initialization
self.run_same_weights_test(LoRAConfig(init_weights_seed=42), ["loras.{name}."])
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,7 @@ def test_prefix_tuning_generate(self):

def test_prefix_tuning_gradient_checkpointing_single_adapter(self):
self.run_gradient_checkpointing_single_adapter_test(PrefixTuningConfig())

def test_same_weights_after_adding_adapter(self):
# setting init_weights_seed should leed to every adapter layer having the same weights after initialization
self.run_same_weights_test(PrefixTuningConfig(init_weights_seed=42), ["prefix_tunings.{name}."])
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,9 @@ def test_train_prompt_tuning(self):

def test_prompt_tuning_gradient_checkpointing_single_adapter(self):
self.run_gradient_checkpointing_single_adapter_test(PromptTuningConfig(prompt_length=10))

def test_same_weights_after_adding_adapter(self):
# setting init_weights_seed should leed to every adapter layer having the same weights after initialization
self.run_same_weights_test(
PromptTuningConfig(init_weights_seed=42, prompt_length=10), ["prompt_tunings.{name}."]
)
4 changes: 4 additions & 0 deletions tests/test_methods/method_test_impl/peft/test_reft.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,7 @@ def test_reft_generate(self):

def test_reft_gradient_checkpointing_single_adapter(self):
self.run_gradient_checkpointing_single_adapter_test(LoReftConfig())

def test_same_weights_after_adding_adapter(self):
# setting init_weights_seed should leed to every adapter layer having the same weights after initialization
self.run_same_weights_test(LoReftConfig(init_weights_seed=42), ["refts.{name}."])

0 comments on commit ba2b3ba

Please sign in to comment.