Skip to content

Commit

Permalink
adding smooth transition for reals
Browse files Browse the repository at this point in the history
  • Loading branch information
dnitti-psee committed Aug 19, 2021
1 parent ec10163 commit 51d8559
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 39 deletions.
16 changes: 14 additions & 2 deletions GAN/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,24 @@ def training_step(self, batch, batch_idx):
if self.version >= 3:
self.num_scales = int(math.log2(real.shape[-1])) - 2
idx = self.get_idx(self.num_scales)
lower_idx = min(self.num_scales, math.floor(idx))
higher_idx = min(self.num_scales, math.ceil(idx))
w1 = weight_formula(lower_idx, idx)
w2 = weight_formula(higher_idx, idx)
sum_w = w1 + w2
w1 /= sum_w
w2 /= sum_w

real = real[:real.shape[0]//max(1,higher_idx)]
size = real.shape[-1] // (2 ** (self.num_scales-higher_idx))
real = torch.nn.functional.interpolate(real, size=(size, size), mode="area")
if lower_idx != higher_idx:
low_res_real = torch.nn.functional.avg_pool2d(real, (2, 2))
low_res_real = torch.nn.functional.upsample(low_res_real, scale_factor=2, mode='nearest')
real = w1 * low_res_real + w2 * real

if batch_idx % 200 == 0:
print('current size real', size)
print('current size real', real.shape,'idx',idx,'w',w1,w2)
if batch_idx % ratio == 0:
d_opt.zero_grad()
d_x = self._disc_step(real)
Expand Down Expand Up @@ -184,7 +197,6 @@ def _get_disc_loss(self, real: torch.Tensor) -> torch.Tensor:

def get_idx(self, max_val):
idx = self.global_step / self.speed_transition
idx = idx ** 0.5
if idx >= max_val:
idx = max_val
return idx
Expand Down
2 changes: 1 addition & 1 deletion GAN/main_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_experiment(name):
'wgan1_custom': "--batch_size 48 --l2_loss_weight 0 --version 3 --loss wgangp"
" --weight_decay 0.0 --beta1 0.0 --use_std --use_avg --norm_disc Identity --custom_conv --speed_transition 50000",
'wgan_custom': "--batch_size 80 --l2_loss_weight 0 --version 3 --loss wgangp"
" --weight_decay 0.0 --beta1 0 --use_std --use_avg --norm_disc Identity --custom_conv --speed_transition 20000"
" --weight_decay 0.0 --beta1 0 --use_std --use_avg --norm_disc Identity --custom_conv --speed_transition 40000"
}
str_exp = experiments[name] + ' --name ' + name
return str_exp.split(' ')
Expand Down
77 changes: 41 additions & 36 deletions GAN/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import torch
from torch import Tensor
from pytorch_lightning import Callback, LightningModule, Trainer
import time,random
import time, random
import torchvision


def sampler(batch_size, dim, device, length):
return torch.randn((batch_size, dim), device=device)

# return torch.rand(batch_size, dim, device=device) * 2 * length - length

class TensorboardGenerativeModelImageSampler(Callback):
Expand All @@ -33,15 +35,15 @@ class TensorboardGenerativeModelImageSampler(Callback):
"""

def __init__(
self,
num_samples: int = 3,
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
norm_range: Optional[Tuple[int, int]] = None,
scale_each: bool = False,
pad_value: int = 0,
length: int = 2,
self,
num_samples: int = 3,
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
norm_range: Optional[Tuple[int, int]] = None,
scale_each: bool = False,
pad_value: int = 0,
length: int = 2,
) -> None:
"""
Args:
Expand Down Expand Up @@ -69,25 +71,25 @@ def __init__(
self.pad_value = pad_value
self.length = length

def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule,outputs,
batch, batch_idx: int, dataloader_idx: int) -> None:
if trainer.global_step % 500 != 0 or trainer.global_step==0:
def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs,
batch, batch_idx: int, dataloader_idx: int) -> None:
if trainer.global_step % 500 != 0 or trainer.global_step == 0:
return
z = sampler(self.num_samples, pl_module.hparams.latent_dim, pl_module.device, self.length)

# generate images
with torch.no_grad():
pl_module.eval()
img = pl_module(z)
if isinstance(img,list):
if isinstance(img, list):
images = []
for img_i in img:
if img_i is None:
images.append(None)
else:
images.append(torch.nn.functional.interpolate(img_i, size=(256, 256)))
else:
images = torch.nn.functional.interpolate(img, size=(256,256))
images = torch.nn.functional.interpolate(img, size=(256, 256))
pl_module.train()

if isinstance(img, list):
Expand All @@ -106,7 +108,7 @@ def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule,output
scale_each=self.scale_each,
pad_value=self.pad_value,
)
str_title = f"{pl_module.__class__.__name__}_images{2**(i+2)}"
str_title = f"{pl_module.__class__.__name__}_images{2 ** (i + 2)}"
trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step)

else:
Expand All @@ -127,7 +129,7 @@ def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule,output
trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step)

grid_real = torchvision.utils.make_grid(
tensor=batch[0],
tensor=batch[0][:self.nrow*self.nrow],
nrow=self.nrow,
padding=self.padding,
normalize=self.normalize,
Expand All @@ -139,11 +141,12 @@ def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule,output
trainer.logger.experiment.add_image(str_title, grid_real, global_step=trainer.global_step)

time.sleep(random.random())
if not os.path.exists(os.path.join(trainer.log_dir,'images')):
os.makedirs(os.path.join(trainer.log_dir,'images'),exist_ok=True)
if not os.path.exists(os.path.join(trainer.log_dir, 'images')):
os.makedirs(os.path.join(trainer.log_dir, 'images'), exist_ok=True)
torchvision.utils.save_image(grid.cpu(), os.path.join(trainer.log_dir,
'images', 'sampled{:07d}.png'.format(trainer.global_step)))


class LatentDimInterpolator(Callback):
"""
Interpolates the latent space for a model by setting all dims to zero and stepping
Expand All @@ -159,14 +162,14 @@ class LatentDimInterpolator(Callback):
"""

def __init__(
self,
interpolate_epoch_interval: int = 20,
range_start: int = -1,
range_end: int = 1,
steps: int = 11,
num_samples: int = 2,
normalize: bool = True,
callback=None
self,
interpolate_epoch_interval: int = 20,
range_start: int = -1,
range_end: int = 1,
steps: int = 11,
num_samples: int = 2,
normalize: bool = True,
callback=None
):
"""
Args:
Expand All @@ -184,14 +187,14 @@ def __init__(
self.num_samples = num_samples
self.normalize = normalize
self.steps = steps
self.callback=callback
self.callback = callback

def on_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
if self.callback is not None and trainer.global_step>10:
if self.callback is not None and trainer.global_step > 10:
self.callback(False)

def on_batch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
if trainer.global_step % 500 != 0 or trainer.global_step==0:
if trainer.global_step % 500 != 0 or trainer.global_step == 0:
return
if (trainer.current_epoch + 1) % self.interpolate_epoch_interval == 0:
images = self.interpolate_latent_space(
Expand All @@ -204,10 +207,11 @@ def on_batch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
grid = torchvision.utils.make_grid(images, nrow=num_rows, normalize=self.normalize)
str_title = f'{pl_module.__class__.__name__}_latent_space'
trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step)
if not os.path.exists(os.path.join(trainer.log_dir,'images')):
os.makedirs(os.path.join(trainer.log_dir,'images'))
if not os.path.exists(os.path.join(trainer.log_dir, 'images')):
os.makedirs(os.path.join(trainer.log_dir, 'images'))
torchvision.utils.save_image(grid.cpu(), os.path.join(trainer.log_dir,
'images', 'latent{:07d}.png'.format(trainer.global_step)))
'images', 'latent{:07d}.png'.format(trainer.global_step)))

def interpolate_latent_space(self, pl_module: LightningModule, latent_dim: int) -> List[Tensor]:
images = []
with torch.no_grad():
Expand All @@ -227,14 +231,14 @@ def interpolate_latent_space(self, pl_module: LightningModule, latent_dim: int)
if isinstance(img, list):
idx = -1
img_tmp = img[idx]
while img_tmp is None and abs(idx)<len(img):
while img_tmp is None and abs(idx) < len(img):
idx -= 1
img_tmp = img[idx]
img = img_tmp
if img is None:
img = torch.zeros((self.num_samples, *pl_module.img_dim))
else:
img = torch.nn.functional.interpolate(img, size=(pl_module.img_dim[-2],pl_module.img_dim[-1]))
img = torch.nn.functional.interpolate(img, size=(pl_module.img_dim[-2], pl_module.img_dim[-1]))
if len(img.size()) == 2:
img = img.view(self.num_samples, *pl_module.img_dim)

Expand All @@ -249,7 +253,8 @@ def interpolate_latent_space(self, pl_module: LightningModule, latent_dim: int)
def compute_gradient_penalty(D, real_samples, fake_samples, idx=None):
"""Calculates the gradient penalty loss for WGAN GP"""
# Random weight term for interpolation between real and fake samples
alpha = torch.rand((real_samples.size(0), 1, 1, 1), device=real_samples.device)# Tensor(np.random.random((real_samples.size(0), 1, 1, 1)),device=real_samples.device)
alpha = torch.rand((real_samples.size(0), 1, 1, 1),
device=real_samples.device) # Tensor(np.random.random((real_samples.size(0), 1, 1, 1)),device=real_samples.device)
# Get random interpolation between real and fake samples
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)

Expand Down

0 comments on commit 51d8559

Please sign in to comment.