From 387f48c003661cd18978bb5cbc859dd5dd4f30a3 Mon Sep 17 00:00:00 2001 From: Rypo Date: Mon, 25 Nov 2024 19:39:21 -0600 Subject: [PATCH 01/13] feat: fast model loading with accelerate Prevents slow CPU initialization of model weights on load by using accelerate `init_empty_weights`. Completely compatible with from_pretrained since weights will always be overwritten by state_dict fixes #72 --- OmniGen/model.py | 33 ++++++++++++++++++++++++++------- OmniGen/pipeline.py | 44 +++++++++++++++++++++++++++----------------- 2 files changed, 53 insertions(+), 24 deletions(-) diff --git a/OmniGen/model.py b/OmniGen/model.py index 8999a8e..7504e54 100644 --- a/OmniGen/model.py +++ b/OmniGen/model.py @@ -1,5 +1,6 @@ # The code is revised from DiT import os +import gc import torch import torch.nn as nn import numpy as np @@ -10,6 +11,7 @@ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp from huggingface_hub import snapshot_download from safetensors.torch import load_file +from accelerate import init_empty_weights from OmniGen.transformer import Phi3Config, Phi3Transformer @@ -187,20 +189,37 @@ def __init__( self.llm.config.use_cache = False @classmethod - def from_pretrained(cls, model_name): + def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch.bfloat16, device: str|torch.device='cuda', low_cpu_mem_usage: bool = True,): if not os.path.exists(model_name): cache_folder = os.getenv('HF_HUB_CACHE') model_name = snapshot_download(repo_id=model_name, cache_dir=cache_folder, ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) - config = Phi3Config.from_pretrained(model_name) - model = cls(config) - if os.path.exists(os.path.join(model_name, 'model.safetensors')): + + model_path = os.path.join(model_name, 'model.safetensors') + if not os.path.exists(model_path): + model_path = os.path.join(model_name, 'model.pt') + ckpt = torch.load(model_path, map_location='cpu') + else: print("Loading safetensors") - ckpt = load_file(os.path.join(model_name, 'model.safetensors')) + ckpt = load_file(model_path, 'cpu') + + if low_cpu_mem_usage: + with init_empty_weights(): + config = Phi3Config.from_pretrained(model_name) + model = cls(config) + + model.load_state_dict(ckpt, assign=True) + model = model.to(device, dtype) else: - ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu') - model.load_state_dict(ckpt) + config = Phi3Config.from_pretrained(model_name) + model = cls(config) + model.load_state_dict(ckpt) + model = model.to(device, dtype) + + del ckpt + torch.cuda.empty_cache() + gc.collect() return model def initialize_weights(self): diff --git a/OmniGen/pipeline.py b/OmniGen/pipeline.py index 09b0731..9f34086 100644 --- a/OmniGen/pipeline.py +++ b/OmniGen/pipeline.py @@ -41,6 +41,15 @@ ``` """ +def best_available_device(): + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!") + device = torch.device("cpu") + return device class OmniGenPipeline: def __init__( @@ -55,14 +64,10 @@ def __init__( self.processor = processor self.device = device - if device is None: - if torch.cuda.is_available(): - self.device = torch.device("cuda") - elif torch.backends.mps.is_available(): - self.device = torch.device("mps") - else: - logger.info("Don't detect any available GPUs, using CPU instead, this may take long time to generate image!!!") - self.device = torch.device("cpu") + if self.device is None: + self.device = best_available_device() + elif isinstance(self.device, str): + self.device = torch.device(self.device) # self.model.to(torch.bfloat16) self.model.eval() @@ -71,7 +76,7 @@ def __init__( self.model_cpu_offload = False @classmethod - def from_pretrained(cls, model_name, vae_path: str=None): + def from_pretrained(cls, model_name, vae_path: str=None, device=None, low_cpu_mem_usage=True): if not os.path.exists(model_name) or (not os.path.exists(os.path.join(model_name, 'model.safetensors')) and model_name == "Shitao/OmniGen-v1"): logger.info("Model not found, downloading...") cache_folder = os.getenv('HF_HUB_CACHE') @@ -79,18 +84,23 @@ def from_pretrained(cls, model_name, vae_path: str=None): cache_dir=cache_folder, ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt']) logger.info(f"Downloaded model to {model_name}") - model = OmniGen.from_pretrained(model_name) + + if device is None: + device = best_available_device() + + model = OmniGen.from_pretrained(model_name, dtype=torch.bfloat16, device=device, low_cpu_mem_usage=low_cpu_mem_usage) processor = OmniGenProcessor.from_pretrained(model_name) - if os.path.exists(os.path.join(model_name, "vae")): - vae = AutoencoderKL.from_pretrained(os.path.join(model_name, "vae")) - elif vae_path is not None: - vae = AutoencoderKL.from_pretrained(vae_path).to(device) - else: + if vae_path is None: + vae_path = os.path.join(model_name, "vae") + + if not os.path.exists(vae_path): logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF") - vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device) + vae_path = "stabilityai/sdxl-vae" + + vae = AutoencoderKL.from_pretrained(vae_path).to(device) - return cls(vae, model, processor) + return cls(vae, model, processor, device) def merge_lora(self, lora_path: str): model = PeftModel.from_pretrained(self.model, lora_path) From 0287b507b2c092cd280c52d2c0b9f4fe764fdd20 Mon Sep 17 00:00:00 2001 From: Rypo Date: Tue, 26 Nov 2024 13:31:37 -0600 Subject: [PATCH 02/13] fix: avoid moving model to device prematurely --- OmniGen/model.py | 6 +++--- OmniGen/pipeline.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/OmniGen/model.py b/OmniGen/model.py index 7504e54..b45f2e1 100644 --- a/OmniGen/model.py +++ b/OmniGen/model.py @@ -189,7 +189,7 @@ def __init__( self.llm.config.use_cache = False @classmethod - def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch.bfloat16, device: str|torch.device='cuda', low_cpu_mem_usage: bool = True,): + def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch.bfloat16, low_cpu_mem_usage: bool = True,): if not os.path.exists(model_name): cache_folder = os.getenv('HF_HUB_CACHE') model_name = snapshot_download(repo_id=model_name, @@ -210,12 +210,12 @@ def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch model = cls(config) model.load_state_dict(ckpt, assign=True) - model = model.to(device, dtype) + model = model.to(dtype) else: config = Phi3Config.from_pretrained(model_name) model = cls(config) model.load_state_dict(ckpt) - model = model.to(device, dtype) + model = model.to(dtype) del ckpt torch.cuda.empty_cache() diff --git a/OmniGen/pipeline.py b/OmniGen/pipeline.py index 9f34086..c2325e4 100644 --- a/OmniGen/pipeline.py +++ b/OmniGen/pipeline.py @@ -88,7 +88,7 @@ def from_pretrained(cls, model_name, vae_path: str=None, device=None, low_cpu_me if device is None: device = best_available_device() - model = OmniGen.from_pretrained(model_name, dtype=torch.bfloat16, device=device, low_cpu_mem_usage=low_cpu_mem_usage) + model = OmniGen.from_pretrained(model_name, dtype=torch.bfloat16, low_cpu_mem_usage=low_cpu_mem_usage) processor = OmniGenProcessor.from_pretrained(model_name) if vae_path is None: From 53794a0236723f88f02c390822df2379566985d3 Mon Sep 17 00:00:00 2001 From: Rypo Date: Tue, 26 Nov 2024 09:27:30 -0600 Subject: [PATCH 03/13] fix: typo --- OmniGen/transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/OmniGen/transformer.py b/OmniGen/transformer.py index 4df2006..fefe546 100644 --- a/OmniGen/transformer.py +++ b/OmniGen/transformer.py @@ -42,7 +42,7 @@ def evict_previous_layer(self, layer_idx: int): for name, param in self.layers[prev_layer_idx].named_parameters(): param.data = param.data.to("cpu", non_blocking=True) - def get_offlaod_layer(self, layer_idx: int, device: torch.device): + def get_offload_layer(self, layer_idx: int, device: torch.device): # init stream if not hasattr(self, "prefetch_stream"): self.prefetch_stream = torch.cuda.Stream() @@ -153,7 +153,7 @@ def forward( ) else: if offload_model and not self.training: - self.get_offlaod_layer(layer_idx, device=inputs_embeds.device) + self.get_offload_layer(layer_idx, device=inputs_embeds.device) layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, From 889b6b95d08ab670096c6c262cf2fb487cf3a8f2 Mon Sep 17 00:00:00 2001 From: Rypo Date: Wed, 27 Nov 2024 14:38:08 -0600 Subject: [PATCH 04/13] feat: add 4bit and 8bit quantization support with bitsandbytes Add a quantization utility for HFQuantizers. Modify pipelines to accept quantization_config. Sets ground work for allow bf16 vae. Update requirements to include bitsandbytes. closes #45, closes #64 --- OmniGen/model.py | 27 ++++++++++++++++++++------- OmniGen/pipeline.py | 33 ++++++++++++++++++++++----------- OmniGen/utils.py | 26 ++++++++++++++++++++++++++ requirements.txt | 1 + 4 files changed, 69 insertions(+), 18 deletions(-) diff --git a/OmniGen/model.py b/OmniGen/model.py index b45f2e1..f1e44a6 100644 --- a/OmniGen/model.py +++ b/OmniGen/model.py @@ -12,9 +12,10 @@ from huggingface_hub import snapshot_download from safetensors.torch import load_file from accelerate import init_empty_weights +from transformers import BitsAndBytesConfig from OmniGen.transformer import Phi3Config, Phi3Transformer - +from OmniGen.utils import quantize_bnb def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) @@ -187,9 +188,13 @@ def __init__( self.llm = Phi3Transformer(config=transformer_config) self.llm.config.use_cache = False + + # bnb 4bit quantized models cannot be offloaded + self.offloadable = True + self.quantization_config = None @classmethod - def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch.bfloat16, low_cpu_mem_usage: bool = True,): + def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch.bfloat16, quantization_config: BitsAndBytesConfig = None, low_cpu_mem_usage: bool = True,): if not os.path.exists(model_name): cache_folder = os.getenv('HF_HUB_CACHE') model_name = snapshot_download(repo_id=model_name, @@ -201,22 +206,30 @@ def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch model_path = os.path.join(model_name, 'model.pt') ckpt = torch.load(model_path, map_location='cpu') else: - print("Loading safetensors") + #print("Loading safetensors") ckpt = load_file(model_path, 'cpu') if low_cpu_mem_usage: with init_empty_weights(): config = Phi3Config.from_pretrained(model_name) model = cls(config) - - model.load_state_dict(ckpt, assign=True) - model = model.to(dtype) + + if quantization_config: + model = quantize_bnb(model, ckpt, quantization_config=quantization_config, pre_quantized=False) + if getattr(quantization_config, 'load_in_4bit', None): + model.offloadable = False + model.quantization_config = quantization_config + else: + model.load_state_dict(ckpt, assign=True) else: + if quantization_config: + raise ValueError('Quantization not supported for `low_cpu_mem_usage=False`.') + config = Phi3Config.from_pretrained(model_name) model = cls(config) model.load_state_dict(ckpt) - model = model.to(dtype) + model = model.to(dtype) del ckpt torch.cuda.empty_cache() gc.collect() diff --git a/OmniGen/pipeline.py b/OmniGen/pipeline.py index c2325e4..233e924 100644 --- a/OmniGen/pipeline.py +++ b/OmniGen/pipeline.py @@ -1,6 +1,6 @@ import os import inspect -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, Literal import gc from PIL import Image @@ -17,6 +17,7 @@ scale_lora_layers, unscale_lora_layers, ) +from transformers import BitsAndBytesConfig from safetensors.torch import load_file from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler @@ -76,7 +77,7 @@ def __init__( self.model_cpu_offload = False @classmethod - def from_pretrained(cls, model_name, vae_path: str=None, device=None, low_cpu_mem_usage=True): + def from_pretrained(cls, model_name, vae_path: str=None, device=None, quantization_config:Literal['bnb_4bit','bnb_8bit']|BitsAndBytesConfig=None, low_cpu_mem_usage=True): if not os.path.exists(model_name) or (not os.path.exists(os.path.join(model_name, 'model.safetensors')) and model_name == "Shitao/OmniGen-v1"): logger.info("Model not found, downloading...") cache_folder = os.getenv('HF_HUB_CACHE') @@ -87,8 +88,16 @@ def from_pretrained(cls, model_name, vae_path: str=None, device=None, low_cpu_me if device is None: device = best_available_device() - - model = OmniGen.from_pretrained(model_name, dtype=torch.bfloat16, low_cpu_mem_usage=low_cpu_mem_usage) + + if isinstance(quantization_config, str): + if quantization_config == 'bnb_4bit': + quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float32, bnb_4bit_quant_type='nf4', bnb_4bit_use_double_quant=False) + elif quantization_config == 'bnb_8bit': + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + else: + raise NotImplementedError(f'Unknown `quantization_config` {quantization_config!r}') + + model = OmniGen.from_pretrained(model_name, dtype=torch.bfloat16, quantization_config=quantization_config, low_cpu_mem_usage=low_cpu_mem_usage) processor = OmniGenProcessor.from_pretrained(model_name) if vae_path is None: @@ -98,7 +107,7 @@ def from_pretrained(cls, model_name, vae_path: str=None, device=None, low_cpu_me logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF") vae_path = "stabilityai/sdxl-vae" - vae = AutoencoderKL.from_pretrained(vae_path).to(device) + vae = AutoencoderKL.from_pretrained(vae_path) return cls(vae, model, processor, device) @@ -131,7 +140,8 @@ def move_to_device(self, data): def enable_model_cpu_offload(self): self.model_cpu_offload = True - self.model.to("cpu") + if self.model.offloadable: + self.model.to("cpu") self.vae.to("cpu") torch.cuda.empty_cache() # Clear VRAM gc.collect() # Run garbage collection to free system RAM @@ -221,6 +231,7 @@ def __call__( if max_input_image_size != self.processor.max_image_size: self.processor = OmniGenProcessor(self.processor.text_tokenizer, max_image_size=max_input_image_size) self.model.to(dtype) + #self.vae.to(dtype) # Uncomment this line to allow bfloat16 VAE if offload_model: self.enable_model_cpu_offload() else: @@ -250,12 +261,12 @@ def __call__( for temp_pixel_values in input_data['input_pixel_values']: temp_input_latents = [] for img in temp_pixel_values: - img = self.vae_encode(img.to(self.device), dtype) + img = self.vae_encode(img.to(self.vae.device, self.vae.dtype), dtype) temp_input_latents.append(img) input_img_latents.append(temp_input_latents) else: for img in input_data['input_pixel_values']: - img = self.vae_encode(img.to(self.device), dtype) + img = self.vae_encode(img.to(self.vae.device, self.vae.dtype), dtype) input_img_latents.append(img) if input_images is not None and self.model_cpu_offload: self.vae.to('cpu') @@ -279,7 +290,7 @@ def __call__( else: func = self.model.forward_with_cfg - if self.model_cpu_offload: + if self.model_cpu_offload and self.model.offloadable: for name, param in self.model.named_parameters(): if 'layers' in name and 'layers.0' not in name: param.data = param.data.cpu() @@ -294,13 +305,13 @@ def __call__( samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache, offload_kv_cache=offload_kv_cache) samples = samples.chunk((1+num_cfg), dim=0)[0] - if self.model_cpu_offload: + if self.model_cpu_offload and self.model.offloadable: self.model.to('cpu') torch.cuda.empty_cache() gc.collect() self.vae.to(self.device) - samples = samples.to(torch.float32) + samples = samples.to(self.vae.dtype) if self.vae.config.shift_factor is not None: samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor else: diff --git a/OmniGen/utils.py b/OmniGen/utils.py index 67a64e8..2225641 100644 --- a/OmniGen/utils.py +++ b/OmniGen/utils.py @@ -1,9 +1,14 @@ +import gc import logging from PIL import Image import torch import numpy as np +from transformers import BitsAndBytesConfig +from transformers.quantizers import AutoHfQuantizer +from transformers.integrations import replace_with_bnb_linear, set_module_quantized_tensor_to_device + def create_logger(logging_dir): """ Create a logger that writes to a log file and stdout. @@ -108,3 +113,24 @@ def vae_encode_list(vae, x, weight_dtype): latents.append(img) return latents + + +@torch.no_grad() +def quantize_bnb(meta_model, state_dict:dict, quantization_config:BitsAndBytesConfig, pre_quantized=False): + # from transformers.integrations import get_keys_to_not_convert + + quantizer = AutoHfQuantizer.from_config(quantization_config, pre_quantized=pre_quantized) + no_convert = [] #get_keys_to_not_convert(meta_model.llm) # might be worth investigating + + model = replace_with_bnb_linear(meta_model, modules_to_not_convert=no_convert, quantization_config=quantizer.quantization_config) + + for param_name, param in state_dict.items(): + if not quantizer.check_quantized_param(model, param, param_name, state_dict): + set_module_quantized_tensor_to_device(model, param_name, device=0, value=param) + else: + quantizer.create_quantized_param(model, param, param_name, target_device=0, state_dict=state_dict) + + del state_dict + torch.cuda.empty_cache() + gc.collect() + return model \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 358c613..0600e8e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ pillow==10.2.0 peft==0.13.2 diffusers==0.30.3 timm==0.9.16 +bitsandbytes==0.44.1 \ No newline at end of file From 8ea2d6d7c48a7430d68626169a108322e298e22e Mon Sep 17 00:00:00 2001 From: Rypo Date: Wed, 27 Nov 2024 19:21:53 -0600 Subject: [PATCH 05/13] feat: add cli arg to gradio demo for nbit quantization --- app.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/app.py b/app.py index ba87673..87ee5c9 100644 --- a/app.py +++ b/app.py @@ -5,12 +5,8 @@ import random import spaces - from OmniGen import OmniGenPipeline -pipe = OmniGenPipeline.from_pretrained( - "Shitao/OmniGen-v1" -) @spaces.GPU(duration=180) def generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, separate_cfg_infer, offload_model, @@ -370,6 +366,8 @@ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_ with gr.Column(): with gr.Column(): + # quantization = gr.Radio(["4bit (NF4)", "8bit", "None (bf16)"], label="bitsandbytes quantization", value="4bit (NF4)") + # quantization.input(change_quantization, inputs=quantization, trigger_mode="once", concurrency_limit=1) # output image output_image = gr.Image(label="Output Image") save_images = gr.Checkbox(label="Save generated images", value=False) @@ -425,7 +423,21 @@ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_ if __name__ == "__main__": parser = argparse.ArgumentParser(description='Run the OmniGen') parser.add_argument('--share', action='store_true', help='Share the Gradio app') + parser.add_argument('-b', '--nbits', choices=['4','8'], help='bitsandbytes quantization n-bits') args = parser.parse_args() + if args.nbits == '4': + quantization_config = 'bnb_4bit' + elif args.nbits == '8': + quantization_config = 'bnb_8bit' + else: + quantization_config = None + + pipe = OmniGenPipeline.from_pretrained( + "Shitao/OmniGen-v1", + quantization_config = quantization_config, + low_cpu_mem_usage=True, + ) + # launch demo.launch(share=args.share) From 8d71606a75a9d0f690065f1b6f928a59848dfda9 Mon Sep 17 00:00:00 2001 From: Rypo Date: Tue, 3 Dec 2024 14:23:39 -0600 Subject: [PATCH 06/13] feat: support quantization with prequantized weights, add auto detection for bnb quant dict --- OmniGen/utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/OmniGen/utils.py b/OmniGen/utils.py index 2225641..946edd2 100644 --- a/OmniGen/utils.py +++ b/OmniGen/utils.py @@ -116,15 +116,22 @@ def vae_encode_list(vae, x, weight_dtype): @torch.no_grad() -def quantize_bnb(meta_model, state_dict:dict, quantization_config:BitsAndBytesConfig, pre_quantized=False): +def quantize_bnb(meta_model, state_dict:dict, quantization_config:BitsAndBytesConfig, pre_quantized=None): # from transformers.integrations import get_keys_to_not_convert + if pre_quantized is None: + if quantization_config.load_in_4bit: + pre_quantized = any('bitsandbytes__' in k for k in state_dict) + elif quantization_config.load_in_8bit: + pre_quantized = any('weight_format' in k for k in state_dict) quantizer = AutoHfQuantizer.from_config(quantization_config, pre_quantized=pre_quantized) no_convert = [] #get_keys_to_not_convert(meta_model.llm) # might be worth investigating model = replace_with_bnb_linear(meta_model, modules_to_not_convert=no_convert, quantization_config=quantizer.quantization_config) - - for param_name, param in state_dict.items(): + + # iterate the model keys, otherwise quantized state dict will throws errors + for param_name in model.state_dict(): + param = state_dict.get(param_name) if not quantizer.check_quantized_param(model, param, param_name, state_dict): set_module_quantized_tensor_to_device(model, param_name, device=0, value=param) else: From 7ba2cbcc79ca57178f687ea8364a9cbfd26f97c1 Mon Sep 17 00:00:00 2001 From: Rypo Date: Tue, 3 Dec 2024 14:48:39 -0600 Subject: [PATCH 07/13] feat: support loading prequantized weights from the hub --- OmniGen/model.py | 53 ++++++++++++++++++++++++----------- OmniGen/pipeline.py | 68 ++++++++++++++++++++++++++++----------------- 2 files changed, 80 insertions(+), 41 deletions(-) diff --git a/OmniGen/model.py b/OmniGen/model.py index f1e44a6..c671a56 100644 --- a/OmniGen/model.py +++ b/OmniGen/model.py @@ -1,6 +1,8 @@ # The code is revised from DiT import os import gc +import warnings +from pathlib import Path import torch import torch.nn as nn import numpy as np @@ -165,6 +167,7 @@ def __init__( pos_embed_max_size: int = 192, ): super().__init__() + self.config = transformer_config self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size @@ -191,41 +194,59 @@ def __init__( # bnb 4bit quantized models cannot be offloaded self.offloadable = True - self.quantization_config = None @classmethod def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch.bfloat16, quantization_config: BitsAndBytesConfig = None, low_cpu_mem_usage: bool = True,): - if not os.path.exists(model_name): - cache_folder = os.getenv('HF_HUB_CACHE') - model_name = snapshot_download(repo_id=model_name, - cache_dir=cache_folder, - ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) + model_path = Path(model_name) - model_path = os.path.join(model_name, 'model.safetensors') - if not os.path.exists(model_path): - model_path = os.path.join(model_name, 'model.pt') - ckpt = torch.load(model_path, map_location='cpu') + if model_path.exists(): + if model_path.is_dir(): + if (weights_loc := list(model_path.glob('*.safetensors'))): + model_path = weights_loc[0] + elif (weights_loc := list(model_path.glob('*.pt'))): + model_path = weights_loc[0] + else: + raise FileNotFoundError(f'No .safetensors or .pt model weights found in {model_path.as_posix()!r}') + else: - #print("Loading safetensors") - ckpt = load_file(model_path, 'cpu') + cache_folder = os.getenv('HF_HUB_CACHE') + model_path = snapshot_download(repo_id=model_name, cache_dir=cache_folder, + ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) + + # assume hub files are always .safetensors + model_path = next(Path(model_path).glob('*.safetensors')) + + ckpt = (load_file(model_path, 'cpu') if model_path.suffix == '.safetensors' else + torch.load(model_path, map_location='cpu')) + + config = Phi3Config.from_pretrained(model_name) + + if hasattr(config, 'quantization_config'): + if quantization_config is not None: + # from: diffusers.quantizers.auto + warnings.warn( + "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading" + " already has a `quantization_config` attribute. The `quantization_config` from the model will be used." + ) + + config.quantization_config.pop("quant_method",None) # prevent unused keys warning + quantization_config = BitsAndBytesConfig.from_dict(config.quantization_config) if low_cpu_mem_usage: with init_empty_weights(): - config = Phi3Config.from_pretrained(model_name) model = cls(config) if quantization_config: - model = quantize_bnb(model, ckpt, quantization_config=quantization_config, pre_quantized=False) + model = quantize_bnb(model, ckpt, quantization_config=quantization_config) if getattr(quantization_config, 'load_in_4bit', None): model.offloadable = False - model.quantization_config = quantization_config + model.config.quantization_config = quantization_config else: model.load_state_dict(ckpt, assign=True) else: if quantization_config: raise ValueError('Quantization not supported for `low_cpu_mem_usage=False`.') - config = Phi3Config.from_pretrained(model_name) model = cls(config) model.load_state_dict(ckpt) diff --git a/OmniGen/pipeline.py b/OmniGen/pipeline.py index 233e924..e4c3c5a 100644 --- a/OmniGen/pipeline.py +++ b/OmniGen/pipeline.py @@ -1,5 +1,6 @@ import os import inspect +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union, Literal import gc @@ -77,37 +78,54 @@ def __init__( self.model_cpu_offload = False @classmethod - def from_pretrained(cls, model_name, vae_path: str=None, device=None, quantization_config:Literal['bnb_4bit','bnb_8bit']|BitsAndBytesConfig=None, low_cpu_mem_usage=True): - if not os.path.exists(model_name) or (not os.path.exists(os.path.join(model_name, 'model.safetensors')) and model_name == "Shitao/OmniGen-v1"): + def from_pretrained(cls, model_name, vae_path: str=None, device=None, quantization_config:Literal['bnb_4bit','bnb_8bit']|BitsAndBytesConfig=None, low_cpu_mem_usage=True, **kwargs): + pretrained_path = Path(model_name) + + # XXX: Consider renaming 'model' to 'transformer' conform to diffusers pipeline syntax + model = kwargs.get('model', None) + processor = kwargs.get('processor', None) + vae = kwargs.get('vae', None) + + # NOTE: should technically allow delayed component inits via model/vae = None, but seems like more of a footgun than it's worth at this point + + if not pretrained_path.exists(): + ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt'] + + if model is not None: + ignore_patterns.append('model.safetensors') # avoid downloading bf16 model if passing existing model + logger.info("Model not found, downloading...") cache_folder = os.getenv('HF_HUB_CACHE') - model_name = snapshot_download(repo_id=model_name, - cache_dir=cache_folder, - ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5', 'model.pt']) - logger.info(f"Downloaded model to {model_name}") - - if device is None: - device = best_available_device() + pretrained_path = Path(snapshot_download(repo_id=model_name, cache_dir=cache_folder, ignore_patterns=ignore_patterns)) + logger.info(f"Downloaded model to {pretrained_path}") + + _quant_alias = { + 'bnb_4bit': BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float32, bnb_4bit_quant_type='nf4', bnb_4bit_use_double_quant=False), + 'bnb_8bit': BitsAndBytesConfig(load_in_8bit=True), + } - if isinstance(quantization_config, str): - if quantization_config == 'bnb_4bit': - quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float32, bnb_4bit_quant_type='nf4', bnb_4bit_use_double_quant=False) - elif quantization_config == 'bnb_8bit': - quantization_config = BitsAndBytesConfig(load_in_8bit=True) - else: - raise NotImplementedError(f'Unknown `quantization_config` {quantization_config!r}') + if model is None: + if isinstance(quantization_config, str): + try: + quantization_config = _quant_alias[quantization_config] + except KeyError: + raise NotImplementedError(f'Unknown `quantization_config` {quantization_config!r}') + + model = OmniGen.from_pretrained(pretrained_path, dtype=torch.bfloat16, quantization_config=quantization_config, low_cpu_mem_usage=low_cpu_mem_usage) - model = OmniGen.from_pretrained(model_name, dtype=torch.bfloat16, quantization_config=quantization_config, low_cpu_mem_usage=low_cpu_mem_usage) - processor = OmniGenProcessor.from_pretrained(model_name) - if vae_path is None: - vae_path = os.path.join(model_name, "vae") - - if not os.path.exists(vae_path): - logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF") - vae_path = "stabilityai/sdxl-vae" + if processor is None: + processor = OmniGenProcessor.from_pretrained(model_name) + + if vae is None: + if vae_path is None: + vae_path = pretrained_path.joinpath("vae") - vae = AutoencoderKL.from_pretrained(vae_path) + if not os.path.exists(vae_path): + logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF") + vae_path = "stabilityai/sdxl-vae" + + vae = AutoencoderKL.from_pretrained(vae_path) return cls(vae, model, processor, device) From caac3773cef7cd4d3855891ccaf04a0c9f05680d Mon Sep 17 00:00:00 2001 From: Rypo Date: Wed, 4 Dec 2024 18:36:20 -0600 Subject: [PATCH 08/13] fix: use default model config if passing weights file directly --- OmniGen/model.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/OmniGen/model.py b/OmniGen/model.py index c671a56..e94fc7a 100644 --- a/OmniGen/model.py +++ b/OmniGen/model.py @@ -10,6 +10,7 @@ from typing import Dict from diffusers.loaders import PeftAdapterMixin +from diffusers.utils import logging from timm.models.vision_transformer import PatchEmbed, Attention, Mlp from huggingface_hub import snapshot_download from safetensors.torch import load_file @@ -19,6 +20,9 @@ from OmniGen.transformer import Phi3Config, Phi3Transformer from OmniGen.utils import quantize_bnb + +logger = logging.get_logger(__name__) + def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) @@ -198,7 +202,8 @@ def __init__( @classmethod def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch.bfloat16, quantization_config: BitsAndBytesConfig = None, low_cpu_mem_usage: bool = True,): model_path = Path(model_name) - + config_loc = model_name # these only diverge when model_name is *.safetensors or *.pt file + if model_path.exists(): if model_path.is_dir(): if (weights_loc := list(model_path.glob('*.safetensors'))): @@ -207,7 +212,9 @@ def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch model_path = weights_loc[0] else: raise FileNotFoundError(f'No .safetensors or .pt model weights found in {model_path.as_posix()!r}') - + else: + logger.info("Loading model weights from file. Using default config from 'Shitao/OmniGen-v1'.") + config_loc = "Shitao/OmniGen-v1" else: cache_folder = os.getenv('HF_HUB_CACHE') model_path = snapshot_download(repo_id=model_name, cache_dir=cache_folder, @@ -219,7 +226,7 @@ def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch ckpt = (load_file(model_path, 'cpu') if model_path.suffix == '.safetensors' else torch.load(model_path, map_location='cpu')) - config = Phi3Config.from_pretrained(model_name) + config = Phi3Config.from_pretrained(config_loc) if hasattr(config, 'quantization_config'): if quantization_config is not None: From b066b34963ec6d5e6258c614faed46036fbaa931 Mon Sep 17 00:00:00 2001 From: Rypo Date: Thu, 5 Dec 2024 12:47:37 -0600 Subject: [PATCH 09/13] refactor: use to HFQuantizer preprocessing in place of manually preprocessing, skip quant norm layers --- OmniGen/utils.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/OmniGen/utils.py b/OmniGen/utils.py index 946edd2..f26f896 100644 --- a/OmniGen/utils.py +++ b/OmniGen/utils.py @@ -7,7 +7,7 @@ from transformers import BitsAndBytesConfig from transformers.quantizers import AutoHfQuantizer -from transformers.integrations import replace_with_bnb_linear, set_module_quantized_tensor_to_device +from transformers.integrations import replace_with_bnb_linear, set_module_quantized_tensor_to_device, get_keys_to_not_convert def create_logger(logging_dir): """ @@ -116,27 +116,40 @@ def vae_encode_list(vae, x, weight_dtype): @torch.no_grad() -def quantize_bnb(meta_model, state_dict:dict, quantization_config:BitsAndBytesConfig, pre_quantized=None): - # from transformers.integrations import get_keys_to_not_convert +def quantize_bnb(meta_model, state_dict:dict, quantization_config:BitsAndBytesConfig, pre_quantized=None, dtype=None): if pre_quantized is None: if quantization_config.load_in_4bit: pre_quantized = any('bitsandbytes__' in k for k in state_dict) elif quantization_config.load_in_8bit: pre_quantized = any('weight_format' in k for k in state_dict) + if quantization_config.llm_int8_skip_modules is None: + quantization_config.llm_int8_skip_modules = get_keys_to_not_convert(meta_model.llm) # ['norm'] + quantizer = AutoHfQuantizer.from_config(quantization_config, pre_quantized=pre_quantized) - no_convert = [] #get_keys_to_not_convert(meta_model.llm) # might be worth investigating - model = replace_with_bnb_linear(meta_model, modules_to_not_convert=no_convert, quantization_config=quantizer.quantization_config) + meta_model.eval() + meta_model.requires_grad_(False) + + model = meta_model - # iterate the model keys, otherwise quantized state dict will throws errors - for param_name in model.state_dict(): + quantizer.preprocess_model(model, device_map=None,) + + # iterate the model keys, otherwise quantized state dict will throws errors + for param_name in model.state_dict(): param = state_dict.get(param_name) + if not pre_quantized: + param = param.to(dtype) + if not quantizer.check_quantized_param(model, param, param_name, state_dict): set_module_quantized_tensor_to_device(model, param_name, device=0, value=param) else: quantizer.create_quantized_param(model, param, param_name, target_device=0, state_dict=state_dict) + del state_dict[param_name], param + + model = quantizer.postprocess_model(model) + del state_dict torch.cuda.empty_cache() gc.collect() From 6d5b4bcfbc8a48ebf979c6af33bba5a92bad8884 Mon Sep 17 00:00:00 2001 From: Rypo Date: Thu, 5 Dec 2024 12:58:16 -0600 Subject: [PATCH 10/13] fix: prevent device/dtype changes to quantized models --- OmniGen/model.py | 13 +++++++------ OmniGen/pipeline.py | 32 +++++++++++++++++++++----------- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/OmniGen/model.py b/OmniGen/model.py index e94fc7a..cca1f53 100644 --- a/OmniGen/model.py +++ b/OmniGen/model.py @@ -196,8 +196,9 @@ def __init__( self.llm = Phi3Transformer(config=transformer_config) self.llm.config.use_cache = False - # bnb 4bit quantized models cannot be offloaded - self.offloadable = True + # bnb quantized models cannot easily be offloaded or recast + self.quantized = False + self.dtype = None @classmethod def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch.bfloat16, quantization_config: BitsAndBytesConfig = None, low_cpu_mem_usage: bool = True,): @@ -244,9 +245,8 @@ def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch model = cls(config) if quantization_config: - model = quantize_bnb(model, ckpt, quantization_config=quantization_config) - if getattr(quantization_config, 'load_in_4bit', None): - model.offloadable = False + model = quantize_bnb(model, ckpt, quantization_config=quantization_config, dtype=dtype) + model.quantized = True model.config.quantization_config = quantization_config else: model.load_state_dict(ckpt, assign=True) @@ -257,7 +257,8 @@ def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch model = cls(config) model.load_state_dict(ckpt) - model = model.to(dtype) + model.dtype = dtype + del ckpt torch.cuda.empty_cache() gc.collect() diff --git a/OmniGen/pipeline.py b/OmniGen/pipeline.py index e4c3c5a..f403c91 100644 --- a/OmniGen/pipeline.py +++ b/OmniGen/pipeline.py @@ -113,6 +113,7 @@ def from_pretrained(cls, model_name, vae_path: str=None, device=None, quantizati model = OmniGen.from_pretrained(pretrained_path, dtype=torch.bfloat16, quantization_config=quantization_config, low_cpu_mem_usage=low_cpu_mem_usage) + model = model.requires_grad_(False).eval() if processor is None: processor = OmniGenProcessor.from_pretrained(model_name) @@ -125,7 +126,7 @@ def from_pretrained(cls, model_name, vae_path: str=None, device=None, quantizati logger.info(f"No VAE found in {model_name}, downloading stabilityai/sdxl-vae from HF") vae_path = "stabilityai/sdxl-vae" - vae = AutoencoderKL.from_pretrained(vae_path) + vae = AutoencoderKL.from_pretrained(vae_path, low_cpu_mem_usage=low_cpu_mem_usage) return cls(vae, model, processor, device) @@ -158,7 +159,7 @@ def move_to_device(self, data): def enable_model_cpu_offload(self): self.model_cpu_offload = True - if self.model.offloadable: + if not self.model.quantized: self.model.to("cpu") self.vae.to("cpu") torch.cuda.empty_cache() # Clear VRAM @@ -248,8 +249,13 @@ def __call__( # set model and processor if max_input_image_size != self.processor.max_image_size: self.processor = OmniGenProcessor(self.processor.text_tokenizer, max_image_size=max_input_image_size) - self.model.to(dtype) + + if not self.model.quantized: + self.model.dtype = dtype + self.model.to(dtype) + #self.vae.to(dtype) # Uncomment this line to allow bfloat16 VAE + if offload_model: self.enable_model_cpu_offload() else: @@ -271,7 +277,7 @@ def __call__( else: generator = None latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device, generator=generator) - latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype) + latents = torch.cat([latents]*(1+num_cfg), 0).to(self.model.dtype) if input_images is not None and self.model_cpu_offload: self.vae.to(self.device) input_img_latents = [] @@ -279,12 +285,12 @@ def __call__( for temp_pixel_values in input_data['input_pixel_values']: temp_input_latents = [] for img in temp_pixel_values: - img = self.vae_encode(img.to(self.vae.device, self.vae.dtype), dtype) + img = self.vae_encode(img.to(self.vae.device, self.vae.dtype), self.model.dtype) temp_input_latents.append(img) input_img_latents.append(temp_input_latents) else: for img in input_data['input_pixel_values']: - img = self.vae_encode(img.to(self.vae.device, self.vae.dtype), dtype) + img = self.vae_encode(img.to(self.vae.device, self.vae.dtype), self.model.dtype) input_img_latents.append(img) if input_images is not None and self.model_cpu_offload: self.vae.to('cpu') @@ -300,7 +306,7 @@ def __call__( img_cfg_scale=img_guidance_scale, use_img_cfg=use_img_guidance, use_kv_cache=use_kv_cache, - offload_model=offload_model, + offload_model=(offload_model and not self.model.quantized), ) if separate_cfg_infer: @@ -308,14 +314,16 @@ def __call__( else: func = self.model.forward_with_cfg - if self.model_cpu_offload and self.model.offloadable: + if self.model_cpu_offload and not self.model.quantized: for name, param in self.model.named_parameters(): if 'layers' in name and 'layers.0' not in name: - param.data = param.data.cpu() + param.data = param.data.to('cpu') else: param.data = param.data.to(self.device) for buffer_name, buffer in self.model.named_buffers(): setattr(self.model, buffer_name, buffer.to(self.device)) + torch.cuda.empty_cache() + gc.collect() # else: # self.model.to(self.device) @@ -323,8 +331,10 @@ def __call__( samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache, offload_kv_cache=offload_kv_cache) samples = samples.chunk((1+num_cfg), dim=0)[0] - if self.model_cpu_offload and self.model.offloadable: - self.model.to('cpu') + if self.model_cpu_offload: + if not self.model.quantized: + self.model.to("cpu") + torch.cuda.empty_cache() gc.collect() From ff20f1eb36847b1e37a7f713631cae6a9c875aa4 Mon Sep 17 00:00:00 2001 From: Rypo Date: Thu, 5 Dec 2024 13:06:40 -0600 Subject: [PATCH 11/13] refactor!: remove quantization_config argument from OmniGenPipeline.from_pretrained Removes quantization_config from main pipeline. Instead, use Diffusers style syntax where the config is passed to the transformer (model) when is then passed to the pipeline. --- OmniGen/pipeline.py | 17 +++-------------- app.py | 26 +++++++++++++------------- 2 files changed, 16 insertions(+), 27 deletions(-) diff --git a/OmniGen/pipeline.py b/OmniGen/pipeline.py index f403c91..0ef87de 100644 --- a/OmniGen/pipeline.py +++ b/OmniGen/pipeline.py @@ -78,7 +78,7 @@ def __init__( self.model_cpu_offload = False @classmethod - def from_pretrained(cls, model_name, vae_path: str=None, device=None, quantization_config:Literal['bnb_4bit','bnb_8bit']|BitsAndBytesConfig=None, low_cpu_mem_usage=True, **kwargs): + def from_pretrained(cls, model_name, vae_path: str=None, device=None, low_cpu_mem_usage=True, **kwargs): pretrained_path = Path(model_name) # XXX: Consider renaming 'model' to 'transformer' conform to diffusers pipeline syntax @@ -98,20 +98,9 @@ def from_pretrained(cls, model_name, vae_path: str=None, device=None, quantizati cache_folder = os.getenv('HF_HUB_CACHE') pretrained_path = Path(snapshot_download(repo_id=model_name, cache_dir=cache_folder, ignore_patterns=ignore_patterns)) logger.info(f"Downloaded model to {pretrained_path}") - - _quant_alias = { - 'bnb_4bit': BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float32, bnb_4bit_quant_type='nf4', bnb_4bit_use_double_quant=False), - 'bnb_8bit': BitsAndBytesConfig(load_in_8bit=True), - } - if model is None: - if isinstance(quantization_config, str): - try: - quantization_config = _quant_alias[quantization_config] - except KeyError: - raise NotImplementedError(f'Unknown `quantization_config` {quantization_config!r}') - - model = OmniGen.from_pretrained(pretrained_path, dtype=torch.bfloat16, quantization_config=quantization_config, low_cpu_mem_usage=low_cpu_mem_usage) + if model is None: + model = OmniGen.from_pretrained(pretrained_path, dtype=torch.bfloat16, quantization_config=None, low_cpu_mem_usage=low_cpu_mem_usage) model = model.requires_grad_(False).eval() diff --git a/app.py b/app.py index 87ee5c9..3dbe96b 100644 --- a/app.py +++ b/app.py @@ -4,9 +4,9 @@ import argparse import random import spaces +from transformers import BitsAndBytesConfig -from OmniGen import OmniGenPipeline - +from OmniGen import OmniGenPipeline, OmniGen @spaces.GPU(duration=180) def generate_image(text, img1, img2, img3, height, width, guidance_scale, img_guidance_scale, inference_steps, seed, separate_cfg_infer, offload_model, @@ -425,19 +425,19 @@ def run_for_examples(text, img1, img2, img3, height, width, guidance_scale, img_ parser.add_argument('--share', action='store_true', help='Share the Gradio app') parser.add_argument('-b', '--nbits', choices=['4','8'], help='bitsandbytes quantization n-bits') args = parser.parse_args() + + quantization_config = None + model = None + if args.nbits: + if args.nbits == '4': + quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4') + elif args.nbits == '8': + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + + model = OmniGen.from_pretrained("Shitao/OmniGen-v1", quantization_config=quantization_config) - if args.nbits == '4': - quantization_config = 'bnb_4bit' - elif args.nbits == '8': - quantization_config = 'bnb_8bit' - else: - quantization_config = None - pipe = OmniGenPipeline.from_pretrained( - "Shitao/OmniGen-v1", - quantization_config = quantization_config, - low_cpu_mem_usage=True, - ) + pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1", model=model) # launch demo.launch(share=args.share) From d75af76e67eab373543d881a01b2774c3f4138b3 Mon Sep 17 00:00:00 2001 From: Rypo Date: Thu, 12 Dec 2024 14:16:59 -0600 Subject: [PATCH 12/13] fix: prevent float16 numerical overflow Adds a small utility to scheduler to find the minimum clip bound to prevent NaNs from popping out of the decoder layers. Search over hardcoded buffer to discard as little information as possible. Phi3Transformer now raises OverflowError when NaNs encountered. Initialize model dtype based on actual weight value to avoid bad casts when quantized. --- OmniGen/model.py | 9 +++++++-- OmniGen/pipeline.py | 3 +++ OmniGen/scheduler.py | 27 ++++++++++++++++++++++++++- OmniGen/transformer.py | 10 +++++++++- OmniGen/utils.py | 2 +- 5 files changed, 46 insertions(+), 5 deletions(-) diff --git a/OmniGen/model.py b/OmniGen/model.py index 7a36b6c..ee07299 100644 --- a/OmniGen/model.py +++ b/OmniGen/model.py @@ -201,7 +201,7 @@ def __init__( self.dtype = None @classmethod - def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch.bfloat16, quantization_config: BitsAndBytesConfig = None, low_cpu_mem_usage: bool = True,): + def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = None, quantization_config: BitsAndBytesConfig = None, low_cpu_mem_usage: bool = True,): model_path = Path(model_name) config_loc = model_name # these only diverge when model_name is *.safetensors or *.pt file @@ -228,6 +228,9 @@ def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch torch.load(model_path, map_location='cpu')) config = Phi3Config.from_pretrained(config_loc) + # avoid inadvertently leaving the weights as float32 + if dtype is None: + dtype = config.torch_dtype if hasattr(config, 'quantization_config'): if quantization_config is not None: @@ -257,7 +260,9 @@ def from_pretrained(cls, model_name: str|os.PathLike, dtype: torch.dtype = torch model = cls(config) model.load_state_dict(ckpt) - model.dtype = dtype + + # determine dtype via x_emb bias since as a Conv2d bias, it should never be quantized + model.dtype = model.x_embedder.proj.bias.dtype del ckpt torch.cuda.empty_cache() diff --git a/OmniGen/pipeline.py b/OmniGen/pipeline.py index 0ef87de..baaf22a 100644 --- a/OmniGen/pipeline.py +++ b/OmniGen/pipeline.py @@ -317,6 +317,9 @@ def __call__( # self.model.to(self.device) scheduler = OmniGenScheduler(num_steps=num_inference_steps) + if latents.dtype == torch.float16: + # fp16 overflows at ±2^16-32, but the actual clamp value may have to be lower to maintain decoder layer stability + scheduler._fp16_clip_autoset(self.model.llm, latents, func, model_kwargs) samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache, offload_kv_cache=offload_kv_cache) samples = samples.chunk((1+num_cfg), dim=0)[0] diff --git a/OmniGen/scheduler.py b/OmniGen/scheduler.py index ffa99cd..0c2912b 100644 --- a/OmniGen/scheduler.py +++ b/OmniGen/scheduler.py @@ -1,10 +1,13 @@ +import copy from tqdm import tqdm from typing import Optional, Dict, Any, Tuple, List import gc import torch from transformers.cache_utils import Cache, DynamicCache, OffloadedCache +from diffusers.utils import logging +logger = logging.get_logger(__name__) class OmniGenCache(DynamicCache): @@ -121,8 +124,30 @@ def __init__(self, num_steps: int=50, time_shifting_factor: int=1): t = torch.linspace(0, 1, num_steps+1) t = t / (t + time_shifting_factor - time_shifting_factor * t) self.sigma = t - + @torch.no_grad() + def _fp16_clip_autoset(self, model_llm, z, func, model_kwargs): + '''Recursively search for a minimal clipping value for fp16 stability''' + timesteps = torch.full(size=(len(z), ), fill_value=self.sigma[0], device=z.device) + _nan_expon = model_kwargs.pop('_nan_expon', None) + if _nan_expon is not None: + clip_val = 2**16 - 2**_nan_expon # fp16 overflows after ±2^16-32 + model_llm.set_clip_val(clip_val) + + try: + _model_kwargs = copy.deepcopy(model_kwargs) + _model_kwargs['use_kv_cache']=False # no cache while searching + _, _ = func(z.clone(), timesteps, past_key_values=None, **_model_kwargs) + except OverflowError: + if _nan_expon is None: + logger.info('FP16 overflow, searching for clamp bounds...') + _nan_expon = 5 # start at 2**5 + + if _nan_expon < 15: # stop at 2**15 + model_kwargs['_nan_expon'] = _nan_expon+1 + return self._fp16_clip_autoset(model_llm, z, func, model_kwargs) + raise OverflowError('Numerical overflow, unable to find suitable clipping bounds.') + def crop_kv_cache(self, past_key_values, num_tokens_for_img): # return crop_past_key_values = () diff --git a/OmniGen/transformer.py b/OmniGen/transformer.py index 1bafb01..e843758 100644 --- a/OmniGen/transformer.py +++ b/OmniGen/transformer.py @@ -29,6 +29,11 @@ class Phi3Transformer(Phi3Model): Args: config: Phi3Config """ + _clip_val: float = None # fp16: ~ (2**16 - 2**7) + + def set_clip_val(self, clip_val:float=None): + self._clip_val = abs(clip_val) + def prefetch_layer(self, layer_idx: int, device: torch.device): "Starts prefetching the next layer cache" with torch.cuda.stream(self.prefetch_stream): @@ -137,6 +142,8 @@ def forward( for decoder_layer in self.layers: layer_idx += 1 + if self._clip_val is not None: + hidden_states.clamp_(-self._clip_val, self._clip_val) if output_hidden_states: all_hidden_states += (hidden_states,) @@ -173,7 +180,8 @@ def forward( all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) - + if hidden_states.isnan().any(): + raise OverflowError('Numerical Overflow: hidden states NaNs') # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/OmniGen/utils.py b/OmniGen/utils.py index f26f896..e6b2349 100644 --- a/OmniGen/utils.py +++ b/OmniGen/utils.py @@ -137,7 +137,7 @@ def quantize_bnb(meta_model, state_dict:dict, quantization_config:BitsAndBytesCo # iterate the model keys, otherwise quantized state dict will throws errors for param_name in model.state_dict(): - param = state_dict.get(param_name) + param = state_dict[param_name] if not pre_quantized: param = param.to(dtype) From 6ce30f2c7fcf95dd6514b5e7062b27b2ac644d47 Mon Sep 17 00:00:00 2001 From: Rypo Date: Mon, 16 Dec 2024 19:11:41 -0600 Subject: [PATCH 13/13] perf(fp16): reduce expected extra required iterations to 1. Start search with minimal clipping value found through testing (2^16 - 3*32). This value was sufficient for all tested inputs. Further analysis still required to guarantee that it will always be sufficient in all cases. --- OmniGen/pipeline.py | 4 +++- OmniGen/scheduler.py | 31 ++++++++++++++++++++++--------- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/OmniGen/pipeline.py b/OmniGen/pipeline.py index baaf22a..950ea33 100644 --- a/OmniGen/pipeline.py +++ b/OmniGen/pipeline.py @@ -318,7 +318,9 @@ def __call__( scheduler = OmniGenScheduler(num_steps=num_inference_steps) if latents.dtype == torch.float16: - # fp16 overflows at ±2^16-32, but the actual clamp value may have to be lower to maintain decoder layer stability + # Continue to monitor. If _clip_val never changes, can remove scheduler autoset func and just hardcode clip val here. + #self.model.llm.set_clip_val(2**16-32 - 2*32) # hardcode clip val + # dry run the inputs, adjusting the clip bounds as necessary scheduler._fp16_clip_autoset(self.model.llm, latents, func, model_kwargs) samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache, offload_kv_cache=offload_kv_cache) samples = samples.chunk((1+num_cfg), dim=0)[0] diff --git a/OmniGen/scheduler.py b/OmniGen/scheduler.py index 0c2912b..cf294f1 100644 --- a/OmniGen/scheduler.py +++ b/OmniGen/scheduler.py @@ -128,23 +128,36 @@ def __init__(self, num_steps: int=50, time_shifting_factor: int=1): @torch.no_grad() def _fp16_clip_autoset(self, model_llm, z, func, model_kwargs): '''Recursively search for a minimal clipping value for fp16 stability''' + fp16_max_repr = torch.finfo(torch.float16).max # fp16 max representable: ±2^16-32 timesteps = torch.full(size=(len(z), ), fill_value=self.sigma[0], device=z.device) - _nan_expon = model_kwargs.pop('_nan_expon', None) - if _nan_expon is not None: - clip_val = 2**16 - 2**_nan_expon # fp16 overflows after ±2^16-32 - model_llm.set_clip_val(clip_val) + _buff_expon = model_kwargs.pop('_buff_expon', None) # temp local recursion var + + if _buff_expon is None: + # fp16 overflows at ±2^16-16 with largest repr being ±2^16-32. repr vals occur at intervals of 32 for nums > 2^15. + # Prelim tests show an additional buffer of at least 2 repr values is needed for stability; why is presently unclear. + # If this continues to hold true, this function can be deleted and replaced with 1 line in pipeline. + clip_val = fp16_max_repr - 2*32 # = 2**6 = (-2,+2 buffer vals) + if model_llm._clip_val is None or model_llm._clip_val > clip_val: + model_llm.set_clip_val(clip_val) + logger.debug(f'set initial clamp: (+-){clip_val} ...') + else: + clip_val = fp16_max_repr - 2**_buff_expon + model_llm.set_clip_val(clip_val) # clamp (-clip_val, +clip_val) try: _model_kwargs = copy.deepcopy(model_kwargs) _model_kwargs['use_kv_cache']=False # no cache while searching _, _ = func(z.clone(), timesteps, past_key_values=None, **_model_kwargs) except OverflowError: - if _nan_expon is None: + if _buff_expon is None: + _buff_expon = 6 # start at 2**(6 + 1) (-4,+4 buffer vals) logger.info('FP16 overflow, searching for clamp bounds...') - _nan_expon = 5 # start at 2**5 - - if _nan_expon < 15: # stop at 2**15 - model_kwargs['_nan_expon'] = _nan_expon+1 + + if _buff_expon < 15: # stop at 2**15 (-1024,+1024 buffer vals) + _buff_expon += 1 + # each iter, double the representable value buffer capacity for both min and max + model_kwargs['_buff_expon'] = _buff_expon + logger.debug(f'trying clamp: (+-){fp16_max_repr - 2**(_buff_expon)} ...') return self._fp16_clip_autoset(model_llm, z, func, model_kwargs) raise OverflowError('Numerical overflow, unable to find suitable clipping bounds.')