-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_model.py
107 lines (73 loc) · 4.02 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from scipy import misc
import numpy as np
import tensorflow as tf
from models import resnet
import utils
import os
import sys
import time
# process command arguments
phone, dped_dir, test_subset, iteration, resolution, use_gpu = utils.process_test_model_args(sys.argv)
# get all available image resolutions
res_sizes = utils.get_resolutions()
# get the specified image resolution
IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_SIZE = utils.get_specified_res(res_sizes, phone, resolution)
# disable gpu if specified
config = tf.ConfigProto(device_count={'GPU': 0}) if use_gpu == "false" else None
# create placeholders for input images
x_ = tf.placeholder(tf.float32, [None, IMAGE_SIZE])
x_image = tf.reshape(x_, [-1, IMAGE_HEIGHT, IMAGE_WIDTH, 3])
# generate enhanced image
enhanced = resnet(x_image)
with tf.Session(config=config) as sess:
test_dir = dped_dir + phone.replace("_orig", "") + "/test_data/full_size_test_images/"
test_photos = [f for f in os.listdir(test_dir) if os.path.isfile(test_dir + f)]
if test_subset == "small":
# use five first images only
test_photos = test_photos[0:5]
if phone.endswith("_orig"):
# load pre-trained model
saver = tf.train.Saver()
saver.restore(sess, "models_orig/" + phone)
for photo in test_photos:
# load training image and crop it if necessary
print("Testing original " + phone.replace("_orig", "") + " model, processing image " + photo)
image = np.float16(misc.imresize(misc.imread(test_dir + photo), res_sizes[phone])) / 255
image_crop = utils.extract_crop(image, resolution, phone, res_sizes)
image_crop_2d = np.reshape(image_crop, [1, IMAGE_SIZE])
# get enhanced image
enhanced_2d = sess.run(enhanced, feed_dict={x_: image_crop_2d})
enhanced_image = np.reshape(enhanced_2d, [IMAGE_HEIGHT, IMAGE_WIDTH, 3])
before_after = np.hstack((image_crop, enhanced_image))
photo_name = photo.rsplit(".", 1)[0]
# save the results as .png images
misc.imsave("visual_results/" + phone + "_" + photo_name + "_enhanced.jpg", enhanced_image)
#misc.imsave("visual_results/" + phone + "_" + photo_name + "_before_after.png", before_after)
else:
num_saved_models = int(len([f for f in os.listdir("models/") if f.startswith(phone + "_iteration")]) / 2)
if iteration == "all":
iteration = np.arange(1, num_saved_models) * 1000
else:
iteration = [int(iteration)]
for i in iteration:
# load pre-trained model
saver = tf.train.Saver()
saver.restore(sess, "models/" + phone + "_iteration_" + str(i) + ".ckpt")
for photo in test_photos:
# load training image and crop it if necessary
start = time.clock()
#--------
print("iteration " + str(i) + ", processing image " + photo)
image = np.float16(misc.imresize(misc.imread(test_dir + photo), res_sizes[phone])) / 255
image_crop = utils.extract_crop(image, resolution, phone, res_sizes)
image_crop_2d = np.reshape(image_crop, [1, IMAGE_SIZE])
# get enhanced image
enhanced_2d = sess.run(enhanced, feed_dict={x_: image_crop_2d})
enhanced_image = np.reshape(enhanced_2d, [IMAGE_HEIGHT, IMAGE_WIDTH, 3])
before_after = np.hstack((image_crop, enhanced_image))
photo_name = photo.rsplit(".", 1)[0]
elapsed = (time.clock() - start)
print("Time used:",elapsed)
# save the results as .png images
misc.imsave("visual_results/" + phone + "_" + photo_name + "_iteration_" + str(i) + "_enhanced.jpg", enhanced_image)
#misc.imsave("visual_results/" + phone + "_" + photo_name + "_iteration_" + str(i) + "_before_after.png", before_after)