-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
153 lines (135 loc) · 7.81 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Yaoyao Liu
## Tianjin University
## Copyright (c) 2019
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import os
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import flags
from trainer.meta import MetaTrainer
from trainer.pre import PreTrainer
from data_generator.meta_data_generator import MetaDataGenerator
FLAGS = flags.FLAGS
### Basic options
flags.DEFINE_integer('img_size', 84, 'image size')
flags.DEFINE_integer('device_id', 0, 'GPU device ID to run the job.')
flags.DEFINE_integer('filter_num', 64, 'number of filters to use in cnn models')
flags.DEFINE_float('gpu_rate', 0.9, 'the parameter for the full_gpu_memory_mode')
flags.DEFINE_string('phase', 'meta', 'pre or meta')
flags.DEFINE_string('exp_log_label', 'experiment_results', 'directory for summaries and checkpoints')
flags.DEFINE_string('logdir_base', './logs/', 'directory for logs')
flags.DEFINE_bool('full_gpu_memory_mode', False, 'in this mode, the code occupies GPU memory in advance')
flags.DEFINE_string('backbone_arch', 'conv4', 'network backbone')
### Pre-train phase options
flags.DEFINE_integer('pre_lr_dropstep', 5000, 'the step number to drop pre_lr')
flags.DEFINE_integer('pre_way_num', 5, 'number of classes used in the pre-train phase')
flags.DEFINE_integer('pre_shot_num', 5, 'number of shots for each class in the pre-train phase')
flags.DEFINE_integer('pre_batch_size', 4, 'batch_size for the pre-train phase')
flags.DEFINE_integer('pre_base_epoch',5, 'number of inner gradient updates during pre-training')
flags.DEFINE_integer('pretrain_iterations', 30000, 'number of pretraining iterations.')
flags.DEFINE_float('pre_lr', 0.001, 'the pretrain learning rate')
flags.DEFINE_float('pre_base_lr', 0.01, 'the pretrain base learning rate')
flags.DEFINE_float('min_pre_lr', 0.0001, 'the pretrain learning rate min')
flags.DEFINE_float('pre_lr_droprate', 0.1, 'the rate to drop pre_lr')
### Meta phase options
flags.DEFINE_integer('way_num', 5, 'number of classes (e.g. 5-way classification)')
flags.DEFINE_integer('shot_num', 1, 'number of examples per class (K for K-shot learning)')
flags.DEFINE_integer('metatrain_epite_sample_num', 15, 'number of meta train episode-test samples')
flags.DEFINE_integer('metatest_epite_sample_num', 0, 'number of meta test episode-test samples, 0 means metatest_epite_sample_num=shot_num')
flags.DEFINE_integer('meta_sum_step', 10, 'the step number to summary during meta-training')
flags.DEFINE_integer('meta_save_step', 500, 'the step number to save the model')
flags.DEFINE_integer('meta_intrain_val_sample', 600, 'the number of samples used for val during meta-train')
flags.DEFINE_integer('meta_print_step', 100, 'the step number to print the meta-train results')
flags.DEFINE_integer('meta_val_print_step', 100, 'the step number to print the meta-val results during meta-training')
flags.DEFINE_integer('metatrain_iterations', 15000, 'number of meta-train iterations.')
flags.DEFINE_integer('meta_batch_size', 2, 'number of tasks sampled per meta-update')
flags.DEFINE_integer('train_base_epoch_num', 20, 'number of inner gradient updates during training.')
flags.DEFINE_integer('test_base_epoch_num', 100, 'number of inner gradient updates during test.')
flags.DEFINE_integer('lr_drop_step', 5000, 'the step number to drop meta_lr')
flags.DEFINE_integer('test_iter', 1000, 'iteration to load model')
flags.DEFINE_integer('resume_iter', 0, 'iteration to resume meta-training')
flags.DEFINE_integer('nontrainable_layers', -1, 'denotes the first layer which can be trained') #-1 means every layer is trainable
flags.DEFINE_float('resume_lr', 0.001, 'meta-lr to use when training is resumed')
flags.DEFINE_float('meta_lr', 0.001, 'the meta learning rate of the generator')
flags.DEFINE_float('lr_drop_rate', 0.5, 'the step number to drop meta_lr')
flags.DEFINE_float('min_meta_lr', 0.0001, 'the min meta learning rate of the generator')
flags.DEFINE_float('base_lr', 1e-3, 'step size alpha for inner gradient update.')
flags.DEFINE_string('metatrain_dir', './data/mini-imagenet/train', 'directory for meta-train set')
flags.DEFINE_string('metaval_dir', './data/mini-imagenet/val', 'directory for meta-val set')
flags.DEFINE_string('metatest_dir', './data/mini-imagenet/test', 'directory for meta-test set')
flags.DEFINE_string('activation', 'leaky_relu', 'leaky_relu, relu, or None')
flags.DEFINE_string('norm', 'batch_norm', 'batch_norm, layer_norm, or None')
flags.DEFINE_bool('metatrain', True, 'is this the meta-train phase')
flags.DEFINE_bool('base_augmentation', True, 'whether do data augmentation during base learning')
flags.DEFINE_bool('redo_init', True, 're-build the initialization weights')
flags.DEFINE_bool('from_scratch', False, 'start meta-train from scratch, do not use pre-train weights')
flags.DEFINE_bool('proto_maml', False, 'whether to use proto-maml initialization for fc weights')
flags.DEFINE_bool('stop_grad', False, 'whether to use the first order approximation')
# Generate experiment key words string
exp_string = 'arch(' + FLAGS.backbone_arch + ')'
exp_string += '.cls(' + str(FLAGS.way_num) + ')'
exp_string += '.shot(' + str(FLAGS.shot_num) + ')'
exp_string += '.meta_batch(' + str(FLAGS.meta_batch_size) + ')'
exp_string += '.base_epoch(' + str(FLAGS.train_base_epoch_num) + ')'
exp_string += '.meta_lr(' + str(FLAGS.meta_lr) + ')'
exp_string += '.base_lr(' + str(FLAGS.base_lr) + ')'
exp_string += '.pre_iterations(' + str(FLAGS.pretrain_iterations) + ')'
exp_string += '.acti(' + str(FLAGS.activation) + ')'
exp_string += '.lr_drop_step(' + str(FLAGS.lr_drop_step) + ')'
exp_string += '.lr_drop_rate(' + str(FLAGS.lr_drop_rate) + ')'
if FLAGS.norm == 'batch_norm':
exp_string += '.norm(batch)'
elif FLAGS.norm == 'layer_norm':
exp_string += '.norm(layer)'
elif FLAGS.norm == 'None':
exp_string += '.norm(none)'
else:
raise Exception('Norm setting is not recognized')
print('Parameters: ' + exp_string)
# Generate pre-train key words string
pre_save_str = 'arch(' + FLAGS.backbone_arch + ')'
pre_save_str += '.cls(' + str(FLAGS.pre_way_num) + ')'
pre_save_str += '.shot(' + str(FLAGS.pre_shot_num) + ')'
pre_save_str += '.meta_batch(' + str(FLAGS.pre_batch_size) + ')'
pre_save_str += '.base_epoch(' + str(FLAGS.pre_base_epoch) + ')'
pre_save_str += '.meta_lr(' + str(FLAGS.pre_lr) + ')'
pre_save_str += '.base_lr(' + str(FLAGS.pre_base_lr) + ')'
pre_save_str += '.lr_drop_step(' + str(FLAGS.pre_lr_dropstep) + ')'
pre_save_str += '.lr_drop_rate(' + str(FLAGS.pre_lr_droprate) + ')'
pre_string = pre_save_str
# Generate log folders
logdir = FLAGS.logdir_base + FLAGS.exp_log_label
pretrain_dir = FLAGS.logdir_base + 'pretrain_weights'
if not os.path.exists(FLAGS.logdir_base):
os.mkdir(FLAGS.logdir_base)
if not os.path.exists(logdir):
os.mkdir(logdir)
if not os.path.exists(pretrain_dir):
os.mkdir(pretrain_dir)
# If FLAGS.redo_init is true, delete the previous intialization weights.
if FLAGS.redo_init:
if os.path.exists('./logs/init_weights'):
os.system('rm -r ./logs/init_weights')
print('Init weights have been deleted')
else:
print('No init weights')
def main():
tf.logging.set_verbosity(tf.compat.v1.logging.ERROR)
# Set GPU device id
print('Using GPU ' + str(FLAGS.device_id))
os.environ['CUDA_VISIBLE_DEVICES'] = str(FLAGS.device_id)
#os.environ['CUDA_VISIBLE_DEVICES'] = "-1"
# Select pre-train phase or meta-learning phase
if FLAGS.phase=='pre':
trainer = PreTrainer(pre_string, pretrain_dir)
elif FLAGS.phase=='meta':
trainer = MetaTrainer(exp_string, logdir, pre_string, pretrain_dir)
else:
raise Exception('Please set correct phase')
if __name__ == "__main__":
main()