Skip to content

Commit

Permalink
Policy gradient reimplementation progress.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Jan 25, 2024
1 parent 1b2fa83 commit a289884
Show file tree
Hide file tree
Showing 15 changed files with 198 additions and 203 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ public static enum SymbolicWeightUpdate {
protected Map<WeightedRule, Integer> ruleIndexMap;

protected float[] weightGradient;
protected float[] rvAtomGradient;
protected float[] deepAtomGradient;
protected float[] MAPRVAtomEnergyGradient;
protected float[] MAPDeepAtomEnergyGradient;
protected float[] rvGradient;
protected float[] deepGradient;
protected float[] MAPRVEnergyGradient;
protected float[] MAPDeepEnergyGradient;
protected float[] epochStartWeights;
protected float epochDeepAtomValueMovement;

Expand Down Expand Up @@ -135,10 +135,10 @@ public GradientDescent(List<Rule> rules, Database trainTargetDatabase, Database
}

weightGradient = new float[mutableRules.size()];
rvAtomGradient = null;
deepAtomGradient = null;
MAPRVAtomEnergyGradient = null;
MAPDeepAtomEnergyGradient = null;
rvGradient = null;
deepGradient = null;
MAPRVEnergyGradient = null;
MAPDeepEnergyGradient = null;

trainingEvaluationComputePeriod = Options.WLA_GRADIENT_DESCENT_TRAINING_COMPUTE_PERIOD.getInt();
trainFullTermStore = null;
Expand Down Expand Up @@ -224,7 +224,7 @@ protected void validateState() {
}

protected void initializeFullModels() {
trainFullTermStore = (SimpleTermStore<? extends ReasonerTerm>)trainInferenceApplication.getTermStore();
trainFullTermStore = (SimpleTermStore<? extends ReasonerTerm>) trainInferenceApplication.getTermStore();

fullTrainingMap = trainingMap;

Expand Down Expand Up @@ -281,11 +281,11 @@ protected void initializeBatchWarmStarts() {
}

protected void initializeGradients() {
rvAtomGradient = new float[trainFullMAPAtomValueState.length];
deepAtomGradient = new float[trainFullMAPAtomValueState.length];
rvGradient = new float[trainFullMAPAtomValueState.length];
deepGradient = new float[trainFullMAPAtomValueState.length];

MAPRVAtomEnergyGradient = new float[trainFullMAPAtomValueState.length];
MAPDeepAtomEnergyGradient = new float[trainFullMAPAtomValueState.length];
MAPRVEnergyGradient = new float[trainFullMAPAtomValueState.length];
MAPDeepEnergyGradient = new float[trainFullMAPAtomValueState.length];
}

protected void initForLearning() {
Expand Down Expand Up @@ -361,10 +361,12 @@ protected void doLearn() {
setBatch(batchId);
DeepPredicate.predictAllDeepPredicates();

resetGradients();

computeIterationStatistics();

computeTotalWeightGradient();
computeTotalAtomGradient();
addTotalAtomGradient();
if (clipWeightGradient) {
clipWeightGradient();
}
Expand Down Expand Up @@ -499,6 +501,14 @@ protected void setFullModel() {
}
}

protected void resetGradients() {
Arrays.fill(weightGradient, 0.0f);
Arrays.fill(rvGradient, 0.0f);
Arrays.fill(deepGradient, 0.0f);
Arrays.fill(MAPRVEnergyGradient, 0.0f);
Arrays.fill(MAPDeepEnergyGradient, 0.0f);
}

protected void setBatch(int batch) {
SimpleTermStore<? extends ReasonerTerm> batchTermStore = batchGenerator.getBatchTermStore(batch);
trainDeepModelPredicates = batchGenerator.getBatchDeepModelPredicates(batch);
Expand Down Expand Up @@ -705,7 +715,7 @@ protected void weightGradientStep(int epoch) {

protected void atomGradientStep() {
for (DeepPredicate deepPredicate : deepPredicates) {
deepPredicate.fitDeepPredicate(deepAtomGradient);
deepPredicate.fitDeepPredicate(deepGradient);
}
}

Expand Down Expand Up @@ -735,9 +745,9 @@ protected float computeGradientNorm() {
}

log.trace("Weight Gradient Norm: {}", norm);
log.trace("Deep atom Gradient Norm: {}", MathUtils.pNorm(deepAtomGradient, 2));
log.trace("Deep atom Gradient Norm: {}", MathUtils.pNorm(deepGradient, 2));

norm += MathUtils.pNorm(deepAtomGradient, 2);
norm += MathUtils.pNorm(deepGradient, 2);

return norm;
}
Expand Down Expand Up @@ -968,8 +978,6 @@ protected float computeRegularization() {
* Compute the gradient of the regularized learning loss with respect to the weights.
*/
protected void computeTotalWeightGradient() {
Arrays.fill(weightGradient, 0.0f);

if (!symbolicWeightLearning) {
return;
}
Expand All @@ -989,11 +997,11 @@ protected void computeTotalWeightGradient() {
protected void addRegularizationWeightGradient() {
for (int i = 0; i < mutableRules.size(); i++) {
float logWeight = (float)Math.log(Math.max(mutableRules.get(i).getWeight(), MathUtils.STRICT_EPSILON));
weightGradient[i] += 2.0f * l2Regularization * mutableRules.get(i).getWeight()
weightGradient[i] += (float) (2.0f * l2Regularization * mutableRules.get(i).getWeight()
- logRegularization / Math.max(mutableRules.get(i).getWeight(), MathUtils.STRICT_EPSILON)
+ entropyRegularization * (logWeight + 1);
+ entropyRegularization * (logWeight + 1));
}
}

protected abstract void computeTotalAtomGradient();
protected abstract void addTotalAtomGradient();
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.linqs.psl.reasoner.term.SimpleTermStore;
import org.linqs.psl.util.RandUtils;
import org.linqs.psl.util.Reflection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
Expand All @@ -37,6 +39,8 @@
* A batch in this case is a set of terms and corresponding atoms defining a subgraph of the complete factor graph.
*/
public abstract class BatchGenerator {
private static final Logger log = LoggerFactory.getLogger(BatchGenerator.class);

protected InferenceApplication inferenceApplication;
protected SimpleTermStore<? extends ReasonerTerm> fullTermStore;
protected AtomStore fullTruthAtomStore;
Expand Down Expand Up @@ -111,6 +115,8 @@ public TrainingMap getBatchTrainingMap(int index) {
}

public void generateBatches() {
log.trace("Generating batches.");

clear();

generateBatchesInternal();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ private void computeMAPInferenceStatistics() {
mapEnergy = trainInferenceApplication.getReasoner().parallelComputeObjective(trainInferenceApplication.getTermStore()).objective;
computeCurrentIncompatibility(mapIncompatibility);
computeCurrentSquaredIncompatibility(mapSquaredIncompatibility);
trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), MAPRVAtomEnergyGradient, MAPDeepAtomEnergyGradient);
trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), MAPRVEnergyGradient, MAPDeepEnergyGradient);
}

/**
Expand Down Expand Up @@ -715,10 +715,7 @@ protected void addLearningLossWeightGradient() {
}

@Override
protected void computeTotalAtomGradient() {
Arrays.fill(rvAtomGradient, 0.0f);
Arrays.fill(deepAtomGradient, 0.0f);

protected void addTotalAtomGradient() {
// Energy Loss Gradient.
for (int i = 0; i < trainInferenceApplication.getTermStore().getAtomStore().size(); i++) {
GroundAtom atom = trainInferenceApplication.getTermStore().getAtomStore().getAtom(i);
Expand All @@ -727,7 +724,7 @@ protected void computeTotalAtomGradient() {
continue;
}

deepAtomGradient[i] += energyLossCoefficient * deepLatentAtomGradient[i];
deepGradient[i] += energyLossCoefficient * deepLatentAtomGradient[i];
}

// Energy difference constraint gradient.
Expand All @@ -745,12 +742,12 @@ protected void computeTotalAtomGradient() {
continue;
}

float rvEnergyGradientDifference = augmentedRVAtomEnergyGradient[i] - MAPRVAtomEnergyGradient[i];
float deepAtomEnergyGradientDifference = augmentedDeepAtomEnergyGradient[i] - MAPDeepAtomEnergyGradient[i];
float rvEnergyGradientDifference = augmentedRVAtomEnergyGradient[i] - MAPRVEnergyGradient[i];
float deepAtomEnergyGradientDifference = augmentedDeepAtomEnergyGradient[i] - MAPDeepEnergyGradient[i];

rvAtomGradient[i] += squaredPenaltyCoefficient * constraintViolation * rvEnergyGradientDifference
rvGradient[i] += squaredPenaltyCoefficient * constraintViolation * rvEnergyGradientDifference
+ linearPenaltyCoefficient * rvEnergyGradientDifference;
deepAtomGradient[i] += squaredPenaltyCoefficient * constraintViolation * deepAtomEnergyGradientDifference
deepGradient[i] += squaredPenaltyCoefficient * constraintViolation * deepAtomEnergyGradientDifference
+ linearPenaltyCoefficient * deepAtomEnergyGradientDifference;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ protected float computeSupervisedLoss() {

int proxRuleIndex = rvAtomIndexToProxRuleIndex.get(atomIndex);

supervisedLoss += -1.0f * (observedAtom.getValue() * Math.log(Math.max(proxRuleObservedAtoms[proxRuleIndex].getValue(), MathUtils.EPSILON_FLOAT))
+ (1.0f - observedAtom.getValue()) * Math.log(Math.max(1.0f - proxRuleObservedAtoms[proxRuleIndex].getValue(), MathUtils.EPSILON_FLOAT)));
supervisedLoss += (float) (-1.0f * (observedAtom.getValue() * Math.log(Math.max(proxRuleObservedAtoms[proxRuleIndex].getValue(), MathUtils.EPSILON_FLOAT))
+ (1.0f - observedAtom.getValue()) * Math.log(Math.max(1.0f - proxRuleObservedAtoms[proxRuleIndex].getValue(), MathUtils.EPSILON_FLOAT))));
}

return supervisedLoss;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.linqs.psl.application.learning.weight.gradient.onlineminimizer;

import org.linqs.psl.application.learning.weight.gradient.GradientDescent;
import org.linqs.psl.application.learning.weight.gradient.batchgenerator.NeuralBatchGenerator;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
Expand Down Expand Up @@ -474,7 +473,7 @@ private void computeMAPInferenceStatistics() {
mapEnergy = trainInferenceApplication.getReasoner().parallelComputeObjective(trainInferenceApplication.getTermStore()).objective;
computeCurrentIncompatibility(mapIncompatibility);
computeCurrentSquaredIncompatibility(mapSquaredIncompatibility);
trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), MAPRVAtomEnergyGradient, MAPDeepAtomEnergyGradient);
trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), MAPRVEnergyGradient, MAPDeepEnergyGradient);
}

/**
Expand Down Expand Up @@ -634,10 +633,7 @@ protected void addLearningLossWeightGradient() {
}

@Override
protected void computeTotalAtomGradient() {
Arrays.fill(rvAtomGradient, 0.0f);
Arrays.fill(deepAtomGradient, 0.0f);

protected void addTotalAtomGradient() {
// Energy Loss Gradient.
for (int i = 0; i < trainInferenceApplication.getTermStore().getAtomStore().size(); i++) {
GroundAtom atom = trainInferenceApplication.getTermStore().getAtomStore().getAtom(i);
Expand All @@ -646,7 +642,7 @@ protected void computeTotalAtomGradient() {
continue;
}

deepAtomGradient[i] += energyLossCoefficient * deepLatentAtomGradient[i];
deepGradient[i] += energyLossCoefficient * deepLatentAtomGradient[i];
}

// Energy difference constraint gradient.
Expand All @@ -664,12 +660,12 @@ protected void computeTotalAtomGradient() {
continue;
}

float rvEnergyGradientDifference = augmentedRVAtomEnergyGradient[i] - MAPRVAtomEnergyGradient[i];
float deepAtomEnergyGradientDifference = augmentedDeepAtomEnergyGradient[i] - MAPDeepAtomEnergyGradient[i];
float rvEnergyGradientDifference = augmentedRVAtomEnergyGradient[i] - MAPRVEnergyGradient[i];
float deepAtomEnergyGradientDifference = augmentedDeepAtomEnergyGradient[i] - MAPDeepEnergyGradient[i];

rvAtomGradient[i] += squaredPenaltyCoefficient * constraintViolation * rvEnergyGradientDifference
rvGradient[i] += squaredPenaltyCoefficient * constraintViolation * rvEnergyGradientDifference
+ linearPenaltyCoefficient * rvEnergyGradientDifference;
deepAtomGradient[i] += squaredPenaltyCoefficient * constraintViolation * deepAtomEnergyGradientDifference
deepGradient[i] += squaredPenaltyCoefficient * constraintViolation * deepAtomEnergyGradientDifference
+ linearPenaltyCoefficient * deepAtomEnergyGradientDifference;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ protected float computeSupervisedLoss() {
continue;
}

supervisedLoss += Math.pow(proxRuleObservedAtoms[rvAtomIndexToProxRuleIndex.get(atomIndex)].getValue() - observedAtom.getValue(), 2.0f);
supervisedLoss += (float) Math.pow(proxRuleObservedAtoms[rvAtomIndexToProxRuleIndex.get(atomIndex)].getValue() - observedAtom.getValue(), 2.0f);
}

return supervisedLoss;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.linqs.psl.database.Database;
import org.linqs.psl.model.rule.Rule;

import java.util.Arrays;
import java.util.List;

/**
Expand Down Expand Up @@ -57,10 +56,7 @@ protected void addLearningLossWeightGradient() {
}

@Override
protected void computeTotalAtomGradient() {
Arrays.fill(rvAtomGradient, 0.0f);
Arrays.fill(deepAtomGradient, 0.0f);

System.arraycopy(deepLatentAtomGradient, 0, deepAtomGradient, 0, deepLatentAtomGradient.length);
protected void addTotalAtomGradient() {
System.arraycopy(deepLatentAtomGradient, 0, deepGradient, 0, deepLatentAtomGradient.length);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ private void computeMAPInferenceIncompatibility() {
inTrainingMAPState = true;

computeCurrentIncompatibility(MAPIncompatibility);
trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), MAPRVAtomEnergyGradient, MAPDeepAtomEnergyGradient);
trainInferenceApplication.getReasoner().computeOptimalValueGradient(trainInferenceApplication.getTermStore(), MAPRVEnergyGradient, MAPDeepEnergyGradient);
}

@Override
Expand All @@ -76,13 +76,10 @@ protected void addLearningLossWeightGradient() {
}

@Override
protected void computeTotalAtomGradient() {
Arrays.fill(rvAtomGradient, 0.0f);
Arrays.fill(deepAtomGradient, 0.0f);

for (int i = 0; i < rvAtomGradient.length; i++) {
rvAtomGradient[i] = rvLatentAtomGradient[i] - MAPRVAtomEnergyGradient[i];
deepAtomGradient[i] = deepLatentAtomGradient[i] - MAPDeepAtomEnergyGradient[i];
protected void addTotalAtomGradient() {
for (int i = 0; i < rvGradient.length; i++) {
rvGradient[i] = rvLatentAtomGradient[i] - MAPRVEnergyGradient[i];
deepGradient[i] = deepLatentAtomGradient[i] - MAPDeepEnergyGradient[i];
}
}
}
Loading

0 comments on commit a289884

Please sign in to comment.