-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_dcgan_celeba32.py
402 lines (355 loc) · 16.7 KB
/
train_dcgan_celeba32.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
import argparse
import os
import time
from pathlib import Path
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from gans import Generator32, Discriminator32, weights_init
from train_utils import *
from plot_utils import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
cudnn.benchmark = True
def get_dataloader(batch_size, num_workers):
tfms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# we can use an image folder dataset
dataset = datasets.ImageFolder(root='data/celeba32', transform=tfms)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=True,
pin_memory=device.type == 'cuda',
num_workers=num_workers if num_workers
else torch.multiprocessing.cpu_count())
return dataloader
def get_noise(batch_size, dim):
"""Defines the prior probability of z.
Returns:
(torch.Tensor): a Gaussian random tensor with mean 0 and variance 1.
"""
noise = torch.randn(batch_size, dim, 1, 1)
return noise.to(device)
class DataSupplier():
"""Class used to provide batches of real and fake images for training GANs."""
REAL_LABEL = 1
FAKE_LABEL = 0
def __init__(self, dataloader, net_G):
self.dataloader = dataloader
self.batch_size = dataloader.batch_size
self.net_G = net_G
self.latent_dim = net_G.latent_dim
self.iterator = iter(dataloader)
def get_batch_real(self):
"""Returns a batch of real images from the dataloader and training targets
(iterates infinitely on the dataloader).
Returns:
torch.Tensor: tensor data
torch.Tensor: tensor target vector
"""
try:
data_real, _ = next(self.iterator)
except StopIteration:
self.iterator = iter(self.dataloader)
data_real, _ = next(self.iterator)
target_real = torch.full((data_real.size(0),), self.REAL_LABEL)
return data_real.to(device), target_real.to(device)
def get_batch_fake(self, train_G=False):
"""Returns a batch of generated images and training targets.
Returns:
torch.Tensor: tensor data
torch.Tensor: tensor target vector
"""
z = get_noise(self.batch_size, self.latent_dim)
data_fake = self.net_G(z)
if not train_G:
target_fake = torch.full((data_fake.size(0),), self.FAKE_LABEL)
else:
# if we train the generator G, then set training targets to real to
# to fool the discriminator D.
target_fake = torch.full((data_fake.size(0),), self.REAL_LABEL)
return data_fake.to(device), target_fake.to(device)
def train(net_G, net_D, optimizer_G, optimizer_D, criterion, data_supplier, steps, num_updates_D,
num_updates_G, writer=None, savepath=None, figures_dir=None, start_step=1):
"""Full training loop."""
print("Training on", 'GPU' if device.type == 'cuda' else 'CPU')
images_list = []
D_losses = []
G_losses = []
D_reals = []
D_fakes = []
step = 1
# create a random noise vector, will be used during training for visualization
FIXED_NOISE = get_noise(196, args.latent_dim)
tic = time.time() # start time
# updates counter for G / D for writing in tensorboard
updates_cnt_G = 1
updates_cnt_D = 1
# checkpointing
checkpoint = Checkpoint(path=savepath, net_G=net_G, net_D=net_D,
optimizer_G=optimizer_G, optimizer_D=optimizer_D,
step=start_step) if savepath else None
for step in range(start_step, steps+1):
for _ in range(num_updates_D):
# Update the discriminator network D:
# maximize for D: log(D(x)) + log(1 - D(G(z)))
net_D.zero_grad()
# get batches
data_real, target_real = data_supplier.get_batch_real()
data_fake, target_fake = data_supplier.get_batch_fake(False)
# forward pass
# note: use detach() on the fake batch in order to update only D
out_real = net_D(data_real).view(-1)
out_fake = net_D(data_fake.detach()).view(-1)
# sum of criterions on real and fake samples
loss_D = criterion(out_real, target_real) + criterion(out_fake, target_fake)
# backward pass and parameters update
loss_D.backward()
optimizer_D.step()
optimizer_D.zero_grad()
# compute and save metrics, log to tensorboard
avg_D_real = out_real.mean().item()
avg_D_fake = out_fake.mean().item()
D_losses.append(loss_D.item())
D_reals.append(avg_D_real)
D_fakes.append(avg_D_fake)
if writer:
writer.add_scalar("Loss_D", loss_D.item(), updates_cnt_D)
writer.add_scalar("Mean_Real_D(x)", avg_D_real, updates_cnt_D)
writer.add_scalar("Mean_Fake_D(G(z))", avg_D_fake, updates_cnt_D)
updates_cnt_D += 1
for _ in range(num_updates_G):
# Update the generator network G:
# maximize for G: log(D(G(z)))
net_G.zero_grad()
# get batches
# note: fake labels are real for G loss
data_fake, target_fake = data_supplier.get_batch_fake(True)
# forward pass
out_fake = net_D(data_fake).view(-1)
# criterion
loss_G = criterion(out_fake, target_fake)
# backward pass and parameters update
loss_G.backward()
optimizer_G.step()
optimizer_G.zero_grad()
# compute and save metrics, log to tensorboard
G_losses.append(loss_G.item())
if writer:
writer.add_scalar("Loss_G", loss_G.item(), updates_cnt_G)
updates_cnt_G += 1
if (step) % 25 == 0:
# log training metrics
print("[{:5d}/{:5d}]\tLoss_D: {:.4f}\tLoss_G: {:.4f}\tD(x): {:.4f}\tD(G(z)): {:.4f}"
.format(step, steps, loss_D.item(), loss_G.item(), avg_D_real,
avg_D_fake))
if (step) % 100 == 0:
# generate images from the fixed noise
with torch.no_grad():
fake = net_G(FIXED_NOISE).detach().cpu()
grid = vutils.make_grid(fake, padding=2, normalize=True, nrow=14)
images_list.append(grid)
if writer:
writer.add_image('Generated', grid, step)
# plt.figure(figsize=(8,8))
# plt.imshow(np.transpose(images_list[-1], (1, 2, 0)))
# plt.axis('off')
# plt.show()
if checkpoint:
# save checkpoint
checkpoint.step = step
checkpoint.save()
vutils.save_image(grid, figures_dir/'images/step={}.png'.format(step))
print("\n======> Done. Total time {}s\t".format(time.time() - tic))
if checkpoint:
checkpoint.step = steps + 1
checkpoint.save(f'_end_step={steps}')
return images_list, G_losses, D_losses, D_reals, D_fakes
def train_from_checkpoint(checkpoint, criterion, data_supplier, steps, num_updates_D,
num_updates_G, writer=None, savepath=None, figures_dir=None):
"""Train from an existing checkpoint."""
kwargs = locals()
net_G, net_D = checkpoint.net_G, checkpoint.net_D
optimizer_G, optimizer_D = checkpoint.optimizer_G, checkpoint.optimizer_D
start_step = checkpoint.step
kwargs.pop('checkpoint')
# print('Kwargs keys:', tuple(kwargs.keys()))
# return
return train(net_G, net_D, optimizer_G, optimizer_D, start_step=start_step, **kwargs)
def main(args):
dataloader = get_dataloader(args.batch_size, args.workers)
# # plot some training images
# real_batch = next(iter(loader))
# plt.figure(figsize=(10,10))
# plt.axis('off')
# plt.title('Training Images Sample')
# plt.imshow(np.transpose(vutils.make_grid(real_batch[0][:64], padding=2,
# normalize=True), (1, 2, 0)))
# plt.show()
print(args)
net_G = Generator32(args.latent_dim, args.num_feature_maps_G).to(device)
net_D = Discriminator32(args.num_feature_maps_D).to(device)
# initialize the weights of the networks
net_G.apply(weights_init)
net_D.apply(weights_init)
# create the criterion function for the discriminator:
# binary-cross entropy loss
criterion = nn.BCELoss()
# create optimizers
optimizer_G = optim.Adam(net_G.parameters(), lr=args.lr_G,
betas=(args.beta1_G, 0.999))
optimizer_D = optim.Adam(net_D.parameters(), lr=args.lr_D,
betas=(args.beta1_D, 0.999))
supplier = DataSupplier(dataloader, net_G)
# data_real, target_real = supplier.get_batch_real()
# print(f"Real batch: {tuple(data_real.size())} -> {tuple(target_real.size())}")
# data_fake, target_fake = supplier.get_batch_fake()
# print(f"Fake batch (for training D): {tuple(data_fake.size())} -> "
# f"{tuple(target_fake.size())}")
# data_fake, target_fake = supplier.get_batch_fake(True)
# print(f"Fake batch (for training G): {tuple(data_fake.size())} -> "
# f"{tuple(target_fake.size())}")
# experiment name for tensorboard
hparams = get_hparams_dict(args,
ignore_keys={'no_tensorboard', 'workers', 'epochs',
'from_checkpoint', 'no_save'})
exp_name = get_experiment_name(prefix='__DCGAN__CelebA-32__', hparams=hparams)
# path where to save checkpoints
if args.no_save:
savepath = None
figures_dir = None
else:
savedir = Path('./checkpoints/')/exp_name
savedir.mkdir(parents=True)
savepath = savedir/'checkpt.pt'
# will store figures (generated images and metrics plots) into a directory
figures_dir = Path('./figures/')/exp_name
(figures_dir/'images').mkdir(parents=True)
# (figures_dir/'images').mkdir(parents=True, exist_ok=True)
if args.no_tensorboard:
writer = None
else:
writer = SummaryWriter(log_dir='runs/'+ exp_name, flush_secs=10)
# log sample data and net graph in tensorboard
#@todo
if args.from_checkpoint:
checkpoint = Checkpoint(path=args.from_checkpoint, net_G=net_G, net_D=net_D,
optimizer_G=optimizer_G, optimizer_D=optimizer_D,
step=None)
# sd_init = checkpoint.state_dict()
# # print('Chkpt state dict before load: \n\n{} \n'.format(sd_init))
checkpoint.load(map_location=device)
checkpoint.step = 8001
# sd_loaded = checkpoint.state_dict()
# print(checkpoint.net_G.state_dict().keys())
# # print('Chkpt state dict after load: \n\n{} \n'.format(sd_loaded))
# assert(not torch.allclose(sd_init['net_G']['model.0.weight'], sd_loaded['net_G']['model.0.weight']))
# assert(not torch.allclose(sd_init['net_G']['model.3.weight'], sd_loaded['net_G']['model.3.weight']))
images_list, G_losses, D_losses, D_reals, D_fakes = \
train_from_checkpoint(checkpoint, criterion, supplier, args.steps,
args.num_updates_D, args.num_updates_G, writer,
savepath, figures_dir)
else:
images_list, G_losses, D_losses, D_reals, D_fakes = \
train(net_G, net_D, optimizer_G, optimizer_D, criterion, supplier,
args.steps, args.num_updates_D, args.num_updates_G, writer,
savepath, figures_dir)
if figures_dir:
show_metrics(D_losses, G_losses, D_reals, D_fakes, savepath=figures_dir/'metrics.svg')
if __name__ == '__main__':
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description="Trains a DCGAN on the CelebA-32 dataset")
# number of workers for dataloader
parser.add_argument('--workers', default=None, type=int,
help="number of workers for dataloader (defaults to None: maximum number of cores)")
# size of the latent vector z, the generator input
parser.add_argument('--latent-dim', default=100, type=int,
help="size of the latent vector z, the generator input")
# base size of feature maps in discriminator / generator
parser.add_argument('--num-feature-maps-D', default=32, type=int,
help="base size of feature maps in discriminator")
parser.add_argument('--num-feature-maps-G', default=32, type=int,
help="base size of feature maps in generator")
# learning rate for the discriminator / generator
parser.add_argument('--lr-D', default=0.002, type=float,
help="learning rate for the discriminator")
parser.add_argument('--lr-G', default=0.002, type=float,
help="learning rate for the generator")
# momentum beta1 for the discriminator / generattor
parser.add_argument('--beta1-D', default=0.5, type=float,
help="momentum beta1 for the discriminator")
parser.add_argument('--beta1-G', default=0.5, type=float,
help="momentum beta1 for the generator")
# number of images per batch
parser.add_argument('--batch-size', default=256, type=int,
help="number of images per batch")
# number of sub-steps of discriminator / generator optim. at each step
parser.add_argument('--num-updates-D', default=1, type=int,
help="number of sub-steps of discriminator optim. at each step")
parser.add_argument('--num-updates-G', default=1, type=int,
help="number of sub-steps of generator optim. at each step")
# number of global steps in the training loop
parser.add_argument('--steps', default=8000, type=int,
help="number of global steps in the training loop")
parser.add_argument('--epochs', default=None, type=int,
help="number of epochs; leave None if you set the number of steps (i.e. batch updates")
# do not log metrics to tensorboard
parser.add_argument('--no-tensorboard', action='store_true',
help="if specified, do not log metrics to tensorboard")
parser.add_argument('--from-checkpoint', default=None, type=str,
help='resume training from the checkpoint at the specified path')
parser.add_argument('--no-save', action='store_true',
help='if specified, do not save checkpoints or figures')
args = parser.parse_args()
if args.epochs is None:
args.epochs = (args.steps * args.batch_size) / (args.num_updates_D * 202000)
else:
args.steps = int(args.epochs * args.num_updates_D * 202000 / args.batch_size)
return args
args = parse_args()
# args = argparse.Namespace()
# number of workers for dataloader (/!\ set to None when you're done
# debugging))
# args.workers = 0
# # size of the latent vector z, the generator input
# args.latent_dim = 100
# # base size of feature maps in discriminator / generator
# args.num_feature_maps_D = 32
# args.num_feature_maps_G = 32
# # learning rate for the discriminator / generator
# args.lr_D = 0.0002
# args.lr_G = 0.0002
# # momentum beta1 for the discriminator / generator
# args.beta1_D = 0.5
# args.beta1_G = 0.5
# # number of images per batch
# args.batch_size = 256
# # number of sub-steps of discriminator / generator optim. at each step
# args.num_updates_D = 1
# args.num_updates_G = 1
# # number of global steps in the training loop
# args.steps = 8000
# # number of epochs; leave None fi you set the number of steps (i.e. batch updates)
# args.epochs = None
# if args.epochs is None:
# args.epochs = (args.steps * args.batch_size) / (args.num_updates_D * 202000)
# else:
# args.steps = int(args.epochs * args.num_updates_D * 202000 / args.batch_size)
# # if False, log to tensorboard
# args.no_tensorboard = False
np.random.seed(42) # random seed for reproducibility
torch.manual_seed(42)
main(args)