-
Notifications
You must be signed in to change notification settings - Fork 2
/
main_UNSW.py
70 lines (53 loc) · 2.84 KB
/
main_UNSW.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from argparse import ArgumentParser
from UNSW.utils import utils
from UNSW.model import rad
import warnings
warnings.filterwarnings('ignore')
def arg_parser():
"""
Add parser parameters
:return:
"""
parser = ArgumentParser()
parser.add_argument('--random_seed', help='random seed', default=9)
parser.add_argument('--dataset_dir', help='please choose dataset directory', default='./UNSW/Datasets/NUSW_small.csv')
parser.add_argument('--out_dim', help='output dimensions', default=128)
parser.add_argument('--lr', help='learning rate', default=0.001)
parser.add_argument('--device', help='device cpu or cuda', default='cuda:0')
parser.add_argument('--dataset', help='name of dataset', default='NUSW_small')
parser.add_argument('--r_ad_alpha', help='hyper-parameter for the reward of anomaly detection', default=1)
parser.add_argument('--r_cl_alpha', help='hyper-parameter for the reward of anomaly classification', default=1)
# set unseen validation set parameters
parser.add_argument('--validation_size', help='set validation size for each class', default=10)
# set classifier network parameters
parser.add_argument('--input_dim', help='input dimensions', default=293)
parser.add_argument('--hid_dim0', default=256)
parser.add_argument('--hid_dim1', default=64)
parser.add_argument('--n_ways', help='n ways', default=6)
parser.add_argument('--n_support', help='n support', default=3)
parser.add_argument('--n_query', help='n query', default=3)
parser.add_argument('--max_epoch', help='max epoch for prototypical networks', default=10)
parser.add_argument('--epoch_size', help='epoch size for each epoch of protonet', default=20)
# set RAD parameters
parser.add_argument('--max_episode', help='max episode for each iterators', default=100)
parser.add_argument('--max_iterators', help='max iterators for training RAD model', default=10)
parser.add_argument('--lambda', help='hyper-parameter to balance two rewards', default=1)
parser.add_argument('--num_samples', help='number of samples generated each episode', default=5)
parser.add_argument('--n_min_size', help='minimum number of final training size', default=10)
parser.add_argument('--ratio_ab', help='ratio of abnormal', default=0.1)
return parser
def main():
parser = arg_parser()
args = parser.parse_args()
options = vars(args)
utils.set_seed(options['random_seed'])
# splitting dataset
seen_x, seen_y, sup_x, sup_y, unseen_x, unseen_y, test_x, test_y = utils.preprocessing_UNSW(options)
df_seen = utils.data2df(seen_x, seen_y)
df_sup = utils.data2df(sup_x, sup_y)
df_unseen = utils.data2df(unseen_x, unseen_y)
model_rad = rad.RAD(options)
model_rad.train_rad(df_seen, df_unseen, df_sup, test_x, test_y)
print('done')
if __name__ == '__main__':
main()