-
Notifications
You must be signed in to change notification settings - Fork 92
/
Copy pathdemo_distributed_gcn.py
103 lines (75 loc) · 3.19 KB
/
demo_distributed_gcn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# coding=utf-8
import os
# multi-gpu ids
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
import tf_geometric as tfg
from tf_geometric.layers import GCN
from tensorflow.keras.regularizers import L1L2
import tensorflow as tf
graph, (train_index, valid_index, test_index) = tfg.datasets.CoraDataset().load_data()
num_classes = graph.y.max() + 1
drop_rate = 0.5
learning_rate = 1e-2
l2_coef = 5e-4
# custom network
class GCNNetwork(tf.keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gcn0 = GCN(16, activation=tf.nn.relu, kernel_regularizer=L1L2(l2=l2_coef))
self.gcn1 = GCN(num_classes, kernel_regularizer=L1L2(l2=l2_coef))
self.dropout = tf.keras.layers.Dropout(drop_rate)
def call(self, inputs, training=None, mask=None):
x, edge_index = inputs
h = self.dropout(x, training=training)
h = self.gcn0([h, edge_index], training=training)
h = self.dropout(h, training=training)
h = self.gcn1([h, edge_index], training=training)
return h
# prepare a generator and a dataset for distributed training
def create_batch_generator():
while True:
yield (graph.x, graph.edge_index), graph.y
def dataset_fn(ctx):
dataset = tf.data.Dataset.from_generator(
create_batch_generator,
output_types=((tf.float32, tf.int32), tf.int32),
output_shapes=((tf.TensorShape([None, graph.x.shape[1]]), tf.TensorShape([2, None])), tf.TensorShape([None]))
)
return dataset
strategy = tf.distribute.MirroredStrategy()
distributed_dataset = strategy.experimental_distribute_datasets_from_function(dataset_fn)
# The model will automatically use all seen GPUs defined by "CUDA_VISIBLE_DEVICES" for distributed training
with strategy.scope():
model = GCNNetwork()
# custom loss function
def masked_cross_entropy(y_true, logits):
y_true = tf.cast(y_true, tf.int32)
masked_logits = tf.gather(logits, train_index)
masked_labels = tf.gather(y_true, train_index)
losses = tf.nn.softmax_cross_entropy_with_logits(
logits=masked_logits,
labels=tf.one_hot(masked_labels, depth=num_classes)
)
return tf.reduce_mean(losses)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2),
loss=masked_cross_entropy,
# run_eagerly=True
)
def evaluate():
logits = model([graph.x, graph.edge_index])
masked_logits = tf.gather(logits, test_index)
masked_labels = tf.gather(graph.y, test_index)
y_pred = tf.argmax(masked_logits, axis=-1, output_type=tf.int32)
corrects = tf.cast(tf.equal(masked_labels, y_pred), tf.float32)
accuracy = tf.reduce_mean(corrects)
return accuracy.numpy()
class EvaluationCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
if epoch % 20 == 0:
test_accuracy = evaluate()
print("epoch = {}\ttest_accuracy = {}".format(epoch, test_accuracy))
# The model will automatically use all seen GPUs defined by "CUDA_VISIBLE_DEVICES" for distributed training
model.fit(distributed_dataset, steps_per_epoch=1, epochs=201, callbacks=[EvaluationCallback()], verbose=2)
test_accuracy = evaluate()
print("final test_accuracy = {}".format(test_accuracy))