Skip to content

Commit

Permalink
Fix for untied weights
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Jun 22, 2024
1 parent 5f2b712 commit 682de00
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
8 changes: 4 additions & 4 deletions src/adapters/configuration/adapter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ class ReftConfig(AdapterConfig):
suffix_positions: int
r: int
orthogonality: bool
tied_weights: bool = True
tied_weights: bool = False
subtract_projection = True
dropout: float = 0.05
non_linearity: Optional[str] = None
Expand All @@ -542,7 +542,7 @@ class LoReftConfig(ReftConfig):
suffix_positions: int = 0
r: int = 1
orthogonality: bool = True
tied_weights: bool = True
tied_weights: bool = False


@dataclass(eq=False)
Expand All @@ -556,7 +556,7 @@ class NoReftConfig(ReftConfig):
suffix_positions: int = 0
r: int = 1
orthogonality: bool = False
tied_weights: bool = True
tied_weights: bool = False


@dataclass(eq=False)
Expand All @@ -570,7 +570,7 @@ class DiReftConfig(ReftConfig):
suffix_positions: int = 0
r: int = 1
orthogonality: bool = False
tied_weights: bool = True
tied_weights: bool = False
subtract_projection = False


Expand Down
12 changes: 7 additions & 5 deletions src/adapters/methods/reft.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,19 @@ def __init__(self, in_features: int, config: ReftConfig):

def _gather_adapted_states(self, hidden_states: torch.Tensor):
context = ForwardContext.get_context()
bsz, _, ddim = hidden_states.size()
# no cached indexing matrices available -> compute now
if not hasattr(context, "pref_idx") and not hasattr(context, "suff_idx"):
# read offsets & lengths from context
if hasattr(context, "seqlens"):
first_non_padding = context.offsets
last_non_padding = context.offsets + context.seqlens
else:
first_non_padding = torch.tensor([0] * hidden_states.size(0))
last_non_padding = torch.tensor([hidden_states.size(1)] * hidden_states.size(0))
first_non_padding = torch.tensor([0] * hidden_states.size(0)).to(hidden_states.device)
last_non_padding = torch.tensor([hidden_states.size(1)] * hidden_states.size(0)).to(
hidden_states.device
)
# create indexing matrices for prefixes & suffixes
bsz, _, ddim = hidden_states.size()
if self.prefix_positions > 0:
pref_idx = first_non_padding.view(-1, 1, 1) + (
torch.arange(self.prefix_positions)
Expand All @@ -99,11 +101,11 @@ def _gather_adapted_states(self, hidden_states: torch.Tensor):
if self.prefix_positions > 0:
prefix = hidden_states.gather(1, context.pref_idx)
else:
prefix = torch.zeros(0, device=hidden_states.device)
prefix = torch.zeros(bsz, 0, ddim, device=hidden_states.device)
if self.suffix_positions > 0:
suffix = hidden_states.gather(1, context.suff_idx)
else:
suffix = torch.zeros(0, device=hidden_states.device)
suffix = torch.zeros(bsz, 0, ddim, device=hidden_states.device)

if self.tied_weights:
adapted_states = [torch.cat([prefix, suffix], dim=1)]
Expand Down

0 comments on commit 682de00

Please sign in to comment.