forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrrn.py
60 lines (48 loc) · 1.72 KB
/
rrn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
Recurrent Relational Network(RRN) module
References:
- Recurrent Relational Networks
- Paper: https://arxiv.org/abs/1711.08028
- Original Code: https://github.com/rasmusbergpalm/recurrent-relational-networks
"""
import torch
from torch import nn
import dgl.function as fn
class RRNLayer(nn.Module):
def __init__(self, msg_layer, node_update_func, edge_drop):
super(RRNLayer, self).__init__()
self.msg_layer = msg_layer
self.node_update_func = node_update_func
self.edge_dropout = nn.Dropout(edge_drop)
def forward(self, g):
g.apply_edges(self.get_msg)
g.edata['e'] = self.edge_dropout(g.edata['e'])
g.update_all(message_func=fn.copy_e('e', 'msg'),
reduce_func=fn.sum('msg', 'm'))
g.apply_nodes(self.node_update)
def get_msg(self, edges):
e = torch.cat([edges.src['h'], edges.dst['h']], -1)
e = self.msg_layer(e)
return {'e': e}
def node_update(self, nodes):
return self.node_update_func(nodes)
class RRN(nn.Module):
def __init__(self,
msg_layer,
node_update_func,
num_steps,
edge_drop):
super(RRN, self).__init__()
self.num_steps = num_steps
self.rrn_layer = RRNLayer(msg_layer, node_update_func, edge_drop)
def forward(self, g, get_all_outputs=True):
outputs = []
for _ in range(self.num_steps):
self.rrn_layer(g)
if get_all_outputs:
outputs.append(g.ndata['h'])
if get_all_outputs:
outputs = torch.stack(outputs, 0) # num_steps x n_nodes x h_dim
else:
outputs = g.ndata['h'] # n_nodes x h_dim
return outputs