forked from RenYurui/StructureFlow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
83 lines (64 loc) · 2.67 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import torch
import argparse
import shutil
from src.config import Config
from src.structure_flow import StructureFlow
def main(mode=None):
r"""starts the model
Args:
mode : train, test, eval, reads from config file if not specified
"""
config = load_config(mode)
config.MODE = mode
os.environ['CUDA_VISIBLE_DEVICES'] = ''.join(str(e) for e in config.GPU)
if torch.cuda.is_available():
config.DEVICE = torch.device("cuda")
torch.backends.cudnn.benchmark = True # cudnn auto-tuner
else:
config.DEVICE = torch.device("cpu")
model = StructureFlow(config)
if mode == 'train':
# config.print()
print('\nstart training...\n')
model.train()
elif mode == 'test':
print('\nstart test...\n')
model.test()
elif mode == 'eval':
print('\nstart eval...\n')
model.eval()
def load_config(mode=None):
r"""loads model config
"""
parser = argparse.ArgumentParser()
parser.add_argument('--name', type=str, help='output model name.')
parser.add_argument('--config', type=str, default='model_config.yaml', help='Path to the config file.')
parser.add_argument('--path', type=str, default='./results', help='outputs path')
parser.add_argument("--resume_all", action="store_true", help='load model from checkpoints')
parser.add_argument("--remove_log", action="store_true", help='remove previous tensorboard log files')
if mode == 'test':
parser.add_argument('--input', type=str, help='path to the input image files')
parser.add_argument('--mask', type=str, help='path to the mask files')
parser.add_argument('--structure', type=str, help='path to the structure files')
parser.add_argument('--output', type=str, help='path to the output directory')
parser.add_argument('--model', type=int, default=3, help='which model to test')
opts = parser.parse_args()
config = Config(opts, mode)
output_dir = os.path.join(opts.path, opts.name)
perpare_sub_floder(output_dir)
if mode == 'train':
config_dir = os.path.join(output_dir, 'config.yaml')
shutil.copyfile(opts.config, config_dir)
return config
def perpare_sub_floder(output_path):
img_dir = os.path.join(output_path, 'images')
if not os.path.exists(img_dir):
print("Creating directory: {}".format(img_dir))
os.makedirs(img_dir)
checkpoints_dir = os.path.join(output_path, 'checkpoints')
if not os.path.exists(checkpoints_dir):
print("Creating directory: {}".format(checkpoints_dir))
os.makedirs(checkpoints_dir)
if __name__ == "__main__":
main()