Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization for latency reduction in Product Quantization #397

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -181,26 +181,26 @@ private static VectorFloat<?> chooseInitialCentroids(VectorFloat<?>[] points, in

float[] distances = new float[points.length];
Arrays.fill(distances, Float.MAX_VALUE);
VectorFloat<?> distancesVector = vectorTypeSupport.createFloatVector(distances);
int distancesLength = points.length;

// Choose the first centroid randomly
VectorFloat<?> firstCentroid = points[random.nextInt(points.length)];
centroids.copyFrom(firstCentroid, 0, 0, firstCentroid.length());
VectorFloat<?> newDistancesVector = vectorTypeSupport.createFloatVector(points.length);
for (int i = 0; i < points.length; i++) {
float distance1 = squareL2Distance(points[i], firstCentroid);
distances[i] = Math.min(distances[i], distance1);
newDistancesVector.set(i, squareL2Distance(points[i], firstCentroid));
}
VectorUtil.minInPlace(distancesVector, newDistancesVector);

// For each subsequent centroid
for (int i = 1; i < k; i++) {
float totalDistance = 0;
for (float distance : distances) {
totalDistance += distance;
}
float totalDistance = VectorUtil.sum(distancesVector);

float r = random.nextFloat() * totalDistance;
int selectedIdx = -1;
for (int j = 0; j < distances.length; j++) {
r -= distances[j];
for (int j = 0; j < distancesLength; j++) {
r -= distancesVector.get(j);
if (r < 1e-6) {
selectedIdx = j;
break;
Expand All @@ -215,10 +215,11 @@ private static VectorFloat<?> chooseInitialCentroids(VectorFloat<?>[] points, in
centroids.copyFrom(nextCentroid, 0, i * nextCentroid.length(), nextCentroid.length());

// Update distances, but only if the new centroid provides a closer distance
for (int j = 0; j < points.length; j++) {
float newDistance = squareL2Distance(points[j], nextCentroid);
distances[j] = Math.min(distances[j], newDistance);
newDistancesVector.zero();
for (int j = 0; j < distancesLength; j++) {
newDistancesVector.set(j, squareL2Distance(points[j], nextCentroid));
}
VectorUtil.minInPlace(distancesVector, newDistancesVector);
}
assertFinite(centroids);
return centroids;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,4 +541,11 @@ public float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValu

return squaredSum;
}

@Override
public void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {
for (int i = 0; i < v1.length(); i++) {
v1.set(i, Math.min(v1.get(i), v2.get(i)));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -238,4 +238,8 @@ public static float nvqLoss(VectorFloat<?> vector, float growthRate, float midpo
public static float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValue, int nBits) {
return impl.nvqUniformLoss(vector, minValue, maxValue, nBits);
}

public static void minInPlace(VectorFloat<?> distances1, VectorFloat<?> distances2) {
impl.minInPlace(distances1, distances2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -300,4 +300,7 @@ default float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCoun
* @param nBits the number of bits per dimension
*/
float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValue, int nBits);

/** Calculates the minimum value for every corresponding lane values in v1 and v2, in place (v1 will be modified) */
void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2);
}
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,9 @@ public float nvqLoss(VectorFloat<?> vector, float growthRate, float midpoint, fl
public float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValue, int nBits) {
return VectorSimdOps.nvqUniformLoss((MemorySegmentVectorFloat) vector, minValue, maxValue, nBits);
}

@Override
public void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {
VectorSimdOps.minInPlace((MemorySegmentVectorFloat) v1, (MemorySegmentVectorFloat) v2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,26 @@ public static int hammingDistance(long[] a, long[] b) {
return res;
}

static void minInPlace(MemorySegmentVectorFloat v1, MemorySegmentVectorFloat v2) {
if (v1.length() != v2.length()) {
throw new IllegalArgumentException("Vectors must have the same length");
}

int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(v1.length());

// Process the vectorized part
for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) {
var a = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, v1.get(), v1.offset(i), ByteOrder.LITTLE_ENDIAN);
var b = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, v2.get(), v2.offset(i), ByteOrder.LITTLE_ENDIAN);
a.min(b).intoMemorySegment(v1.get(), v1.offset(i), ByteOrder.LITTLE_ENDIAN);
}

// Process the tail
for (int i = vectorizedLength; i < v1.length(); i++) {
v1.set(i, Math.min(v1.get(i), v2.get(i)));
}
}

public static float max(MemorySegmentVectorFloat vector) {
var accum = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, -Float.MAX_VALUE);
int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ public void addInPlace(VectorFloat<?> v1, float value) {
SimdOps.addInPlace((ArrayVectorFloat)v1, value);
}

@Override
public void minInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {
SimdOps.minInPlace((ArrayVectorFloat)v1, (ArrayVectorFloat)v2);
}

@Override
public void subInPlace(VectorFloat<?> v1, VectorFloat<?> v2) {
SimdOps.subInPlace((ArrayVectorFloat) v1, (ArrayVectorFloat) v2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,26 @@ static void addInPlace(ArrayVectorFloat v1, float value) {
}
}

static void minInPlace(ArrayVectorFloat v1, ArrayVectorFloat v2) {
if (v1.length() != v2.length()) {
throw new IllegalArgumentException("Vectors must have the same length");
}

int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(v1.length());

// Process the vectorized part
for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) {
var a = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v1.get(), i);
var b = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v2.get(), i);
a.min(b).intoArray(v1.get(), i);
}

// Process the tail
for (int i = vectorizedLength; i < v1.length(); i++) {
v1.set(i, Math.min(v1.get(i), v2.get(i)));
}
}

static void subInPlace(ArrayVectorFloat v1, ArrayVectorFloat v2) {
if (v1.length() != v2.length()) {
throw new IllegalArgumentException("Vectors must have the same length");
Expand Down