diff --git a/jvector-native/src/main/c/jvector_simd.c b/jvector-native/src/main/c/jvector_simd.c index 886186fab..aedc70210 100644 --- a/jvector-native/src/main/c/jvector_simd.c +++ b/jvector-native/src/main/c/jvector_simd.c @@ -318,6 +318,51 @@ float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned c return res; } +float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, const float* partialSums, const float* aMagnitude, int clusterCount, float bMagnitude) { + __m512 sum = _mm512_setzero_ps(); + __m512 vaMagnitude = _mm512_setzero_ps(); + int i = 0; + int limit = baseOffsetsLength - (baseOffsetsLength % 16); + __m512i indexRegister = initialIndexRegister; + __m512i scale = _mm512_set1_epi32(clusterCount); + + + for (; i < limit; i += 16) { + // Load and convert baseOffsets to integers + __m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i)); + __m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw); + + indexRegister = _mm512_add_epi32(indexRegister, indexIncrement); + // Scale the baseOffsets by the cluster count + __m512i scaledOffsets = _mm512_mullo_epi32(indexRegister, scale); + + // Compute the offset base by multiplying 'i' with clusterCount and broadcasting to all lanes + __m512i offsetBase = _mm512_set1_epi32(i * clusterCount); + + // Calculate the final convOffsets by adding the scaled offsets and the offset base + __m512i convOffsets = _mm512_add_epi32(scaledOffsets, offsetBase); + + // Gather and sum values for partial sums and a magnitude + __m512 partialSumVals = _mm512_i32gather_ps(convOffsets, partialSums, 4); + sum = _mm512_add_ps(sum, partialSumVals); + + __m512 aMagnitudeVals = _mm512_i32gather_ps(convOffsets, aMagnitude, 4); + vaMagnitude = _mm512_add_ps(vaMagnitude, aMagnitudeVals); + } + + // Reduce sums + float sumResult = _mm512_reduce_add_ps(sum); + float aMagnitudeResult = _mm512_reduce_add_ps(vaMagnitude); + + // Handle the remaining elements + for (; i < baseOffsetsLength; i++) { + int offset = clusterCount * i + baseOffsets[i]; + sumResult += partialSums[offset]; + aMagnitudeResult += aMagnitude[offset]; + } + + return sumResult / sqrtf(aMagnitudeResult * bMagnitude); +} void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) { int codebookBase = codebookIndex * clusterCount; diff --git a/jvector-native/src/main/c/jvector_simd.h b/jvector-native/src/main/c/jvector_simd.h index a5410ef5f..1b96a0a8e 100644 --- a/jvector-native/src/main/c/jvector_simd.h +++ b/jvector-native/src/main/c/jvector_simd.h @@ -29,6 +29,7 @@ void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codeb void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results); void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results); float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsLength); +float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude); void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums); void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums); void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances);