From c51799d8702bc6987b2fcdf8ea688355a42d51fa Mon Sep 17 00:00:00 2001 From: Charles Dickens Date: Thu, 30 Nov 2023 14:24:39 -0800 Subject: [PATCH] Only evaluate on batches with a non-empty training map. --- .../learning/weight/gradient/GradientDescent.java | 11 ++++++++--- .../gradient/batchgenerator/BatchGenerator.java | 5 +++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java index 67d58d893..6aa4e8f4d 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java @@ -516,13 +516,18 @@ protected void setValidationModel() { } protected void runTrainingEvaluation(int epoch) { - int numBatches = 0; + int numEvaluatedBatches = 0; float totalTrainingEvaluation = 0.0f; DeepPredicate.evalModeAllDeepPredicates(); int batchId = batchGenerator.epochStart(); while (!batchGenerator.isEpochComplete()) { + if (batchGenerator.getBatchTrainingMap(batchId).getLabelMap().size() <= 0) { + batchId = batchGenerator.nextBatch(); + continue; + } + setBatch(batchId); DeepPredicate.predictAllDeepPredicates(); DeepPredicate.evalAllDeepPredicates(); @@ -536,11 +541,11 @@ protected void runTrainingEvaluation(int epoch) { batchId = batchGenerator.nextBatch(); - numBatches++; + numEvaluatedBatches++; } batchGenerator.epochEnd(); - currentTrainingEvaluationMetric = totalTrainingEvaluation / numBatches; + currentTrainingEvaluationMetric = totalTrainingEvaluation / numEvaluatedBatches; if (currentTrainingEvaluationMetric > bestTrainingEvaluationMetric) { lastTrainingImprovementEpoch = epoch; diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/BatchGenerator.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/BatchGenerator.java index fbaed727e..ad0363e9b 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/BatchGenerator.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/BatchGenerator.java @@ -24,6 +24,7 @@ import org.linqs.psl.model.predicate.DeepPredicate; import org.linqs.psl.reasoner.term.ReasonerTerm; import org.linqs.psl.reasoner.term.SimpleTermStore; +import org.linqs.psl.util.Logger; import org.linqs.psl.util.RandUtils; import org.linqs.psl.util.Reflection; @@ -37,6 +38,8 @@ * A batch in this case is a set of terms and corresponding atoms defining a subgraph of the complete factor graph. */ public abstract class BatchGenerator { + private static final Logger log = Logger.getLogger(BatchGenerator.class); + protected InferenceApplication inferenceApplication; protected SimpleTermStore fullTermStore; protected AtomStore fullTruthAtomStore; @@ -136,6 +139,8 @@ public void generateBatches() { for (int i = 0; i < numBatchTermStores(); i++) { batchPermutation.add(i); } + + log.info("Generated " + numBatchTermStores() + " batches."); } protected abstract void generateBatchesInternal();