diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/KMeansPlusPlusClusterer.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/KMeansPlusPlusClusterer.java index 573c00b5..7d4ee04f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/KMeansPlusPlusClusterer.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/KMeansPlusPlusClusterer.java @@ -181,26 +181,26 @@ private static VectorFloat chooseInitialCentroids(VectorFloat[] points, in float[] distances = new float[points.length]; Arrays.fill(distances, Float.MAX_VALUE); + VectorFloat distancesVector = vectorTypeSupport.createFloatVector(distances); + int distancesLength = points.length; // Choose the first centroid randomly VectorFloat firstCentroid = points[random.nextInt(points.length)]; centroids.copyFrom(firstCentroid, 0, 0, firstCentroid.length()); + VectorFloat newDistancesVector = vectorTypeSupport.createFloatVector(points.length); for (int i = 0; i < points.length; i++) { - float distance1 = squareL2Distance(points[i], firstCentroid); - distances[i] = Math.min(distances[i], distance1); + newDistancesVector.set(i, squareL2Distance(points[i], firstCentroid)); } + VectorUtil.minInPlace(distancesVector, newDistancesVector); // For each subsequent centroid for (int i = 1; i < k; i++) { - float totalDistance = 0; - for (float distance : distances) { - totalDistance += distance; - } + float totalDistance = VectorUtil.sum(distancesVector); float r = random.nextFloat() * totalDistance; int selectedIdx = -1; - for (int j = 0; j < distances.length; j++) { - r -= distances[j]; + for (int j = 0; j < distancesLength; j++) { + r -= distancesVector.get(j); if (r < 1e-6) { selectedIdx = j; break; @@ -215,10 +215,11 @@ private static VectorFloat chooseInitialCentroids(VectorFloat[] points, in centroids.copyFrom(nextCentroid, 0, i * nextCentroid.length(), nextCentroid.length()); // Update distances, but only if the new centroid provides a closer distance - for (int j = 0; j < points.length; j++) { - float newDistance = squareL2Distance(points[j], nextCentroid); - distances[j] = Math.min(distances[j], newDistance); + newDistancesVector.zero(); + for (int j = 0; j < distancesLength; j++) { + newDistancesVector.set(j, squareL2Distance(points[j], nextCentroid)); } + VectorUtil.minInPlace(distancesVector, newDistancesVector); } assertFinite(centroids); return centroids; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java index db169766..20313c3c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java @@ -541,4 +541,11 @@ public float nvqUniformLoss(VectorFloat vector, float minValue, float maxValu return squaredSum; } + + @Override + public void minInPlace(VectorFloat v1, VectorFloat v2) { + for (int i = 0; i < v1.length(); i++) { + v1.set(i, Math.min(v1.get(i), v2.get(i))); + } + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java index 0b847a34..875c78c4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java @@ -238,4 +238,8 @@ public static float nvqLoss(VectorFloat vector, float growthRate, float midpo public static float nvqUniformLoss(VectorFloat vector, float minValue, float maxValue, int nBits) { return impl.nvqUniformLoss(vector, minValue, maxValue, nBits); } + + public static void minInPlace(VectorFloat distances1, VectorFloat distances2) { + impl.minInPlace(distances1, distances2); + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java index 2aa3be5d..e8d9d7ba 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java @@ -300,4 +300,7 @@ default float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCoun * @param nBits the number of bits per dimension */ float nvqUniformLoss(VectorFloat vector, float minValue, float maxValue, int nBits); + + /** Calculates the minimum value for every corresponding lane values in v1 and v2, in place (v1 will be modified) */ + void minInPlace(VectorFloat v1, VectorFloat v2); } diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java index a0c070cf..2db601fe 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java @@ -225,4 +225,9 @@ public float nvqLoss(VectorFloat vector, float growthRate, float midpoint, fl public float nvqUniformLoss(VectorFloat vector, float minValue, float maxValue, int nBits) { return VectorSimdOps.nvqUniformLoss((MemorySegmentVectorFloat) vector, minValue, maxValue, nBits); } + + @Override + public void minInPlace(VectorFloat v1, VectorFloat v2) { + VectorSimdOps.minInPlace((MemorySegmentVectorFloat) v1, (MemorySegmentVectorFloat) v2); + } } diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/VectorSimdOps.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/VectorSimdOps.java index 37b99909..c6f01074 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/VectorSimdOps.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/VectorSimdOps.java @@ -598,6 +598,26 @@ public static int hammingDistance(long[] a, long[] b) { return res; } + static void minInPlace(MemorySegmentVectorFloat v1, MemorySegmentVectorFloat v2) { + if (v1.length() != v2.length()) { + throw new IllegalArgumentException("Vectors must have the same length"); + } + + int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(v1.length()); + + // Process the vectorized part + for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + var a = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, v1.get(), v1.offset(i), ByteOrder.LITTLE_ENDIAN); + var b = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, v2.get(), v2.offset(i), ByteOrder.LITTLE_ENDIAN); + a.min(b).intoMemorySegment(v1.get(), v1.offset(i), ByteOrder.LITTLE_ENDIAN); + } + + // Process the tail + for (int i = vectorizedLength; i < v1.length(); i++) { + v1.set(i, Math.min(v1.get(i), v2.get(i))); + } + } + public static float max(MemorySegmentVectorFloat vector) { var accum = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, -Float.MAX_VALUE); int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length()); diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index f18266e5..40ee56cf 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -77,6 +77,11 @@ public void addInPlace(VectorFloat v1, float value) { SimdOps.addInPlace((ArrayVectorFloat)v1, value); } + @Override + public void minInPlace(VectorFloat v1, VectorFloat v2) { + SimdOps.minInPlace((ArrayVectorFloat)v1, (ArrayVectorFloat)v2); + } + @Override public void subInPlace(VectorFloat v1, VectorFloat v2) { SimdOps.subInPlace((ArrayVectorFloat) v1, (ArrayVectorFloat) v2); diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index dd053c3d..0263960a 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -538,6 +538,26 @@ static void addInPlace(ArrayVectorFloat v1, float value) { } } + static void minInPlace(ArrayVectorFloat v1, ArrayVectorFloat v2) { + if (v1.length() != v2.length()) { + throw new IllegalArgumentException("Vectors must have the same length"); + } + + int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(v1.length()); + + // Process the vectorized part + for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + var a = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v1.get(), i); + var b = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v2.get(), i); + a.min(b).intoArray(v1.get(), i); + } + + // Process the tail + for (int i = vectorizedLength; i < v1.length(); i++) { + v1.set(i, Math.min(v1.get(i), v2.get(i))); + } + } + static void subInPlace(ArrayVectorFloat v1, ArrayVectorFloat v2) { if (v1.length() != v2.length()) { throw new IllegalArgumentException("Vectors must have the same length");