-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathDeepDream.py
108 lines (86 loc) · 3.94 KB
/
DeepDream.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# As usual, a bit of setup
from __future__ import print_function
import time, os, json
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from lib.classifiers.squeezenet import SqueezeNet
from lib.utils.data_utils import load_tiny_imagenet, load_imagenet_val
from lib.utils.image_utils import preprocess_image, deprocess_image
from lib.utils.image_utils import SQUEEZENET_MEAN, SQUEEZENET_STD
from scipy.ndimage.filters import gaussian_filter1d
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
def get_session():
"""Create a session that dynamically allocates memory."""
# See: https://www.tensorflow.org/tutorials/using_gpu#allowing_gpu_memory_growth
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
return session
tf.reset_default_graph()
sess = get_session()
SAVE_PATH = 'lib/datasets/squeezenet.ckpt'
# if not os.path.exists(SAVE_PATH):
# raise ValueError("You need to download SqueezeNet!")
model = SqueezeNet(save_path=SAVE_PATH, sess=sess)
X_raw, y, class_names = load_imagenet_val(num=5)
X = np.array([preprocess_image(img) for img in X_raw])
#----------------------------Finish Setup----------------------------
def blur_image(X, sigma=1):
X = gaussian_filter1d(X, sigma, axis=1)
X = gaussian_filter1d(X, sigma, axis=2)
return X
def create_class_visualization(target_y, model, **kwargs):
"""
Generate an image to maximize the score of target_y under a pretrained model.
Inputs:
- target_y: Integer in the range [0, 1000) giving the index of the class
- model: A pretrained CNN that will be used to generate the image
Keyword arguments:
- l2_reg: Strength of L2 regularization on the image
- learning_rate: How big of a step to take
- num_iterations: How many iterations to use
- blur_every: How often to blur the image as an implicit regularizer
- max_jitter: How much to gjitter the image as an implicit regularizer
- show_every: How often to show the intermediate result
"""
l2_reg = kwargs.pop('l2_reg', 1e-3)
learning_rate = kwargs.pop('learning_rate', 25)
num_iterations = kwargs.pop('num_iterations', 100)
blur_every = kwargs.pop('blur_every', 10)
max_jitter = kwargs.pop('max_jitter', 16)
show_every = kwargs.pop('show_every', 25)
X = 255 * np.random.rand(224, 224, 3)
X = preprocess_image(X)[None]
losses = model.classifier[0]
grad = tf.gradients(model.classifier[0, target_y], model.image)[0] - l2_reg*model.image
for t in range(num_iterations):
# Randomly jitter the image a bit; this gives slightly nicer results
ox, oy = np.random.randint(-max_jitter, max_jitter+1, 2)
Xi = X.copy()
X = np.roll(np.roll(X, ox, 1), oy, 2)
loss_val = sess.run(losses, feed_dict={model.image: X})
grad_val = sess.run(grad, feed_dict={model.image: X})
dX = learning_rate * grad_val
X += dX
print('step:%d,current_label_score:%f,target_label_score:%f' % \
(t, loss_val.max(), loss_val[target_y]))
# Undo the jitter
X = np.roll(np.roll(X, -ox, 1), -oy, 2)
# As a regularizer, clip and periodically blur
X = np.clip(X, -SQUEEZENET_MEAN/SQUEEZENET_STD, (1.0 - SQUEEZENET_MEAN)/SQUEEZENET_STD)
if t % blur_every == 0:
X = blur_image(X, sigma=0.5)
# Periodically show the image
if t == 0 or (t + 1) % show_every == 0 or t == num_iterations - 1:
plt.imshow(deprocess_image(X[0]))
class_name = class_names[target_y]
plt.title('%s\nIteration %d / %d' % (class_name, t + 1, num_iterations))
plt.gcf().set_size_inches(4, 4)
plt.axis('off')
plt.show()
return X
target_y = 366 # Gorilla
out = create_class_visualization(target_y, model,num_iterations=200)