From dc1fa5078a1be1c844a309f6f4469e44ee1144b8 Mon Sep 17 00:00:00 2001 From: Timo Imhof Date: Thu, 27 Feb 2025 00:05:40 +0100 Subject: [PATCH] Remove redundant if clause and update documentation --- src/adapters/methods/modeling.py | 4 ++-- src/adapters/methods/prefix_tuning.py | 8 ++++---- src/adapters/methods/prompt_tuning.py | 4 ++-- src/adapters/methods/reft.py | 4 ++-- src/adapters/methods/utils.py | 1 + 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/adapters/methods/modeling.py b/src/adapters/methods/modeling.py index 7fab931ec..df86fbae8 100644 --- a/src/adapters/methods/modeling.py +++ b/src/adapters/methods/modeling.py @@ -661,8 +661,8 @@ def _init_W(self, W_left=None, W_right=None, W=None): def reset_parameters(self): - if self.config.init_weights_seed: - fix_seed(self.config.init_weights_seed) + # Set seed for reproducibility if specified in config + fix_seed(self.config.init_weights_seed) if not self.shared_W_phm: self._init_W() diff --git a/src/adapters/methods/prefix_tuning.py b/src/adapters/methods/prefix_tuning.py index 2e0af9eb4..d49e88c5b 100644 --- a/src/adapters/methods/prefix_tuning.py +++ b/src/adapters/methods/prefix_tuning.py @@ -31,8 +31,8 @@ def __init__( self.n_embd_per_head = n_embd_per_head or self.input_size // self.n_heads self.config = config - if self.config.init_weights_seed: - fix_seed(self.config.init_weights_seed) + # Set seed for reproducibility if specified in config + fix_seed(self.config.init_weights_seed) self.wte = nn.Embedding(self.config.prefix_length, self.input_size) self.control_trans = nn.Sequential( nn.Linear(self.input_size, self.config.bottleneck_size), @@ -83,8 +83,8 @@ def __init__( self.n_embd_per_head = n_embd_per_head or self.input_size // self.n_heads self.config = config - if self.config.init_weights_seed: - fix_seed(self.config.init_weights_seed) + # Set seed for reproducibility if specified in config + fix_seed(self.config.init_weights_seed) self.control_trans = nn.Parameter( torch.randn(self.config.prefix_length * self.n_layers * 2 * self.n_heads * self.n_embd_per_head) diff --git a/src/adapters/methods/prompt_tuning.py b/src/adapters/methods/prompt_tuning.py index 32520aa89..b9504ac40 100644 --- a/src/adapters/methods/prompt_tuning.py +++ b/src/adapters/methods/prompt_tuning.py @@ -67,8 +67,8 @@ def __init__( def _init_prompt_embedding(self, base_model_embeddings: nn.Module) -> None: - if self.prompt_tuning_config.init_weights_seed: - fix_seed(self.prompt_tuning_config.init_weights_seed) + # Set seed for reproducibility if specified in config + fix_seed(self.prompt_tuning_config.init_weights_seed) if self.prompt_tuning_config.prompt_init == "random_uniform": nn.init.uniform_( diff --git a/src/adapters/methods/reft.py b/src/adapters/methods/reft.py index 44fcfdf73..1884bf5e0 100644 --- a/src/adapters/methods/reft.py +++ b/src/adapters/methods/reft.py @@ -29,8 +29,8 @@ def __init__( super().__init__() self.orthogonal = orthogonal - if init_weights_seed: - fix_seed(init_weights_seed) + # Set seed for reproducibility if specified in config + fix_seed(init_weights_seed) self.learned_source = nn.Linear(in_dim, r_dim, bias=True, dtype=dtype) projection = nn.Linear(in_dim, r_dim, bias=False, dtype=dtype) diff --git a/src/adapters/methods/utils.py b/src/adapters/methods/utils.py index 3a37486c2..d0161fa74 100644 --- a/src/adapters/methods/utils.py +++ b/src/adapters/methods/utils.py @@ -20,6 +20,7 @@ def fix_seed(seed: Optional[int] = None): """ Helper function to fix the torch seed on cpu and gpu for initializing adapters with the same weights. + Is only executed if the config provides a respective seed. """ if seed: torch.manual_seed(seed)