diff --git a/tests/foundationals/latent_diffusion/test_freeu.py b/tests/foundationals/latent_diffusion/test_freeu.py index 786cfe301..b21408412 100644 --- a/tests/foundationals/latent_diffusion/test_freeu.py +++ b/tests/foundationals/latent_diffusion/test_freeu.py @@ -1,9 +1,11 @@ from typing import Iterator import pytest +import torch from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter, FreeUResidualConcatenator +from refiners.fluxion import manual_seed @pytest.fixture(scope="module", params=[True, False]) @@ -39,3 +41,27 @@ def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None: def test_freeu_adapter_inconsistent_scales(unet: SD1UNet | SDXLUNet) -> None: with pytest.raises(AssertionError): SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9, 0.9]) + + +def test_freeu_identity_scales() -> None: + manual_seed(0) + text_embedding = torch.randn(1, 77, 768) + timestep = torch.randint(0, 999, size=(1, 1)) + x = torch.randn(1, 4, 32, 32) + + unet = SD1UNet(in_channels=4) + unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s + + with torch.no_grad(): + unet.set_timestep(timestep=timestep) + y_1 = unet(x.clone()) + + freeu = SDFreeUAdapter(unet, backbone_scales=[1.0, 1.0], skip_scales=[1.0, 1.0]) + freeu.inject() + + with torch.no_grad(): + unet.set_timestep(timestep=timestep) + y_2 = unet(x.clone()) + + # The FFT -> inverse FFT sequence (skip features) introduces small numerical differences + assert torch.allclose(y_1, y_2, atol=1e-5)