Skip to content

Commit

Permalink
Remove redundant if clause and update documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
TimoImhof committed Feb 26, 2025
1 parent c5c2981 commit dc1fa50
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/adapters/methods/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,8 @@ def _init_W(self, W_left=None, W_right=None, W=None):

def reset_parameters(self):

if self.config.init_weights_seed:
fix_seed(self.config.init_weights_seed)
# Set seed for reproducibility if specified in config
fix_seed(self.config.init_weights_seed)

if not self.shared_W_phm:
self._init_W()
Expand Down
8 changes: 4 additions & 4 deletions src/adapters/methods/prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def __init__(
self.n_embd_per_head = n_embd_per_head or self.input_size // self.n_heads
self.config = config

if self.config.init_weights_seed:
fix_seed(self.config.init_weights_seed)
# Set seed for reproducibility if specified in config
fix_seed(self.config.init_weights_seed)
self.wte = nn.Embedding(self.config.prefix_length, self.input_size)
self.control_trans = nn.Sequential(
nn.Linear(self.input_size, self.config.bottleneck_size),
Expand Down Expand Up @@ -83,8 +83,8 @@ def __init__(
self.n_embd_per_head = n_embd_per_head or self.input_size // self.n_heads
self.config = config

if self.config.init_weights_seed:
fix_seed(self.config.init_weights_seed)
# Set seed for reproducibility if specified in config
fix_seed(self.config.init_weights_seed)

self.control_trans = nn.Parameter(
torch.randn(self.config.prefix_length * self.n_layers * 2 * self.n_heads * self.n_embd_per_head)
Expand Down
4 changes: 2 additions & 2 deletions src/adapters/methods/prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def __init__(

def _init_prompt_embedding(self, base_model_embeddings: nn.Module) -> None:

if self.prompt_tuning_config.init_weights_seed:
fix_seed(self.prompt_tuning_config.init_weights_seed)
# Set seed for reproducibility if specified in config
fix_seed(self.prompt_tuning_config.init_weights_seed)

if self.prompt_tuning_config.prompt_init == "random_uniform":
nn.init.uniform_(
Expand Down
4 changes: 2 additions & 2 deletions src/adapters/methods/reft.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(
super().__init__()
self.orthogonal = orthogonal

if init_weights_seed:
fix_seed(init_weights_seed)
# Set seed for reproducibility if specified in config
fix_seed(init_weights_seed)
self.learned_source = nn.Linear(in_dim, r_dim, bias=True, dtype=dtype)
projection = nn.Linear(in_dim, r_dim, bias=False, dtype=dtype)

Expand Down
1 change: 1 addition & 0 deletions src/adapters/methods/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
def fix_seed(seed: Optional[int] = None):
"""
Helper function to fix the torch seed on cpu and gpu for initializing adapters with the same weights.
Is only executed if the config provides a respective seed.
"""
if seed:
torch.manual_seed(seed)
Expand Down

0 comments on commit dc1fa50

Please sign in to comment.