-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
66 lines (59 loc) · 3.29 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from model.Seq2SeqETOX import TextGeneratorSeq2Seq
import argparse
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--text', type=str)
parser.add_argument('--target_seq_length', type=int, default=100)
parser.add_argument('--quality_scale', type=float, default=0.7)
parser.add_argument('--stepsize', type=float, default=0.7)
parser.add_argument('--top_size', type=int, default=10)
parser.add_argument('--attention_change', type=str, default='self_attention_decoder')
parser.add_argument('--src_lang', type=str, default='eng_Latn')
parser.add_argument('--tgt_lang', type=str, default='fra_Latn')
parser.add_argument("--unmodified", default=False, type=lambda x: (str(x).lower() == 'true') )
parser.add_argument("--update_when_toxic", default=False, type=lambda x: (str(x).lower() == 'true') )
parser.add_argument('--toxicity_method', type=str, default='ETOX')
parser.add_argument('--beam_size', type=int, default=4)
args = parser.parse_args()
return args
def main(grid_args):
toxicity_filename = './NLLB-200_TWL/{}_twl.txt'.format(grid_args['tgt_lang'])
seq2seq_model = TextGeneratorSeq2Seq(
toxicity_filename,
seed=0,
seq2seq_model='nllb600M',
target_seq_length=grid_args['target_seq_length'],
num_iterations=1,
quality_scale=grid_args['quality_scale'],
stepsize=grid_args['stepsize'],
grad_norm_factor=0.9,
repetition_penalty=1.,
end_factor=1.01,
top_size = grid_args['top_size'],
attention_change = grid_args['attention_change'],
src_lang = grid_args['src_lang'],
tgt_lang = grid_args['tgt_lang'],
unmodified = grid_args['unmodified'],
update_when_toxic = grid_args['update_when_toxic'],
toxicity_method = grid_args['toxicity_method']
)
trans = seq2seq_model.run( grid_args['text'], grid_args['beam_size'])
return trans
if __name__ == '__main__':
args = get_args()
grid_args = {
'text':args.text,
'target_seq_length': args.target_seq_length,
'quality_scale':args.quality_scale,
'stepsize':args.stepsize,
'top_size':args.top_size,
'attention_change':args.attention_change,
'src_lang':args.src_lang,
'tgt_lang':args.tgt_lang,
'unmodified':args.unmodified,
'update_when_toxic':args.update_when_toxic,
'toxicity_method':args.toxicity_method,
'beam_size':args.beam_size
}
translation = main(grid_args)
print('Translated sentence: {}'.format(translation))