diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 6cf79402e..be00c6d2e 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -225,8 +225,8 @@ def __post_init__(self): def from_gguf_properties(properties: dict[str, Any], **kwargs): assert properties["general.architecture"] == "t5" assert ( - properties["t5.attention.layer_norm_epsilon"] - == properties["t5.attention.layer_norm_rms_epsilon"] + properties["t5.attention.layer_norm_epsilon"] + == properties["t5.attention.layer_norm_rms_epsilon"] ) all_kwargs = {"vocab_size": None, "feed_forward_proj": None} @@ -290,10 +290,10 @@ class ClipTextConfig: output_hidden_states: bool = False use_return_dict: bool = True dtype: torch.dtype = torch.float32 - + @staticmethod def from_hugging_face_clip_text_model_config( - config: "transformers.CLIPTextConfig", + config: "transformers.CLIPTextConfig", # type: ignore ) -> "ClipTextConfig": return ClipTextConfig( vocab_size=config.vocab_size, @@ -314,7 +314,7 @@ def from_hugging_face_clip_text_model_config( dtype=config.torch_dtype or torch.float32, ) - def to_hugging_face_clip_text_model_config(self) -> "transformers.CLIPTextConfig": + def to_hugging_face_clip_text_model_config(self) -> "transformers.CLIPTextConfig": # type: ignore kwargs = self.to_properties() kwargs["torch_dtype"] = kwargs["dtype"] del kwargs["dtype"] diff --git a/sharktank/sharktank/models/clip/export.py b/sharktank/sharktank/models/clip/export.py index 95dbdacad..66fd99420 100644 --- a/sharktank/sharktank/models/clip/export.py +++ b/sharktank/sharktank/models/clip/export.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import functools from typing import Optional, Union import transformers from transformers.models.clip.modeling_clip import ( @@ -18,6 +19,7 @@ from ...layers.configs import ClipTextConfig from .clip import ClipTextModel from iree.turbine.aot import FxProgramsBuilder, export +from sharktank.transforms.dataset import set_float_dtype def hugging_face_clip_attention_to_theta(model: HfCLIPAttention) -> Theta: @@ -50,8 +52,12 @@ def clip_text_model_to_dataset(model: ClipTextModel) -> Dataset: return Dataset(properties=model.config.to_properties(), root_theta=model.theta) -def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: PathLike): +def export_clip_text_model_iree_parameters(model: ClipTextModel, output_path: PathLike, dtype=None): dataset = clip_text_model_to_dataset(model) + if dtype: + dataset.root_theta = dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=dtype) + ) dataset.save(output_path) diff --git a/sharktank/sharktank/models/flux/export.py b/sharktank/sharktank/models/flux/export.py index a63e75af9..6485593c0 100644 --- a/sharktank/sharktank/models/flux/export.py +++ b/sharktank/sharktank/models/flux/export.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import functools from os import PathLike import os from pathlib import Path @@ -14,6 +15,7 @@ from .flux import FluxModelV1, FluxParams from ...types import Dataset from ...utils.hf_datasets import get_dataset +from sharktank.transforms.dataset import set_float_dtype flux_transformer_default_batch_sizes = [1] @@ -27,11 +29,14 @@ def export_flux_transformer_model_mlir( def export_flux_transformer_iree_parameters( - model: FluxModelV1, parameters_output_path: PathLike + model: FluxModelV1, parameters_output_path: PathLike, dtype = None ): model.theta.rename_tensors_to_paths() - # TODO: export properties - dataset = Dataset(root_theta=model.theta, properties={}) + dataset = Dataset(root_theta=model.theta, properties=model.params.to_hugging_face_properties()) + if dtype: + dataset.root_theta = dataset.root_theta.transform( + functools.partial(set_float_dtype, dtype=dtype) + ) dataset.save(parameters_output_path) diff --git a/sharktank/sharktank/models/flux/flux.py b/sharktank/sharktank/models/flux/flux.py index 725c483ca..219911fb3 100644 --- a/sharktank/sharktank/models/flux/flux.py +++ b/sharktank/sharktank/models/flux/flux.py @@ -12,8 +12,8 @@ from typing import Any, Optional from collections import OrderedDict from copy import copy +from dataclasses import dataclass, asdict import math -from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F @@ -49,6 +49,19 @@ class FluxParams: qkv_bias: bool guidance_embed: bool + def to_hugging_face_properties(self) -> dict[str, Any]: + hparams = { + "in_channels": self.in_channels, + "pooled_projection_dim": self.vec_in_dim, + "joint_attention_dim": self.context_in_dim, + "num_attention_heads": self.num_heads, + "num_layers": self.depth, + "num_single_layers": self.depth_single_blocks, + "attention_head_dim": sum(self.axes_dim), + "guidance_embeds": self.guidance_embed + } + return {"hparams": hparams} + @staticmethod def from_hugging_face_properties(properties: dict[str, Any]) -> "FluxParams": p = properties["hparams"] @@ -175,6 +188,7 @@ def forward( "Didn't get guidance strength for guidance distilled model." ) vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) txt = self.txt_in(txt) diff --git a/sharktank/sharktank/models/vae/model.py b/sharktank/sharktank/models/vae/model.py index d689aaf72..c2746e800 100644 --- a/sharktank/sharktank/models/vae/model.py +++ b/sharktank/sharktank/models/vae/model.py @@ -74,15 +74,6 @@ def forward( "latent_embeds": latent_embeds, }, ) - if not self.hp.use_post_quant_conv: - sample = rearrange( - sample, - "b (h w) (c ph pw) -> b c (h ph) (w pw)", - h=math.ceil(1024 / 16), - w=math.ceil(1024 / 16), - ph=2, - pw=2, - ) sample = sample / self.hp.scaling_factor + self.hp.shift_factor if self.hp.use_post_quant_conv: diff --git a/sharktank/sharktank/pipelines/flux/__init__.py b/sharktank/sharktank/pipelines/flux/__init__.py new file mode 100644 index 000000000..f32a3f280 --- /dev/null +++ b/sharktank/sharktank/pipelines/flux/__init__.py @@ -0,0 +1,7 @@ +"""Flux text-to-image generation pipeline.""" + +from .flux_pipeline import FluxPipeline + +__all__ = [ + "FluxPipeline", +] diff --git a/sharktank/sharktank/pipelines/flux/flux_pipeline.py b/sharktank/sharktank/pipelines/flux/flux_pipeline.py new file mode 100644 index 000000000..12fc4b9bb --- /dev/null +++ b/sharktank/sharktank/pipelines/flux/flux_pipeline.py @@ -0,0 +1,440 @@ +"""Flux text-to-image generation pipeline.""" + +import argparse +import functools +import math +from os import PathLike +from typing import Callable, Optional + +import torch +import torch.nn as nn +from einops import rearrange, repeat +from PIL import Image +from torch import Tensor +from transformers import CLIPTokenizer, T5Tokenizer, CLIPTextModel as HfCLIPTextModel + +from sharktank.models.t5 import T5Config, T5Encoder +from sharktank.models.clip import ClipTextModel, ClipTextConfig +from sharktank.models.flux.flux import FluxModelV1, FluxParams +from sharktank.models.vae.model import VaeDecoderModel +from sharktank.types import Dataset +from sharktank.transforms.dataset import set_float_dtype + +class FluxPipeline(nn.Module): + """Pipeline for text-to-image generation using the Flux model.""" + + def __init__( + self, + t5_path: PathLike, + clip_path: PathLike, + transformer_path: PathLike, + ae_path: PathLike, + t5_tokenizer_path: Optional[PathLike] = None, + clip_tokenizer_path: Optional[PathLike] = None, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + dtype: torch.dtype = torch.bfloat16, + ): + """Initialize the Flux pipeline.""" + super().__init__() + self.device = torch.device(device) + self.dtype = dtype + if t5_tokenizer_path: + self.t5_tokenizer = T5Tokenizer.from_pretrained(t5_tokenizer_path) + if clip_tokenizer_path: + self.clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path) + + # Load T5 + t5_dataset = Dataset.load(t5_path) + t5_config = T5Config.from_gguf_properties( + t5_dataset.properties, + feed_forward_proj="gated-gelu", + ) + # t5_dataset.root_theta = t5_dataset.root_theta.transform( + # functools.partial(set_float_dtype, dtype=dtype) + # ) + self.t5_model = T5Encoder(theta=t5_dataset.root_theta, config=t5_config) + self.add_module('t5_model', self.t5_model) + self.t5_model.to(device) + + # Load CLIP + clip_dataset = Dataset.load(clip_path) + # TODO: Refactor CLIP to not make the config rely on HuggingFace + hf_clip_model = HfCLIPTextModel.from_pretrained("/data/flux/FLUX.1-dev/text_encoder/") + clip_config = ClipTextConfig.from_hugging_face_clip_text_model_config(hf_clip_model.config) + # clip_dataset.root_theta = clip_dataset.root_theta.transform( + # functools.partial(set_float_dtype, dtype=dtype) + # ) + self.clip_model = ClipTextModel(theta=clip_dataset.root_theta, config=clip_config) + self.add_module('clip_model', self.clip_model) + self.clip_model.to(device) + + # Load Flux Transformer + transformer_dataset = Dataset.load(transformer_path) + transformer_params = FluxParams.from_hugging_face_properties(transformer_dataset.properties) + # transformer_dataset.root_theta = transformer_dataset.root_theta.transform( + # functools.partial(set_float_dtype, dtype=dtype) + # ) + self.transformer_model = FluxModelV1( + theta=transformer_dataset.root_theta, + params=transformer_params + ) + self.add_module('transformer_model', self.transformer_model) + self.transformer_model.to(device) + + # Load VAE + ae_dataset = Dataset.load(ae_path) + # ae_dataset.root_theta = ae_dataset.root_theta.transform( + # functools.partial(set_float_dtype, dtype=dtype) + # ) + self.ae_model = VaeDecoderModel.from_dataset(ae_dataset) + self.add_module('ae_model', self.ae_model) + self.ae_model.to(device) + + self._rng = torch.Generator(device="cpu") + + def _get_noise( + self, + num_samples: int, + height: int, + width: int, + seed: Optional[int] = None, + ) -> Tensor: + """Generate initial noise for the diffusion process.""" + if seed is not None: + self._rng.manual_seed(seed) + + return torch.randn( + num_samples, + 16, + # allow for packing + 2 * math.ceil(height / 16), + 2 * math.ceil(width / 16), + device=self.device, + dtype=self.dtype, + generator=self._rng, + ) + + def __call__( + self, + prompt: str, + height: int = 1024, + width: int = 1024, + latents: Optional[Tensor] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: float = 3.5, + seed: Optional[int] = None, + ) -> Tensor: + """Generate images from a prompt + + Args: + prompt: Text prompt for image generation + height: Height of output image + width: Width of output image + num_inference_steps: Number of denoising steps + guidance_scale: Scale for classifier-free guidance + seed: Random seed for reproducibility + + Returns: + Image tensor + + Raises: + ValueError: If tokenizers are not provided + """ + if not self.t5_tokenizer or not self.clip_tokenizer: + raise ValueError("Tokenizers must be provided to use the __call__ method") + + t5_prompt_ids, clip_prompt_ids = self.tokenize_prompt(prompt) + if not latents: + latents = self._get_noise( + 1, + height, + width, + seed=seed, + ) + + with torch.inference_mode(): + return self.forward( + t5_prompt_ids, + clip_prompt_ids, + latents, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + seed=seed, + ) + + + def forward( + self, + t5_prompt_ids: Tensor, + clip_prompt_ids: Tensor, + latents: Tensor, + height: int = 1024, + width: int = 1024, + num_inference_steps: Optional[int] = 1, # TODO: DO NOT SUBMIT + guidance_scale: float = 3.5, + seed: Optional[int] = None, + ) -> Tensor: + # Set default steps # TODO: Check if dev or schnell + if num_inference_steps is None: + num_inference_steps = 50 + + # Adjust dimensions to be multiples of 16 + height = 16 * (height // 16) + width = 16 * (width // 16) + + # Generate initial noise + x = latents + + # Prepare inputs + inp = self._prepare(self.t5_model, self.clip_model, t5_prompt_ids, clip_prompt_ids, x) + timesteps = self._get_schedule(num_inference_steps, inp["img"].shape[1], shift=True) + + # Denoise + x = self._denoise( + **inp, + timesteps=timesteps, + guidance=guidance_scale, + ) + + # Decode latents + x = self._unpack(x.to(dtype=self.dtype), height, width) + x = self.ae_model(x) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + x = x[0] + x = x.cpu() + x = x.clamp(-1, 1) + x = rearrange(x, "c h w -> h w c") + return x.float() + + + def _prepare(self, t5, clip, t5_prompt_ids, clip_prompt_ids, img: Tensor) -> dict[str, Tensor]: + """Prepare inputs for the transformer model. + + Args: + t5: T5 model for text encoding + clip: CLIP model for text encoding + t5_prompt_ids: Tokenized T5 prompt IDs + clip_prompt_ids: Tokenized CLIP prompt IDs + img: Initial noise tensor + + Returns: + Dictionary containing prepared inputs for the transformer + """ + bs, c, h, w = img.shape + + # Prepare image and position IDs + img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + # Get text embeddings + # Process text through T5 + txt = t5(t5_prompt_ids)["last_hidden_state"] + txt_ids = torch.zeros(bs, txt.shape[1], 3, device=img.device) + + # Process text through CLIP + vec = clip(clip_prompt_ids)["pooler_output"] + + # Return prepared inputs + return { + "img": img, + "img_ids": img_ids, + "txt": txt, + "txt_ids": txt_ids, + "vec": vec, + } + + def _time_shift(self, mu: float, sigma: float, t: Tensor) -> Tensor: + """Apply time shift to the schedule.""" + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + def _get_lin_function( + self, + x1: float = 256, + y1: float = 0.5, + x2: float = 4096, + y2: float = 1.15 + ) -> Callable[[float], float]: + """Get linear interpolation function.""" + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + + def _get_schedule( + self, + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, + ) -> list[float]: + """Generate sampling schedule. + + Args: + num_steps: Number of diffusion steps + image_seq_len: Length of the image sequence + base_shift: Base shift value for schedule adjustment + max_shift: Maximum shift value for schedule adjustment + shift: Whether to apply schedule shifting + + Returns: + List of timesteps for the diffusion process + """ + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # estimate mu based on linear estimation between two points + mu = self._get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = self._time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + def _denoise( + self, + # model input + img: Tensor, + img_ids: Tensor, + txt: Tensor, + txt_ids: Tensor, + vec: Tensor, + # sampling parameters + timesteps: list[float], + guidance: float = 4.0, + # extra img tokens + img_cond: Optional[Tensor] = None, + ) -> Tensor: + """Denoise the latents through the diffusion process.""" + # this is ignored for schnell + guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=self.dtype) + for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): + t_curr_vec = torch.full((img.shape[0],), t_curr, dtype=self.dtype, device=img.device) + t_prev_vec = torch.full((img.shape[0],), t_prev, dtype=self.dtype, device=img.device) + pred = self.transformer_model( + img=torch.cat((img, img_cond), dim=-1) if img_cond is not None else img, + img_ids=img_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t_curr_vec, + guidance=guidance_vec, + ) + print(t_prev, t_curr) + img = img + (t_prev_vec - t_curr_vec) * pred + + return img + + + def _unpack(self, x: Tensor, height: int, width: int) -> Tensor: + return rearrange( + x, + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=math.ceil(height / 16), + w=math.ceil(width / 16), + ph=2, + pw=2, + ) + + def tokenize_prompt(self, prompt: str) -> tuple[Tensor, Tensor]: + """Tokenize a prompt using T5 and CLIP tokenizers. + + Args: + prompt: Text prompt to tokenize + + Returns: + Tuple of (t5_prompt_ids, clip_prompt_ids) tensors + """ + # T5 tokenization + t5_prompt_ids = [self.t5_tokenizer(p).input_ids for p in [prompt]] + t5_prompt_ids = torch.tensor(t5_prompt_ids, dtype=torch.long) + + # CLIP tokenization + clip_prompt_ids = [self.clip_tokenizer(p).input_ids for p in [prompt]] + clip_prompt_ids = torch.tensor(clip_prompt_ids, dtype=torch.long) + + return t5_prompt_ids, clip_prompt_ids + +def main(): + """Example usage of FluxPipeline.""" + parser = argparse.ArgumentParser(description="Flux text-to-image generation pipeline") + + # Model paths + parser.add_argument("--t5-path", default="/data/t5-v1_1-xxl/model.gguf", + help="Path to T5 model") + parser.add_argument("--clip-path", default="/data/flux/FLUX.1-dev/text_encoder/model.irpa", + help="Path to CLIP model") + parser.add_argument("--transformer-path", default="/data/flux/FLUX.1-dev/transformer/model.irpa", + help="Path to Transformer model") + parser.add_argument("--ae-path", default="/data/flux/FLUX.1-dev/vae/model.irpa", + help="Path to VAE model") + parser.add_argument("--t5-tokenizer-path", default="/data/flux/FLUX.1-dev/tokenizer_2/", + help="Path to T5 tokenizer") + parser.add_argument("--clip-tokenizer-path", default="/data/flux/FLUX.1-dev/tokenizer/", + help="Path to CLIP tokenizer") + + # Generation parameters + parser.add_argument("--height", type=int, default=1024, + help="Height of output image") + parser.add_argument("--width", type=int, default=1024, + help="Width of output image") + parser.add_argument("--num-inference-steps", type=int, default=None, + help="Number of denoising steps") + parser.add_argument("--guidance-scale", type=float, default=3.5, + help="Scale for classifier-free guidance") + parser.add_argument("--seed", type=int, default=None, + help="Random seed for reproducibility") + + # Other parameters + parser.add_argument("--prompt", default="a photo of a forest with mist swirling around the tree trunks. " + 'The word "FLUX" is painted over it in big, red brush strokes with visible texture', + help="Text prompt for image generation") + parser.add_argument("--output", default="output.jpg", + help="Output image path") + parser.add_argument("--dtype", default="bfloat16", choices=["float32", "float16", "bfloat16"], + help="Data type for model") + + args = parser.parse_args() + + # Map dtype string to torch dtype + dtype_map = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16 + } + + # Initialize pipeline + pipeline = FluxPipeline( + t5_path=args.t5_path, + clip_path=args.clip_path, + transformer_path=args.transformer_path, + ae_path=args.ae_path, + t5_tokenizer_path=args.t5_tokenizer_path, + clip_tokenizer_path=args.clip_tokenizer_path, + dtype=dtype_map[args.dtype] + ) + + # Generate image + x = pipeline( + prompt=args.prompt, + height=args.height, + width=args.width, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + seed=args.seed + ) + + # Transform and save first image + image = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + image.save(args.output, quality=95, subsampling=0) + + +if __name__ == "__main__": + main() diff --git a/sharktank/sharktank/tools/import_hf_dataset.py b/sharktank/sharktank/tools/import_hf_dataset.py index 8fb5b3f16..46f3f973c 100644 --- a/sharktank/sharktank/tools/import_hf_dataset.py +++ b/sharktank/sharktank/tools/import_hf_dataset.py @@ -48,12 +48,11 @@ def import_hf_dataset( meta_params = {k: v for k, v in config_json.items() if k.startswith("_")} hparams = {k: v for k, v in config_json.items() if not k.startswith("_")} + tensors = [] for params_path in param_paths: with safetensors.safe_open(params_path, framework="pt", device="cpu") as st: - tensors = [ - DefaultPrimitiveTensor( - name=name, data=st.get_tensor(name).to(target_dtype) - ) + tensors += [ + DefaultPrimitiveTensor(name=name, data=st.get_tensor(name)) for name in st.keys() ] diff --git a/sharktank/tests/pipelines/flux/flux_pipeline_test.py b/sharktank/tests/pipelines/flux/flux_pipeline_test.py new file mode 100644 index 000000000..e0bb696ae --- /dev/null +++ b/sharktank/tests/pipelines/flux/flux_pipeline_test.py @@ -0,0 +1,154 @@ +# Copyright 2024 Advanced Micro Devices, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# DO NOT SUBMIT: REVIEW AND TEST FILE + +"""Tests for Flux text-to-image pipeline.""" + +import functools +from typing import Optional +import os +from collections import OrderedDict +import pytest +import torch +from unittest import TestCase +import numpy + +from transformers import CLIPTokenizer, T5Tokenizer +from diffusers import FluxPipeline as ReferenceFluxPipeline + +from sharktank.types import Dataset, dtype_to_serialized_short_name +from sharktank.pipelines.flux import ( + FluxPipeline, + export_flux_pipeline_mlir, + export_flux_pipeline_iree_parameters, +) +from sharktank.utils.testing import TempDirTestBase +from sharktank.transforms.dataset import set_float_dtype +from sharktank.utils.iree import ( + get_iree_devices, + load_iree_module, + run_iree_module_function, + prepare_iree_module_function_args, + call_torch_module_function, + flatten_for_iree_signature, + iree_to_torch, +) +from sharktank import ops +import iree.compiler + +with_flux_data = pytest.mark.skipif("not config.getoption('with_flux_data')") + +@pytest.mark.usefixtures("get_model_artifacts") +class FluxPipelineEagerTest(TestCase): + def setUp(self): + super().setUp() + torch.random.manual_seed(12345) + torch.no_grad() + + @with_flux_data + def testFluxPipelineAgainstGolden(self): + """Test against golden outputs from the original Flux pipeline.""" + model = FluxPipeline( + t5_path="/data/t5-v1_1-xxl/model.gguf", + clip_path="/data/flux/FLUX.1-dev/text_encoder/model.irpa", + transformer_path="/data/flux/FLUX.1-dev/transformer/model.irpa", + ae_path="/data/flux/FLUX.1-dev/vae/model.irpa", + dtype=torch.bfloat16, + ) + + # Load reference inputs + with open("/data/flux/test_data/t5_prompt_ids.pt", "rb") as f: + t5_prompt_ids = torch.load(f) + with open("/data/flux/test_data/clip_prompt_ids.pt", "rb") as f: + clip_prompt_ids = torch.load(f) + + # Generate output using forward method directly + latents = model._get_noise( + 1, + 1024, + 1024, + seed=12345, + ) + output = model.forward( + t5_prompt_ids, + clip_prompt_ids, + latents=latents, + num_inference_steps=1, + seed=12345, + ) + + # Compare against golden output + with open("/data/flux/test_data/flux_1_step_output.pt", "rb") as f: + reference_output = torch.load(f) + + torch.testing.assert_close(output, reference_output) # TODO: why is this not passing? + + def runTestFluxPipelineAgainstReference( + self, + dtype: torch.dtype, + atol: Optional[float] = None, + rtol: Optional[float] = None, + ): + """Compare pipeline outputs between different dtypes.""" + # Initialize reference model + reference_model = ReferenceFluxPipeline.from_pretrained("/data/flux/FLUX.1-dev/") + + # Initialize target model + target_model = FluxPipeline( + t5_path="/data/t5-v1_1-xxl/model.gguf", + clip_path="/data/flux/FLUX.1-dev/text_encoder/model.irpa", + transformer_path="/data/flux/FLUX.1-dev/transformer/model.irpa", + ae_path="/data/flux/FLUX.1-dev/vae/model.irpa", + t5_tokenizer_path="/data/flux/FLUX.1-dev/tokenizer_2/", + clip_tokenizer_path="/data/flux/FLUX.1-dev/tokenizer/", + dtype=dtype, + ) + + # Generate outputs using string prompt + prompt = "a photo of a forest with mist" + latents = reference_model._get_noise( + 1, + 1024, + 1024, + seed=12345, + ).to(dtype=dtype) + reference_image_output = reference_model( + prompt=prompt, + height=1024, + width=1024, + latents=latents, + num_inference_steps=1, + guidance_scale=3.5 + ).images[0] + reference_output = torch.tensor(numpy.array(reference_image_output)).to(dtype=dtype) + + target_output = target_model( + prompt=prompt, + height=1024, + width=1024, + latents=latents, + num_inference_steps=1, + guidance_scale=3.5 + ) + + torch.testing.assert_close(reference_output, target_output, atol=atol, rtol=rtol) + + @with_flux_data + def testFluxPipelineF32(self): + """Test F32 pipeline against reference.""" + self.runTestFluxPipelineAgainstReference( + dtype=torch.float32, + ) + + @with_flux_data + def testFluxPipelineBF16(self): + """Test BF16 pipeline against refence.""" + self.runTestFluxPipelineAgainstReference( + dtype=torch.bfloat16, + ) + + diff --git a/shortfin/python/shortfin_apps/flux/README.md b/shortfin/python/shortfin_apps/flux/README.md new file mode 100644 index 000000000..cf4dd2544 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/README.md @@ -0,0 +1,31 @@ +# Flux.1 Server and CLI + +This directory contains a [Flux](https://blackforestlabs.ai/#get-flux) inference server, CLI and support components. More information about FLUX.1 on [huggingface](https://huggingface.co/black-forest-labs/FLUX.1-dev). + +## Install + +For [nightly releases](../../../../docs/nightly_releases.md) +For our [stable release](../../../../docs/user_guide.md) + +## Start Flux Server +The server will prepare runtime artifacts for you. + +By default, the port is set to 8000. If you would like to change this, use `--port` in each of the following commands. + +You can check if this (or any) port is in use on Linux with `ss -ntl | grep 8000`. + +From a source checkout of shortfin: +``` +python -m shortfin_apps.flux.server --model_config=./python/shortfin_apps/flux/examples/flux_dev_config_mixed.json --device=amdgpu --fibers_per_device=1 --workers_per_device=1 --isolation="per_fiber" --flagfile=./python/shortfin_apps/flux/examples/flux_flags_gfx942.txt --build_preference=precompiled +``` + - Wait until your server outputs: +``` +INFO - Application startup complete. +INFO - Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) +``` +## Run the Flux Client + + - Run a CLI client in a separate shell: +``` +python -m shortfin_apps.flux.simple_client --interactive +``` diff --git a/shortfin/python/shortfin_apps/flux/__init__.py b/shortfin/python/shortfin_apps/flux/__init__.py new file mode 100644 index 000000000..4a168079c --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/__init__.py @@ -0,0 +1,7 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from . import _deps diff --git a/shortfin/python/shortfin_apps/flux/_deps.py b/shortfin/python/shortfin_apps/flux/_deps.py new file mode 100644 index 000000000..92bd089ec --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/_deps.py @@ -0,0 +1,22 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from shortfin.support.deps import ShortfinDepNotFoundError + +try: + import transformers +except ModuleNotFoundError as e: + raise ShortfinDepNotFoundError(__name__, "transformers") from e + +try: + import tokenizers +except ModuleNotFoundError as e: + raise ShortfinDepNotFoundError(__name__, "tokenizers") from e + +try: + import dataclasses_json +except ModuleNotFoundError as e: + raise ShortfinDepNotFoundError(__name__, "dataclasses-json") from e diff --git a/shortfin/python/shortfin_apps/flux/components/builders.py b/shortfin/python/shortfin_apps/flux/components/builders.py new file mode 100644 index 000000000..7a2c2ce60 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/builders.py @@ -0,0 +1,307 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from iree.build import * +from iree.build.executor import FileNamespace, BuildAction, BuildContext, BuildFile +import itertools +import os +import urllib +import shortfin.array as sfnp +import copy + +from shortfin_apps.flux.components.config_struct import ModelParams + +this_dir = os.path.dirname(os.path.abspath(__file__)) +parent = os.path.dirname(this_dir) +default_config_json = os.path.join(parent, "examples", "flux_dev_config_mixed.json") + +dtype_to_filetag = { + "bfloat16": "bf16", + "float32": "fp32", + "float16": "fp16", + sfnp.float32: "fp32", + sfnp.bfloat16: "bf16", +} + +ARTIFACT_VERSION = "12032024" +SDXL_BUCKET = ( + f"https://sharkpublic.blob.core.windows.net/sharkpublic/flux.1/{ARTIFACT_VERSION}/" +) +SDXL_WEIGHTS_BUCKET = ( + "https://sharkpublic.blob.core.windows.net/sharkpublic/flux.1/weights/" +) + + +def filter_by_model(filenames, model): + if not model: + return filenames + filtered = [] + for i in filenames: + if model == "t5xxl" and i == "google__t5_v1_1_xxl_encoder_fp32.irpa": + filtered.extend([i]) + if model.lower() in i.lower(): + filtered.extend([i]) + return filtered + + +def get_mlir_filenames(model_params: ModelParams, model=None): + mlir_filenames = [] + file_stems = get_file_stems(model_params) + for stem in file_stems: + mlir_filenames.extend([stem + ".mlir"]) + return filter_by_model(mlir_filenames, model) + + +def get_vmfb_filenames( + model_params: ModelParams, model=None, target: str = "amdgpu-gfx942" +): + vmfb_filenames = [] + file_stems = get_file_stems(model_params) + for stem in file_stems: + vmfb_filenames.extend([stem + "_" + target + ".vmfb"]) + return filter_by_model(vmfb_filenames, model) + + +def get_params_filenames(model_params: ModelParams, model=None, splat: bool = False): + params_filenames = [] + base = "flux_dev" if not model_params.is_schnell else model_params.base_model_name + modnames = ["clip", "sampler", "vae"] + mod_precs = [ + dtype_to_filetag[model_params.clip_dtype], + dtype_to_filetag[model_params.sampler_dtype], + dtype_to_filetag[model_params.vae_dtype], + ] + if splat == "True": + for idx, mod in enumerate(modnames): + params_filenames.extend( + ["_".join([mod, "splat", f"{mod_precs[idx]}.irpa"])] + ) + else: + for idx, mod in enumerate(modnames): + params_filenames.extend([base + "_" + mod + "_" + mod_precs[idx] + ".irpa"]) + + # this is a hack + params_filenames.extend(["google__t5_v1_1_xxl_encoder_fp32.irpa"]) + + return filter_by_model(params_filenames, model) + + +def get_file_stems(model_params: ModelParams): + file_stems = [] + base = ["flux_dev" if not model_params.is_schnell else model_params.base_model_name] + mod_names = { + "clip": "clip", + "t5xxl": "t5xxl", + "sampler": "sampler", + "vae": "vae", + } + for mod, modname in mod_names.items(): + ord_params = [ + base, + [modname], + ] + bsizes = [] + for bs in getattr(model_params, f"{mod}_batch_sizes", [1]): + bsizes.extend([f"bs{bs}"]) + ord_params.extend([bsizes]) + if mod in ["sampler"]: + ord_params.extend([[str(model_params.max_seq_len)]]) + elif mod == "clip": + ord_params.extend([[str(model_params.clip_max_seq_len)]]) + if mod in ["sampler", "vae"]: + dims = [] + for dim_pair in model_params.dims: + dim_pair_str = [str(d) for d in dim_pair] + dims.extend(["x".join(dim_pair_str)]) + ord_params.extend([dims]) + + dtype_str = dtype_to_filetag[ + getattr(model_params, f"{mod}_dtype", sfnp.float32) + ] + ord_params.extend([[dtype_str]]) + for x in list(itertools.product(*ord_params)): + file_stems.extend(["_".join(x)]) + return file_stems + + +def get_url_map(filenames: list[str], bucket: str): + file_map = {} + for filename in filenames: + file_map[filename] = f"{bucket}{filename}" + return file_map + + +def needs_update(ctx): + stamp = ctx.allocate_file("version.txt") + stamp_path = stamp.get_fs_path() + if os.path.exists(stamp_path): + with open(stamp_path, "r") as s: + ver = s.read() + if ver != ARTIFACT_VERSION: + return True + else: + with open(stamp_path, "w") as s: + s.write(ARTIFACT_VERSION) + return True + return False + + +def needs_file(filename, ctx, url=None, namespace=FileNamespace.GEN): + out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path() + needed = True + if os.path.exists(out_file): + if url: + needed = False + # needed = not is_valid_size(out_file, url) + if not needed: + return False + filekey = os.path.join(ctx.path, filename) + ctx.executor.all[filekey] = None + return True + + +def needs_compile(filename, target, ctx): + vmfb_name = f"{filename}_{target}.vmfb" + namespace = FileNamespace.BIN + return needs_file(vmfb_name, ctx, namespace=namespace) + + +def get_cached_vmfb(filename, target, ctx): + vmfb_name = f"{filename}_{target}.vmfb" + return ctx.file(vmfb_name) + + +def is_valid_size(file_path, url): + if not url: + return True + with urllib.request.urlopen(url) as response: + content_length = response.getheader("Content-Length") + local_size = get_file_size(str(file_path)) + if content_length: + content_length = int(content_length) + if content_length != local_size: + return False + return True + + +def get_file_size(file_path): + """Gets the size of a local file in bytes as an integer.""" + + file_stats = os.stat(file_path) + return file_stats.st_size + + +def fetch_http_check_size(*, name: str, url: str) -> BuildFile: + context = BuildContext.current() + output_file = context.allocate_file(name) + action = FetchHttpWithCheckAction( + url=url, output_file=output_file, desc=f"Fetch {url}", executor=context.executor + ) + output_file.deps.add(action) + return output_file + + +class FetchHttpWithCheckAction(BuildAction): + def __init__(self, url: str, output_file: BuildFile, **kwargs): + super().__init__(**kwargs) + self.url = url + self.output_file = output_file + + def _invoke(self, retries=4): + path = self.output_file.get_fs_path() + self.executor.write_status(f"Fetching URL: {self.url} -> {path}") + try: + urllib.request.urlretrieve(self.url, str(path)) + except urllib.error.HTTPError as e: + if retries > 0: + retries -= 1 + self._invoke(retries=retries) + else: + raise IOError(f"Failed to fetch URL '{self.url}': {e}") from None + local_size = get_file_size(str(path)) + try: + with urllib.request.urlopen(self.url) as response: + content_length = response.getheader("Content-Length") + if content_length: + content_length = int(content_length) + if content_length != local_size: + raise IOError( + f"Size of downloaded artifact does not match content-length header! {content_length} != {local_size}" + ) + except IOError: + if retries > 0: + retries -= 1 + self._invoke(retries=retries) + + +@entrypoint(description="Retreives a set of SDXL submodels.") +def flux( + model_json=cl_arg( + "model-json", + default=default_config_json, + help="Local config filepath", + ), + target=cl_arg( + "target", + default="gfx942", + help="IREE target architecture.", + ), + splat=cl_arg( + "splat", default=False, type=str, help="Download empty weights (for testing)" + ), + build_preference=cl_arg( + "build-preference", + default="precompiled", + help="Sets preference for artifact generation method: [compile, precompiled]", + ), + model=cl_arg("model", type=str, help="Submodel to fetch/compile for."), +): + model_params = ModelParams.load_json(model_json) + ctx = executor.BuildContext.current() + update = needs_update(ctx) + + mlir_bucket = SDXL_BUCKET + "mlir/" + vmfb_bucket = SDXL_BUCKET + "vmfbs/" + if "gfx" in target: + target = "amdgpu-" + target + + mlir_filenames = get_mlir_filenames(model_params, model) + mlir_urls = get_url_map(mlir_filenames, mlir_bucket) + for f, url in mlir_urls.items(): + if update or needs_file(f, ctx, url): + fetch_http(name=f, url=url) + + vmfb_filenames = get_vmfb_filenames(model_params, model=model, target=target) + vmfb_urls = get_url_map(vmfb_filenames, vmfb_bucket) + if build_preference == "compile": + for idx, f in enumerate(copy.deepcopy(vmfb_filenames)): + # We return .vmfb file stems for the compile builder. + file_stem = "_".join(f.split("_")[:-1]) + if needs_compile(file_stem, target, ctx): + for mlirname in mlir_filenames: + if file_stem in mlirname: + mlir_source = mlirname + break + obj = compile(name=file_stem, source=mlir_source) + vmfb_filenames[idx] = obj[0] + else: + vmfb_filenames[idx] = get_cached_vmfb(file_stem, target, ctx) + else: + for f, url in vmfb_urls.items(): + if update or needs_file(f, ctx, url): + fetch_http(name=f, url=url) + + params_filenames = get_params_filenames(model_params, model=model, splat=splat) + params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET) + for f, url in params_urls.items(): + if needs_file(f, ctx, url): + fetch_http_check_size(name=f, url=url) + filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames] + return filenames + + +if __name__ == "__main__": + iree_build_main() diff --git a/shortfin/python/shortfin_apps/flux/components/config_artifacts.py b/shortfin/python/shortfin_apps/flux/components/config_artifacts.py new file mode 100644 index 000000000..8c779ebc6 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/config_artifacts.py @@ -0,0 +1,112 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from iree.build import * +from iree.build.executor import FileNamespace +import os + +ARTIFACT_VERSION = "12032024" +SDXL_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/flux.1/{ARTIFACT_VERSION}/configs/" + + +def get_url_map(filenames: list[str], bucket: str): + file_map = {} + for filename in filenames: + file_map[filename] = f"{bucket}{filename}" + return file_map + + +def needs_update(ctx): + stamp = ctx.allocate_file("version.txt") + stamp_path = stamp.get_fs_path() + if os.path.exists(stamp_path): + with open(stamp_path, "r") as s: + ver = s.read() + if ver != ARTIFACT_VERSION: + return True + else: + with open(stamp_path, "w") as s: + s.write(ARTIFACT_VERSION) + return True + return False + + +def needs_file(filename, ctx, namespace=FileNamespace.GEN): + out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path() + if os.path.exists(out_file): + needed = False + else: + # name_path = "bin" if namespace == FileNamespace.BIN else "" + # if name_path: + # filename = os.path.join(name_path, filename) + filekey = os.path.join(ctx.path, filename) + ctx.executor.all[filekey] = None + needed = True + return needed + + +@entrypoint(description="Retreives a set of SDXL configuration files.") +def sdxlconfig( + target=cl_arg( + "target", + default="gfx942", + help="IREE target architecture.", + ), + model=cl_arg("model", type=str, default="sdxl", help="Model architecture"), + topology=cl_arg( + "topology", + type=str, + default=None, + help="System topology configfile keyword", + ), +): + ctx = executor.BuildContext.current() + update = needs_update(ctx) + + # model_config_filenames = [f"{model}_config_i8.json"] + # model_config_urls = get_url_map(model_config_filenames, SDXL_CONFIG_BUCKET) + # for f, url in model_config_urls.items(): + # if update or needs_file(f, ctx): + # fetch_http(name=f, url=url) + + if topology: + topology_config_filenames = [f"topology_config_{topology}.txt"] + topology_config_urls = get_url_map( + topology_config_filenames, SDXL_CONFIG_BUCKET + ) + for f, url in topology_config_urls.items(): + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + + # flagfile_filenames = [f"{model}_flagfile_{target}.txt"] + # flagfile_urls = get_url_map(flagfile_filenames, SDXL_CONFIG_BUCKET) + # for f, url in flagfile_urls.items(): + # if update or needs_file(f, ctx): + # fetch_http(name=f, url=url) + + tuning_filenames = ( + [f"attention_and_matmul_spec_{target}.mlir"] if target == "gfx942" else [] + ) + tuning_urls = get_url_map(tuning_filenames, SDXL_CONFIG_BUCKET) + for f, url in tuning_urls.items(): + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + filenames = [ + # *model_config_filenames, + # *flagfile_filenames, + *tuning_filenames, + ] + if topology: + filenames.extend( + [ + *topology_config_filenames, + ] + ) + return filenames + + +if __name__ == "__main__": + iree_build_main() diff --git a/shortfin/python/shortfin_apps/flux/components/config_struct.py b/shortfin/python/shortfin_apps/flux/components/config_struct.py new file mode 100644 index 000000000..0e23285e1 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/config_struct.py @@ -0,0 +1,112 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Configuration objects. + +Parameters that are intrinsic to a specific model. + +Typically represented in something like a Huggingface config.json, +we extend the configuration to enumerate inference boundaries of some given set of compiled modules. +""" + +from dataclasses import dataclass +from pathlib import Path + +from dataclasses_json import dataclass_json, Undefined + +import shortfin.array as sfnp + +str_to_dtype = { + "int8": sfnp.int8, + "float16": sfnp.float16, + "bfloat16": sfnp.bfloat16, + "float32": sfnp.float32, +} + + +@dataclass_json(undefined=Undefined.RAISE) +@dataclass +class ModelParams: + """Parameters for a specific set of compiled SD submodels, sufficient to do batching / + invocations.""" + + # Batch sizes that each stage is compiled for. These are expected to be + # functions exported from the model with suffixes of "_bs{batch_size}". Must + # be in ascending order. + clip_batch_sizes: list[int] + + t5xxl_batch_sizes: list[int] + + sampler_batch_sizes: list[int] + + vae_batch_sizes: list[int] + + # Height and Width, respectively, for which sampler and VAE are compiled. e.g. [[512, 512], [1024, 1024]] + dims: list[list[int]] + + base_model_name: str = "flux_dev" + clip_max_seq_len: int = 77 + clip_module_name: str = "compiled_flux_text_encoder" + clip_fn_name: str = "encode_prompts" + clip_dtype: sfnp.DType = sfnp.bfloat16 + + max_seq_len: int = 512 + t5xxl_module_name: str = "module" + t5xxl_fn_name: str = "encode_prompts" + t5xxl_dtype: sfnp.DType = sfnp.bfloat16 + + # Channel dim of latents. + num_latents_channels: int = 16 + + #sampler_module_name: str = "compiled_flux_transformer" + #sampler_fn_name: str = "run_forward" + sampler_module_name: str = "module" + sampler_fn_name: str = "main_graph" + sampler_dtype: sfnp.DType = sfnp.float32 + + vae_module_name: str = "compiled_vae" + vae_fn_name: str = "decode" + vae_dtype: sfnp.DType = sfnp.float32 + + # Whether model is "schnell" (fast) or not. This is roughly equivalent to "turbo" from SDXL. + # It cuts batch dims in half for sampling/encoding and removes negative prompt functionality. + is_schnell: bool = False + + # ABI of the module. + module_abi_version: int = 1 + + # TODO: Understand when this should be a value other than 1 + cfg_mult: int = 1 + + @property + def max_clip_batch_size(self) -> int: + return self.clip_batch_sizes[-1] + + @property + def max_sampler_batch_size(self) -> int: + return self.sampler_batch_sizes[-1] + + @property + def max_vae_batch_size(self) -> int: + return self.vae_batch_sizes[-1] + + @property + def all_batch_sizes(self) -> list: + return [self.clip_batch_sizes, self.sampler_batch_sizes, self.vae_batch_sizes] + + @property + def max_batch_size(self): + return max(self.all_batch_sizes) + + @staticmethod + def load_json(path: Path | str): + with open(path, "rt") as f: + json_text = f.read() + raw_params = ModelParams.from_json(json_text) + for i in ["sampler_dtype", "t5xxl_dtype", "clip_dtype", "vae_dtype"]: + if isinstance(i, str): + setattr(raw_params, i, str_to_dtype[getattr(raw_params, i)]) + return raw_params diff --git a/shortfin/python/shortfin_apps/flux/components/generate.py b/shortfin/python/shortfin_apps/flux/components/generate.py new file mode 100644 index 000000000..1c3560a5d --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/generate.py @@ -0,0 +1,102 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio +import logging +import json + +import shortfin as sf + +# TODO: Have a generic "Responder" interface vs just the concrete impl. +from shortfin.interop.fastapi import FastAPIResponder + +from .io_struct import GenerateReqInput +from .messages import InferenceExecRequest +from .service import GenerateService +from .metrics import measure + +logger = logging.getLogger("shortfin-flux.generate") + + +class GenerateImageProcess(sf.Process): + """Process instantiated for every image generation. + + This process breaks the sequence into individual inference and sampling + steps, submitting them to the batcher and marshaling final + results. + + Responsible for a single image. + """ + + def __init__( + self, + client: "ClientGenerateBatchProcess", + gen_req: GenerateReqInput, + index: int, + ): + super().__init__(fiber=client.fiber) + self.client = client + self.gen_req = gen_req + self.index = index + self.result_image = None + + async def run(self): + exec = InferenceExecRequest.from_batch(self.gen_req, self.index) + self.client.batcher.submit(exec) + await exec.done + self.result_image = exec.result_image + + +class ClientGenerateBatchProcess(sf.Process): + """Process instantiated for handling a batch from a client. + + This takes care of several responsibilities: + + * Tokenization + * Random Latents Generation + * Splitting the batch into GenerateImageProcesses + * Streaming responses + * Final responses + """ + + __slots__ = [ + "batcher", + "complete_infeed", + "gen_req", + "responder", + ] + + def __init__( + self, + service: GenerateService, + gen_req: GenerateReqInput, + responder: FastAPIResponder, + ): + super().__init__(fiber=service.fibers[0]) + self.gen_req = gen_req + self.responder = responder + self.batcher = service.batcher + self.complete_infeed = self.system.create_queue() + + async def run(self): + logger.debug("Started ClientBatchGenerateProcess: %r", self) + try: + # Launch all individual generate processes and wait for them to finish. + gen_processes = [] + for index in range(self.gen_req.num_output_images): + gen_process = GenerateImageProcess(self, self.gen_req, index) + gen_processes.append(gen_process) + gen_process.launch() + + await asyncio.gather(*gen_processes) + + # TODO: stream image outputs + logging.debug("Responding to one shot batch") + response_data = {"images": [p.result_image for p in gen_processes]} + json_str = json.dumps(response_data) + self.responder.send_response(json_str) + finally: + self.responder.ensure_response() diff --git a/shortfin/python/shortfin_apps/flux/components/io_struct.py b/shortfin/python/shortfin_apps/flux/components/io_struct.py new file mode 100644 index 000000000..73e77316f --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/io_struct.py @@ -0,0 +1,79 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import List, Optional, Union +from dataclasses import dataclass +import uuid + + +@dataclass +class GenerateReqInput: + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None + # The input negative prompt. It can be a single prompt or a batch of prompts. + neg_prompt: Optional[Union[List[str], str]] = None + # Output image dimensions per prompt. + height: Optional[Union[List[int], int]] = None + width: Optional[Union[List[int], int]] = None + # The number of inference steps; one int per prompt. + steps: Optional[Union[List[int], int]] = None + # The classifier-free-guidance scale for denoising; one float per prompt. + guidance_scale: Optional[Union[List[float], float]] = None + # The seed for random latents generation; one int per prompt. + seed: Optional[Union[List[int], int]] = None + # Token ids: only used in place of prompt. + input_ids: Optional[Union[List[List[int]], List[int]]] = None + # Negative token ids: only used in place of negative prompt. + neg_input_ids: Optional[Union[List[List[int]], List[int]]] = None + # Output image format. Defaults to base64. One string ("PIL", "base64") + output_type: Optional[List[str]] = None + # The request id. + rid: Optional[Union[List[str], str]] = None + + def post_init(self): + if (self.prompt is None and self.input_ids is None) or ( + self.prompt is not None and self.input_ids is not None + ): + raise ValueError("Either text or input_ids should be provided.") + + if isinstance(self.prompt, str): + self.prompt = [str] + + self.num_output_images = ( + len(self.prompt) if self.prompt is not None else len(self.input_ids) + ) + + batchable_args = [ + self.prompt, + self.neg_prompt, + self.height, + self.width, + self.steps, + self.guidance_scale, + self.seed, + self.input_ids, + self.neg_input_ids, + ] + for arg in batchable_args: + if isinstance(arg, list): + if len(arg) != self.num_output_images and len(arg) != 1: + raise ValueError( + f"Batchable arguments should either be singular or as many as the full batch ({self.num_output_images})." + ) + if self.rid is None: + self.rid = [uuid.uuid4().hex for _ in range(self.num_output_images)] + else: + if not isinstance(self.rid, list): + raise ValueError("The rid should be a list.") + if self.output_type is None: + self.output_type = ["base64"] * self.num_output_images + # Temporary restrictions + heights = [self.height] if not isinstance(self.height, list) else self.height + widths = [self.width] if not isinstance(self.width, list) else self.width + if any(dim != 1024 for dim in [*heights, *widths]): + raise ValueError( + "Currently, only 1024x1024 output image size is supported." + ) diff --git a/shortfin/python/shortfin_apps/flux/components/manager.py b/shortfin/python/shortfin_apps/flux/components/manager.py new file mode 100644 index 000000000..37f090b3f --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/manager.py @@ -0,0 +1,48 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import threading + +import shortfin as sf +from shortfin.interop.support.device_setup import get_selected_devices + +logger = logging.getLogger("shortfin-flux.manager") + + +class SystemManager: + def __init__(self, device="local-task", device_ids=None, async_allocs=True): + if any(x in device for x in ["local-task", "cpu"]): + self.ls = sf.host.CPUSystemBuilder().create_system() + elif any(x in device for x in ["hip", "amdgpu"]): + sb = sf.SystemBuilder( + system_type="amdgpu", amdgpu_async_allocations=async_allocs + ) + if device_ids: + sb.visible_devices = sb.available_devices + sb.visible_devices = get_selected_devices(sb, device_ids) + self.ls = sb.create_system() + logger.info(f"Created local system with {self.ls.device_names} devices") + # TODO: Come up with an easier bootstrap thing than manually + # running a thread. + self.t = threading.Thread(target=lambda: self.ls.run(self.run())) + self.command_queue = self.ls.create_queue("command") + self.command_writer = self.command_queue.writer() + + def start(self): + logger.info("Starting system manager") + self.t.start() + + def shutdown(self): + logger.info("Shutting down system manager") + self.command_queue.close() + self.ls.shutdown() + + async def run(self): + reader = self.command_queue.reader() + while command := await reader(): + ... + logger.info("System manager command processor stopped") diff --git a/shortfin/python/shortfin_apps/flux/components/messages.py b/shortfin/python/shortfin_apps/flux/components/messages.py new file mode 100644 index 000000000..8646da1f6 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/messages.py @@ -0,0 +1,190 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from enum import Enum + +import logging + +import shortfin as sf +import shortfin.array as sfnp + +from .io_struct import GenerateReqInput + +logger = logging.getLogger("shortfin-sd.messages") + + +class InferencePhase(Enum): + # Tokenize prompt, negative prompt and get latents, timesteps, time ids, guidance scale as device arrays + PREPARE = 1 + # Run CLIP to encode tokenized prompts into text embeddings + ENCODE = 2 + # Run UNet to denoise the random sample + DENOISE = 3 + # Run VAE to decode the denoised latents into an image. + DECODE = 4 + # Postprocess VAE outputs. + POSTPROCESS = 5 + + +class InferenceExecRequest(sf.Message): + """ + Generalized request passed for an individual phase of image generation. + + Used for individual image requests. Bundled as lists by the batcher for inference processes, + and inputs joined for programs with bs>1. + + Inference execution processes are responsible for writing their outputs directly to the appropriate attributes here. + """ + + def __init__( + self, + prompt: str | None = None, + neg_prompt: str | None = None, + height: int | None = None, + width: int | None = None, + steps: int | None = None, + guidance_scale: float | sfnp.device_array | None = None, + seed: int | None = None, + clip_input_ids: list[list[int]] | None = None, + t5xxl_input_ids: list[list[int]] | None = None, + sample: sfnp.device_array | None = None, + txt: sfnp.device_array | None = None, + vec: sfnp.device_array | None = None, + img_ids: sfnp.device_array | None = None, + txt_ids: sfnp.device_array | None = None, + timesteps: sfnp.device_array | None = None, + denoised_latents: sfnp.device_array | None = None, + image_array: sfnp.device_array | None = None, + ): + super().__init__() + self.print_debug = True + + self.phases = {} + self.phase = None + self.height = height + self.width = width + + # Phase inputs: + # Prep phase. + self.prompt = prompt + self.neg_prompt = neg_prompt + self.height = height + self.width = width + self.seed = seed + + # Encode phase. + # This is a list of sequenced positive and negative token ids and pooler token ids. + self.clip_input_ids = clip_input_ids + self.t5xxl_input_ids = t5xxl_input_ids + self.sample = sample + + # Denoise phase. + self.img = None + self.txt = txt + self.vec = vec + self.img_ids = img_ids + self.txt_ids = txt_ids + self.steps = steps + self.timesteps = timesteps + self.guidance_scale = guidance_scale + + # Decode phase. + self.denoised_latents = denoised_latents + + # Postprocess. + self.image_array = image_array + + self.result_image = None + self.img_metadata = None + + self.done = sf.VoidFuture() + + # Response control. + # Move the result array to the host and sync to ensure data is + # available. + self.return_host_array: bool = True + + self.post_init() + + @staticmethod + def from_batch(gen_req: GenerateReqInput, index: int) -> "InferenceExecRequest": + gen_inputs = [ + "prompt", + "neg_prompt", + "height", + "width", + "steps", + "guidance_scale", + "seed", + ] + rec_inputs = {} + for item in gen_inputs: + received = getattr(gen_req, item, None) + if isinstance(received, list): + if index >= (len(received)): + if len(received) == 1: + rec_input = received[0] + else: + logging.error( + "Inputs in request must be singular or as many as the list of prompts." + ) + else: + rec_input = received[index] + else: + rec_input = received + rec_inputs[item] = rec_input + return InferenceExecRequest(**rec_inputs) + + def post_init(self): + """Determines necessary inference phases and tags them with static program parameters.""" + for p in reversed(list(InferencePhase)): + required, metadata = self.check_phase(p) + p_data = {"required": required, "metadata": metadata} + self.phases[p] = p_data + if not required: + if p not in [ + InferencePhase.ENCODE, + InferencePhase.PREPARE, + ]: + break + self.phase = p + + def check_phase(self, phase: InferencePhase): + match phase: + case InferencePhase.POSTPROCESS: + return True, None + case InferencePhase.DECODE: + required = not self.image_array + meta = [self.width, self.height] + return required, meta + case InferencePhase.DENOISE: + required = not self.denoised_latents + meta = [self.width, self.height, self.steps] + return required, meta + case InferencePhase.ENCODE: + p_results = [ + self.txt, + self.vec, + ] + required = any([inp is None for inp in p_results]) + return required, None + case InferencePhase.PREPARE: + p_results = [self.sample, self.clip_input_ids, self.t5xxl_input_ids] + required = any([inp is None for inp in p_results]) + return required, None + + def reset(self, phase: InferencePhase): + """Resets all per request state in preparation for an subsequent execution.""" + self.phase = None + self.phases = None + self.done = sf.VoidFuture() + self.return_host_array = True + + +class StrobeMessage(sf.Message): + """Sent to strobe a queue with fake activity (generate a wakeup).""" + + ... diff --git a/shortfin/python/shortfin_apps/flux/components/metrics.py b/shortfin/python/shortfin_apps/flux/components/metrics.py new file mode 100644 index 000000000..62e855698 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/metrics.py @@ -0,0 +1,51 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import time +from typing import Any +import functools + +logger = logging.getLogger("shortfin-sd.metrics") + + +def measure(fn=None, type="exec", task=None, num_items=None, freq=1, label="items"): + assert callable(fn) or fn is None + + def _decorator(func): + @functools.wraps(func) + async def wrapped_fn_async(*args: Any, **kwargs: Any) -> Any: + start = time.time() + ret = await func(*args, **kwargs) + duration = time.time() - start + if type == "exec": + batch_size = len(getattr(args[0], "exec_requests", [])) + log_duration_str(duration, task=task, batch_size=batch_size) + if type == "throughput": + if isinstance(num_items, str): + items = getattr(args[0].gen_req, num_items) + else: + items = str(num_items) + log_throughput(duration, items, freq, label) + return ret + + return wrapped_fn_async + + return _decorator(fn) if callable(fn) else _decorator + + +def log_throughput(duration, num_items, freq, label) -> str: + sps = str(float(num_items) / duration) * freq + freq_str = "second" if freq == 1 else f"{freq} seconds" + logger.info(f"THROUGHPUT: {sps} {label} per {freq_str}") + + +def log_duration_str(duration: float, task, batch_size=0) -> str: + """Get human readable duration string from start time""" + if batch_size > 0: + task = f"{task} (batch size {batch_size})" + duration_str = f"{round(duration * 1e3)}ms" + logger.info(f"Completed {task} in {duration_str}") diff --git a/shortfin/python/shortfin_apps/flux/components/service.py b/shortfin/python/shortfin_apps/flux/components/service.py new file mode 100644 index 000000000..c539a6779 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/service.py @@ -0,0 +1,824 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio +import logging +import math +import torch +import numpy as np +from tqdm.auto import tqdm +from pathlib import Path +from typing import Callable +from PIL import Image +import base64 + +import shortfin as sf +import shortfin.array as sfnp + +from .config_struct import ModelParams +from .manager import SystemManager +from .messages import InferenceExecRequest, InferencePhase, StrobeMessage +from .tokenizer import Tokenizer +from .metrics import measure + +logger = logging.getLogger("shortfin-flux.service") + +prog_isolations = { + "none": sf.ProgramIsolation.NONE, + "per_fiber": sf.ProgramIsolation.PER_FIBER, + "per_call": sf.ProgramIsolation.PER_CALL, +} + + +def time_shift(mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_lin_function( + x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 +) -> Callable[[float], float]: + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + +def get_schedule( + num_steps: int, + image_seq_len: int, + base_shift: float = 0.5, + max_shift: float = 1.15, + shift: bool = True, +) -> list[float]: + # extra step for zero + timesteps = torch.linspace(1, 0, num_steps + 1) + + # shifting the schedule to favor high timesteps for higher signal images + if shift: + # estimate mu based on linear estimation between two points + mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) + timesteps = time_shift(mu, 1.0, timesteps) + + return timesteps.tolist() + + +class GenerateService: + """Top level service interface for image generation.""" + + inference_programs: dict[str, sf.Program] + + inference_functions: dict[str, dict[str, sf.ProgramFunction]] + + def __init__( + self, + *, + name: str, + sysman: SystemManager, + clip_tokenizers: list[Tokenizer], + t5xxl_tokenizers: list[Tokenizer], + model_params: ModelParams, + fibers_per_device: int, + workers_per_device: int = 1, + prog_isolation: str = "per_fiber", + show_progress: bool = False, + trace_execution: bool = False, + ): + self.name = name + + # Application objects. + self.sysman = sysman + self.clip_tokenizers = clip_tokenizers + self.t5xxl_tokenizers = t5xxl_tokenizers + self.model_params = model_params + self.inference_parameters: dict[str, list[sf.BaseProgramParameters]] = {} + self.inference_modules: dict[str, sf.ProgramModule] = {} + self.inference_functions: dict[str, dict[str, sf.ProgramFunction]] = {} + self.inference_programs: dict[int, dict[str, sf.Program]] = {} + self.trace_execution = trace_execution + self.show_progress = show_progress + + self.prog_isolation = prog_isolations[prog_isolation] + + self.workers_per_device = workers_per_device + self.fibers_per_device = fibers_per_device + if fibers_per_device % workers_per_device != 0: + raise ValueError( + "Currently, fibers_per_device must be divisible by workers_per_device" + ) + self.fibers_per_worker = int(fibers_per_device / workers_per_device) + + self.workers = [] + self.fibers = [] + self.idle_fibers = set() + # For each worker index we create one on each device, and add their fibers to the idle set. + # This roughly ensures that the first picked fibers are distributed across available devices. + for i in range(self.workers_per_device): + for idx, device in enumerate(self.sysman.ls.devices): + worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}") + self.workers.append(worker) + for idx, device in enumerate(self.sysman.ls.devices): + for i in range(self.fibers_per_device): + tgt_worker = self.workers[i % len(self.workers)] + fiber = sysman.ls.create_fiber(tgt_worker, devices=[device]) + self.fibers.append(fiber) + self.idle_fibers.add(fiber) + for idx in range(len(self.workers)): + self.inference_programs[idx] = {} + self.inference_functions[idx] = { + "clip": {}, + "t5xxl": {}, + "denoise": {}, + "decode": {}, + } + # Scope dependent objects. + self.batcher = BatcherProcess(self) + + def get_worker_index(self, fiber): + if fiber not in self.fibers: + raise ValueError("A worker was requested from a rogue fiber.") + fiber_idx = self.fibers.index(fiber) + worker_idx = int( + (fiber_idx - fiber_idx % self.fibers_per_worker) / self.fibers_per_worker + ) + return worker_idx + + def load_inference_module(self, vmfb_path: Path, component: str = None): + if not self.inference_modules.get(component): + self.inference_modules[component] = [] + self.inference_modules[component].append( + sf.ProgramModule.load(self.sysman.ls, vmfb_path) + ) + + def load_inference_parameters( + self, + *paths: Path, + parameter_scope: str, + format: str = "", + component: str = None, + ): + p = sf.StaticProgramParameters(self.sysman.ls, parameter_scope=parameter_scope) + for path in paths: + logger.info("Loading parameter fiber '%s' from: %s", parameter_scope, path) + p.load(path, format=format) + if not self.inference_parameters.get(component): + self.inference_parameters[component] = [] + self.inference_parameters[component].append(p) + + def start(self): + # Initialize programs. + for component in self.inference_modules: + logger.info(f"Loading component: {component}") + component_modules = [ + sf.ProgramModule.parameter_provider( + self.sysman.ls, *self.inference_parameters.get(component, []) + ), + *self.inference_modules[component], + ] + + for worker_idx, worker in enumerate(self.workers): + worker_devices = self.fibers[ + worker_idx * (self.fibers_per_worker) + ].raw_devices + logger.info( + f"Loading inference program: {component}, worker index: {worker_idx}, device: {worker_devices}" + ) + self.inference_programs[worker_idx][component] = sf.Program( + modules=component_modules, + devices=worker_devices, + isolation=self.prog_isolation, + trace_execution=self.trace_execution, + ) + + for worker_idx, worker in enumerate(self.workers): + for bs in self.model_params.clip_batch_sizes: + self.inference_functions[worker_idx]["clip"][ + bs + ] = self.inference_programs[worker_idx]["clip"][ + f"{self.model_params.clip_module_name}.encode_prompts" + ] + for bs in self.model_params.t5xxl_batch_sizes: + self.inference_functions[worker_idx]["t5xxl"][ + bs + ] = self.inference_programs[worker_idx]["t5xxl"][ + f"{self.model_params.t5xxl_module_name}.encode_prompts" + ] + self.inference_functions[worker_idx]["denoise"] = {} + for bs in self.model_params.sampler_batch_sizes: + self.inference_functions[worker_idx]["denoise"][bs] = { + "sampler": self.inference_programs[worker_idx]["sampler"][ + f"{self.model_params.sampler_module_name}.{self.model_params.sampler_fn_name}" + ], + } + self.inference_functions[worker_idx]["decode"] = {} + for bs in self.model_params.vae_batch_sizes: + self.inference_functions[worker_idx]["decode"][ + bs + ] = self.inference_programs[worker_idx]["vae"][ + f"{self.model_params.vae_module_name}.decode" + ] + self.batcher.launch() + + def shutdown(self): + self.batcher.shutdown() + + def __repr__(self): + modules = [ + f" {key} : {value}" for key, value in self.inference_modules.items() + ] + params = [ + f" {key} : {value}" for key, value in self.inference_parameters.items() + ] + # For python 3.11 since we can't have \ in the f"" expression. + new_line = "\n" + return ( + f"ServiceManager(" + f"\n INFERENCE DEVICES : \n" + f" {self.sysman.ls.devices}\n" + f"\n MODEL PARAMS : \n" + f"{self.model_params}" + f"\n SERVICE PARAMS : \n" + f" fibers per device : {self.fibers_per_device}\n" + f" program isolation mode : {self.prog_isolation}\n" + f"\n INFERENCE MODULES : \n" + f"{new_line.join(modules)}\n" + f"\n INFERENCE PARAMETERS : \n" + f"{new_line.join(params)}\n" + f")" + ) + + +######################################################################################## +# Batcher +######################################################################################## + + +class BatcherProcess(sf.Process): + """The batcher is a persistent process responsible for flighting incoming work + into batches. + """ + + STROBE_SHORT_DELAY = 0.5 + STROBE_LONG_DELAY = 1 + + def __init__(self, service: GenerateService): + super().__init__(fiber=service.fibers[0]) + self.service = service + self.batcher_infeed = self.system.create_queue() + self.pending_requests: set[InferenceExecRequest] = set() + self.strobe_enabled = True + self.strobes: int = 0 + self.ideal_batch_size: int = max(service.model_params.max_batch_size) + self.num_fibers = len(service.fibers) + + def shutdown(self): + self.batcher_infeed.close() + + def submit(self, request: StrobeMessage | InferenceExecRequest): + self.batcher_infeed.write_nodelay(request) + + async def _background_strober(self): + while not self.batcher_infeed.closed: + await asyncio.sleep( + BatcherProcess.STROBE_SHORT_DELAY + if len(self.pending_requests) > 0 + else BatcherProcess.STROBE_LONG_DELAY + ) + if self.strobe_enabled: + self.submit(StrobeMessage()) + + async def run(self): + strober_task = asyncio.create_task(self._background_strober()) + reader = self.batcher_infeed.reader() + while item := await reader(): + self.strobe_enabled = False + if isinstance(item, InferenceExecRequest): + self.pending_requests.add(item) + elif isinstance(item, StrobeMessage): + self.strobes += 1 + else: + logger.error("Illegal message received by batcher: %r", item) + + self.board_flights() + + self.strobe_enabled = True + await strober_task + + def board_flights(self): + waiting_count = len(self.pending_requests) + if waiting_count == 0: + return + if waiting_count < self.ideal_batch_size and self.strobes < 2: + logger.info("Waiting a bit longer to fill flight") + return + self.strobes = 0 + batches = self.sort_batches() + for batch in batches.values(): + # Assign the batch to the next idle fiber. + if len(self.service.idle_fibers) == 0: + return + fiber = self.service.idle_fibers.pop() + fiber_idx = self.service.fibers.index(fiber) + worker_idx = self.service.get_worker_index(fiber) + logger.debug(f"Sending batch to fiber {fiber_idx} (worker {worker_idx})") + self.board(batch["reqs"], fiber=fiber) + if self.service.prog_isolation != sf.ProgramIsolation.PER_FIBER: + self.service.idle_fibers.add(fiber) + + def sort_batches(self): + """Files pending requests into sorted batches suitable for program invocations.""" + reqs = self.pending_requests + next_key = 0 + batches = {} + for req in reqs: + is_sorted = False + req_metas = [req.phases[phase]["metadata"] for phase in req.phases.keys()] + + for idx_key, data in batches.items(): + if not isinstance(data, dict): + logger.error( + "Expected to find a dictionary containing a list of requests and their shared metadatas." + ) + if len(batches[idx_key]["reqs"]) >= self.ideal_batch_size: + # Batch is full + next_key = idx_key + 1 + continue + elif data["meta"] == req_metas: + batches[idx_key]["reqs"].extend([req]) + is_sorted = True + break + else: + next_key = idx_key + 1 + if not is_sorted: + batches[next_key] = { + "reqs": [req], + "meta": req_metas, + } + return batches + + def board(self, request_bundle, fiber): + pending = request_bundle + if len(pending) == 0: + return + exec_process = InferenceExecutorProcess(self.service, fiber) + for req in pending: + if len(exec_process.exec_requests) >= self.ideal_batch_size: + break + exec_process.exec_requests.append(req) + if exec_process.exec_requests: + for flighted_request in exec_process.exec_requests: + self.pending_requests.remove(flighted_request) + exec_process.launch() + + +######################################################################################## +# Inference Executors +######################################################################################## + + +class InferenceExecutorProcess(sf.Process): + """Executes a stable diffusion inference batch""" + + def __init__( + self, + service: GenerateService, + fiber, + ): + super().__init__(fiber=fiber) + self.service = service + self.worker_index = self.service.get_worker_index(fiber) + self.exec_requests: list[InferenceExecRequest] = [] + + @measure(type="exec", task="inference process") + async def run(self): + try: + phase = None + for req in self.exec_requests: + if phase: + if phase != req.phase: + logger.error("Executor process recieved disjoint batch.") + phase = req.phase + phases = self.exec_requests[0].phases + req_count = len(self.exec_requests) + device0 = self.fiber.device(0) + if phases[InferencePhase.PREPARE]["required"]: + await self._prepare(device=device0, requests=self.exec_requests) + if phases[InferencePhase.ENCODE]["required"]: + await self._clip(device=device0, requests=self.exec_requests) + await self._t5xxl(device=device0, requests=self.exec_requests) + if phases[InferencePhase.DENOISE]["required"]: + await self._denoise(device=device0, requests=self.exec_requests) + if phases[InferencePhase.DECODE]["required"]: + await self._decode(device=device0, requests=self.exec_requests) + if phases[InferencePhase.POSTPROCESS]["required"]: + await self._postprocess(device=device0, requests=self.exec_requests) + await device0 + for i in range(req_count): + req = self.exec_requests[i] + req.done.set_success() + if self.service.prog_isolation == sf.ProgramIsolation.PER_FIBER: + self.service.idle_fibers.add(self.fiber) + + except Exception: + logger.exception("Fatal error in image generation") + # TODO: Cancel and set error correctly + for req in self.exec_requests: + req.done.set_success() + + async def _prepare(self, device, requests): + for request in requests: + # Tokenize prompts and negative prompts. We tokenize in bs1 for now and join later. + clip_input_ids_list = [] + clip_neg_ids_list = [] + for tokenizer in self.service.clip_tokenizers: + input_ids = tokenizer.encode(request.prompt) + clip_input_ids_list.append(input_ids) + neg_ids = tokenizer.encode(request.neg_prompt) + clip_neg_ids_list.append(neg_ids) + clip_ids_list = [*clip_input_ids_list, *clip_neg_ids_list] + + request.clip_input_ids = clip_ids_list + + t5xxl_input_ids_list = [] + t5xxl_neg_ids_list = [] + for tokenizer in self.service.t5xxl_tokenizers: + input_ids = tokenizer.encode(request.prompt) + t5xxl_input_ids_list.append(input_ids) + neg_ids = tokenizer.encode(request.neg_prompt) + t5xxl_neg_ids_list.append(neg_ids) + t5xxl_ids_list = [*t5xxl_input_ids_list, *t5xxl_neg_ids_list] + + request.t5xxl_input_ids = t5xxl_ids_list + + # Generate random sample latents. + seed = request.seed + channels = self.service.model_params.num_latents_channels + image_seq_len = (request.height) * (request.width) // 256 + latents_shape = [ + 1, + image_seq_len, + 64, + ] + + # Create and populate sample device array. + generator = sfnp.RandomGenerator(seed) + request.sample = sfnp.device_array.for_device( + device, latents_shape, self.service.model_params.sampler_dtype + ) + + sample_host = sfnp.device_array.for_host( + device, latents_shape, sfnp.float32 + ) + with sample_host.map(discard=True) as m: + m.fill(bytes(1)) + sfnp.fill_randn(sample_host, generator=generator) + if self.service.model_params.sampler_dtype != sfnp.float32: + sample_transfer = request.sample.for_transfer() + sfnp.convert( + sample_host, + dtype=self.service.model_params.sampler_dtype, + out=sample_transfer, + ) + + request.sample.copy_from(sample_transfer) + else: + request.sample.copy_from(sample_host) + + await device + request.timesteps = get_schedule( + request.steps, + image_seq_len, + shift=not self.service.model_params.is_schnell, + ) + return + + async def _clip(self, device, requests): + req_bs = len(requests) + entrypoints = self.service.inference_functions[self.worker_index]["clip"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._clip(device, [request]) + return + for bs, fn in entrypoints.items(): + if bs == req_bs: + break + + # Prepare tokenized input ids for CLIP inference + clip_inputs = [ + sfnp.device_array.for_device( + device, + [req_bs, self.service.model_params.clip_max_seq_len], + sfnp.sint64, + ), + ] + host_arrs = [None] + for idx, arr in enumerate(clip_inputs): + host_arrs[idx] = arr.for_transfer() + for i in range(req_bs): + with host_arrs[idx].view(i).map(write=True, discard=True) as m: + + num_ids = len(requests[i].clip_input_ids) + np_arr = requests[i].clip_input_ids[idx % (num_ids - 1)].input_ids + + m.fill(np_arr) + clip_inputs[idx].copy_from(host_arrs[idx]) + + # Encode tokenized inputs. + logger.debug( + "INVOKE %r: %s", + fn, + "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]), + ) + await device + (vec,) = await fn(*clip_inputs, fiber=self.fiber) + await device + + for i in range(req_bs): + cfg_mult = 1 + requests[i].vec = vec.view(slice(i, (i + 1))) + + await device + a = vec.for_transfer() + a.copy_from(vec) + await device + + return + + async def _t5xxl(self, device, requests): + req_bs = len(requests) + entrypoints = self.service.inference_functions[self.worker_index]["t5xxl"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._t5xxl(device, [request]) + return + for bs, fn in entrypoints.items(): + if bs == req_bs: + break + + # Prepare tokenized input ids for t5xxl inference + # TODO(eagarvey): Refactor + bs_or_something = 1 + t5xxl_inputs = [ + sfnp.device_array.for_device( + device, [bs_or_something, self.service.model_params.max_seq_len], sfnp.sint64 + ), + ] + host_arrs = [None] + for idx, arr in enumerate(t5xxl_inputs): + host_arrs[idx] = arr.for_transfer() + for i in range(req_bs): + np_arr = requests[i].t5xxl_input_ids[idx].input_ids + for rep in range(bs_or_something): + with host_arrs[idx].view(rep).map(write=True, discard=True) as m: + m.fill(np_arr) + t5xxl_inputs[idx].copy_from(host_arrs[idx]) + + # Encode tokenized inputs. + logger.debug( + "INVOKE %r: %s", + fn, + "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(t5xxl_inputs)]), + ) + await device + (txt,) = await fn(*t5xxl_inputs, fiber=self.fiber) + await device + for i in range(req_bs): + cfg_mult = requests[i].cfg_mult + requests[i].txt = txt.view(slice(i * cfg_mult, (i + 1) * cfg_mult)) + + return + + async def _denoise(self, device, requests): + req_bs = len(requests) + step_count = requests[0].steps + cfg_mult = self.requests[0].cfg_mult + + # Produce denoised latents + entrypoints = self.service.inference_functions[self.worker_index]["denoise"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._denoise(device, [request]) + return + for bs, fns in entrypoints.items(): + if bs == req_bs: + break + + # Get shape of batched latents. + # This assumes all requests are dense at this point. + img_shape = [ + req_bs * cfg_mult, + (requests[0].height) * (requests[0].width) // 256, + 64, + ] + # Assume we are doing classifier-free guidance + txt_shape = [ + req_bs * cfg_mult, + self.service.model_params.max_seq_len, + 4096, + ] + print(req_bs * cfg_mult, 768) + vec_shape = [ + req_bs * cfg_mult, + 768, + ] + denoise_inputs = { + "img": sfnp.device_array.for_device( + device, img_shape, self.service.model_params.sampler_dtype + ), + "txt": sfnp.device_array.for_device( + device, txt_shape, self.service.model_params.sampler_dtype + ), + "vec": sfnp.device_array.for_device( + device, vec_shape, self.service.model_params.sampler_dtype + ), + "step": sfnp.device_array.for_device(device, [1], sfnp.int64), + "timesteps": sfnp.device_array.for_device( + device, [100], self.service.model_params.sampler_dtype + ), + "guidance_scale": sfnp.device_array.for_device( + device, [req_bs], self.service.model_params.sampler_dtype + ), + } + # Send guidance scale to device. + gs_host = denoise_inputs["guidance_scale"].for_transfer() + sample_host = sfnp.device_array.for_host( + device, img_shape, self.service.model_params.sampler_dtype + ) + guidance_float = sfnp.device_array.for_host(device, [req_bs], sfnp.float32) + await device + + for i in range(req_bs): + guidance_float.view(i).items = [requests[i].guidance_scale] + cfg_dim = i * cfg_mult + + # Reshape and batch sample latent inputs on device. + # Currently we just generate random latents in the desired shape. Rework for img2img. + req_samp = requests[i].sample + for rep in range(cfg_mult): + sample_host.view(slice(cfg_dim + rep, cfg_dim + rep + 1)).copy_from( + req_samp + ) + + denoise_inputs["img"].view(slice(cfg_dim, cfg_dim + cfg_mult)).copy_from( + sample_host + ) + + # Batch t5xxl hidden states. + txt = requests[i].txt + denoise_inputs["txt"].view(slice(cfg_dim, cfg_dim + cfg_mult)).copy_from( + txt + ) + + # Batch CLIP projections. + vec = requests[i].vec + for nc in range(cfg_mult): + denoise_inputs["vec"].view(slice(nc, nc + 1)).copy_from(vec) + sfnp.convert( + guidance_float, dtype=self.service.model_params.sampler_dtype, out=gs_host + ) + denoise_inputs["guidance_scale"].copy_from(gs_host) + await device + ts_host = denoise_inputs["timesteps"].for_transfer() + ts_float = sfnp.device_array.for_host( + device, denoise_inputs["timesteps"].shape, dtype=sfnp.float32 + ) + with ts_float.map(write=True) as m: + m.fill(float(1)) + for tstep in range(len(requests[0].timesteps)): + with ts_float.view(tstep).map(write=True, discard=True) as m: + m.fill(np.asarray(requests[0].timesteps[tstep], dtype="float32")) + + sfnp.convert( + ts_float, dtype=self.service.model_params.sampler_dtype, out=ts_host + ) + denoise_inputs["timesteps"].copy_from(ts_host) + await device + + for i, t in tqdm( + enumerate(range(step_count)), + disable=(not self.service.show_progress), + desc=f"DENOISE (bs{req_bs})", + ): + s_host = denoise_inputs["step"].for_transfer() + with s_host.map(write=True) as m: + s_host.items = [i] + denoise_inputs["step"].copy_from(s_host) + + logger.info( + "INVOKE %r", + fns["sampler"], + ) + await device + (noise_pred,) = await fns["sampler"]( + *denoise_inputs.values(), fiber=self.fiber + ) + await device + denoise_inputs["img"].copy_from(noise_pred) + + for idx, req in enumerate(requests): + req.denoised_latents = sfnp.device_array.for_device( + device, img_shape, self.service.model_params.vae_dtype + ) + if ( + self.service.model_params.vae_dtype + != self.service.model_params.sampler_dtype + ): + pred_shape = [ + 1, + (requests[0].height) * (requests[0].width) // 256, + 64, + ] + denoised_inter = sfnp.device_array.for_host( + device, pred_shape, dtype=self.service.model_params.vae_dtype + ) + denoised_host = sfnp.device_array.for_host( + device, pred_shape, dtype=self.service.model_params.sampler_dtype + ) + denoised_host.copy_from(denoise_inputs["img"].view(idx * cfg_mult)) + await device + sfnp.convert( + denoised_host, + dtype=self.service.model_params.vae_dtype, + out=denoised_inter, + ) + req.denoised_latents.copy_from(denoised_inter) + else: + req.denoised_latents.copy_from( + denoise_inputs["img"].view(idx * cfg_mult) + ) + return + + async def _decode(self, device, requests): + req_bs = len(requests) + # Decode latents to images + entrypoints = self.service.inference_functions[self.worker_index]["decode"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._decode(device, [request]) + return + for bs, fn in entrypoints.items(): + if bs == req_bs: + break + latents_shape = [ + req_bs, + (requests[0].height * requests[0].width) // 256, + 64, + ] + latents = sfnp.device_array.for_device( + device, latents_shape, self.service.model_params.vae_dtype + ) + + for i in range(req_bs): + latents.view(i).copy_from(requests[i].denoised_latents) + + # Decode the denoised latents. + logger.debug( + "INVOKE %r: %s", + fn, + "".join([f"\n 0: {latents.shape}"]), + ) + (image,) = await fn(latents, fiber=self.fiber) + await device + images_shape = [ + req_bs, + 3, + requests[0].height, + requests[0].width, + ] + await device + images_host = sfnp.device_array.for_host( + device, images_shape, self.service.model_params.vae_dtype + ) + await device + images_host.copy_from(image) + await device + for idx, req in enumerate(requests): + req.image_array = images_host.view(idx) + return + + async def _postprocess(self, device, requests): + # Process output images + for req in requests: + image_shape = [ + 3, + req.height, + req.width, + ] + out_shape = [req.height, req.width, 3] + images_planar = sfnp.device_array.for_host( + device, image_shape, self.service.model_params.vae_dtype + ) + images_planar.copy_from(req.image_array) + permuted = sfnp.device_array.for_host( + device, out_shape, self.service.model_params.vae_dtype + ) + out = sfnp.device_array.for_host(device, out_shape, sfnp.uint8) + sfnp.transpose(images_planar, (1, 2, 0), out=permuted) + permuted = sfnp.multiply(127.5, (sfnp.add(permuted, 1.0))) + out = sfnp.round(permuted, dtype=sfnp.uint8) + image_bytes = bytes(out.map(read=True)) + + image = base64.b64encode(image_bytes).decode("utf-8") + req.result_image = image + return diff --git a/shortfin/python/shortfin_apps/flux/components/tokenizer.py b/shortfin/python/shortfin_apps/flux/components/tokenizer.py new file mode 100644 index 000000000..54dec295b --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/components/tokenizer.py @@ -0,0 +1,84 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from transformers import CLIPTokenizer, T5Tokenizer, BatchEncoding + +import shortfin as sf +import shortfin.array as sfnp + + +class Tokenizer: + def __init__( + self, + raw_tk: CLIPTokenizer, + max_length: int = 77, + pad_id: int = 0, + attn_mask=False, + ): + self.pad_id = pad_id + self._raw = raw_tk + self.max_length = max_length + self.return_attention_mask = attn_mask + + @staticmethod + def from_pretrained(name: str, subfolder: str) -> "Tokenizer": + if subfolder == "tokenizer_2": + raw_tk = T5Tokenizer.from_pretrained(name, subfolder=subfolder) + max_length = 512 + else: + raw_tk = CLIPTokenizer.from_pretrained(name, subfolder=subfolder) + max_length = 77 + return Tokenizer(raw_tk, max_length=max_length) + + def encode(self, texts: list[str]): + """Encodes a batch of texts, applying no padding.""" + return self._raw( + texts, + padding="max_length", + max_length=self.max_length, + truncation=True, + return_tensors="np", + return_attention_mask=self.return_attention_mask, + ) + + def encoding_length(self, enc: BatchEncoding) -> int: + """Gets the length of an encoding.""" + return len(enc.input_ids) + + def encodings_to_array( + self, + device: sf.ScopedDevice, + encs: dict[str, BatchEncoding], + batch_seq_len: int, + *, + dtype: sfnp.DType = sfnp.int32, + ): + """Creates a device_array with the contents of a batch of encodings. + + It is expected that the user has called post_process_encodings with + the same batch_seq_len in order to properly truncate/pad. + """ + ary = sfnp.device_array.for_host( + device, [len(encs.input_ids), batch_seq_len], dtype + ) + for i, ids in enumerate(encs.input_ids): + ary.view(i).items = ids + return ary + + def attention_masks_to_array( + self, + device: sf.ScopedDevice, + encs: list[BatchEncoding], + batch_seq_len: int, + *, + dtype: sfnp.DType = sfnp.int32, + ): + ary = sfnp.device_array.for_host( + device, [len(encs.attention_mask), batch_seq_len], dtype + ) + for i, enc in enumerate(encs.attention_mask): + ary.view(i).items = enc + return ary diff --git a/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json b/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json new file mode 100644 index 000000000..343bdea28 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/examples/flux_dev_config_mixed.json @@ -0,0 +1,37 @@ +{ + "base_model_name": "flux_dev", + "is_schnell": false, + "num_latents_channels": 16, + "max_seq_len": 512, + "cfg_mult": 1, + "clip_max_seq_len": 77, + "clip_batch_sizes": [ + 1 + ], + "clip_dtype": "bfloat16", + "clip_module_name": "compiled_flux_text_encoder", + "t5xxl_batch_sizes": [ + 1 + ], + "t5xxl_dtype": "bfloat16", + "t5xxl_module_name": "compiled_flux_text_encoder2", + "t5xxl_fn_name": "encode_prompts", + "sampler_batch_sizes": [ + 1 + ], + "sampler_dtype": "bfloat16", + "sampler_module_name": "compiled_flux_transformer", + "sampler_fn_name": "run_forward", + "vae_batch_sizes": [ + 1 + ], + "vae_dtype": "bfloat16", + "vae_module_name": "compiled_flux_auto_encoder", + "vae_fn_name": "decode", + "dims": [ + [ + 1024, + 1024 + ] + ] +} diff --git a/shortfin/python/shortfin_apps/flux/examples/flux_flags_gfx942.txt b/shortfin/python/shortfin_apps/flux/examples/flux_flags_gfx942.txt new file mode 100644 index 000000000..2b882566e --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/examples/flux_flags_gfx942.txt @@ -0,0 +1,28 @@ +all +--iree-hal-target-backends=rocm +--iree-hip-target=gfx942 +--iree-execution-model=async-external +--iree-global-opt-propagate-transposes=1 +--iree-opt-const-eval=0 +--iree-opt-outer-dim-concat=1 +--iree-opt-aggressively-propagate-transposes=1 +--iree-dispatch-creation-enable-aggressive-fusion +--iree-codegen-llvmgpu-use-vector-distribution=1 +--iree-llvmgpu-enable-prefetch=1 +--iree-codegen-gpu-native-math-precision=1 +--iree-hip-legacy-sync=0 +--iree-opt-data-tiling=0 +--iree-vm-target-truncate-unsupported-floats +clip +--iree-hal-force-indirect-command-buffers +--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' +t5xxl +--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' +sampler +--iree-hal-force-indirect-command-buffers +--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' +--iree-dispatch-creation-enable-fuse-horizontal-contractions=1 +vae +--iree-hal-force-indirect-command-buffers +--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' +--iree-dispatch-creation-enable-fuse-horizontal-contractions=1 diff --git a/shortfin/python/shortfin_apps/flux/examples/flux_request_bs2.json b/shortfin/python/shortfin_apps/flux/examples/flux_request_bs2.json new file mode 100644 index 000000000..0ded22888 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/examples/flux_request_bs2.json @@ -0,0 +1,18 @@ +{ + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal" + ], + "neg_prompt": "Watermark, blurry, oversaturated, low resolution, pollution", + "height": 1024, + "width": 1024, + "steps": 20, + "guidance_scale": [ + 7.5, + 7.9 + ], + "seed": 0, + "output_type": [ + "base64" + ] +} diff --git a/shortfin/python/shortfin_apps/flux/server.py b/shortfin/python/shortfin_apps/flux/server.py new file mode 100644 index 000000000..e6eaef1d3 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/server.py @@ -0,0 +1,424 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Any +import argparse +import logging +from pathlib import Path +import sys +import os +import copy +import subprocess +from contextlib import asynccontextmanager +import uvicorn + +# Import first as it does dep checking and reporting. +from shortfin.interop.fastapi import FastAPIResponder +from shortfin.support.logging_setup import native_handler + +from fastapi import FastAPI, Request, Response + +from .components.generate import ClientGenerateBatchProcess +from .components.config_struct import ModelParams +from .components.io_struct import GenerateReqInput +from .components.manager import SystemManager +from .components.service import GenerateService +from .components.tokenizer import Tokenizer + + +logger = logging.getLogger("shortfin-flux") +logger.addHandler(native_handler) +logger.propagate = False + +THIS_DIR = Path(__file__).resolve().parent + +UVICORN_LOG_CONFIG = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "()": "uvicorn.logging.DefaultFormatter", + "format": "[{asctime}] {message}", + "datefmt": "%Y-%m-%d %H:%M:%S", + "style": "{", + "use_colors": True, + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "default", + }, + }, + "loggers": { + "uvicorn": { + "handlers": ["console"], + "level": "INFO", + "propagate": False, + }, + }, +} + + +@asynccontextmanager +async def lifespan(app: FastAPI): + sysman.start() + try: + for service_name, service in services.items(): + logger.info("Initializing service '%s':", service_name) + logger.info(str(service)) + service.start() + except: + sysman.shutdown() + raise + yield + try: + for service_name, service in services.items(): + logger.info("Shutting down service '%s'", service_name) + service.shutdown() + finally: + sysman.shutdown() + + +sysman: SystemManager +services: dict[str, Any] = {} +app = FastAPI(lifespan=lifespan) + + +@app.get("/health") +async def health() -> Response: + return Response(status_code=200) + + +async def generate_request(gen_req: GenerateReqInput, request: Request): + service = services["sd"] + gen_req.post_init() + responder = FastAPIResponder(request) + ClientGenerateBatchProcess(service, gen_req, responder).launch() + return await responder.response + + +app.post("/generate")(generate_request) +app.put("/generate")(generate_request) + + +def configure_sys(args) -> SystemManager: + # Setup system (configure devices, etc). + model_config, topology_config, flagfile, tuning_spec, args = get_configs(args) + sysman = SystemManager(args.device, args.device_ids, args.amdgpu_async_allocations) + return sysman, model_config, flagfile, tuning_spec + + +def configure_service(args, sysman, model_config, flagfile, tuning_spec): + # Setup each service we are hosting. + clip_tokenizers = [ + Tokenizer.from_pretrained(args.tokenizer_source, subfolder="tokenizer") + ] + t5xxl_tokenizers = [ + Tokenizer.from_pretrained(args.tokenizer_source, subfolder="tokenizer_2") + ] + + model_params = ModelParams.load_json(model_config) + vmfbs, params = get_modules(args, model_config, flagfile, tuning_spec) + + sm = GenerateService( + name="sd", + sysman=sysman, + clip_tokenizers=clip_tokenizers, + t5xxl_tokenizers=t5xxl_tokenizers, + model_params=model_params, + fibers_per_device=args.fibers_per_device, + workers_per_device=args.workers_per_device, + prog_isolation=args.isolation, + show_progress=args.show_progress, + trace_execution=args.trace_execution, + ) + for key, vmfblist in vmfbs.items(): + for vmfb in vmfblist: + sm.load_inference_module(vmfb, component=key) + for key, datasets in params.items(): + sm.load_inference_parameters(*datasets, parameter_scope="model", component=key) + services[sm.name] = sm + return sysman + + +def get_configs(args): + # Returns one set of config artifacts. + modelname = "flux" + model_config = args.model_config if args.model_config else None + topology_config = None + tuning_spec = None + flagfile = args.flagfile if args.flagfile else None + cfg_builder_args = [ + sys.executable, + "-m", + "iree.build", + os.path.join(THIS_DIR, "components", "config_artifacts.py"), + f"--target={args.target}", + f"--output-dir={args.artifacts_dir}", + f"--model={modelname}", + ] + if args.topology: + cfg_builder_args.extend( + [ + f"--topology={args.topology}", + ] + ) + outs = subprocess.check_output(cfg_builder_args).decode() + outs_paths = outs.splitlines() + for i in outs_paths: + if "flux_config" in i and not args.model_config: + model_config = i + elif "topology" in i and args.topology: + topology_config = i + elif "flagfile" in i and not args.flagfile: + flagfile = i + elif "attention_and_matmul_spec" in i and args.use_tuned: + tuning_spec = i + + if args.use_tuned and args.tuning_spec: + tuning_spec = os.path.abspath(args.tuning_spec) + + if topology_config: + with open(topology_config, "r") as f: + contents = [line.rstrip() for line in f] + for spec in contents: + if "--" in spec: + arglist = spec.strip("--").split("=") + arg = arglist[0] + if len(arglist) > 2: + value = arglist[1:] + for val in value: + try: + val = int(val) + except ValueError: + val = val + elif len(arglist) == 2: + value = arglist[-1] + try: + value = int(value) + except ValueError: + value = value + else: + # It's a boolean arg. + value = True + setattr(args, arg, value) + else: + # It's an env var. + arglist = spec.split("=") + os.environ[arglist[0]] = arglist[1] + return model_config, topology_config, flagfile, tuning_spec, args + + +def get_modules(args, model_config, flagfile, td_spec): + # TODO: Move this out of server entrypoint + vmfbs = {"clip": [], "t5xxl": [], "sampler": [], "vae": []} + params = {"clip": [], "t5xxl": [], "sampler": [], "vae": []} + model_flags = copy.deepcopy(vmfbs) + model_flags["all"] = args.compile_flags + + if flagfile: + with open(flagfile, "r") as f: + contents = [line.rstrip() for line in f] + flagged_model = "all" + for elem in contents: + match = [keyw in elem for keyw in model_flags.keys()] + if any(match): + flagged_model = elem + else: + model_flags[flagged_model].extend([elem]) + if td_spec: + model_flags["sampler"].extend( + [f"--iree-codegen-transform-dialect-library={td_spec}"] + ) + + filenames = [] + for modelname in vmfbs.keys(): + ireec_args = model_flags["all"] + model_flags[modelname] + ireec_extra_args = " ".join(ireec_args) + builder_args = [ + sys.executable, + "-m", + "iree.build", + os.path.join(THIS_DIR, "components", "builders.py"), + f"--model-json={model_config}", + f"--target={args.target}", + f"--splat={args.splat}", + f"--build-preference={args.build_preference}", + f"--output-dir={args.artifacts_dir}", + f"--model={modelname}", + f"--iree-hal-target-device={args.device}", + f"--iree-hip-target={args.target}", + f"--iree-compile-extra-args={ireec_extra_args}", + ] + logger.info(f"Preparing runtime artifacts for {modelname}...") + logger.info( + "COMMAND LINE EQUIVALENT: " + " ".join([str(argn) for argn in builder_args]) + ) + output = subprocess.check_output(builder_args).decode() + + output_paths = output.splitlines() + filenames.extend(output_paths) + for name in filenames: + for key in vmfbs.keys(): + if key == "t5xxl" and all(x in name.lower() for x in ["xxl", "irpa"]): + params[key].extend([name]) + if key in name.lower(): + if any(x in name for x in [".irpa", ".safetensors", ".gguf"]): + params[key].extend([name]) + elif "vmfb" in name: + vmfbs[key].extend([name]) + return vmfbs, params + + +def main(argv, log_config=UVICORN_LOG_CONFIG): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--timeout-keep-alive", type=int, default=5, help="Keep alive timeout" + ) + parser.add_argument( + "--device", + type=str, + required=True, + choices=["local-task", "hip", "amdgpu"], + help="Primary inferencing device", + ) + parser.add_argument( + "--target", + type=str, + required=False, + default="gfx942", + choices=["gfx942", "gfx1100", "gfx90a"], + help="Primary inferencing device LLVM target arch.", + ) + parser.add_argument( + "--device_ids", + type=str, + nargs="*", + default=None, + help="Device IDs visible to the system builder. Defaults to None (full visibility). Can be an index or a sf device id like amdgpu:0:0@0", + ) + parser.add_argument( + "--tokenizer_source", + type=Path, + default="black-forest-labs/FLUX.1-dev", + help="HF repo from which to load tokenizer(s).", + ) + parser.add_argument( + "--model_config", type=Path, help="Path to the model config file." + ) + parser.add_argument( + "--workers_per_device", + type=int, + default=1, + help="Concurrency control -- how many fibers are created per device to run inference.", + ) + parser.add_argument( + "--fibers_per_device", + type=int, + default=1, + help="Concurrency control -- how many fibers are created per device to run inference.", + ) + parser.add_argument( + "--isolation", + type=str, + default="per_call", + choices=["per_fiber", "per_call", "none"], + help="Concurrency control -- How to isolate programs.", + ) + parser.add_argument( + "--show_progress", + action="store_true", + help="enable tqdm progress for sampler iterations.", + ) + parser.add_argument( + "--trace_execution", + action="store_true", + help="Enable tracing of program modules.", + ) + parser.add_argument( + "--amdgpu_async_allocations", + action="store_true", + help="Enable asynchronous allocations for amdgpu device contexts.", + ) + parser.add_argument( + "--splat", + action="store_true", + help="Use splat (empty) parameter files, usually for testing.", + ) + parser.add_argument( + "--build_preference", + type=str, + choices=["compile", "precompiled"], + default="precompiled", + help="Specify preference for builder artifact generation.", + ) + parser.add_argument( + "--compile_flags", + type=str, + nargs="*", + default=[], + help="extra compile flags for all compile actions. For fine-grained control, use flagfiles.", + ) + parser.add_argument( + "--flagfile", + type=Path, + help="Path to a flagfile to use for SDXL. If not specified, will use latest flagfile from azure.", + ) + parser.add_argument( + "--artifacts_dir", + type=Path, + default=None, + help="Path to local artifacts cache.", + ) + parser.add_argument( + "--tuning_spec", + type=str, + default=None, + help="Path to transform dialect spec if compiling an executable with tunings.", + ) + parser.add_argument( + "--topology", + type=str, + default=None, + choices=["spx_single", "cpx_single", "spx_multi", "cpx_multi"], + help="Use one of four known performant preconfigured device/fiber topologies.", + ) + parser.add_argument( + "--use_tuned", + type=int, + default=1, + help="Use tunings for attention and matmul ops. 0 to disable.", + ) + args = parser.parse_args(argv) + if not args.artifacts_dir: + home = Path.home() + artdir = home / ".cache" / "shark" + args.artifacts_dir = str(artdir) + else: + args.artifacts_dir = Path(args.artifacts_dir).resolve() + + global sysman + sysman, model_config, flagfile, tuning_spec = configure_sys(args) + configure_service(args, sysman, model_config, flagfile, tuning_spec) + uvicorn.run( + app, + host=args.host, + port=args.port, + log_config=log_config, + timeout_keep_alive=args.timeout_keep_alive, + ) + + +if __name__ == "__main__": + logging.root.setLevel(logging.INFO) + main( + sys.argv[1:], + # Make logging defer to the default shortfin logging config. + log_config=UVICORN_LOG_CONFIG, + ) diff --git a/shortfin/python/shortfin_apps/flux/simple_client.py b/shortfin/python/shortfin_apps/flux/simple_client.py new file mode 100644 index 000000000..cfe284e21 --- /dev/null +++ b/shortfin/python/shortfin_apps/flux/simple_client.py @@ -0,0 +1,246 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from datetime import datetime as dt +import os +import sys +import time +import json +import argparse +import base64 +import asyncio +import aiohttp +import requests + +from PIL import Image + +sample_request = { + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + ], + "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"], + "height": [1024], + "width": [1024], + "steps": [2], + "guidance_scale": [3.5], + "seed": [0], + "output_type": ["base64"], + "rid": ["string"], +} + + +def bytes_to_img(in_bytes, outputdir, idx=0, width=1024, height=1024): + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + image = Image.frombytes( + mode="RGB", size=(width, height), data=base64.b64decode(in_bytes) + ) + if not os.path.isdir(outputdir): + os.mkdir(outputdir) + im_path = os.path.join(outputdir, f"shortfin_flux_output_{timestamp}_{idx}.png") + image.save(im_path) + print(f"Saved to {im_path}") + + +def get_batched(request, arg, idx): + if isinstance(request[arg], list): + # some args are broadcasted to each prompt, hence overriding idx for single-item entries + if len(request[arg]) == 1: + indexed = request[arg][0] + else: + indexed = request[arg][idx] + else: + indexed = request[arg] + return indexed + + +async def send_request(session, rep, args, data): + print("Sending request batch #", rep) + url = f"{args.host}:{args.port}/generate" + start = time.time() + async with session.post(url, json=data) as response: + end = time.time() + # Check if the response was successful + if response.status == 200: + response.raise_for_status() # Raise an error for bad responses + res_json = await response.json(content_type=None) + if args.save: + for idx, item in enumerate(res_json["images"]): + width = get_batched(data, "width", idx) + height = get_batched(data, "height", idx) + print("Saving response as image...") + bytes_to_img( + item.encode("utf-8"), args.outputdir, idx, width, height + ) + latency = end - start + print("Responses processed.") + return latency, len(data["prompt"]) + print(f"Error: Received {response.status} from server") + raise Exception + + +async def static(args): + # Create an aiohttp session for sending requests + async with aiohttp.ClientSession() as session: + pending = [] + latencies = [] + sample_counts = [] + # Read the JSON file if supplied. Otherwise, get user input. + try: + if not args.file: + data = sample_request + else: + with open(args.file, "r") as json_file: + data = json.load(json_file) + except Exception as e: + print(f"Error reading the JSON file: {e}") + return + data["prompt"] = ( + [data["prompt"]] if isinstance(data["prompt"], str) else data["prompt"] + ) + start = time.time() + + async for i in async_range(args.reps): + pending.append(asyncio.create_task(send_request(session, i, args, data))) + await asyncio.sleep(1) # Wait for 1 second before sending the next request + while pending: + done, pending = await asyncio.wait( + pending, return_when=asyncio.ALL_COMPLETED + ) + for task in done: + latency, num_samples = await task + latencies.append(latency) + sample_counts.append(num_samples) + end = time.time() + if not any(i is None for i in [latencies, sample_counts]): + total_num_samples = sum(sample_counts) + sps = str(total_num_samples / (end - start)) + # Until we have better measurements, don't report the throughput that includes saving images. + if not args.save: + print(f"Average throughput: {sps} samples per second") + else: + raise ValueError("Received error response from server.") + + +async def interactive(args): + # Create an aiohttp session for sending requests + async with aiohttp.ClientSession() as session: + pending = [] + latencies = [] + sample_counts = [] + # Read the JSON file if supplied. Otherwise, get user input. + try: + if not args.file: + data = sample_request + else: + with open(args.file, "r") as json_file: + data = json.load(json_file) + except Exception as e: + print(f"Error reading the JSON file: {e}") + return + data["prompt"] = ( + [data["prompt"]] if isinstance(data["prompt"], str) else data["prompt"] + ) + while True: + prompt = await ainput("Enter a prompt: ") + data["prompt"] = [prompt] + data["steps"] = [args.steps] + print("Sending request with prompt: ", data["prompt"]) + + async for i in async_range(args.reps): + pending.append( + asyncio.create_task(send_request(session, i, args, data)) + ) + await asyncio.sleep( + 1 + ) # Wait for 1 second before sending the next request + while pending: + done, pending = await asyncio.wait( + pending, return_when=asyncio.ALL_COMPLETED + ) + for task in done: + _, _ = await task + pending = [] + if any(i is None for i in [latencies, sample_counts]): + raise ValueError("Received error response from server.") + + +async def ainput(prompt: str) -> str: + return await asyncio.to_thread(input, f"{prompt} ") + + +async def async_range(count): + for i in range(count): + yield i + await asyncio.sleep(0.0) + + +def check_health(url): + ready = False + print("Waiting for server.", end=None) + while not ready: + try: + if requests.get(f"{url}/health", timeout=20).status_code == 200: + print("Successfully connected to server.") + ready = True + return + time.sleep(2) + print(".", end=None) + except: + time.sleep(2) + print(".", end=None) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument( + "--file", + type=str, + default=None, + help="A non-default request to send to the server.", + ) + p.add_argument( + "--reps", + type=int, + default=1, + help="Number of times to duplicate each request in one second intervals.", + ) + p.add_argument( + "--save", + action=argparse.BooleanOptionalAction, + default=True, + help="Save images. To disable, use --no-save", + ) + p.add_argument( + "--outputdir", + type=str, + default="gen_imgs", + help="Directory to which images get saved.", + ) + p.add_argument( + "--host", type=str, default="http://0.0.0.0", help="Server host address." + ) + p.add_argument("--port", type=str, default="8000", help="Server port") + p.add_argument( + "--steps", + type=int, + default="20", + help="Number of inference steps. More steps usually means a better image. Interactive only.", + ) + p.add_argument( + "--interactive", + action="store_true", + help="Start as an example CLI client instead of sending static requests.", + ) + args = p.parse_args() + check_health(f"{args.host}:{args.port}") + if args.interactive: + asyncio.run(interactive(args)) + else: + asyncio.run(static(args)) + + +if __name__ == "__main__": + main()