-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathgraph_cnn.py
278 lines (225 loc) · 10.8 KB
/
graph_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import numpy as np
import tensorflow as tf
import scipy
import lib_gcnn.graph as graph
class GraphCNN(object):
"""
A graph CNN for text classification. Composed of graph convolutional + max-pooling layer(s) and a
softmax layer.
filter_name = Filter name (i.e. "chebyshev", "spline", "fourier")
L = List of graph Laplacians.
K = List of filter sizes i.e. support sizes (no. of hops)
(Polynomial orders for Chebyshev; K[i] = L[i].shape[0] for non-param Fourier)
F = List of no. of features (per filter).
P = List of pooling sizes (per filter).
FC = List of fully-connected layers.
Paper for Chebyshev Filter: https://arxiv.org/abs/1606.09375
Paper for Spline Filter: https://arxiv.org/abs/1312.6203
Code adapted from https://github.com/mdeff/cnn_graph
"""
def __init__(self, filter_name, L, K, F, P, FC, batch_size, num_vertices, num_classes, l2_reg_lambda):
# Sanity checks
assert len(L) >= len(F) == len(K) == len(P) # verify consistency w.r.t. the no. of GCLs
assert np.all(np.array(P) >= 1) # all pool sizes >= 1
p_log2 = np.where(np.array(P) > 1, np.log2(P), 0)
assert np.all(np.mod(p_log2, 1) == 0) # all pool sizes > 1 should be powers of 2
assert len(L) >= 1 + np.sum(p_log2) # enough coarsening levels for pool sizes
# Retrieve convolutional filter
assert filter_name == "chebyshev" or filter_name == "spline" or filter_name == "fourier"
self.graph_conv = getattr(self, "graph_conv_" + filter_name)
# Placeholders for input, output and dropout
self.input_x = tf.placeholder(tf.float32, [batch_size, num_vertices], name="input_x")
self.input_y = tf.placeholder(tf.int32, [batch_size], name="input_y")
self.train_flag = tf.placeholder(tf.bool, name="train_flag")
self.dropout_keep_prob = tf.placeholder_with_default(1.0, shape=[], name="dropout_keep_prob")
# Keeping track of L2 regularization loss
l2_loss = tf.constant(0.0)
# Keep the useful Laplacians only (may be zero)
M_0 = L[0].shape[0]
j = 0
L_tmp = []
for p_i in P:
L_tmp.append(L[j])
j += int(np.log2(p_i)) if p_i > 1 else 0
L = L_tmp
# Expand dims for convolution operation
x = tf.expand_dims(self.input_x, 2) # B x V x F=1
# Graph convolution + max-pool layer(s)
for i in range(len(K)):
with tf.variable_scope("conv-maxpool-{}".format(i)):
with tf.variable_scope("conv-{}-{}".format(K[i], F[i])):
# Graph convolution operation
x = self.graph_conv(x, L[i], K[i], F[i])
# Add bias & apply non-linearity
b = tf.Variable(tf.constant(0.1, shape=[1, 1, F[i]]), name="b")
x = tf.nn.relu(x + b, name="relu")
with tf.variable_scope("maxpool-{}".format(P[i])):
# Graph max-pooling operation
x = self.graph_max_pool(x, P[i])
# Add dropout
with tf.variable_scope("dropout"):
x = tf.nn.dropout(x, self.dropout_keep_prob)
# Reshape x for fully-connected layers
with tf.variable_scope("reshape"):
B, V, F = x.get_shape()
B, V, F = int(B), int(V), int(F)
x = tf.reshape(x, [B, V * F])
# Add fully-connected layers (if any)
for i, num_units in enumerate(FC):
with tf.variable_scope("fc-{}-{}".format(i, num_units)):
W = tf.get_variable("W",
shape=[x.get_shape().as_list()[1], num_units],
initializer=tf.contrib.layers.xavier_initializer())
b = tf.Variable(tf.constant(0.1, shape=[num_units]), name="b")
l2_loss += tf.nn.l2_loss(W)
x = tf.nn.xw_plus_b(x, W, b)
x = tf.layers.batch_normalization(x, training=self.train_flag)
x = tf.nn.relu(x)
x = tf.nn.dropout(x, self.dropout_keep_prob)
# Final (unnormalized) scores and predictions
with tf.variable_scope("output"):
W = tf.get_variable("W",
shape=[x.get_shape().as_list()[1], num_classes],
initializer=tf.contrib.layers.xavier_initializer())
b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name="b")
l2_loss += tf.nn.l2_loss(W)
self.scores = tf.nn.xw_plus_b(x, W, b, name="scores")
self.predictions = tf.argmax(self.scores, 1, name="predictions")
self.predictions = tf.cast(self.predictions, tf.int32)
# Calculate mean cross-entropy loss
with tf.variable_scope("loss"):
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.scores, labels=self.input_y)
self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss
# Calculate accuracy
with tf.variable_scope("accuracy"):
correct_predictions = tf.equal(self.predictions, self.input_y)
self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy")
def graph_conv_chebyshev(self, x, L, K, F_out):
"""
Graph convolutional operation.
"""
# K = Chebyshev polynomial order & support size
# F_out = No. of output features (per vertex)
# B = Batch size
# V = No. of vertices
# F_in = No. of input features (per vertex)
B, V, F_in = x.get_shape()
B, V, F_in = int(B), int(V), int(F_in)
# Rescale Laplacian and store as a TF sparse tensor (copy to not modify the shared L)
L = scipy.sparse.csr_matrix(L)
L = graph.rescale_L(L, lmax=2)
L = L.tocoo()
indices = np.column_stack((L.row, L.col))
L = tf.SparseTensor(indices, L.data, L.shape)
L = tf.sparse_reorder(L)
L = tf.cast(L, tf.float32)
# Transform to Chebyshev basis
x0 = tf.transpose(x, perm=[1, 2, 0]) # V x F_in x B
x0 = tf.reshape(x0, [V, F_in * B]) # V x F_in*B
x = tf.expand_dims(x0, 0) # 1 x V x F_in*B
def concat(x, x_):
x_ = tf.expand_dims(x_, 0) # 1 x V x F_in*B
return tf.concat([x, x_], axis=0) # K x V x F_in*B
if K > 1:
x1 = tf.sparse_tensor_dense_matmul(L, x0)
x = concat(x, x1)
for k in range(2, K):
x2 = 2 * tf.sparse_tensor_dense_matmul(L, x1) - x0 # V x F_in*B
x = concat(x, x2)
x0, x1 = x1, x2
x = tf.reshape(x, [K, V, F_in, B]) # K x V x F_in x B
x = tf.transpose(x, perm=[3, 1, 2, 0]) # B x V x F_in x K
x = tf.reshape(x, [B * V, F_in * K]) # B*V x F_in*K
# Compose linearly F_in features to get F_out features
W = tf.Variable(tf.truncated_normal([F_in * K, F_out], stddev=0.1), name="W")
x = tf.matmul(x, W) # B*V x F_out
x = tf.reshape(x, [B, V, F_out]) # B x V x F_out
return x
def graph_conv_spline(self, x, L, K, F_out):
"""
Graph convolution operation.
"""
B, V, F_in = x.get_shape()
B, V, F_in = int(B), int(V), int(F_in)
# Fourier basis
lamb, U = graph.fourier(L)
U = tf.constant(U.T, dtype=tf.float32) # V x V
# Spline basis
basis = self.bspline_basis(K, lamb, degree=3) # V x K
basis = tf.constant(basis, dtype=tf.float32)
# Weight multiplication
W = tf.Variable(tf.truncated_normal([K, F_in * F_out], stddev=0.1), name="W")
W = tf.matmul(basis, W) # V x F_out*F_in
W = tf.reshape(W, [V, F_out, F_in])
return self.filter_in_fourier(x, L, K, F_out, U, W)
def graph_conv_fourier(self, x, L, K, F_out):
"""
Graph convolution operation.
"""
assert K == L.shape[0] # artificial but useful to compute number of parameters
B, V, F_in = x.get_shape()
B, V, F_in = int(B), int(V), int(F_in)
# Fourier basis
_, U = graph.fourier(L)
U = tf.constant(U.T, dtype=tf.float32)
# Weights
W = tf.Variable(tf.truncated_normal([V, F_out, F_in], stddev=0.1), name="W")
return self.filter_in_fourier(x, L, K, F_out, U, W)
def graph_max_pool(self, x, p):
"""
Graph max pooling operation. p must be 1 or a power of 2.
"""
if p > 1:
x = tf.expand_dims(x, 3) # B x V x F x 1
x = tf.nn.max_pool(x, ksize=[1, p, 1, 1], strides=[1, p, 1, 1], padding="SAME")
return tf.squeeze(x, [3]) # B x V/p x F
else:
return x
def filter_in_fourier(self, x, L, K, F_out, U, W):
B, V, F_in = x.get_shape()
B, V, F_in = int(B), int(V), int(F_in)
x = tf.transpose(x, perm=[1, 2, 0]) # V x F_in x B
# Transform to Fourier domain
x = tf.reshape(x, [V, F_in * B]) # V x F_in*B
x = tf.matmul(U, x) # V x F_in*B
x = tf.reshape(x, [V, F_in, B]) # V x F_in x B
# Filter
x = tf.matmul(W, x) # for each feature
x = tf.transpose(x) # B x F_out x V
x = tf.reshape(x, [B * F_out, V]) # B*F_out x V
# Transform back to graph domain
x = tf.matmul(x, U) # B*F_out x V
x = tf.reshape(x, [B, F_out, V]) # B x F_out x V
return tf.transpose(x, perm=[0, 2, 1]) # B x V x F_out
def bspline_basis(self, K, x, degree=3):
"""
Return the B-spline basis.
K: Number of control points.
x: Evaluation points or number of evenly distributed evaluation points.
degree: Degree of the spline. Cubic spline by default.
"""
if np.isscalar(x):
x = np.linspace(0, 1, x)
# Evenly distributed knot vectors
kv1 = x.min() * np.ones(degree)
kv2 = np.linspace(x.min(), x.max(), K - degree + 1)
kv3 = x.max() * np.ones(degree)
kv = np.concatenate((kv1, kv2, kv3))
# Cox-DeBoor recursive function to compute one spline over x
def cox_deboor(k, d):
# Test for end conditions, the rectangular degree zero spline
if (d == 0):
return ((x - kv[k] >= 0) & (x - kv[k + 1] < 0)).astype(int)
denom1 = kv[k + d] - kv[k]
term1 = 0
if denom1 > 0:
term1 = ((x - kv[k]) / denom1) * cox_deboor(k, d - 1)
denom2 = kv[k + d + 1] - kv[k + 1]
term2 = 0
if denom2 > 0:
term2 = ((-(x - kv[k + d + 1]) / denom2) * cox_deboor(k + 1, d - 1))
return term1 + term2
# Compute basis for each point
basis = np.column_stack([cox_deboor(k, degree) for k in range(K)])
basis[-1, -1] = 1
return basis