diff --git a/delta_residual/general_delta.py b/delta_residual/general_delta.py index 58cbfbd..8b4ed93 100644 --- a/delta_residual/general_delta.py +++ b/delta_residual/general_delta.py @@ -9,6 +9,7 @@ from .matching_strategy import find_modules from .utils import ( ModuleDeviceAddOn, + ModuleForSelfHook, SeeTrainableParametersAddOn, get_module_device, get_tuple_device, @@ -150,10 +151,10 @@ def hook_into(self, layer: nn.Module): ) self.remove_hook_from(layer) self.others_forward_pre_hook_handles[layer] = layer.register_forward_pre_hook( - hook=self._forward_pre_hook + hook=ModuleForSelfHook(self, self.__class__._forward_pre_hook) ) self.others_forward_hook_handles[layer] = layer.register_forward_hook( - hook=self._forward_hook + hook=ModuleForSelfHook(self, self.__class__._forward_hook) ) def remove_hook_from(self, layer: nn.Module): diff --git a/delta_residual/soft_prompt.py b/delta_residual/soft_prompt.py index 1372a73..76b2cda 100644 --- a/delta_residual/soft_prompt.py +++ b/delta_residual/soft_prompt.py @@ -1,4 +1,4 @@ -#%% +# %% import torch import torch.nn as nn import torch.nn.functional as F diff --git a/delta_residual/utils.py b/delta_residual/utils.py index 6f2b4f4..fe3b311 100644 --- a/delta_residual/utils.py +++ b/delta_residual/utils.py @@ -9,6 +9,20 @@ # from delta_residual.general_delta import AbstractDeltaModule +# 这里是一个partial +@nn.DataParallel +class ModuleForSelfHook(nn.Module): + """Some Information about ModuleForSelfHook""" + + def __init__(self, self_delta_model: nn.Module, hook_without_self: Callable): + super(ModuleForSelfHook, self).__init__() + self.self_delta_model = self_delta_model + self.hook_without_self = hook_without_self + + def forward(self, *args, **kwargs): + return self.hook_without_self(self.self_delta_model, *args, **kwargs) + + def get_sorted_function_inputs_from_args(fun, *args, **kwargs) -> dict[str, Any]: args = list(args) # 将输入函数的参数正则化,变成按照函数调用顺序的、写出参数名称的字典调用