Skip to content

Commit

Permalink
Merge pull request #95 from dongjoon-hyun/DOLPHIN-94
Browse files Browse the repository at this point in the history
[DOLPHIN-94] Show both training/cross-validation error to check overfiting
  • Loading branch information
jsjason committed Aug 30, 2015
2 parents 53cbf18 + 24b7518 commit 194900d
Showing 1 changed file with 21 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ public final class NeuralNetworkTask implements Task {

private static final Logger LOG = Logger.getLogger(NeuralNetworkTask.class.getName());

private final Validator validator;
private final Validator crossValidator;
private final Validator trainingValidator;
private final DataParser<List<Pair<Pair<INDArray, Integer>, Boolean>>> dataParser;
private final NeuralNetwork neuralNetwork;
private final int maxIterations;
Expand Down Expand Up @@ -88,10 +89,17 @@ public void reset() {
/**
* @return the prediction accuracy of model.
*/
public float getStats() {
public float getAccuracy() {
return correctNum / (float) totalNum;
}

/**
* @return the prediction error of model.
*/
public float getError() {
return 1 - getAccuracy();
}

/**
* @return the total number of samples that are used for evaluation.
*/
Expand All @@ -103,13 +111,13 @@ public int getTotalNum() {
@Inject
NeuralNetworkTask(final DataParser<List<Pair<Pair<INDArray, Integer>, Boolean>>> dataParser,
final NeuralNetwork neuralNetwork,
@Parameter(MaxIterations.class) final int maxIterations,
final Validator validator) {
@Parameter(MaxIterations.class) final int maxIterations) {
super();
this.dataParser = dataParser;
this.neuralNetwork = neuralNetwork;
this.maxIterations = maxIterations;
this.validator = validator;
this.trainingValidator = new Validator(neuralNetwork);
this.crossValidator = new Validator(neuralNetwork);
}

/** {@inheritDoc} */
Expand All @@ -125,17 +133,21 @@ public byte[] call(final byte[] bytes) throws Exception {
final int label = data.getFirst().getSecond();
final boolean isValidation = data.getSecond();
if (isValidation) {
validator.validate(input, label);
crossValidator.validate(input, label);
} else {
neuralNetwork.train(input, label);
trainingValidator.validate(input, label);
}
}
LOG.log(Level.INFO, "=========================================================");
LOG.log(Level.INFO, "Iteration: {0}", String.valueOf(i));
LOG.log(Level.INFO, "Result: {0}", String.valueOf(validator.getStats()));
LOG.log(Level.INFO, "# of validation inputs: {0}", String.valueOf(validator.getTotalNum()));
LOG.log(Level.INFO, "Training Error: {0}", String.valueOf(trainingValidator.getError()));
LOG.log(Level.INFO, "Cross Validation Error: {0}", String.valueOf(crossValidator.getError()));
LOG.log(Level.INFO, "# of training inputs: {0}", String.valueOf(trainingValidator.getTotalNum()));
LOG.log(Level.INFO, "# of validation inputs: {0}", String.valueOf(crossValidator.getTotalNum()));
LOG.log(Level.INFO, "=========================================================");
validator.reset();
crossValidator.reset();
trainingValidator.reset();
}

return null;
Expand Down

0 comments on commit 194900d

Please sign in to comment.