From 957cec700a773179a468db2cffc7b7d4ef7ced3c Mon Sep 17 00:00:00 2001 From: Zhaocheng Zhu Date: Thu, 2 Jun 2022 21:05:37 -0400 Subject: [PATCH] small change to neurallp --- torchdrug/models/gin.py | 2 +- torchdrug/models/neurallp.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/torchdrug/models/gin.py b/torchdrug/models/gin.py index 44238ab3..990bb4b2 100644 --- a/torchdrug/models/gin.py +++ b/torchdrug/models/gin.py @@ -29,7 +29,7 @@ class GraphIsomorphismNetwork(nn.Module, core.Configurable): readout (str, optional): readout function. Available functions are ``sum`` and ``mean``. """ - def __init__(self, input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False, + def __init__(self, input_dim, hidden_dims, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False, short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, readout="sum"): super(GraphIsomorphismNetwork, self).__init__() diff --git a/torchdrug/models/neurallp.py b/torchdrug/models/neurallp.py index ef78c67c..34d97e8b 100644 --- a/torchdrug/models/neurallp.py +++ b/torchdrug/models/neurallp.py @@ -17,7 +17,6 @@ class NeuralLogicProgramming(nn.Module, core.Configurable): https://papers.nips.cc/paper/2017/file/0e55666a4ad822e0e34299df3591d979-Paper.pdf Parameters: - num_entity (int): number of entities num_relation (int): number of relations hidden_dim (int): dimension of hidden units in LSTM num_step (int): number of recurrent steps @@ -26,17 +25,15 @@ class NeuralLogicProgramming(nn.Module, core.Configurable): eps = 1e-10 - def __init__(self, num_entity, num_relation, hidden_dim, num_step, num_lstm_layer=1): + def __init__(self, num_relation, hidden_dim, num_step, num_lstm_layer=1): super(NeuralLogicProgramming, self).__init__() num_relation = int(num_relation) - self.num_entity = num_entity self.num_relation = num_relation self.num_step = num_step self.query = nn.Embedding(num_relation * 2 + 1, hidden_dim) self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_lstm_layer) - self.key_linear = nn.Linear(hidden_dim, hidden_dim) self.weight_linear = nn.Linear(hidden_dim, num_relation * 2) self.linear = nn.Linear(1, 1) @@ -56,7 +53,7 @@ def get_t_output(self, graph, h_index, r_index): query = self.query(q_index) hidden, hx = self.lstm(query) - memory = functional.one_hot(h_index, self.num_entity).unsqueeze(0) + memory = functional.one_hot(h_index, graph.num_entity).unsqueeze(0) for i in range(self.num_step): key = hidden[i]