Skip to content

Commit

Permalink
fix issue with avx2 reduce step
Browse files Browse the repository at this point in the history
  • Loading branch information
Kazak Sergey committed Sep 19, 2023
1 parent 68441b9 commit 7cb794d
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 16 deletions.
2 changes: 1 addition & 1 deletion examples/packed_index_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ void basic_packutils_test()
__attribute__ ((aligned (64))) uint16_t packed_avx32[512] = {0};
__attribute__ ((aligned (64))) float unpacked_avx32[512] = {0};
#endif
#if defined(USE_AVX2)
#if defined(USE_AVX2)
__attribute__ ((aligned (16))) uint16_t packed_avx16[512] = {0};
__attribute__ ((aligned (16))) float unpacked_avx16[512] = {0};
#endif
Expand Down
5 changes: 2 additions & 3 deletions src/packedlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,11 @@ class PackedAnnoySearcher {
return true;
}
};
typedef std::pair<T, S> qpair_t;
typedef std::vector<qpair_t, RAlloc<qpair_t>> queue_t;
private:
static constexpr S const _K_mask = S(1UL) << S(sizeof(S) * 8 - 1);
static constexpr S const _K_mask_clear = _K_mask - 1;
protected:
typedef std::pair<T, S> qpair_t;
typedef std::vector<qpair_t, RAlloc<qpair_t>> queue_t;
protected:
S _f;
S _s; // Size of each node
Expand Down
21 changes: 9 additions & 12 deletions src/packutils.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,13 @@ inline float decode_and_euclidean_distance_i16_f32_avx32( uint16_t const *__rest

#if defined(USE_AVX2)

inline float _mm256_reduce_add_ps(__m256 x) {
__m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
__m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
__m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
return _mm_cvtss_f32(x32);
}

inline void pack_float_vector_i16_avx16( float const *__restrict__ x, uint16_t *__restrict__ out, uint32_t d )
{
__m256 mm1 = _mm256_set1_ps(32767.f);
Expand Down Expand Up @@ -290,12 +297,7 @@ inline float decode_and_dot_i16_f32_avx16( uint16_t const *__restrict__ in, floa
d -= 16;
}
msum1 = _mm256_add_ps(msum1, msum2);
// now sum of 8
msum1 = _mm256_hadd_ps (msum1, msum1);
msum1 = _mm256_hadd_ps (msum1, msum1);
// now 0 and 4 left
sum = _mm_cvtss_f32 (_mm256_castps256_ps128(msum1)) +
_mm_cvtss_f32 (_mm256_extractf128_ps(msum1, 1));
sum = _mm256_reduce_add_ps(msum1);

if( d )
{
Expand Down Expand Up @@ -345,12 +347,7 @@ inline float decode_and_euclidean_distance_i16_f32_avx16( uint16_t const *__rest
d -= 16;
}
msum1 = _mm256_add_ps(msum1, msum2);
// now sum of 8
msum1 = _mm256_hadd_ps (msum1, msum1);
msum1 = _mm256_hadd_ps (msum1, msum1);
// now 0 and 4 left
sum = _mm_cvtss_f32 (_mm256_castps256_ps128(msum1)) +
_mm_cvtss_f32 (_mm256_extractf128_ps(msum1, 1));
sum = _mm256_reduce_add_ps(msum1);

// here can be 0/8 left, so do check and calc tail if exists
if( d )
Expand Down

0 comments on commit 7cb794d

Please sign in to comment.