diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradient.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradient.java index 622dc6828..1f0b41692 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradient.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradient.java @@ -14,7 +14,6 @@ import org.linqs.psl.reasoner.term.SimpleTermStore; import org.linqs.psl.reasoner.term.TermState; import org.linqs.psl.util.Logger; -import org.linqs.psl.util.MathUtils; import org.linqs.psl.util.RandUtils; import java.util.ArrayList; @@ -28,17 +27,20 @@ public abstract class PolicyGradient extends GradientDescent { private static final Logger log = Logger.getLogger(PolicyGradient.class); - public enum PolicyDistribution { - CATEGORICAL, - GUMBEL_SOFTMAX + public enum DeepAtomPolicyDistribution { + CATEGORICAL } - private final PolicyDistribution policyDistribution; + public enum PolicyUpdate { + REINFORCE, + REINFORCE_BASELINE, + } - private float lossMovingAverage; - private float[] sampleProbabilities; + private final DeepAtomPolicyDistribution deepAtomPolicyDistribution; + private final PolicyUpdate policyUpdate; - private final float gumbelSoftmaxTemperature; + private float scoreMovingAverage; + private float[] sampleProbabilities; protected float[] initialDeepAtomValues; protected float[] policySampledDeepAtomValues; @@ -61,13 +63,12 @@ public PolicyGradient(List rules, Database trainTargetDatabase, Database t Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) { super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation); - policyDistribution = PolicyDistribution.valueOf(Options.POLICY_GRADIENT_POLICY_DISTRIBUTION.getString().toUpperCase()); - lossMovingAverage = 0.0f; + deepAtomPolicyDistribution = DeepAtomPolicyDistribution.valueOf(Options.POLICY_GRADIENT_POLICY_DISTRIBUTION.getString().toUpperCase()); + policyUpdate = PolicyUpdate.valueOf(Options.POLICY_GRADIENT_POLICY_UPDATE.getString().toUpperCase()); + scoreMovingAverage = 0.0f; sampleProbabilities = null; - gumbelSoftmaxTemperature = Options.POLICY_GRADIENT_GUMBEL_SOFTMAX_TEMPERATURE.getFloat(); - initialDeepAtomValues = null; policySampledDeepAtomValues = null; @@ -90,7 +91,7 @@ public PolicyGradient(List rules, Database trainTargetDatabase, Database t protected void initForLearning() { super.initForLearning(); - lossMovingAverage = 0.0f; + scoreMovingAverage = 0.0f; } protected abstract void computeSupervisedLoss(); @@ -161,15 +162,12 @@ protected void sampleDeepAtomValues() { } } - switch (policyDistribution) { + switch (deepAtomPolicyDistribution) { case CATEGORICAL: sampleCategorical(); break; - case GUMBEL_SOFTMAX: - sampleGumbelSoftmax(); - break; default: - throw new IllegalArgumentException("Unknown policy distribution: " + policyDistribution); + throw new IllegalArgumentException("Unknown policy distribution: " + deepAtomPolicyDistribution); } } @@ -200,7 +198,6 @@ private void sampleCategorical() { } else { sampleProbabilities[atomIndex] = categoryProbabilities[i]; categories.get(i).setValue(1.0f); -// log.trace("Sampled category: {}.", categories.get(i).toStringWithValue()); } policySampledDeepAtomValues[atomIndex] = categories.get(i).getValue(); @@ -210,43 +207,6 @@ private void sampleCategorical() { } } - private void sampleGumbelSoftmax() { - AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); - float[] atomValues = atomStore.getAtomValues(); - - // Sample the deep model predictions according to the stochastic gumbel softmax policy. - for (DeepModelPredicate deepModelPredicate : trainDeepModelPredicates) { - - Map > atomIdentiferToCategories = deepModelPredicate.getAtomIdentiferToCategories(); - for (Map.Entry> entry : atomIdentiferToCategories.entrySet()) { - String atomIdentifier = entry.getKey(); - ArrayList categories = entry.getValue(); - - float[] gumbelSoftmaxSample = new float[categories.size()]; - float categoryProbabilitySum = 0.0f; - for (int i = 0; i < categories.size(); i++) { - float gumbelSample = RandUtils.nextGumbel(); - - gumbelSoftmaxSample[i] = (float) Math.exp((Math.log(Math.max(categories.get(i).getValue(), MathUtils.STRICT_EPSILON)) + gumbelSample) - / gumbelSoftmaxTemperature); - categoryProbabilitySum += gumbelSoftmaxSample[i]; - } - - // Renormalize the probabilities and set the deep atom values. - for (int i = 0; i < categories.size(); i++) { - gumbelSoftmaxSample[i] = gumbelSoftmaxSample[i] / categoryProbabilitySum; - - RandomVariableAtom category = categories.get(i); - int atomIndex = atomStore.getAtomIndex(category); - - category.setValue(gumbelSoftmaxSample[i]); - policySampledDeepAtomValues[atomIndex] = gumbelSoftmaxSample[i]; - atomValues[atomIndex] = gumbelSoftmaxSample[i]; - } - } - } - } - private void resetDeepAtomValues() { AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); float[] atomValues = atomStore.getAtomValues(); @@ -315,36 +275,37 @@ protected void computeTotalAtomGradient() { continue; } - switch (policyDistribution) { + switch (deepAtomPolicyDistribution) { case CATEGORICAL: - computeCategoricalAtomGradient(atom, i); - break; - case GUMBEL_SOFTMAX: - computeGumbelSoftmaxAtomGradient(atom, i); + computeCategoricalAtomGradient(i); break; default: - throw new IllegalArgumentException("Unknown policy distribution: " + policyDistribution); + throw new IllegalArgumentException("Unknown policy distribution: " + deepAtomPolicyDistribution); } } } - private void computeCategoricalAtomGradient(GroundAtom atom, int atomIndex) { + private void computeCategoricalAtomGradient(int atomIndex) { if (policySampledDeepAtomValues[atomIndex] == 0.0f) { deepAtomGradient[atomIndex] = 0.0f; return; } float score = energyLossCoefficient * latentInferenceEnergy + supervisedLoss; - lossMovingAverage = 0.9f * lossMovingAverage + 0.1f * score; + scoreMovingAverage = 0.9f * scoreMovingAverage + 0.1f * score; - deepAtomGradient[atomIndex] += (score - lossMovingAverage) / sampleProbabilities[atomIndex]; - -// log.trace("Atom: {}, Score: {}, Gradient: {}.", atom.toStringWithValue(), score, deepAtomGradient[atomIndex]); + switch (policyUpdate) { + case REINFORCE: + deepAtomGradient[atomIndex] += score / sampleProbabilities[atomIndex]; + break; + case REINFORCE_BASELINE: + deepAtomGradient[atomIndex] += (score - scoreMovingAverage) / sampleProbabilities[atomIndex]; + break; + default: + throw new IllegalArgumentException("Unknown policy update: " + policyUpdate); + } } - private void computeGumbelSoftmaxAtomGradient(GroundAtom atom, int atomIndex) { - // TODO(Charles): Compute the energy and supervised policy gradient for the gumbel softmax policy. - } /** * Set RandomVariableAtoms with labels to their observed (truth) value. diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropy.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropy.java new file mode 100644 index 000000000..a081485a8 --- /dev/null +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropy.java @@ -0,0 +1,59 @@ +/* + * This file is part of the PSL software. + * Copyright 2011-2015 University of Maryland + * Copyright 2013-2023 The Regents of the University of California + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.linqs.psl.application.learning.weight.gradient.policygradient; + +import org.linqs.psl.database.AtomStore; +import org.linqs.psl.database.Database; +import org.linqs.psl.model.atom.ObservedAtom; +import org.linqs.psl.model.atom.RandomVariableAtom; +import org.linqs.psl.model.rule.Rule; +import org.linqs.psl.util.MathUtils; + +import java.util.List; +import java.util.Map; + +/** + * Learns parameters for a model by minimizing the squared error loss function + * using the policy gradient learning framework. + */ +public class PolicyGradientBinaryCrossEntropy extends PolicyGradient { + public PolicyGradientBinaryCrossEntropy(List rules, Database trainTargetDatabase, Database trainTruthDatabase, + Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) { + super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation); + } + + @Override + protected void computeSupervisedLoss() { + AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); + + supervisedLoss = 0.0f; + for (Map.Entry entry: trainingMap.getLabelMap().entrySet()) { + RandomVariableAtom randomVariableAtom = entry.getKey(); + ObservedAtom observedAtom = entry.getValue(); + + int atomIndex = atomStore.getAtomIndex(randomVariableAtom); + if (atomIndex == -1) { + // This atom is not in the current batch. + continue; + } + + supervisedLoss += -1.0f * (observedAtom.getValue() * Math.log(Math.max(atomStore.getAtom(atomIndex).getValue(), MathUtils.EPSILON_FLOAT)) + + (1.0f - observedAtom.getValue()) * Math.log(Math.max(1.0f - atomStore.getAtom(atomIndex).getValue(), MathUtils.EPSILON_FLOAT))); + } + } +} diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredError.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredError.java index fca5ac585..48d78cea5 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredError.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredError.java @@ -28,7 +28,7 @@ /** * Learns parameters for a model by minimizing the squared error loss function - * using the minimizer-based learning framework. + * using the policy gradient learning framework. */ public class PolicyGradientSquaredError extends PolicyGradient { public PolicyGradientSquaredError(List rules, Database trainTargetDatabase, Database trainTruthDatabase, diff --git a/psl-core/src/main/java/org/linqs/psl/config/Options.java b/psl-core/src/main/java/org/linqs/psl/config/Options.java index 73af7cf61..a7ec3358c 100644 --- a/psl-core/src/main/java/org/linqs/psl/config/Options.java +++ b/psl-core/src/main/java/org/linqs/psl/config/Options.java @@ -709,19 +709,18 @@ public class Options { + " not having the atom initially in the database." ); - public static final Option POLICY_GRADIENT_GUMBEL_SOFTMAX_TEMPERATURE = new Option( - "policygradient.gumbelsoftmax.temperature", - 1.0f, - "The temperature parameter for the Gumbel-Softmax distribution.", - Option.FLAG_POSITIVE - ); - public static final Option POLICY_GRADIENT_POLICY_DISTRIBUTION = new Option( "policygradient.policydistribution", - PolicyGradient.PolicyDistribution.CATEGORICAL.toString(), + PolicyGradient.DeepAtomPolicyDistribution.CATEGORICAL.toString(), "The policy distribution to use for policy gradient learning." ); + public static final Option POLICY_GRADIENT_POLICY_UPDATE = new Option( + "policygradient.policyupdate", + PolicyGradient.PolicyUpdate.REINFORCE.toString(), + "The policy update to use for policy gradient learning." + ); + public static final Option POSTGRES_HOST = new Option( "postgres.host", "localhost", diff --git a/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModelPredicate.java b/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModelPredicate.java index 0ac222dec..a41e89199 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModelPredicate.java +++ b/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModelPredicate.java @@ -18,7 +18,6 @@ package org.linqs.psl.model.deep; import org.linqs.psl.database.AtomStore; -import org.linqs.psl.model.atom.GroundAtom; import org.linqs.psl.model.atom.QueryAtom; import org.linqs.psl.model.atom.RandomVariableAtom; import org.linqs.psl.model.predicate.Predicate; @@ -44,7 +43,7 @@ public class DeepModelPredicate extends DeepModel { public static final String CONFIG_ENTITY_ARGUMENT_INDEXES = "entity-argument-indexes"; public static final String CONFIG_CLASS_SIZE = "class-size"; - public AtomStore atomStore; + private AtomStore atomStore; private Predicate predicate; private int classSize; diff --git a/psl-core/src/main/java/org/linqs/psl/model/predicate/Predicate.java b/psl-core/src/main/java/org/linqs/psl/model/predicate/Predicate.java index ac335dc93..5c9094abf 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/predicate/Predicate.java +++ b/psl-core/src/main/java/org/linqs/psl/model/predicate/Predicate.java @@ -124,7 +124,7 @@ public Map getPredicateOptions() { public void setPredicateOption(String name, Object option) { options.put(name, option); - if (name.equals("integer") && Boolean.parseBoolean(option.toString())) { + if (name.equals("Integer") && Boolean.parseBoolean(option.toString())) { integer = true; } @@ -132,7 +132,7 @@ public void setPredicateOption(String name, Object option) { categorical = true; } - if (name.equals("categoricalindexes")) { + if (name.equals("categoricalIndexes")) { categoryIndexes = StringUtils.splitInt(option.toString(), DELIM); for (int categoryIndex : categoryIndexes) { identifierIndexes[categoryIndex] = -1; diff --git a/psl-core/src/main/java/org/linqs/psl/util/RandUtils.java b/psl-core/src/main/java/org/linqs/psl/util/RandUtils.java index e1c8c7d98..0953e90f8 100644 --- a/psl-core/src/main/java/org/linqs/psl/util/RandUtils.java +++ b/psl-core/src/main/java/org/linqs/psl/util/RandUtils.java @@ -213,10 +213,6 @@ public static synchronized int sampleCategorical(float[] probabilities) { return sampledCategory; } - public static synchronized float nextGumbel() { - return (float) (-Math.log(-Math.log(nextDouble()))); - } - /** * Sample from a gamma distribution with the provided shape and scale parameters. * See Marsaglia and Tsang (2000a): https://dl.acm.org/doi/10.1145/358407.358414 diff --git a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropyTest.java b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropyTest.java new file mode 100644 index 000000000..1fff7e688 --- /dev/null +++ b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropyTest.java @@ -0,0 +1,54 @@ +/* + * This file is part of the PSL software. + * Copyright 2011-2015 University of Maryland + * Copyright 2013-2023 The Regents of the University of California + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.linqs.psl.application.learning.weight.gradient.policygradient; + +import org.junit.Before; +import org.junit.Test; +import org.linqs.psl.application.inference.mpe.DualBCDInference; +import org.linqs.psl.application.learning.weight.WeightLearningApplication; +import org.linqs.psl.application.learning.weight.WeightLearningTest; +import org.linqs.psl.config.Options; + +public class PolicyGradientBinaryCrossEntropyTest extends WeightLearningTest { + @Before + public void setup() { + super.setup(); + + Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); + } + + @Override + protected WeightLearningApplication getBaseWLA() { + return new PolicyGradientBinaryCrossEntropy(info.model.getRules(), trainTargetDatabase, trainTruthDatabase, + validationTargetDatabase, validationTruthDatabase, false); + } + + @Test + public void DualBCDFriendshipRankTest() { + Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); + + super.friendshipRankTest(); + } + + @Test + public void DistributedDualBCDFriendshipRankTest() { + Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); + + super.friendshipRankTest(); + } +}