-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
56 lines (43 loc) · 1.95 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from torch_geometric.nn import SAGEConv, GATConv, Linear, to_hetero
class GAT(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = GATConv((-1, -1), hidden_channels, add_self_loops=False)
self.lin1 = Linear(-1, hidden_channels)
self.conv2 = GATConv((-1, -1), out_channels, add_self_loops=False)
self.lin2 = Linear(-1, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index) + self.lin1(x)
x = x.relu()
x = self.conv2(x, edge_index) + self.lin2(x)
return x
class GNNEncoder(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv(-1, hidden_channels)
self.conv2 = SAGEConv(-1, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
class EdgeDecoder(torch.nn.Module):
def __init__(self, hidden_channels, edge_features):
super().__init__()
self.lin1 = Linear(2 * hidden_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, edge_features)
def forward(self, z_dict, edge_label_index):
row, col = edge_label_index
z = torch.cat([z_dict['customer'][row], z_dict['product'][col]], dim=-1)
z = self.lin1(z).relu()
z = self.lin2(z)
return z
class Model(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, edge_features, metadata):
super().__init__()
self.encoder = GNNEncoder(hidden_channels, out_channels)
self.encoder = to_hetero(self.encoder, metadata, aggr='sum')
self.decoder = EdgeDecoder(hidden_channels, edge_features)
def forward(self, x_dict, edge_index_dict, edge_label_index, *args, **kwargs):
z_dict = self.encoder(x_dict, edge_index_dict)
return self.decoder(z_dict, edge_label_index)