Skip to content

Commit

Permalink
Only evaluate on batches with a non-empty training map.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Nov 30, 2023
1 parent 936aa5d commit c51799d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -516,13 +516,18 @@ protected void setValidationModel() {
}

protected void runTrainingEvaluation(int epoch) {
int numBatches = 0;
int numEvaluatedBatches = 0;
float totalTrainingEvaluation = 0.0f;

DeepPredicate.evalModeAllDeepPredicates();

int batchId = batchGenerator.epochStart();
while (!batchGenerator.isEpochComplete()) {
if (batchGenerator.getBatchTrainingMap(batchId).getLabelMap().size() <= 0) {
batchId = batchGenerator.nextBatch();
continue;
}

setBatch(batchId);
DeepPredicate.predictAllDeepPredicates();
DeepPredicate.evalAllDeepPredicates();
Expand All @@ -536,11 +541,11 @@ protected void runTrainingEvaluation(int epoch) {

batchId = batchGenerator.nextBatch();

numBatches++;
numEvaluatedBatches++;
}
batchGenerator.epochEnd();

currentTrainingEvaluationMetric = totalTrainingEvaluation / numBatches;
currentTrainingEvaluationMetric = totalTrainingEvaluation / numEvaluatedBatches;

if (currentTrainingEvaluationMetric > bestTrainingEvaluationMetric) {
lastTrainingImprovementEpoch = epoch;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.SimpleTermStore;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.RandUtils;
import org.linqs.psl.util.Reflection;

Expand All @@ -37,6 +38,8 @@
* A batch in this case is a set of terms and corresponding atoms defining a subgraph of the complete factor graph.
*/
public abstract class BatchGenerator {
private static final Logger log = Logger.getLogger(BatchGenerator.class);

protected InferenceApplication inferenceApplication;
protected SimpleTermStore<? extends ReasonerTerm> fullTermStore;
protected AtomStore fullTruthAtomStore;
Expand Down Expand Up @@ -136,6 +139,8 @@ public void generateBatches() {
for (int i = 0; i < numBatchTermStores(); i++) {
batchPermutation.add(i);
}

log.info("Generated " + numBatchTermStores() + " batches.");
}

protected abstract void generateBatchesInternal();
Expand Down

0 comments on commit c51799d

Please sign in to comment.