Skip to content

Commit

Permalink
Bug fix + plot_roc
Browse files Browse the repository at this point in the history
  • Loading branch information
luigiba committed Aug 30, 2019
1 parent 6208650 commit ae5706a
Show file tree
Hide file tree
Showing 16 changed files with 315 additions and 117 deletions.
57 changes: 56 additions & 1 deletion .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

63 changes: 62 additions & 1 deletion Config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#coding:utf-8
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import os
import ctypes
Expand Down Expand Up @@ -44,6 +45,12 @@ def __init__(self, cpp_lib_path=None, init_new_entities=False):
self.lib.getBestThreshold.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
self.lib.test_triple_classification.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]

#ROC
self.lib.get_n_interval.argtypes = [ctypes.c_int64, ctypes.c_void_p, ctypes.c_void_p]
self.lib.get_n_interval.restype = ctypes.c_int64
self.lib.get_TPFP.argtypes = [ctypes.c_int64, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
self.lib.get_TPFP.restype = ctypes.POINTER( ctypes.c_int64 * 2 )

#set other parameters
self.in_path = None
self.test_log_path = None
Expand All @@ -54,7 +61,7 @@ def __init__(self, cpp_lib_path=None, init_new_entities=False):
self.rel_size = self.hidden_size
self.train_times = 0
self.margin = 1.0
self.nbatches = 100
self.nbatches = 0
self.negative_ent = 1
self.negative_rel = 0
self.workThreads = 8
Expand Down Expand Up @@ -852,6 +859,60 @@ def test(self):
print("\nElapsed test time (seconds): {}".format(test_time_elapsed))


def plot_roc(self, rel_index, fig_name=None):
if self.importName != None:
self.restore_tensorflow()
self.init_triple_classification()

self.lib.getValidBatch(self.valid_pos_h_addr, self.valid_pos_t_addr, self.valid_pos_r_addr, self.valid_neg_h_addr, self.valid_neg_t_addr, self.valid_neg_r_addr)
res_pos_valid = self.test_step(self.valid_pos_h, self.valid_pos_t, self.valid_pos_r)
res_neg_valid = self.test_step(self.valid_neg_h, self.valid_neg_t, self.valid_neg_r)

self.lib.getTestBatch(self.test_pos_h_addr, self.test_pos_t_addr, self.test_pos_r_addr, self.test_neg_h_addr, self.test_neg_t_addr, self.test_neg_r_addr)
res_pos_test = self.test_step(self.test_pos_h, self.test_pos_t, self.test_pos_r)
res_neg_test = self.test_step(self.test_neg_h, self.test_neg_t, self.test_neg_r)

n_intervals = self.lib.get_n_interval(rel_index, res_pos_valid.__array_interface__['data'][0], res_neg_valid.__array_interface__['data'][0])
self.lib.get_TPFP.restype = ctypes.POINTER( ctypes.c_int64 * ((n_intervals+1)*2) )
res = [j for j in self.lib.get_TPFP(rel_index, res_pos_valid.__array_interface__['data'][0], res_neg_valid.__array_interface__['data'][0], res_pos_test.__array_interface__['data'][0], res_neg_test.__array_interface__['data'][0]).contents]

TPR = []
FPR = []

if res[0] != 0 or res[0+n_intervals+1] != 0:
TPR.append(0)
FPR.append(0)


for i in range(0, n_intervals+1):
TPR.append(res[i])
FPR.append(res[i+n_intervals+1])

if TPR[len(TPR)-1] != len(res_pos_test.flatten()) or FPR[len(FPR)-1] != len(res_neg_test.flatten()):
TPR.append(len(res_pos_test.flatten()))
FPR.append(len(res_neg_test.flatten()))


for i in range(len(TPR)): TPR[i] /= TPR[-1]
for i in range(len(FPR)): FPR[i] /= FPR[-1]

auc = np.trapz(TPR, FPR)

plt.figure()
lw=2
plt.plot(FPR, TPR, color='darkorange', lw=lw, label='ROC curve (area = %0.3f)' % auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate (FPR)')
plt.ylabel('True Positive Rate (TPR)')
plt.title('ROC Curve')
plt.legend(loc="lower right")
if fig_name == None or fig_name == '':
plt.show()
else:
plt.savefig(fig_name)


def predict_head_entity(self, t, r, k):
r'''This mothod predicts the top k head entities given tail entity and relation.
Expand Down
Binary file modified __pycache__/Config.cpython-36.pyc
Binary file not shown.
Binary file modified __pycache__/TransD.cpython-36.pyc
Binary file not shown.
Binary file modified __pycache__/TransE.cpython-36.pyc
Binary file not shown.
Binary file modified __pycache__/distribute_training.cpython-36.pyc
Binary file not shown.
2 changes: 2 additions & 0 deletions base/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ void importTrainFiles();
extern "C"
void importOntologyFiles();



struct Parameter {
INT id;
INT *batch_h;
Expand Down
4 changes: 4 additions & 0 deletions base/Setting.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ extern "C"
void setBern(INT con) {
bernFlag = con;
}
/*
============================================================
*/
REAL interval = 0.01;


#endif
62 changes: 60 additions & 2 deletions base/Test.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,9 @@ void getValidBatch(INT *ph, INT *pt, INT *pr, INT *nh, INT *nt, INT *nr) {
REAL threshEntire;
extern "C"
void getBestThreshold(REAL *relThresh, REAL *score_pos, REAL *score_neg) {
REAL interval = 0.01;
REAL min_score, max_score, bestThresh, tmpThresh, bestAcc, tmpAcc;
INT n_interval, correct, total;

for (INT r = 0; r < relationTotal; r++) {
if (validLef[r] == -1) continue;
total = (validRig[r] - validLef[r] + 1) * 2;
Expand All @@ -319,6 +319,7 @@ void getBestThreshold(REAL *relThresh, REAL *score_pos, REAL *score_neg) {
if(score_neg[i] > max_score) max_score = score_neg[i];
}
n_interval = INT((max_score - min_score)/interval);

for (INT i = 0; i <= n_interval; i++) {
tmpThresh = min_score + i * interval;
correct = 0;
Expand All @@ -345,8 +346,8 @@ extern "C"
//EDIT
void test_triple_classification(REAL *relThresh, REAL *score_pos, REAL *score_neg, REAL *acc_addr) {
testAcc = (REAL *)calloc(relationTotal, sizeof(REAL));
INT TP = 0, TN = 0, FP = 0, FN = 0;
REAL accuracy, precision, recall, fmeasure;
INT TP = 0, TN = 0, FP = 0, FN = 0;

for (INT r = 0; r < relationTotal; r++) {
if (validLef[r] == -1 || testLef[r] ==-1) continue;
Expand Down Expand Up @@ -386,5 +387,62 @@ void test_triple_classification(REAL *relThresh, REAL *score_pos, REAL *score_ne
}


extern "C"
INT get_n_interval(INT r, REAL *score_pos, REAL *score_neg){
REAL min_score, max_score;
INT n_interval, total;
if (validLef[r] == -1) return 0;
total = (validRig[r] - validLef[r] + 1) * 2;
min_score = score_pos[validLef[r]];
if (score_neg[validLef[r]] < min_score) min_score = score_neg[validLef[r]];
max_score = score_pos[validLef[r]];
if (score_neg[validLef[r]] > max_score) max_score = score_neg[validLef[r]];
for (INT i = validLef[r]+1; i <= validRig[r]; i++) {
if(score_pos[i] < min_score) min_score = score_pos[i];
if(score_pos[i] > max_score) max_score = score_pos[i];
if(score_neg[i] < min_score) min_score = score_neg[i];
if(score_neg[i] > max_score) max_score = score_neg[i];
}
return INT((max_score - min_score)/interval);
}


extern "C"
INT* get_TPFP(INT r, REAL *score_pos, REAL *score_neg, REAL *score_pos_test, REAL *score_neg_test) {
REAL min_score, max_score, tmpThresh;
INT n_interval, total;


if (validLef[r] == -1) return 0;
total = (validRig[r] - validLef[r] + 1) * 2;
min_score = score_pos[validLef[r]];
if (score_neg[validLef[r]] < min_score) min_score = score_neg[validLef[r]];
max_score = score_pos[validLef[r]];
if (score_neg[validLef[r]] > max_score) max_score = score_neg[validLef[r]];
for (INT i = validLef[r]+1; i <= validRig[r]; i++) {
if(score_pos[i] < min_score) min_score = score_pos[i];
if(score_pos[i] > max_score) max_score = score_pos[i];
if(score_neg[i] < min_score) min_score = score_neg[i];
if(score_neg[i] > max_score) max_score = score_neg[i];
}
n_interval = INT((max_score - min_score)/interval);


INT* TPFPs = new INT[(n_interval+1)*2];
for (INT i = 0; i <= n_interval; i++) {
INT TP = 0, FP = 0;
tmpThresh = min_score + i * interval;
for (INT i = testLef[r]; i <= testRig[r]; i++) {
if (score_pos_test[i] <= tmpThresh) TP++;
if (score_neg_test[i] <= tmpThresh) FP++;
}
TPFPs[i] = TP;
TPFPs[i + n_interval+1] = FP;
}

return TPFPs;
}



#endif
Loading

0 comments on commit ae5706a

Please sign in to comment.