diff --git a/sat/arguments.py b/sat/arguments.py index 65a187f..9e3b0e0 100755 --- a/sat/arguments.py +++ b/sat/arguments.py @@ -41,9 +41,11 @@ def add_model_config_args(parser): group.add_argument("--prompts", type=str, help="Multiple prompts separated by semicolons") group.add_argument("--ref-token-idx", type=int, nargs="+", help="Reference token indices") group.add_argument("--cur-token-idx", type=int, nargs="+", help="Current token indices") + group.add_argument("--reweight-token-idx", type=int, nargs="+",help="Reweight token indices") + group.add_argument("--reweight-scale", type=float, help="Reweight scale") group.add_argument("--is-run-isolated", type=bool, default=False, help="If running isolated video for comparison") group.add_argument("--single-prompt-length", type=int, default=0, help="Length of single prompt") - + return parser @@ -300,7 +302,7 @@ def generate_output_path(args): 'attn_map_step_idx', 'attn_map_layer_idx', 'mask_save_dir', 'overlap_size', 'num_transition_blocks', 'longer_mid_segment', 'ref_token_idx', 'cur_token_idx', 'prompts', - 'single_prompt_length', 'is_edit' + 'single_prompt_length', 'is_edit', 'reweight_token_idx', 'reweight_scale' ] clean_args = {k: getattr(args, k) for k in save_keys if hasattr(args, k)} @@ -364,7 +366,7 @@ def process_config_to_args(args): 'mask_save_dir', 'ref_token_idx', 'cur_token_idx', 'attn_map_step_idx', 'attn_map_layer_idx', 'thres', 'num_prompts', 'num_transition_blocks', 'longer_mid_segment', - 'is_edit' + 'is_edit', 'reweight_token_idx', 'reweight_scale' ] for param in params_to_register: diff --git a/sat/configs/inference.yaml b/sat/configs/inference.yaml index 4b782b1..dad699b 100755 --- a/sat/configs/inference.yaml +++ b/sat/configs/inference.yaml @@ -45,3 +45,6 @@ args: # Only for single prompt case single_prompt_length: 0 + # Only for reweight case + reweight_token_idx: 0 + reweight_scale: -5 \ No newline at end of file diff --git a/sat/dit_video_concat.py b/sat/dit_video_concat.py index 1a4746d..bd60d0b 100755 --- a/sat/dit_video_concat.py +++ b/sat/dit_video_concat.py @@ -27,6 +27,36 @@ timestep_embedding, ) from sat.ops.layernorm import LayerNorm, RMSNorm + + +# Copied from https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html: +def scaled_dot_product_attention_map(query, key, value, attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=False): + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + if enable_gqa: + key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) + value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return attn_weight, value + def concatenate_frames(frames): """ Concatenate multiple frames horizontally into a single image @@ -80,7 +110,13 @@ def set_self_attn_mask(self, mask): def reset_cross_attns(self): self.cross_attn_sum = None self.cross_attn_count = 0 - + def reset_video_self_attn(self): + self.video_self_attn_sum = None + self.video_self_attn_count = 0 + def reset_text_self_attn(self): + self.text_self_attn_sum = None + self.text_self_attn_count = 0 + def reset_self_attn_mask(self): self.self_attn_mask = None @@ -850,6 +886,8 @@ def __init__( num_transition_blocks=None, longer_mid_segment=None, is_edit=False, + reweight_token_idx=None, + reweight_scale=-5, ): super().__init__() print("BaseAdaLNMixin Init") @@ -878,6 +916,8 @@ def __init__( ] ) self.is_edit = is_edit + self.reweight_token_idx = reweight_token_idx + self.reweight_scale = reweight_scale self.cur_step = 0 self.cur_layer = 0 self.num_prompts = num_prompts @@ -1057,6 +1097,8 @@ def __init__( num_transition_blocks=None, longer_mid_segment=None, is_edit=False, + reweight_token_idx=None, + reweight_scale=-5, ): super().__init__() print("VisualizeAdaLNMixin Init") @@ -1085,6 +1127,8 @@ def __init__( ] ) self.is_edit = is_edit + self.reweight_token_idx = reweight_token_idx + self.reweight_scale = reweight_scale self.cur_step = 0 self.cur_layer = 0 self.num_prompts = num_prompts @@ -1223,7 +1267,7 @@ def attention_fn( } # hardcode here step_idx = [25] - layer_idx = [15] + layer_idx = [15, 20, 25, 29] frame_indices = [ self.compressed_num_frames // 2, # middle frame ] @@ -1276,7 +1320,9 @@ def attention_fn( return attn_output def after_total_layers(self): - pass + self.attn_controller.reset_cross_attns() + self.attn_controller.reset_video_self_attn() + self.attn_controller.reset_text_self_attn() @@ -1311,6 +1357,8 @@ def __init__( num_transition_blocks=None, longer_mid_segment=None, is_edit=False, + reweight_token_idx=None, + reweight_scale=-5, ): super().__init__() print("KVSharingAdaLNMixin Init") @@ -1339,6 +1387,8 @@ def __init__( ] ) self.is_edit = is_edit + self.reweight_token_idx = reweight_token_idx + self.reweight_scale = reweight_scale self.num_prompts = num_prompts self.num_transition_blocks = num_transition_blocks self.longer_mid_segment = longer_mid_segment @@ -1555,6 +1605,8 @@ def __init__( num_transition_blocks=None, longer_mid_segment=None, is_edit=False, + reweight_token_idx=None, + reweight_scale=-5, ): super().__init__() print("KVSharingMaskGuidedAdaLNMixin Init") @@ -1582,15 +1634,16 @@ def __init__( for _ in range(num_layers) ] ) - + self.is_edit = is_edit + self.reweight_token_idx = reweight_token_idx + self.reweight_scale = reweight_scale self.cur_step = 0 self.cur_layer = 0 self.num_prompts = num_prompts self.num_transition_blocks = num_transition_blocks self.longer_mid_segment = longer_mid_segment self.count_segment = 0 - self.is_edit = is_edit - + self.end_step = end_step self.end_layer = end_layer self.start_step = start_step @@ -1900,6 +1953,278 @@ def attn_batch(self, q, k, v, attention_mask, ref_token_idx, attention_dropout, def after_total_layers(self): self.attn_controller.reset_cross_attns() + self.attn_controller.reset_video_self_attn() + self.attn_controller.reset_text_self_attn() + + +class ReWeightAdaLNMixin(BaseMixin): + def __init__( + self, + width, + height, + hidden_size, + num_layers, + time_embed_dim, + compressed_num_frames, + text_length, + qk_ln=True, + hidden_size_head=None, + elementwise_affine=True, + start_step=2, + start_layer=5, + layer_idx=None, + step_idx=None, + overlap_size=6, + sampling_num_frames=13, + end_step=50, + end_layer=30, + mask_save_dir=None, + ref_token_idx=None, + cur_token_idx=None, + attn_map_step_idx=None, + attn_map_layer_idx=None, + thres=0.1, + num_prompts=None, + num_transition_blocks=None, + longer_mid_segment=None, + is_edit=False, + reweight_token_idx=None, + reweight_scale=-5, + ): + super().__init__() + print("ReWeightAdaLNMixin Init") + self.num_layers = num_layers + self.width = width + self.height = height + self.compressed_num_frames = compressed_num_frames + self.text_length = text_length + + self.adaLN_modulations = nn.ModuleList( + [nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)] + ) + + self.qk_ln = qk_ln + if qk_ln: + self.query_layernorm_list = nn.ModuleList( + [ + LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine) + for _ in range(num_layers) + ] + ) + self.key_layernorm_list = nn.ModuleList( + [ + LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine) + for _ in range(num_layers) + ] + ) + self.is_edit = is_edit + self.reweight_token_idx = reweight_token_idx + self.reweight_scale = reweight_scale + self.num_prompts = num_prompts + self.num_transition_blocks = num_transition_blocks + self.longer_mid_segment = longer_mid_segment + self.count_segment = 0 + + self.cur_step = 0 + self.cur_layer = 0 + + self.end_step = end_step + self.end_layer = end_layer + self.start_step = start_step + self.start_layer = start_layer + self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.end_layer)) + self.step_idx = step_idx if step_idx is not None else list(range(start_step, end_step)) + self.overlap_size = overlap_size + self.sampling_num_frames = sampling_num_frames + + def layer_forward( + self, + hidden_states, + mask, + *args, + **kwargs, + ): + text_length = kwargs["text_length"] + # hidden_states (b,(n_t+t*n_i),d) + text_hidden_states = hidden_states[:, :text_length] # (b,n,d) + img_hidden_states = hidden_states[:, text_length:] # (b,(t n),d) + layer = self.transformer.layers[kwargs["layer_id"]] + adaLN_modulation = self.adaLN_modulations[kwargs["layer_id"]] + + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + text_shift_msa, + text_scale_msa, + text_gate_msa, + text_shift_mlp, + text_scale_mlp, + text_gate_mlp, + ) = adaLN_modulation(kwargs["emb"]).chunk(12, dim=1) + gate_msa, gate_mlp, text_gate_msa, text_gate_mlp = ( + gate_msa.unsqueeze(1), + gate_mlp.unsqueeze(1), + text_gate_msa.unsqueeze(1), + text_gate_mlp.unsqueeze(1), + ) + + # self full attention (b,(t n),d) + img_attention_input = layer.input_layernorm(img_hidden_states) + text_attention_input = layer.input_layernorm(text_hidden_states) + img_attention_input = modulate(img_attention_input, shift_msa, scale_msa) + text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa) + + # text_attention_input[-1] = text_attention_input[-1] + 0.2 + + attention_input = torch.cat((text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d) + attention_output = layer.attention(attention_input, mask, **kwargs) + text_attention_output = attention_output[:, :text_length] # (b,n,d) + img_attention_output = attention_output[:, text_length:] # (b,(t n),d) + + if self.transformer.layernorm_order == "sandwich": + text_attention_output = layer.third_layernorm(text_attention_output) + img_attention_output = layer.third_layernorm(img_attention_output) + img_hidden_states = img_hidden_states + gate_msa * img_attention_output # (b,(t n),d) + # text_gate_msa = text_gate_msa + 4 + text_hidden_states = text_hidden_states + text_gate_msa * text_attention_output # (b,n,d) + + # mlp (b,(t n),d) + img_mlp_input = layer.post_attention_layernorm(img_hidden_states) # vision (b,(t n),d) + text_mlp_input = layer.post_attention_layernorm(text_hidden_states) # language (b,n,d) + img_mlp_input = modulate(img_mlp_input, shift_mlp, scale_mlp) + text_mlp_input = modulate(text_mlp_input, text_shift_mlp, text_scale_mlp) + mlp_input = torch.cat((text_mlp_input, img_mlp_input), dim=1) # (b,(n_t+t*n_i),d + mlp_output = layer.mlp(mlp_input, **kwargs) + img_mlp_output = mlp_output[:, text_length:] # vision (b,(t n),d) + text_mlp_output = mlp_output[:, :text_length] # language (b,n,d) + if self.transformer.layernorm_order == "sandwich": + text_mlp_output = layer.fourth_layernorm(text_mlp_output) + img_mlp_output = layer.fourth_layernorm(img_mlp_output) + + img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d) + text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d) + + hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d) + + self.cur_layer += 1 + if self.cur_layer == self.num_layers: + self.cur_layer = 0 + self.cur_step += 1 + self.after_total_layers() + + return hidden_states + + def reinit(self, parent_model=None): + self.cur_step = 0 + self.cur_layer = 0 + for layer in self.adaLN_modulations: + nn.init.constant_(layer[-1].weight, 0) + nn.init.constant_(layer[-1].bias, 0) + + @non_conflict + def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=None, + log_attention_weights=None, + scaling_attention_score=True, + old_impl=attention_fn_default, + **kwargs, + ): + if self.qk_ln: + query_layernorm = self.query_layernorm_list[kwargs["layer_id"]] + key_layernorm = self.key_layernorm_list[kwargs["layer_id"]] + query_layer = query_layernorm(query_layer) + key_layer = key_layernorm(key_layer) + + if self.cur_step in self.step_idx and self.cur_layer in self.layer_idx: + qu_s, qu_t, qc_s, qc_t = query_layer.chunk(4) + ku_s, ku_t, kc_s, kc_t = key_layer.chunk(4) + vu_s, vu_t, vc_s, vc_t = value_layer.chunk(4) + + # source branch + out_u_s = old_impl(qu_s, ku_s, vu_s, attention_mask, attention_dropout, log_attention_weights, scaling_attention_score, **kwargs) + out_c_s = old_impl(qc_s, kc_s, vc_s, attention_mask, attention_dropout, log_attention_weights, scaling_attention_score, **kwargs) + # no kv-sharing + if self.cur_step >= 2: # hardcode here, which is for learning similar layout + out_u_t = self.attn_batch(qu_t, ku_t, vu_t, attention_mask, attention_dropout, log_attention_weights, scaling_attention_score, **kwargs) + out_c_t = self.attn_batch(qc_t, kc_t, vc_t, attention_mask, attention_dropout, log_attention_weights, scaling_attention_score, **kwargs) + else: # kv-sharing for initial steps + out_u_t = self.attn_batch(qu_t, ku_s, vu_s, attention_mask, attention_dropout, log_attention_weights, scaling_attention_score, **kwargs) + out_c_t = self.attn_batch(qc_t, kc_s, vc_s, attention_mask, attention_dropout, log_attention_weights, scaling_attention_score, **kwargs) + out = torch.cat([out_u_s, out_u_t, out_c_s, out_c_t], dim=0) + return out + + return old_impl( + query_layer, + key_layer, + value_layer, + attention_mask, + attention_dropout=attention_dropout, + log_attention_weights=log_attention_weights, + scaling_attention_score=scaling_attention_score, + **kwargs, + ) + + + def attn_batch(self, q, k, v, attention_mask, attention_dropout, log_attention_weights, scaling_attention_score, **kwargs): + + # Ensure the input tensors are contiguous + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + + token_idx = self.reweight_token_idx + scale = self.reweight_scale + text_length = kwargs["text_length"] + + patch_tokens_per_frame = self.height * self.width + compressed_num_frames = self.compressed_num_frames + + '''First Version: We use add/sub by scaled_dot_product_attention's attn_mask instead of multiply/divide, because + if not using scaled_dot_product_attention, we may get OOM error when GPU memory is limited.''' + # Only create a (seq_len, seq_len) mask, allowing PyTorch to automatically broadcast to batch and heads dimensions + + # seq_len = q.shape[-2] + # attn_bias = torch.zeros((seq_len, seq_len), + # dtype=q.dtype, + # device=q.device) + + # # Strengthen/weaken the attention from video tokens to specific text tokens + # attn_bias[text_length:, token_idx] += scale + # # Strengthen/weaken the attention from specific text tokens to video tokens + # attn_bias[token_idx, text_length:] += scale + + # attn_output = torch.nn.functional.scaled_dot_product_attention( + # q, k, v, + # attn_mask=attn_bias, + # dropout_p=dropout_p, + # is_causal=False + # ) + + '''Second Version: We use multiply/divide about attn_weight, which uses ~71GB.''' + attn_weight, value = scaled_dot_product_attention_map( + q, k, v, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + # import pdb; pdb.set_trace() + attn_weight[:,:,text_length:, token_idx] *= scale # video to text + attn_weight[:,:,token_idx, text_length:] *= scale # text to video + + attn_output = attn_weight @ value + + return attn_output + + def after_total_layers(self): + pass + + str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} diff --git a/sat/inference_case_configs/reweight/bear_0.yaml b/sat/inference_case_configs/reweight/bear_0.yaml new file mode 100755 index 0000000..d4ae560 --- /dev/null +++ b/sat/inference_case_configs/reweight/bear_0.yaml @@ -0,0 +1,20 @@ +args: + seed: 42 + output_dir: outputs/reweight_case/bear_0 + sampling_fps: 8 + reweight_token_idx: 0 #pink + reweight_scale: 0 + # Please see dit_video_concat.py: attn_batch of class ReWeightAdaLNMixin + # for more details about reweight_token_idx and reweight_scale + # if you use version 1, reweight_scale is used for add/sub + # if you use version 2, reweight_scale is used for multiply/divide + start_step: 0 + end_step: 50 + start_layer: 0 + end_layer: 30 + adaln_mixin_names: + - 'ReWeightAdaLNMixin' + + prompts: + - "pink teddy bear wearing a cute pink bow tie" + - "pink teddy bear wearing a cute pink bow tie" diff --git a/sat/inference_case_configs/reweight/field_0p5.yaml b/sat/inference_case_configs/reweight/field_0p5.yaml new file mode 100755 index 0000000..4b7a7d8 --- /dev/null +++ b/sat/inference_case_configs/reweight/field_0p5.yaml @@ -0,0 +1,17 @@ +args: + seed: 42 + output_dir: outputs/reweight_case/field_0p5 + sampling_fps: 8 + reweight_token_idx: 5 #night + reweight_scale: 0.5 + + start_step: 0 + end_step: 50 + start_layer: 0 + end_layer: 30 + adaln_mixin_names: + - 'ReWeightAdaLNMixin' + + prompts: + - "a field of flowers at night" + - "a field of flowers at night" diff --git a/sat/run_reweight_video.sh b/sat/run_reweight_video.sh new file mode 100755 index 0000000..67a0aaa --- /dev/null +++ b/sat/run_reweight_video.sh @@ -0,0 +1,15 @@ +#! /bin/bash + +echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" + +environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1" + + + +inference_case_config="inference_case_configs/reweight/bear_0.yaml" +run_cmd="$environs python sample_video_edit.py --base configs/cogvideox_2b.yaml configs/inference.yaml --custom-config $inference_case_config" +echo ${run_cmd} +eval ${run_cmd} + + +echo "DONE on `hostname`" \ No newline at end of file diff --git a/sat/sample_video.py b/sat/sample_video.py index d56e4c2..cd5f64e 100755 --- a/sat/sample_video.py +++ b/sat/sample_video.py @@ -189,7 +189,7 @@ def process_multi_prompt_video_with_adaln(model, args, c_total, uc_total, save_path = os.path.join( args.output_dir, - "MultiPrompt_"+ adaln_name, + "MultiPrompt", ) if mpu.get_model_parallel_rank() == 0: save_video_as_grid_and_mp4(samples, save_path, fps=args.sampling_fps) diff --git a/sat/sample_video_edit.py b/sat/sample_video_edit.py index 3c14008..fd9e98d 100644 --- a/sat/sample_video_edit.py +++ b/sat/sample_video_edit.py @@ -258,9 +258,9 @@ def sampling_main(args, model_cls): next_adaln_name = AdaLNMixin_NAMES[(i + 1) % len(AdaLNMixin_NAMES)] - model.switch_adaln_layer(next_adaln_name) - load_checkpoint(model, args) - model.eval() + # model.switch_adaln_layer(next_adaln_name) + # load_checkpoint(model, args) + # model.eval() if __name__ == "__main__": if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: diff --git a/sat/sample_video_visualize.py b/sat/sample_video_visualize.py index 1f7c503..c27f1ee 100644 --- a/sat/sample_video_visualize.py +++ b/sat/sample_video_visualize.py @@ -258,9 +258,9 @@ def sampling_main(args, model_cls): next_adaln_name = AdaLNMixin_NAMES[(i + 1) % len(AdaLNMixin_NAMES)] - model.switch_adaln_layer(next_adaln_name) - load_checkpoint(model, args) - model.eval() + # model.switch_adaln_layer(next_adaln_name) + # load_checkpoint(model, args) + # model.eval() if __name__ == "__main__": if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: