Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
julian-fong committed Dec 14, 2024
1 parent 385cd35 commit 46af3fd
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 10 deletions.
6 changes: 3 additions & 3 deletions src/adapters/configuration/adapter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,14 +545,14 @@ class VeraConfig(LoRAConfig):
The `composition_mode` parameter should also be set to `add`.
"""

selfattn_lora: bool = False
selfattn_lora: bool = True
intermediate_lora: bool = False
output_lora: bool = False

r: int = 8
init_weights: str = "vera"
d: Union[bool, float] = 0.1
b: Union[bool, float] = 0
b: Union[bool, float] = 0.0
init_weights: str = "vera"


@dataclass(eq=False)
Expand Down
72 changes: 65 additions & 7 deletions src/adapters/methods/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def __init__(
elif config.init_weights == "ia3":
nn.init.ones_(self.lora_A)
nn.init.ones_(self.lora_B)
elif config.init_weights == "vera":
nn.init.kaiming_uniform_(self.lora_A)
nn.init.kaiming_uniform_(self.lora_B)
else:
raise ValueError("Unknown init_weights type: {}".format(config.init_weights))

Expand All @@ -91,6 +94,7 @@ def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor:
return weights - added * self.scaling

def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor):
print("triggered")
if hidden_states is None:
hidden_states = layer_input
hidden_states = self.lora_dropout(hidden_states) @ torch.t(self.lora_A) @ torch.t(self.lora_B)
Expand Down Expand Up @@ -178,15 +182,21 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens
class Vera(nn.Module):
def __init__(
self,
name,
lora_A_shape,
lora_B_shape,
config: LoRAConfig,
gating_heads: int = 1,
):
super().__init__()
self.name = name
self.d = config.d
self.b = config.b
self.r = config.r
self.alpha = config.alpha
self.use_gating = config.use_gating

# Optional dropout
if config.dropout > 0.0:
self.lora_dropout = nn.Dropout(p=config.dropout)

self.lora_A_shape = lora_A_shape
self.lora_B_shape = lora_B_shape
Expand All @@ -196,6 +206,11 @@ def __init__(
# Actual trainable parameters
self.vera_D = nn.Parameter(torch.diag(torch.ones(self.d_shape) * self.d))
self.vera_B = nn.Parameter(torch.diag(torch.ones(self.b_shape) * self.b))
self.scaling = self.alpha / self.r

if self.use_gating:
self.gate = nn.Linear(lora_A_shape[-1], gating_heads)
nn.init.normal_(self.gate.weight, std=0.02)

@property
def delta_w(self) -> torch.Tensor:
Expand All @@ -204,29 +219,68 @@ def delta_w(self) -> torch.Tensor:
lora_B = parameters["lora_B"]
return self.vera_B @ lora_B @ self.vera_D @ lora_A

def com(self, weights: torch.Tensor, added: torch.Tensor, scaling=None) -> torch.Tensor:
"""Performs the composition operation between existing and injected weights."""
if scaling is None:
scaling = self.scaling
return weights + added * scaling

def com_inv(self, weights: torch.Tensor, added: torch.Tensor) -> torch.Tensor:
"""Inverts the composition operation between existing and injected weights."""
return weights - added * self.scaling

def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tensor):
parameters = ForwardContext.get_context().shared_parameters[self.name]
lora_A = parameters["lora_A"]
lora_B = parameters["lora_B"]

if hidden_states is None:
hidden_states = layer_input
hidden_states = self.vera_B @ lora_B @ self.vera_D @ lora_A

return hidden_states
if getattr(self, "lora_dropout"):
hidden_states = self.lora_dropout(hidden_states)

hidden_states = hidden_states @ self.vera_B @ lora_B @ self.vera_D @ lora_A

if self.use_gating:
gate = torch.sigmoid(self.gate(layer_input))
gate = torch.mean(gate, dim=1).unsqueeze(-1)
hidden_states = hidden_states * gate
else:
gate = None

return hidden_states, gate

def set_vera_adapter_name(self, name):
self.name = name


def init_shared_Vera_parameters(model_config, adapter_config, device):
hidden_size = model_config.hidden_size
r = adapter_config["r"]

parameters = nn.ParameterDict()

# initialize frozen, random tensors A, B
parameters["lora_A"] = torch.zeros(r, hidden_size).to(device)
parameters["lora_B"] = torch.zeros(hidden_size, r).to(device)

nn.init.kaiming_uniform_(parameters["lora_A"])
nn.init.kaiming_uniform_(parameters["lora_B"])
if adapter_config["init_weights"] == "lora":
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(parameters["lora_A"], a=math.sqrt(5))
nn.init.zeros_(parameters["lora_B"])
elif adapter_config["init_weights"] == "bert":
nn.init.normal_(parameters["lora_A"], std=0.02)
nn.init.normal_(parameters["lora_B"], std=0.02)
elif adapter_config["init_weights"] == "ia3":
nn.init.ones_(parameters["lora_A"])
nn.init.ones_(parameters["lora_B"])
elif adapter_config["init_weights"] == "vera":
nn.init.kaiming_uniform_(parameters["lora_A"])
nn.init.kaiming_uniform_(parameters["lora_B"])
else:
raise ValueError("Unknown init_wfffeights type: {}".format(adapter_config["init_weights"]))

return parameters


Expand Down Expand Up @@ -264,7 +318,7 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
)
if lora_config is not None and self._check_lora_location(lora_config):
if lora_config.composition_mode == "add":
if lora_config.d and lora_config.b:
if isinstance(lora_config.d, float) or isinstance(lora_config.b, float):
lora_cls = Vera
else:
lora_cls = LoRA
Expand All @@ -277,6 +331,10 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
lora_config,
gating_heads=self.get_n_heads(lora_config),
)
# if we're using Vera, then set the adapter name into the Vera object
if lora_cls == Vera:
lora.set_vera_adapter_name(name=adapter_name)

lora.train(self.training)
lora = lora.to(self.weight.device)
self.loras[adapter_name] = lora
Expand Down

0 comments on commit 46af3fd

Please sign in to comment.