Skip to content

Commit

Permalink
Policy Gradient progress
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Jan 27, 2024
1 parent a289884 commit e22079b
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -963,10 +963,14 @@ protected float computeTotalLoss() {
*/
protected float computeRegularization() {
float regularization = 0.0f;
for (int i = 0; i < mutableRules.size(); i++) {
WeightedRule mutableRule = mutableRules.get(i);
float logWeight = (float)Math.max(Math.log(mutableRule.getWeight()), Math.log(MathUtils.STRICT_EPSILON));
regularization += l2Regularization * (float)Math.pow(mutableRule.getWeight(), 2.0f)

if (!symbolicWeightLearning) {
return regularization;
}

for (WeightedRule mutableRule : mutableRules) {
float logWeight = (float) Math.max(Math.log(mutableRule.getWeight()), Math.log(MathUtils.STRICT_EPSILON));
regularization += l2Regularization * (float) Math.pow(mutableRule.getWeight(), 2.0f)
- logRegularization * logWeight
+ entropyRegularization * mutableRule.getWeight() * logWeight;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.linqs.psl.reasoner.term.SimpleTermStore;
import org.linqs.psl.reasoner.term.TermState;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.RandUtils;

import java.util.ArrayList;
Expand All @@ -32,11 +33,19 @@ public enum DeepAtomPolicyDistribution {
}

public enum PolicyUpdate {
REINFORCE
REINFORCE,
REINFORCE_BASELINE
}

public enum RewardFunction {
NEGATIVE_LOSS,
NEGATIVE_LOSS_SQUARED,
INVERSE_LOSS
}

private final DeepAtomPolicyDistribution deepAtomPolicyDistribution;
private final PolicyUpdate policyUpdate;
private final RewardFunction rewardFunction;

private int numSamples;
protected int[] actionSampleCounts;
Expand Down Expand Up @@ -66,6 +75,7 @@ public PolicyGradient(List<Rule> rules, Database trainTargetDatabase, Database t

deepAtomPolicyDistribution = DeepAtomPolicyDistribution.valueOf(Options.POLICY_GRADIENT_POLICY_DISTRIBUTION.getString().toUpperCase());
policyUpdate = PolicyUpdate.valueOf(Options.POLICY_GRADIENT_POLICY_UPDATE.getString().toUpperCase());
rewardFunction = RewardFunction.valueOf(Options.POLICY_GRADIENT_REWARD_FUNCTION.getString().toUpperCase());

numSamples = Options.POLICY_GRADIENT_NUM_SAMPLES.getInt();
actionSampleCounts = null;
Expand All @@ -81,7 +91,7 @@ public PolicyGradient(List<Rule> rules, Database trainTargetDatabase, Database t
deepLatentEnergyGradient = null;
deepSupervisedLossGradient = null;

energyLossCoefficient = Options.MINIMIZER_ENERGY_LOSS_COEFFICIENT.getFloat();
energyLossCoefficient = Options.POLICY_GRADIENT_ENERGY_LOSS_COEFFICIENT.getFloat();

MAPStateSupervisedLoss = Float.POSITIVE_INFINITY;

Expand Down Expand Up @@ -164,20 +174,42 @@ protected void computeIterationStatistics() {

switch (policyUpdate) {
case REINFORCE:
addREINFORCESupervisedLossGradient();
addREINFORCESupervisedLossGradient(0.0f);
break;
case REINFORCE_BASELINE:
addREINFORCESupervisedLossGradient(MAPStateSupervisedLoss);
break;
default:
throw new IllegalArgumentException("Unknown policy update: " + policyUpdate);
}
}

private void addREINFORCESupervisedLossGradient() {
private void addREINFORCESupervisedLossGradient(float baseline) {
for (int i = 0; i < numSamples; i++) {
sampleAllDeepAtomValues();

computeMAPStateWithWarmStart(trainInferenceApplication, trainMAPTermState, trainMAPAtomValueState);

addSupervisedLossDeepGradient(computeSupervisedLoss());
float supervisedLoss = computeSupervisedLoss();
float reward = 0.0f;

switch (rewardFunction) {
case NEGATIVE_LOSS:
reward = 1.0f - supervisedLoss;
break;
case NEGATIVE_LOSS_SQUARED:
reward = (float) Math.pow(1.0f - supervisedLoss, 2.0f);
break;
case INVERSE_LOSS:
// The inverse loss may result in a reward of infinity.
// Therefore, we add a small constant to the loss to avoid division by zero.
reward = 1.0f / (supervisedLoss + MathUtils.EPSILON_FLOAT);
break;
default:
throw new IllegalArgumentException("Unknown reward function: " + rewardFunction);
}

addPolicyScoreGradient(reward - baseline);

resetAllDeepAtomValues();
}
Expand All @@ -188,8 +220,8 @@ private void addREINFORCESupervisedLossGradient() {
continue;
}

log.trace("Atom: {} Deep Supervised Loss Gradient: {}",
trainInferenceApplication.getTermStore().getAtomStore().getAtom(i).toStringWithValue(), deepSupervisedLossGradient[i]);
// log.trace("Atom: {} Deep Supervised Loss Gradient: {}",
// trainInferenceApplication.getTermStore().getAtomStore().getAtom(i).toStringWithValue(), deepSupervisedLossGradient[i] / actionSampleCounts[i]);
deepSupervisedLossGradient[i] /= actionSampleCounts[i];
}
}
Expand Down Expand Up @@ -304,13 +336,19 @@ protected void addLearningLossWeightGradient() {

@Override
protected void addTotalAtomGradient() {
for (int i = 0; i < rvGradient.length; i++) {
for (int i = 0; i < trainInferenceApplication.getTermStore().getAtomStore().size(); i++) {
rvGradient[i] = energyLossCoefficient * rvLatentEnergyGradient[i];
deepGradient[i] = energyLossCoefficient * deepLatentEnergyGradient[i] + deepSupervisedLossGradient[i];

if (trainInferenceApplication.getTermStore().getAtomStore().getAtom(i).getPredicate() instanceof DeepPredicate) {
log.trace("Atom: {} deepLatentEnergyGradient: {}, deepSupervisedLossGradient: {}",
trainInferenceApplication.getTermStore().getAtomStore().getAtom(i).toStringWithValue(),
deepLatentEnergyGradient[i], deepSupervisedLossGradient[i]);
}
}
}

private void addSupervisedLossDeepGradient(float supervisedLoss) {
private void addPolicyScoreGradient(float reward) {
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();
for (int atomIndex = 0; atomIndex < atomStore.size(); atomIndex++) {
GroundAtom atom = atomStore.getAtom(atomIndex);
Expand All @@ -322,24 +360,25 @@ private void addSupervisedLossDeepGradient(float supervisedLoss) {

switch (deepAtomPolicyDistribution) {
case CATEGORICAL:
addCategoricalPolicySupervisedLossGradient(atomIndex, (RandomVariableAtom) atom, supervisedLoss);
addCategoricalPolicyScoreGradient(atomIndex, (RandomVariableAtom) atom, reward);
break;
default:
throw new IllegalArgumentException("Unknown policy distribution: " + deepAtomPolicyDistribution);
}
}
}

private void addCategoricalPolicySupervisedLossGradient(int atomIndex, RandomVariableAtom atom, float score) {
private void addCategoricalPolicyScoreGradient(int atomIndex, RandomVariableAtom atom, float reward) {
// Skip atoms not selected by the policy.
if (atom.getValue() == 0.0f) {
return;
}

switch (policyUpdate) {
case REINFORCE:
case REINFORCE_BASELINE:
// The initialDeepAtomValues are the action probabilities.
deepSupervisedLossGradient[atomIndex] += score / initialDeepAtomValues[atomIndex];
deepSupervisedLossGradient[atomIndex] -= reward / initialDeepAtomValues[atomIndex];
break;
default:
throw new IllegalArgumentException("Unknown policy update: " + policyUpdate);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* This file is part of the PSL software.
* Copyright 2011-2015 University of Maryland
* Copyright 2013-2023 The Regents of the University of California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.linqs.psl.application.learning.weight.gradient.policygradient;

import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;

import java.util.List;
import java.util.Map;

/**
* Learns parameters for a model by minimizing the squared error loss function
* using the policy gradient learning framework.
*/
public class PolicyGradientBinaryStep extends PolicyGradient {
private static final Logger log = Logger.getLogger(PolicyGradientBinaryStep.class);

public PolicyGradientBinaryStep(List<Rule> rules, Database trainTargetDatabase, Database trainTruthDatabase,
Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) {
super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation);
}

@Override
protected float computeSupervisedLoss() {
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();

float supervisedLoss = 0.0f;
for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : trainingMap.getLabelMap().entrySet()) {
RandomVariableAtom randomVariableAtom = entry.getKey();
ObservedAtom observedAtom = entry.getValue();

int atomIndex = atomStore.getAtomIndex(randomVariableAtom);
if (atomIndex == -1) {
// This atom is not in the current batch.
continue;
}

if (!MathUtils.equals(observedAtom.getValue(), atomStore.getAtom(atomIndex).getValue())) {
supervisedLoss = 1.0f;
break;
}
}

return supervisedLoss;
}
}
13 changes: 13 additions & 0 deletions psl-core/src/main/java/org/linqs/psl/config/Options.java
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,13 @@ public class Options {
+ " not having the atom initially in the database."
);

public static final Option POLICY_GRADIENT_ENERGY_LOSS_COEFFICIENT = new Option(
"policygradient.energylosscoefficient",
1.0f,
"The coefficient of the energy loss term in the augmented Lagrangian minimizer-based learning framework.",
Option.FLAG_NON_NEGATIVE
);

public static final Option POLICY_GRADIENT_NUM_SAMPLES = new Option(
"policygradient.numsamples",
10,
Expand All @@ -733,6 +740,12 @@ public class Options {
"The policy update to use for policy gradient learning."
);

public static final Option POLICY_GRADIENT_REWARD_FUNCTION = new Option(
"policygradient.rewardfunction",
PolicyGradient.RewardFunction.NEGATIVE_LOSS.toString(),
"The reward function to use for policy gradient learning."
);

public static final Option POSTGRES_HOST = new Option(
"postgres.host",
"localhost",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* This file is part of the PSL software.
* Copyright 2011-2015 University of Maryland
* Copyright 2013-2023 The Regents of the University of California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.linqs.psl.application.learning.weight.gradient.policygradient;

import org.junit.Before;
import org.junit.Test;
import org.linqs.psl.application.inference.mpe.DualBCDInference;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.application.learning.weight.WeightLearningTest;
import org.linqs.psl.config.Options;

public class PolicyGradientBinaryStepTest extends WeightLearningTest {
@Before
public void setup() {
super.setup();

Options.WLA_INFERENCE.set(DualBCDInference.class.getName());
Options.WLA_GRADIENT_DESCENT_SYMBOLIC_LEARNING.set(false);
}

@Override
protected WeightLearningApplication getBaseWLA() {
return new PolicyGradientBinaryStep(info.model.getRules(), trainTargetDatabase, trainTruthDatabase,
validationTargetDatabase, validationTruthDatabase, false);
}

@Test
public void DualBCDFriendshipRankTest() {
Options.WLA_INFERENCE.set(DualBCDInference.class.getName());

super.friendshipRankTest();
}

@Test
public void DistributedDualBCDFriendshipRankTest() {
Options.WLA_INFERENCE.set(DualBCDInference.class.getName());

super.friendshipRankTest();
}
}

0 comments on commit e22079b

Please sign in to comment.