From 6f148c2fade4a059909ba2ff19739463a18770b7 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Thu, 16 Jan 2025 22:53:15 -0800 Subject: [PATCH 01/20] Create the Flux Pipeline and tests --- .../sharktank/layers/configs/llm_configs.py | 8 +- sharktank/sharktank/models/vae/model.py | 9 - .../sharktank/pipelines/flux/__init__.py | 10 + sharktank/sharktank/pipelines/flux/export.py | 105 +++++ .../sharktank/pipelines/flux/flux_pipeline.py | 438 ++++++++++++++++++ .../sharktank/tools/import_hf_dataset.py | 7 +- .../pipelines/flux/flux_pipeline_test.py | 294 ++++++++++++ 7 files changed, 854 insertions(+), 17 deletions(-) create mode 100644 sharktank/sharktank/pipelines/flux/__init__.py create mode 100644 sharktank/sharktank/pipelines/flux/export.py create mode 100644 sharktank/sharktank/pipelines/flux/flux_pipeline.py create mode 100644 sharktank/tests/pipelines/flux/flux_pipeline_test.py diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 6cf79402e..55b97fd8e 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -225,8 +225,8 @@ def __post_init__(self): def from_gguf_properties(properties: dict[str, Any], **kwargs): assert properties["general.architecture"] == "t5" assert ( - properties["t5.attention.layer_norm_epsilon"] - == properties["t5.attention.layer_norm_rms_epsilon"] + properties["t5.attention.layer_norm_epsilon"] + == properties["t5.attention.layer_norm_rms_epsilon"] ) all_kwargs = {"vocab_size": None, "feed_forward_proj": None} @@ -293,7 +293,7 @@ class ClipTextConfig: @staticmethod def from_hugging_face_clip_text_model_config( - config: "transformers.CLIPTextConfig", + config: "transformers.CLIPTextConfig", # type: ignore ) -> "ClipTextConfig": return ClipTextConfig( vocab_size=config.vocab_size, @@ -314,7 +314,7 @@ def from_hugging_face_clip_text_model_config( dtype=config.torch_dtype or torch.float32, ) - def to_hugging_face_clip_text_model_config(self) -> "transformers.CLIPTextConfig": + def to_hugging_face_clip_text_model_config(self) -> "transformers.CLIPTextConfig": # type: ignore kwargs = self.to_properties() kwargs["torch_dtype"] = kwargs["dtype"] del kwargs["dtype"] diff --git a/sharktank/sharktank/models/vae/model.py b/sharktank/sharktank/models/vae/model.py index d689aaf72..c2746e800 100644 --- a/sharktank/sharktank/models/vae/model.py +++ b/sharktank/sharktank/models/vae/model.py @@ -74,15 +74,6 @@ def forward( "latent_embeds": latent_embeds, }, ) - if not self.hp.use_post_quant_conv: - sample = rearrange( - sample, - "b (h w) (c ph pw) -> b c (h ph) (w pw)", - h=math.ceil(1024 / 16), - w=math.ceil(1024 / 16), - ph=2, - pw=2, - ) sample = sample / self.hp.scaling_factor + self.hp.shift_factor if self.hp.use_post_quant_conv: diff --git a/sharktank/sharktank/pipelines/flux/__init__.py b/sharktank/sharktank/pipelines/flux/__init__.py new file mode 100644 index 000000000..0c3b3f4cb --- /dev/null +++ b/sharktank/sharktank/pipelines/flux/__init__.py @@ -0,0 +1,10 @@ +"""Flux text-to-image generation pipeline.""" + +from .flux_pipeline import FluxPipeline +from .export import export_flux_pipeline_mlir #, export_flux_pipeline_iree_parameters + +__all__ = [ + "FluxPipeline", + "export_flux_pipeline_mlir", + #"export_flux_pipeline_iree_parameters", +] \ No newline at end of file diff --git a/sharktank/sharktank/pipelines/flux/export.py b/sharktank/sharktank/pipelines/flux/export.py new file mode 100644 index 000000000..dfd4c1b97 --- /dev/null +++ b/sharktank/sharktank/pipelines/flux/export.py @@ -0,0 +1,105 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Export utilities for Flux text-to-image pipeline.""" +#TODO: DO NOT SUBMIT: FIX AND TEST THIS FILE +import functools +from typing import Optional, Union +from pathlib import Path +import torch +from copy import copy + +from .flux_pipeline import FluxPipeline +from ...types import Dataset +from ...transforms.dataset import set_float_dtype +from iree.turbine.aot import FxProgramsBuilder, export + +__all__ = [ + "export_flux_pipeline_mlir", + "export_flux_pipeline_iree_parameters", +] + +def export_flux_pipeline_mlir( + model: Union[FluxPipeline, Path, str], + batch_sizes: list[int], + mlir_output_path: str, + dtype: torch.dtype, +): + """Export Flux pipeline to MLIR format. + + Args: + model: Either the FluxPipeline instance or path to model files + batch_sizes: List of batch sizes to export for + mlir_output_path: Output path for MLIR file + """ + if isinstance(model, (Path, str)): + model = FluxPipeline( + t5_path=str(Path(model) / "text_encoder_2/model.gguf"), + clip_path=str(Path(model) / "text_encoder/model.irpa"), + transformer_path=str(Path(model) / "transformer/model.irpa"), + ae_path=str(Path(model) / "vae/model.irpa"), + dtype=dtype, + ) + + fxb = FxProgramsBuilder(model) + + for batch_size in batch_sizes: + # Create sample inputs with default dimensions + t5_prompt_ids = torch.zeros((batch_size, 128), dtype=torch.long) + clip_prompt_ids = torch.zeros((batch_size, 77), dtype=torch.long) + latents = model._get_noise( + 1, + 1024, + 1024, + seed=12345, + ) + + @fxb.export_program( + name=f"forward_bs{batch_size}", + args=(t5_prompt_ids, clip_prompt_ids, latents), + dynamic_shapes={}, + strict=False, + ) + def _(model, t5_prompt_ids, clip_prompt_ids, latents): + return model.forward( + t5_prompt_ids=t5_prompt_ids, + clip_prompt_ids=clip_prompt_ids, + latents=latents, + ) + + try: + output = export(fxb) + except Exception as e: + print(f"Error during export: {e}") + print(f"Model dtype: {model.dtype}") + print(f"Latents dtype: {latents.dtype}") + raise + output.save_mlir(mlir_output_path) + +# def export_flux_pipeline_iree_parameters( +# model_path_or_dataset: str | Dataset, +# output_path: str, +# dtype: Optional[torch.dtype] = None, +# ): +# """Export Flux pipeline parameters to IREE format. + +# Args: +# model_path_or_dataset: Path to model files or Dataset instance +# output_path: Output path for IREE parameters +# dtype: Optional dtype to convert parameters to +# """ +# # TODO: loop over models +# if isinstance(model_path_or_dataset, Dataset): +# dataset = copy(model_path_or_dataset) +# else: +# dataset = Dataset.load(model_path_or_dataset) + +# if dtype: +# dataset.root_theta = dataset.root_theta.transform( +# functools.partial(set_float_dtype, dtype=dtype) +# ) + +# dataset.save(output_path) \ No newline at end of file diff --git a/sharktank/sharktank/pipelines/flux/flux_pipeline.py b/sharktank/sharktank/pipelines/flux/flux_pipeline.py new file mode 100644 index 000000000..fc02d6582 --- /dev/null +++ b/sharktank/sharktank/pipelines/flux/flux_pipeline.py @@ -0,0 +1,438 @@ +"""Flux text-to-image generation pipeline.""" + +import argparse +import functools +import math +from os import PathLike +from typing import Callable, Optional + +import torch +import torch.nn as nn +from einops import rearrange, repeat +from PIL import Image +from torch import Tensor +from transformers import CLIPTokenizer, T5Tokenizer, CLIPTextModel as HfCLIPTextModel + +from sharktank.models.t5 import T5Config, T5Encoder +from sharktank.models.clip import ClipTextModel, ClipTextConfig +from sharktank.models.flux.flux import FluxModelV1, FluxParams +from sharktank.models.vae.model import VaeDecoderModel +from sharktank.types import Dataset +from sharktank.transforms.dataset import set_float_dtype + +class FluxPipeline(nn.Module): + """Pipeline for text-to-image generation using the Flux model.""" + + def __init__( + self, + t5_path: PathLike, + clip_path: PathLike, + transformer_path: PathLike, + ae_path: PathLike, + t5_tokenizer_path: Optional[PathLike] = None, + clip_tokenizer_path: Optional[PathLike] = None, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + dtype: torch.dtype = torch.bfloat16, + ): + """Initialize the Flux pipeline.""" + super().__init__() + self.device = torch.device(device) + self.dtype = dtype + if t5_tokenizer_path: + self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_tokenizer_path) + if clip_tokenizer_path: + self.clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path) + + # Load T5 + t5_dataset = Dataset.load(t5_path) + t5_config = T5Config.from_gguf_properties( + t5_dataset.properties, + feed_forward_proj="gated-gelu", + ) + t5_dataset.root_theta = t5_dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=dtype) + ) + self.t5_model = T5Encoder(theta=t5_dataset.root_theta, config=t5_config) + self.add_module('t5_model', self.t5_model) + self.t5_model.to(device) + + # Load CLIP + clip_dataset = Dataset.load(clip_path) + # TODO: Refactor CLIP to not make the config rely on HuggingFace + hf_clip_model = HfCLIPTextModel.from_pretrained("/data/flux/FLUX.1-dev/text_encoder/") + clip_config = ClipTextConfig.from_hugging_face_clip_text_model_config(hf_clip_model.config) + clip_dataset.root_theta = clip_dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=dtype) + ) + self.clip_model = ClipTextModel(theta=clip_dataset.root_theta, config=clip_config) + self.add_module('clip_model', self.clip_model) + self.clip_model.to(device) + + # Load Flux Transformer + transformer_dataset = Dataset.load(transformer_path) + transformer_params = FluxParams.from_hugging_face_properties(transformer_dataset.properties) + transformer_dataset.root_theta = transformer_dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=dtype) + ) + self.transformer_model = FluxModelV1( + theta=transformer_dataset.root_theta, + params=transformer_params + ) + self.add_module('transformer_model', self.transformer_model) + self.transformer_model.to(device) + + # Load VAE + ae_dataset = Dataset.load(ae_path) + ae_dataset.root_theta = ae_dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=dtype) + ) + self.ae_model = VaeDecoderModel.from_dataset(ae_dataset) + self.add_module('ae_model', self.ae_model) + self.ae_model.to(device) + + self._rng = torch.Generator(device="cpu") + + def _get_noise( + self, + num_samples: int, + height: int, + width: int, + seed: Optional[int] = None, + ) -> Tensor: + """Generate initial noise for the diffusion process.""" + if seed is not None: + self._rng.manual_seed(seed) + + return torch.randn( + num_samples, + 16, + # allow for packing + 2 * math.ceil(height / 16), + 2 * math.ceil(width / 16), + device=self.device, + dtype=self.dtype, + generator=self._rng, + ) + + def __call__( + self, + prompt: str, + height: int = 1024, + width: int = 1024, + num_inference_steps: Optional[int] = 50, + guidance_scale: float = 3.5, + seed: Optional[int] = None, + ) -> Tensor: + """Generate images from a prompt + + Args: + prompt: Text prompt for image generation + height: Height of output image + width: Width of output image + num_inference_steps: Number of denoising steps + guidance_scale: Scale for classifier-free guidance + seed: Random seed for reproducibility + + Returns: + Image tensor + + Raises: + ValueError: If tokenizers are not provided + """ + if not self.t5_tokenizer or not self.clip_tokenizer: + raise ValueError("Tokenizers must be provided to use the __call__ method") + + t5_prompt_ids, clip_prompt_ids = self.tokenize_prompt(prompt) + latents = self._get_noise( + 1, + height, + width, + seed=seed, + ) + + with torch.inference_mode(): + return self.forward( + t5_prompt_ids, + clip_prompt_ids, + latents, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + seed=seed, + ) + + + def forward( + self, + t5_prompt_ids: Tensor, + clip_prompt_ids: Tensor, + latents: Tensor, + height: int = 1024, + width: int = 1024, + num_inference_steps: Optional[int] = 1, # TODO: DO NOT SUBMIT + guidance_scale: float = 3.5, + seed: Optional[int] = None, + ) -> Tensor: + # Set default steps # TODO: Check if dev or schnell + if num_inference_steps is None: + num_inference_steps = 50 + + # Adjust dimensions to be multiples of 16 + height = 16 * (height // 16) + width = 16 * (width // 16) + + # Generate initial noise + x = latents + + # Prepare inputs + inp = self._prepare(self.t5_model, self.clip_model, t5_prompt_ids, clip_prompt_ids, x) + timesteps = self._get_schedule(num_inference_steps, inp["img"].shape[1], shift=True) + + # Denoise + x = self._denoise( + **inp, + timesteps=timesteps, + guidance=guidance_scale, + ) + + # Decode latents + x = self._unpack(x.to(dtype=self.dtype), height, width) + x = self.ae_model(x) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + x = x[0] + x = x.cpu() + x = x.clamp(-1, 1) + x = rearrange(x, "c h w -> h w c") + return x.float() + + + def _prepare(self, t5, clip, t5_prompt_ids, clip_prompt_ids, img: Tensor) -> dict[str, Tensor]: + """Prepare inputs for the transformer model. + + Args: + t5: T5 model for text encoding + clip: CLIP model for text encoding + t5_prompt_ids: Tokenized T5 prompt IDs + clip_prompt_ids: Tokenized CLIP prompt IDs + img: Initial noise tensor + + Returns: + Dictionary containing prepared inputs for the transformer + """ + bs, c, h, w = img.shape + + # Prepare image and position IDs + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + # Get text embeddings + # Process text through T5 + txt = t5(t5_prompt_ids)["last_hidden_state"] + txt_ids = torch.zeros(bs, txt.shape[1], 3, device=img.device) + + # Process text through CLIP + vec = clip(clip_prompt_ids)["pooler_output"] + + # Return prepared inputs + return { + "img": img, + "img_ids": img_ids, + "txt": txt, + "txt_ids": txt_ids, + "vec": vec, + } + + def _time_shift(self, mu: float, sigma: float, t: Tensor) -> Tensor: + """Apply time shift to the schedule.""" + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + def _get_lin_function( + self, + x1: float = 256, + y1: float = 0.5, + x2: float = 4096, + y2: float = 1.15 + ) -> Callable[[float], float]: + """Get linear interpolation function.""" + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + + def _get_schedule( + self, + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, + ) -> list[float]: + """Generate sampling schedule. + + Args: + num_steps: Number of diffusion steps + image_seq_len: Length of the image sequence + base_shift: Base shift value for schedule adjustment + max_shift: Maximum shift value for schedule adjustment + shift: Whether to apply schedule shifting + + Returns: + List of timesteps for the diffusion process + """ + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # estimate mu based on linear estimation between two points + mu = self._get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = self._time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + def _denoise( + self, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, + # extra img tokens + img_cond: Optional[Tensor] = None, + ) -> Tensor: + """Denoise the latents through the diffusion process.""" + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=self.dtype) + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_curr_vec = torch.full((img.shape[0],), t_curr, dtype=self.dtype, device=img.device) + t_prev_vec = torch.full((img.shape[0],), t_prev, dtype=self.dtype, device=img.device) + pred = self.transformer_model( + img=torch.cat((img, img_cond), dim=-1) if img_cond is not None else img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_curr_vec, + guidance=guidance_vec, + ) + print(t_prev, t_curr) + img = img + (t_prev_vec - t_curr_vec) * pred + + return img + + + def _unpack(self, x: Tensor, height: int, width: int) -> Tensor: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) + + def tokenize_prompt(self, prompt: str) -> tuple[Tensor, Tensor]: + """Tokenize a prompt using T5 and CLIP tokenizers. + + Args: + prompt: Text prompt to tokenize + + Returns: + Tuple of (t5_prompt_ids, clip_prompt_ids) tensors + """ + # T5 tokenization + t5_prompt_ids = [self.t5_tokenizer(p).input_ids for p in [prompt]] + t5_prompt_ids = torch.tensor(t5_prompt_ids, dtype=torch.long) + + # CLIP tokenization + clip_prompt_ids = [self.clip_tokenizer(p).input_ids for p in [prompt]] + clip_prompt_ids = torch.tensor(clip_prompt_ids, dtype=torch.long) + + return t5_prompt_ids, clip_prompt_ids + +def main(): + """Example usage of FluxPipeline.""" + parser = argparse.ArgumentParser(description="Flux text-to-image generation pipeline") + + # Model paths + parser.add_argument("--t5-path", default="/data/t5-v1_1-xxl/model.gguf", + help="Path to T5 model") + parser.add_argument("--clip-path", default="/data/flux/FLUX.1-dev/text_encoder/model.irpa", + help="Path to CLIP model") + parser.add_argument("--transformer-path", default="/data/flux/FLUX.1-dev/transformer/model.irpa", + help="Path to Transformer model") + parser.add_argument("--ae-path", default="/data/flux/FLUX.1-dev/vae/model.irpa", + help="Path to VAE model") + parser.add_argument("--t5-tokenizer-path", default="/data/flux/FLUX.1-dev/tokenizer_2/", + help="Path to T5 tokenizer") + parser.add_argument("--clip-tokenizer-path", default="/data/flux/FLUX.1-dev/tokenizer/", + help="Path to CLIP tokenizer") + + # Generation parameters + parser.add_argument("--height", type=int, default=1024, + help="Height of output image") + parser.add_argument("--width", type=int, default=1024, + help="Width of output image") + parser.add_argument("--num-inference-steps", type=int, default=None, + help="Number of denoising steps") + parser.add_argument("--guidance-scale", type=float, default=3.5, + help="Scale for classifier-free guidance") + parser.add_argument("--seed", type=int, default=None, + help="Random seed for reproducibility") + + # Other parameters + parser.add_argument("--prompt", default="a photo of a forest with mist swirling around the tree trunks. " + 'The word "FLUX" is painted over it in big, red brush strokes with visible texture', + help="Text prompt for image generation") + parser.add_argument("--output", default="output.jpg", + help="Output image path") + parser.add_argument("--dtype", default="bfloat16", choices=["float32", "float16", "bfloat16"], + help="Data type for model") + + args = parser.parse_args() + + # Map dtype string to torch dtype + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16 + } + + # Initialize pipeline + pipeline = FluxPipeline( + t5_path=args.t5_path, + clip_path=args.clip_path, + transformer_path=args.transformer_path, + ae_path=args.ae_path, + t5_tokenizer_path=args.t5_tokenizer_path, + clip_tokenizer_path=args.clip_tokenizer_path, + dtype=dtype_map[args.dtype] + ) + + # Generate image + x = pipeline( + prompt=args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + seed=args.seed + ) + + # Transform and save first image + image = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + image.save(args.output, quality=95, subsampling=0) + + +if __name__ == "__main__": + main() diff --git a/sharktank/sharktank/tools/import_hf_dataset.py b/sharktank/sharktank/tools/import_hf_dataset.py index 8fb5b3f16..46f3f973c 100644 --- a/sharktank/sharktank/tools/import_hf_dataset.py +++ b/sharktank/sharktank/tools/import_hf_dataset.py @@ -48,12 +48,11 @@ def import_hf_dataset( meta_params = {k: v for k, v in config_json.items() if k.startswith("_")} hparams = {k: v for k, v in config_json.items() if not k.startswith("_")} + tensors = [] for params_path in param_paths: with safetensors.safe_open(params_path, framework="pt", device="cpu") as st: - tensors = [ - DefaultPrimitiveTensor( - name=name, data=st.get_tensor(name).to(target_dtype) - ) + tensors += [ + DefaultPrimitiveTensor(name=name, data=st.get_tensor(name)) for name in st.keys() ] diff --git a/sharktank/tests/pipelines/flux/flux_pipeline_test.py b/sharktank/tests/pipelines/flux/flux_pipeline_test.py new file mode 100644 index 000000000..0f3ca202e --- /dev/null +++ b/sharktank/tests/pipelines/flux/flux_pipeline_test.py @@ -0,0 +1,294 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# DO NOT SUBMIT: REVIEW AND TEST FILE + +"""Tests for Flux text-to-image pipeline.""" + +import functools +from typing import Optional +import os +from collections import OrderedDict +import pytest +import torch +from unittest import TestCase +import numpy + +from transformers import CLIPTokenizer, T5Tokenizer +from diffusers import FluxPipeline as ReferenceFluxPipeline + +from sharktank.types import Dataset, dtype_to_serialized_short_name +from sharktank.pipelines.flux import ( + FluxPipeline, + export_flux_pipeline_mlir, + #export_flux_pipeline_iree_parameters, +) +from sharktank.utils.testing import TempDirTestBase +from sharktank.transforms.dataset import set_float_dtype +from sharktank.utils.iree import ( + get_iree_devices, + load_iree_module, + run_iree_module_function, + prepare_iree_module_function_args, + call_torch_module_function, + flatten_for_iree_signature, + iree_to_torch, +) +from sharktank import ops +import iree.compiler + +with_flux_data = pytest.mark.skipif("not config.getoption('with_flux_data')") + +@pytest.mark.usefixtures("get_model_artifacts") +class FluxPipelineEagerTest(TestCase): + def setUp(self): + super().setUp() + torch.random.manual_seed(12345) + torch.no_grad() + + @with_flux_data + def testFluxPipelineAgainstGolden(self): + """Test against golden outputs from the original Flux pipeline.""" + model = FluxPipeline( + t5_path="/data/t5-v1_1-xxl/model.gguf", + clip_path="/data/flux/FLUX.1-dev/text_encoder/model.irpa", + transformer_path="/data/flux/FLUX.1-dev/transformer/model.irpa", + ae_path="/data/flux/FLUX.1-dev/vae/model.irpa", + dtype=torch.bfloat16, + ) + + # Load reference inputs + with open("/data/flux/test_data/t5_prompt_ids.pt", "rb") as f: + t5_prompt_ids = torch.load(f) + with open("/data/flux/test_data/clip_prompt_ids.pt", "rb") as f: + clip_prompt_ids = torch.load(f) + + # Generate output using forward method directly + latents = model._get_noise( + 1, + 1024, + 1024, + seed=12345, + ) + output = model.forward( + t5_prompt_ids, + clip_prompt_ids, + latents=latents, + num_inference_steps=1, + seed=12345, + ) + + # Compare against golden output + with open("/data/flux/test_data/flux_1_step_output.pt", "rb") as f: + reference_output = torch.load(f) + + torch.testing.assert_close(output, reference_output) # TODO: why is this not passing? + + def runTestFluxPipelineAgainstReference( + self, + dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): + """Compare pipeline outputs between different dtypes.""" + # Initialize reference model + reference_model = ReferenceFluxPipeline.from_pretrained("/data/flux/FLUX.1-dev/") + + # Initialize target model + target_model = FluxPipeline( + t5_path="/data/t5-v1_1-xxl/model.gguf", + clip_path="/data/flux/FLUX.1-dev/text_encoder/model.irpa", + transformer_path="/data/flux/FLUX.1-dev/transformer/model.irpa", + ae_path="/data/flux/FLUX.1-dev/vae/model.irpa", + t5_tokenizer_path="/data/flux/FLUX.1-dev/tokenizer_2/", + clip_tokenizer_path="/data/flux/FLUX.1-dev/tokenizer/", + dtype=dtype, + ) + + # Generate outputs using string prompt + prompt = "a photo of a forest with mist" + reference_image_output = reference_model( + prompt=prompt, + height=1024, + width=1024, + num_inference_steps=1, + guidance_scale=3.5 + ).images[0] + reference_output = torch.tensor(numpy.array(reference_image_output)).to(dtype=dtype) + + target_output = target_model( + prompt=prompt, + height=1024, + width=1024, + num_inference_steps=1, + guidance_scale=3.5 + ) + + torch.testing.assert_close(reference_output, target_output, atol=atol, rtol=rtol) + + @with_flux_data + def testFluxPipelineF32(self): + """Test F32 pipeline against reference.""" + self.runTestFluxPipelineAgainstReference( + dtype=torch.float32, + ) + + @with_flux_data + def testFluxPipelineBF16(self): + """Test BF16 pipeline against refence.""" + self.runTestFluxPipelineAgainstReference( + dtype=torch.bfloat16, + ) + + +@pytest.mark.usefixtures("caching", "get_model_artifacts", "path_prefix") +class FluxPipelineIreeTest(TempDirTestBase): + def setUp(self): + super().setUp() + if self.path_prefix is None: + self.path_prefix = f"{self._temp_dir}/" + + def runTestFluxPipelineIreeCompare( + self, + reference_dtype: torch.dtype, + target_dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): + """Compare IREE pipeline against eager execution.""" + # Initialize reference model + reference_model = FluxPipeline( + t5_path="/data/t5-v1_1-xxl/model.gguf", + clip_path="/data/flux/FLUX.1-dev/text_encoder/model.irpa", + transformer_path="/data/flux/FLUX.1-dev/transformer/model.irpa", + ae_path="/data/flux/FLUX.1-dev/vae/model.irpa", + t5_tokenizer_path="/data/flux/FLUX.1-dev/tokenizer_2/", + clip_tokenizer_path="/data/flux/FLUX.1-dev/tokenizer/", + dtype=reference_dtype, + ) + + # Create input tokens + t5_tokenizer = T5Tokenizer.from_pretrained("/data/flux/FLUX.1-dev/tokenizer_2/") + clip_tokenizer = CLIPTokenizer.from_pretrained("/data/flux/FLUX.1-dev/tokenizer/") + + prompt = "a photo of a forest with mist" + t5_prompt_ids = torch.tensor([t5_tokenizer(prompt).input_ids], dtype=torch.long) + clip_prompt_ids = torch.tensor([clip_tokenizer(prompt).input_ids], dtype=torch.long) + latents = reference_model._get_noise( + 1, + 1024, + 1024, + seed=12345, + ).to(dtype=target_dtype) # TODO: it isn't great to be getting this from the reference model + + input_args = OrderedDict([ + ("t5_prompt_ids", t5_prompt_ids), + ("clip_prompt_ids", clip_prompt_ids), + ("latents", latents) + ]) + batch_size = t5_prompt_ids.shape[0] + + # Get reference result + reference_result = reference_model.forward(t5_prompt_ids, clip_prompt_ids, latents) + + # Export and compile for IREE + target_dtype_name = dtype_to_serialized_short_name(target_dtype) + target_path_prefix = f"{self.path_prefix}flux_pipeline_{target_dtype_name}" + + parameters_path = f"/data/flux/FLUX.1-dev/" + # if not self.caching or not os.path.exists(parameters_path): + # export_flux_pipeline_iree_parameters( + # "/data/flux/FLUX.1-dev", + # parameters_path, + # dtype=target_dtype, + # ) + + mlir_path = f"{target_path_prefix}.mlir" + if not self.caching or not os.path.exists(mlir_path): + export_flux_pipeline_mlir( + parameters_path, + batch_sizes=[batch_size], + mlir_output_path=mlir_path, + dtype=target_dtype + ) + + iree_module_path = f"{target_path_prefix}.vmfb" + if not self.caching or not os.path.exists(iree_module_path): + iree.compiler.compile_file( + mlir_path, + output_file=iree_module_path, + extra_args=[ + "--iree-hal-target-device=hip", + "--iree-hip-target=gfx942", + "--iree-opt-const-eval=false", + "--iree-opt-strip-assertions=true", + "--iree-global-opt-propagate-transposes=true", + "--iree-dispatch-creation-enable-fuse-horizontal-contractions=true", + "--iree-dispatch-creation-enable-aggressive-fusion=true", + "--iree-opt-aggressively-propagate-transposes=true", + "--iree-opt-outer-dim-concat=true", + "--iree-vm-target-truncate-unsupported-floats", + "--iree-llvmgpu-enable-prefetch=true", + "--iree-opt-data-tiling=false", + "--iree-codegen-gpu-native-math-precision=true", + "--iree-codegen-llvmgpu-use-vector-distribution", + "--iree-hip-waves-per-eu=2", + "--iree-execution-model=async-external", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)", + ], + ) + + # Run with IREE + iree_devices = get_iree_devices(driver="hip", device_count=1) + iree_module, iree_vm_context, iree_vm_instance = load_iree_module( + module_path=iree_module_path, + devices=iree_devices, + parameters_path=parameters_path, + ) + iree_args = prepare_iree_module_function_args( + args=flatten_for_iree_signature(input_args), + devices=iree_devices, + ) + iree_result = iree_to_torch( + *run_iree_module_function( + module=iree_module, + vm_context=iree_vm_context, + args=iree_args, + driver="hip", + function_name=f"forward_bs{batch_size}", + trace_path_prefix=f"{target_path_prefix}_iree_", + ) + ) + iree_result = [ + ops.to(iree_result[i], dtype=reference_result[i].dtype) + for i in range(len(reference_result)) + ] + + torch.testing.assert_close(reference_result, iree_result, atol=atol, rtol=rtol) + + @with_flux_data + def testFluxPipelineIreeF32(self): + """Test F32 IREE pipeline against eager execution.""" + self.runTestFluxPipelineIreeCompare( + reference_dtype=torch.float32, + target_dtype=torch.float32, + atol=1e-4, + rtol=2.0e-3, + ) + + @pytest.mark.xfail( + raises=AssertionError, + reason="BF16 vs F32 accuracy needs investigation", + ) + @with_flux_data + def testFluxPipelineIreeBF16vsF32(self): + """Test BF16 IREE pipeline against F16 eager execution.""" + self.runTestFluxPipelineIreeCompare( + reference_dtype=torch.float32, + target_dtype=torch.bfloat16, + atol=1e-2, + rtol=1.6e-2, + ) From af6e9d7d7f79da928a8f1a892136b9a746bf3478 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Mon, 20 Jan 2025 10:55:02 -0800 Subject: [PATCH 02/20] More updates --- sharktank/sharktank/models/clip/export.py | 8 +- sharktank/sharktank/models/flux/export.py | 11 +- sharktank/sharktank/models/flux/flux.py | 15 ++- .../sharktank/pipelines/flux/__init__.py | 4 +- sharktank/sharktank/pipelines/flux/export.py | 111 +++++++++++++----- .../sharktank/pipelines/flux/flux_pipeline.py | 38 +++--- .../pipelines/flux/flux_pipeline_test.py | 80 +++++++------ 7 files changed, 178 insertions(+), 89 deletions(-) diff --git a/sharktank/sharktank/models/clip/export.py b/sharktank/sharktank/models/clip/export.py index 95dbdacad..aba0e730d 100644 --- a/sharktank/sharktank/models/clip/export.py +++ b/sharktank/sharktank/models/clip/export.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import functools from typing import Optional, Union import transformers from transformers.models.clip.modeling_clip import ( @@ -18,6 +19,7 @@ from ...layers.configs import ClipTextConfig from .clip import ClipTextModel from iree.turbine.aot import FxProgramsBuilder, export +from sharktank.transforms.dataset import set_float_dtype def hugging_face_clip_attention_to_theta(model: HfCLIPAttention) -> Theta: @@ -50,8 +52,12 @@ def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset: return Dataset(properties=model.config.to_properties(), root_theta=model.theta) -def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: PathLike): +def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: PathLike, dtype=None): dataset = clip_text_model_to_dataset(model) + if dtype: + dataset.root_theta = tdataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=dtype) + ) dataset.save(output_path) diff --git a/sharktank/sharktank/models/flux/export.py b/sharktank/sharktank/models/flux/export.py index a63e75af9..6485593c0 100644 --- a/sharktank/sharktank/models/flux/export.py +++ b/sharktank/sharktank/models/flux/export.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import functools from os import PathLike import os from pathlib import Path @@ -14,6 +15,7 @@ from .flux import FluxModelV1, FluxParams from ...types import Dataset from ...utils.hf_datasets import get_dataset +from sharktank.transforms.dataset import set_float_dtype flux_transformer_default_batch_sizes = [1] @@ -27,11 +29,14 @@ def export_flux_transformer_model_mlir( def export_flux_transformer_iree_parameters( - model: FluxModelV1, parameters_output_path: PathLike + model: FluxModelV1, parameters_output_path: PathLike, dtype = None ): model.theta.rename_tensors_to_paths() - # TODO: export properties - dataset = Dataset(root_theta=model.theta, properties={}) + dataset = Dataset(root_theta=model.theta, properties=model.params.to_hugging_face_properties()) + if dtype: + dataset.root_theta = dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=dtype) + ) dataset.save(parameters_output_path) diff --git a/sharktank/sharktank/models/flux/flux.py b/sharktank/sharktank/models/flux/flux.py index 725c483ca..e8003bb68 100644 --- a/sharktank/sharktank/models/flux/flux.py +++ b/sharktank/sharktank/models/flux/flux.py @@ -12,8 +12,8 @@ from typing import Any, Optional from collections import OrderedDict from copy import copy +from dataclasses import dataclass, asdict import math -from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F @@ -49,6 +49,19 @@ class FluxParams: qkv_bias: bool guidance_embed: bool + def to_hugging_face_properties(self) -> dict[str, Any]: + hparams = { + "in_channels": self.in_channels, + "pooled_projection_dim": self.vec_in_dim, + "joint_attention_dim": self.context_in_dim, + "num_attention_heads": self.num_heads, + "num_layers": self.depth, + "num_single_layers": self.depth_single_blocks, + "attention_head_dim": sum(self.axes_dim), + "guidance_embeds": self.guidance_embed + } + return {"hparams": hparams} + @staticmethod def from_hugging_face_properties(properties: dict[str, Any]) -> "FluxParams": p = properties["hparams"] diff --git a/sharktank/sharktank/pipelines/flux/__init__.py b/sharktank/sharktank/pipelines/flux/__init__.py index 0c3b3f4cb..f427f423a 100644 --- a/sharktank/sharktank/pipelines/flux/__init__.py +++ b/sharktank/sharktank/pipelines/flux/__init__.py @@ -1,10 +1,10 @@ """Flux text-to-image generation pipeline.""" from .flux_pipeline import FluxPipeline -from .export import export_flux_pipeline_mlir #, export_flux_pipeline_iree_parameters +from .export import export_flux_pipeline_mlir , export_flux_pipeline_iree_parameters __all__ = [ "FluxPipeline", "export_flux_pipeline_mlir", - #"export_flux_pipeline_iree_parameters", + "export_flux_pipeline_iree_parameters", ] \ No newline at end of file diff --git a/sharktank/sharktank/pipelines/flux/export.py b/sharktank/sharktank/pipelines/flux/export.py index dfd4c1b97..a1f9502ef 100644 --- a/sharktank/sharktank/pipelines/flux/export.py +++ b/sharktank/sharktank/pipelines/flux/export.py @@ -5,17 +5,24 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """Export utilities for Flux text-to-image pipeline.""" -#TODO: DO NOT SUBMIT: FIX AND TEST THIS FILE import functools from typing import Optional, Union from pathlib import Path import torch from copy import copy +import logging from .flux_pipeline import FluxPipeline -from ...types import Dataset +from ...types import Dataset, dtype_to_serialized_short_name from ...transforms.dataset import set_float_dtype from iree.turbine.aot import FxProgramsBuilder, export +from ...models.t5.export import export_encoder_iree_parameters as export_t5_parameters +from ...models.clip.export import export_clip_text_model_iree_parameters +from ...models.flux.export import export_flux_transformer_iree_parameters +from ...models.vae.model import VaeDecoderModel +from ...models.clip import ClipTextModel, ClipTextConfig +from transformers import CLIPTokenizer, T5Tokenizer, CLIPTextModel as HfCLIPTextModel +from ...models.flux.flux import FluxModelV1, FluxParams __all__ = [ "export_flux_pipeline_mlir", @@ -36,11 +43,12 @@ def export_flux_pipeline_mlir( mlir_output_path: Output path for MLIR file """ if isinstance(model, (Path, str)): + model_parameter_path = Path(model) / f"exported_parameters_{dtype_to_serialized_short_name(dtype)}" model = FluxPipeline( - t5_path=str(Path(model) / "text_encoder_2/model.gguf"), - clip_path=str(Path(model) / "text_encoder/model.irpa"), - transformer_path=str(Path(model) / "transformer/model.irpa"), - ae_path=str(Path(model) / "vae/model.irpa"), + t5_path=str(model_parameter_path / "t5.irpa"), + clip_path=str(model_parameter_path / "clip.irpa"), + transformer_path=str(model_parameter_path / "transformer.irpa"), + ae_path=str(model_parameter_path / "vae.irpa"), dtype=dtype, ) @@ -79,27 +87,72 @@ def _(model, t5_prompt_ids, clip_prompt_ids, latents): raise output.save_mlir(mlir_output_path) -# def export_flux_pipeline_iree_parameters( -# model_path_or_dataset: str | Dataset, -# output_path: str, -# dtype: Optional[torch.dtype] = None, -# ): -# """Export Flux pipeline parameters to IREE format. +def is_already_exported(output_path: Path) -> bool: + return output_path.exists() + +def export_flux_pipeline_iree_parameters( + model_path_or_dataset: str | Dataset, + output_path: str, + dtype: Optional[torch.dtype] = None, +): + """Export Flux pipeline parameters to IREE format. -# Args: -# model_path_or_dataset: Path to model files or Dataset instance -# output_path: Output path for IREE parameters -# dtype: Optional dtype to convert parameters to -# """ -# # TODO: loop over models -# if isinstance(model_path_or_dataset, Dataset): -# dataset = copy(model_path_or_dataset) -# else: -# dataset = Dataset.load(model_path_or_dataset) - -# if dtype: -# dataset.root_theta = dataset.root_theta.transform( -# functools.partial(set_float_dtype, dtype=dtype) -# ) - -# dataset.save(output_path) \ No newline at end of file + Args: + model_path_or_dataset: Path to model files or Dataset instance + output_path: Output path for IREE parameters + dtype: Optional dtype to convert parameters to + """ + # Ensure output_path is a Path object + output_path = Path(output_path) / f"exported_parameters_{dtype_to_serialized_short_name(dtype)}" + output_path.mkdir(parents=True, exist_ok=True) + + # Export T5 parameters + t5_path = Path(model_path_or_dataset) / "text_encoder_2/model.gguf" + t5_output_path = output_path / "t5.irpa" + print("hi") + if not is_already_exported(t5_output_path): + print("hello") + export_t5_parameters(t5_path, str(t5_output_path), dtype) + logging.info(f"Exported T5 parameters to {t5_output_path}") + else: + logging.info(f"Skipped T5 parameter export, already exists at {t5_output_path}") + + # Export CLIP parameters + clip_path = Path(model_path_or_dataset) / "text_encoder/model.irpa" + clip_output_path = output_path / "clip.irpa" + if not is_already_exported(clip_output_path): + clip_dataset = Dataset.load(clip_path) + # TODO: Refactor CLIP to not make the config rely on HuggingFace + hf_clip_model = HfCLIPTextModel.from_pretrained("/data/flux/FLUX.1-dev/text_encoder/") + clip_config = ClipTextConfig.from_hugging_face_clip_text_model_config(hf_clip_model.config) + clip_model = ClipTextModel(theta=clip_dataset.root_theta, config=clip_config) + export_clip_text_model_iree_parameters(clip_model, str(clip_output_path)) + logging.info(f"Exported CLIP parameters to {clip_output_path}") + else: + logging.info(f"Skipped CLIP parameter export, already exists at {clip_output_path}") + + # Export FluxTransformer parameters + transformer_path = Path(model_path_or_dataset) / "transformer/model.irpa" + transformer_output_path = output_path / "transformer.irpa" + if not is_already_exported(transformer_output_path): + transformer_dataset = Dataset.load(transformer_path) + transformer_model = FluxModelV1(theta=transformer_dataset.root_theta, params=FluxParams.from_hugging_face_properties(transformer_dataset.properties)) + export_flux_transformer_iree_parameters(transformer_model, str(transformer_output_path), dtype=dtype) + logging.info(f"Exported FluxTransformer parameters to {transformer_output_path}") + else: + logging.info(f"Skipped FluxTransformer parameter export, already exists at {transformer_output_path}") + + # Export VAE parameters + vae_path = Path(model_path_or_dataset) / "vae/model.irpa" + vae_output_path = output_path / "vae.irpa" + if not is_already_exported(vae_output_path): + vae_dataset = Dataset.load(vae_path) + vae_dataset.root_theta = vae_dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=dtype) + ) + vae_dataset.save(str(vae_output_path)) + logging.info(f"Exported VAE parameters to {vae_output_path}") + else: + logging.info(f"Skipped VAE parameter export, already exists at {vae_output_path}") + + logging.info(f"Completed Flux pipeline parameter export to {output_path}") \ No newline at end of file diff --git a/sharktank/sharktank/pipelines/flux/flux_pipeline.py b/sharktank/sharktank/pipelines/flux/flux_pipeline.py index fc02d6582..12fc4b9bb 100644 --- a/sharktank/sharktank/pipelines/flux/flux_pipeline.py +++ b/sharktank/sharktank/pipelines/flux/flux_pipeline.py @@ -49,9 +49,9 @@ def __init__( t5_dataset.properties, feed_forward_proj="gated-gelu", ) - t5_dataset.root_theta = t5_dataset.root_theta.transform( - functools.partial(set_float_dtype, dtype=dtype) - ) + # t5_dataset.root_theta = t5_dataset.root_theta.transform( + # functools.partial(set_float_dtype, dtype=dtype) + # ) self.t5_model = T5Encoder(theta=t5_dataset.root_theta, config=t5_config) self.add_module('t5_model', self.t5_model) self.t5_model.to(device) @@ -61,9 +61,9 @@ def __init__( # TODO: Refactor CLIP to not make the config rely on HuggingFace hf_clip_model = HfCLIPTextModel.from_pretrained("/data/flux/FLUX.1-dev/text_encoder/") clip_config = ClipTextConfig.from_hugging_face_clip_text_model_config(hf_clip_model.config) - clip_dataset.root_theta = clip_dataset.root_theta.transform( - functools.partial(set_float_dtype, dtype=dtype) - ) + # clip_dataset.root_theta = clip_dataset.root_theta.transform( + # functools.partial(set_float_dtype, dtype=dtype) + # ) self.clip_model = ClipTextModel(theta=clip_dataset.root_theta, config=clip_config) self.add_module('clip_model', self.clip_model) self.clip_model.to(device) @@ -71,9 +71,9 @@ def __init__( # Load Flux Transformer transformer_dataset = Dataset.load(transformer_path) transformer_params = FluxParams.from_hugging_face_properties(transformer_dataset.properties) - transformer_dataset.root_theta = transformer_dataset.root_theta.transform( - functools.partial(set_float_dtype, dtype=dtype) - ) + # transformer_dataset.root_theta = transformer_dataset.root_theta.transform( + # functools.partial(set_float_dtype, dtype=dtype) + # ) self.transformer_model = FluxModelV1( theta=transformer_dataset.root_theta, params=transformer_params @@ -83,9 +83,9 @@ def __init__( # Load VAE ae_dataset = Dataset.load(ae_path) - ae_dataset.root_theta = ae_dataset.root_theta.transform( - functools.partial(set_float_dtype, dtype=dtype) - ) + # ae_dataset.root_theta = ae_dataset.root_theta.transform( + # functools.partial(set_float_dtype, dtype=dtype) + # ) self.ae_model = VaeDecoderModel.from_dataset(ae_dataset) self.add_module('ae_model', self.ae_model) self.ae_model.to(device) @@ -119,6 +119,7 @@ def __call__( prompt: str, height: int = 1024, width: int = 1024, + latents: Optional[Tensor] = None, num_inference_steps: Optional[int] = 50, guidance_scale: float = 3.5, seed: Optional[int] = None, @@ -143,12 +144,13 @@ def __call__( raise ValueError("Tokenizers must be provided to use the __call__ method") t5_prompt_ids, clip_prompt_ids = self.tokenize_prompt(prompt) - latents = self._get_noise( - 1, - height, - width, - seed=seed, - ) + if not latents: + latents = self._get_noise( + 1, + height, + width, + seed=seed, + ) with torch.inference_mode(): return self.forward( diff --git a/sharktank/tests/pipelines/flux/flux_pipeline_test.py b/sharktank/tests/pipelines/flux/flux_pipeline_test.py index 0f3ca202e..ef9237c1e 100644 --- a/sharktank/tests/pipelines/flux/flux_pipeline_test.py +++ b/sharktank/tests/pipelines/flux/flux_pipeline_test.py @@ -24,7 +24,7 @@ from sharktank.pipelines.flux import ( FluxPipeline, export_flux_pipeline_mlir, - #export_flux_pipeline_iree_parameters, + export_flux_pipeline_iree_parameters, ) from sharktank.utils.testing import TempDirTestBase from sharktank.transforms.dataset import set_float_dtype @@ -110,10 +110,17 @@ def runTestFluxPipelineAgainstReference( # Generate outputs using string prompt prompt = "a photo of a forest with mist" + latents = reference_model._get_noise( + 1, + 1024, + 1024, + seed=12345, + ).to(dtype=dtype) reference_image_output = reference_model( prompt=prompt, height=1024, width=1024, + latents=latents, num_inference_steps=1, guidance_scale=3.5 ).images[0] @@ -123,6 +130,7 @@ def runTestFluxPipelineAgainstReference( prompt=prompt, height=1024, width=1024, + latents=latents, num_inference_steps=1, guidance_scale=3.5 ) @@ -159,17 +167,6 @@ def runTestFluxPipelineIreeCompare( rtol: Optional[float] = None, ): """Compare IREE pipeline against eager execution.""" - # Initialize reference model - reference_model = FluxPipeline( - t5_path="/data/t5-v1_1-xxl/model.gguf", - clip_path="/data/flux/FLUX.1-dev/text_encoder/model.irpa", - transformer_path="/data/flux/FLUX.1-dev/transformer/model.irpa", - ae_path="/data/flux/FLUX.1-dev/vae/model.irpa", - t5_tokenizer_path="/data/flux/FLUX.1-dev/tokenizer_2/", - clip_tokenizer_path="/data/flux/FLUX.1-dev/tokenizer/", - dtype=reference_dtype, - ) - # Create input tokens t5_tokenizer = T5Tokenizer.from_pretrained("/data/flux/FLUX.1-dev/tokenizer_2/") clip_tokenizer = CLIPTokenizer.from_pretrained("/data/flux/FLUX.1-dev/tokenizer/") @@ -177,34 +174,31 @@ def runTestFluxPipelineIreeCompare( prompt = "a photo of a forest with mist" t5_prompt_ids = torch.tensor([t5_tokenizer(prompt).input_ids], dtype=torch.long) clip_prompt_ids = torch.tensor([clip_tokenizer(prompt).input_ids], dtype=torch.long) - latents = reference_model._get_noise( - 1, - 1024, - 1024, - seed=12345, - ).to(dtype=target_dtype) # TODO: it isn't great to be getting this from the reference model + # latents = reference_model._get_noise( + # 1, + # 1024, + # 1024, + # seed=12345, + # ).to(dtype=target_dtype) # TODO: it isn't great to be getting this from the reference model - input_args = OrderedDict([ - ("t5_prompt_ids", t5_prompt_ids), - ("clip_prompt_ids", clip_prompt_ids), - ("latents", latents) - ]) + # input_args = OrderedDict([ + # ("t5_prompt_ids", t5_prompt_ids), + # ("clip_prompt_ids", clip_prompt_ids), + # ("latents", latents) + # ]) batch_size = t5_prompt_ids.shape[0] - # Get reference result - reference_result = reference_model.forward(t5_prompt_ids, clip_prompt_ids, latents) - # Export and compile for IREE target_dtype_name = dtype_to_serialized_short_name(target_dtype) target_path_prefix = f"{self.path_prefix}flux_pipeline_{target_dtype_name}" - parameters_path = f"/data/flux/FLUX.1-dev/" - # if not self.caching or not os.path.exists(parameters_path): - # export_flux_pipeline_iree_parameters( - # "/data/flux/FLUX.1-dev", - # parameters_path, - # dtype=target_dtype, - # ) + parameters_path = "/data/flux/FLUX.1-dev/" + if not self.caching or not os.path.exists(mlir_path): + export_flux_pipeline_iree_parameters( + "/data/flux/FLUX.1-dev/", + parameters_path, + dtype=target_dtype, + ) mlir_path = f"{target_path_prefix}.mlir" if not self.caching or not os.path.exists(mlir_path): @@ -262,11 +256,27 @@ def runTestFluxPipelineIreeCompare( trace_path_prefix=f"{target_path_prefix}_iree_", ) ) + + # Reference model + reference_model = FluxPipeline( + t5_path="/data/t5-v1_1-xxl/model.gguf", + clip_path="/data/flux/FLUX.1-dev/text_encoder/model.irpa", + transformer_path="/data/flux/FLUX.1-dev/transformer/model.irpa", + ae_path="/data/flux/FLUX.1-dev/vae/model.irpa", + t5_tokenizer_path="/data/flux/FLUX.1-dev/tokenizer_2/", + clip_tokenizer_path="/data/flux/FLUX.1-dev/tokenizer/", + dtype=reference_dtype, + ) + # reference_result = reference_model.forward(t5_prompt_ids, clip_prompt_ids, latents) + + # Reformat the result for direct comparison iree_result = [ ops.to(iree_result[i], dtype=reference_result[i].dtype) for i in range(len(reference_result)) ] + + torch.testing.assert_close(reference_result, iree_result, atol=atol, rtol=rtol) @with_flux_data @@ -284,10 +294,10 @@ def testFluxPipelineIreeF32(self): reason="BF16 vs F32 accuracy needs investigation", ) @with_flux_data - def testFluxPipelineIreeBF16vsF32(self): + def testFluxPipelineIreeBF16(self): """Test BF16 IREE pipeline against F16 eager execution.""" self.runTestFluxPipelineIreeCompare( - reference_dtype=torch.float32, + reference_dtype=torch.float16, target_dtype=torch.bfloat16, atol=1e-2, rtol=1.6e-2, From b58cb11edb9b9785836800f0b6787c810e5a2f64 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 2 Dec 2024 11:35:24 -0600 Subject: [PATCH 03/20] Add dynamo exports for flux.1 --- sharktank/sharktank/dynamo_exports/flux/ae.py | 347 ++++++++++++ .../sharktank/dynamo_exports/flux/export.py | 382 +++++++++++++ .../sharktank/dynamo_exports/flux/mmdit.py | 182 ++++++ .../dynamo_exports/flux/scheduler.py | 53 ++ sharktank/sharktank/dynamo_exports/flux/te.py | 524 ++++++++++++++++++ 5 files changed, 1488 insertions(+) create mode 100644 sharktank/sharktank/dynamo_exports/flux/ae.py create mode 100644 sharktank/sharktank/dynamo_exports/flux/export.py create mode 100644 sharktank/sharktank/dynamo_exports/flux/mmdit.py create mode 100644 sharktank/sharktank/dynamo_exports/flux/scheduler.py create mode 100644 sharktank/sharktank/dynamo_exports/flux/te.py diff --git a/sharktank/sharktank/dynamo_exports/flux/ae.py b/sharktank/sharktank/dynamo_exports/flux/ae.py new file mode 100644 index 000000000..ad05cc16c --- /dev/null +++ b/sharktank/sharktank/dynamo_exports/flux/ae.py @@ -0,0 +1,347 @@ +from torch import Tensor, nn +import torch +from einops import rearrange +from dataclasses import dataclass + +# This Flux AE implementation is copied from https://github.com/black-forest-labs/flux. + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: list[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +def swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, h_: Tensor) -> Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = nn.functional.scaled_dot_product_attention(q, k, v) + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = nn.GroupNorm( + num_groups=32, num_channels=out_channels, eps=1e-6, affine=True + ) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h = x + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + # no asymmetric padding in torch conv, must do it ourselves + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x: Tensor): + pad = (0, 1, 0, 1) + x = nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: Tensor): + x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + block_in = self.ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=block_in, eps=1e-6, affine=True + ) + self.conv_out = nn.Conv2d( + block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: Tensor) -> Tensor: + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, num_channels=block_in, eps=1e-6, affine=True + ) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) diff --git a/sharktank/sharktank/dynamo_exports/flux/export.py b/sharktank/sharktank/dynamo_exports/flux/export.py new file mode 100644 index 000000000..abdb93db4 --- /dev/null +++ b/sharktank/sharktank/dynamo_exports/flux/export.py @@ -0,0 +1,382 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import re +from dataclasses import dataclass + +from iree.compiler.ir import Context +from iree.turbine.aot import * +from iree.turbine.dynamo.passes import ( + DEFAULT_DECOMPOSITIONS, +) +import torch + +from diffusers.models.transformers import FluxTransformer2DModel +from te import ClipTextEncoderModule +from ae import AutoEncoder, AutoEncoderParams +from scheduler import FluxScheduler +from mmdit import get_flux_transformer_model + + +@dataclass +class ModelSpec: + ae_params: AutoEncoderParams + ae_path: str | None + + +fluxconfigs = { + "flux-dev": ModelSpec( + ae_path=None, # os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": ModelSpec( + ae_path=None, # os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + +model_repo_map = { + "flux-dev": "black-forest-labs/FLUX.1-dev", + "flux-schnell": "black-forest-labs/FLUX.1-schnell", +} +model_file_map = { + "flux-dev": "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors", + "flux-schnell": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/main/flux1-schnell.safetensors", +} + +torch_dtypes = { + "fp16": torch.float16, + "fp32": torch.float32, + "bf16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, +} + + +def create_safe_name(hf_model_name, model_name_str=""): + if not model_name_str: + model_name_str = "" + if model_name_str != "" and (not model_name_str.startswith("_")): + model_name_str = "_" + model_name_str + + safe_name = hf_model_name.split("/")[-1].strip() + model_name_str + safe_name = re.sub("-", "_", safe_name) + safe_name = re.sub("\.", "_", safe_name) + return safe_name + + +def get_flux_model_and_inputs( + hf_model_name, precision, batch_size, max_length, height, width +): + dtype = torch_dtypes[precision] + return get_flux_transformer_model( + hf_model_name, height, width, 8, max_length, dtype, batch_size + ) + + +def get_te_model_and_inputs( + hf_model_name, component, precision, batch_size, max_length +): + match component: + case "clip": + # te = CLIPTextModel.from_pretrained( + # model_repo_map[hf_model_name], + # subfolder="text_encoder" + # ) + te = ClipTextEncoderModule( + model_repo_map[hf_model_name], torch_dtypes[precision] + ) + clip_ids_shape = ( + batch_size, + 77, + 2, + ) + input_args = [ + torch.ones(clip_ids_shape, dtype=torch.int64), + ] + return te, input_args + case "t5xxl": + return None, None + + +def get_ae_model_and_inputs(hf_model_name, precision, batch_size, height, width): + dtype = torch_dtypes[precision] + aeparams = fluxconfigs[hf_model_name].ae_params + ae = AutoEncoder(params=aeparams).to(dtype) + latents_shape = ( + batch_size, + 16, + int(height * width / 256), + 64, + ) + img_shape = ( + 1, + aeparams.in_channels, + int(height), + int(width), + ) + encode_inputs = [ + torch.empty(img_shape, dtype=dtype), + ] + decode_inputs = [ + torch.empty(latents_shape, dtype=dtype), + ] + return ae, encode_inputs, decode_inputs + + +def get_scheduler_model_and_inputs(hf_model_name, max_length, precision): + is_schnell = "schnell" in hf_model_name + mod = FluxScheduler( + max_length=max_length, + torch_dtype=torch_dtypes[precision], + is_schnell=is_schnell, + ) + sample_inputs = (torch.empty(1, dtype=torch.int64),) + # tdim = torch.export.Dim("timesteps") + # dynamic_inputs = {"timesteps": {0: tdim}} + return mod, sample_inputs + + +@torch.no_grad() +def export_flux_model( + hf_model_name, + component, + batch_size, + height, + width, + precision="fp16", + max_length=512, + compile_to="torch", + external_weights=None, + external_weight_path=None, + decomp_attn=False, +): + dtype = torch_dtypes[precision] + decomp_list = [] + if decomp_attn == True: + decomp_list = [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.scaled_dot_product_attention, + ] + with decompositions.extend_aot_decompositions( + from_current=True, + add_ops=decomp_list, + ): + if component == "mmdit": + model, sample_inputs, _ = get_flux_model_and_inputs( + hf_model_name, precision, batch_size, max_length, height, width + ) + + fxb = FxProgramsBuilder(model) + + @fxb.export_program( + args=(sample_inputs,), + ) + def _forward( + module, + inputs, + ): + return module.forward(*inputs) + + class CompiledFluxTransformer(CompiledModule): + run_forward = _forward + + if external_weights: + externalize_module_parameters(model) + save_module_parameters(external_weight_path, model) + + inst = CompiledFluxTransformer(context=Context(), import_to="IMPORT") + + module = CompiledModule.get_mlir_module(inst) + + elif component in ["clip", "t5xxl"]: + model, sample_inputs = get_te_model_and_inputs( + hf_model_name, component, precision, batch_size, max_length + ) + + fxb = FxProgramsBuilder(model) + + @fxb.export_program( + args=(sample_inputs,), + ) + def _forward( + module, + inputs, + ): + return module.forward(*inputs) + + class CompiledFluxTextEncoder(CompiledModule): + encode_prompts = _forward + + if external_weights: + externalize_module_parameters(model) + save_module_parameters(external_weight_path, model) + + inst = CompiledFluxTextEncoder(context=Context(), import_to="IMPORT") + + module = CompiledModule.get_mlir_module(inst) + elif component == "vae": + model, encode_inputs, decode_inputs = get_ae_model_and_inputs( + hf_model_name, precision, batch_size, height, width + ) + + fxb = FxProgramsBuilder(model) + + @fxb.export_program( + args=(encode_inputs,), + ) + def _encode( + module, + inputs, + ): + return module.encode(*inputs) + + @fxb.export_program( + args=(decode_inputs,), + ) + def _decode( + module, + inputs, + ): + return module.decode(*inputs) + + class CompiledFluxAutoEncoder(CompiledModule): + encode = _encode + decode = _decode + + if external_weights: + externalize_module_parameters(model) + save_module_parameters(external_weight_path, model) + + inst = CompiledFluxAutoEncoder(context=Context(), import_to="IMPORT") + + module = CompiledModule.get_mlir_module(inst) + + elif component == "scheduler": + model, sample_inputs = get_scheduler_model_and_inputs( + hf_model_name, max_length, precision + ) + + fxb = FxProgramsBuilder(model) + + @fxb.export_program( + args=(sample_inputs,), + ) + def _prepare( + module, + inputs, + ): + return module.prepare(*inputs) + + class CompiledFlowScheduler(CompiledModule): + run_prep = _prepare + + inst = CompiledFlowScheduler(context=Context(), import_to="IMPORT") + + module = CompiledModule.get_mlir_module(inst) + + module_str = str(module) + return module_str + + +def get_filename(args): + match args.component: + case "mmdit": + return create_safe_name( + args.model, + f"mmdit_bs{args.batch_size}_{args.max_length}_{args.height}x{args.width}_{args.precision}", + ) + case "clip": + return create_safe_name( + args.model, f"clip_bs{args.batch_size}_77_{args.precision}" + ) + case "scheduler": + return create_safe_name( + args.model, + f"scheduler_bs{args.batch_size}_{args.max_length}_{args.precision}", + ) + case "vae": + return create_safe_name( + args.model, + f"vae_bs{args.batch_size}_{args.height}x{args.width}_{args.precision}", + ) + + +if __name__ == "__main__": + import logging + import argparse + + logging.basicConfig(level=logging.DEBUG) + p = argparse.ArgumentParser() + p.add_argument( + "--model", + default="flux-schnell", + choices=["flux-dev", "flux-schnell", "flux-pro"], + ) + p.add_argument( + "--component", + default="mmdit", + choices=["mmdit", "clip", "t5xxl", "scheduler", "vae"], + ) + p.add_argument("--batch_size", default=1) + p.add_argument("--height", default=1024) + p.add_argument("--width", default=1024) + p.add_argument("--precision", default="fp32") + p.add_argument("--max_length", default=512) + p.add_argument("--external_weights", default="irpa") + p.add_argument("--external_weights_file", default=None) + p.add_argument("--decomp_attn", action="store_true") + args = p.parse_args() + + if args.external_weights and not args.external_weights_file: + args.external_weights_file = ( + create_safe_name( + args.model, + args.component + "_" + args.precision, + ) + + "." + + args.external_weights + ) + safe_name = get_filename(args) + mod_str = export_flux_model( + args.model, + args.component, + args.batch_size, + args.height, + args.width, + args.precision, + args.max_length, + "mlir", + args.external_weights, + args.external_weights_file, + args.decomp_attn, + ) + + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/sharktank/sharktank/dynamo_exports/flux/mmdit.py b/sharktank/sharktank/dynamo_exports/flux/mmdit.py new file mode 100644 index 000000000..d4329f89d --- /dev/null +++ b/sharktank/sharktank/dynamo_exports/flux/mmdit.py @@ -0,0 +1,182 @@ +import os +import torch +import math +from diffusers import FluxTransformer2DModel +from typing import Callable +from iree.turbine.aot import * + + +def get_local_path(local_dir, model_dir): + model_local_dir = os.path.join(local_dir, model_dir) + if not os.path.exists(model_local_dir): + os.makedirs(model_local_dir) + return model_local_dir + + +class FluxModelCFG(torch.nn.Module): + def __init__(self, torch_dtype): + super().__init__() + self.mmdit = FluxTransformer2DModel.from_single_file( + "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" + ).to(torch_dtype) + + def forward( + self, + hidden_states, + encoder_hidden_states, + pooled_projections, + img_ids, + txt_ids, + guidance_vec, + t_vec, + t_curr, + t_prev, + cfg_scale, + ): + pred = self.mmdit( + hidden_states=hidden_states, + img_ids=img_ids, + encoder_hidden_states=encoder_hidden_states, + txt_ids=txt_ids, + pooled_projections=pooled_projections, + timestep=t_vec, + guidance=guidance_vec, + return_dict=False, + )[0] + pred_uncond, pred = torch.chunk(pred, 2, dim=0) + pred = pred_uncond + cfg_scale * (pred - pred_uncond) + hidden_states = hidden_states + (t_prev - t_curr) * pred + return hidden_states + + +class FluxModelSchnell(torch.nn.Module): + def __init__(self, torch_dtype): + super().__init__() + self.mmdit = FluxTransformer2DModel.from_single_file( + "https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/main/flux1-schnell.safetensors" + ).to(torch_dtype) + + def forward( + self, + hidden_states, + encoder_hidden_states, + pooled_projections, + img_ids, + txt_ids, + guidance_vec, + t_vec, + t_curr, + t_prev, + cfg_scale, + ): + pred = self.mmdit( + hidden_states=hidden_states, + img_ids=img_ids, + encoder_hidden_states=encoder_hidden_states, + txt_ids=txt_ids, + pooled_projections=pooled_projections, + timestep=t_vec, + guidance=guidance_vec, + return_dict=False, + )[0] + hidden_states = hidden_states + (t_prev - t_curr) * pred + return hidden_states + + +@torch.no_grad() +def get_flux_transformer_model( + hf_model_path, + img_height=1024, + img_width=1024, + compression_factor=8, + max_len=512, + torch_dtype=torch.float32, + bs=1, +): + + latent_h, latent_w = ( + img_height // compression_factor, + img_width // compression_factor, + ) + + if "schnell" in hf_model_path: + model = FluxModelSchnell(torch_dtype=torch_dtype) + config = model.mmdit.config + sample_inputs = ( + torch.randn( + bs, + (latent_h // 2) * (latent_w // 2), + config["in_channels"], + dtype=torch_dtype, + ), + torch.randn(bs, max_len, config["joint_attention_dim"], dtype=torch_dtype), + torch.randn(bs, config["pooled_projection_dim"], dtype=torch_dtype), + torch.randn((latent_h // 2) * (latent_w // 2), 3, dtype=torch_dtype), + torch.randn(max_len, 3, dtype=torch_dtype), + torch.tensor([1.0] * bs, dtype=torch_dtype), + torch.tensor([1.0] * bs, dtype=torch_dtype), + torch.tensor([1.0], dtype=torch_dtype), + torch.tensor([1.0], dtype=torch_dtype), + torch.tensor([1.0] * bs, dtype=torch_dtype), + ) + else: + model = FluxModelCFG(torch_dtype=torch_dtype) + config = model.mmdit.config + cfg_bs = bs * 2 + sample_inputs = ( + torch.randn( + cfg_bs, + (latent_h // 2) * (latent_w // 2), + config["in_channels"], + dtype=torch_dtype, + ), + torch.randn( + cfg_bs, max_len, config["joint_attention_dim"], dtype=torch_dtype + ), + torch.randn(cfg_bs, config["pooled_projection_dim"], dtype=torch_dtype), + torch.randn((latent_h // 2) * (latent_w // 2), 3, dtype=torch_dtype), + torch.randn(max_len, 3, dtype=torch_dtype), + torch.tensor([1.0] * bs, dtype=torch_dtype), + torch.tensor([1.0] * cfg_bs, dtype=torch_dtype), + torch.tensor([1.0], dtype=torch_dtype), + torch.tensor([1.0], dtype=torch_dtype), + torch.tensor([1.0] * bs, dtype=torch_dtype), + ) + + input_names = [ + "hidden_states", + "encoder_hidden_states", + "pooled_projections", + "img_ids", + "txt_ids", + "guidance_vec", + "t_curr", + "t_prev", + "cfg_scale", + ] + return model, sample_inputs, input_names + + # if not os.path.isfile(onnx_path): + # output_names = ["latent"] + # dynamic_axes = { + # 'hidden_states': {0: 'B', 1: 'latent_dim'}, + # 'encoder_hidden_states': {0: 'B',1: 'L'}, + # 'pooled_projections': {0: 'B'}, + # 'timestep': {0: 'B'}, + # 'img_ids': {0: 'latent_dim'}, + # 'txt_ids': {0: 'L'}, + # 'guidance': {0: 'B'}, + # } + + # with torch.inference_mode(): + # torch.onnx.export( + # model, + # sample_inputs, + # onnx_path, + # export_params=True, + # input_names=input_names, + # output_names=output_names) + + # assert os.path.isfile(onnx_path) + + # return onnx_path diff --git a/sharktank/sharktank/dynamo_exports/flux/scheduler.py b/sharktank/sharktank/dynamo_exports/flux/scheduler.py new file mode 100644 index 000000000..397f7b815 --- /dev/null +++ b/sharktank/sharktank/dynamo_exports/flux/scheduler.py @@ -0,0 +1,53 @@ +import math +import torch +from typing import Callable + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps + + +class FluxScheduler(torch.nn.Module): + def __init__(self, max_length, torch_dtype, is_schnell=False): + super().__init__() + self.is_schnell = is_schnell + self.max_length = max_length + timesteps = [torch.empty((100), dtype=torch_dtype, requires_grad=False)] * 100 + for i in range(1, 100): + schedule = get_schedule(i, max_length, shift=not self.is_schnell) + timesteps[i] = torch.nn.functional.pad(schedule, (0, 99 - i), "constant", 0) + self.timesteps = torch.stack(timesteps, dim=0).clone().detach() + + def prepare(self, num_steps): + # s = num_steps.item() + # torch._check(s >= 1) + # torch._check(s <= 100) + timesteps = self.timesteps[num_steps] + return timesteps diff --git a/sharktank/sharktank/dynamo_exports/flux/te.py b/sharktank/sharktank/dynamo_exports/flux/te.py new file mode 100644 index 000000000..bfb55b1ac --- /dev/null +++ b/sharktank/sharktank/dynamo_exports/flux/te.py @@ -0,0 +1,524 @@ +### This file contains impls for underlying related models (CLIP, T5, etc) + +import torch, math +from torch import nn +from transformers import CLIPTokenizer, T5TokenizerFast +from transformers import T5EncoderModel +from iree.turbine import ops +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from sharktank.layers import T5Config +from sharktank.models import t5 + +CLIP_CONFIG = { + "hidden_act": "quick_gelu", + "hidden_size": 768, + "intermediate_size": 3172, + "num_attention_heads": 12, + "num_hidden_layers": 12, +} + + +class ClipTextEncoderModule(torch.nn.Module): + @torch.no_grad() + def __init__( + self, + repo, + precision, + ): + super().__init__() + self.dtype = torch.float16 if precision == "fp16" else torch.float32 + self.clip = SDClipModel( + layer="hidden", + layer_idx=-2, + device="cpu", + dtype=self.dtype, + layer_norm_hidden_state=False, + return_projected_pooled=True, + textmodel_json_config=CLIP_CONFIG, + ) + if precision == "fp16": + self.clip = self.clip.half() + clip_weights = hf_hub_download( + repo_id=repo, + filename="text_encoder/model.safetensors", + ) + with safe_open(clip_weights, framework="pt", device="cpu") as f: + load_into(f, self.clip.transformer, "", "cpu", self.dtype) + + def forward(self, clip_ids): + vec = self.clip(clip_ids)[1] + + return vec + + +################################################################################################# +### Core/Utility +################################################################################################# + + +def attention(q, k, v, heads, mask=None): + """Convenience wrapper around a basic attention operation""" + b, _, dim_head = q.shape + # ops.iree.trace_tensor("attention_q", q[0,0,:5]) + # ops.iree.trace_tensor("attention_k", k[0,0,:5]) + # ops.iree.trace_tensor("attention_v", v[0,0,:5]) + dim_head //= heads + q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v)) + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + # ops.iree.trace_tensor("attention_out", out[0,0,:5]) + return out.transpose(1, 2).reshape(b, -1, heads * dim_head) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + bias=True, + dtype=None, + device=None, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear( + in_features, hidden_features, bias=bias, dtype=dtype, device=device + ) + self.act = act_layer + self.fc2 = nn.Linear( + hidden_features, out_features, bias=bias, dtype=dtype, device=device + ) + + def forward(self, x): + x = self.fc1(x) + # ops.iree.trace_tensor("mlpfx", x[0,0,:5]) + x = self.act(x) + # ops.iree.trace_tensor("mlpact", x[0,0,:5]) + x = self.fc2(x) + # ops.iree.trace_tensor("mlpanotherfc", x[0,0,:5]) + return x + + +def load_into(f, model, prefix, device, dtype=None): + """Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module.""" + for key in f.keys(): + if key.startswith(prefix) and not key.startswith("loss."): + path = key[len(prefix) :].split(".") + obj = model + for p in path: + if obj is list: + obj = obj[int(p)] + else: + obj = getattr(obj, p, None) + if obj is None: + print( + f"Skipping key '{key}' in safetensors file as '{p}' does not exist in python model" + ) + break + if obj is None: + continue + try: + tensor = f.get_tensor(key).to(device=device) + if dtype is not None: + tensor = tensor.to(dtype=dtype) + obj.requires_grad_(False) + obj.set_(tensor) + except Exception as e: + print(f"Failed to load key '{key}' in safetensors file: {e}") + raise e + + +################################################################################################# +### CLIP +################################################################################################# + + +class CLIPAttention(torch.nn.Module): + def __init__(self, embed_dim, heads, dtype, device): + super().__init__() + self.heads = heads + self.q_proj = nn.Linear( + embed_dim, embed_dim, bias=True, dtype=dtype, device=device + ) + self.k_proj = nn.Linear( + embed_dim, embed_dim, bias=True, dtype=dtype, device=device + ) + self.v_proj = nn.Linear( + embed_dim, embed_dim, bias=True, dtype=dtype, device=device + ) + self.out_proj = nn.Linear( + embed_dim, embed_dim, bias=True, dtype=dtype, device=device + ) + + def forward(self, x, mask=None): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + out = attention(q, k, v, self.heads, mask) + return self.out_proj(out) + + +ACTIVATIONS = { + "quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), + "gelu": torch.nn.functional.gelu, +} + + +class CLIPLayer(torch.nn.Module): + def __init__( + self, + embed_dim, + heads, + intermediate_size, + intermediate_activation, + dtype, + device, + ): + super().__init__() + self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) + self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + # self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) + self.mlp = Mlp( + embed_dim, + intermediate_size, + embed_dim, + act_layer=ACTIVATIONS[intermediate_activation], + dtype=dtype, + device=device, + ) + + def forward(self, x, mask=None): + x += self.self_attn(self.layer_norm1(x), mask) + x += self.mlp(self.layer_norm2(x)) + return x + + +class CLIPEncoder(torch.nn.Module): + def __init__( + self, + num_layers, + embed_dim, + heads, + intermediate_size, + intermediate_activation, + dtype, + device, + ): + super().__init__() + self.layers = torch.nn.ModuleList( + [ + CLIPLayer( + embed_dim, + heads, + intermediate_size, + intermediate_activation, + dtype, + device, + ) + for i in range(num_layers) + ] + ) + + def forward(self, x, mask=None, intermediate_output=None): + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layers) + intermediate_output + intermediate = None + for i, l in enumerate(self.layers): + x = l(x, mask) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + + +class CLIPEmbeddings(torch.nn.Module): + def __init__( + self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None + ): + super().__init__() + self.token_embedding = torch.nn.Embedding( + vocab_size, embed_dim, dtype=dtype, device=device + ) + self.position_embedding = torch.nn.Embedding( + num_positions, embed_dim, dtype=dtype, device=device + ) + + def forward(self, input_tokens): + return self.token_embedding(input_tokens) + self.position_embedding.weight + + +class CLIPTextModel_(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + super().__init__() + self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) + self.encoder = CLIPEncoder( + num_layers, + embed_dim, + heads, + intermediate_size, + intermediate_activation, + dtype, + device, + ) + self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + + def forward( + self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True + ): + x = self.embeddings(input_tokens) + causal_mask = ( + torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device) + .fill_(float("-inf")) + .triu_(1) + ) + x, i = self.encoder( + x, mask=causal_mask, intermediate_output=intermediate_output + ) + x = self.final_layer_norm(x) + if i is not None and final_layer_norm_intermediate: + i = self.final_layer_norm(i) + pooled_output = x[ + torch.arange(x.shape[0], device=x.device), + input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1), + ] + return x, i, pooled_output + + +class CLIPTextModel(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_hidden_layers"] + self.text_model = CLIPTextModel_(config_dict, dtype, device) + embed_dim = config_dict["hidden_size"] + self.text_projection = nn.Linear( + embed_dim, embed_dim, bias=False, dtype=dtype, device=device + ) + self.text_projection.weight.copy_(torch.eye(embed_dim)) + self.dtype = dtype + + def get_input_embeddings(self): + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, embeddings): + self.text_model.embeddings.token_embedding = embeddings + + def forward(self, *args, **kwargs): + x = self.text_model(*args, **kwargs) + out = self.text_projection(x[2]) + return (x[0], x[1], out, x[2]) + + +class SDTokenizer: + def __init__( + self, + max_length=77, + pad_with_end=True, + tokenizer=None, + has_start_token=True, + pad_to_max_length=True, + min_length=None, + ): + self.tokenizer = tokenizer + self.max_length = max_length + self.min_length = min_length + empty = self.tokenizer("")["input_ids"] + if has_start_token: + self.tokens_start = 1 + self.start_token = empty[0] + self.end_token = empty[1] + else: + self.tokens_start = 0 + self.start_token = None + self.end_token = empty[0] + self.pad_with_end = pad_with_end + self.pad_to_max_length = pad_to_max_length + vocab = self.tokenizer.get_vocab() + self.inv_vocab = {v: k for k, v in vocab.items()} + self.max_word_length = 8 + + def tokenize_with_weights(self, text: str): + """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" + if self.pad_with_end: + pad_token = self.end_token + else: + pad_token = 0 + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0)) + to_tokenize = text.replace("\n", " ").split(" ") + to_tokenize = [x for x in to_tokenize if x != ""] + for word in to_tokenize: + batch.extend( + [ + (t, 1) + for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1] + ] + ) + batch.append((self.end_token, 1.0)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) + if self.min_length is not None and len(batch) < self.min_length: + batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) + return [batch] + + +class SDXLClipGTokenizer(SDTokenizer): + def __init__(self, tokenizer): + super().__init__(pad_with_end=False, tokenizer=tokenizer) + + +class SD3Tokenizer: + def __init__(self): + clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) + self.clip_g = SDXLClipGTokenizer(clip_tokenizer) + self.t5xxl = T5XXLTokenizer() + + def tokenize_with_weights(self, text: str | list[str]): + out = {} + if isinstance(text, list): + text = text[0] + out["g"] = self.clip_g.tokenize_with_weights(text) + out["l"] = self.clip_l.tokenize_with_weights(text) + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text) + for k, v in out.items(): + out[k] = torch.tensor(v, dtype=torch.int64, device="cpu") + return out + + +class ClipTokenWeightEncoder: + def encode_token_weights(self, token_weight_pairs): + # tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + tokens = token_weight_pairs[:, :, 0] + out, pooled = self(tokens) + if pooled is not None: + first_pooled = pooled[0:1].cpu() + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2).cpu(), first_pooled + + +class SDClipModel(torch.nn.Module): + """Uses the CLIP transformer encoder for text (from huggingface)""" + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + device="cpu", + max_length=77, + layer="last", + layer_idx=None, + textmodel_json_config=None, + dtype=None, + model_class=CLIPTextModel, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, + layer_norm_hidden_state=True, + return_projected_pooled=True, + ): + super().__init__() + assert layer in self.LAYERS + self.transformer = model_class(textmodel_json_config, dtype, device) + self.num_layers = self.transformer.num_layers + self.max_length = max_length + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + self.layer = layer + self.layer_idx = None + self.special_tokens = special_tokens + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.layer_norm_hidden_state = layer_norm_hidden_state + self.return_projected_pooled = return_projected_pooled + if layer == "hidden": + assert layer_idx is not None + assert abs(layer_idx) < self.num_layers + self.set_clip_options({"layer": layer_idx}) + self.options_default = ( + self.layer, + self.layer_idx, + self.return_projected_pooled, + ) + + def encode_token_weights(self, token_weight_pairs): + pass + + def set_clip_options(self, options): + layer_idx = options.get("layer", self.layer_idx) + self.return_projected_pooled = options.get( + "projected_pooled", self.return_projected_pooled + ) + if layer_idx is None or abs(layer_idx) > self.num_layers: + self.layer = "last" + else: + self.layer = "hidden" + self.layer_idx = layer_idx + + def forward(self, token_weight_pairs): + # tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + tokens = token_weight_pairs[:, :, 0] + # backup_embeds = self.transformer.get_input_embeddings() + # device = backup_embeds.weight.device + # tokens = torch.LongTensor(tokens).to(device) + outputs = self.transformer( + tokens, + intermediate_output=self.layer_idx, + final_layer_norm_intermediate=self.layer_norm_hidden_state, + ) + # self.transformer.set_input_embeddings(backup_embeds) + if self.layer == "last": + z = outputs[0] + else: + z = outputs[1] + pooled_output = None + if len(outputs) >= 3: + if ( + not self.return_projected_pooled + and len(outputs) >= 4 + and outputs[3] is not None + ): + pooled_output = outputs[3].float() + elif outputs[2] is not None: + pooled_output = outputs[2].float() + out, pooled = z.float(), pooled_output + if pooled is not None: + first_pooled = pooled[0:1].cpu() + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2).cpu(), first_pooled + + +class SDXLClipG(SDClipModel): + """Wraps the CLIP-G model into the SD-CLIP-Model interface""" + + def __init__( + self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None + ): + if layer == "penultimate": + layer = "hidden" + layer_idx = -2 + super().__init__( + device=device, + layer=layer, + layer_idx=layer_idx, + textmodel_json_config=config, + dtype=dtype, + special_tokens={"start": 49406, "end": 49407, "pad": 0}, + layer_norm_hidden_state=False, + ) From 0dff3e5351a0dd763676779f581d4b6bdd52c982 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 2 Dec 2024 11:39:36 -0600 Subject: [PATCH 04/20] Add README --- sharktank/sharktank/dynamo_exports/flux/README.md | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 sharktank/sharktank/dynamo_exports/flux/README.md diff --git a/sharktank/sharktank/dynamo_exports/flux/README.md b/sharktank/sharktank/dynamo_exports/flux/README.md new file mode 100644 index 000000000..db989c0e8 --- /dev/null +++ b/sharktank/sharktank/dynamo_exports/flux/README.md @@ -0,0 +1,8 @@ +# Flux.1 dynamo exports + +### Quick Start + +All the exports in this directory are done through `export.py`, with the CLI syntax as follows: +```shell +python sharktank/sharktank/dynamo_exports/flux/export.py --model="flux-dev" --component= --precision= +``` From ec92d2fafcd13db1b9ff8211b761533198dfd42e Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Mon, 2 Dec 2024 16:07:51 -0600 Subject: [PATCH 05/20] Remove commented code. --- sharktank/sharktank/dynamo_exports/flux/scheduler.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sharktank/sharktank/dynamo_exports/flux/scheduler.py b/sharktank/sharktank/dynamo_exports/flux/scheduler.py index 397f7b815..c89b52971 100644 --- a/sharktank/sharktank/dynamo_exports/flux/scheduler.py +++ b/sharktank/sharktank/dynamo_exports/flux/scheduler.py @@ -46,8 +46,5 @@ def __init__(self, max_length, torch_dtype, is_schnell=False): self.timesteps = torch.stack(timesteps, dim=0).clone().detach() def prepare(self, num_steps): - # s = num_steps.item() - # torch._check(s >= 1) - # torch._check(s <= 100) timesteps = self.timesteps[num_steps] return timesteps From 58fa0f1932c239e6e0d7aae735b3b85cb35abae4 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 4 Dec 2024 12:26:44 -0600 Subject: [PATCH 06/20] Initial commit of flux server. --- sharktank/sharktank/dynamo_exports/flux/ae.py | 25 +- .../sharktank/dynamo_exports/flux/export.py | 7 +- sharktank/sharktank/dynamo_exports/flux/te.py | 2 +- shortfin/python/shortfin_apps/flux/README.md | 30 + .../python/shortfin_apps/flux/__init__.py | 7 + shortfin/python/shortfin_apps/flux/_deps.py | 22 + .../shortfin_apps/flux/components/builders.py | 307 ++++++++ .../flux/components/config_artifacts.py | 112 +++ .../flux/components/config_struct.py | 107 +++ .../shortfin_apps/flux/components/generate.py | 102 +++ .../flux/components/io_struct.py | 79 ++ .../shortfin_apps/flux/components/manager.py | 48 ++ .../shortfin_apps/flux/components/messages.py | 190 +++++ .../shortfin_apps/flux/components/metrics.py | 51 ++ .../shortfin_apps/flux/components/service.py | 734 ++++++++++++++++++ .../flux/components/tokenizer.py | 84 ++ .../flux/examples/flux_dev_config_mixed.json | 36 + .../flux/examples/flux_flags_gfx942.txt | 28 + .../flux/examples/flux_request_bs2.json | 18 + shortfin/python/shortfin_apps/flux/server.py | 424 ++++++++++ .../shortfin_apps/flux/simple_client.py | 246 ++++++ 21 files changed, 2647 insertions(+), 12 deletions(-) create mode 100644 shortfin/python/shortfin_apps/flux/README.md create mode 100644 shortfin/python/shortfin_apps/flux/__init__.py create mode 100644 shortfin/python/shortfin_apps/flux/_deps.py create mode 100644 shortfin/python/shortfin_apps/flux/components/builders.py create mode 100644 shortfin/python/shortfin_apps/flux/components/config_artifacts.py create mode 100644 shortfin/python/shortfin_apps/flux/components/config_struct.py create mode 100644 shortfin/python/shortfin_apps/flux/components/generate.py create mode 100644 shortfin/python/shortfin_apps/flux/components/io_struct.py create mode 100644 shortfin/python/shortfin_apps/flux/components/manager.py create mode 100644 shortfin/python/shortfin_apps/flux/components/messages.py create mode 100644 shortfin/python/shortfin_apps/flux/components/metrics.py create mode 100644 shortfin/python/shortfin_apps/flux/components/service.py create mode 100644 shortfin/python/shortfin_apps/flux/components/tokenizer.py create mode 100644 shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json create mode 100644 shortfin/python/shortfin_apps/flux/examples/flux_flags_gfx942.txt create mode 100644 shortfin/python/shortfin_apps/flux/examples/flux_request_bs2.json create mode 100644 shortfin/python/shortfin_apps/flux/server.py create mode 100644 shortfin/python/shortfin_apps/flux/simple_client.py diff --git a/sharktank/sharktank/dynamo_exports/flux/ae.py b/sharktank/sharktank/dynamo_exports/flux/ae.py index ad05cc16c..a0618d63c 100644 --- a/sharktank/sharktank/dynamo_exports/flux/ae.py +++ b/sharktank/sharktank/dynamo_exports/flux/ae.py @@ -2,6 +2,7 @@ import torch from einops import rearrange from dataclasses import dataclass +import math # This Flux AE implementation is copied from https://github.com/black-forest-labs/flux. @@ -17,6 +18,8 @@ class AutoEncoderParams: z_channels: int scale_factor: float shift_factor: float + height: int + width: int def swish(x: Tensor) -> Tensor: @@ -325,14 +328,8 @@ def __init__(self, params: AutoEncoderParams): self.scale_factor = params.scale_factor self.shift_factor = params.shift_factor - - @property - def device(self) -> torch.device: - return next(self.parameters()).device - - @property - def dtype(self) -> torch.dtype: - return next(self.parameters()).dtype + self.height = params.height + self.width = params.width def encode(self, x: Tensor) -> Tensor: z = self.reg(self.encoder(x)) @@ -340,8 +337,16 @@ def encode(self, x: Tensor) -> Tensor: return z def decode(self, z: Tensor) -> Tensor: - z = z / self.scale_factor + self.shift_factor - return self.decoder(z) + d_in = rearrange( + z, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(self.height / 16), + w=math.ceil(self.width / 16), + ph=2, + pw=2, + ) + d_in = d_in / self.scale_factor + self.shift_factor + return self.decoder(d_in) def forward(self, x: Tensor) -> Tensor: return self.decode(self.encode(x)) diff --git a/sharktank/sharktank/dynamo_exports/flux/export.py b/sharktank/sharktank/dynamo_exports/flux/export.py index abdb93db4..7d82cab25 100644 --- a/sharktank/sharktank/dynamo_exports/flux/export.py +++ b/sharktank/sharktank/dynamo_exports/flux/export.py @@ -41,6 +41,8 @@ class ModelSpec: z_channels=16, scale_factor=0.3611, shift_factor=0.1159, + height=1024, + width=1024, ), ), "flux-schnell": ModelSpec( @@ -55,6 +57,8 @@ class ModelSpec: z_channels=16, scale_factor=0.3611, shift_factor=0.1159, + height=1024, + width=1024, ), ), } @@ -126,10 +130,11 @@ def get_te_model_and_inputs( def get_ae_model_and_inputs(hf_model_name, precision, batch_size, height, width): dtype = torch_dtypes[precision] aeparams = fluxconfigs[hf_model_name].ae_params + aeparams.height = height + aeparams.width = width ae = AutoEncoder(params=aeparams).to(dtype) latents_shape = ( batch_size, - 16, int(height * width / 256), 64, ) diff --git a/sharktank/sharktank/dynamo_exports/flux/te.py b/sharktank/sharktank/dynamo_exports/flux/te.py index bfb55b1ac..59d728d2c 100644 --- a/sharktank/sharktank/dynamo_exports/flux/te.py +++ b/sharktank/sharktank/dynamo_exports/flux/te.py @@ -47,7 +47,7 @@ def __init__( load_into(f, self.clip.transformer, "", "cpu", self.dtype) def forward(self, clip_ids): - vec = self.clip(clip_ids)[1] + vec = self.clip(clip_ids) return vec diff --git a/shortfin/python/shortfin_apps/flux/README.md b/shortfin/python/shortfin_apps/flux/README.md new file mode 100644 index 000000000..a0ff7b809 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/README.md @@ -0,0 +1,30 @@ +# Flux.1 Server and CLI + +This directory contains a [Flux](https://blackforestlabs.ai/#get-flux) inference server, CLI and support components. More information about FLUX.1 on [huggingface](https://huggingface.co/black-forest-labs/FLUX.1-dev). + +## Install + +For [nightly releases](../../../../docs/nightly_releases.md) +For our [stable release](../../../../docs/user_guide.md) + +## Start Flux Server +The server will prepare runtime artifacts for you. + +By default, the port is set to 8000. If you would like to change this, use `--port` in each of the following commands. + +You can check if this (or any) port is in use on Linux with `ss -ntl | grep 8000`. + +``` +python -m shortfin_apps.flux.server --device=amdgpu --device_ids=0 --build_preference=precompiled --topology="spx_single" +``` + - Wait until your server outputs: +``` +INFO - Application startup complete. +INFO - Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) +``` +## Run the Flux Client + + - Run a CLI client in a separate shell: +``` +python -m shortfin_apps.flux.simple_client --interactive +``` diff --git a/shortfin/python/shortfin_apps/flux/__init__.py b/shortfin/python/shortfin_apps/flux/__init__.py new file mode 100644 index 000000000..4a168079c --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/__init__.py @@ -0,0 +1,7 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from . import _deps diff --git a/shortfin/python/shortfin_apps/flux/_deps.py b/shortfin/python/shortfin_apps/flux/_deps.py new file mode 100644 index 000000000..92bd089ec --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/_deps.py @@ -0,0 +1,22 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from shortfin.support.deps import ShortfinDepNotFoundError + +try: + import transformers +except ModuleNotFoundError as e: + raise ShortfinDepNotFoundError(__name__, "transformers") from e + +try: + import tokenizers +except ModuleNotFoundError as e: + raise ShortfinDepNotFoundError(__name__, "tokenizers") from e + +try: + import dataclasses_json +except ModuleNotFoundError as e: + raise ShortfinDepNotFoundError(__name__, "dataclasses-json") from e diff --git a/shortfin/python/shortfin_apps/flux/components/builders.py b/shortfin/python/shortfin_apps/flux/components/builders.py new file mode 100644 index 000000000..7a2c2ce60 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/builders.py @@ -0,0 +1,307 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from iree.build import * +from iree.build.executor import FileNamespace, BuildAction, BuildContext, BuildFile +import itertools +import os +import urllib +import shortfin.array as sfnp +import copy + +from shortfin_apps.flux.components.config_struct import ModelParams + +this_dir = os.path.dirname(os.path.abspath(__file__)) +parent = os.path.dirname(this_dir) +default_config_json = os.path.join(parent, "examples", "flux_dev_config_mixed.json") + +dtype_to_filetag = { + "bfloat16": "bf16", + "float32": "fp32", + "float16": "fp16", + sfnp.float32: "fp32", + sfnp.bfloat16: "bf16", +} + +ARTIFACT_VERSION = "12032024" +SDXL_BUCKET = ( + f"https://sharkpublic.blob.core.windows.net/sharkpublic/flux.1/{ARTIFACT_VERSION}/" +) +SDXL_WEIGHTS_BUCKET = ( + "https://sharkpublic.blob.core.windows.net/sharkpublic/flux.1/weights/" +) + + +def filter_by_model(filenames, model): + if not model: + return filenames + filtered = [] + for i in filenames: + if model == "t5xxl" and i == "google__t5_v1_1_xxl_encoder_fp32.irpa": + filtered.extend([i]) + if model.lower() in i.lower(): + filtered.extend([i]) + return filtered + + +def get_mlir_filenames(model_params: ModelParams, model=None): + mlir_filenames = [] + file_stems = get_file_stems(model_params) + for stem in file_stems: + mlir_filenames.extend([stem + ".mlir"]) + return filter_by_model(mlir_filenames, model) + + +def get_vmfb_filenames( + model_params: ModelParams, model=None, target: str = "amdgpu-gfx942" +): + vmfb_filenames = [] + file_stems = get_file_stems(model_params) + for stem in file_stems: + vmfb_filenames.extend([stem + "_" + target + ".vmfb"]) + return filter_by_model(vmfb_filenames, model) + + +def get_params_filenames(model_params: ModelParams, model=None, splat: bool = False): + params_filenames = [] + base = "flux_dev" if not model_params.is_schnell else model_params.base_model_name + modnames = ["clip", "sampler", "vae"] + mod_precs = [ + dtype_to_filetag[model_params.clip_dtype], + dtype_to_filetag[model_params.sampler_dtype], + dtype_to_filetag[model_params.vae_dtype], + ] + if splat == "True": + for idx, mod in enumerate(modnames): + params_filenames.extend( + ["_".join([mod, "splat", f"{mod_precs[idx]}.irpa"])] + ) + else: + for idx, mod in enumerate(modnames): + params_filenames.extend([base + "_" + mod + "_" + mod_precs[idx] + ".irpa"]) + + # this is a hack + params_filenames.extend(["google__t5_v1_1_xxl_encoder_fp32.irpa"]) + + return filter_by_model(params_filenames, model) + + +def get_file_stems(model_params: ModelParams): + file_stems = [] + base = ["flux_dev" if not model_params.is_schnell else model_params.base_model_name] + mod_names = { + "clip": "clip", + "t5xxl": "t5xxl", + "sampler": "sampler", + "vae": "vae", + } + for mod, modname in mod_names.items(): + ord_params = [ + base, + [modname], + ] + bsizes = [] + for bs in getattr(model_params, f"{mod}_batch_sizes", [1]): + bsizes.extend([f"bs{bs}"]) + ord_params.extend([bsizes]) + if mod in ["sampler"]: + ord_params.extend([[str(model_params.max_seq_len)]]) + elif mod == "clip": + ord_params.extend([[str(model_params.clip_max_seq_len)]]) + if mod in ["sampler", "vae"]: + dims = [] + for dim_pair in model_params.dims: + dim_pair_str = [str(d) for d in dim_pair] + dims.extend(["x".join(dim_pair_str)]) + ord_params.extend([dims]) + + dtype_str = dtype_to_filetag[ + getattr(model_params, f"{mod}_dtype", sfnp.float32) + ] + ord_params.extend([[dtype_str]]) + for x in list(itertools.product(*ord_params)): + file_stems.extend(["_".join(x)]) + return file_stems + + +def get_url_map(filenames: list[str], bucket: str): + file_map = {} + for filename in filenames: + file_map[filename] = f"{bucket}{filename}" + return file_map + + +def needs_update(ctx): + stamp = ctx.allocate_file("version.txt") + stamp_path = stamp.get_fs_path() + if os.path.exists(stamp_path): + with open(stamp_path, "r") as s: + ver = s.read() + if ver != ARTIFACT_VERSION: + return True + else: + with open(stamp_path, "w") as s: + s.write(ARTIFACT_VERSION) + return True + return False + + +def needs_file(filename, ctx, url=None, namespace=FileNamespace.GEN): + out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path() + needed = True + if os.path.exists(out_file): + if url: + needed = False + # needed = not is_valid_size(out_file, url) + if not needed: + return False + filekey = os.path.join(ctx.path, filename) + ctx.executor.all[filekey] = None + return True + + +def needs_compile(filename, target, ctx): + vmfb_name = f"{filename}_{target}.vmfb" + namespace = FileNamespace.BIN + return needs_file(vmfb_name, ctx, namespace=namespace) + + +def get_cached_vmfb(filename, target, ctx): + vmfb_name = f"{filename}_{target}.vmfb" + return ctx.file(vmfb_name) + + +def is_valid_size(file_path, url): + if not url: + return True + with urllib.request.urlopen(url) as response: + content_length = response.getheader("Content-Length") + local_size = get_file_size(str(file_path)) + if content_length: + content_length = int(content_length) + if content_length != local_size: + return False + return True + + +def get_file_size(file_path): + """Gets the size of a local file in bytes as an integer.""" + + file_stats = os.stat(file_path) + return file_stats.st_size + + +def fetch_http_check_size(*, name: str, url: str) -> BuildFile: + context = BuildContext.current() + output_file = context.allocate_file(name) + action = FetchHttpWithCheckAction( + url=url, output_file=output_file, desc=f"Fetch {url}", executor=context.executor + ) + output_file.deps.add(action) + return output_file + + +class FetchHttpWithCheckAction(BuildAction): + def __init__(self, url: str, output_file: BuildFile, **kwargs): + super().__init__(**kwargs) + self.url = url + self.output_file = output_file + + def _invoke(self, retries=4): + path = self.output_file.get_fs_path() + self.executor.write_status(f"Fetching URL: {self.url} -> {path}") + try: + urllib.request.urlretrieve(self.url, str(path)) + except urllib.error.HTTPError as e: + if retries > 0: + retries -= 1 + self._invoke(retries=retries) + else: + raise IOError(f"Failed to fetch URL '{self.url}': {e}") from None + local_size = get_file_size(str(path)) + try: + with urllib.request.urlopen(self.url) as response: + content_length = response.getheader("Content-Length") + if content_length: + content_length = int(content_length) + if content_length != local_size: + raise IOError( + f"Size of downloaded artifact does not match content-length header! {content_length} != {local_size}" + ) + except IOError: + if retries > 0: + retries -= 1 + self._invoke(retries=retries) + + +@entrypoint(description="Retreives a set of SDXL submodels.") +def flux( + model_json=cl_arg( + "model-json", + default=default_config_json, + help="Local config filepath", + ), + target=cl_arg( + "target", + default="gfx942", + help="IREE target architecture.", + ), + splat=cl_arg( + "splat", default=False, type=str, help="Download empty weights (for testing)" + ), + build_preference=cl_arg( + "build-preference", + default="precompiled", + help="Sets preference for artifact generation method: [compile, precompiled]", + ), + model=cl_arg("model", type=str, help="Submodel to fetch/compile for."), +): + model_params = ModelParams.load_json(model_json) + ctx = executor.BuildContext.current() + update = needs_update(ctx) + + mlir_bucket = SDXL_BUCKET + "mlir/" + vmfb_bucket = SDXL_BUCKET + "vmfbs/" + if "gfx" in target: + target = "amdgpu-" + target + + mlir_filenames = get_mlir_filenames(model_params, model) + mlir_urls = get_url_map(mlir_filenames, mlir_bucket) + for f, url in mlir_urls.items(): + if update or needs_file(f, ctx, url): + fetch_http(name=f, url=url) + + vmfb_filenames = get_vmfb_filenames(model_params, model=model, target=target) + vmfb_urls = get_url_map(vmfb_filenames, vmfb_bucket) + if build_preference == "compile": + for idx, f in enumerate(copy.deepcopy(vmfb_filenames)): + # We return .vmfb file stems for the compile builder. + file_stem = "_".join(f.split("_")[:-1]) + if needs_compile(file_stem, target, ctx): + for mlirname in mlir_filenames: + if file_stem in mlirname: + mlir_source = mlirname + break + obj = compile(name=file_stem, source=mlir_source) + vmfb_filenames[idx] = obj[0] + else: + vmfb_filenames[idx] = get_cached_vmfb(file_stem, target, ctx) + else: + for f, url in vmfb_urls.items(): + if update or needs_file(f, ctx, url): + fetch_http(name=f, url=url) + + params_filenames = get_params_filenames(model_params, model=model, splat=splat) + params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET) + for f, url in params_urls.items(): + if needs_file(f, ctx, url): + fetch_http_check_size(name=f, url=url) + filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames] + return filenames + + +if __name__ == "__main__": + iree_build_main() diff --git a/shortfin/python/shortfin_apps/flux/components/config_artifacts.py b/shortfin/python/shortfin_apps/flux/components/config_artifacts.py new file mode 100644 index 000000000..8c779ebc6 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/config_artifacts.py @@ -0,0 +1,112 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from iree.build import * +from iree.build.executor import FileNamespace +import os + +ARTIFACT_VERSION = "12032024" +SDXL_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/flux.1/{ARTIFACT_VERSION}/configs/" + + +def get_url_map(filenames: list[str], bucket: str): + file_map = {} + for filename in filenames: + file_map[filename] = f"{bucket}{filename}" + return file_map + + +def needs_update(ctx): + stamp = ctx.allocate_file("version.txt") + stamp_path = stamp.get_fs_path() + if os.path.exists(stamp_path): + with open(stamp_path, "r") as s: + ver = s.read() + if ver != ARTIFACT_VERSION: + return True + else: + with open(stamp_path, "w") as s: + s.write(ARTIFACT_VERSION) + return True + return False + + +def needs_file(filename, ctx, namespace=FileNamespace.GEN): + out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path() + if os.path.exists(out_file): + needed = False + else: + # name_path = "bin" if namespace == FileNamespace.BIN else "" + # if name_path: + # filename = os.path.join(name_path, filename) + filekey = os.path.join(ctx.path, filename) + ctx.executor.all[filekey] = None + needed = True + return needed + + +@entrypoint(description="Retreives a set of SDXL configuration files.") +def sdxlconfig( + target=cl_arg( + "target", + default="gfx942", + help="IREE target architecture.", + ), + model=cl_arg("model", type=str, default="sdxl", help="Model architecture"), + topology=cl_arg( + "topology", + type=str, + default=None, + help="System topology configfile keyword", + ), +): + ctx = executor.BuildContext.current() + update = needs_update(ctx) + + # model_config_filenames = [f"{model}_config_i8.json"] + # model_config_urls = get_url_map(model_config_filenames, SDXL_CONFIG_BUCKET) + # for f, url in model_config_urls.items(): + # if update or needs_file(f, ctx): + # fetch_http(name=f, url=url) + + if topology: + topology_config_filenames = [f"topology_config_{topology}.txt"] + topology_config_urls = get_url_map( + topology_config_filenames, SDXL_CONFIG_BUCKET + ) + for f, url in topology_config_urls.items(): + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + + # flagfile_filenames = [f"{model}_flagfile_{target}.txt"] + # flagfile_urls = get_url_map(flagfile_filenames, SDXL_CONFIG_BUCKET) + # for f, url in flagfile_urls.items(): + # if update or needs_file(f, ctx): + # fetch_http(name=f, url=url) + + tuning_filenames = ( + [f"attention_and_matmul_spec_{target}.mlir"] if target == "gfx942" else [] + ) + tuning_urls = get_url_map(tuning_filenames, SDXL_CONFIG_BUCKET) + for f, url in tuning_urls.items(): + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + filenames = [ + # *model_config_filenames, + # *flagfile_filenames, + *tuning_filenames, + ] + if topology: + filenames.extend( + [ + *topology_config_filenames, + ] + ) + return filenames + + +if __name__ == "__main__": + iree_build_main() diff --git a/shortfin/python/shortfin_apps/flux/components/config_struct.py b/shortfin/python/shortfin_apps/flux/components/config_struct.py new file mode 100644 index 000000000..29bc651c0 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/config_struct.py @@ -0,0 +1,107 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Configuration objects. + +Parameters that are intrinsic to a specific model. + +Typically represented in something like a Huggingface config.json, +we extend the configuration to enumerate inference boundaries of some given set of compiled modules. +""" + +from dataclasses import dataclass +from pathlib import Path + +from dataclasses_json import dataclass_json, Undefined + +import shortfin.array as sfnp + +str_to_dtype = { + "int8": sfnp.int8, + "float16": sfnp.float16, + "bfloat16": sfnp.bfloat16, + "float32": sfnp.float32, +} + + +@dataclass_json(undefined=Undefined.RAISE) +@dataclass +class ModelParams: + """Parameters for a specific set of compiled SD submodels, sufficient to do batching / + invocations.""" + + # Batch sizes that each stage is compiled for. These are expected to be + # functions exported from the model with suffixes of "_bs{batch_size}". Must + # be in ascending order. + clip_batch_sizes: list[int] + + t5xxl_batch_sizes: list[int] + + sampler_batch_sizes: list[int] + + vae_batch_sizes: list[int] + + # Height and Width, respectively, for which sampler and VAE are compiled. e.g. [[512, 512], [1024, 1024]] + dims: list[list[int]] + + base_model_name: str = "flux_dev" + clip_max_seq_len: int = 77 + clip_module_name: str = "compiled_flux_text_encoder" + clip_fn_name: str = "encode_prompts" + clip_dtype: sfnp.DType = sfnp.bfloat16 + + max_seq_len: int = 512 + t5xxl_module_name: str = "module" + t5xxl_fn_name: str = "forward_bs4" + t5xxl_dtype: sfnp.DType = sfnp.bfloat16 + + # Channel dim of latents. + num_latents_channels: int = 16 + + sampler_module_name: str = "" + sampler_fn_name: str = "main_graph" + sampler_dtype: sfnp.DType = sfnp.float32 + + vae_module_name: str = "compiled_vae" + vae_fn_name: str = "decode" + vae_dtype: sfnp.DType = sfnp.float32 + + # Whether model is "schnell" (fast) or not. This is roughly equivalent to "turbo" from SDXL. + # It cuts batch dims in half for sampling/encoding and removes negative prompt functionality. + is_schnell: bool = False + + # ABI of the module. + module_abi_version: int = 1 + + @property + def max_clip_batch_size(self) -> int: + return self.clip_batch_sizes[-1] + + @property + def max_sampler_batch_size(self) -> int: + return self.sampler_batch_sizes[-1] + + @property + def max_vae_batch_size(self) -> int: + return self.vae_batch_sizes[-1] + + @property + def all_batch_sizes(self) -> list: + return [self.clip_batch_sizes, self.sampler_batch_sizes, self.vae_batch_sizes] + + @property + def max_batch_size(self): + return max(self.all_batch_sizes) + + @staticmethod + def load_json(path: Path | str): + with open(path, "rt") as f: + json_text = f.read() + raw_params = ModelParams.from_json(json_text) + for i in ["sampler_dtype", "t5xxl_dtype", "clip_dtype", "vae_dtype"]: + if isinstance(i, str): + setattr(raw_params, i, str_to_dtype[getattr(raw_params, i)]) + return raw_params diff --git a/shortfin/python/shortfin_apps/flux/components/generate.py b/shortfin/python/shortfin_apps/flux/components/generate.py new file mode 100644 index 000000000..1c3560a5d --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/generate.py @@ -0,0 +1,102 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio +import logging +import json + +import shortfin as sf + +# TODO: Have a generic "Responder" interface vs just the concrete impl. +from shortfin.interop.fastapi import FastAPIResponder + +from .io_struct import GenerateReqInput +from .messages import InferenceExecRequest +from .service import GenerateService +from .metrics import measure + +logger = logging.getLogger("shortfin-flux.generate") + + +class GenerateImageProcess(sf.Process): + """Process instantiated for every image generation. + + This process breaks the sequence into individual inference and sampling + steps, submitting them to the batcher and marshaling final + results. + + Responsible for a single image. + """ + + def __init__( + self, + client: "ClientGenerateBatchProcess", + gen_req: GenerateReqInput, + index: int, + ): + super().__init__(fiber=client.fiber) + self.client = client + self.gen_req = gen_req + self.index = index + self.result_image = None + + async def run(self): + exec = InferenceExecRequest.from_batch(self.gen_req, self.index) + self.client.batcher.submit(exec) + await exec.done + self.result_image = exec.result_image + + +class ClientGenerateBatchProcess(sf.Process): + """Process instantiated for handling a batch from a client. + + This takes care of several responsibilities: + + * Tokenization + * Random Latents Generation + * Splitting the batch into GenerateImageProcesses + * Streaming responses + * Final responses + """ + + __slots__ = [ + "batcher", + "complete_infeed", + "gen_req", + "responder", + ] + + def __init__( + self, + service: GenerateService, + gen_req: GenerateReqInput, + responder: FastAPIResponder, + ): + super().__init__(fiber=service.fibers[0]) + self.gen_req = gen_req + self.responder = responder + self.batcher = service.batcher + self.complete_infeed = self.system.create_queue() + + async def run(self): + logger.debug("Started ClientBatchGenerateProcess: %r", self) + try: + # Launch all individual generate processes and wait for them to finish. + gen_processes = [] + for index in range(self.gen_req.num_output_images): + gen_process = GenerateImageProcess(self, self.gen_req, index) + gen_processes.append(gen_process) + gen_process.launch() + + await asyncio.gather(*gen_processes) + + # TODO: stream image outputs + logging.debug("Responding to one shot batch") + response_data = {"images": [p.result_image for p in gen_processes]} + json_str = json.dumps(response_data) + self.responder.send_response(json_str) + finally: + self.responder.ensure_response() diff --git a/shortfin/python/shortfin_apps/flux/components/io_struct.py b/shortfin/python/shortfin_apps/flux/components/io_struct.py new file mode 100644 index 000000000..73e77316f --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/io_struct.py @@ -0,0 +1,79 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import List, Optional, Union +from dataclasses import dataclass +import uuid + + +@dataclass +class GenerateReqInput: + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None + # The input negative prompt. It can be a single prompt or a batch of prompts. + neg_prompt: Optional[Union[List[str], str]] = None + # Output image dimensions per prompt. + height: Optional[Union[List[int], int]] = None + width: Optional[Union[List[int], int]] = None + # The number of inference steps; one int per prompt. + steps: Optional[Union[List[int], int]] = None + # The classifier-free-guidance scale for denoising; one float per prompt. + guidance_scale: Optional[Union[List[float], float]] = None + # The seed for random latents generation; one int per prompt. + seed: Optional[Union[List[int], int]] = None + # Token ids: only used in place of prompt. + input_ids: Optional[Union[List[List[int]], List[int]]] = None + # Negative token ids: only used in place of negative prompt. + neg_input_ids: Optional[Union[List[List[int]], List[int]]] = None + # Output image format. Defaults to base64. One string ("PIL", "base64") + output_type: Optional[List[str]] = None + # The request id. + rid: Optional[Union[List[str], str]] = None + + def post_init(self): + if (self.prompt is None and self.input_ids is None) or ( + self.prompt is not None and self.input_ids is not None + ): + raise ValueError("Either text or input_ids should be provided.") + + if isinstance(self.prompt, str): + self.prompt = [str] + + self.num_output_images = ( + len(self.prompt) if self.prompt is not None else len(self.input_ids) + ) + + batchable_args = [ + self.prompt, + self.neg_prompt, + self.height, + self.width, + self.steps, + self.guidance_scale, + self.seed, + self.input_ids, + self.neg_input_ids, + ] + for arg in batchable_args: + if isinstance(arg, list): + if len(arg) != self.num_output_images and len(arg) != 1: + raise ValueError( + f"Batchable arguments should either be singular or as many as the full batch ({self.num_output_images})." + ) + if self.rid is None: + self.rid = [uuid.uuid4().hex for _ in range(self.num_output_images)] + else: + if not isinstance(self.rid, list): + raise ValueError("The rid should be a list.") + if self.output_type is None: + self.output_type = ["base64"] * self.num_output_images + # Temporary restrictions + heights = [self.height] if not isinstance(self.height, list) else self.height + widths = [self.width] if not isinstance(self.width, list) else self.width + if any(dim != 1024 for dim in [*heights, *widths]): + raise ValueError( + "Currently, only 1024x1024 output image size is supported." + ) diff --git a/shortfin/python/shortfin_apps/flux/components/manager.py b/shortfin/python/shortfin_apps/flux/components/manager.py new file mode 100644 index 000000000..37f090b3f --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/manager.py @@ -0,0 +1,48 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import threading + +import shortfin as sf +from shortfin.interop.support.device_setup import get_selected_devices + +logger = logging.getLogger("shortfin-flux.manager") + + +class SystemManager: + def __init__(self, device="local-task", device_ids=None, async_allocs=True): + if any(x in device for x in ["local-task", "cpu"]): + self.ls = sf.host.CPUSystemBuilder().create_system() + elif any(x in device for x in ["hip", "amdgpu"]): + sb = sf.SystemBuilder( + system_type="amdgpu", amdgpu_async_allocations=async_allocs + ) + if device_ids: + sb.visible_devices = sb.available_devices + sb.visible_devices = get_selected_devices(sb, device_ids) + self.ls = sb.create_system() + logger.info(f"Created local system with {self.ls.device_names} devices") + # TODO: Come up with an easier bootstrap thing than manually + # running a thread. + self.t = threading.Thread(target=lambda: self.ls.run(self.run())) + self.command_queue = self.ls.create_queue("command") + self.command_writer = self.command_queue.writer() + + def start(self): + logger.info("Starting system manager") + self.t.start() + + def shutdown(self): + logger.info("Shutting down system manager") + self.command_queue.close() + self.ls.shutdown() + + async def run(self): + reader = self.command_queue.reader() + while command := await reader(): + ... + logger.info("System manager command processor stopped") diff --git a/shortfin/python/shortfin_apps/flux/components/messages.py b/shortfin/python/shortfin_apps/flux/components/messages.py new file mode 100644 index 000000000..8646da1f6 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/messages.py @@ -0,0 +1,190 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from enum import Enum + +import logging + +import shortfin as sf +import shortfin.array as sfnp + +from .io_struct import GenerateReqInput + +logger = logging.getLogger("shortfin-sd.messages") + + +class InferencePhase(Enum): + # Tokenize prompt, negative prompt and get latents, timesteps, time ids, guidance scale as device arrays + PREPARE = 1 + # Run CLIP to encode tokenized prompts into text embeddings + ENCODE = 2 + # Run UNet to denoise the random sample + DENOISE = 3 + # Run VAE to decode the denoised latents into an image. + DECODE = 4 + # Postprocess VAE outputs. + POSTPROCESS = 5 + + +class InferenceExecRequest(sf.Message): + """ + Generalized request passed for an individual phase of image generation. + + Used for individual image requests. Bundled as lists by the batcher for inference processes, + and inputs joined for programs with bs>1. + + Inference execution processes are responsible for writing their outputs directly to the appropriate attributes here. + """ + + def __init__( + self, + prompt: str | None = None, + neg_prompt: str | None = None, + height: int | None = None, + width: int | None = None, + steps: int | None = None, + guidance_scale: float | sfnp.device_array | None = None, + seed: int | None = None, + clip_input_ids: list[list[int]] | None = None, + t5xxl_input_ids: list[list[int]] | None = None, + sample: sfnp.device_array | None = None, + txt: sfnp.device_array | None = None, + vec: sfnp.device_array | None = None, + img_ids: sfnp.device_array | None = None, + txt_ids: sfnp.device_array | None = None, + timesteps: sfnp.device_array | None = None, + denoised_latents: sfnp.device_array | None = None, + image_array: sfnp.device_array | None = None, + ): + super().__init__() + self.print_debug = True + + self.phases = {} + self.phase = None + self.height = height + self.width = width + + # Phase inputs: + # Prep phase. + self.prompt = prompt + self.neg_prompt = neg_prompt + self.height = height + self.width = width + self.seed = seed + + # Encode phase. + # This is a list of sequenced positive and negative token ids and pooler token ids. + self.clip_input_ids = clip_input_ids + self.t5xxl_input_ids = t5xxl_input_ids + self.sample = sample + + # Denoise phase. + self.img = None + self.txt = txt + self.vec = vec + self.img_ids = img_ids + self.txt_ids = txt_ids + self.steps = steps + self.timesteps = timesteps + self.guidance_scale = guidance_scale + + # Decode phase. + self.denoised_latents = denoised_latents + + # Postprocess. + self.image_array = image_array + + self.result_image = None + self.img_metadata = None + + self.done = sf.VoidFuture() + + # Response control. + # Move the result array to the host and sync to ensure data is + # available. + self.return_host_array: bool = True + + self.post_init() + + @staticmethod + def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest": + gen_inputs = [ + "prompt", + "neg_prompt", + "height", + "width", + "steps", + "guidance_scale", + "seed", + ] + rec_inputs = {} + for item in gen_inputs: + received = getattr(gen_req, item, None) + if isinstance(received, list): + if index >= (len(received)): + if len(received) == 1: + rec_input = received[0] + else: + logging.error( + "Inputs in request must be singular or as many as the list of prompts." + ) + else: + rec_input = received[index] + else: + rec_input = received + rec_inputs[item] = rec_input + return InferenceExecRequest(**rec_inputs) + + def post_init(self): + """Determines necessary inference phases and tags them with static program parameters.""" + for p in reversed(list(InferencePhase)): + required, metadata = self.check_phase(p) + p_data = {"required": required, "metadata": metadata} + self.phases[p] = p_data + if not required: + if p not in [ + InferencePhase.ENCODE, + InferencePhase.PREPARE, + ]: + break + self.phase = p + + def check_phase(self, phase: InferencePhase): + match phase: + case InferencePhase.POSTPROCESS: + return True, None + case InferencePhase.DECODE: + required = not self.image_array + meta = [self.width, self.height] + return required, meta + case InferencePhase.DENOISE: + required = not self.denoised_latents + meta = [self.width, self.height, self.steps] + return required, meta + case InferencePhase.ENCODE: + p_results = [ + self.txt, + self.vec, + ] + required = any([inp is None for inp in p_results]) + return required, None + case InferencePhase.PREPARE: + p_results = [self.sample, self.clip_input_ids, self.t5xxl_input_ids] + required = any([inp is None for inp in p_results]) + return required, None + + def reset(self, phase: InferencePhase): + """Resets all per request state in preparation for an subsequent execution.""" + self.phase = None + self.phases = None + self.done = sf.VoidFuture() + self.return_host_array = True + + +class StrobeMessage(sf.Message): + """Sent to strobe a queue with fake activity (generate a wakeup).""" + + ... diff --git a/shortfin/python/shortfin_apps/flux/components/metrics.py b/shortfin/python/shortfin_apps/flux/components/metrics.py new file mode 100644 index 000000000..62e855698 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/metrics.py @@ -0,0 +1,51 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import time +from typing import Any +import functools + +logger = logging.getLogger("shortfin-sd.metrics") + + +def measure(fn=None, type="exec", task=None, num_items=None, freq=1, label="items"): + assert callable(fn) or fn is None + + def _decorator(func): + @functools.wraps(func) + async def wrapped_fn_async(*args: Any, **kwargs: Any) -> Any: + start = time.time() + ret = await func(*args, **kwargs) + duration = time.time() - start + if type == "exec": + batch_size = len(getattr(args[0], "exec_requests", [])) + log_duration_str(duration, task=task, batch_size=batch_size) + if type == "throughput": + if isinstance(num_items, str): + items = getattr(args[0].gen_req, num_items) + else: + items = str(num_items) + log_throughput(duration, items, freq, label) + return ret + + return wrapped_fn_async + + return _decorator(fn) if callable(fn) else _decorator + + +def log_throughput(duration, num_items, freq, label) -> str: + sps = str(float(num_items) / duration) * freq + freq_str = "second" if freq == 1 else f"{freq} seconds" + logger.info(f"THROUGHPUT: {sps} {label} per {freq_str}") + + +def log_duration_str(duration: float, task, batch_size=0) -> str: + """Get human readable duration string from start time""" + if batch_size > 0: + task = f"{task} (batch size {batch_size})" + duration_str = f"{round(duration * 1e3)}ms" + logger.info(f"Completed {task} in {duration_str}") diff --git a/shortfin/python/shortfin_apps/flux/components/service.py b/shortfin/python/shortfin_apps/flux/components/service.py new file mode 100644 index 000000000..6fef52939 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/service.py @@ -0,0 +1,734 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio +import logging +import numpy as np +from tqdm.auto import tqdm +from pathlib import Path +from PIL import Image +import base64 + +import shortfin as sf +import shortfin.array as sfnp + +from .config_struct import ModelParams +from .manager import SystemManager +from .messages import InferenceExecRequest, InferencePhase, StrobeMessage +from .tokenizer import Tokenizer +from .metrics import measure + +logger = logging.getLogger("shortfin-flux.service") + +prog_isolations = { + "none": sf.ProgramIsolation.NONE, + "per_fiber": sf.ProgramIsolation.PER_FIBER, + "per_call": sf.ProgramIsolation.PER_CALL, +} + + +class GenerateService: + """Top level service interface for image generation.""" + + inference_programs: dict[str, sf.Program] + + inference_functions: dict[str, dict[str, sf.ProgramFunction]] + + def __init__( + self, + *, + name: str, + sysman: SystemManager, + clip_tokenizers: list[Tokenizer], + t5xxl_tokenizers: list[Tokenizer], + model_params: ModelParams, + fibers_per_device: int, + workers_per_device: int = 1, + prog_isolation: str = "per_fiber", + show_progress: bool = False, + trace_execution: bool = False, + ): + self.name = name + + # Application objects. + self.sysman = sysman + self.clip_tokenizers = clip_tokenizers + self.t5xxl_tokenizers = t5xxl_tokenizers + self.model_params = model_params + self.inference_parameters: dict[str, list[sf.BaseProgramParameters]] = {} + self.inference_modules: dict[str, sf.ProgramModule] = {} + self.inference_functions: dict[str, dict[str, sf.ProgramFunction]] = {} + self.inference_programs: dict[int, dict[str, sf.Program]] = {} + self.trace_execution = trace_execution + self.show_progress = show_progress + + self.prog_isolation = prog_isolations[prog_isolation] + + self.workers_per_device = workers_per_device + self.fibers_per_device = fibers_per_device + if fibers_per_device % workers_per_device != 0: + raise ValueError( + "Currently, fibers_per_device must be divisible by workers_per_device" + ) + self.fibers_per_worker = int(fibers_per_device / workers_per_device) + + self.workers = [] + self.fibers = [] + self.idle_fibers = set() + # For each worker index we create one on each device, and add their fibers to the idle set. + # This roughly ensures that the first picked fibers are distributed across available devices. + for i in range(self.workers_per_device): + for idx, device in enumerate(self.sysman.ls.devices): + worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}") + self.workers.append(worker) + for idx, device in enumerate(self.sysman.ls.devices): + for i in range(self.fibers_per_device): + tgt_worker = self.workers[i % len(self.workers)] + fiber = sysman.ls.create_fiber(tgt_worker, devices=[device]) + self.fibers.append(fiber) + self.idle_fibers.add(fiber) + for idx in range(len(self.workers)): + self.inference_programs[idx] = {} + self.inference_functions[idx] = { + "clip": {}, + "t5xxl": {}, + "denoise": {}, + "decode": {}, + } + # Scope dependent objects. + self.batcher = BatcherProcess(self) + + def get_worker_index(self, fiber): + if fiber not in self.fibers: + raise ValueError("A worker was requested from a rogue fiber.") + fiber_idx = self.fibers.index(fiber) + worker_idx = int( + (fiber_idx - fiber_idx % self.fibers_per_worker) / self.fibers_per_worker + ) + return worker_idx + + def load_inference_module(self, vmfb_path: Path, component: str = None): + if not self.inference_modules.get(component): + self.inference_modules[component] = [] + self.inference_modules[component].append( + sf.ProgramModule.load(self.sysman.ls, vmfb_path) + ) + + def load_inference_parameters( + self, + *paths: Path, + parameter_scope: str, + format: str = "", + component: str = None, + ): + p = sf.StaticProgramParameters(self.sysman.ls, parameter_scope=parameter_scope) + for path in paths: + logger.info("Loading parameter fiber '%s' from: %s", parameter_scope, path) + p.load(path, format=format) + if not self.inference_parameters.get(component): + self.inference_parameters[component] = [] + self.inference_parameters[component].append(p) + + def start(self): + # Initialize programs. + for component in self.inference_modules: + logger.info(f"Loading component: {component}") + component_modules = [ + sf.ProgramModule.parameter_provider( + self.sysman.ls, *self.inference_parameters.get(component, []) + ), + *self.inference_modules[component], + ] + + for worker_idx, worker in enumerate(self.workers): + worker_devices = self.fibers[ + worker_idx * (self.fibers_per_worker) + ].raw_devices + logger.info( + f"Loading inference program: {component}, worker index: {worker_idx}, device: {worker_devices}" + ) + self.inference_programs[worker_idx][component] = sf.Program( + modules=component_modules, + devices=worker_devices, + isolation=self.prog_isolation, + trace_execution=self.trace_execution, + ) + + for worker_idx, worker in enumerate(self.workers): + for bs in self.model_params.clip_batch_sizes: + self.inference_functions[worker_idx]["clip"][ + bs + ] = self.inference_programs[worker_idx]["clip"][ + f"{self.model_params.clip_module_name}.encode_prompts" + ] + for bs in self.model_params.t5xxl_batch_sizes: + self.inference_functions[worker_idx]["t5xxl"][ + bs + ] = self.inference_programs[worker_idx]["t5xxl"][ + f"{self.model_params.t5xxl_module_name}.forward_bs4" + ] + self.inference_functions[worker_idx]["denoise"] = {} + for bs in self.model_params.sampler_batch_sizes: + self.inference_functions[worker_idx]["denoise"][bs] = { + "sampler": self.inference_programs[worker_idx]["sampler"][ + f"{self.model_params.sampler_module_name}.{self.model_params.sampler_fn_name}" + ], + } + self.inference_functions[worker_idx]["decode"] = {} + for bs in self.model_params.vae_batch_sizes: + self.inference_functions[worker_idx]["decode"][ + bs + ] = self.inference_programs[worker_idx]["vae"][ + f"{self.model_params.vae_module_name}.decode" + ] + self.batcher.launch() + + def shutdown(self): + self.batcher.shutdown() + + def __repr__(self): + modules = [ + f" {key} : {value}" for key, value in self.inference_modules.items() + ] + params = [ + f" {key} : {value}" for key, value in self.inference_parameters.items() + ] + # For python 3.11 since we can't have \ in the f"" expression. + new_line = "\n" + return ( + f"ServiceManager(" + f"\n INFERENCE DEVICES : \n" + f" {self.sysman.ls.devices}\n" + f"\n MODEL PARAMS : \n" + f"{self.model_params}" + f"\n SERVICE PARAMS : \n" + f" fibers per device : {self.fibers_per_device}\n" + f" program isolation mode : {self.prog_isolation}\n" + f"\n INFERENCE MODULES : \n" + f"{new_line.join(modules)}\n" + f"\n INFERENCE PARAMETERS : \n" + f"{new_line.join(params)}\n" + f")" + ) + + +######################################################################################## +# Batcher +######################################################################################## + + +class BatcherProcess(sf.Process): + """The batcher is a persistent process responsible for flighting incoming work + into batches. + """ + + STROBE_SHORT_DELAY = 0.5 + STROBE_LONG_DELAY = 1 + + def __init__(self, service: GenerateService): + super().__init__(fiber=service.fibers[0]) + self.service = service + self.batcher_infeed = self.system.create_queue() + self.pending_requests: set[InferenceExecRequest] = set() + self.strobe_enabled = True + self.strobes: int = 0 + self.ideal_batch_size: int = max(service.model_params.max_batch_size) + self.num_fibers = len(service.fibers) + + def shutdown(self): + self.batcher_infeed.close() + + def submit(self, request: StrobeMessage | InferenceExecRequest): + self.batcher_infeed.write_nodelay(request) + + async def _background_strober(self): + while not self.batcher_infeed.closed: + await asyncio.sleep( + BatcherProcess.STROBE_SHORT_DELAY + if len(self.pending_requests) > 0 + else BatcherProcess.STROBE_LONG_DELAY + ) + if self.strobe_enabled: + self.submit(StrobeMessage()) + + async def run(self): + strober_task = asyncio.create_task(self._background_strober()) + reader = self.batcher_infeed.reader() + while item := await reader(): + self.strobe_enabled = False + if isinstance(item, InferenceExecRequest): + self.pending_requests.add(item) + elif isinstance(item, StrobeMessage): + self.strobes += 1 + else: + logger.error("Illegal message received by batcher: %r", item) + + self.board_flights() + + self.strobe_enabled = True + await strober_task + + def board_flights(self): + waiting_count = len(self.pending_requests) + if waiting_count == 0: + return + if waiting_count < self.ideal_batch_size and self.strobes < 2: + logger.info("Waiting a bit longer to fill flight") + return + self.strobes = 0 + batches = self.sort_batches() + for batch in batches.values(): + # Assign the batch to the next idle fiber. + if len(self.service.idle_fibers) == 0: + return + fiber = self.service.idle_fibers.pop() + fiber_idx = self.service.fibers.index(fiber) + worker_idx = self.service.get_worker_index(fiber) + logger.debug(f"Sending batch to fiber {fiber_idx} (worker {worker_idx})") + self.board(batch["reqs"], fiber=fiber) + if self.service.prog_isolation != sf.ProgramIsolation.PER_FIBER: + self.service.idle_fibers.add(fiber) + + def sort_batches(self): + """Files pending requests into sorted batches suitable for program invocations.""" + reqs = self.pending_requests + next_key = 0 + batches = {} + for req in reqs: + is_sorted = False + req_metas = [req.phases[phase]["metadata"] for phase in req.phases.keys()] + + for idx_key, data in batches.items(): + if not isinstance(data, dict): + logger.error( + "Expected to find a dictionary containing a list of requests and their shared metadatas." + ) + if len(batches[idx_key]["reqs"]) >= self.ideal_batch_size: + # Batch is full + next_key = idx_key + 1 + continue + elif data["meta"] == req_metas: + batches[idx_key]["reqs"].extend([req]) + is_sorted = True + break + else: + next_key = idx_key + 1 + if not is_sorted: + batches[next_key] = { + "reqs": [req], + "meta": req_metas, + } + return batches + + def board(self, request_bundle, fiber): + pending = request_bundle + if len(pending) == 0: + return + exec_process = InferenceExecutorProcess(self.service, fiber) + for req in pending: + if len(exec_process.exec_requests) >= self.ideal_batch_size: + break + exec_process.exec_requests.append(req) + if exec_process.exec_requests: + for flighted_request in exec_process.exec_requests: + self.pending_requests.remove(flighted_request) + exec_process.launch() + + +######################################################################################## +# Inference Executors +######################################################################################## + + +class InferenceExecutorProcess(sf.Process): + """Executes a stable diffusion inference batch""" + + def __init__( + self, + service: GenerateService, + fiber, + ): + super().__init__(fiber=fiber) + self.service = service + self.worker_index = self.service.get_worker_index(fiber) + self.exec_requests: list[InferenceExecRequest] = [] + + @measure(type="exec", task="inference process") + async def run(self): + try: + phase = None + for req in self.exec_requests: + if phase: + if phase != req.phase: + logger.error("Executor process recieved disjoint batch.") + phase = req.phase + phases = self.exec_requests[0].phases + req_count = len(self.exec_requests) + device0 = self.fiber.device(0) + if phases[InferencePhase.PREPARE]["required"]: + await self._prepare(device=device0, requests=self.exec_requests) + if phases[InferencePhase.ENCODE]["required"]: + await self._clip(device=device0, requests=self.exec_requests) + await self._t5xxl(device=device0, requests=self.exec_requests) + if phases[InferencePhase.DENOISE]["required"]: + await self._denoise(device=device0, requests=self.exec_requests) + if phases[InferencePhase.DECODE]["required"]: + await self._decode(device=device0, requests=self.exec_requests) + if phases[InferencePhase.POSTPROCESS]["required"]: + await self._postprocess(device=device0, requests=self.exec_requests) + await device0 + for i in range(req_count): + req = self.exec_requests[i] + breakpoint() + req.done.set_success() + if self.service.prog_isolation == sf.ProgramIsolation.PER_FIBER: + self.service.idle_fibers.add(self.fiber) + + except Exception: + logger.exception("Fatal error in image generation") + # TODO: Cancel and set error correctly + for req in self.exec_requests: + req.done.set_success() + + async def _prepare(self, device, requests): + for request in requests: + # Tokenize prompts and negative prompts. We tokenize in bs1 for now and join later. + clip_input_ids_list = [] + clip_neg_ids_list = [] + for tokenizer in self.service.clip_tokenizers: + input_ids = tokenizer.encode(request.prompt) + clip_input_ids_list.append(input_ids) + neg_ids = tokenizer.encode(request.neg_prompt) + clip_neg_ids_list.append(neg_ids) + clip_ids_list = [*clip_input_ids_list, *clip_neg_ids_list] + + request.clip_input_ids = clip_ids_list + + t5xxl_input_ids_list = [] + t5xxl_neg_ids_list = [] + for tokenizer in self.service.t5xxl_tokenizers: + input_ids = tokenizer.encode(request.prompt) + t5xxl_input_ids_list.append(input_ids) + neg_ids = tokenizer.encode(request.neg_prompt) + t5xxl_neg_ids_list.append(neg_ids) + t5xxl_ids_list = [*t5xxl_input_ids_list, *t5xxl_neg_ids_list] + + request.t5xxl_input_ids = t5xxl_ids_list + + # Generate random sample latents. + seed = request.seed + channels = self.service.model_params.num_latents_channels + latents_shape = [ + 1, + (requests[0].height) * (requests[0].width) // 256, + 64, + ] + # latents_shape = ( + # 1, + # channels, + # request.height // 8, + # request.width // 8, + # ) + + # Create and populate sample device array. + generator = sfnp.RandomGenerator(seed) + request.sample = sfnp.device_array.for_device( + device, latents_shape, self.service.model_params.sampler_dtype + ) + + sample_host = request.sample.for_transfer() + with sample_host.map(discard=True) as m: + m.fill(bytes(1)) + + sfnp.fill_randn(sample_host, generator=generator) + + request.sample.copy_from(sample_host) + await device + return + + async def _clip(self, device, requests): + req_bs = len(requests) + entrypoints = self.service.inference_functions[self.worker_index]["clip"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._clip(device, [request]) + return + for bs, fn in entrypoints.items(): + if bs == req_bs: + break + + # Prepare tokenized input ids for CLIP inference + + clip_inputs = [ + sfnp.device_array.for_device( + device, + [req_bs, self.service.model_params.clip_max_seq_len, 2], + sfnp.sint64, + ), + ] + host_arrs = [None] + for idx, arr in enumerate(clip_inputs): + host_arrs[idx] = arr.for_transfer() + for i in range(req_bs): + with host_arrs[idx].view(i).map(write=True, discard=True) as m: + + num_ids = len(requests[i].clip_input_ids) + np_arr = requests[i].clip_input_ids[idx % (num_ids - 1)].input_ids + + m.fill(np_arr) + clip_inputs[idx].copy_from(host_arrs[idx]) + + # Encode tokenized inputs. + logger.debug( + "INVOKE %r: %s", + fn, + "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]), + ) + (vec, _) = await fn(*clip_inputs, fiber=self.fiber) + + await device + for i in range(req_bs): + cfg_mult = 2 + requests[i].vec = vec.view(slice(i * cfg_mult, (i + 1) * cfg_mult)) + + return + + async def _t5xxl(self, device, requests): + req_bs = len(requests) + entrypoints = self.service.inference_functions[self.worker_index]["t5xxl"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._t5xxl(device, [request]) + return + for bs, fn in entrypoints.items(): + if bs == req_bs: + break + + # Prepare tokenized input ids for t5xxl inference + + t5xxl_inputs = [ + sfnp.device_array.for_device( + device, [4, self.service.model_params.max_seq_len], sfnp.sint64 + ), + ] + host_arrs = [None] + for idx, arr in enumerate(t5xxl_inputs): + host_arrs[idx] = arr.for_transfer() + for i in range(req_bs): + np_arr = requests[i].t5xxl_input_ids[idx].input_ids + for rep in range(4): + with host_arrs[idx].view(rep).map(write=True, discard=True) as m: + m.fill(np_arr) + t5xxl_inputs[idx].copy_from(host_arrs[idx]) + + # Encode tokenized inputs. + logger.debug( + "INVOKE %r: %s", + fn, + "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(t5xxl_inputs)]), + ) + await device + (txt,) = await fn(*t5xxl_inputs, fiber=self.fiber) + await device + for i in range(req_bs): + cfg_mult = 2 + requests[i].txt = txt.view(slice(i * cfg_mult, (i + 1) * cfg_mult)) + + return + + async def _denoise(self, device, requests): + req_bs = len(requests) + step_count = requests[0].steps + cfg_mult = 2 if not self.service.model_params.is_schnell else 1 + # Produce denoised latents + entrypoints = self.service.inference_functions[self.worker_index]["denoise"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._denoise(device, [request]) + return + for bs, fns in entrypoints.items(): + if bs == req_bs: + break + + # Get shape of batched latents. + # This assumes all requests are dense at this point. + img_shape = [ + req_bs * cfg_mult, + (requests[0].height) * (requests[0].width) // 256, + 64, + ] + # Assume we are doing classifier-free guidance + txt_shape = [ + req_bs * cfg_mult, + self.service.model_params.max_seq_len, + 4096, + ] + vec_shape = [ + req_bs * cfg_mult, + 768, + ] + denoise_inputs = { + "img": sfnp.device_array.for_device( + device, img_shape, self.service.model_params.sampler_dtype + ), + "txt": sfnp.device_array.for_device( + device, txt_shape, self.service.model_params.sampler_dtype + ), + "vec": sfnp.device_array.for_device( + device, vec_shape, self.service.model_params.sampler_dtype + ), + "step": sfnp.device_array.for_device(device, [1], sfnp.int64), + "num_steps": sfnp.device_array.for_device(device, [1], sfnp.int64), + "guidance_scale": sfnp.device_array.for_device( + device, [req_bs], self.service.model_params.sampler_dtype + ), + } + # Send guidance scale to device. + gs_host = denoise_inputs["guidance_scale"].for_transfer() + sample_host = sfnp.device_array.for_host( + device, img_shape, self.service.model_params.sampler_dtype + ) + for i in range(req_bs): + cfg_dim = i * cfg_mult + with gs_host.view(i).map(write=True, discard=True) as m: + # TODO: do this without numpy + np_arr = np.asarray(requests[i].guidance_scale, dtype="float32") + + m.fill(np_arr) + + # Reshape and batch sample latent inputs on device. + # Currently we just generate random latents in the desired shape. Rework for img2img. + req_samp = requests[i].sample + for rep in range(cfg_mult): + sample_host.view(slice(cfg_dim + rep, cfg_dim + rep + 1)).copy_from( + req_samp + ) + + denoise_inputs["img"].view(slice(cfg_dim, cfg_dim + cfg_mult)).copy_from( + sample_host + ) + + # Batch t5xxl hidden states. + txt = requests[i].txt + denoise_inputs["txt"].view(slice(cfg_dim, cfg_dim + cfg_mult)).copy_from( + txt + ) + + # Batch CLIP projections. + vec = requests[i].vec + denoise_inputs["vec"].view(slice(cfg_dim, cfg_dim + cfg_mult)).copy_from( + vec + ) + + denoise_inputs["guidance_scale"].copy_from(gs_host) + + ns_host = denoise_inputs["num_steps"].for_transfer() + with ns_host.map(write=True) as m: + ns_host.items = [step_count] + + denoise_inputs["num_steps"].copy_from(ns_host) + + for i, t in tqdm( + enumerate(range(step_count)), + disable=(not self.service.show_progress), + desc=f"DENOISE (bs{req_bs})", + ): + s_host = denoise_inputs["step"].for_transfer() + with s_host.map(write=True) as m: + s_host.items = [i] + denoise_inputs["step"].copy_from(s_host) + + logger.debug( + "INVOKE %r", + fns["sampler"], + ) + (noise_pred,) = await fns["sampler"]( + *denoise_inputs.values(), fiber=self.fiber + ) + + denoise_inputs["img"].copy_from(noise_pred) + + for idx, req in enumerate(requests): + req.denoised_latents = sfnp.device_array.for_device( + device, img_shape, self.service.model_params.vae_dtype + ) + req.denoised_latents.copy_from(denoise_inputs["img"].view(idx * cfg_mult)) + return + + async def _decode(self, device, requests): + req_bs = len(requests) + # Decode latents to images + entrypoints = self.service.inference_functions[self.worker_index]["decode"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._decode(device, [request]) + return + for bs, fn in entrypoints.items(): + if bs == req_bs: + break + await device + latents_shape = [ + req_bs, + (requests[0].height) * (requests[0].width) // 256, + 64, + ] + latents = sfnp.device_array.for_device( + device, latents_shape, self.service.model_params.vae_dtype + ) + for i in range(req_bs): + latents.view(i).copy_from(requests[i].denoised_latents) + + await device + # Decode the denoised latents. + logger.debug( + "INVOKE %r: %s", + fn, + "".join([f"\n 0: {latents.shape}"]), + ) + (image,) = await fn(latents, fiber=self.fiber) + + await device + images_shape = [ + req_bs, + 3, + requests[0].height, + requests[0].width, + ] + images_host = sfnp.device_array.for_host( + device, images_shape, self.service.model_params.vae_dtype + ) + images_host.copy_from(image) + await device + for idx, req in enumerate(requests): + req.image_array = images_host.view(idx) + return + + async def _postprocess(self, device, requests): + # Process output images + for req in requests: + image_shape = [ + 1, + 3, + req.height, + req.width, + ] + images_planar = sfnp.device_array.for_host( + device, image_shape, self.service.model_params.vae_dtype + ) + images_planar.copy_from(req.image_array) + for j in range(3): + data = [0.3 + j * 0.1 for _ in range(req.height * req.width)] + images_planar.view(0, j).items = data + permuted = sfnp.transpose(images_planar, (0, 2, 3, 1)) + breakpoint() + cast_image = sfnp.multiply(127.5, (sfnp.add(permuted, 1.0))) + image = sfnp.round(cast_image, dtype=sfnp.uint8) + + image_bytes = bytes(image.map(read=True)) + + image = base64.b64encode(image_bytes).decode("utf-8") + req.result_image = image + return diff --git a/shortfin/python/shortfin_apps/flux/components/tokenizer.py b/shortfin/python/shortfin_apps/flux/components/tokenizer.py new file mode 100644 index 000000000..54dec295b --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/tokenizer.py @@ -0,0 +1,84 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from transformers import CLIPTokenizer, T5Tokenizer, BatchEncoding + +import shortfin as sf +import shortfin.array as sfnp + + +class Tokenizer: + def __init__( + self, + raw_tk: CLIPTokenizer, + max_length: int = 77, + pad_id: int = 0, + attn_mask=False, + ): + self.pad_id = pad_id + self._raw = raw_tk + self.max_length = max_length + self.return_attention_mask = attn_mask + + @staticmethod + def from_pretrained(name: str, subfolder: str) -> "Tokenizer": + if subfolder == "tokenizer_2": + raw_tk = T5Tokenizer.from_pretrained(name, subfolder=subfolder) + max_length = 512 + else: + raw_tk = CLIPTokenizer.from_pretrained(name, subfolder=subfolder) + max_length = 77 + return Tokenizer(raw_tk, max_length=max_length) + + def encode(self, texts: list[str]): + """Encodes a batch of texts, applying no padding.""" + return self._raw( + texts, + padding="max_length", + max_length=self.max_length, + truncation=True, + return_tensors="np", + return_attention_mask=self.return_attention_mask, + ) + + def encoding_length(self, enc: BatchEncoding) -> int: + """Gets the length of an encoding.""" + return len(enc.input_ids) + + def encodings_to_array( + self, + device: sf.ScopedDevice, + encs: dict[str, BatchEncoding], + batch_seq_len: int, + *, + dtype: sfnp.DType = sfnp.int32, + ): + """Creates a device_array with the contents of a batch of encodings. + + It is expected that the user has called post_process_encodings with + the same batch_seq_len in order to properly truncate/pad. + """ + ary = sfnp.device_array.for_host( + device, [len(encs.input_ids), batch_seq_len], dtype + ) + for i, ids in enumerate(encs.input_ids): + ary.view(i).items = ids + return ary + + def attention_masks_to_array( + self, + device: sf.ScopedDevice, + encs: list[BatchEncoding], + batch_seq_len: int, + *, + dtype: sfnp.DType = sfnp.int32, + ): + ary = sfnp.device_array.for_host( + device, [len(encs.attention_mask), batch_seq_len], dtype + ) + for i, enc in enumerate(encs.attention_mask): + ary.view(i).items = enc + return ary diff --git a/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json b/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json new file mode 100644 index 000000000..cbe44680d --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json @@ -0,0 +1,36 @@ +{ + "base_model_name": "flux_dev", + "is_schnell": false, + "num_latents_channels": 16, + "max_seq_len": 512, + "clip_max_seq_len": 77, + "clip_batch_sizes": [ + 1 + ], + "clip_dtype": "bfloat16", + "clip_module_name": "compiled_flux_text_encoder", + "t5xxl_batch_sizes": [ + 1 + ], + "t5xxl_dtype": "float32", + "t5xxl_module_name": "module", + "t5xxl_fn_name": "forward_bs4", + "sampler_batch_sizes": [ + 1 + ], + "sampler_dtype": "float32", + "sampler_module_name": "module", + "sampler_fn_name": "main_graph", + "vae_batch_sizes": [ + 1 + ], + "vae_dtype": "float32", + "vae_module_name": "compiled_flux_auto_encoder", + "vae_fn_name": "decode", + "dims": [ + [ + 1024, + 1024 + ] + ] +} diff --git a/shortfin/python/shortfin_apps/flux/examples/flux_flags_gfx942.txt b/shortfin/python/shortfin_apps/flux/examples/flux_flags_gfx942.txt new file mode 100644 index 000000000..2b882566e --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/examples/flux_flags_gfx942.txt @@ -0,0 +1,28 @@ +all +--iree-hal-target-backends=rocm +--iree-hip-target=gfx942 +--iree-execution-model=async-external +--iree-global-opt-propagate-transposes=1 +--iree-opt-const-eval=0 +--iree-opt-outer-dim-concat=1 +--iree-opt-aggressively-propagate-transposes=1 +--iree-dispatch-creation-enable-aggressive-fusion +--iree-codegen-llvmgpu-use-vector-distribution=1 +--iree-llvmgpu-enable-prefetch=1 +--iree-codegen-gpu-native-math-precision=1 +--iree-hip-legacy-sync=0 +--iree-opt-data-tiling=0 +--iree-vm-target-truncate-unsupported-floats +clip +--iree-hal-force-indirect-command-buffers +--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' +t5xxl +--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' +sampler +--iree-hal-force-indirect-command-buffers +--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' +--iree-dispatch-creation-enable-fuse-horizontal-contractions=1 +vae +--iree-hal-force-indirect-command-buffers +--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' +--iree-dispatch-creation-enable-fuse-horizontal-contractions=1 diff --git a/shortfin/python/shortfin_apps/flux/examples/flux_request_bs2.json b/shortfin/python/shortfin_apps/flux/examples/flux_request_bs2.json new file mode 100644 index 000000000..0ded22888 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/examples/flux_request_bs2.json @@ -0,0 +1,18 @@ +{ + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal" + ], + "neg_prompt": "Watermark, blurry, oversaturated, low resolution, pollution", + "height": 1024, + "width": 1024, + "steps": 20, + "guidance_scale": [ + 7.5, + 7.9 + ], + "seed": 0, + "output_type": [ + "base64" + ] +} diff --git a/shortfin/python/shortfin_apps/flux/server.py b/shortfin/python/shortfin_apps/flux/server.py new file mode 100644 index 000000000..e6eaef1d3 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/server.py @@ -0,0 +1,424 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Any +import argparse +import logging +from pathlib import Path +import sys +import os +import copy +import subprocess +from contextlib import asynccontextmanager +import uvicorn + +# Import first as it does dep checking and reporting. +from shortfin.interop.fastapi import FastAPIResponder +from shortfin.support.logging_setup import native_handler + +from fastapi import FastAPI, Request, Response + +from .components.generate import ClientGenerateBatchProcess +from .components.config_struct import ModelParams +from .components.io_struct import GenerateReqInput +from .components.manager import SystemManager +from .components.service import GenerateService +from .components.tokenizer import Tokenizer + + +logger = logging.getLogger("shortfin-flux") +logger.addHandler(native_handler) +logger.propagate = False + +THIS_DIR = Path(__file__).resolve().parent + +UVICORN_LOG_CONFIG = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "()": "uvicorn.logging.DefaultFormatter", + "format": "[{asctime}] {message}", + "datefmt": "%Y-%m-%d %H:%M:%S", + "style": "{", + "use_colors": True, + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "default", + }, + }, + "loggers": { + "uvicorn": { + "handlers": ["console"], + "level": "INFO", + "propagate": False, + }, + }, +} + + +@asynccontextmanager +async def lifespan(app: FastAPI): + sysman.start() + try: + for service_name, service in services.items(): + logger.info("Initializing service '%s':", service_name) + logger.info(str(service)) + service.start() + except: + sysman.shutdown() + raise + yield + try: + for service_name, service in services.items(): + logger.info("Shutting down service '%s'", service_name) + service.shutdown() + finally: + sysman.shutdown() + + +sysman: SystemManager +services: dict[str, Any] = {} +app = FastAPI(lifespan=lifespan) + + +@app.get("/health") +async def health() -> Response: + return Response(status_code=200) + + +async def generate_request(gen_req: GenerateReqInput, request: Request): + service = services["sd"] + gen_req.post_init() + responder = FastAPIResponder(request) + ClientGenerateBatchProcess(service, gen_req, responder).launch() + return await responder.response + + +app.post("/generate")(generate_request) +app.put("/generate")(generate_request) + + +def configure_sys(args) -> SystemManager: + # Setup system (configure devices, etc). + model_config, topology_config, flagfile, tuning_spec, args = get_configs(args) + sysman = SystemManager(args.device, args.device_ids, args.amdgpu_async_allocations) + return sysman, model_config, flagfile, tuning_spec + + +def configure_service(args, sysman, model_config, flagfile, tuning_spec): + # Setup each service we are hosting. + clip_tokenizers = [ + Tokenizer.from_pretrained(args.tokenizer_source, subfolder="tokenizer") + ] + t5xxl_tokenizers = [ + Tokenizer.from_pretrained(args.tokenizer_source, subfolder="tokenizer_2") + ] + + model_params = ModelParams.load_json(model_config) + vmfbs, params = get_modules(args, model_config, flagfile, tuning_spec) + + sm = GenerateService( + name="sd", + sysman=sysman, + clip_tokenizers=clip_tokenizers, + t5xxl_tokenizers=t5xxl_tokenizers, + model_params=model_params, + fibers_per_device=args.fibers_per_device, + workers_per_device=args.workers_per_device, + prog_isolation=args.isolation, + show_progress=args.show_progress, + trace_execution=args.trace_execution, + ) + for key, vmfblist in vmfbs.items(): + for vmfb in vmfblist: + sm.load_inference_module(vmfb, component=key) + for key, datasets in params.items(): + sm.load_inference_parameters(*datasets, parameter_scope="model", component=key) + services[sm.name] = sm + return sysman + + +def get_configs(args): + # Returns one set of config artifacts. + modelname = "flux" + model_config = args.model_config if args.model_config else None + topology_config = None + tuning_spec = None + flagfile = args.flagfile if args.flagfile else None + cfg_builder_args = [ + sys.executable, + "-m", + "iree.build", + os.path.join(THIS_DIR, "components", "config_artifacts.py"), + f"--target={args.target}", + f"--output-dir={args.artifacts_dir}", + f"--model={modelname}", + ] + if args.topology: + cfg_builder_args.extend( + [ + f"--topology={args.topology}", + ] + ) + outs = subprocess.check_output(cfg_builder_args).decode() + outs_paths = outs.splitlines() + for i in outs_paths: + if "flux_config" in i and not args.model_config: + model_config = i + elif "topology" in i and args.topology: + topology_config = i + elif "flagfile" in i and not args.flagfile: + flagfile = i + elif "attention_and_matmul_spec" in i and args.use_tuned: + tuning_spec = i + + if args.use_tuned and args.tuning_spec: + tuning_spec = os.path.abspath(args.tuning_spec) + + if topology_config: + with open(topology_config, "r") as f: + contents = [line.rstrip() for line in f] + for spec in contents: + if "--" in spec: + arglist = spec.strip("--").split("=") + arg = arglist[0] + if len(arglist) > 2: + value = arglist[1:] + for val in value: + try: + val = int(val) + except ValueError: + val = val + elif len(arglist) == 2: + value = arglist[-1] + try: + value = int(value) + except ValueError: + value = value + else: + # It's a boolean arg. + value = True + setattr(args, arg, value) + else: + # It's an env var. + arglist = spec.split("=") + os.environ[arglist[0]] = arglist[1] + return model_config, topology_config, flagfile, tuning_spec, args + + +def get_modules(args, model_config, flagfile, td_spec): + # TODO: Move this out of server entrypoint + vmfbs = {"clip": [], "t5xxl": [], "sampler": [], "vae": []} + params = {"clip": [], "t5xxl": [], "sampler": [], "vae": []} + model_flags = copy.deepcopy(vmfbs) + model_flags["all"] = args.compile_flags + + if flagfile: + with open(flagfile, "r") as f: + contents = [line.rstrip() for line in f] + flagged_model = "all" + for elem in contents: + match = [keyw in elem for keyw in model_flags.keys()] + if any(match): + flagged_model = elem + else: + model_flags[flagged_model].extend([elem]) + if td_spec: + model_flags["sampler"].extend( + [f"--iree-codegen-transform-dialect-library={td_spec}"] + ) + + filenames = [] + for modelname in vmfbs.keys(): + ireec_args = model_flags["all"] + model_flags[modelname] + ireec_extra_args = " ".join(ireec_args) + builder_args = [ + sys.executable, + "-m", + "iree.build", + os.path.join(THIS_DIR, "components", "builders.py"), + f"--model-json={model_config}", + f"--target={args.target}", + f"--splat={args.splat}", + f"--build-preference={args.build_preference}", + f"--output-dir={args.artifacts_dir}", + f"--model={modelname}", + f"--iree-hal-target-device={args.device}", + f"--iree-hip-target={args.target}", + f"--iree-compile-extra-args={ireec_extra_args}", + ] + logger.info(f"Preparing runtime artifacts for {modelname}...") + logger.info( + "COMMAND LINE EQUIVALENT: " + " ".join([str(argn) for argn in builder_args]) + ) + output = subprocess.check_output(builder_args).decode() + + output_paths = output.splitlines() + filenames.extend(output_paths) + for name in filenames: + for key in vmfbs.keys(): + if key == "t5xxl" and all(x in name.lower() for x in ["xxl", "irpa"]): + params[key].extend([name]) + if key in name.lower(): + if any(x in name for x in [".irpa", ".safetensors", ".gguf"]): + params[key].extend([name]) + elif "vmfb" in name: + vmfbs[key].extend([name]) + return vmfbs, params + + +def main(argv, log_config=UVICORN_LOG_CONFIG): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--timeout-keep-alive", type=int, default=5, help="Keep alive timeout" + ) + parser.add_argument( + "--device", + type=str, + required=True, + choices=["local-task", "hip", "amdgpu"], + help="Primary inferencing device", + ) + parser.add_argument( + "--target", + type=str, + required=False, + default="gfx942", + choices=["gfx942", "gfx1100", "gfx90a"], + help="Primary inferencing device LLVM target arch.", + ) + parser.add_argument( + "--device_ids", + type=str, + nargs="*", + default=None, + help="Device IDs visible to the system builder. Defaults to None (full visibility). Can be an index or a sf device id like amdgpu:0:0@0", + ) + parser.add_argument( + "--tokenizer_source", + type=Path, + default="black-forest-labs/FLUX.1-dev", + help="HF repo from which to load tokenizer(s).", + ) + parser.add_argument( + "--model_config", type=Path, help="Path to the model config file." + ) + parser.add_argument( + "--workers_per_device", + type=int, + default=1, + help="Concurrency control -- how many fibers are created per device to run inference.", + ) + parser.add_argument( + "--fibers_per_device", + type=int, + default=1, + help="Concurrency control -- how many fibers are created per device to run inference.", + ) + parser.add_argument( + "--isolation", + type=str, + default="per_call", + choices=["per_fiber", "per_call", "none"], + help="Concurrency control -- How to isolate programs.", + ) + parser.add_argument( + "--show_progress", + action="store_true", + help="enable tqdm progress for sampler iterations.", + ) + parser.add_argument( + "--trace_execution", + action="store_true", + help="Enable tracing of program modules.", + ) + parser.add_argument( + "--amdgpu_async_allocations", + action="store_true", + help="Enable asynchronous allocations for amdgpu device contexts.", + ) + parser.add_argument( + "--splat", + action="store_true", + help="Use splat (empty) parameter files, usually for testing.", + ) + parser.add_argument( + "--build_preference", + type=str, + choices=["compile", "precompiled"], + default="precompiled", + help="Specify preference for builder artifact generation.", + ) + parser.add_argument( + "--compile_flags", + type=str, + nargs="*", + default=[], + help="extra compile flags for all compile actions. For fine-grained control, use flagfiles.", + ) + parser.add_argument( + "--flagfile", + type=Path, + help="Path to a flagfile to use for SDXL. If not specified, will use latest flagfile from azure.", + ) + parser.add_argument( + "--artifacts_dir", + type=Path, + default=None, + help="Path to local artifacts cache.", + ) + parser.add_argument( + "--tuning_spec", + type=str, + default=None, + help="Path to transform dialect spec if compiling an executable with tunings.", + ) + parser.add_argument( + "--topology", + type=str, + default=None, + choices=["spx_single", "cpx_single", "spx_multi", "cpx_multi"], + help="Use one of four known performant preconfigured device/fiber topologies.", + ) + parser.add_argument( + "--use_tuned", + type=int, + default=1, + help="Use tunings for attention and matmul ops. 0 to disable.", + ) + args = parser.parse_args(argv) + if not args.artifacts_dir: + home = Path.home() + artdir = home / ".cache" / "shark" + args.artifacts_dir = str(artdir) + else: + args.artifacts_dir = Path(args.artifacts_dir).resolve() + + global sysman + sysman, model_config, flagfile, tuning_spec = configure_sys(args) + configure_service(args, sysman, model_config, flagfile, tuning_spec) + uvicorn.run( + app, + host=args.host, + port=args.port, + log_config=log_config, + timeout_keep_alive=args.timeout_keep_alive, + ) + + +if __name__ == "__main__": + logging.root.setLevel(logging.INFO) + main( + sys.argv[1:], + # Make logging defer to the default shortfin logging config. + log_config=UVICORN_LOG_CONFIG, + ) diff --git a/shortfin/python/shortfin_apps/flux/simple_client.py b/shortfin/python/shortfin_apps/flux/simple_client.py new file mode 100644 index 000000000..42a3e02f3 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/simple_client.py @@ -0,0 +1,246 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from datetime import datetime as dt +import os +import sys +import time +import json +import argparse +import base64 +import asyncio +import aiohttp +import requests + +from PIL import Image + +sample_request = { + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + ], + "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"], + "height": [1024], + "width": [1024], + "steps": [20], + "guidance_scale": [3], + "seed": [0], + "output_type": ["base64"], + "rid": ["string"], +} + + +def bytes_to_img(in_bytes, outputdir, idx=0, width=1024, height=1024): + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + image = Image.frombytes( + mode="RGB", size=(width, height), data=base64.b64decode(in_bytes) + ) + if not os.path.isdir(outputdir): + os.mkdir(outputdir) + im_path = os.path.join(outputdir, f"shortfin_flux_output_{timestamp}_{idx}.png") + image.save(im_path) + print(f"Saved to {im_path}") + + +def get_batched(request, arg, idx): + if isinstance(request[arg], list): + # some args are broadcasted to each prompt, hence overriding idx for single-item entries + if len(request[arg]) == 1: + indexed = request[arg][0] + else: + indexed = request[arg][idx] + else: + indexed = request[arg] + return indexed + + +async def send_request(session, rep, args, data): + print("Sending request batch #", rep) + url = f"{args.host}:{args.port}/generate" + start = time.time() + async with session.post(url, json=data) as response: + end = time.time() + # Check if the response was successful + if response.status == 200: + response.raise_for_status() # Raise an error for bad responses + res_json = await response.json(content_type=None) + if args.save: + for idx, item in enumerate(res_json["images"]): + width = get_batched(data, "width", idx) + height = get_batched(data, "height", idx) + print("Saving response as image...") + bytes_to_img( + item.encode("utf-8"), args.outputdir, idx, width, height + ) + latency = end - start + print("Responses processed.") + return latency, len(data["prompt"]) + print(f"Error: Received {response.status} from server") + raise Exception + + +async def static(args): + # Create an aiohttp session for sending requests + async with aiohttp.ClientSession() as session: + pending = [] + latencies = [] + sample_counts = [] + # Read the JSON file if supplied. Otherwise, get user input. + try: + if not args.file: + data = sample_request + else: + with open(args.file, "r") as json_file: + data = json.load(json_file) + except Exception as e: + print(f"Error reading the JSON file: {e}") + return + data["prompt"] = ( + [data["prompt"]] if isinstance(data["prompt"], str) else data["prompt"] + ) + start = time.time() + + async for i in async_range(args.reps): + pending.append(asyncio.create_task(send_request(session, i, args, data))) + await asyncio.sleep(1) # Wait for 1 second before sending the next request + while pending: + done, pending = await asyncio.wait( + pending, return_when=asyncio.ALL_COMPLETED + ) + for task in done: + latency, num_samples = await task + latencies.append(latency) + sample_counts.append(num_samples) + end = time.time() + if not any(i is None for i in [latencies, sample_counts]): + total_num_samples = sum(sample_counts) + sps = str(total_num_samples / (end - start)) + # Until we have better measurements, don't report the throughput that includes saving images. + if not args.save: + print(f"Average throughput: {sps} samples per second") + else: + raise ValueError("Received error response from server.") + + +async def interactive(args): + # Create an aiohttp session for sending requests + async with aiohttp.ClientSession() as session: + pending = [] + latencies = [] + sample_counts = [] + # Read the JSON file if supplied. Otherwise, get user input. + try: + if not args.file: + data = sample_request + else: + with open(args.file, "r") as json_file: + data = json.load(json_file) + except Exception as e: + print(f"Error reading the JSON file: {e}") + return + data["prompt"] = ( + [data["prompt"]] if isinstance(data["prompt"], str) else data["prompt"] + ) + while True: + prompt = await ainput("Enter a prompt: ") + data["prompt"] = [prompt] + data["steps"] = [args.steps] + print("Sending request with prompt: ", data["prompt"]) + + async for i in async_range(args.reps): + pending.append( + asyncio.create_task(send_request(session, i, args, data)) + ) + await asyncio.sleep( + 1 + ) # Wait for 1 second before sending the next request + while pending: + done, pending = await asyncio.wait( + pending, return_when=asyncio.ALL_COMPLETED + ) + for task in done: + _, _ = await task + pending = [] + if any(i is None for i in [latencies, sample_counts]): + raise ValueError("Received error response from server.") + + +async def ainput(prompt: str) -> str: + return await asyncio.to_thread(input, f"{prompt} ") + + +async def async_range(count): + for i in range(count): + yield i + await asyncio.sleep(0.0) + + +def check_health(url): + ready = False + print("Waiting for server.", end=None) + while not ready: + try: + if requests.get(f"{url}/health", timeout=20).status_code == 200: + print("Successfully connected to server.") + ready = True + return + time.sleep(2) + print(".", end=None) + except: + time.sleep(2) + print(".", end=None) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument( + "--file", + type=str, + default=None, + help="A non-default request to send to the server.", + ) + p.add_argument( + "--reps", + type=int, + default=1, + help="Number of times to duplicate each request in one second intervals.", + ) + p.add_argument( + "--save", + action=argparse.BooleanOptionalAction, + default=True, + help="Save images. To disable, use --no-save", + ) + p.add_argument( + "--outputdir", + type=str, + default="gen_imgs", + help="Directory to which images get saved.", + ) + p.add_argument( + "--host", type=str, default="http://0.0.0.0", help="Server host address." + ) + p.add_argument("--port", type=str, default="8000", help="Server port") + p.add_argument( + "--steps", + type=int, + default="20", + help="Number of inference steps. More steps usually means a better image. Interactive only.", + ) + p.add_argument( + "--interactive", + action="store_true", + help="Start as an example CLI client instead of sending static requests.", + ) + args = p.parse_args() + check_health(f"{args.host}:{args.port}") + if args.interactive: + asyncio.run(interactive(args)) + else: + asyncio.run(static(args)) + + +if __name__ == "__main__": + main() From c75addaa566a1ff9e6fa48089ce91a92ec4649be Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 4 Dec 2024 12:27:30 -0600 Subject: [PATCH 07/20] Remove breakpoints. --- shortfin/python/shortfin_apps/flux/components/service.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/shortfin/python/shortfin_apps/flux/components/service.py b/shortfin/python/shortfin_apps/flux/components/service.py index 6fef52939..6d723bbe6 100644 --- a/shortfin/python/shortfin_apps/flux/components/service.py +++ b/shortfin/python/shortfin_apps/flux/components/service.py @@ -382,7 +382,6 @@ async def run(self): await device0 for i in range(req_count): req = self.exec_requests[i] - breakpoint() req.done.set_success() if self.service.prog_isolation == sf.ProgramIsolation.PER_FIBER: self.service.idle_fibers.add(self.fiber) @@ -723,7 +722,6 @@ async def _postprocess(self, device, requests): data = [0.3 + j * 0.1 for _ in range(req.height * req.width)] images_planar.view(0, j).items = data permuted = sfnp.transpose(images_planar, (0, 2, 3, 1)) - breakpoint() cast_image = sfnp.multiply(127.5, (sfnp.add(permuted, 1.0))) image = sfnp.round(cast_image, dtype=sfnp.uint8) From a45383cf720568b8bed662bb34283cf851cc4f29 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 6 Dec 2024 15:11:39 -0600 Subject: [PATCH 08/20] Fixup model exports and service. --- sharktank/sharktank/dynamo_exports/flux/ae.py | 2 +- .../sharktank/dynamo_exports/flux/export.py | 62 +- sharktank/sharktank/dynamo_exports/flux/te.py | 535 +----------------- .../shortfin_apps/flux/components/service.py | 111 +++- .../flux/examples/flux_dev_config_mixed.json | 2 +- .../shortfin_apps/flux/simple_client.py | 6 +- 6 files changed, 152 insertions(+), 566 deletions(-) diff --git a/sharktank/sharktank/dynamo_exports/flux/ae.py b/sharktank/sharktank/dynamo_exports/flux/ae.py index a0618d63c..b9363090e 100644 --- a/sharktank/sharktank/dynamo_exports/flux/ae.py +++ b/sharktank/sharktank/dynamo_exports/flux/ae.py @@ -346,7 +346,7 @@ def decode(self, z: Tensor) -> Tensor: pw=2, ) d_in = d_in / self.scale_factor + self.shift_factor - return self.decoder(d_in) + return self.decoder(d_in).clamp(-1, 1) def forward(self, x: Tensor) -> Tensor: return self.decode(self.encode(x)) diff --git a/sharktank/sharktank/dynamo_exports/flux/export.py b/sharktank/sharktank/dynamo_exports/flux/export.py index 7d82cab25..7ecd7c089 100644 --- a/sharktank/sharktank/dynamo_exports/flux/export.py +++ b/sharktank/sharktank/dynamo_exports/flux/export.py @@ -7,6 +7,9 @@ import os import re from dataclasses import dataclass +import math + +from einops import rearrange from iree.compiler.ir import Context from iree.turbine.aot import * @@ -16,7 +19,9 @@ import torch from diffusers.models.transformers import FluxTransformer2DModel -from te import ClipTextEncoderModule +from diffusers.models.autoencoders import AutoencoderKL +from te import HFEmbedder +from transformers import CLIPTextModel from ae import AutoEncoder, AutoEncoderParams from scheduler import FluxScheduler from mmdit import get_flux_transformer_model @@ -107,17 +112,14 @@ def get_te_model_and_inputs( ): match component: case "clip": - # te = CLIPTextModel.from_pretrained( - # model_repo_map[hf_model_name], - # subfolder="text_encoder" - # ) - te = ClipTextEncoderModule( - model_repo_map[hf_model_name], torch_dtypes[precision] + te = HFEmbedder( + "openai/clip-vit-large-patch14", + max_length=77, + torch_dtype=torch.float32, ) clip_ids_shape = ( batch_size, 77, - 2, ) input_args = [ torch.ones(clip_ids_shape, dtype=torch.int64), @@ -127,12 +129,34 @@ def get_te_model_and_inputs( return None, None +class FluxAEWrapper(torch.nn.Module): + def __init__(self, height=1024, width=1024): + super().__init__() + self.ae = AutoencoderKL.from_pretrained( + "black-forest-labs/FLUX.1-dev", subfolder="vae" + ) + self.height = height + self.width = width + + def forward(self, z): + d_in = rearrange( + z, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(self.height / 16), + w=math.ceil(self.width / 16), + ph=2, + pw=2, + ) + d_in = d_in / self.ae.config.scaling_factor + self.ae.config.shift_factor + return self.ae.decode(d_in, return_dict=False)[0].clamp(-1, 1) + + def get_ae_model_and_inputs(hf_model_name, precision, batch_size, height, width): dtype = torch_dtypes[precision] aeparams = fluxconfigs[hf_model_name].ae_params aeparams.height = height aeparams.width = width - ae = AutoEncoder(params=aeparams).to(dtype) + ae = FluxAEWrapper(height, width) latents_shape = ( batch_size, int(height * width / 256), @@ -252,14 +276,14 @@ class CompiledFluxTextEncoder(CompiledModule): fxb = FxProgramsBuilder(model) - @fxb.export_program( - args=(encode_inputs,), - ) - def _encode( - module, - inputs, - ): - return module.encode(*inputs) + # @fxb.export_program( + # args=(encode_inputs,), + # ) + # def _encode( + # module, + # inputs, + # ): + # return module.encode(*inputs) @fxb.export_program( args=(decode_inputs,), @@ -268,10 +292,10 @@ def _decode( module, inputs, ): - return module.decode(*inputs) + return module.forward(*inputs) class CompiledFluxAutoEncoder(CompiledModule): - encode = _encode + # encode = _encode decode = _decode if external_weights: diff --git a/sharktank/sharktank/dynamo_exports/flux/te.py b/sharktank/sharktank/dynamo_exports/flux/te.py index 59d728d2c..74d17a665 100644 --- a/sharktank/sharktank/dynamo_exports/flux/te.py +++ b/sharktank/sharktank/dynamo_exports/flux/te.py @@ -1,524 +1,29 @@ -### This file contains impls for underlying related models (CLIP, T5, etc) +from torch import Tensor, nn +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer -import torch, math -from torch import nn -from transformers import CLIPTokenizer, T5TokenizerFast -from transformers import T5EncoderModel -from iree.turbine import ops -from huggingface_hub import hf_hub_download -from safetensors import safe_open -from sharktank.layers import T5Config -from sharktank.models import t5 - -CLIP_CONFIG = { - "hidden_act": "quick_gelu", - "hidden_size": 768, - "intermediate_size": 3172, - "num_attention_heads": 12, - "num_hidden_layers": 12, -} - - -class ClipTextEncoderModule(torch.nn.Module): - @torch.no_grad() - def __init__( - self, - repo, - precision, - ): - super().__init__() - self.dtype = torch.float16 if precision == "fp16" else torch.float32 - self.clip = SDClipModel( - layer="hidden", - layer_idx=-2, - device="cpu", - dtype=self.dtype, - layer_norm_hidden_state=False, - return_projected_pooled=True, - textmodel_json_config=CLIP_CONFIG, - ) - if precision == "fp16": - self.clip = self.clip.half() - clip_weights = hf_hub_download( - repo_id=repo, - filename="text_encoder/model.safetensors", - ) - with safe_open(clip_weights, framework="pt", device="cpu") as f: - load_into(f, self.clip.transformer, "", "cpu", self.dtype) - - def forward(self, clip_ids): - vec = self.clip(clip_ids) - - return vec - - -################################################################################################# -### Core/Utility -################################################################################################# - - -def attention(q, k, v, heads, mask=None): - """Convenience wrapper around a basic attention operation""" - b, _, dim_head = q.shape - # ops.iree.trace_tensor("attention_q", q[0,0,:5]) - # ops.iree.trace_tensor("attention_k", k[0,0,:5]) - # ops.iree.trace_tensor("attention_v", v[0,0,:5]) - dim_head //= heads - q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v)) - out = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False - ) - # ops.iree.trace_tensor("attention_out", out[0,0,:5]) - return out.transpose(1, 2).reshape(b, -1, heads * dim_head) - - -class Mlp(nn.Module): - """MLP as used in Vision Transformer, MLP-Mixer and related networks""" - - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - bias=True, - dtype=None, - device=None, - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - - self.fc1 = nn.Linear( - in_features, hidden_features, bias=bias, dtype=dtype, device=device - ) - self.act = act_layer - self.fc2 = nn.Linear( - hidden_features, out_features, bias=bias, dtype=dtype, device=device - ) - - def forward(self, x): - x = self.fc1(x) - # ops.iree.trace_tensor("mlpfx", x[0,0,:5]) - x = self.act(x) - # ops.iree.trace_tensor("mlpact", x[0,0,:5]) - x = self.fc2(x) - # ops.iree.trace_tensor("mlpanotherfc", x[0,0,:5]) - return x - - -def load_into(f, model, prefix, device, dtype=None): - """Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module.""" - for key in f.keys(): - if key.startswith(prefix) and not key.startswith("loss."): - path = key[len(prefix) :].split(".") - obj = model - for p in path: - if obj is list: - obj = obj[int(p)] - else: - obj = getattr(obj, p, None) - if obj is None: - print( - f"Skipping key '{key}' in safetensors file as '{p}' does not exist in python model" - ) - break - if obj is None: - continue - try: - tensor = f.get_tensor(key).to(device=device) - if dtype is not None: - tensor = tensor.to(dtype=dtype) - obj.requires_grad_(False) - obj.set_(tensor) - except Exception as e: - print(f"Failed to load key '{key}' in safetensors file: {e}") - raise e - - -################################################################################################# -### CLIP -################################################################################################# - - -class CLIPAttention(torch.nn.Module): - def __init__(self, embed_dim, heads, dtype, device): - super().__init__() - self.heads = heads - self.q_proj = nn.Linear( - embed_dim, embed_dim, bias=True, dtype=dtype, device=device - ) - self.k_proj = nn.Linear( - embed_dim, embed_dim, bias=True, dtype=dtype, device=device - ) - self.v_proj = nn.Linear( - embed_dim, embed_dim, bias=True, dtype=dtype, device=device - ) - self.out_proj = nn.Linear( - embed_dim, embed_dim, bias=True, dtype=dtype, device=device - ) - - def forward(self, x, mask=None): - q = self.q_proj(x) - k = self.k_proj(x) - v = self.v_proj(x) - out = attention(q, k, v, self.heads, mask) - return self.out_proj(out) - - -ACTIVATIONS = { - "quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), - "gelu": torch.nn.functional.gelu, -} - - -class CLIPLayer(torch.nn.Module): - def __init__( - self, - embed_dim, - heads, - intermediate_size, - intermediate_activation, - dtype, - device, - ): - super().__init__() - self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) - self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) - self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) - # self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) - self.mlp = Mlp( - embed_dim, - intermediate_size, - embed_dim, - act_layer=ACTIVATIONS[intermediate_activation], - dtype=dtype, - device=device, - ) - - def forward(self, x, mask=None): - x += self.self_attn(self.layer_norm1(x), mask) - x += self.mlp(self.layer_norm2(x)) - return x - - -class CLIPEncoder(torch.nn.Module): - def __init__( - self, - num_layers, - embed_dim, - heads, - intermediate_size, - intermediate_activation, - dtype, - device, - ): - super().__init__() - self.layers = torch.nn.ModuleList( - [ - CLIPLayer( - embed_dim, - heads, - intermediate_size, - intermediate_activation, - dtype, - device, - ) - for i in range(num_layers) - ] - ) - - def forward(self, x, mask=None, intermediate_output=None): - if intermediate_output is not None: - if intermediate_output < 0: - intermediate_output = len(self.layers) + intermediate_output - intermediate = None - for i, l in enumerate(self.layers): - x = l(x, mask) - if i == intermediate_output: - intermediate = x.clone() - return x, intermediate - - -class CLIPEmbeddings(torch.nn.Module): - def __init__( - self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None - ): +# Copied from https://github.com/black-forest-labs/flux +class HFEmbedder(nn.Module): + def __init__(self, version: str, max_length: int, **hf_kwargs): super().__init__() - self.token_embedding = torch.nn.Embedding( - vocab_size, embed_dim, dtype=dtype, device=device - ) - self.position_embedding = torch.nn.Embedding( - num_positions, embed_dim, dtype=dtype, device=device - ) - - def forward(self, input_tokens): - return self.token_embedding(input_tokens) + self.position_embedding.weight - - -class CLIPTextModel_(torch.nn.Module): - def __init__(self, config_dict, dtype, device): - num_layers = config_dict["num_hidden_layers"] - embed_dim = config_dict["hidden_size"] - heads = config_dict["num_attention_heads"] - intermediate_size = config_dict["intermediate_size"] - intermediate_activation = config_dict["hidden_act"] - super().__init__() - self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) - self.encoder = CLIPEncoder( - num_layers, - embed_dim, - heads, - intermediate_size, - intermediate_activation, - dtype, - device, - ) - self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) - - def forward( - self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True - ): - x = self.embeddings(input_tokens) - causal_mask = ( - torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device) - .fill_(float("-inf")) - .triu_(1) - ) - x, i = self.encoder( - x, mask=causal_mask, intermediate_output=intermediate_output - ) - x = self.final_layer_norm(x) - if i is not None and final_layer_norm_intermediate: - i = self.final_layer_norm(i) - pooled_output = x[ - torch.arange(x.shape[0], device=x.device), - input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1), - ] - return x, i, pooled_output - - -class CLIPTextModel(torch.nn.Module): - def __init__(self, config_dict, dtype, device): - super().__init__() - self.num_layers = config_dict["num_hidden_layers"] - self.text_model = CLIPTextModel_(config_dict, dtype, device) - embed_dim = config_dict["hidden_size"] - self.text_projection = nn.Linear( - embed_dim, embed_dim, bias=False, dtype=dtype, device=device - ) - self.text_projection.weight.copy_(torch.eye(embed_dim)) - self.dtype = dtype - - def get_input_embeddings(self): - return self.text_model.embeddings.token_embedding - - def set_input_embeddings(self, embeddings): - self.text_model.embeddings.token_embedding = embeddings - - def forward(self, *args, **kwargs): - x = self.text_model(*args, **kwargs) - out = self.text_projection(x[2]) - return (x[0], x[1], out, x[2]) - - -class SDTokenizer: - def __init__( - self, - max_length=77, - pad_with_end=True, - tokenizer=None, - has_start_token=True, - pad_to_max_length=True, - min_length=None, - ): - self.tokenizer = tokenizer + self.is_clip = version.startswith("openai") self.max_length = max_length - self.min_length = min_length - empty = self.tokenizer("")["input_ids"] - if has_start_token: - self.tokens_start = 1 - self.start_token = empty[0] - self.end_token = empty[1] - else: - self.tokens_start = 0 - self.start_token = None - self.end_token = empty[0] - self.pad_with_end = pad_with_end - self.pad_to_max_length = pad_to_max_length - vocab = self.tokenizer.get_vocab() - self.inv_vocab = {v: k for k, v in vocab.items()} - self.max_word_length = 8 + self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" - def tokenize_with_weights(self, text: str): - """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" - if self.pad_with_end: - pad_token = self.end_token - else: - pad_token = 0 - batch = [] - if self.start_token is not None: - batch.append((self.start_token, 1.0)) - to_tokenize = text.replace("\n", " ").split(" ") - to_tokenize = [x for x in to_tokenize if x != ""] - for word in to_tokenize: - batch.extend( - [ - (t, 1) - for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1] - ] + if self.is_clip: + self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( + version, **hf_kwargs ) - batch.append((self.end_token, 1.0)) - if self.pad_to_max_length: - batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) - if self.min_length is not None and len(batch) < self.min_length: - batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) - return [batch] - - -class SDXLClipGTokenizer(SDTokenizer): - def __init__(self, tokenizer): - super().__init__(pad_with_end=False, tokenizer=tokenizer) - - -class SD3Tokenizer: - def __init__(self): - clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) - self.clip_g = SDXLClipGTokenizer(clip_tokenizer) - self.t5xxl = T5XXLTokenizer() - - def tokenize_with_weights(self, text: str | list[str]): - out = {} - if isinstance(text, list): - text = text[0] - out["g"] = self.clip_g.tokenize_with_weights(text) - out["l"] = self.clip_l.tokenize_with_weights(text) - out["t5xxl"] = self.t5xxl.tokenize_with_weights(text) - for k, v in out.items(): - out[k] = torch.tensor(v, dtype=torch.int64, device="cpu") - return out - - -class ClipTokenWeightEncoder: - def encode_token_weights(self, token_weight_pairs): - # tokens = list(map(lambda a: a[0], token_weight_pairs[0])) - tokens = token_weight_pairs[:, :, 0] - out, pooled = self(tokens) - if pooled is not None: - first_pooled = pooled[0:1].cpu() - else: - first_pooled = pooled - output = [out[0:1]] - return torch.cat(output, dim=-2).cpu(), first_pooled - - -class SDClipModel(torch.nn.Module): - """Uses the CLIP transformer encoder for text (from huggingface)""" - - LAYERS = ["last", "pooled", "hidden"] - - def __init__( - self, - device="cpu", - max_length=77, - layer="last", - layer_idx=None, - textmodel_json_config=None, - dtype=None, - model_class=CLIPTextModel, - special_tokens={"start": 49406, "end": 49407, "pad": 49407}, - layer_norm_hidden_state=True, - return_projected_pooled=True, - ): - super().__init__() - assert layer in self.LAYERS - self.transformer = model_class(textmodel_json_config, dtype, device) - self.num_layers = self.transformer.num_layers - self.max_length = max_length - self.transformer = self.transformer.eval() - for param in self.parameters(): - param.requires_grad = False - self.layer = layer - self.layer_idx = None - self.special_tokens = special_tokens - self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) - self.layer_norm_hidden_state = layer_norm_hidden_state - self.return_projected_pooled = return_projected_pooled - if layer == "hidden": - assert layer_idx is not None - assert abs(layer_idx) < self.num_layers - self.set_clip_options({"layer": layer_idx}) - self.options_default = ( - self.layer, - self.layer_idx, - self.return_projected_pooled, - ) - - def encode_token_weights(self, token_weight_pairs): - pass - - def set_clip_options(self, options): - layer_idx = options.get("layer", self.layer_idx) - self.return_projected_pooled = options.get( - "projected_pooled", self.return_projected_pooled - ) - if layer_idx is None or abs(layer_idx) > self.num_layers: - self.layer = "last" else: - self.layer = "hidden" - self.layer_idx = layer_idx - - def forward(self, token_weight_pairs): - # tokens = list(map(lambda a: a[0], token_weight_pairs[0])) - tokens = token_weight_pairs[:, :, 0] - # backup_embeds = self.transformer.get_input_embeddings() - # device = backup_embeds.weight.device - # tokens = torch.LongTensor(tokens).to(device) - outputs = self.transformer( - tokens, - intermediate_output=self.layer_idx, - final_layer_norm_intermediate=self.layer_norm_hidden_state, - ) - # self.transformer.set_input_embeddings(backup_embeds) - if self.layer == "last": - z = outputs[0] - else: - z = outputs[1] - pooled_output = None - if len(outputs) >= 3: - if ( - not self.return_projected_pooled - and len(outputs) >= 4 - and outputs[3] is not None - ): - pooled_output = outputs[3].float() - elif outputs[2] is not None: - pooled_output = outputs[2].float() - out, pooled = z.float(), pooled_output - if pooled is not None: - first_pooled = pooled[0:1].cpu() - else: - first_pooled = pooled - output = [out[0:1]] - return torch.cat(output, dim=-2).cpu(), first_pooled - + self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained( + version, **hf_kwargs + ) -class SDXLClipG(SDClipModel): - """Wraps the CLIP-G model into the SD-CLIP-Model interface""" + self.hf_module = self.hf_module.eval().requires_grad_(False) - def __init__( - self, config, device="cpu", layer="penultimate", layer_idx=None, dtype=None - ): - if layer == "penultimate": - layer = "hidden" - layer_idx = -2 - super().__init__( - device=device, - layer=layer, - layer_idx=layer_idx, - textmodel_json_config=config, - dtype=dtype, - special_tokens={"start": 49406, "end": 49407, "pad": 0}, - layer_norm_hidden_state=False, + def forward(self, input_ids) -> Tensor: + outputs = self.hf_module( + input_ids=input_ids, + attention_mask=None, + output_hidden_states=False, ) + return outputs[self.output_key] diff --git a/shortfin/python/shortfin_apps/flux/components/service.py b/shortfin/python/shortfin_apps/flux/components/service.py index 6d723bbe6..5d1f58f03 100644 --- a/shortfin/python/shortfin_apps/flux/components/service.py +++ b/shortfin/python/shortfin_apps/flux/components/service.py @@ -6,9 +6,12 @@ import asyncio import logging +import math +import torch import numpy as np from tqdm.auto import tqdm from pathlib import Path +from typing import Callable from PIL import Image import base64 @@ -21,6 +24,8 @@ from .tokenizer import Tokenizer from .metrics import measure +from einops import rearrange + logger = logging.getLogger("shortfin-flux.service") prog_isolations = { @@ -30,6 +35,37 @@ } +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # estimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + class GenerateService: """Top level service interface for image generation.""" @@ -420,9 +456,10 @@ async def _prepare(self, device, requests): # Generate random sample latents. seed = request.seed channels = self.service.model_params.num_latents_channels + image_seq_len = (request.height) * (request.width) // 256 latents_shape = [ 1, - (requests[0].height) * (requests[0].width) // 256, + image_seq_len, 64, ] # latents_shape = ( @@ -446,6 +483,11 @@ async def _prepare(self, device, requests): request.sample.copy_from(sample_host) await device + request.timesteps = get_schedule( + request.steps, + image_seq_len, + shift=not self.service.model_params.is_schnell, + ) return async def _clip(self, device, requests): @@ -464,7 +506,7 @@ async def _clip(self, device, requests): clip_inputs = [ sfnp.device_array.for_device( device, - [req_bs, self.service.model_params.clip_max_seq_len, 2], + [req_bs, self.service.model_params.clip_max_seq_len], sfnp.sint64, ), ] @@ -486,12 +528,12 @@ async def _clip(self, device, requests): fn, "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]), ) - (vec, _) = await fn(*clip_inputs, fiber=self.fiber) + (vec,) = await fn(*clip_inputs, fiber=self.fiber) await device for i in range(req_bs): cfg_mult = 2 - requests[i].vec = vec.view(slice(i * cfg_mult, (i + 1) * cfg_mult)) + requests[i].vec = vec.view(slice(i, (i + 1))) return @@ -580,7 +622,9 @@ async def _denoise(self, device, requests): device, vec_shape, self.service.model_params.sampler_dtype ), "step": sfnp.device_array.for_device(device, [1], sfnp.int64), - "num_steps": sfnp.device_array.for_device(device, [1], sfnp.int64), + "timesteps": sfnp.device_array.for_device( + device, [100], self.service.model_params.sampler_dtype + ), "guidance_scale": sfnp.device_array.for_device( device, [req_bs], self.service.model_params.sampler_dtype ), @@ -618,17 +662,20 @@ async def _denoise(self, device, requests): # Batch CLIP projections. vec = requests[i].vec - denoise_inputs["vec"].view(slice(cfg_dim, cfg_dim + cfg_mult)).copy_from( - vec - ) + for nc in range(2): + denoise_inputs["vec"].view(slice(nc, nc + 1)).copy_from(vec) denoise_inputs["guidance_scale"].copy_from(gs_host) - - ns_host = denoise_inputs["num_steps"].for_transfer() - with ns_host.map(write=True) as m: - ns_host.items = [step_count] - - denoise_inputs["num_steps"].copy_from(ns_host) + await device + ts_host = denoise_inputs["timesteps"].for_transfer() + with ts_host.map(write=True) as m: + m.fill(float(1)) + for tstep in range(len(requests[0].timesteps)): + with ts_host.view(tstep).map(write=True, discard=True) as m: + m.fill(np.asarray(requests[0].timesteps[tstep], dtype="float32")) + + denoise_inputs["timesteps"].copy_from(ts_host) + await device for i, t in tqdm( enumerate(range(step_count)), @@ -640,14 +687,25 @@ async def _denoise(self, device, requests): s_host.items = [i] denoise_inputs["step"].copy_from(s_host) - logger.debug( + logger.info( "INVOKE %r", fns["sampler"], ) + await device + # np_arrs = {} + # host_arrs = {} + # for key, value in denoise_inputs.items(): + # host_arrs[key] = denoise_inputs[key].for_transfer() + # host_arrs[key].copy_from(denoise_inputs[key]) + # await device + # np_arrs[key] = np.array(host_arrs[key]) + # for key, value in np_arrs.items(): + # np.save(f"{key}.npy", value) + (noise_pred,) = await fns["sampler"]( *denoise_inputs.values(), fiber=self.fiber ) - + await device denoise_inputs["img"].copy_from(noise_pred) for idx, req in enumerate(requests): @@ -671,7 +729,7 @@ async def _decode(self, device, requests): await device latents_shape = [ req_bs, - (requests[0].height) * (requests[0].width) // 256, + (requests[0].height * requests[0].width) // 256, 64, ] latents = sfnp.device_array.for_device( @@ -688,7 +746,6 @@ async def _decode(self, device, requests): "".join([f"\n 0: {latents.shape}"]), ) (image,) = await fn(latents, fiber=self.fiber) - await device images_shape = [ req_bs, @@ -709,23 +766,23 @@ async def _postprocess(self, device, requests): # Process output images for req in requests: image_shape = [ - 1, 3, req.height, req.width, ] + out_shape = [req.height, req.width, 3] images_planar = sfnp.device_array.for_host( device, image_shape, self.service.model_params.vae_dtype ) images_planar.copy_from(req.image_array) - for j in range(3): - data = [0.3 + j * 0.1 for _ in range(req.height * req.width)] - images_planar.view(0, j).items = data - permuted = sfnp.transpose(images_planar, (0, 2, 3, 1)) - cast_image = sfnp.multiply(127.5, (sfnp.add(permuted, 1.0))) - image = sfnp.round(cast_image, dtype=sfnp.uint8) - - image_bytes = bytes(image.map(read=True)) + permuted = sfnp.device_array.for_host( + device, out_shape, self.service.model_params.vae_dtype + ) + out = sfnp.device_array.for_host(device, out_shape, sfnp.uint8) + sfnp.transpose(images_planar, (1, 2, 0), out=permuted) + permuted = sfnp.multiply(127.5, (sfnp.add(permuted, 1.0))) + out = sfnp.round(permuted, dtype=sfnp.uint8) + image_bytes = bytes(out.map(read=True)) image = base64.b64encode(image_bytes).decode("utf-8") req.result_image = image diff --git a/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json b/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json index cbe44680d..46a20af78 100644 --- a/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json +++ b/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json @@ -7,7 +7,7 @@ "clip_batch_sizes": [ 1 ], - "clip_dtype": "bfloat16", + "clip_dtype": "float32", "clip_module_name": "compiled_flux_text_encoder", "t5xxl_batch_sizes": [ 1 diff --git a/shortfin/python/shortfin_apps/flux/simple_client.py b/shortfin/python/shortfin_apps/flux/simple_client.py index 42a3e02f3..9382ddd9d 100644 --- a/shortfin/python/shortfin_apps/flux/simple_client.py +++ b/shortfin/python/shortfin_apps/flux/simple_client.py @@ -19,13 +19,13 @@ sample_request = { "prompt": [ - " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " A mountain with a halo cloud over it, Death Mountain, spooky, Zelda", ], "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"], "height": [1024], "width": [1024], - "steps": [20], - "guidance_scale": [3], + "steps": [50], + "guidance_scale": [3.5], "seed": [0], "output_type": ["base64"], "rid": ["string"], From a7592c0ebba5153e6b04b42412845a0c4f2d3b42 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Fri, 3 Jan 2025 11:01:34 -0600 Subject: [PATCH 09/20] dtype flexibility for flux. --- .../sharktank/dynamo_exports/flux/export.py | 7 +- .../shortfin_apps/flux/components/service.py | 85 +++++++++++++++---- .../shortfin_apps/flux/simple_client.py | 4 +- 3 files changed, 75 insertions(+), 21 deletions(-) diff --git a/sharktank/sharktank/dynamo_exports/flux/export.py b/sharktank/sharktank/dynamo_exports/flux/export.py index 7ecd7c089..9002aae18 100644 --- a/sharktank/sharktank/dynamo_exports/flux/export.py +++ b/sharktank/sharktank/dynamo_exports/flux/export.py @@ -130,10 +130,11 @@ def get_te_model_and_inputs( class FluxAEWrapper(torch.nn.Module): - def __init__(self, height=1024, width=1024): + def __init__(self, height=1024, width=1024, precision="fp32"): super().__init__() + dtype = torch_dtypes[precision] self.ae = AutoencoderKL.from_pretrained( - "black-forest-labs/FLUX.1-dev", subfolder="vae" + "black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtypes=dtype ) self.height = height self.width = width @@ -156,7 +157,7 @@ def get_ae_model_and_inputs(hf_model_name, precision, batch_size, height, width) aeparams = fluxconfigs[hf_model_name].ae_params aeparams.height = height aeparams.width = width - ae = FluxAEWrapper(height, width) + ae = FluxAEWrapper(height, width, precision).to(dtype) latents_shape = ( batch_size, int(height * width / 256), diff --git a/shortfin/python/shortfin_apps/flux/components/service.py b/shortfin/python/shortfin_apps/flux/components/service.py index 5d1f58f03..b758536d3 100644 --- a/shortfin/python/shortfin_apps/flux/components/service.py +++ b/shortfin/python/shortfin_apps/flux/components/service.py @@ -24,8 +24,6 @@ from .tokenizer import Tokenizer from .metrics import measure -from einops import rearrange - logger = logging.getLogger("shortfin-flux.service") prog_isolations = { @@ -475,13 +473,26 @@ async def _prepare(self, device, requests): device, latents_shape, self.service.model_params.sampler_dtype ) - sample_host = request.sample.for_transfer() + sample_host = sfnp.device_array.for_host( + device, latents_shape, sfnp.float32 + ) with sample_host.map(discard=True) as m: m.fill(bytes(1)) - sfnp.fill_randn(sample_host, generator=generator) + if self.service.model_params.sampler_dtype != sfnp.float32: + sample_transfer = request.sample.for_transfer() + sfnp.convert( + sample_host, + dtype=self.service.model_params.sampler_dtype, + out=sample_transfer, + ) + + request.sample.copy_from(sample_transfer) + # sample_debug = torch.frombuffer(sample_transfer.items, dtype=torch.bfloat16) + # print(sample_debug) + else: + request.sample.copy_from(sample_host) - request.sample.copy_from(sample_host) await device request.timesteps = get_schedule( request.steps, @@ -634,13 +645,11 @@ async def _denoise(self, device, requests): sample_host = sfnp.device_array.for_host( device, img_shape, self.service.model_params.sampler_dtype ) + guidance_float = sfnp.device_array.for_host(device, [req_bs], sfnp.float32) + for i in range(req_bs): + guidance_float.view(i).items = [requests[i].guidance_scale] cfg_dim = i * cfg_mult - with gs_host.view(i).map(write=True, discard=True) as m: - # TODO: do this without numpy - np_arr = np.asarray(requests[i].guidance_scale, dtype="float32") - - m.fill(np_arr) # Reshape and batch sample latent inputs on device. # Currently we just generate random latents in the desired shape. Rework for img2img. @@ -664,16 +673,24 @@ async def _denoise(self, device, requests): vec = requests[i].vec for nc in range(2): denoise_inputs["vec"].view(slice(nc, nc + 1)).copy_from(vec) - + sfnp.convert( + guidance_float, dtype=self.service.model_params.sampler_dtype, out=gs_host + ) denoise_inputs["guidance_scale"].copy_from(gs_host) await device ts_host = denoise_inputs["timesteps"].for_transfer() - with ts_host.map(write=True) as m: + ts_float = sfnp.device_array.for_host( + device, denoise_inputs["timesteps"].shape, dtype=sfnp.float32 + ) + with ts_float.map(write=True) as m: m.fill(float(1)) for tstep in range(len(requests[0].timesteps)): - with ts_host.view(tstep).map(write=True, discard=True) as m: + with ts_float.view(tstep).map(write=True, discard=True) as m: m.fill(np.asarray(requests[0].timesteps[tstep], dtype="float32")) + sfnp.convert( + ts_float, dtype=self.service.model_params.sampler_dtype, out=ts_host + ) denoise_inputs["timesteps"].copy_from(ts_host) await device @@ -712,7 +729,33 @@ async def _denoise(self, device, requests): req.denoised_latents = sfnp.device_array.for_device( device, img_shape, self.service.model_params.vae_dtype ) - req.denoised_latents.copy_from(denoise_inputs["img"].view(idx * cfg_mult)) + if ( + self.service.model_params.vae_dtype + != self.service.model_params.sampler_dtype + ): + pred_shape = [ + 1, + (requests[0].height) * (requests[0].width) // 256, + 64, + ] + denoised_inter = sfnp.device_array.for_host( + device, pred_shape, dtype=self.service.model_params.vae_dtype + ) + denoised_host = sfnp.device_array.for_host( + device, pred_shape, dtype=self.service.model_params.sampler_dtype + ) + denoised_host.copy_from(denoise_inputs["img"].view(idx * cfg_mult)) + await device + sfnp.convert( + denoised_host, + dtype=self.service.model_params.vae_dtype, + out=denoised_inter, + ) + req.denoised_latents.copy_from(denoised_inter) + else: + req.denoised_latents.copy_from( + denoise_inputs["img"].view(idx * cfg_mult) + ) return async def _decode(self, device, requests): @@ -726,7 +769,6 @@ async def _decode(self, device, requests): for bs, fn in entrypoints.items(): if bs == req_bs: break - await device latents_shape = [ req_bs, (requests[0].height * requests[0].width) // 256, @@ -735,10 +777,16 @@ async def _decode(self, device, requests): latents = sfnp.device_array.for_device( device, latents_shape, self.service.model_params.vae_dtype ) + # latents_host = sfnp.device_array.for_host( + # device, latents_shape, self.service.model_params.vae_dtype + # ) + # latents_host.copy_from(latents) + # print(latents_host) + # lat_arr = np.array(latents_host, dtype="float32") + # np.save("vae_in.npy", lat_arr) for i in range(req_bs): latents.view(i).copy_from(requests[i].denoised_latents) - await device # Decode the denoised latents. logger.debug( "INVOKE %r: %s", @@ -756,7 +804,12 @@ async def _decode(self, device, requests): images_host = sfnp.device_array.for_host( device, images_shape, self.service.model_params.vae_dtype ) + await device images_host.copy_from(image) + # await device + # print(images_host) + # img_arr = np.array(images_host, dtype="float32") + # np.save("vae_out.npy", img_arr) await device for idx, req in enumerate(requests): req.image_array = images_host.view(idx) diff --git a/shortfin/python/shortfin_apps/flux/simple_client.py b/shortfin/python/shortfin_apps/flux/simple_client.py index 9382ddd9d..cfe284e21 100644 --- a/shortfin/python/shortfin_apps/flux/simple_client.py +++ b/shortfin/python/shortfin_apps/flux/simple_client.py @@ -19,12 +19,12 @@ sample_request = { "prompt": [ - " A mountain with a halo cloud over it, Death Mountain, spooky, Zelda", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", ], "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"], "height": [1024], "width": [1024], - "steps": [50], + "steps": [2], "guidance_scale": [3.5], "seed": [0], "output_type": ["base64"], From 62d552d6be189ad77f61300f37c9ad309cf7bb7e Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 8 Jan 2025 12:15:48 -0600 Subject: [PATCH 10/20] Add onnx export of mmdit and rename dynamo_exports dir to torch_exports --- .../flux/README.md | 0 .../flux/ae.py | 0 .../flux/export.py | 0 .../flux/mmdit.py | 0 .../torch_exports/flux/mmdit_onnx.py | 333 ++++++++++++++++++ .../flux/scheduler.py | 0 .../flux/te.py | 0 7 files changed, 333 insertions(+) rename sharktank/sharktank/{dynamo_exports => torch_exports}/flux/README.md (100%) rename sharktank/sharktank/{dynamo_exports => torch_exports}/flux/ae.py (100%) rename sharktank/sharktank/{dynamo_exports => torch_exports}/flux/export.py (100%) rename sharktank/sharktank/{dynamo_exports => torch_exports}/flux/mmdit.py (100%) create mode 100644 sharktank/sharktank/torch_exports/flux/mmdit_onnx.py rename sharktank/sharktank/{dynamo_exports => torch_exports}/flux/scheduler.py (100%) rename sharktank/sharktank/{dynamo_exports => torch_exports}/flux/te.py (100%) diff --git a/sharktank/sharktank/dynamo_exports/flux/README.md b/sharktank/sharktank/torch_exports/flux/README.md similarity index 100% rename from sharktank/sharktank/dynamo_exports/flux/README.md rename to sharktank/sharktank/torch_exports/flux/README.md diff --git a/sharktank/sharktank/dynamo_exports/flux/ae.py b/sharktank/sharktank/torch_exports/flux/ae.py similarity index 100% rename from sharktank/sharktank/dynamo_exports/flux/ae.py rename to sharktank/sharktank/torch_exports/flux/ae.py diff --git a/sharktank/sharktank/dynamo_exports/flux/export.py b/sharktank/sharktank/torch_exports/flux/export.py similarity index 100% rename from sharktank/sharktank/dynamo_exports/flux/export.py rename to sharktank/sharktank/torch_exports/flux/export.py diff --git a/sharktank/sharktank/dynamo_exports/flux/mmdit.py b/sharktank/sharktank/torch_exports/flux/mmdit.py similarity index 100% rename from sharktank/sharktank/dynamo_exports/flux/mmdit.py rename to sharktank/sharktank/torch_exports/flux/mmdit.py diff --git a/sharktank/sharktank/torch_exports/flux/mmdit_onnx.py b/sharktank/sharktank/torch_exports/flux/mmdit_onnx.py new file mode 100644 index 000000000..214162f8a --- /dev/null +++ b/sharktank/sharktank/torch_exports/flux/mmdit_onnx.py @@ -0,0 +1,333 @@ +import os +import math +from typing import Callable + +import torch +from einops import repeat, rearrange +from diffusers import FluxTransformer2DModel + +from iree.turbine.aot import * + + +def get_local_path(local_dir, model_dir): + model_local_dir = os.path.join(local_dir, model_dir) + if not os.path.exists(model_local_dir): + os.makedirs(model_local_dir) + return model_local_dir + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps + + +class FluxScheduler(torch.nn.Module): + def __init__(self, max_length, torch_dtype, is_schnell=False): + super().__init__() + self.is_schnell = is_schnell + self.max_length = max_length + timesteps = [torch.empty((100), dtype=torch_dtype, requires_grad=False)] * 100 + for i in range(1, 100): + schedule = get_schedule(i, max_length, shift=not self.is_schnell) + timesteps[i] = torch.nn.functional.pad(schedule, (0, 99 - i), "constant", 0) + self.timesteps = torch.stack(timesteps, dim=0).clone().detach() + + def prepare(self, num_steps): + timesteps = self.timesteps[num_steps] + return timesteps + + +class FluxModelCFG(torch.nn.Module): + def __init__( + self, + torch_dtype, + model_id="flux-dev", + batch_size=1, + max_length=512, + height=1024, + width=1024, + ): + super().__init__() + self.mmdit = FluxTransformer2DModel.from_single_file( + "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" + ).to(torch_dtype) + self.batch_size = batch_size * 2 + img_ids = torch.zeros(height // 16, width // 16, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(height // 16)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(width // 16)[None, :] + self.img_ids = rearrange(img_ids, "h w c -> (h w) c") + self.txt_ids = torch.zeros(max_length, 3) + + def forward(self, img, txt, vec, step, timesteps, guidance_scale): + guidance_vec = guidance_scale.repeat(self.batch_size) + t_curr = torch.index_select(timesteps, 0, step) + t_prev = torch.index_select(timesteps, 0, step + 1) + t_vec = t_curr.repeat(self.batch_size) + + pred = self.mmdit( + hidden_states=img, + img_ids=self.img_ids, + encoder_hidden_states=txt, + txt_ids=self.txt_ids, + pooled_projections=vec, + timestep=t_vec, + guidance=guidance_vec, + return_dict=False, + )[0] + pred_uncond, pred = torch.chunk(pred, 2, dim=0) + pred = pred_uncond + guidance_scale * (pred - pred_uncond) + img = img + (t_prev - t_curr) * pred + return img + + +class FluxModelSchnell(torch.nn.Module): + def __init__( + self, + torch_dtype, + model_id="flux-schnell", + batch_size=1, + max_length=512, + height=1024, + width=1024, + ): + super().__init__() + if "schnell" in model_id: + self.mmdit = FluxTransformer2DModel.from_single_file( + "https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/main/flux1-schnell.safetensors" + ).to(torch_dtype) + elif "dev" in model_id: + self.mmdit = FluxTransformer2DModel.from_single_file( + "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" + ).to(torch_dtype) + img_ids = torch.zeros(height // 16, width // 16, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(height // 16)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(width // 16)[None, :] + self.img_ids = repeat(img_ids, "h w c -> (h w) c") + self.txt_ids = torch.zeros(max_length, 3) + + def forward(self, img, txt, vec, step, timesteps, guidance_scale): + guidance_vec = guidance_scale.repeat(self.batch_size) + t_curr = torch.index_select(timesteps, 0, step) + t_prev = torch.index_select(timesteps, 0, step + 1) + t_vec = t_curr.repeat(self.batch_size) + + pred = self.mmdit( + hidden_states=img, + img_ids=self.img_ids, + encoder_hidden_states=txt, + txt_ids=self.txt_ids, + pooled_projections=vec, + timestep=t_vec, + guidance=guidance_vec, + return_dict=False, + )[0] + img = img + (t_prev - t_curr) * pred + return img + + +@torch.no_grad() +def get_flux_sampler_model( + local_dir, + hf_model_path, + img_height=1024, + img_width=1024, + compression_factor=8, + max_len=512, + model_dir="transformer", + torch_dtype=torch.float32, + bs=1, + cfg_mode=True, +): + + transformer_local_dir = get_local_path(local_dir, model_dir) + onnx_file = "model.onnx" + onnx_path = os.path.join(transformer_local_dir, onnx_file) + if os.path.exists(onnx_path): + return onnx_path + latent_h, latent_w = ( + img_height // compression_factor, + img_width // compression_factor, + ) + + if "schnell" in hf_model_path or cfg_mode == False: + model = FluxModelSchnell(torch_dtype=torch_dtype, model_id=hf_model_path) + config = model.mmdit.config + sample_inputs = ( + torch.randn( + bs, + (latent_h // 2) * (latent_w // 2), + config["in_channels"], + dtype=torch_dtype, + ), + torch.randn(bs, max_len, config["joint_attention_dim"], dtype=torch_dtype), + torch.randn(bs, config["pooled_projection_dim"], dtype=torch_dtype), + torch.tensor([0.0], dtype=torch.int64), + torch.randn(100, dtype=torch_dtype), + torch.empty(bs, dtype=torch_dtype), + ) + else: + model = FluxModelCFG(torch_dtype=torch_dtype, model_id=hf_model_path) + config = model.mmdit.config + cfg_bs = bs * 2 + sample_inputs = ( + torch.randn( + cfg_bs, + (latent_h // 2) * (latent_w // 2), + config["in_channels"], + dtype=torch_dtype, + ), + torch.randn( + cfg_bs, max_len, config["joint_attention_dim"], dtype=torch_dtype + ), + torch.randn(cfg_bs, config["pooled_projection_dim"], dtype=torch_dtype), + torch.tensor([0.0], dtype=torch.int64), + torch.randn(100, dtype=torch_dtype), + torch.randn(bs, dtype=torch_dtype), + ) + + input_names = ["img", "txt", "vec", "step", "timesteps", "guidance_scale"] + + if not os.path.isfile(onnx_path): + output_names = ["latent"] + + with torch.inference_mode(): + torch.onnx.export( + model, + sample_inputs, + onnx_path, + export_params=True, + input_names=input_names, + output_names=output_names, + do_constant_folding=False, + ) + + assert os.path.isfile(onnx_path) + + return onnx_path + + +def do_onnx_import(args, model_dir="transformer"): + if args.save_params_to: + params_path = args.save_params_to + else: + params_path = None + mlir_path = args.save_mlir_to + onnx_model_path = os.path.join(args.path, model_dir, "model.onnx") + process_args = [ + "python", + "-m", + "iree.compiler.tools.import_onnx", + onnx_model_path, + "-o", + mlir_path, + "--externalize-params", + "--large-model", + "--num-elements-threshold=32", + ] + if params_path: + process_args.extend(["--save-params-to", params_path]) + + subprocess.run(process_args) + return mlir_path, params_path + + +if __name__ == "__main__": + import argparse + import subprocess + + parser = argparse.ArgumentParser(description="Flux Sampler ONNX export") + + parser.add_argument( + "--hf_model_id", + type=str, + default="black-forest-labs/FLUX.1-dev", + choices=["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-dev"], + help="Model name", + ) + parser.add_argument("--path", type=str, default=".") + parser.add_argument( + "--dtype", + type=str, + default="float32", + choices=["float32", "bfloat16"], + help="Precision with which to export the model.", + ) + parser.add_argument( + "--height", + type=int, + default=1024, + ) + parser.add_argument( + "--width", + type=int, + default=1024, + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + ) + parser.add_argument( + "--cfg_mode", + type=int, + default=1, + choices=[0, 1], + help="Whether or not to use CFG mode (batch dim -> 2, enables conditioning, flux-dev/pro only)", + ) + parser.add_argument( + "--save_mlir_to", + type=str, + default=None, + ) + parser.add_argument( + "--save_params_to", + type=str, + default=None, + ) + args = parser.parse_args() + torch_dtypes = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + } + model_dir = "transformer" + + onnx_path = get_flux_sampler_model( + args.path, + args.hf_model_id, + img_height=args.height, + img_width=args.width, + compression_factor=8, + max_len=512, + model_dir=model_dir, + torch_dtype=torch_dtypes[args.dtype], + bs=args.batch_size, + cfg_mode=args.cfg_mode, + ) + if args.save_mlir_to or args.save_params_to: + mlir_path, params_path = do_onnx_import(args, model_dir=model_dir) diff --git a/sharktank/sharktank/dynamo_exports/flux/scheduler.py b/sharktank/sharktank/torch_exports/flux/scheduler.py similarity index 100% rename from sharktank/sharktank/dynamo_exports/flux/scheduler.py rename to sharktank/sharktank/torch_exports/flux/scheduler.py diff --git a/sharktank/sharktank/dynamo_exports/flux/te.py b/sharktank/sharktank/torch_exports/flux/te.py similarity index 100% rename from sharktank/sharktank/dynamo_exports/flux/te.py rename to sharktank/sharktank/torch_exports/flux/te.py From cfd351959ce164b9363ac7909cead2ec5d762ffb Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 8 Jan 2025 16:39:49 -0600 Subject: [PATCH 11/20] Update README --- shortfin/python/shortfin_apps/flux/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/shortfin/python/shortfin_apps/flux/README.md b/shortfin/python/shortfin_apps/flux/README.md index a0ff7b809..cf4dd2544 100644 --- a/shortfin/python/shortfin_apps/flux/README.md +++ b/shortfin/python/shortfin_apps/flux/README.md @@ -14,8 +14,9 @@ By default, the port is set to 8000. If you would like to change this, use `--po You can check if this (or any) port is in use on Linux with `ss -ntl | grep 8000`. +From a source checkout of shortfin: ``` -python -m shortfin_apps.flux.server --device=amdgpu --device_ids=0 --build_preference=precompiled --topology="spx_single" +python -m shortfin_apps.flux.server --model_config=./python/shortfin_apps/flux/examples/flux_dev_config_mixed.json --device=amdgpu --fibers_per_device=1 --workers_per_device=1 --isolation="per_fiber" --flagfile=./python/shortfin_apps/flux/examples/flux_flags_gfx942.txt --build_preference=precompiled ``` - Wait until your server outputs: ``` From 0e65537296b0d8d5f736af120aeb96a91b329f7c Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Tue, 28 Jan 2025 20:26:04 +0000 Subject: [PATCH 12/20] Updates that get everything but VAE working --- sharktank/sharktank/models/flux/flux.py | 1 + .../sharktank/torch_exports/flux/export.py | 38 +++- .../sharktank/torch_exports/flux/mmdit.py | 178 ++++++------------ sharktank/sharktank/torch_exports/flux/te.py | 19 +- .../flux/components/config_struct.py | 6 +- .../shortfin_apps/flux/components/service.py | 49 ++++- .../flux/examples/flux_dev_config_mixed.json | 8 +- 7 files changed, 150 insertions(+), 149 deletions(-) diff --git a/sharktank/sharktank/models/flux/flux.py b/sharktank/sharktank/models/flux/flux.py index e8003bb68..219911fb3 100644 --- a/sharktank/sharktank/models/flux/flux.py +++ b/sharktank/sharktank/models/flux/flux.py @@ -188,6 +188,7 @@ def forward( "Didn't get guidance strength for guidance distilled model." ) vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) txt = self.txt_in(txt) diff --git a/sharktank/sharktank/torch_exports/flux/export.py b/sharktank/sharktank/torch_exports/flux/export.py index 9002aae18..5f307ce0a 100644 --- a/sharktank/sharktank/torch_exports/flux/export.py +++ b/sharktank/sharktank/torch_exports/flux/export.py @@ -25,7 +25,9 @@ from ae import AutoEncoder, AutoEncoderParams from scheduler import FluxScheduler from mmdit import get_flux_transformer_model - +from sharktank.models.vae.model import VaeDecoderModel +#from sharktank.models.flux import FluxParams, FluxModelV1 +from sharktank.types.theta import Theta, Dataset, torch_module_to_theta @dataclass class ModelSpec: @@ -126,16 +128,30 @@ def get_te_model_and_inputs( ] return te, input_args case "t5xxl": - return None, None + te = HFEmbedder( + "t5xxl", + max_length=512, + torch_dtype=torch.float32, + ) + clip_ids_shape = ( + batch_size, + 512, + ) + input_args = [ + torch.ones(clip_ids_shape, dtype=torch.int64), + ] + return te, input_args class FluxAEWrapper(torch.nn.Module): def __init__(self, height=1024, width=1024, precision="fp32"): super().__init__() dtype = torch_dtypes[precision] - self.ae = AutoencoderKL.from_pretrained( - "black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtypes=dtype - ) + #self.ae = AutoencoderKL.from_pretrained( + # "black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtypes=dtype + #) + dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/vae.irpa") + self.ae = VaeDecoderModel.from_dataset(dataset) self.height = height self.width = width @@ -148,8 +164,9 @@ def forward(self, z): ph=2, pw=2, ) - d_in = d_in / self.ae.config.scaling_factor + self.ae.config.shift_factor - return self.ae.decode(d_in, return_dict=False)[0].clamp(-1, 1) + #d_in = d_in / self.ae.config.scaling_factor + self.ae.config.shift_factor + #return self.ae.decode(d_in, return_dict=False)[0].clamp(-1, 1) + return self.ae.forward(d_in) def get_ae_model_and_inputs(hf_model_name, precision, batch_size, height, width): @@ -218,9 +235,10 @@ def export_flux_model( add_ops=decomp_list, ): if component == "mmdit": - model, sample_inputs, _ = get_flux_model_and_inputs( + model, sample_inputs = get_flux_model_and_inputs( hf_model_name, precision, batch_size, max_length, height, width ) + print(sample_inputs) fxb = FxProgramsBuilder(model) @@ -345,6 +363,10 @@ def get_filename(args): return create_safe_name( args.model, f"clip_bs{args.batch_size}_77_{args.precision}" ) + case "t5xxl": + return create_safe_name( + args.model, f"t5xxl_bs{args.batch_size}_256_{args.precision}" + ) case "scheduler": return create_safe_name( args.model, diff --git a/sharktank/sharktank/torch_exports/flux/mmdit.py b/sharktank/sharktank/torch_exports/flux/mmdit.py index d4329f89d..e47a0c5c4 100644 --- a/sharktank/sharktank/torch_exports/flux/mmdit.py +++ b/sharktank/sharktank/torch_exports/flux/mmdit.py @@ -4,6 +4,8 @@ from diffusers import FluxTransformer2DModel from typing import Callable from iree.turbine.aot import * +from sharktank.models.flux.flux import FluxModelV1, FluxParams +from sharktank.types.theta import Theta, Dataset, torch_module_to_theta def get_local_path(local_dir, model_dir): @@ -13,74 +15,45 @@ def get_local_path(local_dir, model_dir): return model_local_dir -class FluxModelCFG(torch.nn.Module): - def __init__(self, torch_dtype): - super().__init__() - self.mmdit = FluxTransformer2DModel.from_single_file( - "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" - ).to(torch_dtype) - def forward( +class FluxDenoiseStepModel(torch.nn.Module): + def __init__( self, - hidden_states, - encoder_hidden_states, - pooled_projections, - img_ids, - txt_ids, - guidance_vec, - t_vec, - t_curr, - t_prev, - cfg_scale, + theta, + params, + batch_size=1, + max_length=512, + height=1024, + width=1024, ): - pred = self.mmdit( - hidden_states=hidden_states, - img_ids=img_ids, - encoder_hidden_states=encoder_hidden_states, - txt_ids=txt_ids, - pooled_projections=pooled_projections, - timestep=t_vec, - guidance=guidance_vec, - return_dict=False, - )[0] - pred_uncond, pred = torch.chunk(pred, 2, dim=0) - pred = pred_uncond + cfg_scale * (pred - pred_uncond) - hidden_states = hidden_states + (t_prev - t_curr) * pred - return hidden_states - - -class FluxModelSchnell(torch.nn.Module): - def __init__(self, torch_dtype): super().__init__() - self.mmdit = FluxTransformer2DModel.from_single_file( - "https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/main/flux1-schnell.safetensors" - ).to(torch_dtype) + self.mmdit = FluxModelV1(theta=theta, params=params) + self.batch_size = batch_size + img_ids = torch.zeros(height // 16, width // 16, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(height // 16)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(width // 16)[None, :] + self.img_ids = img_ids.reshape(1, height * width // 256, 3) + self.txt_ids = torch.zeros(1, max_length, 3) + + def forward(self, img, txt, vec, step, timesteps, guidance_scale): + guidance_vec = guidance_scale.repeat(self.batch_size) + t_curr = torch.index_select(timesteps, 0, step) + t_prev = torch.index_select(timesteps, 0, step + 1) + t_vec = t_curr.repeat(self.batch_size) - def forward( - self, - hidden_states, - encoder_hidden_states, - pooled_projections, - img_ids, - txt_ids, - guidance_vec, - t_vec, - t_curr, - t_prev, - cfg_scale, - ): pred = self.mmdit( - hidden_states=hidden_states, - img_ids=img_ids, - encoder_hidden_states=encoder_hidden_states, - txt_ids=txt_ids, - pooled_projections=pooled_projections, - timestep=t_vec, + img=img, + img_ids=self.img_ids, + txt=txt, + txt_ids=self.txt_ids, + y=vec, + timesteps=t_vec, guidance=guidance_vec, - return_dict=False, - )[0] - hidden_states = hidden_states + (t_prev - t_curr) * pred - return hidden_states + ) + #pred_uncond, pred = torch.chunk(pred, 2, dim=0) + #pred = pred_uncond + guidance_scale * (pred - pred_uncond) + img = img + (t_prev - t_curr) * pred + return img @torch.no_grad() @@ -93,68 +66,29 @@ def get_flux_transformer_model( torch_dtype=torch.float32, bs=1, ): - - latent_h, latent_w = ( - img_height // compression_factor, - img_width // compression_factor, + #transformer_dataset = Dataset.load(transformer_path) + #transformer_dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/transformer/model.irpa") + transformer_dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/transformer.irpa") + model = FluxDenoiseStepModel(theta=transformer_dataset.root_theta, params=FluxParams.from_hugging_face_properties(transformer_dataset.properties)) + #model = FluxModelV1(theta=transformer_dataset.root_theta, params=FluxParams.from_hugging_face_properties(transformer_dataset.properties)) + #dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/transformer.irpa") + #transformer_params = FluxParams.from_hugging_face_properties(transformer_dataset.properties) + #model = FluxModelV1( + # theta=transformer_dataset.root_theta, + # params=transformer_params + #) + sample_args, sample_kwargs = model.mmdit.sample_inputs() + sample_inputs = ( + sample_kwargs["img"], + #sample_kwargs["img_ids"], + sample_kwargs["txt"], + #sample_kwargs["txt_ids"], + sample_kwargs["y"], + torch.full((bs,), 1, dtype=torch.int64), + torch.full((100,), 1, dtype=torch_dtype), # TODO: non-dev timestep sizes + sample_kwargs["guidance"], ) - - if "schnell" in hf_model_path: - model = FluxModelSchnell(torch_dtype=torch_dtype) - config = model.mmdit.config - sample_inputs = ( - torch.randn( - bs, - (latent_h // 2) * (latent_w // 2), - config["in_channels"], - dtype=torch_dtype, - ), - torch.randn(bs, max_len, config["joint_attention_dim"], dtype=torch_dtype), - torch.randn(bs, config["pooled_projection_dim"], dtype=torch_dtype), - torch.randn((latent_h // 2) * (latent_w // 2), 3, dtype=torch_dtype), - torch.randn(max_len, 3, dtype=torch_dtype), - torch.tensor([1.0] * bs, dtype=torch_dtype), - torch.tensor([1.0] * bs, dtype=torch_dtype), - torch.tensor([1.0], dtype=torch_dtype), - torch.tensor([1.0], dtype=torch_dtype), - torch.tensor([1.0] * bs, dtype=torch_dtype), - ) - else: - model = FluxModelCFG(torch_dtype=torch_dtype) - config = model.mmdit.config - cfg_bs = bs * 2 - sample_inputs = ( - torch.randn( - cfg_bs, - (latent_h // 2) * (latent_w // 2), - config["in_channels"], - dtype=torch_dtype, - ), - torch.randn( - cfg_bs, max_len, config["joint_attention_dim"], dtype=torch_dtype - ), - torch.randn(cfg_bs, config["pooled_projection_dim"], dtype=torch_dtype), - torch.randn((latent_h // 2) * (latent_w // 2), 3, dtype=torch_dtype), - torch.randn(max_len, 3, dtype=torch_dtype), - torch.tensor([1.0] * bs, dtype=torch_dtype), - torch.tensor([1.0] * cfg_bs, dtype=torch_dtype), - torch.tensor([1.0], dtype=torch_dtype), - torch.tensor([1.0], dtype=torch_dtype), - torch.tensor([1.0] * bs, dtype=torch_dtype), - ) - - input_names = [ - "hidden_states", - "encoder_hidden_states", - "pooled_projections", - "img_ids", - "txt_ids", - "guidance_vec", - "t_curr", - "t_prev", - "cfg_scale", - ] - return model, sample_inputs, input_names + return model, sample_inputs # if not os.path.isfile(onnx_path): # output_names = ["latent"] diff --git a/sharktank/sharktank/torch_exports/flux/te.py b/sharktank/sharktank/torch_exports/flux/te.py index 74d17a665..bd5bba640 100644 --- a/sharktank/sharktank/torch_exports/flux/te.py +++ b/sharktank/sharktank/torch_exports/flux/te.py @@ -1,5 +1,10 @@ +import torch from torch import Tensor, nn -from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer + +from sharktank.types.theta import Theta, Dataset, torch_module_to_theta +from transformers import CLIPTextModel +from sharktank.models.clip import ClipTextModel, ClipTextConfig +from sharktank.models.t5 import T5Encoder, T5Config # Copied from https://github.com/black-forest-labs/flux class HFEmbedder(nn.Module): @@ -13,10 +18,18 @@ def __init__(self, version: str, max_length: int, **hf_kwargs): self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( version, **hf_kwargs ) + #theta = torch_module_to_theta(self.hf_module) + config = ClipTextConfig.from_hugging_face_clip_text_model_config(self.hf_module.config) + config.dtype = torch.float32 + dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/clip.irpa") + self.hf_module = ClipTextModel(theta=dataset.root_theta, config=config) else: - self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained( - version, **hf_kwargs + t5_dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/t5.irpa") + t5_config = T5Config.from_gguf_properties( + t5_dataset.properties, + feed_forward_proj="gated-gelu", ) + self.hf_module = T5Encoder(theta=t5_dataset.root_theta, config=t5_config) self.hf_module = self.hf_module.eval().requires_grad_(False) diff --git a/shortfin/python/shortfin_apps/flux/components/config_struct.py b/shortfin/python/shortfin_apps/flux/components/config_struct.py index 29bc651c0..d99f8f793 100644 --- a/shortfin/python/shortfin_apps/flux/components/config_struct.py +++ b/shortfin/python/shortfin_apps/flux/components/config_struct.py @@ -55,13 +55,15 @@ class ModelParams: max_seq_len: int = 512 t5xxl_module_name: str = "module" - t5xxl_fn_name: str = "forward_bs4" + t5xxl_fn_name: str = "encode_prompts" t5xxl_dtype: sfnp.DType = sfnp.bfloat16 # Channel dim of latents. num_latents_channels: int = 16 - sampler_module_name: str = "" + #sampler_module_name: str = "compiled_flux_transformer" + #sampler_fn_name: str = "run_forward" + sampler_module_name: str = "module" sampler_fn_name: str = "main_graph" sampler_dtype: sfnp.DType = sfnp.float32 diff --git a/shortfin/python/shortfin_apps/flux/components/service.py b/shortfin/python/shortfin_apps/flux/components/service.py index b758536d3..04136cc5e 100644 --- a/shortfin/python/shortfin_apps/flux/components/service.py +++ b/shortfin/python/shortfin_apps/flux/components/service.py @@ -202,7 +202,7 @@ def start(self): self.inference_functions[worker_idx]["t5xxl"][ bs ] = self.inference_programs[worker_idx]["t5xxl"][ - f"{self.model_params.t5xxl_module_name}.forward_bs4" + f"{self.model_params.t5xxl_module_name}.encode_prompts" ] self.inference_functions[worker_idx]["denoise"] = {} for bs in self.model_params.sampler_batch_sizes: @@ -513,7 +513,6 @@ async def _clip(self, device, requests): break # Prepare tokenized input ids for CLIP inference - clip_inputs = [ sfnp.device_array.for_device( device, @@ -539,12 +538,19 @@ async def _clip(self, device, requests): fn, "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]), ) + await device (vec,) = await fn(*clip_inputs, fiber=self.fiber) - await device + for i in range(req_bs): - cfg_mult = 2 + cfg_mult = 1 requests[i].vec = vec.view(slice(i, (i + 1))) + + await device + a = vec.for_transfer() + a.copy_from(vec) + await device + print(torch.frombuffer(a.items, dtype=torch.float32).shape) return @@ -560,10 +566,10 @@ async def _t5xxl(self, device, requests): break # Prepare tokenized input ids for t5xxl inference - + bs_or_something =1 t5xxl_inputs = [ sfnp.device_array.for_device( - device, [4, self.service.model_params.max_seq_len], sfnp.sint64 + device, [bs_or_something, self.service.model_params.max_seq_len], sfnp.sint64 ), ] host_arrs = [None] @@ -571,7 +577,7 @@ async def _t5xxl(self, device, requests): host_arrs[idx] = arr.for_transfer() for i in range(req_bs): np_arr = requests[i].t5xxl_input_ids[idx].input_ids - for rep in range(4): + for rep in range(bs_or_something): with host_arrs[idx].view(rep).map(write=True, discard=True) as m: m.fill(np_arr) t5xxl_inputs[idx].copy_from(host_arrs[idx]) @@ -586,7 +592,7 @@ async def _t5xxl(self, device, requests): (txt,) = await fn(*t5xxl_inputs, fiber=self.fiber) await device for i in range(req_bs): - cfg_mult = 2 + cfg_mult = 1 requests[i].txt = txt.view(slice(i * cfg_mult, (i + 1) * cfg_mult)) return @@ -594,7 +600,8 @@ async def _t5xxl(self, device, requests): async def _denoise(self, device, requests): req_bs = len(requests) step_count = requests[0].steps - cfg_mult = 2 if not self.service.model_params.is_schnell else 1 + #cfg_mult = 2 if not self.service.model_params.is_schnell else 1 + cfg_mult = 1 # Produce denoised latents entrypoints = self.service.inference_functions[self.worker_index]["denoise"] if req_bs not in list(entrypoints.keys()): @@ -618,6 +625,7 @@ async def _denoise(self, device, requests): self.service.model_params.max_seq_len, 4096, ] + print(req_bs * cfg_mult, 768) vec_shape = [ req_bs * cfg_mult, 768, @@ -641,15 +649,27 @@ async def _denoise(self, device, requests): ), } # Send guidance scale to device. + print("hi") + await device gs_host = denoise_inputs["guidance_scale"].for_transfer() sample_host = sfnp.device_array.for_host( device, img_shape, self.service.model_params.sampler_dtype ) guidance_float = sfnp.device_array.for_host(device, [req_bs], sfnp.float32) + print("hi") + await device + + #for key, value in denoise_inputs.items(): + #host_arrs[key] = denoise_inputs[key].for_transfer() + #host_arrs[key].copy_from(denoise_inputs[key]) + #await device + #print(torch.frombuffer(host_arrs[key].items, dtype=torch.float32).shape) for i in range(req_bs): guidance_float.view(i).items = [requests[i].guidance_scale] cfg_dim = i * cfg_mult + print("hi") + await device # Reshape and batch sample latent inputs on device. # Currently we just generate random latents in the desired shape. Rework for img2img. @@ -658,21 +678,30 @@ async def _denoise(self, device, requests): sample_host.view(slice(cfg_dim + rep, cfg_dim + rep + 1)).copy_from( req_samp ) + print("hi") + await device denoise_inputs["img"].view(slice(cfg_dim, cfg_dim + cfg_mult)).copy_from( sample_host ) + print("hi") + await device # Batch t5xxl hidden states. txt = requests[i].txt denoise_inputs["txt"].view(slice(cfg_dim, cfg_dim + cfg_mult)).copy_from( txt ) + print("hi") + await device # Batch CLIP projections. vec = requests[i].vec - for nc in range(2): + #for nc in range(cfg_mult): + for nc in range(1): denoise_inputs["vec"].view(slice(nc, nc + 1)).copy_from(vec) + print("hi") + await device sfnp.convert( guidance_float, dtype=self.service.model_params.sampler_dtype, out=gs_host ) diff --git a/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json b/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json index 46a20af78..4c107452e 100644 --- a/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json +++ b/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json @@ -13,14 +13,14 @@ 1 ], "t5xxl_dtype": "float32", - "t5xxl_module_name": "module", - "t5xxl_fn_name": "forward_bs4", + "t5xxl_module_name": "compiled_flux_text_encoder_2", + "t5xxl_fn_name": "encode_prompts", "sampler_batch_sizes": [ 1 ], "sampler_dtype": "float32", - "sampler_module_name": "module", - "sampler_fn_name": "main_graph", + "sampler_module_name": "compiled_flux_transformer", + "sampler_fn_name": "run_forward", "vae_batch_sizes": [ 1 ], From 04c1a0e22fa4db3927f5721bf2a635ced8e71cb0 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Tue, 28 Jan 2025 23:37:53 +0000 Subject: [PATCH 13/20] Simplify export code to one file --- sharktank/sharktank/torch_exports/flux/ae.py | 352 ------------------ .../sharktank/torch_exports/flux/export.py | 205 ++++++++-- .../sharktank/torch_exports/flux/mmdit.py | 116 ------ .../torch_exports/flux/mmdit_onnx.py | 333 ----------------- .../sharktank/torch_exports/flux/scheduler.py | 50 --- sharktank/sharktank/torch_exports/flux/te.py | 42 --- 6 files changed, 181 insertions(+), 917 deletions(-) delete mode 100644 sharktank/sharktank/torch_exports/flux/ae.py delete mode 100644 sharktank/sharktank/torch_exports/flux/mmdit.py delete mode 100644 sharktank/sharktank/torch_exports/flux/mmdit_onnx.py delete mode 100644 sharktank/sharktank/torch_exports/flux/scheduler.py delete mode 100644 sharktank/sharktank/torch_exports/flux/te.py diff --git a/sharktank/sharktank/torch_exports/flux/ae.py b/sharktank/sharktank/torch_exports/flux/ae.py deleted file mode 100644 index b9363090e..000000000 --- a/sharktank/sharktank/torch_exports/flux/ae.py +++ /dev/null @@ -1,352 +0,0 @@ -from torch import Tensor, nn -import torch -from einops import rearrange -from dataclasses import dataclass -import math - -# This Flux AE implementation is copied from https://github.com/black-forest-labs/flux. - - -@dataclass -class AutoEncoderParams: - resolution: int - in_channels: int - ch: int - out_ch: int - ch_mult: list[int] - num_res_blocks: int - z_channels: int - scale_factor: float - shift_factor: float - height: int - width: int - - -def swish(x: Tensor) -> Tensor: - return x * torch.sigmoid(x) - - -class AttnBlock(nn.Module): - def __init__(self, in_channels: int): - super().__init__() - self.in_channels = in_channels - - self.norm = nn.GroupNorm( - num_groups=32, num_channels=in_channels, eps=1e-6, affine=True - ) - - self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) - self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) - self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) - self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) - - def attention(self, h_: Tensor) -> Tensor: - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - b, c, h, w = q.shape - q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() - k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() - v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() - h_ = nn.functional.scaled_dot_product_attention(q, k, v) - - return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) - - def forward(self, x: Tensor) -> Tensor: - return x + self.proj_out(self.attention(x)) - - -class ResnetBlock(nn.Module): - def __init__(self, in_channels: int, out_channels: int): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - - self.norm1 = nn.GroupNorm( - num_groups=32, num_channels=in_channels, eps=1e-6, affine=True - ) - self.conv1 = nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - self.norm2 = nn.GroupNorm( - num_groups=32, num_channels=out_channels, eps=1e-6, affine=True - ) - self.conv2 = nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - if self.in_channels != self.out_channels: - self.nin_shortcut = nn.Conv2d( - in_channels, out_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x): - h = x - h = self.norm1(h) - h = swish(h) - h = self.conv1(h) - - h = self.norm2(h) - h = swish(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - x = self.nin_shortcut(x) - - return x + h - - -class Downsample(nn.Module): - def __init__(self, in_channels: int): - super().__init__() - # no asymmetric padding in torch conv, must do it ourselves - self.conv = nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=2, padding=0 - ) - - def forward(self, x: Tensor): - pad = (0, 1, 0, 1) - x = nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - return x - - -class Upsample(nn.Module): - def __init__(self, in_channels: int): - super().__init__() - self.conv = nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x: Tensor): - x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - x = self.conv(x) - return x - - -class Encoder(nn.Module): - def __init__( - self, - resolution: int, - in_channels: int, - ch: int, - ch_mult: list[int], - num_res_blocks: int, - z_channels: int, - ): - super().__init__() - self.ch = ch - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - # downsampling - self.conv_in = nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.ModuleList() - block_in = self.ch - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for _ in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) - block_in = block_out - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) - - # end - self.norm_out = nn.GroupNorm( - num_groups=32, num_channels=block_in, eps=1e-6, affine=True - ) - self.conv_out = nn.Conv2d( - block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x: Tensor) -> Tensor: - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1]) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h) - h = self.mid.attn_1(h) - h = self.mid.block_2(h) - # end - h = self.norm_out(h) - h = swish(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__( - self, - ch: int, - out_ch: int, - ch_mult: list[int], - num_res_blocks: int, - in_channels: int, - resolution: int, - z_channels: int, - ): - super().__init__() - self.ch = ch - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.ffactor = 2 ** (self.num_resolutions - 1) - - # compute in_ch_mult, block_in and curr_res at lowest res - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - - # z to block_in - self.conv_in = nn.Conv2d( - z_channels, block_in, kernel_size=3, stride=1, padding=1 - ) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for _ in range(self.num_res_blocks + 1): - block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) - block_in = block_out - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = nn.GroupNorm( - num_groups=32, num_channels=block_in, eps=1e-6, affine=True - ) - self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) - - def forward(self, z: Tensor) -> Tensor: - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h) - h = self.mid.attn_1(h) - h = self.mid.block_2(h) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = swish(h) - h = self.conv_out(h) - return h - - -class DiagonalGaussian(nn.Module): - def __init__(self, sample: bool = True, chunk_dim: int = 1): - super().__init__() - self.sample = sample - self.chunk_dim = chunk_dim - - def forward(self, z: Tensor) -> Tensor: - mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) - if self.sample: - std = torch.exp(0.5 * logvar) - return mean + std * torch.randn_like(mean) - else: - return mean - - -class AutoEncoder(nn.Module): - def __init__(self, params: AutoEncoderParams): - super().__init__() - self.encoder = Encoder( - resolution=params.resolution, - in_channels=params.in_channels, - ch=params.ch, - ch_mult=params.ch_mult, - num_res_blocks=params.num_res_blocks, - z_channels=params.z_channels, - ) - self.decoder = Decoder( - resolution=params.resolution, - in_channels=params.in_channels, - ch=params.ch, - out_ch=params.out_ch, - ch_mult=params.ch_mult, - num_res_blocks=params.num_res_blocks, - z_channels=params.z_channels, - ) - self.reg = DiagonalGaussian() - - self.scale_factor = params.scale_factor - self.shift_factor = params.shift_factor - self.height = params.height - self.width = params.width - - def encode(self, x: Tensor) -> Tensor: - z = self.reg(self.encoder(x)) - z = self.scale_factor * (z - self.shift_factor) - return z - - def decode(self, z: Tensor) -> Tensor: - d_in = rearrange( - z, - "b (h w) (c ph pw) -> b c (h ph) (w pw)", - h=math.ceil(self.height / 16), - w=math.ceil(self.width / 16), - ph=2, - pw=2, - ) - d_in = d_in / self.scale_factor + self.shift_factor - return self.decoder(d_in).clamp(-1, 1) - - def forward(self, x: Tensor) -> Tensor: - return self.decode(self.encode(x)) diff --git a/sharktank/sharktank/torch_exports/flux/export.py b/sharktank/sharktank/torch_exports/flux/export.py index 5f307ce0a..77738883a 100644 --- a/sharktank/sharktank/torch_exports/flux/export.py +++ b/sharktank/sharktank/torch_exports/flux/export.py @@ -8,6 +8,8 @@ import re from dataclasses import dataclass import math +import torch +from typing import Callable from einops import rearrange @@ -16,19 +18,16 @@ from iree.turbine.dynamo.passes import ( DEFAULT_DECOMPOSITIONS, ) -import torch -from diffusers.models.transformers import FluxTransformer2DModel -from diffusers.models.autoencoders import AutoencoderKL -from te import HFEmbedder from transformers import CLIPTextModel -from ae import AutoEncoder, AutoEncoderParams -from scheduler import FluxScheduler -from mmdit import get_flux_transformer_model +from sharktank.models.clip import ClipTextModel, ClipTextConfig +from sharktank.models.t5 import T5Encoder, T5Config +from sharktank.models.flux.flux import FluxModelV1, FluxParams from sharktank.models.vae.model import VaeDecoderModel -#from sharktank.models.flux import FluxParams, FluxModelV1 from sharktank.types.theta import Theta, Dataset, torch_module_to_theta + + @dataclass class ModelSpec: ae_params: AutoEncoderParams @@ -100,6 +99,72 @@ def create_safe_name(hf_model_name, model_name_str=""): return safe_name +class FluxDenoiseStepModel(torch.nn.Module): + def __init__( + self, + theta, + params, + batch_size=1, + max_length=512, + height=1024, + width=1024, + ): + super().__init__() + self.mmdit = FluxModelV1(theta=theta, params=params) + self.batch_size = batch_size + img_ids = torch.zeros(height // 16, width // 16, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(height // 16)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(width // 16)[None, :] + self.img_ids = img_ids.reshape(1, height * width // 256, 3) + self.txt_ids = torch.zeros(1, max_length, 3) + + def forward(self, img, txt, vec, step, timesteps, guidance_scale): + guidance_vec = guidance_scale.repeat(self.batch_size) + t_curr = torch.index_select(timesteps, 0, step) + t_prev = torch.index_select(timesteps, 0, step + 1) + t_vec = t_curr.repeat(self.batch_size) + + pred = self.mmdit( + img=img, + img_ids=self.img_ids, + txt=txt, + txt_ids=self.txt_ids, + y=vec, + timesteps=t_vec, + guidance=guidance_vec, + ) + # TODO: Use guidance scale + # pred_uncond, pred = torch.chunk(pred, 2, dim=0) + # pred = pred_uncond + guidance_scale * (pred - pred_uncond) + img = img + (t_prev - t_curr) * pred + return img + + +@torch.no_grad() +def get_flux_transformer_model( + hf_model_path, + img_height=1024, + img_width=1024, + compression_factor=8, + max_len=512, + torch_dtype=torch.float32, + bs=1, +): + # DNS: refactor file to select datatype + transformer_dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/transformer.irpa") + model = FluxDenoiseStepModel(theta=transformer_dataset.root_theta, params=FluxParams.from_hugging_face_properties(transformer_dataset.properties)) + sample_args, sample_kwargs = model.mmdit.sample_inputs() + sample_inputs = ( + sample_kwargs["img"], + sample_kwargs["txt"], + sample_kwargs["y"], + torch.full((bs,), 1, dtype=torch.int64), + torch.full((100,), 1, dtype=torch_dtype), # TODO: non-dev timestep sizes + sample_kwargs["guidance"], + ) + return model, sample_inputs + + def get_flux_model_and_inputs( hf_model_name, precision, batch_size, max_length, height, width ): @@ -108,6 +173,40 @@ def get_flux_model_and_inputs( hf_model_name, height, width, 8, max_length, dtype, batch_size ) +# Copied from https://github.com/black-forest-labs/flux +class HFEmbedder(nn.Module): + def __init__(self, version: str, max_length: int, **hf_kwargs): + super().__init__() + self.is_clip = version.startswith("openai") + self.max_length = max_length + self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" + + if self.is_clip: + self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( + version, **hf_kwargs + ) + # DNS: Refactor to not rely on huggingface + config = ClipTextConfig.from_hugging_face_clip_text_model_config(self.hf_module.config) + config.dtype = torch.float32 + dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/clip.irpa") + self.hf_module = ClipTextModel(theta=dataset.root_theta, config=config) + else: + t5_dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/t5.irpa") + t5_config = T5Config.from_gguf_properties( + t5_dataset.properties, + feed_forward_proj="gated-gelu", + ) + self.hf_module = T5Encoder(theta=t5_dataset.root_theta, config=t5_config) + + self.hf_module = self.hf_module.eval().requires_grad_(False) + + def forward(self, input_ids) -> Tensor: + outputs = self.hf_module( + input_ids=input_ids, + attention_mask=None, + output_hidden_states=False, + ) + return outputs[self.output_key] def get_te_model_and_inputs( hf_model_name, component, precision, batch_size, max_length @@ -135,7 +234,7 @@ def get_te_model_and_inputs( ) clip_ids_shape = ( batch_size, - 512, + 512, #DNS ) input_args = [ torch.ones(clip_ids_shape, dtype=torch.int64), @@ -147,9 +246,6 @@ class FluxAEWrapper(torch.nn.Module): def __init__(self, height=1024, width=1024, precision="fp32"): super().__init__() dtype = torch_dtypes[precision] - #self.ae = AutoencoderKL.from_pretrained( - # "black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtypes=dtype - #) dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/vae.irpa") self.ae = VaeDecoderModel.from_dataset(dataset) self.height = height @@ -195,6 +291,52 @@ def get_ae_model_and_inputs(hf_model_name, precision, batch_size, height, width) return ae, encode_inputs, decode_inputs +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # eastimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps + + +class FluxScheduler(torch.nn.Module): + def __init__(self, max_length, torch_dtype, is_schnell=False): + super().__init__() + self.is_schnell = is_schnell + self.max_length = max_length + timesteps = [torch.empty((100), dtype=torch_dtype, requires_grad=False)] * 100 + for i in range(1, 100): + schedule = get_schedule(i, max_length, shift=not self.is_schnell) + timesteps[i] = torch.nn.functional.pad(schedule, (0, 99 - i), "constant", 0) + self.timesteps = torch.stack(timesteps, dim=0).clone().detach() + + def prepare(self, num_steps): + timesteps = self.timesteps[num_steps] + return timesteps + def get_scheduler_model_and_inputs(hf_model_name, max_length, precision): is_schnell = "schnell" in hf_model_name mod = FluxScheduler( @@ -203,8 +345,6 @@ def get_scheduler_model_and_inputs(hf_model_name, max_length, precision): is_schnell=is_schnell, ) sample_inputs = (torch.empty(1, dtype=torch.int64),) - # tdim = torch.export.Dim("timesteps") - # dynamic_inputs = {"timesteps": {0: tdim}} return mod, sample_inputs @@ -262,7 +402,7 @@ class CompiledFluxTransformer(CompiledModule): module = CompiledModule.get_mlir_module(inst) - elif component in ["clip", "t5xxl"]: + elif component == "clip": model, sample_inputs = get_te_model_and_inputs( hf_model_name, component, precision, batch_size, max_length ) @@ -287,6 +427,32 @@ class CompiledFluxTextEncoder(CompiledModule): inst = CompiledFluxTextEncoder(context=Context(), import_to="IMPORT") + module = CompiledModule.get_mlir_module(inst) + elif component == "t5xxl": + model, sample_inputs = get_te_model_and_inputs( + hf_model_name, component, precision, batch_size, max_length + ) + + fxb = FxProgramsBuilder(model) + + @fxb.export_program( + args=(sample_inputs,), + ) + def _forward( + module, + inputs, + ): + return module.forward(*inputs) + + class CompiledFluxTextEncoder2(CompiledModule): + encode_prompts = _forward + + if external_weights: + externalize_module_parameters(model) + save_module_parameters(external_weight_path, model) + + inst = CompiledFluxTextEncoder(context=Context(), import_to="IMPORT") + module = CompiledModule.get_mlir_module(inst) elif component == "vae": model, encode_inputs, decode_inputs = get_ae_model_and_inputs( @@ -295,15 +461,6 @@ class CompiledFluxTextEncoder(CompiledModule): fxb = FxProgramsBuilder(model) - # @fxb.export_program( - # args=(encode_inputs,), - # ) - # def _encode( - # module, - # inputs, - # ): - # return module.encode(*inputs) - @fxb.export_program( args=(decode_inputs,), ) diff --git a/sharktank/sharktank/torch_exports/flux/mmdit.py b/sharktank/sharktank/torch_exports/flux/mmdit.py deleted file mode 100644 index e47a0c5c4..000000000 --- a/sharktank/sharktank/torch_exports/flux/mmdit.py +++ /dev/null @@ -1,116 +0,0 @@ -import os -import torch -import math -from diffusers import FluxTransformer2DModel -from typing import Callable -from iree.turbine.aot import * -from sharktank.models.flux.flux import FluxModelV1, FluxParams -from sharktank.types.theta import Theta, Dataset, torch_module_to_theta - - -def get_local_path(local_dir, model_dir): - model_local_dir = os.path.join(local_dir, model_dir) - if not os.path.exists(model_local_dir): - os.makedirs(model_local_dir) - return model_local_dir - - - -class FluxDenoiseStepModel(torch.nn.Module): - def __init__( - self, - theta, - params, - batch_size=1, - max_length=512, - height=1024, - width=1024, - ): - super().__init__() - self.mmdit = FluxModelV1(theta=theta, params=params) - self.batch_size = batch_size - img_ids = torch.zeros(height // 16, width // 16, 3) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(height // 16)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(width // 16)[None, :] - self.img_ids = img_ids.reshape(1, height * width // 256, 3) - self.txt_ids = torch.zeros(1, max_length, 3) - - def forward(self, img, txt, vec, step, timesteps, guidance_scale): - guidance_vec = guidance_scale.repeat(self.batch_size) - t_curr = torch.index_select(timesteps, 0, step) - t_prev = torch.index_select(timesteps, 0, step + 1) - t_vec = t_curr.repeat(self.batch_size) - - pred = self.mmdit( - img=img, - img_ids=self.img_ids, - txt=txt, - txt_ids=self.txt_ids, - y=vec, - timesteps=t_vec, - guidance=guidance_vec, - ) - #pred_uncond, pred = torch.chunk(pred, 2, dim=0) - #pred = pred_uncond + guidance_scale * (pred - pred_uncond) - img = img + (t_prev - t_curr) * pred - return img - - -@torch.no_grad() -def get_flux_transformer_model( - hf_model_path, - img_height=1024, - img_width=1024, - compression_factor=8, - max_len=512, - torch_dtype=torch.float32, - bs=1, -): - #transformer_dataset = Dataset.load(transformer_path) - #transformer_dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/transformer/model.irpa") - transformer_dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/transformer.irpa") - model = FluxDenoiseStepModel(theta=transformer_dataset.root_theta, params=FluxParams.from_hugging_face_properties(transformer_dataset.properties)) - #model = FluxModelV1(theta=transformer_dataset.root_theta, params=FluxParams.from_hugging_face_properties(transformer_dataset.properties)) - #dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/transformer.irpa") - #transformer_params = FluxParams.from_hugging_face_properties(transformer_dataset.properties) - #model = FluxModelV1( - # theta=transformer_dataset.root_theta, - # params=transformer_params - #) - sample_args, sample_kwargs = model.mmdit.sample_inputs() - sample_inputs = ( - sample_kwargs["img"], - #sample_kwargs["img_ids"], - sample_kwargs["txt"], - #sample_kwargs["txt_ids"], - sample_kwargs["y"], - torch.full((bs,), 1, dtype=torch.int64), - torch.full((100,), 1, dtype=torch_dtype), # TODO: non-dev timestep sizes - sample_kwargs["guidance"], - ) - return model, sample_inputs - - # if not os.path.isfile(onnx_path): - # output_names = ["latent"] - # dynamic_axes = { - # 'hidden_states': {0: 'B', 1: 'latent_dim'}, - # 'encoder_hidden_states': {0: 'B',1: 'L'}, - # 'pooled_projections': {0: 'B'}, - # 'timestep': {0: 'B'}, - # 'img_ids': {0: 'latent_dim'}, - # 'txt_ids': {0: 'L'}, - # 'guidance': {0: 'B'}, - # } - - # with torch.inference_mode(): - # torch.onnx.export( - # model, - # sample_inputs, - # onnx_path, - # export_params=True, - # input_names=input_names, - # output_names=output_names) - - # assert os.path.isfile(onnx_path) - - # return onnx_path diff --git a/sharktank/sharktank/torch_exports/flux/mmdit_onnx.py b/sharktank/sharktank/torch_exports/flux/mmdit_onnx.py deleted file mode 100644 index 214162f8a..000000000 --- a/sharktank/sharktank/torch_exports/flux/mmdit_onnx.py +++ /dev/null @@ -1,333 +0,0 @@ -import os -import math -from typing import Callable - -import torch -from einops import repeat, rearrange -from diffusers import FluxTransformer2DModel - -from iree.turbine.aot import * - - -def get_local_path(local_dir, model_dir): - model_local_dir = os.path.join(local_dir, model_dir) - if not os.path.exists(model_local_dir): - os.makedirs(model_local_dir) - return model_local_dir - - -def time_shift(mu: float, sigma: float, t: torch.Tensor): - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) - - -def get_lin_function( - x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 -) -> Callable[[float], float]: - m = (y2 - y1) / (x2 - x1) - b = y1 - m * x1 - return lambda x: m * x + b - - -def get_schedule( - num_steps: int, - image_seq_len: int, - base_shift: float = 0.5, - max_shift: float = 1.15, - shift: bool = True, -) -> list[float]: - # extra step for zero - timesteps = torch.linspace(1, 0, num_steps + 1) - - # shifting the schedule to favor high timesteps for higher signal images - if shift: - # eastimate mu based on linear estimation between two points - mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) - timesteps = time_shift(mu, 1.0, timesteps) - - return timesteps - - -class FluxScheduler(torch.nn.Module): - def __init__(self, max_length, torch_dtype, is_schnell=False): - super().__init__() - self.is_schnell = is_schnell - self.max_length = max_length - timesteps = [torch.empty((100), dtype=torch_dtype, requires_grad=False)] * 100 - for i in range(1, 100): - schedule = get_schedule(i, max_length, shift=not self.is_schnell) - timesteps[i] = torch.nn.functional.pad(schedule, (0, 99 - i), "constant", 0) - self.timesteps = torch.stack(timesteps, dim=0).clone().detach() - - def prepare(self, num_steps): - timesteps = self.timesteps[num_steps] - return timesteps - - -class FluxModelCFG(torch.nn.Module): - def __init__( - self, - torch_dtype, - model_id="flux-dev", - batch_size=1, - max_length=512, - height=1024, - width=1024, - ): - super().__init__() - self.mmdit = FluxTransformer2DModel.from_single_file( - "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" - ).to(torch_dtype) - self.batch_size = batch_size * 2 - img_ids = torch.zeros(height // 16, width // 16, 3) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(height // 16)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(width // 16)[None, :] - self.img_ids = rearrange(img_ids, "h w c -> (h w) c") - self.txt_ids = torch.zeros(max_length, 3) - - def forward(self, img, txt, vec, step, timesteps, guidance_scale): - guidance_vec = guidance_scale.repeat(self.batch_size) - t_curr = torch.index_select(timesteps, 0, step) - t_prev = torch.index_select(timesteps, 0, step + 1) - t_vec = t_curr.repeat(self.batch_size) - - pred = self.mmdit( - hidden_states=img, - img_ids=self.img_ids, - encoder_hidden_states=txt, - txt_ids=self.txt_ids, - pooled_projections=vec, - timestep=t_vec, - guidance=guidance_vec, - return_dict=False, - )[0] - pred_uncond, pred = torch.chunk(pred, 2, dim=0) - pred = pred_uncond + guidance_scale * (pred - pred_uncond) - img = img + (t_prev - t_curr) * pred - return img - - -class FluxModelSchnell(torch.nn.Module): - def __init__( - self, - torch_dtype, - model_id="flux-schnell", - batch_size=1, - max_length=512, - height=1024, - width=1024, - ): - super().__init__() - if "schnell" in model_id: - self.mmdit = FluxTransformer2DModel.from_single_file( - "https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/main/flux1-schnell.safetensors" - ).to(torch_dtype) - elif "dev" in model_id: - self.mmdit = FluxTransformer2DModel.from_single_file( - "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" - ).to(torch_dtype) - img_ids = torch.zeros(height // 16, width // 16, 3) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(height // 16)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(width // 16)[None, :] - self.img_ids = repeat(img_ids, "h w c -> (h w) c") - self.txt_ids = torch.zeros(max_length, 3) - - def forward(self, img, txt, vec, step, timesteps, guidance_scale): - guidance_vec = guidance_scale.repeat(self.batch_size) - t_curr = torch.index_select(timesteps, 0, step) - t_prev = torch.index_select(timesteps, 0, step + 1) - t_vec = t_curr.repeat(self.batch_size) - - pred = self.mmdit( - hidden_states=img, - img_ids=self.img_ids, - encoder_hidden_states=txt, - txt_ids=self.txt_ids, - pooled_projections=vec, - timestep=t_vec, - guidance=guidance_vec, - return_dict=False, - )[0] - img = img + (t_prev - t_curr) * pred - return img - - -@torch.no_grad() -def get_flux_sampler_model( - local_dir, - hf_model_path, - img_height=1024, - img_width=1024, - compression_factor=8, - max_len=512, - model_dir="transformer", - torch_dtype=torch.float32, - bs=1, - cfg_mode=True, -): - - transformer_local_dir = get_local_path(local_dir, model_dir) - onnx_file = "model.onnx" - onnx_path = os.path.join(transformer_local_dir, onnx_file) - if os.path.exists(onnx_path): - return onnx_path - latent_h, latent_w = ( - img_height // compression_factor, - img_width // compression_factor, - ) - - if "schnell" in hf_model_path or cfg_mode == False: - model = FluxModelSchnell(torch_dtype=torch_dtype, model_id=hf_model_path) - config = model.mmdit.config - sample_inputs = ( - torch.randn( - bs, - (latent_h // 2) * (latent_w // 2), - config["in_channels"], - dtype=torch_dtype, - ), - torch.randn(bs, max_len, config["joint_attention_dim"], dtype=torch_dtype), - torch.randn(bs, config["pooled_projection_dim"], dtype=torch_dtype), - torch.tensor([0.0], dtype=torch.int64), - torch.randn(100, dtype=torch_dtype), - torch.empty(bs, dtype=torch_dtype), - ) - else: - model = FluxModelCFG(torch_dtype=torch_dtype, model_id=hf_model_path) - config = model.mmdit.config - cfg_bs = bs * 2 - sample_inputs = ( - torch.randn( - cfg_bs, - (latent_h // 2) * (latent_w // 2), - config["in_channels"], - dtype=torch_dtype, - ), - torch.randn( - cfg_bs, max_len, config["joint_attention_dim"], dtype=torch_dtype - ), - torch.randn(cfg_bs, config["pooled_projection_dim"], dtype=torch_dtype), - torch.tensor([0.0], dtype=torch.int64), - torch.randn(100, dtype=torch_dtype), - torch.randn(bs, dtype=torch_dtype), - ) - - input_names = ["img", "txt", "vec", "step", "timesteps", "guidance_scale"] - - if not os.path.isfile(onnx_path): - output_names = ["latent"] - - with torch.inference_mode(): - torch.onnx.export( - model, - sample_inputs, - onnx_path, - export_params=True, - input_names=input_names, - output_names=output_names, - do_constant_folding=False, - ) - - assert os.path.isfile(onnx_path) - - return onnx_path - - -def do_onnx_import(args, model_dir="transformer"): - if args.save_params_to: - params_path = args.save_params_to - else: - params_path = None - mlir_path = args.save_mlir_to - onnx_model_path = os.path.join(args.path, model_dir, "model.onnx") - process_args = [ - "python", - "-m", - "iree.compiler.tools.import_onnx", - onnx_model_path, - "-o", - mlir_path, - "--externalize-params", - "--large-model", - "--num-elements-threshold=32", - ] - if params_path: - process_args.extend(["--save-params-to", params_path]) - - subprocess.run(process_args) - return mlir_path, params_path - - -if __name__ == "__main__": - import argparse - import subprocess - - parser = argparse.ArgumentParser(description="Flux Sampler ONNX export") - - parser.add_argument( - "--hf_model_id", - type=str, - default="black-forest-labs/FLUX.1-dev", - choices=["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-dev"], - help="Model name", - ) - parser.add_argument("--path", type=str, default=".") - parser.add_argument( - "--dtype", - type=str, - default="float32", - choices=["float32", "bfloat16"], - help="Precision with which to export the model.", - ) - parser.add_argument( - "--height", - type=int, - default=1024, - ) - parser.add_argument( - "--width", - type=int, - default=1024, - ) - parser.add_argument( - "--batch_size", - type=int, - default=1, - ) - parser.add_argument( - "--cfg_mode", - type=int, - default=1, - choices=[0, 1], - help="Whether or not to use CFG mode (batch dim -> 2, enables conditioning, flux-dev/pro only)", - ) - parser.add_argument( - "--save_mlir_to", - type=str, - default=None, - ) - parser.add_argument( - "--save_params_to", - type=str, - default=None, - ) - args = parser.parse_args() - torch_dtypes = { - "float32": torch.float32, - "bfloat16": torch.bfloat16, - } - model_dir = "transformer" - - onnx_path = get_flux_sampler_model( - args.path, - args.hf_model_id, - img_height=args.height, - img_width=args.width, - compression_factor=8, - max_len=512, - model_dir=model_dir, - torch_dtype=torch_dtypes[args.dtype], - bs=args.batch_size, - cfg_mode=args.cfg_mode, - ) - if args.save_mlir_to or args.save_params_to: - mlir_path, params_path = do_onnx_import(args, model_dir=model_dir) diff --git a/sharktank/sharktank/torch_exports/flux/scheduler.py b/sharktank/sharktank/torch_exports/flux/scheduler.py deleted file mode 100644 index c89b52971..000000000 --- a/sharktank/sharktank/torch_exports/flux/scheduler.py +++ /dev/null @@ -1,50 +0,0 @@ -import math -import torch -from typing import Callable - - -def time_shift(mu: float, sigma: float, t: torch.Tensor): - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) - - -def get_lin_function( - x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 -) -> Callable[[float], float]: - m = (y2 - y1) / (x2 - x1) - b = y1 - m * x1 - return lambda x: m * x + b - - -def get_schedule( - num_steps: int, - image_seq_len: int, - base_shift: float = 0.5, - max_shift: float = 1.15, - shift: bool = True, -) -> list[float]: - # extra step for zero - timesteps = torch.linspace(1, 0, num_steps + 1) - - # shifting the schedule to favor high timesteps for higher signal images - if shift: - # eastimate mu based on linear estimation between two points - mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) - timesteps = time_shift(mu, 1.0, timesteps) - - return timesteps - - -class FluxScheduler(torch.nn.Module): - def __init__(self, max_length, torch_dtype, is_schnell=False): - super().__init__() - self.is_schnell = is_schnell - self.max_length = max_length - timesteps = [torch.empty((100), dtype=torch_dtype, requires_grad=False)] * 100 - for i in range(1, 100): - schedule = get_schedule(i, max_length, shift=not self.is_schnell) - timesteps[i] = torch.nn.functional.pad(schedule, (0, 99 - i), "constant", 0) - self.timesteps = torch.stack(timesteps, dim=0).clone().detach() - - def prepare(self, num_steps): - timesteps = self.timesteps[num_steps] - return timesteps diff --git a/sharktank/sharktank/torch_exports/flux/te.py b/sharktank/sharktank/torch_exports/flux/te.py deleted file mode 100644 index bd5bba640..000000000 --- a/sharktank/sharktank/torch_exports/flux/te.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -from torch import Tensor, nn - -from sharktank.types.theta import Theta, Dataset, torch_module_to_theta -from transformers import CLIPTextModel -from sharktank.models.clip import ClipTextModel, ClipTextConfig -from sharktank.models.t5 import T5Encoder, T5Config - -# Copied from https://github.com/black-forest-labs/flux -class HFEmbedder(nn.Module): - def __init__(self, version: str, max_length: int, **hf_kwargs): - super().__init__() - self.is_clip = version.startswith("openai") - self.max_length = max_length - self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" - - if self.is_clip: - self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( - version, **hf_kwargs - ) - #theta = torch_module_to_theta(self.hf_module) - config = ClipTextConfig.from_hugging_face_clip_text_model_config(self.hf_module.config) - config.dtype = torch.float32 - dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/clip.irpa") - self.hf_module = ClipTextModel(theta=dataset.root_theta, config=config) - else: - t5_dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/t5.irpa") - t5_config = T5Config.from_gguf_properties( - t5_dataset.properties, - feed_forward_proj="gated-gelu", - ) - self.hf_module = T5Encoder(theta=t5_dataset.root_theta, config=t5_config) - - self.hf_module = self.hf_module.eval().requires_grad_(False) - - def forward(self, input_ids) -> Tensor: - outputs = self.hf_module( - input_ids=input_ids, - attention_mask=None, - output_hidden_states=False, - ) - return outputs[self.output_key] From f6e4a8116e66e532e515611e6cabad82753bc399 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Wed, 29 Jan 2025 02:32:34 +0000 Subject: [PATCH 14/20] Refactor --- .../shortfin_apps/flux/components/service.py | 63 +++---------------- .../flux/examples/flux_dev_config_mixed.json | 1 + 2 files changed, 9 insertions(+), 55 deletions(-) diff --git a/shortfin/python/shortfin_apps/flux/components/service.py b/shortfin/python/shortfin_apps/flux/components/service.py index 04136cc5e..c539a6779 100644 --- a/shortfin/python/shortfin_apps/flux/components/service.py +++ b/shortfin/python/shortfin_apps/flux/components/service.py @@ -460,12 +460,6 @@ async def _prepare(self, device, requests): image_seq_len, 64, ] - # latents_shape = ( - # 1, - # channels, - # request.height // 8, - # request.width // 8, - # ) # Create and populate sample device array. generator = sfnp.RandomGenerator(seed) @@ -488,8 +482,6 @@ async def _prepare(self, device, requests): ) request.sample.copy_from(sample_transfer) - # sample_debug = torch.frombuffer(sample_transfer.items, dtype=torch.bfloat16) - # print(sample_debug) else: request.sample.copy_from(sample_host) @@ -550,7 +542,6 @@ async def _clip(self, device, requests): a = vec.for_transfer() a.copy_from(vec) await device - print(torch.frombuffer(a.items, dtype=torch.float32).shape) return @@ -566,7 +557,8 @@ async def _t5xxl(self, device, requests): break # Prepare tokenized input ids for t5xxl inference - bs_or_something =1 + # TODO(eagarvey): Refactor + bs_or_something = 1 t5xxl_inputs = [ sfnp.device_array.for_device( device, [bs_or_something, self.service.model_params.max_seq_len], sfnp.sint64 @@ -592,7 +584,7 @@ async def _t5xxl(self, device, requests): (txt,) = await fn(*t5xxl_inputs, fiber=self.fiber) await device for i in range(req_bs): - cfg_mult = 1 + cfg_mult = requests[i].cfg_mult requests[i].txt = txt.view(slice(i * cfg_mult, (i + 1) * cfg_mult)) return @@ -600,8 +592,8 @@ async def _t5xxl(self, device, requests): async def _denoise(self, device, requests): req_bs = len(requests) step_count = requests[0].steps - #cfg_mult = 2 if not self.service.model_params.is_schnell else 1 - cfg_mult = 1 + cfg_mult = self.requests[0].cfg_mult + # Produce denoised latents entrypoints = self.service.inference_functions[self.worker_index]["denoise"] if req_bs not in list(entrypoints.keys()): @@ -649,27 +641,16 @@ async def _denoise(self, device, requests): ), } # Send guidance scale to device. - print("hi") - await device gs_host = denoise_inputs["guidance_scale"].for_transfer() sample_host = sfnp.device_array.for_host( device, img_shape, self.service.model_params.sampler_dtype ) guidance_float = sfnp.device_array.for_host(device, [req_bs], sfnp.float32) - print("hi") await device - #for key, value in denoise_inputs.items(): - #host_arrs[key] = denoise_inputs[key].for_transfer() - #host_arrs[key].copy_from(denoise_inputs[key]) - #await device - #print(torch.frombuffer(host_arrs[key].items, dtype=torch.float32).shape) - for i in range(req_bs): guidance_float.view(i).items = [requests[i].guidance_scale] cfg_dim = i * cfg_mult - print("hi") - await device # Reshape and batch sample latent inputs on device. # Currently we just generate random latents in the desired shape. Rework for img2img. @@ -678,30 +659,21 @@ async def _denoise(self, device, requests): sample_host.view(slice(cfg_dim + rep, cfg_dim + rep + 1)).copy_from( req_samp ) - print("hi") - await device denoise_inputs["img"].view(slice(cfg_dim, cfg_dim + cfg_mult)).copy_from( sample_host ) - print("hi") - await device # Batch t5xxl hidden states. txt = requests[i].txt denoise_inputs["txt"].view(slice(cfg_dim, cfg_dim + cfg_mult)).copy_from( txt ) - print("hi") - await device # Batch CLIP projections. vec = requests[i].vec - #for nc in range(cfg_mult): - for nc in range(1): + for nc in range(cfg_mult): denoise_inputs["vec"].view(slice(nc, nc + 1)).copy_from(vec) - print("hi") - await device sfnp.convert( guidance_float, dtype=self.service.model_params.sampler_dtype, out=gs_host ) @@ -738,16 +710,6 @@ async def _denoise(self, device, requests): fns["sampler"], ) await device - # np_arrs = {} - # host_arrs = {} - # for key, value in denoise_inputs.items(): - # host_arrs[key] = denoise_inputs[key].for_transfer() - # host_arrs[key].copy_from(denoise_inputs[key]) - # await device - # np_arrs[key] = np.array(host_arrs[key]) - # for key, value in np_arrs.items(): - # np.save(f"{key}.npy", value) - (noise_pred,) = await fns["sampler"]( *denoise_inputs.values(), fiber=self.fiber ) @@ -806,13 +768,7 @@ async def _decode(self, device, requests): latents = sfnp.device_array.for_device( device, latents_shape, self.service.model_params.vae_dtype ) - # latents_host = sfnp.device_array.for_host( - # device, latents_shape, self.service.model_params.vae_dtype - # ) - # latents_host.copy_from(latents) - # print(latents_host) - # lat_arr = np.array(latents_host, dtype="float32") - # np.save("vae_in.npy", lat_arr) + for i in range(req_bs): latents.view(i).copy_from(requests[i].denoised_latents) @@ -830,15 +786,12 @@ async def _decode(self, device, requests): requests[0].height, requests[0].width, ] + await device images_host = sfnp.device_array.for_host( device, images_shape, self.service.model_params.vae_dtype ) await device images_host.copy_from(image) - # await device - # print(images_host) - # img_arr = np.array(images_host, dtype="float32") - # np.save("vae_out.npy", img_arr) await device for idx, req in enumerate(requests): req.image_array = images_host.view(idx) diff --git a/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json b/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json index 4c107452e..93c7dbfae 100644 --- a/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json +++ b/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json @@ -3,6 +3,7 @@ "is_schnell": false, "num_latents_channels": 16, "max_seq_len": 512, + "cfg_mult": 1, "clip_max_seq_len": 77, "clip_batch_sizes": [ 1 From 7fe1c1d9d8c55bad3fda9be69a91bd566e28fe83 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Fri, 31 Jan 2025 04:25:40 +0000 Subject: [PATCH 15/20] Fix up everything. BF16 still failing though --- .../sharktank/layers/configs/llm_configs.py | 2 +- sharktank/sharktank/models/clip/export.py | 2 +- .../sharktank/pipelines/flux/__init__.py | 5 +- sharktank/sharktank/pipelines/flux/export.py | 158 ----- .../sharktank/torch_exports/flux/README.md | 8 - .../sharktank/torch_exports/flux/export.py | 591 ------------------ .../pipelines/flux/flux_pipeline_test.py | 150 ----- .../flux/components/config_struct.py | 3 + .../flux/examples/flux_dev_config_mixed.json | 10 +- 9 files changed, 11 insertions(+), 918 deletions(-) delete mode 100644 sharktank/sharktank/pipelines/flux/export.py delete mode 100644 sharktank/sharktank/torch_exports/flux/README.md delete mode 100644 sharktank/sharktank/torch_exports/flux/export.py diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 55b97fd8e..be00c6d2e 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -290,7 +290,7 @@ class ClipTextConfig: output_hidden_states: bool = False use_return_dict: bool = True dtype: torch.dtype = torch.float32 - + @staticmethod def from_hugging_face_clip_text_model_config( config: "transformers.CLIPTextConfig", # type: ignore diff --git a/sharktank/sharktank/models/clip/export.py b/sharktank/sharktank/models/clip/export.py index aba0e730d..66fd99420 100644 --- a/sharktank/sharktank/models/clip/export.py +++ b/sharktank/sharktank/models/clip/export.py @@ -55,7 +55,7 @@ def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset: def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: PathLike, dtype=None): dataset = clip_text_model_to_dataset(model) if dtype: - dataset.root_theta = tdataset.root_theta.transform( + dataset.root_theta = dataset.root_theta.transform( functools.partial(set_float_dtype, dtype=dtype) ) dataset.save(output_path) diff --git a/sharktank/sharktank/pipelines/flux/__init__.py b/sharktank/sharktank/pipelines/flux/__init__.py index f427f423a..f32a3f280 100644 --- a/sharktank/sharktank/pipelines/flux/__init__.py +++ b/sharktank/sharktank/pipelines/flux/__init__.py @@ -1,10 +1,7 @@ """Flux text-to-image generation pipeline.""" from .flux_pipeline import FluxPipeline -from .export import export_flux_pipeline_mlir , export_flux_pipeline_iree_parameters __all__ = [ "FluxPipeline", - "export_flux_pipeline_mlir", - "export_flux_pipeline_iree_parameters", -] \ No newline at end of file +] diff --git a/sharktank/sharktank/pipelines/flux/export.py b/sharktank/sharktank/pipelines/flux/export.py deleted file mode 100644 index a1f9502ef..000000000 --- a/sharktank/sharktank/pipelines/flux/export.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Export utilities for Flux text-to-image pipeline.""" -import functools -from typing import Optional, Union -from pathlib import Path -import torch -from copy import copy -import logging - -from .flux_pipeline import FluxPipeline -from ...types import Dataset, dtype_to_serialized_short_name -from ...transforms.dataset import set_float_dtype -from iree.turbine.aot import FxProgramsBuilder, export -from ...models.t5.export import export_encoder_iree_parameters as export_t5_parameters -from ...models.clip.export import export_clip_text_model_iree_parameters -from ...models.flux.export import export_flux_transformer_iree_parameters -from ...models.vae.model import VaeDecoderModel -from ...models.clip import ClipTextModel, ClipTextConfig -from transformers import CLIPTokenizer, T5Tokenizer, CLIPTextModel as HfCLIPTextModel -from ...models.flux.flux import FluxModelV1, FluxParams - -__all__ = [ - "export_flux_pipeline_mlir", - "export_flux_pipeline_iree_parameters", -] - -def export_flux_pipeline_mlir( - model: Union[FluxPipeline, Path, str], - batch_sizes: list[int], - mlir_output_path: str, - dtype: torch.dtype, -): - """Export Flux pipeline to MLIR format. - - Args: - model: Either the FluxPipeline instance or path to model files - batch_sizes: List of batch sizes to export for - mlir_output_path: Output path for MLIR file - """ - if isinstance(model, (Path, str)): - model_parameter_path = Path(model) / f"exported_parameters_{dtype_to_serialized_short_name(dtype)}" - model = FluxPipeline( - t5_path=str(model_parameter_path / "t5.irpa"), - clip_path=str(model_parameter_path / "clip.irpa"), - transformer_path=str(model_parameter_path / "transformer.irpa"), - ae_path=str(model_parameter_path / "vae.irpa"), - dtype=dtype, - ) - - fxb = FxProgramsBuilder(model) - - for batch_size in batch_sizes: - # Create sample inputs with default dimensions - t5_prompt_ids = torch.zeros((batch_size, 128), dtype=torch.long) - clip_prompt_ids = torch.zeros((batch_size, 77), dtype=torch.long) - latents = model._get_noise( - 1, - 1024, - 1024, - seed=12345, - ) - - @fxb.export_program( - name=f"forward_bs{batch_size}", - args=(t5_prompt_ids, clip_prompt_ids, latents), - dynamic_shapes={}, - strict=False, - ) - def _(model, t5_prompt_ids, clip_prompt_ids, latents): - return model.forward( - t5_prompt_ids=t5_prompt_ids, - clip_prompt_ids=clip_prompt_ids, - latents=latents, - ) - - try: - output = export(fxb) - except Exception as e: - print(f"Error during export: {e}") - print(f"Model dtype: {model.dtype}") - print(f"Latents dtype: {latents.dtype}") - raise - output.save_mlir(mlir_output_path) - -def is_already_exported(output_path: Path) -> bool: - return output_path.exists() - -def export_flux_pipeline_iree_parameters( - model_path_or_dataset: str | Dataset, - output_path: str, - dtype: Optional[torch.dtype] = None, -): - """Export Flux pipeline parameters to IREE format. - - Args: - model_path_or_dataset: Path to model files or Dataset instance - output_path: Output path for IREE parameters - dtype: Optional dtype to convert parameters to - """ - # Ensure output_path is a Path object - output_path = Path(output_path) / f"exported_parameters_{dtype_to_serialized_short_name(dtype)}" - output_path.mkdir(parents=True, exist_ok=True) - - # Export T5 parameters - t5_path = Path(model_path_or_dataset) / "text_encoder_2/model.gguf" - t5_output_path = output_path / "t5.irpa" - print("hi") - if not is_already_exported(t5_output_path): - print("hello") - export_t5_parameters(t5_path, str(t5_output_path), dtype) - logging.info(f"Exported T5 parameters to {t5_output_path}") - else: - logging.info(f"Skipped T5 parameter export, already exists at {t5_output_path}") - - # Export CLIP parameters - clip_path = Path(model_path_or_dataset) / "text_encoder/model.irpa" - clip_output_path = output_path / "clip.irpa" - if not is_already_exported(clip_output_path): - clip_dataset = Dataset.load(clip_path) - # TODO: Refactor CLIP to not make the config rely on HuggingFace - hf_clip_model = HfCLIPTextModel.from_pretrained("/data/flux/FLUX.1-dev/text_encoder/") - clip_config = ClipTextConfig.from_hugging_face_clip_text_model_config(hf_clip_model.config) - clip_model = ClipTextModel(theta=clip_dataset.root_theta, config=clip_config) - export_clip_text_model_iree_parameters(clip_model, str(clip_output_path)) - logging.info(f"Exported CLIP parameters to {clip_output_path}") - else: - logging.info(f"Skipped CLIP parameter export, already exists at {clip_output_path}") - - # Export FluxTransformer parameters - transformer_path = Path(model_path_or_dataset) / "transformer/model.irpa" - transformer_output_path = output_path / "transformer.irpa" - if not is_already_exported(transformer_output_path): - transformer_dataset = Dataset.load(transformer_path) - transformer_model = FluxModelV1(theta=transformer_dataset.root_theta, params=FluxParams.from_hugging_face_properties(transformer_dataset.properties)) - export_flux_transformer_iree_parameters(transformer_model, str(transformer_output_path), dtype=dtype) - logging.info(f"Exported FluxTransformer parameters to {transformer_output_path}") - else: - logging.info(f"Skipped FluxTransformer parameter export, already exists at {transformer_output_path}") - - # Export VAE parameters - vae_path = Path(model_path_or_dataset) / "vae/model.irpa" - vae_output_path = output_path / "vae.irpa" - if not is_already_exported(vae_output_path): - vae_dataset = Dataset.load(vae_path) - vae_dataset.root_theta = vae_dataset.root_theta.transform( - functools.partial(set_float_dtype, dtype=dtype) - ) - vae_dataset.save(str(vae_output_path)) - logging.info(f"Exported VAE parameters to {vae_output_path}") - else: - logging.info(f"Skipped VAE parameter export, already exists at {vae_output_path}") - - logging.info(f"Completed Flux pipeline parameter export to {output_path}") \ No newline at end of file diff --git a/sharktank/sharktank/torch_exports/flux/README.md b/sharktank/sharktank/torch_exports/flux/README.md deleted file mode 100644 index db989c0e8..000000000 --- a/sharktank/sharktank/torch_exports/flux/README.md +++ /dev/null @@ -1,8 +0,0 @@ -# Flux.1 dynamo exports - -### Quick Start - -All the exports in this directory are done through `export.py`, with the CLI syntax as follows: -```shell -python sharktank/sharktank/dynamo_exports/flux/export.py --model="flux-dev" --component= --precision= -``` diff --git a/sharktank/sharktank/torch_exports/flux/export.py b/sharktank/sharktank/torch_exports/flux/export.py deleted file mode 100644 index 77738883a..000000000 --- a/sharktank/sharktank/torch_exports/flux/export.py +++ /dev/null @@ -1,591 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import os -import re -from dataclasses import dataclass -import math -import torch -from typing import Callable - -from einops import rearrange - -from iree.compiler.ir import Context -from iree.turbine.aot import * -from iree.turbine.dynamo.passes import ( - DEFAULT_DECOMPOSITIONS, -) - -from transformers import CLIPTextModel -from sharktank.models.clip import ClipTextModel, ClipTextConfig -from sharktank.models.t5 import T5Encoder, T5Config -from sharktank.models.flux.flux import FluxModelV1, FluxParams -from sharktank.models.vae.model import VaeDecoderModel -from sharktank.types.theta import Theta, Dataset, torch_module_to_theta - - - -@dataclass -class ModelSpec: - ae_params: AutoEncoderParams - ae_path: str | None - - -fluxconfigs = { - "flux-dev": ModelSpec( - ae_path=None, # os.getenv("AE"), - ae_params=AutoEncoderParams( - resolution=256, - in_channels=3, - ch=128, - out_ch=3, - ch_mult=[1, 2, 4, 4], - num_res_blocks=2, - z_channels=16, - scale_factor=0.3611, - shift_factor=0.1159, - height=1024, - width=1024, - ), - ), - "flux-schnell": ModelSpec( - ae_path=None, # os.getenv("AE"), - ae_params=AutoEncoderParams( - resolution=256, - in_channels=3, - ch=128, - out_ch=3, - ch_mult=[1, 2, 4, 4], - num_res_blocks=2, - z_channels=16, - scale_factor=0.3611, - shift_factor=0.1159, - height=1024, - width=1024, - ), - ), -} - -model_repo_map = { - "flux-dev": "black-forest-labs/FLUX.1-dev", - "flux-schnell": "black-forest-labs/FLUX.1-schnell", -} -model_file_map = { - "flux-dev": "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors", - "flux-schnell": "https://huggingface.co/black-forest-labs/FLUX.1-schnell/blob/main/flux1-schnell.safetensors", -} - -torch_dtypes = { - "fp16": torch.float16, - "fp32": torch.float32, - "bf16": torch.bfloat16, - "float16": torch.float16, - "float32": torch.float32, -} - - -def create_safe_name(hf_model_name, model_name_str=""): - if not model_name_str: - model_name_str = "" - if model_name_str != "" and (not model_name_str.startswith("_")): - model_name_str = "_" + model_name_str - - safe_name = hf_model_name.split("/")[-1].strip() + model_name_str - safe_name = re.sub("-", "_", safe_name) - safe_name = re.sub("\.", "_", safe_name) - return safe_name - - -class FluxDenoiseStepModel(torch.nn.Module): - def __init__( - self, - theta, - params, - batch_size=1, - max_length=512, - height=1024, - width=1024, - ): - super().__init__() - self.mmdit = FluxModelV1(theta=theta, params=params) - self.batch_size = batch_size - img_ids = torch.zeros(height // 16, width // 16, 3) - img_ids[..., 1] = img_ids[..., 1] + torch.arange(height // 16)[:, None] - img_ids[..., 2] = img_ids[..., 2] + torch.arange(width // 16)[None, :] - self.img_ids = img_ids.reshape(1, height * width // 256, 3) - self.txt_ids = torch.zeros(1, max_length, 3) - - def forward(self, img, txt, vec, step, timesteps, guidance_scale): - guidance_vec = guidance_scale.repeat(self.batch_size) - t_curr = torch.index_select(timesteps, 0, step) - t_prev = torch.index_select(timesteps, 0, step + 1) - t_vec = t_curr.repeat(self.batch_size) - - pred = self.mmdit( - img=img, - img_ids=self.img_ids, - txt=txt, - txt_ids=self.txt_ids, - y=vec, - timesteps=t_vec, - guidance=guidance_vec, - ) - # TODO: Use guidance scale - # pred_uncond, pred = torch.chunk(pred, 2, dim=0) - # pred = pred_uncond + guidance_scale * (pred - pred_uncond) - img = img + (t_prev - t_curr) * pred - return img - - -@torch.no_grad() -def get_flux_transformer_model( - hf_model_path, - img_height=1024, - img_width=1024, - compression_factor=8, - max_len=512, - torch_dtype=torch.float32, - bs=1, -): - # DNS: refactor file to select datatype - transformer_dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/transformer.irpa") - model = FluxDenoiseStepModel(theta=transformer_dataset.root_theta, params=FluxParams.from_hugging_face_properties(transformer_dataset.properties)) - sample_args, sample_kwargs = model.mmdit.sample_inputs() - sample_inputs = ( - sample_kwargs["img"], - sample_kwargs["txt"], - sample_kwargs["y"], - torch.full((bs,), 1, dtype=torch.int64), - torch.full((100,), 1, dtype=torch_dtype), # TODO: non-dev timestep sizes - sample_kwargs["guidance"], - ) - return model, sample_inputs - - -def get_flux_model_and_inputs( - hf_model_name, precision, batch_size, max_length, height, width -): - dtype = torch_dtypes[precision] - return get_flux_transformer_model( - hf_model_name, height, width, 8, max_length, dtype, batch_size - ) - -# Copied from https://github.com/black-forest-labs/flux -class HFEmbedder(nn.Module): - def __init__(self, version: str, max_length: int, **hf_kwargs): - super().__init__() - self.is_clip = version.startswith("openai") - self.max_length = max_length - self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" - - if self.is_clip: - self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( - version, **hf_kwargs - ) - # DNS: Refactor to not rely on huggingface - config = ClipTextConfig.from_hugging_face_clip_text_model_config(self.hf_module.config) - config.dtype = torch.float32 - dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/clip.irpa") - self.hf_module = ClipTextModel(theta=dataset.root_theta, config=config) - else: - t5_dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/t5.irpa") - t5_config = T5Config.from_gguf_properties( - t5_dataset.properties, - feed_forward_proj="gated-gelu", - ) - self.hf_module = T5Encoder(theta=t5_dataset.root_theta, config=t5_config) - - self.hf_module = self.hf_module.eval().requires_grad_(False) - - def forward(self, input_ids) -> Tensor: - outputs = self.hf_module( - input_ids=input_ids, - attention_mask=None, - output_hidden_states=False, - ) - return outputs[self.output_key] - -def get_te_model_and_inputs( - hf_model_name, component, precision, batch_size, max_length -): - match component: - case "clip": - te = HFEmbedder( - "openai/clip-vit-large-patch14", - max_length=77, - torch_dtype=torch.float32, - ) - clip_ids_shape = ( - batch_size, - 77, - ) - input_args = [ - torch.ones(clip_ids_shape, dtype=torch.int64), - ] - return te, input_args - case "t5xxl": - te = HFEmbedder( - "t5xxl", - max_length=512, - torch_dtype=torch.float32, - ) - clip_ids_shape = ( - batch_size, - 512, #DNS - ) - input_args = [ - torch.ones(clip_ids_shape, dtype=torch.int64), - ] - return te, input_args - - -class FluxAEWrapper(torch.nn.Module): - def __init__(self, height=1024, width=1024, precision="fp32"): - super().__init__() - dtype = torch_dtypes[precision] - dataset = Dataset.load("/data/flux/flux/FLUX.1-dev/exported_parameters_f32/vae.irpa") - self.ae = VaeDecoderModel.from_dataset(dataset) - self.height = height - self.width = width - - def forward(self, z): - d_in = rearrange( - z, - "b (h w) (c ph pw) -> b c (h ph) (w pw)", - h=math.ceil(self.height / 16), - w=math.ceil(self.width / 16), - ph=2, - pw=2, - ) - #d_in = d_in / self.ae.config.scaling_factor + self.ae.config.shift_factor - #return self.ae.decode(d_in, return_dict=False)[0].clamp(-1, 1) - return self.ae.forward(d_in) - - -def get_ae_model_and_inputs(hf_model_name, precision, batch_size, height, width): - dtype = torch_dtypes[precision] - aeparams = fluxconfigs[hf_model_name].ae_params - aeparams.height = height - aeparams.width = width - ae = FluxAEWrapper(height, width, precision).to(dtype) - latents_shape = ( - batch_size, - int(height * width / 256), - 64, - ) - img_shape = ( - 1, - aeparams.in_channels, - int(height), - int(width), - ) - encode_inputs = [ - torch.empty(img_shape, dtype=dtype), - ] - decode_inputs = [ - torch.empty(latents_shape, dtype=dtype), - ] - return ae, encode_inputs, decode_inputs - - -def time_shift(mu: float, sigma: float, t: torch.Tensor): - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) - - -def get_lin_function( - x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 -) -> Callable[[float], float]: - m = (y2 - y1) / (x2 - x1) - b = y1 - m * x1 - return lambda x: m * x + b - - -def get_schedule( - num_steps: int, - image_seq_len: int, - base_shift: float = 0.5, - max_shift: float = 1.15, - shift: bool = True, -) -> list[float]: - # extra step for zero - timesteps = torch.linspace(1, 0, num_steps + 1) - - # shifting the schedule to favor high timesteps for higher signal images - if shift: - # eastimate mu based on linear estimation between two points - mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) - timesteps = time_shift(mu, 1.0, timesteps) - - return timesteps - - -class FluxScheduler(torch.nn.Module): - def __init__(self, max_length, torch_dtype, is_schnell=False): - super().__init__() - self.is_schnell = is_schnell - self.max_length = max_length - timesteps = [torch.empty((100), dtype=torch_dtype, requires_grad=False)] * 100 - for i in range(1, 100): - schedule = get_schedule(i, max_length, shift=not self.is_schnell) - timesteps[i] = torch.nn.functional.pad(schedule, (0, 99 - i), "constant", 0) - self.timesteps = torch.stack(timesteps, dim=0).clone().detach() - - def prepare(self, num_steps): - timesteps = self.timesteps[num_steps] - return timesteps - -def get_scheduler_model_and_inputs(hf_model_name, max_length, precision): - is_schnell = "schnell" in hf_model_name - mod = FluxScheduler( - max_length=max_length, - torch_dtype=torch_dtypes[precision], - is_schnell=is_schnell, - ) - sample_inputs = (torch.empty(1, dtype=torch.int64),) - return mod, sample_inputs - - -@torch.no_grad() -def export_flux_model( - hf_model_name, - component, - batch_size, - height, - width, - precision="fp16", - max_length=512, - compile_to="torch", - external_weights=None, - external_weight_path=None, - decomp_attn=False, -): - dtype = torch_dtypes[precision] - decomp_list = [] - if decomp_attn == True: - decomp_list = [ - torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops.aten.scaled_dot_product_attention, - ] - with decompositions.extend_aot_decompositions( - from_current=True, - add_ops=decomp_list, - ): - if component == "mmdit": - model, sample_inputs = get_flux_model_and_inputs( - hf_model_name, precision, batch_size, max_length, height, width - ) - print(sample_inputs) - - fxb = FxProgramsBuilder(model) - - @fxb.export_program( - args=(sample_inputs,), - ) - def _forward( - module, - inputs, - ): - return module.forward(*inputs) - - class CompiledFluxTransformer(CompiledModule): - run_forward = _forward - - if external_weights: - externalize_module_parameters(model) - save_module_parameters(external_weight_path, model) - - inst = CompiledFluxTransformer(context=Context(), import_to="IMPORT") - - module = CompiledModule.get_mlir_module(inst) - - elif component == "clip": - model, sample_inputs = get_te_model_and_inputs( - hf_model_name, component, precision, batch_size, max_length - ) - - fxb = FxProgramsBuilder(model) - - @fxb.export_program( - args=(sample_inputs,), - ) - def _forward( - module, - inputs, - ): - return module.forward(*inputs) - - class CompiledFluxTextEncoder(CompiledModule): - encode_prompts = _forward - - if external_weights: - externalize_module_parameters(model) - save_module_parameters(external_weight_path, model) - - inst = CompiledFluxTextEncoder(context=Context(), import_to="IMPORT") - - module = CompiledModule.get_mlir_module(inst) - elif component == "t5xxl": - model, sample_inputs = get_te_model_and_inputs( - hf_model_name, component, precision, batch_size, max_length - ) - - fxb = FxProgramsBuilder(model) - - @fxb.export_program( - args=(sample_inputs,), - ) - def _forward( - module, - inputs, - ): - return module.forward(*inputs) - - class CompiledFluxTextEncoder2(CompiledModule): - encode_prompts = _forward - - if external_weights: - externalize_module_parameters(model) - save_module_parameters(external_weight_path, model) - - inst = CompiledFluxTextEncoder(context=Context(), import_to="IMPORT") - - module = CompiledModule.get_mlir_module(inst) - elif component == "vae": - model, encode_inputs, decode_inputs = get_ae_model_and_inputs( - hf_model_name, precision, batch_size, height, width - ) - - fxb = FxProgramsBuilder(model) - - @fxb.export_program( - args=(decode_inputs,), - ) - def _decode( - module, - inputs, - ): - return module.forward(*inputs) - - class CompiledFluxAutoEncoder(CompiledModule): - # encode = _encode - decode = _decode - - if external_weights: - externalize_module_parameters(model) - save_module_parameters(external_weight_path, model) - - inst = CompiledFluxAutoEncoder(context=Context(), import_to="IMPORT") - - module = CompiledModule.get_mlir_module(inst) - - elif component == "scheduler": - model, sample_inputs = get_scheduler_model_and_inputs( - hf_model_name, max_length, precision - ) - - fxb = FxProgramsBuilder(model) - - @fxb.export_program( - args=(sample_inputs,), - ) - def _prepare( - module, - inputs, - ): - return module.prepare(*inputs) - - class CompiledFlowScheduler(CompiledModule): - run_prep = _prepare - - inst = CompiledFlowScheduler(context=Context(), import_to="IMPORT") - - module = CompiledModule.get_mlir_module(inst) - - module_str = str(module) - return module_str - - -def get_filename(args): - match args.component: - case "mmdit": - return create_safe_name( - args.model, - f"mmdit_bs{args.batch_size}_{args.max_length}_{args.height}x{args.width}_{args.precision}", - ) - case "clip": - return create_safe_name( - args.model, f"clip_bs{args.batch_size}_77_{args.precision}" - ) - case "t5xxl": - return create_safe_name( - args.model, f"t5xxl_bs{args.batch_size}_256_{args.precision}" - ) - case "scheduler": - return create_safe_name( - args.model, - f"scheduler_bs{args.batch_size}_{args.max_length}_{args.precision}", - ) - case "vae": - return create_safe_name( - args.model, - f"vae_bs{args.batch_size}_{args.height}x{args.width}_{args.precision}", - ) - - -if __name__ == "__main__": - import logging - import argparse - - logging.basicConfig(level=logging.DEBUG) - p = argparse.ArgumentParser() - p.add_argument( - "--model", - default="flux-schnell", - choices=["flux-dev", "flux-schnell", "flux-pro"], - ) - p.add_argument( - "--component", - default="mmdit", - choices=["mmdit", "clip", "t5xxl", "scheduler", "vae"], - ) - p.add_argument("--batch_size", default=1) - p.add_argument("--height", default=1024) - p.add_argument("--width", default=1024) - p.add_argument("--precision", default="fp32") - p.add_argument("--max_length", default=512) - p.add_argument("--external_weights", default="irpa") - p.add_argument("--external_weights_file", default=None) - p.add_argument("--decomp_attn", action="store_true") - args = p.parse_args() - - if args.external_weights and not args.external_weights_file: - args.external_weights_file = ( - create_safe_name( - args.model, - args.component + "_" + args.precision, - ) - + "." - + args.external_weights - ) - safe_name = get_filename(args) - mod_str = export_flux_model( - args.model, - args.component, - args.batch_size, - args.height, - args.width, - args.precision, - args.max_length, - "mlir", - args.external_weights, - args.external_weights_file, - args.decomp_attn, - ) - - with open(f"{safe_name}.mlir", "w+") as f: - f.write(mod_str) - print("Saved to", safe_name + ".mlir") diff --git a/sharktank/tests/pipelines/flux/flux_pipeline_test.py b/sharktank/tests/pipelines/flux/flux_pipeline_test.py index ef9237c1e..e0bb696ae 100644 --- a/sharktank/tests/pipelines/flux/flux_pipeline_test.py +++ b/sharktank/tests/pipelines/flux/flux_pipeline_test.py @@ -152,153 +152,3 @@ def testFluxPipelineBF16(self): ) -@pytest.mark.usefixtures("caching", "get_model_artifacts", "path_prefix") -class FluxPipelineIreeTest(TempDirTestBase): - def setUp(self): - super().setUp() - if self.path_prefix is None: - self.path_prefix = f"{self._temp_dir}/" - - def runTestFluxPipelineIreeCompare( - self, - reference_dtype: torch.dtype, - target_dtype: torch.dtype, - atol: Optional[float] = None, - rtol: Optional[float] = None, - ): - """Compare IREE pipeline against eager execution.""" - # Create input tokens - t5_tokenizer = T5Tokenizer.from_pretrained("/data/flux/FLUX.1-dev/tokenizer_2/") - clip_tokenizer = CLIPTokenizer.from_pretrained("/data/flux/FLUX.1-dev/tokenizer/") - - prompt = "a photo of a forest with mist" - t5_prompt_ids = torch.tensor([t5_tokenizer(prompt).input_ids], dtype=torch.long) - clip_prompt_ids = torch.tensor([clip_tokenizer(prompt).input_ids], dtype=torch.long) - # latents = reference_model._get_noise( - # 1, - # 1024, - # 1024, - # seed=12345, - # ).to(dtype=target_dtype) # TODO: it isn't great to be getting this from the reference model - - # input_args = OrderedDict([ - # ("t5_prompt_ids", t5_prompt_ids), - # ("clip_prompt_ids", clip_prompt_ids), - # ("latents", latents) - # ]) - batch_size = t5_prompt_ids.shape[0] - - # Export and compile for IREE - target_dtype_name = dtype_to_serialized_short_name(target_dtype) - target_path_prefix = f"{self.path_prefix}flux_pipeline_{target_dtype_name}" - - parameters_path = "/data/flux/FLUX.1-dev/" - if not self.caching or not os.path.exists(mlir_path): - export_flux_pipeline_iree_parameters( - "/data/flux/FLUX.1-dev/", - parameters_path, - dtype=target_dtype, - ) - - mlir_path = f"{target_path_prefix}.mlir" - if not self.caching or not os.path.exists(mlir_path): - export_flux_pipeline_mlir( - parameters_path, - batch_sizes=[batch_size], - mlir_output_path=mlir_path, - dtype=target_dtype - ) - - iree_module_path = f"{target_path_prefix}.vmfb" - if not self.caching or not os.path.exists(iree_module_path): - iree.compiler.compile_file( - mlir_path, - output_file=iree_module_path, - extra_args=[ - "--iree-hal-target-device=hip", - "--iree-hip-target=gfx942", - "--iree-opt-const-eval=false", - "--iree-opt-strip-assertions=true", - "--iree-global-opt-propagate-transposes=true", - "--iree-dispatch-creation-enable-fuse-horizontal-contractions=true", - "--iree-dispatch-creation-enable-aggressive-fusion=true", - "--iree-opt-aggressively-propagate-transposes=true", - "--iree-opt-outer-dim-concat=true", - "--iree-vm-target-truncate-unsupported-floats", - "--iree-llvmgpu-enable-prefetch=true", - "--iree-opt-data-tiling=false", - "--iree-codegen-gpu-native-math-precision=true", - "--iree-codegen-llvmgpu-use-vector-distribution", - "--iree-hip-waves-per-eu=2", - "--iree-execution-model=async-external", - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics)", - ], - ) - - # Run with IREE - iree_devices = get_iree_devices(driver="hip", device_count=1) - iree_module, iree_vm_context, iree_vm_instance = load_iree_module( - module_path=iree_module_path, - devices=iree_devices, - parameters_path=parameters_path, - ) - iree_args = prepare_iree_module_function_args( - args=flatten_for_iree_signature(input_args), - devices=iree_devices, - ) - iree_result = iree_to_torch( - *run_iree_module_function( - module=iree_module, - vm_context=iree_vm_context, - args=iree_args, - driver="hip", - function_name=f"forward_bs{batch_size}", - trace_path_prefix=f"{target_path_prefix}_iree_", - ) - ) - - # Reference model - reference_model = FluxPipeline( - t5_path="/data/t5-v1_1-xxl/model.gguf", - clip_path="/data/flux/FLUX.1-dev/text_encoder/model.irpa", - transformer_path="/data/flux/FLUX.1-dev/transformer/model.irpa", - ae_path="/data/flux/FLUX.1-dev/vae/model.irpa", - t5_tokenizer_path="/data/flux/FLUX.1-dev/tokenizer_2/", - clip_tokenizer_path="/data/flux/FLUX.1-dev/tokenizer/", - dtype=reference_dtype, - ) - # reference_result = reference_model.forward(t5_prompt_ids, clip_prompt_ids, latents) - - # Reformat the result for direct comparison - iree_result = [ - ops.to(iree_result[i], dtype=reference_result[i].dtype) - for i in range(len(reference_result)) - ] - - - - torch.testing.assert_close(reference_result, iree_result, atol=atol, rtol=rtol) - - @with_flux_data - def testFluxPipelineIreeF32(self): - """Test F32 IREE pipeline against eager execution.""" - self.runTestFluxPipelineIreeCompare( - reference_dtype=torch.float32, - target_dtype=torch.float32, - atol=1e-4, - rtol=2.0e-3, - ) - - @pytest.mark.xfail( - raises=AssertionError, - reason="BF16 vs F32 accuracy needs investigation", - ) - @with_flux_data - def testFluxPipelineIreeBF16(self): - """Test BF16 IREE pipeline against F16 eager execution.""" - self.runTestFluxPipelineIreeCompare( - reference_dtype=torch.float16, - target_dtype=torch.bfloat16, - atol=1e-2, - rtol=1.6e-2, - ) diff --git a/shortfin/python/shortfin_apps/flux/components/config_struct.py b/shortfin/python/shortfin_apps/flux/components/config_struct.py index d99f8f793..0e23285e1 100644 --- a/shortfin/python/shortfin_apps/flux/components/config_struct.py +++ b/shortfin/python/shortfin_apps/flux/components/config_struct.py @@ -78,6 +78,9 @@ class ModelParams: # ABI of the module. module_abi_version: int = 1 + # TODO: Understand when this should be a value other than 1 + cfg_mult: int = 1 + @property def max_clip_batch_size(self) -> int: return self.clip_batch_sizes[-1] diff --git a/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json b/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json index 93c7dbfae..343bdea28 100644 --- a/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json +++ b/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json @@ -8,24 +8,24 @@ "clip_batch_sizes": [ 1 ], - "clip_dtype": "float32", + "clip_dtype": "bfloat16", "clip_module_name": "compiled_flux_text_encoder", "t5xxl_batch_sizes": [ 1 ], - "t5xxl_dtype": "float32", - "t5xxl_module_name": "compiled_flux_text_encoder_2", + "t5xxl_dtype": "bfloat16", + "t5xxl_module_name": "compiled_flux_text_encoder2", "t5xxl_fn_name": "encode_prompts", "sampler_batch_sizes": [ 1 ], - "sampler_dtype": "float32", + "sampler_dtype": "bfloat16", "sampler_module_name": "compiled_flux_transformer", "sampler_fn_name": "run_forward", "vae_batch_sizes": [ 1 ], - "vae_dtype": "float32", + "vae_dtype": "bfloat16", "vae_module_name": "compiled_flux_auto_encoder", "vae_fn_name": "decode", "dims": [ From 53bf17a08b4d2d89c4dbce44761959ca3475ae72 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Thu, 6 Feb 2025 04:00:10 +0000 Subject: [PATCH 16/20] Fix grayscale bug and allow mixing different types of decoders with the denoiser --- .../shortfin_apps/flux/components/service.py | 54 ++++++++++++++++--- 1 file changed, 47 insertions(+), 7 deletions(-) diff --git a/shortfin/python/shortfin_apps/flux/components/service.py b/shortfin/python/shortfin_apps/flux/components/service.py index c539a6779..3b903e2bd 100644 --- a/shortfin/python/shortfin_apps/flux/components/service.py +++ b/shortfin/python/shortfin_apps/flux/components/service.py @@ -592,7 +592,7 @@ async def _t5xxl(self, device, requests): async def _denoise(self, device, requests): req_bs = len(requests) step_count = requests[0].steps - cfg_mult = self.requests[0].cfg_mult + cfg_mult = requests[0].cfg_mult # Produce denoised latents entrypoints = self.service.inference_functions[self.worker_index]["denoise"] @@ -659,21 +659,59 @@ async def _denoise(self, device, requests): sample_host.view(slice(cfg_dim + rep, cfg_dim + rep + 1)).copy_from( req_samp ) - denoise_inputs["img"].view(slice(cfg_dim, cfg_dim + cfg_mult)).copy_from( sample_host ) # Batch t5xxl hidden states. txt = requests[i].txt - denoise_inputs["txt"].view(slice(cfg_dim, cfg_dim + cfg_mult)).copy_from( - txt - ) + if ( + self.service.model_params.t5xxl_dtype + != self.service.model_params.sampler_dtype + ): + inter = sfnp.device_array.for_host( + device, txt_shape, dtype=self.service.model_params.sampler_dtype + ) + host = sfnp.device_array.for_host( + device, txt_shape, dtype=self.service.model_params.t5xxl_dtype + ) + host.view(slice(cfg_dim, cfg_dim + cfg_mult)).copy_from(txt) + await device + sfnp.convert( + host, + dtype=self.service.model_params.sampler_dtype, + out=inter, + ) + denoise_inputs["txt"].view(slice(cfg_dim, cfg_dim + cfg_mult)).copy_from(inter) + else: + denoise_inputs["txt"].view(slice(cfg_dim, cfg_dim + cfg_mult)).copy_from( + txt + ) # Batch CLIP projections. vec = requests[i].vec - for nc in range(cfg_mult): - denoise_inputs["vec"].view(slice(nc, nc + 1)).copy_from(vec) + if ( + self.service.model_params.t5xxl_dtype + != self.service.model_params.sampler_dtype + ): + for nc in range(cfg_mult): + inter = sfnp.device_array.for_host( + device, vec_shape, dtype=self.service.model_params.sampler_dtype + ) + host = sfnp.device_array.for_host( + device, vec_shape, dtype=self.service.model_params.clip_dtype + ) + host.view(slice(nc, nc + 1)).copy_from(vec) + await device + sfnp.convert( + host, + dtype=self.service.model_params.sampler_dtype, + out=inter, + ) + denoise_inputs["vec"].view(slice(nc, nc + 1)).copy_from(inter) + else: + for nc in range(cfg_mult): + denoise_inputs["vec"].view(slice(nc, nc + 1)).copy_from(vec) sfnp.convert( guidance_float, dtype=self.service.model_params.sampler_dtype, out=gs_host ) @@ -778,6 +816,7 @@ async def _decode(self, device, requests): fn, "".join([f"\n 0: {latents.shape}"]), ) + await device (image,) = await fn(latents, fiber=self.fiber) await device images_shape = [ @@ -810,6 +849,7 @@ async def _postprocess(self, device, requests): device, image_shape, self.service.model_params.vae_dtype ) images_planar.copy_from(req.image_array) + await device permuted = sfnp.device_array.for_host( device, out_shape, self.service.model_params.vae_dtype ) From 63d46f42eac72d94a3459d28d1f11fa34d0e7816 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Thu, 6 Feb 2025 22:50:26 +0000 Subject: [PATCH 17/20] Latest fixes --- shortfin/CMakeLists.txt | 9 +- shortfin/dev_me.py | 2 +- shortfin/python/array_host_ops.cc | 109 ++++++++++++++++++ .../shortfin_apps/flux/components/messages.py | 2 + .../flux/examples/flux_dev_config_mixed.json | 4 +- 5 files changed, 122 insertions(+), 4 deletions(-) diff --git a/shortfin/CMakeLists.txt b/shortfin/CMakeLists.txt index 0c69a08a0..c2cbb917c 100644 --- a/shortfin/CMakeLists.txt +++ b/shortfin/CMakeLists.txt @@ -164,6 +164,13 @@ if(SHORTFIN_BUNDLE_DEPS) GIT_REPOSITORY https://github.com/gabime/spdlog.git GIT_TAG 8e5613379f5140fefb0b60412fbf1f5406e7c7f8 # v1.15.0 ) + + ## xsimd: required for bf16 + FetchContent_Declare( + xsimd + GIT_REPOSITORY https://github.com/xtensor-stack/xsimd.git + GIT_TAG 148fa1328c674ab2ee1d03b1460204671ae82a8b # v13.1.0 + ) ## xtl: required for xtensor FetchContent_Declare( @@ -185,7 +192,7 @@ if(SHORTFIN_BUNDLE_DEPS) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSPDLOG_SHARED_LIB -Dspdlog_EXPORTS") message(STATUS "Fetching bundled projects") list(APPEND CMAKE_MESSAGE_INDENT " ") - FetchContent_MakeAvailable(fmt spdlog xtl xtensor) + FetchContent_MakeAvailable(fmt spdlog xsimd xtl xtensor) shortfin_pop_bundled_lib_options() list(POP_BACK CMAKE_MESSAGE_INDENT) else() diff --git a/shortfin/dev_me.py b/shortfin/dev_me.py index c2d95c220..87f18ec7f 100755 --- a/shortfin/dev_me.py +++ b/shortfin/dev_me.py @@ -253,7 +253,7 @@ def configure_mode(env_info: EnvInfo, args): def build_mode(env_info: EnvInfo): print("Building") for build_dir in env_info.configured_dirs: - subprocess.check_call([env_info.cmake_exe, "--build", str(build_dir)]) + subprocess.check_call([env_info.cmake_exe, "--build", str(build_dir), "--verbose"]) if __name__ == "__main__": diff --git a/shortfin/python/array_host_ops.cc b/shortfin/python/array_host_ops.cc index 3e2a8ebe3..302cdbd45 100644 --- a/shortfin/python/array_host_ops.cc +++ b/shortfin/python/array_host_ops.cc @@ -12,6 +12,101 @@ #include "xtensor/xsort.hpp" #include "xtl/xhalf_float.hpp" + +#ifndef BFLOAT16_HPP +#define BFLOAT16_HPP + +#include +#include +#include +#include + +// A minimal bfloat16 type as a trivial wrapper over a 16-bit value. +struct bfloat16_t { + uint16_t value; + + // Default constructor: zero. + constexpr bfloat16_t() noexcept : value(0) {} + + explicit constexpr bfloat16_t(float f) noexcept { + // reinterpret f as uint32_t + uint32_t temp = std::bit_cast(f); + // drop the lower 16 bits + value = static_cast(temp >> 16); + } + + template && + !std::is_same_v>> + constexpr bfloat16_t(T value) noexcept : bfloat16_t(static_cast(value)) {} + + constexpr operator float() const noexcept { + // shift stored bits to the high half of a 32-bit word + uint32_t temp = static_cast(value) << 16; + return std::bit_cast(temp); + } + + // Arithmetic operators (implemented via conversion to float) + constexpr bfloat16_t operator+(const bfloat16_t& other) const noexcept { + return bfloat16_t(float(*this) + float(other)); + } + constexpr bfloat16_t operator-(const bfloat16_t& other) const noexcept { + return bfloat16_t(float(*this) - float(other)); + } + constexpr bfloat16_t operator*(const bfloat16_t& other) const noexcept { + return bfloat16_t(float(*this) * float(other)); + } + constexpr bfloat16_t operator/(const bfloat16_t& other) const noexcept { + return bfloat16_t(float(*this) / float(other)); + } + + constexpr bfloat16_t& operator+=(const bfloat16_t& other) noexcept { + *this = *this + other; + return *this; + } + constexpr bfloat16_t& operator-=(const bfloat16_t& other) noexcept { + *this = *this - other; + return *this; + } + constexpr bfloat16_t& operator*=(const bfloat16_t& other) noexcept { + *this = *this * other; + return *this; + } + constexpr bfloat16_t& operator/=(const bfloat16_t& other) noexcept { + *this = *this / other; + return *this; + } + + // Comparison operators (using conversion to float) + constexpr bool operator==(const bfloat16_t& other) const noexcept { + return float(*this) == float(other); + } + constexpr bool operator!=(const bfloat16_t& other) const noexcept { + return !(*this == other); + } + constexpr bool operator<(const bfloat16_t& other) const noexcept { + return float(*this) < float(other); + } + constexpr bool operator<=(const bfloat16_t& other) const noexcept { + return float(*this) <= float(other); + } + constexpr bool operator>(const bfloat16_t& other) const noexcept { + return float(*this) > float(other); + } + constexpr bool operator>=(const bfloat16_t& other) const noexcept { + return float(*this) >= float(other); + } +}; + +// Mark bfloat16_t as a trivial, standard-layout type so that xtensor can use it. +namespace std { + template<> struct is_trivial : std::true_type {}; + template<> struct is_standard_layout : std::true_type {}; + template<> struct is_trivially_copyable : std::true_type {}; +} + +#endif // BFLOAT16_HPP + + using namespace shortfin::array; namespace shortfin::python { @@ -191,6 +286,7 @@ struct ConvertFunctor { } switch (dtype) { SF_STORE_CASE(float16, half_float::half); + SF_STORE_CASE(bfloat16, bfloat16_t); SF_STORE_CASE(float32, float); SF_STORE_CASE(float64, double); SF_STORE_CASE(uint8, uint8_t); @@ -210,6 +306,7 @@ struct ConvertFunctor { switch (input.dtype()) { SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t); SF_UNARY_THUNK_CASE(float32, float); SF_UNARY_THUNK_CASE(float64, double); SF_UNARY_THUNK_CASE(uint8, uint8_t); @@ -264,6 +361,7 @@ struct ConvertRoundFunctor { switch (input.dtype()) { SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t); SF_UNARY_THUNK_CASE(float32, float); default: throw std::invalid_argument(fmt::format( @@ -308,6 +406,7 @@ struct ConvertCeilFunctor { switch (input.dtype()) { SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t); SF_UNARY_THUNK_CASE(float32, float); default: throw std::invalid_argument(fmt::format( @@ -352,6 +451,7 @@ struct ConvertFloorFunctor { switch (input.dtype()) { SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t); SF_UNARY_THUNK_CASE(float32, float); default: throw std::invalid_argument(fmt::format( @@ -396,6 +496,7 @@ struct ConvertTruncFunctor { switch (input.dtype()) { SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t); SF_UNARY_THUNK_CASE(float32, float); default: throw std::invalid_argument(fmt::format( @@ -525,6 +626,11 @@ half_float::half ConvertPyToEltTy(py::handle py_value, half_float::half zero) { return static_cast(py::cast(py_value)); } +bfloat16_t ConvertPyToEltTy(py::handle py_value, bfloat16_t zero) { + // Python can't cast directly to half so first go to double. + return static_cast(py::cast(py_value)); +} + struct AddFunctor { template static auto Invoke(Lhs &&lhs, Rhs &&rhs) { @@ -610,6 +716,7 @@ device_array ElementwiseOperation(py::handle lhs, py::handle rhs, switch (dtype) { SF_UNARY_FUNCTION_CASE(float16, half_float::half); + SF_UNARY_FUNCTION_CASE(bfloat16, bfloat16_t); SF_UNARY_FUNCTION_CASE(float32, float); SF_UNARY_FUNCTION_CASE(float64, double); SF_UNARY_FUNCTION_CASE(uint8, uint8_t); @@ -661,6 +768,7 @@ void BindArrayHostOps(py::module_ &m) { switch (input.dtype()) { SF_UNARY_FUNCTION_CASE(float16, half_float::half); + SF_UNARY_FUNCTION_CASE(bfloat16, bfloat16_t); SF_UNARY_FUNCTION_CASE(float32, float); default: throw std::invalid_argument( @@ -690,6 +798,7 @@ void BindArrayHostOps(py::module_ &m) { switch (out.dtype()) { SF_UNARY_FUNCTION_CASE(float16, half_float::half); + SF_UNARY_FUNCTION_CASE(bfloat16, bfloat16_t); SF_UNARY_FUNCTION_CASE(float32, float); default: throw std::invalid_argument( diff --git a/shortfin/python/shortfin_apps/flux/components/messages.py b/shortfin/python/shortfin_apps/flux/components/messages.py index 8646da1f6..442c1beb7 100644 --- a/shortfin/python/shortfin_apps/flux/components/messages.py +++ b/shortfin/python/shortfin_apps/flux/components/messages.py @@ -48,6 +48,7 @@ def __init__( steps: int | None = None, guidance_scale: float | sfnp.device_array | None = None, seed: int | None = None, + cfg_mult: int = 1, clip_input_ids: list[list[int]] | None = None, t5xxl_input_ids: list[list[int]] | None = None, sample: sfnp.device_array | None = None, @@ -74,6 +75,7 @@ def __init__( self.height = height self.width = width self.seed = seed + self.cfg_mult = 1 # Encode phase. # This is a list of sequenced positive and negative token ids and pooler token ids. diff --git a/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json b/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json index 343bdea28..688807290 100644 --- a/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json +++ b/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json @@ -13,8 +13,8 @@ "t5xxl_batch_sizes": [ 1 ], - "t5xxl_dtype": "bfloat16", - "t5xxl_module_name": "compiled_flux_text_encoder2", + "t5xxl_dtype": "float32", + "t5xxl_module_name": "compiled_flux_text_encoder_2", "t5xxl_fn_name": "encode_prompts", "sampler_batch_sizes": [ 1 From b2aa5c88d3297951a2f75a1bcafab3a63581b68c Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Mon, 10 Feb 2025 13:37:00 -0800 Subject: [PATCH 18/20] Typo fixes and style correction --- sharktank/sharktank/models/clip/export.py | 2 +- .../sharktank/pipelines/flux/flux_pipeline.py | 46 ++++--------------- .../shortfin_apps/flux/components/builders.py | 12 ++--- .../flux/components/config_artifacts.py | 12 ++--- .../flux/components/config_struct.py | 2 +- 5 files changed, 22 insertions(+), 52 deletions(-) diff --git a/sharktank/sharktank/models/clip/export.py b/sharktank/sharktank/models/clip/export.py index 66fd99420..1b32d5e49 100644 --- a/sharktank/sharktank/models/clip/export.py +++ b/sharktank/sharktank/models/clip/export.py @@ -52,7 +52,7 @@ def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset: return Dataset(properties=model.config.to_properties(), root_theta=model.theta) -def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: PathLike, dtype=None): +def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: PathLike, dtype: torch.dtype = None): dataset = clip_text_model_to_dataset(model) if dtype: dataset.root_theta = dataset.root_theta.transform( diff --git a/sharktank/sharktank/pipelines/flux/flux_pipeline.py b/sharktank/sharktank/pipelines/flux/flux_pipeline.py index 12fc4b9bb..4618c6509 100644 --- a/sharktank/sharktank/pipelines/flux/flux_pipeline.py +++ b/sharktank/sharktank/pipelines/flux/flux_pipeline.py @@ -13,6 +13,7 @@ from torch import Tensor from transformers import CLIPTokenizer, T5Tokenizer, CLIPTextModel as HfCLIPTextModel +from sharktank.layers.base import BaseLayer from sharktank.models.t5 import T5Config, T5Encoder from sharktank.models.clip import ClipTextModel, ClipTextConfig from sharktank.models.flux.flux import FluxModelV1, FluxParams @@ -20,7 +21,7 @@ from sharktank.types import Dataset from sharktank.transforms.dataset import set_float_dtype -class FluxPipeline(nn.Module): +class FluxPipeline(BaseLayer): """Pipeline for text-to-image generation using the Flux model.""" def __init__( @@ -31,12 +32,15 @@ def __init__( ae_path: PathLike, t5_tokenizer_path: Optional[PathLike] = None, clip_tokenizer_path: Optional[PathLike] = None, - device: str = "cuda" if torch.cuda.is_available() else "cpu", + device: str = None, dtype: torch.dtype = torch.bfloat16, ): """Initialize the Flux pipeline.""" super().__init__() - self.device = torch.device(device) + if device: + self.device = torch.device(device) + else: + self.device = torch.get_default_device() self.dtype = dtype if t5_tokenizer_path: self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_tokenizer_path) @@ -49,9 +53,6 @@ def __init__( t5_dataset.properties, feed_forward_proj="gated-gelu", ) - # t5_dataset.root_theta = t5_dataset.root_theta.transform( - # functools.partial(set_float_dtype, dtype=dtype) - # ) self.t5_model = T5Encoder(theta=t5_dataset.root_theta, config=t5_config) self.add_module('t5_model', self.t5_model) self.t5_model.to(device) @@ -61,9 +62,6 @@ def __init__( # TODO: Refactor CLIP to not make the config rely on HuggingFace hf_clip_model = HfCLIPTextModel.from_pretrained("/data/flux/FLUX.1-dev/text_encoder/") clip_config = ClipTextConfig.from_hugging_face_clip_text_model_config(hf_clip_model.config) - # clip_dataset.root_theta = clip_dataset.root_theta.transform( - # functools.partial(set_float_dtype, dtype=dtype) - # ) self.clip_model = ClipTextModel(theta=clip_dataset.root_theta, config=clip_config) self.add_module('clip_model', self.clip_model) self.clip_model.to(device) @@ -71,9 +69,6 @@ def __init__( # Load Flux Transformer transformer_dataset = Dataset.load(transformer_path) transformer_params = FluxParams.from_hugging_face_properties(transformer_dataset.properties) - # transformer_dataset.root_theta = transformer_dataset.root_theta.transform( - # functools.partial(set_float_dtype, dtype=dtype) - # ) self.transformer_model = FluxModelV1( theta=transformer_dataset.root_theta, params=transformer_params @@ -83,36 +78,11 @@ def __init__( # Load VAE ae_dataset = Dataset.load(ae_path) - # ae_dataset.root_theta = ae_dataset.root_theta.transform( - # functools.partial(set_float_dtype, dtype=dtype) - # ) self.ae_model = VaeDecoderModel.from_dataset(ae_dataset) self.add_module('ae_model', self.ae_model) self.ae_model.to(device) self._rng = torch.Generator(device="cpu") - - def _get_noise( - self, - num_samples: int, - height: int, - width: int, - seed: Optional[int] = None, - ) -> Tensor: - """Generate initial noise for the diffusion process.""" - if seed is not None: - self._rng.manual_seed(seed) - - return torch.randn( - num_samples, - 16, - # allow for packing - 2 * math.ceil(height / 16), - 2 * math.ceil(width / 16), - device=self.device, - dtype=self.dtype, - generator=self._rng, - ) def __call__( self, @@ -145,7 +115,7 @@ def __call__( t5_prompt_ids, clip_prompt_ids = self.tokenize_prompt(prompt) if not latents: - latents = self._get_noise( + latents = self.transformer_model._get_noise( 1, height, width, diff --git a/shortfin/python/shortfin_apps/flux/components/builders.py b/shortfin/python/shortfin_apps/flux/components/builders.py index 7a2c2ce60..09d9e388c 100644 --- a/shortfin/python/shortfin_apps/flux/components/builders.py +++ b/shortfin/python/shortfin_apps/flux/components/builders.py @@ -27,10 +27,10 @@ } ARTIFACT_VERSION = "12032024" -SDXL_BUCKET = ( +FLUX_BUCKET = ( f"https://sharkpublic.blob.core.windows.net/sharkpublic/flux.1/{ARTIFACT_VERSION}/" ) -SDXL_WEIGHTS_BUCKET = ( +FLUX_WEIGHTS_BUCKET = ( "https://sharkpublic.blob.core.windows.net/sharkpublic/flux.1/weights/" ) @@ -237,7 +237,7 @@ def _invoke(self, retries=4): self._invoke(retries=retries) -@entrypoint(description="Retreives a set of SDXL submodels.") +@entrypoint(description="Retreives a set of FLUX submodels.") def flux( model_json=cl_arg( "model-json", @@ -263,8 +263,8 @@ def flux( ctx = executor.BuildContext.current() update = needs_update(ctx) - mlir_bucket = SDXL_BUCKET + "mlir/" - vmfb_bucket = SDXL_BUCKET + "vmfbs/" + mlir_bucket = FLUX_BUCKET + "mlir/" + vmfb_bucket = FLUX_BUCKET + "vmfbs/" if "gfx" in target: target = "amdgpu-" + target @@ -295,7 +295,7 @@ def flux( fetch_http(name=f, url=url) params_filenames = get_params_filenames(model_params, model=model, splat=splat) - params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET) + params_urls = get_url_map(params_filenames, FLUX_WEIGHTS_BUCKET) for f, url in params_urls.items(): if needs_file(f, ctx, url): fetch_http_check_size(name=f, url=url) diff --git a/shortfin/python/shortfin_apps/flux/components/config_artifacts.py b/shortfin/python/shortfin_apps/flux/components/config_artifacts.py index 8c779ebc6..276eb6834 100644 --- a/shortfin/python/shortfin_apps/flux/components/config_artifacts.py +++ b/shortfin/python/shortfin_apps/flux/components/config_artifacts.py @@ -9,7 +9,7 @@ import os ARTIFACT_VERSION = "12032024" -SDXL_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/flux.1/{ARTIFACT_VERSION}/configs/" +FLUX_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/flux.1/{ARTIFACT_VERSION}/configs/" def get_url_map(filenames: list[str], bucket: str): @@ -48,7 +48,7 @@ def needs_file(filename, ctx, namespace=FileNamespace.GEN): return needed -@entrypoint(description="Retreives a set of SDXL configuration files.") +@entrypoint(description="Retreives a set of FLUX configuration files.") def sdxlconfig( target=cl_arg( "target", @@ -67,7 +67,7 @@ def sdxlconfig( update = needs_update(ctx) # model_config_filenames = [f"{model}_config_i8.json"] - # model_config_urls = get_url_map(model_config_filenames, SDXL_CONFIG_BUCKET) + # model_config_urls = get_url_map(model_config_filenames, FLUX_CONFIG_BUCKET) # for f, url in model_config_urls.items(): # if update or needs_file(f, ctx): # fetch_http(name=f, url=url) @@ -75,14 +75,14 @@ def sdxlconfig( if topology: topology_config_filenames = [f"topology_config_{topology}.txt"] topology_config_urls = get_url_map( - topology_config_filenames, SDXL_CONFIG_BUCKET + topology_config_filenames, FLUX_CONFIG_BUCKET ) for f, url in topology_config_urls.items(): if update or needs_file(f, ctx): fetch_http(name=f, url=url) # flagfile_filenames = [f"{model}_flagfile_{target}.txt"] - # flagfile_urls = get_url_map(flagfile_filenames, SDXL_CONFIG_BUCKET) + # flagfile_urls = get_url_map(flagfile_filenames, FLUX_CONFIG_BUCKET) # for f, url in flagfile_urls.items(): # if update or needs_file(f, ctx): # fetch_http(name=f, url=url) @@ -90,7 +90,7 @@ def sdxlconfig( tuning_filenames = ( [f"attention_and_matmul_spec_{target}.mlir"] if target == "gfx942" else [] ) - tuning_urls = get_url_map(tuning_filenames, SDXL_CONFIG_BUCKET) + tuning_urls = get_url_map(tuning_filenames, FLUX_CONFIG_BUCKET) for f, url in tuning_urls.items(): if update or needs_file(f, ctx): fetch_http(name=f, url=url) diff --git a/shortfin/python/shortfin_apps/flux/components/config_struct.py b/shortfin/python/shortfin_apps/flux/components/config_struct.py index 0e23285e1..35bae3814 100644 --- a/shortfin/python/shortfin_apps/flux/components/config_struct.py +++ b/shortfin/python/shortfin_apps/flux/components/config_struct.py @@ -71,7 +71,7 @@ class ModelParams: vae_fn_name: str = "decode" vae_dtype: sfnp.DType = sfnp.float32 - # Whether model is "schnell" (fast) or not. This is roughly equivalent to "turbo" from SDXL. + # Whether model is "schnell" (fast) or not. This is roughly equivalent to "turbo" from FLUX. # It cuts batch dims in half for sampling/encoding and removes negative prompt functionality. is_schnell: bool = False From 98870e384732ec362b7d19f152dfee36db9d5aa3 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Mon, 10 Feb 2025 14:35:56 -0800 Subject: [PATCH 19/20] Additional cleanup --- .../sharktank/pipelines/flux/flux_pipeline.py | 47 +++++++++++-------- .../pipelines/flux/flux_pipeline_test.py | 38 +++------------ .../shortfin_apps/flux/components/builders.py | 2 +- 3 files changed, 36 insertions(+), 51 deletions(-) diff --git a/sharktank/sharktank/pipelines/flux/flux_pipeline.py b/sharktank/sharktank/pipelines/flux/flux_pipeline.py index 4618c6509..10b982dde 100644 --- a/sharktank/sharktank/pipelines/flux/flux_pipeline.py +++ b/sharktank/sharktank/pipelines/flux/flux_pipeline.py @@ -122,17 +122,16 @@ def __call__( seed=seed, ) - with torch.inference_mode(): - return self.forward( - t5_prompt_ids, - clip_prompt_ids, - latents, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - seed=seed, - ) + return self.forward( + t5_prompt_ids, + clip_prompt_ids, + latents, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + seed=seed, + ) def forward( @@ -142,7 +141,7 @@ def forward( latents: Tensor, height: int = 1024, width: int = 1024, - num_inference_steps: Optional[int] = 1, # TODO: DO NOT SUBMIT + num_inference_steps: Optional[int] = None, guidance_scale: float = 3.5, seed: Optional[int] = None, ) -> Tensor: @@ -171,9 +170,6 @@ def forward( # Decode latents x = self._unpack(x.to(dtype=self.dtype), height, width) x = self.ae_model(x) - - if torch.cuda.is_available(): - torch.cuda.synchronize() x = x[0] x = x.cpu() @@ -182,7 +178,7 @@ def forward( return x.float() - def _prepare(self, t5, clip, t5_prompt_ids, clip_prompt_ids, img: Tensor) -> dict[str, Tensor]: + def _prepare(self, t5: T5Encoder, clip: ClipTextModel, t5_prompt_ids: Tensor, clip_prompt_ids: Tensor, img: Tensor) -> dict[str, Tensor]: """Prepare inputs for the transformer model. Args: @@ -278,7 +274,7 @@ def _denoise( vec: Tensor, # sampling parameters timesteps: list[float], - guidance: float = 4.0, + guidance: float = 3.5, # extra img tokens img_cond: Optional[Tensor] = None, ) -> Tensor: @@ -323,11 +319,24 @@ def tokenize_prompt(self, prompt: str) -> tuple[Tensor, Tensor]: Tuple of (t5_prompt_ids, clip_prompt_ids) tensors """ # T5 tokenization - t5_prompt_ids = [self.t5_tokenizer(p).input_ids for p in [prompt]] + t5_prompt_ids = [self.t5_tokenizer(p, + truncation=True, + max_length=self.max_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt",).input_ids for p in [prompt] + ] t5_prompt_ids = torch.tensor(t5_prompt_ids, dtype=torch.long) # CLIP tokenization - clip_prompt_ids = [self.clip_tokenizer(p).input_ids for p in [prompt]] + clip_prompt_ids = [self.clip_tokenizer(p, + truncation=True, + max_length=self.max_length, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt",).input_ids for p in [prompt]] clip_prompt_ids = torch.tensor(clip_prompt_ids, dtype=torch.long) return t5_prompt_ids, clip_prompt_ids diff --git a/sharktank/tests/pipelines/flux/flux_pipeline_test.py b/sharktank/tests/pipelines/flux/flux_pipeline_test.py index e0bb696ae..8870b4c85 100644 --- a/sharktank/tests/pipelines/flux/flux_pipeline_test.py +++ b/sharktank/tests/pipelines/flux/flux_pipeline_test.py @@ -4,41 +4,17 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# DO NOT SUBMIT: REVIEW AND TEST FILE - """Tests for Flux text-to-image pipeline.""" -import functools from typing import Optional -import os -from collections import OrderedDict import pytest import torch from unittest import TestCase import numpy -from transformers import CLIPTokenizer, T5Tokenizer from diffusers import FluxPipeline as ReferenceFluxPipeline -from sharktank.types import Dataset, dtype_to_serialized_short_name -from sharktank.pipelines.flux import ( - FluxPipeline, - export_flux_pipeline_mlir, - export_flux_pipeline_iree_parameters, -) -from sharktank.utils.testing import TempDirTestBase -from sharktank.transforms.dataset import set_float_dtype -from sharktank.utils.iree import ( - get_iree_devices, - load_iree_module, - run_iree_module_function, - prepare_iree_module_function_args, - call_torch_module_function, - flatten_for_iree_signature, - iree_to_torch, -) -from sharktank import ops -import iree.compiler +from sharktank.pipelines.flux import FluxPipeline with_flux_data = pytest.mark.skipif("not config.getoption('with_flux_data')") @@ -87,13 +63,13 @@ def testFluxPipelineAgainstGolden(self): torch.testing.assert_close(output, reference_output) # TODO: why is this not passing? - def runTestFluxPipelineAgainstReference( + def runTestFluxPipelineAgainstHuggingFace( self, dtype: torch.dtype, atol: Optional[float] = None, rtol: Optional[float] = None, ): - """Compare pipeline outputs between different dtypes.""" + """Compare pipeline outputs against a HuggingFace based reference""" # Initialize reference model reference_model = ReferenceFluxPipeline.from_pretrained("/data/flux/FLUX.1-dev/") @@ -138,16 +114,16 @@ def runTestFluxPipelineAgainstReference( torch.testing.assert_close(reference_output, target_output, atol=atol, rtol=rtol) @with_flux_data - def testFluxPipelineF32(self): + def testFluxPipelineF32AgainstHuggingFace(self): """Test F32 pipeline against reference.""" - self.runTestFluxPipelineAgainstReference( + self.runTestFluxPipelineAgainstHuggingFace( dtype=torch.float32, ) @with_flux_data - def testFluxPipelineBF16(self): + def testFluxPipelineBF16AgainstHuggingFace(self): """Test BF16 pipeline against refence.""" - self.runTestFluxPipelineAgainstReference( + self.runTestFluxPipelineAgainstHuggingFace( dtype=torch.bfloat16, ) diff --git a/shortfin/python/shortfin_apps/flux/components/builders.py b/shortfin/python/shortfin_apps/flux/components/builders.py index 09d9e388c..92a4f8045 100644 --- a/shortfin/python/shortfin_apps/flux/components/builders.py +++ b/shortfin/python/shortfin_apps/flux/components/builders.py @@ -26,7 +26,7 @@ sfnp.bfloat16: "bf16", } -ARTIFACT_VERSION = "12032024" +ARTIFACT_VERSION = "02102024" FLUX_BUCKET = ( f"https://sharkpublic.blob.core.windows.net/sharkpublic/flux.1/{ARTIFACT_VERSION}/" ) From 88887dc0b2c281f37115e4d871bf4a200f35e98f Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Mon, 10 Feb 2025 14:48:20 -0800 Subject: [PATCH 20/20] More cleanup --- sharktank/sharktank/pipelines/flux/flux_pipeline.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/sharktank/sharktank/pipelines/flux/flux_pipeline.py b/sharktank/sharktank/pipelines/flux/flux_pipeline.py index 10b982dde..4faddc092 100644 --- a/sharktank/sharktank/pipelines/flux/flux_pipeline.py +++ b/sharktank/sharktank/pipelines/flux/flux_pipeline.py @@ -1,17 +1,15 @@ """Flux text-to-image generation pipeline.""" import argparse -import functools import math from os import PathLike from typing import Callable, Optional import torch -import torch.nn as nn from einops import rearrange, repeat from PIL import Image from torch import Tensor -from transformers import CLIPTokenizer, T5Tokenizer, CLIPTextModel as HfCLIPTextModel +from transformers import CLIPTokenizer, T5Tokenizer from sharktank.layers.base import BaseLayer from sharktank.models.t5 import T5Config, T5Encoder @@ -19,7 +17,6 @@ from sharktank.models.flux.flux import FluxModelV1, FluxParams from sharktank.models.vae.model import VaeDecoderModel from sharktank.types import Dataset -from sharktank.transforms.dataset import set_float_dtype class FluxPipeline(BaseLayer): """Pipeline for text-to-image generation using the Flux model.""" @@ -59,9 +56,7 @@ def __init__( # Load CLIP clip_dataset = Dataset.load(clip_path) - # TODO: Refactor CLIP to not make the config rely on HuggingFace - hf_clip_model = HfCLIPTextModel.from_pretrained("/data/flux/FLUX.1-dev/text_encoder/") - clip_config = ClipTextConfig.from_hugging_face_clip_text_model_config(hf_clip_model.config) + clip_config = ClipTextConfig.from_properties(clip_dataset.properties) self.clip_model = ClipTextModel(theta=clip_dataset.root_theta, config=clip_config) self.add_module('clip_model', self.clip_model) self.clip_model.to(device) @@ -220,6 +215,7 @@ def _prepare(self, t5: T5Encoder, clip: ClipTextModel, t5_prompt_ids: Tensor, cl def _time_shift(self, mu: float, sigma: float, t: Tensor) -> Tensor: """Apply time shift to the schedule.""" return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + def _get_lin_function( self, x1: float = 256, @@ -232,7 +228,6 @@ def _get_lin_function( b = y1 - m * x1 return lambda x: m * x + b - def _get_schedule( self, num_steps: int, @@ -298,7 +293,6 @@ def _denoise( return img - def _unpack(self, x: Tensor, height: int, width: int) -> Tensor: return rearrange( x,