From d534780ed13034f1c11d6e01576c9885c1cf124a Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Tue, 17 Oct 2023 11:16:10 -0500 Subject: [PATCH 1/4] cleanup VectorUtil: - replace divInPlace() with scale() - remove add() that had same semantics as addInPlace() - l2normalize() operates in-place --- .../jvector/graph/GraphIndexBuilder.java | 2 +- .../jvector/pq/KMeansPlusPlusClusterer.java | 4 +- .../vector/DefaultVectorUtilSupport.java | 4 +- .../jbellis/jvector/vector/VectorUtil.java | 66 ++----------------- .../jvector/vector/VectorUtilSupport.java | 2 +- .../io/github/jbellis/jvector/TestUtil.java | 1 - .../vector/PanamaVectorUtilSupport.java | 4 +- .../jbellis/jvector/vector/SimdOps.java | 6 +- 8 files changed, 17 insertions(+), 72 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 04addafa8..63671c1f4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -529,7 +529,7 @@ private int approximateMedioid() { var node = it.nextInt(); VectorUtil.addInPlace(centroid, (float[]) vc.vectorValue(node)); } - VectorUtil.divInPlace(centroid, graph.size()); + VectorUtil.scale(centroid, 1.0f / graph.size()); NodeSimilarity.ExactScoreFunction scoreFunction = i -> scoreBetween(vc.vectorValue(i), (T) centroid); int ep = graph.entry(); var result = gs.searchInternal(scoreFunction, null, beamWidth, 0.0f, 0.0f, ep, Bits.ALL); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java index dd6b60699..8690971e2 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java @@ -211,7 +211,7 @@ private void updateCentroids() { centroids[i] = points[random.nextInt(points.length)]; } else { centroids[i] = Arrays.copyOf(centroidNums[i], centroidNums[i].length); - VectorUtil.divInPlace(centroids[i], centroidDenoms[i]); + VectorUtil.scale(centroids[i], 1.0f / centroidDenoms[i]); } } } @@ -225,7 +225,7 @@ public static float[] centroidOf(List points) { } float[] centroid = VectorUtil.sum(points); - VectorUtil.divInPlace(centroid, points.size()); + VectorUtil.scale(centroid, 1.0f / points.size()); return centroid; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java index 777f6b289..c4b3aad2d 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java @@ -226,9 +226,9 @@ public float sum(float[] vector) { } @Override - public void divInPlace(float[] vector, float divisor) { + public void scale(float[] vector, float multiplier) { for (int i = 0; i < vector.length; i++) { - vector[i] /= divisor; + vector[i] *= multiplier; } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java index 0ba43b713..905d55d67 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java @@ -116,52 +116,14 @@ public static int squareDistance(byte[] a, byte[] b) { /** * Modifies the argument to be unit length, dividing by its l2-norm. IllegalArgumentException is * thrown for zero vectors. - * - * @return the input array after normalization - */ - public static float[] l2normalize(float[] v) { - l2normalize(v, true); - return v; - } - - /** - * Modifies the argument to be unit length, dividing by its l2-norm. - * - * @param v the vector to normalize - * @param throwOnZero whether to throw an exception when v has all zeros - * @return the input array after normalization - * @throws IllegalArgumentException when the vector is all zero and throwOnZero is true */ - public static float[] l2normalize(float[] v, boolean throwOnZero) { - double squareSum = 0.0f; - int dim = v.length; - for (float x : v) { - squareSum += x * x; - } + public static void l2normalize(float[] v) { + double squareSum = dotProduct(v, v); if (squareSum == 0) { - if (throwOnZero) { - throw new IllegalArgumentException("Cannot normalize a zero-length vector"); - } else { - return v; - } + throw new IllegalArgumentException("Cannot normalize a zero-length vector"); } double length = Math.sqrt(squareSum); - for (int i = 0; i < dim; i++) { - v[i] /= length; - } - return v; - } - - /** - * Adds the second argument to the first - * - * @param u the destination - * @param v the vector to add to the destination - */ - public static void add(float[] u, float[] v) { - for (int i = 0; i < u.length; i++) { - u[i] += v[i]; - } + scale(v, (float) (1.0 / length)); } /** @@ -191,22 +153,6 @@ public static float dotProductScore(byte[] a, byte[] b) { return 0.5f + dotProduct(a, b) / denom; } - /** - * Checks if a float vector only has finite components. - * - * @param v bytes containing a vector - * @return the vector for call-chaining - * @throws IllegalArgumentException if any component of vector is not finite - */ - public static float[] checkFinite(float[] v) { - for (int i = 0; i < v.length; i++) { - if (!Float.isFinite(v[i])) { - throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + v[i]); - } - } - return v; - } - public static float[] sum(List vectors) { if (vectors.isEmpty()) { throw new IllegalArgumentException("Input list cannot be empty"); @@ -219,8 +165,8 @@ public static float sum(float[] vector) { return impl.sum(vector); } - public static void divInPlace(float[] vector, float divisor) { - impl.divInPlace(vector, divisor); + public static void scale(float[] vector, float multiplier) { + impl.scale(vector, multiplier); } public static void addInPlace(float[] v1, float[] v2) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java index 45c493dca..f5426fadd 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java @@ -62,7 +62,7 @@ public interface VectorUtilSupport { float sum(float[] vector); /** Divide vector by divisor, in place (vector will be modified) */ - void divInPlace(float[] vector, float divisor); + void scale(float[] vector, float multiplier); /** Adds v2 into v1, in place (v1 will be modified) */ public void addInPlace(float[] v1, float[] v2); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java index 215b31357..b4653c5ec 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java @@ -35,7 +35,6 @@ import java.nio.file.attribute.BasicFileAttributes; import java.util.*; import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index 22a8ff7c2..87416c691 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -71,8 +71,8 @@ public float sum(float[] vector) { } @Override - public void divInPlace(float[] vector, float divisor) { - SimdOps.divInPlace(vector, divisor); + public void scale(float[] vector, float multiplier) { + SimdOps.scale(vector, multiplier); } @Override diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index b0ad99119..d054ec667 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -70,19 +70,19 @@ static float[] sum(List vectors) { return sum; } - static void divInPlace(float[] vector, float divisor) { + static void scale(float[] vector, float multiplier) { int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length); // Process the vectorized part for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { var a = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vector, i); - var divResult = a.div(divisor); + var divResult = a.mul(multiplier); divResult.intoArray(vector, i); } // Process the tail for (int i = vectorizedLength; i < vector.length; i++) { - vector[i] = vector[i] / divisor; + vector[i] = vector[i] * multiplier; } } From edef9e5c614044020c4b9e14f99c85c888e5d4d4 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Tue, 23 Jan 2024 09:14:42 -0600 Subject: [PATCH 2/4] flesh out pq tests a bit --- .../jvector/pq/KMeansPlusPlusClusterer.java | 4 + .../jvector/pq/TestProductQuantization.java | 79 ++++++++++++++++++- 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java index 8690971e2..3767263b7 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java @@ -229,4 +229,8 @@ public static float[] centroidOf(List points) { return centroid; } + + public float[][] getCentroids() { + return centroids; + } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java index 5d6f76422..13f2f03b3 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java @@ -19,29 +19,102 @@ import com.carrotsearch.randomizedtesting.RandomizedTest; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorUtil; import org.junit.Test; import java.util.Arrays; +import java.util.List; +import java.util.Random; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static io.github.jbellis.jvector.TestUtil.randomVector; +import static java.lang.Math.min; import static org.junit.Assert.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class TestProductQuantization extends RandomizedTest { @Test + // special cases where each vector maps exactly to a centroid public void testPerfectReconstruction() { - var vectors = IntStream.range(0,ProductQuantization.CLUSTERS).mapToObj( - i -> new float[] {getRandom().nextInt(100000), getRandom().nextInt(100000), getRandom().nextInt(100000) }) + Random R = getRandom(); + + // exactly the same number of random vectors as clusters + var v1 = IntStream.range(0, ProductQuantization.CLUSTERS).mapToObj( + i -> new float[] { R.nextInt(100_000), R.nextInt(100_000), R.nextInt(100_000) }) + .collect(Collectors.toList()); + assertPerfectQuantization(v1); + + // 10x the number of random vectors as clusters (with duplicates) + var v2 = v1.stream().flatMap(v -> IntStream.range(0, 10).mapToObj(i -> v)) .collect(Collectors.toList()); + assertPerfectQuantization(v2); + } + + private static void assertPerfectQuantization(List vectors) { var ravv = new ListRandomAccessVectorValues(vectors, 3); var pq = ProductQuantization.compute(ravv, 2, false); var encoded = pq.encodeAll(vectors); var decodedScratch = new float[3]; - // if the number of vectors is equal to the number of clusters, we should perfectly reconstruct vectors for (int i = 0; i < vectors.size(); i++) { pq.decode(encoded[i], decodedScratch); assertArrayEquals(Arrays.toString(vectors.get(i)) + "!=" + Arrays.toString(decodedScratch), vectors.get(i), decodedScratch, 0); } } + + @Test + // validate that iterating on our cluster centroids improves the encoding + public void testIterativeImprovement() { + for (int i = 0; i < 10; i++) { + testIterativeImprovementOnce(); + } + } + + public void testIterativeImprovementOnce() { + Random R = getRandom(); + float[][] vectors = generate(ProductQuantization.CLUSTERS + R.nextInt(10*ProductQuantization.CLUSTERS), + 2 + R.nextInt(10), + 1_000 + R.nextInt(10_000)); + + var clusterer = new KMeansPlusPlusClusterer(vectors, ProductQuantization.CLUSTERS, VectorUtil::dotProduct); + var initialLoss = loss(clusterer, vectors); + + assert clusterer.clusterOnce() > 0; + var improvedLoss = loss(clusterer, vectors); + + assertTrue(improvedLoss < initialLoss, "improvedLoss=" + improvedLoss + " initialLoss=" + initialLoss); + } + + private static double loss(KMeansPlusPlusClusterer clusterer, float[][] vectors) { + var pq = new ProductQuantization(new float[][][] { clusterer.getCentroids() }, null); + byte[][] encoded = pq.encodeAll(List.of(vectors)); + + var decodedScratch = new float[vectors[0].length]; + var loss = 0.0; + for (int i = 0; i < vectors.length; i++) { + pq.decode(encoded[i], decodedScratch); + loss += 1 - VectorSimilarityFunction.COSINE.compare(vectors[i], decodedScratch); + } + return loss; + } + + private static float[][] generate(int nClusters, int nDimensions, int nVectors) { + Random R = getRandom(); + + // generate clusters + var clusters = IntStream.range(0, nClusters) + .mapToObj(i -> randomVector(R, nDimensions)) + .collect(Collectors.toList()); + + // generate vectors by perturbing clusters + return IntStream.range(0, nVectors).mapToObj(__ -> { + var cluster = clusters.get(R.nextInt(nClusters)); + var v = randomVector(R, nDimensions); + VectorUtil.scale(v, 0.1f + 0.9f * R.nextFloat()); + VectorUtil.addInPlace(v, cluster); + return v; + }).toArray(float[][]::new); + } } From 17b6e3631c167a197553474b98bc8e9af2ce12f8 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Sun, 28 Jan 2024 12:20:45 -0600 Subject: [PATCH 3/4] add comments and remove distanceFunction from kmeans -- non-L2 distances don't make sense (we have no way to easily compute new centroids from the points assigned to each cluster) --- .../jvector/pq/KMeansPlusPlusClusterer.java | 43 +++++++++++++------ .../jvector/pq/ProductQuantization.java | 2 +- .../jvector/pq/TestProductQuantization.java | 3 +- 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java index 3767263b7..90521fc3a 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java @@ -21,18 +21,23 @@ import java.util.Arrays; import java.util.List; import java.util.Random; -import java.util.function.BiFunction; /** * A KMeans++ implementation for float vectors. Optimizes to use SIMD vector instructions if available. */ public class KMeansPlusPlusClusterer { - private final int k; - private final BiFunction distanceFunction; private final Random random; + + // number of centroids to compute + private final int k; + + // the points to train on private final float[][] points; + // the cluster each point is assigned to private final int[] assignments; + // the centroids of each cluster private final float[][] centroids; + // the number of points assigned to each cluster private final int[] centroidDenoms; private final float[][] centroidNums; @@ -41,9 +46,8 @@ public class KMeansPlusPlusClusterer { * maximum iterations, and distance function. * * @param k number of clusters. - * @param distanceFunction a function to compute the distance between two points. */ - public KMeansPlusPlusClusterer(float[][] points, int k, BiFunction distanceFunction) { + public KMeansPlusPlusClusterer(float[][] points, int k) { if (k <= 0) { throw new IllegalArgumentException("Number of clusters must be positive."); } @@ -53,12 +57,12 @@ public KMeansPlusPlusClusterer(float[][] points, int k, BiFunction vectors, int M, int[][] subvect float[][] subvectors = vectors.stream().parallel() .map(vector -> getSubVector(vector, m, subvectorSizeAndOffset)) .toArray(float[][]::new); - var clusterer = new KMeansPlusPlusClusterer(subvectors, CLUSTERS, VectorUtil::squareDistance); + var clusterer = new KMeansPlusPlusClusterer(subvectors, CLUSTERS); return clusterer.cluster(K_MEANS_ITERATIONS); }) .toArray(float[][][]::new)) diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java index 13f2f03b3..7ab6590f5 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java @@ -30,7 +30,6 @@ import java.util.stream.IntStream; import static io.github.jbellis.jvector.TestUtil.randomVector; -import static java.lang.Math.min; import static org.junit.Assert.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -78,7 +77,7 @@ public void testIterativeImprovementOnce() { 2 + R.nextInt(10), 1_000 + R.nextInt(10_000)); - var clusterer = new KMeansPlusPlusClusterer(vectors, ProductQuantization.CLUSTERS, VectorUtil::dotProduct); + var clusterer = new KMeansPlusPlusClusterer(vectors, ProductQuantization.CLUSTERS); var initialLoss = loss(clusterer, vectors); assert clusterer.clusterOnce() > 0; From 6ee8ebdc8b1a007ddbab16bfb168edb8748010cf Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Thu, 8 Feb 2024 10:56:15 -0600 Subject: [PATCH 4/4] add PQ.refine --- .../jvector/pq/BinaryQuantization.java | 15 +-- .../jvector/pq/KMeansPlusPlusClusterer.java | 36 +++--- .../jvector/pq/ProductQuantization.java | 105 +++++++++++++----- .../jvector/pq/TestProductQuantization.java | 28 +++++ 4 files changed, 130 insertions(+), 54 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/BinaryQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/BinaryQuantization.java index 08477a64b..aed5ef32e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/BinaryQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/BinaryQuantization.java @@ -48,20 +48,7 @@ public static BinaryQuantization compute(RandomAccessVectorValues ravv) } public static BinaryQuantization compute(RandomAccessVectorValues ravv, ForkJoinPool parallelExecutor) { - // limit the number of vectors we train on - var P = min(1.0f, ProductQuantization.MAX_PQ_TRAINING_SET_SIZE / (float) ravv.size()); - var ravvCopy = ravv.threadLocalSupplier(); - var vectors = parallelExecutor.submit(() -> IntStream.range(0, ravv.size()).parallel() - .filter(i -> ThreadLocalRandom.current().nextFloat() < P) - .mapToObj(targetOrd -> { - var localRavv = ravvCopy.get(); - float[] v = localRavv.vectorValue(targetOrd); - return localRavv.isValueShared() ? Arrays.copyOf(v, v.length) : v; - }) - .collect(Collectors.toList())) - .join(); - - // compute the centroid of the training set + var vectors = ProductQuantization.extractTrainingVectors(ravv, parallelExecutor); float[] globalCentroid = KMeansPlusPlusClusterer.centroidOf(vectors); return new BinaryQuantization(globalCentroid); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java index 90521fc3a..2aebd384a 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java @@ -21,16 +21,14 @@ import java.util.Arrays; import java.util.List; import java.util.Random; +import java.util.concurrent.ThreadLocalRandom; /** * A KMeans++ implementation for float vectors. Optimizes to use SIMD vector instructions if available. */ public class KMeansPlusPlusClusterer { - private final Random random; - // number of centroids to compute private final int k; - // the points to train on private final float[][] points; // the cluster each point is assigned to @@ -48,19 +46,21 @@ public class KMeansPlusPlusClusterer { * @param k number of clusters. */ public KMeansPlusPlusClusterer(float[][] points, int k) { - if (k <= 0) { - throw new IllegalArgumentException("Number of clusters must be positive."); - } - if (k > points.length) { - throw new IllegalArgumentException(String.format("Number of clusters %d cannot exceed number of points %d", k, points.length)); - } + this(points, chooseInitialCentroids(points, k)); + } + /** + * Constructs a KMeansPlusPlusFloatClusterer with the specified number of clusters, + * maximum iterations, and distance function. + *

+ * The initial centroids provided as a parameter are copied before modification. + */ + public KMeansPlusPlusClusterer(float[][] points, float[][] centroids) { this.points = points; - this.k = k; - random = new Random(); + this.k = centroids.length; + this.centroids = Arrays.stream(centroids).map(float[]::clone).toArray(float[][]::new); centroidDenoms = new int[k]; centroidNums = new float[k][points[0].length]; - centroids = chooseInitialCentroids(points); assignments = new int[points.length]; initializeAssignedPoints(); @@ -95,10 +95,17 @@ public int clusterOnce() { * across the data and not initialized too closely to each other, leading to better * convergence and potentially improved final clusterings. * - * @param points a list of points from which centroids are chosen. * @return an array of initial centroids. */ - private float[][] chooseInitialCentroids(float[][] points) { + private static float[][] chooseInitialCentroids(float[][] points, int k) { + if (k <= 0) { + throw new IllegalArgumentException("Number of clusters must be positive."); + } + if (k > points.length) { + throw new IllegalArgumentException(String.format("Number of clusters %d cannot exceed number of points %d", k, points.length)); + } + + var random = ThreadLocalRandom.current(); float[][] centroids = new float[k][]; float[] distances = new float[points.length]; Arrays.fill(distances, Float.MAX_VALUE); @@ -224,6 +231,7 @@ private static void assertFinite(float[] vector) { * Calculates centroids from centroidNums/centroidDenoms updated during point assignment */ private void updateCentroids() { + var random = ThreadLocalRandom.current(); for (int i = 0; i < centroids.length; i++) { var denom = centroidDenoms[i]; if (denom == 0) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java index dae6e262e..887adedba 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java @@ -83,20 +83,10 @@ public static ProductQuantization compute( int M, boolean globallyCenter, ForkJoinPool simdExecutor, - ForkJoinPool parallelExecutor) { - // limit the number of vectors we train on - var P = min(1.0f, MAX_PQ_TRAINING_SET_SIZE / (float) ravv.size()); - var ravvCopy = ravv.threadLocalSupplier(); + ForkJoinPool parallelExecutor) + { var subvectorSizesAndOffsets = getSubvectorSizesAndOffsets(ravv.dimension(), M); - var vectors = parallelExecutor.submit(() -> IntStream.range(0, ravv.size()).parallel() - .filter(i -> ThreadLocalRandom.current().nextFloat() < P) - .mapToObj(targetOrd -> { - var localRavv = ravvCopy.get(); - float[] v = localRavv.vectorValue(targetOrd); - return localRavv.isValueShared() ? Arrays.copyOf(v, v.length) : v; - }) - .collect(Collectors.toList())) - .join(); + var vectors = extractTrainingVectors(ravv, parallelExecutor); // subtract the centroid from each training vector float[] globalCentroid; @@ -110,10 +100,64 @@ public static ProductQuantization compute( } // derive the codebooks - var codebooks = createCodebooks(vectors, M, subvectorSizesAndOffsets, simdExecutor); + var codebooks = createCodebooks(vectors, M, subvectorSizesAndOffsets, simdExecutor, parallelExecutor); return new ProductQuantization(codebooks, globalCentroid); } + static List extractTrainingVectors(RandomAccessVectorValues ravv, ForkJoinPool parallelExecutor) { + // limit the number of vectors we train on + var P = min(1.0f, MAX_PQ_TRAINING_SET_SIZE / (float) ravv.size()); + var ravvCopy = ravv.threadLocalSupplier(); + return parallelExecutor.submit(() -> IntStream.range(0, ravv.size()).parallel() + .filter(i -> ThreadLocalRandom.current().nextFloat() < P) + .mapToObj(targetOrd -> { + var localRavv = ravvCopy.get(); + float[] v = localRavv.vectorValue(targetOrd); + return localRavv.isValueShared() ? Arrays.copyOf(v, v.length) : v; + }) + .collect(Collectors.toList())) + .join(); + } + + /** + * Create a new PQ by fine-tuning this one with the data in `ravv` + */ + public ProductQuantization refine(RandomAccessVectorValues ravv) { + return refine(ravv, 1, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool()); + } + + /** + * Create a new PQ by fine-tuning this one with the data in `ravv` + * + * @param lloydsRounds number of Lloyd's iterations to run against + * the new data. Suggested values are 1 or 2. + */ + public ProductQuantization refine(RandomAccessVectorValues ravv, + int lloydsRounds, + ForkJoinPool simdExecutor, + ForkJoinPool parallelExecutor) + { + if (lloydsRounds < 0) { + throw new IllegalArgumentException("lloydsRounds must be non-negative"); + } + + var subvectorSizesAndOffsets = getSubvectorSizesAndOffsets(ravv.dimension(), M); + var vectorsMutable = extractTrainingVectors(ravv, parallelExecutor); + if (globalCentroid != null) { + var vectors = vectorsMutable; + vectorsMutable = simdExecutor.submit(() -> vectors.stream().parallel().map(v -> VectorUtil.sub(v, globalCentroid)).collect(Collectors.toList())).join(); + } + var vectors = vectorsMutable; // "effectively final" to make the closure happy + + var refinedCodebooks = simdExecutor.submit(() -> IntStream.range(0, M).parallel().mapToObj(m -> { + float[][] subvectors = extractSubvectors(vectors, m, subvectorSizesAndOffsets, parallelExecutor); + var clusterer = new KMeansPlusPlusClusterer(subvectors, codebooks[m]); + return clusterer.cluster(lloydsRounds); + }).toArray(float[][][]::new)).join(); + + return new ProductQuantization(refinedCodebooks, globalCentroid); + } + ProductQuantization(float[][][] codebooks, float[] globalCentroid) { this.codebooks = codebooks; @@ -213,19 +257,28 @@ private static String arraySummary(float[] a) { return "[" + String.join(", ", b) + "]"; } - static float[][][] createCodebooks(List vectors, int M, int[][] subvectorSizeAndOffset, ForkJoinPool simdExecutor) { - return simdExecutor.submit(() -> IntStream.range(0, M).parallel() - .mapToObj(m -> { - float[][] subvectors = vectors.stream().parallel() - .map(vector -> getSubVector(vector, m, subvectorSizeAndOffset)) - .toArray(float[][]::new); - var clusterer = new KMeansPlusPlusClusterer(subvectors, CLUSTERS); - return clusterer.cluster(K_MEANS_ITERATIONS); - }) - .toArray(float[][][]::new)) - .join(); + static float[][][] createCodebooks(List vectors, + int M, + int[][] subvectorSizeAndOffset, + ForkJoinPool simdExecutor, + ForkJoinPool parallelExecutor) + { + return simdExecutor.submit(() -> IntStream.range(0, M).parallel().mapToObj(m -> { + float[][] subvectors = extractSubvectors(vectors, m, subvectorSizeAndOffset, parallelExecutor); + var clusterer = new KMeansPlusPlusClusterer(subvectors, CLUSTERS); + return clusterer.cluster(K_MEANS_ITERATIONS); + }).toArray(float[][][]::new)).join(); } - + + /** extract float[] subvectors corresponding to the m'th subspace, in parallel */ + private static float[][] extractSubvectors(List vectors, int m, int[][] subvectorSizeAndOffset, ForkJoinPool parallelExecutor) { + return parallelExecutor.submit(() -> { + return vectors.parallelStream() + .map(vector -> getSubVector(vector, m, subvectorSizeAndOffset)) + .toArray(float[][]::new); + }).join(); + } + static int closetCentroidIndex(float[] subvector, float[][] codebook) { int index = 0; float minDist = Integer.MAX_VALUE; diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java index 7ab6590f5..5ba1994ae 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java @@ -86,6 +86,34 @@ public void testIterativeImprovementOnce() { assertTrue(improvedLoss < initialLoss, "improvedLoss=" + improvedLoss + " initialLoss=" + initialLoss); } + @Test + public void testRefine() { + Random R = getRandom(); + float[][] vectors = generate(ProductQuantization.CLUSTERS + R.nextInt(10*ProductQuantization.CLUSTERS), + 2 + R.nextInt(10), + 1_000 + R.nextInt(10_000)); + + // generate PQ codebooks from half of the dataset + var half1 = Arrays.copyOf(vectors, vectors.length / 2); + var ravv1 = new ListRandomAccessVectorValues(List.of(half1), vectors[0].length); + var pq1 = ProductQuantization.compute(ravv1, 1, false); + + // refine the codebooks with the other half (so, drawn from the same distribution) + int remaining = vectors.length - vectors.length / 2; + var half2 = new float[remaining][]; + System.arraycopy(vectors, vectors.length / 2, half2, 0, remaining); + var ravv2 = new ListRandomAccessVectorValues(List.of(half2), vectors[0].length); + var pq2 = pq1.refine(ravv2); + + // the refined version should work better + var clusterer1 = new KMeansPlusPlusClusterer(half2, pq1.codebooks[0]); + var clusterer2 = new KMeansPlusPlusClusterer(half2, pq2.codebooks[0]); + var loss1 = loss(clusterer1, half2); + var loss2 = loss(clusterer2, half2); + assertTrue(loss2 < loss1, "loss1=" + loss1 + " loss2=" + loss2); + } + + private static double loss(KMeansPlusPlusClusterer clusterer, float[][] vectors) { var pq = new ProductQuantization(new float[][][] { clusterer.getCentroids() }, null); byte[][] encoded = pq.encodeAll(List.of(vectors));