Skip to content

Commit

Permalink
FIX:ffm predictor
Browse files Browse the repository at this point in the history
scharoun committed May 9, 2018
1 parent 47bbbb9 commit aa26f8a
Showing 7 changed files with 32 additions and 22 deletions.
1 change: 1 addition & 0 deletions bin/predict.sh
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@ resultFileSuffix="_"${model_name}"_"${resultSaveMode}

# max error data format tolerate number
max_error_tol=100
# auc,mae,rmse,confusion_matrix
eval_metric="auc,mae"
#value or leafid
predict_type="value"
1 change: 1 addition & 0 deletions config/model/ffm.conf
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ data {
field_delim : "@"
},

max_feature_dim: ???,
// ["0@0.1","1@0.5",...]
y_sampling : [],
assigned : false,
1 change: 1 addition & 0 deletions demo/ffm/binary_classification/ffm.conf
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ data {
field_delim : "@"
},

max_feature_dim: 117,
// ["0@0.1","1@0.5",...]
y_sampling : [],
assigned : false,
1 change: 1 addition & 0 deletions demo/ffm/regression/ffm.conf
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ data {
field_delim : "@"
},

max_feature_dim: 39,
// ["0@0.1","1@0.5",...]
y_sampling : [],
assigned : false,
Original file line number Diff line number Diff line change
@@ -58,6 +58,7 @@ public class FFMModelDataFlow extends ContinuousDataFlow {
private Map<String, Integer> field2IndexMap = new HashMap<>();
private int fieldSize;
private int maxFeatureNum = -1;
private int maxFeatureDim = 100;

private RandomParams randomParams;

@@ -88,6 +89,7 @@ public FFMModelDataFlow(IFileSystem fs,

fieldDelim = config.getString("data.delim.field_delim");
fieldDictPath = config.getString("model.field_dict_path");
maxFeatureDim = config.getInt("data.max_feature_dim");

randomParams = new RandomParams(config, "");

@@ -103,6 +105,7 @@ public FFMModelDataFlow(IFileSystem fs,
@Data
public static class FFMCoreData extends ContinuousCoreData {
private int maxFeatureNum;
private int maxFeatureDim;
private String fieldDelim;
private Map<String, Integer> field2IndexMap;

@@ -320,7 +323,7 @@ protected void loadModel() throws IOException, Mp4jException {
@Override
protected void handleOtherTrainInfo() throws Mp4jException {
this.maxFeatureNum = ((FFMCoreData)threadTrainCoreDatas[0]).getMaxFeatureNum();
LOG_UTILS.importantInfo("train line max feature num:" + maxFeatureNum);
LOG_UTILS.importantInfo("train line max feature num:" + maxFeatureNum + ", config max feature dim:" + maxFeatureDim);
}

@Override
Original file line number Diff line number Diff line change
@@ -327,7 +327,6 @@ public double batchPredictFromFiles(String modelName,
}
}
double predict = predict(fmap, otherinfo);

if (hasLabel) {
label = Float.parseFloat(linfo[0]);
loss += weight * loss(fmap, label, otherinfo); // score not predict?
@@ -388,10 +387,10 @@ public double batchPredictFromFiles(String modelName,
", local error total number:" + errorNum +
", max error tol:" + maxErrorTol +
", has read real lines:" + realcnt +
", weight lines:" + weightCnt);
", weight lines:" + weightCnt, e);
if (errorNum > maxErrorTol) {
LOG.error("[ERROR] error number:" + errorNum +
" > " + "max tol:" + maxErrorTol);
" > " + "max tol:" + maxErrorTol, e);
throw e;
}
}
40 changes: 22 additions & 18 deletions src/main/java/com/fenbi/ytklearn/predictor/FFMOnlinePredictor.java
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ public class FFMOnlinePredictor extends ContinuousOnlinePredictor<float[]> {
private final Map<String, Integer> field2IndexMap = new HashMap<>();
private int fieldSize;

private final ThreadLocal<double[]> assistbuffer = new ThreadLocal<>();
private final ThreadLocal<float[]> assistbuffer = new ThreadLocal<>();
private final ThreadLocal<int[]> fieldbuffer = new ThreadLocal<>();
private final ThreadLocal<float[]> valbuffer = new ThreadLocal<>();

@@ -55,9 +55,9 @@ public FFMOnlinePredictor(String configPath) throws Exception {
List<Integer> klist = config.getIntList("k");
K = klist.get(1);

fieldDelim = config.getString("field_delim");
fieldDelim = config.getString("data.delim.field_delim");
fieldDictPath = config.getString("model.field_dict_path");
maxFeatureNum = config.getInt("max_line_feature_num");
maxFeatureNum = config.getInt("data.max_feature_dim") + 1;

loadModel();
}
@@ -68,9 +68,9 @@ public FFMOnlinePredictor(Reader configReader) throws Exception {
List<Integer> klist = config.getIntList("k");
K = klist.get(1);

fieldDelim = config.getString("field_delim");
fieldDelim = config.getString("data.delim.field_delim");
fieldDictPath = config.getString("model.field_dict_path");
maxFeatureNum = config.getInt("max_line_feature_num");
maxFeatureNum = config.getInt("data.max_feature_dim") + 1;

loadModel();
}
@@ -104,24 +104,24 @@ protected OnlinePredictor loadModel() throws Exception {
}

int cnt = 0;
iterators = fs.read(Arrays.asList(fieldDictPath));
iterators = fs.read(Arrays.asList(modelParams.data_path));
for (Iterator<String> it : iterators) {
while (it.hasNext()) {
String line = it.next();
if (line.trim().length() == 0) {
LOG.error("invalid model line:" + line);
LOG.error("invalid model line(length=0):" + line);
continue;
}
String []info = line.trim().split(modelParams.delim);
if (fieldSize != (info.length - 5) / K) {
LOG.info("invalid model line:" + line);
continue;
}
// if (fieldSize != (info.length - 5) / K) {
// LOG.info("invalid model line:" + line);
// continue;
// }

if (info.length < 2) {
LOG.error("[invalid model line:" + line);
continue;
}
// if (info.length < 2) {
// LOG.error("[invalid model line:" + line);
// continue;
// }


float []w = modelMap.get(info[0]);
@@ -156,9 +156,9 @@ public double score(Map<String, Float> features, Object other) {

int stride = fieldSize * K;

double []assist = assistbuffer.get();
float []assist = assistbuffer.get();
if (assist == null) {
assist = new double[K * fieldSize * (maxFeatureNum + 1)];
assist = new float[K * fieldSize * (maxFeatureNum + 1)];
assistbuffer.set(assist);
}

@@ -183,7 +183,11 @@ public double score(Map<String, Float> features, Object other) {
for (Map.Entry<String, Float> feature : features.entrySet()) {

// field idx
fieldIdxArr[cidx] = field2IndexMap.get(feature.getKey().split(fieldDelim)[0]);
Integer fieldIdx = field2IndexMap.get(feature.getKey().split(fieldDelim)[0]);
if (fieldIdx == null) {
continue;
}
fieldIdxArr[cidx] = fieldIdx;

float val = transform(feature.getKey(), feature.getValue());
// val

0 comments on commit aa26f8a

Please sign in to comment.