diff --git a/psl-core/src/main/java/org/linqs/psl/application/inference/InferenceApplication.java b/psl-core/src/main/java/org/linqs/psl/application/inference/InferenceApplication.java index a9246e3a8..39536b211 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/inference/InferenceApplication.java +++ b/psl-core/src/main/java/org/linqs/psl/application/inference/InferenceApplication.java @@ -173,9 +173,7 @@ public double inference(boolean commitAtoms, boolean reset, List batchLatentInferenceAtomValueStates; protected float[] rvLatentEnergyGradient; protected float[] deepLatentEnergyGradient; - protected float[] deepSupervisedLossGradient; + protected float[] deepPolicyGradient; protected float energyLossCoefficient; - protected float MAPStateSupervisedLoss; + protected float MAPStateEvaluation; protected float mapEnergy; protected float[] mapIncompatibility; @@ -76,6 +76,8 @@ public PolicyGradient(List rules, Database trainTargetDatabase, Database t deepAtomPolicyDistribution = DeepAtomPolicyDistribution.valueOf(Options.POLICY_GRADIENT_POLICY_DISTRIBUTION.getString().toUpperCase()); policyUpdate = PolicyUpdate.valueOf(Options.POLICY_GRADIENT_POLICY_UPDATE.getString().toUpperCase()); rewardFunction = RewardFunction.valueOf(Options.POLICY_GRADIENT_REWARD_FUNCTION.getString().toUpperCase()); + valueFunction = 0.0f; + actionValueFunction = null; numSamples = Options.POLICY_GRADIENT_NUM_SAMPLES.getInt(); actionSampleCounts = null; @@ -89,11 +91,11 @@ public PolicyGradient(List rules, Database trainTargetDatabase, Database t batchLatentInferenceAtomValueStates = new ArrayList(); rvLatentEnergyGradient = null; deepLatentEnergyGradient = null; - deepSupervisedLossGradient = null; + deepPolicyGradient = null; energyLossCoefficient = Options.POLICY_GRADIENT_ENERGY_LOSS_COEFFICIENT.getFloat(); - MAPStateSupervisedLoss = Float.POSITIVE_INFINITY; + MAPStateEvaluation = Float.NEGATIVE_INFINITY; mapEnergy = Float.POSITIVE_INFINITY; mapIncompatibility = new float[mutableRules.size()]; @@ -108,11 +110,11 @@ protected void initForLearning() { } } - protected abstract float computeSupervisedLoss(); + protected abstract float computeReward(); @Override protected float computeLearningLoss() { - return MAPStateSupervisedLoss + energyLossCoefficient * latentInferenceEnergy; + return (float) ((evaluation.getNormalizedMaxRepMetric() - MAPStateEvaluation) + energyLossCoefficient * latentInferenceEnergy); } @Override @@ -132,8 +134,8 @@ protected void initializeGradients() { rvLatentEnergyGradient = new float[trainFullMAPAtomValueState.length]; deepLatentEnergyGradient = new float[trainFullMAPAtomValueState.length]; - deepSupervisedLossGradient = new float[trainFullMAPAtomValueState.length]; - + deepPolicyGradient = new float[trainFullMAPAtomValueState.length]; + actionValueFunction = new float[trainFullMAPAtomValueState.length]; actionSampleCounts = new int[trainFullMAPAtomValueState.length]; } @@ -143,8 +145,8 @@ protected void resetGradients() { Arrays.fill(rvLatentEnergyGradient, 0.0f); Arrays.fill(deepLatentEnergyGradient, 0.0f); - Arrays.fill(deepSupervisedLossGradient, 0.0f); - + Arrays.fill(deepPolicyGradient, 0.0f); + Arrays.fill(actionValueFunction, 0.0f); Arrays.fill(actionSampleCounts, 0); } @@ -164,7 +166,7 @@ protected void computeIterationStatistics() { computeMAPInferenceStatistics(); - MAPStateSupervisedLoss = computeSupervisedLoss(); + MAPStateEvaluation = computeReward(); computeLatentInferenceStatistics(); @@ -172,57 +174,82 @@ protected void computeIterationStatistics() { // and to compute action probabilities. System.arraycopy(atomStore.getAtomValues(), 0, initialDeepAtomValues, 0, atomStore.size()); - switch (policyUpdate) { - case REINFORCE: - addREINFORCESupervisedLossGradient(0.0f); - break; - case REINFORCE_BASELINE: - addREINFORCESupervisedLossGradient(MAPStateSupervisedLoss); - break; - default: - throw new IllegalArgumentException("Unknown policy update: " + policyUpdate); - } + computeValueFunctionEstimates(); + computePolicyGradient(); } - private void addREINFORCESupervisedLossGradient(float baseline) { - for (int i = 0; i < numSamples; i++) { - sampleAllDeepAtomValues(); + private void computePolicyGradient() { + AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); + for (int atomIndex = 0; atomIndex < atomStore.size(); atomIndex++) { + GroundAtom atom = atomStore.getAtom(atomIndex); - computeMAPStateWithWarmStart(trainInferenceApplication, trainMAPTermState, trainMAPAtomValueState); + // Skip atoms that are not DeepAtoms. + if (!((atom instanceof RandomVariableAtom) && (atom.getPredicate() instanceof DeepPredicate))) { + continue; + } - float supervisedLoss = computeSupervisedLoss(); - float reward = 0.0f; + if (actionSampleCounts[atomIndex] == 0) { + deepPolicyGradient[atomIndex] = 0.0f; + continue; + } - switch (rewardFunction) { - case NEGATIVE_LOSS: - reward = 1.0f - supervisedLoss; + switch (policyUpdate) { + case REINFORCE: + deepPolicyGradient[atomIndex] = -1.0f * (actionValueFunction[atomIndex]) / atom.getValue(); break; - case NEGATIVE_LOSS_SQUARED: - reward = (float) Math.pow(1.0f - supervisedLoss, 2.0f); - break; - case INVERSE_LOSS: - // The inverse loss may result in a reward of infinity. - // Therefore, we add a small constant to the loss to avoid division by zero. - reward = 1.0f / (supervisedLoss + MathUtils.EPSILON_FLOAT); + case REINFORCE_BASELINE: + deepPolicyGradient[atomIndex] = -1.0f * (actionValueFunction[atomIndex] - valueFunction) / atom.getValue(); break; default: - throw new IllegalArgumentException("Unknown reward function: " + rewardFunction); + throw new IllegalArgumentException("Unknown policy update: " + policyUpdate); + } + } + + clipPolicyGradient(); + } + + /** + * Clip policy gradient to stabilize learning. + */ + private void clipPolicyGradient() { + AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); + + float gradientMagnitude = MathUtils.pNorm(deepPolicyGradient, maxGradientNorm); + + if (gradientMagnitude > maxGradientMagnitude) { +// log.trace("Clipping policy gradient. Original gradient magnitude: {} exceeds limit: {} in L_{} space.", +// gradientMagnitude, maxGradientMagnitude, maxGradientNorm); + for (int atomIndex = 0; atomIndex < atomStore.size(); atomIndex++) { + deepPolicyGradient[atomIndex] = maxGradientMagnitude * deepPolicyGradient[atomIndex] / gradientMagnitude; } + } + } + + private void computeValueFunctionEstimates() { + valueFunction = 0.0f; - addPolicyScoreGradient(reward - baseline); + for (int i = 0; i < numSamples; i++) { + sampleAllDeepAtomValues(); + + computeMAPStateWithWarmStart(trainInferenceApplication, trainMAPTermState, trainMAPAtomValueState); + + float reward = computeReward(); + addActionValue(reward); + valueFunction += reward; resetAllDeepAtomValues(); } - for (int i = 0; i < deepSupervisedLossGradient.length; i++) { + // Average the value functions. + valueFunction /= numSamples; + + for (int i = 0; i < actionValueFunction.length; i++) { if (actionSampleCounts[i] == 0) { - deepSupervisedLossGradient[i] = 0.0f; + actionValueFunction[i] = 0.0f; continue; } -// log.trace("Atom: {} Deep Supervised Loss Gradient: {}", -// trainInferenceApplication.getTermStore().getAtomStore().getAtom(i).toStringWithValue(), deepSupervisedLossGradient[i] / actionSampleCounts[i]); - deepSupervisedLossGradient[i] /= actionSampleCounts[i]; + actionValueFunction[i] /= actionSampleCounts[i]; } } @@ -338,17 +365,11 @@ protected void addLearningLossWeightGradient() { protected void addTotalAtomGradient() { for (int i = 0; i < trainInferenceApplication.getTermStore().getAtomStore().size(); i++) { rvGradient[i] = energyLossCoefficient * rvLatentEnergyGradient[i]; - deepGradient[i] = energyLossCoefficient * deepLatentEnergyGradient[i] + deepSupervisedLossGradient[i]; - - if (trainInferenceApplication.getTermStore().getAtomStore().getAtom(i).getPredicate() instanceof DeepPredicate) { - log.trace("Atom: {} deepLatentEnergyGradient: {}, deepSupervisedLossGradient: {}", - trainInferenceApplication.getTermStore().getAtomStore().getAtom(i).toStringWithValue(), - deepLatentEnergyGradient[i], deepSupervisedLossGradient[i]); - } + deepGradient[i] = energyLossCoefficient * deepLatentEnergyGradient[i] + deepPolicyGradient[i]; } } - private void addPolicyScoreGradient(float reward) { + private void addActionValue(float reward) { AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); for (int atomIndex = 0; atomIndex < atomStore.size(); atomIndex++) { GroundAtom atom = atomStore.getAtom(atomIndex); @@ -360,7 +381,7 @@ private void addPolicyScoreGradient(float reward) { switch (deepAtomPolicyDistribution) { case CATEGORICAL: - addCategoricalPolicyScoreGradient(atomIndex, (RandomVariableAtom) atom, reward); + addCategoricalActionValue(atomIndex, (RandomVariableAtom) atom, reward); break; default: throw new IllegalArgumentException("Unknown policy distribution: " + deepAtomPolicyDistribution); @@ -368,17 +389,16 @@ private void addPolicyScoreGradient(float reward) { } } - private void addCategoricalPolicyScoreGradient(int atomIndex, RandomVariableAtom atom, float reward) { + private void addCategoricalActionValue(int atomIndex, RandomVariableAtom atom, float reward) { // Skip atoms not selected by the policy. - if (atom.getValue() == 0.0f) { + if (MathUtils.isZero(atom.getValue())) { return; } switch (policyUpdate) { case REINFORCE: case REINFORCE_BASELINE: - // The initialDeepAtomValues are the action probabilities. - deepSupervisedLossGradient[atomIndex] -= reward / initialDeepAtomValues[atomIndex]; + actionValueFunction[atomIndex] += reward; break; default: throw new IllegalArgumentException("Unknown policy update: " + policyUpdate); diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropy.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropy.java deleted file mode 100644 index a6091ca65..000000000 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropy.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * This file is part of the PSL software. - * Copyright 2011-2015 University of Maryland - * Copyright 2013-2023 The Regents of the University of California - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.linqs.psl.application.learning.weight.gradient.policygradient; - -import org.linqs.psl.database.AtomStore; -import org.linqs.psl.database.Database; -import org.linqs.psl.model.atom.ObservedAtom; -import org.linqs.psl.model.atom.RandomVariableAtom; -import org.linqs.psl.model.rule.Rule; -import org.linqs.psl.util.Logger; -import org.linqs.psl.util.MathUtils; - -import java.util.List; -import java.util.Map; - -/** - * Learns parameters for a model by minimizing the squared error loss function - * using the policy gradient learning framework. - */ -public class PolicyGradientBinaryCrossEntropy extends PolicyGradient { - private static final Logger log = Logger.getLogger(PolicyGradientBinaryCrossEntropy.class); - - public PolicyGradientBinaryCrossEntropy(List rules, Database trainTargetDatabase, Database trainTruthDatabase, - Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) { - super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation); - } - - @Override - protected float computeSupervisedLoss() { - AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); - - float supervisedLoss = 0.0f; - int numEvaluatedAtoms = 0; - for (Map.Entry entry : trainingMap.getLabelMap().entrySet()) { - RandomVariableAtom randomVariableAtom = entry.getKey(); - ObservedAtom observedAtom = entry.getValue(); - - int atomIndex = atomStore.getAtomIndex(randomVariableAtom); - if (atomIndex == -1) { - // This atom is not in the current batch. - continue; - } - - supervisedLoss += (float) (-1.0f * (observedAtom.getValue() * Math.log(Math.max(atomStore.getAtom(atomIndex).getValue(), MathUtils.EPSILON_FLOAT)) - + (1.0f - observedAtom.getValue()) * Math.log(Math.max(1.0f - atomStore.getAtom(atomIndex).getValue(), MathUtils.EPSILON_FLOAT)))); - - numEvaluatedAtoms++; - } - - if (numEvaluatedAtoms > 0) { - supervisedLoss /= numEvaluatedAtoms; - } - - return supervisedLoss; - } -} diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryStep.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientEvaluation.java similarity index 55% rename from psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryStep.java rename to psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientEvaluation.java index abc4f0aa0..1bf7faa78 100644 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryStep.java +++ b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientEvaluation.java @@ -17,50 +17,39 @@ */ package org.linqs.psl.application.learning.weight.gradient.policygradient; -import org.linqs.psl.database.AtomStore; import org.linqs.psl.database.Database; -import org.linqs.psl.model.atom.ObservedAtom; -import org.linqs.psl.model.atom.RandomVariableAtom; import org.linqs.psl.model.rule.Rule; import org.linqs.psl.util.Logger; import org.linqs.psl.util.MathUtils; import java.util.List; -import java.util.Map; /** - * Learns parameters for a model by minimizing the squared error loss function + * Learns parameters for a model by minimizing the specified evaluation metric * using the policy gradient learning framework. */ -public class PolicyGradientBinaryStep extends PolicyGradient { - private static final Logger log = Logger.getLogger(PolicyGradientBinaryStep.class); +public class PolicyGradientEvaluation extends PolicyGradient { + private static final Logger log = Logger.getLogger(PolicyGradientEvaluation.class); - public PolicyGradientBinaryStep(List rules, Database trainTargetDatabase, Database trainTruthDatabase, + public PolicyGradientEvaluation(List rules, Database trainTargetDatabase, Database trainTruthDatabase, Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) { super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation); } @Override - protected float computeSupervisedLoss() { - AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); + protected float computeReward() { + evaluation.compute(trainingMap); - float supervisedLoss = 0.0f; - for (Map.Entry entry : trainingMap.getLabelMap().entrySet()) { - RandomVariableAtom randomVariableAtom = entry.getKey(); - ObservedAtom observedAtom = entry.getValue(); + float reward = (float) evaluation.getNormalizedRepMetric(); - int atomIndex = atomStore.getAtomIndex(randomVariableAtom); - if (atomIndex == -1) { - // This atom is not in the current batch. - continue; - } - - if (!MathUtils.equals(observedAtom.getValue(), atomStore.getAtom(atomIndex).getValue())) { - supervisedLoss = 1.0f; + switch (rewardFunction) { + case EVALUATION: + reward = reward; break; - } + default: + throw new IllegalArgumentException("Unknown reward function: " + rewardFunction); } - return supervisedLoss; + return reward; } } diff --git a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredError.java b/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredError.java deleted file mode 100644 index 0c0a12a71..000000000 --- a/psl-core/src/main/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredError.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * This file is part of the PSL software. - * Copyright 2011-2015 University of Maryland - * Copyright 2013-2023 The Regents of the University of California - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.linqs.psl.application.learning.weight.gradient.policygradient; - -import org.linqs.psl.database.AtomStore; -import org.linqs.psl.database.Database; -import org.linqs.psl.model.atom.ObservedAtom; -import org.linqs.psl.model.atom.RandomVariableAtom; -import org.linqs.psl.model.rule.Rule; - -import java.util.List; -import java.util.Map; - -/** - * Learns parameters for a model by minimizing the squared error loss function - * using the policy gradient learning framework. - */ -public class PolicyGradientSquaredError extends PolicyGradient { - public PolicyGradientSquaredError(List rules, Database trainTargetDatabase, Database trainTruthDatabase, - Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) { - super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation); - } - - @Override - protected float computeSupervisedLoss() { - AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore(); - - float supervisedLoss = 0.0f; - int numEvaluatedAtoms = 0; - for (Map.Entry entry: trainingMap.getLabelMap().entrySet()) { - RandomVariableAtom randomVariableAtom = entry.getKey(); - ObservedAtom observedAtom = entry.getValue(); - - int atomIndex = atomStore.getAtomIndex(randomVariableAtom); - if (atomIndex == -1) { - // This atom is not in the current batch. - continue; - } - - supervisedLoss += (float) Math.pow(atomStore.getAtom(atomIndex).getValue() - observedAtom.getValue(), 2.0f); - numEvaluatedAtoms++; - } - - if (numEvaluatedAtoms > 0) { - supervisedLoss /= numEvaluatedAtoms; - } - - return supervisedLoss; - } -} diff --git a/psl-core/src/main/java/org/linqs/psl/config/Options.java b/psl-core/src/main/java/org/linqs/psl/config/Options.java index 1b49a25cd..109b614e3 100644 --- a/psl-core/src/main/java/org/linqs/psl/config/Options.java +++ b/psl-core/src/main/java/org/linqs/psl/config/Options.java @@ -742,7 +742,7 @@ public class Options { public static final Option POLICY_GRADIENT_REWARD_FUNCTION = new Option( "policygradient.rewardfunction", - PolicyGradient.RewardFunction.NEGATIVE_LOSS.toString(), + PolicyGradient.RewardFunction.EVALUATION.toString(), "The reward function to use for policy gradient learning." ); diff --git a/psl-core/src/main/java/org/linqs/psl/evaluation/statistics/CategoricalEvaluator.java b/psl-core/src/main/java/org/linqs/psl/evaluation/statistics/CategoricalEvaluator.java index 81db2504d..12fb997ed 100644 --- a/psl-core/src/main/java/org/linqs/psl/evaluation/statistics/CategoricalEvaluator.java +++ b/psl-core/src/main/java/org/linqs/psl/evaluation/statistics/CategoricalEvaluator.java @@ -194,13 +194,11 @@ private Set getTrueCategoryIndexes(StandardPredicate predicate) { categoryIndexes.add(Integer.valueOf(index)); } - log.trace("True category indexes for {}: [{}].", predicate.getName(), StringUtils.join(", ", categoryIndexes.toArray())); - return categoryIndexes; } /** - * Build up a set that has all the atoms that represet the best categorical assignments. + * Build up a set that has all the atoms that represent the best categorical assignments. */ protected Set getPredictedCategories(TrainingMap trainingMap, StandardPredicate predicate) { // This map will be as deep as the number of category arguments. diff --git a/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModelPredicate.java b/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModelPredicate.java index 4b30ead2f..429e28d2d 100644 --- a/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModelPredicate.java +++ b/psl-core/src/main/java/org/linqs/psl/model/deep/DeepModelPredicate.java @@ -43,6 +43,8 @@ public class DeepModelPredicate extends DeepModel { public static final String CONFIG_ENTITY_ARGUMENT_INDEXES = "entity-argument-indexes"; public static final String CONFIG_CLASS_SIZE = "class-size"; + private static HashMap> dataMapEntities; + private AtomStore atomStore; private Predicate predicate; @@ -60,6 +62,10 @@ public class DeepModelPredicate extends DeepModel { public DeepModelPredicate(Predicate predicate) { super("DeepModelPredicate"); + if (dataMapEntities == null) { + dataMapEntities = new HashMap>(); + } + this.atomStore = null; this.predicate = predicate; @@ -247,6 +253,9 @@ public void close() { gradients = null; symbolicGradients = null; + dataMapEntities.get(predicate).clear(); + dataMapEntities.remove(predicate); + validAtomIndexes.clear(); validDataIndexes.clear(); @@ -310,17 +319,15 @@ private void validateOptions() { } } - /** - * Read the entities from a file and map to atom indexes. - */ - private int mapEntitiesFromFileToAtoms(String filePath, AtomStore atomStore, int numEntityArgs) { - Constant[] arguments = new Constant[numEntityArgs + 1]; + private void readEntityDataMap(String filePath, int numEntityArgs) { + dataMapEntities.put(predicate, new ArrayList()); + ConstantType type; String line = null; int lineNumber = 0; - int atomIndex = 0; - int dataIndex = 0; + + log.trace("Reading Entity Data Map: {}", filePath); try (BufferedReader reader = FileUtils.getBufferedReader(filePath)) { while ((line = reader.readLine()) != null) { @@ -340,6 +347,7 @@ private int mapEntitiesFromFileToAtoms(String filePath, AtomStore atomStore, int lineNumber, numEntityArgs, predicate.getName())); } + Constant[] arguments = new Constant[numEntityArgs + 1]; // Get constant types for this entity. for (int index = 0; index < arguments.length - 1; index++) { type = predicate.getArgumentType(index); @@ -349,38 +357,55 @@ private int mapEntitiesFromFileToAtoms(String filePath, AtomStore atomStore, int // Add atom index and data index for each class. type = predicate.getArgumentType(arguments.length - 1); - QueryAtom queryAtom = null; - for (int index = 0; index < classSize; index++) { - arguments[arguments.length - 1] = ConstantType.getConstant(String.valueOf(index), type); - - if (index == 0) { - queryAtom = new QueryAtom(predicate, arguments); - } else { - queryAtom.assume(predicate, arguments); - } - - atomIndex = atomStore.getAtomIndex(queryAtom); - if (atomIndex == -1) { - break; - } - validAtomIndexes.add(atomIndex); - } - - // Verify that the entities have atoms for all classes. - if (validAtomIndexes.size() % classSize != 0) { - throw new RuntimeException(String.format( - "Entity found on line (%d) has unspecified class values for predicate %s.", - lineNumber, predicate.getName())); - } + arguments[arguments.length - 1] = ConstantType.getConstant(String.valueOf(0), type); - if (atomIndex != -1) { - validDataIndexes.add(dataIndex); - } - dataIndex++; + dataMapEntities.get(predicate).add(new QueryAtom(predicate, arguments)); } } catch (IOException ex) { throw new RuntimeException("Unable to parse entity data map file: " + filePath, ex); } + } + + /** + * Read the entities from a file and map to atom indexes. + */ + private int mapEntitiesFromFileToAtoms(String filePath, AtomStore atomStore, int numEntityArgs) { + int lineNumber = 0; + int atomIndex = 0; + int dataIndex = 0; + + if (!dataMapEntities.containsKey(predicate)) { + readEntityDataMap(filePath, numEntityArgs); + } + + for (QueryAtom entity : dataMapEntities.get(predicate)) { + Constant[] arguments = (Constant[]) entity.getArguments(); + + // Add atom index and data index for each class. + for (int classIndex = 0; classIndex < classSize; classIndex++) { + arguments[arguments.length - 1] = ConstantType.getConstant(String.valueOf(classIndex), + predicate.getArgumentType(arguments.length - 1)); + entity.assume(predicate, arguments); + + atomIndex = atomStore.getAtomIndex(entity); + if (atomIndex == -1) { + break; + } + validAtomIndexes.add(atomIndex); + } + + // Verify that the entities have atoms for all classes. + if (validAtomIndexes.size() % classSize != 0) { + throw new RuntimeException(String.format( + "Entity found on line (%d) has unspecified class values for predicate %s.", + lineNumber, predicate.getName())); + } + + if (atomIndex != -1) { + validDataIndexes.add(dataIndex); + } + dataIndex++; + } return dataIndex; } diff --git a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropyTest.java b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropyTest.java deleted file mode 100644 index c224cf766..000000000 --- a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryCrossEntropyTest.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * This file is part of the PSL software. - * Copyright 2011-2015 University of Maryland - * Copyright 2013-2023 The Regents of the University of California - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.linqs.psl.application.learning.weight.gradient.policygradient; - -import org.junit.Before; -import org.junit.Test; -import org.linqs.psl.application.inference.mpe.DualBCDInference; -import org.linqs.psl.application.learning.weight.WeightLearningApplication; -import org.linqs.psl.application.learning.weight.WeightLearningTest; -import org.linqs.psl.config.Options; - -public class PolicyGradientBinaryCrossEntropyTest extends WeightLearningTest { - @Before - public void setup() { - super.setup(); - - Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); - Options.WLA_GRADIENT_DESCENT_SYMBOLIC_LEARNING.set(false); - } - - @Override - protected WeightLearningApplication getBaseWLA() { - return new PolicyGradientBinaryCrossEntropy(info.model.getRules(), trainTargetDatabase, trainTruthDatabase, - validationTargetDatabase, validationTruthDatabase, false); - } - - @Test - public void DualBCDFriendshipRankTest() { - Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); - - super.friendshipRankTest(); - } - - @Test - public void DistributedDualBCDFriendshipRankTest() { - Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); - - super.friendshipRankTest(); - } -} diff --git a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryStepTest.java b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientEvaluationTest.java similarity index 93% rename from psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryStepTest.java rename to psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientEvaluationTest.java index ff1f1a267..f82574e80 100644 --- a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientBinaryStepTest.java +++ b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientEvaluationTest.java @@ -24,7 +24,7 @@ import org.linqs.psl.application.learning.weight.WeightLearningTest; import org.linqs.psl.config.Options; -public class PolicyGradientBinaryStepTest extends WeightLearningTest { +public class PolicyGradientEvaluationTest extends WeightLearningTest { @Before public void setup() { super.setup(); @@ -35,7 +35,7 @@ public void setup() { @Override protected WeightLearningApplication getBaseWLA() { - return new PolicyGradientBinaryStep(info.model.getRules(), trainTargetDatabase, trainTruthDatabase, + return new PolicyGradientEvaluation(info.model.getRules(), trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, false); } diff --git a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredErrorTest.java b/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredErrorTest.java deleted file mode 100644 index 624d6b143..000000000 --- a/psl-core/src/test/java/org/linqs/psl/application/learning/weight/gradient/policygradient/PolicyGradientSquaredErrorTest.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * This file is part of the PSL software. - * Copyright 2011-2015 University of Maryland - * Copyright 2013-2023 The Regents of the University of California - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.linqs.psl.application.learning.weight.gradient.policygradient; - -import org.junit.Before; -import org.junit.Test; -import org.linqs.psl.application.inference.mpe.DualBCDInference; -import org.linqs.psl.application.learning.weight.WeightLearningApplication; -import org.linqs.psl.application.learning.weight.WeightLearningTest; -import org.linqs.psl.config.Options; - -public class PolicyGradientSquaredErrorTest extends WeightLearningTest { - @Before - public void setup() { - super.setup(); - - Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); - Options.WLA_GRADIENT_DESCENT_SYMBOLIC_LEARNING.set(false); - } - - @Override - protected WeightLearningApplication getBaseWLA() { - return new PolicyGradientSquaredError(info.model.getRules(), trainTargetDatabase, trainTruthDatabase, - validationTargetDatabase, validationTruthDatabase, false); - } - - @Test - public void DualBCDFriendshipRankTest() { - Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); - - super.friendshipRankTest(); - } - - @Test - public void DistributedDualBCDFriendshipRankTest() { - Options.WLA_INFERENCE.set(DualBCDInference.class.getName()); - - super.friendshipRankTest(); - } -} diff --git a/psl-java/src/main/java/org/linqs/psl/runtime/Runtime.java b/psl-java/src/main/java/org/linqs/psl/runtime/Runtime.java index c6f4b6677..a30e8b8d3 100644 --- a/psl-java/src/main/java/org/linqs/psl/runtime/Runtime.java +++ b/psl-java/src/main/java/org/linqs/psl/runtime/Runtime.java @@ -434,7 +434,9 @@ protected void runInferenceInternal(RuntimeConfig config, Model model, RuntimeRe while (runInference) { DeepPredicate.predictAllDeepPredicates(); + log.info("Beginning inference."); inferenceApplication.inference(RuntimeOptions.INFERENCE_COMMIT.getBoolean(), false, evaluations, truthDatabase); + log.info("Inference complete."); if (RuntimeOptions.INFERENCE_OUTPUT_RESULTS.getBoolean()) { String outputDir = RuntimeOptions.INFERENCE_OUTPUT_RESULTS_DIR.getString();