Skip to content

Commit

Permalink
Evaluation mode
Browse files Browse the repository at this point in the history
  • Loading branch information
luigiba committed Aug 30, 2019
1 parent ae5706a commit 154d225
Show file tree
Hide file tree
Showing 8 changed files with 591 additions and 627 deletions.
7 changes: 0 additions & 7 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

343 changes: 0 additions & 343 deletions Config.py

Large diffs are not rendered by default.

Binary file modified __pycache__/Config.cpython-36.pyc
Binary file not shown.
Binary file modified __pycache__/distribute_training.cpython-36.pyc
Binary file not shown.
504 changes: 392 additions & 112 deletions distribute_training.py

Large diffs are not rendered by default.

269 changes: 179 additions & 90 deletions main_spark.py

Large diffs are not rendered by default.

52 changes: 20 additions & 32 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,54 +4,42 @@
from TransR import TransR
from TransD import TransD
import sys
import os
os.environ['CUDA_VISIBLE_DEVICES']='-1'
for arg in sys.argv: print(arg, type(arg))
target_rel_index = None

for arg in sys.argv:
print(type(arg), arg)
print("\n")

path_to_append = sys.argv[1]
# max = sys.argv[2]
dim = sys.argv[2]
model = sys.argv[3]
lp = sys.argv[4]
dataset_path = sys.argv[1]
model_path = sys.argv[2]
cpp_path = sys.argv[3]
dim = sys.argv[4]
model = sys.argv[5]
if (len(sys.argv) >= 7): target_rel_index = sys.argv[6]



def get_ckpt(p):
ckpt = None
with open(p + "checkpoint", 'r') as f:
with open(os.path.join(p,"checkpoint"), 'r') as f:
first_line = f.readline()
ckpt = first_line.split(':')[1].strip().replace('"', '').split('/')
ckpt = ckpt[len(ckpt) - 1]
return ckpt

dataset_path = '/content/drive/My Drive/DBpedia/{}'.format(path_to_append)
path = dataset_path + 'model/'
ckpt = get_ckpt(path)

con = Config(cpp_lib_path='/content/OpenKEonSpark/release/Base.so')
ckpt = get_ckpt(model_path)
con = Config(cpp_lib_path=cpp_path)
con.set_in_path(dataset_path)
con.set_test_link_prediction(bool(int(lp)))
con.set_test_triple_classification(True)
con.set_dimension(int(dim))
con.init()

if model.lower() == "transe": con.set_model_and_session(TransE)
elif model.lower() == "transh": con.set_model_and_session(TransH)
elif model.lower() == "transr": con.set_model_and_session(TransR)
else: con.set_model_and_session(TransD)

con.set_import_files(os.path.join(model_path, ckpt))

if model.lower() == "transe":
con.set_model_and_session(TransE)
con.set_n_threads_LP(5)
elif model.lower() == "transh":
con.set_model_and_session(TransH)
con.set_n_threads_LP(5)
elif model.lower() == "transr":
con.set_model_and_session(TransR)
con.set_n_threads_LP(2)
else:
con.set_model_and_session(TransD)
con.set_n_threads_LP(2)

con.set_import_files(path+ckpt)
con.set_test_log_path(path)
con.test()
print(con.acc)
con.plot_roc(rel_index=5, fig_name='plot.png')
if target_rel_index != None: con.plot_roc(rel_index=int(target_rel_index), fig_name='plot.png')
43 changes: 0 additions & 43 deletions test_1.py

This file was deleted.

0 comments on commit 154d225

Please sign in to comment.