From e22079b2a3a641ec4e97f4437552ba8e7317fdac Mon Sep 17 00:00:00 2001 From: Charles Dickens Date: Fri, 26 Jan 2024 18:01:49 -0800 Subject: [PATCH] Policy Gradient progress --- .../weight/gradient/GradientDescent.java | 12 ++-- .../policygradient/PolicyGradient.java | 63 ++++++++++++++---- .../PolicyGradientBinaryStep.java | 66 +++++++++++++++++++ .../java/org/linqs/psl/config/Options.java | 13 ++++ .../PolicyGradientBinaryStepTest.java | 55 ++++++++++++++++ 5 files changed, 193 insertions(+), 16 deletions(-) create mode 100644 psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryStep.java create mode 100644 psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryStepTest.java diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java index e3966b78b..96daf7de1 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/GradientDescent.java @@ -963,10 +963,14 @@ protected float computeTotalLoss() { */ protected float computeRegularization() { float regularization = 0.0f; - for (int i = 0; i < mutableRules.size(); i++) { - WeightedRule mutableRule = mutableRules.get(i); - float logWeight = (float)Math.max(Math.log(mutableRule.getWeight()), Math.log(MathUtils.STRICT_EPSILON)); - regularization += l2Regularization * (float)Math.pow(mutableRule.getWeight(), 2.0f) + + if (!symbolicWeightLearning) { + return regularization; + } + + for (WeightedRule mutableRule : mutableRules) { + float logWeight = (float) Math.max(Math.log(mutableRule.getWeight()), Math.log(MathUtils.STRICT_EPSILON)); + regularization += l2Regularization * (float) Math.pow(mutableRule.getWeight(), 2.0f) - logRegularization * logWeight + entropyRegularization * mutableRule.getWeight() * logWeight; } diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradient.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradient.java index d14958c6a..415049a69 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradient.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradient.java @@ -14,6 +14,7 @@ import org.linqs.psl.reasoner.term.SimpleTermStore; import org.linqs.psl.reasoner.term.TermState; import org.linqs.psl.util.Logger; +import org.linqs.psl.util.MathUtils; import org.linqs.psl.util.RandUtils; import java.util.ArrayList; @@ -32,11 +33,19 @@ public enum DeepAtomPolicyDistribution { } public enum PolicyUpdate { - REINFORCE + REINFORCE, + REINFORCE_BASELINE + } + + public enum RewardFunction { + NEGATIVE_LOSS, + NEGATIVE_LOSS_SQUARED, + INVERSE_LOSS } private final DeepAtomPolicyDistribution deepAtomPolicyDistribution; private final PolicyUpdate policyUpdate; + private final RewardFunction rewardFunction; private int numSamples; protected int[] actionSampleCounts; @@ -66,6 +75,7 @@ public PolicyGradient(List rules, Database trainTargetDatabase, Database t deepAtomPolicyDistribution = DeepAtomPolicyDistribution.valueOf(Options.POLICY_GRADIENT_POLICY_DISTRIBUTION.getString().toUpperCase()); policyUpdate = PolicyUpdate.valueOf(Options.POLICY_GRADIENT_POLICY_UPDATE.getString().toUpperCase()); + rewardFunction = RewardFunction.valueOf(Options.POLICY_GRADIENT_REWARD_FUNCTION.getString().toUpperCase()); numSamples = Options.POLICY_GRADIENT_NUM_SAMPLES.getInt(); actionSampleCounts = null; @@ -81,7 +91,7 @@ public PolicyGradient(List rules, Database trainTargetDatabase, Database t deepLatentEnergyGradient = null; deepSupervisedLossGradient = null; - energyLossCoefficient = Options.MINIMIZER_ENERGY_LOSS_COEFFICIENT.getFloat(); + energyLossCoefficient = Options.POLICY_GRADIENT_ENERGY_LOSS_COEFFICIENT.getFloat(); MAPStateSupervisedLoss = Float.POSITIVE_INFINITY; @@ -164,20 +174,42 @@ protected void computeIterationStatistics() { switch (policyUpdate) { case REINFORCE: - addREINFORCESupervisedLossGradient(); + addREINFORCESupervisedLossGradient(0.0f); + break; + case REINFORCE_BASELINE: + addREINFORCESupervisedLossGradient(MAPStateSupervisedLoss); break; default: throw new IllegalArgumentException("Unknown policy update: " + policyUpdate); } } - private void addREINFORCESupervisedLossGradient() { + private void addREINFORCESupervisedLossGradient(float baseline) { for (int i = 0; i < numSamples; i++) { sampleAllDeepAtomValues(); computeMAPStateWithWarmStart(trainInferenceApplication, trainMAPTermState, trainMAPAtomValueState); - addSupervisedLossDeepGradient(computeSupervisedLoss()); + float supervisedLoss = computeSupervisedLoss(); + float reward = 0.0f; + + switch (rewardFunction) { + case NEGATIVE_LOSS: + reward = 1.0f - supervisedLoss; + break; + case NEGATIVE_LOSS_SQUARED: + reward = (float) Math.pow(1.0f - supervisedLoss, 2.0f); + break; + case INVERSE_LOSS: + // The inverse loss may result in a reward of infinity. + // Therefore, we add a small constant to the loss to avoid division by zero. + reward = 1.0f / (supervisedLoss + MathUtils.EPSILON_FLOAT); + break; + default: + throw new IllegalArgumentException("Unknown reward function: " + rewardFunction); + } + + addPolicyScoreGradient(reward - baseline); resetAllDeepAtomValues(); } @@ -188,8 +220,8 @@ private void addREINFORCESupervisedLossGradient() { continue; } - log.trace("Atom: {} Deep Supervised Loss Gradient: {}", - trainInferenceApplication.getTermStore().getAtomStore().getAtom(i).toStringWithValue(), deepSupervisedLossGradient[i]); +// log.trace("Atom: {} Deep Supervised Loss Gradient: {}", +// trainInferenceApplication.getTermStore().getAtomStore().getAtom(i).toStringWithValue(), deepSupervisedLossGradient[i] / actionSampleCounts[i]); deepSupervisedLossGradient[i] /= actionSampleCounts[i]; } } @@ -304,13 +336,19 @@ protected void addLearningLossWeightGradient() { @Override protected void addTotalAtomGradient() { - for (int i = 0; i < rvGradient.length; i++) { + for (int i = 0; i < trainInferenceApplication.getTermStore().getAtomStore().size(); i++) { rvGradient[i] = energyLossCoefficient * rvLatentEnergyGradient[i]; deepGradient[i] = energyLossCoefficient * deepLatentEnergyGradient[i] + deepSupervisedLossGradient[i]; + + if (trainInferenceApplication.getTermStore().getAtomStore().getAtom(i).getPredicate() instanceof DeepPredicate) { + log.trace("Atom: {} deepLatentEnergyGradient: {}, deepSupervisedLossGradient: {}", + trainInferenceApplication.getTermStore().getAtomStore().getAtom(i).toStringWithValue(), + deepLatentEnergyGradient[i], deepSupervisedLossGradient[i]); + } } } - private void addSupervisedLossDeepGradient(float supervisedLoss) { + private void addPolicyScoreGradient(float reward) { AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); for (int atomIndex = 0; atomIndex < atomStore.size(); atomIndex++) { GroundAtom atom = atomStore.getAtom(atomIndex); @@ -322,7 +360,7 @@ private void addSupervisedLossDeepGradient(float supervisedLoss) { switch (deepAtomPolicyDistribution) { case CATEGORICAL: - addCategoricalPolicySupervisedLossGradient(atomIndex, (RandomVariableAtom) atom, supervisedLoss); + addCategoricalPolicyScoreGradient(atomIndex, (RandomVariableAtom) atom, reward); break; default: throw new IllegalArgumentException("Unknown policy distribution: " + deepAtomPolicyDistribution); @@ -330,7 +368,7 @@ private void addSupervisedLossDeepGradient(float supervisedLoss) { } } - private void addCategoricalPolicySupervisedLossGradient(int atomIndex, RandomVariableAtom atom, float score) { + private void addCategoricalPolicyScoreGradient(int atomIndex, RandomVariableAtom atom, float reward) { // Skip atoms not selected by the policy. if (atom.getValue() == 0.0f) { return; @@ -338,8 +376,9 @@ private void addCategoricalPolicySupervisedLossGradient(int atomIndex, RandomVar switch (policyUpdate) { case REINFORCE: + case REINFORCE_BASELINE: // The initialDeepAtomValues are the action probabilities. - deepSupervisedLossGradient[atomIndex] += score / initialDeepAtomValues[atomIndex]; + deepSupervisedLossGradient[atomIndex] -= reward / initialDeepAtomValues[atomIndex]; break; default: throw new IllegalArgumentException("Unknown policy update: " + policyUpdate); diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryStep.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryStep.java new file mode 100644 index 000000000..abc4f0aa0 --- /dev/null +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryStep.java @@ -0,0 +1,66 @@ +/* + * This file is part of the PSL software. + * Copyright 2011-2015 University of Maryland + * Copyright 2013-2023 The Regents of the University of California + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.linqs.psl.application.learning.weight.gradient.policygradient; + +import org.linqs.psl.database.AtomStore; +import org.linqs.psl.database.Database; +import org.linqs.psl.model.atom.ObservedAtom; +import org.linqs.psl.model.atom.RandomVariableAtom; +import org.linqs.psl.model.rule.Rule; +import org.linqs.psl.util.Logger; +import org.linqs.psl.util.MathUtils; + +import java.util.List; +import java.util.Map; + +/** + * Learns parameters for a model by minimizing the squared error loss function + * using the policy gradient learning framework. + */ +public class PolicyGradientBinaryStep extends PolicyGradient { + private static final Logger log = Logger.getLogger(PolicyGradientBinaryStep.class); + + public PolicyGradientBinaryStep(List rules, Database trainTargetDatabase, Database trainTruthDatabase, + Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) { + super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation); + } + + @Override + protected float computeSupervisedLoss() { + AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); + + float supervisedLoss = 0.0f; + for (Map.Entry entry : trainingMap.getLabelMap().entrySet()) { + RandomVariableAtom randomVariableAtom = entry.getKey(); + ObservedAtom observedAtom = entry.getValue(); + + int atomIndex = atomStore.getAtomIndex(randomVariableAtom); + if (atomIndex == -1) { + // This atom is not in the current batch. + continue; + } + + if (!MathUtils.equals(observedAtom.getValue(), atomStore.getAtom(atomIndex).getValue())) { + supervisedLoss = 1.0f; + break; + } + } + + return supervisedLoss; + } +} diff --git a/psl-core/src/main/java/org/linqs/psl/config/Options.java b/psl-core/src/main/java/org/linqs/psl/config/Options.java index ec8066a07..1b49a25cd 100644 --- a/psl-core/src/main/java/org/linqs/psl/config/Options.java +++ b/psl-core/src/main/java/org/linqs/psl/config/Options.java @@ -715,6 +715,13 @@ public class Options { + " not having the atom initially in the database." ); + public static final Option POLICY_GRADIENT_ENERGY_LOSS_COEFFICIENT = new Option( + "policygradient.energylosscoefficient", + 1.0f, + "The coefficient of the energy loss term in the augmented Lagrangian minimizer-based learning framework.", + Option.FLAG_NON_NEGATIVE + ); + public static final Option POLICY_GRADIENT_NUM_SAMPLES = new Option( "policygradient.numsamples", 10, @@ -733,6 +740,12 @@ public class Options { "The policy update to use for policy gradient learning." ); + public static final Option POLICY_GRADIENT_REWARD_FUNCTION = new Option( + "policygradient.rewardfunction", + PolicyGradient.RewardFunction.NEGATIVE_LOSS.toString(), + "The reward function to use for policy gradient learning." + ); + public static final Option POSTGRES_HOST = new Option( "postgres.host", "localhost", diff --git a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryStepTest.java b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryStepTest.java new file mode 100644 index 000000000..ff1f1a267 --- /dev/null +++ b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryStepTest.java @@ -0,0 +1,55 @@ +/* + * This file is part of the PSL software. + * Copyright 2011-2015 University of Maryland + * Copyright 2013-2023 The Regents of the University of California + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.linqs.psl.application.learning.weight.gradient.policygradient; + +import org.junit.Before; +import org.junit.Test; +import org.linqs.psl.application.inference.mpe.DualBCDInference; +import org.linqs.psl.application.learning.weight.WeightLearningApplication; +import org.linqs.psl.application.learning.weight.WeightLearningTest; +import org.linqs.psl.config.Options; + +public class PolicyGradientBinaryStepTest extends WeightLearningTest { + @Before + public void setup() { + super.setup(); + + Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); + Options.WLA_GRADIENT_DESCENT_SYMBOLIC_LEARNING.set(false); + } + + @Override + protected WeightLearningApplication getBaseWLA() { + return new PolicyGradientBinaryStep(info.model.getRules(), trainTargetDatabase, trainTruthDatabase, + validationTargetDatabase, validationTruthDatabase, false); + } + + @Test + public void DualBCDFriendshipRankTest() { + Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); + + super.friendshipRankTest(); + } + + @Test + public void DistributedDualBCDFriendshipRankTest() { + Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); + + super.friendshipRankTest(); + } +}