Skip to content

Commit

Permalink
Optimization for latency reduction in Product Quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
AbhijitKulkarni1 committed Feb 28, 2025
1 parent 1f6c5b1 commit 5a97b43
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,26 +181,26 @@ private static VectorFloat<?> chooseInitialCentroids(VectorFloat<?>[] points, in

float[] distances = new float[points.length];
Arrays.fill(distances, Float.MAX_VALUE);
VectorFloat<?> distancesVector = vectorTypeSupport.createFloatVector(distances);
int distancesLength = points.length;

// Choose the first centroid randomly
VectorFloat<?> firstCentroid = points[random.nextInt(points.length)];
centroids.copyFrom(firstCentroid, 0, 0, firstCentroid.length());
VectorFloat<?> newDistancesVector = vectorTypeSupport.createFloatVector(points.length);
for (int i = 0; i < points.length; i++) {
float distance1 = squareL2Distance(points[i], firstCentroid);
distances[i] = Math.min(distances[i], distance1);
newDistancesVector.set(i, squareL2Distance(points[i], firstCentroid));
}
VectorUtil.minInPlace(distancesVector, newDistancesVector);

// For each subsequent centroid
for (int i = 1; i < k; i++) {
float totalDistance = 0;
for (float distance : distances) {
totalDistance += distance;
}
float totalDistance = VectorUtil.sum(distancesVector);

float r = random.nextFloat() * totalDistance;
int selectedIdx = -1;
for (int j = 0; j < distances.length; j++) {
r -= distances[j];
for (int j = 0; j < distancesLength; j++) {
r -= distancesVector.get(j);
if (r < 1e-6) {
selectedIdx = j;
break;
Expand All @@ -215,10 +215,11 @@ private static VectorFloat<?> chooseInitialCentroids(VectorFloat<?>[] points, in
centroids.copyFrom(nextCentroid, 0, i * nextCentroid.length(), nextCentroid.length());

// Update distances, but only if the new centroid provides a closer distance
for (int j = 0; j < points.length; j++) {
float newDistance = squareL2Distance(points[j], nextCentroid);
distances[j] = Math.min(distances[j], newDistance);
newDistancesVector.zero();
for (int j = 0; j < distancesLength; j++) {
newDistancesVector.set(j, squareL2Distance(points[j], nextCentroid));
}
VectorUtil.minInPlace(distancesVector, newDistancesVector);
}
assertFinite(centroids);
return centroids;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,4 +541,11 @@ public float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValu

return squaredSum;
}

@Override
public void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {
for (int i = 0; i < v1.length(); i++) {
v1.set(i, Math.min(v1.get(i), v2.get(i)));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -238,4 +238,8 @@ public static float nvqLoss(VectorFloat<?> vector, float growthRate, float midpo
public static float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValue, int nBits) {
return impl.nvqUniformLoss(vector, minValue, maxValue, nBits);
}

public static void minInPlace(VectorFloat<?> distances1, VectorFloat<?> distances2) {
impl.minInPlace(distances1, distances2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -300,4 +300,7 @@ default float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCoun
* @param nBits the number of bits per dimension
*/
float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValue, int nBits);

/** Calculates the minimum value for every corresponding lane values in v1 and v2, in place (v1 will be modified) */
void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2);
}
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,9 @@ public float nvqLoss(VectorFloat<?> vector, float growthRate, float midpoint, fl
public float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValue, int nBits) {
return VectorSimdOps.nvqUniformLoss((MemorySegmentVectorFloat) vector, minValue, maxValue, nBits);
}

@Override
public void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {
VectorSimdOps.minInPlace((MemorySegmentVectorFloat) v1, (MemorySegmentVectorFloat) v2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,26 @@ public static int hammingDistance(long[] a, long[] b) {
return res;
}

static void minInPlace(MemorySegmentVectorFloat v1, MemorySegmentVectorFloat v2) {
if (v1.length() != v2.length()) {
throw new IllegalArgumentException("Vectors must have the same length");
}

int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(v1.length());

// Process the vectorized part
for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) {
var a = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, v1.get(), v1.offset(i), ByteOrder.LITTLE_ENDIAN);
var b = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, v2.get(), v2.offset(i), ByteOrder.LITTLE_ENDIAN);
a.min(b).intoMemorySegment(v1.get(), v1.offset(i), ByteOrder.LITTLE_ENDIAN);
}

// Process the tail
for (int i = vectorizedLength; i < v1.length(); i++) {
v1.set(i, Math.min(v1.get(i), v2.get(i)));
}
}

public static float max(MemorySegmentVectorFloat vector) {
var accum = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, -Float.MAX_VALUE);
int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ public void addInPlace(VectorFloat<?> v1, float value) {
SimdOps.addInPlace((ArrayVectorFloat)v1, value);
}

@Override
public void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {
SimdOps.minInPlace((ArrayVectorFloat)v1, (ArrayVectorFloat)v2);
}

@Override
public void subInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {
SimdOps.subInPlace((ArrayVectorFloat) v1, (ArrayVectorFloat) v2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,26 @@ static void addInPlace(ArrayVectorFloat v1, float value) {
}
}

static void minInPlace(ArrayVectorFloat v1, ArrayVectorFloat v2) {
if (v1.length() != v2.length()) {
throw new IllegalArgumentException("Vectors must have the same length");
}

int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(v1.length());

// Process the vectorized part
for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) {
var a = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v1.get(), i);
var b = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v2.get(), i);
a.min(b).intoArray(v1.get(), i);
}

// Process the tail
for (int i = vectorizedLength; i < v1.length(); i++) {
v1.set(i, Math.min(v1.get(i), v2.get(i)));
}
}

static void subInPlace(ArrayVectorFloat v1, ArrayVectorFloat v2) {
if (v1.length() != v2.length()) {
throw new IllegalArgumentException("Vectors must have the same length");
Expand Down

0 comments on commit 5a97b43

Please sign in to comment.