diff --git a/src/main/java/edu/snu/reef/dolphin/neuralnet/NeuralNetworkTask.java b/src/main/java/edu/snu/reef/dolphin/neuralnet/NeuralNetworkTask.java index 6247a8e..2d99618 100644 --- a/src/main/java/edu/snu/reef/dolphin/neuralnet/NeuralNetworkTask.java +++ b/src/main/java/edu/snu/reef/dolphin/neuralnet/NeuralNetworkTask.java @@ -38,7 +38,8 @@ public final class NeuralNetworkTask implements Task { private static final Logger LOG = Logger.getLogger(NeuralNetworkTask.class.getName()); - private final Validator validator; + private final Validator crossValidator; + private final Validator trainingValidator; private final DataParser, Boolean>>> dataParser; private final NeuralNetwork neuralNetwork; private final int maxIterations; @@ -88,10 +89,17 @@ public void reset() { /** * @return the prediction accuracy of model. */ - public float getStats() { + public float getAccuracy() { return correctNum / (float) totalNum; } + /** + * @return the prediction error of model. + */ + public float getError() { + return 1 - getAccuracy(); + } + /** * @return the total number of samples that are used for evaluation. */ @@ -103,13 +111,13 @@ public int getTotalNum() { @Inject NeuralNetworkTask(final DataParser, Boolean>>> dataParser, final NeuralNetwork neuralNetwork, - @Parameter(MaxIterations.class) final int maxIterations, - final Validator validator) { + @Parameter(MaxIterations.class) final int maxIterations) { super(); this.dataParser = dataParser; this.neuralNetwork = neuralNetwork; this.maxIterations = maxIterations; - this.validator = validator; + this.trainingValidator = new Validator(neuralNetwork); + this.crossValidator = new Validator(neuralNetwork); } /** {@inheritDoc} */ @@ -125,17 +133,21 @@ public byte[] call(final byte[] bytes) throws Exception { final int label = data.getFirst().getSecond(); final boolean isValidation = data.getSecond(); if (isValidation) { - validator.validate(input, label); + crossValidator.validate(input, label); } else { neuralNetwork.train(input, label); + trainingValidator.validate(input, label); } } LOG.log(Level.INFO, "========================================================="); LOG.log(Level.INFO, "Iteration: {0}", String.valueOf(i)); - LOG.log(Level.INFO, "Result: {0}", String.valueOf(validator.getStats())); - LOG.log(Level.INFO, "# of validation inputs: {0}", String.valueOf(validator.getTotalNum())); + LOG.log(Level.INFO, "Training Error: {0}", String.valueOf(trainingValidator.getError())); + LOG.log(Level.INFO, "Cross Validation Error: {0}", String.valueOf(crossValidator.getError())); + LOG.log(Level.INFO, "# of training inputs: {0}", String.valueOf(trainingValidator.getTotalNum())); + LOG.log(Level.INFO, "# of validation inputs: {0}", String.valueOf(crossValidator.getTotalNum())); LOG.log(Level.INFO, "========================================================="); - validator.reset(); + crossValidator.reset(); + trainingValidator.reset(); } return null;