-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUV_Encoders.py
41 lines (33 loc) · 1.44 KB
/
UV_Encoders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
class UV_Encoder(nn.Module):
def __init__(self, features, embed_dim, history_uv_lists, history_r_lists, aggregator, cuda="cpu", uv=True):
super(UV_Encoder, self).__init__()
self.features = features
self.uv = uv
self.history_uv_lists = history_uv_lists
self.history_r_lists = history_r_lists
self.aggregator = aggregator
self.embed_dim = embed_dim
self.device = cuda
self.linear1 = nn.Linear(2 *self.embed_dim, self.embed_dim) #
self.linear0 = nn.Linear(2 * self.embed_dim, 2*self.embed_dim) #
self.dropout = nn.Dropout(p=0.5)
def forward(self, nodes):
tmp_history_uv = []
tmp_history_r = []
for node in nodes:
#try:
tmp_history_uv.append(self.history_uv_lists[int(node)])
tmp_history_r.append(self.history_r_lists[int(node)])
#except : print(self.uv, int(node))
neigh_feats = self.aggregator.forward(nodes, tmp_history_uv, tmp_history_r) # user-item network
self_feats = self.features.weight[nodes]
# self-connection could be considered.
combined = torch.cat([self_feats, neigh_feats], dim=1)
combined = F.relu(self.linear0(combined))
combined = self.dropout(combined)
combined = F.relu(self.linear1(combined))
return combined