Skip to content

Commit

Permalink
lora update
Browse files Browse the repository at this point in the history
  • Loading branch information
eps696 committed Dec 11, 2023
1 parent 0f516f2 commit b84a99d
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 53 deletions.
21 changes: 15 additions & 6 deletions model_half.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch

parser = argparse.ArgumentParser()
parser.add_argument("--dir", '-d', default='./', help="directory with models")
parser.add_argument("--ext", '-e', default=['ckpt','pt', 'bin'], help="model extensions")
parser.add_argument("--input", '-i', default='./', help="file or directory with models")
parser.add_argument("--ext", '-e', default=['ckpt','pt', 'bin','safetensors'], help="model extensions")
a = parser.parse_args()

def basename(file):
Expand All @@ -29,18 +29,27 @@ def float2half(data):
if isinstance(data[k], collections.abc.Mapping):
data[k] = float2half(data[k])
elif isinstance(data[k], list):
data[k] = [float2half(x) for x in data[k]]
data[k] = [float2half(x) for x in data[k] if not isinstance(x, int)]
else:
if data[k] is not None and torch.is_tensor(data[k]) and data[k].type() in ['torch.FloatTensor', 'torch.cuda.FloatTensor']:
data[k] = data[k].half()
return data

models = file_list(a.dir, a.ext)
models = [a.input] if os.path.isfile(a.input) else file_list(a.input, a.ext)

if any(['safetensors' in f for f in models]):
import safetensors.torch as safe

for model_path in models:
model = torch.load(model_path)
issafe = '.safetensors' in model_path.lower()
model = safe.load_file(model_path) if issafe else torch.load(model_path)

model = float2half(model)

file_bak = basename(model_path) + '-full' + os.path.splitext(model_path)[-1]
move(model_path, file_bak)
torch.save(model, model_path)
if issafe:
safe.save_file(model, model_path)
else:
torch.save(model, model_path)

4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ git+https://github.com/huggingface/diffusers
git+https://github.com/openai/CLIP
torchmetrics
accelerate
safetensors==0.3.1
peft

imageio-ffmpeg
omegaconf
easydict
numpy==1.23
tqdm
Expand All @@ -19,4 +21,4 @@ resize_right
torchdiffeq
torchsde
ipywidgets

av
40 changes: 2 additions & 38 deletions src/core/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,7 @@
from torch.utils.data import Dataset
from torchvision import transforms

try: # diffusers 0.14
from diffusers.models.attention import CrossAttention as Attention
from diffusers.models.cross_attention import LoRACrossAttnProcessor as LoRAAttnProcessor
try:
import xformers
except: pass
except: # diffusers 0.15+
from diffusers.models.attention_processor import Attention
try:
import xformers
from diffusers.models.attention_processor import LoRAXFormersAttnProcessor as LoRAAttnProcessor
except:
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import Attention

PIL_INTERPOLATION = PIL.Image.Resampling.BICUBIC if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0") else PIL.Image.BICUBIC

Expand Down Expand Up @@ -65,7 +52,7 @@ def __init__(self):
}
model_path = CAPTION_MODELS['blip-base'] # 'blip-large' is too imaginative
model = BlipForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype) # Blip2ForConditionalGeneration
self.processor = AutoProcessor.from_pretrained(model_path)
self.processor = AutoProcessor.from_pretrained(model_path, do_rescale=False)
self.model = model.eval().to(self.device)

def __call__(self, image):
Expand Down Expand Up @@ -275,29 +262,6 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
hidden_states = attn.to_out[1](hidden_states) # dropout
return hidden_states

# # # # # # # # # LoRA # # # # # # # # #

def prep_lora(unet):
# attention processors => 32 layers
# 3x down blocks * 2x attn layers * 2x transformer layers = 12
# 1x mid blocks * 2x attn layers * 1x transformer layers = 2
# 3x up blocks * 2x attn layers * 3x transformer layers = 18
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim).to(unet.device, dtype=unet.dtype)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
return unet, lora_layers

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #

def save_embeds(save_path, text_encoder, tokens, tokens_id, accelerator=None):
Expand Down
20 changes: 12 additions & 8 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
transformers.utils.logging.set_verbosity_warning()
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict

from core.finetune import FinetuneDataset, custom_diff, prep_lora, save_delta, load_delta, save_embeds
from core.finetune import FinetuneDataset, custom_diff, save_delta, load_delta, save_embeds
from core.utils import save_img, save_cfg, isset, progbar, basename

import warnings
Expand Down Expand Up @@ -47,6 +49,7 @@
parser.add_argument('-val', '--validate', action='store_true', help="Save test samples during training")
parser.add_argument('-lo', '--low_mem', action='store_true', help="Use gradient checkpointing: less memory, slower training")
parser.add_argument('--freeze_model', default='crossattn_kv', help="set 'crossattn' to enable fine-tuning of all key, value, query matrices")
parser.add_argument( '--rank', default=4, type=int, help="The dimension of the LoRA update matrices.")
parser.add_argument('-lr', '--lr', default=1e-5, type=float, help="Initial learning rate") # 1e-3 ~ 5e-4 for text inv, 1e-4 for lora
parser.add_argument('--scale_lr', default=True, help="Scale learning rate by batch")
parser.add_argument('-S', '--seed', default=None, type=int, help="A seed for reproducible training.")
Expand Down Expand Up @@ -158,9 +161,10 @@ def collate_fn(examples):
else:
text_encoder.to(dtype=weight_dtype).requires_grad_(False)
if a.type=='lora':
unet, lora_layers = prep_lora(unet)
unet.requires_grad_(False)
lora_layers.requires_grad_(True)
unet_lora_config = LoraConfig(r=a.rank, init_lora_weights="gaussian", target_modules=["to_k", "to_q", "to_v", "to_out.0"])
unet.add_adapter(unet_lora_config)
lora_params = filter(lambda p: p.requires_grad, unet.parameters())
unet_dtype = torch.float32 # !!! must be trained as float32; otherwise inf/nan
elif a.type=='text':
unet.requires_grad_(False)
Expand All @@ -172,7 +176,7 @@ def collate_fn(examples):
unet.to(dtype=unet_dtype)
text_encoder.to(dtype=torch.float32) # !!! must be trained as float32 + avoiding q/k/v mistype error at lora validation

if a.xformers is True and a.type != 'lora':
if a.xformers is True:
try:
import xformers
unet.enable_xformers_memory_efficient_attention()
Expand All @@ -192,7 +196,7 @@ def collate_fn(examples):
if a.type == 'text': # text inversion: only new embedding(s)
params_to_optimize = text_encoder.get_input_embeddings().parameters()
elif a.type == 'lora':
params_to_optimize = lora_layers.parameters()
params_to_optimize = lora_params
elif a.freeze_model == 'crossattn': # custom: embeddings & unet attention all
params_to_optimize = itertools.chain(text_encoder.get_input_embeddings().parameters(),
[x[1] for x in unet.named_parameters() if 'attn2' in x[0]])
Expand Down Expand Up @@ -255,11 +259,11 @@ def collate_fn(examples):
save_delta(save_path, unet, text_encoder, mod_tokens, mod_tokens_id, a.freeze_model, unet0=unet0)
elif a.type == 'lora':
save_path = os.path.join(a.out_dir, '%s-%s-%04d.pt' % (basename(a.data), a.type, global_step))
torch.save(lora_layers.state_dict(), save_path)
torch.save(get_peft_model_state_dict(unet), save_path)

# test sample
if a.validate:
pipetest = StableDiffusionPipeline(vae, text_encoder, tokenizer, unet, scheduler, None, None, False).to(device)
pipetest = StableDiffusionPipeline(vae, text_encoder, tokenizer, unet, scheduler, None, None, None, False).to(device)
pipetest.set_progress_bar_config(disable=True)
with torch.autocast("cuda"):
if train_txtenc:
Expand All @@ -284,7 +288,7 @@ def collate_fn(examples):
save_delta(save_path, unet, text_encoder, mod_tokens, mod_tokens_id, a.freeze_model, unet0=unet0)
elif a.type == 'lora':
save_path = os.path.join(a.out_dir, '%s-%s.pt' % (basename(a.data), a.type))
torch.save(lora_layers.state_dict(), save_path)
torch.save(get_peft_model_state_dict(unet), save_path)


if __name__ == "__main__":
Expand Down

0 comments on commit b84a99d

Please sign in to comment.