-
Notifications
You must be signed in to change notification settings - Fork 149
/
Copy pathsupernet_options.py
111 lines (96 loc) · 6.73 KB
/
supernet_options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import data
import supernets
from options.base_options import BaseOptions
class SupernetOptions(BaseOptions):
"""This class defines options used during both training and test time.
It also implements several helper functions such as parsing, printing, and saving the options.
It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
"""
def __init__(self, isTrain=True):
"""Reset the class; indicates the class hasn't been initailized"""
super(SupernetOptions, self).__init__()
self.isTrain = isTrain
def initialize(self, parser):
"""Define the common options that are used in both training and test."""
parser = BaseOptions.initialize(self, parser)
# log parameters
parser.add_argument('--log_dir', type=str, default='logs/distill',
help='specify an experiment directory')
parser.add_argument('--tensorboard_dir', type=str, default=None,
help='tensorboard is saved here')
parser.add_argument('--print_freq', type=int, default=100,
help='frequency of showing training results on console')
parser.add_argument('--save_latest_freq', type=int, default=20000,
help='frequency of evaluating and save the latest model')
parser.add_argument('--save_epoch_freq', type=int, default=5,
help='frequency of saving checkpoints at the end of epoch')
parser.add_argument('--epoch_base', type=int, default=1,
help='the epoch base of the training (used for resuming)')
parser.add_argument('--iter_base', type=int, default=0,
help='the iteration base of the training (used for resuming)')
# model parameters
parser.add_argument('--supernet', type=str, default='resnet',
help='specify which supernet you want to use [resnet | spade]')
parser.add_argument('--teacher_netG', type=str, help='specify teacher generator architecture')
parser.add_argument('--student_netG', type=str, help='specify student generator architecture')
parser.add_argument('--netD', type=str, default='n_layers',
help='specify discriminator architecture [n_layers | pixel]. '
'The basic model is a 70x70 PatchGAN. '
'n_layers allows you to specify the layers in the discriminator')
parser.add_argument('--teacher_ngf', type=int, help='the number of filters of the teacher generator')
parser.add_argument('--student_ngf', type=int, help='the base number of filters of the student generator')
parser.add_argument('--ndf', type=int, default=128, help='the base number of discriminator filters')
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
parser.add_argument('--gan_mode', type=str, default='hinge', choices=['lsgan', 'vanilla', 'hinge'],
help='the type of GAN objective. [vanilla| lsgan | hinge]. '
'vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
parser.add_argument('--restore_teacher_G_path', type=str, required=True,
help='the path to restore the teacher generator')
parser.add_argument('--restore_student_G_path', type=str, default=None,
help='the path to restore the student generator')
parser.add_argument('--restore_A_path', type=str, default=None,
help='the path to restore the adaptors for distillation')
parser.add_argument('--restore_D_path', type=str, default=None,
help='the path to restore the discriminator')
parser.add_argument('--restore_O_path', type=str, default=None,
help='the path to restore the optimizer')
# training parameters
parser.add_argument('--nepochs', type=int, default=10,
help='number of epochs with the initial learning rate')
parser.add_argument('--nepochs_decay', type=int, default=30,
help='number of epochs to linearly decay learning rate to zero')
parser.add_argument('--niters', type=int, default=1000000000,
help='max number of iteration of the training')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--lr_policy', type=str, default='linear',
help='learning rate policy. [linear | step | plateau | cosine]')
parser.add_argument('--lr_decay_steps', type=int, default=100000,
help='multiply by a gamma every lr_decay_steps steps (only for step lr policy)')
parser.add_argument('--scheduler_counter', type=str, default='epoch', choices=['epoch', 'iter'],
help='which counter to use for the scheduler')
parser.add_argument('--gamma', type=float, default=0.5,
help='multiply by a gamma every lr_decay_epochs epochs (only for step lr policy)')
parser.add_argument('--eval_batch_size', type=int, default=1, help='the evaluation batch size')
parser.add_argument('--real_stat_path', type=str, required=True,
help='the path to load the ground-truth images information to compute FID.')
parser.add_argument('--no_fid', action='store_true', help='No FID evaluation during training')
parser.add_argument('--no_mIoU', action='store_true', help='No mIoU evaluation during training '
'(sometimes because there are CUDA memory)')
return parser
def gather_options(self):
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)
opt, _ = parser.parse_known_args()
supernet_name = opt.supernet
supernet_option_setter = supernets.get_option_setter(supernet_name)
parser = supernet_option_setter(parser, self.isTrain)
opt, _ = parser.parse_known_args()
# modify dataset-related parser options
dataset_name = opt.dataset_mode
dataset_option_setter = data.get_option_setter(dataset_name)
parser = dataset_option_setter(parser, self.isTrain)
# save and return the parser
self.parser = parser
return parser.parse_args()