Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for fp8_e4m3fn model #80

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model Tr
python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
```

You can also use the `--fp8` option to enable FP8 precision for reduced memory usage. Make sure to download the [FP8 model weight](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors) and place it in the `Wan2.1-T2V-1.3B` folder.

Additionally, an [FP8 version of the T5 model](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/umt5-xxl-enc-fp8_e4m3fn.safetensors) is available. To use the FP8 T5 model, update the configuration file:

```
t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors'
```

> 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance.


Expand Down Expand Up @@ -222,6 +230,17 @@ Similar to Text-to-Video, Image-to-Video is also divided into processes with and
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```

To minimize GPU memory usage, you can enable model offloading with `--offload_model True` and use FP8 precision with `--fp8`.

For example, to run **Wan2.1-I2V-14B-480P** on an RTX 4090 GPU:

1. First, download the [FP8 model weights](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors) and place them in the `Wan2.1-I2V-14B-480P` folder.
2. Then, execute the following command:

```
python generate.py --task i2v-14B --size 832*480 --ckpt_dir ./Wan2.1-I2V-14B-480P --offload_model True --fp8 --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```

> 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.


Expand Down
7 changes: 7 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def _parse_args():
action="store_true",
default=False,
help="Whether to use FSDP for DiT.")
parser.add_argument(
"--fp8",
action="store_true",
default=False,
help="Whether to use fp8.")
parser.add_argument(
"--save_file",
type=str,
Expand Down Expand Up @@ -306,6 +311,7 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
fp8=args.fp8,
)

logging.info(
Expand Down Expand Up @@ -363,6 +369,7 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
fp8=args.fp8,
)

logging.info("Generating video ...")
Expand Down
2 changes: 2 additions & 0 deletions wan/configs/wan_i2v_14B.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
i2v_14B.update(wan_shared_cfg)

i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
# i2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
i2v_14B.t5_tokenizer = 'google/umt5-xxl'

# clip
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
i2v_14B.clip_dtype = torch.float16
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
# i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' # Kijai's fp16 model
i2v_14B.clip_tokenizer = 'xlm-roberta-large'

# vae
Expand Down
1 change: 1 addition & 0 deletions wan/configs/wan_t2v_14B.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

# t5
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
# t2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
t2v_14B.t5_tokenizer = 'google/umt5-xxl'

# vae
Expand Down
1 change: 1 addition & 0 deletions wan/configs/wan_t2v_1_3B.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

# t5
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
# t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'

# vae
Expand Down
73 changes: 66 additions & 7 deletions wan/image2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
import torchvision.transforms.functional as TF
from tqdm import tqdm

from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from safetensors.torch import load_file

from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel
from .modules.model import WanModel
Expand All @@ -39,6 +43,7 @@ def __init__(
use_usp=False,
t5_cpu=False,
init_on_cpu=True,
fp8=False,
):
r"""
Initializes the image-to-video generation model components.
Expand All @@ -62,6 +67,8 @@ def __init__(
Whether to place T5 model on CPU. Only works without t5_fsdp.
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
fp8 (`bool`, *optional*, defaults to False):
Enable 8-bit floating point precision for model parameters.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
Expand All @@ -73,17 +80,23 @@ def __init__(
self.param_dtype = config.param_dtype

shard_fn = partial(shard_model, device_id=device_id)
if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
quantization = "fp8_e4m3fn"
else:
quantization = "disabled"
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None,
quantization=quantization,
)

self.vae_stride = config.vae_stride
self.patch_size = config.patch_size

self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
Expand All @@ -96,7 +109,46 @@ def __init__(
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))

logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir)
if not fp8:
self.model = WanModel.from_pretrained(checkpoint_dir)
else:
if '480P' in checkpoint_dir:
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu")
elif '720P' in checkpoint_dir:
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-720P_fp8_e4m3fn.safetensors', device="cpu")
dim = state_dict["patch_embedding.weight"].shape[0]
in_channels = state_dict["patch_embedding.weight"].shape[1]
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
model_type = "i2v" if in_channels == 36 else "t2v"
num_heads = 40 if dim == 5120 else 12
num_layers = 40 if dim == 5120 else 30
TRANSFORMER_CONFIG= {
"dim": dim,
"ffn_dim": ffn_dim,
"eps": 1e-06,
"freq_dim": 256,
"in_dim": in_channels,
"model_type": model_type,
"out_dim": 16,
"text_len": 512,
"num_heads": num_heads,
"num_layers": num_layers,
}

with init_empty_weights():
self.model = WanModel(**TRANSFORMER_CONFIG)

base_dtype=torch.bfloat16
dtype=torch.float8_e4m3fn
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
for name, param in self.model.named_parameters():
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
# dtype_to_use = torch.bfloat16
# print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])

del state_dict

self.model.eval().requires_grad_(False)

if t5_fsdp or dit_fsdp or use_usp:
Expand Down Expand Up @@ -219,13 +271,15 @@ def generate(self,
# preprocess
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]

Expand All @@ -242,9 +296,12 @@ def generate(self,
torch.zeros(3, 80, h, w)
],
dim=1).to(self.device)
])[0]
],device=self.device)[0]
y = torch.concat([msk, y])

if offload_model:
self.vae.model.cpu()

@contextmanager
def noop_no_sync():
yield
Expand Down Expand Up @@ -332,9 +389,11 @@ def noop_no_sync():
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
# load vae model back to device
self.vae.model.to(self.device)

if self.rank == 0:
videos = self.vae.decode(x0)
videos = self.vae.decode(x0, device=self.device)

del noise, latent
del sample_scheduler
Expand Down
10 changes: 8 additions & 2 deletions wan/modules/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from safetensors.torch import load_file

from .attention import flash_attention
from .tokenizers import HuggingfaceTokenizer
Expand Down Expand Up @@ -515,8 +516,13 @@ def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
device=device)
self.model = self.model.eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}')
self.model.load_state_dict(
torch.load(checkpoint_path, map_location='cpu'))
if checkpoint_path.endswith('.safetensors'):
state_dict = load_file(checkpoint_path, device='cpu')
self.model.load_state_dict(state_dict)
elif checkpoint_path.endswith('.pth'):
self.model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
else:
raise ValueError(f'Unsupported checkpoint file format: {checkpoint_path}')

# init tokenizer
self.tokenizer = HuggingfaceTokenizer(
Expand Down
38 changes: 30 additions & 8 deletions wan/modules/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@

from .tokenizers import HuggingfaceTokenizer

from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from safetensors.torch import load_file

__all__ = [
'T5Model',
'T5Encoder',
Expand Down Expand Up @@ -442,7 +446,7 @@ def _t5(name,
model = model_cls(**kwargs)

# set device
model = model.to(dtype=dtype, device=device)
# model = model.to(dtype=dtype, device=device)

# init tokenizer
if return_tokenizer:
Expand Down Expand Up @@ -479,21 +483,39 @@ def __init__(
checkpoint_path=None,
tokenizer_path=None,
shard_fn=None,
quantization="disabled",
):
self.text_len = text_len
self.dtype = dtype
self.device = device
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path

# init model
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device).eval().requires_grad_(False)

logging.info(f'loading {checkpoint_path}')
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
if quantization == "disabled":
# init model
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device).eval().requires_grad_(False)
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
elif quantization == "fp8_e4m3fn":
with init_empty_weights():
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device).eval().requires_grad_(False)
cast_dtype = torch.float8_e4m3fn
state_dict = load_file(checkpoint_path, device="cpu")
params_to_keep = {'norm', 'pos_embedding', 'token_embedding'}
for name, param in model.named_parameters():
dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype
set_module_tensor_to_device(model, name, device=device, dtype=dtype_to_use, value=state_dict[name])
del state_dict

self.model = model
if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False)
Expand Down
4 changes: 2 additions & 2 deletions wan/modules/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ def __init__(self,
z_dim=z_dim,
).eval().requires_grad_(False).to(device)

def encode(self, videos):
def encode(self, videos, device=None):
"""
videos: A list of videos each with shape [C, T, H, W].
"""
Expand All @@ -654,7 +654,7 @@ def encode(self, videos):
for u in videos
]

def decode(self, zs):
def decode(self, zs, device=None):
with amp.autocast(dtype=self.dtype):
return [
self.model.decode(u.unsqueeze(0),
Expand Down
Loading