Skip to content

Commit

Permalink
Some cleanup on flux integration (#829)
Browse files Browse the repository at this point in the history
Merges the diffusers reference models for sdxl and flux vae models. 

Also renames the exported function to decode from forward to avoid
confusion with vae encode to be added in the future
  • Loading branch information
IanNod authored Jan 15, 2025
1 parent d508b48 commit 0da9f25
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 52 deletions.
8 changes: 6 additions & 2 deletions sharktank/sharktank/models/vae/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ python -m sharktank.models.punet.tools.import_hf_dataset \
```

# Run Vae decoder model eager mode
# Sample SDXL command
```
python -m sharktank.models.vae.tools.run_vae --irpa-file ~/models/vae.irpa --device cpu
python -m sharktank.models.vae.tools.run_vae --irpa-file ~/models/vae.irpa --device cpu --dtype=float32
```
# Sample Flux command to run through iree and compare vs huggingface diffusers torch model
```
python -m sharktank.models.vae.tools.run_vae --irpa-file ~/models/vae.irpa --device cpu --compare_vs_torch --dtype=float32 --sharktank_config=flux --torch_model=black-forest-labs/FLUX.1-dev
```

## License

Significant portions of this implementation were derived from diffusers,
Expand Down
68 changes: 34 additions & 34 deletions sharktank/sharktank/models/vae/tools/diffuser_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@ def __init__(
self,
hf_model_name,
custom_vae="",
height=1024,
width=1024,
flux=False,
):
super().__init__()
self.vae = None
self.height = height
self.width = width
self.flux = flux
if custom_vae in ["", None]:
self.vae = AutoencoderKL.from_pretrained(
hf_model_name,
Expand All @@ -44,43 +50,37 @@ def __init__(
custom_vae,
subfolder="vae",
)
self.shift_factor = (
0.0
if self.vae.config.shift_factor is None
else self.vae.config.shift_factor
)

def decode(self, inp):
# The reference vae decode does not do scaling and leaves it for the sdxl pipeline. We integrate it into vae for pipeline performance so using the hardcoded values from the config.json here
img = 1 / 0.13025 * inp
if self.flux:
inp = rearrange(
inp,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(self.height / 16),
w=math.ceil(self.width / 16),
ph=2,
pw=2,
)
img = inp / self.vae.config.scaling_factor + self.shift_factor
x = self.vae.decode(img, return_dict=False)[0]
return (x / 2 + 0.5).clamp(0, 1)
if self.flux:
return x.clamp(-1, 1)
else:
return (x / 2 + 0.5).clamp(0, 1)


def run_torch_vae(hf_model_name, example_input):
vae_model = VaeModel(hf_model_name)
def run_torch_vae(
hf_model_name,
example_input,
height=1024,
width=1024,
flux=False,
dtype=torch.float32,
):
vae_model = VaeModel(hf_model_name, height=height, width=width, flux=flux).to(dtype)
return vae_model.decode(example_input)


# TODO Remove and integrate with VaeModel
class FluxAEWrapper(torch.nn.Module):
def __init__(self, height=1024, width=1024):
super().__init__()
self.ae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16
)
self.height = height
self.width = width

def forward(self, z):
d_in = rearrange(
z,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=math.ceil(self.height / 16),
w=math.ceil(self.width / 16),
ph=2,
pw=2,
)
d_in = d_in / self.ae.config.scaling_factor + self.ae.config.shift_factor
return self.ae.decode(d_in, return_dict=False)[0].clamp(-1, 1)


def run_flux_vae(example_input, dtype):
# TODO add support for other height/width sizes
vae_model = FluxAEWrapper(1024, 1024).to(dtype)
return vae_model.forward(example_input)
12 changes: 7 additions & 5 deletions sharktank/sharktank/models/vae/tools/run_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def export_vae(model, sample_inputs, decomp_attn):
fxb = FxProgramsBuilder(model)

@fxb.export_program(
name=f"forward",
name=f"decode",
args=tuple(torch.unsqueeze(sample_inputs, 0)),
strict=False,
)
Expand Down Expand Up @@ -86,7 +86,7 @@ def main(argv):
parser.add_argument(
"--torch_model",
default="stabilityai/stable-diffusion-xl-base-1.0",
help="HF reference model id",
help="HF reference model id, currently tested with stabilityai/stable-diffusion-xl-base-1.0 and black-forest-labs/FLUX.1-dev",
)

parser.add_argument(
Expand Down Expand Up @@ -141,12 +141,14 @@ def main(argv):
intermediates_saver.save_file(args.save_intermediates_path)

if args.compare_vs_torch:
from .diffuser_ref import run_torch_vae, run_flux_vae
from .diffuser_ref import run_torch_vae

if args.sharktank_config == "flux":
diffusers_results = run_flux_vae(inputs, torch.bfloat16)
diffusers_results = run_torch_vae(
args.torch_model, inputs, flux=True, dtype=dtype
)
elif args.sharktank_config == "sdxl":
run_torch_vae(args.torch_model, inputs)
run_torch_vae(args.torch_model, inputs, flux=False, dtype=dtype)
print("diffusers results:", diffusers_results)
torch.testing.assert_close(diffusers_results, results)

Expand Down
29 changes: 18 additions & 11 deletions sharktank/tests/models/vae/vae_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from sharktank.types import Dataset
from sharktank.models.vae.model import VaeDecoderModel
from sharktank.models.vae.tools.diffuser_ref import run_torch_vae, run_flux_vae
from sharktank.models.vae.tools.diffuser_ref import run_torch_vae
from sharktank.models.vae.tools.run_vae import export_vae
from sharktank.models.vae.tools.sample_data import get_random_inputs

Expand Down Expand Up @@ -166,7 +166,7 @@ def testVaeIreeVsHuggingFace(self):
vm_context=iree_vm_context,
args=iree_args,
driver="hip",
function_name="forward",
function_name="decode",
)[0].to_host()
# TODO: Verify these numerics are good or if tolerances are too loose
# TODO: Upload IR on passing tests to keep https://github.com/iree-org/iree/blob/main/experimental/regression_suite/shark-test-suite-models/sdxl/test_vae.py at latest
Expand Down Expand Up @@ -194,7 +194,7 @@ def testVaeIreeVsHuggingFace(self):
vm_context=iree_vm_context,
args=iree_args,
driver="hip",
function_name="forward",
function_name="decode",
)[0].to_host()
# TODO: Upload IR on passing tests
torch.testing.assert_close(
Expand Down Expand Up @@ -237,7 +237,9 @@ def setUp(self):
def testCompareBF16EagerVsHuggingface(self):
dtype = torch.bfloat16
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1, config="flux")
ref_results = run_flux_vae(inputs, dtype)
ref_results = run_torch_vae(
"black-forest-labs/FLUX.1-dev", inputs, 1024, 1024, True, dtype
)

ds = Dataset.load("{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa")
model = VaeDecoderModel.from_dataset(ds).to(device="cpu")
Expand All @@ -249,7 +251,9 @@ def testCompareBF16EagerVsHuggingface(self):
def testCompareF32EagerVsHuggingface(self):
dtype = torch.float32
inputs = get_random_inputs(dtype=dtype, device="cpu", bs=1, config="flux")
ref_results = run_flux_vae(inputs, dtype)
ref_results = run_torch_vae(
"black-forest-labs/FLUX.1-dev", inputs, 1024, 1024, True, dtype
)

ds = Dataset.load("{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa")
model = VaeDecoderModel.from_dataset(ds).to(device="cpu", dtype=dtype)
Expand All @@ -262,8 +266,9 @@ def testVaeIreeVsHuggingFace(self):
inputs = get_random_inputs(
dtype=torch.float32, device="cpu", bs=1, config="flux"
)
ref_results = run_flux_vae(inputs.to(dtype), dtype)
ref_results_f32 = run_flux_vae(inputs, torch.float32)
ref_results = run_torch_vae(
"black-forest-labs/FLUX.1-dev", inputs, 1024, 1024, True, torch.float32
)

ds = Dataset.load("{self._temp_dir}/flux_vae_bf16.irpa", file_type="irpa")
ds_f32 = Dataset.load("{self._temp_dir}/flux_vae_f32.irpa", file_type="irpa")
Expand Down Expand Up @@ -324,12 +329,14 @@ def testVaeIreeVsHuggingFace(self):
vm_context=iree_vm_context,
args=iree_args,
driver="hip",
function_name="forward",
function_name="decode",
)[0]
)

# TODO verify these numerics
torch.testing.assert_close(ref_results, iree_result, atol=3.3e-2, rtol=4e5)
torch.testing.assert_close(
ref_results.to(torch.bfloat16), iree_result, atol=3.3e-2, rtol=4e5
)

iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
module_path="{self._temp_dir}/flux_vae_f32.vmfb",
Expand All @@ -349,11 +356,11 @@ def testVaeIreeVsHuggingFace(self):
vm_context=iree_vm_context,
args=iree_args,
driver="hip",
function_name="forward",
function_name="decode",
)[0]
)

torch.testing.assert_close(ref_results_f32, iree_result_f32)
torch.testing.assert_close(ref_results, iree_result_f32)


if __name__ == "__main__":
Expand Down

0 comments on commit 0da9f25

Please sign in to comment.