Skip to content

Commit

Permalink
Clean up policy gradient implementation and add BCE loss.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Dec 4, 2023
1 parent cca7ee5 commit c2a6e69
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.linqs.psl.reasoner.term.SimpleTermStore;
import org.linqs.psl.reasoner.term.TermState;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.RandUtils;

import java.util.ArrayList;
Expand All @@ -28,17 +27,20 @@
public abstract class PolicyGradient extends GradientDescent {
private static final Logger log = Logger.getLogger(PolicyGradient.class);

public enum PolicyDistribution {
CATEGORICAL,
GUMBEL_SOFTMAX
public enum DeepAtomPolicyDistribution {
CATEGORICAL
}

private final PolicyDistribution policyDistribution;
public enum PolicyUpdate {
REINFORCE,
REINFORCE_BASELINE,
}

private float lossMovingAverage;
private float[] sampleProbabilities;
private final DeepAtomPolicyDistribution deepAtomPolicyDistribution;
private final PolicyUpdate policyUpdate;

private final float gumbelSoftmaxTemperature;
private float scoreMovingAverage;
private float[] sampleProbabilities;

protected float[] initialDeepAtomValues;
protected float[] policySampledDeepAtomValues;
Expand All @@ -61,13 +63,12 @@ public PolicyGradient(List<Rule> rules, Database trainTargetDatabase, Database t
Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) {
super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation);

policyDistribution = PolicyDistribution.valueOf(Options.POLICY_GRADIENT_POLICY_DISTRIBUTION.getString().toUpperCase());
lossMovingAverage = 0.0f;
deepAtomPolicyDistribution = DeepAtomPolicyDistribution.valueOf(Options.POLICY_GRADIENT_POLICY_DISTRIBUTION.getString().toUpperCase());
policyUpdate = PolicyUpdate.valueOf(Options.POLICY_GRADIENT_POLICY_UPDATE.getString().toUpperCase());
scoreMovingAverage = 0.0f;

sampleProbabilities = null;

gumbelSoftmaxTemperature = Options.POLICY_GRADIENT_GUMBEL_SOFTMAX_TEMPERATURE.getFloat();

initialDeepAtomValues = null;
policySampledDeepAtomValues = null;

Expand All @@ -90,7 +91,7 @@ public PolicyGradient(List<Rule> rules, Database trainTargetDatabase, Database t
protected void initForLearning() {
super.initForLearning();

lossMovingAverage = 0.0f;
scoreMovingAverage = 0.0f;
}

protected abstract void computeSupervisedLoss();
Expand Down Expand Up @@ -161,15 +162,12 @@ protected void sampleDeepAtomValues() {
}
}

switch (policyDistribution) {
switch (deepAtomPolicyDistribution) {
case CATEGORICAL:
sampleCategorical();
break;
case GUMBEL_SOFTMAX:
sampleGumbelSoftmax();
break;
default:
throw new IllegalArgumentException("Unknown policy distribution: " + policyDistribution);
throw new IllegalArgumentException("Unknown policy distribution: " + deepAtomPolicyDistribution);
}
}

Expand Down Expand Up @@ -200,7 +198,6 @@ private void sampleCategorical() {
} else {
sampleProbabilities[atomIndex] = categoryProbabilities[i];
categories.get(i).setValue(1.0f);
// log.trace("Sampled category: {}.", categories.get(i).toStringWithValue());
}

policySampledDeepAtomValues[atomIndex] = categories.get(i).getValue();
Expand All @@ -210,43 +207,6 @@ private void sampleCategorical() {
}
}

private void sampleGumbelSoftmax() {
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();
float[] atomValues = atomStore.getAtomValues();

// Sample the deep model predictions according to the stochastic gumbel softmax policy.
for (DeepModelPredicate deepModelPredicate : trainDeepModelPredicates) {

Map <String, ArrayList<RandomVariableAtom>> atomIdentiferToCategories = deepModelPredicate.getAtomIdentiferToCategories();
for (Map.Entry<String, ArrayList<RandomVariableAtom>> entry : atomIdentiferToCategories.entrySet()) {
String atomIdentifier = entry.getKey();
ArrayList<RandomVariableAtom> categories = entry.getValue();

float[] gumbelSoftmaxSample = new float[categories.size()];
float categoryProbabilitySum = 0.0f;
for (int i = 0; i < categories.size(); i++) {
float gumbelSample = RandUtils.nextGumbel();

gumbelSoftmaxSample[i] = (float) Math.exp((Math.log(Math.max(categories.get(i).getValue(), MathUtils.STRICT_EPSILON)) + gumbelSample)
/ gumbelSoftmaxTemperature);
categoryProbabilitySum += gumbelSoftmaxSample[i];
}

// Renormalize the probabilities and set the deep atom values.
for (int i = 0; i < categories.size(); i++) {
gumbelSoftmaxSample[i] = gumbelSoftmaxSample[i] / categoryProbabilitySum;

RandomVariableAtom category = categories.get(i);
int atomIndex = atomStore.getAtomIndex(category);

category.setValue(gumbelSoftmaxSample[i]);
policySampledDeepAtomValues[atomIndex] = gumbelSoftmaxSample[i];
atomValues[atomIndex] = gumbelSoftmaxSample[i];
}
}
}
}

private void resetDeepAtomValues() {
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();
float[] atomValues = atomStore.getAtomValues();
Expand Down Expand Up @@ -315,36 +275,37 @@ protected void computeTotalAtomGradient() {
continue;
}

switch (policyDistribution) {
switch (deepAtomPolicyDistribution) {
case CATEGORICAL:
computeCategoricalAtomGradient(atom, i);
break;
case GUMBEL_SOFTMAX:
computeGumbelSoftmaxAtomGradient(atom, i);
computeCategoricalAtomGradient(i);
break;
default:
throw new IllegalArgumentException("Unknown policy distribution: " + policyDistribution);
throw new IllegalArgumentException("Unknown policy distribution: " + deepAtomPolicyDistribution);
}
}
}

private void computeCategoricalAtomGradient(GroundAtom atom, int atomIndex) {
private void computeCategoricalAtomGradient(int atomIndex) {
if (policySampledDeepAtomValues[atomIndex] == 0.0f) {
deepAtomGradient[atomIndex] = 0.0f;
return;
}

float score = energyLossCoefficient * latentInferenceEnergy + supervisedLoss;
lossMovingAverage = 0.9f * lossMovingAverage + 0.1f * score;
scoreMovingAverage = 0.9f * scoreMovingAverage + 0.1f * score;

deepAtomGradient[atomIndex] += (score - lossMovingAverage) / sampleProbabilities[atomIndex];

// log.trace("Atom: {}, Score: {}, Gradient: {}.", atom.toStringWithValue(), score, deepAtomGradient[atomIndex]);
switch (policyUpdate) {
case REINFORCE:
deepAtomGradient[atomIndex] += score / sampleProbabilities[atomIndex];
break;
case REINFORCE_BASELINE:
deepAtomGradient[atomIndex] += (score - scoreMovingAverage) / sampleProbabilities[atomIndex];
break;
default:
throw new IllegalArgumentException("Unknown policy update: " + policyUpdate);
}
}

private void computeGumbelSoftmaxAtomGradient(GroundAtom atom, int atomIndex) {
// TODO(Charles): Compute the energy and supervised policy gradient for the gumbel softmax policy.
}

/**
* Set RandomVariableAtoms with labels to their observed (truth) value.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* This file is part of the PSL software.
* Copyright 2011-2015 University of Maryland
* Copyright 2013-2023 The Regents of the University of California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.linqs.psl.application.learning.weight.gradient.policygradient;

import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.util.MathUtils;

import java.util.List;
import java.util.Map;

/**
* Learns parameters for a model by minimizing the squared error loss function
* using the policy gradient learning framework.
*/
public class PolicyGradientBinaryCrossEntropy extends PolicyGradient {
public PolicyGradientBinaryCrossEntropy(List<Rule> rules, Database trainTargetDatabase, Database trainTruthDatabase,
Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) {
super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation);
}

@Override
protected void computeSupervisedLoss() {
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();

supervisedLoss = 0.0f;
for (Map.Entry<RandomVariableAtom, ObservedAtom> entry: trainingMap.getLabelMap().entrySet()) {
RandomVariableAtom randomVariableAtom = entry.getKey();
ObservedAtom observedAtom = entry.getValue();

int atomIndex = atomStore.getAtomIndex(randomVariableAtom);
if (atomIndex == -1) {
// This atom is not in the current batch.
continue;
}

supervisedLoss += -1.0f * (observedAtom.getValue() * Math.log(Math.max(atomStore.getAtom(atomIndex).getValue(), MathUtils.EPSILON_FLOAT))
+ (1.0f - observedAtom.getValue()) * Math.log(Math.max(1.0f - atomStore.getAtom(atomIndex).getValue(), MathUtils.EPSILON_FLOAT)));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

/**
* Learns parameters for a model by minimizing the squared error loss function
* using the minimizer-based learning framework.
* using the policy gradient learning framework.
*/
public class PolicyGradientSquaredError extends PolicyGradient {
public PolicyGradientSquaredError(List<Rule> rules, Database trainTargetDatabase, Database trainTruthDatabase,
Expand Down
15 changes: 7 additions & 8 deletions psl-core/src/main/java/org/linqs/psl/config/Options.java
Original file line number Diff line number Diff line change
Expand Up @@ -709,19 +709,18 @@ public class Options {
+ " not having the atom initially in the database."
);

public static final Option POLICY_GRADIENT_GUMBEL_SOFTMAX_TEMPERATURE = new Option(
"policygradient.gumbelsoftmax.temperature",
1.0f,
"The temperature parameter for the Gumbel-Softmax distribution.",
Option.FLAG_POSITIVE
);

public static final Option POLICY_GRADIENT_POLICY_DISTRIBUTION = new Option(
"policygradient.policydistribution",
PolicyGradient.PolicyDistribution.CATEGORICAL.toString(),
PolicyGradient.DeepAtomPolicyDistribution.CATEGORICAL.toString(),
"The policy distribution to use for policy gradient learning."
);

public static final Option POLICY_GRADIENT_POLICY_UPDATE = new Option(
"policygradient.policyupdate",
PolicyGradient.PolicyUpdate.REINFORCE.toString(),
"The policy update to use for policy gradient learning."
);

public static final Option POSTGRES_HOST = new Option(
"postgres.host",
"localhost",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.linqs.psl.model.deep;

import org.linqs.psl.database.AtomStore;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.QueryAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.Predicate;
Expand All @@ -44,7 +43,7 @@ public class DeepModelPredicate extends DeepModel {
public static final String CONFIG_ENTITY_ARGUMENT_INDEXES = "entity-argument-indexes";
public static final String CONFIG_CLASS_SIZE = "class-size";

public AtomStore atomStore;
private AtomStore atomStore;
private Predicate predicate;

private int classSize;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,15 @@ public Map<String, Object> getPredicateOptions() {
public void setPredicateOption(String name, Object option) {
options.put(name, option);

if (name.equals("integer") && Boolean.parseBoolean(option.toString())) {
if (name.equals("Integer") && Boolean.parseBoolean(option.toString())) {
integer = true;
}

if (name.equals("categorical") && Boolean.parseBoolean(option.toString())) {
categorical = true;
}

if (name.equals("categoricalindexes")) {
if (name.equals("categoricalIndexes")) {
categoryIndexes = StringUtils.splitInt(option.toString(), DELIM);
for (int categoryIndex : categoryIndexes) {
identifierIndexes[categoryIndex] = -1;
Expand Down
4 changes: 0 additions & 4 deletions psl-core/src/main/java/org/linqs/psl/util/RandUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,6 @@ public static synchronized int sampleCategorical(float[] probabilities) {
return sampledCategory;
}

public static synchronized float nextGumbel() {
return (float) (-Math.log(-Math.log(nextDouble())));
}

/**
* Sample from a gamma distribution with the provided shape and scale parameters.
* See Marsaglia and Tsang (2000a): https://dl.acm.org/doi/10.1145/358407.358414
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* This file is part of the PSL software.
* Copyright 2011-2015 University of Maryland
* Copyright 2013-2023 The Regents of the University of California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.linqs.psl.application.learning.weight.gradient.policygradient;

import org.junit.Before;
import org.junit.Test;
import org.linqs.psl.application.inference.mpe.DualBCDInference;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.application.learning.weight.WeightLearningTest;
import org.linqs.psl.config.Options;

public class PolicyGradientBinaryCrossEntropyTest extends WeightLearningTest {
@Before
public void setup() {
super.setup();

Options.WLA_INFERENCE.set(DualBCDInference.class.getName());
}

@Override
protected WeightLearningApplication getBaseWLA() {
return new PolicyGradientBinaryCrossEntropy(info.model.getRules(), trainTargetDatabase, trainTruthDatabase,
validationTargetDatabase, validationTruthDatabase, false);
}

@Test
public void DualBCDFriendshipRankTest() {
Options.WLA_INFERENCE.set(DualBCDInference.class.getName());

super.friendshipRankTest();
}

@Test
public void DistributedDualBCDFriendshipRankTest() {
Options.WLA_INFERENCE.set(DualBCDInference.class.getName());

super.friendshipRankTest();
}
}

0 comments on commit c2a6e69

Please sign in to comment.