Skip to content

Commit

Permalink
soft_prompt is implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
2catycm committed Mar 30, 2024
1 parent a1b9133 commit 9ceed80
Show file tree
Hide file tree
Showing 12 changed files with 466 additions and 130 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
learn

# Created by https://www.gitignore.io/api/osx,python,pycharm,windows,visualstudio,visualstudiocode
# Edit at https://www.gitignore.io/?templates=osx,python,pycharm,windows,visualstudio,visualstudiocode
venv
Expand Down
6 changes: 3 additions & 3 deletions assets/images/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 0 additions & 19 deletions delta_residual/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +0,0 @@
# type: ignore[attr-defined]
"""A minimal PyTorch re-implementation of Parameter-Efficient Fine-Tuning (Delta Tuning)."""

import sys

if sys.version_info >= (3, 8):
from importlib import metadata as importlib_metadata
else:
import importlib_metadata


def get_version() -> str:
try:
return importlib_metadata.version(__name__)
except importlib_metadata.PackageNotFoundError: # pragma: no cover
return "unknown"


version: str = get_version()
67 changes: 1 addition & 66 deletions delta_residual/__main__.py
Original file line number Diff line number Diff line change
@@ -1,66 +1 @@
# type: ignore[attr-defined]
from enum import Enum
from random import choice
from typing import Optional

import typer
from rich.console import Console

from delta_residual import version
from delta_residual.example import hello


class Color(str, Enum):
white = "white"
red = "red"
cyan = "cyan"
magenta = "magenta"
yellow = "yellow"
green = "green"


app = typer.Typer(
name="delta_residual",
help="A minimal PyTorch re-implementation of Parameter-Efficient Fine-Tuning (Delta Tuning).",
add_completion=False,
)
console = Console()


def version_callback(print_version: bool) -> None:
"""Print the version of the package."""
if print_version:
console.print(f"[yellow]delta_residual[/] version: [bold blue]{version}[/]")
raise typer.Exit()


@app.command(name="")
def main(
name: str = typer.Option(..., help="Person to greet."),
color: Optional[Color] = typer.Option(
None,
"-c",
"--color",
"--colour",
case_sensitive=False,
help="Color for print. If not specified then choice will be random.",
),
print_version: bool = typer.Option(
None,
"-v",
"--version",
callback=version_callback,
is_eager=True,
help="Prints the version of the delta_residual package.",
),
) -> None:
"""Print a greeting with a giving name."""
if color is None:
color = choice(list(Color))

greeting: str = hello(name)
console.print(f"[bold {color}]{greeting}[/]")


if __name__ == "__main__":
app()
raise NotImplementedError()
19 changes: 0 additions & 19 deletions delta_residual/example.py

This file was deleted.

154 changes: 154 additions & 0 deletions delta_residual/general_delta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# from utils import *
from .matching_strategy import find_modules

# 这个只修改Attention。 不考虑LayerNorm的话,prompt只对attention产生了影响
# class GeneralSoftPromptAttentionLayer(nn.Module):

# 这个对每一层做修改。
# class GeneralSoftPromptForEncoderLayer(nn.Module):


# 实际上修改哪一层可以指定。
# class DeltaModel(nn.Module):
# """Some Information about LayerReplacedModel"""
# def __init__(self, original_model:nn.Module):
# super().__init__()
# # self.original_model = (original_model, )
# # self.forward = self.original_model[0].forward # 没有改变original_layer的行为
# # self.hooked:bool = False
# # def hook_in(self):
# # if self.injected:
# # return
# # self.injected = True


# # def hook_out(self):
# # if not self.injected:
# # return
# # self.injected = False

# def forward(self, x):
# # assert self.injected, "If not injected, you can't forward the model."
# assert False
# return x
class AbstractDeltaModule(nn.Module):
def __init__(self) -> None:
super().__init__()

def refer_to(self, model: nn.Module = None):
"""Simply let the DeltaModel `forward()` equals the reference model's `forward()`.
Note: If the model is not hooked into by `self`, then the self.`__call__()` may not work as expected. Put it differently, DeltaModel's own delta computations will not be called.
Note: This shall not change the behavior of `model`.
Args:
model (nn.Module, optional): reference Pytorch model. Defaults to None. If None, the DeltaModel is set to be not callable.
"""
if model is None:
self.forward = None
else:
self.forward = model.forward
self.reference_model_tup = (model,)

def hook_into(self, model: nn.Module):
"""Let the DeltaModel injects its computation into the reference model.
After that, the reference model's `__call__()` is modified, with not only reference model's own `forward()`, but also the delta computations.
The hooking method is designed to be invertible. To cancel the modification, see also `remove_hook_from()`.
Note: This method would change the behavior of `model`.
Args:
model (nn.Module): reference Pytorch model.
Raises:
NotImplementedError: _description_
"""
raise NotImplementedError

def remove_hook_from(self, model: nn.Module):
"""Remove the hooking effect of `self` on `model`.
Note: This method would change the behavior of `model`.
Args:
model (nn.Module): reference Pytorch model.
Raises:
NotImplementedError: _description_
"""
raise NotImplementedError

def merge_into(self, model: nn.Module):
"""Re-parameterize the reference model with the delta model.
Note: This method would change the behavior of `model`.
Args:
model (nn.Module): reference Pytorch model.
Raises:
NotImplementedError: _description_
"""
raise NotImplementedError


class AbstractDeltaLayer(AbstractDeltaModule):
"""
1. DeltaLayer is a more rigorous version of DeltaModel.
DeltaLayer can compute `__call__()` with reference parameters and delta parameters by simply `refer_to` a reference model,
without modifying the behavior of the reference model.
Hooking, which would change the behavior of the reference model, is supported but not necessary.
2. This class provides resource management of hooks and handles for the subclasses.
"""

def refer_to(self, model: nn.Module = None):
"""Simply let the DeltaModel `forward()` equals the reference model's `forward()`.
Note: This shall not change the behavior of `model`.
Args:
model (nn.Module, optional): reference Pytorch model. Defaults to None. If None, the DeltaModel is set to be not callable.
"""
super().refer_to(model)


class GeneralDeltaModel(AbstractDeltaModule):
"""我不是个抽象类, 我是任何DeltaModel都直接能用。
我是个Layer替代器。
"""

def __init__(
self,
reference_model: nn.Module,
modified_modules: list[str],
adapter_name="delta",
layer_delta_class=nn.Module,
layer_config: dict = None,
) -> None:
super().__init__()
self.layer_delta_class = layer_delta_class
self.layer_config = layer_config or dict()
self.adapter_name = adapter_name
self.refer_to(reference_model)
self.initiate(reference_model, modified_modules)
self.hook_into(reference_model)

def initiate(self, reference_model: nn.Module, modified_modules: list[str]):
self.delta_layers: nn.ModuleDict[str, AbstractDeltaModule] = nn.ModuleDict()
for name, module in find_modules(reference_model, modified_modules):
# self.delta_layers.add_module(f"{self.adapter_name}.{name}",
self.delta_layers.add_module(
name.replace(".", "=="),
self.layer_delta_class(module, **self.layer_config),
)

def hook_into(self, model: nn.Module):
for name, layer in self.delta_layers.items():
original = model.get_submodule(name.replace("==", "."))
layer.hook_into(original)

def remove_hook_from(self, model: nn.Module):
for name, layer in self.delta_layers.items():
original = model.get_submodule(name.replace("==", "."))
layer.remove_hook_from(original)

def merge_into(self, model: nn.Module):
for name, layer in self.delta_layers.items():
original = model.get_submodule(name.replace("==", "."))
layer.merge_into(original)
51 changes: 51 additions & 0 deletions delta_residual/matching_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# xpath, beautiful soup 也是指定模块的方法
# 正则表达式也是
# 树状结构,有深浅
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


def is_match(
module_simple_name: str, module_full_name: str, modified_modules: list[str]
) -> bool:
for requirement in modified_modules:
if module_simple_name == requirement or (
module_full_name.endswith(requirement) and requirement != ""
):
return True
return False


# 特殊情况,
# [""]空字符串表示根模型,也就是model本身。
# [] 空列表表示不想要任何module,返回False
def find_modules_all(model: nn.Module, modified_modules: list[str]):
for name, module in model.named_modules(): # 先序遍历
# print(f"name={name}, cls_name={module.__class__.__name__}")
if is_match(module.__class__.__name__, name, modified_modules):
yield (name, module) # 为了让用户知道自己在干什么,筛选出来的模型对不对。
# print("It matches!")
# print()


def find_modules_parent_only(
model: nn.Module, modified_modules: list[str], parent_name=""
):
if is_match(model.__class__.__name__, parent_name, modified_modules):
yield parent_name, model # 一般都不是这个情况
return # 父模块优先,已经找到就不深究了。
for name, children in model.named_children(): # 先序遍历
# print(f"name={name}, cls_name={module.__class__.__name__}")
full_name = f"{parent_name}.{name}" if parent_name != "" else name
yield from find_modules_parent_only(children, modified_modules, full_name)
# print("It matches!")
# print()


find_modules = find_modules_parent_only


def find_modules_dict(model: nn.Module, modified_modules: list[str]) -> dict:
return dict(find_modules(model, modified_modules))
Empty file.
Loading

0 comments on commit 9ceed80

Please sign in to comment.