Skip to content

Commit

Permalink
small change to neurallp
Browse files Browse the repository at this point in the history
  • Loading branch information
KiddoZhu committed Jun 4, 2022
1 parent 5f9ce1c commit 957cec7
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
2 changes: 1 addition & 1 deletion torchdrug/models/gin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable):
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
"""

def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False,
def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False,
short_cut=False, batch_norm=False, activation="relu", concat_hidden=False,
readout="sum"):
super(GraphIsomorphismNetwork, self).__init__()
Expand Down
7 changes: 2 additions & 5 deletions torchdrug/models/neurallp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class NeuralLogicProgramming(nn.Module, core.Configurable):
https://papers.nips.cc/paper/2017/file/0e55666a4ad822e0e34299df3591d979-Paper.pdf
Parameters:
num_entity (int): number of entities
num_relation (int): number of relations
hidden_dim (int): dimension of hidden units in LSTM
num_step (int): number of recurrent steps
Expand All @@ -26,17 +25,15 @@ class NeuralLogicProgramming(nn.Module, core.Configurable):

eps = 1e-10

def __init__(self, num_entity, num_relation, hidden_dim, num_step, num_lstm_layer=1):
def __init__(self, num_relation, hidden_dim, num_step, num_lstm_layer=1):
super(NeuralLogicProgramming, self).__init__()

num_relation = int(num_relation)
self.num_entity = num_entity
self.num_relation = num_relation
self.num_step = num_step

self.query = nn.Embedding(num_relation * 2 + 1, hidden_dim)
self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_lstm_layer)
self.key_linear = nn.Linear(hidden_dim, hidden_dim)
self.weight_linear = nn.Linear(hidden_dim, num_relation * 2)
self.linear = nn.Linear(1, 1)

Expand All @@ -56,7 +53,7 @@ def get_t_output(self, graph, h_index, r_index):
query = self.query(q_index)

hidden, hx = self.lstm(query)
memory = functional.one_hot(h_index, self.num_entity).unsqueeze(0)
memory = functional.one_hot(h_index, graph.num_entity).unsqueeze(0)

for i in range(self.num_step):
key = hidden[i]
Expand Down

0 comments on commit 957cec7

Please sign in to comment.