diff --git a/bin/predict.sh b/bin/predict.sh index e0f5471..6926124 100644 --- a/bin/predict.sh +++ b/bin/predict.sh @@ -24,6 +24,7 @@ resultFileSuffix="_"${model_name}"_"${resultSaveMode} # max error data format tolerate number max_error_tol=100 +# auc,mae,rmse,confusion_matrix eval_metric="auc,mae" #value or leafid predict_type="value" diff --git a/config/model/ffm.conf b/config/model/ffm.conf index 6f60bca..91079e5 100644 --- a/config/model/ffm.conf +++ b/config/model/ffm.conf @@ -19,6 +19,7 @@ data { field_delim : "@" }, + max_feature_dim: ???, // ["0@0.1","1@0.5",...] y_sampling : [], assigned : false, diff --git a/demo/ffm/binary_classification/ffm.conf b/demo/ffm/binary_classification/ffm.conf index 14bb71f..04323fb 100644 --- a/demo/ffm/binary_classification/ffm.conf +++ b/demo/ffm/binary_classification/ffm.conf @@ -19,6 +19,7 @@ data { field_delim : "@" }, + max_feature_dim: 117, // ["0@0.1","1@0.5",...] y_sampling : [], assigned : false, diff --git a/demo/ffm/regression/ffm.conf b/demo/ffm/regression/ffm.conf index d9c1793..fe56d82 100644 --- a/demo/ffm/regression/ffm.conf +++ b/demo/ffm/regression/ffm.conf @@ -19,6 +19,7 @@ data { field_delim : "@" }, + max_feature_dim: 39, // ["0@0.1","1@0.5",...] y_sampling : [], assigned : false, diff --git a/src/main/java/com/fenbi/ytklearn/dataflow/FFMModelDataFlow.java b/src/main/java/com/fenbi/ytklearn/dataflow/FFMModelDataFlow.java index c237dca..ee98af0 100644 --- a/src/main/java/com/fenbi/ytklearn/dataflow/FFMModelDataFlow.java +++ b/src/main/java/com/fenbi/ytklearn/dataflow/FFMModelDataFlow.java @@ -58,6 +58,7 @@ public class FFMModelDataFlow extends ContinuousDataFlow { private Map field2IndexMap = new HashMap<>(); private int fieldSize; private int maxFeatureNum = -1; + private int maxFeatureDim = 100; private RandomParams randomParams; @@ -88,6 +89,7 @@ public FFMModelDataFlow(IFileSystem fs, fieldDelim = config.getString("data.delim.field_delim"); fieldDictPath = config.getString("model.field_dict_path"); + maxFeatureDim = config.getInt("data.max_feature_dim"); randomParams = new RandomParams(config, ""); @@ -103,6 +105,7 @@ public FFMModelDataFlow(IFileSystem fs, @Data public static class FFMCoreData extends ContinuousCoreData { private int maxFeatureNum; + private int maxFeatureDim; private String fieldDelim; private Map field2IndexMap; @@ -320,7 +323,7 @@ protected void loadModel() throws IOException, Mp4jException { @Override protected void handleOtherTrainInfo() throws Mp4jException { this.maxFeatureNum = ((FFMCoreData)threadTrainCoreDatas[0]).getMaxFeatureNum(); - LOG_UTILS.importantInfo("train line max feature num:" + maxFeatureNum); + LOG_UTILS.importantInfo("train line max feature num:" + maxFeatureNum + ", config max feature dim:" + maxFeatureDim); } @Override diff --git a/src/main/java/com/fenbi/ytklearn/predictor/ContinuousOnlinePredictor.java b/src/main/java/com/fenbi/ytklearn/predictor/ContinuousOnlinePredictor.java index 8ca8456..f3614a4 100644 --- a/src/main/java/com/fenbi/ytklearn/predictor/ContinuousOnlinePredictor.java +++ b/src/main/java/com/fenbi/ytklearn/predictor/ContinuousOnlinePredictor.java @@ -327,7 +327,6 @@ public double batchPredictFromFiles(String modelName, } } double predict = predict(fmap, otherinfo); - if (hasLabel) { label = Float.parseFloat(linfo[0]); loss += weight * loss(fmap, label, otherinfo); // score not predict? @@ -388,10 +387,10 @@ public double batchPredictFromFiles(String modelName, ", local error total number:" + errorNum + ", max error tol:" + maxErrorTol + ", has read real lines:" + realcnt + - ", weight lines:" + weightCnt); + ", weight lines:" + weightCnt, e); if (errorNum > maxErrorTol) { LOG.error("[ERROR] error number:" + errorNum + - " > " + "max tol:" + maxErrorTol); + " > " + "max tol:" + maxErrorTol, e); throw e; } } diff --git a/src/main/java/com/fenbi/ytklearn/predictor/FFMOnlinePredictor.java b/src/main/java/com/fenbi/ytklearn/predictor/FFMOnlinePredictor.java index 248caaf..9fc5600 100644 --- a/src/main/java/com/fenbi/ytklearn/predictor/FFMOnlinePredictor.java +++ b/src/main/java/com/fenbi/ytklearn/predictor/FFMOnlinePredictor.java @@ -43,7 +43,7 @@ public class FFMOnlinePredictor extends ContinuousOnlinePredictor { private final Map field2IndexMap = new HashMap<>(); private int fieldSize; - private final ThreadLocal assistbuffer = new ThreadLocal<>(); + private final ThreadLocal assistbuffer = new ThreadLocal<>(); private final ThreadLocal fieldbuffer = new ThreadLocal<>(); private final ThreadLocal valbuffer = new ThreadLocal<>(); @@ -55,9 +55,9 @@ public FFMOnlinePredictor(String configPath) throws Exception { List klist = config.getIntList("k"); K = klist.get(1); - fieldDelim = config.getString("field_delim"); + fieldDelim = config.getString("data.delim.field_delim"); fieldDictPath = config.getString("model.field_dict_path"); - maxFeatureNum = config.getInt("max_line_feature_num"); + maxFeatureNum = config.getInt("data.max_feature_dim") + 1; loadModel(); } @@ -68,9 +68,9 @@ public FFMOnlinePredictor(Reader configReader) throws Exception { List klist = config.getIntList("k"); K = klist.get(1); - fieldDelim = config.getString("field_delim"); + fieldDelim = config.getString("data.delim.field_delim"); fieldDictPath = config.getString("model.field_dict_path"); - maxFeatureNum = config.getInt("max_line_feature_num"); + maxFeatureNum = config.getInt("data.max_feature_dim") + 1; loadModel(); } @@ -104,24 +104,24 @@ protected OnlinePredictor loadModel() throws Exception { } int cnt = 0; - iterators = fs.read(Arrays.asList(fieldDictPath)); + iterators = fs.read(Arrays.asList(modelParams.data_path)); for (Iterator it : iterators) { while (it.hasNext()) { String line = it.next(); if (line.trim().length() == 0) { - LOG.error("invalid model line:" + line); + LOG.error("invalid model line(length=0):" + line); continue; } String []info = line.trim().split(modelParams.delim); - if (fieldSize != (info.length - 5) / K) { - LOG.info("invalid model line:" + line); - continue; - } +// if (fieldSize != (info.length - 5) / K) { +// LOG.info("invalid model line:" + line); +// continue; +// } - if (info.length < 2) { - LOG.error("[invalid model line:" + line); - continue; - } +// if (info.length < 2) { +// LOG.error("[invalid model line:" + line); +// continue; +// } float []w = modelMap.get(info[0]); @@ -156,9 +156,9 @@ public double score(Map features, Object other) { int stride = fieldSize * K; - double []assist = assistbuffer.get(); + float []assist = assistbuffer.get(); if (assist == null) { - assist = new double[K * fieldSize * (maxFeatureNum + 1)]; + assist = new float[K * fieldSize * (maxFeatureNum + 1)]; assistbuffer.set(assist); } @@ -183,7 +183,11 @@ public double score(Map features, Object other) { for (Map.Entry feature : features.entrySet()) { // field idx - fieldIdxArr[cidx] = field2IndexMap.get(feature.getKey().split(fieldDelim)[0]); + Integer fieldIdx = field2IndexMap.get(feature.getKey().split(fieldDelim)[0]); + if (fieldIdx == null) { + continue; + } + fieldIdxArr[cidx] = fieldIdx; float val = transform(feature.getKey(), feature.getValue()); // val