From ec10163d5719df999058389f3e6a25d5820ff7af Mon Sep 17 00:00:00 2001 From: davide Date: Fri, 20 Aug 2021 00:58:19 +0200 Subject: [PATCH] several fixes --- GAN/components.py | 295 ++++++++++++++++++++++++++++++++++++++ GAN/gan.py | 331 +++++++++++++++++++++++++++++++++++++++++++ GAN/main_gan.py | 175 +++++++++++++++++++++++ GAN/modules.py | 122 ++++++++++++++++ GAN/utils.py | 270 +++++++++++++++++++++++++++++++++++ RL/agent/__init__.py | 0 6 files changed, 1193 insertions(+) create mode 100644 GAN/components.py create mode 100644 GAN/gan.py create mode 100644 GAN/main_gan.py create mode 100644 GAN/modules.py create mode 100644 GAN/utils.py delete mode 100644 RL/agent/__init__.py diff --git a/GAN/components.py b/GAN/components.py new file mode 100644 index 0000000..cbee72c --- /dev/null +++ b/GAN/components.py @@ -0,0 +1,295 @@ +# Based on https://github.com/pytorch/examples/blob/master/dcgan/main.py +import torch +import math +from torch import nn +from core.modules import View, PixelNorm2d +import core.modules +import random +from GAN.modules import EqualizedConv2d + + +def weight_formula(i, idx, speed=2): + if idx > i: + w = max(0, min(1, 1 - abs(i - idx) * speed)) + else: + w = max(0, min(1, speed - abs(i - idx) * speed)) + return w + + +class FirstLayer(nn.Module): + + def __init__(self, latent_dim, size_out, out_channels) -> None: + """ + Args: + latent_dim: Dimension of the latent space + feature_maps: Number of feature maps to use + image_channels: Number of channels of the images from the dataset + """ + super().__init__() + self.size_out = size_out + self.linear = nn.Linear(latent_dim, size_out * size_out * out_channels) + self.out_channels = out_channels + self.bn = nn.Sequential(PixelNorm2d(self.out_channels), nn.LeakyReLU(0.2)) + + def forward(self, noise: torch.Tensor) -> torch.Tensor: + l1 = self.linear(noise.view(noise.shape[0], -1)) + l1_view = l1.view(noise.shape[0], self.out_channels, self.size_out, self.size_out) + return self.bn(l1_view) + + +class DCGANGenerator(nn.Module): + + def __init__(self, latent_dim: int, feature_maps: int, image_channels: int, version: float, size: int, + custom_conv: bool) -> None: + """ + Args: + latent_dim: Dimension of the latent space + feature_maps: Number of feature maps to use + image_channels: Number of channels of the images from the dataset + """ + super().__init__() + self.num_layers = int(math.log2(size)) - 2 + self.version = version + if version == 1: + self.gen = nn.Sequential( + FirstLayer(latent_dim, 4, feature_maps), + self._make_gen_block(feature_maps, feature_maps, custom_conv=custom_conv), # 8x8 + self._make_gen_block(feature_maps, feature_maps // 2, custom_conv=custom_conv), # 16x16 + self._make_gen_block(feature_maps // 2, feature_maps // 4, custom_conv=custom_conv), # 32x32 + self._make_gen_block(feature_maps // 4, feature_maps // 8, scale=1, custom_conv=custom_conv), # 32x32 + self._make_gen_block(feature_maps // 8, image_channels, last_block=True, custom_conv=custom_conv) # 64x64 + ) + elif version == 2.1: + gen_layers = [ + FirstLayer(latent_dim, 4, feature_maps), + self._make_gen_block(feature_maps, feature_maps // 2, custom_conv=custom_conv), # 8x8 + self._make_gen_block(feature_maps // 2, feature_maps // 4, custom_conv=custom_conv), # 16x16 + self._make_gen_block(feature_maps // 4, feature_maps // 4, scale=1, custom_conv=custom_conv), + self._make_gen_block(feature_maps // 4, feature_maps // 8, custom_conv=custom_conv), # 32x32 + self._make_gen_block(feature_maps // 8, feature_maps // 8, scale=1, custom_conv=custom_conv), + self._make_gen_block(feature_maps // 8, image_channels, last_block=True, custom_conv=custom_conv) # 64x64 + ] + self.gen = nn.Sequential(*gen_layers) + elif version == 3: + gen_layers = [nn.Sequential(FirstLayer(latent_dim, 4, feature_maps), + self._make_gen_block(feature_maps, feature_maps, scale=1, custom_conv=custom_conv))] + out_layers = [self._make_gen_block(feature_maps, image_channels, scale=1, last_block=True, custom_conv=custom_conv)] + num_features = feature_maps + for layer in range(self.num_layers): + out_features = num_features if layer <= 3 else num_features // 2 + gen_layers += [nn.Sequential( + self._make_gen_block(num_features, out_features, custom_conv=custom_conv), + self._make_gen_block(out_features, out_features, scale=1, custom_conv=custom_conv))] + out_layers += [ + self._make_gen_block(out_features, image_channels, scale=1, last_block=True, custom_conv=custom_conv) + ] + num_features = out_features + self.gen = nn.ModuleList(gen_layers) + self.out_layers = nn.ModuleList(out_layers) + else: + raise NotImplementedError + + @staticmethod + def _make_gen_block( + in_channels: int, + out_channels: int, + kernel_size: int = 3, + scale: int = 2, + bias: bool = True, + last_block: bool = False, + use_tanh=False, + custom_conv=False + ) -> nn.Sequential: + if custom_conv: + conv = EqualizedConv2d + if not bias: + print('_make_gen_block: setting bias to True') + bias = True + else: + conv = nn.Conv2d + if use_tanh: + last_act = nn.Tanh() + else: + last_act = nn.Identity() + if scale > 1: + upscale = nn.Upsample(scale_factor=scale) + else: + upscale = nn.Identity() + if not last_block: + gen_block = nn.Sequential( + upscale, + conv(in_channels, out_channels, kernel_size, 1, kernel_size // 2, bias=bias), + PixelNorm2d(out_channels), + nn.LeakyReLU(0.2) + ) + else: + gen_block = nn.Sequential( + upscale, + conv(in_channels, out_channels, kernel_size, 1, kernel_size // 2, bias=bias), + last_act + ) + + return gen_block + + def forward(self, noise: torch.Tensor, idx: float): + if self.version < 3: + return torch.clamp(self.gen(noise), -1, 1) + else: + out = None + layer = noise + assert self.num_layers == len(self.gen) - 1 + lower_idx = min(self.num_layers, math.floor(idx)) + higher_idx = min(self.num_layers, math.ceil(idx)) + debug = random.random() < 0.002 and False + sum_w = 0.0 + for l in range(higher_idx + 1): + if debug: + print(l) + layer = self.gen[l](layer) + if lower_idx <= l <= higher_idx: + w = weight_formula(l, idx) + sum_w += w + if out is None: + out = w * self.out_layers[l](layer) + else: + new_out = self.out_layers[l](layer) + out = torch.nn.functional.interpolate(out, scale_factor=2, mode='nearest') + w * new_out + assert sum_w > 0 + return out / sum_w + + +class DCGANDiscriminator(nn.Module): + + def __init__(self, feature_maps: int, image_channels: int, version: float, size: int, use_avg: bool, + norm: str, use_std: bool, custom_conv: bool) -> None: + """ + Args: + feature_maps: Number of feature maps to use + image_channels: Number of channels of the images from the dataset + """ + super().__init__() + self.num_layers = int(math.log2(size)) - 2 + self.version = version + self.use_std = use_std + if self.version < 3: + num_features = feature_maps // (2 ** (self.num_layers - 1)) + self.disc = [ + self._make_disc_block(image_channels, num_features, use_avg=use_avg, norm=norm, custom_conv=custom_conv)] + for l in range(self.num_layers - 1): + # self.disc.append(self._make_disc_block(num_features, num_features, use_avg=use_avg, stride=1,norm=norm)) + self.disc.append( + self._make_disc_block(num_features, num_features * 2, use_avg=use_avg, norm=norm, custom_conv=custom_conv)) + num_features *= 2 + assert num_features == feature_maps + self.disc.append(self._make_disc_block(num_features, 1, kernel_size=4, + stride=1, padding=0, last_block=True, custom_conv=custom_conv)) + + self.disc = nn.Sequential(*self.disc) + else: + self.avg2x = nn.AvgPool2d(2) + if use_std: + chan_std = 1 + else: + chan_std = 0 + num_features = feature_maps # // (2 ** (self.num_layers - 3)) + self.disc = [] + self.from_rgb = [self._make_disc_block(image_channels, feature_maps, stride=1, use_avg=use_avg, norm=norm, + custom_conv=custom_conv)] + + self.out = nn.Sequential( + self._make_disc_block(feature_maps + chan_std, feature_maps, stride=1, use_avg=use_avg, norm=norm, + custom_conv=custom_conv), + self._make_disc_block(feature_maps, 1, kernel_size=4, stride=1, padding=0, last_block=True, + custom_conv=custom_conv)) + + for num_l in range(self.num_layers): + num_features //= 2 + self.from_rgb.append(self._make_disc_block(image_channels, num_features, stride=1, use_avg=use_avg, norm=norm, + custom_conv=custom_conv)) + single_disc = [] + single_disc.append(self._make_disc_block(num_features, num_features, stride=1, use_avg=use_avg, norm=norm, + custom_conv=custom_conv)) + single_disc.append( + self._make_disc_block(num_features, num_features * 2, use_avg=use_avg, norm=norm, custom_conv=custom_conv)) + + self.disc.append(nn.Sequential(*single_disc)) + + self.disc = nn.ModuleList(self.disc) + self.from_rgb = nn.ModuleList(self.from_rgb) + + @staticmethod + def _make_disc_block( + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 2, + padding: int = 1, + bias: bool = True, + last_block: bool = False, + use_avg=False, + custom_conv=False, + norm="" + ) -> nn.Sequential: + if use_avg: + stride_conv = 1 + if stride > 1: + downscale = nn.AvgPool2d(stride) + else: + downscale = nn.Identity() + else: + downscale = nn.Identity() + stride_conv = stride + if custom_conv: + conv = EqualizedConv2d + if not bias: + print('_make_gen_block: setting bias to True') + bias = True + else: + conv = nn.Conv2d + if not last_block: + if hasattr(nn, norm): + norm_layer = getattr(nn, norm) + else: + norm_layer = getattr(core.modules, norm) + disc_block = nn.Sequential( + conv(in_channels, out_channels, kernel_size, stride_conv, padding, bias=bias), + norm_layer(out_channels), + nn.LeakyReLU(0.2), + downscale + ) + else: + disc_block = nn.Sequential( + conv(in_channels, out_channels, kernel_size, stride, padding, bias=bias) # , + # nn.Sigmoid(), + ) + + return disc_block + + def forward(self, x, idx): + if self.version < 3: + return self.disc(x).view(x.shape[0], 1) + else: + lower_idx = min(self.num_layers, math.floor(idx)) + higher_idx = min(self.num_layers, math.ceil(idx)) + w1 = weight_formula(lower_idx, idx) + w2 = weight_formula(higher_idx, idx) + sum_w = w1 + w2 + w1 /= sum_w + w2 /= sum_w + if random.random() < 0.001: + print('idx', lower_idx, idx, higher_idx, 'w', w1, w2) + if lower_idx == higher_idx: + o = self.from_rgb[lower_idx](x) + # print(o.shape) + if higher_idx > 0: + o = self.disc[higher_idx - 1](o) + else: + o = w1 * self.from_rgb[lower_idx](self.avg2x(x)) + w2 * self.disc[higher_idx - 1](self.from_rgb[higher_idx](x)) + # print(lower_idx, higher_idx, o.shape) + for l in range(higher_idx - 2, -1, -1): + o = self.disc[l](o) + # print(lower_idx,higher_idx,o.shape) + if self.use_std: + o = core.modules.miniBatchStdDev(o) + o = self.out(o) + return o.view(x.shape[0], 1) diff --git a/GAN/gan.py b/GAN/gan.py new file mode 100644 index 0000000..c4f8357 --- /dev/null +++ b/GAN/gan.py @@ -0,0 +1,331 @@ +from typing import Any +import os, math +import pytorch_lightning as pl +import torch +from torch import nn +from GAN.components import DCGANDiscriminator, DCGANGenerator, weight_formula + +from GAN import utils + + + + +class DCGAN(pl.LightningModule): + """ + DCGAN implementation. + + Example:: + + from pl_bolts.models.gan import DCGAN + + m = DCGAN() + Trainer(gpus=2).fit(m) + + Example CLI:: + + # mnist + python dcgan_module.py --gpus 1 + + # cifar10 + python dcgan_module.py --gpus 1 --dataset cifar10 --image_channels 3 + """ + + def __init__( + self, + beta1: float = 0.5, + feature_maps_disc: int = 64, + image_channels: int = 1, + latent_dim: int = 100, + lambda_gp: float = 10, + decay: float = 0.0, + loss: str = "", + length: int = 2, + version: float = None, + l2_loss_weight: float = None, + speed_transition = 40000, + **kwargs: Any, + ) -> None: + """ + Args: + beta1: Beta1 value for Adam optimizer + feature_maps_gen: Number of feature maps to use for the generator + feature_maps_disc: Number of feature maps to use for the discriminator + image_channels: Number of channels of the images from the dataset + latent_dim: Dimension of the latent space + learning_rate: Learning rate + """ + super().__init__() + self.save_hyperparameters() + self.generator = self._get_generator() + self.discriminator = self._get_discriminator() + self.loss = loss + self.version = version + self.args = kwargs + self.img_dim = (image_channels, kwargs['image_size'], kwargs['image_size']) + if loss == "rals": + self.dic_loss_func = self._get_disc_loss_lsregan + self.gen_loss_func = self._get_gen_loss_lsregan + elif loss == "wgangp": + self.dic_loss_func = self._get_disc_loss_wgangp + self.gen_loss_func = self._get_gen_loss_wgangp + elif loss == "dcgan": + self.dic_loss_func = self._get_disc_loss + self.gen_loss_func = self._get_gen_loss + self.length = length + self.decay = decay + self.l2_loss_weight = l2_loss_weight + self.lambda_gp = lambda_gp + self.speed_transition = speed_transition + + @property + def automatic_optimization(self) -> bool: + return False + + def _get_generator(self) -> nn.Module: + generator = DCGANGenerator(self.hparams.latent_dim, self.hparams.feature_maps_gen, + self.hparams.image_channels, self.hparams.version, self.hparams.image_size, self.hparams.custom_conv) + if not self.hparams.custom_conv: + generator.apply(self._weights_init) + return generator + + def _get_discriminator(self) -> nn.Module: + discriminator = DCGANDiscriminator(self.hparams.feature_maps_disc, self.hparams.image_channels, self.hparams.version, + self.hparams.image_size, self.hparams.use_avg, self.hparams.norm_disc, + self.hparams.use_std, self.hparams.custom_conv) + if not self.hparams.custom_conv: + discriminator.apply(self._weights_init) + return discriminator + + @staticmethod + def _weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + torch.nn.init.normal_(m.weight, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + torch.nn.init.normal_(m.weight, 1.0, 0.02) + torch.nn.init.zeros_(m.bias) + + def configure_optimizers(self): + lr = self.hparams.learning_rate + betas = (self.hparams.beta1, 0.99) + opt_disc = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=betas, weight_decay=self.decay) + opt_gen = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=betas, weight_decay=self.decay) + return opt_disc, opt_gen + + def forward(self, noise: torch.Tensor) -> torch.Tensor: + """ + Generates an image given input noise + + Example:: + + noise = torch.rand(batch_size, latent_dim) + gan = GAN.load_from_checkpoint(PATH) + img = gan(noise) + """ + noise = noise.view(*noise.shape, 1, 1) + return self.generator(noise, self.get_idx(self.num_scales)) + + def training_step(self, batch, batch_idx): + d_opt, g_opt = self.optimizers() + ratio = 1 + real, _ = batch + self.num_scales = 1 + if self.version >= 3: + self.num_scales = int(math.log2(real.shape[-1])) - 2 + idx = self.get_idx(self.num_scales) + higher_idx = min(self.num_scales, math.ceil(idx)) + size = real.shape[-1] // (2 ** (self.num_scales-higher_idx)) + real = torch.nn.functional.interpolate(real, size=(size, size), mode="area") + if batch_idx % 200 == 0: + print('current size real', size) + if batch_idx % ratio == 0: + d_opt.zero_grad() + d_x = self._disc_step(real) + if self.args['use_tpu']: + self.manual_backward(d_x,d_opt) + else: + self.manual_backward(d_x) + d_opt.step() + + g_opt.zero_grad() + g_x = self._gen_step(real) + if self.args['use_tpu']: + self.manual_backward(g_x,g_opt) + else: + self.manual_backward(g_x) + g_opt.step() + if batch_idx % ratio == 0: + self.log_dict({'g_loss': g_x, 'd_loss': d_x}, prog_bar=True) + + def _disc_step(self, real: torch.Tensor) -> torch.Tensor: + disc_loss = self.dic_loss_func(real) + self.log("loss/disc", disc_loss, on_step=True, on_epoch=True) + return disc_loss + + def _gen_step(self, real: torch.Tensor) -> torch.Tensor: + gen_loss = self.gen_loss_func(real) + self.log("loss/gen", gen_loss, on_step=True, on_epoch=True) + return gen_loss + + def _get_disc_loss(self, real: torch.Tensor) -> torch.Tensor: + # Train with real + real_pred = self.discriminator(real, self.get_idx(self.num_scales)) + real_gt = torch.ones_like(real_pred) + real_loss = torch.nn.functional.binary_cross_entropy_with_logits(real_pred, real_gt) + + # Train with fake + fake_pred = self._get_fake_pred(real) + fake_gt = torch.zeros_like(fake_pred) + fake_loss = torch.nn.functional.binary_cross_entropy_with_logits(fake_pred, fake_gt) + + disc_loss = real_loss + fake_loss + + return disc_loss + + def get_idx(self, max_val): + idx = self.global_step / self.speed_transition + idx = idx ** 0.5 + if idx >= max_val: + idx = max_val + return idx + + def _get_disc_loss_wgangp(self, real: torch.Tensor) -> torch.Tensor: + # Train with real + idx = self.get_idx(self.num_scales) + real_pred = self.discriminator(real, idx) + # Train with fake + fake_pred, fake = self._get_fake_pred(real, True) + + gradient_penalty = utils.compute_gradient_penalty(self.discriminator, real, fake, idx) + disc_loss = (-torch.mean(real_pred) + torch.mean(fake_pred) + self.lambda_gp * gradient_penalty) + + return disc_loss + + def _get_gen_loss_wgangp(self, real: torch.Tensor) -> torch.Tensor: + # Train with real + idx = self.get_idx(self.num_scales) + real_pred = self.discriminator(real, idx) + # Train with fake + fake_pred = self._get_fake_pred(real) + + self.log("loss/realfake_diff", torch.mean(real_pred) - torch.mean(fake_pred), + on_step=True, on_epoch=True) + self.log("loss/real_mean", torch.mean(real_pred), + on_step=True, on_epoch=True) + self.log("loss/fake_mean", torch.mean(fake_pred), + on_step=True, on_epoch=True) + gen_loss = (-torch.mean(fake_pred)) + return gen_loss + + def _get_disc_loss_lsregan(self, real: torch.Tensor) -> torch.Tensor: + if self.version>=3: + num_layers = self.discriminator.num_layers + assert num_layers==self.num_scales + + idx = self.get_idx(self.num_scales) + # Train with real + real_pred = self.discriminator(real,idx) + # Train with fake + fake_pred = self._get_fake_pred(real) + disc_loss = torch.mean((real_pred - torch.mean(fake_pred) - 1) ** 2) + \ + torch.mean((fake_pred - torch.mean(real_pred) + 1) ** 2) + if self.l2_loss_weight > 0: + if isinstance(real_pred, list): + l2_loss = 0.0 + sum_w = 0.0 + for i in range(len(real_pred)): + if fake_pred[i] is None: + assert real_pred[i] is None + continue + w = weight_formula(i, idx) + sum_w += w + l2_loss += w * self.l2_loss_weight * (torch.mean(fake_pred[i] ** 2) + torch.mean(real_pred[i] ** 2)) + l2_loss /= sum_w + else: + l2_loss = self.l2_loss_weight * (torch.mean(fake_pred ** 2) + torch.mean(real_pred ** 2)) + return disc_loss / 2 + l2_loss + else: + return disc_loss / 2 + + def _get_gen_loss_lsregan(self, real: torch.Tensor) -> torch.Tensor: + idx = self.get_idx(self.num_scales) + + # Train with real + real_pred = self.discriminator(real,idx) + # Train with fake + fake_pred = self._get_fake_pred(real) + if isinstance(real_pred, list): + gen_loss = 0.0 + sum_w = .0 + for i in range(len(real_pred)): + + if fake_pred[i] is None: + assert real_pred[i] is None + continue + w = weight_formula(i, idx) + sum_w += w + if self.global_step % 100 == 0: + print(2 ** (i + 2), 'idx', idx, 'w', w) + gen_loss += w * (torch.mean((real_pred[i] - torch.mean(fake_pred[i]) + 1) ** 2) + + torch.mean((fake_pred[i] - torch.mean(real_pred[i]) - 1) ** 2)) + self.log("loss/realfake_diff" + str(2 ** (i + 2)), torch.mean(real_pred[i]) - torch.mean(fake_pred[i]), + on_step=True, on_epoch=True) + gen_loss /= sum_w + + else: + self.log("loss/realfake_diff", torch.mean(real_pred) - torch.mean(fake_pred), + on_step=True, on_epoch=True) + self.log("loss/real_mean", torch.mean(real_pred), + on_step=True, on_epoch=True) + self.log("loss/fake_mean", torch.mean(fake_pred), + on_step=True, on_epoch=True) + gen_loss = torch.mean((real_pred - torch.mean(fake_pred) + 1) ** 2) + \ + torch.mean((fake_pred - torch.mean(real_pred) - 1) ** 2) + if self.l2_loss_weight > 0 and False: # disabled for generator + if isinstance(real_pred, list): + l2_loss = 0.0 + if self.global_step % 50 == 0: + print() + sum_w = .0 + for i in range(len(real_pred)): + if fake_pred[i] is None: + assert real_pred[i] is None + continue + w = weight_formula(i, idx) + sum_w += w + + l2_loss += w * self.l2_loss_weight * (torch.mean(fake_pred[i] ** 2) + torch.mean(real_pred[i] ** 2)) + l2_loss /= sum_w + else: + l2_loss = self.l2_loss_weight * (torch.mean(fake_pred ** 2) + torch.mean(real_pred ** 2)) + self.log("loss/l2_loss", l2_loss, on_step=True, on_epoch=True) + return gen_loss / 2 + l2_loss + else: + return gen_loss / 2 + + def _get_gen_loss(self, real: torch.Tensor) -> torch.Tensor: + # Train with fake + fake_pred = self._get_fake_pred(real) + fake_gt = torch.ones_like(fake_pred) + gen_loss = torch.nn.functional.binary_cross_entropy_with_logits(fake_pred, fake_gt) + + return gen_loss + + def _get_fake_pred(self, real: torch.Tensor, return_fake: bool = False) -> torch.Tensor: + if isinstance(real, list): + for r in real: + if r is not None: + batch_size = r.shape[0] + break + else: + batch_size = real.shape[0] + noise = self._get_noise(batch_size, self.hparams.latent_dim) + fake = self(noise) + fake_pred = self.discriminator(fake,self.get_idx(self.num_scales)) + if return_fake: + return fake_pred, fake + else: + return fake_pred + + def _get_noise(self, n_samples: int, latent_dim: int) -> torch.Tensor: + return utils.sampler(n_samples, latent_dim, device=self.device, length=self.length) diff --git a/GAN/main_gan.py b/GAN/main_gan.py new file mode 100644 index 0000000..fe7648f --- /dev/null +++ b/GAN/main_gan.py @@ -0,0 +1,175 @@ +from argparse import ArgumentParser +from typing import Any +import time +import os, math +import pytorch_lightning as pl +import torch +from torch import nn +from torch.utils.data import DataLoader +from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks import ModelCheckpoint +from GAN.gan import DCGAN +from torchvision import transforms as transform_lib +from torchvision.datasets import LSUN, MNIST, ImageFolder + +from GAN import utils +#import os +#os.environ['CUDA_VISIBLE_DEVICES']='0' + + +def get_args(args_list=None): + parser = ArgumentParser(description='GAN') + parser.add_argument("--batch_size", default=64, type=int) + parser.add_argument("--dataset_type", default="face", type=str, choices=["face", "lsun", "mnist"]) + parser.add_argument("--image_size", default=128, type=int) + parser.add_argument("--num_workers", default=0 if __name__ == "__main__" else 8, type=int) + + parser.add_argument('--dataset', default="/mnt/teradisk/davide/datasets/celeba/", + help='dataset path e.g. https://drive.google.com/open?id=0BxYys69jI14kYVM3aVhKS1VhRUk') + parser.add_argument('--res_dir', default='./', help='result dir') + parser.add_argument('--net_params', default={}, type=dict, help='net_params') + + parser.add_argument("--beta1", default=0.0, type=float) + parser.add_argument("--latent_dim", default=512, type=int) + parser.add_argument("--loss", default="rals", type=str, choices=["rals", "dcgan",'wgangp']) + parser.add_argument("--length", default=1, type=float) + parser.add_argument("--weight_decay", default=0.00000, type=float) + parser.add_argument("--l2_loss_weight", default=1, type=float) + + parser.add_argument("--use_tpu", action="store_true") + parser.add_argument("--use_std", action="store_true") + + parser.add_argument("--custom_conv", action="store_true") + + parser.add_argument("--use_avg", action="store_true") + parser.add_argument("--norm_disc", default="Identity") # e.g., Identity PixelNorm2d BatchNorm2d + + parser.add_argument("--version", default=3, type=float) + parser.add_argument("--name", default="newdiscr") + parser.add_argument("--speed_transition", default=50000, type=float) + + args = parser.parse_args(args_list) + + return args + + +def main(args=None, callback=None, upload_checkpoint=False): + args.feature_maps_gen = args.latent_dim + args.feature_maps_disc = args.latent_dim + pl.seed_everything(1234) + + if args.dataset_type == "face": + dataset = ImageFolder(root=args.dataset, + transform=transform_lib.Compose([ + transform_lib.Resize(args.image_size), + transform_lib.CenterCrop(args.image_size), + transform_lib.ToTensor(), + transform_lib.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ])) + image_channels = 3 + elif args.dataset_type == "lsun": + transforms = transform_lib.Compose([ + transform_lib.Resize(args.image_size), + transform_lib.CenterCrop(args.image_size), + transform_lib.ToTensor(), + transform_lib.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]) + dataset = LSUN(root=args.dataset, classes=["bedroom_train"], transform=transforms) + image_channels = 3 + elif args.dataset_type == "mnist": + transforms = transform_lib.Compose([ + transform_lib.Resize(args.image_size), + transform_lib.ToTensor(), + transform_lib.Normalize((0.5,), (0.5,)), + ]) + dataset = MNIST(root=args.dataset, download=True, transform=transforms) + image_channels = 1 + + dataloader = DataLoader( + dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers + ) + version = 0 + if args.use_tpu: + tpu_cores = 8 + gpus = None + use_tpu_string = "" + else: + tpu_cores = None + gpus = 1 + use_tpu_string = "" + if args.use_avg: + use_avg = "avg" + else: + use_avg = "noavg" + use_std = "with_std" if args.use_std else "no_std" + dirpath = "{}_loss_{}_latent_{}_decay{}_v{}{}_l2_loss_weight{}_{}_{}_b{}_{}".format(args.name, args.loss, args.latent_dim, + args.weight_decay, + args.version, + use_tpu_string, args.l2_loss_weight, + use_avg, args.norm_disc, + args.beta1,use_std) + print(dirpath) + resume_from_checkpoint = os.path.join(args.res_dir, dirpath, + "version_{}".format(version), 'checkpoints', 'last.ckpt') + if not os.path.exists(resume_from_checkpoint): + resume_from_checkpoint = None + if args.loss == "rals": + args.learning_rate = 0.0004 + elif args.loss == "dcgan": + args.learning_rate = 0.0002 + elif args.loss == "wgangp": + if args.custom_conv: + args.learning_rate = 0.001 + else: + args.learning_rate = 0.00025 + print(vars(args)) + + model = DCGAN(**vars(args), image_channels=image_channels) + + callbacks = [ + ModelCheckpoint(filename='last', save_last=True), + utils.TensorboardGenerativeModelImageSampler(length=args.length, num_samples=9, normalize=True, nrow=3), + utils.LatentDimInterpolator(range_start=-args.length, range_end=args.length, + interpolate_epoch_interval=1, + normalize=True, callback=callback) + ] + tb_logger = pl_loggers.TensorBoardLogger(save_dir=args.res_dir, name=dirpath, version=version) + print('starting trainer') + trainer = pl.Trainer(tpu_cores=tpu_cores, gpus=gpus, logger=tb_logger, resume_from_checkpoint=resume_from_checkpoint, + callbacks=callbacks, checkpoint_callback=True, max_epochs=100) + trainer.fit(model, dataloader) + +def get_experiment(name): + experiments = { + 'ral0':"--batch_size 48 --l2_loss_weight 0.1 --version 2.1 --loss rals --image_size 64" + "--weight_decay 0.000001 --beta1 0.5 --use_std --use_avg --norm_disc Identity", + 'ral3':"--batch_size 64 --l2_loss_weight 0. --version 2.1 --loss rals --image_size 64" + "--weight_decay 0.000001 --beta1 0. --use_avg --norm_disc BatchNorm2d", + 'ral4':"--batch_size 64 --l2_loss_weight 0. --version 2.1 --loss rals --image_size 64" + "--weight_decay 0.000001 --beta1 0. --use_avg --norm_disc Identity", + 'ral1':"--batch_size 64 --l2_loss_weight 0.2 --version 3 --loss rals" + "--weight_decay 0.0 --beta1 0.0 --use_std --use_avg --norm_disc Identity", + 'ral2': "--batch_size 64 --l2_loss_weight 0 --version 3 --loss rals" + "--weight_decay 0.0 --beta1 0.0 --use_std --use_avg --norm_disc Identity", + 'wgan1_new': "--batch_size 32 --l2_loss_weight 0 --version 3 --loss wgangp" + " --weight_decay 0.0 --beta1 0. --use_std --use_avg --norm_disc Identity --speed_transition 25000", + 'wgan1_slow': "--batch_size 40 --l2_loss_weight 0 --version 3 --loss wgangp" + " --weight_decay 0.0 --beta1 0.0 --use_std --use_avg --norm_disc Identity", + 'wgan1_custom_big': "--batch_size 32 --l2_loss_weight 0 --version 3 --loss wgangp" + " --weight_decay 0.0 --beta1 0.0 --use_std --use_avg --norm_disc Identity --custom_conv --speed_transition 60000", + 'wgan1_custom': "--batch_size 48 --l2_loss_weight 0 --version 3 --loss wgangp" + " --weight_decay 0.0 --beta1 0.0 --use_std --use_avg --norm_disc Identity --custom_conv --speed_transition 50000", + 'wgan_custom': "--batch_size 80 --l2_loss_weight 0 --version 3 --loss wgangp" + " --weight_decay 0.0 --beta1 0 --use_std --use_avg --norm_disc Identity --custom_conv --speed_transition 20000" + } + str_exp = experiments[name] + ' --name ' + name + return str_exp.split(' ') + +if __name__ == "__main__": + base_dir_res = "/home/davide/results/GAN_face"#"/home/davide/Dropbox/Apps/davide_colab/results/GAN_face_gpu/output/"# + base_dir_dataset = "/mnt/teradisk/davide/datasets/celeba/" + list_args = ['--dataset', base_dir_dataset, + '--res_dir', base_dir_res] + list_args += get_experiment('wgan_custom') + args = get_args(list_args) + main(args) diff --git a/GAN/modules.py b/GAN/modules.py new file mode 100644 index 0000000..3867256 --- /dev/null +++ b/GAN/modules.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import math + +import torch.nn as nn + +from numpy import prod + + +class NormalizationLayer(nn.Module): + + def __init__(self): + super(NormalizationLayer, self).__init__() + + def forward(self, x, epsilon=1e-8): + return x * (((x**2).mean(dim=1, keepdim=True) + epsilon).rsqrt()) + + +def Upscale2d(x, factor=2): + assert isinstance(factor, int) and factor >= 1 + if factor == 1: + return x + s = x.size() + x = x.view(-1, s[1], s[2], 1, s[3], 1) + x = x.expand(-1, s[1], s[2], factor, s[3], factor) + x = x.contiguous().view(-1, s[1], s[2] * factor, s[3] * factor) + return x + + +def getLayerNormalizationFactor(x): + r""" + Get He's constant for the given layer + https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf + """ + size = x.weight.size() + fan_in = prod(size[1:]) + + return math.sqrt(2.0 / fan_in) + + +class ConstrainedLayer(nn.Module): + r""" + A handy refactor that allows the user to: + - initialize one layer's bias to zero + - apply He's initialization at runtime + """ + + def __init__(self, + module, + equalized=True, + lrMul=1.0, + initBiasToZero=True): + r""" + equalized (bool): if true, the layer's weight should evolve within + the range (-1, 1) + initBiasToZero (bool): if true, bias will be initialized to zero + """ + + super(ConstrainedLayer, self).__init__() + + self.module = module + self.equalized = equalized + + if initBiasToZero: + self.module.bias.data.fill_(0) + if self.equalized: + self.module.weight.data.normal_(0, 1) + self.module.weight.data /= lrMul + self.weight = getLayerNormalizationFactor(self.module) * lrMul + + def forward(self, x): + + x = self.module(x) + if self.equalized: + x *= self.weight + return x + + +class EqualizedConv2d(ConstrainedLayer): + + def __init__(self, + nChannelsPrevious, + nChannels, + kernelSize, + stride, + padding=0, + bias=True, + **kwargs): + r""" + A nn.Conv2d module with specific constraints + Args: + nChannelsPrevious (int): number of channels in the previous layer + nChannels (int): number of channels of the current layer + kernelSize (int): size of the convolutional kernel + padding (int): convolution's padding + bias (bool): with bias ? + """ + + ConstrainedLayer.__init__(self, + nn.Conv2d(nChannelsPrevious, nChannels, + kernelSize, stride, padding=padding, + bias=bias), + **kwargs) + + +class EqualizedLinear(ConstrainedLayer): + + def __init__(self, + nChannelsPrevious, + nChannels, + bias=True, + **kwargs): + r""" + A nn.Linear module with specific constraints + Args: + nChannelsPrevious (int): number of channels in the previous layer + nChannels (int): number of channels of the current layer + bias (bool): with bias ? + """ + + ConstrainedLayer.__init__(self, + nn.Linear(nChannelsPrevious, nChannels, + bias=bias), **kwargs) diff --git a/GAN/utils.py b/GAN/utils.py new file mode 100644 index 0000000..eca4476 --- /dev/null +++ b/GAN/utils.py @@ -0,0 +1,270 @@ +from typing import Optional, Tuple, List +import numpy as np +import os +import torch +from torch import Tensor +from pytorch_lightning import Callback, LightningModule, Trainer +import time,random +import torchvision + +def sampler(batch_size, dim, device, length): + return torch.randn((batch_size, dim), device=device) +# return torch.rand(batch_size, dim, device=device) * 2 * length - length + +class TensorboardGenerativeModelImageSampler(Callback): + """ + Generates images and logs to tensorboard. + Your model must implement the ``forward`` function for generation + + Requirements:: + + # model must have img_dim arg + model.img_dim = (1, 28, 28) + + # model forward must work for sampling + z = torch.rand(batch_size, latent_dim) + img_samples = your_model(z) + + Example:: + + from pl_bolts.callbacks import TensorboardGenerativeModelImageSampler + + trainer = Trainer(callbacks=[TensorboardGenerativeModelImageSampler()]) + """ + + def __init__( + self, + num_samples: int = 3, + nrow: int = 8, + padding: int = 2, + normalize: bool = False, + norm_range: Optional[Tuple[int, int]] = None, + scale_each: bool = False, + pad_value: int = 0, + length: int = 2, + ) -> None: + """ + Args: + num_samples: Number of images displayed in the grid. Default: ``3``. + nrow: Number of images displayed in each row of the grid. + The final grid size is ``(B / nrow, nrow)``. Default: ``8``. + padding: Amount of padding. Default: ``2``. + normalize: If ``True``, shift the image to the range (0, 1), + by the min and max values specified by :attr:`range`. Default: ``False``. + norm_range: Tuple (min, max) where min and max are numbers, + then these numbers are used to normalize the image. By default, min and max + are computed from the tensor. + scale_each: If ``True``, scale each image in the batch of + images separately rather than the (min, max) over all images. Default: ``False``. + pad_value: Value for the padded pixels. Default: ``0``. + """ + + super().__init__() + self.num_samples = num_samples + self.nrow = nrow + self.padding = padding + self.normalize = normalize + self.norm_range = norm_range + self.scale_each = scale_each + self.pad_value = pad_value + self.length = length + + def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule,outputs, + batch, batch_idx: int, dataloader_idx: int) -> None: + if trainer.global_step % 500 != 0 or trainer.global_step==0: + return + z = sampler(self.num_samples, pl_module.hparams.latent_dim, pl_module.device, self.length) + + # generate images + with torch.no_grad(): + pl_module.eval() + img = pl_module(z) + if isinstance(img,list): + images = [] + for img_i in img: + if img_i is None: + images.append(None) + else: + images.append(torch.nn.functional.interpolate(img_i, size=(256, 256))) + else: + images = torch.nn.functional.interpolate(img, size=(256,256)) + pl_module.train() + + if isinstance(img, list): + for i, images_i in enumerate(images): + if images_i is None: + continue + if len(images_i.size()) == 2: + img_dim = pl_module.img_dim + images_i = images_i.view(self.num_samples, *img_dim) + grid = torchvision.utils.make_grid( + tensor=images_i, + nrow=self.nrow, + padding=self.padding, + normalize=self.normalize, + range=self.norm_range, + scale_each=self.scale_each, + pad_value=self.pad_value, + ) + str_title = f"{pl_module.__class__.__name__}_images{2**(i+2)}" + trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step) + + else: + if len(images.size()) == 2: + img_dim = pl_module.img_dim + images = images.view(self.num_samples, *img_dim) + + grid = torchvision.utils.make_grid( + tensor=images, + nrow=self.nrow, + padding=self.padding, + normalize=self.normalize, + range=self.norm_range, + scale_each=self.scale_each, + pad_value=self.pad_value, + ) + str_title = f"{pl_module.__class__.__name__}_images" + trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step) + + grid_real = torchvision.utils.make_grid( + tensor=batch[0], + nrow=self.nrow, + padding=self.padding, + normalize=self.normalize, + range=self.norm_range, + scale_each=self.scale_each, + pad_value=self.pad_value + ) + str_title = f"{pl_module.__class__.__name__}_real_images" + trainer.logger.experiment.add_image(str_title, grid_real, global_step=trainer.global_step) + + time.sleep(random.random()) + if not os.path.exists(os.path.join(trainer.log_dir,'images')): + os.makedirs(os.path.join(trainer.log_dir,'images'),exist_ok=True) + torchvision.utils.save_image(grid.cpu(), os.path.join(trainer.log_dir, + 'images', 'sampled{:07d}.png'.format(trainer.global_step))) + +class LatentDimInterpolator(Callback): + """ + Interpolates the latent space for a model by setting all dims to zero and stepping + through the first two dims increasing one unit at a time. + + Default interpolates between [-5, 5] (-5, -4, -3, ..., 3, 4, 5) + + Example:: + + from pl_bolts.callbacks import LatentDimInterpolator + + Trainer(callbacks=[LatentDimInterpolator()]) + """ + + def __init__( + self, + interpolate_epoch_interval: int = 20, + range_start: int = -1, + range_end: int = 1, + steps: int = 11, + num_samples: int = 2, + normalize: bool = True, + callback=None + ): + """ + Args: + interpolate_epoch_interval: default 20 + range_start: default -5 + range_end: default 5 + steps: number of step between start and end + num_samples: default 2 + normalize: default True (change image to (0, 1) range) + """ + super().__init__() + self.interpolate_epoch_interval = interpolate_epoch_interval + self.range_start = range_start + self.range_end = range_end + self.num_samples = num_samples + self.normalize = normalize + self.steps = steps + self.callback=callback + + def on_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + if self.callback is not None and trainer.global_step>10: + self.callback(False) + + def on_batch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + if trainer.global_step % 500 != 0 or trainer.global_step==0: + return + if (trainer.current_epoch + 1) % self.interpolate_epoch_interval == 0: + images = self.interpolate_latent_space( + pl_module, + latent_dim=pl_module.hparams.latent_dim # type: ignore[union-attr] + ) + images = torch.cat(images, dim=0) # type: ignore[assignment] + + num_rows = self.steps + grid = torchvision.utils.make_grid(images, nrow=num_rows, normalize=self.normalize) + str_title = f'{pl_module.__class__.__name__}_latent_space' + trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step) + if not os.path.exists(os.path.join(trainer.log_dir,'images')): + os.makedirs(os.path.join(trainer.log_dir,'images')) + torchvision.utils.save_image(grid.cpu(), os.path.join(trainer.log_dir, + 'images', 'latent{:07d}.png'.format(trainer.global_step))) + def interpolate_latent_space(self, pl_module: LightningModule, latent_dim: int) -> List[Tensor]: + images = [] + with torch.no_grad(): + pl_module.eval() + for z1 in np.linspace(self.range_start, self.range_end, self.steps): + for z2 in np.linspace(self.range_start, self.range_end, self.steps): + # set all dims to zero + z = torch.zeros(self.num_samples, latent_dim, device=pl_module.device) + + # set the fist 2 dims to the value + z[:, 0] = torch.tensor(z1) + z[:, 1] = torch.tensor(z2) + + # sample + # generate images + img = pl_module(z) + if isinstance(img, list): + idx = -1 + img_tmp = img[idx] + while img_tmp is None and abs(idx)