Skip to content

Commit

Permalink
parallel wrapper for hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
2catycm committed Apr 6, 2024
1 parent 8e632c0 commit 553c6a4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 3 deletions.
5 changes: 3 additions & 2 deletions delta_residual/general_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .matching_strategy import find_modules
from .utils import (
ModuleDeviceAddOn,
ModuleForSelfHook,
SeeTrainableParametersAddOn,
get_module_device,
get_tuple_device,
Expand Down Expand Up @@ -150,10 +151,10 @@ def hook_into(self, layer: nn.Module):
)
self.remove_hook_from(layer)
self.others_forward_pre_hook_handles[layer] = layer.register_forward_pre_hook(
hook=self._forward_pre_hook
hook=ModuleForSelfHook(self, self.__class__._forward_pre_hook)
)
self.others_forward_hook_handles[layer] = layer.register_forward_hook(
hook=self._forward_hook
hook=ModuleForSelfHook(self, self.__class__._forward_hook)
)

def remove_hook_from(self, layer: nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion delta_residual/soft_prompt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#%%
# %%
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down
14 changes: 14 additions & 0 deletions delta_residual/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@
# from delta_residual.general_delta import AbstractDeltaModule


# 这里是一个partial
@nn.DataParallel
class ModuleForSelfHook(nn.Module):
"""Some Information about ModuleForSelfHook"""

def __init__(self, self_delta_model: nn.Module, hook_without_self: Callable):
super(ModuleForSelfHook, self).__init__()
self.self_delta_model = self_delta_model
self.hook_without_self = hook_without_self

def forward(self, *args, **kwargs):
return self.hook_without_self(self.self_delta_model, *args, **kwargs)


def get_sorted_function_inputs_from_args(fun, *args, **kwargs) -> dict[str, Any]:
args = list(args)
# 将输入函数的参数正则化,变成按照函数调用顺序的、写出参数名称的字典调用
Expand Down

0 comments on commit 553c6a4

Please sign in to comment.