diff --git a/core/src/main/java/hivemall/factorization/cofactor/CofactorModel.java b/core/src/main/java/hivemall/factorization/cofactor/CofactorModel.java new file mode 100644 index 000000000..f26ce559e --- /dev/null +++ b/core/src/main/java/hivemall/factorization/cofactor/CofactorModel.java @@ -0,0 +1,1139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.factorization.cofactor; + +import hivemall.annotations.VisibleForTesting; +import hivemall.fm.Feature; +import hivemall.fm.StringFeature; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.math.MathUtils; +import it.unimi.dsi.fastutil.objects.Object2DoubleArrayMap; +import it.unimi.dsi.fastutil.objects.Object2DoubleMap; +import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.commons.logging.Log; +import org.apache.commons.math3.linear.Array2DRowRealMatrix; +import org.apache.commons.math3.linear.ArrayRealVector; +import org.apache.commons.math3.linear.DecompositionSolver; +import org.apache.commons.math3.linear.LUDecomposition; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealVector; +import org.apache.commons.math3.linear.SingularValueDecomposition; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.mapred.Counters; + +public class CofactorModel { + + @Nonnegative + private final int factor; + + // rank matrix initialization + private final RankInitScheme initScheme; + + private double globalBias; + + // storing trainable latent factors and weights + private final Weights theta; + private final Weights beta; + private final Object2DoubleMap betaBias; + private final Weights gamma; + private final Object2DoubleMap gammaBias; + + private final Random[] randU, randI; + + // hyperparameters + @Nonnegative + private final float c0, c1; + private final float lambdaTheta, lambdaBeta, lambdaGamma; + + // validation + private final CofactorizationUDTF.ValidationMetric validationMetric; + private final Feature[] validationProbes; + private final Prediction[] predictions; + private final int numValPerRecord; + private String[] users; + private String[] items; + + // solve + private final RealMatrix B; + private final RealVector A; + + // counters + private Counters.Counter userCounter; + private Counters.Counter itemCounter; + private Counters.Counter skippedUserCounter; + private Counters.Counter skippedItemCounter; + private Counters.Counter thetaTrainableFeaturesCounter; + private Counters.Counter thetaTotalFeaturesCounter; + private Counters.Counter betaTrainableFeaturesCounter; + private Counters.Counter betaTotalFeaturesCounter; + + protected static class Weights extends Object2ObjectOpenHashMap { + private static final long serialVersionUID = -7048382051969687548L; + + protected Object[] getKey() { + return key; + } + + @Nonnull + String[] getNonnullKeys() { + final String[] keys = new String[size]; + final Object[] k = (Object[]) key; + final int len = k.length; + for (int i = 0, j = 0; i < len; i++) { + final Object ki = k[i]; + if (ki != null) { + keys[j++] = ki.toString(); + } + } + return keys; + } + } + + public enum RankInitScheme { + random /* default */, gaussian; + @Nonnegative + private float maxInitValue; + @Nonnegative + private double initStdDev; + + @Nonnull + public static CofactorModel.RankInitScheme resolve(@Nullable String opt) { + if (opt == null) { + return random; + } else if ("gaussian".equalsIgnoreCase(opt)) { + return gaussian; + } else if ("random".equalsIgnoreCase(opt)) { + return random; + } + return random; + } + + public void setMaxInitValue(float maxInitValue) { + this.maxInitValue = maxInitValue; + } + + public void setInitStdDev(double initStdDev) { + this.initStdDev = initStdDev; + } + } + + private static class Prediction implements Comparable { + + private double prediction; + private int label; + + @Override + public int compareTo(@Nonnull Prediction other) { + // descending order + return -Double.compare(prediction, other.prediction); + } + + } + + public CofactorModel(@Nonnegative final int factor, @Nonnull final RankInitScheme initScheme, + @Nonnegative final float c0, @Nonnegative final float c1, + @Nonnegative final float lambdaTheta, @Nonnegative final float lambdaBeta, + @Nonnegative final float lambdaGamma, final float globalBias, + @Nullable CofactorizationUDTF.ValidationMetric validationMetric, + @Nonnegative final int numValPerRecord, @Nonnull final Log log) { + + // rank init scheme is gaussian + // https://github.com/dawenl/cofactor/blob/master/src/cofacto.py#L98 + this.factor = factor; + this.initScheme = initScheme; + this.globalBias = globalBias; + this.lambdaTheta = lambdaTheta; + this.lambdaBeta = lambdaBeta; + this.lambdaGamma = lambdaGamma; + + this.theta = new Weights(); + this.beta = new Weights(); + this.betaBias = new Object2DoubleArrayMap<>(); + this.betaBias.defaultReturnValue(0.d); + this.gamma = new Weights(); + this.gammaBias = new Object2DoubleArrayMap<>(); + this.gammaBias.defaultReturnValue(0.d); + + this.B = new Array2DRowRealMatrix(this.factor, this.factor); + this.A = new ArrayRealVector(this.factor); + + this.randU = newRandoms(factor, 31L); + this.randI = newRandoms(factor, 41L); + + Preconditions.checkArgument(c0 >= 0.f && c0 <= 1.f); + Preconditions.checkArgument(c1 >= 0.f && c1 <= 1.f); + + this.c0 = c0; + this.c1 = c1; + + if (validationMetric == null) { + this.validationMetric = CofactorizationUDTF.ValidationMetric.AUC; + } else { + this.validationMetric = validationMetric; + } + + this.numValPerRecord = numValPerRecord; + this.validationProbes = new Feature[numValPerRecord]; + this.predictions = new Prediction[numValPerRecord]; + for (int i = 0; i < validationProbes.length; i++) { + validationProbes[i] = new StringFeature("", 0.d); + predictions[i] = new Prediction(); + } + } + + private void initFactorVector(final String key, final Weights weights) throws HiveException { + if (weights.containsKey(key)) { + throw new HiveException(String.format( + "two items or two users cannot have same `context` in training set: found duplicate context `%s`", + key)); + } + final double[] v = new double[factor]; + switch (initScheme) { + case random: + uniformFill(v, randI[0], initScheme.maxInitValue); + break; + case gaussian: + gaussianFill(v, randI, initScheme.initStdDev); + break; + default: + throw new IllegalStateException( + "Unsupported rank initialization scheme: " + initScheme); + + } + weights.put(key, v); + } + + @Nullable + private static double[] getFactorVector(String key, Weights weights) { + return weights.get(key); + } + + private static void setFactorVector(final String key, final Weights weights, + final RealVector factorVector) throws HiveException { + final double[] vec = weights.get(key); + if (vec == null) { + throw new HiveException(); + } + copyData(vec, factorVector); + } + + private static double getBias(String key, Object2DoubleMap biases) { + return biases.getDouble(key); + } + + private static void setBias(String key, Object2DoubleMap biases, double value) { + biases.put(key, value); + } + + @Nullable + public double[] getGammaVector(@Nonnull final String key) { + return getFactorVector(key, gamma); + } + + public double getGammaBias(@Nonnull final String key) { + return getBias(key, gammaBias); + } + + public void setGammaBias(@Nonnull final String key, final double value) { + setBias(key, gammaBias, value); + } + + public double getGlobalBias() { + return globalBias; + } + + public void setGlobalBias(final double value) { + globalBias = value; + } + + @Nullable + public double[] getThetaVector(@Nonnull final String key) { + return getFactorVector(key, theta); + } + + @Nullable + public double[] getBetaVector(@Nonnull final String key) { + return getFactorVector(key, beta); + } + + public double getBetaBias(@Nonnull final String key) { + return getBias(key, betaBias); + } + + public void setBetaBias(@Nonnull final String key, final double value) { + setBias(key, betaBias, value); + } + + @Nonnull + public Weights getTheta() { + return theta; + } + + @Nonnull + public Weights getBeta() { + return beta; + } + + @Nonnull + public Weights getGamma() { + return gamma; + } + + @Nonnull + public Object2DoubleMap getBetaBiases() { + return betaBias; + } + + @Nonnull + public Object2DoubleMap getGammaBiases() { + return gammaBias; + } + + public void updateWithUsers(@Nonnull final Map> userToItems) + throws HiveException { + updateTheta(userToItems); + } + + public void updateWithItems(@Nonnull final Map> items, + Map sppmi) throws HiveException { + updateBeta(items, sppmi); + updateGamma(items, sppmi); + updateBetaBias(items, sppmi); + updateGammaBias(items, sppmi); + updateGlobalBias(items, sppmi); + } + + /** + * Update latent factors of the users in the provided mini-batch. + * + * @param samples + */ + private void updateTheta(@Nonnull final Map> samples) + throws HiveException { + // initialize item factors + // items should only be trainable if the dataset contains a major entry for that item (which it may not) + // variable names follow cofacto.py + final double[][] BTBpR = calculateWTWpR(beta, factor, c0, lambdaTheta); + + for (Map.Entry> sample : samples.entrySet()) { + RealVector newThetaVec = + calculateNewThetaVector(sample, beta, factor, B, A, BTBpR, c0, c1); + if (newThetaVec != null) { + setFactorVector(sample.getKey(), theta, newThetaVec); + } else { + skippedUserCounter.increment(1); + } + userCounter.increment(1); + } + } + + @VisibleForTesting + protected RealVector calculateNewThetaVector( + @Nonnull final Map.Entry> sample, @Nonnull final Weights beta, + @Nonnegative final int numFactors, @Nonnull final RealMatrix B, + @Nonnull final RealVector A, @Nonnull final double[][] BTBpR, + @Nonnegative final float c0, @Nonnegative final float c1) throws HiveException { + // filter for trainable items + List trainableItems = filterTrainableFeatures(sample.getValue(), beta); + thetaTotalFeaturesCounter.increment(sample.getValue().size()); + if (trainableItems.isEmpty()) { + return null; + } + thetaTrainableFeaturesCounter.increment(trainableItems.size()); + final double[] a = calculateA(trainableItems, beta, numFactors, c1); + final double[][] delta = + calculateWTWSubsetStrings(trainableItems, beta, numFactors, c1 - c0); + final double[][] b = addInPlace(delta, BTBpR); + // solve and update factors + return solve(B, b, A, a); + } + + /** + * Update latent factors of the items in the provided mini-batch. + */ + private void updateBeta(@Nonnull final Map> items, + Map sppmi) throws HiveException { + // precomputed matrix + final double[][] TTTpR = calculateWTWpR(theta, factor, c0, lambdaBeta); + for (Map.Entry> sample : items.entrySet()) { + RealVector newBetaVec = calculateNewBetaVector(sample, sppmi, theta, gamma, gammaBias, + betaBias, factor, B, A, TTTpR, c0, c1, globalBias); + if (newBetaVec != null) { + setFactorVector(sample.getKey(), beta, newBetaVec); + } else { + skippedItemCounter.increment(1); + } + itemCounter.increment(1); + } + } + + @VisibleForTesting + protected RealVector calculateNewBetaVector( + @Nonnull final Map.Entry> sample, + @Nonnull final Map sppmi, @Nonnull final Weights theta, + @Nonnull final Weights gamma, @Nonnull final Object2DoubleMap gammaBias, + @Nonnull final Object2DoubleMap betaBias, final int numFactors, + @Nonnull final RealMatrix B, @Nonnull final RealVector A, + @Nonnull final double[][] TTTpR, @Nonnegative final float c0, + @Nonnegative final float c1, final double globalBias) throws HiveException { + // filter for trainable users + final List trainableUsers = filterTrainableFeatures(sample.getValue(), theta); + betaTotalFeaturesCounter.increment(sample.getValue().size()); + if (trainableUsers.isEmpty()) { + return null; + } + + betaTrainableFeaturesCounter.increment(trainableUsers.size()); + + final List trainableCooccurringItems = + filterTrainableFeatures(sppmi.get(sample.getKey()), gamma); + final double[] RSD = calculateRSD(sample.getKey(), trainableCooccurringItems, numFactors, + betaBias, gammaBias, gamma, globalBias); + final double[] ApRSD = + addInPlace(calculateA(trainableUsers, theta, numFactors, c1), RSD, 1.f); + + final double[][] GTG = + calculateWTWSubsetFeatures(trainableCooccurringItems, gamma, numFactors, 1.f); + final double[][] delta = + calculateWTWSubsetStrings(trainableUsers, theta, numFactors, c1 - c0); + // never add into the precomputed `TTTpR` array, only add into temporary arrays like `delta` and `GTG` + final double[][] b = addInPlace(addInPlace(delta, GTG), TTTpR); + + // solve and update factors + return solve(B, b, A, ApRSD); + } + + /** + * Update latent factors of the items in the provided mini-batch. + */ + private void updateGamma(@Nonnull final Map> samples, + Map sppmi) throws HiveException { + for (Map.Entry> sample : samples.entrySet()) { + RealVector newGammaVec = calculateNewGammaVector(sample, sppmi, beta, gammaBias, + betaBias, factor, B, A, lambdaGamma, globalBias); + if (newGammaVec != null) { + setFactorVector(sample.getKey(), gamma, newGammaVec); + } + } + } + + @VisibleForTesting + protected static RealVector calculateNewGammaVector( + @Nonnull final Map.Entry> sample, + @Nonnull final Map sppmi, @Nonnull final Weights beta, + @Nonnull final Object2DoubleMap gammaBias, + @Nonnull final Object2DoubleMap betaBias, @Nonnegative final int numFactors, + @Nonnull final RealMatrix B, @Nonnull final RealVector A, + @Nonnegative final float lambdaGamma, final double globalBias) throws HiveException { + // filter for trainable items + final List trainableCooccurringItems = + filterTrainableFeatures(sppmi.get(sample.getKey()), beta); + if (trainableCooccurringItems.isEmpty()) { + return null; + } + final double[][] b = regularize( + calculateWTWSubsetFeatures(trainableCooccurringItems, beta, numFactors, 1.f), + lambdaGamma); + final double[] rsd = calculateRSD(sample.getKey(), trainableCooccurringItems, numFactors, + gammaBias, betaBias, beta, globalBias); + // solve and update factors + return solve(B, b, A, rsd); + } + + private static double[][] regularize(@Nonnull final double[][] A, final float lambda) { + for (int i = 0; i < A.length; i++) { + A[i][i] += lambda; + } + return A; + } + + private void updateBetaBias(@Nonnull final Map> samples, + Map sppmi) throws HiveException { + for (Map.Entry> sample : samples.entrySet()) { + Double newBetaBias = + calculateNewBias(sample, sppmi, beta, gamma, gammaBias, globalBias); + if (newBetaBias != null) { + setBetaBias(sample.getKey(), newBetaBias); + } + } + } + + public void updateGammaBias(@Nonnull final Map> samples, + Map sppmi) throws HiveException { + for (Map.Entry> sample : samples.entrySet()) { + Double newGammaBias = + calculateNewBias(sample, sppmi, gamma, beta, betaBias, globalBias); + if (newGammaBias != null) { + setGammaBias(sample.getKey(), newGammaBias); + } + } + } + + private void updateGlobalBias(@Nonnull final Map> samples, + Map sppmi) throws HiveException { + Double newGlobalBias = + calculateNewGlobalBias(samples, sppmi, beta, gamma, betaBias, gammaBias); + if (newGlobalBias != null) { + setGlobalBias(newGlobalBias); + } + } + + @Nullable + protected static Double calculateNewGlobalBias(@Nonnull final Map> samples, + @Nonnull final Map sppmi, @Nonnull Weights beta, + @Nonnull Weights gamma, @Nonnull final Object2DoubleMap betaBias, + @Nonnull final Object2DoubleMap gammaBias) throws HiveException { + double newGlobalBias = 0.d; + int numEntriesInSPPMI = 0; + for (Map.Entry> sample : samples.entrySet()) { + // filter for trainable items + final List trainableCooccurringItems = + filterTrainableFeatures(sppmi.get(sample.getKey()), beta); + if (trainableCooccurringItems.isEmpty()) { + continue; + } + numEntriesInSPPMI += trainableCooccurringItems.size(); + newGlobalBias += calculateGlobalBiasRSD(sample.getKey(), trainableCooccurringItems, + beta, gamma, betaBias, gammaBias); + } + if (numEntriesInSPPMI == 0) { + return null; + } + return newGlobalBias / numEntriesInSPPMI; + } + + @VisibleForTesting + protected static Double calculateNewBias(@Nonnull final Map.Entry> sample, + @Nonnull final Map sppmi, @Nonnull final Weights beta, + @Nonnull final Weights gamma, @Nonnull final Object2DoubleMap biases, + final double globalBias) throws HiveException { + // filter for trainable items + final List trainableCooccurringItems = + filterTrainableFeatures(sppmi.get(sample.getKey()), beta); + if (trainableCooccurringItems.isEmpty()) { + return null; + } + double rsd = calculateBiasRSD(sample.getKey(), trainableCooccurringItems, beta, gamma, + biases, globalBias); + return rsd / trainableCooccurringItems.size(); + + } + + @VisibleForTesting + protected static double calculateGlobalBiasRSD(@Nonnull final String thisItem, + @Nonnull final List trainableItems, @Nonnull final Weights beta, + @Nonnull final Weights gamma, @Nonnull final Object2DoubleMap betaBias, + @Nonnull final Object2DoubleMap gammaBias) { + double result = 0.d; + final double[] thisFactorVec = getFactorVector(thisItem, beta); + final double thisBias = getBias(thisItem, betaBias); + for (Feature cooccurrence : trainableItems) { + String j = cooccurrence.getFeature(); + final double[] cooccurVec = getFactorVector(j, gamma); + double cooccurBias = getBias(j, gammaBias); + double value = cooccurrence.getValue() - dotProduct(thisFactorVec, cooccurVec) + - thisBias - cooccurBias; + result += value; + } + return result; + } + + @VisibleForTesting + protected static double calculateBiasRSD(@Nonnull final String thisItem, + @Nonnull final List trainableItems, @Nonnull final Weights beta, + @Nonnull final Weights gamma, @Nonnull final Object2DoubleMap biases, + final double globalBias) { + double result = 0.d; + final double[] thisFactorVec = getFactorVector(thisItem, beta); + for (Feature cooccurrence : trainableItems) { + String j = cooccurrence.getFeature(); + final double[] cooccurVec = getFactorVector(j, gamma); + double cooccurBias = getBias(j, biases); + double value = cooccurrence.getValue() - dotProduct(thisFactorVec, cooccurVec) + - cooccurBias - globalBias; + result += value; + } + return result; + } + + @VisibleForTesting + @Nonnull + protected static double[] calculateRSD(@Nonnull final String thisItem, + @Nonnull final List trainableItems, final int numFactors, + @Nonnull final Object2DoubleMap fixedBias, + @Nonnull final Object2DoubleMap changingBias, @Nonnull final Weights weights, + final double globalBias) throws HiveException { + + final double b = getBias(thisItem, fixedBias); + final double[] accumulator = new double[numFactors]; + for (Feature cooccurrence : trainableItems) { + final String j = cooccurrence.getFeature(); + double scale = cooccurrence.getValue() - b - getBias(j, changingBias) - globalBias; + final double[] g = getFactorVector(j, weights); + addInPlace(accumulator, g, scale); + } + return accumulator; + } + + /** + * Calculate W' x W plus regularization matrix + */ + @VisibleForTesting + @Nonnull + protected static double[][] calculateWTWpR(@Nonnull final Weights W, + @Nonnegative final int numFactors, @Nonnegative final float c0, + @Nonnegative final float lambda) { + double[][] WTW = calculateWTW(W, numFactors, c0); + return regularize(WTW, lambda); + } + + private static void checkCondition(final boolean condition, final String errorMessage) + throws HiveException { + if (!condition) { + throw new HiveException(errorMessage); + } + } + + @VisibleForTesting + @Nonnull + protected static double[][] addInPlace(@Nonnull final double[][] A, @Nonnull final double[][] B) + throws HiveException { + checkCondition(A.length == A[0].length && A.length == B.length && B.length == B[0].length, + "Array is not square"); + for (int i = 0; i < A.length; i++) { + for (int j = 0; j < A[0].length; j++) { + A[i][j] += B[i][j]; + } + } + return A; + } + + @VisibleForTesting + @Nonnull + protected static List filterTrainableFeatures(@Nonnull final List features, + @Nonnull final Weights weights) { + final List trainableFeatures = new ArrayList<>(); + for (String feature : features) { + if (isTrainable(feature, weights)) { + trainableFeatures.add(feature); + } + } + return trainableFeatures; + } + + @VisibleForTesting + @Nonnull + protected static List filterTrainableFeatures(@Nullable final Feature[] features, + @Nonnull final Weights weights) throws HiveException { + checkCondition(features != null, "features cannot be null"); + final List trainableFeatures = new ArrayList<>(); + String fName; + for (Feature f : features) { + fName = f.getFeature(); + if (isTrainable(fName, weights)) { + trainableFeatures.add(f); + } + } + return trainableFeatures; + } + + @VisibleForTesting + protected static RealVector solve(@Nonnull final RealMatrix B, @Nonnull final double[][] dataB, + @Nonnull final RealVector A, @Nonnull final double[] dataA) throws HiveException { + // b * x = a + // solves for x + copyData(B, dataB); + copyData(A, dataA); + + final LUDecomposition LU = new LUDecomposition(B); + final DecompositionSolver solver = LU.getSolver(); + + if (solver.isNonSingular()) { + return LU.getSolver().solve(A); + } else { + SingularValueDecomposition svd = new SingularValueDecomposition(B); + return svd.getSolver().solve(A); + } + } + + private static void copyData(@Nonnull final RealMatrix dst, @Nonnull final double[][] src) + throws HiveException { + checkCondition( + dst.getRowDimension() == src.length && dst.getColumnDimension() == src[0].length, + "Matrix do not match in size"); + for (int i = 0, rows = dst.getRowDimension(); i < rows; i++) { + final double[] src_i = src[i]; + for (int j = 0, cols = dst.getColumnDimension(); j < cols; j++) { + dst.setEntry(i, j, src_i[j]); + } + } + } + + private static void copyData(@Nonnull final RealVector dst, @Nonnull final double[] src) + throws HiveException { + checkCondition(dst.getDimension() == src.length, "Vector do not match in size"); + for (int i = 0; i < dst.getDimension(); i++) { + dst.setEntry(i, src[i]); + } + } + + private static void copyData(@Nonnull final double[] dst, @Nonnull final RealVector src) + throws HiveException { + checkCondition(dst.length == src.getDimension(), "Vector do not match in size"); + for (int i = 0; i < dst.length; i++) { + dst[i] = src.getEntry(i); + } + } + + @VisibleForTesting + @Nonnull + protected static double[][] calculateWTW(@Nonnull final Weights weights, + @Nonnull final int numFactors, @Nonnull final float constant) { + final double[][] WTW = new double[numFactors][numFactors]; + for (double[] vec : weights.values()) { + for (int i = 0; i < numFactors; i++) { + final double[] WTW_f = WTW[i]; + for (int j = 0; j < numFactors; j++) { + double val = constant * vec[i] * vec[j]; + WTW_f[j] += val; + } + } + } + return WTW; + } + + @VisibleForTesting + @Nonnull + protected static double[][] calculateWTWSubsetStrings(@Nonnull final List subset, + @Nonnull final Weights weights, @Nonnegative final int numFactors, + @Nonnegative final float constant) throws HiveException { + // equivalent to `B_u.T.dot((c1 - c0) * B_u)` in cofacto.py + final double[][] delta = new double[numFactors][numFactors]; + for (String f : subset) { + final double[] vec = getFactorVector(f, weights); + checkCondition(vec != null, "null vector is not allowed"); + for (int i = 0; i < numFactors; i++) { + final double[] delta_f = delta[i]; + for (int j = 0; j < numFactors; j++) { + double val = constant * vec[i] * vec[j]; + delta_f[j] += val; + } + } + } + return delta; + } + + @VisibleForTesting + @Nonnull + protected static double[][] calculateWTWSubsetFeatures(@Nonnull final List subset, + @Nonnull final Weights weights, @Nonnegative final int numFactors, + @Nonnegative final float constant) throws HiveException { + // equivalent to `B_u.T.dot((c1 - c0) * B_u)` in cofacto.py + final double[][] delta = new double[numFactors][numFactors]; + for (Feature f : subset) { + final double[] vec = getFactorVector(f.getFeature(), weights); + checkCondition(vec != null, "null vector is not allowed"); + for (int i = 0; i < numFactors; i++) { + final double[] delta_f = delta[i]; + for (int j = 0; j < numFactors; j++) { + double val = constant * vec[i] * vec[j]; + delta_f[j] += val; + } + } + } + return delta; + } + + @VisibleForTesting + @Nonnull + protected static double[] calculateA(@Nonnull final List items, + @Nonnull final Weights weights, @Nonnegative final int numFactors, + @Nonnegative final float constant) throws HiveException { + // Equivalent to: a = x_u.dot(c1 * B_u) + // x_u is a (1, i) matrix of all ones + // B_u is a (i, F) matrix + // What it does: sums factor n of each item in B_u + final double[] A = new double[numFactors]; + for (String item : items) { + addInPlace(A, getFactorVector(item, weights), 1.d); + } + for (int a = 0; a < A.length; a++) { + A[a] *= constant; + } + return A; + } + + @Nullable + public Double predict(@Nonnull final String user, @Nonnull final String item) { + if (!theta.containsKey(user) || !beta.containsKey(item)) { + return null; + } + final double[] u = getThetaVector(user), i = getBetaVector(item); + return dotProduct(u, i); + } + + @VisibleForTesting + protected static double dotProduct(@Nonnull final double[] u, @Nonnull final double[] v) { + double result = 0.d; + for (int i = 0; i < u.length; i++) { + result += u[i] * v[i]; + } + return result; + } + + // public double calculateLoss(@Nonnull final Map> users, @Nonnull final List>> items) { + // // for speed - can calculate loss on a small subset of the training data + // double mf_loss = calculateMFLoss(users, theta, beta, c0, c1) + calculateMFLoss(items, beta, theta, c0, c1); + // double embed_loss = calculateEmbedLoss(items, beta, gamma, betaBias, gammaBias); + // return mf_loss + embed_loss + sumL2Loss(theta, lambdaTheta) + sumL2Loss(beta, lambdaBeta) + sumL2Loss(gamma, lambdaGamma); + // } + + + @VisibleForTesting + protected static double calculateEmbedLoss( + @Nonnull final List items, + @Nonnull final Weights beta, @Nonnull final Weights gamma, + @Nonnull final Object2DoubleMap betaBias, + @Nonnull final Object2DoubleMap gammaBias) { + double loss = 0.d, val, bBias, gBias; + double[] bFactors, gFactors; + String bKey, gKey; + for (CofactorizationUDTF.TrainingSample item : items) { + bKey = item.context; + bFactors = getFactorVector(bKey, beta); + bBias = getBias(bKey, betaBias); + for (Feature cooccurrence : item.sppmi) { + if (!isTrainable(cooccurrence.getFeature(), beta)) { + continue; + } + gKey = cooccurrence.getFeature(); + gFactors = getFactorVector(gKey, gamma); + gBias = getBias(gKey, gammaBias); + val = cooccurrence.getValue() - dotProduct(bFactors, gFactors) - bBias - gBias; + loss += val * val; + } + } + return loss; + } + + @VisibleForTesting + protected static double calculateMFLoss( + @Nonnull final List samples, + @Nonnull final Weights contextWeights, @Nonnull final Weights featureWeights, + @Nonnegative final float c0, @Nonnegative final float c1) { + double loss = 0.d, err, predicted, y; + double[] contextFactors, ratedFactors; + + for (CofactorizationUDTF.TrainingSample sample : samples) { + contextFactors = getFactorVector(sample.context, contextWeights); + // all items / users + for (double[] unratedFactors : featureWeights.values()) { + predicted = dotProduct(contextFactors, unratedFactors); + err = (0.d - predicted); + loss += c0 * err * err; + } + // only rated items / users + for (Feature f : sample.features) { + if (!isTrainable(f.getFeature(), featureWeights)) { + continue; + } + ratedFactors = getFactorVector(f.getFeature(), featureWeights); + predicted = dotProduct(contextFactors, ratedFactors); + y = f.getValue(); + err = y - predicted; + loss += (c1 - c0) * err * err; + } + } + return loss; + } + + @VisibleForTesting + protected static double sumL2Loss(@Nonnull final Weights weights, @Nonnegative float lambda) { + double loss = 0.d; + for (double[] v : weights.values()) { + loss += L2Distance(v); + } + return lambda * loss; + } + + @VisibleForTesting + protected static double L2Distance(@Nonnull final double[] vec) { + double result = 0.d; + for (double v : vec) { + result += v * v; + } + return Math.sqrt(result); + } + + /** + * Sample positive and negative validation examples and return a performance metric that should + * be minimized. + * + * @return Validation metric + * @throws HiveException + */ + public Double validate(@Nonnull final String user, @Nonnull final String item) + throws HiveException { + if (!theta.containsKey(user) || !beta.containsKey(item)) { + return null; + } + // limit numPos and numNeg + // int numPos = Math.min(sample.features.length, (int) Math.ceil(this.numValPerRecord * 0.5)); + int numPos = 1; + // int numNeg = Math.min(this.numValPerRecord - numPos, sample.isItem() ? users.length : items.length); + int numNeg = 2; + // if (numPos == 0) { + // throw new HiveException("numPos is 0: sample.features.length = " + sample.features.length + ", ceil = " + (int) Math.ceil(this.numValPerRecord * 0.5)); + // } + // if (numNeg == 0) { + // throw new HiveException("numNeg is 0, users.length = " + users.length + ", items.length = " + items.length); + // } + + // getValidationExamples(numPos, numNeg, sample.features, sample.isItem(), validationProbes, seed); + // if (validationMetric == CofactorizationUDTF.ValidationMetric.AUC) { + // return -calculateAUC(validationProbes, predictions, sample, numPos, numNeg); + // } else { + // return calculateLoss(validationProbes, sample, numPos, numNeg); + // } + return null; + } + + private boolean isPredictable(@Nonnull final String context, final boolean isItem) { + if (isItem) { + return beta.containsKey(context); + } else { + return theta.containsKey(context); + } + } + + /** + * TODO: not implemented + * + * @return + */ + private double calculateLoss(Feature[] validationProbes, + CofactorizationUDTF.TrainingSample sample, int numPos, int numNeg) { + return 0d; + } + + /** + * Calculates area under curve for validation metric. + */ + private double calculateAUC(@Nonnull final Feature[] validationProbes, + @Nonnull final Prediction[] predictions, CofactorizationUDTF.TrainingSample sample, + final int numPos, final int numNeg) throws HiveException { + // make predictions for positive and then negative examples + int nextIdx = fillPredictions(validationProbes, predictions, sample, 0, numPos, 0, 1); + // if (nextIdx == 0) { + // throw new HiveException("nextIdx is 0, no positives in predictions, validation probes = " + Arrays.toString(validationProbes)); + // } + int endIdx = fillPredictions(validationProbes, predictions, sample, nextIdx, + numPos + numNeg, nextIdx, 0); + + // sort in descending order for all filled predictions + Arrays.sort(predictions, 0, endIdx); + + double area = 0d, scorePrev = Double.MIN_VALUE; + int fp = 0, tp = 0; + int fpPrev = 0, tpPrev = 0; + + for (int i = 0; i < endIdx; i++) { + final Prediction p = predictions[i]; + if (p.prediction != scorePrev) { + area += trapezoid(fp, fpPrev, tp, tpPrev); + scorePrev = p.prediction; + fpPrev = fp; + tpPrev = tp; + } + if (p.label == 1) { + tp += 1; + } else { + fp += 1; + } + } + area += trapezoid(fp, fpPrev, tp, tpPrev); + if (tp * fp == 0) { + return 0d; + } + return area / (tp * fp); + } + + /** + * Calculates area of a trapezoid. + */ + private static double trapezoid(final int x1, final int x2, final int y1, final int y2) { + final int base = Math.abs(x1 - x2); + final double height = (y1 + y2) * 0.5; + return base * height; + } + + /** + * Fill an array of predictions. + * + * @return index of the next empty entry in {@code predictions} array + */ + private int fillPredictions(@Nonnull final Feature[] validationProbes, + @Nonnull final Prediction[] predictions, + @Nonnull final CofactorizationUDTF.TrainingSample sample, final int lo, final int hi, + int fillIdx, final int label) { + for (int i = lo; i < hi; i++) { + final Feature pos = validationProbes[i]; + final Double pred; + if (sample.isItem()) { + pred = predict(pos.getFeature(), sample.context); + } else { + pred = predict(sample.context, pos.getFeature()); + } + if (pred == null) { + continue; + } + predictions[fillIdx].prediction = pred; + predictions[fillIdx].label = label; + fillIdx++; + } + return fillIdx; + } + + /** + * Sample positive and negative samples. + * + * @return number of negatives that were successfully sampled + */ + private void getValidationExamples(final int numPos, final int numNeg, + @Nonnull final Feature[] positives, final boolean isContextAnItem, + @Nonnull final Feature[] validationProbes, final int seed) { + final Random rand = new Random(seed); + samplePositives(numPos, positives, validationProbes, rand); + final String[] keys = isContextAnItem ? users : items; + sampleNegatives(numPos, numNeg, validationProbes, keys, rand); + } + + /** + * Samples negative examples. + */ + @VisibleForTesting + protected static void sampleNegatives(final int numPos, final int numNeg, + @Nonnull final Feature[] validationProbes, @Nonnull final String[] keys, + @Nonnull final Random rand) { + // sample numPos positive examples without replacement + for (int i = numPos, size = numPos + numNeg; i < size; i++) { + final String negKey = keys[rand.nextInt(keys.length)]; + validationProbes[i].setFeature(negKey); + validationProbes[i].setValue(0.d); + } + } + + private static void samplePositives(final int numPos, @Nonnull final Feature[] positives, + @Nonnull final Feature[] validationProbes, @Nonnull final Random rand) { + // sample numPos positive examples without replacement + for (int i = 0; i < numPos; i++) { + validationProbes[i] = positives[rand.nextInt(positives.length)]; + } + } + + /** + * Add v to u in-place without creating a new RealVector instance. + * + * @param u array to which v will be added + * @param v array containing new values to be added to u + * @param scalar value to multiply each entry in v before adding to u + */ + @VisibleForTesting + @Nonnull + protected static double[] addInPlace(@Nonnull final double[] u, @Nullable final double[] v, + final double scalar) throws HiveException { + checkCondition(v != null, "null vector is not allowed"); + checkCondition(u.length == v.length, "Vector do not match in size"); + for (int i = 0; i < u.length; i++) { + u[i] += scalar * v[i]; + } + return u; + } + + private static boolean isTrainable(@Nonnull final String name, @Nonnull final Weights weights) { + return weights.containsKey(name); + } + + @Nonnull + private static Random[] newRandoms(@Nonnegative final int size, final long seed) { + final Random[] rand = new Random[size]; + for (int i = 0, len = rand.length; i < len; i++) { + rand[i] = new Random(seed + i); + } + return rand; + } + + private static void uniformFill(@Nonnull final double[] a, @Nonnull final Random rand, + final float maxInitValue) { + for (int i = 0, len = a.length; i < len; i++) { + double v = rand.nextDouble() * maxInitValue / len; + a[i] = v; + } + } + + private static void gaussianFill(@Nonnull final double[] a, @Nonnull final Random[] rand, + @Nonnegative final double stddev) { + for (int i = 0, len = a.length; i < len; i++) { + double v = MathUtils.gaussian(0.d, stddev, rand[i]); + a[i] = v; + } + } + + public void registerCounters(Counters.Counter userCounter, Counters.Counter itemCounter, + Counters.Counter skippedUserCounter, Counters.Counter skippedItemCounter, + Counters.Counter thetaTrainable, Counters.Counter thetaTotal, + Counters.Counter betaTrainable, Counters.Counter betaTotal) { + this.userCounter = userCounter; + this.itemCounter = itemCounter; + this.skippedUserCounter = skippedUserCounter; + this.skippedItemCounter = skippedItemCounter; + this.thetaTotalFeaturesCounter = thetaTotal; + this.thetaTrainableFeaturesCounter = thetaTrainable; + this.betaTotalFeaturesCounter = betaTotal; + this.betaTrainableFeaturesCounter = betaTrainable; + } + + public void registerUsers(Set users) throws HiveException { + for (String key : users) { + initFactorVector(key, theta); + } + this.users = theta.getNonnullKeys(); + } + + public void registerItems(Set items) throws HiveException { + for (String key : items) { + initFactorVector(key, beta); + initFactorVector(key, gamma); + } + this.items = beta.getNonnullKeys(); + } + +} diff --git a/core/src/main/java/hivemall/factorization/cofactor/CofactorizationPredictUDF.java b/core/src/main/java/hivemall/factorization/cofactor/CofactorizationPredictUDF.java new file mode 100644 index 000000000..8fd911ac2 --- /dev/null +++ b/core/src/main/java/hivemall/factorization/cofactor/CofactorizationPredictUDF.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.factorization.cofactor; + +import java.util.List; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDF; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.io.FloatWritable; + +@Description(name = "cofactor_predict", + value = "_FUNC_(array theta, array beta) - Returns the prediction value") +@UDFType(deterministic = true, stateful = false) +public final class CofactorizationPredictUDF extends UDF { + + private static final DoubleWritable ZERO = new DoubleWritable(0.d); + + // reused result variable + private final DoubleWritable result = new DoubleWritable(); + + @Nonnull + public DoubleWritable evaluate(@Nullable List Pu, + @Nullable List Qi) throws HiveException { + if (Pu == null || Qi == null) { + return ZERO; + } + + final int PuSize = Pu.size(); + final int QiSize = Qi.size(); + // workaround for TD + if (PuSize == 0) { + return ZERO; + } else if (QiSize == 0) { + return ZERO; + } + + if (QiSize != PuSize) { + throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize); + } + + double ret = 0.d; + for (int k = 0; k < PuSize; k++) { + FloatWritable Pu_k = Pu.get(k); + if (Pu_k == null) { + continue; + } + FloatWritable Qi_k = Qi.get(k); + if (Qi_k == null) { + continue; + } + ret += Pu_k.get() * Qi_k.get(); + } + result.set(ret); + return result; + } +} diff --git a/core/src/main/java/hivemall/factorization/cofactor/CofactorizationUDTF.java b/core/src/main/java/hivemall/factorization/cofactor/CofactorizationUDTF.java new file mode 100644 index 000000000..098fcb8d7 --- /dev/null +++ b/core/src/main/java/hivemall/factorization/cofactor/CofactorizationUDTF.java @@ -0,0 +1,558 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.factorization.cofactor; + +import hivemall.UDTFWithOptions; +import hivemall.common.ConversionState; +import hivemall.factorization.cofactor.CofactorModel.RankInitScheme; +import hivemall.fm.Feature; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.Primitives; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapred.Counters; +import org.apache.hadoop.mapred.Reporter; + +/** + * Cofactorization for implicit and explicit recommendation + */ +@Description(name = "train_cofactor", + value = "_FUNC_(string context, array features, boolean is_validation, boolean is_item, array sppmi [, String options])" + + " - Returns a relation theta, array beta>") +public final class CofactorizationUDTF extends UDTFWithOptions { + private static final Log LOG = LogFactory.getLog(CofactorizationUDTF.class); + + // Option variables + // The number of latent factors + private int factor; + // The scaling hyperparameter for zero entries in the rank matrix + private float c0; + // The scaling hyperparameter for non-zero entries in the rank matrix + private float c1; + // The initial mean rating + private float globalBias; + // Whether update (and return) the mean rating or not + private boolean updateGlobalBias; + // The number of iterations + private int maxIters; + // Whether to use bias clause + private boolean useBiasClause; + // Whether to use normalization + private boolean useL2Norm; + // regularization hyperparameters + private float lambdaTheta; + private float lambdaBeta; + private float lambdaGamma; + + // validation metric + private ValidationMetric validationMetric; + + // Initialization strategy of rank matrix + private RankInitScheme rankInit; + + // Model itself + private CofactorModel model; + + // Variable managing status of learning + private ConversionState validationState; + private int numValPerRecord; + + // Input OIs and Context + private PrimitiveObjectInspector userOI; + private PrimitiveObjectInspector itemOI; + + private BooleanObjectInspector isValidationOI; + private ListObjectInspector sppmiOI; + + // Used for iterations + private long numValidations; + private long numTraining; + + // training data + private Map> userToItems; + private Map> itemToUsers; + private Map sppmi; + + // validation + private Random rand; + private double validationRatio; + private List validationUsers; + private List validationItems; + + static class MiniBatch { + @Nonnull + private final List users; + @Nonnull + private final List items; + @Nonnull + private final List validationSamples; + + MiniBatch() { + this.users = new ArrayList<>(); + this.items = new ArrayList<>(); + this.validationSamples = new ArrayList<>(); + } + + void add(TrainingSample sample) { + if (sample.isValidation) { + validationSamples.add(sample); + } else { + if (sample.isItem()) { + items.add(sample); + } else { + users.add(sample); + } + } + } + + void clear() { + users.clear(); + items.clear(); + validationSamples.clear(); + } + + int trainingSize() { + return items.size() + users.size(); + } + + int validationSize() { + return validationSamples.size(); + } + + @Nonnull + List getItems() { + return items; + } + + @Nonnull + List getUsers() { + return users; + } + + @Nonnull + List getValidationSamples() { + return validationSamples; + } + } + + static final class TrainingSample { + @Nonnull + final String context; + @Nonnull + final Feature[] features; + @Nonnull + final Feature[] sppmi; + final boolean isValidation; + + TrainingSample(@Nonnull String context, @Nonnull Feature[] features, boolean isValidation, + @Nullable Feature[] sppmi) { + this.context = context; + this.features = features; + this.sppmi = sppmi; + this.isValidation = isValidation; + } + + boolean isItem() { + return sppmi != null; + } + } + + enum ValidationMetric { + AUC, OBJECTIVE; + + static ValidationMetric resolve(@Nonnull final String opt) { + switch (opt.toLowerCase()) { + case "auc": + return AUC; + case "objective": + case "loss": + return OBJECTIVE; + default: + throw new IllegalArgumentException( + opt + " is not a supported Validation Metric."); + } + } + } + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("k", "factor", true, "The number of latent factor [default: 10] " + + " Note this is alias for `factors` option."); + opts.addOption("f", "factors", true, "The number of latent factor [default: 10]"); + opts.addOption("lt", "lambda_theta", true, + "The theta regularization factor [default: 1e-5]"); + opts.addOption("lb", "lambda_beta", true, "The beta regularization factor [default: 1e-5]"); + opts.addOption("lg", "lambda_gamma", true, + "The gamma regularization factor [default: 1.0]"); + opts.addOption("c0", "c0", true, + "The scaling hyperparameter for zero entries in the rank matrix [default: 0.1]"); + opts.addOption("c1", "c1", true, + "The scaling hyperparameter for non-zero entries in the rank matrix [default: 1.0]"); + opts.addOption("gb", "global_bias", true, "The global bias [default: 0.0]"); + opts.addOption("update_gb", "update_global_bias", true, + "Whether update (and return) the global bias or not [default: false]"); + opts.addOption("rankinit", true, + "Initialization strategy of rank matrix [random, gaussian] (default: gaussian)"); + opts.addOption("maxval", "max_init_value", true, + "The maximum initial value in the rank matrix [default: 1.0]"); + opts.addOption("min_init_stddev", true, + "The minimum standard deviation of initial rank matrix [default: 0.01]"); + opts.addOption("iters", "iterations", true, "The number of iterations [default: 1]"); + opts.addOption("iter", true, + "The number of iterations [default: 1] Alias for `-iterations`"); + opts.addOption("max_iters", "max_iters", true, "The number of iterations [default: 1]"); + opts.addOption("disable_bias", "no_bias", false, "Turn off bias clause"); + // normalization + opts.addOption("disable_norm", "disable_l2norm", false, + "Disable instance-wise L2 normalization"); + // validation + opts.addOption("disable_cv", "disable_cvtest", false, + "Whether to disable convergence check [default: enabled]"); + opts.addOption("cv_rate", "convergence_rate", true, + "Threshold to determine convergence [default: 0.005]"); + opts.addOption("val_metric", "validation_metric", true, + "Metric to use for validation ['auc', 'objective']"); + opts.addOption("val_ratio", "validation_ratio", true, + "Proportion of examples to use as validation data [default: 0.125]"); + opts.addOption("num_val", "num_validation_examples_per_record", true, + "Number of validation examples to use per record [default: 10]"); + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = null; + String rankInitOpt = "gaussian"; + float maxInitValue = 1.f; + double initStdDev = 0.01d; + boolean convergenceCheck = true; + double convergenceRate = 0.005d; + String validationMetricOpt = "auc"; + this.c0 = 0.1f; + this.c1 = 1.0f; + this.lambdaTheta = 1e-5f; + this.lambdaBeta = 1e-5f; + this.lambdaGamma = 1.0f; + this.globalBias = 0.f; + this.maxIters = 1; + this.factor = 10; + this.numValPerRecord = 10; + this.validationRatio = 0.125; + + if (argOIs.length >= 3) { + String rawArgs = HiveUtils.getConstString(argOIs[3]); + cl = parseOptions(rawArgs); + if (cl.hasOption("factors")) { + this.factor = Primitives.parseInt(cl.getOptionValue("factors"), factor); + } else { + this.factor = Primitives.parseInt(cl.getOptionValue("factor"), factor); + } + this.lambdaTheta = + Primitives.parseFloat(cl.getOptionValue("lambda_theta"), lambdaTheta); + this.lambdaBeta = Primitives.parseFloat(cl.getOptionValue("lambda_beta"), lambdaBeta); + this.lambdaGamma = + Primitives.parseFloat(cl.getOptionValue("lambda_gamma"), lambdaGamma); + + this.c0 = Primitives.parseFloat(cl.getOptionValue("c0"), c0); + this.c1 = Primitives.parseFloat(cl.getOptionValue("c1"), c1); + + this.globalBias = Primitives.parseFloat(cl.getOptionValue("global_bias"), globalBias); + this.updateGlobalBias = cl.hasOption("update_global_bias"); + + rankInitOpt = cl.getOptionValue("rankinit", rankInitOpt); + maxInitValue = Primitives.parseFloat(cl.getOptionValue("max_init_value"), maxInitValue); + initStdDev = Primitives.parseDouble(cl.getOptionValue("min_init_stddev"), initStdDev); + + if (cl.hasOption("iter")) { + this.maxIters = Primitives.parseInt(cl.getOptionValue("iter"), maxIters); + } else { + this.maxIters = Primitives.parseInt(cl.getOptionValue("max_iters"), maxIters); + } + if (maxIters < 1) { + throw new UDFArgumentException( + "'-max_iters' must be greater than or equal to 1: " + maxIters); + } + + convergenceCheck = !cl.hasOption("disable_cvtest"); + convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate); + validationMetricOpt = cl.getOptionValue("validation_metric", validationMetricOpt); + this.numValPerRecord = Primitives.parseInt( + cl.getOptionValue("num_validation_examples_per_record"), numValPerRecord); + this.validationRatio = Primitives.parseDouble(cl.getOptionValue("validation_ratio"), + this.validationRatio); + if (this.validationRatio > 1 || this.validationRatio < 0) { + throw new UDFArgumentException("'-validation_ratio' must be between 0.0 and 1.0"); + } + boolean noBias = cl.hasOption("no_bias"); + this.useBiasClause = !noBias; + if (noBias && updateGlobalBias) { + throw new UDFArgumentException("Cannot set both `update_gb` and `no_bias` option"); + } + this.useL2Norm = !cl.hasOption("disable_l2norm"); + } + this.rankInit = RankInitScheme.resolve(rankInitOpt); + rankInit.setMaxInitValue(maxInitValue); + rankInit.setInitStdDev(initStdDev); + this.validationState = new ConversionState(convergenceCheck, convergenceRate); + this.validationMetric = ValidationMetric.resolve(validationMetricOpt); + return cl; + } + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length < 3) { + throw new UDFArgumentException( + "_FUNC_ takes 3 arguments: string user, string item, array sppmi [, CONSTANT STRING options]"); + } + this.userOI = HiveUtils.asPrimitiveObjectInspector(argOIs[0]); + this.itemOI = HiveUtils.asPrimitiveObjectInspector(argOIs[1]); + this.sppmiOI = HiveUtils.asListOI(argOIs[2]); + HiveUtils.validateFeatureOI(sppmiOI.getListElementObjectInspector()); + + processOptions(argOIs); + + this.model = new CofactorModel(factor, rankInit, c0, c1, lambdaTheta, lambdaBeta, + lambdaGamma, globalBias, validationMetric, numValPerRecord, LOG); + + userToItems = new HashMap<>(); + itemToUsers = new HashMap<>(); + sppmi = new HashMap<>(); + + validationUsers = new ArrayList<>(); + validationItems = new ArrayList<>(); + + rand = new Random(31); + + List fieldNames = new ArrayList(); + List fieldOIs = new ArrayList(); + fieldNames.add("context"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + fieldNames.add("theta"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableFloatObjectInspector)); + fieldNames.add("beta"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableFloatObjectInspector)); + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @Override + public void process(Object[] args) throws HiveException { + final String user = PrimitiveObjectInspectorUtils.getString(args[0], userOI); + final String item = PrimitiveObjectInspectorUtils.getString(args[1], itemOI); + Feature[] sppmiVec = null; + if (!sppmi.containsKey(item)) { + if (args[2] != null) { + sppmiVec = Feature.parseFeatures(args[2], sppmiOI, null, false); + sppmi.put(item, sppmiVec); + } +// } else { +// throw new HiveException( +// "null sppmi vector provided when item does not exist in sppmi"); +// } + } + recordSample(user, item); + } + + private static void addToMap(@Nonnull final Map> map, + @Nonnull final String key, @Nonnull final String value) { + List values = map.get(key); + final boolean isNewKey = values == null; + if (isNewKey) { + values = new ArrayList<>(); + values.add(value); + map.put(key, values); + } else { + values.add(value); + } + } + + private void recordSample(@Nonnull final String user, @Nonnull final String item) { + // validation data + if (rand.nextDouble() < validationRatio) { + addValidationSample(user, item); + } else { + // train + addToMap(userToItems, user, item); + addToMap(itemToUsers, item, user); + } + } + + private void addValidationSample(@Nonnull final String user, @Nonnull final String item) { + validationUsers.add(user); + validationItems.add(item); + } + + private void addToSPPMI(@Nonnull final String item, @Nonnull final Feature[] sppmiVec) { + if (sppmi.containsKey(item)) { + return; + } + sppmi.put(item, sppmiVec); + } + + @Override + public void close() throws HiveException { + try { + model.registerUsers(userToItems.keySet()); + model.registerItems(itemToUsers.keySet()); + + final Reporter reporter = getReporter(); + final Counters.Counter iterCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", "iteration"); + + final Counters.Counter userCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", "users"); + final Counters.Counter itemCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", "items"); + final Counters.Counter skippedUserCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", "skippedUsers"); + final Counters.Counter skippedItemCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", "skippedItems"); + + final Counters.Counter thetaTotalCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", + "thetaTotalFeaturesCounter"); + final Counters.Counter thetaTrainableCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", + "thetaTrainableFeaturesCounter"); + + final Counters.Counter betaTotalCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", + "betaTotalFeaturesCounter"); + final Counters.Counter betaTrainableCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", + "betaTrainableFeaturesCounter"); + + model.registerCounters(userCounter, itemCounter, skippedUserCounter, skippedItemCounter, + thetaTrainableCounter, thetaTotalCounter, betaTrainableCounter, betaTotalCounter); + + + for (int iteration = 0; iteration < maxIters; iteration++) { + // train the model on a full batch (i.e., all the data) using mini-batch updates + validationState.next(); + reportProgress(reporter); + setCounterValue(iterCounter, iteration); + runTrainingIteration(); + + System.out.println( + "Validation loss: " + validationState.getAverageLoss(numValidations)); + + LOG.info("Performed " + iteration + " iterations of " + + NumberUtils.formatNumber(maxIters) + " with " + numTraining + + " training examples and " + numValidations + " validation examples."); + // + " training examples on a secondary storage (thus " + // + NumberUtils.formatNumber(_t) + " training updates in total), used " + // + _numValidations + " validation examples"); + + if (validationState.isConverged(numTraining)) { + break; + } + } + forwardModel(); + } finally { + this.model = null; + } + } + + private void forwardModel() throws HiveException { + if (model == null) { + return; + } + + final Text id = new Text(); + final FloatWritable[] theta = HiveUtils.newFloatArray(factor, 0.f); + final FloatWritable[] beta = HiveUtils.newFloatArray(factor, 0.f); + final Object[] forwardObj = new Object[] {id, theta, null}; + + int numUsersForwarded = 0, numItemsForwarded = 0; + + for (Map.Entry entry : model.getTheta().entrySet()) { + id.set(entry.getKey()); + copyTo(entry.getValue(), theta); + forward(forwardObj); + numUsersForwarded++; + } + + forwardObj[1] = null; + forwardObj[2] = beta; + for (Map.Entry entry : model.getBeta().entrySet()) { + id.set(entry.getKey()); + copyTo(entry.getValue(), beta); + forward(forwardObj); + numItemsForwarded++; + } + LOG.info("Forwarded the prediction model of " + numUsersForwarded + + " user rows (theta) and " + numItemsForwarded + " item rows (beta).]"); + + } + + private void copyTo(@Nonnull final double[] src, @Nonnull final FloatWritable[] dst) { + for (int k = 0, size = factor; k < size; k++) { + dst[k].set((float) src[k]); + } + } + + private void runTrainingIteration() throws HiveException { + model.updateWithUsers(userToItems); + model.updateWithItems(itemToUsers, sppmi); + // model.validate() + } + + private void validate() throws HiveException { + if (validationUsers.size() != validationItems.size()) { + throw new HiveException("number of validation users and items must be the same"); + } + for (int i = 0, numVal = validationUsers.size(); i < numVal; i++) { + final Double loss = model.validate(validationUsers.get(i), validationUsers.get(i)); + if (loss != null) { + if (!NumberUtils.isFinite(loss)) { + throw new HiveException("Non-finite validation loss encountered"); + } + validationState.incrLoss(loss); + } + } + } +} diff --git a/core/src/main/java/hivemall/fm/Feature.java b/core/src/main/java/hivemall/fm/Feature.java index 2ccf7be0e..bdd752bb9 100644 --- a/core/src/main/java/hivemall/fm/Feature.java +++ b/core/src/main/java/hivemall/fm/Feature.java @@ -73,6 +73,10 @@ public double getValue() { return value; } + public void setValue(final double value) { + this.value = value; + } + public abstract int bytes(); public abstract void writeTo(@Nonnull ByteBuffer dst); @@ -171,7 +175,7 @@ public static Feature[] parseFFMFeatures(@Nonnull final Object arg, } @Nonnull - static Feature parseFeature(@Nonnull final String fv, final boolean asIntFeature) + public static Feature parseFeature(@Nonnull final String fv, final boolean asIntFeature) throws HiveException { final int pos1 = fv.indexOf(':'); if (pos1 == -1) { @@ -382,5 +386,4 @@ public static void l2normalize(@Nonnull final Feature[] features) { f.value *= invNorm; } } - } diff --git a/core/src/main/java/hivemall/mf/CofactorModel.java b/core/src/main/java/hivemall/mf/CofactorModel.java new file mode 100644 index 000000000..863fa0cfc --- /dev/null +++ b/core/src/main/java/hivemall/mf/CofactorModel.java @@ -0,0 +1,1043 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.mf; + +import hivemall.annotations.VisibleForTesting; +import hivemall.fm.Feature; +import hivemall.fm.StringFeature; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.math.MathUtils; +import it.unimi.dsi.fastutil.objects.Object2DoubleArrayMap; +import it.unimi.dsi.fastutil.objects.Object2DoubleMap; +import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; +import org.apache.commons.logging.Log; +import org.apache.commons.math3.linear.*; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.mapred.Counters; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.util.*; + + +public class CofactorModel { + + public void registerCounters(Counters.Counter userCounter, Counters.Counter itemCounter, Counters.Counter skippedUserCounter, + Counters.Counter skippedItemCounter, Counters.Counter thetaTrainable, Counters.Counter thetaTotal, + Counters.Counter betaTrainable, Counters.Counter betaTotal) { + this.userCounter = userCounter; + this.itemCounter = itemCounter; + this.skippedUserCounter = skippedUserCounter; + this.skippedItemCounter = skippedItemCounter; + this.thetaTotalFeaturesCounter = thetaTotal; + this.thetaTrainableFeaturesCounter = thetaTrainable; + this.betaTotalFeaturesCounter = betaTotal; + this.betaTrainableFeaturesCounter = betaTrainable; + } + + public void registerUsers(Set users) throws HiveException { + for (String key : users) { + initFactorVector(key, theta); + } + this.users = theta.getNonnullKeys(); + } + + public void registerItems(Set items) throws HiveException { + for (String key : items) { + initFactorVector(key, beta); + initFactorVector(key, gamma); + } + this.items = beta.getNonnullKeys(); + } + + public enum RankInitScheme { + random /* default */, gaussian; + @Nonnegative + private float maxInitValue; + @Nonnegative + private double initStdDev; + + @Nonnull + public static CofactorModel.RankInitScheme resolve(@Nullable String opt) { + if (opt == null) { + return random; + } else if ("gaussian".equalsIgnoreCase(opt)) { + return gaussian; + } else if ("random".equalsIgnoreCase(opt)) { + return random; + } + return random; + } + + public void setMaxInitValue(float maxInitValue) { + this.maxInitValue = maxInitValue; + } + + public void setInitStdDev(double initStdDev) { + this.initStdDev = initStdDev; + } + } + + private static class Prediction implements Comparable { + + private double prediction; + private int label; + @Override + public int compareTo(@Nonnull Prediction other) { + // descending order + return -Double.compare(prediction, other.prediction); + } + + } + + @Nonnegative + private final int factor; + + // rank matrix initialization + private final RankInitScheme initScheme; + + private double globalBias; + + // storing trainable latent factors and weights + private final Weights theta; + private final Weights beta; + private final Object2DoubleMap betaBias; + private final Weights gamma; + private final Object2DoubleMap gammaBias; + + private final Random[] randU, randI; + + // hyperparameters + @Nonnegative + private final float c0, c1; + private final float lambdaTheta, lambdaBeta, lambdaGamma; + + // validation + private final CofactorizationUDTF.ValidationMetric validationMetric; + private final Feature[] validationProbes; + private final Prediction[] predictions; + private final int numValPerRecord; + private String[] users; + private String[] items; + + // solve + private final RealMatrix B; + private final RealVector A; + + // counters + private Counters.Counter userCounter; + private Counters.Counter itemCounter; + private Counters.Counter skippedUserCounter; + private Counters.Counter skippedItemCounter; + private Counters.Counter thetaTrainableFeaturesCounter; + private Counters.Counter thetaTotalFeaturesCounter; + private Counters.Counter betaTrainableFeaturesCounter; + private Counters.Counter betaTotalFeaturesCounter; + + // error message strings + private static final String ARRAY_NOT_SQUARE_ERR = "Array is not square"; + private static final String DIFFERENT_DIMS_ERR = "Matrix, vector or array do not match in size"; + protected static class Weights extends Object2ObjectOpenHashMap { + + protected Object[] getKey() { + return key; + } + + @Nonnull + String[] getNonnullKeys() { + final String[] keys = new String[size]; + final Object[] k = (Object[]) key; + final int len = k.length; + for (int i = 0, j = 0; i < len; i++) { + final Object ki = k[i]; + if (ki != null) { + keys[j++] = ki.toString(); + } + } + return keys; + } + } + + public CofactorModel(@Nonnegative final int factor, @Nonnull final RankInitScheme initScheme, + @Nonnegative final float c0, @Nonnegative final float c1, @Nonnegative final float lambdaTheta, + @Nonnegative final float lambdaBeta, @Nonnegative final float lambdaGamma, final float globalBias, + @Nullable CofactorizationUDTF.ValidationMetric validationMetric, @Nonnegative final int numValPerRecord, + @Nonnull final Log log) { + + // rank init scheme is gaussian + // https://github.com/dawenl/cofactor/blob/master/src/cofacto.py#L98 + this.factor = factor; + this.initScheme = initScheme; + this.globalBias = globalBias; + this.lambdaTheta = lambdaTheta; + this.lambdaBeta = lambdaBeta; + this.lambdaGamma = lambdaGamma; + + this.theta = new Weights(); + this.beta = new Weights(); + this.betaBias = new Object2DoubleArrayMap<>(); + this.betaBias.defaultReturnValue(0.d); + this.gamma = new Weights(); + this.gammaBias = new Object2DoubleArrayMap<>(); + this.gammaBias.defaultReturnValue(0.d); + + this.B = new Array2DRowRealMatrix(this.factor, this.factor); + this.A = new ArrayRealVector(this.factor); + + this.randU = newRandoms(factor, 31L); + this.randI = newRandoms(factor, 41L); + + Preconditions.checkArgument(c0 >= 0.f && c0 <= 1.f); + Preconditions.checkArgument(c1 >= 0.f && c1 <= 1.f); + + this.c0 = c0; + this.c1 = c1; + + if (validationMetric == null) { + this.validationMetric = CofactorizationUDTF.ValidationMetric.AUC; + } else { + this.validationMetric = validationMetric; + } + + this.numValPerRecord = numValPerRecord; + this.validationProbes = new Feature[numValPerRecord]; + this.predictions = new Prediction[numValPerRecord]; + for (int i = 0; i < validationProbes.length; i++) { + validationProbes[i] = new StringFeature("", 0.d); + predictions[i] = new Prediction(); + } + } + + private void initFactorVector(final String key, final Weights weights) throws HiveException { + if (weights.containsKey(key)) { + throw new HiveException(String.format("two items or two users cannot have same `context` in training set: found duplicate context `%s`", key)); + } + final double[] v = new double[factor]; + switch (initScheme) { + case random: + uniformFill(v, randI[0], initScheme.maxInitValue); + break; + case gaussian: + gaussianFill(v, randI, initScheme.initStdDev); + break; + default: + throw new IllegalStateException( + "Unsupported rank initialization scheme: " + initScheme); + + } + weights.put(key, v); + } + + @Nullable + private static double[] getFactorVector(String key, Weights weights) { + return weights.get(key); + } + + private static void setFactorVector(final String key, final Weights weights, final RealVector factorVector) throws HiveException { + final double[] vec = weights.get(key); + if (vec == null) { + throw new HiveException(); + } + copyData(vec, factorVector); + } + + private static double getBias(String key, Object2DoubleMap biases) { + return biases.getDouble(key); + } + + private static void setBias(String key, Object2DoubleMap biases, double value) { + biases.put(key, value); + } + + @Nullable + public double[] getGammaVector(@Nonnull final String key) { + return getFactorVector(key, gamma); + } + + public double getGammaBias(@Nonnull final String key) { + return getBias(key, gammaBias); + } + + public void setGammaBias(@Nonnull final String key, final double value) { + setBias(key, gammaBias, value); + } + + public double getGlobalBias() { + return globalBias; + } + + public void setGlobalBias(final double value) { + globalBias = value; + } + + @Nullable + public double[] getThetaVector(@Nonnull final String key) { + return getFactorVector(key, theta); + } + + @Nullable + public double[] getBetaVector(@Nonnull final String key) { + return getFactorVector(key, beta); + } + + public double getBetaBias(@Nonnull final String key) { + return getBias(key, betaBias); + } + + public void setBetaBias(@Nonnull final String key, final double value) { + setBias(key, betaBias, value); + } + + @Nonnull + public Weights getTheta() { + return theta; + } + + @Nonnull + public Weights getBeta() { + return beta; + } + + @Nonnull + public Weights getGamma() { + return gamma; + } + + @Nonnull + public Object2DoubleMap getBetaBiases() { + return betaBias; + } + + @Nonnull + public Object2DoubleMap getGammaBiases() { + return gammaBias; + } + + public void updateWithUsers(@Nonnull final Map> userToItems) throws HiveException { + updateTheta(userToItems); + } + + public void updateWithItems(@Nonnull final Map> items, Map sppmi) throws HiveException { + updateBeta(items, sppmi); + updateGamma(items, sppmi); + updateBetaBias(items, sppmi); + updateGammaBias(items, sppmi); + updateGlobalBias(items, sppmi); + } + + /** + * Update latent factors of the users in the provided mini-batch. + * @param samples + */ + private void updateTheta(@Nonnull final Map> samples) throws HiveException { + // initialize item factors + // items should only be trainable if the dataset contains a major entry for that item (which it may not) + // variable names follow cofacto.py + final double[][] BTBpR = calculateWTWpR(beta, factor, c0, lambdaTheta); + + for (Map.Entry> sample : samples.entrySet()) { + RealVector newThetaVec = calculateNewThetaVector(sample, beta, factor, B, A, BTBpR, c0, c1); + if (newThetaVec != null) { + setFactorVector(sample.getKey(), theta, newThetaVec); + } else { + skippedUserCounter.increment(1); + } + userCounter.increment(1); + } + } + + @VisibleForTesting + protected static RealVector calculateNewThetaVector(@Nonnull final Map.Entry> sample, @Nonnull final Weights beta, + @Nonnegative final int numFactors, @Nonnull final RealMatrix B, @Nonnull final RealVector A, + @Nonnull final double[][] BTBpR, @Nonnegative final float c0, @Nonnegative final float c1) throws HiveException { + // filter for trainable items + List trainableItems = filterTrainableFeatures(sample.getValue(), beta); +// thetaTotalFeaturesCounter.increment(sample.getValue().size()); + if (trainableItems.isEmpty()) { + return null; + } +// thetaTrainableFeaturesCounter.increment(trainableItems.size()); + final double[] a = calculateA(trainableItems, beta, numFactors, c1); + final double[][] delta = calculateWTWSubsetStrings(trainableItems, beta, numFactors, c1 - c0); + final double[][] b = addInPlace(delta, BTBpR); + // solve and update factors + return solve(B, b, A, a); + } + + /** + * Update latent factors of the items in the provided mini-batch. + */ + private void updateBeta(@Nonnull final Map> items, @Nonnull final Map sppmi) throws HiveException { + // precomputed matrix + final double[][] TTTpR = calculateWTWpR(theta, factor, c0, lambdaBeta); + for (Map.Entry> sample : items.entrySet()) { + RealVector newBetaVec = calculateNewBetaVector(sample, sppmi, theta, gamma, gammaBias, betaBias, factor, B, A, TTTpR, c0, c1, globalBias); + if (newBetaVec != null) { + setFactorVector(sample.getKey(), beta, newBetaVec); + } else { + skippedItemCounter.increment(1); + } + itemCounter.increment(1); + } + } + + @VisibleForTesting + protected static RealVector calculateNewBetaVector(@Nonnull final Map.Entry> sample, + @Nonnull final Map sppmi, + @Nonnull final Weights theta, + @Nonnull final Weights gamma, @Nonnull final Object2DoubleMap gammaBias, + @Nonnull final Object2DoubleMap betaBias, final int numFactors, @Nonnull final RealMatrix B, + @Nonnull final RealVector A, @Nonnull final double[][] TTTpR, @Nonnegative final float c0, + @Nonnegative final float c1, final double globalBias) throws HiveException { + // filter for trainable users + final List trainableUsers = filterTrainableFeatures(sample.getValue(), theta); +// betaTotalFeaturesCounter.increment(sample.getValue().size()); + if (trainableUsers.isEmpty()) { + return null; + } + +// betaTrainableFeaturesCounter.increment(trainableUsers.size()); + + final List trainableCooccurringItems = filterTrainableFeatures(sppmi.get(sample.getKey()), gamma); + final double[] RSD = calculateRSD(sample.getKey(), trainableCooccurringItems, numFactors, betaBias, gammaBias, gamma, globalBias); + final double[] ApRSD = addInPlace(calculateA(trainableUsers, theta, numFactors, c1), RSD, 1.f); + + final double[][] GTG = calculateWTWSubsetFeatures(trainableCooccurringItems, gamma, numFactors, 1.f); + final double[][] delta = calculateWTWSubsetStrings(trainableUsers, theta, numFactors, c1 - c0); + // never add into the precomputed `TTTpR` array, only add into temporary arrays like `delta` and `GTG` + final double[][] b = addInPlace(addInPlace(delta, GTG), TTTpR); + + // solve and update factors + return solve(B, b, A, ApRSD); + } + + /** + * Update latent factors of the items in the provided mini-batch. + */ + private void updateGamma(@Nonnull final Map> samples, Map sppmi) throws HiveException { + for (Map.Entry> sample : samples.entrySet()) { + RealVector newGammaVec = calculateNewGammaVector(sample, sppmi, beta, gammaBias, betaBias, factor, B, A, lambdaGamma, globalBias); + if (newGammaVec != null) { + setFactorVector(sample.getKey(), gamma, newGammaVec); + } + } + } + + @VisibleForTesting + protected static RealVector calculateNewGammaVector(@Nonnull final Map.Entry> sample, @Nonnull final Map sppmi, + @Nonnull final Weights beta, + @Nonnull final Object2DoubleMap gammaBias, @Nonnull final Object2DoubleMap betaBias, + @Nonnegative final int numFactors, @Nonnull final RealMatrix B, @Nonnull final RealVector A, + @Nonnegative final float lambdaGamma, final double globalBias) throws HiveException { + // filter for trainable items + final List trainableCooccurringItems = filterTrainableFeatures(sppmi.get(sample.getKey()), beta); + if (trainableCooccurringItems.isEmpty()) { + return null; + } + final double[][] b = regularize(calculateWTWSubsetFeatures(trainableCooccurringItems, beta, numFactors, 1.f), lambdaGamma); + final double[] rsd = calculateRSD(sample.getKey(), trainableCooccurringItems, numFactors, gammaBias, betaBias, beta, globalBias); + // solve and update factors + return solve(B, b, A, rsd); + } + + private static double[][] regularize(@Nonnull final double[][] A, final float lambda) { + for (int i = 0; i < A.length; i++) { + A[i][i] += lambda; + } + return A; + } + + private void updateBetaBias(@Nonnull final Map> samples, Map sppmi) throws HiveException { + for (Map.Entry> sample : samples.entrySet()) { + Double newBetaBias = calculateNewBias(sample, sppmi, beta, gamma, gammaBias, globalBias); + if (newBetaBias != null) { + setBetaBias(sample.getKey(), newBetaBias); + } + } + } + + public void updateGammaBias(@Nonnull final Map> samples, Map sppmi) throws HiveException { + for (Map.Entry> sample : samples.entrySet()) { + Double newGammaBias = calculateNewBias(sample, sppmi, gamma, beta, betaBias, globalBias); + if (newGammaBias != null) { + setGammaBias(sample.getKey(), newGammaBias); + } + } + } + + private void updateGlobalBias(@Nonnull final Map> samples, Map sppmi) throws HiveException { + Double newGlobalBias = calculateNewGlobalBias(samples, sppmi, beta, gamma, betaBias, gammaBias); + if (newGlobalBias != null) { + setGlobalBias(newGlobalBias); + } + } + + @Nullable + protected static Double calculateNewGlobalBias(@Nonnull final Map> samples, @Nonnull final Map sppmi, + @Nonnull Weights beta, + @Nonnull Weights gamma, @Nonnull final Object2DoubleMap betaBias, + @Nonnull final Object2DoubleMap gammaBias) throws HiveException { + double newGlobalBias = 0.d; + int numEntriesInSPPMI = 0; + for (Map.Entry> sample : samples.entrySet()) { + // filter for trainable items + final List trainableCooccurringItems = filterTrainableFeatures(sppmi.get(sample.getKey()), beta); + if (trainableCooccurringItems.isEmpty()) { + continue; + } + numEntriesInSPPMI += trainableCooccurringItems.size(); + newGlobalBias += calculateGlobalBiasRSD(sample.getKey(), trainableCooccurringItems, beta, gamma, betaBias, gammaBias); + } + if (numEntriesInSPPMI == 0) { + return null; + } + return newGlobalBias / numEntriesInSPPMI; + } + + @VisibleForTesting + protected static Double calculateNewBias(@Nonnull final Map.Entry> sample, + @Nonnull final Map sppmi, + @Nonnull final Weights beta, + @Nonnull final Weights gamma, @Nonnull final Object2DoubleMap biases, + final double globalBias) throws HiveException { + // filter for trainable items + final List trainableCooccurringItems = filterTrainableFeatures(sppmi.get(sample.getKey()), beta); + if (trainableCooccurringItems.isEmpty()) { + return null; + } + double rsd = calculateBiasRSD(sample.getKey(), trainableCooccurringItems, beta, gamma, biases, globalBias); + return rsd / trainableCooccurringItems.size(); + + } + + @VisibleForTesting + protected static double calculateGlobalBiasRSD(@Nonnull final String thisItem, @Nonnull final List trainableItems, + @Nonnull final Weights beta, @Nonnull final Weights gamma, + @Nonnull final Object2DoubleMap betaBias, @Nonnull final Object2DoubleMap gammaBias) { + double result = 0.d; + final double[] thisFactorVec = getFactorVector(thisItem, beta); + final double thisBias = getBias(thisItem, betaBias); + for (Feature cooccurrence : trainableItems) { + String j = cooccurrence.getFeature(); + final double[] cooccurVec = getFactorVector(j, gamma); + double cooccurBias = getBias(j, gammaBias); + double value = cooccurrence.getValue() - dotProduct(thisFactorVec, cooccurVec) - thisBias - cooccurBias; + result += value; + } + return result; + } + + @VisibleForTesting + protected static double calculateBiasRSD(@Nonnull final String thisItem, @Nonnull final List trainableItems, @Nonnull final Weights beta, + @Nonnull final Weights gamma, @Nonnull final Object2DoubleMap biases, final double globalBias) { + double result = 0.d; + final double[] thisFactorVec = getFactorVector(thisItem, beta); + for (Feature cooccurrence : trainableItems) { + String j = cooccurrence.getFeature(); + final double[] cooccurVec = getFactorVector(j, gamma); + double cooccurBias = getBias(j, biases); + double value = cooccurrence.getValue() - dotProduct(thisFactorVec, cooccurVec) - cooccurBias - globalBias; + result += value; + } + return result; + } + + @VisibleForTesting + @Nonnull + protected static double[] calculateRSD(@Nonnull final String thisItem, @Nonnull final List trainableItems, final int numFactors, + @Nonnull final Object2DoubleMap fixedBias, @Nonnull final Object2DoubleMap changingBias, + @Nonnull final Weights weights, final double globalBias) throws HiveException { + + final double b = getBias(thisItem, fixedBias); + final double[] accumulator = new double[numFactors]; + for (Feature cooccurrence : trainableItems) { + final String j = cooccurrence.getFeature(); + double scale = cooccurrence.getValue() - b - getBias(j, changingBias) - globalBias; + final double[] g = getFactorVector(j, weights); + addInPlace(accumulator, g, scale); + } + return accumulator; + } + + /** + * Calculate W' x W plus regularization matrix + */ + @VisibleForTesting + @Nonnull + protected static double[][] calculateWTWpR(@Nonnull final Weights W, @Nonnegative final int numFactors, @Nonnegative final float c0, @Nonnegative final float lambda) { + double[][] WTW = calculateWTW(W, numFactors, c0); + return regularize(WTW, lambda); + } + + private static void checkCondition(final boolean condition, final String errorMessage) throws HiveException { + if (!condition) { + throw new HiveException(errorMessage); + } + } + + @VisibleForTesting + @Nonnull + protected static double[][] addInPlace(@Nonnull final double[][] A, @Nonnull final double[][] B) throws HiveException { + checkCondition(A.length == A[0].length && A.length == B.length && B.length == B[0].length, ARRAY_NOT_SQUARE_ERR); + for (int i = 0; i < A.length; i++) { + for (int j = 0; j < A[0].length; j++) { + A[i][j] += B[i][j]; + } + } + return A; + } + + @VisibleForTesting + @Nonnull + protected static List filterTrainableFeatures(@Nonnull final List features, @Nonnull final Weights weights) { + final List trainableFeatures = new ArrayList<>(); + for (String feature : features) { + if (isTrainable(feature, weights)) { + trainableFeatures.add(feature); + } + } + return trainableFeatures; + } + + @VisibleForTesting + @Nonnull + protected static List filterTrainableFeatures(@Nullable final Feature[] features, @Nonnull final Weights weights) throws HiveException { + checkCondition(features != null, "features cannot be null"); + final List trainableFeatures = new ArrayList<>(); + String fName; + for (Feature f : features) { + fName = f.getFeature(); + if (isTrainable(fName, weights)) { + trainableFeatures.add(f); + } + } + return trainableFeatures; + } + + @VisibleForTesting + protected static RealVector solve(@Nonnull final RealMatrix B, @Nonnull final double[][] dataB, @Nonnull final RealVector A, @Nonnull final double[] dataA) throws HiveException { + // b * x = a + // solves for x + copyData(B, dataB); + copyData(A, dataA); + + final LUDecomposition LU = new LUDecomposition(B); + final DecompositionSolver solver = LU.getSolver(); + + if (solver.isNonSingular()) { + return LU.getSolver().solve(A); + } else { + SingularValueDecomposition svd = new SingularValueDecomposition(B); + return svd.getSolver().solve(A); + } + } + + private static void copyData(@Nonnull final RealMatrix dst, @Nonnull final double[][] src) throws HiveException { + checkCondition(dst.getRowDimension() == src.length && dst.getColumnDimension() == src[0].length, DIFFERENT_DIMS_ERR); + for (int i = 0, rows = dst.getRowDimension(); i < rows; i++) { + final double[] src_i = src[i]; + for (int j = 0, cols = dst.getColumnDimension(); j < cols; j++) { + dst.setEntry(i, j, src_i[j]); + } + } + } + + private static void copyData(@Nonnull final RealVector dst, @Nonnull final double[] src) throws HiveException { + checkCondition(dst.getDimension() == src.length, DIFFERENT_DIMS_ERR); + for (int i = 0; i < dst.getDimension(); i++) { + dst.setEntry(i, src[i]); + } + } + + private static void copyData(@Nonnull final double[] dst, @Nonnull final RealVector src) throws HiveException { + checkCondition(dst.length == src.getDimension(), DIFFERENT_DIMS_ERR); + for (int i = 0; i < dst.length; i++) { + dst[i] = src.getEntry(i); + } + } + + @VisibleForTesting + @Nonnull + protected static double[][] calculateWTW(@Nonnull final Weights weights, @Nonnull final int numFactors, @Nonnull final float constant) { + final double[][] WTW = new double[numFactors][numFactors]; + for (double[] vec : weights.values()) { + for (int i = 0; i < numFactors; i++) { + final double[] WTW_f = WTW[i]; + for (int j = 0; j < numFactors; j++) { + double val = constant * vec[i] * vec[j]; + WTW_f[j] += val; + } + } + } + return WTW; + } + + @VisibleForTesting + @Nonnull + protected static double[][] calculateWTWSubsetStrings(@Nonnull final List subset, @Nonnull final Weights weights, @Nonnegative final int numFactors, @Nonnegative final float constant) throws HiveException { + // equivalent to `B_u.T.dot((c1 - c0) * B_u)` in cofacto.py + final double[][] delta = new double[numFactors][numFactors]; + for (String f : subset) { + final double[] vec = getFactorVector(f, weights); + checkCondition(vec != null, "null vector is not allowed"); + for (int i = 0; i < numFactors; i++) { + final double[] delta_f = delta[i]; + for (int j = 0; j < numFactors; j++) { + double val = constant * vec[i] * vec[j]; + delta_f[j] += val; + } + } + } + return delta; + } + + @VisibleForTesting + @Nonnull + protected static double[][] calculateWTWSubsetFeatures(@Nonnull final List subset, @Nonnull final Weights weights, @Nonnegative final int numFactors, @Nonnegative final float constant) throws HiveException { + // equivalent to `B_u.T.dot((c1 - c0) * B_u)` in cofacto.py + final double[][] delta = new double[numFactors][numFactors]; + for (Feature f : subset) { + final double[] vec = getFactorVector(f.getFeature(), weights); + checkCondition(vec != null, "null vector is not allowed"); + for (int i = 0; i < numFactors; i++) { + final double[] delta_f = delta[i]; + for (int j = 0; j < numFactors; j++) { + double val = constant * vec[i] * vec[j]; + delta_f[j] += val; + } + } + } + return delta; + } + + @VisibleForTesting + @Nonnull + protected static double[] calculateA(@Nonnull final List items, @Nonnull final Weights weights, @Nonnegative final int numFactors, @Nonnegative final float constant) throws HiveException { + // Equivalent to: a = x_u.dot(c1 * B_u) + // x_u is a (1, i) matrix of all ones + // B_u is a (i, F) matrix + // What it does: sums factor n of each item in B_u + final double[] A = new double[numFactors]; + for (String item : items) { + addInPlace(A, getFactorVector(item, weights), 1.d); + } + for (int a = 0; a < A.length; a++) { + A[a] *= constant; + } + return A; + } + + @Nullable + public Double predict(@Nonnull final String user, @Nonnull final String item) { + if (!theta.containsKey(user) || !beta.containsKey(item)) { + return null; + } + final double[] u = getThetaVector(user), i = getBetaVector(item); + return dotProduct(u, i); + } + + @VisibleForTesting + protected static double dotProduct(@Nonnull final double[] u, @Nonnull final double[] v) { + double result = 0.d; + for (int i = 0; i < u.length; i++) { + result += u[i] * v[i]; + } + return result; + } + +// public double calculateLoss(@Nonnull final Map> users, @Nonnull final List>> items) { +// // for speed - can calculate loss on a small subset of the training data +// double mf_loss = calculateMFLoss(users, theta, beta, c0, c1) + calculateMFLoss(items, beta, theta, c0, c1); +// double embed_loss = calculateEmbedLoss(items, beta, gamma, betaBias, gammaBias); +// return mf_loss + embed_loss + sumL2Loss(theta, lambdaTheta) + sumL2Loss(beta, lambdaBeta) + sumL2Loss(gamma, lambdaGamma); +// } + + + @VisibleForTesting + protected static double calculateEmbedLoss(@Nonnull final List items, @Nonnull final Weights beta, + @Nonnull final Weights gamma, @Nonnull final Object2DoubleMap betaBias, + @Nonnull final Object2DoubleMap gammaBias) { + double loss = 0.d, val, bBias, gBias; + double[] bFactors, gFactors; + String bKey, gKey; + for (CofactorizationUDTF.TrainingSample item : items) { + bKey = item.context; + bFactors = getFactorVector(bKey, beta); + bBias = getBias(bKey, betaBias); + for (Feature cooccurrence : item.sppmi) { + if (!isTrainable(cooccurrence.getFeature(), beta)) { + continue; + } + gKey = cooccurrence.getFeature(); + gFactors = getFactorVector(gKey, gamma); + gBias = getBias(gKey, gammaBias); + val = cooccurrence.getValue() - dotProduct(bFactors, gFactors) - bBias - gBias; + loss += val * val; + } + } + return loss; + } + + @VisibleForTesting + protected static double calculateMFLoss(@Nonnull final List samples, @Nonnull final Weights contextWeights, + @Nonnull final Weights featureWeights, @Nonnegative final float c0, @Nonnegative final float c1) { + double loss = 0.d, err, predicted, y; + double[] contextFactors, ratedFactors; + + for (CofactorizationUDTF.TrainingSample sample : samples) { + contextFactors = getFactorVector(sample.context, contextWeights); + // all items / users + for (double[] unratedFactors : featureWeights.values()) { + predicted = dotProduct(contextFactors, unratedFactors); + err = (0.d - predicted); + loss += c0 * err * err; + } + // only rated items / users + for (Feature f : sample.features) { + if (!isTrainable(f.getFeature(), featureWeights)) { + continue; + } + ratedFactors = getFactorVector(f.getFeature(), featureWeights); + predicted = dotProduct(contextFactors, ratedFactors); + y = f.getValue(); + err = y - predicted; + loss += (c1 - c0) * err * err; + } + } + return loss; + } + + @VisibleForTesting + protected static double sumL2Loss(@Nonnull final Weights weights, @Nonnegative float lambda) { + double loss = 0.d; + for (double[] v : weights.values()) { + loss += L2Distance(v); + } + return lambda * loss; + } + + @VisibleForTesting + protected static double L2Distance(@Nonnull final double[] vec) { + double result = 0.d; + for (double v : vec) { + result += v * v; + } + return Math.sqrt(result); + } + + /** + * Sample positive and negative validation examples and return a performance metric that + * should be minimized. + * + * @return Validation metric + * @throws HiveException + */ + public Double validate(@Nonnull final String user, @Nonnull final String item) throws HiveException { + if (!theta.containsKey(user) || !beta.containsKey(item)) { + return null; + } + // limit numPos and numNeg +// int numPos = Math.min(sample.features.length, (int) Math.ceil(this.numValPerRecord * 0.5)); + int numPos = 1; +// int numNeg = Math.min(this.numValPerRecord - numPos, sample.isItem() ? users.length : items.length); + int numNeg = 2; +// if (numPos == 0) { +// throw new HiveException("numPos is 0: sample.features.length = " + sample.features.length + ", ceil = " + (int) Math.ceil(this.numValPerRecord * 0.5)); +// } +// if (numNeg == 0) { +// throw new HiveException("numNeg is 0, users.length = " + users.length + ", items.length = " + items.length); +// } + +// getValidationExamples(numPos, numNeg, sample.features, sample.isItem(), validationProbes, seed); +// if (validationMetric == CofactorizationUDTF.ValidationMetric.AUC) { +// return -calculateAUC(validationProbes, predictions, sample, numPos, numNeg); +// } else { +// return calculateLoss(validationProbes, sample, numPos, numNeg); +// } + return null; + } + + private boolean isPredictable(@Nonnull final String context, final boolean isItem) { + if (isItem) { + return beta.containsKey(context); + } else { + return theta.containsKey(context); + } + } + + /** + * TODO: not implemented + * + * @return + */ + private double calculateLoss(Feature[] validationProbes, CofactorizationUDTF.TrainingSample sample, int numPos, int numNeg) { + return 0d; + } + + /** + * Calculates area under curve for validation metric. + */ + private double calculateAUC(@Nonnull final Feature[] validationProbes, @Nonnull final Prediction[] predictions, CofactorizationUDTF.TrainingSample sample, final int numPos, final int numNeg) throws HiveException { + // make predictions for positive and then negative examples + int nextIdx = fillPredictions(validationProbes, predictions, sample, 0, numPos, 0, 1); +// if (nextIdx == 0) { +// throw new HiveException("nextIdx is 0, no positives in predictions, validation probes = " + Arrays.toString(validationProbes)); +// } + int endIdx = fillPredictions(validationProbes, predictions, sample, nextIdx, numPos + numNeg, nextIdx, 0); + + // sort in descending order for all filled predictions + Arrays.sort(predictions, 0, endIdx); + + double area = 0d, scorePrev = Double.MIN_VALUE; + int fp = 0, tp = 0; + int fpPrev = 0, tpPrev = 0; + + for (int i = 0; i < endIdx; i++) { + final Prediction p = predictions[i]; + if (p.prediction != scorePrev) { + area += trapezoid(fp, fpPrev, tp, tpPrev); + scorePrev = p.prediction; + fpPrev = fp; + tpPrev = tp; + } + if (p.label == 1) { + tp += 1; + } else { + fp += 1; + } + } + area += trapezoid(fp, fpPrev, tp, tpPrev); + if (tp * fp == 0) { + return 0d; + } + return area / (tp * fp); + } + + /** + * Calculates area of a trapezoid. + */ + private static double trapezoid(final int x1, final int x2, final int y1, final int y2) { + final int base = Math.abs(x1 - x2); + final double height = (y1 + y2) * 0.5; + return base * height; + } + + /** + * Fill an array of predictions. + * @return index of the next empty entry in {@code predictions} array + */ + private int fillPredictions(@Nonnull final Feature[] validationProbes, @Nonnull final Prediction[] predictions, @Nonnull final CofactorizationUDTF.TrainingSample sample, + final int lo, final int hi, int fillIdx, final int label) { + for (int i = lo; i < hi; i++) { + final Feature pos = validationProbes[i]; + final Double pred; + if (sample.isItem()) { + pred = predict(pos.getFeature(), sample.context); + } else { + pred = predict(sample.context, pos.getFeature()); + } + if (pred == null) { + continue; + } + predictions[fillIdx].prediction = pred; + predictions[fillIdx].label = label; + fillIdx++; + } + return fillIdx; + } + + /** + * Sample positive and negative samples. + * @return number of negatives that were successfully sampled + */ + private void getValidationExamples(final int numPos, final int numNeg, @Nonnull final Feature[] positives, final boolean isContextAnItem, + @Nonnull final Feature[] validationProbes, final int seed) { + final Random rand = new Random(seed); + samplePositives(numPos, positives, validationProbes, rand); + final String[] keys = isContextAnItem ? users : items; + sampleNegatives(numPos, numNeg, validationProbes, keys, rand); + } + + /** + * Samples negative examples. + */ + @VisibleForTesting + protected static void sampleNegatives(final int numPos, final int numNeg, @Nonnull final Feature[] validationProbes, + @Nonnull final String[] keys, @Nonnull final Random rand) { + // sample numPos positive examples without replacement + for (int i = numPos, size = numPos + numNeg; i < size; i++) { + final String negKey = keys[rand.nextInt(keys.length)]; + validationProbes[i].setFeature(negKey); + validationProbes[i].setValue(0.d); + } + } + + private static void samplePositives(final int numPos, @Nonnull final Feature[] positives, @Nonnull final Feature[] validationProbes, @Nonnull final Random rand) { + // sample numPos positive examples without replacement + for (int i = 0; i < numPos; i++) { + validationProbes[i] = positives[rand.nextInt(positives.length)]; + } + } + + /** + * Add v to u in-place without creating a new RealVector instance. + * + * @param u array to which v will be added + * @param v array containing new values to be added to u + * @param scalar value to multiply each entry in v before adding to u + */ + @VisibleForTesting + @Nonnull + protected static double[] addInPlace(@Nonnull final double[] u, @Nullable final double[] v, final double scalar) throws HiveException { + checkCondition(v != null, "null vector is not allowed"); + checkCondition(u.length == v.length, DIFFERENT_DIMS_ERR); + for (int i = 0; i < u.length; i++) { + u[i] += scalar * v[i]; + } + return u; + } + + private static boolean isTrainable(@Nonnull final String name, @Nonnull final Weights weights) { + return weights.containsKey(name); + } + + @Nonnull + private static Random[] newRandoms(@Nonnegative final int size, final long seed) { + final Random[] rand = new Random[size]; + for (int i = 0, len = rand.length; i < len; i++) { + rand[i] = new Random(seed + i); + } + return rand; + } + + private static void uniformFill(@Nonnull final double[] a, @Nonnull final Random rand, final float maxInitValue) { + for (int i = 0, len = a.length; i < len; i++) { + double v = rand.nextDouble() * maxInitValue / len; + a[i] = v; + } + } + + private static void gaussianFill(@Nonnull final double[] a, @Nonnull final Random[] rand, @Nonnegative final double stddev) { + for (int i = 0, len = a.length; i < len; i++) { + double v = MathUtils.gaussian(0.d, stddev, rand[i]); + a[i] = v; + } + } +} diff --git a/core/src/main/java/hivemall/mf/CofactorizationPredictUDF.java b/core/src/main/java/hivemall/mf/CofactorizationPredictUDF.java new file mode 100644 index 000000000..d6bd696f5 --- /dev/null +++ b/core/src/main/java/hivemall/mf/CofactorizationPredictUDF.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.mf; + +import java.util.List; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDF; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.io.FloatWritable; + +@Description(name = "cofactor_predict", + value = "_FUNC_(array theta, array beta) - Returns the prediction value") +@UDFType(deterministic = true, stateful = false) +public final class CofactorizationPredictUDF extends UDF { + + private static final double DEFAULT_RESULT = 0.d; + + @Nonnull + public DoubleWritable evaluate(@Nullable List Pu, @Nullable List Qi) throws HiveException { + if (Pu == null || Qi == null) { + return new DoubleWritable(DEFAULT_RESULT); + } + + final int PuSize = Pu.size(); + final int QiSize = Qi.size(); + // workaround for TD + if (PuSize == 0) { + return new DoubleWritable(DEFAULT_RESULT); + } else if (QiSize == 0) { + return new DoubleWritable(DEFAULT_RESULT); + } + + if (QiSize != PuSize) { + throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize); + } + + double ret = DEFAULT_RESULT; + for (int k = 0; k < PuSize; k++) { + FloatWritable Pu_k = Pu.get(k); + if (Pu_k == null) { + continue; + } + FloatWritable Qi_k = Qi.get(k); + if (Qi_k == null) { + continue; + } + ret += Pu_k.get() * Qi_k.get(); + } + return new DoubleWritable(ret); + } +} diff --git a/core/src/main/java/hivemall/mf/CofactorizationUDTF.java b/core/src/main/java/hivemall/mf/CofactorizationUDTF.java new file mode 100644 index 000000000..8bec53bca --- /dev/null +++ b/core/src/main/java/hivemall/mf/CofactorizationUDTF.java @@ -0,0 +1,529 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.mf; + +import hivemall.UDTFWithOptions; +import hivemall.annotations.VisibleForTesting; +import hivemall.common.ConversionState; +import hivemall.fm.Feature; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.io.FileUtils; +import hivemall.utils.io.NioStatefulSegment; +import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.Primitives; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.*; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapred.Counters; +import org.apache.hadoop.mapred.Reporter; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.*; + + +/** + * Cofactorization for implicit and explicit recommendation + */ +@Description(name = "train_cofactor", + value = "_FUNC_(string context, array features, boolean is_validation, boolean is_item, array sppmi [, String options])" + + " - Returns a relation theta, array beta>") +public class CofactorizationUDTF extends UDTFWithOptions { + private static final Log LOG = LogFactory.getLog(CofactorizationUDTF.class); + + // Option variables + // The number of latent factors + private int factor; + // The scaling hyperparameter for zero entries in the rank matrix + private float c0; + // The scaling hyperparameter for non-zero entries in the rank matrix + private float c1; + // The initial mean rating + private float globalBias; + // Whether update (and return) the mean rating or not + private boolean updateGlobalBias; + // The number of iterations + private int maxIters; + // Whether to use bias clause + private boolean useBiasClause; + // Whether to use normalization + private boolean useL2Norm; + // regularization hyperparameters + private float lambdaTheta; + private float lambdaBeta; + private float lambdaGamma; + + // validation metric + private ValidationMetric validationMetric; + + // Initialization strategy of rank matrix + private CofactorModel.RankInitScheme rankInit; + + // Model itself + @VisibleForTesting + protected CofactorModel model; + + // Variable managing status of learning + private ConversionState validationState; + private int numValPerRecord; + + // Input OIs and Context + private PrimitiveObjectInspector userOI; + @VisibleForTesting + protected PrimitiveObjectInspector itemOI; + + private BooleanObjectInspector isValidationOI; + @VisibleForTesting + protected ListObjectInspector sppmiOI; + + // Used for iterations + @VisibleForTesting + protected long numValidations; + protected long numTraining; + + // training data + private Map> userToItems; + private Map> itemToUsers; + private Map sppmi; + + // validation + private Random rand; + private double validationRatio; + private List validationUsers; + private List validationItems; + + static class MiniBatch { + private List users; + private List items; + private List validationSamples; + + protected MiniBatch() { + users = new ArrayList<>(); + items = new ArrayList<>(); + validationSamples = new ArrayList<>(); + } + + protected void add(TrainingSample sample) { + if (sample.isValidation) { + validationSamples.add(sample); + } else { + if (sample.isItem()) { + items.add(sample); + } else { + users.add(sample); + } + } + } + + protected void clear() { + users.clear(); + items.clear(); + validationSamples.clear(); + } + + protected int trainingSize() { + return items.size() + users.size(); + } + + protected int validationSize() { + return validationSamples.size(); + } + + protected List getItems() { + return items; + } + + protected List getUsers() { + return users; + } + + public List getValidationSamples() { + return validationSamples; + } + } + + static class TrainingSample { + protected String context; + protected Feature[] features; + protected Feature[] sppmi; + protected boolean isValidation; + + protected TrainingSample(@Nonnull final String context, @Nonnull final Feature[] features, final boolean isValidation, @Nullable final Feature[] sppmi) { + this.context = context; + this.features = features; + this.sppmi = sppmi; + this.isValidation = isValidation; + } + + protected boolean isItem() { + return sppmi != null; + } + } + + enum ValidationMetric { + AUC, OBJECTIVE; + + static ValidationMetric resolve(@Nonnull final String opt) { + switch (opt.toLowerCase()) { + case "auc": + return AUC; + case "objective": + case "loss": + return OBJECTIVE; + default: + throw new IllegalArgumentException(opt + " is not a supported Validation Metric."); + } + } + } + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("k", "factor", true, "The number of latent factor [default: 10] " + + " Note this is alias for `factors` option."); + opts.addOption("f", "factors", true, "The number of latent factor [default: 10]"); + opts.addOption("lt", "lambda_theta", true, "The theta regularization factor [default: 1e-5]"); + opts.addOption("lb", "lambda_beta", true, "The beta regularization factor [default: 1e-5]"); + opts.addOption("lg", "lambda_gamma", true, "The gamma regularization factor [default: 1.0]"); + opts.addOption("c0", "c0", true, + "The scaling hyperparameter for zero entries in the rank matrix [default: 0.1]"); + opts.addOption("c1", "c1", true, + "The scaling hyperparameter for non-zero entries in the rank matrix [default: 1.0]"); + opts.addOption("gb", "global_bias", true, "The global bias [default: 0.0]"); + opts.addOption("update_gb", "update_global_bias", true, + "Whether update (and return) the global bias or not [default: false]"); + opts.addOption("rankinit", true, + "Initialization strategy of rank matrix [random, gaussian] (default: gaussian)"); + opts.addOption("maxval", "max_init_value", true, + "The maximum initial value in the rank matrix [default: 1.0]"); + opts.addOption("min_init_stddev", true, + "The minimum standard deviation of initial rank matrix [default: 0.01]"); + opts.addOption("iters", "iterations", true, "The number of iterations [default: 1]"); + opts.addOption("iter", true, + "The number of iterations [default: 1] Alias for `-iterations`"); + opts.addOption("max_iters", "max_iters", true, "The number of iterations [default: 1]"); + opts.addOption("disable_bias", "no_bias", false, "Turn off bias clause"); + // normalization + opts.addOption("disable_norm", "disable_l2norm", false, "Disable instance-wise L2 normalization"); + // validation + opts.addOption("disable_cv", "disable_cvtest", false, "Whether to disable convergence check [default: enabled]"); + opts.addOption("cv_rate", "convergence_rate", true, "Threshold to determine convergence [default: 0.005]"); + opts.addOption("val_metric", "validation_metric", true, "Metric to use for validation ['auc', 'objective']"); + opts.addOption("val_ratio", "validation_ratio", true, "Proportion of examples to use as validation data [default: 0.125]"); + opts.addOption("num_val", "num_validation_examples_per_record", true, "Number of validation examples to use per record [default: 10]"); + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = null; + String rankInitOpt = "gaussian"; + float maxInitValue = 1.f; + double initStdDev = 0.01d; + boolean convergenceCheck = true; + double convergenceRate = 0.005d; + String validationMetricOpt = "auc"; + this.c0 = 0.1f; + this.c1 = 1.0f; + this.lambdaTheta = 1e-5f; + this.lambdaBeta = 1e-5f; + this.lambdaGamma = 1.0f; + this.globalBias = 0.f; + this.maxIters = 1; + this.factor = 10; + this.numValPerRecord = 10; + this.validationRatio = 0.125; + + if (argOIs.length >= 3) { + String rawArgs = HiveUtils.getConstString(argOIs[3]); + cl = parseOptions(rawArgs); + if (cl.hasOption("factors")) { + this.factor = Primitives.parseInt(cl.getOptionValue("factors"), factor); + } else { + this.factor = Primitives.parseInt(cl.getOptionValue("factor"), factor); + } + this.lambdaTheta = Primitives.parseFloat(cl.getOptionValue("lambda_theta"), lambdaTheta); + this.lambdaBeta = Primitives.parseFloat(cl.getOptionValue("lambda_beta"), lambdaBeta); + this.lambdaGamma = Primitives.parseFloat(cl.getOptionValue("lambda_gamma"), lambdaGamma); + + this.c0 = Primitives.parseFloat(cl.getOptionValue("c0"), c0); + this.c1 = Primitives.parseFloat(cl.getOptionValue("c1"), c1); + + this.globalBias = Primitives.parseFloat(cl.getOptionValue("global_bias"), globalBias); + this.updateGlobalBias = cl.hasOption("update_global_bias"); + + rankInitOpt = cl.getOptionValue("rankinit", rankInitOpt); + maxInitValue = Primitives.parseFloat(cl.getOptionValue("max_init_value"), maxInitValue); + initStdDev = Primitives.parseDouble(cl.getOptionValue("min_init_stddev"), initStdDev); + + if (cl.hasOption("iter")) { + this.maxIters = Primitives.parseInt(cl.getOptionValue("iter"), maxIters); + } else { + this.maxIters = Primitives.parseInt(cl.getOptionValue("max_iters"), maxIters); + } + if (maxIters < 1) { + throw new UDFArgumentException( + "'-max_iters' must be greater than or equal to 1: " + maxIters); + } + + convergenceCheck = !cl.hasOption("disable_cvtest"); + convergenceRate = Primitives.parseDouble(cl.getOptionValue("cv_rate"), convergenceRate); + validationMetricOpt = cl.getOptionValue("validation_metric", validationMetricOpt); + this.numValPerRecord = Primitives.parseInt(cl.getOptionValue("num_validation_examples_per_record"), numValPerRecord); + this.validationRatio = Primitives.parseDouble(cl.getOptionValue("validation_ratio"), this.validationRatio); + if (this.validationRatio > 1 || this.validationRatio < 0) { + throw new UDFArgumentException( + "'-validation_ratio' must be between 0.0 and 1.0" + ); + } + boolean noBias = cl.hasOption("no_bias"); + this.useBiasClause = !noBias; + if (noBias && updateGlobalBias) { + throw new UDFArgumentException( + "Cannot set both `update_gb` and `no_bias` option"); + } + this.useL2Norm = !cl.hasOption("disable_l2norm"); + } + this.rankInit = CofactorModel.RankInitScheme.resolve(rankInitOpt); + rankInit.setMaxInitValue(maxInitValue); + rankInit.setInitStdDev(initStdDev); + this.validationState = new ConversionState(convergenceCheck, convergenceRate); + this.validationMetric = ValidationMetric.resolve(validationMetricOpt); + return cl; + } + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length < 3) { + throw new UDFArgumentException( + "_FUNC_ takes 3 arguments: string user, string item, array sppmi [, CONSTANT STRING options]"); + } + this.userOI = HiveUtils.asPrimitiveObjectInspector(argOIs[0]); + this.itemOI = HiveUtils.asPrimitiveObjectInspector(argOIs[1]); + this.sppmiOI = HiveUtils.asListOI(argOIs[2]); + HiveUtils.validateFeatureOI(sppmiOI.getListElementObjectInspector()); + + processOptions(argOIs); + + this.model = new CofactorModel(factor, rankInit, c0, c1, lambdaTheta, lambdaBeta, lambdaGamma, globalBias, + validationMetric, numValPerRecord, LOG); + + userToItems = new HashMap<>(); + itemToUsers = new HashMap<>(); + + validationUsers = new ArrayList<>(); + validationItems = new ArrayList<>(); + + rand = new Random(31); + + List fieldNames = new ArrayList(); + List fieldOIs = new ArrayList(); + fieldNames.add("context"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + fieldNames.add("theta"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableFloatObjectInspector)); + fieldNames.add("beta"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableFloatObjectInspector)); + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @Override + public void process(Object[] args) throws HiveException { + final String user = PrimitiveObjectInspectorUtils.getString(args[0], userOI); + final String item = PrimitiveObjectInspectorUtils.getString(args[1], itemOI); + Feature[] sppmiVec = null; + if (!sppmi.containsKey(item)) { + if (args[2] != null) { + sppmiVec = Feature.parseFeatures(args[2], sppmiOI, null, false); + sppmi.put(item, sppmiVec); + } else { + throw new HiveException("null sppmi vector provided when item does not exist in sppmi"); + } + } + recordSample(user, item); + } + + private static void addToMap(@Nonnull final Map> map, @Nonnull final String key, @Nonnull final String value) { + List values = map.get(key); + final boolean isNewKey = values == null; + if (isNewKey) { + values = new ArrayList<>(); + values.add(value); + map.put(key, values); + } else { + values.add(value); + } + } + + private void recordSample(@Nonnull final String user, @Nonnull final String item) { + // validation data + if (rand.nextDouble() < validationRatio) { + addValidationSample(user, item); + } else { + // train + addToMap(userToItems, user, item); + addToMap(itemToUsers, item, user); + } + } + + private void addValidationSample(@Nonnull final String user, @Nonnull final String item) { + validationUsers.add(user); + validationItems.add(item); + } + + private void addToSPPMI(@Nonnull final String item, @Nonnull final Feature[] sppmiVec) { + if (sppmi.containsKey(item)) { + return; + } + sppmi.put(item, sppmiVec); + } + + @Override + public void close() throws HiveException { + try { + model.registerUsers(userToItems.keySet()); + model.registerItems(itemToUsers.keySet()); + + final Reporter reporter = getReporter(); + final Counters.Counter iterCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", "iteration"); + + final Counters.Counter userCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", "users"); + final Counters.Counter itemCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", "items"); + final Counters.Counter skippedUserCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", "skippedUsers"); + final Counters.Counter skippedItemCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", "skippedItems"); + + final Counters.Counter thetaTotalCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", "thetaTotalFeaturesCounter"); + final Counters.Counter thetaTrainableCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", "thetaTrainableFeaturesCounter"); + + final Counters.Counter betaTotalCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", "betaTotalFeaturesCounter"); + final Counters.Counter betaTrainableCounter = (reporter == null) ? null + : reporter.getCounter("hivemall.mf.Cofactor$Counter", "betaTrainableFeaturesCounter"); + + model.registerCounters(userCounter, itemCounter, skippedUserCounter, skippedItemCounter, thetaTrainableCounter, + thetaTotalCounter, betaTrainableCounter, betaTotalCounter); + + + for (int iteration = 0; iteration < maxIters; iteration++) { + // train the model on a full batch (i.e., all the data) using mini-batch updates + validationState.next(); + reportProgress(reporter); + setCounterValue(iterCounter, iteration); + runTrainingIteration(); + + System.out.println("Validation loss: " + validationState.getAverageLoss(numValidations)); + + LOG.info("Performed " + iteration + " iterations of " + + NumberUtils.formatNumber(maxIters) + " with " + numTraining + " training examples and " + + numValidations + " validation examples."); +// + " training examples on a secondary storage (thus " +// + NumberUtils.formatNumber(_t) + " training updates in total), used " +// + _numValidations + " validation examples"); + + if (validationState.isConverged(numTraining)) { + break; + } + } + forwardModel(); + } finally { + this.model = null; + } + } + + private void forwardModel() throws HiveException { + if (model == null) { + return; + } + + final Text id = new Text(); + final FloatWritable[] theta = HiveUtils.newFloatArray(factor, 0.f); + final FloatWritable[] beta = HiveUtils.newFloatArray(factor, 0.f); + final Object[] forwardObj = new Object[] {id, theta, null}; + + int numUsersForwarded = 0, numItemsForwarded = 0; + + for (Map.Entry entry : model.getTheta().entrySet()) { + id.set(entry.getKey()); + copyTo(entry.getValue(), theta); + forward(forwardObj); + numUsersForwarded++; + } + + forwardObj[1] = null; + forwardObj[2] = beta; + for (Map.Entry entry : model.getBeta().entrySet()) { + id.set(entry.getKey()); + copyTo(entry.getValue(), beta); + forward(forwardObj); + numItemsForwarded++; + } + LOG.info("Forwarded the prediction model of " + numUsersForwarded + " user rows (theta) and " + numItemsForwarded + " item rows (beta).]"); + + } + + private void copyTo(@Nonnull final double[] src, @Nonnull final FloatWritable[] dst) { + for (int k = 0, size = factor; k < size; k++) { + dst[k].set((float) src[k]); + } + } + + private void runTrainingIteration() throws HiveException { + model.updateWithUsers(userToItems); + model.updateWithItems(itemToUsers, sppmi); +// model.validate() + } + + private void validate() throws HiveException { + if (validationUsers.size() != validationItems.size()) { + throw new HiveException("number of validation users and items must be the same"); + } + for (int i = 0, numVal = validationUsers.size(); i < numVal; i++) { + final Double loss = model.validate(validationUsers.get(i), validationUsers.get(i)); + if (loss != null) { + if (!NumberUtils.isFinite(loss)) { + throw new HiveException("Non-finite validation loss encountered"); + } + validationState.incrLoss(loss); + } + } + } +} diff --git a/core/src/main/java/hivemall/mf/FactorizedModel.java b/core/src/main/java/hivemall/mf/FactorizedModel.java index c3aa7398e..16bc357e0 100644 --- a/core/src/main/java/hivemall/mf/FactorizedModel.java +++ b/core/src/main/java/hivemall/mf/FactorizedModel.java @@ -33,14 +33,14 @@ public final class FactorizedModel { @Nonnull - private final RatingInitializer ratingInitializer; + protected final RatingInitializer ratingInitializer; @Nonnegative - private final int factor; + protected final int factor; // rank matrix initialization - private final RankInitScheme initScheme; + protected final RankInitScheme initScheme; - private int minIndex, maxIndex; + protected int minIndex, maxIndex; @Nonnull private Rating meanRating; private Int2ObjectMap users; @@ -48,7 +48,7 @@ public final class FactorizedModel { private Int2ObjectMap userBias; private Int2ObjectMap itemBias; - private final Random[] randU, randI; + protected final Random[] randU, randI; public FactorizedModel(@Nonnull RatingInitializer ratingInitializer, @Nonnegative int factor, @Nonnull RankInitScheme initScheme) { @@ -80,9 +80,9 @@ public enum RankInitScheme { random /* default */, gaussian; @Nonnegative - private float maxInitValue; + protected float maxInitValue; @Nonnegative - private double initStdDev; + protected double initStdDev; @Nonnull public static RankInitScheme resolve(@Nullable String opt) { @@ -253,7 +253,7 @@ public void setItemBias(final int i, final float value) { b.setWeight(value); } - private static void uniformFill(final Rating[] a, final Random rand, final float maxInitValue, + protected static void uniformFill(final Rating[] a, final Random rand, final float maxInitValue, final RatingInitializer init) { for (int i = 0, len = a.length; i < len; i++) { float v = rand.nextFloat() * maxInitValue / len; @@ -261,7 +261,7 @@ private static void uniformFill(final Rating[] a, final Random rand, final float } } - private static void gaussianFill(final Rating[] a, final Random[] rand, final double stddev, + protected static void gaussianFill(final Rating[] a, final Random[] rand, final double stddev, final RatingInitializer init) { for (int i = 0, len = a.length; i < len; i++) { float v = (float) MathUtils.gaussian(0.d, stddev, rand[i]); diff --git a/core/src/test/java/hivemall/factorization/cofactor/CofactorModelTest.java b/core/src/test/java/hivemall/factorization/cofactor/CofactorModelTest.java new file mode 100644 index 000000000..69ec2df6f --- /dev/null +++ b/core/src/test/java/hivemall/factorization/cofactor/CofactorModelTest.java @@ -0,0 +1,531 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.factorization.cofactor; + +import hivemall.fm.Feature; +import hivemall.fm.StringFeature; +import it.unimi.dsi.fastutil.objects.Object2DoubleArrayMap; +import it.unimi.dsi.fastutil.objects.Object2DoubleMap; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import org.apache.commons.math3.linear.Array2DRowRealMatrix; +import org.apache.commons.math3.linear.ArrayRealVector; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealVector; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.junit.Assert; +import org.junit.Test; + +public class CofactorModelTest { + private static final double EPSILON = 1e-3; + private static final int NUM_FACTORS = 2; + + // items + private static final String TOOTHBRUSH = "toothbrush"; + private static final String TOOTHPASTE = "toothpaste"; + private static final String SHAVER = "shaver"; + + // users + private static final String MAKOTO = "makoto"; + private static final String TAKUYA = "takuya"; + private static final String JACKSON = "jackson"; + private static final String ALIEN = "alien"; + + @Test + public void calculateWTW() { + CofactorModel.Weights weights = getTestBeta(); + + double[][] expectedWTW = new double[][] {{0.63, -0.238}, {-0.238, 0.346}}; + + double[][] actualWTW = CofactorModel.calculateWTW(weights, 2, 0.1f); + Assert.assertTrue(matricesAreEqual(actualWTW, expectedWTW)); + } + + @Test + public void calculateA() throws HiveException { + final CofactorModel.Weights itemFactors = getTestBeta(); + final List ratedItems = getUserToItems().get(MAKOTO); + double[] actual = CofactorModel.calculateA(ratedItems, itemFactors, NUM_FACTORS, 0.5f); + double[] expected = new double[] {-0.85, 0.95}; + Assert.assertArrayEquals(expected, actual, EPSILON); + } + + @Test + public void calculateWTWSubsetFeatures() throws HiveException { + CofactorModel.Weights itemFactors = getTestBeta(); + List ratedItems = getUserToItems().get(MAKOTO); + + double[][] actual = + CofactorModel.calculateWTWSubsetStrings(ratedItems, itemFactors, NUM_FACTORS, 0.9f); + double[][] expected = new double[][] {{4.581, -3.033}, {-3.033, 2.385}}; + + Assert.assertTrue(matricesAreEqual(actual, expected)); + } + + @Test + public void calculateNewThetaVector() throws HiveException { + final float c0 = 0.1f, c1 = 1.f, lambdaTheta = 1e-5f; + CofactorModel.Weights itemFactors = getTestBeta(); + + double[][] BTBpR = CofactorModel.calculateWTWpR(itemFactors, NUM_FACTORS, c0, lambdaTheta); + double[][] initialBTBpR = copyArray(BTBpR); + + RealMatrix B = new Array2DRowRealMatrix(NUM_FACTORS, NUM_FACTORS); + RealVector A = new ArrayRealVector(NUM_FACTORS); + + RealVector actual = CofactorModel.calculateNewThetaVector( + new AbstractMap.SimpleEntry<>(MAKOTO, getUserToItems().get(MAKOTO)), itemFactors, + NUM_FACTORS, B, A, BTBpR, c0, c1); + Assert.assertNotNull(actual); + + // ensure that TTTpR has not been accidentally changed after one update + Assert.assertTrue(matricesAreEqual(initialBTBpR, BTBpR)); + + double[] expected = new double[] {0.44514062, 1.22886953}; + Assert.assertArrayEquals(expected, actual.toArray(), EPSILON); + } + + @Test + public void calculateRSD() throws HiveException { + double[] actual = + CofactorModel.calculateRSD(TOOTHBRUSH, Arrays.asList(getSPPMI().get(TOOTHBRUSH)), + NUM_FACTORS, getTestBetaBias(), getTestGammaBias(), getTestGamma(), 0.d); + double[] expected = new double[] {2.656, 0.154}; + Assert.assertArrayEquals(expected, actual, EPSILON); + } + + @Test + public void calculateNewBetaVector() throws HiveException { + final float c0 = 0.1f, c1 = 1.f, lambdaBeta = 1e-5f; + + Object2DoubleMap betaBias = getTestBetaBias(); + Object2DoubleMap gammaBias = getTestGammaBias(); + CofactorModel.Weights gamma = getTestGamma(); + CofactorModel.Weights theta = getTestTheta(); + + RealMatrix B = new Array2DRowRealMatrix(NUM_FACTORS, NUM_FACTORS); + RealVector A = new ArrayRealVector(NUM_FACTORS); + + // solve for new weights for toothbrush + final Map> items = getItemToUsers(); + final Map sppmi = getSPPMI(); + final Map.Entry> toothbrush = + new AbstractMap.SimpleEntry<>(TOOTHBRUSH, items.get(TOOTHBRUSH)); + + double[][] TTTpR = CofactorModel.calculateWTWpR(theta, NUM_FACTORS, c0, lambdaBeta); + double[][] initialTTTpR = copyArray(TTTpR); + + // zero bias: solve and update factors + RealVector actual = CofactorModel.calculateNewBetaVector(toothbrush, sppmi, theta, gamma, + gammaBias, betaBias, NUM_FACTORS, B, A, TTTpR, c0, c1, 0.d); + Assert.assertNotNull(actual); + // ensure that TTTpR has not been accidentally changed after one update + Assert.assertTrue(matricesAreEqual(initialTTTpR, TTTpR)); + double[] expected = new double[] {0.23246102, -0.147114}; + Assert.assertArrayEquals(expected, actual.toArray(), EPSILON); + + // non-zero bias: solve and update factors + actual = CofactorModel.calculateNewBetaVector(toothbrush, sppmi, theta, gamma, gammaBias, + betaBias, NUM_FACTORS, B, A, TTTpR, c0, c1, 2.5d); + Assert.assertNotNull(actual); + // ensure that TTTpR has not been accidentally changed after one update + Assert.assertTrue(matricesAreEqual(initialTTTpR, TTTpR)); + expected = new double[] {-0.77140623, -1.19014975}; + Assert.assertArrayEquals(expected, actual.toArray(), EPSILON); + } + + private static Map getSPPMI() { + final Map sppmi = new HashMap<>(); + sppmi.put(TOOTHBRUSH, + new Feature[] {new StringFeature(TOOTHPASTE, 1.22d), new StringFeature(SHAVER, 1.22d)}); + sppmi.put(TOOTHPASTE, + new Feature[] {new StringFeature(TOOTHBRUSH, 1.22d), new StringFeature(SHAVER, 1.35d)}); + sppmi.put(SHAVER, new Feature[] {new StringFeature(TOOTHBRUSH, 1.22d), + new StringFeature(TOOTHPASTE, 1.35d)}); + return sppmi; + } + + @Test + public void calculateNewGlobalBias() throws HiveException { + CofactorModel.Weights beta = getTestBeta(); + CofactorModel.Weights gamma = getTestGamma(); + Object2DoubleMap betaBias = getTestBetaBias(); + Object2DoubleMap gammaBias = getTestGammaBias(); + + Double actual = CofactorModel.calculateNewGlobalBias(getItemToUsers(), getSPPMI(), beta, + gamma, betaBias, gammaBias); + Assert.assertNotNull(actual); + Assert.assertEquals(-0.2667, actual, EPSILON); + } + + private static double[][] copyArray(double[][] A) { + double[][] newA = new double[A.length][A[0].length]; + for (int i = 0; i < A.length; i++) { + for (int j = 0; j < A[0].length; j++) { + newA[i][j] = A[i][j]; + } + } + return newA; + } + + @Test + public void calculateNewGammaVector() throws HiveException { + final float lambdaGamma = 1e-5f; + + final Object2DoubleMap betaBias = getTestBetaBias(); + final Object2DoubleMap gammaBias = getTestGammaBias(); + final CofactorModel.Weights beta = getTestBeta(); + + final RealMatrix B = new Array2DRowRealMatrix(NUM_FACTORS, NUM_FACTORS); + final RealVector A = new ArrayRealVector(NUM_FACTORS); + + final Map.Entry> currentItem = + new AbstractMap.SimpleEntry<>(TOOTHBRUSH, getItemToUsers().get(TOOTHBRUSH)); + final Map sppmi = getSPPMI(); + + // zero global bias + RealVector actual = CofactorModel.calculateNewGammaVector(currentItem, sppmi, beta, + gammaBias, betaBias, NUM_FACTORS, B, A, lambdaGamma, 0.d); + Assert.assertNotNull(actual); + double[] expected = new double[] {0.95828914, -1.48234826}; + Assert.assertArrayEquals(expected, actual.toArray(), EPSILON); + + // non-zero global bias + actual = CofactorModel.calculateNewGammaVector(currentItem, sppmi, beta, gammaBias, + betaBias, NUM_FACTORS, B, A, lambdaGamma, 2.5d); + Assert.assertNotNull(actual); + expected = new double[] {0.49037982, -3.68822023}; + Assert.assertArrayEquals(expected, actual.toArray(), EPSILON); + } + + @Test + public void calculateNewBias_forBetaBias_returnsNonNull() throws HiveException { + Object2DoubleMap gammaBias = getTestGammaBias(); + CofactorModel.Weights beta = getTestBeta(); + CofactorModel.Weights gamma = getTestGamma(); + + final Map.Entry> currentItem = + new AbstractMap.SimpleEntry<>(TOOTHBRUSH, getItemToUsers().get(TOOTHBRUSH)); + final Map sppmi = getSPPMI(); + + // zero global bias + Double actual = + CofactorModel.calculateNewBias(currentItem, sppmi, beta, gamma, gammaBias, 0.d); + Assert.assertNotNull(actual); + Assert.assertEquals(-0.235, actual, EPSILON); + + // non-zero global bias + actual = CofactorModel.calculateNewBias(currentItem, sppmi, beta, gamma, gammaBias, 2.5d); + Assert.assertNotNull(actual); + Assert.assertEquals(-2.735, actual, EPSILON); + } + + @Test + public void L2Distance() throws HiveException { + double[] v = new double[] {0.1, 2.3, 5.3}; + double actual = CofactorModel.L2Distance(v); + double expected = 5.7784d; + Assert.assertEquals(actual, expected, EPSILON); + } + + // @Test + // public void calculateMFLoss_allFeaturesAreTrainable() throws HiveException { + // List samples = getSamples_itemAsContext_allUsersInTheta(); + // CofactorModel.Weights beta = getTestBeta(); + // CofactorModel.Weights theta = getTestTheta(); + // double actual = CofactorModel.calculateMFLoss(samples, beta, theta, 0.1f, 1.f); + // double expected = 0.7157; + // Assert.assertEquals(actual, expected, EPSILON); + // } + + @Test + public void calculateMFLoss_oneFeatureNotTrainable() throws HiveException { + // tests case where a user found in the item's feature array + // was not also distributed to the same UDTF instance + List samples = + getSamples_itemAsContext_oneUserNotInTheta(); + CofactorModel.Weights beta = getTestBeta(); + CofactorModel.Weights theta = getTestTheta(); + double actual = CofactorModel.calculateMFLoss(samples, beta, theta, 0.1f, 1.f); + double expected = 0.7157; + Assert.assertEquals(actual, expected, EPSILON); + } + + // @Test + // public void calculateEmbedLoss() { + // List samples = getSamples_itemAsContext_allUsersInTheta(); + // CofactorModel.Weights beta = getTestBeta(); + // CofactorModel.Weights gamma = getTestGamma(); + // Object2DoubleMap betaBias = getTestBetaBias(); + // Object2DoubleMap gammaBias = getTestGammaBias(); + // + // double actual = CofactorModel.calculateEmbedLoss(samples, beta, gamma, betaBias, gammaBias); + // double expected = 2.756; + // Assert.assertEquals(expected, actual, EPSILON); + // } + + @Test + public void dotProduct() { + double[] u = new double[] {0.1, 5.1, 3.2}; + double[] v = new double[] {1, 2, 3}; + Assert.assertEquals(CofactorModel.dotProduct(u, v), 19.9, EPSILON); + } + + @Test + public void addInPlaceArray1D() throws HiveException { + double[] u = new double[] {0.1, 5.1, 3.2}; + double[] v = new double[] {1, 2, 3}; + + double[] actual = CofactorModel.addInPlace(u, v, 1.f); + double[] expected = new double[] {1.1, 7.1, 6.2}; + Assert.assertArrayEquals(u, expected, EPSILON); + Assert.assertArrayEquals(actual, expected, EPSILON); + } + + @Test + public void addInPlaceArray2D() throws HiveException { + double[][] u = new double[][] {{0.1, 5.1}, {3.2, 1.2}}; + double[][] v = new double[][] {{1, 2}, {3, 4}}; + + double[][] actual = CofactorModel.addInPlace(u, v); + double[][] expected = new double[][] {{1.1, 7.1}, {6.2, 5.2}}; + Assert.assertTrue(matricesAreEqual(u, expected)); + Assert.assertTrue(matricesAreEqual(actual, expected)); + } + + // @Test + // public void smallTrainingTest_implicitFeedback() throws HiveException { + // final boolean IS_FEEDBACK_EXPLICIT = false; + // CofactorModel.RankInitScheme init = CofactorModel.RankInitScheme.gaussian; + // init.setInitStdDev(1.0f); + // + // CofactorModel model = new CofactorModel(NUM_FACTORS, init, + // 0.1f, 1.f, 1e-5f, 1e-5f, 1.f, 0.f, null, 0, LOG, skippedUserCounter, skippedItemCounter, userCounter, itemCounter); + // int iterations = 5; + // List users = getUserToItems(IS_FEEDBACK_EXPLICIT); + // List items = getItemToUsers(IS_FEEDBACK_EXPLICIT); + // + // // record features + // recordContexts(model, users, false); + // recordContexts(model, items, true); + // + // double prevLoss = Double.MAX_VALUE; + // for (int i = 0; i < iterations; i++) { + // model.updateWithUsers(users); + // model.updateWithItems(items); + // Double loss = model.calculateLoss(users, items); + // Assert.assertNotNull(loss); + // Assert.assertTrue(loss < prevLoss); + // prevLoss = loss; + // } + // + // // assert that the user-item predictions after N iterations is identical to expected predictions + //// String expected = "makoto -> (toothpaste:0.976), (toothbrush:0.942), (shaver:1.076), \n" + + //// "takuya -> (toothpaste:1.001), (toothbrush:-0.167), (shaver:0.173), \n" + + //// "jackson -> (toothpaste:1.031), (toothbrush:0.715), (shaver:0.906), \n"; + // String predictionString = generatePredictionString(model, users, items); + // System.out.println(predictionString); + //// Assert.assertEquals(predictionString, expected); + // } + + + // @Test + // public void calculateAUC() throws HiveException { + // CofactorModel.RankInitScheme init = CofactorModel.RankInitScheme.gaussian; + // init.setInitStdDev(0.01f); + // CofactorModel model = new CofactorModel(NUM_FACTORS, init, 0.1f, 1.f, 1e-5f, 1e-5f, + // 1.f, 0.f, CofactorizationUDTF.ValidationMetric.AUC, 3, LOG, skippedUserCounter, skippedItemCounter, userCounter, itemCounter); + // + // List users = getUserToItems(false); + // List items = getItemToUsers(false); + // + // // record features + // recordContexts(model, users, false); + // recordContexts(model, items, true); + // + // model.finalizeContexts(); + // + // model.validate(items.get(0), 31); + // } + + @Test + public void sampleNegatives() throws HiveException { + // first validation example is positive, last two examples are negative + final int numVal = 3, numPos = 1, numNeg = numVal - numPos; + final int seed = 31; + final double DUMMY_VALUE = 0d; + final Feature[] validationProbes = new Feature[numVal]; + + validationProbes[0] = new StringFeature("positive", DUMMY_VALUE); + validationProbes[1] = new StringFeature("placeholder", DUMMY_VALUE); + validationProbes[2] = new StringFeature("placeholder", DUMMY_VALUE); + + final String[] items = getTestBeta().getNonnullKeys(); + + CofactorModel.sampleNegatives(numPos, numNeg, validationProbes, items, new Random(seed)); + Assert.assertEquals(validationProbes[0].getFeature(), "positive"); + Assert.assertEquals(validationProbes[1].getFeature(), TOOTHPASTE); + Assert.assertEquals(validationProbes[2].getFeature(), TOOTHBRUSH); + } + + private static String generatePredictionString(CofactorModel model, + List users, + List items) { + StringBuilder predicted = new StringBuilder(); + for (CofactorizationUDTF.TrainingSample user : users) { + predicted.append(user.context).append(" -> "); + for (CofactorizationUDTF.TrainingSample item : items) { + Double score = model.predict(user.context, item.context); + predicted.append("(") + .append(item.context) + .append(":") + .append(String.format("%.3f", score)) + .append("), "); + } + predicted.append('\n'); + } + return predicted.toString(); + } + + private static String mapToString(CofactorModel.Weights weights) { + StringBuilder sb = new StringBuilder(); + for (Map.Entry entry : weights.entrySet()) { + sb.append(entry.getKey() + ": " + arrayToString(entry.getValue(), 3) + ", "); + } + return sb.toString(); + } + + private static String arrayToString(double[] A, int decimals) { + StringBuilder sb = new StringBuilder(); + sb.append('['); + for (int i = 0; i < A.length; i++) { + sb.append(String.format("%." + decimals + "f", A[i])); + if (i != A.length - 1) { + sb.append(", "); + } + } + sb.append(']'); + return sb.toString(); + } + + private static Map> getItemToUsers() { + final Map> items = new HashMap<>(); + items.put(TOOTHBRUSH, Collections.singletonList(MAKOTO)); + items.put(TOOTHPASTE, Arrays.asList(TAKUYA, MAKOTO, JACKSON)); + items.put(SHAVER, Arrays.asList(JACKSON, MAKOTO)); + return items; + } + + private static Map> getUserToItems() { + final Map> users = new HashMap<>(); + users.put(MAKOTO, Arrays.asList(TOOTHBRUSH, SHAVER)); + users.put(TAKUYA, Collections.singletonList(TOOTHPASTE)); + users.put(JACKSON, Arrays.asList(TOOTHPASTE, SHAVER)); + return users; + } + + + private static List getSamples_itemAsContext_oneUserNotInTheta() { + List samples = new ArrayList<>(); + samples.add(new CofactorizationUDTF.TrainingSample(TOOTHBRUSH, + getSuperset_userFeatureVector_implicitFeedback(), false, null)); + return samples; + } + + + private static boolean matricesAreEqual(double[][] A, double[][] B) { + if (A.length != B.length || A[0].length != B[0].length) { + return false; + } + for (int r = 0; r < A.length; r++) { + for (int c = 0; c < A[0].length; c++) { + if (Math.abs(A[r][c] - B[r][c]) > EPSILON) { + return false; + } + } + } + return true; + } + + private static CofactorModel.Weights getTestTheta() { + CofactorModel.Weights weights = new CofactorModel.Weights(); + weights.put(MAKOTO, new double[] {0.8, -0.7}); + weights.put(TAKUYA, new double[] {-0.05, 1.7}); + weights.put(JACKSON, new double[] {1.8, -0.3}); + return weights; + } + + private static CofactorModel.Weights getTestBeta() { + CofactorModel.Weights weights = new CofactorModel.Weights(); + weights.put(TOOTHBRUSH, new double[] {0.5, 0.3}); + weights.put(TOOTHPASTE, new double[] {1.1, 0.9}); + weights.put(SHAVER, new double[] {-2.2, 1.6}); + return weights; + } + + private static CofactorModel.Weights getTestGamma() { + CofactorModel.Weights weights = new CofactorModel.Weights(); + weights.put(TOOTHBRUSH, new double[] {1.3, -0.2}); + weights.put(TOOTHPASTE, new double[] {1.6, 0.1}); + weights.put(SHAVER, new double[] {3.2, -0.4}); + return weights; + } + + private static Object2DoubleMap getTestBetaBias() { + Object2DoubleMap weights = new Object2DoubleArrayMap<>(); + weights.put(TOOTHBRUSH, 0.1); + weights.put(TOOTHPASTE, -1.9); + weights.put(SHAVER, 2.3); + return weights; + } + + private static Object2DoubleMap getTestGammaBias() { + Object2DoubleMap weights = new Object2DoubleArrayMap<>(); + weights.put(TOOTHBRUSH, 3.4); + weights.put(TOOTHPASTE, -0.5); + weights.put(SHAVER, 1.1); + return weights; + } + + private static Feature[] getSubset_userFeatureVector_implicitFeedback() { + // Makoto and Jackson both prefer a particular item + Feature[] f = new Feature[2]; + f[0] = new StringFeature(MAKOTO, 1.d); + f[1] = new StringFeature(JACKSON, 1.d); + return f; + } + + private static Feature[] getSuperset_userFeatureVector_implicitFeedback() { + // Makoto, Jackson and Alien prefer a particular item + Feature[] f = new Feature[3]; + f[0] = new StringFeature(MAKOTO, 1.d); + f[1] = new StringFeature(JACKSON, 1.d); + f[2] = new StringFeature(ALIEN, 1.d); + assert !getTestGamma().containsKey(ALIEN); + return f; + } +} diff --git a/core/src/test/java/hivemall/factorization/cofactor/CofactorizationUDTFTest.java b/core/src/test/java/hivemall/factorization/cofactor/CofactorizationUDTFTest.java new file mode 100644 index 000000000..d36541a7b --- /dev/null +++ b/core/src/test/java/hivemall/factorization/cofactor/CofactorizationUDTFTest.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.factorization.cofactor; + +import hivemall.fm.Feature; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.StringUtils; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.zip.GZIPInputStream; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.exec.MapredContext; +import org.apache.hadoop.hive.ql.exec.MapredContextAccessor; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.Collector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.junit.Assert; +import org.junit.Before; + +public class CofactorizationUDTFTest { + + CofactorizationUDTF udtf; + + private static class TrainingSample { + private String context; + private List features; + private List sppmi; + private boolean isValidation; + + private TrainingSample() {} + + private Object[] toArray() { + boolean isItem = sppmi != null; + return new Object[] {context, features, isValidation, isItem, sppmi}; + } + } + + private static class TestingSample { + private String user; + private String item; + private double rating; + + private TestingSample() {} + } + + @Before + public void setUp() throws HiveException { + udtf = new CofactorizationUDTF(); + } + + private void initialize(final boolean initMapred, @Nonnull final String options) + throws HiveException { + if (initMapred) { + MapredContext mapredContext = MapredContextAccessor.create(true, null); + udtf.configure(mapredContext); + udtf.setCollector(new Collector() { + @Override + public void collect(Object args) throws HiveException {} + }); + } + + ObjectInspector[] argOIs = + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector), + PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, + PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector), + HiveUtils.getConstStringObjectInspector(options)}; + udtf.initialize(argOIs); + } + + // @Test + // public void testTrain() throws HiveException, IOException { + // initialize(true, "-max_iters 5 -factors 100 -c0 0.03 -c1 0.3"); + // + // TrainingSample trainSample = new TrainingSample(); + // + // BufferedReader train = readFile("ml100k-cofactor.trainval.gz"); + // String line; + // while ((line = train.readLine()) != null) { + // parseLine(line, trainSample); + // udtf.process(trainSample.toArray()); + // } + // Assert.assertEquals(udtf.numTraining, 52287); + // Assert.assertEquals(udtf.numValidations, 9227); + // udtf.close(); + // } + + @Nonnull + private static BufferedReader readFile(@Nonnull String fileName) throws IOException { + InputStream is = CofactorizationUDTFTest.class.getResourceAsStream(fileName); + if (fileName.endsWith(".gz")) { + is = new GZIPInputStream(is); + } + return new BufferedReader(new InputStreamReader(is)); + } + + private static void parseLine(@Nonnull String line, @Nonnull TrainingSample sample) { + String[] cols = StringUtils.split(line, ' '); + Assert.assertNotNull(cols); + Assert.assertTrue(cols.length == 4 || cols.length == 5); + sample.context = cols[0]; + boolean isItem = Integer.parseInt(cols[1]) == 1; + sample.isValidation = Integer.parseInt(cols[2]) == 1; + sample.features = parseFeatures(cols[3]); + sample.sppmi = cols.length == 5 ? parseFeatures(cols[4]) : null; + } + + private static void parseLine(@Nonnull String line, @Nonnull TestingSample sample) { + String[] cols = StringUtils.split(line, ' '); + Assert.assertNotNull(cols); + Assert.assertEquals(cols.length, 3); + sample.user = cols[0]; + sample.item = cols[1]; + sample.rating = Double.parseDouble(cols[2]); + } + + private static List parseFeatures(@Nonnull String string) { + String[] entries = StringUtils.split(string, ','); + List features = new ArrayList<>(); + features.addAll(Arrays.asList(entries)); + return features; + } + + private static boolean featureArraysAreEqual(Feature[] f1, Feature[] f2) { + if (f1 == null && f2 == null) { + return true; + } + if (f1 == null || f2 == null) { + return false; + } + if (f1.length != f2.length) { + return false; + } + for (int i = 0; i < f1.length; i++) { + if (!featuresAreEqual(f1[i], f2[i])) { + return false; + } + } + return true; + } + + private static boolean featuresAreEqual(Feature f1, Feature f2) { + return f1.getFeature().equals(f2.getFeature()) && f1.getValue() == f2.getValue(); + } + + private static Object[] getItemTrainSample() { + return new Object[] {"string1", getDummyFeatures(), false, true, getDummyFeatures()}; + } + + private static Object[] getItemValidationSample() { + return new Object[] {"string1", getDummyFeatures(), true, true, getDummyFeatures()}; + } + + private static Object[] getUserSample() { + return new Object[] {"user", getDummyFeatures(), false, false, null}; + } + + private static List getDummyFeatures() { + List features = new ArrayList<>(); + features.add("feature1:1"); + features.add("feature2:2"); + features.add("feature3:3"); + return features; + } +} diff --git a/core/src/test/java/hivemall/mf/CofactorModelTest.java b/core/src/test/java/hivemall/mf/CofactorModelTest.java new file mode 100644 index 000000000..8c968e2ad --- /dev/null +++ b/core/src/test/java/hivemall/mf/CofactorModelTest.java @@ -0,0 +1,524 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.mf; + +import hivemall.fm.Feature; +import hivemall.fm.StringFeature; +import it.unimi.dsi.fastutil.objects.Object2DoubleArrayMap; +import it.unimi.dsi.fastutil.objects.Object2DoubleMap; +import org.apache.commons.math3.linear.Array2DRowRealMatrix; +import org.apache.commons.math3.linear.ArrayRealVector; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealVector; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.junit.Assert; +import org.junit.Test; + +import java.util.*; + +public class CofactorModelTest { + private static final double EPSILON = 1e-3; + private static final int NUM_FACTORS = 2; + + // items + private static final String TOOTHBRUSH = "toothbrush"; + private static final String TOOTHPASTE = "toothpaste"; + private static final String SHAVER = "shaver"; + + // users + private static final String MAKOTO = "makoto"; + private static final String TAKUYA = "takuya"; + private static final String JACKSON = "jackson"; + private static final String ALIEN = "alien"; + + @Test + public void calculateWTW() { + CofactorModel.Weights weights = getTestBeta(); + + double[][] expectedWTW = new double[][]{ + {0.63, -0.238}, + {-0.238, 0.346} + }; + + double[][] actualWTW = CofactorModel.calculateWTW(weights, 2, 0.1f); + Assert.assertTrue(matricesAreEqual(actualWTW, expectedWTW)); + } + + @Test + public void calculateA() throws HiveException { + final CofactorModel.Weights itemFactors = getTestBeta(); + final List ratedItems = getUserToItems().get(MAKOTO); + double[] actual = CofactorModel.calculateA(ratedItems, itemFactors, NUM_FACTORS, 0.5f); + double[] expected = new double[]{-0.85, 0.95}; + Assert.assertArrayEquals(expected, actual, EPSILON); + } + + @Test + public void calculateWTWSubsetFeatures() throws HiveException { + CofactorModel.Weights itemFactors = getTestBeta(); + List ratedItems = getUserToItems().get(MAKOTO); + + double[][] actual = CofactorModel.calculateWTWSubsetStrings(ratedItems, itemFactors, NUM_FACTORS, 0.9f); + double[][] expected = new double[][]{ + {4.581, -3.033}, + {-3.033, 2.385} + }; + + Assert.assertTrue(matricesAreEqual(actual, expected)); + } + + @Test + public void calculateNewThetaVector() throws HiveException { + final float c0 = 0.1f, c1 = 1.f, lambdaTheta = 1e-5f; + CofactorModel.Weights itemFactors = getTestBeta(); + + double[][] BTBpR = CofactorModel.calculateWTWpR(itemFactors, NUM_FACTORS, c0, lambdaTheta); + double[][] initialBTBpR = copyArray(BTBpR); + + RealMatrix B = new Array2DRowRealMatrix(NUM_FACTORS, NUM_FACTORS); + RealVector A = new ArrayRealVector(NUM_FACTORS); + + RealVector actual = CofactorModel.calculateNewThetaVector( + new AbstractMap.SimpleEntry<>(MAKOTO, getUserToItems().get(MAKOTO)), itemFactors, NUM_FACTORS, B, A, BTBpR, c0, c1); + Assert.assertNotNull(actual); + + // ensure that TTTpR has not been accidentally changed after one update + Assert.assertTrue(matricesAreEqual(initialBTBpR, BTBpR)); + + double[] expected = new double[]{0.44514062, 1.22886953}; + Assert.assertArrayEquals(expected, actual.toArray(), EPSILON); + } + + @Test + public void calculateRSD() throws HiveException { + double[] actual = CofactorModel.calculateRSD( + TOOTHBRUSH, + Arrays.asList(getSPPMI().get(TOOTHBRUSH)), + NUM_FACTORS, + getTestBetaBias(), + getTestGammaBias(), + getTestGamma(), + 0.d); + double[] expected = new double[]{2.656, 0.154}; + Assert.assertArrayEquals(expected, actual, EPSILON); + } + + @Test + public void calculateNewBetaVector() throws HiveException { + final float c0 = 0.1f, c1 = 1.f, lambdaBeta = 1e-5f; + + Object2DoubleMap betaBias = getTestBetaBias(); + Object2DoubleMap gammaBias = getTestGammaBias(); + CofactorModel.Weights gamma = getTestGamma(); + CofactorModel.Weights theta = getTestTheta(); + + RealMatrix B = new Array2DRowRealMatrix(NUM_FACTORS, NUM_FACTORS); + RealVector A = new ArrayRealVector(NUM_FACTORS); + + // solve for new weights for toothbrush + final Map> items = getItemToUsers(); + final Map sppmi = getSPPMI(); + final Map.Entry> toothbrush = new AbstractMap.SimpleEntry<>(TOOTHBRUSH, items.get(TOOTHBRUSH)); + + double[][] TTTpR = CofactorModel.calculateWTWpR(theta, NUM_FACTORS, c0, lambdaBeta); + double[][] initialTTTpR = copyArray(TTTpR); + + // zero bias: solve and update factors + RealVector actual = CofactorModel.calculateNewBetaVector(toothbrush, sppmi, theta, gamma, gammaBias, betaBias, NUM_FACTORS, B, A, TTTpR, c0, c1, 0.d); + Assert.assertNotNull(actual); + // ensure that TTTpR has not been accidentally changed after one update + Assert.assertTrue(matricesAreEqual(initialTTTpR, TTTpR)); + double[] expected = new double[]{0.23246102, -0.147114}; + Assert.assertArrayEquals(expected, actual.toArray(), EPSILON); + + // non-zero bias: solve and update factors + actual = CofactorModel.calculateNewBetaVector(toothbrush, sppmi, theta, gamma, gammaBias, betaBias, NUM_FACTORS, B, A, TTTpR, c0, c1, 2.5d); + Assert.assertNotNull(actual); + // ensure that TTTpR has not been accidentally changed after one update + Assert.assertTrue(matricesAreEqual(initialTTTpR, TTTpR)); + expected = new double[]{-0.77140623, -1.19014975}; + Assert.assertArrayEquals(expected, actual.toArray(), EPSILON); + } + + private static Map getSPPMI() { + final Map sppmi = new HashMap<>(); + Feature f = new StringFeature(TOOTHPASTE, 1.22d); + sppmi.put(TOOTHBRUSH, new Feature[]{ + new StringFeature(TOOTHPASTE, 1.22d), new StringFeature(SHAVER, 1.22d)}); + sppmi.put(TOOTHPASTE, new Feature[]{ + new StringFeature(TOOTHBRUSH, 1.22d), new StringFeature(SHAVER, 1.35d)}); + sppmi.put(SHAVER, new Feature[]{ + new StringFeature(TOOTHBRUSH, 1.22d), new StringFeature(TOOTHPASTE, 1.35d)}); + return sppmi; + } + + @Test + public void calculateNewGlobalBias() throws HiveException { + CofactorModel.Weights beta = getTestBeta(); + CofactorModel.Weights gamma = getTestGamma(); + Object2DoubleMap betaBias = getTestBetaBias(); + Object2DoubleMap gammaBias = getTestGammaBias(); + + Double actual = CofactorModel.calculateNewGlobalBias(getItemToUsers(), getSPPMI(), beta, gamma, betaBias, gammaBias); + Assert.assertNotNull(actual); + Assert.assertEquals(-0.2667, actual, EPSILON); + } + + private static double[][] copyArray(double[][] A) { + double[][] newA = new double[A.length][A[0].length]; + for (int i = 0; i < A.length; i++) { + for (int j = 0; j < A[0].length; j++) { + newA[i][j] = A[i][j]; + } + } + return newA; + } + + @Test + public void calculateNewGammaVector() throws HiveException { + final float lambdaGamma = 1e-5f; + + final Object2DoubleMap betaBias = getTestBetaBias(); + final Object2DoubleMap gammaBias = getTestGammaBias(); + final CofactorModel.Weights beta = getTestBeta(); + + final RealMatrix B = new Array2DRowRealMatrix(NUM_FACTORS, NUM_FACTORS); + final RealVector A = new ArrayRealVector(NUM_FACTORS); + + final Map.Entry> currentItem = new AbstractMap.SimpleEntry<>(TOOTHBRUSH, getItemToUsers().get(TOOTHBRUSH)); + final Map sppmi = getSPPMI(); + + // zero global bias + RealVector actual = CofactorModel.calculateNewGammaVector(currentItem, sppmi, beta, gammaBias, betaBias, NUM_FACTORS, B, A, lambdaGamma, 0.d); + Assert.assertNotNull(actual); + double[] expected = new double[]{0.95828914, -1.48234826}; + Assert.assertArrayEquals(expected, actual.toArray(), EPSILON); + + // non-zero global bias + actual = CofactorModel.calculateNewGammaVector(currentItem, sppmi, beta, gammaBias, betaBias, NUM_FACTORS, B, A, lambdaGamma, 2.5d); + Assert.assertNotNull(actual); + expected = new double[]{0.49037982, -3.68822023}; + Assert.assertArrayEquals(expected, actual.toArray(), EPSILON); + } + + @Test + public void calculateNewBias_forBetaBias_returnsNonNull() throws HiveException { + Object2DoubleMap gammaBias = getTestGammaBias(); + CofactorModel.Weights beta = getTestBeta(); + CofactorModel.Weights gamma = getTestGamma(); + + final Map.Entry> currentItem = new AbstractMap.SimpleEntry<>(TOOTHBRUSH, getItemToUsers().get(TOOTHBRUSH)); + final Map sppmi = getSPPMI(); + + // zero global bias + Double actual = CofactorModel.calculateNewBias(currentItem, sppmi, beta, gamma, gammaBias, 0.d); + Assert.assertNotNull(actual); + Assert.assertEquals(-0.235, actual, EPSILON); + + // non-zero global bias + actual = CofactorModel.calculateNewBias(currentItem, sppmi, beta, gamma, gammaBias, 2.5d); + Assert.assertNotNull(actual); + Assert.assertEquals(-2.735, actual, EPSILON); + } + + @Test + public void L2Distance() throws HiveException { + double[] v = new double[]{0.1, 2.3, 5.3}; + double actual = CofactorModel.L2Distance(v); + double expected = 5.7784d; + Assert.assertEquals(actual, expected, EPSILON); + } + +// @Test +// public void calculateMFLoss_allFeaturesAreTrainable() throws HiveException { +// List samples = getSamples_itemAsContext_allUsersInTheta(); +// CofactorModel.Weights beta = getTestBeta(); +// CofactorModel.Weights theta = getTestTheta(); +// double actual = CofactorModel.calculateMFLoss(samples, beta, theta, 0.1f, 1.f); +// double expected = 0.7157; +// Assert.assertEquals(actual, expected, EPSILON); +// } + + @Test + public void calculateMFLoss_oneFeatureNotTrainable() throws HiveException { + // tests case where a user found in the item's feature array + // was not also distributed to the same UDTF instance + List samples = getSamples_itemAsContext_oneUserNotInTheta(); + CofactorModel.Weights beta = getTestBeta(); + CofactorModel.Weights theta = getTestTheta(); + double actual = CofactorModel.calculateMFLoss(samples, beta, theta, 0.1f, 1.f); + double expected = 0.7157; + Assert.assertEquals(actual, expected, EPSILON); + } + +// @Test +// public void calculateEmbedLoss() { +// List samples = getSamples_itemAsContext_allUsersInTheta(); +// CofactorModel.Weights beta = getTestBeta(); +// CofactorModel.Weights gamma = getTestGamma(); +// Object2DoubleMap betaBias = getTestBetaBias(); +// Object2DoubleMap gammaBias = getTestGammaBias(); +// +// double actual = CofactorModel.calculateEmbedLoss(samples, beta, gamma, betaBias, gammaBias); +// double expected = 2.756; +// Assert.assertEquals(expected, actual, EPSILON); +// } + + @Test + public void dotProduct() { + double[] u = new double[]{0.1, 5.1, 3.2}; + double[] v = new double[]{1, 2, 3}; + Assert.assertEquals(CofactorModel.dotProduct(u, v), 19.9, EPSILON); + } + + @Test + public void addInPlaceArray1D() throws HiveException { + double[] u = new double[]{0.1, 5.1, 3.2}; + double[] v = new double[]{1, 2, 3}; + + double[] actual = CofactorModel.addInPlace(u, v, 1.f); + double[] expected = new double[]{1.1, 7.1, 6.2}; + Assert.assertArrayEquals(u, expected, EPSILON); + Assert.assertArrayEquals(actual, expected, EPSILON); + } + + @Test + public void addInPlaceArray2D() throws HiveException { + double[][] u = new double[][]{{0.1, 5.1}, {3.2, 1.2}}; + double[][] v = new double[][]{{1, 2}, {3, 4}}; + + double[][] actual = CofactorModel.addInPlace(u, v); + double[][] expected = new double[][]{{1.1, 7.1}, {6.2, 5.2}}; + Assert.assertTrue(matricesAreEqual(u, expected)); + Assert.assertTrue(matricesAreEqual(actual, expected)); + } + +// @Test +// public void smallTrainingTest_implicitFeedback() throws HiveException { +// final boolean IS_FEEDBACK_EXPLICIT = false; +// CofactorModel.RankInitScheme init = CofactorModel.RankInitScheme.gaussian; +// init.setInitStdDev(1.0f); +// +// CofactorModel model = new CofactorModel(NUM_FACTORS, init, +// 0.1f, 1.f, 1e-5f, 1e-5f, 1.f, 0.f, null, 0, LOG, skippedUserCounter, skippedItemCounter, userCounter, itemCounter); +// int iterations = 5; +// List users = getUserToItems(IS_FEEDBACK_EXPLICIT); +// List items = getItemToUsers(IS_FEEDBACK_EXPLICIT); +// +// // record features +// recordContexts(model, users, false); +// recordContexts(model, items, true); +// +// double prevLoss = Double.MAX_VALUE; +// for (int i = 0; i < iterations; i++) { +// model.updateWithUsers(users); +// model.updateWithItems(items); +// Double loss = model.calculateLoss(users, items); +// Assert.assertNotNull(loss); +// Assert.assertTrue(loss < prevLoss); +// prevLoss = loss; +// } +// +// // assert that the user-item predictions after N iterations is identical to expected predictions +//// String expected = "makoto -> (toothpaste:0.976), (toothbrush:0.942), (shaver:1.076), \n" + +//// "takuya -> (toothpaste:1.001), (toothbrush:-0.167), (shaver:0.173), \n" + +//// "jackson -> (toothpaste:1.031), (toothbrush:0.715), (shaver:0.906), \n"; +// String predictionString = generatePredictionString(model, users, items); +// System.out.println(predictionString); +//// Assert.assertEquals(predictionString, expected); +// } + + +// @Test +// public void calculateAUC() throws HiveException { +// CofactorModel.RankInitScheme init = CofactorModel.RankInitScheme.gaussian; +// init.setInitStdDev(0.01f); +// CofactorModel model = new CofactorModel(NUM_FACTORS, init, 0.1f, 1.f, 1e-5f, 1e-5f, +// 1.f, 0.f, CofactorizationUDTF.ValidationMetric.AUC, 3, LOG, skippedUserCounter, skippedItemCounter, userCounter, itemCounter); +// +// List users = getUserToItems(false); +// List items = getItemToUsers(false); +// +// // record features +// recordContexts(model, users, false); +// recordContexts(model, items, true); +// +// model.finalizeContexts(); +// +// model.validate(items.get(0), 31); +// } + + @Test + public void sampleNegatives() throws HiveException { + // first validation example is positive, last two examples are negative + final int numVal = 3, numPos = 1, numNeg = numVal - numPos; + final int seed = 31; + final double DUMMY_VALUE = 0d; + final Feature[] validationProbes = new Feature[numVal]; + + validationProbes[0] = new StringFeature("positive", DUMMY_VALUE); + validationProbes[1] = new StringFeature("placeholder", DUMMY_VALUE); + validationProbes[2] = new StringFeature("placeholder", DUMMY_VALUE); + + final String[] items = getTestBeta().getNonnullKeys(); + + CofactorModel.sampleNegatives(numPos, numNeg, validationProbes, items, new Random(seed)); + Assert.assertEquals(validationProbes[0].getFeature(), "positive"); + Assert.assertEquals(validationProbes[1].getFeature(), TOOTHPASTE); + Assert.assertEquals(validationProbes[2].getFeature(), TOOTHBRUSH); + } + + private static String generatePredictionString(CofactorModel model, List users, List items) { + StringBuilder predicted = new StringBuilder(); + for (CofactorizationUDTF.TrainingSample user : users) { + predicted.append(user.context).append(" -> "); + for (CofactorizationUDTF.TrainingSample item : items) { + Double score = model.predict(user.context, item.context); + predicted.append("(") + .append(item.context) + .append(":") + .append(String.format("%.3f", score)) + .append("), "); + } + predicted.append('\n'); + } + return predicted.toString(); + } + + private static String mapToString(CofactorModel.Weights weights) { + StringBuilder sb = new StringBuilder(); + for (Map.Entry entry : weights.entrySet()) { + sb.append(entry.getKey() + ": " + arrayToString(entry.getValue(), 3) + ", "); + } + return sb.toString(); + } + + private static String arrayToString(double[] A, int decimals) { + StringBuilder sb = new StringBuilder(); + sb.append('['); + for (int i = 0; i < A.length; i++) { + sb.append(String.format("%." + decimals + "f", A[i])); + if (i != A.length - 1) { + sb.append(", "); + } + } + sb.append(']'); + return sb.toString(); + } + + private static Map> getItemToUsers() { + final Map> items = new HashMap<>(); + items.put(TOOTHBRUSH, Collections.singletonList(MAKOTO)); + items.put(TOOTHPASTE, Arrays.asList(TAKUYA, MAKOTO, JACKSON)); + items.put(SHAVER, Arrays.asList(JACKSON, MAKOTO)); + return items; + } + + private static Map> getUserToItems() { + final Map> users = new HashMap<>(); + users.put(MAKOTO, Arrays.asList(TOOTHBRUSH, SHAVER)); + users.put(TAKUYA, Collections.singletonList(TOOTHPASTE)); + users.put(JACKSON, Arrays.asList(TOOTHPASTE, SHAVER)); + return users; + } + + + private static List getSamples_itemAsContext_oneUserNotInTheta() { + List samples = new ArrayList<>(); + samples.add(new CofactorizationUDTF.TrainingSample( + TOOTHBRUSH, + getSuperset_userFeatureVector_implicitFeedback(), + false, + null)); + return samples; + } + + + private static boolean matricesAreEqual(double[][] A, double[][] B) { + if (A.length != B.length || A[0].length != B[0].length) { + return false; + } + for (int r = 0; r < A.length; r++) { + for (int c = 0; c < A[0].length; c++) { + if (Math.abs(A[r][c] - B[r][c]) > EPSILON) { + return false; + } + } + } + return true; + } + + private static CofactorModel.Weights getTestTheta() { + CofactorModel.Weights weights = new CofactorModel.Weights(); + weights.put(MAKOTO, new double[]{0.8, -0.7}); + weights.put(TAKUYA, new double[]{-0.05, 1.7}); + weights.put(JACKSON, new double[]{1.8, -0.3}); + return weights; + } + + private static CofactorModel.Weights getTestBeta() { + CofactorModel.Weights weights = new CofactorModel.Weights(); + weights.put(TOOTHBRUSH, new double[]{0.5, 0.3}); + weights.put(TOOTHPASTE, new double[]{1.1, 0.9}); + weights.put(SHAVER, new double[]{-2.2, 1.6}); + return weights; + } + + private static CofactorModel.Weights getTestGamma() { + CofactorModel.Weights weights = new CofactorModel.Weights(); + weights.put(TOOTHBRUSH, new double[]{1.3, -0.2}); + weights.put(TOOTHPASTE, new double[]{1.6, 0.1}); + weights.put(SHAVER, new double[]{3.2, -0.4}); + return weights; + } + + private static Object2DoubleMap getTestBetaBias() { + Object2DoubleMap weights = new Object2DoubleArrayMap<>(); + weights.put(TOOTHBRUSH, 0.1); + weights.put(TOOTHPASTE, -1.9); + weights.put(SHAVER, 2.3); + return weights; + } + + private static Object2DoubleMap getTestGammaBias() { + Object2DoubleMap weights = new Object2DoubleArrayMap<>(); + weights.put(TOOTHBRUSH, 3.4); + weights.put(TOOTHPASTE, -0.5); + weights.put(SHAVER, 1.1); + return weights; + } + + private static Feature[] getSubset_userFeatureVector_implicitFeedback() { + // Makoto and Jackson both prefer a particular item + Feature[] f = new Feature[2]; + f[0] = new StringFeature(MAKOTO, 1.d); + f[1] = new StringFeature(JACKSON, 1.d); + return f; + } + + private static Feature[] getSuperset_userFeatureVector_implicitFeedback() { + // Makoto, Jackson and Alien prefer a particular item + Feature[] f = new Feature[3]; + f[0] = new StringFeature(MAKOTO, 1.d); + f[1] = new StringFeature(JACKSON, 1.d); + f[2] = new StringFeature(ALIEN, 1.d); + assert !getTestGamma().containsKey(ALIEN); + return f; + } +} diff --git a/core/src/test/java/hivemall/mf/CofactorizationUDTFTest.java b/core/src/test/java/hivemall/mf/CofactorizationUDTFTest.java new file mode 100644 index 000000000..8999c18dd --- /dev/null +++ b/core/src/test/java/hivemall/mf/CofactorizationUDTFTest.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.mf; + +import hivemall.fm.Feature; +import hivemall.fm.StringFeature; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.StringUtils; +import org.apache.hadoop.hive.ql.exec.MapredContext; +import org.apache.hadoop.hive.ql.exec.MapredContextAccessor; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.Collector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import javax.annotation.Nonnull; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.zip.GZIPInputStream; + +public class CofactorizationUDTFTest { + + CofactorizationUDTF udtf; + + private static class TrainingSample { + private String context; + private List features; + private List sppmi; + private boolean isValidation; + + private TrainingSample() {} + + private Object[] toArray() { + boolean isItem = sppmi != null; + return new Object[]{context, features, isValidation, isItem, sppmi}; + } + } + + private static class TestingSample { + private String user; + private String item; + private double rating; + + private TestingSample() {} + } + + @Before + public void setUp() throws HiveException { + udtf = new CofactorizationUDTF(); + } + + private void initialize(final boolean initMapred, @Nonnull final String options) throws HiveException { + if (initMapred) { + MapredContext mapredContext = MapredContextAccessor.create(true, null); + udtf.configure(mapredContext); + udtf.setCollector(new Collector() { + @Override + public void collect(Object args) throws HiveException { + } + }); + } + + ObjectInspector[] argOIs = new ObjectInspector[]{ + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), + PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, + PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), + HiveUtils.getConstStringObjectInspector(options) + }; + udtf.initialize(argOIs); + } + +// @Test +// public void testTrain() throws HiveException, IOException { +// initialize(true, "-max_iters 5 -factors 100 -c0 0.03 -c1 0.3"); +// +// TrainingSample trainSample = new TrainingSample(); +// +// BufferedReader train = readFile("ml100k-cofactor.trainval.gz"); +// String line; +// while ((line = train.readLine()) != null) { +// parseLine(line, trainSample); +// udtf.process(trainSample.toArray()); +// } +// Assert.assertEquals(udtf.numTraining, 52287); +// Assert.assertEquals(udtf.numValidations, 9227); +// udtf.close(); +// } + + @Nonnull + private static BufferedReader readFile(@Nonnull String fileName) throws IOException { + InputStream is = CofactorizationUDTFTest.class.getResourceAsStream(fileName); + if (fileName.endsWith(".gz")) { + is = new GZIPInputStream(is); + } + return new BufferedReader(new InputStreamReader(is)); + } + + private static void parseLine(@Nonnull String line, @Nonnull TrainingSample sample) { + String[] cols = StringUtils.split(line, ' '); + Assert.assertNotNull(cols); + Assert.assertTrue(cols.length == 4 || cols.length == 5); + sample.context = cols[0]; + boolean isItem = Integer.parseInt(cols[1]) == 1; + sample.isValidation = Integer.parseInt(cols[2]) == 1; + sample.features = parseFeatures(cols[3]); + sample.sppmi = cols.length == 5 ? parseFeatures(cols[4]) : null; + } + + private static void parseLine(@Nonnull String line, @Nonnull TestingSample sample) { + String[] cols = StringUtils.split(line, ' '); + Assert.assertNotNull(cols); + Assert.assertEquals(cols.length, 3); + sample.user = cols[0]; + sample.item = cols[1]; + sample.rating = Double.parseDouble(cols[2]); + } + + + private static List parseFeatures(@Nonnull String string) { + String[] entries = StringUtils.split(string, ','); + List features = new ArrayList<>(); + features.addAll(Arrays.asList(entries)); + return features; + } + + private static boolean featureArraysAreEqual(Feature[] f1, Feature[] f2) { + if (f1 == null && f2 == null) { + return true; + } + if (f1 == null || f2 == null) { + return false; + } + if (f1.length != f2.length) { + return false; + } + for (int i = 0; i < f1.length; i++) { + if (!featuresAreEqual(f1[i], f2[i])) { + return false; + } + } + return true; + } + + private static boolean featuresAreEqual(Feature f1, Feature f2) { + return f1.getFeature().equals(f2.getFeature()) && f1.getValue() == f2.getValue(); + } + + private static Object[] getItemTrainSample() { + return new Object[]{"string1", getDummyFeatures(), false, true, getDummyFeatures()}; + } + + private static Object[] getItemValidationSample() { + return new Object[]{"string1", getDummyFeatures(), true, true, getDummyFeatures()}; + } + + private static Object[] getUserSample() { + return new Object[]{"user", getDummyFeatures(), false, false, null}; + } + + private static List getDummyFeatures() { + List features = new ArrayList<>(); + features.add("feature1:1"); + features.add("feature2:2"); + features.add("feature3:3"); + return features; + } +} diff --git a/core/src/test/resources/hivemall/factorization/cofactor/ml100k-cofactor.trainval.gz b/core/src/test/resources/hivemall/factorization/cofactor/ml100k-cofactor.trainval.gz new file mode 100644 index 000000000..0342d090e Binary files /dev/null and b/core/src/test/resources/hivemall/factorization/cofactor/ml100k-cofactor.trainval.gz differ diff --git a/core/src/test/resources/hivemall/mf/ml100k-cofactor.trainval.gz b/core/src/test/resources/hivemall/mf/ml100k-cofactor.trainval.gz new file mode 100644 index 000000000..0342d090e Binary files /dev/null and b/core/src/test/resources/hivemall/mf/ml100k-cofactor.trainval.gz differ diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive index 69dcf6934..4b0e41a69 100644 --- a/resources/ddl/define-all-as-permanent.hive +++ b/resources/ddl/define-all-as-permanent.hive @@ -878,3 +878,9 @@ CREATE FUNCTION xgboost_predict AS 'hivemall.xgboost.tools.XGBoostPredictUDTF' U DROP FUNCTION xgboost_multiclass_predict; CREATE FUNCTION xgboost_multiclass_predict AS 'hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF' USING JAR '${hivemall_jar}'; + +DROP FUNCTION IF EXISTS train_cofactor; +CREATE FUNCTION train_cofactor as 'hivemall.factorization.cofactor.CofactorizationUDTF' USING JAR '${hivemall_jar}'; + +DROP FUNCTION IF EXISTS cofactor_predict; +CREATE FUNCTION cofactor_predict as 'hivemall.factorization.cofactor.CofactorizationPredictUDF' USING JAR '${hivemall_jar}'; diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive index f39aea3be..c7fe042cb 100644 --- a/resources/ddl/define-all.hive +++ b/resources/ddl/define-all.hive @@ -885,3 +885,9 @@ log(10, n_docs / max2(1,df_t)) + 1.0; create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE) tf * (log(10, n_docs / max2(1,df_t)) + 1.0); + +drop temporary function if exists train_cofactor; +create temporary function train_cofactor as 'hivemall.factorization.cofactor.CofactorizationUDTF'; + +drop temporary function if exists cofactor_predict; +create temporary function cofactor_predict as 'hivemall.factorization.cofactor.CofactorizationPredictUDF'; diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark index 4d46694ba..5ea479ae3 100644 --- a/resources/ddl/define-all.spark +++ b/resources/ddl/define-all.spark @@ -838,3 +838,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION bloom_or AS 'hivemall.sketch.bloom.Blo sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS bloom_contains_any") sqlContext.sql("CREATE TEMPORARY FUNCTION bloom_contains_any AS 'hivemall.sketch.bloom.BloomContainsAnyUDF'") + +sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS train_cofactor") +sqlContext.sql("CREATE TEMPORARY FUNCTION train_cofactor AS 'hivemall.factorization.cofactor.CofactorizationUDTF'") + +sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS cofactor_predict") +sqlContext.sql("CREATE TEMPORARY FUNCTION cofactor_predict AS 'hivemall.factorization.cofactor.CofactorizationPredictUDF'")