Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add PQ.refine for when you already built PQ once for a similar set of vectors #209

Merged
merged 4 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ private int approximateMedioid() {
var node = it.nextInt();
VectorUtil.addInPlace(centroid, (float[]) vc.vectorValue(node));
}
VectorUtil.divInPlace(centroid, graph.size());
VectorUtil.scale(centroid, 1.0f / graph.size());
NodeSimilarity.ExactScoreFunction scoreFunction = i -> scoreBetween(vc.vectorValue(i), (T) centroid);
int ep = graph.entry();
var result = gs.searchInternal(scoreFunction, null, beamWidth, 0.0f, 0.0f, ep, Bits.ALL);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,7 @@ public static BinaryQuantization compute(RandomAccessVectorValues<float[]> ravv)
}

public static BinaryQuantization compute(RandomAccessVectorValues<float[]> ravv, ForkJoinPool parallelExecutor) {
// limit the number of vectors we train on
var P = min(1.0f, ProductQuantization.MAX_PQ_TRAINING_SET_SIZE / (float) ravv.size());
var ravvCopy = ravv.threadLocalSupplier();
var vectors = parallelExecutor.submit(() -> IntStream.range(0, ravv.size()).parallel()
.filter(i -> ThreadLocalRandom.current().nextFloat() < P)
.mapToObj(targetOrd -> {
var localRavv = ravvCopy.get();
float[] v = localRavv.vectorValue(targetOrd);
return localRavv.isValueShared() ? Arrays.copyOf(v, v.length) : v;
})
.collect(Collectors.toList()))
.join();

// compute the centroid of the training set
var vectors = ProductQuantization.extractTrainingVectors(ravv, parallelExecutor);
float[] globalCentroid = KMeansPlusPlusClusterer.centroidOf(vectors);
return new BinaryQuantization(globalCentroid);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,21 @@
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.function.BiFunction;
import java.util.concurrent.ThreadLocalRandom;

/**
* A KMeans++ implementation for float vectors. Optimizes to use SIMD vector instructions if available.
*/
public class KMeansPlusPlusClusterer {
// number of centroids to compute
private final int k;
private final BiFunction<float[], float[], Float> distanceFunction;
private final Random random;
// the points to train on
private final float[][] points;
// the cluster each point is assigned to
private final int[] assignments;
// the centroids of each cluster
private final float[][] centroids;
// the number of points assigned to each cluster
private final int[] centroidDenoms;
private final float[][] centroidNums;

Expand All @@ -41,24 +44,25 @@ public class KMeansPlusPlusClusterer {
* maximum iterations, and distance function.
*
* @param k number of clusters.
* @param distanceFunction a function to compute the distance between two points.
*/
public KMeansPlusPlusClusterer(float[][] points, int k, BiFunction<float[], float[], Float> distanceFunction) {
if (k <= 0) {
throw new IllegalArgumentException("Number of clusters must be positive.");
}
if (k > points.length) {
throw new IllegalArgumentException(String.format("Number of clusters %d cannot exceed number of points %d", k, points.length));
}
public KMeansPlusPlusClusterer(float[][] points, int k) {
this(points, chooseInitialCentroids(points, k));
}

/**
* Constructs a KMeansPlusPlusFloatClusterer with the specified number of clusters,
* maximum iterations, and distance function.
* <p>
* The initial centroids provided as a parameter are copied before modification.
*/
public KMeansPlusPlusClusterer(float[][] points, float[][] centroids) {
this.points = points;
this.k = k;
this.distanceFunction = distanceFunction;
random = new Random();
this.k = centroids.length;
this.centroids = Arrays.stream(centroids).map(float[]::clone).toArray(float[][]::new);
centroidDenoms = new int[k];
centroidNums = new float[k][points[0].length];
centroids = chooseInitialCentroids(points);
assignments = new int[points.length];

initializeAssignedPoints();
}

Expand Down Expand Up @@ -91,10 +95,17 @@ public int clusterOnce() {
* across the data and not initialized too closely to each other, leading to better
* convergence and potentially improved final clusterings.
*
* @param points a list of points from which centroids are chosen.
* @return an array of initial centroids.
*/
private float[][] chooseInitialCentroids(float[][] points) {
private static float[][] chooseInitialCentroids(float[][] points, int k) {
if (k <= 0) {
throw new IllegalArgumentException("Number of clusters must be positive.");
}
if (k > points.length) {
throw new IllegalArgumentException(String.format("Number of clusters %d cannot exceed number of points %d", k, points.length));
}

var random = ThreadLocalRandom.current();
float[][] centroids = new float[k][];
float[] distances = new float[points.length];
Arrays.fill(distances, Float.MAX_VALUE);
Expand All @@ -103,7 +114,7 @@ private float[][] chooseInitialCentroids(float[][] points) {
float[] firstCentroid = points[random.nextInt(points.length)];
centroids[0] = firstCentroid;
for (int i = 0; i < points.length; i++) {
float distance1 = distanceFunction.apply(points[i], firstCentroid);
float distance1 = VectorUtil.squareDistance(points[i], firstCentroid);
distances[i] = Math.min(distances[i], distance1);
}

Expand Down Expand Up @@ -133,11 +144,14 @@ private float[][] chooseInitialCentroids(float[][] points) {

// Update distances, but only if the new centroid provides a closer distance
for (int j = 0; j < points.length; j++) {
float newDistance = distanceFunction.apply(points[j], nextCentroid);
float newDistance = VectorUtil.squareDistance(points[j], nextCentroid);
distances[j] = Math.min(distances[j], newDistance);
}
}

for (float[] centroid : centroids) {
assertFinite(centroid);
}
return centroids;
}

Expand All @@ -148,7 +162,7 @@ private float[][] chooseInitialCentroids(float[][] points) {
private void initializeAssignedPoints() {
for (int i = 0; i < points.length; i++) {
float[] point = points[i];
var newAssignment = getNearestCluster(point, centroids);
var newAssignment = getNearestCluster(point);
centroidDenoms[newAssignment] = centroidDenoms[newAssignment] + 1;
VectorUtil.addInPlace(centroidNums[newAssignment], point);
assignments[i] = newAssignment;
Expand All @@ -168,7 +182,7 @@ private int updateAssignedPoints() {
for (int i = 0; i < points.length; i++) {
float[] point = points[i];
var oldAssignment = assignments[i];
var newAssignment = getNearestCluster(point, centroids);
var newAssignment = getNearestCluster(point);

if (newAssignment != oldAssignment) {
centroidDenoms[oldAssignment] = centroidDenoms[oldAssignment] - 1;
Expand All @@ -186,12 +200,12 @@ private int updateAssignedPoints() {
/**
* Return the index of the closest centroid to the given point
*/
private int getNearestCluster(float[] point, float[][] centroids) {
private int getNearestCluster(float[] point) {
float minDistance = Float.MAX_VALUE;
int nearestCluster = 0;

for (int i = 0; i < k; i++) {
float distance = distanceFunction.apply(point, centroids[i]);
float distance = VectorUtil.squareDistance(point, centroids[i]);
if (distance < minDistance) {
minDistance = distance;
nearestCluster = i;
Expand All @@ -201,17 +215,30 @@ private int getNearestCluster(float[] point, float[][] centroids) {
return nearestCluster;
}

@SuppressWarnings({"AssertWithSideEffects", "ConstantConditions"})
private static void assertFinite(float[] vector) {
boolean assertsEnabled = false;
assert assertsEnabled = true;

if (assertsEnabled) {
for (float v : vector) {
assert Float.isFinite(v) : "vector " + Arrays.toString(vector) + " contains non-finite value";
}
}
}

/**
* Calculates centroids from centroidNums/centroidDenoms updated during point assignment
*/
private void updateCentroids() {
var random = ThreadLocalRandom.current();
for (int i = 0; i < centroids.length; i++) {
var denom = centroidDenoms[i];
if (denom == 0) {
centroids[i] = points[random.nextInt(points.length)];
} else {
centroids[i] = Arrays.copyOf(centroidNums[i], centroidNums[i].length);
VectorUtil.divInPlace(centroids[i], centroidDenoms[i]);
VectorUtil.scale(centroids[i], 1.0f / centroidDenoms[i]);
}
}
}
Expand All @@ -225,8 +252,12 @@ public static float[] centroidOf(List<float[]> points) {
}

float[] centroid = VectorUtil.sum(points);
VectorUtil.divInPlace(centroid, points.size());
VectorUtil.scale(centroid, 1.0f / points.size());

return centroid;
}

public float[][] getCentroids() {
return centroids;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,10 @@ public static ProductQuantization compute(
int M,
boolean globallyCenter,
ForkJoinPool simdExecutor,
ForkJoinPool parallelExecutor) {
// limit the number of vectors we train on
var P = min(1.0f, MAX_PQ_TRAINING_SET_SIZE / (float) ravv.size());
var ravvCopy = ravv.threadLocalSupplier();
ForkJoinPool parallelExecutor)
{
var subvectorSizesAndOffsets = getSubvectorSizesAndOffsets(ravv.dimension(), M);
var vectors = parallelExecutor.submit(() -> IntStream.range(0, ravv.size()).parallel()
.filter(i -> ThreadLocalRandom.current().nextFloat() < P)
.mapToObj(targetOrd -> {
var localRavv = ravvCopy.get();
float[] v = localRavv.vectorValue(targetOrd);
return localRavv.isValueShared() ? Arrays.copyOf(v, v.length) : v;
})
.collect(Collectors.toList()))
.join();
var vectors = extractTrainingVectors(ravv, parallelExecutor);

// subtract the centroid from each training vector
float[] globalCentroid;
Expand All @@ -110,10 +100,64 @@ public static ProductQuantization compute(
}

// derive the codebooks
var codebooks = createCodebooks(vectors, M, subvectorSizesAndOffsets, simdExecutor);
var codebooks = createCodebooks(vectors, M, subvectorSizesAndOffsets, simdExecutor, parallelExecutor);
return new ProductQuantization(codebooks, globalCentroid);
}

static List<float[]> extractTrainingVectors(RandomAccessVectorValues<float[]> ravv, ForkJoinPool parallelExecutor) {
// limit the number of vectors we train on
var P = min(1.0f, MAX_PQ_TRAINING_SET_SIZE / (float) ravv.size());
var ravvCopy = ravv.threadLocalSupplier();
return parallelExecutor.submit(() -> IntStream.range(0, ravv.size()).parallel()
.filter(i -> ThreadLocalRandom.current().nextFloat() < P)
.mapToObj(targetOrd -> {
var localRavv = ravvCopy.get();
float[] v = localRavv.vectorValue(targetOrd);
return localRavv.isValueShared() ? Arrays.copyOf(v, v.length) : v;
})
.collect(Collectors.toList()))
.join();
}

/**
* Create a new PQ by fine-tuning this one with the data in `ravv`
*/
public ProductQuantization refine(RandomAccessVectorValues<float[]> ravv) {
return refine(ravv, 1, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool());
}

/**
* Create a new PQ by fine-tuning this one with the data in `ravv`
*
* @param lloydsRounds number of Lloyd's iterations to run against
* the new data. Suggested values are 1 or 2.
*/
public ProductQuantization refine(RandomAccessVectorValues<float[]> ravv,
int lloydsRounds,
ForkJoinPool simdExecutor,
ForkJoinPool parallelExecutor)
{
if (lloydsRounds < 0) {
throw new IllegalArgumentException("lloydsRounds must be non-negative");
}

var subvectorSizesAndOffsets = getSubvectorSizesAndOffsets(ravv.dimension(), M);
var vectorsMutable = extractTrainingVectors(ravv, parallelExecutor);
if (globalCentroid != null) {
var vectors = vectorsMutable;
vectorsMutable = simdExecutor.submit(() -> vectors.stream().parallel().map(v -> VectorUtil.sub(v, globalCentroid)).collect(Collectors.toList())).join();
}
var vectors = vectorsMutable; // "effectively final" to make the closure happy

var refinedCodebooks = simdExecutor.submit(() -> IntStream.range(0, M).parallel().mapToObj(m -> {
float[][] subvectors = extractSubvectors(vectors, m, subvectorSizesAndOffsets, parallelExecutor);
var clusterer = new KMeansPlusPlusClusterer(subvectors, codebooks[m]);
return clusterer.cluster(lloydsRounds);
}).toArray(float[][][]::new)).join();

return new ProductQuantization(refinedCodebooks, globalCentroid);
}

ProductQuantization(float[][][] codebooks, float[] globalCentroid)
{
this.codebooks = codebooks;
Expand Down Expand Up @@ -213,19 +257,28 @@ private static String arraySummary(float[] a) {
return "[" + String.join(", ", b) + "]";
}

static float[][][] createCodebooks(List<float[]> vectors, int M, int[][] subvectorSizeAndOffset, ForkJoinPool simdExecutor) {
return simdExecutor.submit(() -> IntStream.range(0, M).parallel()
.mapToObj(m -> {
float[][] subvectors = vectors.stream().parallel()
.map(vector -> getSubVector(vector, m, subvectorSizeAndOffset))
.toArray(float[][]::new);
var clusterer = new KMeansPlusPlusClusterer(subvectors, CLUSTERS, VectorUtil::squareDistance);
return clusterer.cluster(K_MEANS_ITERATIONS);
})
.toArray(float[][][]::new))
.join();
static float[][][] createCodebooks(List<float[]> vectors,
int M,
int[][] subvectorSizeAndOffset,
ForkJoinPool simdExecutor,
ForkJoinPool parallelExecutor)
{
return simdExecutor.submit(() -> IntStream.range(0, M).parallel().mapToObj(m -> {
float[][] subvectors = extractSubvectors(vectors, m, subvectorSizeAndOffset, parallelExecutor);
var clusterer = new KMeansPlusPlusClusterer(subvectors, CLUSTERS);
return clusterer.cluster(K_MEANS_ITERATIONS);
}).toArray(float[][][]::new)).join();
}


/** extract float[] subvectors corresponding to the m'th subspace, in parallel */
private static float[][] extractSubvectors(List<float[]> vectors, int m, int[][] subvectorSizeAndOffset, ForkJoinPool parallelExecutor) {
return parallelExecutor.submit(() -> {
return vectors.parallelStream()
.map(vector -> getSubVector(vector, m, subvectorSizeAndOffset))
.toArray(float[][]::new);
}).join();
}

static int closetCentroidIndex(float[] subvector, float[][] codebook) {
int index = 0;
float minDist = Integer.MAX_VALUE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ public float sum(float[] vector) {
}

@Override
public void divInPlace(float[] vector, float divisor) {
public void scale(float[] vector, float multiplier) {
for (int i = 0; i < vector.length; i++) {
vector[i] /= divisor;
vector[i] *= multiplier;
}
}

Expand Down
Loading
Loading