forked from somewacko/deconvfaces
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfacegen.py
executable file
·126 lines (95 loc) · 4.29 KB
/
facegen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
"""
fg
CLI for training and interfacing with FaceGen models.
"""
import argparse
import sys
import types
# ---- Available commands
def train():
"""
Command to train a FaceGen model.
"""
parser = argparse.ArgumentParser(
description = "Trains a FaceGen model using the Radboud Face Database",
usage = "fg <data> [<args>]",
)
parser.add_argument('data', type=str, help=
"Directory where RaFD data lives.")
parser.add_argument('-o', '--output', type=str, default='output', help=
"Directory to output results to.")
parser.add_argument('-m', '--model', type=str, default='', help=
"The model to load. If none specified, a new model will be made instead.")
parser.add_argument('-b', '--batch-size', type=int, default=16, help=
"Batch size to use while training.")
parser.add_argument('-e', '--num-epochs', type=int, default=100, help=
"The number of epochs to train.")
parser.add_argument('-opt', '--optimizer', type=str, default='adam', help=
"Optimizer to use, must be a valid optimizer included in Keras.")
parser.add_argument('-d', '--deconv-layers', type=int, default=5, help=
"The number of deconvolution layers to include in the model.")
parser.add_argument('-k', '--kernels-per-layer', type=int, nargs='+', help=
"The number of kernels to include in each layer.")
parser.add_argument('-v', '--visualize', action='store_true', help=
"Output intermediate results after each epoch.")
args = parser.parse_args(sys.argv[2:])
import facegen.train
if args.deconv_layers > 6:
print("Warning: Having more than 6 deconv layers will create images "
"larger than the original data! (and may not fit in memory)")
facegen.train.train_model(args.data, args.output, args.model,
batch_size = args.batch_size,
num_epochs = args.num_epochs,
optimizer = args.optimizer,
deconv_layers = args.deconv_layers,
kernels_per_layer = args.kernels_per_layer,
generate_intermediate = args.visualize,
verbose = True,
)
def generate():
"""
Command to generate faces with a FaceGen model.
"""
parser = argparse.ArgumentParser(
description = "Generate faces using a trained FaceGen model.",
usage = "fg [<args>]",
)
parser.add_argument('-m', '--model', type=str, required=True, help=
"Model definition file to use.")
parser.add_argument('-o', '--output', type=str, required=True, help=
"Directory to output results to.")
parser.add_argument('-f', '--gen-file', type=str, required=True, help=
"YAML file that specifies the parameters to generate.")
parser.add_argument('-b', '--batch_size', type=int, default=64, help=
"Batch size to use while generating images.")
parser.add_argument('-ext', '--extension', type=str, default='jpg', help=
"Image file extension to use when saving images.")
args = parser.parse_args(sys.argv[2:])
import facegen.generate
facegen.generate.generate_from_yaml(args.gen_file, args.model, args.output,
batch_size=args.batch_size, extension=args.extension)
# ---- Command-line invocation
if __name__ == '__main__':
# Use all functions defined in this file as possible commands to run
cmd_fns = [x for x in locals().values() if isinstance(x, types.FunctionType)]
cmd_names = sorted([fn.__name__ for fn in cmd_fns])
cmd_dict = {fn.__name__: fn for fn in cmd_fns}
parser = argparse.ArgumentParser(
description = "Generate faces using a deconvolution network.",
usage = "fg <command> [<args>]"
)
parser.add_argument('command', type=str, help=
"Command to run. Available commands: {}.".format(cmd_names))
args = parser.parse_args([sys.argv[1]])
cmd = None
try:
cmd = cmd_dict[args.command]
except KeyError:
sys.stderr.write('\033[91m')
sys.stderr.write("\nInvalid command {}!\n\n".format(args.command))
sys.stderr.write('\033[0m')
sys.stderr.flush()
parser.print_help()
if cmd is not None:
cmd()