diff --git a/USAGE.md b/USAGE.md index e3b1ec4..b07f228 100644 --- a/USAGE.md +++ b/USAGE.md @@ -5,7 +5,7 @@ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses ### Setup -We only test our code in the following environment. +We only tested our code in the following environment. - OS: Ubuntu 16.04 - CUDA: 9.1 - **Python 2 from Anaconda2** diff --git a/demo.py b/demo.py index 13371ee..9a93795 100644 --- a/demo.py +++ b/demo.py @@ -2,20 +2,11 @@ Copyright (C) 2018 NVIDIA Corporation. All rights reserved. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). """ -import os -import torch -from torch.autograd import Variable -import torchvision.transforms as transforms -import torchvision.utils as utils + import argparse -import time -import numpy as np -import cv2 -from PIL import Image -from photo_wct import PhotoWCT -from photo_smooth import Propagator -from smooth_filter import smooth_filter +import process_stylization +from photo_wct import PhotoWCT parser = argparse.ArgumentParser(description='Photorealistic Image Stylization') parser.add_argument('--vgg1', default='./models/vgg_normalised_conv1_1_mask.t7', help='Path to the VGG conv1_1') @@ -37,48 +28,13 @@ # Load model p_wct = PhotoWCT(args) -p_pro = Propagator() p_wct.cuda(0) -content_image_path = args.content_image_path -content_seg_path = args.content_seg_path -style_image_path = args.style_image_path -style_seg_path = args.style_seg_path -output_image_path = args.output_image_path - -# Load image -cont_img = Image.open(content_image_path).convert('RGB') -styl_img = Image.open(style_image_path).convert('RGB') -try: - cont_seg = Image.open(content_seg_path) - styl_seg = Image.open(style_seg_path) -except: - cont_seg = [] - styl_seg = [] - -cont_img = transforms.ToTensor()(cont_img).unsqueeze(0) -styl_img = transforms.ToTensor()(styl_img).unsqueeze(0) -cont_img = Variable(cont_img.cuda(0), volatile=True) -styl_img = Variable(styl_img.cuda(0), volatile=True) - -cont_seg = np.asarray(cont_seg) -styl_seg = np.asarray(styl_seg) - -start_style_time = time.time() -stylized_img = p_wct.transform(cont_img, styl_img, cont_seg, styl_seg) -end_style_time = time.time() -print('Elapsed time in stylization: %f' % (end_style_time - start_style_time)) -utils.save_image(stylized_img.data.cpu().float(), output_image_path, nrow=1) - -start_propagation_time = time.time() -out_img = p_pro.process(output_image_path, content_image_path) -end_propagation_time = time.time() -print('Elapsed time in propagation: %f' % (end_propagation_time - start_propagation_time)) -cv2.imwrite(output_image_path, out_img) - -start_postprocessing_time = time.time() -out_img = smooth_filter(output_image_path, content_image_path, f_radius=15, f_edge=1e-1) -end_postprocessing_time = time.time() -print('Elapsed time in post processing: %f' % (end_postprocessing_time - start_postprocessing_time)) - -out_img.save(output_image_path) +process_stylization.stylization( + p_wct=p_wct, + content_image_path=args.content_image_path, + style_image_path=args.style_image_path, + content_seg_path=args.content_seg_path, + style_seg_path=args.style_seg_path, + output_image_path=args.output_image_path, +) diff --git a/models.py b/models.py index cd7d9f8..a029852 100644 --- a/models.py +++ b/models.py @@ -3,21 +3,20 @@ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). """ import torch.nn as nn -import torch class VGGEncoder1(nn.Module): def __init__(self, vgg1): super(VGGEncoder1, self).__init__() # 224 x 224 self.conv1 = nn.Conv2d(3, 3, 1, 1, 0) - self.conv1.weight = torch.nn.Parameter(vgg1.get(0).weight.float()) - self.conv1.bias = torch.nn.Parameter(vgg1.get(0).bias.float()) + self.conv1.weight = nn.Parameter(vgg1.get(0).weight.float()) + self.conv1.bias = nn.Parameter(vgg1.get(0).bias.float()) # 224 x 224 self.reflect_pad1 = nn.ReflectionPad2d((1, 1, 1, 1)) # 226 x 226 self.conv2 = nn.Conv2d(3, 64, 3, 1, 0) - self.conv2.weight = torch.nn.Parameter(vgg1.get(2).weight.float()) - self.conv2.bias = torch.nn.Parameter(vgg1.get(2).bias.float()) + self.conv2.weight = nn.Parameter(vgg1.get(2).weight.float()) + self.conv2.bias = nn.Parameter(vgg1.get(2).bias.float()) self.relu = nn.ReLU(inplace=True) # 224 x 224 @@ -35,8 +34,8 @@ def __init__(self, d1): self.reflect_pad2 = nn.ReflectionPad2d((1, 1, 1, 1)) # 226 x 226 self.conv3 = nn.Conv2d(64, 3, 3, 1, 0) - self.conv3.weight = torch.nn.Parameter(d1.get(1).weight.float()) - self.conv3.bias = torch.nn.Parameter(d1.get(1).bias.float()) + self.conv3.weight = nn.Parameter(d1.get(1).weight.float()) + self.conv3.bias = nn.Parameter(d1.get(1).bias.float()) # 224 x 224 def forward(self, x): @@ -50,21 +49,21 @@ def __init__(self, vgg): super(VGGEncoder2, self).__init__() # 224 x 224 self.conv1 = nn.Conv2d(3, 3, 1, 1, 0) - self.conv1.weight = torch.nn.Parameter(vgg.get(0).weight.float()) - self.conv1.bias = torch.nn.Parameter(vgg.get(0).bias.float()) + self.conv1.weight = nn.Parameter(vgg.get(0).weight.float()) + self.conv1.bias = nn.Parameter(vgg.get(0).bias.float()) self.reflect_pad2 = nn.ReflectionPad2d((1, 1, 1, 1)) # 226 x 226 self.conv2 = nn.Conv2d(3, 64, 3, 1, 0) - self.conv2.weight = torch.nn.Parameter(vgg.get(2).weight.float()) - self.conv2.bias = torch.nn.Parameter(vgg.get(2).bias.float()) + self.conv2.weight = nn.Parameter(vgg.get(2).weight.float()) + self.conv2.bias = nn.Parameter(vgg.get(2).bias.float()) self.relu2 = nn.ReLU(inplace=True) # 224 x 224 self.reflect_pad3 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv3 = nn.Conv2d(64, 64, 3, 1, 0) - self.conv3.weight = torch.nn.Parameter(vgg.get(5).weight.float()) - self.conv3.bias = torch.nn.Parameter(vgg.get(5).bias.float()) + self.conv3.weight = nn.Parameter(vgg.get(5).weight.float()) + self.conv3.bias = nn.Parameter(vgg.get(5).bias.float()) self.relu3 = nn.ReLU(inplace=True) # 224 x 224 @@ -73,8 +72,8 @@ def __init__(self, vgg): self.reflect_pad4 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv4 = nn.Conv2d(64, 128, 3, 1, 0) - self.conv4.weight = torch.nn.Parameter(vgg.get(9).weight.float()) - self.conv4.bias = torch.nn.Parameter(vgg.get(9).bias.float()) + self.conv4.weight = nn.Parameter(vgg.get(9).weight.float()) + self.conv4.bias = nn.Parameter(vgg.get(9).bias.float()) self.relu4 = nn.ReLU(inplace=True) # 112 x 112 @@ -99,8 +98,8 @@ def __init__(self, d): # decoder self.reflect_pad5 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv5 = nn.Conv2d(128, 64, 3, 1, 0) - self.conv5.weight = torch.nn.Parameter(d.get(1).weight.float()) - self.conv5.bias = torch.nn.Parameter(d.get(1).bias.float()) + self.conv5.weight = nn.Parameter(d.get(1).weight.float()) + self.conv5.bias = nn.Parameter(d.get(1).bias.float()) self.relu5 = nn.ReLU(inplace=True) # 112 x 112 @@ -110,15 +109,15 @@ def __init__(self, d): self.reflect_pad6 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv6 = nn.Conv2d(64, 64, 3, 1, 0) - self.conv6.weight = torch.nn.Parameter(d.get(5).weight.float()) - self.conv6.bias = torch.nn.Parameter(d.get(5).bias.float()) + self.conv6.weight = nn.Parameter(d.get(5).weight.float()) + self.conv6.bias = nn.Parameter(d.get(5).bias.float()) self.relu6 = nn.ReLU(inplace=True) # 224 x 224 self.reflect_pad7 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv7 = nn.Conv2d(64, 3, 3, 1, 0) - self.conv7.weight = torch.nn.Parameter(d.get(8).weight.float()) - self.conv7.bias = torch.nn.Parameter(d.get(8).bias.float()) + self.conv7.weight = nn.Parameter(d.get(8).weight.float()) + self.conv7.bias = nn.Parameter(d.get(8).bias.float()) def forward(self, x, pool_idx, pool): out = self.reflect_pad5(x) @@ -138,21 +137,21 @@ def __init__(self, vgg): super(VGGEncoder3, self).__init__() # 224 x 224 self.conv1 = nn.Conv2d(3, 3, 1, 1, 0) - self.conv1.weight = torch.nn.Parameter(vgg.get(0).weight.float()) - self.conv1.bias = torch.nn.Parameter(vgg.get(0).bias.float()) + self.conv1.weight = nn.Parameter(vgg.get(0).weight.float()) + self.conv1.bias = nn.Parameter(vgg.get(0).bias.float()) self.reflect_pad1 = nn.ReflectionPad2d((1, 1, 1, 1)) # 226 x 226 self.conv2 = nn.Conv2d(3, 64, 3, 1, 0) - self.conv2.weight = torch.nn.Parameter(vgg.get(2).weight.float()) - self.conv2.bias = torch.nn.Parameter(vgg.get(2).bias.float()) + self.conv2.weight = nn.Parameter(vgg.get(2).weight.float()) + self.conv2.bias = nn.Parameter(vgg.get(2).bias.float()) self.relu2 = nn.ReLU(inplace=True) # 224 x 224 self.reflect_pad3 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv3 = nn.Conv2d(64, 64, 3, 1, 0) - self.conv3.weight = torch.nn.Parameter(vgg.get(5).weight.float()) - self.conv3.bias = torch.nn.Parameter(vgg.get(5).bias.float()) + self.conv3.weight = nn.Parameter(vgg.get(5).weight.float()) + self.conv3.bias = nn.Parameter(vgg.get(5).bias.float()) self.relu3 = nn.ReLU(inplace=True) # 224 x 224 @@ -161,15 +160,15 @@ def __init__(self, vgg): self.reflect_pad4 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv4 = nn.Conv2d(64, 128, 3, 1, 0) - self.conv4.weight = torch.nn.Parameter(vgg.get(9).weight.float()) - self.conv4.bias = torch.nn.Parameter(vgg.get(9).bias.float()) + self.conv4.weight = nn.Parameter(vgg.get(9).weight.float()) + self.conv4.bias = nn.Parameter(vgg.get(9).bias.float()) self.relu4 = nn.ReLU(inplace=True) # 112 x 112 self.reflect_pad5 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv5 = nn.Conv2d(128, 128, 3, 1, 0) - self.conv5.weight = torch.nn.Parameter(vgg.get(12).weight.float()) - self.conv5.bias = torch.nn.Parameter(vgg.get(12).bias.float()) + self.conv5.weight = nn.Parameter(vgg.get(12).weight.float()) + self.conv5.bias = nn.Parameter(vgg.get(12).bias.float()) self.relu5 = nn.ReLU(inplace=True) # 112 x 112 @@ -178,8 +177,8 @@ def __init__(self, vgg): self.reflect_pad6 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv6 = nn.Conv2d(128, 256, 3, 1, 0) - self.conv6.weight = torch.nn.Parameter(vgg.get(16).weight.float()) - self.conv6.bias = torch.nn.Parameter(vgg.get(16).bias.float()) + self.conv6.weight = nn.Parameter(vgg.get(16).weight.float()) + self.conv6.bias = nn.Parameter(vgg.get(16).bias.float()) self.relu6 = nn.ReLU(inplace=True) # 56 x 56 @@ -211,8 +210,8 @@ def __init__(self, d): # decoder self.reflect_pad7 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv7 = nn.Conv2d(256, 128, 3, 1, 0) - self.conv7.weight = torch.nn.Parameter(d.get(1).weight.float()) - self.conv7.bias = torch.nn.Parameter(d.get(1).bias.float()) + self.conv7.weight = nn.Parameter(d.get(1).weight.float()) + self.conv7.bias = nn.Parameter(d.get(1).bias.float()) self.relu7 = nn.ReLU(inplace=True) # 56 x 56 @@ -222,15 +221,15 @@ def __init__(self, d): self.reflect_pad8 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv8 = nn.Conv2d(128, 128, 3, 1, 0) - self.conv8.weight = torch.nn.Parameter(d.get(5).weight.float()) - self.conv8.bias = torch.nn.Parameter(d.get(5).bias.float()) + self.conv8.weight = nn.Parameter(d.get(5).weight.float()) + self.conv8.bias = nn.Parameter(d.get(5).bias.float()) self.relu8 = nn.ReLU(inplace=True) # 112 x 112 self.reflect_pad9 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv9 = nn.Conv2d(128, 64, 3, 1, 0) - self.conv9.weight = torch.nn.Parameter(d.get(8).weight.float()) - self.conv9.bias = torch.nn.Parameter(d.get(8).bias.float()) + self.conv9.weight = nn.Parameter(d.get(8).weight.float()) + self.conv9.bias = nn.Parameter(d.get(8).bias.float()) self.relu9 = nn.ReLU(inplace=True) self.unpool2 = nn.MaxUnpool2d(kernel_size=2, stride=2) @@ -238,14 +237,14 @@ def __init__(self, d): self.reflect_pad10 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv10 = nn.Conv2d(64, 64, 3, 1, 0) - self.conv10.weight = torch.nn.Parameter(d.get(12).weight.float()) - self.conv10.bias = torch.nn.Parameter(d.get(12).bias.float()) + self.conv10.weight = nn.Parameter(d.get(12).weight.float()) + self.conv10.bias = nn.Parameter(d.get(12).bias.float()) self.relu10 = nn.ReLU(inplace=True) self.reflect_pad11 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv11 = nn.Conv2d(64, 3, 3, 1, 0) - self.conv11.weight = torch.nn.Parameter(d.get(15).weight.float()) - self.conv11.bias = torch.nn.Parameter(d.get(15).bias.float()) + self.conv11.weight = nn.Parameter(d.get(15).weight.float()) + self.conv11.bias = nn.Parameter(d.get(15).bias.float()) def forward(self, x, pool_idx, pool1, pool_idx2, pool2): out = self.reflect_pad7(x) @@ -273,21 +272,21 @@ def __init__(self, vgg): # vgg # 224 x 224 self.conv1 = nn.Conv2d(3, 3, 1, 1, 0) - self.conv1.weight = torch.nn.Parameter(vgg.get(0).weight.float()) - self.conv1.bias = torch.nn.Parameter(vgg.get(0).bias.float()) + self.conv1.weight = nn.Parameter(vgg.get(0).weight.float()) + self.conv1.bias = nn.Parameter(vgg.get(0).bias.float()) self.reflect_pad1 = nn.ReflectionPad2d((1, 1, 1, 1)) # 226 x 226 self.conv2 = nn.Conv2d(3, 64, 3, 1, 0) - self.conv2.weight = torch.nn.Parameter(vgg.get(2).weight.float()) - self.conv2.bias = torch.nn.Parameter(vgg.get(2).bias.float()) + self.conv2.weight = nn.Parameter(vgg.get(2).weight.float()) + self.conv2.bias = nn.Parameter(vgg.get(2).bias.float()) self.relu2 = nn.ReLU(inplace=True) # 224 x 224 self.reflect_pad3 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv3 = nn.Conv2d(64, 64, 3, 1, 0) - self.conv3.weight = torch.nn.Parameter(vgg.get(5).weight.float()) - self.conv3.bias = torch.nn.Parameter(vgg.get(5).bias.float()) + self.conv3.weight = nn.Parameter(vgg.get(5).weight.float()) + self.conv3.bias = nn.Parameter(vgg.get(5).bias.float()) self.relu3 = nn.ReLU(inplace=True) # 224 x 224 @@ -296,15 +295,15 @@ def __init__(self, vgg): self.reflect_pad4 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv4 = nn.Conv2d(64, 128, 3, 1, 0) - self.conv4.weight = torch.nn.Parameter(vgg.get(9).weight.float()) - self.conv4.bias = torch.nn.Parameter(vgg.get(9).bias.float()) + self.conv4.weight = nn.Parameter(vgg.get(9).weight.float()) + self.conv4.bias = nn.Parameter(vgg.get(9).bias.float()) self.relu4 = nn.ReLU(inplace=True) # 112 x 112 self.reflect_pad5 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv5 = nn.Conv2d(128, 128, 3, 1, 0) - self.conv5.weight = torch.nn.Parameter(vgg.get(12).weight.float()) - self.conv5.bias = torch.nn.Parameter(vgg.get(12).bias.float()) + self.conv5.weight = nn.Parameter(vgg.get(12).weight.float()) + self.conv5.bias = nn.Parameter(vgg.get(12).bias.float()) self.relu5 = nn.ReLU(inplace=True) # 112 x 112 @@ -313,29 +312,29 @@ def __init__(self, vgg): self.reflect_pad6 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv6 = nn.Conv2d(128, 256, 3, 1, 0) - self.conv6.weight = torch.nn.Parameter(vgg.get(16).weight.float()) - self.conv6.bias = torch.nn.Parameter(vgg.get(16).bias.float()) + self.conv6.weight = nn.Parameter(vgg.get(16).weight.float()) + self.conv6.bias = nn.Parameter(vgg.get(16).bias.float()) self.relu6 = nn.ReLU(inplace=True) # 56 x 56 self.reflect_pad7 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv7 = nn.Conv2d(256, 256, 3, 1, 0) - self.conv7.weight = torch.nn.Parameter(vgg.get(19).weight.float()) - self.conv7.bias = torch.nn.Parameter(vgg.get(19).bias.float()) + self.conv7.weight = nn.Parameter(vgg.get(19).weight.float()) + self.conv7.bias = nn.Parameter(vgg.get(19).bias.float()) self.relu7 = nn.ReLU(inplace=True) # 56 x 56 self.reflect_pad8 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv8 = nn.Conv2d(256, 256, 3, 1, 0) - self.conv8.weight = torch.nn.Parameter(vgg.get(22).weight.float()) - self.conv8.bias = torch.nn.Parameter(vgg.get(22).bias.float()) + self.conv8.weight = nn.Parameter(vgg.get(22).weight.float()) + self.conv8.bias = nn.Parameter(vgg.get(22).bias.float()) self.relu8 = nn.ReLU(inplace=True) # 56 x 56 self.reflect_pad9 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv9 = nn.Conv2d(256, 256, 3, 1, 0) - self.conv9.weight = torch.nn.Parameter(vgg.get(25).weight.float()) - self.conv9.bias = torch.nn.Parameter(vgg.get(25).bias.float()) + self.conv9.weight = nn.Parameter(vgg.get(25).weight.float()) + self.conv9.bias = nn.Parameter(vgg.get(25).bias.float()) self.relu9 = nn.ReLU(inplace=True) # 56 x 56 @@ -344,8 +343,8 @@ def __init__(self, vgg): self.reflect_pad10 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv10 = nn.Conv2d(256, 512, 3, 1, 0) - self.conv10.weight = torch.nn.Parameter(vgg.get(29).weight.float()) - self.conv10.bias = torch.nn.Parameter(vgg.get(29).bias.float()) + self.conv10.weight = nn.Parameter(vgg.get(29).weight.float()) + self.conv10.bias = nn.Parameter(vgg.get(29).bias.float()) self.relu10 = nn.ReLU(inplace=True) # 28 x 28 @@ -424,8 +423,8 @@ def __init__(self, d): # decoder self.reflect_pad11 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv11 = nn.Conv2d(512, 256, 3, 1, 0) - self.conv11.weight = torch.nn.Parameter(d.get(1).weight.float()) - self.conv11.bias = torch.nn.Parameter(d.get(1).bias.float()) + self.conv11.weight = nn.Parameter(d.get(1).weight.float()) + self.conv11.bias = nn.Parameter(d.get(1).bias.float()) self.relu11 = nn.ReLU(inplace=True) # 28 x 28 @@ -434,29 +433,29 @@ def __init__(self, d): self.reflect_pad12 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv12 = nn.Conv2d(256, 256, 3, 1, 0) - self.conv12.weight = torch.nn.Parameter(d.get(5).weight.float()) - self.conv12.bias = torch.nn.Parameter(d.get(5).bias.float()) + self.conv12.weight = nn.Parameter(d.get(5).weight.float()) + self.conv12.bias = nn.Parameter(d.get(5).bias.float()) self.relu12 = nn.ReLU(inplace=True) # 56 x 56 self.reflect_pad13 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv13 = nn.Conv2d(256, 256, 3, 1, 0) - self.conv13.weight = torch.nn.Parameter(d.get(8).weight.float()) - self.conv13.bias = torch.nn.Parameter(d.get(8).bias.float()) + self.conv13.weight = nn.Parameter(d.get(8).weight.float()) + self.conv13.bias = nn.Parameter(d.get(8).bias.float()) self.relu13 = nn.ReLU(inplace=True) # 56 x 56 self.reflect_pad14 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv14 = nn.Conv2d(256, 256, 3, 1, 0) - self.conv14.weight = torch.nn.Parameter(d.get(11).weight.float()) - self.conv14.bias = torch.nn.Parameter(d.get(11).bias.float()) + self.conv14.weight = nn.Parameter(d.get(11).weight.float()) + self.conv14.bias = nn.Parameter(d.get(11).bias.float()) self.relu14 = nn.ReLU(inplace=True) # 56 x 56 self.reflect_pad15 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv15 = nn.Conv2d(256, 128, 3, 1, 0) - self.conv15.weight = torch.nn.Parameter(d.get(14).weight.float()) - self.conv15.bias = torch.nn.Parameter(d.get(14).bias.float()) + self.conv15.weight = nn.Parameter(d.get(14).weight.float()) + self.conv15.bias = nn.Parameter(d.get(14).bias.float()) self.relu15 = nn.ReLU(inplace=True) # 56 x 56 @@ -465,15 +464,15 @@ def __init__(self, d): self.reflect_pad16 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv16 = nn.Conv2d(128, 128, 3, 1, 0) - self.conv16.weight = torch.nn.Parameter(d.get(18).weight.float()) - self.conv16.bias = torch.nn.Parameter(d.get(18).bias.float()) + self.conv16.weight = nn.Parameter(d.get(18).weight.float()) + self.conv16.bias = nn.Parameter(d.get(18).bias.float()) self.relu16 = nn.ReLU(inplace=True) # 112 x 112 self.reflect_pad17 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv17 = nn.Conv2d(128, 64, 3, 1, 0) - self.conv17.weight = torch.nn.Parameter(d.get(21).weight.float()) - self.conv17.bias = torch.nn.Parameter(d.get(21).bias.float()) + self.conv17.weight = nn.Parameter(d.get(21).weight.float()) + self.conv17.bias = nn.Parameter(d.get(21).bias.float()) self.relu17 = nn.ReLU(inplace=True) # 112 x 112 @@ -482,15 +481,15 @@ def __init__(self, d): self.reflect_pad18 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv18 = nn.Conv2d(64, 64, 3, 1, 0) - self.conv18.weight = torch.nn.Parameter(d.get(25).weight.float()) - self.conv18.bias = torch.nn.Parameter(d.get(25).bias.float()) + self.conv18.weight = nn.Parameter(d.get(25).weight.float()) + self.conv18.bias = nn.Parameter(d.get(25).bias.float()) self.relu18 = nn.ReLU(inplace=True) # 224 x 224 self.reflect_pad19 = nn.ReflectionPad2d((1, 1, 1, 1)) self.conv19 = nn.Conv2d(64, 3, 3, 1, 0) - self.conv19.weight = torch.nn.Parameter(d.get(28).weight.float()) - self.conv19.bias = torch.nn.Parameter(d.get(28).bias.float()) + self.conv19.weight = nn.Parameter(d.get(28).weight.float()) + self.conv19.bias = nn.Parameter(d.get(28).bias.float()) def forward(self, x, pool_idx, pool1, pool_idx2, pool2, pool_idx3, pool3): # decoder diff --git a/photo_smooth.py b/photo_smooth.py index 1c48ff8..221c4a3 100644 --- a/photo_smooth.py +++ b/photo_smooth.py @@ -3,15 +3,14 @@ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). """ from __future__ import division -import torch import torch.nn as nn import scipy.misc import numpy as np import scipy.sparse import scipy.sparse.linalg from numpy.lib.stride_tricks import as_strided -import cv2 +from PIL import Image class Propagator(nn.Module): def __init__(self, beta=0.9999): @@ -45,12 +44,10 @@ def process(self, initImg, contentImg): V[:,1] = solver(B[:,1]) V[:,2] = solver(B[:,2]) V = V*(1-self.beta) - V = np.reshape(V,(h1,w1,k)) - V[V > 1] = 1 - V[V < 0] = 0 + V = V.reshape(h1,w1,k) V = V[2:2+h,2:2+w,:] - img = np.uint8(V*255) - img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + img = Image.fromarray(np.uint8(np.clip(V * 255., 0, 255.))) return img # Returns sparse matting laplacian diff --git a/photo_wct.py b/photo_wct.py index cd0138a..16c2d78 100644 --- a/photo_wct.py +++ b/photo_wct.py @@ -6,8 +6,10 @@ import torch import torch.nn as nn from torch.utils.serialization import load_lua + import numpy as np -import cv2 + +from PIL import Image class PhotoWCT(nn.Module): def __init__(self, args): @@ -121,12 +123,12 @@ def __feature_wct(self, cont_feat, styl_feat, cont_seg, styl_seg): ccsF = target_feature.float().unsqueeze(0) return ccsF - t_cont_seg = cv2.resize(cont_seg, (cont_w, cont_h), interpolation = cv2.INTER_NEAREST) - t_styl_seg = cv2.resize(styl_seg, (styl_w, styl_h), interpolation = cv2.INTER_NEAREST) + t_cont_seg = np.asarray(Image.fromarray(cont_seg, mode='RGB').resize((cont_w, cont_h), Image.NEAREST)) + t_styl_seg = np.asarray(Image.fromarray(styl_seg, mode='RGB').resize((styl_w, styl_h), Image.NEAREST)) for l in self.label_set: if self.label_indicator[l]==0: - continue; + continue cont_mask = np.where(t_cont_seg.reshape(t_cont_seg.shape[0] * t_cont_seg.shape[1]) == l) styl_mask = np.where(t_styl_seg.reshape(t_styl_seg.shape[0] * t_styl_seg.shape[1]) == l) if cont_mask[0].size <= 0 or styl_mask[0].size <= 0 : @@ -184,8 +186,4 @@ def __wct_core(self, cont_feat, styl_feat): return targetFeature def __large_dff(self, a, b): - if (a / b >= 100): - return True - if (b / a >= 100): - return True - return False \ No newline at end of file + return a / b >= 100 or b / a >= 100 \ No newline at end of file diff --git a/process_stylization.py b/process_stylization.py new file mode 100644 index 0000000..edd56a0 --- /dev/null +++ b/process_stylization.py @@ -0,0 +1,64 @@ +""" +Copyright (C) 2018 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" + +from __future__ import print_function + +import time + +import numpy as np +from PIL import Image +from torch.autograd import Variable +import torchvision.transforms as transforms +import torchvision.utils as utils + +from photo_smooth import Propagator +from smooth_filter import smooth_filter + +# Load Propagator +p_pro = Propagator() + + +class Timer: + def __init__(self, msg): + self.msg = msg + self.start_time = None + + def __enter__(self): + self.start_time = time.time() + + def __exit__(self, exc_type, exc_value, exc_tb): + print(self.msg % (time.time() - self.start_time)) + + +def stylization(p_wct, content_image_path, style_image_path, content_seg_path, style_seg_path, output_image_path): + # Load image + cont_img = Image.open(content_image_path).convert('RGB') + styl_img = Image.open(style_image_path).convert('RGB') + try: + cont_seg = Image.open(content_seg_path) + styl_seg = Image.open(style_seg_path) + except: + cont_seg = [] + styl_seg = [] + + cont_img = transforms.ToTensor()(cont_img).unsqueeze(0) + styl_img = transforms.ToTensor()(styl_img).unsqueeze(0) + cont_img = Variable(cont_img.cuda(0), volatile=True) + styl_img = Variable(styl_img.cuda(0), volatile=True) + + cont_seg = np.asarray(cont_seg) + styl_seg = np.asarray(styl_seg) + + with Timer("Elapsed time in stylization: %f"): + stylized_img = p_wct.transform(cont_img, styl_img, cont_seg, styl_seg) + utils.save_image(stylized_img.data.cpu().float(), output_image_path, nrow=1) + + with Timer("Elapsed time in propagation: %f"): + out_img = p_pro.process(output_image_path, content_image_path) + out_img.save(output_image_path) + + with Timer("Elapsed time in post processing: %f"): + out_img = smooth_filter(output_image_path, content_image_path, f_radius=15, f_edge=1e-1) + out_img.save(output_image_path) diff --git a/process_stylization_examples.py b/process_stylization_examples.py index a024cf7..fe24c21 100644 --- a/process_stylization_examples.py +++ b/process_stylization_examples.py @@ -3,18 +3,10 @@ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). """ import os -import torch -from torch.autograd import Variable -import torchvision.transforms as transforms -import torchvision.utils as utils import argparse -import time -import numpy as np -import cv2 -from PIL import Image + from photo_wct import PhotoWCT -from photo_smooth import Propagator -from smooth_filter import smooth_filter +import process_stylization parser = argparse.ArgumentParser(description='Photorealistic Image Stylization') parser.add_argument('--vgg1', default='./models/vgg_normalised_conv1_1_mask.t7', help='Path to the VGG conv1_1') @@ -30,61 +22,32 @@ args = parser.parse_args() folder = 'examples' -cont_img_folder = os.path.join(folder,'content_img') -cont_seg_folder = os.path.join(folder,'content_seg') -styl_img_folder = os.path.join(folder,'style_img') -styl_seg_folder = os.path.join(folder,'style_seg') -outp_img_folder = os.path.join(folder,'results') +cont_img_folder = os.path.join(folder, 'content_img') +cont_seg_folder = os.path.join(folder, 'content_seg') +styl_img_folder = os.path.join(folder, 'style_img') +styl_seg_folder = os.path.join(folder, 'style_seg') +outp_img_folder = os.path.join(folder, 'results') cont_img_list = [f for f in os.listdir(cont_img_folder) if os.path.isfile(os.path.join(cont_img_folder, f))] cont_img_list.sort() # Load model p_wct = PhotoWCT(args) -p_pro = Propagator() p_wct.cuda(0) for f in cont_img_list: - print("Process " + f) - - content_image_path = os.path.join(cont_img_folder, f) - content_seg_path = os.path.join(cont_seg_folder, f).replace(".png",".pgm") - style_image_path = os.path.join(styl_img_folder, f) - style_seg_path = os.path.join(styl_seg_folder, f).replace(".png",".pgm") - output_image_path = os.path.join(outp_img_folder, f) - - # Load image - cont_img = Image.open(content_image_path).convert('RGB') - styl_img = Image.open(style_image_path).convert('RGB') - try: - cont_seg = Image.open(content_seg_path) - styl_seg = Image.open(style_seg_path) - except: - cont_seg = [] - styl_seg = [] - - cont_img = transforms.ToTensor()(cont_img).unsqueeze(0) - styl_img = transforms.ToTensor()(styl_img).unsqueeze(0) - cont_img = Variable(cont_img.cuda(0), volatile=True) - styl_img = Variable(styl_img.cuda(0), volatile=True) - - cont_seg = np.asarray(cont_seg) - styl_seg = np.asarray(styl_seg) - - start_style_time = time.time() - stylized_img = p_wct.transform(cont_img, styl_img, cont_seg, styl_seg) - end_style_time = time.time() - print('Elapsed time in stylization: %f' % (end_style_time - start_style_time)) - utils.save_image(stylized_img.data.cpu().float(), output_image_path, nrow=1) - - start_propagation_time = time.time() - out_img = p_pro.process(output_image_path, content_image_path) - end_propagation_time = time.time() - print('Elapsed time in propagation: %f' % (end_propagation_time - start_propagation_time)) - cv2.imwrite(output_image_path, out_img) - - start_postprocessing_time = time.time() - out_img = smooth_filter(output_image_path, content_image_path, f_radius=15,f_edge=1e-1) - end_postprocessing_time = time.time() - print('Elapsed time in post processing: %f' % (end_postprocessing_time - start_postprocessing_time)) - - out_img.save(output_image_path) \ No newline at end of file + print("Process " + f) + + content_image_path = os.path.join(cont_img_folder, f) + content_seg_path = os.path.join(cont_seg_folder, f).replace(".png", ".pgm") + style_image_path = os.path.join(styl_img_folder, f) + style_seg_path = os.path.join(styl_seg_folder, f).replace(".png", ".pgm") + output_image_path = os.path.join(outp_img_folder, f) + + process_stylization.stylization( + p_wct=p_wct, + content_image_path=content_image_path, + style_image_path=style_image_path, + content_seg_path=content_seg_path, + style_seg_path=style_seg_path, + output_image_path=output_image_path, + )