diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 7f520bb378..9dc9426795 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -523,7 +523,7 @@ class ReftConfig(AdapterConfig): suffix_positions: int r: int orthogonality: bool - tied_weights: bool = True + tied_weights: bool = False subtract_projection = True dropout: float = 0.05 non_linearity: Optional[str] = None @@ -542,7 +542,7 @@ class LoReftConfig(ReftConfig): suffix_positions: int = 0 r: int = 1 orthogonality: bool = True - tied_weights: bool = True + tied_weights: bool = False @dataclass(eq=False) @@ -556,7 +556,7 @@ class NoReftConfig(ReftConfig): suffix_positions: int = 0 r: int = 1 orthogonality: bool = False - tied_weights: bool = True + tied_weights: bool = False @dataclass(eq=False) @@ -570,7 +570,7 @@ class DiReftConfig(ReftConfig): suffix_positions: int = 0 r: int = 1 orthogonality: bool = False - tied_weights: bool = True + tied_weights: bool = False subtract_projection = False diff --git a/src/adapters/methods/reft.py b/src/adapters/methods/reft.py index f4a66d4c7e..3089c0abf7 100644 --- a/src/adapters/methods/reft.py +++ b/src/adapters/methods/reft.py @@ -66,6 +66,7 @@ def __init__(self, in_features: int, config: ReftConfig): def _gather_adapted_states(self, hidden_states: torch.Tensor): context = ForwardContext.get_context() + bsz, _, ddim = hidden_states.size() # no cached indexing matrices available -> compute now if not hasattr(context, "pref_idx") and not hasattr(context, "suff_idx"): # read offsets & lengths from context @@ -73,10 +74,11 @@ def _gather_adapted_states(self, hidden_states: torch.Tensor): first_non_padding = context.offsets last_non_padding = context.offsets + context.seqlens else: - first_non_padding = torch.tensor([0] * hidden_states.size(0)) - last_non_padding = torch.tensor([hidden_states.size(1)] * hidden_states.size(0)) + first_non_padding = torch.tensor([0] * hidden_states.size(0)).to(hidden_states.device) + last_non_padding = torch.tensor([hidden_states.size(1)] * hidden_states.size(0)).to( + hidden_states.device + ) # create indexing matrices for prefixes & suffixes - bsz, _, ddim = hidden_states.size() if self.prefix_positions > 0: pref_idx = first_non_padding.view(-1, 1, 1) + ( torch.arange(self.prefix_positions) @@ -99,11 +101,11 @@ def _gather_adapted_states(self, hidden_states: torch.Tensor): if self.prefix_positions > 0: prefix = hidden_states.gather(1, context.pref_idx) else: - prefix = torch.zeros(0, device=hidden_states.device) + prefix = torch.zeros(bsz, 0, ddim, device=hidden_states.device) if self.suffix_positions > 0: suffix = hidden_states.gather(1, context.suff_idx) else: - suffix = torch.zeros(0, device=hidden_states.device) + suffix = torch.zeros(bsz, 0, ddim, device=hidden_states.device) if self.tied_weights: adapted_states = [torch.cat([prefix, suffix], dim=1)]