Skip to content

Commit

Permalink
feature: reweight implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
onevfall committed Feb 20, 2025
1 parent 9c335a0 commit 1e001f8
Show file tree
Hide file tree
Showing 9 changed files with 398 additions and 16 deletions.
8 changes: 5 additions & 3 deletions sat/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ def add_model_config_args(parser):
group.add_argument("--prompts", type=str, help="Multiple prompts separated by semicolons")
group.add_argument("--ref-token-idx", type=int, nargs="+", help="Reference token indices")
group.add_argument("--cur-token-idx", type=int, nargs="+", help="Current token indices")
group.add_argument("--reweight-token-idx", type=int, nargs="+",help="Reweight token indices")
group.add_argument("--reweight-scale", type=float, help="Reweight scale")
group.add_argument("--is-run-isolated", type=bool, default=False, help="If running isolated video for comparison")
group.add_argument("--single-prompt-length", type=int, default=0, help="Length of single prompt")

return parser


Expand Down Expand Up @@ -300,7 +302,7 @@ def generate_output_path(args):
'attn_map_step_idx', 'attn_map_layer_idx', 'mask_save_dir',
'overlap_size', 'num_transition_blocks', 'longer_mid_segment',
'ref_token_idx', 'cur_token_idx', 'prompts',
'single_prompt_length', 'is_edit'
'single_prompt_length', 'is_edit', 'reweight_token_idx', 'reweight_scale'
]

clean_args = {k: getattr(args, k) for k in save_keys if hasattr(args, k)}
Expand Down Expand Up @@ -364,7 +366,7 @@ def process_config_to_args(args):
'mask_save_dir', 'ref_token_idx', 'cur_token_idx',
'attn_map_step_idx', 'attn_map_layer_idx', 'thres',
'num_prompts', 'num_transition_blocks', 'longer_mid_segment',
'is_edit'
'is_edit', 'reweight_token_idx', 'reweight_scale'
]

for param in params_to_register:
Expand Down
3 changes: 3 additions & 0 deletions sat/configs/inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,6 @@ args:

# Only for single prompt case
single_prompt_length: 0
# Only for reweight case
reweight_token_idx: 0
reweight_scale: -5
Loading

0 comments on commit 1e001f8

Please sign in to comment.