Skip to content

Commit

Permalink
refractor delta managements to AbstractDeltaLayer
Browse files Browse the repository at this point in the history
  • Loading branch information
2catycm committed Mar 30, 2024
1 parent 9ceed80 commit 81fbae7
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 51 deletions.
4 changes: 2 additions & 2 deletions assets/images/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
72 changes: 69 additions & 3 deletions delta_residual/general_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from loguru import logger

# from utils import *
from .matching_strategy import find_modules
Expand Down Expand Up @@ -37,8 +38,9 @@
# assert False
# return x
class AbstractDeltaModule(nn.Module):
def __init__(self) -> None:
def __init__(self, reference_model: nn.Module = None) -> None:
super().__init__()
self.refer_to(reference_model)

def refer_to(self, model: nn.Module = None):
"""Simply let the DeltaModel `forward()` equals the reference model's `forward()`.
Expand All @@ -47,12 +49,25 @@ def refer_to(self, model: nn.Module = None):
Args:
model (nn.Module, optional): reference Pytorch model. Defaults to None. If None, the DeltaModel is set to be not callable.
"""
# 接口保留
# 注意,original_layer不是GeneralSoftPromptLayer的子模块,参数不应该被保存和加载,而是应该总是从外部传入。
# https://discuss.pytorch.org/t/unregister-prevent-from-registering-a-nn-module/134768
# self.reference_layer = (reference_layer, )
# self.forward = self.reference_layer[0].forward
if model is None:
self.forward = None
else:
self.forward = model.forward
self.reference_model_tup = (model,)

@property
def reference_model(self):
return self.reference_model_tup[0]

@reference_model.setter
def reference_model(self, new_reference_model: nn.Module):
self.refer_to(new_reference_model)

def hook_into(self, model: nn.Module):
"""Let the DeltaModel injects its computation into the reference model.
After that, the reference model's `__call__()` is modified, with not only reference model's own `forward()`, but also the delta computations.
Expand Down Expand Up @@ -99,13 +114,64 @@ class AbstractDeltaLayer(AbstractDeltaModule):
2. This class provides resource management of hooks and handles for the subclasses.
"""

def __init__(self, reference_layer: nn.Module = None) -> None:
super().__init__(
reference_model=reference_layer
) # 这里会调用`refer_to`, 没有改变original_layer的行为。如果需要里面的性质才能推导出来layer要怎么初始化的话,可以用self.reference_model
# 对自身的forward进行hook
self.forward_pre_hook_handle = self.register_forward_pre_hook(
hook=self._forward_pre_hook
) # 闭包,知晓这个类的信息的。
self.forward_hook_handle = self.register_forward_hook(
hook=self._forward_hook
) # 闭包,知晓这个类的信息的。
# 去hook其他模型
self.others_forward_pre_hook_handles = dict()
self.others_forward_hook_handles = dict()

def __del__(self):
for layer in self.others_forward_pre_hook_handles.keys():
self.remove_hook_from(layer)

def refer_to(self, model: nn.Module = None):
"""Simply let the DeltaModel `forward()` equals the reference model's `forward()`.
Note: This shall not change the behavior of `model`.
Args:
model (nn.Module, optional): reference Pytorch model. Defaults to None. If None, the DeltaModel is set to be not callable.
"""
super().refer_to(model)
super().refer_to(
model
) # Compared to AbstractDeltaModule, we just change the documentation here.

def hook_into(self, layer: nn.Module):
if self.others_forward_pre_hook_handles.get(layer) is not None:
logger.warning(
f"Layer {layer} has already been hooked. I will remove the old first and then replace it with the new."
)
self.remove_hook_from(layer)
self.others_forward_pre_hook_handles[layer] = layer.register_forward_pre_hook(
hook=self._forward_pre_hook
)
self.others_forward_hook_handles[layer] = layer.register_forward_hook(
hook=self._forward_hook
)

def remove_hook_from(self, layer: nn.Module):
if not self.others_forward_pre_hook_handles.get(layer):
logger.warning(
f"Layer {layer} has not been hooked. I will do nothing and return."
)
return
self.others_forward_pre_hook_handles[layer].remove()
self.others_forward_hook_handles[layer].remove()
del self.others_forward_pre_hook_handles[layer]
del self.others_forward_hook_handles[layer]

def _forward_pre_hook(self):
raise NotImplementedError("Shall be implemented by subclasses. ")

def _forward_hook(self):
raise NotImplementedError("Shall be implemented by subclasses. ")


class GeneralDeltaModel(AbstractDeltaModule):
Expand All @@ -121,7 +187,7 @@ def __init__(
layer_delta_class=nn.Module,
layer_config: dict = None,
) -> None:
super().__init__()
super().__init__(reference_model)
self.layer_delta_class = layer_delta_class
self.layer_config = layer_config or dict()
self.adapter_name = adapter_name
Expand Down
50 changes: 4 additions & 46 deletions delta_residual/soft_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import torch.optim as optim
from loguru import logger

from .general_delta import AbstractDeltaModule
from .general_delta import AbstractDeltaLayer


class GeneralSoftPromptLayer(AbstractDeltaModule):
class GeneralSoftPromptLayer(AbstractDeltaLayer):
"""One Layer to substitute another, don't change the original one's behavior unless specified."""

def __init__(
Expand All @@ -28,14 +28,7 @@ def __init__(
dim_of_hidden=-1,
dim_of_batches=0,
):
super().__init__()
# 接口保留
# 注意,original_layer不是GeneralSoftPromptLayer的子模块,参数不应该被保存和加载,而是应该总是从外部传入。
# https://discuss.pytorch.org/t/unregister-prevent-from-registering-a-nn-module/134768
# self.reference_layer = (reference_layer, )
# self.forward = self.reference_layer[0].forward
self.refer_to(reference_layer) # 没有改变original_layer的行为
# super().__init__() # 这样的话,original_layer就没有被记录,非常好
super().__init__(reference_layer) # 会调用`refer_to`。
# peft方法的参数
self.soft_token_num = soft_token_num
self.hidden_dim = hidden_dim
Expand All @@ -54,31 +47,15 @@ def __init__(
self.dim_of_hidden = dim_of_hidden
self.dim_of_batches = dim_of_batches

# 对自身的forward进行hook
self.forward_pre_hook_handle = self.register_forward_pre_hook(
hook=self._forward_pre_hook
) # 闭包,知晓这个类的信息的。
self.forward_hook_handle = self.register_forward_hook(
hook=self._forward_hook
) # 闭包,知晓这个类的信息的。
# 去hook其他模型
self.others_forward_pre_hook_handles = dict()
self.others_forward_hook_handles = dict()

def __del__(self):
for layer in self.others_forward_pre_hook_handles:
self.remove_hook_from(layer)

def _forward_pre_hook(self, module: nn.Module, input: tuple) -> tuple:
# 返回新的input
new_input = list(input)
new_input = list(input) # 因为元组不能Assignment
for i in self.prepended_inputs:
# selected_tensor = self.tensor_selector(input[i]) # 得到一个指针
# selected_tensor = torch.cat([selected_tensor, self.prompts], dim=self.dim_of_tokens) # 并没有修改原来的tensor
# self.tensor_selector(input[i]) = selected_tensor
selected_tensor: torch.Tensor = input[i]
b = selected_tensor.shape[self.dim_of_batches]
# TODO 操作不太对, 元组不能直接assign
new_input[i] = torch.cat(
[
selected_tensor,
Expand Down Expand Up @@ -127,24 +104,5 @@ def instantiate(self, hidden_dim) -> None:
self.soft_prompts: torch.Tensor = nn.Parameter(soft_prompts, requires_grad=True)
# .to(self.device)

def hook_into(self, layer: nn.Module):
if self.others_forward_pre_hook_handles.get(layer) is not None:
logger.warning(
f"Layer {layer} has already been hooked. I will remove the old first and then replace it with the new."
)
self.remove_hook_from(layer)
self.others_forward_pre_hook_handles[layer] = layer.register_forward_pre_hook(
hook=self._forward_pre_hook
)
self.others_forward_hook_handles[layer] = layer.register_forward_hook(
hook=self._forward_hook
)

def remove_hook_from(self, layer: nn.Module):
self.others_forward_pre_hook_handles[layer].remove()
self.others_forward_hook_handles[layer].remove()
del self.others_forward_pre_hook_handles[layer]
del self.others_forward_hook_handles[layer]

def merge_into(self, layer: nn.Module):
raise ArithmeticError("General Soft Prompt Tuning cannot be re-parameterized.")

0 comments on commit 81fbae7

Please sign in to comment.