From f859b6c6acbd93062292c702d780425025a1fe3c Mon Sep 17 00:00:00 2001 From: luigiba Date: Tue, 13 Aug 2019 12:53:08 +0200 Subject: [PATCH] Refactoring --- .idea/workspace.xml | 12 +++++------- Config.py | 8 +++++--- Model.py | 2 +- TransE.py | 1 + __pycache__/Config.cpython-36.pyc | Bin 26063 -> 26267 bytes __pycache__/Model.cpython-36.pyc | Bin 3932 -> 3929 bytes __pycache__/TransE.cpython-36.pyc | Bin 2230 -> 2227 bytes .../distribute_training.cpython-36.pyc | Bin 7780 -> 7859 bytes base/Corrupt.h | 10 ---------- base/Test.h | 2 -- commands.txt | 6 +++--- distribute_training.py | 3 ++- res_spark/README.md | 1 - run_dbpedia.sh | 15 +++++++++++++-- test.py | 6 +++--- test_1.py | 7 ++++++- 16 files changed, 39 insertions(+), 34 deletions(-) delete mode 100644 res_spark/README.md diff --git a/.idea/workspace.xml b/.idea/workspace.xml index aa0ecbe..855e1df 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -4,16 +4,14 @@ - - + + + - - - - - + + diff --git a/Config.py b/Config.py index b055d84..ef6ea5a 100644 --- a/Config.py +++ b/Config.py @@ -175,6 +175,11 @@ def init(self): if self.valid_triple_classification: self.init_valid_triple_classification() + def set_n_threads_LP(self, n): + self.N_THREADS_LP = n + self.lp_res = [] + for _ in range(self.N_THREADS_LP): self.lp_res.append({}) + def set_mini_batch(self): tot = None @@ -402,7 +407,6 @@ def test_step(self, test_h, test_t, test_r): return predict - def test_lp_range(self, index, lef, rig): current_lp_res = { 'r_tot' : 0.0, 'r_filter_tot' : 0.0, 'r_tot_constrain' : 0.0, 'r_filter_tot_constrain' : 0.0, @@ -452,7 +456,6 @@ def test_lp_range(self, index, lef, rig): with open(self.test_log_path+"thread"+str(index), 'r') as f: last_i = int(f.readline()) print("Restoring test results from index {}".format(last_i)) - lef = last_i + 1 for key in current_lp_res.keys(): current_lp_res[key] = float(f.readline()) @@ -582,7 +585,6 @@ def test_lp_range(self, index, lef, rig): self.lp_res[index] = current_lp_res - def test(self): with self.graph.as_default(): with self.sess.as_default(): diff --git a/Model.py b/Model.py index 08d2337..ec16f4e 100644 --- a/Model.py +++ b/Model.py @@ -67,7 +67,7 @@ def input_def(self): self.negative_t = tf.transpose(tf.reshape(self.batch_t[config.batch_size:config.batch_seq_size], [config.negative_ent + config.negative_rel, -1]), perm = [1, 0]) self.negative_r = tf.transpose(tf.reshape(self.batch_r[config.batch_size:config.batch_seq_size], [config.negative_ent + config.negative_rel, -1]), perm = [1, 0]) self.negative_y = tf.transpose(tf.reshape(self.batch_y[config.batch_size:config.batch_seq_size], [config.negative_ent + config.negative_rel, -1]), perm = [1, 0]) - + self.predict_h = tf.placeholder(tf.int64, [None]) self.predict_t = tf.placeholder(tf.int64, [None]) self.predict_r = tf.placeholder(tf.int64, [None]) diff --git a/TransE.py b/TransE.py index 8810365..91fef0b 100644 --- a/TransE.py +++ b/TransE.py @@ -55,4 +55,5 @@ def predict_def(self): predict_h_e = tf.nn.embedding_lookup(self.ent_embeddings, predict_h) predict_t_e = tf.nn.embedding_lookup(self.ent_embeddings, predict_t) predict_r_e = tf.nn.embedding_lookup(self.rel_embeddings, predict_r) + ##--## self.predict = tf.reduce_mean(self._calc(predict_h_e, predict_t_e, predict_r_e), 1, keep_dims = False) \ No newline at end of file diff --git a/__pycache__/Config.cpython-36.pyc b/__pycache__/Config.cpython-36.pyc index a9a9aac98d0e66469ad54a194e3556d8ab475ed5..916bfcc2f982529725e962900e4729bb8cd5c278 100644 GIT binary patch delta 2761 zcmai$3v`sl70183N!HCKOCFG9^WJQpOEwRZ5DgCj6F^&vm_SK4jq8&6NS2Z;&TeBM zfmkIy)%M8aT9ucGEw#19M^~!tIeN5dmE!}o1#GukLTRmlt;SYQPd(*y{9QMMy+q)m%k-RPQ9Wo5l@R2Zl7D2uQF|_01$<>ET=Id!iPo5bhEGS#9|d_gQ%o+>dct`{qm z$0J6>4WeCHGew73sp_i5D)B{i_lnJ;Q+!F))QBy@C%TkXE50ncRb8F9QTRE}vMMJc zzN*S|#i$q*x2P8Lgr&pq+#20tIX`{CV#GfQ>V~UTUEz9#5gIZ4Mz5hrLq!qGh_TMN zo!X4<^v$qoT|!Wm$++3M+7N~cNXtEmvRU<%R{~c7Eef6lS*LC>8jy21Q0i=pVed*hkxSaJws$ z{-q|A!%eQ%3%+lgDCt@ zcPgbRUk(?Q_EM&{_LXj-EG{VPqP@D!_OevURW(mgA-9(wr4lZwc&>K85>Gw=8~`2! zzN2tKcVCJ*v>_ywn1Oy%9@JAVt#nca&#SyU{qoK{Q8|}txS;A_-sM%F#FbV2ipRwt zSEW-UpZDZ*rRQ5&Q#F2wk9cO&T-AoDm3PkEeR(SrA%Dn=su$7{zODLx$7Atn{|Np^ zxyD;gkLdZX^`_Bs9`rs+hqQ!CaKz_cl_GZVk@tD|x@(F&k zwwz41)OFFgj$B`uWbF?LbA2#6QfH^!PuCHvwzN}|4-jGd09BX*e zs#AD~?`>Q_uj;C^jVUz7A2sg(q-sx-*ZLZQ|4P?SG}&o;^jyajv?|{Dm?*iw5KLV$LcNBWnr9SU3kbhP;83TR``~f%x{2urS_z?KJ z+I5C}3+CT|4}eMFZQxJ9S>TKUr(e@tEFJ3Dl+yqlI&sW`m8N_yex|j7e!`QjUhfE; zn}KA&t`M`DL&0z)EOSvv2uGxR3U$MLU0XSw;80rw{gfYUYfkw!+{b|9e7dcgp60a0 zW!C55h;qx~Gqjsamvj;F=#tY!(VnGUHa>FwJt=xsdVR#f72Q|JWX#ka9bfS>)oN9= zT3T+c^c1RYY8sw1uc4lAYA@&4+LQTQdo^$C$W6HM{)ERT&&O9^8nCf5!{tTK8sM{=D3&xBSRWKG zL`COGD!{V!Tw_TADFYkhuc&qKo?yQ(+k@qLZ2D~_s}U_#Ax2-3dixwlL%oNh`pJ9T zH-nBvpZA@l?t0bEW4|bFT#;D_p~LHU;R*C_fUgm_WUv_Z>7is3DyAGPt9u5WiO%dk zLH_?SSRaP?KZD6;45F7e^>BgG85tX-DS|x_zckkeBI*!4AKmJI(UPSV)a%eG*9#m$ zK6u-j7Wc)#(9LJXTRz+u&uu(Ie0GhSrn94Gb?Rf#TK#2EzxlkOXKC3c%zXeD1#Sht z2HXbR4vg_cPvz)cFt-6afSmvX-v;gh?ge%MyMaBxKHz>}8t`4|fjVF|a3ydR&;nfCoZ8aIoJ9plQCE0S!IqF{ zu{YULZB@2Bo6DADo57PJYo}@5X(qDOY+h!!jT()%VVlF2g|9?!PpX8Nde8an5py6M dl55w8Zi&B^@I=Z~;AcPq=1~EZ^Eb?;{|0Ld&tw1q delta 2462 zcmai#2~?C-6vw|i!Z2)#D6+^71B`64nwc00;*yGj0)vR7?-MYc(Kiz*E?6$djwcPT zWvPjlsHv$q9qk-Fny0cfm&7F-OT`s6bE}T0PS(3`NQ;9t48QMw_uhB!d;j;nZ&n?J z<6nb&nzy${Ul>vrQ0KnX_aZ2Ysu+POMhydF2UcCT`CjcSEStSBAB>T#Kg*e~7^7G& z8z8RHERPKoS3R4~2C=~+BZduOFA7gAo56;%VZsx~YS?g=FDwHa!A6Q$4_3em#WkMI zWTRP;h$OICteBMuE0GylsfZ=9F|15nli65SjzOAHon|u|ds$@kVVVNPFvq@7^QenG z+rKBkPUHb~a1dPrt#AmZ2U_3=wglSX__9!cv;~zx6E+3Sg3~y>`vN$Nr@NbAK4yn? z!NI{{a1JL0SG3+=4GxF%s12!f-g|jiRa^OFOGUZG+>A>?rookEp&YKGHMAOT;>}Pi z+{UqCR=A5zVN2ivn!=5+L{`p^aP6p?ZRM~4?$b6C8{+yv z95%-_Jf+AgLk1+_3B$&A;|=&mj||8Z@$PuO$8^ZTqWG09`y_4$JKjp{4F#xATAkcx%Fb$Hs8MUOTB|C|4i~lB zG+ES3rg~fPZc-{3F)DeQ_daJI>EDU#lJzhSzf3NL1Jupaoi=}pg{hO7J-ep|^nKjO z;Z~rE(iXE*&Kgn1;cDZiDzmY3)s(3gF2<=gQ+Wk9O|5jQ-GGZy5~0$*D@6}s2WTRn z5k4n8uFJom+j{&hHOu#)FpaJSrfQWtD{R5RX?m?3dRVHlD6Jb^#XbWfOzd(!Lw4vou5&>o}W-#C|MQ?0ggt4*D{*uEoUwl|!>r2fUVXDH(|;XL67 zf`jlA;Wpt%!gqvQgqwsLgiC~G!Zm>k@h?7i`|+#d!p{=EC0rz2Abd~wnec#cUr}Nd zevZt0gu8@0ge!#0gcibejL&&C)Kk2sxEG-(IVDoa@)o`joR^adALFi^1TAe2pNTCw z;o4nf&cKejdN_sYxv8)lD{?bEPmsHba1uA=#={;wl^dlzT&*!j=K0puwEH z%b>Bx4J_5UOOZ$}d@X+k%|+qA5Prqz5oPwpBMyNUxNl?{?kI?l zCgGeq%8ihF&{lr@vmi0{IweX)q^45Nc9cU$L6FbC6jr!9&Mq9-_5d2F-Uf6OhWgAP zGn3FNoNiP^R=8N8!_{gtahtQ!L_{4J3D0a{FU_ zajXzLe?kg=I_mY15ORhRX!{&s991E3D$E%7q$q&bNbs z>Uqd{y~O2%ags*Lr=w9Gt?JUQaMHfG^a7M6i9CY`cdAEeQj26NA^XYB)RgxmUkZU# zZEfa&O4*91QlQ;tjXZRp&;wZprouk>HC>6NPOV0i{J&386o z4xyGXm+&fK9$^7tA+BYy>QXY_B-9a*u$=G~VFlqG!n=f3gw=$Og!c*S335Ff$=XD4 z{)@jCNY{Dk;&ef} zaBNhI0z7rKN~Uh6&Rge4e_SKd3|@R2hEMbhlW%`IG2Dyr2_c9|MiZj2aAMwH!wI7I diff --git a/__pycache__/Model.cpython-36.pyc b/__pycache__/Model.cpython-36.pyc index e602c8f7ca80eb9a20effb0578449b4191b9be79..712d6be23ef22f26d7481460028e0f031115022b 100644 GIT binary patch delta 37 tcmca3cTzgdV%p2f$Tax|`#k{V^bAq} delta 41 xcmdlixJ{7Ln3tC;(bgoEXCvnuMn>bw-x%dt<2Q>l?PX2f<@96O0`q4B1|kGDi3v$8V_evy2?^je;IlBp>6kHi z4;4qnBoi@5#Y`ya@uf+VDrLH)a+@?MQl!nK$lH^GRY2}$?|u9BdvCw@cJKG9cfL3k z3I(FflJq+)n- zln?QPXk&bsN6^OK{H#-mcz@E@4D#NKfgc!LcI$yrpq>@@n!=h3p>yFfZj!`tNpAMa zy%I9^PYTMl4cbx7m)?(|pP@bVx^Sy%3D2kwMMdBKse*ri_zZcK`_b zMy83K6@$hphu~Z@P9r{zb{g$8+B0a+pgoKB?7Cv-F!wnTK+lI_-yJ4$4*i1JAK1E` zx7sSt8wHWKsv<82k_1CyBgBFJc^v0LcNBknDUNlRsUttLq53Ezq@;{>-l!QDENU<# z&kG5RlB<@=i!$;md0SYmTWpYYXxpsy5>d0Zm3*6YMf8$(*n`<+DeI!hVFfc)iD!L$ z$;Ti0_+!MfO3j#)_1OhvUG9tcq%Y!1UqpK^q=+Ghj8F0VnJ5^a^HPfNbc%@L9(EWu zT(ai>kA-V>NAZ~pGJ{>Rl#XK7uoI)Q6KKmKD!U4#jgv7^>P|#SL^mXri;^A1n-TW| zzeLWs-~7Ct+ANE>NN*5iJri;NE_7MXkX6B>j62|SB}+xtujoI7{&d%GUEMn%e%NW# zZNovSI{06v94gm$AsP!uZtMUQt@W! zb_~TQFitho$T{8)j#eYCHZSuY9`^6{UExtn^ z)MPPB{-{jVALo{=o2IqCc}HJuKJZp5k-u6f@jk~FxhSnK!JcGXsM zZO!{WH3^YQ+0>D>%T}!|^D9K=jhzt+gS%?;{_4%9OEVQ{xf7YO!6&%bqTn_%kH$fv zVH$-1XjEew$T!H+a9oKZisPV6haiOAzcQ-+WRhu<8Dj&BC^u X&Xu9y&g7CI&=iGf;Xql-X{mn!#H}>O delta 1146 zcmZ`&OH30{6n*c_RLWp0l>TYyw55fr_ya|2QBcG$1vF7Z{497%oSLdbV`_3IU5rbk zH4_sTtX;S>vvT9g7}v&?3pcWI?I$rVyzjLHOrXj1o_F6l_nrGb?)YjcJL23;CNr;Vmk*P+)qV3Ve2`6_f6BQ~J~yv8WvQH%#M9>jPE;~|V=7{~6YRuWZ*LBq@u z=v-1DX|rpJ)o-ONz2@P%kpaC{1Rb_T2`WN;#HJg^VH`usZm2fJ>J!AZoTC$BLjy0OwoJGDVF-I7Sf=}i?GRO~d);Qv`j8|Qd7MtD3%AhV8b z1=|&@^ODurf`*e9Y*kWcLg8_7d|2%*R{#T6z6A_)eQW zi&_IPV9eGM5VI-YB@T&_D}innA=yppQQ$tCRa5-A*Te*{*T`yy0gbl&|Gi8A+hY(I^E@hslGw|uLRGj)7~ zUzEjK-MoRwyy+h8W7L^S6zF}Yni%Y8h)4RDI4g-qisGk^o{UqkGoD=ZYgA)q9j6Pl oTnIa_le2Wm`I9U=GyO4Icb59|RCk{CN4r#2@%pvASNF#L0SiMPUjP6A diff --git a/base/Corrupt.h b/base/Corrupt.h index 41cbdf4..6ee0c1f 100644 --- a/base/Corrupt.h +++ b/base/Corrupt.h @@ -120,16 +120,6 @@ INT corrupt(INT h, INT r){ INT rr = tail_rig[r]; INT t; - //EDIT -// while (ll < rr){ -// t = tail_type[ll]; -// if (not _find(h, t, r)) { -// return t; -// } -// ll++; -// } -// return corrupt_head(0, h, r); - INT loop = 0; while(1) { t = tail_type[rand(ll, rr)]; diff --git a/base/Test.h b/base/Test.h index 4474caa..c5c7023 100644 --- a/base/Test.h +++ b/base/Test.h @@ -113,13 +113,11 @@ INT* testTail(INT index, REAL *con) { if (value < minimal) { r_s += 1; - if (value < r_min_s){ r_min_s = value; r_min = j; } - if (not _find(h, j, r)){ r_filter_s += 1; diff --git a/commands.txt b/commands.txt index 95b7a40..f5839ca 100644 --- a/commands.txt +++ b/commands.txt @@ -37,9 +37,9 @@ os.environ["WORK_DIR_PREFIX"] = "/content/OpenKEonSpark" os.environ["SPARK_HOME"] = "/content/spark-2.1.1-bin-hadoop2.7" #execute -!bash $WORK_DIR_PREFIX/run_dbpedia.sh 5 64 "TransE" -!bash $WORK_DIR_PREFIX/run_dbpedia.sh 10 64 "TransE" -!bash $WORK_DIR_PREFIX/run_dbpedia.sh 15 64 "TransE" +!bash $WORK_DIR_PREFIX/run_dbpedia.sh 5 64 "TransE" 0.0001 +!bash $WORK_DIR_PREFIX/run_dbpedia.sh 10 64 "TransE" 0.0001 +!bash $WORK_DIR_PREFIX/run_dbpedia.sh 15 64 "TransE" 0.0001 diff --git a/distribute_training.py b/distribute_training.py index bbd9fd6..a93baad 100644 --- a/distribute_training.py +++ b/distribute_training.py @@ -232,7 +232,8 @@ def main_fun(argv, ctx): if (task_index == 0) and (not sess.should_stop()) and (g >= to_reach_step): - to_reach_step += stopping_step + while (g >= to_reach_step): + to_reach_step += stopping_step ################## ACCURACY ################## feed_dict[trainModel.predict_h] = con.valid_pos_h diff --git a/res_spark/README.md b/res_spark/README.md deleted file mode 100644 index 8b13789..0000000 --- a/res_spark/README.md +++ /dev/null @@ -1 +0,0 @@ - diff --git a/run_dbpedia.sh b/run_dbpedia.sh index 8394f8c..f93637e 100644 --- a/run_dbpedia.sh +++ b/run_dbpedia.sh @@ -2,6 +2,7 @@ echo "====================================== Params ============================ echo "$1" echo "$2" echo "$3" +echo "$4" echo "====================================== Clearning res_spark directory ======================================" rm /home/luigi/IdeaProjects/OpenKE_new_Spark/res_spark/* @@ -18,6 +19,12 @@ m=$((n-1)) for i in `seq 0 $m` do + if [ -f /content/drive/My\ Drive/DBpedia/$n/$i/model/thread0 ]; then + echo "====================================== Test for batch $i ======================================" + python3 $WORK_DIR_PREFIX/test.py $i $n $2 $3 1 | tee /content/drive/My\ Drive/DBpedia/$n/$i/res.txt + continue + fi + if [ -f /content/drive/My\ Drive/DBpedia/$n/$i/res.txt ]; then echo "Batch $i already done; Skipping batch $i" continue @@ -53,7 +60,7 @@ do --cluster_size $SPARK_WORKER_INSTANCES --num_ps 1 --num_gpus 1 --cpp_lib_path $WORK_DIR_PREFIX/release/Base.so \ --input_path /content/drive/My\ Drive/DBpedia/$n/$i/ \ --output_path $WORK_DIR_PREFIX/res_spark \ - --alpha 0.0001 --optimizer SGD --train_times 50 --ent_neg_rate 1 --embedding_dimension $2 --margin 1.0 --model $3 + --alpha $4 --optimizer SGD --train_times 50 --ent_neg_rate 1 --embedding_dimension $2 --margin 1.0 --model $3 echo "====================================== Copying model for batch $i ======================================" @@ -61,7 +68,11 @@ do echo "====================================== Test for batch $i ======================================" - python3 $WORK_DIR_PREFIX/test.py $i $n $2 $3 | tee /content/drive/My\ Drive/DBpedia/$n/$i/res.txt + if [ $i -eq $m ]; then + python3 $WORK_DIR_PREFIX/test.py $i $n $2 $3 1 | tee /content/drive/My\ Drive/DBpedia/$n/$i/res.txt + else + python3 $WORK_DIR_PREFIX/test.py $i $n $2 $3 0 | tee /content/drive/My\ Drive/DBpedia/$n/$i/res.txt + fi done diff --git a/test.py b/test.py index fe549c0..46c02b2 100644 --- a/test.py +++ b/test.py @@ -6,14 +6,14 @@ import sys for arg in sys.argv: - print(arg) - print(type(arg)) + print(type(arg), arg) print("\n") n = sys.argv[1] max = sys.argv[2] dim = sys.argv[3] model = sys.argv[4] +lp = sys.argv[5] @@ -31,7 +31,7 @@ def get_ckpt(p): con = Config(cpp_lib_path='/content/OpenKEonSpark/release/Base.so') con.set_in_path(dataset_path) -con.set_test_link_prediction(True) +con.set_test_link_prediction(bool(lp)) con.set_test_triple_classification(True) con.set_dimension(int(dim)) con.init() diff --git a/test_1.py b/test_1.py index 7cb8f95..ee865e6 100644 --- a/test_1.py +++ b/test_1.py @@ -1,5 +1,6 @@ from Config import Config from TransE import TransE +from TransH import TransH import sys # import os @@ -15,8 +16,9 @@ def get_ckpt(p): #/home/luigi/IdeaProjects/OpenKE_new_Spark/benchmarks/DBpedia dataset_path = '/home/luigi/files/stuff/Done/DBpedia/5/0/' -# dataset_path = '/home/luigi/files/stuff/superuser/9/0/' +# dataset_path = '/home/luigi/files/stuff/superuser/9/1/' path = dataset_path + 'model/' +# path = '/home/luigi/IdeaProjects/OpenKEonSpark/res_spark/' print(path) ckpt = get_ckpt(path) @@ -30,7 +32,10 @@ def get_ckpt(p): con.set_model_and_session(TransE) con.set_import_files(path+ckpt) con.set_test_log_path(path) +con.set_n_threads_LP(1) con.test() + +con.predict_tail_entity(349585, 5, 10) # for i in range(0,100): # con.predict_tail_entity(i,0,1) # print(con.acc)