Skip to content

Commit

Permalink
MOD:libsvm convert tool & Online predict for multiclass
Browse files Browse the repository at this point in the history
  • Loading branch information
scharoun committed May 19, 2017
1 parent 412dcb9 commit 16b3e26
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 39 deletions.
19 changes: 16 additions & 3 deletions src/main/java/com/fenbi/ytklearn/dataflow/GBDTCoreData.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.fenbi.mp4j.comm.ThreadCommSlave;
import com.fenbi.mp4j.exception.Mp4jException;
import com.fenbi.ytklearn.data.Constants;
import com.fenbi.ytklearn.exception.YtkLearnException;
import com.fenbi.ytklearn.loss.ILossFunction;
import com.fenbi.ytklearn.utils.CheckUtils;
import lombok.Data;
Expand Down Expand Up @@ -226,9 +227,21 @@ protected boolean yExtract(String line, String[] info) throws Exception {

} else { // multiclass softmax
String[] linfo = info[1].split(coreParams.y_delim);
CheckUtils.check(linfo.length == numTreeInGroup, "[GBDT] label num must equal %d, line: %s", numTreeInGroup, line);
for (int i = 0; i < numTreeInGroup; i++) {
label[i] = Float.parseFloat(linfo[i]);
CheckUtils.check(linfo.length == numTreeInGroup || linfo.length == 1, "[GBDT] label num must equal %d, line: %s", numTreeInGroup, line);

if (linfo.length == 1) {
for (int i = 0; i < numTreeInGroup; i++) {
label[i] = 0;
}
int clazz = Integer.parseInt(linfo[0]);
if (clazz >= numTreeInGroup) {
throw new YtkLearnException("multi classification label must in range [0,K-1]!\n" + line);
}
label[clazz] = 1.0f;
} else {
for (int i = 0; i < numTreeInGroup; i++) {
label[i] = Float.parseFloat(linfo[i]);
}
}

CheckUtils.check(obj.checkLabel(label), "[GBDT] all label sum must equal 1.0, line: %s", line);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ public double batchPredictFromFiles(String modelName,
int realcnt = 0;
double weightCnt = 0;
int errorNum = 0;
int K = modelName.equalsIgnoreCase("multiclass_linear") ? config.getInt("k") : -1;

for (String path : paths) {
String predictPath = path + resultFileSuffix;
BufferedReader reader = new BufferedReader(fs.getReader(path));
Expand Down Expand Up @@ -253,16 +255,34 @@ public double batchPredictFromFiles(String modelName,
double []predicts = predicts(fmap, null); // todo: judge sampleDepdtBaseScore
if (hasLabel) {
if (labels == null) {
labels = new double[linfo.length];
labels = new double[K];
}

if (linfo.length != K && linfo.length != 1) {
throw new Exception("label num must = " + K + ", or = 1, line:" + line);
}
for (int i = 0; i < linfo.length; i++) {
labels[i] = Float.parseFloat(linfo[i]);

if (linfo.length == 1) {
for (int i = 0; i < K; i++) {
labels[i] = 0;
}
int clazz = Integer.parseInt(linfo[0]);
if (clazz >= K) {
throw new YtkLearnException("multi classification label must in range [0,K-1]!\n" + line);
}
labels[clazz] = 1.0f;
} else if (linfo.length == K){
for (int i = 0; i < K; i++) {
labels[i] = Float.parseFloat(linfo[i]);
}
} else {
throw new YtkLearnException("multi classification label error:" + line);
}

loss += weight * loss(fmap, labels, null);
if (needEval) {
if (testData == null) {
testData = new PredictCoreData(null, linfo.length);
testData = new PredictCoreData(null, K);
}
testData.addPredict(predicts, labels, weight);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,21 +379,36 @@ public double batchPredictFromFiles(String modelName,

if (hasLabel) {
if (labels == null) {
labels = new double[linfo.length];
}
for (int i = 0; i < linfo.length; i++) {
labels[i] = Float.parseFloat(linfo[i]);
labels = new double[numTreeInGroup];
}

if (numTreeInGroup == 1) {
loss += weight * loss(fmap, labels[0], sampleDepBasePrediction? otherinfo[0]: null);
} else {
if (numTreeInGroup > 1) {
if (linfo.length == 1) {
for (int i = 0; i < numTreeInGroup; i++) {
labels[i] = 0;
}
int clazz = Integer.parseInt(linfo[0]);
if (clazz >= numTreeInGroup) {
throw new YtkLearnException("multi classification label must in range [0,K-1]!\n" + line);
}
labels[clazz] = 1.0f;
} else if (linfo.length == numTreeInGroup){
for (int i = 0; i < numTreeInGroup; i++) {
labels[i] = Float.parseFloat(linfo[i]);
}
} else {
throw new YtkLearnException("label format error:" + line);
}

loss += weight * loss(fmap, labels, otherinfo);
} else {
labels[0] = Float.parseFloat(linfo[0]);
loss += weight * loss(fmap, labels[0], sampleDepBasePrediction? otherinfo[0]: null);
}

if (needEval) {
if (testData == null) {
testData = new PredictCoreData(null, linfo.length);
testData = new PredictCoreData(null, numTreeInGroup);
}
testData.addPredict(predicts, labels, weight);
}
Expand Down
29 changes: 5 additions & 24 deletions src/main/java/com/fenbi/ytklearn/utils/LibsvmConvertTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,18 @@ public static void main(String []args) {
String inputPath = args[6];
String outputPath = args[7];

Map<Integer, String> kLabelStrMap = new HashMap<>();
BufferedReader reader = null;
PrintWriter writer = null;
int cnt = 0;
String line = "";

int k = 2;
Map<String, Integer> kLabel2IndexMap = new HashMap<>();
Map<String, String> kLabel2LabelMap = new HashMap<>();
if (mode.contains("classification")) {
String []labelinfo = mode.split("@")[1].trim().split(",");
k = labelinfo.length;
for (int i = 0; i < k; i++) {
kLabel2IndexMap.put(labelinfo[i], i);
if (mode.startsWith("binary")) {
kLabel2LabelMap.put(labelinfo[i], i + "");
} else {
kLabel2LabelMap.put(labelinfo[i], createKLabelStr(k, i, y_delim));
}
}
}

Expand All @@ -102,7 +95,7 @@ public static void main(String []args) {
reader = new BufferedReader(fs.getReader(inputPath));
writer = new PrintWriter(fs.getWriter(outputPath));


LOG.info("libsvm format data path:" + inputPath);
while((line = reader.readLine()) != null) {
StringBuilder sb = new StringBuilder();
String []info = line.trim().split("\\s+");
Expand All @@ -126,7 +119,6 @@ public static void main(String []args) {

kcnt[label] ++;
} else if (mode.startsWith("multi_classification")) {
//int label = Integer.parseInt(info[0]);
Integer label = kLabel2IndexMap.get(info[0]);
if (label == null) {
throw new Exception("unknown label:" + info[0]);
Expand All @@ -135,12 +127,7 @@ public static void main(String []args) {
if (label < 0 || label >= k) {
throw new Exception("error libsvm format for mode:" + mode + " - " + line);
}
String kStr = kLabelStrMap.get(label);
if (kStr == null) {
kStr = createKLabelStr(k, label, y_delim);
kLabelStrMap.put(label, kStr);
}
sb.append(kStr).append(x_delim);
sb.append(label).append(x_delim);

kcnt[label] ++;
} else if (mode.startsWith("regression")) {
Expand All @@ -162,27 +149,21 @@ public static void main(String []args) {
}
}

// if (cnt < 10) {
// LOG.info("libsvm format:" + line);
// LOG.info("ytk-learn format:" + sb.toString());
// }

writer.println(sb.toString());

cnt++;
}

LOG.info("convert finished! convert count:" + cnt);
if (mode.contains("classification")) {
LOG.info("classification label stat:" + Arrays.toString(kcnt));
}

if (mode.contains("classification")) {
for (Map.Entry<String, Integer> entry : kLabel2IndexMap.entrySet()) {
LOG.info("libsvm classification label:" + entry.getKey() + " ----> ytklearn classification label:" + kLabel2LabelMap.get(entry.getKey()));
LOG.info("libsvm classification label:" + entry.getKey() + " ----> ytklearn classification label:" + entry.getValue() + ", count:" + kcnt[entry.getValue()]);
}
}

LOG.info("ytk-learn format data path:" + outputPath);


} catch (Exception e) {
LOG.error("error", e);
Expand Down

0 comments on commit 16b3e26

Please sign in to comment.