-
Notifications
You must be signed in to change notification settings - Fork 149
/
Copy pathbase_options.py
160 lines (136 loc) · 8.25 KB
/
base_options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import argparse
import os
import pickle
import torch
import data
import models
class BaseOptions:
"""This class defines options used during both training and test time.
It also implements several helper functions such as parsing, printing, and saving the options.
It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
"""
def __init__(self):
"""Reset the class; indicates the class hasn't been initailized"""
self.isTrain = True
def initialize(self, parser):
"""Define the common options that are used in both training and test."""
# basic parameters
parser.add_argument('--dataroot', required=True,
help='path to images (should have subfolders trainA, trainB, valA, valB, train, val, etc)')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--seed', type=int, default=233, help='random seed')
# model parameters
parser.add_argument('--input_nc', type=int, default=3,
help='# of input image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--output_nc', type=int, default=3,
help='# of output image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--norm', type=str, default='instance',
help='instance normalization or batch normalization [instance | batch | none]')
parser.add_argument('--init_type', type=str, default='normal',
help='network initialization [normal | xavier | kaiming | orthogonal]')
parser.add_argument('--init_gain', type=float, default=0.02,
help='scaling factor for normal, xavier and orthogonal.')
# dataset parameters
parser.add_argument('--dataset_mode', type=str, default='aligned',
help='chooses how datasets are loaded. [unaligned | aligned | single]')
parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
parser.add_argument('--serial_batches', action='store_true',
help='if true, takes images in order to make batches, otherwise takes them randomly')
parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
parser.add_argument('--aspect_ratio', type=float, default=1.0,
help='The ratio width/height. The final height of the load image will be crop_size/aspect_ratio')
parser.add_argument('--max_dataset_size', type=int, default=-1,
help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
parser.add_argument('--preprocess', type=str, default='resize_and_crop',
help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
parser.add_argument('--no_flip', action='store_true',
help='if specified, do not flip the images for data augmentation')
parser.add_argument('--display_winsize', type=int, default=256,
help='display window size for both visdom and HTML')
parser.add_argument('--load_in_memory', action='store_true',
help='whether you will load the data into the memory to skip the IO.')
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
parser.add_argument('--config_set', type=str, default=None,
help='the name of the configuration set for the set of subnet configurations.')
parser.add_argument('--config_str', type=str, default=None,
help='the configuration string for a specific subnet in the supernet')
# evaluation metric parameters
parser.add_argument('--drn_path', type=str, default='drn-d-105_ms_cityscapes.pth',
help='the path to the pre-trained drn path to compute mIoU')
parser.add_argument('--cityscapes_path', type=str, default='database/cityscapes-origin',
help='the original cityscapes dataset path (not the pix2pix preprocessed one)')
parser.add_argument('--table_path', type=str, default='datasets/val_table.txt',
help='the path to the mapping table (generated by datasets/prepare_cityscapes_dataset.py)')
parser.add_argument('--deeplabv2_path', type=str, default='deeplabv2_resnet101_msc-cocostuff164k-100000.pth',
help='the path to the pre-trained deeplabv2 path to compute the coco scores')
parser.add_argument('--calibration_load_in_memory', action='store_true',
help='whether you will load your calibration data into the memory to skip the IO.')
parser.add_argument('--calibration_meta_path', type=str, default=None,
help='the path to the calibration meta file')
return parser
def gather_options(self):
# """Initialize our parser with basic options(only once).
# Add additional model-specific and dataset-specific options.
# These options are defined in the <modify_commandline_options> function
# in model and dataset classes.
# """
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)
# get the basic options
opt, _ = parser.parse_known_args()
# modify model-related parser options
model_name = opt.model
model_option_setter = models.get_option_setter(model_name)
parser = model_option_setter(parser, self.isTrain)
opt, _ = parser.parse_known_args() # parse again with new defaults
# modify dataset-related parser options
dataset_name = opt.dataset_mode
dataset_option_setter = data.get_option_setter(dataset_name)
parser = dataset_option_setter(parser, self.isTrain)
# save and return the parser
self.parser = parser
return parser.parse_args()
def print_options(self, opt):
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
# save to the disk
if self.isTrain:
expr_dir = os.path.join(opt.log_dir)
os.makedirs(expr_dir, exist_ok=True)
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write(message)
opt_file.write('\n')
file_name = os.path.join(expr_dir, 'opt.pkl')
with open(file_name, 'wb') as opt_file:
pickle.dump(opt, opt_file)
def parse(self, verbose=True):
opt = self.gather_options()
opt.isTrain = self.isTrain # train or test
if hasattr(opt, 'contain_dontcare_label') and hasattr(opt, 'no_instance'):
opt.semantic_nc = opt.input_nc + (1 if opt.contain_dontcare_label else 0) + (0 if opt.no_instance else 1)
if verbose:
self.print_options(opt)
# set gpu ids
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids.append(id)
opt.gpu_ids = sorted(opt.gpu_ids)
if len(opt.gpu_ids) > 0:
torch.cuda.set_device(opt.gpu_ids[0])
self.opt = opt
return self.opt