From a289884fdeba8a542c9cba1b64c3fdc84e2380df Mon Sep 17 00:00:00 2001 From: Charles Dickens Date: Thu, 25 Jan 2024 09:56:21 -0800 Subject: [PATCH] Policy gradient reimplementation progress. --- .../weight/gradient/GradientDescent.java | 52 ++-- .../batchgenerator/BatchGenerator.java | 6 + .../weight/gradient/minimizer/Minimizer.java | 17 +- .../OnlineBinaryCrossEntropy.java | 4 +- .../onlineminimizer/OnlineMinimizer.java | 18 +- .../onlineminimizer/OnlineSquaredError.java | 2 +- .../weight/gradient/optimalvalue/Energy.java | 8 +- .../optimalvalue/StructuredPerceptron.java | 13 +- .../policygradient/PolicyGradient.java | 247 ++++++++---------- .../PolicyGradientBinaryCrossEntropy.java | 10 +- .../PolicyGradientSquaredError.java | 8 +- .../java/org/linqs/psl/config/Options.java | 6 + .../linqs/psl/model/predicate/Predicate.java | 8 +- .../PolicyGradientBinaryCrossEntropyTest.java | 1 + .../PolicyGradientSquaredErrorTest.java | 1 + 15 files changed, 198 insertions(+), 203 deletions(-) diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java index d3b9be16b..e3966b78b 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java @@ -66,10 +66,10 @@ public static enum SymbolicWeightUpdate { protected Map ruleIndexMap; protected float[] weightGradient; - protected float[] rvAtomGradient; - protected float[] deepAtomGradient; - protected float[] MAPRVAtomEnergyGradient; - protected float[] MAPDeepAtomEnergyGradient; + protected float[] rvGradient; + protected float[] deepGradient; + protected float[] MAPRVEnergyGradient; + protected float[] MAPDeepEnergyGradient; protected float[] epochStartWeights; protected float epochDeepAtomValueMovement; @@ -135,10 +135,10 @@ public GradientDescent(List rules, Database trainTargetDatabase, Database } weightGradient = new float[mutableRules.size()]; - rvAtomGradient = null; - deepAtomGradient = null; - MAPRVAtomEnergyGradient = null; - MAPDeepAtomEnergyGradient = null; + rvGradient = null; + deepGradient = null; + MAPRVEnergyGradient = null; + MAPDeepEnergyGradient = null; trainingEvaluationComputePeriod = Options.WLA_GRADIENT_DESCENT_TRAINING_COMPUTE_PERIOD.getInt(); trainFullTermStore = null; @@ -224,7 +224,7 @@ protected void validateState() { } protected void initializeFullModels() { - trainFullTermStore = (SimpleTermStore)trainInferenceApplication.getTermStore(); + trainFullTermStore = (SimpleTermStore) trainInferenceApplication.getTermStore(); fullTrainingMap = trainingMap; @@ -281,11 +281,11 @@ protected void initializeBatchWarmStarts() { } protected void initializeGradients() { - rvAtomGradient = new float[trainFullMAPAtomValueState.length]; - deepAtomGradient = new float[trainFullMAPAtomValueState.length]; + rvGradient = new float[trainFullMAPAtomValueState.length]; + deepGradient = new float[trainFullMAPAtomValueState.length]; - MAPRVAtomEnergyGradient = new float[trainFullMAPAtomValueState.length]; - MAPDeepAtomEnergyGradient = new float[trainFullMAPAtomValueState.length]; + MAPRVEnergyGradient = new float[trainFullMAPAtomValueState.length]; + MAPDeepEnergyGradient = new float[trainFullMAPAtomValueState.length]; } protected void initForLearning() { @@ -361,10 +361,12 @@ protected void doLearn() { setBatch(batchId); DeepPredicate.predictAllDeepPredicates(); + resetGradients(); + computeIterationStatistics(); computeTotalWeightGradient(); - computeTotalAtomGradient(); + addTotalAtomGradient(); if (clipWeightGradient) { clipWeightGradient(); } @@ -499,6 +501,14 @@ protected void setFullModel() { } } + protected void resetGradients() { + Arrays.fill(weightGradient, 0.0f); + Arrays.fill(rvGradient, 0.0f); + Arrays.fill(deepGradient, 0.0f); + Arrays.fill(MAPRVEnergyGradient, 0.0f); + Arrays.fill(MAPDeepEnergyGradient, 0.0f); + } + protected void setBatch(int batch) { SimpleTermStore batchTermStore = batchGenerator.getBatchTermStore(batch); trainDeepModelPredicates = batchGenerator.getBatchDeepModelPredicates(batch); @@ -705,7 +715,7 @@ protected void weightGradientStep(int epoch) { protected void atomGradientStep() { for (DeepPredicate deepPredicate : deepPredicates) { - deepPredicate.fitDeepPredicate(deepAtomGradient); + deepPredicate.fitDeepPredicate(deepGradient); } } @@ -735,9 +745,9 @@ protected float computeGradientNorm() { } log.trace("Weight Gradient Norm: {}", norm); - log.trace("Deep atom Gradient Norm: {}", MathUtils.pNorm(deepAtomGradient, 2)); + log.trace("Deep atom Gradient Norm: {}", MathUtils.pNorm(deepGradient, 2)); - norm += MathUtils.pNorm(deepAtomGradient, 2); + norm += MathUtils.pNorm(deepGradient, 2); return norm; } @@ -968,8 +978,6 @@ protected float computeRegularization() { * Compute the gradient of the regularized learning loss with respect to the weights. */ protected void computeTotalWeightGradient() { - Arrays.fill(weightGradient, 0.0f); - if (!symbolicWeightLearning) { return; } @@ -989,11 +997,11 @@ protected void computeTotalWeightGradient() { protected void addRegularizationWeightGradient() { for (int i = 0; i < mutableRules.size(); i++) { float logWeight = (float)Math.log(Math.max(mutableRules.get(i).getWeight(), MathUtils.STRICT_EPSILON)); - weightGradient[i] += 2.0f * l2Regularization * mutableRules.get(i).getWeight() + weightGradient[i] += (float) (2.0f * l2Regularization * mutableRules.get(i).getWeight() - logRegularization / Math.max(mutableRules.get(i).getWeight(), MathUtils.STRICT_EPSILON) - + entropyRegularization * (logWeight + 1); + + entropyRegularization * (logWeight + 1)); } } - protected abstract void computeTotalAtomGradient(); + protected abstract void addTotalAtomGradient(); } diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/BatchGenerator.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/BatchGenerator.java index fbaed727e..5019cbe5d 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/BatchGenerator.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/batchgenerator/BatchGenerator.java @@ -26,6 +26,8 @@ import org.linqs.psl.reasoner.term.SimpleTermStore; import org.linqs.psl.util.RandUtils; import org.linqs.psl.util.Reflection; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; @@ -37,6 +39,8 @@ * A batch in this case is a set of terms and corresponding atoms defining a subgraph of the complete factor graph. */ public abstract class BatchGenerator { + private static final Logger log = LoggerFactory.getLogger(BatchGenerator.class); + protected InferenceApplication inferenceApplication; protected SimpleTermStore fullTermStore; protected AtomStore fullTruthAtomStore; @@ -111,6 +115,8 @@ public TrainingMap getBatchTrainingMap(int index) { } public void generateBatches() { + log.trace("Generating batches."); + clear(); generateBatchesInternal(); diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java index 2b3819b41..b48098dcb 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/minimizer/Minimizer.java @@ -545,7 +545,7 @@ private void computeMAPInferenceStatistics() { mapEnergy = trainInferenceApplication.getReasoner().parallelComputeObjective(trainInferenceApplication.getTermStore()).objective; computeCurrentIncompatibility(mapIncompatibility); computeCurrentSquaredIncompatibility(mapSquaredIncompatibility); - trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), MAPRVAtomEnergyGradient, MAPDeepAtomEnergyGradient); + trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), MAPRVEnergyGradient, MAPDeepEnergyGradient); } /** @@ -715,10 +715,7 @@ protected void addLearningLossWeightGradient() { } @Override - protected void computeTotalAtomGradient() { - Arrays.fill(rvAtomGradient, 0.0f); - Arrays.fill(deepAtomGradient, 0.0f); - + protected void addTotalAtomGradient() { // Energy Loss Gradient. for (int i = 0; i < trainInferenceApplication.getTermStore().getAtomStore().size(); i++) { GroundAtom atom = trainInferenceApplication.getTermStore().getAtomStore().getAtom(i); @@ -727,7 +724,7 @@ protected void computeTotalAtomGradient() { continue; } - deepAtomGradient[i] += energyLossCoefficient * deepLatentAtomGradient[i]; + deepGradient[i] += energyLossCoefficient * deepLatentAtomGradient[i]; } // Energy difference constraint gradient. @@ -745,12 +742,12 @@ protected void computeTotalAtomGradient() { continue; } - float rvEnergyGradientDifference = augmentedRVAtomEnergyGradient[i] - MAPRVAtomEnergyGradient[i]; - float deepAtomEnergyGradientDifference = augmentedDeepAtomEnergyGradient[i] - MAPDeepAtomEnergyGradient[i]; + float rvEnergyGradientDifference = augmentedRVAtomEnergyGradient[i] - MAPRVEnergyGradient[i]; + float deepAtomEnergyGradientDifference = augmentedDeepAtomEnergyGradient[i] - MAPDeepEnergyGradient[i]; - rvAtomGradient[i] += squaredPenaltyCoefficient * constraintViolation * rvEnergyGradientDifference + rvGradient[i] += squaredPenaltyCoefficient * constraintViolation * rvEnergyGradientDifference + linearPenaltyCoefficient * rvEnergyGradientDifference; - deepAtomGradient[i] += squaredPenaltyCoefficient * constraintViolation * deepAtomEnergyGradientDifference + deepGradient[i] += squaredPenaltyCoefficient * constraintViolation * deepAtomEnergyGradientDifference + linearPenaltyCoefficient * deepAtomEnergyGradientDifference; } } diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/onlineminimizer/OnlineBinaryCrossEntropy.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/onlineminimizer/OnlineBinaryCrossEntropy.java index 0481f6948..893f2788a 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/onlineminimizer/OnlineBinaryCrossEntropy.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/onlineminimizer/OnlineBinaryCrossEntropy.java @@ -54,8 +54,8 @@ protected float computeSupervisedLoss() { int proxRuleIndex = rvAtomIndexToProxRuleIndex.get(atomIndex); - supervisedLoss += -1.0f * (observedAtom.getValue() * Math.log(Math.max(proxRuleObservedAtoms[proxRuleIndex].getValue(), MathUtils.EPSILON_FLOAT)) - + (1.0f - observedAtom.getValue()) * Math.log(Math.max(1.0f - proxRuleObservedAtoms[proxRuleIndex].getValue(), MathUtils.EPSILON_FLOAT))); + supervisedLoss += (float) (-1.0f * (observedAtom.getValue() * Math.log(Math.max(proxRuleObservedAtoms[proxRuleIndex].getValue(), MathUtils.EPSILON_FLOAT)) + + (1.0f - observedAtom.getValue()) * Math.log(Math.max(1.0f - proxRuleObservedAtoms[proxRuleIndex].getValue(), MathUtils.EPSILON_FLOAT)))); } return supervisedLoss; diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/onlineminimizer/OnlineMinimizer.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/onlineminimizer/OnlineMinimizer.java index 83f6038b4..c6baa19a0 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/onlineminimizer/OnlineMinimizer.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/onlineminimizer/OnlineMinimizer.java @@ -18,7 +18,6 @@ package org.linqs.psl.application.learning.weight.gradient.onlineminimizer; import org.linqs.psl.application.learning.weight.gradient.GradientDescent; -import org.linqs.psl.application.learning.weight.gradient.batchgenerator.NeuralBatchGenerator; import org.linqs.psl.config.Options; import org.linqs.psl.database.AtomStore; import org.linqs.psl.database.Database; @@ -474,7 +473,7 @@ private void computeMAPInferenceStatistics() { mapEnergy = trainInferenceApplication.getReasoner().parallelComputeObjective(trainInferenceApplication.getTermStore()).objective; computeCurrentIncompatibility(mapIncompatibility); computeCurrentSquaredIncompatibility(mapSquaredIncompatibility); - trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), MAPRVAtomEnergyGradient, MAPDeepAtomEnergyGradient); + trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), MAPRVEnergyGradient, MAPDeepEnergyGradient); } /** @@ -634,10 +633,7 @@ protected void addLearningLossWeightGradient() { } @Override - protected void computeTotalAtomGradient() { - Arrays.fill(rvAtomGradient, 0.0f); - Arrays.fill(deepAtomGradient, 0.0f); - + protected void addTotalAtomGradient() { // Energy Loss Gradient. for (int i = 0; i < trainInferenceApplication.getTermStore().getAtomStore().size(); i++) { GroundAtom atom = trainInferenceApplication.getTermStore().getAtomStore().getAtom(i); @@ -646,7 +642,7 @@ protected void computeTotalAtomGradient() { continue; } - deepAtomGradient[i] += energyLossCoefficient * deepLatentAtomGradient[i]; + deepGradient[i] += energyLossCoefficient * deepLatentAtomGradient[i]; } // Energy difference constraint gradient. @@ -664,12 +660,12 @@ protected void computeTotalAtomGradient() { continue; } - float rvEnergyGradientDifference = augmentedRVAtomEnergyGradient[i] - MAPRVAtomEnergyGradient[i]; - float deepAtomEnergyGradientDifference = augmentedDeepAtomEnergyGradient[i] - MAPDeepAtomEnergyGradient[i]; + float rvEnergyGradientDifference = augmentedRVAtomEnergyGradient[i] - MAPRVEnergyGradient[i]; + float deepAtomEnergyGradientDifference = augmentedDeepAtomEnergyGradient[i] - MAPDeepEnergyGradient[i]; - rvAtomGradient[i] += squaredPenaltyCoefficient * constraintViolation * rvEnergyGradientDifference + rvGradient[i] += squaredPenaltyCoefficient * constraintViolation * rvEnergyGradientDifference + linearPenaltyCoefficient * rvEnergyGradientDifference; - deepAtomGradient[i] += squaredPenaltyCoefficient * constraintViolation * deepAtomEnergyGradientDifference + deepGradient[i] += squaredPenaltyCoefficient * constraintViolation * deepAtomEnergyGradientDifference + linearPenaltyCoefficient * deepAtomEnergyGradientDifference; } } diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/onlineminimizer/OnlineSquaredError.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/onlineminimizer/OnlineSquaredError.java index 10ea1e6dc..f5518bcf6 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/onlineminimizer/OnlineSquaredError.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/onlineminimizer/OnlineSquaredError.java @@ -51,7 +51,7 @@ protected float computeSupervisedLoss() { continue; } - supervisedLoss += Math.pow(proxRuleObservedAtoms[rvAtomIndexToProxRuleIndex.get(atomIndex)].getValue() - observedAtom.getValue(), 2.0f); + supervisedLoss += (float) Math.pow(proxRuleObservedAtoms[rvAtomIndexToProxRuleIndex.get(atomIndex)].getValue() - observedAtom.getValue(), 2.0f); } return supervisedLoss; diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/Energy.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/Energy.java index 35e9b1222..281653539 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/Energy.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/Energy.java @@ -20,7 +20,6 @@ import org.linqs.psl.database.Database; import org.linqs.psl.model.rule.Rule; -import java.util.Arrays; import java.util.List; /** @@ -57,10 +56,7 @@ protected void addLearningLossWeightGradient() { } @Override - protected void computeTotalAtomGradient() { - Arrays.fill(rvAtomGradient, 0.0f); - Arrays.fill(deepAtomGradient, 0.0f); - - System.arraycopy(deepLatentAtomGradient, 0, deepAtomGradient, 0, deepLatentAtomGradient.length); + protected void addTotalAtomGradient() { + System.arraycopy(deepLatentAtomGradient, 0, deepGradient, 0, deepLatentAtomGradient.length); } } diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/StructuredPerceptron.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/StructuredPerceptron.java index 5680f6705..2b8abab33 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/StructuredPerceptron.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/optimalvalue/StructuredPerceptron.java @@ -55,7 +55,7 @@ private void computeMAPInferenceIncompatibility() { inTrainingMAPState = true; computeCurrentIncompatibility(MAPIncompatibility); - trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), MAPRVAtomEnergyGradient, MAPDeepAtomEnergyGradient); + trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), MAPRVEnergyGradient, MAPDeepEnergyGradient); } @Override @@ -76,13 +76,10 @@ protected void addLearningLossWeightGradient() { } @Override - protected void computeTotalAtomGradient() { - Arrays.fill(rvAtomGradient, 0.0f); - Arrays.fill(deepAtomGradient, 0.0f); - - for (int i = 0; i < rvAtomGradient.length; i++) { - rvAtomGradient[i] = rvLatentAtomGradient[i] - MAPRVAtomEnergyGradient[i]; - deepAtomGradient[i] = deepLatentAtomGradient[i] - MAPDeepAtomEnergyGradient[i]; + protected void addTotalAtomGradient() { + for (int i = 0; i < rvGradient.length; i++) { + rvGradient[i] = rvLatentAtomGradient[i] - MAPRVEnergyGradient[i]; + deepGradient[i] = deepLatentAtomGradient[i] - MAPDeepEnergyGradient[i]; } } } diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradient.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradient.java index 49f09961f..d14958c6a 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradient.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradient.java @@ -8,6 +8,7 @@ import org.linqs.psl.model.atom.ObservedAtom; import org.linqs.psl.model.atom.RandomVariableAtom; import org.linqs.psl.model.deep.DeepModelPredicate; +import org.linqs.psl.model.predicate.DeepPredicate; import org.linqs.psl.model.rule.Rule; import org.linqs.psl.reasoner.term.ReasonerTerm; import org.linqs.psl.reasoner.term.SimpleTermStore; @@ -31,21 +32,16 @@ public enum DeepAtomPolicyDistribution { } public enum PolicyUpdate { - SCORE, - REINFORCE, - REINFORCE_BASELINE, + REINFORCE } private final DeepAtomPolicyDistribution deepAtomPolicyDistribution; private final PolicyUpdate policyUpdate; - private float score; - private float[] scores; - private float scoreMovingAverage; - private float[] sampleProbabilities; + private int numSamples; + protected int[] actionSampleCounts; protected float[] initialDeepAtomValues; - protected float[] policySampledDeepAtomValues; protected float latentInferenceEnergy; protected float[] latentInferenceIncompatibility; @@ -53,12 +49,15 @@ public enum PolicyUpdate { protected float[] latentInferenceAtomValueState; protected List batchLatentInferenceTermStates; protected List batchLatentInferenceAtomValueStates; - protected float[] rvLatentAtomGradient; - protected float[] deepLatentAtomGradient; + protected float[] rvLatentEnergyGradient; + protected float[] deepLatentEnergyGradient; + protected float[] deepSupervisedLossGradient; + protected float energyLossCoefficient; + protected float MAPStateSupervisedLoss; + protected float mapEnergy; - protected float supervisedLoss; protected float[] mapIncompatibility; public PolicyGradient(List rules, Database trainTargetDatabase, Database trainTruthDatabase, @@ -68,14 +67,9 @@ public PolicyGradient(List rules, Database trainTargetDatabase, Database t deepAtomPolicyDistribution = DeepAtomPolicyDistribution.valueOf(Options.POLICY_GRADIENT_POLICY_DISTRIBUTION.getString().toUpperCase()); policyUpdate = PolicyUpdate.valueOf(Options.POLICY_GRADIENT_POLICY_UPDATE.getString().toUpperCase()); - score = 0.0f; - scores = null; - scoreMovingAverage = 0.0f; - - sampleProbabilities = null; - + numSamples = Options.POLICY_GRADIENT_NUM_SAMPLES.getInt(); + actionSampleCounts = null; initialDeepAtomValues = null; - policySampledDeepAtomValues = null; latentInferenceEnergy = Float.POSITIVE_INFINITY; latentInferenceIncompatibility = new float[mutableRules.size()]; @@ -83,12 +77,15 @@ public PolicyGradient(List rules, Database trainTargetDatabase, Database t latentInferenceAtomValueState = null; batchLatentInferenceTermStates = new ArrayList(); batchLatentInferenceAtomValueStates = new ArrayList(); - rvLatentAtomGradient = null; - deepLatentAtomGradient = null; + rvLatentEnergyGradient = null; + deepLatentEnergyGradient = null; + deepSupervisedLossGradient = null; + energyLossCoefficient = Options.MINIMIZER_ENERGY_LOSS_COEFFICIENT.getFloat(); + MAPStateSupervisedLoss = Float.POSITIVE_INFINITY; + mapEnergy = Float.POSITIVE_INFINITY; - supervisedLoss = Float.POSITIVE_INFINITY; mapIncompatibility = new float[mutableRules.size()]; } @@ -97,17 +94,15 @@ protected void initForLearning() { super.initForLearning(); if (symbolicWeightLearning){ - throw new IllegalArgumentException("Policy Gradient does not support symbolic weight learning."); + throw new IllegalArgumentException("Policy Gradient does not currently support symbolic weight learning."); } - - scoreMovingAverage = Float.POSITIVE_INFINITY; } - protected abstract void computeSupervisedLoss(); + protected abstract float computeSupervisedLoss(); @Override protected float computeLearningLoss() { - return supervisedLoss + energyLossCoefficient * latentInferenceEnergy; + return MAPStateSupervisedLoss + energyLossCoefficient * latentInferenceEnergy; } @Override @@ -125,8 +120,22 @@ protected void initializeBatchWarmStarts() { protected void initializeGradients() { super.initializeGradients(); - rvLatentAtomGradient = new float[trainFullMAPAtomValueState.length]; - deepLatentAtomGradient = new float[trainFullMAPAtomValueState.length]; + rvLatentEnergyGradient = new float[trainFullMAPAtomValueState.length]; + deepLatentEnergyGradient = new float[trainFullMAPAtomValueState.length]; + deepSupervisedLossGradient = new float[trainFullMAPAtomValueState.length]; + + actionSampleCounts = new int[trainFullMAPAtomValueState.length]; + } + + @Override + protected void resetGradients() { + super.resetGradients(); + + Arrays.fill(rvLatentEnergyGradient, 0.0f); + Arrays.fill(deepLatentEnergyGradient, 0.0f); + Arrays.fill(deepSupervisedLossGradient, 0.0f); + + Arrays.fill(actionSampleCounts, 0); } @Override @@ -136,73 +145,72 @@ protected void setBatch(int batch) { latentInferenceTermState = batchLatentInferenceTermStates.get(batch); latentInferenceAtomValueState = batchLatentInferenceAtomValueStates.get(batch); initialDeepAtomValues = new float[trainInferenceApplication.getTermStore().getAtomStore().size()]; - policySampledDeepAtomValues = new float[trainInferenceApplication.getTermStore().getAtomStore().size()]; - scores = new float[trainInferenceApplication.getTermStore().getAtomStore().size()]; } @Override protected void computeIterationStatistics() { - Arrays.fill(scores, 0.0f); + AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); computeMAPInferenceStatistics(); - computeSupervisedLoss(); + + MAPStateSupervisedLoss = computeSupervisedLoss(); + computeLatentInferenceStatistics(); - score = computeScore(); + // Save the initial deep model predictions to reset the deep atom values after computing iteration statistics + // and to compute action probabilities. + System.arraycopy(atomStore.getAtomValues(), 0, initialDeepAtomValues, 0, atomStore.size()); - if (policyUpdate == PolicyUpdate.SCORE) { - return; + switch (policyUpdate) { + case REINFORCE: + addREINFORCESupervisedLossGradient(); + break; + default: + throw new IllegalArgumentException("Unknown policy update: " + policyUpdate); + } + } + + private void addREINFORCESupervisedLossGradient() { + for (int i = 0; i < numSamples; i++) { + sampleAllDeepAtomValues(); + + computeMAPStateWithWarmStart(trainInferenceApplication, trainMAPTermState, trainMAPAtomValueState); + + addSupervisedLossDeepGradient(computeSupervisedLoss()); + + resetAllDeepAtomValues(); } + for (int i = 0; i < deepSupervisedLossGradient.length; i++) { + if (actionSampleCounts[i] == 0) { + deepSupervisedLossGradient[i] = 0.0f; + continue; + } + + log.trace("Atom: {} Deep Supervised Loss Gradient: {}", + trainInferenceApplication.getTermStore().getAtomStore().getAtom(i).toStringWithValue(), deepSupervisedLossGradient[i]); + deepSupervisedLossGradient[i] /= actionSampleCounts[i]; + } + } + + /** + * Sample all deep atom values according to a policy parameterized by the deep model predictions. + */ + protected void sampleAllDeepAtomValues() { for (DeepModelPredicate deepModelPredicate : trainDeepModelPredicates) { Map> atomIdentiferToCategories = deepModelPredicate.getAtomIdentiferToCategories(); for (Map.Entry> entry : atomIdentiferToCategories.entrySet()) { - ArrayList categories = entry.getValue(); - - sampleDeepAtomValues(categories); - - computeMAPInferenceStatistics(); - computeSupervisedLoss(); - computeLatentInferenceStatistics(); - - float sampleScore = computeScore(); - for (RandomVariableAtom category : categories) { - int atomIndex = trainInferenceApplication.getTermStore().getAtomStore().getAtomIndex(category); - if (policySampledDeepAtomValues[atomIndex] == 1.0f) { - scores[atomIndex] = score - sampleScore; -// log.trace("Deep Atom: {} Score: {}", -// trainInferenceApplication.getTermStore().getAtomStore().getAtom(atomIndex), scores[atomIndex]); - } else { - scores[atomIndex] = 0.0f; - } - } - - resetDeepAtomValues(categories); + sampleDeepAtomValues(entry.getValue()); } } - updateScoreMovingAverage(); - - computeMAPInferenceStatistics(); - computeSupervisedLoss(); - computeLatentInferenceStatistics(); + inTrainingMAPState = false; } /** * Sample the deep atom values according to a policy parameterized by the deep model predictions. */ protected void sampleDeepAtomValues(ArrayList categories) { - // Save the initial deep model predictions to reset the deep atom values after computing iteration statistics. - AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); - - Arrays.fill(initialDeepAtomValues, 0.0f); - Arrays.fill(policySampledDeepAtomValues, 0.0f); - - for (RandomVariableAtom category : categories) { - int atomIndex = atomStore.getAtomIndex(category); - initialDeepAtomValues[atomIndex] = atomStore.getAtomValues()[atomIndex]; - } - switch (deepAtomPolicyDistribution) { case CATEGORICAL: sampleCategorical(categories); @@ -216,8 +224,6 @@ private void sampleCategorical(ArrayList categories) { AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); float[] atomValues = atomStore.getAtomValues(); - sampleProbabilities = new float[atomStore.size()]; - // Sample the deep model predictions according to the stochastic categorical policy. float[] categoryProbabilities = new float[categories.size()]; for (int i = 0; i < categories.size(); i++) { @@ -232,23 +238,31 @@ private void sampleCategorical(ArrayList categories) { if (i != sampledCategoryIndex) { categories.get(i).setValue(0.0f); } else { - sampleProbabilities[atomIndex] = categoryProbabilities[i]; categories.get(i).setValue(1.0f); + actionSampleCounts[atomIndex]++; } - policySampledDeepAtomValues[atomIndex] = categories.get(i).getValue(); atomValues[atomIndex] = categories.get(i).getValue(); } } - private void resetDeepAtomValues(ArrayList categories) { + /** + * Reset all deep atom values to their initial values. + */ + private void resetAllDeepAtomValues() { AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); float[] atomValues = atomStore.getAtomValues(); - for (RandomVariableAtom category : categories) { - int atomIndex = atomStore.getAtomIndex(category); - category.setValue(initialDeepAtomValues[atomIndex]); - atomValues[atomIndex] = initialDeepAtomValues[atomIndex]; + for (int i = 0; i < atomStore.size(); i++) { + GroundAtom atom = atomStore.getAtom(i); + + // Skip atoms that are not DeepAtoms. + if (!((atom instanceof RandomVariableAtom) && (atom.getPredicate() instanceof DeepPredicate))) { + continue; + } + + ((RandomVariableAtom) atom).setValue(initialDeepAtomValues[i]); + atomValues[i] = initialDeepAtomValues[i]; } } @@ -262,7 +276,7 @@ private void computeMAPInferenceStatistics() { mapEnergy = trainInferenceApplication.getReasoner().parallelComputeObjective(trainInferenceApplication.getTermStore()).objective; computeCurrentIncompatibility(mapIncompatibility); - trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), MAPRVAtomEnergyGradient, MAPDeepAtomEnergyGradient); + trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), MAPRVEnergyGradient, MAPDeepEnergyGradient); } /** @@ -278,7 +292,7 @@ protected void computeLatentInferenceStatistics() { latentInferenceEnergy = trainInferenceApplication.getReasoner().parallelComputeObjective(trainInferenceApplication.getTermStore()).objective; computeCurrentIncompatibility(latentInferenceIncompatibility); - trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), rvLatentAtomGradient, deepLatentAtomGradient); + trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), rvLatentEnergyGradient, deepLatentEnergyGradient); unfixLabeledRandomVariables(); } @@ -289,21 +303,26 @@ protected void addLearningLossWeightGradient() { } @Override - protected void computeTotalAtomGradient() { - Arrays.fill(rvAtomGradient, 0.0f); - Arrays.fill(deepAtomGradient, 0.0f); + protected void addTotalAtomGradient() { + for (int i = 0; i < rvGradient.length; i++) { + rvGradient[i] = energyLossCoefficient * rvLatentEnergyGradient[i]; + deepGradient[i] = energyLossCoefficient * deepLatentEnergyGradient[i] + deepSupervisedLossGradient[i]; + } + } + private void addSupervisedLossDeepGradient(float supervisedLoss) { AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); - for (int i = 0; i < atomStore.size(); i++) { - GroundAtom atom = atomStore.getAtom(i); + for (int atomIndex = 0; atomIndex < atomStore.size(); atomIndex++) { + GroundAtom atom = atomStore.getAtom(atomIndex); - if (atom instanceof ObservedAtom) { + // Skip atoms that are not DeepAtoms. + if (!((atom instanceof RandomVariableAtom) && (atom.getPredicate() instanceof DeepPredicate))) { continue; } switch (deepAtomPolicyDistribution) { case CATEGORICAL: - computeCategoricalAtomGradient(i); + addCategoricalPolicySupervisedLossGradient(atomIndex, (RandomVariableAtom) atom, supervisedLoss); break; default: throw new IllegalArgumentException("Unknown policy distribution: " + deepAtomPolicyDistribution); @@ -311,60 +330,22 @@ protected void computeTotalAtomGradient() { } } - private void computeCategoricalAtomGradient(int atomIndex) { - if (policySampledDeepAtomValues[atomIndex] == 0.0f) { - deepAtomGradient[atomIndex] = 0.0f; + private void addCategoricalPolicySupervisedLossGradient(int atomIndex, RandomVariableAtom atom, float score) { + // Skip atoms not selected by the policy. + if (atom.getValue() == 0.0f) { return; } switch (policyUpdate) { - case SCORE: - deepAtomGradient[atomIndex] = score; - break; case REINFORCE: - deepAtomGradient[atomIndex] -= scores[atomIndex] / sampleProbabilities[atomIndex]; - break; - case REINFORCE_BASELINE: - deepAtomGradient[atomIndex] -= (scores[atomIndex] - scoreMovingAverage) / sampleProbabilities[atomIndex]; + // The initialDeepAtomValues are the action probabilities. + deepSupervisedLossGradient[atomIndex] += score / initialDeepAtomValues[atomIndex]; break; default: throw new IllegalArgumentException("Unknown policy update: " + policyUpdate); } - -// log.trace("Deep Atom: {} Score: {}, Deep Atom Gradient: {}", -// trainInferenceApplication.getTermStore().getAtomStore().getAtom(atomIndex), scores[atomIndex], deepAtomGradient[atomIndex]); - } - - private float computeScore() { -// log.trace("Latent Inference Energy: " + latentInferenceEnergy + " Supervised Loss: " + supervisedLoss); - return energyLossCoefficient * latentInferenceEnergy + supervisedLoss; - } - - private void updateScoreMovingAverage() { - float scoreAverage = 0.0f; - int numScores = 0; - for (DeepModelPredicate deepModelPredicate : trainDeepModelPredicates) { - Map> atomIdentiferToCategories = deepModelPredicate.getAtomIdentiferToCategories(); - for (Map.Entry> entry : atomIdentiferToCategories.entrySet()) { - ArrayList categories = entry.getValue(); - - for (RandomVariableAtom category : categories) { - int atomIndex = trainInferenceApplication.getTermStore().getAtomStore().getAtomIndex(category); - scoreAverage += scores[atomIndex]; - numScores += 1; - } - } - } - scoreAverage /= numScores; - - if (!Float.isInfinite(scoreMovingAverage)) { - scoreMovingAverage = 0.9f * scoreMovingAverage + 0.1f * scoreAverage; - } else { - scoreMovingAverage = scoreAverage; - } } - /** * Set RandomVariableAtoms with labels to their observed (truth) value. * This method relies on random variable atoms and observed atoms diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropy.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropy.java index 799f91beb..a6091ca65 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropy.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropy.java @@ -41,10 +41,10 @@ public PolicyGradientBinaryCrossEntropy(List rules, Database trainTargetDa } @Override - protected void computeSupervisedLoss() { + protected float computeSupervisedLoss() { AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); - supervisedLoss = 0.0f; + float supervisedLoss = 0.0f; int numEvaluatedAtoms = 0; for (Map.Entry entry : trainingMap.getLabelMap().entrySet()) { RandomVariableAtom randomVariableAtom = entry.getKey(); @@ -56,8 +56,8 @@ protected void computeSupervisedLoss() { continue; } - supervisedLoss += -1.0f * (observedAtom.getValue() * Math.log(Math.max(atomStore.getAtom(atomIndex).getValue(), MathUtils.EPSILON_FLOAT)) - + (1.0f - observedAtom.getValue()) * Math.log(Math.max(1.0f - atomStore.getAtom(atomIndex).getValue(), MathUtils.EPSILON_FLOAT))); + supervisedLoss += (float) (-1.0f * (observedAtom.getValue() * Math.log(Math.max(atomStore.getAtom(atomIndex).getValue(), MathUtils.EPSILON_FLOAT)) + + (1.0f - observedAtom.getValue()) * Math.log(Math.max(1.0f - atomStore.getAtom(atomIndex).getValue(), MathUtils.EPSILON_FLOAT)))); numEvaluatedAtoms++; } @@ -65,5 +65,7 @@ protected void computeSupervisedLoss() { if (numEvaluatedAtoms > 0) { supervisedLoss /= numEvaluatedAtoms; } + + return supervisedLoss; } } diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredError.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredError.java index 77b356d38..0c0a12a71 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredError.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredError.java @@ -37,10 +37,10 @@ public PolicyGradientSquaredError(List rules, Database trainTargetDatabase } @Override - protected void computeSupervisedLoss() { + protected float computeSupervisedLoss() { AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); - supervisedLoss = 0.0f; + float supervisedLoss = 0.0f; int numEvaluatedAtoms = 0; for (Map.Entry entry: trainingMap.getLabelMap().entrySet()) { RandomVariableAtom randomVariableAtom = entry.getKey(); @@ -52,12 +52,14 @@ protected void computeSupervisedLoss() { continue; } - supervisedLoss += Math.pow(atomStore.getAtom(atomIndex).getValue() - observedAtom.getValue(), 2.0f); + supervisedLoss += (float) Math.pow(atomStore.getAtom(atomIndex).getValue() - observedAtom.getValue(), 2.0f); numEvaluatedAtoms++; } if (numEvaluatedAtoms > 0) { supervisedLoss /= numEvaluatedAtoms; } + + return supervisedLoss; } } diff --git a/psl-core/src/main/java/org/linqs/psl/config/Options.java b/psl-core/src/main/java/org/linqs/psl/config/Options.java index 2224c35fb..ec8066a07 100644 --- a/psl-core/src/main/java/org/linqs/psl/config/Options.java +++ b/psl-core/src/main/java/org/linqs/psl/config/Options.java @@ -715,6 +715,12 @@ public class Options { + " not having the atom initially in the database." ); + public static final Option POLICY_GRADIENT_NUM_SAMPLES = new Option( + "policygradient.numsamples", + 10, + "The number of samples to use to estimate each gradient." + ); + public static final Option POLICY_GRADIENT_POLICY_DISTRIBUTION = new Option( "policygradient.policydistribution", PolicyGradient.DeepAtomPolicyDistribution.CATEGORICAL.toString(), diff --git a/psl-core/src/main/java/org/linqs/psl/model/predicate/Predicate.java b/psl-core/src/main/java/org/linqs/psl/model/predicate/Predicate.java index d4c31e030..eaae9b8ff 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/predicate/Predicate.java +++ b/psl-core/src/main/java/org/linqs/psl/model/predicate/Predicate.java @@ -124,15 +124,17 @@ public Map getPredicateOptions() { public void setPredicateOption(String name, Object option) { options.put(name, option); - if (name.equals("Integer") && Boolean.parseBoolean(option.toString())) { + String lowerCaseName = name.toLowerCase(); + + if (lowerCaseName.equals("integer") && Boolean.parseBoolean(option.toString())) { integer = true; } - if (name.equals("Categorical") && Boolean.parseBoolean(option.toString())) { + if (lowerCaseName.equals("categorical") && Boolean.parseBoolean(option.toString())) { categorical = true; } - if (name.equals("CategoricalIndexes")) { + if (lowerCaseName.equals("categoricalindexes")) { categoryIndexes = StringUtils.splitInt(option.toString(), DELIM); for (int categoryIndex : categoryIndexes) { identifierIndexes[categoryIndex] = -1; diff --git a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropyTest.java b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropyTest.java index 1fff7e688..c224cf766 100644 --- a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropyTest.java +++ b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropyTest.java @@ -30,6 +30,7 @@ public void setup() { super.setup(); Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); + Options.WLA_GRADIENT_DESCENT_SYMBOLIC_LEARNING.set(false); } @Override diff --git a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredErrorTest.java b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredErrorTest.java index 9433b0457..624d6b143 100644 --- a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredErrorTest.java +++ b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredErrorTest.java @@ -30,6 +30,7 @@ public void setup() { super.setup(); Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); + Options.WLA_GRADIENT_DESCENT_SYMBOLIC_LEARNING.set(false); } @Override