Skip to content

Commit

Permalink
Add 'pq' prefix to all forms of decodedCosineSimilarity methods
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeljmarshall committed Nov 21, 2024
1 parent 3d79217 commit f99c746
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ protected float decodedCosine(int node2) {

ByteSequence<?> encoded = cv.get(node2);

return VectorUtil.decodedCosineSimilarity(encoded, cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude);
return VectorUtil.pqDecodedCosineSimilarity(encoded, cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ public static float min(VectorFloat<?> v) {
return impl.min(v);
}

public static float decodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude) {
return impl.decodedCosineSimilarity(encoded, clusterCount, partialSums, aMagnitude, bMagnitude);
public static float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude) {
return impl.pqDecodedCosineSimilarity(encoded, clusterCount, partialSums, aMagnitude, bMagnitude);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ default void bulkShuffleQuantizedSimilarityCosine(ByteSequence<?> shuffles, int
float max(VectorFloat<?> v);
float min(VectorFloat<?> v);

default float decodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
default float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
{
float sum = 0.0f;
float aMag = 0.0f;
Expand Down
2 changes: 1 addition & 1 deletion jvector-native/src/main/c/jvector_simd.c
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned c
return res;
}

float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) {
float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) {
__m512 sum = _mm512_setzero_ps();
__m512 vaMagnitude = _mm512_setzero_ps();
int i = 0;
Expand Down
2 changes: 1 addition & 1 deletion jvector-native/src/main/c/jvector_simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codeb
void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results);
void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results);
float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsLength);
float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude);
float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude);
void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums);
void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums);
void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ public void bulkShuffleQuantizedSimilarityCosine(ByteSequence<?> shuffles, int c
}

@Override
public float decodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
public float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
{
return NativeSimdOps.decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encoded.length(), clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude);
return NativeSimdOps.pq_decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encoded.length(), clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ public static float assemble_and_sum_f32_512(MemorySegment data, int dataBase, M
}
}

private static class decoded_cosine_similarity_f32_512 {
private static class pq_decoded_cosine_similarity_f32_512 {
public static final FunctionDescriptor DESC = FunctionDescriptor.of(
NativeSimdOps.C_FLOAT,
NativeSimdOps.C_POINTER,
Expand All @@ -464,39 +464,39 @@ private static class decoded_cosine_similarity_f32_512 {
);

public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(
NativeSimdOps.findOrThrow("decoded_cosine_similarity_f32_512"),
NativeSimdOps.findOrThrow("pq_decoded_cosine_similarity_f32_512"),
DESC, Linker.Option.critical(true));
}

/**
* Function descriptor for:
* {@snippet lang=c :
* float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
* float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
* }
*/
public static FunctionDescriptor decoded_cosine_similarity_f32_512$descriptor() {
return decoded_cosine_similarity_f32_512.DESC;
public static FunctionDescriptor pq_decoded_cosine_similarity_f32_512$descriptor() {
return pq_decoded_cosine_similarity_f32_512.DESC;
}

/**
* Downcall method handle for:
* {@snippet lang=c :
* float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
* float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
* }
*/
public static MethodHandle decoded_cosine_similarity_f32_512$handle() {
return decoded_cosine_similarity_f32_512.HANDLE;
public static MethodHandle pq_decoded_cosine_similarity_f32_512$handle() {
return pq_decoded_cosine_similarity_f32_512.HANDLE;
}
/**
* {@snippet lang=c :
* float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
* float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
* }
*/
public static float decoded_cosine_similarity_f32_512(MemorySegment baseOffsets, int baseOffsetsLength, int clusterCount, MemorySegment partialSums, MemorySegment aMagnitude, float bMagnitude) {
var mh$ = decoded_cosine_similarity_f32_512.HANDLE;
public static float pq_decoded_cosine_similarity_f32_512(MemorySegment baseOffsets, int baseOffsetsLength, int clusterCount, MemorySegment partialSums, MemorySegment aMagnitude, float bMagnitude) {
var mh$ = pq_decoded_cosine_similarity_f32_512.HANDLE;
try {
if (TRACE_DOWNCALLS) {
traceDowncall("decoded_cosine_similarity_f32_512", baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude);
traceDowncall("pq_decoded_cosine_similarity_f32_512", baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude);
}
return (float)mh$.invokeExact(baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude);
} catch (Throwable ex$) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ public void quantizePartials(float delta, VectorFloat<?> partials, VectorFloat<?
}

@Override
public float decodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
public float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
{
return SimdOps.decodedCosineSimilarity((ArrayByteSequence) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude);
return SimdOps.pqDecodedCosineSimilarity((ArrayByteSequence) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -660,13 +660,13 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra
}
}

public static float decodedCosineSimilarity(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
public static float pqDecodedCosineSimilarity(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
return HAS_AVX512
? decodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude)
: decodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude);
? pqDecodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude)
: pqDecodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude);
}

public static float decodedCosineSimilarity512(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
public static float pqDecodedCosineSimilarity512(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
var sum = FloatVector.zero(FloatVector.SPECIES_512);
var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_512);
var baseOffsets = encoded.get();
Expand Down Expand Up @@ -705,7 +705,7 @@ public static float decodedCosineSimilarity512(ArrayByteSequence encoded, int cl
return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude));
}

public static float decodedCosineSimilarity256(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
public static float pqDecodedCosineSimilarity256(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
var sum = FloatVector.zero(FloatVector.SPECIES_256);
var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256);
var baseOffsets = encoded.get();
Expand Down

0 comments on commit f99c746

Please sign in to comment.