diff --git a/gradio_demo/app.py b/gradio_demo/app.py index 4b7bd02..a6f3192 100755 --- a/gradio_demo/app.py +++ b/gradio_demo/app.py @@ -311,4 +311,4 @@ def get_example(): gr.Markdown(article) -demo.launch() +demo.launch() \ No newline at end of file diff --git a/gradio_demo/app_v2.py b/gradio_demo/app_v2.py index ee6e1a4..8fb54fb 100644 --- a/gradio_demo/app_v2.py +++ b/gradio_demo/app_v2.py @@ -85,6 +85,7 @@ def generate_image( style_strength_ratio, num_outputs, guidance_scale, + pag_scale, seed, use_doodle, sketch_image, @@ -162,6 +163,8 @@ def generate_image( start_merge_step=start_merge_step, generator=generator, guidance_scale=guidance_scale, + pag_scale=pag_scale, + pag_applied_layers=['mid'], id_embeds=id_embeds, image=sketch_image, adapter_conditioning_scale=adapter_conditioning_scale, @@ -368,6 +371,13 @@ def get_example(): step=0.1, value=5, ) + pag_scale = gr.Slider( + label="PAG scale", + minimum=0.0, + maximum=10.0, + step=0.1, + value=3.0, + ) seed = gr.Slider( label="Seed", minimum=0, @@ -394,6 +404,7 @@ def get_example(): style_strength_ratio, num_outputs, guidance_scale, + pag_scale, seed, enable_doodle, sketch_image, diff --git a/photomaker/pipeline.py b/photomaker/pipeline.py index fde8512..15dfcbe 100755 --- a/photomaker/pipeline.py +++ b/photomaker/pipeline.py @@ -23,6 +23,7 @@ import PIL import torch +import torch.nn.functional as F from transformers import CLIPImageProcessor from safetensors import safe_open @@ -46,6 +47,7 @@ scale_lora_layers, unscale_lora_layers, ) +from diffusers.models.attention_processor import Attention, AttnProcessor2_0 if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -67,6 +69,229 @@ List[torch.FloatTensor], ] +class PAGIdentitySelfAttnProcessor: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) + + # original path + batch_size, sequence_length, _ = hidden_states_org.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + key = attn.to_k(hidden_states_org) + value = attn.to_v(hidden_states_org) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # perturbed path (identity attention) + batch_size, sequence_length, _ = hidden_states_ptb.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + value = attn.to_v(hidden_states_ptb) + + # hidden_states_ptb = torch.zeros(value.shape).to(value.get_device()) + hidden_states_ptb = value + + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class PAGCFGIdentitySelfAttnProcessor: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) + hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) + + # original path + batch_size, sequence_length, _ = hidden_states_org.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + key = attn.to_k(hidden_states_org) + value = attn.to_v(hidden_states_org) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # perturbed path (identity attention) + batch_size, sequence_length, _ = hidden_states_ptb.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + value = attn.to_v(hidden_states_ptb) + hidden_states_ptb = value + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): @@ -486,6 +711,34 @@ def encode_prompt_with_trigger_word( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, class_tokens_mask + @property + def pag_scale(self): + return self._pag_scale + + @property + def do_adversarial_guidance(self): + return self._pag_scale > 0 + + @property + def pag_adaptive_scaling(self): + return self._pag_adaptive_scaling + + @property + def do_pag_adaptive_scaling(self): + return self._pag_adaptive_scaling > 0 + + @property + def pag_drop_rate(self): + return self._pag_drop_rate + + @property + def pag_applied_layers(self): + return self._pag_applied_layers + + @property + def pag_applied_layers_index(self): + return self._pag_applied_layers_index + @torch.no_grad() def __call__( self, @@ -498,6 +751,11 @@ def __call__( sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, + pag_scale: float = 0.0, + pag_adaptive_scaling: float = 0.0, + pag_drop_rate: float = 0.5, + pag_applied_layers: List[str] = ['mid'], #['down', 'mid', 'up'] + pag_applied_layers_index: List[str] = None, #['d4', 'd5', 'm0'] negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -608,6 +866,11 @@ def __call__( self._denoising_end = denoising_end self._interrupt = False + self._pag_scale = pag_scale + self._pag_adaptive_scaling = pag_adaptive_scaling + self._pag_drop_rate = pag_drop_rate + self._pag_applied_layers = pag_applied_layers + self._pag_applied_layers_index = pag_applied_layers_index # if prompt_embeds is not None and class_tokens_mask is None: raise ValueError( @@ -753,8 +1016,17 @@ def __call__( else: negative_add_time_ids = add_time_ids - if self.do_classifier_free_guidance: + # if self.do_classifier_free_guidance: + # add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + #cfg + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + #pag + elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + #both + elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) @@ -796,6 +1068,23 @@ def __call__( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) + # 13. Create down mid and up layer lists + if self.do_adversarial_guidance: + down_layers = [] + mid_layers = [] + up_layers = [] + for name, module in self.unet.named_modules(): + if 'attn1' in name and 'to' not in name: + layer_type = name.split('.')[0].split('_')[0] + if layer_type == 'down': + down_layers.append(module) + elif layer_type == 'mid': + mid_layers.append(module) + elif layer_type == 'up': + up_layers.append(module) + else: + raise ValueError(f"Invalid layer type: {layer_type}") + self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -803,24 +1092,98 @@ def __call__( continue # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + # latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + #cfg + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: + latent_model_input = torch.cat([latents] * 2) + #pag + elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + latent_model_input = torch.cat([latents] * 2) + #both + elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + latent_model_input = torch.cat([latents] * 3) + #no + else: + latent_model_input = latents + + # change attention layer in UNet if use PAG + if self.do_adversarial_guidance: + + if self.do_classifier_free_guidance: + replace_processor = PAGCFGIdentitySelfAttnProcessor() + else: + replace_processor = PAGIdentitySelfAttnProcessor() + + if(self.pag_applied_layers_index): + drop_layers = self.pag_applied_layers_index + for drop_layer in drop_layers: + layer_number = int(drop_layer[1:]) + try: + if drop_layer[0] == 'd': + down_layers[layer_number].processor = replace_processor + elif drop_layer[0] == 'm': + mid_layers[layer_number].processor = replace_processor + elif drop_layer[0] == 'u': + up_layers[layer_number].processor = replace_processor + else: + raise ValueError(f"Invalid layer type: {drop_layer[0]}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." + ) + elif(self.pag_applied_layers): + drop_full_layers = self.pag_applied_layers + for drop_full_layer in drop_full_layers: + try: + if drop_full_layer == "down": + for down_layer in down_layers: + down_layer.processor = replace_processor + elif drop_full_layer == "mid": + for mid_layer in mid_layers: + mid_layer.processor = replace_processor + elif drop_full_layer == "up": + for up_layer in up_layers: + up_layer.processor = replace_processor + else: + raise ValueError(f"Invalid layer type: {drop_full_layer}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" + ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - if i <= start_merge_step: - current_prompt_embeds = torch.cat( - [negative_prompt_embeds, prompt_embeds_text_only], dim=0 - ) if self.do_classifier_free_guidance else prompt_embeds_text_only - add_text_embeds = torch.cat( - [negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0 - ) if self.do_classifier_free_guidance else pooled_prompt_embeds_text_only + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: + if i <= start_merge_step: + current_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds_text_only], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0) + else: + current_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + #pag + elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + if i <= start_merge_step: + current_prompt_embeds = torch.cat([prompt_embeds_text_only, prompt_embeds_text_only], dim=0) + add_text_embeds = torch.cat([pooled_prompt_embeds_text_only, pooled_prompt_embeds_text_only], dim=0) + else: + current_prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + #both + elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + if i <= start_merge_step: + current_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds_text_only, prompt_embeds_text_only], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only, pooled_prompt_embeds_text_only], dim=0) + else: + current_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + #nothing else: - current_prompt_embeds = torch.cat( - [negative_prompt_embeds, prompt_embeds], dim=0 - ) if self.do_classifier_free_guidance else prompt_embeds - add_text_embeds = torch.cat( - [negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0 - ) if self.do_classifier_free_guidance else pooled_prompt_embeds + if i <= start_merge_step: + current_prompt_embeds = prompt_embeds_text_only + add_text_embeds = pooled_prompt_embeds_text_only + else: + current_prompt_embeds = prompt_embeds + add_text_embeds = pooled_prompt_embeds added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} if ip_adapter_image is not None or ip_adapter_image_embeds is not None: @@ -838,9 +1201,34 @@ def __call__( )[0] # perform guidance - if self.do_classifier_free_guidance: + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + # pag + elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + noise_pred_original, noise_pred_perturb = noise_pred.chunk(2) + + signal_scale = self.pag_scale + if self.do_pag_adaptive_scaling: + signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t) + if signal_scale<0: + signal_scale = 0 + + noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb) + + # both + elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + + noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3) + + signal_scale = self.pag_scale + if self.do_pag_adaptive_scaling: + signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t) + if signal_scale<0: + signal_scale = 0 + + noise_pred = noise_pred_text + (self.guidance_scale-1.0) * (noise_pred_text - noise_pred_uncond) + signal_scale * (noise_pred_text - noise_pred_text_perturb) + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf @@ -928,4 +1316,43 @@ def __call__( if not return_dict: return (image,) + #Change the attention layers back to original ones after PAG was applied + if self.do_adversarial_guidance: + if(self.pag_applied_layers_index): + drop_layers = self.pag_applied_layers_index + for drop_layer in drop_layers: + layer_number = int(drop_layer[1:]) + try: + if drop_layer[0] == 'd': + down_layers[layer_number].processor = AttnProcessor2_0() + elif drop_layer[0] == 'm': + mid_layers[layer_number].processor = AttnProcessor2_0() + elif drop_layer[0] == 'u': + up_layers[layer_number].processor = AttnProcessor2_0() + else: + raise ValueError(f"Invalid layer type: {drop_layer[0]}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." + ) + elif(self.pag_applied_layers): + drop_full_layers = self.pag_applied_layers + for drop_full_layer in drop_full_layers: + try: + if drop_full_layer == "down": + for down_layer in down_layers: + down_layer.processor = AttnProcessor2_0() + elif drop_full_layer == "mid": + for mid_layer in mid_layers: + mid_layer.processor = AttnProcessor2_0() + elif drop_full_layer == "up": + for up_layer in up_layers: + up_layer.processor = AttnProcessor2_0() + else: + raise ValueError(f"Invalid layer type: {drop_full_layer}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" + ) + return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file diff --git a/photomaker/pipeline_t2i_adapter.py b/photomaker/pipeline_t2i_adapter.py index 3aecff0..b6d496b 100755 --- a/photomaker/pipeline_t2i_adapter.py +++ b/photomaker/pipeline_t2i_adapter.py @@ -24,6 +24,7 @@ import numpy as np import PIL.Image import torch +import torch.nn.functional as F from transformers import ( CLIPImageProcessor, CLIPTextModel, @@ -45,6 +46,7 @@ LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, + Attention, ) from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.schedulers import KarrasDiffusionSchedulers @@ -69,7 +71,230 @@ PhotoMakerIDEncoder_CLIPInsightfaceExtendtoken, # PhotoMaker v2 ) +class PAGIdentitySelfAttnProcessor: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) + + # original path + batch_size, sequence_length, _ = hidden_states_org.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + key = attn.to_k(hidden_states_org) + value = attn.to_v(hidden_states_org) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # perturbed path (identity attention) + batch_size, sequence_length, _ = hidden_states_ptb.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + value = attn.to_v(hidden_states_ptb) + + # hidden_states_ptb = torch.zeros(value.shape).to(value.get_device()) + hidden_states_ptb = value + + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + +class PAGCFGIdentitySelfAttnProcessor: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # chunk + hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) + hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) + + # original path + batch_size, sequence_length, _ = hidden_states_org.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states_org) + key = attn.to_k(hidden_states_org) + value = attn.to_v(hidden_states_org) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_org = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_org = hidden_states_org.to(query.dtype) + + # linear proj + hidden_states_org = attn.to_out[0](hidden_states_org) + # dropout + hidden_states_org = attn.to_out[1](hidden_states_org) + + if input_ndim == 4: + hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # perturbed path (identity attention) + batch_size, sequence_length, _ = hidden_states_ptb.shape + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) + + value = attn.to_v(hidden_states_ptb) + hidden_states_ptb = value + hidden_states_ptb = hidden_states_ptb.to(query.dtype) + + # linear proj + hidden_states_ptb = attn.to_out[0](hidden_states_ptb) + # dropout + hidden_states_ptb = attn.to_out[1](hidden_states_ptb) + + if input_ndim == 4: + hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) + + # cat + hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ @@ -517,6 +742,34 @@ def encode_prompt_with_trigger_word( @property def interrupt(self): return self._interrupt + + @property + def pag_scale(self): + return self._pag_scale + + @property + def do_adversarial_guidance(self): + return self._pag_scale > 0 + + @property + def pag_adaptive_scaling(self): + return self._pag_adaptive_scaling + + @property + def do_pag_adaptive_scaling(self): + return self._pag_adaptive_scaling > 0 + + @property + def pag_drop_rate(self): + return self._pag_drop_rate + + @property + def pag_applied_layers(self): + return self._pag_applied_layers + + @property + def pag_applied_layers_index(self): + return self._pag_applied_layers_index @torch.no_grad() def __call__( @@ -531,6 +784,11 @@ def __call__( sigmas: List[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, + pag_scale: float = 0.0, + pag_adaptive_scaling: float = 0.0, + pag_drop_rate: float = 0.5, + pag_applied_layers: List[str] = ['mid'], #['down', 'mid', 'up'] + pag_applied_layers_index: List[str] = None, #['d4', 'd5', 'm0'] negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -627,6 +885,11 @@ def __call__( ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip + self._pag_scale = pag_scale + self._pag_adaptive_scaling = pag_adaptive_scaling + self._pag_drop_rate = pag_drop_rate + self._pag_applied_layers = pag_applied_layers + self._pag_applied_layers_index = pag_applied_layers_index # if prompt_embeds is not None and class_tokens_mask is None: @@ -807,8 +1070,15 @@ def __call__( else: negative_add_time_ids = add_time_ids - if self.do_classifier_free_guidance: + #cfg + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + #pag + elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + #both + elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) @@ -827,28 +1097,117 @@ def __call__( num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] + # 13. Create down mid and up layer lists + if self.do_adversarial_guidance: + down_layers = [] + mid_layers = [] + up_layers = [] + for name, module in self.unet.named_modules(): + if 'attn1' in name and 'to' not in name: + layer_type = name.split('.')[0].split('_')[0] + if layer_type == 'down': + down_layers.append(module) + elif layer_type == 'mid': + mid_layers.append(module) + elif layer_type == 'up': + up_layers.append(module) + else: + raise ValueError(f"Invalid layer type: {layer_type}") + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + #cfg + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: + latent_model_input = torch.cat([latents] * 2) + #pag + elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + latent_model_input = torch.cat([latents] * 2) + #both + elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + latent_model_input = torch.cat([latents] * 3) + #no + else: + latent_model_input = latents + + # change attention layer in UNet if use PAG + if self.do_adversarial_guidance: + + if self.do_classifier_free_guidance: + replace_processor = PAGCFGIdentitySelfAttnProcessor() + else: + replace_processor = PAGIdentitySelfAttnProcessor() + + if(self.pag_applied_layers_index): + drop_layers = self.pag_applied_layers_index + for drop_layer in drop_layers: + layer_number = int(drop_layer[1:]) + try: + if drop_layer[0] == 'd': + down_layers[layer_number].processor = replace_processor + elif drop_layer[0] == 'm': + mid_layers[layer_number].processor = replace_processor + elif drop_layer[0] == 'u': + up_layers[layer_number].processor = replace_processor + else: + raise ValueError(f"Invalid layer type: {drop_layer[0]}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." + ) + elif(self.pag_applied_layers): + drop_full_layers = self.pag_applied_layers + for drop_full_layer in drop_full_layers: + try: + if drop_full_layer == "down": + for down_layer in down_layers: + down_layer.processor = replace_processor + elif drop_full_layer == "mid": + for mid_layer in mid_layers: + mid_layer.processor = replace_processor + elif drop_full_layer == "up": + for up_layer in up_layers: + up_layer.processor = replace_processor + else: + raise ValueError(f"Invalid layer type: {drop_full_layer}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" + ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - if i <= start_merge_step: - current_prompt_embeds = torch.cat( - [negative_prompt_embeds, prompt_embeds_text_only], dim=0 - ) if self.do_classifier_free_guidance else prompt_embeds_text_only - add_text_embeds = torch.cat( - [negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0 - ) if self.do_classifier_free_guidance else pooled_prompt_embeds_text_only + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: + if i <= start_merge_step: + current_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds_text_only], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0) + else: + current_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + #pag + elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + if i <= start_merge_step: + current_prompt_embeds = torch.cat([prompt_embeds_text_only, prompt_embeds_text_only], dim=0) + add_text_embeds = torch.cat([pooled_prompt_embeds_text_only, pooled_prompt_embeds_text_only], dim=0) + else: + current_prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + #both + elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + if i <= start_merge_step: + current_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds_text_only, prompt_embeds_text_only], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only, pooled_prompt_embeds_text_only], dim=0) + else: + current_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + #nothing else: - current_prompt_embeds = torch.cat( - [negative_prompt_embeds, prompt_embeds], dim=0 - ) if self.do_classifier_free_guidance else prompt_embeds - add_text_embeds = torch.cat( - [negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0 - ) if self.do_classifier_free_guidance else pooled_prompt_embeds + if i <= start_merge_step: + current_prompt_embeds = prompt_embeds_text_only + add_text_embeds = pooled_prompt_embeds_text_only + else: + current_prompt_embeds = prompt_embeds + add_text_embeds = pooled_prompt_embeds if i < int(num_inference_steps * adapter_conditioning_factor) and (use_adapter): down_intrablock_additional_residuals = [state.clone() for state in adapter_state] @@ -872,9 +1231,34 @@ def __call__( )[0] # perform guidance - if self.do_classifier_free_guidance: + if self.do_classifier_free_guidance and not self.do_adversarial_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + # pag + elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: + noise_pred_original, noise_pred_perturb = noise_pred.chunk(2) + + signal_scale = self.pag_scale + if self.do_pag_adaptive_scaling: + signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t) + if signal_scale<0: + signal_scale = 0 + + noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb) + + # both + elif self.do_classifier_free_guidance and self.do_adversarial_guidance: + + noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3) + + signal_scale = self.pag_scale + if self.do_pag_adaptive_scaling: + signal_scale = self.pag_scale - self.pag_adaptive_scaling * (1000-t) + if signal_scale<0: + signal_scale = 0 + + noise_pred = noise_pred_text + (self.guidance_scale-1.0) * (noise_pred_text - noise_pred_uncond) + signal_scale * (noise_pred_text - noise_pred_text_perturb) + if self.do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf @@ -915,4 +1299,43 @@ def __call__( if not return_dict: return (image,) + #Change the attention layers back to original ones after PAG was applied + if self.do_adversarial_guidance: + if(self.pag_applied_layers_index): + drop_layers = self.pag_applied_layers_index + for drop_layer in drop_layers: + layer_number = int(drop_layer[1:]) + try: + if drop_layer[0] == 'd': + down_layers[layer_number].processor = AttnProcessor2_0() + elif drop_layer[0] == 'm': + mid_layers[layer_number].processor = AttnProcessor2_0() + elif drop_layer[0] == 'u': + up_layers[layer_number].processor = AttnProcessor2_0() + else: + raise ValueError(f"Invalid layer type: {drop_layer[0]}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." + ) + elif(self.pag_applied_layers): + drop_full_layers = self.pag_applied_layers + for drop_full_layer in drop_full_layers: + try: + if drop_full_layer == "down": + for down_layer in down_layers: + down_layer.processor = AttnProcessor2_0() + elif drop_full_layer == "mid": + for mid_layer in mid_layers: + mid_layer.processor = AttnProcessor2_0() + elif drop_full_layer == "up": + for up_layer in up_layers: + up_layer.processor = AttnProcessor2_0() + else: + raise ValueError(f"Invalid layer type: {drop_full_layer}") + except IndexError: + raise ValueError( + f"Invalid layer index: {drop_full_layer}. Available layers are: down, mid and up. If you need to specify each layer index, you can use `pag_applied_layers_index`" + ) + return StableDiffusionXLPipelineOutput(images=image) \ No newline at end of file