Skip to content

Commit

Permalink
delta
Browse files Browse the repository at this point in the history
  • Loading branch information
yul091 committed May 6, 2023
1 parent 9dd3a8a commit b1e39e1
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
3 changes: 3 additions & 0 deletions attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def main(args: argparse.Namespace):
att_method = args.attack_strategy
cls_weight = args.cls_weight
eos_weight = args.eos_weight
delta = args.delta
use_combined_loss = args.use_combined_loss
out_dir = args.out_dir

Expand Down Expand Up @@ -332,6 +333,7 @@ def main(args: argparse.Namespace):
select_beams=select_beams,
eos_weight=eos_weight,
cls_weight=cls_weight,
delta=delta,
use_combined_loss=use_combined_loss,
)
elif att_method.lower() == 'pwws':
Expand Down Expand Up @@ -477,6 +479,7 @@ def main(args: argparse.Namespace):
parser.add_argument("--seed", type=int, default=2019, help="Random seed")
parser.add_argument("--eos_weight", type=float, default=0.8, help="Weight for EOS gradient")
parser.add_argument("--cls_weight", type=float, default=0.2, help="Weight for classification gradient")
parser.add_argument("--delta", type=float, default=0.5, help="Threshold for adaptive search strategy")
parser.add_argument("--use_combined_loss", action="store_true", help="Use combined loss")
parser.add_argument("--attack_strategy", "-a", type=str,
default='structure',
Expand Down
3 changes: 2 additions & 1 deletion attacker/DGSlow.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,12 @@ def __init__(
select_beams: int = 1,
eos_weight: float = 0.5,
cls_weight: float = 0.5,
delta: float = 0.5,
use_combined_loss: bool = False,
):
super(StructureAttacker, self).__init__(
device, tokenizer, model, max_len, max_per, task, fitness,
select_beams, eos_weight, cls_weight, use_combined_loss,
select_beams, eos_weight, cls_weight, delta, use_combined_loss,
)
self.filter_words = set(ENGLISH_FILTER_WORDS)
# BERT initialization
Expand Down
6 changes: 4 additions & 2 deletions attacker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def __init__(
select_beams: int = 1,
eos_weight: float = 0.5,
cls_weight: float = 0.5,
delta: float = 0.5,
use_combined_loss: bool = False,
):
super(SlowAttacker, self).__init__(
Expand All @@ -204,6 +205,7 @@ def __init__(
self.sp_token = '<SEP>'
self.select_beam = select_beams
self.fitness = fitness
self.delta = delta
self.eos_weight = eos_weight
self.cls_weight = cls_weight
self.use_combined_loss = use_combined_loss
Expand Down Expand Up @@ -301,8 +303,8 @@ def select_best(
elif self.fitness == 'adaptive':
assert len(pred_len) == len(pred_acc)
q = np.mean([self.sent_encoder.get_sim(prototype, x[1]) for x in new_strings])
preference = (it - 1) * np.exp(q - 1) / (self.max_per - 1)
if preference > 0.5:
preference = it * np.exp(q - 1) / (self.max_per - 1)
if preference > self.delta:
# Random search
rand_i = np.random.choice(len(pred_len), min(self.select_beam, len(pred_len)), replace=False)
return [new_strings[i] for i in rand_i], [pred_len[i] for i in rand_i]
Expand Down

0 comments on commit b1e39e1

Please sign in to comment.