From 52cebb1f4513b9eafca832da559513670de0a295 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tarek=20Ziad=C3=A9?= Date: Mon, 2 Dec 2024 10:09:49 +0100 Subject: [PATCH] added a firefox matmul backend --- build.sh | 15 +- cmake/onnxruntime_webassembly.cmake | 1 + include/onnxruntime/gemmology.h | 1390 ++++++++ include/onnxruntime/gemmology_fwd.h | 282 ++ .../arch/generic/xsimd_generic_arithmetic.hpp | 241 ++ .../arch/generic/xsimd_generic_complex.hpp | 108 + .../arch/generic/xsimd_generic_details.hpp | 316 ++ .../arch/generic/xsimd_generic_logical.hpp | 208 ++ .../xsimd/arch/generic/xsimd_generic_math.hpp | 2499 +++++++++++++++ .../arch/generic/xsimd_generic_memory.hpp | 672 ++++ .../arch/generic/xsimd_generic_rounding.hpp | 72 + .../arch/generic/xsimd_generic_trigo.hpp | 969 ++++++ include/onnxruntime/xsimd/arch/xsimd_avx.hpp | 1820 +++++++++++ include/onnxruntime/xsimd/arch/xsimd_avx2.hpp | 1021 ++++++ .../onnxruntime/xsimd/arch/xsimd_avx512bw.hpp | 701 ++++ .../onnxruntime/xsimd/arch/xsimd_avx512cd.hpp | 28 + .../onnxruntime/xsimd/arch/xsimd_avx512dq.hpp | 212 ++ .../onnxruntime/xsimd/arch/xsimd_avx512er.hpp | 20 + .../onnxruntime/xsimd/arch/xsimd_avx512f.hpp | 2167 +++++++++++++ .../xsimd/arch/xsimd_avx512ifma.hpp | 20 + .../onnxruntime/xsimd/arch/xsimd_avx512pf.hpp | 20 + .../xsimd/arch/xsimd_avx512vbmi.hpp | 20 + .../xsimd/arch/xsimd_avx512vnni_avx512bw.hpp | 20 + .../arch/xsimd_avx512vnni_avx512vbmi.hpp | 20 + .../onnxruntime/xsimd/arch/xsimd_avxvnni.hpp | 20 + .../xsimd/arch/xsimd_constants.hpp | 391 +++ .../onnxruntime/xsimd/arch/xsimd_emulated.hpp | 771 +++++ .../onnxruntime/xsimd/arch/xsimd_fma3_avx.hpp | 80 + .../xsimd/arch/xsimd_fma3_avx2.hpp | 46 + .../onnxruntime/xsimd/arch/xsimd_fma3_sse.hpp | 79 + include/onnxruntime/xsimd/arch/xsimd_fma4.hpp | 79 + .../onnxruntime/xsimd/arch/xsimd_generic.hpp | 23 + .../xsimd/arch/xsimd_generic_fwd.hpp | 44 + .../xsimd/arch/xsimd_i8mm_neon64.hpp | 17 + include/onnxruntime/xsimd/arch/xsimd_isa.hpp | 130 + include/onnxruntime/xsimd/arch/xsimd_neon.hpp | 2813 +++++++++++++++++ .../onnxruntime/xsimd/arch/xsimd_neon64.hpp | 1536 +++++++++ include/onnxruntime/xsimd/arch/xsimd_rvv.hpp | 1500 +++++++++ .../onnxruntime/xsimd/arch/xsimd_scalar.hpp | 1223 +++++++ include/onnxruntime/xsimd/arch/xsimd_sse2.hpp | 1763 +++++++++++ include/onnxruntime/xsimd/arch/xsimd_sse3.hpp | 64 + .../onnxruntime/xsimd/arch/xsimd_sse4_1.hpp | 339 ++ .../onnxruntime/xsimd/arch/xsimd_sse4_2.hpp | 44 + .../onnxruntime/xsimd/arch/xsimd_ssse3.hpp | 175 + include/onnxruntime/xsimd/arch/xsimd_sve.hpp | 1148 +++++++ include/onnxruntime/xsimd/arch/xsimd_wasm.hpp | 1703 ++++++++++ .../onnxruntime/xsimd/config/xsimd_arch.hpp | 238 ++ .../onnxruntime/xsimd/config/xsimd_config.hpp | 462 +++ .../onnxruntime/xsimd/config/xsimd_cpuid.hpp | 262 ++ .../onnxruntime/xsimd/config/xsimd_inline.hpp | 23 + .../onnxruntime/xsimd/math/xsimd_rem_pio2.hpp | 719 +++++ .../xsimd/memory/xsimd_aligned_allocator.hpp | 349 ++ .../xsimd/memory/xsimd_alignment.hpp | 91 + .../xsimd/types/xsimd_all_registers.hpp | 52 + include/onnxruntime/xsimd/types/xsimd_api.hpp | 2700 ++++++++++++++++ .../xsimd/types/xsimd_avx2_register.hpp | 39 + .../xsimd/types/xsimd_avx512bw_register.hpp | 47 + .../xsimd/types/xsimd_avx512cd_register.hpp | 47 + .../xsimd/types/xsimd_avx512dq_register.hpp | 47 + .../xsimd/types/xsimd_avx512er_register.hpp | 47 + .../xsimd/types/xsimd_avx512f_register.hpp | 73 + .../xsimd/types/xsimd_avx512ifma_register.hpp | 47 + .../xsimd/types/xsimd_avx512pf_register.hpp | 47 + .../xsimd/types/xsimd_avx512vbmi_register.hpp | 47 + .../xsimd_avx512vnni_avx512bw_register.hpp | 50 + .../xsimd_avx512vnni_avx512vbmi_register.hpp | 50 + .../xsimd/types/xsimd_avx_register.hpp | 60 + .../xsimd/types/xsimd_avxvnni_register.hpp | 39 + .../onnxruntime/xsimd/types/xsimd_batch.hpp | 1492 +++++++++ .../xsimd/types/xsimd_batch_constant.hpp | 300 ++ .../xsimd/types/xsimd_emulated_register.hpp | 80 + .../xsimd/types/xsimd_fma3_avx2_register.hpp | 45 + .../xsimd/types/xsimd_fma3_avx_register.hpp | 45 + .../xsimd/types/xsimd_fma3_sse_register.hpp | 45 + .../xsimd/types/xsimd_fma4_register.hpp | 41 + .../xsimd/types/xsimd_generic_arch.hpp | 47 + .../types/xsimd_i8mm_neon64_register.hpp | 50 + .../xsimd/types/xsimd_neon64_register.hpp | 51 + .../xsimd/types/xsimd_neon_register.hpp | 154 + .../xsimd/types/xsimd_register.hpp | 94 + .../xsimd/types/xsimd_rvv_register.hpp | 497 +++ .../xsimd/types/xsimd_sse2_register.hpp | 59 + .../xsimd/types/xsimd_sse3_register.hpp | 44 + .../xsimd/types/xsimd_sse4_1_register.hpp | 43 + .../xsimd/types/xsimd_sse4_2_register.hpp | 43 + .../xsimd/types/xsimd_ssse3_register.hpp | 43 + .../xsimd/types/xsimd_sve_register.hpp | 156 + .../onnxruntime/xsimd/types/xsimd_traits.hpp | 324 ++ .../onnxruntime/xsimd/types/xsimd_utils.hpp | 530 ++++ .../xsimd/types/xsimd_wasm_register.hpp | 59 + include/onnxruntime/xsimd/xsimd.hpp | 69 + .../contrib_ops/cpu/cpu_contrib_kernels.cc | 6 +- .../quantization/firefox_matmul_integer.cc | 236 ++ .../cpu/quantization/firefox_matmul_integer.h | 309 ++ onnxruntime/core/framework/session_state.cc | 3 + .../core/graph/contrib_ops/contrib_defs.cc | 51 + onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 + .../firefox_matmul_integer_test.cc | 50 + .../test/framework/inference_session_test.cc | 3 + onnxruntime/wasm/pre-jsep.js | 60 +- onnxruntime/wasm/pre.js | 91 +- 101 files changed, 37449 insertions(+), 39 deletions(-) create mode 100644 include/onnxruntime/gemmology.h create mode 100644 include/onnxruntime/gemmology_fwd.h create mode 100644 include/onnxruntime/xsimd/arch/generic/xsimd_generic_arithmetic.hpp create mode 100644 include/onnxruntime/xsimd/arch/generic/xsimd_generic_complex.hpp create mode 100644 include/onnxruntime/xsimd/arch/generic/xsimd_generic_details.hpp create mode 100644 include/onnxruntime/xsimd/arch/generic/xsimd_generic_logical.hpp create mode 100644 include/onnxruntime/xsimd/arch/generic/xsimd_generic_math.hpp create mode 100644 include/onnxruntime/xsimd/arch/generic/xsimd_generic_memory.hpp create mode 100644 include/onnxruntime/xsimd/arch/generic/xsimd_generic_rounding.hpp create mode 100644 include/onnxruntime/xsimd/arch/generic/xsimd_generic_trigo.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_avx.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_avx2.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_avx512bw.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_avx512cd.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_avx512dq.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_avx512er.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_avx512f.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_avx512ifma.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_avx512pf.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_avx512vbmi.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_avx512vnni_avx512bw.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_avx512vnni_avx512vbmi.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_avxvnni.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_constants.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_emulated.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_fma3_avx.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_fma3_avx2.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_fma3_sse.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_fma4.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_generic.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_generic_fwd.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_i8mm_neon64.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_isa.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_neon.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_neon64.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_rvv.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_scalar.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_sse2.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_sse3.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_sse4_1.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_sse4_2.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_ssse3.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_sve.hpp create mode 100644 include/onnxruntime/xsimd/arch/xsimd_wasm.hpp create mode 100644 include/onnxruntime/xsimd/config/xsimd_arch.hpp create mode 100644 include/onnxruntime/xsimd/config/xsimd_config.hpp create mode 100644 include/onnxruntime/xsimd/config/xsimd_cpuid.hpp create mode 100644 include/onnxruntime/xsimd/config/xsimd_inline.hpp create mode 100644 include/onnxruntime/xsimd/math/xsimd_rem_pio2.hpp create mode 100644 include/onnxruntime/xsimd/memory/xsimd_aligned_allocator.hpp create mode 100644 include/onnxruntime/xsimd/memory/xsimd_alignment.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_all_registers.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_api.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_avx2_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_avx512bw_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_avx512cd_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_avx512dq_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_avx512er_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_avx512f_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_avx512ifma_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_avx512pf_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_avx512vbmi_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_avx512vnni_avx512bw_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_avx512vnni_avx512vbmi_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_avx_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_avxvnni_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_batch.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_batch_constant.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_emulated_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_fma3_avx2_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_fma3_avx_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_fma3_sse_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_fma4_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_generic_arch.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_i8mm_neon64_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_neon64_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_neon_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_rvv_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_sse2_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_sse3_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_sse4_1_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_sse4_2_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_ssse3_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_sve_register.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_traits.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_utils.hpp create mode 100644 include/onnxruntime/xsimd/types/xsimd_wasm_register.hpp create mode 100644 include/onnxruntime/xsimd/xsimd.hpp create mode 100644 onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc create mode 100644 onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h create mode 100644 onnxruntime/test/contrib_ops/firefox_matmul_integer_test.cc diff --git a/build.sh b/build.sh index bf799ac8b7211..0b293effe6330 100755 --- a/build.sh +++ b/build.sh @@ -1,21 +1,24 @@ #!/bin/bash # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +set -ex # Get directory this script is in -DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" OS=$(uname -s) if [ "$OS" = "Darwin" ]; then - DIR_OS="MacOS" + DIR_OS="MacOS" else - DIR_OS="Linux" + DIR_OS="Linux" fi if [[ "$*" == *"--ios"* ]]; then - DIR_OS="iOS" + DIR_OS="iOS" elif [[ "$*" == *"--android"* ]]; then - DIR_OS="Android" + DIR_OS="Android" fi -python3 $DIR/tools/ci_build/build.py --build_dir $DIR/build/$DIR_OS "$@" +PYTHON="${PYTHON:-python3}" + +$PYTHON $DIR/tools/ci_build/build.py --build_dir $DIR/build/$DIR_OS "$@" diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 66268cefac9ef..3a5575b163b35 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -382,6 +382,7 @@ jsepDownload:_pp_") "SHELL:-s ASYNCIFY_STACK_SIZE=65536" "SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']" "SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','jsepDownload']" + "SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0" ) set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) endif() diff --git a/include/onnxruntime/gemmology.h b/include/onnxruntime/gemmology.h new file mode 100644 index 0000000000000..332afe166870d --- /dev/null +++ b/include/onnxruntime/gemmology.h @@ -0,0 +1,1390 @@ +#ifndef GEMMOLOGY_H +#define GEMMOLOGY_H + +#include "gemmology_fwd.h" + +#include +#include +#include + +#ifdef GEMMOLOGY_WITH_STD_THREAD +#include +#include +#endif + +#include "xsimd/xsimd.hpp" + +namespace gemmology { + +namespace { + +// +// Arch specific implementation of various elementary operations +// + +namespace kernel { + +#ifdef __AVX512BW__ +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, xsimd::batch second, + xsimd::kernel::requires_arch) { + return {_mm512_unpacklo_epi8(first, second), + _mm512_unpackhi_epi8(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {_mm512_unpacklo_epi16(first, second), + _mm512_unpackhi_epi16(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {_mm512_unpacklo_epi32(first, second), + _mm512_unpackhi_epi32(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {_mm512_unpacklo_epi64(first, second), + _mm512_unpackhi_epi64(first, second)}; +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return _mm512_packs_epi16(first, second); +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return _mm512_packs_epi32(first, second); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm512_madd_epi16(x, y); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm512_maddubs_epi16(x, y); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm512_madd_epi16(x, y); +} + +template +inline xsimd::batch +PermuteSummer(xsimd::batch pack0123, + xsimd::batch pack4567, + xsimd::kernel::requires_arch) { + // Form [0th 128-bit register of pack0123, 0st 128-bit register of pack4567, + // 2nd 128-bit register of pack0123, 2nd 128-bit register of pack4567] + __m512i mix0 = + _mm512_mask_permutex_epi64(pack0123, 0xcc, pack4567, (0 << 4) | (1 << 6)); + // Form [1st 128-bit register of pack0123, 1st 128-bit register of pack4567, + // 3rd 128-bit register of pack0123, 3rd 128-bit register of pack4567] + __m512i mix1 = + _mm512_mask_permutex_epi64(pack4567, 0x33, pack0123, 2 | (3 << 2)); + __m512i added = _mm512_add_epi32(mix0, mix1); + // Now we have 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7. + // Fold register over itself. + return _mm256_add_epi32(_mm512_castsi512_si256(added), + _mm512_extracti64x4_epi64(added, 1)); +} +#endif + +#ifdef __AVX2__ +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, xsimd::batch second, + xsimd::kernel::requires_arch) { + return {_mm256_unpacklo_epi8(first, second), + _mm256_unpackhi_epi8(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {_mm256_unpacklo_epi16(first, second), + _mm256_unpackhi_epi16(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {_mm256_unpacklo_epi32(first, second), + _mm256_unpackhi_epi32(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {_mm256_unpacklo_epi64(first, second), + _mm256_unpackhi_epi64(first, second)}; +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return _mm256_packs_epi16(first, second); +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return _mm256_packs_epi32(first, second); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm256_madd_epi16(x, y); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm256_maddubs_epi16(x, y); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm256_maddubs_epi16(xsimd::abs(x), _mm256_sign_epi8(y, x)); +} + +template +inline xsimd::batch +PermuteSummer(xsimd::batch pack0123, + xsimd::batch pack4567, + xsimd::kernel::requires_arch) { + // This instruction generates 1s 2s 3s 4s 5f 6f 7f 8f + __m256i rev = _mm256_permute2f128_si256(pack0123, pack4567, 0x21); + // This instruction generates 1f 2f 3f 4f 5s 6s 7s 8s + __m256i blended = _mm256_blend_epi32(pack0123, pack4567, 0xf0); + return _mm256_add_epi32(rev, blended); +} + +template +inline xsimd::batch Pack0123(xsimd::batch sum0, + xsimd::batch sum1, + xsimd::batch sum2, + xsimd::batch sum3, + xsimd::kernel::requires_arch) { + auto pack01 = _mm256_hadd_epi32(sum0, sum1); + auto pack23 = _mm256_hadd_epi32(sum2, sum3); + return _mm256_hadd_epi32(pack01, pack23); +} + +#ifdef __AVXVNNI__ + +template +inline xsimd::batch +maddw(xsimd::batch x, xsimd::batch y, + xsimd::batch z, + xsimd::kernel::requires_arch) { + return _mm256_dpbusd_avx_epi32(z, x, y); +} +#endif + +#ifdef __AVX512VNNI__ + +template +inline xsimd::batch +maddw(xsimd::batch x, xsimd::batch y, + xsimd::batch z, + xsimd::kernel::requires_arch>) { + return _mm512_dpbusd_epi32(z, x, y); +} + +template +inline xsimd::batch +maddw(xsimd::batch x, xsimd::batch y, + xsimd::batch z, + xsimd::kernel::requires_arch>) { + return _mm512_dpbusd_epi32(z, x, y); +} +#endif + +#endif + +#ifdef __SSSE3__ + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm_maddubs_epi16(x, y); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm_maddubs_epi16(xsimd::abs(x), _mm_sign_epi8(y, x)); +} + +template +inline xsimd::batch Pack0123(xsimd::batch sum0, + xsimd::batch sum1, + xsimd::batch sum2, + xsimd::batch sum3, + xsimd::kernel::requires_arch) { + auto pack01 = _mm_hadd_epi32(sum0, sum1); + auto pack23 = _mm_hadd_epi32(sum2, sum3); + return _mm_hadd_epi32(pack01, pack23); +} +#endif + +#ifdef __SSE2__ +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, xsimd::batch second, + xsimd::kernel::requires_arch) { + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return _mm_packs_epi16(first, second); +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return _mm_packs_epi32(first, second); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return _mm_madd_epi16(x, y); +} + +template +inline xsimd::batch +madd(xsimd::batch a, xsimd::batch b, + xsimd::kernel::requires_arch) { + // Adapted from + // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2 + // a = 0x00 0x01 0xFE 0x04 ... + // b = 0x00 0x02 0x80 0x84 ... + + // To extend signed 8-bit value, MSB has to be set to 0xFF + __m128i sign_mask_b = _mm_cmplt_epi8(b, _mm_setzero_si128()); + + // sign_mask_b = 0x00 0x00 0xFF 0xFF ... + + // Unpack positives with 0x00, negatives with 0xFF + __m128i a_epi16_l = _mm_unpacklo_epi8(a, _mm_setzero_si128()); + __m128i a_epi16_h = _mm_unpackhi_epi8(a, _mm_setzero_si128()); + __m128i b_epi16_l = _mm_unpacklo_epi8(b, sign_mask_b); + __m128i b_epi16_h = _mm_unpackhi_epi8(b, sign_mask_b); + + // Here - valid 16-bit signed integers corresponding to the 8-bit input + // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ... + + // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts + __m128i madd_epi32_l = _mm_madd_epi16(a_epi16_l, b_epi16_l); + __m128i madd_epi32_h = _mm_madd_epi16(a_epi16_h, b_epi16_h); + + // Now go back from 32-bit values to 16-bit values & signed saturate + return _mm_packs_epi32(madd_epi32_l, madd_epi32_h); +} + +template +inline xsimd::batch +madd(xsimd::batch a, xsimd::batch b, + xsimd::kernel::requires_arch) { + // adapted + // https://stackoverflow.com/questions/19957709/how-to-achieve-8bit-madd-using-sse2 + // a = 0x00 0x01 0xFE 0x04 ... + // b = 0x00 0x02 0x80 0x84 ... + + // To extend signed 8-bit value, MSB has to be set to 0xFF + __m128i sign_mask_a = _mm_cmplt_epi8(a, _mm_setzero_si128()); + __m128i sign_mask_b = _mm_cmplt_epi8(b, _mm_setzero_si128()); + + // sign_mask_a = 0x00 0x00 0xFF 0x00 ... + // sign_mask_b = 0x00 0x00 0xFF 0xFF ... + + // Unpack positives with 0x00, negatives with 0xFF + __m128i a_epi16_l = _mm_unpacklo_epi8(a, sign_mask_a); + __m128i a_epi16_h = _mm_unpackhi_epi8(a, sign_mask_a); + __m128i b_epi16_l = _mm_unpacklo_epi8(b, sign_mask_b); + __m128i b_epi16_h = _mm_unpackhi_epi8(b, sign_mask_b); + + // Here - valid 16-bit signed integers corresponding to the 8-bit input + // a_epi16_l = 0x00 0x00 0x01 0x00 0xFE 0xFF 0x04 0x00 ... + + // Get the a[i] * b[i] + a[i+1] * b[i+1] for both low and high parts + __m128i madd_epi32_l = _mm_madd_epi16(a_epi16_l, b_epi16_l); + __m128i madd_epi32_h = _mm_madd_epi16(a_epi16_h, b_epi16_h); + + // Now go back from 32-bit values to 16-bit values & signed saturate + return _mm_packs_epi32(madd_epi32_l, madd_epi32_h); +} + +template +inline std::tuple, xsimd::batch> +PermuteSummer(xsimd::batch pack0123, + xsimd::batch pack4567, + xsimd::kernel::requires_arch) { + return {pack0123, pack4567}; +} + +#endif + +#if __ARM_ARCH >= 7 +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, xsimd::batch second, + xsimd::kernel::requires_arch) { + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {xsimd::zip_lo(first, second), xsimd::zip_hi(first, second)}; +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + + return vcombine_s8(vqmovn_s16(first), vqmovn_s16(second)); +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return vcombine_s16(vqmovn_s32(first), vqmovn_s32(second)); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + + int32x4_t low = vmull_s16(vget_low_s16(x), vget_low_s16(y)); + int32x4_t high = vmull_s16(vget_high_s16(x), vget_high_s16(y)); + + int32x2_t low_sum = vpadd_s32(vget_low_s32(low), vget_high_s32(low)); + int32x2_t high_sum = vpadd_s32(vget_low_s32(high), vget_high_s32(high)); + + return vcombine_s32(low_sum, high_sum); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + + // This would be much simpler if x86 would choose to zero extend OR sign + // extend, not both. This could probably be optimized better. + + // Zero extend x + int16x8_t x_odd = + vreinterpretq_s16_u16(vshrq_n_u16(vreinterpretq_u16_u8(x), 8)); + int16x8_t x_even = vreinterpretq_s16_u16( + vbicq_u16(vreinterpretq_u16_u8(x), vdupq_n_u16(0xff00))); + + // Sign extend by shifting left then shifting right. + int16x8_t y_even = vshrq_n_s16(vshlq_n_s16(vreinterpretq_s16_s8(y), 8), 8); + int16x8_t y_odd = vshrq_n_s16(vreinterpretq_s16_s8(y), 8); + + // multiply + int16x8_t prod1 = vmulq_s16(x_even, y_even); + int16x8_t prod2 = vmulq_s16(x_odd, y_odd); + + // saturated add + return vqaddq_s16(prod1, prod2); +} + +template +inline xsimd::batch +madd(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + int16x8_t low = vmull_s8(vget_low_s8(x), vget_low_s8(y)); + int16x8_t high = vmull_s8(vget_high_s8(x), vget_high_s8(y)); + + int16x4_t low_sum = vpadd_s16(vget_low_s16(low), vget_high_s16(low)); + int16x4_t high_sum = vpadd_s16(vget_low_s16(high), vget_high_s16(high)); + + return vcombine_s16(low_sum, high_sum); +} + +template +inline std::tuple, xsimd::batch> +PermuteSummer(xsimd::batch pack0123, + xsimd::batch pack4567, + xsimd::kernel::requires_arch) { + return {pack0123, pack4567}; +} +#endif + +#ifdef __aarch64__ +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, xsimd::batch second, + xsimd::kernel::requires_arch) { + return {vzip1q_s8(first, second), vzip2q_s8(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {vzip1q_s16(first, second), vzip2q_s16(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {vzip1q_s32(first, second), vzip2q_s32(first, second)}; +} + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return {vzip1q_s64(first, second), vzip2q_s64(first, second)}; +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + + return vqmovn_high_s16(vqmovn_s16(first), second); +} + +template +xsimd::batch +deinterleave(xsimd::batch first, + xsimd::batch second, + xsimd::kernel::requires_arch) { + return vqmovn_high_s32(vqmovn_s32(first), second); +} + +#ifdef __ARM_FEATURE_MATMUL_INT8 +template +inline xsimd::batch +maddw(xsimd::batch x, xsimd::batch y, + xsimd::batch z, + xsimd::kernel::requires_arch>) { + return vusdotq_s32(z, x, y); +} +#endif + +template +inline xsimd::batch +maddw(xsimd::batch x, xsimd::batch y, + xsimd::batch z, + xsimd::kernel::requires_arch) { + int16x8_t tl = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(x))), + vmovl_s8(vget_low_s8(y))); + int16x8_t th = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(x))), + vmovl_s8(vget_high_s8(y))); + return vpadalq_s16(vpadalq_s16(z, tl), th); +} + +template +inline xsimd::batch +maddw(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + int16x8_t tl = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(x))), + vmovl_s8(vget_low_s8(y))); + int16x8_t th = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(x))), + vmovl_s8(vget_high_s8(y))); + return vpadalq_s16(vpaddlq_s16(tl), th); +} + +template +inline xsimd::batch Pack0123(xsimd::batch sum0, + xsimd::batch sum1, + xsimd::batch sum2, + xsimd::batch sum3, + xsimd::kernel::requires_arch) { + auto pack01 = vpaddq_s32(sum0, sum1); + auto pack23 = vpaddq_s32(sum2, sum3); + return vpaddq_s32(pack01, pack23); +} + +#endif + +template +inline xsimd::batch +maddw(xsimd::batch x, xsimd::batch y, + xsimd::batch z, + xsimd::kernel::requires_arch) { + return z + madd(xsimd::batch(1), madd(x, y, Arch{}), Arch{}); +} + +template +inline xsimd::batch +maddw(xsimd::batch x, xsimd::batch y, + xsimd::kernel::requires_arch) { + return maddw(x, y, xsimd::batch(0), Arch{}); +} + +} // namespace kernel + +// +// Generic dispatcher for interleave, deinterleave madd and PermuteSummer +// + +template +std::tuple, xsimd::batch> +interleave(xsimd::batch first, xsimd::batch second) { + return kernel::interleave(first, second, Arch{}); +} + +template +xsimd::batch deinterleave(xsimd::batch first, + xsimd::batch second) { + return kernel::deinterleave(first, second, Arch{}); +} +template +xsimd::batch deinterleave(xsimd::batch first, + xsimd::batch second) { + return kernel::deinterleave(first, second, Arch{}); +} + +template +inline xsimd::batch madd(xsimd::batch x, + xsimd::batch y) { + return kernel::madd(x, y, Arch{}); +} +template +inline xsimd::batch madd(xsimd::batch x, + xsimd::batch y) { + return kernel::madd(x, y, Arch{}); +} +template +inline xsimd::batch madd(xsimd::batch x, + xsimd::batch y) { + return kernel::madd(x, y, Arch{}); +} +template +inline xsimd::batch maddw(xsimd::batch x, + xsimd::batch y, + xsimd::batch z + ) { + return kernel::maddw(x, y, z, Arch{}); +} +template +inline xsimd::batch maddw(xsimd::batch x, + xsimd::batch y + ) { + return kernel::maddw(x, y, Arch{}); +} + +template +inline auto PermuteSummer(xsimd::batch pack0123, + xsimd::batch pack4567) + -> decltype(kernel::PermuteSummer(pack0123, pack4567, Arch{})) { + return kernel::PermuteSummer(pack0123, pack4567, Arch{}); +} + + +namespace kernel { + + template + inline xsimd::batch Pack0123(xsimd::batch sum0, + xsimd::batch sum1, + xsimd::batch sum2, + xsimd::batch sum3, + xsimd::kernel::requires_arch) { + + std::tie(sum0, sum1) = interleave(sum0, sum1, Arch{}); + auto pack01 = sum0 + sum1; + std::tie(sum2, sum3) = interleave(sum2, sum3, Arch{}); + auto pack23 = sum2 + sum3; + + auto packed = interleave(xsimd::bitwise_cast(pack01), + xsimd::bitwise_cast(pack23), + Arch{}); + return xsimd::bitwise_cast(std::get<0>(packed)) + + xsimd::bitwise_cast(std::get<1>(packed)); + } +} + +template +inline xsimd::batch Pack0123(xsimd::batch sum0, + xsimd::batch sum1, + xsimd::batch sum2, + xsimd::batch sum3) { + return kernel::Pack0123(sum0, sum1, sum2, sum3, Arch{}); +} + +template +static inline xsimd::batch +quantize(xsimd::batch input, + xsimd::batch quant_mult) { + return xsimd::nearbyint_as_int(input * quant_mult); +} + +template +inline xsimd::batch +QuantizerGrab(const float *input, xsimd::batch quant_mult_reg) { + return quantize(xsimd::batch::load_unaligned(input), + quant_mult_reg); +} + +#ifdef __AVX512BW__ +inline __m512 Concat(const __m256 first, const __m256 second) { + // INTGEMM_AVX512DQ but that goes with INTGEMM_AVX512BW anyway. + return _mm512_insertf32x8(_mm512_castps256_ps512(first), second, 1); +} + +// Like QuantizerGrab, but allows 32-byte halves (i.e. 8 columns) to be +// controlled independently. +/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set + * INTGEMM_AVX512BW */ +inline __m512i QuantizerGrabHalves(const float *input0, const float *input1, + const __m512 quant_mult_reg) { + __m512 appended = Concat(_mm256_loadu_ps(input0), _mm256_loadu_ps(input1)); + appended = _mm512_mul_ps(appended, quant_mult_reg); + return _mm512_cvtps_epi32(appended); +} +#else +template +inline xsimd::batch +QuantizerGrabHalves(const float *input0, const float *input1, + xsimd::batch quant_mult_reg); +#endif + +/* Read 8 floats at a time from input0, input1, input2, and input3. Quantize + * them to 8-bit by multiplying with quant_mult_reg then rounding. Concatenate + * the result into one register and return it. + */ +class QuantizeTile8 { + template struct Tiler { + static constexpr uint32_t get(std::size_t i, std::size_t n) { + size_t factor = xsimd::batch::size / 4; + return (i % factor) * 4 + i / factor; + } + }; + +public: + template + static inline xsimd::batch + Consecutive(xsimd::batch quant_mult, const float *input) { + return Tile(quant_mult, input + 0 * xsimd::batch::size, + input + 1 * xsimd::batch::size, + input + 2 * xsimd::batch::size, + input + 3 * xsimd::batch::size); + } + + template + static inline xsimd::batch + ConsecutiveU(xsimd::batch quant_mult, const float *input) { + return TileU(quant_mult, input + 0 * xsimd::batch::size, + input + 1 * xsimd::batch::size, + input + 2 * xsimd::batch::size, + input + 3 * xsimd::batch::size); + } + + template + static inline xsimd::batch + ConsecutiveWithWrapping(xsimd::batch quant_mult, + const float *input, size_t cols_left, size_t cols, + size_t row_step) { + using batchf32 = xsimd::batch; + const float *inputs[4]; + for (size_t i = 0; i < std::size(inputs); ++i) { + while (cols_left < batchf32::size) { + input += cols * (row_step - 1); + cols_left += cols; + } + inputs[i] = input; + input += batchf32::size; + cols_left -= batchf32::size; + } + return Tile(quant_mult, inputs[0], inputs[1], inputs[2], inputs[3]); + } + + template + static inline xsimd::batch + ForReshape(xsimd::batch quant_mult, const float *input, + size_t cols) { + using batchf32 = xsimd::batch; + using batch8 = xsimd::batch; + using batch16 = xsimd::batch; + using batch32 = xsimd::batch; + + // Put higher rows in the second half of the register. These will jumble + // around in the same way then conveniently land in the right place. + if constexpr (batchf32::size == 16) { + const batch8 neg127(-127); + // In reverse order: grabbing the first 32-bit values from each 128-bit + // register, then the second 32-bit values, etc. Grab 4 registers at a + // time in 32-bit format. + batch32 g0 = + QuantizerGrabHalves(input + 0 * cols, input + 2 * cols, quant_mult); + batch32 g1 = + QuantizerGrabHalves(input + 16 * cols, input + 18 * cols, quant_mult); + batch32 g2 = + QuantizerGrabHalves(input + 32 * cols, input + 34 * cols, quant_mult); + batch32 g3 = + QuantizerGrabHalves(input + 48 * cols, input + 50 * cols, quant_mult); + + // Pack 32-bit to 16-bit. + batch16 packed0 = deinterleave(g0, g1); + batch16 packed1 = deinterleave(g2, g3); + // Pack 16-bit to 8-bit. + batch8 packed = deinterleave(packed0, packed1); + // Ban -128. + packed = xsimd::max(packed, neg127); + + return xsimd::bitwise_cast( + xsimd::swizzle(xsimd::bitwise_cast(packed), + xsimd::make_batch_constant>())); + } else if constexpr (batchf32::size == 8) + return Tile(quant_mult, input, input + 2 * cols, input + 16 * cols, + input + 18 * cols); + else if constexpr (batchf32::size == 4) + // Skip a row. + return Tile(quant_mult, input, input + 4, input + 2 * cols, + input + 2 * cols + 4); + else + return {}; + } + + template + static inline xsimd::batch + Tile(xsimd::batch quant_mult, const float *input0, + const float *input1, const float *input2, const float *input3) { + using batch8 = xsimd::batch; + using batch16 = xsimd::batch; + using batch32 = xsimd::batch; + + const batch8 neg127(-127); + // Grab 4 registers at a time in 32-bit format. + batch32 g0 = QuantizerGrab(input0, quant_mult); + batch32 g1 = QuantizerGrab(input1, quant_mult); + batch32 g2 = QuantizerGrab(input2, quant_mult); + batch32 g3 = QuantizerGrab(input3, quant_mult); + // Pack 32-bit to 16-bit. + batch16 packed0 = deinterleave(g0, g1); + batch16 packed1 = deinterleave(g2, g3); + // Pack 16-bit to 8-bit. + batch8 packed = deinterleave(packed0, packed1); + // Ban -128. + packed = xsimd::max(packed, neg127); + + if constexpr (batch32::size == 4) + return packed; + // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14 + // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7 + // Technically this could be removed so long as the rows are bigger than 16 + // and the values are only used for GEMM. + return xsimd::bitwise_cast( + xsimd::swizzle(xsimd::bitwise_cast(packed), + xsimd::make_batch_constant>())); + } + +private: + // A version that produces uint8_ts + template + static inline xsimd::batch + TileU(xsimd::batch quant_mult, const float *input0, + const float *input1, const float *input2, const float *input3) { + using batch8 = xsimd::batch; + using batch16 = xsimd::batch; + using batch32 = xsimd::batch; + + const batch8 neg127 = -127; + const batch8 pos127 = +127; + // Grab 4 registers at a time in 32-bit format. + batch32 g0 = QuantizerGrab(input0, quant_mult); + batch32 g1 = QuantizerGrab(input1, quant_mult); + batch32 g2 = QuantizerGrab(input2, quant_mult); + batch32 g3 = QuantizerGrab(input3, quant_mult); + // Pack 32-bit to 16-bit. + batch16 packed0 = deinterleave(g0, g1); + batch16 packed1 = deinterleave(g2, g3); + // Pack 16-bit to 8-bit. + batch8 packed = deinterleave(packed0, packed1); + // Ban -128. + packed = xsimd::max(packed, neg127); // Could be removed if we use +128 + packed = packed + pos127; + if (batch32::size == 4) + return xsimd::bitwise_cast(packed); + // Currently in 0 1 2 3 8 9 10 11 16 17 18 19 24 25 26 27 4 5 6 7 12 13 14 + // 15 20 21 22 23 28 29 30 31 Or as 32-bit integers 0 2 4 6 1 3 5 7 + // Technically this could be removed so long as the rows are bigger than 16 + // and the values are only used for GEMM. + return xsimd::bitwise_cast( + xsimd::swizzle(xsimd::bitwise_cast(packed), + xsimd::make_batch_constant>())); + } +}; + +template +inline void Transpose16InLane( + xsimd::batch &r0, xsimd::batch &r1, + xsimd::batch &r2, xsimd::batch &r3, + xsimd::batch &r4, xsimd::batch &r5, + xsimd::batch &r6, xsimd::batch &r7) { + /* r0: columns 0 1 2 3 4 5 6 7 from row 0 + r1: columns 0 1 2 3 4 5 6 7 from row 1*/ + auto r0_16 = xsimd::bitwise_cast(r0); + auto r1_16 = xsimd::bitwise_cast(r1); + auto r2_16 = xsimd::bitwise_cast(r2); + auto r3_16 = xsimd::bitwise_cast(r3); + auto r4_16 = xsimd::bitwise_cast(r4); + auto r5_16 = xsimd::bitwise_cast(r5); + auto r6_16 = xsimd::bitwise_cast(r6); + auto r7_16 = xsimd::bitwise_cast(r7); + + std::tie(r0_16, r1_16) = interleave(r0_16, r1_16); + std::tie(r2_16, r3_16) = interleave(r2_16, r3_16); + std::tie(r4_16, r5_16) = interleave(r4_16, r5_16); + std::tie(r6_16, r7_16) = interleave(r6_16, r7_16); + /* r0: columns 0 0 1 1 2 2 3 3 from rows 0 and 1 + r1: columns 4 4 5 5 6 6 7 7 from rows 0 and 1 + r2: columns 0 0 1 1 2 2 3 3 from rows 2 and 3 + r3: columns 4 4 5 5 6 6 7 7 from rows 2 and 3 + r4: columns 0 0 1 1 2 2 3 3 from rows 4 and 5 + r5: columns 4 4 5 5 6 6 7 7 from rows 4 and 5 + r6: columns 0 0 1 1 2 2 3 3 from rows 6 and 7 + r7: columns 4 4 5 5 6 6 7 7 from rows 6 and 7*/ + auto r0_32 = xsimd::bitwise_cast(r0_16); + auto r2_32 = xsimd::bitwise_cast(r2_16); + auto r1_32 = xsimd::bitwise_cast(r1_16); + auto r3_32 = xsimd::bitwise_cast(r3_16); + auto r4_32 = xsimd::bitwise_cast(r4_16); + auto r6_32 = xsimd::bitwise_cast(r6_16); + auto r5_32 = xsimd::bitwise_cast(r5_16); + auto r7_32 = xsimd::bitwise_cast(r7_16); + + std::tie(r0_32, r2_32) = interleave(r0_32, r2_32); + std::tie(r1_32, r3_32) = interleave(r1_32, r3_32); + std::tie(r4_32, r6_32) = interleave(r4_32, r6_32); + std::tie(r5_32, r7_32) = interleave(r5_32, r7_32); + /* r0: columns 0 0 0 0 1 1 1 1 from rows 0, 1, 2, and 3 + r1: columns 4 4 4 4 5 5 5 5 from rows 0, 1, 2, and 3 + r2: columns 2 2 2 2 3 3 3 3 from rows 0, 1, 2, and 3 + r3: columns 6 6 6 6 7 7 7 7 from rows 0, 1, 2, and 3 + r4: columns 0 0 0 0 1 1 1 1 from rows 4, 5, 6, and 7 + r5: columns 4 4 4 4 5 5 5 5 from rows 4, 5, 6, and 7 + r6: columns 2 2 2 2 3 3 3 3 from rows 4, 5, 6, and 7 + r7: columns 6 6 6 6 7 7 7 7 from rows 4, 5, 6, and 7*/ + + auto r0_64 = xsimd::bitwise_cast(r0_32); + auto r2_64 = xsimd::bitwise_cast(r2_32); + auto r1_64 = xsimd::bitwise_cast(r1_32); + auto r3_64 = xsimd::bitwise_cast(r3_32); + auto r4_64 = xsimd::bitwise_cast(r4_32); + auto r6_64 = xsimd::bitwise_cast(r6_32); + auto r5_64 = xsimd::bitwise_cast(r5_32); + auto r7_64 = xsimd::bitwise_cast(r7_32); + + std::tie(r0_64, r4_64) = interleave(r0_64, r4_64); + std::tie(r1_64, r5_64) = interleave(r1_64, r5_64); + std::tie(r2_64, r6_64) = interleave(r2_64, r6_64); + std::tie(r3_64, r7_64) = interleave(r3_64, r7_64); + + r0 = xsimd::bitwise_cast(r0_64); + r1 = xsimd::bitwise_cast(r1_64); + r2 = xsimd::bitwise_cast(r2_64); + r3 = xsimd::bitwise_cast(r3_64); + r4 = xsimd::bitwise_cast(r4_64); + r5 = xsimd::bitwise_cast(r5_64); + r6 = xsimd::bitwise_cast(r6_64); + r7 = xsimd::bitwise_cast(r7_64); + /* r0: columns 0 0 0 0 0 0 0 0 from rows 0 through 7 + r1: columns 4 4 4 4 4 4 4 4 from rows 0 through 7 + r2: columns 2 2 2 2 2 2 2 2 from rows 0 through 7 + r3: columns 6 6 6 6 6 6 6 6 from rows 0 through 7 + r4: columns 1 1 1 1 1 1 1 1 from rows 0 through 7 + r5: columns 5 5 5 5 5 5 5 5 from rows 0 through 7*/ + /* Empirically gcc is able to remove these movs and just rename the outputs of + * Interleave64. */ + std::swap(r1, r4); + std::swap(r3, r6); +} + +template +void SelectColumnsOfB(const xsimd::batch *input, + xsimd::batch *output, + size_t rows_bytes /* number of bytes in a row */, + const IntegerTy *cols_begin, const IntegerTy *cols_end) { + using batch8 = xsimd::batch; + /* Do columns for multiples of 8.*/ + size_t register_rows = rows_bytes / batch8::size; + const batch8 *starts[8]; + for (; cols_begin != cols_end; cols_begin += 8) { + for (size_t k = 0; k < 8; ++k) { + starts[k] = + input + (cols_begin[k] & 7) + (cols_begin[k] & ~7) * register_rows; + } + for (size_t r = 0; r < register_rows; ++r) { + for (size_t k = 0; k < 8; ++k) { + *(output++) = *starts[k]; + starts[k] += 8; + } + } + } +} + +} // namespace + +namespace callbacks { +template +xsimd::batch Unquantize::operator()(xsimd::batch total, size_t, size_t, + size_t) { + return xsimd::batch_cast(total) * unquant_mult; +} + +template +std::tuple, xsimd::batch> Unquantize::operator()( + std::tuple, xsimd::batch> total, + size_t, size_t, size_t) { + return std::make_tuple( + xsimd::batch_cast(std::get<0>(total)) * unquant_mult, + xsimd::batch_cast(std::get<1>(total)) * unquant_mult); +} + +template +xsimd::batch AddBias::operator()(xsimd::batch total, size_t, + size_t col_idx, size_t) { + return total + xsimd::batch::load_aligned(bias_addr + col_idx); +} + +template +std::tuple, xsimd::batch> +AddBias::operator()( + std::tuple, xsimd::batch> total, + size_t, size_t col_idx, size_t) { + return std::make_tuple( + std::get<0>(total) + xsimd::batch::load_aligned(bias_addr + col_idx + 0), + std::get<1>(total) + + xsimd::batch::load_aligned(bias_addr + col_idx + + xsimd::batch::size)); +} + +template +void Write::operator()(xsimd::batch result, size_t row_idx, + size_t col_idx, size_t col_size) { + result.store_aligned(output_addr + row_idx * col_size + col_idx); +} + +template +void Write::operator()(xsimd::batch result, size_t row_idx, + size_t col_idx, size_t col_size) { + xsimd::bitwise_cast(result).store_aligned( + output_addr + row_idx * col_size + col_idx); +} + +template +void Write::operator()( + std::tuple, xsimd::batch> result, + size_t row_idx, size_t col_idx, size_t col_size) { + std::get<0>(result).store_aligned(output_addr + row_idx * col_size + col_idx + + 0); + std::get<1>(result).store_aligned(output_addr + row_idx * col_size + col_idx + + xsimd::batch::size); +} + +template +void Write::operator()( + std::tuple, xsimd::batch> result, + size_t row_idx, size_t col_idx, size_t col_size) { + xsimd::bitwise_cast(std::get<0>(result)) + .store_aligned(output_addr + row_idx * col_size + col_idx + 0); + xsimd::bitwise_cast(std::get<1>(result)) + .store_aligned(output_addr + row_idx * col_size + col_idx + + xsimd::batch::size); +} + +template +void UnquantizeAndWrite::operator()(T const &total, size_t row_idx, + size_t col_idx, size_t col_size) { + auto unquantized = unquantize(total, row_idx, col_idx, col_size); + write(unquantized, row_idx, col_idx, col_size); +} + +template +void UnquantizeAndAddBiasAndWrite::operator()(T const &total, size_t row_idx, + size_t col_idx, size_t col_size) { + auto unquantized = unquantize(total, row_idx, col_idx, col_size); + auto bias_added = add_bias(unquantized, row_idx, col_idx, col_size); + write(bias_added, row_idx, col_idx, col_size); +} +} // namespace callbacks + +template +void Engine::QuantizeU(const float *input, uint8_t *output, + float quant_mult, size_t size) { + using batch8 = xsimd::batch; + + xsimd::batch q(quant_mult); + const float *end = input + size; + for (; input != end; input += batch8::size, output += batch8::size) { + auto tile = QuantizeTile8::ConsecutiveU(q, input); + tile.store_aligned(output); + } +} + +template +void Engine::Quantize(const float *const input, int8_t *const output, + float quant_mult, size_t size) { + using batch8 = xsimd::batch; + + const std::size_t kBatch = batch8::size; + const std::size_t fast_end = size & ~(kBatch - 1); + + xsimd::batch q(quant_mult); + for (std::size_t i = 0; i < fast_end; i += kBatch) { + auto tile = QuantizeTile8::Consecutive(q, input + i); + tile.store_aligned(output + i); + } + + std::size_t overhang = size & (kBatch - 1); + if (!overhang) + return; + /* Each does size(xsimd::batch) / 32 == kBatch / 4 floats at a + * time. If we're allowed to read one of them, then we can read the whole + * register. + */ + const float *inputs[4]; + std::size_t i; + for (i = 0; i < (overhang + (kBatch / 4) - 1) / (kBatch / 4); ++i) { + inputs[i] = &input[fast_end + i * (kBatch / 4)]; + } + /* These will be clipped off. */ + for (; i < 4; ++i) { + inputs[i] = &input[fast_end]; + } + auto result = + QuantizeTile8::Tile(q, inputs[0], inputs[1], inputs[2], inputs[3]); + alignas(Arch::alignment()) int8_t buffer[kBatch]; + result.store_aligned(buffer); + std::memcpy(output + (size & ~(kBatch - 1)), buffer, overhang); +} + +template +template +void Engine::SelectColumnsB(const int8_t *input, int8_t *output, + size_t rows, const IntegerTy *cols_begin, + const IntegerTy *cols_end) { + using batch8 = xsimd::batch; + SelectColumnsOfB(reinterpret_cast(input), + reinterpret_cast(output), rows, cols_begin, + cols_end); +} + +template +void Engine::PrepareBTransposed(const float *input, int8_t *output, + float quant_mult, size_t cols, + size_t rows) { + using batch8 = xsimd::batch; + const size_t RegisterElemsInt = batch8::size; + const size_t kColStride = 8; + + xsimd::batch q(quant_mult); + auto *output_it = reinterpret_cast(output); + size_t r = 0; + size_t c = 0; + while (r < rows) { + for (size_t ri = 0; ri < 8; ++ri) + *output_it++ = QuantizeTile8::ConsecutiveWithWrapping( + q, input + (r + ri) * cols + c, cols - c, cols, 8); + c += RegisterElemsInt; + while (c >= cols) { + r += kColStride; + c -= cols; + } + } +} + +template +void Engine::PrepareBQuantizedTransposed(const int8_t *input, + int8_t *output, size_t cols, + size_t rows) { + using batch8 = xsimd::batch; + const size_t RegisterElems = batch8::size; + const size_t kColStride = 8; + + auto *output_it = reinterpret_cast(output); + for (size_t r = 0; r < rows; r += kColStride) + for (size_t c = 0; c < cols; c += RegisterElems) + for (size_t ri = 0; ri < 8; ++ri) + *output_it++ = + *reinterpret_cast(input + (r + ri) * cols + c); +} + +template +void Engine::PrepareB(const float *input, int8_t *output_shadow, + float quant_mult, size_t rows, size_t cols) { + using batch8 = xsimd::batch; + + xsimd::batch q(quant_mult); + /* Currently all multipliers have a stride of 8 columns.*/ + const size_t kColStride = 8; + auto *output = reinterpret_cast(output_shadow); + for (size_t c = 0; c < cols; c += kColStride) { + for (size_t r = 0; r < rows; r += sizeof(*output), output += 8) { + output[0] = + QuantizeTile8::ForReshape(q, input + cols * (r + 0) + c, cols); + output[1] = + QuantizeTile8::ForReshape(q, input + cols * (r + 1) + c, cols); + output[2] = + QuantizeTile8::ForReshape(q, input + cols * (r + 4) + c, cols); + output[3] = + QuantizeTile8::ForReshape(q, input + cols * (r + 5) + c, cols); + output[4] = + QuantizeTile8::ForReshape(q, input + cols * (r + 8) + c, cols); + output[5] = + QuantizeTile8::ForReshape(q, input + cols * (r + 9) + c, cols); + output[6] = + QuantizeTile8::ForReshape(q, input + cols * (r + 12) + c, cols); + output[7] = + QuantizeTile8::ForReshape(q, input + cols * (r + 13) + c, cols); + std::tie(output[0], output[1]) = + interleave(xsimd::bitwise_cast(output[0]), + xsimd::bitwise_cast(output[1])); + std::tie(output[2], output[3]) = + interleave(xsimd::bitwise_cast(output[2]), + xsimd::bitwise_cast(output[3])); + std::tie(output[4], output[5]) = + interleave(xsimd::bitwise_cast(output[4]), + xsimd::bitwise_cast(output[5])); + std::tie(output[6], output[7]) = + interleave(xsimd::bitwise_cast(output[6]), + xsimd::bitwise_cast(output[7])); + Transpose16InLane(output[0], output[1], output[2], output[3], output[4], + output[5], output[6], output[7]); + } + } +} + +template +void Engine::PrepareA(const float *input, int8_t *output, + float quant_mult, size_t rows, size_t cols) { + Quantize(input, output, quant_mult, rows * cols); +} + +template +void Engine::Shift::PrepareA(const float *input, uint8_t *output, + float quant_mult, size_t rows, size_t cols) { + QuantizeU(input, output, quant_mult, rows * cols); +} + +struct SequentialExecutionEngine { + + template + inline void operator()(size_t Start, size_t End, size_t Stride, F&& f) { + for(size_t i = Start; i < End; i += Stride) { + f(i); + } + } + +}; + +template +template +void Engine::Shift::Multiply(const uint8_t *A, const int8_t *B, + size_t A_rows, size_t width, size_t B_cols, + Callback callback, ExecutionEngine& engine) { + + using batch8 = xsimd::batch; + using ubatch8 = xsimd::batch; + using batch32 = xsimd::batch; + + engine(0, B_cols, 8, [A, B, A_rows, width, B_cols, &callback](size_t B0_colidx) { + const size_t simd_width = width / batch8::size; + const auto *B0_col = + reinterpret_cast(B) + simd_width * B0_colidx; + /* Process one row of A at a time. Doesn't seem to be faster to do multiple + * rows of A at once.*/ + for (size_t A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { + const auto *A_row = + reinterpret_cast(A + A_rowidx * width); + /* These will be packed 16-bit integers containing sums for each row of B + multiplied by the row of A. Iterate over shared (inner) dimension.*/ + /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is + * declared here.*/ + size_t k = 0; + ubatch8 a = *(A_row + k); + batch32 isum0 = maddw(a, *(B0_col + k * 8)); + batch32 isum1 = maddw(a, *(B0_col + k * 8 + 1)); + batch32 isum2 = maddw(a, *(B0_col + k * 8 + 2)); + batch32 isum3 = maddw(a, *(B0_col + k * 8 + 3)); + batch32 isum4 = maddw(a, *(B0_col + k * 8 + 4)); + batch32 isum5 = maddw(a, *(B0_col + k * 8 + 5)); + batch32 isum6 = maddw(a, *(B0_col + k * 8 + 6)); + batch32 isum7 = maddw(a, *(B0_col + k * 8 + 7)); + for (k = 1; k < simd_width; ++k) { + a = *(A_row + k); + /* Multiply 8-bit, horizontally add to packed 16-bit integers.*/ + /* Upcast to 32-bit and horizontally add.*/ + isum0 = maddw(a, *(B0_col + k * 8 + 0), isum0); + isum1 = maddw(a, *(B0_col + k * 8 + 1), isum1); + isum2 = maddw(a, *(B0_col + k * 8 + 2), isum2); + isum3 = maddw(a, *(B0_col + k * 8 + 3), isum3); + isum4 = maddw(a, *(B0_col + k * 8 + 4), isum4); + isum5 = maddw(a, *(B0_col + k * 8 + 5), isum5); + isum6 = maddw(a, *(B0_col + k * 8 + 6), isum6); + isum7 = maddw(a, *(B0_col + k * 8 + 7), isum7); + } + /* Reduce sums within 128-bit lanes.*/ + auto pack0123 = Pack0123(isum0, isum1, isum2, isum3); + auto pack4567 = Pack0123(isum4, isum5, isum6, isum7); + /*The specific implementation may need to reduce further.*/ + auto total = PermuteSummer(pack0123, pack4567); + callback(total, A_rowidx, B0_colidx, B_cols); + } + }); +} + +template +template +void Engine::Shift::PrepareBias(const int8_t *B, size_t width, + size_t B_cols, Callback C) { + using batch8 = xsimd::batch; + const size_t simd_width = width / batch8::size; + xsimd::batch a(1); + for (size_t j = 0; j < B_cols; j += 8) { + /*Process one row of A at a time. Doesn't seem to be faster to do multiple + * rows of A at once.*/ + const int8_t *B_j = B + j * width; + + /* Rather than initializing as zeros and adding, just initialize the + * first.*/ + /* These will be packed 16-bit integers containing sums for each column of + * B multiplied by the row of A.*/ + /* Upcast to 32-bit and horizontally add. Seems a bit faster if this is + * declared here.*/ + auto isum0 = maddw(a, batch8::load_aligned(&B_j[0 * batch8::size])); + auto isum1 = maddw(a, batch8::load_aligned(&B_j[1 * batch8::size])); + auto isum2 = maddw(a, batch8::load_aligned(&B_j[2 * batch8::size])); + auto isum3 = maddw(a, batch8::load_aligned(&B_j[3 * batch8::size])); + auto isum4 = maddw(a, batch8::load_aligned(&B_j[4 * batch8::size])); + auto isum5 = maddw(a, batch8::load_aligned(&B_j[5 * batch8::size])); + auto isum6 = maddw(a, batch8::load_aligned(&B_j[6 * batch8::size])); + auto isum7 = maddw(a, batch8::load_aligned(&B_j[7 * batch8::size])); + + B_j += 8 * batch8::size; + + for (size_t k = 1; k < simd_width; ++k, B_j += 8 * batch8::size) { + isum0 = maddw(a, batch8::load_aligned(&B_j[0 * batch8::size]), isum0); + isum1 = maddw(a, batch8::load_aligned(&B_j[1 * batch8::size]), isum1); + isum2 = maddw(a, batch8::load_aligned(&B_j[2 * batch8::size]), isum2); + isum3 = maddw(a, batch8::load_aligned(&B_j[3 * batch8::size]), isum3); + isum4 = maddw(a, batch8::load_aligned(&B_j[4 * batch8::size]), isum4); + isum5 = maddw(a, batch8::load_aligned(&B_j[5 * batch8::size]), isum5); + isum6 = maddw(a, batch8::load_aligned(&B_j[6 * batch8::size]), isum6); + isum7 = maddw(a, batch8::load_aligned(&B_j[7 * batch8::size]), isum7); + } + + auto pack0123 = Pack0123(isum0, isum1, isum2, isum3); + auto pack4567 = Pack0123(isum4, isum5, isum6, isum7); + + auto total = PermuteSummer(pack0123, pack4567); + C(total, 0, j, B_cols); + } +} + +} // namespace gemmology + +#endif diff --git a/include/onnxruntime/gemmology_fwd.h b/include/onnxruntime/gemmology_fwd.h new file mode 100644 index 0000000000000..ba5a2490ed879 --- /dev/null +++ b/include/onnxruntime/gemmology_fwd.h @@ -0,0 +1,282 @@ +/*************************************************************** + * _ * + * | | * + * __ _ ___ _ __ ___ _ __ ___ ___ | | ___ __ _ _ _ * + * / _` |/ _ \ '_ ` _ \| '_ ` _ \ / _ \| |/ _ \ / _` | | | | * + * | (_| | __/ | | | | | | | | | | (_) | | (_) | (_| | |_| | * + * \__, |\___|_| |_| |_|_| |_| |_|\___/|_|\___/ \__, |\__, | * + * __/ | __/ | __/ | * + * |___/ |___/ |___/ * + * * + * version 0.1 * + ***************************************************************/ + +#ifndef GEMMOLOGY_FWD_H +#define GEMMOLOGY_FWD_H + +#include +#include +#include +#include "xsimd/xsimd.hpp" + +#ifdef GEMMOLOGY_WITH_STD_THREAD +#include +#include +#endif + +namespace gemmology { + +struct SequentialExecutionEngine; + +#ifdef GEMMOLOGY_WITH_STD_THREAD +struct StdThreadExecutionEngine { + + StdThreadExecutionEngine(size_t PoolSize) : MaxPoolSize(PoolSize) { + Pool.reserve(PoolSize - 1); + } + + template + inline void operator()(size_t Start, size_t End, size_t Stride, F&& f) { + const size_t NbIter = (End - Start) / Stride; + const size_t NbThread = std::min(NbIter, MaxPoolSize); + const size_t Chunk = (NbIter / NbThread) * Stride; + + size_t Curr = Start, Next = Start; + + for(size_t threadID = 0; threadID < NbThread - 1; ++threadID) { + Next += Chunk; + Pool.emplace_back([=]() { + for(size_t i = Curr; i < Next; i += Stride) { + f(i); + }; + }); + Curr = Next; + } + + for(size_t i = Next; i < End; i += Stride) { + f(i); + }; + for(size_t threadID = 0; threadID < Pool.size(); ++threadID) { + Pool[threadID].join(); + } + Pool.clear(); + } + + private: + const size_t MaxPoolSize; + std::vector Pool; + +}; + +#endif + +#ifdef _OPENMP +struct OpenMPExecutionEngine { + + template + inline void operator()(size_t Start, size_t End, size_t Stride, F&& f) { +#pragma omp parallel for + for(size_t i = Start; i < End; i += Stride) { + f(i); + } + } + +}; +#endif + +namespace callbacks { + +struct Unquantize { + float unquant_mult; + template + xsimd::batch operator()(xsimd::batch total, size_t, size_t, size_t); + template + std::tuple, xsimd::batch> operator()( + std::tuple, xsimd::batch> + total, + size_t, size_t, size_t); +}; + +struct AddBias { + const float *bias_addr; + template + xsimd::batch operator()(xsimd::batch total, size_t, size_t col_idx, + size_t); + template + std::tuple, xsimd::batch> + operator()( + std::tuple, xsimd::batch> total, + size_t, size_t col_idx, size_t); +}; + +struct Write { + float *output_addr; + + Write(float *o) : output_addr(o) {} + + template + void operator()(xsimd::batch result, size_t row_idx, + size_t col_idx, size_t col_size); + template + void operator()(xsimd::batch result, size_t row_idx, + size_t col_idx, size_t col_size); + + template + void operator()( + std::tuple, xsimd::batch> result, + size_t row_idx, size_t col_idx, size_t col_size); + + template + void operator()( + std::tuple, xsimd::batch> + result, + size_t row_idx, size_t col_idx, size_t col_size); +}; + +struct UnquantizeAndWrite { + + Unquantize unquantize; + Write write; + + UnquantizeAndWrite(float factor, float *output) + : unquantize{factor}, write{output} {} + + template + void operator()(T const &total, size_t row_idx, size_t col_idx, + size_t col_size); +}; + +struct UnquantizeAndAddBiasAndWrite { + + Unquantize unquantize; + AddBias add_bias; + Write write; + + UnquantizeAndAddBiasAndWrite(float factor, const float *bias, float *output) + : unquantize{factor}, add_bias{bias}, write{output} {} + + template + void operator()(T const &total, size_t row_idx, size_t col_idx, + size_t col_size); +}; + +} // namespace callbacks + +// +// Arch-specific implementation of each routine +// +template struct Engine { + + static void QuantizeU(const float *input, uint8_t *output, float quant_mult, + size_t size); + + static void Quantize(const float *const input, int8_t *const output, + float quant_mult, size_t size); + + template + static void SelectColumnsB(const int8_t *input, int8_t *output, size_t rows, + const IntegerTy *cols_begin, + const IntegerTy *cols_end); + + static void PrepareBTransposed(const float *input, int8_t *output, + float quant_mult, size_t cols, size_t rows); + + static void PrepareBQuantizedTransposed(const int8_t *input, int8_t *output, + size_t cols, size_t rows); + + static void PrepareB(const float *input, int8_t *output_shadow, + float quant_mult, size_t rows, size_t cols); + + static void PrepareA(const float *input, int8_t *output, float quant_mult, + size_t rows, size_t cols); + + struct Shift { + + static void PrepareA(const float *input, uint8_t *output, float quant_mult, + size_t rows, size_t cols); + + template + static void Multiply(const uint8_t *A, const int8_t *B, size_t A_rows, + size_t width, size_t B_cols, Callback callback, + ExecutionEngine& engine); + + template + static void PrepareBias(const int8_t *B, size_t width, size_t B_cols, + Callback C); + }; +}; + +// +// Top-level wrappers that mostly match intgemm API +// + +template +inline void QuantizeU(const float *input, uint8_t *output, float quant_mult, + size_t size) { + return Engine::QuantizeU(input, output, quant_mult, size); +} + +template +inline void Quantize(const float *const input, int8_t *const output, + float quant_mult, size_t size) { + return Engine::Quantize(input, output, quant_mult, size); +} + +template +inline void SelectColumnsB(const int8_t *input, int8_t *output, size_t rows, + const IntegerTy *cols_begin, + const IntegerTy *cols_end) { + return Engine::SelectColumnsB(input, output, rows, cols_begin, + cols_end); +} + +template +inline void PrepareBTransposed(const float *input, int8_t *output, + float quant_mult, size_t cols, size_t rows) { + return Engine::PrepareBTransposed(input, output, quant_mult, cols, + rows); +} + +template +inline void PrepareBQuantizedTransposed(const int8_t *input, int8_t *output, + size_t cols, size_t rows) { + return Engine::PrepareBQuantizedTransposed(input, output, cols, rows); +} + +template +inline void PrepareB(const float *input, int8_t *output_shadow, + float quant_mult, size_t rows, size_t cols) { + return Engine::PrepareB(input, output_shadow, quant_mult, rows, cols); +} + +template +inline void PrepareA(const float *input, int8_t *output, float quant_mult, + size_t rows, size_t cols) { + return Engine::PrepareA(input, output, quant_mult, rows, cols); +} + +namespace Shift { + +template +inline void PrepareA(const float *input, uint8_t *output, float quant_mult, + size_t rows, size_t cols) { + return Engine::Shift::PrepareA(input, output, quant_mult, rows, cols); +} + +template +inline void Multiply(const uint8_t *A, const int8_t *B, size_t A_rows, + size_t width, size_t B_cols, Callback C, ExecutionEngine&& engine={}) { + return Engine::Shift::Multiply(A, B, A_rows, width, B_cols, C, engine); +} + +template +inline void PrepareBias(const int8_t *B, size_t width, size_t B_cols, + Callback C) { + return Engine::Shift::PrepareBias(B, width, B_cols, C); +} + +} // namespace Shift + +} // namespace gemmology + +#endif diff --git a/include/onnxruntime/xsimd/arch/generic/xsimd_generic_arithmetic.hpp b/include/onnxruntime/xsimd/arch/generic/xsimd_generic_arithmetic.hpp new file mode 100644 index 0000000000000..e7916b0d43641 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/generic/xsimd_generic_arithmetic.hpp @@ -0,0 +1,241 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_GENERIC_ARITHMETIC_HPP +#define XSIMD_GENERIC_ARITHMETIC_HPP + +#include +#include +#include + +#include "./xsimd_generic_details.hpp" + +namespace xsimd +{ + + namespace kernel + { + + using namespace types; + + // bitwise_lshift + template ::value, void>::type*/> + XSIMD_INLINE batch bitwise_lshift(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::apply([](T x, T y) noexcept + { return x << y; }, + self, other); + } + + // bitwise_rshift + template ::value, void>::type*/> + XSIMD_INLINE batch bitwise_rshift(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::apply([](T x, T y) noexcept + { return x >> y; }, + self, other); + } + + // decr + template + XSIMD_INLINE batch decr(batch const& self, requires_arch) noexcept + { + return self - T(1); + } + + // decr_if + template + XSIMD_INLINE batch decr_if(batch const& self, Mask const& mask, requires_arch) noexcept + { + return select(mask, decr(self), self); + } + + // div + template ::value, void>::type> + XSIMD_INLINE batch div(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::apply([](T x, T y) noexcept -> T + { return x / y; }, + self, other); + } + + // fma + template + XSIMD_INLINE batch fma(batch const& x, batch const& y, batch const& z, requires_arch) noexcept + { + return x * y + z; + } + + template + XSIMD_INLINE batch, A> fma(batch, A> const& x, batch, A> const& y, batch, A> const& z, requires_arch) noexcept + { + auto res_r = fms(x.real(), y.real(), fms(x.imag(), y.imag(), z.real())); + auto res_i = fma(x.real(), y.imag(), fma(x.imag(), y.real(), z.imag())); + return { res_r, res_i }; + } + + // fms + template + XSIMD_INLINE batch fms(batch const& x, batch const& y, batch const& z, requires_arch) noexcept + { + return x * y - z; + } + + template + XSIMD_INLINE batch, A> fms(batch, A> const& x, batch, A> const& y, batch, A> const& z, requires_arch) noexcept + { + auto res_r = fms(x.real(), y.real(), fma(x.imag(), y.imag(), z.real())); + auto res_i = fma(x.real(), y.imag(), fms(x.imag(), y.real(), z.imag())); + return { res_r, res_i }; + } + + // fnma + template + XSIMD_INLINE batch fnma(batch const& x, batch const& y, batch const& z, requires_arch) noexcept + { + return -x * y + z; + } + + template + XSIMD_INLINE batch, A> fnma(batch, A> const& x, batch, A> const& y, batch, A> const& z, requires_arch) noexcept + { + auto res_r = -fms(x.real(), y.real(), fma(x.imag(), y.imag(), z.real())); + auto res_i = -fma(x.real(), y.imag(), fms(x.imag(), y.real(), z.imag())); + return { res_r, res_i }; + } + + // fnms + template + XSIMD_INLINE batch fnms(batch const& x, batch const& y, batch const& z, requires_arch) noexcept + { + return -x * y - z; + } + + template + XSIMD_INLINE batch, A> fnms(batch, A> const& x, batch, A> const& y, batch, A> const& z, requires_arch) noexcept + { + auto res_r = -fms(x.real(), y.real(), fms(x.imag(), y.imag(), z.real())); + auto res_i = -fma(x.real(), y.imag(), fma(x.imag(), y.real(), z.imag())); + return { res_r, res_i }; + } + + // hadd + template ::value, void>::type*/> + XSIMD_INLINE T hadd(batch const& self, requires_arch) noexcept + { + alignas(A::alignment()) T buffer[batch::size]; + self.store_aligned(buffer); + T res = 0; + for (T val : buffer) + { + res += val; + } + return res; + } + + // incr + template + XSIMD_INLINE batch incr(batch const& self, requires_arch) noexcept + { + return self + T(1); + } + + // incr_if + template + XSIMD_INLINE batch incr_if(batch const& self, Mask const& mask, requires_arch) noexcept + { + return select(mask, incr(self), self); + } + + // mul + template ::value, void>::type*/> + XSIMD_INLINE batch mul(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::apply([](T x, T y) noexcept -> T + { return x * y; }, + self, other); + } + + // rotl + template + XSIMD_INLINE batch rotl(batch const& self, STy other, requires_arch) noexcept + { + constexpr auto N = std::numeric_limits::digits; + return (self << other) | (self >> (N - other)); + } + + // rotr + template + XSIMD_INLINE batch rotr(batch const& self, STy other, requires_arch) noexcept + { + constexpr auto N = std::numeric_limits::digits; + return (self >> other) | (self << (N - other)); + } + + // sadd + template + XSIMD_INLINE batch sadd(batch const& self, batch const& other, requires_arch) noexcept + { + return add(self, other); // no saturated arithmetic on floating point numbers + } + template ::value, void>::type*/> + XSIMD_INLINE batch sadd(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + auto mask = (other >> (8 * sizeof(T) - 1)); + auto self_pos_branch = min(std::numeric_limits::max() - other, self); + auto self_neg_branch = max(std::numeric_limits::min() - other, self); + return other + select(batch_bool(mask.data), self_neg_branch, self_pos_branch); + } + else + { + const auto diffmax = std::numeric_limits::max() - self; + const auto mindiff = min(diffmax, other); + return self + mindiff; + } + } + template + XSIMD_INLINE batch sadd(batch const& self, batch const& other, requires_arch) noexcept + { + return add(self, other); // no saturated arithmetic on floating point numbers + } + + // ssub + template + XSIMD_INLINE batch ssub(batch const& self, batch const& other, requires_arch) noexcept + { + return sub(self, other); // no saturated arithmetic on floating point numbers + } + template ::value, void>::type*/> + XSIMD_INLINE batch ssub(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + return sadd(self, -other); + } + else + { + const auto diff = min(self, other); + return self - diff; + } + } + template + XSIMD_INLINE batch ssub(batch const& self, batch const& other, requires_arch) noexcept + { + return sub(self, other); // no saturated arithmetic on floating point numbers + } + + } + +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/generic/xsimd_generic_complex.hpp b/include/onnxruntime/xsimd/arch/generic/xsimd_generic_complex.hpp new file mode 100644 index 0000000000000..812c592aec03c --- /dev/null +++ b/include/onnxruntime/xsimd/arch/generic/xsimd_generic_complex.hpp @@ -0,0 +1,108 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_GENERIC_COMPLEX_HPP +#define XSIMD_GENERIC_COMPLEX_HPP + +#include + +#include "./xsimd_generic_details.hpp" + +namespace xsimd +{ + + namespace kernel + { + + using namespace types; + + // real + template + XSIMD_INLINE batch real(batch const& self, requires_arch) noexcept + { + return self; + } + + template + XSIMD_INLINE batch real(batch, A> const& self, requires_arch) noexcept + { + return self.real(); + } + + // imag + template + XSIMD_INLINE batch imag(batch const& /*self*/, requires_arch) noexcept + { + return batch(T(0)); + } + + template + XSIMD_INLINE batch imag(batch, A> const& self, requires_arch) noexcept + { + return self.imag(); + } + + // arg + template + XSIMD_INLINE real_batch_type_t> arg(batch const& self, requires_arch) noexcept + { + return atan2(imag(self), real(self)); + } + + // conj + template + XSIMD_INLINE complex_batch_type_t> conj(batch const& self, requires_arch) noexcept + { + return { real(self), -imag(self) }; + } + + // norm + template + XSIMD_INLINE real_batch_type_t> norm(batch const& self, requires_arch) noexcept + { + return { fma(real(self), real(self), imag(self) * imag(self)) }; + } + + // proj + template + XSIMD_INLINE complex_batch_type_t> proj(batch const& self, requires_arch) noexcept + { + using batch_type = complex_batch_type_t>; + using real_batch = typename batch_type::real_batch; + using real_value_type = typename real_batch::value_type; + auto cond = xsimd::isinf(real(self)) || xsimd::isinf(imag(self)); + return select(cond, + batch_type(constants::infinity(), + copysign(real_batch(real_value_type(0)), imag(self))), + batch_type(self)); + } + + template + XSIMD_INLINE batch_bool isnan(batch, A> const& self, requires_arch) noexcept + { + return batch_bool(isnan(self.real()) || isnan(self.imag())); + } + + template + XSIMD_INLINE batch_bool isinf(batch, A> const& self, requires_arch) noexcept + { + return batch_bool(isinf(self.real()) || isinf(self.imag())); + } + + template + XSIMD_INLINE batch_bool isfinite(batch, A> const& self, requires_arch) noexcept + { + return batch_bool(isfinite(self.real()) && isfinite(self.imag())); + } + } +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/generic/xsimd_generic_details.hpp b/include/onnxruntime/xsimd/arch/generic/xsimd_generic_details.hpp new file mode 100644 index 0000000000000..a9af608c88c56 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/generic/xsimd_generic_details.hpp @@ -0,0 +1,316 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_GENERIC_DETAILS_HPP +#define XSIMD_GENERIC_DETAILS_HPP + +#include + +#include "../../math/xsimd_rem_pio2.hpp" +#include "../../types/xsimd_generic_arch.hpp" +#include "../../types/xsimd_utils.hpp" +#include "../xsimd_constants.hpp" + +namespace xsimd +{ + // Forward declaration. Should we put them in a separate file? + template + XSIMD_INLINE batch abs(batch const& self) noexcept; + template + XSIMD_INLINE batch abs(batch, A> const& self) noexcept; + template + XSIMD_INLINE bool any(batch_bool const& self) noexcept; + template + XSIMD_INLINE batch atan2(batch const& self, batch const& other) noexcept; + template + XSIMD_INLINE batch batch_cast(batch const&, batch const& out) noexcept; + template + XSIMD_INLINE batch bitofsign(batch const& self) noexcept; + template + XSIMD_INLINE batch bitwise_cast(batch const& self) noexcept; + template + XSIMD_INLINE batch cos(batch const& self) noexcept; + template + XSIMD_INLINE batch cosh(batch const& self) noexcept; + template + XSIMD_INLINE batch exp(batch const& self) noexcept; + template + XSIMD_INLINE batch fma(batch const& x, batch const& y, batch const& z) noexcept; + template + XSIMD_INLINE batch fms(batch const& x, batch const& y, batch const& z) noexcept; + template + XSIMD_INLINE batch frexp(const batch& x, const batch, A>& e) noexcept; + template + XSIMD_INLINE batch horner(const batch& self) noexcept; + template + XSIMD_INLINE batch hypot(const batch& self) noexcept; + template + XSIMD_INLINE batch_bool is_even(batch const& self) noexcept; + template + XSIMD_INLINE batch_bool is_flint(batch const& self) noexcept; + template + XSIMD_INLINE batch_bool is_odd(batch const& self) noexcept; + template + XSIMD_INLINE typename batch::batch_bool_type isinf(batch const& self) noexcept; + template + XSIMD_INLINE typename batch::batch_bool_type isfinite(batch const& self) noexcept; + template + XSIMD_INLINE typename batch::batch_bool_type isnan(batch const& self) noexcept; + template + XSIMD_INLINE batch ldexp(const batch& x, const batch, A>& e) noexcept; + template + XSIMD_INLINE batch log(batch const& self) noexcept; + template + XSIMD_INLINE batch nearbyint(batch const& self) noexcept; + template + XSIMD_INLINE batch, A> nearbyint_as_int(const batch& x) noexcept; + template + XSIMD_INLINE T reduce_add(batch const&) noexcept; + template + XSIMD_INLINE batch select(batch_bool const&, batch const&, batch const&) noexcept; + template + XSIMD_INLINE batch, A> select(batch_bool const&, batch, A> const&, batch, A> const&) noexcept; + template + XSIMD_INLINE batch sign(batch const& self) noexcept; + template + XSIMD_INLINE batch signnz(batch const& self) noexcept; + template + XSIMD_INLINE batch sin(batch const& self) noexcept; + template + XSIMD_INLINE batch sinh(batch const& self) noexcept; + template + XSIMD_INLINE std::pair, batch> sincos(batch const& self) noexcept; + template + XSIMD_INLINE batch sqrt(batch const& self) noexcept; + template + XSIMD_INLINE batch tan(batch const& self) noexcept; + template + XSIMD_INLINE batch, A> to_float(batch const& self) noexcept; + template + XSIMD_INLINE batch, A> to_int(batch const& self) noexcept; + template + XSIMD_INLINE batch trunc(batch const& self) noexcept; + + namespace kernel + { + + namespace detail + { + template + XSIMD_INLINE batch apply(F&& func, batch const& self, batch const& other) noexcept + { + constexpr std::size_t size = batch::size; + alignas(A::alignment()) T self_buffer[size]; + alignas(A::alignment()) T other_buffer[size]; + self.store_aligned(&self_buffer[0]); + other.store_aligned(&other_buffer[0]); + for (std::size_t i = 0; i < size; ++i) + { + self_buffer[i] = func(self_buffer[i], other_buffer[i]); + } + return batch::load_aligned(self_buffer); + } + + template + XSIMD_INLINE batch apply_transform(F&& func, batch const& self) noexcept + { + static_assert(batch::size == batch::size, + "Source and destination sizes must match"); + constexpr std::size_t src_size = batch::size; + constexpr std::size_t dest_size = batch::size; + alignas(A::alignment()) T self_buffer[src_size]; + alignas(A::alignment()) U other_buffer[dest_size]; + self.store_aligned(&self_buffer[0]); + for (std::size_t i = 0; i < src_size; ++i) + { + other_buffer[i] = func(self_buffer[i]); + } + return batch::load_aligned(other_buffer); + } + } + + // some generic fast_cast conversion + namespace detail + { + template + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch) noexcept + { + return bitwise_cast(self); + } + template + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch) noexcept + { + return bitwise_cast(self); + } + template + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch) noexcept + { + return bitwise_cast(self); + } + template + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch) noexcept + { + return bitwise_cast(self); + } + template + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch) noexcept + { + return bitwise_cast(self); + } + template + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch) noexcept + { + return bitwise_cast(self); + } + template + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch) noexcept + { + return bitwise_cast(self); + } + template + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch) noexcept + { + return bitwise_cast(self); + } + + // Provide a generic uint32_t -> float cast only if we have a + // non-generic int32_t -> float fast_cast + template const&>(), std::declval const&>(), A {}))> + XSIMD_INLINE batch fast_cast(batch const& v, batch const&, requires_arch) noexcept + { + // see https://stackoverflow.com/questions/34066228/how-to-perform-uint32-float-conversion-with-sse + batch msk_lo(0xFFFF); + batch cnst65536f(65536.0f); + + auto v_lo = batch_cast(v & msk_lo); /* extract the 16 lowest significant bits of self */ + auto v_hi = batch_cast(v >> 16); /* 16 most significant bits of v */ + auto v_lo_flt = batch_cast(v_lo); /* No rounding */ + auto v_hi_flt = batch_cast(v_hi); /* No rounding */ + v_hi_flt = cnst65536f * v_hi_flt; /* No rounding */ + return v_hi_flt + v_lo_flt; /* Rounding may occur here, mul and add may fuse to fma for haswell and newer */ + } + + // Provide a generic float -> uint32_t cast only if we have a + // non-generic float -> int32_t fast_cast + template const&>(), std::declval const&>(), A {}))> + XSIMD_INLINE batch fast_cast(batch const& v, batch const&, requires_arch) noexcept + { + auto is_large = v >= batch(1u << 31); + auto small_v = bitwise_cast(batch_cast(v)); + auto large_v = bitwise_cast( + batch_cast(v - batch(1u << 31)) + ^ batch(1u << 31)); + return bitwise_cast(select(is_large, large_v, small_v)); + } + } + + namespace detail + { + // Generic conversion handling machinery. Each architecture must define + // conversion function when such conversions exits in the form of + // intrinsic. Then we use that information to automatically decide whether + // to use scalar or vector conversion when doing load / store / batch_cast + struct with_fast_conversion + { + }; + struct with_slow_conversion + { + }; + + template + struct conversion_type_impl + { + using type = with_slow_conversion; + }; + + using xsimd::detail::void_t; + + template + struct conversion_type_impl&>(), + std::declval&>(), + std::declval()))>> + { + using type = with_fast_conversion; + }; + + template + using conversion_type = typename conversion_type_impl::type; + } + + namespace detail + { + /* origin: boost/simdfunction/horn.hpp*/ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + XSIMD_INLINE B coef() noexcept + { + using value_type = typename B::value_type; + return B(bit_cast(as_unsigned_integer_t(c))); + } + template + XSIMD_INLINE B horner(const B&) noexcept + { + return B(typename B::value_type(0.)); + } + + template + XSIMD_INLINE B horner(const B&) noexcept + { + return coef(); + } + + template + XSIMD_INLINE B horner(const B& self) noexcept + { + return fma(self, horner(self), coef()); + } + + /* origin: boost/simdfunction/horn1.hpp*/ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + XSIMD_INLINE B horner1(const B&) noexcept + { + return B(1.); + } + + template + XSIMD_INLINE B horner1(const B& x) noexcept + { + return x + detail::coef(); + } + + template + XSIMD_INLINE B horner1(const B& x) noexcept + { + return fma(x, horner1(x), detail::coef()); + } + } + + } + +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/generic/xsimd_generic_logical.hpp b/include/onnxruntime/xsimd/arch/generic/xsimd_generic_logical.hpp new file mode 100644 index 0000000000000..4f5dd8e4bd04e --- /dev/null +++ b/include/onnxruntime/xsimd/arch/generic/xsimd_generic_logical.hpp @@ -0,0 +1,208 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_GENERIC_LOGICAL_HPP +#define XSIMD_GENERIC_LOGICAL_HPP + +#include "./xsimd_generic_details.hpp" + +#include + +namespace xsimd +{ + + namespace kernel + { + + using namespace types; + + // count + template + XSIMD_INLINE size_t count(batch_bool const& self, requires_arch) noexcept + { + uint64_t m = self.mask(); + XSIMD_IF_CONSTEXPR(batch_bool::size < 14) + { + // https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSet64 + return (m * 0x200040008001ULL & 0x111111111111111ULL) % 0xf; + } + else + { +#if defined __has_builtin +#if __has_builtin(__builtin_popcountg) +#define builtin_popcount(v) __builtin_popcountg(v) +#endif +#endif + +#ifdef builtin_popcount + return builtin_popcount(m); +#else + // FIXME: we could do better by dispatching to the appropriate + // popcount instruction depending on the arch... + XSIMD_IF_CONSTEXPR(batch_bool::size <= 32) + { + uint32_t m32 = static_cast(m); + // https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel + m32 = m32 - ((m32 >> 1) & 0x55555555); // reuse input as temporary + m32 = (m32 & 0x33333333) + ((m32 >> 2) & 0x33333333); // temp + return (((m32 + (m32 >> 4)) & 0xF0F0F0F) * 0x1010101) >> 24; // count + } + else + { + // https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel + m = m - ((m >> 1) & (uint64_t) ~(uint64_t)0 / 3); // temp + m = (m & (uint64_t) ~(uint64_t)0 / 15 * 3) + ((m >> 2) & (uint64_t) ~(uint64_t)0 / 15 * 3); // temp + m = (m + (m >> 4)) & (uint64_t) ~(uint64_t)0 / 255 * 15; // temp + return (m * ((uint64_t) ~(uint64_t)0 / 255)) >> (sizeof(uint64_t) - 1) * CHAR_BIT; // count + } +#endif + } + } + + // from mask + template + XSIMD_INLINE batch_bool from_mask(batch_bool const&, uint64_t mask, requires_arch) noexcept + { + alignas(A::alignment()) bool buffer[batch_bool::size]; + // This is inefficient but should never be called. It's just a + // temporary implementation until arm support is added. + for (size_t i = 0; i < batch_bool::size; ++i) + buffer[i] = mask & (1ull << i); + return batch_bool::load_aligned(buffer); + } + + // ge + template + XSIMD_INLINE batch_bool ge(batch const& self, batch const& other, requires_arch) noexcept + { + return other <= self; + } + + // gt + template + XSIMD_INLINE batch_bool gt(batch const& self, batch const& other, requires_arch) noexcept + { + return other < self; + } + + // is_even + template + XSIMD_INLINE batch_bool is_even(batch const& self, requires_arch) noexcept + { + return is_flint(self * T(0.5)); + } + + // is_flint + template + XSIMD_INLINE batch_bool is_flint(batch const& self, requires_arch) noexcept + { + auto frac = select(isnan(self - self), constants::nan>(), self - trunc(self)); + return frac == T(0.); + } + + // is_odd + template + XSIMD_INLINE batch_bool is_odd(batch const& self, requires_arch) noexcept + { + return is_even(self - T(1.)); + } + + // isinf + template ::value, void>::type> + XSIMD_INLINE batch_bool isinf(batch const&, requires_arch) noexcept + { + return batch_bool(false); + } + template + XSIMD_INLINE batch_bool isinf(batch const& self, requires_arch) noexcept + { + return abs(self) == std::numeric_limits::infinity(); + } + template + XSIMD_INLINE batch_bool isinf(batch const& self, requires_arch) noexcept + { + return abs(self) == std::numeric_limits::infinity(); + } + + // isfinite + template ::value, void>::type> + XSIMD_INLINE batch_bool isfinite(batch const&, requires_arch) noexcept + { + return batch_bool(true); + } + template + XSIMD_INLINE batch_bool isfinite(batch const& self, requires_arch) noexcept + { + return (self - self) == 0.f; + } + template + XSIMD_INLINE batch_bool isfinite(batch const& self, requires_arch) noexcept + { + return (self - self) == 0.; + } + + // isnan + template ::value, void>::type> + XSIMD_INLINE batch_bool isnan(batch const&, requires_arch) noexcept + { + return batch_bool(false); + } + + // le + template ::value, void>::type> + XSIMD_INLINE batch_bool le(batch const& self, batch const& other, requires_arch) noexcept + { + return (self < other) || (self == other); + } + + // neq + template + XSIMD_INLINE batch_bool neq(batch const& self, batch const& other, requires_arch) noexcept + { + return !(other == self); + } + + // logical_and + template + XSIMD_INLINE batch logical_and(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::apply([](T x, T y) noexcept + { return x && y; }, + self, other); + } + + // logical_or + template + XSIMD_INLINE batch logical_or(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::apply([](T x, T y) noexcept + { return x || y; }, + self, other); + } + + // mask + template + XSIMD_INLINE uint64_t mask(batch_bool const& self, requires_arch) noexcept + { + alignas(A::alignment()) bool buffer[batch_bool::size]; + self.store_aligned(buffer); + // This is inefficient but should never be called. It's just a + // temporary implementation until arm support is added. + uint64_t res = 0; + for (size_t i = 0; i < batch_bool::size; ++i) + if (buffer[i]) + res |= 1ul << i; + return res; + } + } +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/generic/xsimd_generic_math.hpp b/include/onnxruntime/xsimd/arch/generic/xsimd_generic_math.hpp new file mode 100644 index 0000000000000..b8db7f805d141 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/generic/xsimd_generic_math.hpp @@ -0,0 +1,2499 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_GENERIC_MATH_HPP +#define XSIMD_GENERIC_MATH_HPP + +#include "../xsimd_scalar.hpp" +#include "./xsimd_generic_details.hpp" +#include "./xsimd_generic_trigo.hpp" + +#include + +namespace xsimd +{ + + namespace kernel + { + + using namespace types; + // abs + template + XSIMD_INLINE batch abs(batch const& self, requires_arch) noexcept + { + if (std::is_unsigned::value) + return self; + else + { + auto sign = bitofsign(self); + auto inv = self ^ sign; + return inv - sign; + } + } + + template + XSIMD_INLINE batch abs(batch, A> const& z, requires_arch) noexcept + { + return hypot(z.real(), z.imag()); + } + + // avg + namespace detail + { + template + XSIMD_INLINE batch avg(batch const& x, batch const& y, std::true_type, std::false_type) noexcept + { + return (x & y) + ((x ^ y) >> 1); + } + + template + XSIMD_INLINE batch avg(batch const& x, batch const& y, std::true_type, std::true_type) noexcept + { + // Inspired by + // https://stackoverflow.com/questions/5697500/take-the-average-of-two-signed-numbers-in-c + auto t = (x & y) + ((x ^ y) >> 1); + auto t_u = bitwise_cast::type>(t); + auto avg = t + (bitwise_cast(t_u >> (8 * sizeof(T) - 1)) & (x ^ y)); + return avg; + } + + template + XSIMD_INLINE batch avg(batch const& x, batch const& y, std::false_type, std::true_type) noexcept + { + return (x + y) / 2; + } + } + + template + XSIMD_INLINE batch avg(batch const& x, batch const& y, requires_arch) noexcept + { + return detail::avg(x, y, typename std::is_integral::type {}, typename std::is_signed::type {}); + } + + // avgr + namespace detail + { + template + XSIMD_INLINE batch avgr(batch const& x, batch const& y, std::true_type) noexcept + { + constexpr unsigned shift = 8 * sizeof(T) - 1; + auto adj = std::is_signed::value ? ((x ^ y) & 0x1) : (((x ^ y) << shift) >> shift); + return ::xsimd::kernel::avg(x, y, A {}) + adj; + } + + template + XSIMD_INLINE batch avgr(batch const& x, batch const& y, std::false_type) noexcept + { + return ::xsimd::kernel::avg(x, y, A {}); + } + } + + template + XSIMD_INLINE batch avgr(batch const& x, batch const& y, requires_arch) noexcept + { + return detail::avgr(x, y, typename std::is_integral::type {}); + } + + // batch_cast + template + XSIMD_INLINE batch batch_cast(batch const& self, batch const&, requires_arch) noexcept + { + return self; + } + + namespace detail + { + template + XSIMD_INLINE batch batch_cast(batch const& self, batch const& out, requires_arch, with_fast_conversion) noexcept + { + return fast_cast(self, out, A {}); + } + template + XSIMD_INLINE batch batch_cast(batch const& self, batch const&, requires_arch, with_slow_conversion) noexcept + { + static_assert(!std::is_same::value, "there should be no conversion for this type combination"); + using batch_type_in = batch; + using batch_type_out = batch; + static_assert(batch_type_in::size == batch_type_out::size, "compatible sizes"); + alignas(A::alignment()) T_in buffer_in[batch_type_in::size]; + alignas(A::alignment()) T_out buffer_out[batch_type_out::size]; + self.store_aligned(&buffer_in[0]); + std::copy(std::begin(buffer_in), std::end(buffer_in), std::begin(buffer_out)); + return batch_type_out::load_aligned(buffer_out); + } + + } + + template + XSIMD_INLINE batch batch_cast(batch const& self, batch const& out, requires_arch) noexcept + { + return detail::batch_cast(self, out, A {}, detail::conversion_type {}); + } + + // bitofsign + template + XSIMD_INLINE batch bitofsign(batch const& self, requires_arch) noexcept + { + static_assert(std::is_integral::value, "int type implementation"); + if (std::is_unsigned::value) + return batch(0); + else + return self >> (T)(8 * sizeof(T) - 1); + } + + template + XSIMD_INLINE batch bitofsign(batch const& self, requires_arch) noexcept + { + return self & constants::signmask>(); + } + template + XSIMD_INLINE batch bitofsign(batch const& self, requires_arch) noexcept + { + return self & constants::signmask>(); + } + + // bitwise_cast + template + XSIMD_INLINE batch bitwise_cast(batch const& self, batch const&, requires_arch) noexcept + { + return self; + } + + // cbrt + /* origin: boost/simd/arch/common/simd/function/cbrt.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + XSIMD_INLINE batch cbrt(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + batch_type z = abs(self); +#ifndef XSIMD_NO_DENORMALS + auto denormal = z < constants::smallestposval(); + z = select(denormal, z * constants::twotonmb(), z); + batch_type f = select(denormal, constants::twotonmbo3(), batch_type(1.)); +#endif + const batch_type CBRT2(bit_cast(0x3fa14518)); + const batch_type CBRT4(bit_cast(0x3fcb2ff5)); + const batch_type CBRT2I(bit_cast(0x3f4b2ff5)); + const batch_type CBRT4I(bit_cast(0x3f214518)); + using i_type = as_integer_t; + i_type e; + batch_type x = frexp(z, e); + x = detail::horner(x); + auto flag = e >= i_type(0); + i_type e1 = abs(e); + i_type rem = e1; + e1 /= i_type(3); + rem -= e1 * i_type(3); + e = e1 * sign(e); + const batch_type cbrt2 = select(batch_bool_cast(flag), CBRT2, CBRT2I); + const batch_type cbrt4 = select(batch_bool_cast(flag), CBRT4, CBRT4I); + batch_type fact = select(batch_bool_cast(rem == i_type(1)), cbrt2, batch_type(1.)); + fact = select(batch_bool_cast(rem == i_type(2)), cbrt4, fact); + x = ldexp(x * fact, e); + x -= (x - z / (x * x)) * batch_type(1.f / 3.f); +#ifndef XSIMD_NO_DENORMALS + x = (x | bitofsign(self)) * f; +#else + x = x | bitofsign(self); +#endif +#ifndef XSIMD_NO_INFINITIES + return select(self == batch_type(0.) || isinf(self), self, x); +#else + return select(self == batch_type(0.), self, x); +#endif + } + + template + XSIMD_INLINE batch cbrt(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + batch_type z = abs(self); +#ifndef XSIMD_NO_DENORMALS + auto denormal = z < constants::smallestposval(); + z = select(denormal, z * constants::twotonmb(), z); + batch_type f = select(denormal, constants::twotonmbo3(), batch_type(1.)); +#endif + const batch_type CBRT2(bit_cast(int64_t(0x3ff428a2f98d728b))); + const batch_type CBRT4(bit_cast(int64_t(0x3ff965fea53d6e3d))); + const batch_type CBRT2I(bit_cast(int64_t(0x3fe965fea53d6e3d))); + const batch_type CBRT4I(bit_cast(int64_t(0x3fe428a2f98d728b))); + using i_type = as_integer_t; + i_type e; + batch_type x = frexp(z, e); + x = detail::horner(x); + auto flag = e >= typename i_type::value_type(0); + i_type e1 = abs(e); + i_type rem = e1; + e1 /= i_type(3); + rem -= e1 * i_type(3); + e = e1 * sign(e); + const batch_type cbrt2 = select(batch_bool_cast(flag), CBRT2, CBRT2I); + const batch_type cbrt4 = select(batch_bool_cast(flag), CBRT4, CBRT4I); + batch_type fact = select(batch_bool_cast(rem == i_type(1)), cbrt2, batch_type(1.)); + fact = select(batch_bool_cast(rem == i_type(2)), cbrt4, fact); + x = ldexp(x * fact, e); + x -= (x - z / (x * x)) * batch_type(1. / 3.); + x -= (x - z / (x * x)) * batch_type(1. / 3.); +#ifndef XSIMD_NO_DENORMALS + x = (x | bitofsign(self)) * f; +#else + x = x | bitofsign(self); +#endif +#ifndef XSIMD_NO_INFINITIES + return select(self == batch_type(0.) || isinf(self), self, x); +#else + return select(self == batch_type(0.), self, x); +#endif + } + + // clip + template + XSIMD_INLINE batch clip(batch const& self, batch const& lo, batch const& hi, requires_arch) noexcept + { + return min(hi, max(self, lo)); + } + + // copysign + template ::value, void>::type> + XSIMD_INLINE batch copysign(batch const& self, batch const& other, requires_arch) noexcept + { + return abs(self) | bitofsign(other); + } + + // erf + + namespace detail + { + /* origin: boost/simd/arch/common/detail/generic/erf_kernel.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + struct erf_kernel; + + template + struct erf_kernel> + { + using batch_type = batch; + // computes erf(a0)/a0 + // x is sqr(a0) and 0 <= abs(a0) <= 2/3 + static XSIMD_INLINE batch_type erf1(const batch_type& x) noexcept + { + return detail::horner(x); + } + + // computes erfc(x)*exp(sqr(x)) + // x >= 2/3 + static XSIMD_INLINE batch_type erfc2(const batch_type& x) noexcept + { + return detail::horner(x); + } + + static XSIMD_INLINE batch_type erfc3(const batch_type& x) noexcept + { + return (batch_type(1.) - x) * detail::horner(x); + } + }; + + template + struct erf_kernel> + { + using batch_type = batch; + // computes erf(a0)/a0 + // x is sqr(a0) and 0 <= abs(a0) <= 0.65 + static XSIMD_INLINE batch_type erf1(const batch_type& x) noexcept + { + return detail::horner(x) + / detail::horner(x); + } + + // computes erfc(x)*exp(x*x) + // 0.65 <= abs(x) <= 2.2 + static XSIMD_INLINE batch_type erfc2(const batch_type& x) noexcept + { + return detail::horner(x) + / detail::horner(x); + } + + // computes erfc(x)*exp(x*x) + // 2.2 <= abs(x) <= 6 + static XSIMD_INLINE batch_type erfc3(const batch_type& x) noexcept + { + return detail::horner(x) + / detail::horner(x); + } + + // computes erfc(rx)*exp(rx*rx) + // x >= 6 rx = 1/x + static XSIMD_INLINE batch_type erfc4(const batch_type& x) noexcept + { + return detail::horner(x); + } + }; + } + /* origin: boost/simd/arch/common/simd/function/erf.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + + template + XSIMD_INLINE batch erf(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + batch_type x = abs(self); + batch_type r1(0.); + auto test1 = x < batch_type(2.f / 3.f); + if (any(test1)) + { + r1 = self * detail::erf_kernel::erf1(x * x); + if (all(test1)) + return r1; + } + batch_type z = x / (batch_type(1.) + x); + z -= batch_type(0.4f); + batch_type r2 = batch_type(1.) - exp(-x * x) * detail::erf_kernel::erfc2(z); + r2 = select(self < batch_type(0.), -r2, r2); + r1 = select(test1, r1, r2); +#ifndef XSIMD_NO_INFINITIES + r1 = select(xsimd::isinf(self), sign(self), r1); +#endif + return r1; + } + + template + XSIMD_INLINE batch erf(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + batch_type x = abs(self); + batch_type xx = x * x; + batch_type lim1(0.65); + batch_type lim2(2.2); + auto test1 = x < lim1; + batch_type r1(0.); + if (any(test1)) + { + r1 = self * detail::erf_kernel::erf1(xx); + if (all(test1)) + return r1; + } + auto test2 = x < lim2; + auto test3 = test2 && !test1; + batch_type ex = exp(-xx); + if (any(test3)) + { + batch_type z = batch_type(1.) - ex * detail::erf_kernel::erfc2(x); + batch_type r2 = select(self < batch_type(0.), -z, z); + r1 = select(test1, r1, r2); + if (all(test1 || test3)) + return r1; + } + batch_type z = batch_type(1.) - ex * detail::erf_kernel::erfc3(x); + z = select(self < batch_type(0.), -z, z); +#ifndef XSIMD_NO_INFINITIES + z = select(xsimd::isinf(self), sign(self), z); +#endif + return select(test2, r1, z); + } + + // erfc + template + XSIMD_INLINE batch erfc(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + batch_type x = abs(self); + auto test0 = self < batch_type(0.); + batch_type r1(0.); + auto test1 = 3.f * x < 2.f; + batch_type z = x / (batch_type(1.) + x); + if (any(test1)) + { + r1 = detail::erf_kernel::erfc3(z); + if (all(test1)) + return select(test0, batch_type(2.) - r1, r1); + } + + z -= batch_type(0.4f); + batch_type r2 = exp(-x * x) * detail::erf_kernel::erfc2(z); + r1 = select(test1, r1, r2); +#ifndef XSIMD_NO_INFINITIES + r1 = select(x == constants::infinity(), batch_type(0.), r1); +#endif + return select(test0, batch_type(2.) - r1, r1); + } + + template + XSIMD_INLINE batch erfc(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + batch_type x = abs(self); + batch_type xx = x * x; + batch_type lim1(0.65); + batch_type lim2(2.2); + auto test0 = self < batch_type(0.); + auto test1 = x < lim1; + batch_type r1(0.); + if (any(test1)) + { + r1 = batch_type(1.) - x * detail::erf_kernel::erf1(xx); + if (all(test1)) + return select(test0, batch_type(2.) - r1, r1); + } + auto test2 = x < lim2; + auto test3 = test2 && !test1; + batch_type ex = exp(-xx); + if (any(test3)) + { + batch_type z = ex * detail::erf_kernel::erfc2(x); + r1 = select(test1, r1, z); + if (all(test1 || test3)) + return select(test0, batch_type(2.) - r1, r1); + } + batch_type z = ex * detail::erf_kernel::erfc3(x); + r1 = select(test2, r1, z); +#ifndef XSIMD_NO_INFINITIES + r1 = select(x == constants::infinity(), batch_type(0.), r1); +#endif + return select(test0, batch_type(2.) - r1, r1); + } + + // estrin + namespace detail + { + + template + struct estrin + { + B x; + + template + XSIMD_INLINE B operator()(const Ts&... coefs) noexcept + { + return eval(coefs...); + } + + private: + XSIMD_INLINE B eval(const B& c0) noexcept + { + return c0; + } + + XSIMD_INLINE B eval(const B& c0, const B& c1) noexcept + { + return fma(x, c1, c0); + } + + template + XSIMD_INLINE B eval(::xsimd::detail::index_sequence, const Tuple& tuple) + { + return estrin { x * x }(std::get(tuple)...); + } + + template + XSIMD_INLINE B eval(const std::tuple& tuple) noexcept + { + return eval(::xsimd::detail::make_index_sequence(), tuple); + } + + template + XSIMD_INLINE B eval(const std::tuple& tuple, const B& c0) noexcept + { + return eval(std::tuple_cat(tuple, std::make_tuple(eval(c0)))); + } + + template + XSIMD_INLINE B eval(const std::tuple& tuple, const B& c0, const B& c1) noexcept + { + return eval(std::tuple_cat(tuple, std::make_tuple(eval(c0, c1)))); + } + + template + XSIMD_INLINE B eval(const std::tuple& tuple, const B& c0, const B& c1, const Ts&... coefs) noexcept + { + return eval(std::tuple_cat(tuple, std::make_tuple(eval(c0, c1))), coefs...); + } + + template + XSIMD_INLINE B eval(const B& c0, const B& c1, const Ts&... coefs) noexcept + { + return eval(std::make_tuple(eval(c0, c1)), coefs...); + } + }; + } + + template + XSIMD_INLINE batch estrin(const batch& self) noexcept + { + using batch_type = batch; + return detail::estrin { self }(detail::coef()...); + } + + // exp + /* origin: boost/simd/arch/common/detail/simd/expo_base.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + namespace detail + { + enum exp_reduction_tag + { + exp_tag, + exp2_tag, + exp10_tag + }; + + template + struct exp_reduction_base; + + template + struct exp_reduction_base + { + static constexpr B maxlog() noexcept + { + return constants::maxlog(); + } + + static constexpr B minlog() noexcept + { + return constants::minlog(); + } + }; + + template + struct exp_reduction_base + { + static constexpr B maxlog() noexcept + { + return constants::maxlog10(); + } + + static constexpr B minlog() noexcept + { + return constants::minlog10(); + } + }; + + template + struct exp_reduction_base + { + static constexpr B maxlog() noexcept + { + return constants::maxlog2(); + } + + static constexpr B minlog() noexcept + { + return constants::minlog2(); + } + }; + + template + struct exp_reduction; + + template + struct exp_reduction : exp_reduction_base, exp_tag> + { + using batch_type = batch; + static XSIMD_INLINE batch_type approx(const batch_type& x) noexcept + { + batch_type y = detail::horner(x); + return ++fma(y, x * x, x); + } + + static XSIMD_INLINE batch_type reduce(const batch_type& a, batch_type& x) noexcept + { + batch_type k = nearbyint(constants::invlog_2() * a); + x = fnma(k, constants::log_2hi(), a); + x = fnma(k, constants::log_2lo(), x); + return k; + } + }; + + template + struct exp_reduction : exp_reduction_base, exp10_tag> + { + using batch_type = batch; + static XSIMD_INLINE batch_type approx(const batch_type& x) noexcept + { + return ++(detail::horner(x) + * x); + } + + static XSIMD_INLINE batch_type reduce(const batch_type& a, batch_type& x) noexcept + { + batch_type k = nearbyint(constants::invlog10_2() * a); + x = fnma(k, constants::log10_2hi(), a); + x -= k * constants::log10_2lo(); + return k; + } + }; + + template + struct exp_reduction : exp_reduction_base, exp2_tag> + { + using batch_type = batch; + static XSIMD_INLINE batch_type approx(const batch_type& x) noexcept + { + batch_type y = detail::horner(x); + return ++fma(y, x * x, x * constants::log_2()); + } + + static XSIMD_INLINE batch_type reduce(const batch_type& a, batch_type& x) noexcept + { + batch_type k = nearbyint(a); + x = (a - k); + return k; + } + }; + + template + struct exp_reduction : exp_reduction_base, exp_tag> + { + using batch_type = batch; + static XSIMD_INLINE batch_type approx(const batch_type& x) noexcept + { + batch_type t = x * x; + return fnma(t, + detail::horner(t), + x); + } + + static XSIMD_INLINE batch_type reduce(const batch_type& a, batch_type& hi, batch_type& lo, batch_type& x) noexcept + { + batch_type k = nearbyint(constants::invlog_2() * a); + hi = fnma(k, constants::log_2hi(), a); + lo = k * constants::log_2lo(); + x = hi - lo; + return k; + } + + static XSIMD_INLINE batch_type finalize(const batch_type& x, const batch_type& c, const batch_type& hi, const batch_type& lo) noexcept + { + return batch_type(1.) - (((lo - (x * c) / (batch_type(2.) - c)) - hi)); + } + }; + + template + struct exp_reduction : exp_reduction_base, exp10_tag> + { + using batch_type = batch; + static XSIMD_INLINE batch_type approx(const batch_type& x) noexcept + { + batch_type xx = x * x; + batch_type px = x * detail::horner(xx); + batch_type x2 = px / (detail::horner1(xx) - px); + return ++(x2 + x2); + } + + static XSIMD_INLINE batch_type reduce(const batch_type& a, batch_type&, batch_type&, batch_type& x) noexcept + { + batch_type k = nearbyint(constants::invlog10_2() * a); + x = fnma(k, constants::log10_2hi(), a); + x = fnma(k, constants::log10_2lo(), x); + return k; + } + + static XSIMD_INLINE batch_type finalize(const batch_type&, const batch_type& c, const batch_type&, const batch_type&) noexcept + { + return c; + } + }; + + template + struct exp_reduction : exp_reduction_base, exp2_tag> + { + using batch_type = batch; + static XSIMD_INLINE batch_type approx(const batch_type& x) noexcept + { + batch_type t = x * x; + return fnma(t, + detail::horner(t), + x); + } + + static XSIMD_INLINE batch_type reduce(const batch_type& a, batch_type&, batch_type&, batch_type& x) noexcept + { + batch_type k = nearbyint(a); + x = (a - k) * constants::log_2(); + return k; + } + + static XSIMD_INLINE batch_type finalize(const batch_type& x, const batch_type& c, const batch_type&, const batch_type&) noexcept + { + return batch_type(1.) + x + x * c / (batch_type(2.) - c); + } + }; + + template + XSIMD_INLINE batch exp(batch const& self) noexcept + { + using batch_type = batch; + using reducer_t = exp_reduction; + batch_type x; + batch_type k = reducer_t::reduce(self, x); + x = reducer_t::approx(x); + x = select(self <= reducer_t::minlog(), batch_type(0.), ldexp(x, to_int(k))); + x = select(self >= reducer_t::maxlog(), constants::infinity(), x); + return x; + } + + template + XSIMD_INLINE batch exp(batch const& self) noexcept + { + using batch_type = batch; + using reducer_t = exp_reduction; + batch_type hi, lo, x; + batch_type k = reducer_t::reduce(self, hi, lo, x); + batch_type c = reducer_t::approx(x); + c = reducer_t::finalize(x, c, hi, lo); + c = select(self <= reducer_t::minlog(), batch_type(0.), ldexp(c, to_int(k))); + c = select(self >= reducer_t::maxlog(), constants::infinity(), c); + return c; + } + } + + template + XSIMD_INLINE batch exp(batch const& self, requires_arch) noexcept + { + return detail::exp(self); + } + + template + XSIMD_INLINE batch, A> exp(batch, A> const& self, requires_arch) noexcept + { + using batch_type = batch, A>; + auto isincos = sincos(self.imag()); + return exp(self.real()) * batch_type(std::get<1>(isincos), std::get<0>(isincos)); + } + + // exp10 + template + XSIMD_INLINE batch exp10(batch const& self, requires_arch) noexcept + { + return detail::exp(self); + } + + // exp2 + template + XSIMD_INLINE batch exp2(batch const& self, requires_arch) noexcept + { + return detail::exp(self); + } + + // expm1 + namespace detail + { + /* origin: boost/simd/arch/common/detail/generic/expm1_kernel.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + static XSIMD_INLINE batch expm1(const batch& a) noexcept + { + using batch_type = batch; + batch_type k = nearbyint(constants::invlog_2() * a); + batch_type x = fnma(k, constants::log_2hi(), a); + x = fnma(k, constants::log_2lo(), x); + batch_type hx = x * batch_type(0.5); + batch_type hxs = x * hx; + batch_type r = detail::horner(hxs); + batch_type t = fnma(r, hx, batch_type(3.)); + batch_type e = hxs * ((r - t) / (batch_type(6.) - x * t)); + e = fms(x, e, hxs); + using i_type = as_integer_t; + i_type ik = to_int(k); + batch_type two2mk = ::xsimd::bitwise_cast((constants::maxexponent() - ik) << constants::nmb()); + batch_type y = batch_type(1.) - two2mk - (e - x); + return ldexp(y, ik); + } + + template + static XSIMD_INLINE batch expm1(const batch& a) noexcept + { + using batch_type = batch; + batch_type k = nearbyint(constants::invlog_2() * a); + batch_type hi = fnma(k, constants::log_2hi(), a); + batch_type lo = k * constants::log_2lo(); + batch_type x = hi - lo; + batch_type hxs = x * x * batch_type(0.5); + batch_type r = detail::horner(hxs); + batch_type t = batch_type(3.) - r * batch_type(0.5) * x; + batch_type e = hxs * ((r - t) / (batch_type(6) - x * t)); + batch_type c = (hi - x) - lo; + e = (x * (e - c) - c) - hxs; + using i_type = as_integer_t; + i_type ik = to_int(k); + batch_type two2mk = ::xsimd::bitwise_cast((constants::maxexponent() - ik) << constants::nmb()); + batch_type ct1 = batch_type(1.) - two2mk - (e - x); + batch_type ct2 = ++(x - (e + two2mk)); + batch_type y = select(k < batch_type(20.), ct1, ct2); + return ldexp(y, ik); + } + + } + + template + XSIMD_INLINE batch expm1(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + return select(self < constants::logeps(), + batch_type(-1.), + select(self > constants::maxlog(), + constants::infinity(), + detail::expm1(self))); + } + + template + XSIMD_INLINE batch, A> expm1(const batch, A>& z, requires_arch) noexcept + { + using batch_type = batch, A>; + using real_batch = typename batch_type::real_batch; + real_batch isin = sin(z.imag()); + real_batch rem1 = expm1(z.real()); + real_batch re = rem1 + 1.; + real_batch si = sin(z.imag() * 0.5); + return { rem1 - 2. * re * si * si, re * isin }; + } + + // polar + template + XSIMD_INLINE batch, A> polar(const batch& r, const batch& theta, requires_arch) noexcept + { + auto sincosTheta = sincos(theta); + return { r * sincosTheta.second, r * sincosTheta.first }; + } + + // fdim + template + XSIMD_INLINE batch fdim(batch const& self, batch const& other, requires_arch) noexcept + { + return fmax(batch(0), self - other); + } + + // fmod + template + XSIMD_INLINE batch fmod(batch const& self, batch const& other, requires_arch) noexcept + { + return fnma(trunc(self / other), other, self); + } + + // frexp + /* origin: boost/simd/arch/common/simd/function/ifrexp.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + XSIMD_INLINE batch frexp(const batch& self, batch, A>& exp, requires_arch) noexcept + { + using batch_type = batch; + using int_type = as_integer_t; + using i_type = batch; + i_type m1f = constants::mask1frexp(); + i_type r1 = m1f & ::xsimd::bitwise_cast(self); + batch_type x = self & ::xsimd::bitwise_cast(~m1f); + exp = (r1 >> constants::nmb()) - constants::maxexponentm1(); + exp = select(batch_bool_cast(self != batch_type(0.)), exp, i_type(typename i_type::value_type(0))); + return select((self != batch_type(0.)), x | ::xsimd::bitwise_cast(constants::mask2frexp()), batch_type(0.)); + } + + // from bool + template + XSIMD_INLINE batch from_bool(batch_bool const& self, requires_arch) noexcept + { + return batch(self.data) & batch(1); + } + + // horner + template + XSIMD_INLINE batch horner(const batch& self) noexcept + { + return detail::horner, Coefs...>(self); + } + + // hypot + template + XSIMD_INLINE batch hypot(batch const& self, batch const& other, requires_arch) noexcept + { + return sqrt(fma(self, self, other * other)); + } + + // ipow + template + XSIMD_INLINE batch ipow(batch const& self, ITy other, requires_arch) noexcept + { + return ::xsimd::detail::ipow(self, other); + } + + // ldexp + /* origin: boost/simd/arch/common/simd/function/ldexp.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + XSIMD_INLINE batch ldexp(const batch& self, const batch, A>& other, requires_arch) noexcept + { + using batch_type = batch; + using itype = as_integer_t; + itype ik = other + constants::maxexponent(); + ik = ik << constants::nmb(); + return self * ::xsimd::bitwise_cast(ik); + } + + // lgamma + template + XSIMD_INLINE batch lgamma(batch const& self, requires_arch) noexcept; + + namespace detail + { + /* origin: boost/simd/arch/common/detail/generic/gammaln_kernel.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + static XSIMD_INLINE batch gammalnB(const batch& x) noexcept + { + return horner, + 0x3ed87730, // 4.227843421859038E-001 + 0x3ea51a64, // 3.224669577325661E-001, + 0xbd89f07e, // -6.735323259371034E-002, + 0x3ca89ed8, // 2.058355474821512E-002, + 0xbbf164fd, // -7.366775108654962E-003, + 0x3b3ba883, // 2.863437556468661E-003, + 0xbaabeab1, // -1.311620815545743E-003, + 0x3a1ebb94 // 6.055172732649237E-004 + >(x); + } + + template + static XSIMD_INLINE batch gammalnC(const batch& x) noexcept + { + return horner, + 0xbf13c468, // -5.772156501719101E-001 + 0x3f528d34, // 8.224670749082976E-001, + 0xbecd27a8, // -4.006931650563372E-001, + 0x3e8a898b, // 2.705806208275915E-001, + 0xbe53c04f, // -2.067882815621965E-001, + 0x3e2d4dab, // 1.692415923504637E-001, + 0xbe22d329, // -1.590086327657347E-001, + 0x3e0c3c4f // 1.369488127325832E-001 + >(x); + } + + template + static XSIMD_INLINE batch gammaln2(const batch& x) noexcept + { + return horner, + 0x3daaaa94, // 8.333316229807355E-002f + 0xbb358701, // -2.769887652139868E-003f, + 0x3a31fd69 // 6.789774945028216E-004f + >(x); + } + + template + static XSIMD_INLINE batch gammaln1(const batch& x) noexcept + { + return horner, + 0xc12a0c675418055eull, // -8.53555664245765465627E5 + 0xc13a45890219f20bull, // -1.72173700820839662146E6, + 0xc131bc82f994db51ull, // -1.16237097492762307383E6, + 0xc1143d73f89089e5ull, // -3.31612992738871184744E5, + 0xc0e2f234355bb93eull, // -3.88016315134637840924E4, + 0xc09589018ff36761ull // -1.37825152569120859100E3 + >(x) + / horner, + 0xc13ece4b6a11e14aull, // -2.01889141433532773231E6 + 0xc1435255892ff34cull, // -2.53252307177582951285E6, + 0xc131628671950043ull, // -1.13933444367982507207E6, + 0xc10aeb84b9744c9bull, // -2.20528590553854454839E5, + 0xc0d0aa0d7b89d757ull, // -1.70642106651881159223E4, + 0xc075fd0d1cf312b2ull, // -3.51815701436523470549E2, + 0x3ff0000000000000ull // 1.00000000000000000000E0 + >(x); + } + + template + static XSIMD_INLINE batch gammalnA(const batch& x) noexcept + { + return horner, + 0x3fb555555555554bull, // 8.33333333333331927722E-2 + 0xbf66c16c16b0a5a1ull, // -2.77777777730099687205E-3, + 0x3f4a019f20dc5ebbull, // 7.93650340457716943945E-4, + 0xbf437fbdb580e943ull, // -5.95061904284301438324E-4, + 0x3f4a985027336661ull // 8.11614167470508450300E-4 + >(x); + } + + /* origin: boost/simd/arch/common/simd/function/gammaln.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + struct lgamma_impl; + + template + struct lgamma_impl> + { + using batch_type = batch; + static XSIMD_INLINE batch_type compute(const batch_type& a) noexcept + { + auto inf_result = (a <= batch_type(0.)) && is_flint(a); + batch_type x = select(inf_result, constants::nan(), a); + batch_type q = abs(x); +#ifndef XSIMD_NO_INFINITIES + inf_result = (x == constants::infinity()) || inf_result; +#endif + auto ltza = a < batch_type(0.); + batch_type r(0); + batch_type r1 = other(q); + if (any(ltza)) + { + r = select(inf_result, constants::infinity(), negative(q, r1)); + if (all(ltza)) + return r; + } + batch_type r2 = select(ltza, r, r1); + return select(a == constants::minusinfinity(), constants::nan(), select(inf_result, constants::infinity(), r2)); + } + + private: + static XSIMD_INLINE batch_type negative(const batch_type& q, const batch_type& w) noexcept + { + batch_type p = floor(q); + batch_type z = q - p; + auto test2 = z < batch_type(0.5); + z = select(test2, z - batch_type(1.), z); + z = q * sin(z, trigo_pi_tag()); + return -log(constants::invpi() * abs(z)) - w; + } + + static XSIMD_INLINE batch_type other(const batch_type& x) noexcept + { + auto xlt650 = (x < batch_type(6.5)); + batch_type r0x = x; + batch_type r0z = x; + batch_type r0s = batch_type(1.); + batch_type r1 = batch_type(0.); + batch_type p = constants::nan(); + if (any(xlt650)) + { + batch_type z = batch_type(1.); + batch_type tx = select(xlt650, x, batch_type(0.)); + batch_type nx = batch_type(0.); + const batch_type _075 = batch_type(0.75); + const batch_type _150 = batch_type(1.50); + const batch_type _125 = batch_type(1.25); + const batch_type _250 = batch_type(2.50); + auto xge150 = (x >= _150); + auto txgt250 = (tx > _250); + + // x >= 1.5 + while (any(xge150 && txgt250)) + { + nx = select(txgt250, nx - batch_type(1.), nx); + tx = select(txgt250, x + nx, tx); + z = select(txgt250, z * tx, z); + txgt250 = (tx > _250); + } + r0x = select(xge150, x + nx - batch_type(2.), x); + r0z = select(xge150, z, r0z); + r0s = select(xge150, batch_type(1.), r0s); + + // x >= 1.25 && x < 1.5 + auto xge125 = (x >= _125); + auto xge125t = xge125 && !xge150; + if (any(xge125)) + { + r0x = select(xge125t, x - batch_type(1.), r0x); + r0z = select(xge125t, z * x, r0z); + r0s = select(xge125t, batch_type(-1.), r0s); + } + + // x >= 0.75 && x < 1.5 + batch_bool kernelC(false); + auto xge075 = (x >= _075); + auto xge075t = xge075 && !xge125; + if (any(xge075t)) + { + kernelC = xge075t; + r0x = select(xge075t, x - batch_type(1.), x); + r0z = select(xge075t, batch_type(1.), r0z); + r0s = select(xge075t, batch_type(-1.), r0s); + p = gammalnC(r0x); + } + + // tx < 1.5 && x < 0.75 + auto txlt150 = (tx < _150) && !xge075; + if (any(txlt150)) + { + auto orig = txlt150; + while (any(txlt150)) + { + z = select(txlt150, z * tx, z); + nx = select(txlt150, nx + batch_type(1.), nx); + tx = select(txlt150, x + nx, tx); + txlt150 = (tx < _150) && !xge075; + } + r0x = select(orig, r0x + nx - batch_type(2.), r0x); + r0z = select(orig, z, r0z); + r0s = select(orig, batch_type(-1.), r0s); + } + p = select(kernelC, p, gammalnB(r0x)); + if (all(xlt650)) + return fma(r0x, p, r0s * log(abs(r0z))); + } + r0z = select(xlt650, abs(r0z), x); + batch_type m = log(r0z); + r1 = fma(r0x, p, r0s * m); + batch_type r2 = fma(x - batch_type(0.5), m, constants::logsqrt2pi() - x); + r2 += gammaln2(batch_type(1.) / (x * x)) / x; + return select(xlt650, r1, r2); + } + }; + + template + struct lgamma_impl> + { + using batch_type = batch; + + static XSIMD_INLINE batch_type compute(const batch_type& a) noexcept + { + auto inf_result = (a <= batch_type(0.)) && is_flint(a); + batch_type x = select(inf_result, constants::nan(), a); + batch_type q = abs(x); +#ifndef XSIMD_NO_INFINITIES + inf_result = (q == constants::infinity()); +#endif + auto test = (a < batch_type(-34.)); + batch_type r = constants::nan(); + if (any(test)) + { + r = large_negative(q); + if (all(test)) + return select(inf_result, constants::nan(), r); + } + batch_type r1 = other(a); + batch_type r2 = select(test, r, r1); + return select(a == constants::minusinfinity(), constants::nan(), select(inf_result, constants::infinity(), r2)); + } + + private: + // FIXME: cannot mark this one as XSIMD_INLINE because there's a + // recursive loop on `lgamma'. + static inline batch_type large_negative(const batch_type& q) noexcept + { + batch_type w = lgamma(q); + batch_type p = floor(q); + batch_type z = q - p; + auto test2 = (z < batch_type(0.5)); + z = select(test2, z - batch_type(1.), z); + z = q * sin(z, trigo_pi_tag()); + z = abs(z); + return constants::logpi() - log(z) - w; + } + + static XSIMD_INLINE batch_type other(const batch_type& xx) noexcept + { + batch_type x = xx; + auto test = (x < batch_type(13.)); + batch_type r1 = batch_type(0.); + if (any(test)) + { + batch_type z = batch_type(1.); + batch_type p = batch_type(0.); + batch_type u = select(test, x, batch_type(0.)); + auto test1 = (u >= batch_type(3.)); + while (any(test1)) + { + p = select(test1, p - batch_type(1.), p); + u = select(test1, x + p, u); + z = select(test1, z * u, z); + test1 = (u >= batch_type(3.)); + } + + auto test2 = (u < batch_type(2.)); + while (any(test2)) + { + z = select(test2, z / u, z); + p = select(test2, p + batch_type(1.), p); + u = select(test2, x + p, u); + test2 = (u < batch_type(2.)); + } + + z = abs(z); + x += p - batch_type(2.); + r1 = x * gammaln1(x) + log(z); + if (all(test)) + return r1; + } + batch_type r2 = fma(xx - batch_type(0.5), log(xx), constants::logsqrt2pi() - xx); + batch_type p = batch_type(1.) / (xx * xx); + r2 += gammalnA(p) / xx; + return select(test, r1, r2); + } + }; + } + + template + XSIMD_INLINE batch lgamma(batch const& self, requires_arch) noexcept + { + return detail::lgamma_impl>::compute(self); + } + + // log + /* origin: boost/simd/arch/common/simd/function/log.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + XSIMD_INLINE batch log(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + using int_type = as_integer_t; + using i_type = batch; + batch_type x = self; + i_type k(0); + auto isnez = (self != batch_type(0.)); +#ifndef XSIMD_NO_DENORMALS + auto test = (self < constants::smallestposval()) && isnez; + if (any(test)) + { + k = select(batch_bool_cast(test), k - i_type(23), k); + x = select(test, x * batch_type(8388608ul), x); + } +#endif + i_type ix = ::xsimd::bitwise_cast(x); + ix += 0x3f800000 - 0x3f3504f3; + k += (ix >> 23) - 0x7f; + ix = (ix & i_type(0x007fffff)) + 0x3f3504f3; + x = ::xsimd::bitwise_cast(ix); + batch_type f = --x; + batch_type s = f / (batch_type(2.) + f); + batch_type z = s * s; + batch_type w = z * z; + batch_type t1 = w * detail::horner(w); + batch_type t2 = z * detail::horner(w); + batch_type R = t2 + t1; + batch_type hfsq = batch_type(0.5) * f * f; + batch_type dk = to_float(k); + batch_type r = fma(dk, constants::log_2hi(), fma(s, (hfsq + R), dk * constants::log_2lo()) - hfsq + f); +#ifndef XSIMD_NO_INFINITIES + batch_type zz = select(isnez, select(self == constants::infinity(), constants::infinity(), r), constants::minusinfinity()); +#else + batch_type zz = select(isnez, r, constants::minusinfinity()); +#endif + return select(!(self >= batch_type(0.)), constants::nan(), zz); + } + + template + XSIMD_INLINE batch log(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + using int_type = as_integer_t; + using i_type = batch; + + batch_type x = self; + i_type hx = ::xsimd::bitwise_cast(x) >> 32; + i_type k(0); + auto isnez = (self != batch_type(0.)); +#ifndef XSIMD_NO_DENORMALS + auto test = (self < constants::smallestposval()) && isnez; + if (any(test)) + { + k = select(batch_bool_cast(test), k - i_type(54), k); + x = select(test, x * batch_type(18014398509481984ull), x); + } +#endif + hx += 0x3ff00000 - 0x3fe6a09e; + k += (hx >> 20) - 0x3ff; + batch_type dk = to_float(k); + hx = (hx & i_type(0x000fffff)) + 0x3fe6a09e; + x = ::xsimd::bitwise_cast(hx << 32 | (i_type(0xffffffff) & ::xsimd::bitwise_cast(x))); + + batch_type f = --x; + batch_type hfsq = batch_type(0.5) * f * f; + batch_type s = f / (batch_type(2.) + f); + batch_type z = s * s; + batch_type w = z * z; + + batch_type t1 = w * detail::horner(w); + batch_type t2 = z * detail::horner(w); + batch_type R = t2 + t1; + batch_type r = fma(dk, constants::log_2hi(), fma(s, (hfsq + R), dk * constants::log_2lo()) - hfsq + f); +#ifndef XSIMD_NO_INFINITIES + batch_type zz = select(isnez, select(self == constants::infinity(), constants::infinity(), r), constants::minusinfinity()); +#else + batch_type zz = select(isnez, r, constants::minusinfinity()); +#endif + return select(!(self >= batch_type(0.)), constants::nan(), zz); + } + + template + XSIMD_INLINE batch, A> log(const batch, A>& z, requires_arch) noexcept + { + return batch, A>(log(abs(z)), atan2(z.imag(), z.real())); + } + + // log2 + template + XSIMD_INLINE batch log2(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + using int_type = as_integer_t; + using i_type = batch; + batch_type x = self; + i_type k(0); + auto isnez = (self != batch_type(0.)); +#ifndef XSIMD_NO_DENORMALS + auto test = (self < constants::smallestposval()) && isnez; + if (any(test)) + { + k = select(batch_bool_cast(test), k - i_type(25), k); + x = select(test, x * batch_type(33554432ul), x); + } +#endif + i_type ix = ::xsimd::bitwise_cast(x); + ix += 0x3f800000 - 0x3f3504f3; + k += (ix >> 23) - 0x7f; + ix = (ix & i_type(0x007fffff)) + 0x3f3504f3; + x = ::xsimd::bitwise_cast(ix); + batch_type f = --x; + batch_type s = f / (batch_type(2.) + f); + batch_type z = s * s; + batch_type w = z * z; + batch_type t1 = w * detail::horner(w); + batch_type t2 = z * detail::horner(w); + batch_type R = t1 + t2; + batch_type hfsq = batch_type(0.5) * f * f; + batch_type dk = to_float(k); + batch_type r = fma(fms(s, hfsq + R, hfsq) + f, constants::invlog_2(), dk); +#ifndef XSIMD_NO_INFINITIES + batch_type zz = select(isnez, select(self == constants::infinity(), constants::infinity(), r), constants::minusinfinity()); +#else + batch_type zz = select(isnez, r, constants::minusinfinity()); +#endif + return select(!(self >= batch_type(0.)), constants::nan(), zz); + } + + template + XSIMD_INLINE batch log2(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + using int_type = as_integer_t; + using i_type = batch; + batch_type x = self; + i_type hx = ::xsimd::bitwise_cast(x) >> 32; + i_type k(0); + auto isnez = (self != batch_type(0.)); +#ifndef XSIMD_NO_DENORMALS + auto test = (self < constants::smallestposval()) && isnez; + if (any(test)) + { + k = select(batch_bool_cast(test), k - i_type(54), k); + x = select(test, x * batch_type(18014398509481984ull), x); + } +#endif + hx += 0x3ff00000 - 0x3fe6a09e; + k += (hx >> 20) - 0x3ff; + hx = (hx & i_type(0x000fffff)) + 0x3fe6a09e; + x = ::xsimd::bitwise_cast(hx << 32 | (i_type(0xffffffff) & ::xsimd::bitwise_cast(x))); + batch_type f = --x; + batch_type s = f / (batch_type(2.) + f); + batch_type z = s * s; + batch_type w = z * z; + batch_type t1 = w * detail::horner(w); + batch_type t2 = z * detail::horner(w); + batch_type R = t2 + t1; + batch_type hfsq = batch_type(0.5) * f * f; + batch_type hi = f - hfsq; + hi = hi & ::xsimd::bitwise_cast((constants::allbits() << 32)); + batch_type lo = fma(s, hfsq + R, f - hi - hfsq); + batch_type val_hi = hi * constants::invlog_2hi(); + batch_type val_lo = fma(lo + hi, constants::invlog_2lo(), lo * constants::invlog_2hi()); + batch_type dk = to_float(k); + batch_type w1 = dk + val_hi; + val_lo += (dk - w1) + val_hi; + val_hi = w1; + batch_type r = val_lo + val_hi; +#ifndef XSIMD_NO_INFINITIES + batch_type zz = select(isnez, select(self == constants::infinity(), constants::infinity(), r), constants::minusinfinity()); +#else + batch_type zz = select(isnez, r, constants::minusinfinity()); +#endif + return select(!(self >= batch_type(0.)), constants::nan(), zz); + } + + namespace detail + { + template + XSIMD_INLINE batch logN_complex_impl(const batch& z, typename batch::value_type base) noexcept + { + using batch_type = batch; + using rv_type = typename batch_type::value_type; + return log(z) / batch_type(rv_type(base)); + } + } + + template + XSIMD_INLINE batch, A> log2(batch, A> const& self, requires_arch) noexcept + { + return detail::logN_complex_impl(self, std::log(2)); + } + + // log10 + /* origin: FreeBSD /usr/src/lib/msun/src/e_log10f.c */ + /* + * ==================================================== + * Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved. + * + * Developed at SunPro, a Sun Microsystems, Inc. business. + * Permission to use, copy, modify, and distribute this + * software is freely granted, provided that this notice + * is preserved. + * ==================================================== + */ + template + XSIMD_INLINE batch log10(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + const batch_type + ivln10hi(4.3432617188e-01f), + ivln10lo(-3.1689971365e-05f), + log10_2hi(3.0102920532e-01f), + log10_2lo(7.9034151668e-07f); + using int_type = as_integer_t; + using i_type = batch; + batch_type x = self; + i_type k(0); + auto isnez = (self != batch_type(0.)); +#ifndef XSIMD_NO_DENORMALS + auto test = (self < constants::smallestposval()) && isnez; + if (any(test)) + { + k = select(batch_bool_cast(test), k - i_type(25), k); + x = select(test, x * batch_type(33554432ul), x); + } +#endif + i_type ix = ::xsimd::bitwise_cast(x); + ix += 0x3f800000 - 0x3f3504f3; + k += (ix >> 23) - 0x7f; + ix = (ix & i_type(0x007fffff)) + 0x3f3504f3; + x = ::xsimd::bitwise_cast(ix); + batch_type f = --x; + batch_type s = f / (batch_type(2.) + f); + batch_type z = s * s; + batch_type w = z * z; + batch_type t1 = w * detail::horner(w); + batch_type t2 = z * detail::horner(w); + batch_type R = t2 + t1; + batch_type dk = to_float(k); + batch_type hfsq = batch_type(0.5) * f * f; + batch_type hibits = f - hfsq; + hibits &= ::xsimd::bitwise_cast(i_type(0xfffff000)); + batch_type lobits = fma(s, hfsq + R, f - hibits - hfsq); + batch_type r = fma(dk, log10_2hi, + fma(hibits, ivln10hi, + fma(lobits, ivln10hi, + fma(lobits + hibits, ivln10lo, dk * log10_2lo)))); +#ifndef XSIMD_NO_INFINITIES + batch_type zz = select(isnez, select(self == constants::infinity(), constants::infinity(), r), constants::minusinfinity()); +#else + batch_type zz = select(isnez, r, constants::minusinfinity()); +#endif + return select(!(self >= batch_type(0.)), constants::nan(), zz); + } + + template + XSIMD_INLINE batch log10(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + const batch_type + ivln10hi(4.34294481878168880939e-01), + ivln10lo(2.50829467116452752298e-11), + log10_2hi(3.01029995663611771306e-01), + log10_2lo(3.69423907715893078616e-13); + using int_type = as_integer_t; + using i_type = batch; + batch_type x = self; + i_type hx = ::xsimd::bitwise_cast(x) >> 32; + i_type k(0); + auto isnez = (self != batch_type(0.)); +#ifndef XSIMD_NO_DENORMALS + auto test = (self < constants::smallestposval()) && isnez; + if (any(test)) + { + k = select(batch_bool_cast(test), k - i_type(54), k); + x = select(test, x * batch_type(18014398509481984ull), x); + } +#endif + hx += 0x3ff00000 - 0x3fe6a09e; + k += (hx >> 20) - 0x3ff; + hx = (hx & i_type(0x000fffff)) + 0x3fe6a09e; + x = ::xsimd::bitwise_cast(hx << 32 | (i_type(0xffffffff) & ::xsimd::bitwise_cast(x))); + batch_type f = --x; + batch_type dk = to_float(k); + batch_type s = f / (batch_type(2.) + f); + batch_type z = s * s; + batch_type w = z * z; + batch_type t1 = w * detail::horner(w); + batch_type t2 = z * detail::horner(w); + batch_type R = t2 + t1; + batch_type hfsq = batch_type(0.5) * f * f; + batch_type hi = f - hfsq; + hi = hi & ::xsimd::bitwise_cast(constants::allbits() << 32); + batch_type lo = f - hi - hfsq + s * (hfsq + R); + batch_type val_hi = hi * ivln10hi; + batch_type y = dk * log10_2hi; + batch_type val_lo = dk * log10_2lo + (lo + hi) * ivln10lo + lo * ivln10hi; + batch_type w1 = y + val_hi; + val_lo += (y - w1) + val_hi; + val_hi = w1; + batch_type r = val_lo + val_hi; +#ifndef XSIMD_NO_INFINITIES + batch_type zz = select(isnez, select(self == constants::infinity(), constants::infinity(), r), constants::minusinfinity()); +#else + batch_type zz = select(isnez, r, constants::minusinfinity()); +#endif + return select(!(self >= batch_type(0.)), constants::nan(), zz); + } + + template + XSIMD_INLINE batch, A> log10(const batch, A>& z, requires_arch) noexcept + { + return detail::logN_complex_impl(z, std::log(10)); + } + + // log1p + /* origin: boost/simd/arch/common/simd/function/log1p.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + XSIMD_INLINE batch log1p(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + using int_type = as_integer_t; + using i_type = batch; + const batch_type uf = self + batch_type(1.); + auto isnez = (uf != batch_type(0.)); + i_type iu = ::xsimd::bitwise_cast(uf); + iu += 0x3f800000 - 0x3f3504f3; + i_type k = (iu >> 23) - 0x7f; + iu = (iu & i_type(0x007fffff)) + 0x3f3504f3; + batch_type f = --(::xsimd::bitwise_cast(iu)); + batch_type s = f / (batch_type(2.) + f); + batch_type z = s * s; + batch_type w = z * z; + batch_type t1 = w * detail::horner(w); + batch_type t2 = z * detail::horner(w); + batch_type R = t2 + t1; + batch_type hfsq = batch_type(0.5) * f * f; + batch_type dk = to_float(k); + /* correction term ~ log(1+x)-log(u), avoid underflow in c/u */ + batch_type c = select(batch_bool_cast(k >= i_type(2)), batch_type(1.) - (uf - self), self - (uf - batch_type(1.))) / uf; + batch_type r = fma(dk, constants::log_2hi(), fma(s, (hfsq + R), dk * constants::log_2lo() + c) - hfsq + f); +#ifndef XSIMD_NO_INFINITIES + batch_type zz = select(isnez, select(self == constants::infinity(), constants::infinity(), r), constants::minusinfinity()); +#else + batch_type zz = select(isnez, r, constants::minusinfinity()); +#endif + return select(!(uf >= batch_type(0.)), constants::nan(), zz); + } + + template + XSIMD_INLINE batch log1p(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + using int_type = as_integer_t; + using i_type = batch; + const batch_type uf = self + batch_type(1.); + auto isnez = (uf != batch_type(0.)); + i_type hu = ::xsimd::bitwise_cast(uf) >> 32; + hu += 0x3ff00000 - 0x3fe6a09e; + i_type k = (hu >> 20) - 0x3ff; + /* correction term ~ log(1+x)-log(u), avoid underflow in c/u */ + batch_type c = select(batch_bool_cast(k >= i_type(2)), batch_type(1.) - (uf - self), self - (uf - batch_type(1.))) / uf; + hu = (hu & i_type(0x000fffff)) + 0x3fe6a09e; + batch_type f = ::xsimd::bitwise_cast((hu << 32) | (i_type(0xffffffff) & ::xsimd::bitwise_cast(uf))); + f = --f; + batch_type hfsq = batch_type(0.5) * f * f; + batch_type s = f / (batch_type(2.) + f); + batch_type z = s * s; + batch_type w = z * z; + batch_type t1 = w * detail::horner(w); + batch_type t2 = z * detail::horner(w); + batch_type R = t2 + t1; + batch_type dk = to_float(k); + batch_type r = fma(dk, constants::log_2hi(), fma(s, hfsq + R, dk * constants::log_2lo() + c) - hfsq + f); +#ifndef XSIMD_NO_INFINITIES + batch_type zz = select(isnez, select(self == constants::infinity(), constants::infinity(), r), constants::minusinfinity()); +#else + batch_type zz = select(isnez, r, constants::minusinfinity()); +#endif + return select(!(uf >= batch_type(0.)), constants::nan(), zz); + } + + template + XSIMD_INLINE batch, A> log1p(batch, A> const& self, requires_arch) noexcept + { + using batch_type = batch, A>; + using real_batch = typename batch_type::real_batch; + batch_type u = 1 + self; + batch_type logu = log(u); + return select(u == batch_type(1.), + self, + select(u.real() <= real_batch(0.), + logu, + logu * self / (u - batch_type(1.)))); + } + + // mod + template ::value, void>::type> + XSIMD_INLINE batch mod(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::apply([](T x, T y) noexcept -> T + { return x % y; }, + self, other); + } + + // nearbyint + template ::value, void>::type> + XSIMD_INLINE batch nearbyint(batch const& self, requires_arch) noexcept + { + return self; + } + namespace detail + { + template + XSIMD_INLINE batch nearbyintf(batch const& self) noexcept + { + using batch_type = batch; + batch_type s = bitofsign(self); + batch_type v = self ^ s; + batch_type t2n = constants::twotonmb(); + // Under fast-math, reordering is possible and the compiler optimizes d + // to v. That's not what we want, so prevent compiler optimization here. + // FIXME: it may be better to emit a memory barrier here (?). +#ifdef __FAST_MATH__ + volatile batch_type d0 = v + t2n; + batch_type d = *(batch_type*)(void*)(&d0) - t2n; +#else + batch_type d0 = v + t2n; + batch_type d = d0 - t2n; +#endif + return s ^ select(v < t2n, d, v); + } + } + template + XSIMD_INLINE batch nearbyint(batch const& self, requires_arch) noexcept + { + return detail::nearbyintf(self); + } + template + XSIMD_INLINE batch nearbyint(batch const& self, requires_arch) noexcept + { + return detail::nearbyintf(self); + } + + // nearbyint_as_int + template ::value, void>::type> + XSIMD_INLINE batch nearbyint_as_int(batch const& self, requires_arch) noexcept + { + return self; + } + + // nearbyint_as_int + template + XSIMD_INLINE batch, A> + nearbyint_as_int(batch const& self, requires_arch) noexcept + { + using U = as_integer_t; + return kernel::detail::apply_transform([](float x) noexcept -> U + { return std::nearbyintf(x); }, + self); + } + + template + XSIMD_INLINE batch, A> + nearbyint_as_int(batch const& self, requires_arch) noexcept + { + using U = as_integer_t; + return kernel::detail::apply_transform([](double x) noexcept -> U + { return std::nearbyint(x); }, + self); + } + + // nextafter + namespace detail + { + template ::value> + struct nextafter_kernel + { + using batch_type = batch; + + static XSIMD_INLINE batch_type next(batch_type const& b) noexcept + { + return b; + } + + static XSIMD_INLINE batch_type prev(batch_type const& b) noexcept + { + return b; + } + }; + + template + struct bitwise_cast_batch; + + template + struct bitwise_cast_batch + { + using type = batch; + }; + + template + struct bitwise_cast_batch + { + using type = batch; + }; + + template + struct nextafter_kernel + { + using batch_type = batch; + using int_batch = typename bitwise_cast_batch::type; + using int_type = typename int_batch::value_type; + + static XSIMD_INLINE batch_type next(const batch_type& b) noexcept + { + batch_type n = ::xsimd::bitwise_cast(::xsimd::bitwise_cast(b) + int_type(1)); + return select(b == constants::infinity(), b, n); + } + + static XSIMD_INLINE batch_type prev(const batch_type& b) noexcept + { + batch_type p = ::xsimd::bitwise_cast(::xsimd::bitwise_cast(b) - int_type(1)); + return select(b == constants::minusinfinity(), b, p); + } + }; + } + template + XSIMD_INLINE batch nextafter(batch const& from, batch const& to, requires_arch) noexcept + { + using kernel = detail::nextafter_kernel; + return select(from == to, from, + select(to > from, kernel::next(from), kernel::prev(from))); + } + + // pow + /* origin: boost/simd/arch/common/simd/function/pow.hpp*/ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + XSIMD_INLINE batch pow(batch const& self, batch const& other, requires_arch) noexcept + { + using batch_type = batch; + const auto zero = batch_type(0.); + auto negself = self < zero; + auto iszeropowpos = self == zero && other >= zero; + auto adj_self = select(iszeropowpos, batch_type(1), abs(self)); + batch_type z = exp(other * log(adj_self)); + z = select(iszeropowpos, zero, z); + z = select(is_odd(other) && negself, -z, z); + auto invalid = negself && !(is_flint(other) || isinf(other)); + return select(invalid, constants::nan(), z); + } + + template + XSIMD_INLINE batch, A> pow(const batch, A>& a, const batch, A>& z, requires_arch) noexcept + { + using cplx_batch = batch, A>; + using real_batch = typename cplx_batch::real_batch; + real_batch absa = abs(a); + real_batch arga = arg(a); + real_batch x = z.real(); + real_batch y = z.imag(); + real_batch r = pow(absa, x); + real_batch theta = x * arga; + real_batch ze(0); + auto cond = (y == ze); + r = select(cond, r, r * exp(-y * arga)); + theta = select(cond, theta, theta + y * log(absa)); + auto sincosTheta = xsimd::sincos(theta); + return select(absa == ze, cplx_batch(ze), cplx_batch(r * sincosTheta.second, r * sincosTheta.first)); + } + + template + inline batch, A> pow(const batch, A>& a, const batch& z, requires_arch) noexcept + { + using cplx_batch = batch, A>; + + auto absa = abs(a); + auto arga = arg(a); + auto r = pow(absa, z); + + auto theta = z * arga; + auto sincosTheta = xsimd::sincos(theta); + return select(absa == 0, cplx_batch(0), cplx_batch(r * sincosTheta.second, r * sincosTheta.first)); + } + + template + inline batch, A> pow(const batch& a, const batch, A>& z, requires_arch) noexcept + { + return pow(batch, A> { a, batch {} }, z); + } + + // reciprocal + template ::value, void>::type> + XSIMD_INLINE batch reciprocal(batch const& self, + requires_arch) noexcept + { + using batch_type = batch; + return div(batch_type(1), self); + } + + // reduce_add + template + XSIMD_INLINE std::complex reduce_add(batch, A> const& self, requires_arch) noexcept + { + return { reduce_add(self.real()), reduce_add(self.imag()) }; + } + + namespace detail + { + template + struct split_high + { + static constexpr T get(T i, T) + { + return i >= N ? (i % 2) : i + N; + } + }; + + template + XSIMD_INLINE T reduce(Op, batch const& self, std::integral_constant) noexcept + { + return self.get(0); + } + + template + XSIMD_INLINE T reduce(Op op, batch const& self, std::integral_constant) noexcept + { + using index_type = as_unsigned_integer_t; + batch split = swizzle(self, make_batch_constant>()); + return reduce(op, op(split, self), std::integral_constant()); + } + } + + // reduce_max + template + XSIMD_INLINE T reduce_max(batch const& self, requires_arch) noexcept + { + return detail::reduce([](batch const& x, batch const& y) + { return max(x, y); }, + self, std::integral_constant::size>()); + } + + // reduce_min + template + XSIMD_INLINE T reduce_min(batch const& self, requires_arch) noexcept + { + return detail::reduce([](batch const& x, batch const& y) + { return min(x, y); }, + self, std::integral_constant::size>()); + } + + // remainder + template + XSIMD_INLINE batch remainder(batch const& self, batch const& other, requires_arch) noexcept + { + return fnma(nearbyint(self / other), other, self); + } + template + XSIMD_INLINE batch remainder(batch const& self, batch const& other, requires_arch) noexcept + { + return fnma(nearbyint(self / other), other, self); + } + template ::value, void>::type> + XSIMD_INLINE batch remainder(batch const& self, batch const& other, requires_arch) noexcept + { + auto mod = self % other; + return select(mod <= other / 2, mod, mod - other); + } + + // select + template + XSIMD_INLINE batch, A> select(batch_bool const& cond, batch, A> const& true_br, batch, A> const& false_br, requires_arch) noexcept + { + return { select(cond, true_br.real(), false_br.real()), select(cond, true_br.imag(), false_br.imag()) }; + } + + // sign + template ::value, void>::type> + XSIMD_INLINE batch sign(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + batch_type res = select(self > batch_type(0), batch_type(1), batch_type(0)) - select(self < batch_type(0), batch_type(1), batch_type(0)); + return res; + } + + namespace detail + { + template + XSIMD_INLINE batch signf(batch const& self) noexcept + { + using batch_type = batch; + batch_type res = select(self > batch_type(0.f), batch_type(1.f), batch_type(0.f)) - select(self < batch_type(0.f), batch_type(1.f), batch_type(0.f)); +#ifdef XSIMD_NO_NANS + return res; +#else + return select(isnan(self), constants::nan(), res); +#endif + } + } + + template + XSIMD_INLINE batch sign(batch const& self, requires_arch) noexcept + { + return detail::signf(self); + } + template + XSIMD_INLINE batch sign(batch const& self, requires_arch) noexcept + { + return detail::signf(self); + } + template + XSIMD_INLINE batch, A> sign(const batch, A>& z, requires_arch) noexcept + { + using batch_type = batch, A>; + using real_batch = typename batch_type::real_batch; + auto rz = z.real(); + auto iz = z.imag(); + return select(rz != real_batch(0.), + batch_type(sign(rz)), + batch_type(sign(iz))); + } + + // signnz + template ::value, void>::type> + XSIMD_INLINE batch signnz(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + return (self >> (sizeof(T) * 8 - 1)) | batch_type(1.); + } + + namespace detail + { + template + XSIMD_INLINE batch signnzf(batch const& self) noexcept + { + using batch_type = batch; +#ifndef XSIMD_NO_NANS + return select(isnan(self), constants::nan(), batch_type(1.) | (constants::signmask() & self)); +#else + return batch_type(1.) | (constants::signmask() & self); +#endif + } + } + + template + XSIMD_INLINE batch signnz(batch const& self, requires_arch) noexcept + { + return detail::signnzf(self); + } + template + XSIMD_INLINE batch signnz(batch const& self, requires_arch) noexcept + { + return detail::signnzf(self); + } + + // sqrt + template + XSIMD_INLINE batch, A> sqrt(batch, A> const& z, requires_arch) noexcept + { + + constexpr T csqrt_scale_factor = std::is_same::value ? 6.7108864e7f : 1.8014398509481984e16; + constexpr T csqrt_scale = std::is_same::value ? 1.220703125e-4f : 7.450580596923828125e-9; + using batch_type = batch, A>; + using real_batch = batch; + real_batch x = z.real(); + real_batch y = z.imag(); + real_batch sqrt_x = sqrt(fabs(x)); + real_batch sqrt_hy = sqrt(0.5 * fabs(y)); + auto cond = (fabs(x) > real_batch(4.) || fabs(y) > real_batch(4.)); + x = select(cond, x * 0.25, x * csqrt_scale_factor); + y = select(cond, y * 0.25, y * csqrt_scale_factor); + real_batch scale = select(cond, real_batch(2.), real_batch(csqrt_scale)); + real_batch r = abs(batch_type(x, y)); + + auto condxp = x > real_batch(0.); + real_batch t0 = select(condxp, xsimd::sqrt(0.5 * (r + x)), xsimd::sqrt(0.5 * (r - x))); + real_batch r0 = scale * fabs((0.5 * y) / t0); + t0 *= scale; + real_batch t = select(condxp, t0, r0); + r = select(condxp, r0, t0); + batch_type resg = select(y < real_batch(0.), batch_type(t, -r), batch_type(t, r)); + real_batch ze(0.); + + return select(y == ze, + select(x == ze, + batch_type(ze, ze), + select(x < ze, batch_type(ze, sqrt_x), batch_type(sqrt_x, ze))), + select(x == ze, + select(y > ze, batch_type(sqrt_hy, sqrt_hy), batch_type(sqrt_hy, -sqrt_hy)), + resg)); + } + + // tgamma + + namespace detail + { + /* origin: boost/simd/arch/common/detail/generic/stirling_kernel.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + struct stirling_kernel; + + template + struct stirling_kernel> + { + using batch_type = batch; + static XSIMD_INLINE batch_type compute(const batch_type& x) noexcept + { + return horner(x); + } + + static XSIMD_INLINE batch_type split_limit() noexcept + { + return batch_type(bit_cast(uint32_t(0x41d628f6))); + } + + static XSIMD_INLINE batch_type large_limit() noexcept + { + return batch_type(bit_cast(uint32_t(0x420c28f3))); + } + }; + + template + struct stirling_kernel> + { + using batch_type = batch; + static XSIMD_INLINE batch_type compute(const batch_type& x) noexcept + { + return horner(x); + } + + static XSIMD_INLINE batch_type split_limit() noexcept + { + return batch_type(bit_cast(uint64_t(0x4061e083ba3443d4))); + } + + static XSIMD_INLINE batch_type large_limit() noexcept + { + return batch_type(bit_cast(uint64_t(0x4065800000000000))); + } + }; + + /* origin: boost/simd/arch/common/simd/function/stirling.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + XSIMD_INLINE batch stirling(const batch& a) noexcept + { + using batch_type = batch; + const batch_type stirlingsplitlim = stirling_kernel::split_limit(); + const batch_type stirlinglargelim = stirling_kernel::large_limit(); + batch_type x = select(a >= batch_type(0.), a, constants::nan()); + batch_type w = batch_type(1.) / x; + w = fma(w, stirling_kernel::compute(w), batch_type(1.)); + batch_type y = exp(-x); + auto test = (x < stirlingsplitlim); + batch_type z = x - batch_type(0.5); + z = select(test, z, batch_type(0.5) * z); + batch_type v = exp(z * log(abs(x))); + y *= v; + y = select(test, y, y * v); + y *= constants::sqrt_2pi() * w; +#ifndef XSIMD_NO_INFINITIES + y = select(isinf(x), x, y); +#endif + return select(x > stirlinglargelim, constants::infinity(), y); + } + + /* origin: boost/simd/arch/common/detail/generic/gamma_kernel.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + struct tgamma_kernel; + + template + struct tgamma_kernel> + { + using batch_type = batch; + static XSIMD_INLINE batch_type compute(const batch_type& x) noexcept + { + return horner(x); + } + }; + + template + struct tgamma_kernel> + { + using batch_type = batch; + static XSIMD_INLINE batch_type compute(const batch_type& x) noexcept + { + return horner(x) + / horner(x); + } + }; + + /* origin: boost/simd/arch/common/simd/function/gamma.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + XSIMD_INLINE B tgamma_large_negative(const B& a) noexcept + { + B st = stirling(a); + B p = floor(a); + B sgngam = select(is_even(p), -B(1.), B(1.)); + B z = a - p; + auto test2 = z < B(0.5); + z = select(test2, z - B(1.), z); + z = a * sin(z, trigo_pi_tag()); + z = abs(z); + return sgngam * constants::pi() / (z * st); + } + + template + XSIMD_INLINE B tgamma_other(const B& a, const BB& test) noexcept + { + B x = select(test, B(2.), a); +#ifndef XSIMD_NO_INFINITIES + auto inf_result = (a == constants::infinity()); + x = select(inf_result, B(2.), x); +#endif + B z = B(1.); + auto test1 = (x >= B(3.)); + while (any(test1)) + { + x = select(test1, x - B(1.), x); + z = select(test1, z * x, z); + test1 = (x >= B(3.)); + } + test1 = (x < B(0.)); + while (any(test1)) + { + z = select(test1, z / x, z); + x = select(test1, x + B(1.), x); + test1 = (x < B(0.)); + } + auto test2 = (x < B(2.)); + while (any(test2)) + { + z = select(test2, z / x, z); + x = select(test2, x + B(1.), x); + test2 = (x < B(2.)); + } + x = z * tgamma_kernel::compute(x - B(2.)); +#ifndef XSIMD_NO_INFINITIES + return select(inf_result, a, x); +#else + return x; +#endif + } + } + + template + XSIMD_INLINE batch tgamma(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + auto nan_result = (self < batch_type(0.) && is_flint(self)); +#ifndef XSIMD_NO_INVALIDS + nan_result = isnan(self) || nan_result; +#endif + batch_type q = abs(self); + auto test = (self < batch_type(-33.)); + batch_type r = constants::nan(); + if (any(test)) + { + r = detail::tgamma_large_negative(q); + if (all(test)) + return select(nan_result, constants::nan(), r); + } + batch_type r1 = detail::tgamma_other(self, test); + batch_type r2 = select(test, r, r1); + return select(self == batch_type(0.), copysign(constants::infinity(), self), select(nan_result, constants::nan(), r2)); + } + + } + +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/generic/xsimd_generic_memory.hpp b/include/onnxruntime/xsimd/arch/generic/xsimd_generic_memory.hpp new file mode 100644 index 0000000000000..fbe1bbc136620 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/generic/xsimd_generic_memory.hpp @@ -0,0 +1,672 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_GENERIC_MEMORY_HPP +#define XSIMD_GENERIC_MEMORY_HPP + +#include +#include +#include + +#include "../../types/xsimd_batch_constant.hpp" +#include "./xsimd_generic_details.hpp" + +namespace xsimd +{ + template + struct batch_constant; + + template + struct batch_bool_constant; + + namespace kernel + { + + using namespace types; + + // compress + namespace detail + { + template + XSIMD_INLINE batch create_compress_swizzle_mask(I bitmask, ::xsimd::detail::index_sequence) + { + batch swizzle_mask(IT(0)); + alignas(A::alignment()) IT mask_buffer[batch::size] = { Is... }; + size_t inserted = 0; + for (size_t i = 0; i < sizeof...(Is); ++i) + if ((bitmask >> i) & 1u) + std::swap(mask_buffer[inserted++], mask_buffer[i]); + return batch::load_aligned(&mask_buffer[0]); + } + } + + template + XSIMD_INLINE batch + compress(batch const& x, batch_bool const& mask, + kernel::requires_arch) noexcept + { + using IT = as_unsigned_integer_t; + constexpr std::size_t size = batch_bool::size; + auto bitmask = mask.mask(); + auto z = select(mask, x, batch((T)0)); + auto compress_mask = detail::create_compress_swizzle_mask(bitmask, ::xsimd::detail::make_index_sequence()); + return swizzle(z, compress_mask); + } + + // expand + namespace detail + { + template + XSIMD_INLINE batch create_expand_swizzle_mask(I bitmask, ::xsimd::detail::index_sequence) + { + batch swizzle_mask(IT(0)); + IT j = 0; + (void)std::initializer_list { ((swizzle_mask = insert(swizzle_mask, j, index())), (j += ((bitmask >> Is) & 1u)), true)... }; + return swizzle_mask; + } + } + + template + XSIMD_INLINE batch + expand(batch const& x, batch_bool const& mask, + kernel::requires_arch) noexcept + { + constexpr std::size_t size = batch_bool::size; + auto bitmask = mask.mask(); + auto swizzle_mask = detail::create_expand_swizzle_mask, A>(bitmask, ::xsimd::detail::make_index_sequence()); + auto z = swizzle(x, swizzle_mask); + return select(mask, z, batch(T(0))); + } + + // extract_pair + template + XSIMD_INLINE batch extract_pair(batch const& self, batch const& other, std::size_t i, requires_arch) noexcept + { + constexpr std::size_t size = batch::size; + assert(i < size && "index in bounds"); + + alignas(A::alignment()) T self_buffer[size]; + self.store_aligned(self_buffer); + + alignas(A::alignment()) T other_buffer[size]; + other.store_aligned(other_buffer); + + alignas(A::alignment()) T concat_buffer[size]; + + for (std::size_t j = 0; j < (size - i); ++j) + { + concat_buffer[j] = other_buffer[i + j]; + if (j < i) + { + concat_buffer[size - 1 - j] = self_buffer[i - 1 - j]; + } + } + return batch::load_aligned(concat_buffer); + } + + // gather + namespace detail + { + // Not using XSIMD_INLINE here as it makes msvc hand got ever on avx512 + template ::type = 0> + inline batch gather(U const* src, batch const& index, + ::xsimd::index I) noexcept + { + return insert(batch {}, static_cast(src[index.get(I)]), I); + } + + template ::type = 0> + inline batch + gather(U const* src, batch const& index, ::xsimd::index I) noexcept + { + static_assert(N <= batch::size, "Incorrect value in recursion!"); + + const auto test = gather(src, index, {}); + return insert(test, static_cast(src[index.get(I)]), I); + } + } // namespace detail + + template + XSIMD_INLINE batch + gather(batch const&, T const* src, batch const& index, + kernel::requires_arch) noexcept + { + static_assert(batch::size == batch::size, + "Index and destination sizes must match"); + + return detail::gather::size - 1, T, A>(src, index, {}); + } + + // Gather with runtime indexes and mismatched strides. + template + XSIMD_INLINE detail::sizes_mismatch_t> + gather(batch const&, U const* src, batch const& index, + kernel::requires_arch) noexcept + { + static_assert(batch::size == batch::size, + "Index and destination sizes must match"); + + return detail::gather::size - 1, T, A>(src, index, {}); + } + + // Gather with runtime indexes and matching strides. + template + XSIMD_INLINE detail::stride_match_t> + gather(batch const&, U const* src, batch const& index, + kernel::requires_arch) noexcept + { + static_assert(batch::size == batch::size, + "Index and destination sizes must match"); + + return batch_cast(kernel::gather(batch {}, src, index, A {})); + } + + // insert + template + XSIMD_INLINE batch insert(batch const& self, T val, index, requires_arch) noexcept + { + struct index_mask + { + static constexpr bool get(size_t index, size_t /* size*/) + { + return index != I; + } + }; + batch tmp(val); + return select(make_batch_bool_constant(), self, tmp); + } + + // get + template + XSIMD_INLINE T get(batch const& self, ::xsimd::index, requires_arch) noexcept + { + alignas(A::alignment()) T buffer[batch::size]; + self.store_aligned(&buffer[0]); + return buffer[I]; + } + + template + XSIMD_INLINE T get(batch_bool const& self, ::xsimd::index, requires_arch) noexcept + { + alignas(A::alignment()) T buffer[batch_bool::size]; + self.store_aligned(&buffer[0]); + return buffer[I]; + } + + template + XSIMD_INLINE auto get(batch, A> const& self, ::xsimd::index, requires_arch) noexcept -> typename batch, A>::value_type + { + alignas(A::alignment()) T buffer[batch, A>::size]; + self.store_aligned(&buffer[0]); + return buffer[I]; + } + + template + XSIMD_INLINE T get(batch const& self, std::size_t i, requires_arch) noexcept + { + alignas(A::alignment()) T buffer[batch::size]; + self.store_aligned(&buffer[0]); + return buffer[i]; + } + + template + XSIMD_INLINE T get(batch_bool const& self, std::size_t i, requires_arch) noexcept + { + alignas(A::alignment()) bool buffer[batch_bool::size]; + self.store_aligned(&buffer[0]); + return buffer[i]; + } + + template + XSIMD_INLINE auto get(batch, A> const& self, std::size_t i, requires_arch) noexcept -> typename batch, A>::value_type + { + using T2 = typename batch, A>::value_type; + alignas(A::alignment()) T2 buffer[batch, A>::size]; + self.store_aligned(&buffer[0]); + return buffer[i]; + } + + // load_aligned + namespace detail + { + template + XSIMD_INLINE batch load_aligned(T_in const* mem, convert, requires_arch, with_fast_conversion) noexcept + { + using batch_type_in = batch; + using batch_type_out = batch; + return fast_cast(batch_type_in::load_aligned(mem), batch_type_out(), A {}); + } + template + XSIMD_INLINE batch load_aligned(T_in const* mem, convert, requires_arch, with_slow_conversion) noexcept + { + static_assert(!std::is_same::value, "there should be a direct load for this type combination"); + using batch_type_out = batch; + alignas(A::alignment()) T_out buffer[batch_type_out::size]; + std::copy(mem, mem + batch_type_out::size, std::begin(buffer)); + return batch_type_out::load_aligned(buffer); + } + } + template + XSIMD_INLINE batch load_aligned(T_in const* mem, convert cvt, requires_arch) noexcept + { + return detail::load_aligned(mem, cvt, A {}, detail::conversion_type {}); + } + + // load_unaligned + namespace detail + { + template + XSIMD_INLINE batch load_unaligned(T_in const* mem, convert, requires_arch, with_fast_conversion) noexcept + { + using batch_type_in = batch; + using batch_type_out = batch; + return fast_cast(batch_type_in::load_unaligned(mem), batch_type_out(), A {}); + } + + template + XSIMD_INLINE batch load_unaligned(T_in const* mem, convert cvt, requires_arch, with_slow_conversion) noexcept + { + static_assert(!std::is_same::value, "there should be a direct load for this type combination"); + return load_aligned(mem, cvt, generic {}, with_slow_conversion {}); + } + } + template + XSIMD_INLINE batch load_unaligned(T_in const* mem, convert cvt, requires_arch) noexcept + { + return detail::load_unaligned(mem, cvt, generic {}, detail::conversion_type {}); + } + + // rotate_right + template + XSIMD_INLINE batch rotate_right(batch const& self, requires_arch) noexcept + { + struct rotate_generator + { + static constexpr size_t get(size_t index, size_t size) + { + return (index - N) % size; + } + }; + + return swizzle(self, make_batch_constant, A, rotate_generator>(), A {}); + } + + template + XSIMD_INLINE batch, A> rotate_right(batch, A> const& self, requires_arch) noexcept + { + return { rotate_right(self.real()), rotate_right(self.imag()) }; + } + + // rotate_left + template + XSIMD_INLINE batch rotate_left(batch const& self, requires_arch) noexcept + { + struct rotate_generator + { + static constexpr size_t get(size_t index, size_t size) + { + return (index + N) % size; + } + }; + + return swizzle(self, make_batch_constant, A, rotate_generator>(), A {}); + } + + template + XSIMD_INLINE batch, A> rotate_left(batch, A> const& self, requires_arch) noexcept + { + return { rotate_left(self.real()), rotate_left(self.imag()) }; + } + + // Scatter with runtime indexes. + namespace detail + { + template ::type = 0> + XSIMD_INLINE void scatter(batch const& src, U* dst, + batch const& index, + ::xsimd::index I) noexcept + { + dst[index.get(I)] = static_cast(src.get(I)); + } + + template ::type = 0> + XSIMD_INLINE void + scatter(batch const& src, U* dst, batch const& index, + ::xsimd::index I) noexcept + { + static_assert(N <= batch::size, "Incorrect value in recursion!"); + + kernel::detail::scatter( + src, dst, index, {}); + dst[index.get(I)] = static_cast(src.get(I)); + } + } // namespace detail + + template + XSIMD_INLINE void + scatter(batch const& src, T* dst, + batch const& index, + kernel::requires_arch) noexcept + { + static_assert(batch::size == batch::size, + "Source and index sizes must match"); + kernel::detail::scatter::size - 1, T, A, T, V>( + src, dst, index, {}); + } + + template + XSIMD_INLINE detail::sizes_mismatch_t + scatter(batch const& src, U* dst, + batch const& index, + kernel::requires_arch) noexcept + { + static_assert(batch::size == batch::size, + "Source and index sizes must match"); + kernel::detail::scatter::size - 1, T, A, U, V>( + src, dst, index, {}); + } + + template + XSIMD_INLINE detail::stride_match_t + scatter(batch const& src, U* dst, + batch const& index, + kernel::requires_arch) noexcept + { + static_assert(batch::size == batch::size, + "Source and index sizes must match"); + const auto tmp = batch_cast(src); + kernel::scatter(tmp, dst, index, A {}); + } + + // shuffle + namespace detail + { + constexpr bool is_swizzle_fst(size_t) + { + return true; + } + template + constexpr bool is_swizzle_fst(size_t bsize, ITy index, ITys... indices) + { + return index < bsize && is_swizzle_fst(bsize, indices...); + } + constexpr bool is_swizzle_snd(size_t) + { + return true; + } + template + constexpr bool is_swizzle_snd(size_t bsize, ITy index, ITys... indices) + { + return index >= bsize && is_swizzle_snd(bsize, indices...); + } + + constexpr bool is_zip_lo(size_t) + { + return true; + } + + template + constexpr bool is_zip_lo(size_t, ITy) + { + return false; + } + + template + constexpr bool is_zip_lo(size_t bsize, ITy0 index0, ITy1 index1, ITys... indices) + { + return index0 == (bsize - (sizeof...(indices) + 2)) && index1 == (2 * bsize - (sizeof...(indices) + 2)) && is_zip_lo(bsize, indices...); + } + + constexpr bool is_zip_hi(size_t) + { + return true; + } + + template + constexpr bool is_zip_hi(size_t, ITy) + { + return false; + } + + template + constexpr bool is_zip_hi(size_t bsize, ITy0 index0, ITy1 index1, ITys... indices) + { + return index0 == (bsize / 2 + bsize - (sizeof...(indices) + 2)) && index1 == (bsize / 2 + 2 * bsize - (sizeof...(indices) + 2)) && is_zip_hi(bsize, indices...); + } + + constexpr bool is_select(size_t) + { + return true; + } + + template + constexpr bool is_select(size_t bsize, ITy index, ITys... indices) + { + return (index < bsize ? index : index - bsize) == (bsize - sizeof...(ITys)) && is_select(bsize, indices...); + } + + } + + template + XSIMD_INLINE batch shuffle(batch const& x, batch const& y, batch_constant, requires_arch) noexcept + { + constexpr size_t bsize = sizeof...(Indices); + static_assert(bsize == batch::size, "valid shuffle"); + + // Detect common patterns + XSIMD_IF_CONSTEXPR(detail::is_swizzle_fst(bsize, Indices...)) + { + return swizzle(x, batch_constant= bsize) ? 0 /* never happens */ : Indices)...>()); + } + + XSIMD_IF_CONSTEXPR(detail::is_swizzle_snd(bsize, Indices...)) + { + return swizzle(y, batch_constant= bsize) ? (Indices - bsize) : 0 /* never happens */)...>()); + } + + XSIMD_IF_CONSTEXPR(detail::is_zip_lo(bsize, Indices...)) + { + return zip_lo(x, y); + } + + XSIMD_IF_CONSTEXPR(detail::is_zip_hi(bsize, Indices...)) + { + return zip_hi(x, y); + } + + XSIMD_IF_CONSTEXPR(detail::is_select(bsize, Indices...)) + { + return select(batch_bool_constant(), x, y); + } + +#if defined(__has_builtin) && !defined(XSIMD_WITH_EMULATED) +#if __has_builtin(__builtin_shufflevector) +#define builtin_shuffle __builtin_shufflevector +#endif +#endif + +#if defined(builtin_shuffle) + typedef T vty __attribute__((__vector_size__(sizeof(batch)))); + return (typename batch::register_type)builtin_shuffle((vty)x.data, (vty)y.data, Indices...); + +// FIXME: my experiments show that GCC only correctly optimizes this builtin +// starting at GCC 13, where it already has __builtin_shuffle_vector +// +// #elif __has_builtin(__builtin_shuffle) || GCC >= 6 +// typedef ITy integer_vector_type __attribute__((vector_size(sizeof(batch)))); +// return __builtin_shuffle(x.data, y.data, integer_vector_type{Indices...}); +#else + // Use a generic_pattern. It is suboptimal but clang optimizes this + // pretty well. + batch x_lane = swizzle(x, batch_constant= bsize) ? (Indices - bsize) : Indices)...>()); + batch y_lane = swizzle(y, batch_constant= bsize) ? (Indices - bsize) : Indices)...>()); + batch_bool_constant select_x_lane; + return select(select_x_lane, x_lane, y_lane); +#endif + } + + // store + template + XSIMD_INLINE void store(batch_bool const& self, bool* mem, requires_arch) noexcept + { + using batch_type = batch; + constexpr auto size = batch_bool::size; + alignas(A::alignment()) T buffer[size]; + kernel::store_aligned(&buffer[0], batch_type(self), A {}); + for (std::size_t i = 0; i < size; ++i) + mem[i] = bool(buffer[i]); + } + + // store_aligned + template + XSIMD_INLINE void store_aligned(T_out* mem, batch const& self, requires_arch) noexcept + { + static_assert(!std::is_same::value, "there should be a direct store for this type combination"); + alignas(A::alignment()) T_in buffer[batch::size]; + store_aligned(&buffer[0], self); + std::copy(std::begin(buffer), std::end(buffer), mem); + } + + // store_unaligned + template + XSIMD_INLINE void store_unaligned(T_out* mem, batch const& self, requires_arch) noexcept + { + static_assert(!std::is_same::value, "there should be a direct store for this type combination"); + return store_aligned(mem, self, generic {}); + } + + // swizzle + template + XSIMD_INLINE batch, A> swizzle(batch, A> const& self, batch_constant mask, requires_arch) noexcept + { + return { swizzle(self.real(), mask), swizzle(self.imag(), mask) }; + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + constexpr size_t size = batch::size; + alignas(A::alignment()) T self_buffer[size]; + store_aligned(&self_buffer[0], self); + + alignas(A::alignment()) ITy mask_buffer[size]; + store_aligned(&mask_buffer[0], mask); + + alignas(A::alignment()) T out_buffer[size]; + for (size_t i = 0; i < size; ++i) + out_buffer[i] = self_buffer[mask_buffer[i]]; + return batch::load_aligned(out_buffer); + } + + template + XSIMD_INLINE batch, A> swizzle(batch, A> const& self, batch mask, requires_arch) noexcept + { + return { swizzle(self.real(), mask), swizzle(self.imag(), mask) }; + } + + // load_complex_aligned + namespace detail + { + template + XSIMD_INLINE batch, A> load_complex(batch const& /*hi*/, batch const& /*lo*/, requires_arch) noexcept + { + static_assert(std::is_same::value, "load_complex not implemented for the required architecture"); + } + + template + XSIMD_INLINE batch complex_high(batch, A> const& /*src*/, requires_arch) noexcept + { + static_assert(std::is_same::value, "complex_high not implemented for the required architecture"); + } + + template + XSIMD_INLINE batch complex_low(batch, A> const& /*src*/, requires_arch) noexcept + { + static_assert(std::is_same::value, "complex_low not implemented for the required architecture"); + } + } + + template + XSIMD_INLINE batch, A> load_complex_aligned(std::complex const* mem, convert>, requires_arch) noexcept + { + using real_batch = batch; + T_in const* buffer = reinterpret_cast(mem); + real_batch hi = real_batch::load_aligned(buffer), + lo = real_batch::load_aligned(buffer + real_batch::size); + return detail::load_complex(hi, lo, A {}); + } + + // load_complex_unaligned + template + XSIMD_INLINE batch, A> load_complex_unaligned(std::complex const* mem, convert>, requires_arch) noexcept + { + using real_batch = batch; + T_in const* buffer = reinterpret_cast(mem); + real_batch hi = real_batch::load_unaligned(buffer), + lo = real_batch::load_unaligned(buffer + real_batch::size); + return detail::load_complex(hi, lo, A {}); + } + + // store_complex_aligned + template + XSIMD_INLINE void store_complex_aligned(std::complex* dst, batch, A> const& src, requires_arch) noexcept + { + using real_batch = batch; + real_batch hi = detail::complex_high(src, A {}); + real_batch lo = detail::complex_low(src, A {}); + T_out* buffer = reinterpret_cast(dst); + lo.store_aligned(buffer); + hi.store_aligned(buffer + real_batch::size); + } + + // store_compelx_unaligned + template + XSIMD_INLINE void store_complex_unaligned(std::complex* dst, batch, A> const& src, requires_arch) noexcept + { + using real_batch = batch; + real_batch hi = detail::complex_high(src, A {}); + real_batch lo = detail::complex_low(src, A {}); + T_out* buffer = reinterpret_cast(dst); + lo.store_unaligned(buffer); + hi.store_unaligned(buffer + real_batch::size); + } + + // transpose + template + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept + { + assert((matrix_end - matrix_begin == batch::size) && "correctly sized matrix"); + (void)matrix_end; + alignas(A::alignment()) T scratch_buffer[batch::size * batch::size]; + for (size_t i = 0; i < batch::size; ++i) + { + matrix_begin[i].store_aligned(&scratch_buffer[i * batch::size]); + } + // FIXME: this is super naive we can probably do better. + for (size_t i = 0; i < batch::size; ++i) + { + for (size_t j = 0; j < i; ++j) + { + std::swap(scratch_buffer[i * batch::size + j], + scratch_buffer[j * batch::size + i]); + } + } + for (size_t i = 0; i < batch::size; ++i) + { + matrix_begin[i] = batch::load_aligned(&scratch_buffer[i * batch::size]); + } + } + + } + +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/generic/xsimd_generic_rounding.hpp b/include/onnxruntime/xsimd/arch/generic/xsimd_generic_rounding.hpp new file mode 100644 index 0000000000000..daf7b58ea718d --- /dev/null +++ b/include/onnxruntime/xsimd/arch/generic/xsimd_generic_rounding.hpp @@ -0,0 +1,72 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_GENERIC_ROUNDING_HPP +#define XSIMD_GENERIC_ROUNDING_HPP + +#include "./xsimd_generic_details.hpp" + +namespace xsimd +{ + + namespace kernel + { + + using namespace types; + + // ceil + template + XSIMD_INLINE batch ceil(batch const& self, requires_arch) noexcept + { + batch truncated_self = trunc(self); + return select(truncated_self < self, truncated_self + 1, truncated_self); + } + + // floor + template + XSIMD_INLINE batch floor(batch const& self, requires_arch) noexcept + { + batch truncated_self = trunc(self); + return select(truncated_self > self, truncated_self - 1, truncated_self); + } + + // round + template + XSIMD_INLINE batch round(batch const& self, requires_arch) noexcept + { + auto v = abs(self); + auto c = ceil(v); + auto cp = select(c - 0.5 > v, c - 1, c); + return select(v > constants::maxflint>(), self, copysign(cp, self)); + } + + // trunc + template ::value, void>::type> + XSIMD_INLINE batch trunc(batch const& self, requires_arch) noexcept + { + return self; + } + template + XSIMD_INLINE batch trunc(batch const& self, requires_arch) noexcept + { + return select(abs(self) < constants::maxflint>(), to_float(to_int(self)), self); + } + template + XSIMD_INLINE batch trunc(batch const& self, requires_arch) noexcept + { + return select(abs(self) < constants::maxflint>(), to_float(to_int(self)), self); + } + + } + +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/generic/xsimd_generic_trigo.hpp b/include/onnxruntime/xsimd/arch/generic/xsimd_generic_trigo.hpp new file mode 100644 index 0000000000000..b1bb68f25e9f9 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/generic/xsimd_generic_trigo.hpp @@ -0,0 +1,969 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_GENERIC_TRIGO_HPP +#define XSIMD_GENERIC_TRIGO_HPP + +#include "./xsimd_generic_details.hpp" + +#include + +namespace xsimd +{ + + namespace kernel + { + /* origin: boost/simd/arch/common/detail/simd/trig_base.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + + using namespace types; + + // acos + template + XSIMD_INLINE batch acos(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + batch_type x = abs(self); + auto x_larger_05 = x > batch_type(0.5); + x = select(x_larger_05, sqrt(fma(batch_type(-0.5), x, batch_type(0.5))), self); + x = asin(x); + x = select(x_larger_05, x + x, x); + x = select(self < batch_type(-0.5), constants::pi() - x, x); + return select(x_larger_05, x, constants::pio2() - x); + } + template + XSIMD_INLINE batch, A> acos(const batch, A>& z, requires_arch) noexcept + { + using batch_type = batch, A>; + using real_batch = typename batch_type::real_batch; + batch_type tmp = asin(z); + return { constants::pio2() - tmp.real(), -tmp.imag() }; + } + + // acosh + /* origin: boost/simd/arch/common/simd/function/acosh.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + XSIMD_INLINE batch acosh(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + batch_type x = self - batch_type(1.); + auto test = x > constants::oneotwoeps(); + batch_type z = select(test, self, x + sqrt(x + x + x * x)); + batch_type l1pz = log1p(z); + return select(test, l1pz + constants::log_2(), l1pz); + } + template + XSIMD_INLINE batch, A> acosh(const batch, A>& z, requires_arch) noexcept + { + using batch_type = batch, A>; + batch_type w = acos(z); + w = batch_type(-w.imag(), w.real()); + return w; + } + + // asin + template + XSIMD_INLINE batch asin(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + batch_type x = abs(self); + batch_type sign = bitofsign(self); + auto x_larger_05 = x > batch_type(0.5); + batch_type z = select(x_larger_05, batch_type(0.5) * (batch_type(1.) - x), x * x); + x = select(x_larger_05, sqrt(z), x); + batch_type z1 = detail::horner(z); + z1 = fma(z1, z * x, x); + z = select(x_larger_05, constants::pio2() - (z1 + z1), z1); + return z ^ sign; + } + template + XSIMD_INLINE batch asin(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + batch_type x = abs(self); + auto small_cond = x < constants::sqrteps(); + batch_type ct1 = batch_type(bit_cast(int64_t(0x3fe4000000000000))); + batch_type zz1 = batch_type(1.) - x; + batch_type vp = zz1 * detail::horner(zz1) / detail::horner1(zz1); + zz1 = sqrt(zz1 + zz1); + batch_type z = constants::pio4() - zz1; + zz1 = fms(zz1, vp, constants::pio_2lo()); + z = z - zz1; + zz1 = z + constants::pio4(); + batch_type zz2 = self * self; + z = zz2 * detail::horner(zz2) / detail::horner1(zz2); + zz2 = fma(x, z, x); + return select(x > batch_type(1.), constants::nan(), + select(small_cond, x, + select(x > ct1, zz1, zz2)) + ^ bitofsign(self)); + } + template + XSIMD_INLINE batch, A> asin(const batch, A>& z, requires_arch) noexcept + { + using batch_type = batch, A>; + using real_batch = typename batch_type::real_batch; + real_batch x = z.real(); + real_batch y = z.imag(); + + batch_type ct(-y, x); + batch_type zz(real_batch(1.) - (x - y) * (x + y), -2 * x * y); + zz = log(ct + sqrt(zz)); + batch_type resg(zz.imag(), -zz.real()); + + return select(y == real_batch(0.), + select(fabs(x) > real_batch(1.), + batch_type(constants::pio2(), real_batch(0.)), + batch_type(asin(x), real_batch(0.))), + resg); + } + + // asinh + /* origin: boost/simd/arch/common/simd/function/asinh.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + namespace detail + { + template ::value, void>::type> + XSIMD_INLINE batch + average(const batch& x1, const batch& x2) noexcept + { + return (x1 & x2) + ((x1 ^ x2) >> 1); + } + + template + XSIMD_INLINE batch + averagef(const batch& x1, const batch& x2) noexcept + { + using batch_type = batch; + return fma(x1, batch_type(0.5), x2 * batch_type(0.5)); + } + template + XSIMD_INLINE batch average(batch const& x1, batch const& x2) noexcept + { + return averagef(x1, x2); + } + template + XSIMD_INLINE batch average(batch const& x1, batch const& x2) noexcept + { + return averagef(x1, x2); + } + } + template + XSIMD_INLINE batch asinh(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + batch_type x = abs(self); + auto lthalf = x < batch_type(0.5); + batch_type x2 = x * x; + batch_type bts = bitofsign(self); + batch_type z(0.); + if (any(lthalf)) + { + z = detail::horner(x2) + * x; + if (all(lthalf)) + return z ^ bts; + } + batch_type tmp = select(x > constants::oneosqrteps(), x, detail::average(x, hypot(batch_type(1.), x))); +#ifndef XSIMD_NO_NANS + return select(isnan(self), constants::nan(), select(lthalf, z, log(tmp) + constants::log_2()) ^ bts); +#else + return select(lthalf, z, log(tmp) + constants::log_2()) ^ bts; +#endif + } + template + XSIMD_INLINE batch asinh(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + batch_type x = abs(self); + auto test = x > constants::oneosqrteps(); + batch_type z = select(test, x - batch_type(1.), x + x * x / (batch_type(1.) + hypot(batch_type(1.), x))); +#ifndef XSIMD_NO_INFINITIES + z = select(x == constants::infinity(), x, z); +#endif + batch_type l1pz = log1p(z); + z = select(test, l1pz + constants::log_2(), l1pz); + return bitofsign(self) ^ z; + } + template + XSIMD_INLINE batch, A> asinh(const batch, A>& z, requires_arch) noexcept + { + using batch_type = batch, A>; + batch_type w = asin(batch_type(-z.imag(), z.real())); + w = batch_type(w.imag(), -w.real()); + return w; + } + + // atan + namespace detail + { + template + static XSIMD_INLINE batch kernel_atan(const batch& x, const batch& recx) noexcept + { + using batch_type = batch; + const auto flag1 = x < constants::tan3pio8(); + const auto flag2 = (x >= batch_type(bit_cast((uint32_t)0x3ed413cd))) && flag1; + batch_type yy = select(flag1, batch_type(0.), constants::pio2()); + yy = select(flag2, constants::pio4(), yy); + batch_type xx = select(flag1, x, -recx); + xx = select(flag2, (x - batch_type(1.)) / (x + batch_type(1.)), xx); + const batch_type z = xx * xx; + batch_type z1 = detail::horner(z); + z1 = fma(xx, z1 * z, xx); + z1 = select(flag2, z1 + constants::pio_4lo(), z1); + z1 = select(!flag1, z1 + constants::pio_2lo(), z1); + return yy + z1; + } + template + static XSIMD_INLINE batch kernel_atan(const batch& x, const batch& recx) noexcept + { + using batch_type = batch; + const auto flag1 = x < constants::tan3pio8(); + const auto flag2 = (x >= constants::tanpio8()) && flag1; + batch_type yy = select(flag1, batch_type(0.), constants::pio2()); + yy = select(flag2, constants::pio4(), yy); + batch_type xx = select(flag1, x, -recx); + xx = select(flag2, (x - batch_type(1.)) / (x + batch_type(1.)), xx); + batch_type z = xx * xx; + z *= detail::horner(z) + / detail::horner1(z); + z = fma(xx, z, xx); + z = select(flag2, z + constants::pio_4lo(), z); + z = z + select(flag1, batch_type(0.), constants::pio_2lo()); + return yy + z; + } + } + template + XSIMD_INLINE batch atan(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + const batch_type absa = abs(self); + const batch_type x = detail::kernel_atan(absa, batch_type(1.) / absa); + return x ^ bitofsign(self); + } + template + XSIMD_INLINE batch, A> atan(const batch, A>& z, requires_arch) noexcept + { + using batch_type = batch, A>; + using real_batch = typename batch_type::real_batch; + real_batch x = z.real(); + real_batch y = z.imag(); + real_batch x2 = x * x; + real_batch one(1.); + real_batch a = one - x2 - (y * y); + real_batch w = 0.5 * atan2(2. * x, a); + real_batch num = y + one; + num = x2 + num * num; + real_batch den = y - one; + den = x2 + den * den; + batch_type res = select((x == real_batch(0.)) && (y == real_batch(1.)), + batch_type(real_batch(0.), constants::infinity()), + batch_type(w, 0.25 * log(num / den))); + return res; + } + + // atanh + /* origin: boost/simd/arch/common/simd/function/acosh.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + XSIMD_INLINE batch atanh(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + batch_type x = abs(self); + batch_type t = x + x; + batch_type z = batch_type(1.) - x; + auto test = x < batch_type(0.5); + batch_type tmp = select(test, x, t) / z; + return bitofsign(self) ^ (batch_type(0.5) * log1p(select(test, fma(t, tmp, t), tmp))); + } + template + XSIMD_INLINE batch, A> atanh(const batch, A>& z, requires_arch) noexcept + { + using batch_type = batch, A>; + batch_type w = atan(batch_type(-z.imag(), z.real())); + w = batch_type(w.imag(), -w.real()); + return w; + } + + // atan2 + template + XSIMD_INLINE batch atan2(batch const& self, batch const& other, requires_arch) noexcept + { + using batch_type = batch; + const batch_type q = abs(self / other); + const batch_type z = detail::kernel_atan(q, batch_type(1.) / q); + return select(other > batch_type(0.), z, constants::pi() - z) * signnz(self); + } + + // cos + namespace detail + { + template + XSIMD_INLINE batch quadrant(const batch& x) noexcept + { + return x & batch(3); + } + + template + XSIMD_INLINE batch quadrant(const batch& x) noexcept + { + return to_float(quadrant(to_int(x))); + } + + template + XSIMD_INLINE batch quadrant(const batch& x) noexcept + { + using batch_type = batch; + batch_type a = x * batch_type(0.25); + return (a - floor(a)) * batch_type(4.); + } + /* origin: boost/simd/arch/common/detail/simd/f_trig_evaluation.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + + template + XSIMD_INLINE batch cos_eval(const batch& z) noexcept + { + using batch_type = batch; + batch_type y = detail::horner(z); + return batch_type(1.) + fma(z, batch_type(-0.5), y * z * z); + } + + template + XSIMD_INLINE batch sin_eval(const batch& z, const batch& x) noexcept + { + using batch_type = batch; + batch_type y = detail::horner(z); + return fma(y * z, x, x); + } + + template + static XSIMD_INLINE batch base_tancot_eval(const batch& z) noexcept + { + using batch_type = batch; + batch_type zz = z * z; + batch_type y = detail::horner(zz); + return fma(y, zz * z, z); + } + + template + static XSIMD_INLINE batch tan_eval(const batch& z, const BB& test) noexcept + { + using batch_type = batch; + batch_type y = base_tancot_eval(z); + return select(test, y, -batch_type(1.) / y); + } + + template + static XSIMD_INLINE batch cot_eval(const batch& z, const BB& test) noexcept + { + using batch_type = batch; + batch_type y = base_tancot_eval(z); + return select(test, batch_type(1.) / y, -y); + } + + /* origin: boost/simd/arch/common/detail/simd/d_trig_evaluation.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + static XSIMD_INLINE batch cos_eval(const batch& z) noexcept + { + using batch_type = batch; + batch_type y = detail::horner(z); + return batch_type(1.) - y * z; + } + + template + static XSIMD_INLINE batch sin_eval(const batch& z, const batch& x) noexcept + { + using batch_type = batch; + batch_type y = detail::horner(z); + return fma(y * z, x, x); + } + + template + static XSIMD_INLINE batch base_tancot_eval(const batch& z) noexcept + { + using batch_type = batch; + batch_type zz = z * z; + batch_type num = detail::horner(zz); + batch_type den = detail::horner1(zz); + return fma(z, (zz * (num / den)), z); + } + + template + static XSIMD_INLINE batch tan_eval(const batch& z, const BB& test) noexcept + { + using batch_type = batch; + batch_type y = base_tancot_eval(z); + return select(test, y, -batch_type(1.) / y); + } + + template + static XSIMD_INLINE batch cot_eval(const batch& z, const BB& test) noexcept + { + using batch_type = batch; + batch_type y = base_tancot_eval(z); + return select(test, batch_type(1.) / y, -y); + } + /* origin: boost/simd/arch/common/detail/simd/trig_reduction.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + + struct trigo_radian_tag + { + }; + struct trigo_pi_tag + { + }; + + template + struct trigo_reducer + { + static XSIMD_INLINE B reduce(const B& x, B& xr) noexcept + { + if (all(x <= constants::pio4())) + { + xr = x; + return B(0.); + } + else if (all(x <= constants::pio2())) + { + auto test = x > constants::pio4(); + xr = x - constants::pio2_1(); + xr -= constants::pio2_2(); + xr -= constants::pio2_3(); + xr = select(test, xr, x); + return select(test, B(1.), B(0.)); + } + else if (all(x <= constants::twentypi())) + { + B xi = nearbyint(x * constants::twoopi()); + xr = fnma(xi, constants::pio2_1(), x); + xr -= xi * constants::pio2_2(); + xr -= xi * constants::pio2_3(); + return quadrant(xi); + } + else if (all(x <= constants::mediumpi())) + { + B fn = nearbyint(x * constants::twoopi()); + B r = x - fn * constants::pio2_1(); + B w = fn * constants::pio2_1t(); + B t = r; + w = fn * constants::pio2_2(); + r = t - w; + w = fn * constants::pio2_2t() - ((t - r) - w); + t = r; + w = fn * constants::pio2_3(); + r = t - w; + w = fn * constants::pio2_3t() - ((t - r) - w); + xr = r - w; + return quadrant(fn); + } + else + { + static constexpr std::size_t size = B::size; + using value_type = typename B::value_type; + alignas(B) std::array tmp; + alignas(B) std::array txr; + alignas(B) std::array args; + x.store_aligned(args.data()); + + for (std::size_t i = 0; i < size; ++i) + { + double arg = args[i]; + if (arg == std::numeric_limits::infinity()) + { + tmp[i] = 0.; + txr[i] = std::numeric_limits::quiet_NaN(); + } + else + { + double y[2]; + std::int32_t n = ::xsimd::detail::__ieee754_rem_pio2(arg, y); + tmp[i] = value_type(n & 3); + txr[i] = value_type(y[0]); + } + } + xr = B::load_aligned(&txr[0]); + B res = B::load_aligned(&tmp[0]); + return res; + } + } + }; + + template + struct trigo_reducer + { + static XSIMD_INLINE B reduce(const B& x, B& xr) noexcept + { + B xi = nearbyint(x * B(2.)); + B x2 = x - xi * B(0.5); + xr = x2 * constants::pi(); + return quadrant(xi); + } + }; + + } + template + XSIMD_INLINE batch cos(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + const batch_type x = abs(self); + batch_type xr = constants::nan(); + const batch_type n = detail::trigo_reducer::reduce(x, xr); + auto tmp = select(n >= batch_type(2.), batch_type(1.), batch_type(0.)); + auto swap_bit = fma(batch_type(-2.), tmp, n); + auto sign_bit = select((swap_bit ^ tmp) != batch_type(0.), constants::signmask(), batch_type(0.)); + const batch_type z = xr * xr; + const batch_type se = detail::sin_eval(z, xr); + const batch_type ce = detail::cos_eval(z); + const batch_type z1 = select(swap_bit != batch_type(0.), se, ce); + return z1 ^ sign_bit; + } + + template + XSIMD_INLINE batch, A> cos(batch, A> const& z, requires_arch) noexcept + { + return { cos(z.real()) * cosh(z.imag()), -sin(z.real()) * sinh(z.imag()) }; + } + + // cosh + + /* origin: boost/simd/arch/common/simd/function/cosh.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + + template + XSIMD_INLINE batch cosh(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + batch_type x = abs(self); + auto test1 = x > (constants::maxlog() - constants::log_2()); + batch_type fac = select(test1, batch_type(0.5), batch_type(1.)); + batch_type tmp = exp(x * fac); + batch_type tmp1 = batch_type(0.5) * tmp; + return select(test1, tmp1 * tmp, detail::average(tmp, batch_type(1.) / tmp)); + } + template + XSIMD_INLINE batch, A> cosh(const batch, A>& z, requires_arch) noexcept + { + auto x = z.real(); + auto y = z.imag(); + return { cosh(x) * cos(y), sinh(x) * sin(y) }; + } + + // sin + namespace detail + { + template + XSIMD_INLINE batch sin(batch const& self, Tag = Tag()) noexcept + { + using batch_type = batch; + const batch_type x = abs(self); + batch_type xr = constants::nan(); + const batch_type n = detail::trigo_reducer::reduce(x, xr); + auto tmp = select(n >= batch_type(2.), batch_type(1.), batch_type(0.)); + auto swap_bit = fma(batch_type(-2.), tmp, n); + auto sign_bit = bitofsign(self) ^ select(tmp != batch_type(0.), constants::signmask(), batch_type(0.)); + const batch_type z = xr * xr; + const batch_type se = detail::sin_eval(z, xr); + const batch_type ce = detail::cos_eval(z); + const batch_type z1 = select(swap_bit == batch_type(0.), se, ce); + return z1 ^ sign_bit; + } + } + + template + XSIMD_INLINE batch sin(batch const& self, requires_arch) noexcept + { + return detail::sin(self); + } + + template + XSIMD_INLINE batch, A> sin(batch, A> const& z, requires_arch) noexcept + { + return { sin(z.real()) * cosh(z.imag()), cos(z.real()) * sinh(z.imag()) }; + } + + // sincos + template + XSIMD_INLINE std::pair, batch> sincos(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + const batch_type x = abs(self); + batch_type xr = constants::nan(); + const batch_type n = detail::trigo_reducer::reduce(x, xr); + auto tmp = select(n >= batch_type(2.), batch_type(1.), batch_type(0.)); + auto swap_bit = fma(batch_type(-2.), tmp, n); + const batch_type z = xr * xr; + const batch_type se = detail::sin_eval(z, xr); + const batch_type ce = detail::cos_eval(z); + auto sin_sign_bit = bitofsign(self) ^ select(tmp != batch_type(0.), constants::signmask(), batch_type(0.)); + const batch_type sin_z1 = select(swap_bit == batch_type(0.), se, ce); + auto cos_sign_bit = select((swap_bit ^ tmp) != batch_type(0.), constants::signmask(), batch_type(0.)); + const batch_type cos_z1 = select(swap_bit != batch_type(0.), se, ce); + return std::make_pair(sin_z1 ^ sin_sign_bit, cos_z1 ^ cos_sign_bit); + } + + template + XSIMD_INLINE std::pair, A>, batch, A>> + sincos(batch, A> const& z, requires_arch) noexcept + { + using batch_type = batch, A>; + using real_batch = typename batch_type::real_batch; + real_batch rcos = cos(z.real()); + real_batch rsin = sin(z.real()); + real_batch icosh = cosh(z.imag()); + real_batch isinh = sinh(z.imag()); + return std::make_pair(batch_type(rsin * icosh, rcos * isinh), batch_type(rcos * icosh, -rsin * isinh)); + } + + // sinh + namespace detail + { + /* origin: boost/simd/arch/common/detail/generic/sinh_kernel.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + XSIMD_INLINE batch sinh_kernel(batch const& self) noexcept + { + using batch_type = batch; + batch_type sqr_self = self * self; + return detail::horner(sqr_self) + * self; + } + + template + XSIMD_INLINE batch sinh_kernel(batch const& self) noexcept + { + using batch_type = batch; + batch_type sqrself = self * self; + return fma(self, (detail::horner(sqrself) + / detail::horner1(sqrself)) + * sqrself, + self); + } + } + /* origin: boost/simd/arch/common/simd/function/sinh.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + XSIMD_INLINE batch sinh(batch const& a, requires_arch) noexcept + { + using batch_type = batch; + batch_type half(0.5); + batch_type x = abs(a); + auto lt1 = x < batch_type(1.); + batch_type bts = bitofsign(a); + batch_type z(0.); + if (any(lt1)) + { + z = detail::sinh_kernel(x); + if (all(lt1)) + return z ^ bts; + } + auto test1 = x > (constants::maxlog() - constants::log_2()); + batch_type fac = select(test1, half, batch_type(1.)); + batch_type tmp = exp(x * fac); + batch_type tmp1 = half * tmp; + batch_type r = select(test1, tmp1 * tmp, tmp1 - half / tmp); + return select(lt1, z, r) ^ bts; + } + template + XSIMD_INLINE batch, A> sinh(const batch, A>& z, requires_arch) noexcept + { + auto x = z.real(); + auto y = z.imag(); + return { sinh(x) * cos(y), cosh(x) * sin(y) }; + } + + // tan + template + XSIMD_INLINE batch tan(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + const batch_type x = abs(self); + batch_type xr = constants::nan(); + const batch_type n = detail::trigo_reducer::reduce(x, xr); + auto tmp = select(n >= batch_type(2.), batch_type(1.), batch_type(0.)); + auto swap_bit = fma(batch_type(-2.), tmp, n); + auto test = (swap_bit == batch_type(0.)); + const batch_type y = detail::tan_eval(xr, test); + return y ^ bitofsign(self); + } + template + XSIMD_INLINE batch, A> tan(batch, A> const& z, requires_arch) noexcept + { + using batch_type = batch, A>; + using real_batch = typename batch_type::real_batch; + real_batch d = cos(2 * z.real()) + cosh(2 * z.imag()); + batch_type winf(constants::infinity(), constants::infinity()); + real_batch wreal = sin(2 * z.real()) / d; + real_batch wimag = sinh(2 * z.imag()); + batch_type wres = select(isinf(wimag), batch_type(wreal, real_batch(1.)), batch_type(wreal, wimag / d)); + return select(d == real_batch(0.), winf, wres); + } + + // tanh + namespace detail + { + /* origin: boost/simd/arch/common/detail/generic/tanh_kernel.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + struct tanh_kernel; + + template + struct tanh_kernel> + { + using batch_type = batch; + static XSIMD_INLINE batch_type tanh(const batch_type& x) noexcept + { + batch_type sqrx = x * x; + return fma(detail::horner(sqrx) + * sqrx, + x, x); + } + + static XSIMD_INLINE batch_type cotanh(const batch_type& x) noexcept + { + return batch_type(1.) / tanh(x); + } + }; + + template + struct tanh_kernel> + { + using batch_type = batch; + static XSIMD_INLINE batch_type tanh(const batch_type& x) noexcept + { + batch_type sqrx = x * x; + return fma(sqrx * p(sqrx) / q(sqrx), x, x); + } + + static XSIMD_INLINE batch_type cotanh(const batch_type& x) noexcept + { + batch_type sqrx = x * x; + batch_type qval = q(sqrx); + return qval / (x * fma(p(sqrx), sqrx, qval)); + } + + static XSIMD_INLINE batch_type p(const batch_type& x) noexcept + { + return detail::horner(x); + } + + static XSIMD_INLINE batch_type q(const batch_type& x) noexcept + { + return detail::horner1(x); + } + }; + + } + /* origin: boost/simd/arch/common/simd/function/tanh.hpp */ + /* + * ==================================================== + * copyright 2016 NumScale SAS + * + * Distributed under the Boost Software License, Version 1.0. + * (See copy at http://boost.org/LICENSE_1_0.txt) + * ==================================================== + */ + template + XSIMD_INLINE batch tanh(batch const& self, requires_arch) noexcept + { + using batch_type = batch; + batch_type one(1.); + batch_type x = abs(self); + auto test = x < (batch_type(5.) / batch_type(8.)); + batch_type bts = bitofsign(self); + batch_type z = one; + if (any(test)) + { + z = detail::tanh_kernel::tanh(x); + if (all(test)) + return z ^ bts; + } + batch_type r = fma(batch_type(-2.), one / (one + exp(x + x)), one); + return select(test, z, r) ^ bts; + } + template + XSIMD_INLINE batch, A> tanh(const batch, A>& z, requires_arch) noexcept + { + using real_batch = typename batch, A>::real_batch; + auto x = z.real(); + auto y = z.imag(); + real_batch two(2); + auto d = cosh(two * x) + cos(two * y); + return { sinh(two * x) / d, sin(two * y) / d }; + } + + } + +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_avx.hpp b/include/onnxruntime/xsimd/arch/xsimd_avx.hpp new file mode 100644 index 0000000000000..116ea7762472e --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_avx.hpp @@ -0,0 +1,1820 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_AVX_HPP +#define XSIMD_AVX_HPP + +#include +#include +#include + +#include "../types/xsimd_avx_register.hpp" + +namespace xsimd +{ + + namespace kernel + { + using namespace types; + + // fwd + template + XSIMD_INLINE batch insert(batch const& self, T val, index, requires_arch) noexcept; + + namespace detail + { + XSIMD_INLINE void split_avx(__m256i val, __m128i& low, __m128i& high) noexcept + { + low = _mm256_castsi256_si128(val); + high = _mm256_extractf128_si256(val, 1); + } + XSIMD_INLINE void split_avx(__m256 val, __m128& low, __m128& high) noexcept + { + low = _mm256_castps256_ps128(val); + high = _mm256_extractf128_ps(val, 1); + } + XSIMD_INLINE void split_avx(__m256d val, __m128d& low, __m128d& high) noexcept + { + low = _mm256_castpd256_pd128(val); + high = _mm256_extractf128_pd(val, 1); + } + XSIMD_INLINE __m256i merge_sse(__m128i low, __m128i high) noexcept + { + return _mm256_insertf128_si256(_mm256_castsi128_si256(low), high, 1); + } + XSIMD_INLINE __m256 merge_sse(__m128 low, __m128 high) noexcept + { + return _mm256_insertf128_ps(_mm256_castps128_ps256(low), high, 1); + } + XSIMD_INLINE __m256d merge_sse(__m128d low, __m128d high) noexcept + { + return _mm256_insertf128_pd(_mm256_castpd128_pd256(low), high, 1); + } + template + XSIMD_INLINE __m256i fwd_to_sse(F f, __m256i self) noexcept + { + __m128i self_low, self_high; + split_avx(self, self_low, self_high); + __m128i res_low = f(self_low); + __m128i res_high = f(self_high); + return merge_sse(res_low, res_high); + } + template + XSIMD_INLINE __m256i fwd_to_sse(F f, __m256i self, __m256i other) noexcept + { + __m128i self_low, self_high, other_low, other_high; + split_avx(self, self_low, self_high); + split_avx(other, other_low, other_high); + __m128i res_low = f(self_low, other_low); + __m128i res_high = f(self_high, other_high); + return merge_sse(res_low, res_high); + } + template + XSIMD_INLINE __m256i fwd_to_sse(F f, __m256i self, int32_t other) noexcept + { + __m128i self_low, self_high; + split_avx(self, self_low, self_high); + __m128i res_low = f(self_low, other); + __m128i res_high = f(self_high, other); + return merge_sse(res_low, res_high); + } + } + + // abs + template + XSIMD_INLINE batch abs(batch const& self, requires_arch) noexcept + { + __m256 sign_mask = _mm256_set1_ps(-0.f); // -0.f = 1 << 31 + return _mm256_andnot_ps(sign_mask, self); + } + template + XSIMD_INLINE batch abs(batch const& self, requires_arch) noexcept + { + __m256d sign_mask = _mm256_set1_pd(-0.f); // -0.f = 1 << 31 + return _mm256_andnot_pd(sign_mask, self); + } + + // add + template ::value, void>::type> + XSIMD_INLINE batch add(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::fwd_to_sse([](__m128i s, __m128i o) noexcept + { return add(batch(s), batch(o)); }, + self, other); + } + template + XSIMD_INLINE batch add(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_add_ps(self, other); + } + template + XSIMD_INLINE batch add(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_add_pd(self, other); + } + + // all + template + XSIMD_INLINE bool all(batch_bool const& self, requires_arch) noexcept + { + return _mm256_testc_ps(self, batch_bool(true)) != 0; + } + template + XSIMD_INLINE bool all(batch_bool const& self, requires_arch) noexcept + { + return _mm256_testc_pd(self, batch_bool(true)) != 0; + } + template ::value, void>::type> + XSIMD_INLINE bool all(batch_bool const& self, requires_arch) noexcept + { + return _mm256_testc_si256(self, batch_bool(true)) != 0; + } + + // any + template + XSIMD_INLINE bool any(batch_bool const& self, requires_arch) noexcept + { + return !_mm256_testz_ps(self, self); + } + template + XSIMD_INLINE bool any(batch_bool const& self, requires_arch) noexcept + { + return !_mm256_testz_pd(self, self); + } + template ::value, void>::type> + XSIMD_INLINE bool any(batch_bool const& self, requires_arch) noexcept + { + return !_mm256_testz_si256(self, self); + } + + // batch_bool_cast + template + XSIMD_INLINE batch_bool batch_bool_cast(batch_bool const& self, batch_bool const&, requires_arch) noexcept + { + return { bitwise_cast(batch(self.data)).data }; + } + + // bitwise_and + template + XSIMD_INLINE batch bitwise_and(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_and_ps(self, other); + } + template + XSIMD_INLINE batch bitwise_and(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_and_pd(self, other); + } + + template + XSIMD_INLINE batch_bool bitwise_and(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return _mm256_and_ps(self, other); + } + template + XSIMD_INLINE batch_bool bitwise_and(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return _mm256_and_pd(self, other); + } + + template ::value, void>::type> + XSIMD_INLINE batch bitwise_and(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::fwd_to_sse([](__m128i s, __m128i o) noexcept + { return bitwise_and(batch(s), batch(o)); }, + self, other); + } + template ::value, void>::type> + XSIMD_INLINE batch_bool bitwise_and(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return detail::fwd_to_sse([](__m128i s, __m128i o) noexcept + { return bitwise_and(batch(s), batch(o)); }, + self, other); + } + + // bitwise_andnot + template + XSIMD_INLINE batch bitwise_andnot(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_andnot_ps(other, self); + } + template + XSIMD_INLINE batch bitwise_andnot(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_andnot_pd(other, self); + } + + template + XSIMD_INLINE batch_bool bitwise_andnot(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return _mm256_andnot_ps(other, self); + } + template + XSIMD_INLINE batch_bool bitwise_andnot(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return _mm256_andnot_pd(other, self); + } + + template ::value, void>::type> + XSIMD_INLINE batch bitwise_andnot(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::fwd_to_sse([](__m128i s, __m128i o) noexcept + { return bitwise_andnot(batch(s), batch(o)); }, + self, other); + } + template ::value, void>::type> + XSIMD_INLINE batch_bool bitwise_andnot(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return detail::fwd_to_sse([](__m128i s, __m128i o) noexcept + { return bitwise_andnot(batch(s), batch(o)); }, + self, other); + } + + // bitwise_lshift + template ::value, void>::type> + XSIMD_INLINE batch bitwise_lshift(batch const& self, int32_t other, requires_arch) noexcept + { + return detail::fwd_to_sse([](__m128i s, int32_t o) noexcept + { return bitwise_lshift(batch(s), o, sse4_2 {}); }, + self, other); + } + + // bitwise_not + template ::value, void>::type> + XSIMD_INLINE batch bitwise_not(batch const& self, requires_arch) noexcept + { + return detail::fwd_to_sse([](__m128i s) noexcept + { return bitwise_not(batch(s), sse4_2 {}); }, + self); + } + template ::value, void>::type> + XSIMD_INLINE batch_bool bitwise_not(batch_bool const& self, requires_arch) noexcept + { + return detail::fwd_to_sse([](__m128i s) noexcept + { return bitwise_not(batch_bool(s), sse4_2 {}); }, + self); + } + + // bitwise_or + template + XSIMD_INLINE batch bitwise_or(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_or_ps(self, other); + } + template + XSIMD_INLINE batch bitwise_or(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_or_pd(self, other); + } + template + XSIMD_INLINE batch_bool bitwise_or(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return _mm256_or_ps(self, other); + } + template + XSIMD_INLINE batch_bool bitwise_or(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return _mm256_or_pd(self, other); + } + template ::value, void>::type> + XSIMD_INLINE batch bitwise_or(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::fwd_to_sse([](__m128i s, __m128i o) noexcept + { return bitwise_or(batch(s), batch(o)); }, + self, other); + } + template ::value, void>::type> + XSIMD_INLINE batch_bool bitwise_or(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return detail::fwd_to_sse([](__m128i s, __m128i o) noexcept + { return bitwise_or(batch_bool(s), batch_bool(o)); }, + self, other); + } + + // bitwise_rshift + template ::value, void>::type> + XSIMD_INLINE batch bitwise_rshift(batch const& self, int32_t other, requires_arch) noexcept + { + return detail::fwd_to_sse([](__m128i s, int32_t o) noexcept + { return bitwise_rshift(batch(s), o, sse4_2 {}); }, + self, other); + } + + // bitwise_xor + template + XSIMD_INLINE batch bitwise_xor(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_xor_ps(self, other); + } + template + XSIMD_INLINE batch bitwise_xor(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_xor_pd(self, other); + } + template + XSIMD_INLINE batch_bool bitwise_xor(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return _mm256_xor_ps(self, other); + } + template + XSIMD_INLINE batch_bool bitwise_xor(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return _mm256_xor_pd(self, other); + } + template ::value, void>::type> + XSIMD_INLINE batch bitwise_xor(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::fwd_to_sse([](__m128i s, __m128i o) noexcept + { return bitwise_xor(batch(s), batch(o), sse4_2 {}); }, + self, other); + } + template ::value, void>::type> + XSIMD_INLINE batch bitwise_xor(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return detail::fwd_to_sse([](__m128i s, __m128i o) noexcept + { return bitwise_xor(batch_bool(s), batch_bool(o), sse4_2 {}); }, + self, other); + } + + // bitwise_cast + template ::value, void>::type> + XSIMD_INLINE batch bitwise_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm256_castsi256_ps(self); + } + template ::value, void>::type> + XSIMD_INLINE batch bitwise_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm256_castsi256_pd(self); + } + template ::type>::value, void>::type> + XSIMD_INLINE batch bitwise_cast(batch const& self, batch const&, requires_arch) noexcept + { + return batch(self.data); + } + template + XSIMD_INLINE batch bitwise_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm256_castps_pd(self); + } + template ::value, void>::type> + XSIMD_INLINE batch bitwise_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm256_castps_si256(self); + } + template + XSIMD_INLINE batch bitwise_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm256_castpd_ps(self); + } + template ::value, void>::type> + XSIMD_INLINE batch bitwise_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm256_castpd_si256(self); + } + + // bitwise_not + template + XSIMD_INLINE batch bitwise_not(batch const& self, requires_arch) noexcept + { + return _mm256_xor_ps(self, _mm256_castsi256_ps(_mm256_set1_epi32(-1))); + } + template + XSIMD_INLINE batch bitwise_not(batch const& self, requires_arch) noexcept + { + return _mm256_xor_pd(self, _mm256_castsi256_pd(_mm256_set1_epi32(-1))); + } + template + XSIMD_INLINE batch_bool bitwise_not(batch_bool const& self, requires_arch) noexcept + { + return _mm256_xor_ps(self, _mm256_castsi256_ps(_mm256_set1_epi32(-1))); + } + template + XSIMD_INLINE batch_bool bitwise_not(batch_bool const& self, requires_arch) noexcept + { + return _mm256_xor_pd(self, _mm256_castsi256_pd(_mm256_set1_epi32(-1))); + } + + // broadcast + template ::value, void>::type> + XSIMD_INLINE batch broadcast(T val, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_set1_epi8(val); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_set1_epi16(val); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_set1_epi32(val); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm256_set1_epi64x(val); + } + else + { + assert(false && "unsupported"); + return {}; + } + } + template + XSIMD_INLINE batch broadcast(float val, requires_arch) noexcept + { + return _mm256_set1_ps(val); + } + template + XSIMD_INLINE batch broadcast(double val, requires_arch) noexcept + { + return _mm256_set1_pd(val); + } + + // ceil + template + XSIMD_INLINE batch ceil(batch const& self, requires_arch) noexcept + { + return _mm256_ceil_ps(self); + } + template + XSIMD_INLINE batch ceil(batch const& self, requires_arch) noexcept + { + return _mm256_ceil_pd(self); + } + + namespace detail + { + // On clang, _mm256_extractf128_ps is built upon build_shufflevector + // which require index parameter to be a constant + template + XSIMD_INLINE B get_half_complex_f(const B& real, const B& imag) noexcept + { + __m128 tmp0 = _mm256_extractf128_ps(real, index); + __m128 tmp1 = _mm256_extractf128_ps(imag, index); + __m128 tmp2 = _mm_unpackhi_ps(tmp0, tmp1); + tmp0 = _mm_unpacklo_ps(tmp0, tmp1); + __m256 res = real; + res = _mm256_insertf128_ps(res, tmp0, 0); + res = _mm256_insertf128_ps(res, tmp2, 1); + return res; + } + template + XSIMD_INLINE B get_half_complex_d(const B& real, const B& imag) noexcept + { + __m128d tmp0 = _mm256_extractf128_pd(real, index); + __m128d tmp1 = _mm256_extractf128_pd(imag, index); + __m128d tmp2 = _mm_unpackhi_pd(tmp0, tmp1); + tmp0 = _mm_unpacklo_pd(tmp0, tmp1); + __m256d res = real; + res = _mm256_insertf128_pd(res, tmp0, 0); + res = _mm256_insertf128_pd(res, tmp2, 1); + return res; + } + + // complex_low + template + XSIMD_INLINE batch complex_low(batch, A> const& self, requires_arch) noexcept + { + return get_half_complex_f<0>(self.real(), self.imag()); + } + template + XSIMD_INLINE batch complex_low(batch, A> const& self, requires_arch) noexcept + { + return get_half_complex_d<0>(self.real(), self.imag()); + } + + // complex_high + template + XSIMD_INLINE batch complex_high(batch, A> const& self, requires_arch) noexcept + { + return get_half_complex_f<1>(self.real(), self.imag()); + } + template + XSIMD_INLINE batch complex_high(batch, A> const& self, requires_arch) noexcept + { + return get_half_complex_d<1>(self.real(), self.imag()); + } + } + + // fast_cast + namespace detail + { + template + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm256_cvtepi32_ps(self); + } + + template + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm256_cvttps_epi32(self); + } + } + + // decr_if + template ::value, void>::type> + XSIMD_INLINE batch decr_if(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return self + batch(mask.data); + } + + // div + template + XSIMD_INLINE batch div(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_div_ps(self, other); + } + template + XSIMD_INLINE batch div(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_div_pd(self, other); + } + + // eq + template + XSIMD_INLINE batch_bool eq(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_cmp_ps(self, other, _CMP_EQ_OQ); + } + template + XSIMD_INLINE batch_bool eq(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_cmp_pd(self, other, _CMP_EQ_OQ); + } + template + XSIMD_INLINE batch_bool eq(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return ~(self != other); + } + template + XSIMD_INLINE batch_bool eq(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return ~(self != other); + } + template ::value, void>::type> + XSIMD_INLINE batch_bool eq(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::fwd_to_sse([](__m128i s, __m128i o) noexcept + { return eq(batch(s), batch(o), sse4_2 {}); }, + self, other); + } + + template ::value, void>::type> + XSIMD_INLINE batch_bool eq(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return ~(self != other); + } + + // floor + template + XSIMD_INLINE batch floor(batch const& self, requires_arch) noexcept + { + return _mm256_floor_ps(self); + } + template + XSIMD_INLINE batch floor(batch const& self, requires_arch) noexcept + { + return _mm256_floor_pd(self); + } + + // from_mask + template + XSIMD_INLINE batch_bool from_mask(batch_bool const&, uint64_t mask, requires_arch) noexcept + { + alignas(A::alignment()) static const uint64_t lut32[] = { + 0x0000000000000000ul, + 0x00000000FFFFFFFFul, + 0xFFFFFFFF00000000ul, + 0xFFFFFFFFFFFFFFFFul, + }; + assert(!(mask & ~0xFFul) && "inbound mask"); + return _mm256_castsi256_ps(_mm256_setr_epi64x(lut32[mask & 0x3], lut32[(mask >> 2) & 0x3], lut32[(mask >> 4) & 0x3], lut32[mask >> 6])); + } + template + XSIMD_INLINE batch_bool from_mask(batch_bool const&, uint64_t mask, requires_arch) noexcept + { + alignas(A::alignment()) static const uint64_t lut64[][4] = { + { 0x0000000000000000ul, 0x0000000000000000ul, 0x0000000000000000ul, 0x0000000000000000ul }, + { 0xFFFFFFFFFFFFFFFFul, 0x0000000000000000ul, 0x0000000000000000ul, 0x0000000000000000ul }, + { 0x0000000000000000ul, 0xFFFFFFFFFFFFFFFFul, 0x0000000000000000ul, 0x0000000000000000ul }, + { 0xFFFFFFFFFFFFFFFFul, 0xFFFFFFFFFFFFFFFFul, 0x0000000000000000ul, 0x0000000000000000ul }, + { 0x0000000000000000ul, 0x0000000000000000ul, 0xFFFFFFFFFFFFFFFFul, 0x0000000000000000ul }, + { 0xFFFFFFFFFFFFFFFFul, 0x0000000000000000ul, 0xFFFFFFFFFFFFFFFFul, 0x0000000000000000ul }, + { 0x0000000000000000ul, 0xFFFFFFFFFFFFFFFFul, 0xFFFFFFFFFFFFFFFFul, 0x0000000000000000ul }, + { 0xFFFFFFFFFFFFFFFFul, 0xFFFFFFFFFFFFFFFFul, 0xFFFFFFFFFFFFFFFFul, 0x0000000000000000ul }, + { 0x0000000000000000ul, 0x0000000000000000ul, 0x0000000000000000ul, 0xFFFFFFFFFFFFFFFFul }, + { 0xFFFFFFFFFFFFFFFFul, 0x0000000000000000ul, 0x0000000000000000ul, 0xFFFFFFFFFFFFFFFFul }, + { 0x0000000000000000ul, 0xFFFFFFFFFFFFFFFFul, 0x0000000000000000ul, 0xFFFFFFFFFFFFFFFFul }, + { 0xFFFFFFFFFFFFFFFFul, 0xFFFFFFFFFFFFFFFFul, 0x0000000000000000ul, 0xFFFFFFFFFFFFFFFFul }, + { 0x0000000000000000ul, 0x0000000000000000ul, 0xFFFFFFFFFFFFFFFFul, 0xFFFFFFFFFFFFFFFFul }, + { 0xFFFFFFFFFFFFFFFFul, 0x0000000000000000ul, 0xFFFFFFFFFFFFFFFFul, 0xFFFFFFFFFFFFFFFFul }, + { 0x0000000000000000ul, 0xFFFFFFFFFFFFFFFFul, 0xFFFFFFFFFFFFFFFFul, 0xFFFFFFFFFFFFFFFFul }, + { 0xFFFFFFFFFFFFFFFFul, 0xFFFFFFFFFFFFFFFFul, 0xFFFFFFFFFFFFFFFFul, 0xFFFFFFFFFFFFFFFFul }, + }; + assert(!(mask & ~0xFul) && "inbound mask"); + return _mm256_castsi256_pd(_mm256_load_si256((const __m256i*)lut64[mask])); + } + template ::value, void>::type> + XSIMD_INLINE batch_bool from_mask(batch_bool const&, uint64_t mask, requires_arch) noexcept + { + alignas(A::alignment()) static const uint32_t lut32[] = { + 0x00000000, + 0x000000FF, + 0x0000FF00, + 0x0000FFFF, + 0x00FF0000, + 0x00FF00FF, + 0x00FFFF00, + 0x00FFFFFF, + 0xFF000000, + 0xFF0000FF, + 0xFF00FF00, + 0xFF00FFFF, + 0xFFFF0000, + 0xFFFF00FF, + 0xFFFFFF00, + 0xFFFFFFFF, + }; + alignas(A::alignment()) static const uint64_t lut64[] = { + 0x0000000000000000ul, + 0x000000000000FFFFul, + 0x00000000FFFF0000ul, + 0x00000000FFFFFFFFul, + 0x0000FFFF00000000ul, + 0x0000FFFF0000FFFFul, + 0x0000FFFFFFFF0000ul, + 0x0000FFFFFFFFFFFFul, + 0xFFFF000000000000ul, + 0xFFFF00000000FFFFul, + 0xFFFF0000FFFF0000ul, + 0xFFFF0000FFFFFFFFul, + 0xFFFFFFFF00000000ul, + 0xFFFFFFFF0000FFFFul, + 0xFFFFFFFFFFFF0000ul, + 0xFFFFFFFFFFFFFFFFul, + }; + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + assert(!(mask & ~0xFFFFFFFFul) && "inbound mask"); + return _mm256_setr_epi32(lut32[mask & 0xF], lut32[(mask >> 4) & 0xF], + lut32[(mask >> 8) & 0xF], lut32[(mask >> 12) & 0xF], + lut32[(mask >> 16) & 0xF], lut32[(mask >> 20) & 0xF], + lut32[(mask >> 24) & 0xF], lut32[mask >> 28]); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + assert(!(mask & ~0xFFFFul) && "inbound mask"); + return _mm256_setr_epi64x(lut64[mask & 0xF], lut64[(mask >> 4) & 0xF], lut64[(mask >> 8) & 0xF], lut64[(mask >> 12) & 0xF]); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_castps_si256(from_mask(batch_bool {}, mask, avx {})); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm256_castpd_si256(from_mask(batch_bool {}, mask, avx {})); + } + } + + // haddp + template + XSIMD_INLINE batch haddp(batch const* row, requires_arch) noexcept + { + // row = (a,b,c,d,e,f,g,h) + // tmp0 = (a0+a1, a2+a3, b0+b1, b2+b3, a4+a5, a6+a7, b4+b5, b6+b7) + __m256 tmp0 = _mm256_hadd_ps(row[0], row[1]); + // tmp1 = (c0+c1, c2+c3, d1+d2, d2+d3, c4+c5, c6+c7, d4+d5, d6+d7) + __m256 tmp1 = _mm256_hadd_ps(row[2], row[3]); + // tmp1 = (a0+a1+a2+a3, b0+b1+b2+b3, c0+c1+c2+c3, d0+d1+d2+d3, + // a4+a5+a6+a7, b4+b5+b6+b7, c4+c5+c6+c7, d4+d5+d6+d7) + tmp1 = _mm256_hadd_ps(tmp0, tmp1); + // tmp0 = (e0+e1, e2+e3, f0+f1, f2+f3, e4+e5, e6+e7, f4+f5, f6+f7) + tmp0 = _mm256_hadd_ps(row[4], row[5]); + // tmp2 = (g0+g1, g2+g3, h0+h1, h2+h3, g4+g5, g6+g7, h4+h5, h6+h7) + __m256 tmp2 = _mm256_hadd_ps(row[6], row[7]); + // tmp2 = (e0+e1+e2+e3, f0+f1+f2+f3, g0+g1+g2+g3, h0+h1+h2+h3, + // e4+e5+e6+e7, f4+f5+f6+f7, g4+g5+g6+g7, h4+h5+h6+h7) + tmp2 = _mm256_hadd_ps(tmp0, tmp2); + // tmp0 = (a0+a1+a2+a3, b0+b1+b2+b3, c0+c1+c2+c3, d0+d1+d2+d3, + // e4+e5+e6+e7, f4+f5+f6+f7, g4+g5+g6+g7, h4+h5+h6+h7) + tmp0 = _mm256_blend_ps(tmp1, tmp2, 0b11110000); + // tmp1 = (a4+a5+a6+a7, b4+b5+b6+b7, c4+c5+c6+c7, d4+d5+d6+d7, + // e0+e1+e2+e3, f0+f1+f2+f3, g0+g1+g2+g3, h0+h1+h2+h3) + tmp1 = _mm256_permute2f128_ps(tmp1, tmp2, 0x21); + return _mm256_add_ps(tmp0, tmp1); + } + template + XSIMD_INLINE batch haddp(batch const* row, requires_arch) noexcept + { + // row = (a,b,c,d) + // tmp0 = (a0+a1, b0+b1, a2+a3, b2+b3) + __m256d tmp0 = _mm256_hadd_pd(row[0], row[1]); + // tmp1 = (c0+c1, d0+d1, c2+c3, d2+d3) + __m256d tmp1 = _mm256_hadd_pd(row[2], row[3]); + // tmp2 = (a0+a1, b0+b1, c2+c3, d2+d3) + __m256d tmp2 = _mm256_blend_pd(tmp0, tmp1, 0b1100); + // tmp1 = (a2+a3, b2+b3, c2+c3, d2+d3) + tmp1 = _mm256_permute2f128_pd(tmp0, tmp1, 0x21); + return _mm256_add_pd(tmp1, tmp2); + } + + // incr_if + template ::value, void>::type> + XSIMD_INLINE batch incr_if(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return self - batch(mask.data); + } + + // insert + template ::value, void>::type> + XSIMD_INLINE batch insert(batch const& self, T val, index pos, requires_arch) noexcept + { +#if !defined(_MSC_VER) || _MSC_VER > 1900 + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_insert_epi8(self, val, I); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_insert_epi16(self, val, I); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_insert_epi32(self, val, I); + } + else + { + return insert(self, val, pos, generic {}); + } +#endif + return insert(self, val, pos, generic {}); + } + + // isnan + template + XSIMD_INLINE batch_bool isnan(batch const& self, requires_arch) noexcept + { + return _mm256_cmp_ps(self, self, _CMP_UNORD_Q); + } + template + XSIMD_INLINE batch_bool isnan(batch const& self, requires_arch) noexcept + { + return _mm256_cmp_pd(self, self, _CMP_UNORD_Q); + } + + // le + template + XSIMD_INLINE batch_bool le(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_cmp_ps(self, other, _CMP_LE_OQ); + } + template + XSIMD_INLINE batch_bool le(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_cmp_pd(self, other, _CMP_LE_OQ); + } + + // load_aligned + template ::value, void>::type> + XSIMD_INLINE batch load_aligned(T const* mem, convert, requires_arch) noexcept + { + return _mm256_load_si256((__m256i const*)mem); + } + template + XSIMD_INLINE batch load_aligned(float const* mem, convert, requires_arch) noexcept + { + return _mm256_load_ps(mem); + } + template + XSIMD_INLINE batch load_aligned(double const* mem, convert, requires_arch) noexcept + { + return _mm256_load_pd(mem); + } + + namespace detail + { + // load_complex + template + XSIMD_INLINE batch, A> load_complex(batch const& hi, batch const& lo, requires_arch) noexcept + { + using batch_type = batch; + __m128 tmp0 = _mm256_extractf128_ps(hi, 0); + __m128 tmp1 = _mm256_extractf128_ps(hi, 1); + __m128 tmp_real = _mm_shuffle_ps(tmp0, tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + __m128 tmp_imag = _mm_shuffle_ps(tmp0, tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + batch_type real = _mm256_castps128_ps256(tmp_real); + batch_type imag = _mm256_castps128_ps256(tmp_imag); + + tmp0 = _mm256_extractf128_ps(lo, 0); + tmp1 = _mm256_extractf128_ps(lo, 1); + tmp_real = _mm_shuffle_ps(tmp0, tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + tmp_imag = _mm_shuffle_ps(tmp0, tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + real = _mm256_insertf128_ps(real, tmp_real, 1); + imag = _mm256_insertf128_ps(imag, tmp_imag, 1); + return { real, imag }; + } + template + XSIMD_INLINE batch, A> load_complex(batch const& hi, batch const& lo, requires_arch) noexcept + { + using batch_type = batch; + __m128d tmp0 = _mm256_extractf128_pd(hi, 0); + __m128d tmp1 = _mm256_extractf128_pd(hi, 1); + batch_type real = _mm256_castpd128_pd256(_mm_unpacklo_pd(tmp0, tmp1)); + batch_type imag = _mm256_castpd128_pd256(_mm_unpackhi_pd(tmp0, tmp1)); + + tmp0 = _mm256_extractf128_pd(lo, 0); + tmp1 = _mm256_extractf128_pd(lo, 1); + __m256d re_tmp1 = _mm256_insertf128_pd(real, _mm_unpacklo_pd(tmp0, tmp1), 1); + __m256d im_tmp1 = _mm256_insertf128_pd(imag, _mm_unpackhi_pd(tmp0, tmp1), 1); + real = _mm256_blend_pd(real, re_tmp1, 12); + imag = _mm256_blend_pd(imag, im_tmp1, 12); + return { real, imag }; + } + } + + // load_unaligned + template ::value, void>::type> + XSIMD_INLINE batch load_unaligned(T const* mem, convert, requires_arch) noexcept + { + return _mm256_loadu_si256((__m256i const*)mem); + } + template + XSIMD_INLINE batch load_unaligned(float const* mem, convert, requires_arch) noexcept + { + return _mm256_loadu_ps(mem); + } + template + XSIMD_INLINE batch load_unaligned(double const* mem, convert, requires_arch) noexcept + { + return _mm256_loadu_pd(mem); + } + + // lt + template + XSIMD_INLINE batch_bool lt(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_cmp_ps(self, other, _CMP_LT_OQ); + } + template + XSIMD_INLINE batch_bool lt(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_cmp_pd(self, other, _CMP_LT_OQ); + } + + template ::value, void>::type> + XSIMD_INLINE batch_bool lt(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::fwd_to_sse([](__m128i s, __m128i o) noexcept + { return lt(batch(s), batch(o)); }, + self, other); + } + + // mask + template ::value, void>::type> + XSIMD_INLINE uint64_t mask(batch_bool const& self, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1 || sizeof(T) == 2) + { + __m128i self_low, self_high; + detail::split_avx(self, self_low, self_high); + return mask(batch_bool(self_low), sse4_2 {}) | (mask(batch_bool(self_high), sse4_2 {}) << (128 / (8 * sizeof(T)))); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_movemask_ps(_mm256_castsi256_ps(self)); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm256_movemask_pd(_mm256_castsi256_pd(self)); + } + else + { + assert(false && "unsupported arch/op combination"); + return {}; + } + } + template + XSIMD_INLINE uint64_t mask(batch_bool const& self, requires_arch) noexcept + { + return _mm256_movemask_ps(self); + } + + template + XSIMD_INLINE uint64_t mask(batch_bool const& self, requires_arch) noexcept + { + return _mm256_movemask_pd(self); + } + + // max + template + XSIMD_INLINE batch max(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_max_ps(self, other); + } + template + XSIMD_INLINE batch max(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_max_pd(self, other); + } + template ::value, void>::type> + XSIMD_INLINE batch max(batch const& self, batch const& other, requires_arch) noexcept + { + return select(self > other, self, other); + } + + // min + template + XSIMD_INLINE batch min(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_min_ps(self, other); + } + template + XSIMD_INLINE batch min(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_min_pd(self, other); + } + template ::value, void>::type> + XSIMD_INLINE batch min(batch const& self, batch const& other, requires_arch) noexcept + { + return select(self <= other, self, other); + } + + // mul + template + XSIMD_INLINE batch mul(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_mul_ps(self, other); + } + template + XSIMD_INLINE batch mul(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_mul_pd(self, other); + } + + // nearbyint + template + XSIMD_INLINE batch nearbyint(batch const& self, requires_arch) noexcept + { + return _mm256_round_ps(self, _MM_FROUND_TO_NEAREST_INT); + } + template + XSIMD_INLINE batch nearbyint(batch const& self, requires_arch) noexcept + { + return _mm256_round_pd(self, _MM_FROUND_TO_NEAREST_INT); + } + + // nearbyint_as_int + template + XSIMD_INLINE batch nearbyint_as_int(batch const& self, + requires_arch) noexcept + { + return _mm256_cvtps_epi32(self); + } + + // neg + template ::value, void>::type> + XSIMD_INLINE batch neg(batch const& self, requires_arch) noexcept + { + return 0 - self; + } + template + batch neg(batch const& self, requires_arch) + { + return _mm256_xor_ps(self, _mm256_castsi256_ps(_mm256_set1_epi32(0x80000000))); + } + template + XSIMD_INLINE batch neg(batch const& self, requires_arch) noexcept + { + return _mm256_xor_pd(self, _mm256_castsi256_pd(_mm256_set1_epi64x(0x8000000000000000))); + } + + // neq + template + XSIMD_INLINE batch_bool neq(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_cmp_ps(self, other, _CMP_NEQ_UQ); + } + template + XSIMD_INLINE batch_bool neq(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_cmp_pd(self, other, _CMP_NEQ_UQ); + } + template ::value, void>::type> + XSIMD_INLINE batch_bool neq(batch const& self, batch const& other, requires_arch) noexcept + { + return ~(self == other); + } + + template + XSIMD_INLINE batch_bool neq(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return _mm256_xor_ps(self, other); + } + template + XSIMD_INLINE batch_bool neq(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return _mm256_xor_pd(self, other); + } + template ::value, void>::type> + XSIMD_INLINE batch_bool neq(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return _mm256_castps_si256(_mm256_xor_ps(_mm256_castsi256_ps(self.data), _mm256_castsi256_ps(other.data))); + } + + // reciprocal + template + XSIMD_INLINE batch reciprocal(batch const& self, + kernel::requires_arch) noexcept + { + return _mm256_rcp_ps(self); + } + + // reduce_add + template + XSIMD_INLINE float reduce_add(batch const& rhs, requires_arch) noexcept + { + // Warning about _mm256_hadd_ps: + // _mm256_hadd_ps(a,b) gives + // (a0+a1,a2+a3,b0+b1,b2+b3,a4+a5,a6+a7,b4+b5,b6+b7). Hence we can't + // rely on a naive use of this method + // rhs = (x0, x1, x2, x3, x4, x5, x6, x7) + // tmp = (x4, x5, x6, x7, x0, x1, x2, x3) + __m256 tmp = _mm256_permute2f128_ps(rhs, rhs, 1); + // tmp = (x4+x0, x5+x1, x6+x2, x7+x3, x0+x4, x1+x5, x2+x6, x3+x7) + tmp = _mm256_add_ps(rhs, tmp); + // tmp = (x4+x0+x5+x1, x6+x2+x7+x3, -, -, -, -, -, -) + tmp = _mm256_hadd_ps(tmp, tmp); + // tmp = (x4+x0+x5+x1+x6+x2+x7+x3, -, -, -, -, -, -, -) + tmp = _mm256_hadd_ps(tmp, tmp); + return _mm_cvtss_f32(_mm256_extractf128_ps(tmp, 0)); + } + template + XSIMD_INLINE double reduce_add(batch const& rhs, requires_arch) noexcept + { + // rhs = (x0, x1, x2, x3) + // tmp = (x2, x3, x0, x1) + __m256d tmp = _mm256_permute2f128_pd(rhs, rhs, 1); + // tmp = (x2+x0, x3+x1, -, -) + tmp = _mm256_add_pd(rhs, tmp); + // tmp = (x2+x0+x3+x1, -, -, -) + tmp = _mm256_hadd_pd(tmp, tmp); + return _mm_cvtsd_f64(_mm256_extractf128_pd(tmp, 0)); + } + template ::value, void>::type> + XSIMD_INLINE T reduce_add(batch const& self, requires_arch) noexcept + { + __m128i low, high; + detail::split_avx(self, low, high); + batch blow(low), bhigh(high); + return reduce_add(blow) + reduce_add(bhigh); + } + + // reduce_max + template ::type> + XSIMD_INLINE T reduce_max(batch const& self, requires_arch) noexcept + { + constexpr auto mask = detail::shuffle(1, 0); + batch step = _mm256_permute2f128_si256(self, self, mask); + batch acc = max(self, step); + __m128i low = _mm256_castsi256_si128(acc); + return reduce_max(batch(low)); + } + + // reduce_min + template ::type> + XSIMD_INLINE T reduce_min(batch const& self, requires_arch) noexcept + { + constexpr auto mask = detail::shuffle(1, 0); + batch step = _mm256_permute2f128_si256(self, self, mask); + batch acc = min(self, step); + __m128i low = _mm256_castsi256_si128(acc); + return reduce_min(batch(low)); + } + + // rsqrt + template + XSIMD_INLINE batch rsqrt(batch const& val, requires_arch) noexcept + { + return _mm256_rsqrt_ps(val); + } + template + XSIMD_INLINE batch rsqrt(batch const& val, requires_arch) noexcept + { + return _mm256_cvtps_pd(_mm_rsqrt_ps(_mm256_cvtpd_ps(val))); + } + + // sadd + template ::value, void>::type> + XSIMD_INLINE batch sadd(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + auto mask = (other >> (8 * sizeof(T) - 1)); + auto self_pos_branch = min(std::numeric_limits::max() - other, self); + auto self_neg_branch = max(std::numeric_limits::min() - other, self); + return other + select(batch_bool(mask.data), self_neg_branch, self_pos_branch); + } + else + { + const auto diffmax = std::numeric_limits::max() - self; + const auto mindiff = min(diffmax, other); + return self + mindiff; + } + } + + // select + template + XSIMD_INLINE batch select(batch_bool const& cond, batch const& true_br, batch const& false_br, requires_arch) noexcept + { + return _mm256_blendv_ps(false_br, true_br, cond); + } + template + XSIMD_INLINE batch select(batch_bool const& cond, batch const& true_br, batch const& false_br, requires_arch) noexcept + { + return _mm256_blendv_pd(false_br, true_br, cond); + } + template ::value, void>::type> + XSIMD_INLINE batch select(batch_bool const& cond, batch const& true_br, batch const& false_br, requires_arch) noexcept + { + __m128i cond_low, cond_hi; + detail::split_avx(cond, cond_low, cond_hi); + + __m128i true_low, true_hi; + detail::split_avx(true_br, true_low, true_hi); + + __m128i false_low, false_hi; + detail::split_avx(false_br, false_low, false_hi); + + __m128i res_low = select(batch_bool(cond_low), batch(true_low), batch(false_low), sse4_2 {}); + __m128i res_hi = select(batch_bool(cond_hi), batch(true_hi), batch(false_hi), sse4_2 {}); + return detail::merge_sse(res_low, res_hi); + } + template ::value, void>::type> + XSIMD_INLINE batch select(batch_bool_constant const&, batch const& true_br, batch const& false_br, requires_arch) noexcept + { + return select(batch_bool { Values... }, true_br, false_br, avx2 {}); + } + + template + XSIMD_INLINE batch select(batch_bool_constant const&, batch const& true_br, batch const& false_br, requires_arch) noexcept + { + constexpr auto mask = batch_bool_constant::mask(); + return _mm256_blend_ps(false_br, true_br, mask); + } + + template + XSIMD_INLINE batch select(batch_bool_constant const&, batch const& true_br, batch const& false_br, requires_arch) noexcept + { + constexpr auto mask = batch_bool_constant::mask(); + return _mm256_blend_pd(false_br, true_br, mask); + } + + // set + template + XSIMD_INLINE batch set(batch const&, requires_arch, Values... values) noexcept + { + static_assert(sizeof...(Values) == batch::size, "consistent init"); + return _mm256_setr_ps(values...); + } + + template + XSIMD_INLINE batch set(batch const&, requires_arch, Values... values) noexcept + { + static_assert(sizeof...(Values) == batch::size, "consistent init"); + return _mm256_setr_pd(values...); + } + template ::value, void>::type> + XSIMD_INLINE batch set(batch const&, requires_arch, T v0, T v1, T v2, T v3) noexcept + { + return _mm256_set_epi64x(v3, v2, v1, v0); + } + template ::value, void>::type> + XSIMD_INLINE batch set(batch const&, requires_arch, T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7) noexcept + { + return _mm256_setr_epi32(v0, v1, v2, v3, v4, v5, v6, v7); + } + template ::value, void>::type> + XSIMD_INLINE batch set(batch const&, requires_arch, T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15) noexcept + { + return _mm256_setr_epi16(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15); + } + template ::value, void>::type> + XSIMD_INLINE batch set(batch const&, requires_arch, T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15, + T v16, T v17, T v18, T v19, T v20, T v21, T v22, T v23, T v24, T v25, T v26, T v27, T v28, T v29, T v30, T v31) noexcept + { + return _mm256_setr_epi8(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31); + } + + template ::value, void>::type> + XSIMD_INLINE batch_bool set(batch_bool const&, requires_arch, Values... values) noexcept + { + return set(batch(), A {}, static_cast(values ? -1LL : 0LL)...).data; + } + + template + XSIMD_INLINE batch_bool set(batch_bool const&, requires_arch, Values... values) noexcept + { + static_assert(sizeof...(Values) == batch_bool::size, "consistent init"); + return _mm256_castsi256_ps(set(batch(), A {}, static_cast(values ? -1LL : 0LL)...).data); + } + + template + XSIMD_INLINE batch_bool set(batch_bool const&, requires_arch, Values... values) noexcept + { + static_assert(sizeof...(Values) == batch_bool::size, "consistent init"); + return _mm256_castsi256_pd(set(batch(), A {}, static_cast(values ? -1LL : 0LL)...).data); + } + + // shuffle + template + XSIMD_INLINE batch shuffle(batch const& x, batch const& y, batch_constant mask, requires_arch) noexcept + { + constexpr uint32_t smask = detail::mod_shuffle(I0, I1, I2, I3); + // shuffle within lane + if (I4 == (I0 + 4) && I5 == (I1 + 4) && I6 == (I2 + 4) && I7 == (I3 + 4) && I0 < 4 && I1 < 4 && I2 >= 8 && I2 < 12 && I3 >= 8 && I3 < 12) + return _mm256_shuffle_ps(x, y, smask); + + // shuffle within opposite lane + if (I4 == (I0 + 4) && I5 == (I1 + 4) && I6 == (I2 + 4) && I7 == (I3 + 4) && I2 < 4 && I3 < 4 && I0 >= 8 && I0 < 12 && I1 >= 8 && I1 < 12) + return _mm256_shuffle_ps(y, x, smask); + + return shuffle(x, y, mask, generic {}); + } + + template + XSIMD_INLINE batch shuffle(batch const& x, batch const& y, batch_constant mask, requires_arch) noexcept + { + constexpr uint32_t smask = (I0 & 0x1) | ((I1 & 0x1) << 1) | ((I2 & 0x1) << 2) | ((I3 & 0x1) << 3); + // shuffle within lane + if (I0 < 2 && I1 >= 4 && I1 < 6 && I2 >= 2 && I2 < 4 && I3 >= 6) + return _mm256_shuffle_pd(x, y, smask); + + // shuffle within opposite lane + if (I1 < 2 && I0 >= 4 && I0 < 6 && I3 >= 2 && I3 < 4 && I2 >= 6) + return _mm256_shuffle_pd(y, x, smask); + + return shuffle(x, y, mask, generic {}); + } + + // slide_left + template + XSIMD_INLINE batch slide_left(batch const& x, requires_arch) noexcept + { + constexpr unsigned BitCount = N * 8; + if (BitCount == 0) + { + return x; + } + if (BitCount >= 256) + { + return batch(T(0)); + } + if (BitCount > 128) + { + constexpr unsigned M = (BitCount - 128) / 8; + __m128i low = _mm256_castsi256_si128(x); + auto y = _mm_slli_si128(low, M); + __m256i zero = _mm256_setzero_si256(); + return _mm256_insertf128_si256(zero, y, 1); + } + if (BitCount == 128) + { + __m128i low = _mm256_castsi256_si128(x); + __m256i zero = _mm256_setzero_si256(); + return _mm256_insertf128_si256(zero, low, 1); + } + // shifting by [0, 128[ bits + constexpr unsigned M = BitCount / 8; + + __m128i low = _mm256_castsi256_si128(x); + auto ylow = _mm_slli_si128(low, M); + auto zlow = _mm_srli_si128(low, 16 - M); + + __m128i high = _mm256_extractf128_si256(x, 1); + auto yhigh = _mm_slli_si128(high, M); + + __m256i res = _mm256_castsi128_si256(ylow); + return _mm256_insertf128_si256(res, _mm_or_si128(yhigh, zlow), 1); + } + + // slide_right + template + XSIMD_INLINE batch slide_right(batch const& x, requires_arch) noexcept + { + constexpr unsigned BitCount = N * 8; + if (BitCount == 0) + { + return x; + } + if (BitCount >= 256) + { + return batch(T(0)); + } + if (BitCount > 128) + { + constexpr unsigned M = (BitCount - 128) / 8; + __m128i high = _mm256_extractf128_si256(x, 1); + __m128i y = _mm_srli_si128(high, M); + __m256i zero = _mm256_setzero_si256(); + return _mm256_insertf128_si256(zero, y, 0); + } + if (BitCount == 128) + { + __m128i high = _mm256_extractf128_si256(x, 1); + return _mm256_castsi128_si256(high); + } + // shifting by [0, 128[ bits + constexpr unsigned M = BitCount / 8; + + __m128i low = _mm256_castsi256_si128(x); + auto ylow = _mm_srli_si128(low, M); + + __m128i high = _mm256_extractf128_si256(x, 1); + auto yhigh = _mm_srli_si128(high, M); + auto zhigh = _mm_slli_si128(high, 16 - M); + + __m256i res = _mm256_castsi128_si256(_mm_or_si128(ylow, zhigh)); + return _mm256_insertf128_si256(res, yhigh, 1); + } + + // sqrt + template + XSIMD_INLINE batch sqrt(batch const& val, requires_arch) noexcept + { + return _mm256_sqrt_ps(val); + } + template + XSIMD_INLINE batch sqrt(batch const& val, requires_arch) noexcept + { + return _mm256_sqrt_pd(val); + } + + // ssub + template ::value, void>::type> + XSIMD_INLINE batch ssub(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + return sadd(self, -other); + } + else + { + const auto diff = min(self, other); + return self - diff; + } + } + + // store_aligned + template ::value, void>::type> + XSIMD_INLINE void store_aligned(T* mem, batch const& self, requires_arch) noexcept + { + return _mm256_store_si256((__m256i*)mem, self); + } + template ::value, void>::type> + XSIMD_INLINE void store_aligned(T* mem, batch_bool const& self, requires_arch) noexcept + { + return _mm256_store_si256((__m256i*)mem, self); + } + template + XSIMD_INLINE void store_aligned(float* mem, batch const& self, requires_arch) noexcept + { + return _mm256_store_ps(mem, self); + } + template + XSIMD_INLINE void store_aligned(double* mem, batch const& self, requires_arch) noexcept + { + return _mm256_store_pd(mem, self); + } + + // store_unaligned + template ::value, void>::type> + XSIMD_INLINE void store_unaligned(T* mem, batch const& self, requires_arch) noexcept + { + return _mm256_storeu_si256((__m256i*)mem, self); + } + template ::value, void>::type> + XSIMD_INLINE void store_unaligned(T* mem, batch_bool const& self, requires_arch) noexcept + { + return _mm256_storeu_si256((__m256i*)mem, self); + } + template + XSIMD_INLINE void store_unaligned(float* mem, batch const& self, requires_arch) noexcept + { + return _mm256_storeu_ps(mem, self); + } + template + XSIMD_INLINE void store_unaligned(double* mem, batch const& self, requires_arch) noexcept + { + return _mm256_storeu_pd(mem, self); + } + + // sub + template ::value, void>::type> + XSIMD_INLINE batch sub(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::fwd_to_sse([](__m128i s, __m128i o) noexcept + { return sub(batch(s), batch(o)); }, + self, other); + } + template + XSIMD_INLINE batch sub(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_sub_ps(self, other); + } + template + XSIMD_INLINE batch sub(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_sub_pd(self, other); + } + + // swizzle (dynamic mask) + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + // duplicate low and high part of input + __m256 hi = _mm256_castps128_ps256(_mm256_extractf128_ps(self, 1)); + __m256 hi_hi = _mm256_insertf128_ps(self, _mm256_castps256_ps128(hi), 0); + + __m256 low = _mm256_castps128_ps256(_mm256_castps256_ps128(self)); + __m256 low_low = _mm256_insertf128_ps(self, _mm256_castps256_ps128(low), 1); + + // normalize mask + batch half_mask = mask % 4; + + // permute within each lane + __m256 r0 = _mm256_permutevar_ps(low_low, half_mask); + __m256 r1 = _mm256_permutevar_ps(hi_hi, half_mask); + + // mask to choose the right lane + batch_bool blend_mask = mask >= 4; + + // blend the two permutes + return _mm256_blendv_ps(r0, r1, batch_bool_cast(blend_mask)); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + // duplicate low and high part of input + __m256d hi = _mm256_castpd128_pd256(_mm256_extractf128_pd(self, 1)); + __m256d hi_hi = _mm256_insertf128_pd(self, _mm256_castpd256_pd128(hi), 0); + + __m256d low = _mm256_castpd128_pd256(_mm256_castpd256_pd128(self)); + __m256d low_low = _mm256_insertf128_pd(self, _mm256_castpd256_pd128(low), 1); + + // normalize mask + batch half_mask = -(mask & 1); + + // permute within each lane + __m256d r0 = _mm256_permutevar_pd(low_low, half_mask); + __m256d r1 = _mm256_permutevar_pd(hi_hi, half_mask); + + // mask to choose the right lane + batch_bool blend_mask = mask >= 2; + + // blend the two permutes + return _mm256_blendv_pd(r0, r1, batch_bool_cast(blend_mask)); + } + + template = 0> + XSIMD_INLINE batch swizzle(batch const& self, batch const& mask, requires_arch) noexcept + { + return bitwise_cast( + swizzle(bitwise_cast(self), mask)); + } + + template = 0> + XSIMD_INLINE batch + swizzle(batch const& self, batch const& mask, requires_arch) noexcept + { + return bitwise_cast( + swizzle(bitwise_cast(self), mask)); + } + + // swizzle (constant mask) + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant, requires_arch) noexcept + { + // duplicate low and high part of input + __m256 hi = _mm256_castps128_ps256(_mm256_extractf128_ps(self, 1)); + __m256 hi_hi = _mm256_insertf128_ps(self, _mm256_castps256_ps128(hi), 0); + + __m256 low = _mm256_castps128_ps256(_mm256_castps256_ps128(self)); + __m256 low_low = _mm256_insertf128_ps(self, _mm256_castps256_ps128(low), 1); + + // normalize mask + batch_constant half_mask; + + // permute within each lane + __m256 r0 = _mm256_permutevar_ps(low_low, half_mask.as_batch()); + __m256 r1 = _mm256_permutevar_ps(hi_hi, half_mask.as_batch()); + + // mask to choose the right lane + batch_bool_constant= 4), (V1 >= 4), (V2 >= 4), (V3 >= 4), (V4 >= 4), (V5 >= 4), (V6 >= 4), (V7 >= 4)> blend_mask; + + // blend the two permutes + constexpr auto mask = blend_mask.mask(); + return _mm256_blend_ps(r0, r1, mask); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant, requires_arch) noexcept + { + // duplicate low and high part of input + __m256d hi = _mm256_castpd128_pd256(_mm256_extractf128_pd(self, 1)); + __m256d hi_hi = _mm256_insertf128_pd(self, _mm256_castpd256_pd128(hi), 0); + + __m256d low = _mm256_castpd128_pd256(_mm256_castpd256_pd128(self)); + __m256d low_low = _mm256_insertf128_pd(self, _mm256_castpd256_pd128(low), 1); + + // normalize mask + batch_constant half_mask; + + // permute within each lane + __m256d r0 = _mm256_permutevar_pd(low_low, half_mask.as_batch()); + __m256d r1 = _mm256_permutevar_pd(hi_hi, half_mask.as_batch()); + + // mask to choose the right lane + batch_bool_constant= 2), (V1 >= 2), (V2 >= 2), (V3 >= 2)> blend_mask; + + // blend the two permutes + constexpr auto mask = blend_mask.mask(); + return _mm256_blend_pd(r0, r1, mask); + } + template = 0> + XSIMD_INLINE batch swizzle(batch const& self, + batch_constant const& mask, + requires_arch) noexcept + { + return bitwise_cast( + swizzle(bitwise_cast(self), mask)); + } + + template = 0> + XSIMD_INLINE batch + swizzle(batch const& self, + batch_constant const& mask, + requires_arch) noexcept + { + return bitwise_cast( + swizzle(bitwise_cast(self), mask)); + } + // transpose + template + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept + { + assert((matrix_end - matrix_begin == batch::size) && "correctly sized matrix"); + (void)matrix_end; + // See + // https://stackoverflow.com/questions/25622745/transpose-an-8x8-float-using-avx-avx2 + auto r0 = matrix_begin[0], r1 = matrix_begin[1], + r2 = matrix_begin[2], r3 = matrix_begin[3], + r4 = matrix_begin[4], r5 = matrix_begin[5], + r6 = matrix_begin[6], r7 = matrix_begin[7]; + + auto t0 = _mm256_unpacklo_ps(r0, r1); + auto t1 = _mm256_unpackhi_ps(r0, r1); + auto t2 = _mm256_unpacklo_ps(r2, r3); + auto t3 = _mm256_unpackhi_ps(r2, r3); + auto t4 = _mm256_unpacklo_ps(r4, r5); + auto t5 = _mm256_unpackhi_ps(r4, r5); + auto t6 = _mm256_unpacklo_ps(r6, r7); + auto t7 = _mm256_unpackhi_ps(r6, r7); + + r0 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(1, 0, 1, 0)); + r1 = _mm256_shuffle_ps(t0, t2, _MM_SHUFFLE(3, 2, 3, 2)); + r2 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(1, 0, 1, 0)); + r3 = _mm256_shuffle_ps(t1, t3, _MM_SHUFFLE(3, 2, 3, 2)); + r4 = _mm256_shuffle_ps(t4, t6, _MM_SHUFFLE(1, 0, 1, 0)); + r5 = _mm256_shuffle_ps(t4, t6, _MM_SHUFFLE(3, 2, 3, 2)); + r6 = _mm256_shuffle_ps(t5, t7, _MM_SHUFFLE(1, 0, 1, 0)); + r7 = _mm256_shuffle_ps(t5, t7, _MM_SHUFFLE(3, 2, 3, 2)); + + matrix_begin[0] = _mm256_permute2f128_ps(r0, r4, 0x20); + matrix_begin[1] = _mm256_permute2f128_ps(r1, r5, 0x20); + matrix_begin[2] = _mm256_permute2f128_ps(r2, r6, 0x20); + matrix_begin[3] = _mm256_permute2f128_ps(r3, r7, 0x20); + matrix_begin[4] = _mm256_permute2f128_ps(r0, r4, 0x31); + matrix_begin[5] = _mm256_permute2f128_ps(r1, r5, 0x31); + matrix_begin[6] = _mm256_permute2f128_ps(r2, r6, 0x31); + matrix_begin[7] = _mm256_permute2f128_ps(r3, r7, 0x31); + } + + template + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept + { + return transpose(reinterpret_cast*>(matrix_begin), reinterpret_cast*>(matrix_end), A {}); + } + template + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept + { + return transpose(reinterpret_cast*>(matrix_begin), reinterpret_cast*>(matrix_end), A {}); + } + + template + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept + { + assert((matrix_end - matrix_begin == batch::size) && "correctly sized matrix"); + (void)matrix_end; + auto r0 = matrix_begin[0], r1 = matrix_begin[1], + r2 = matrix_begin[2], r3 = matrix_begin[3]; + + auto t0 = _mm256_unpacklo_pd(r0, r1); // r00 r10 r01 r11 + auto t1 = _mm256_unpackhi_pd(r0, r1); // r02 r12 r03 r13 + auto t2 = _mm256_unpacklo_pd(r2, r3); // r20 r30 r21 r31 + auto t3 = _mm256_unpackhi_pd(r2, r3); // r22 r32 r23 r33 + + matrix_begin[0] = _mm256_permute2f128_pd(t0, t2, 0x20); + matrix_begin[1] = _mm256_permute2f128_pd(t1, t3, 0x20); + matrix_begin[2] = _mm256_permute2f128_pd(t0, t2, 0x31); + matrix_begin[3] = _mm256_permute2f128_pd(t1, t3, 0x31); + } + + template + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept + { + return transpose(reinterpret_cast*>(matrix_begin), reinterpret_cast*>(matrix_end), A {}); + } + template + XSIMD_INLINE void transpose(batch* matrix_begin, batch* matrix_end, requires_arch) noexcept + { + return transpose(reinterpret_cast*>(matrix_begin), reinterpret_cast*>(matrix_end), A {}); + } + + // trunc + template + XSIMD_INLINE batch trunc(batch const& self, requires_arch) noexcept + { + return _mm256_round_ps(self, _MM_FROUND_TO_ZERO); + } + template + XSIMD_INLINE batch trunc(batch const& self, requires_arch) noexcept + { + return _mm256_round_pd(self, _MM_FROUND_TO_ZERO); + } + + // zip_hi + template ::value, void>::type> + XSIMD_INLINE batch zip_hi(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1 || sizeof(T) == 2) + { + // extract high word + __m128i self_hi = _mm256_extractf128_si256(self, 1); + __m128i other_hi = _mm256_extractf128_si256(other, 1); + + // interleave + __m128i res_lo, res_hi; + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + res_lo = _mm_unpacklo_epi8(self_hi, other_hi); + res_hi = _mm_unpackhi_epi8(self_hi, other_hi); + } + else + { + res_lo = _mm_unpacklo_epi16(self_hi, other_hi); + res_hi = _mm_unpackhi_epi16(self_hi, other_hi); + } + + // fuse + return _mm256_castps_si256( + _mm256_insertf128_ps( + _mm256_castsi256_ps(_mm256_castsi128_si256(res_lo)), + _mm_castsi128_ps(res_hi), + 1)); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + auto lo = _mm256_unpacklo_ps(_mm256_castsi256_ps(self), _mm256_castsi256_ps(other)); + auto hi = _mm256_unpackhi_ps(_mm256_castsi256_ps(self), _mm256_castsi256_ps(other)); + return _mm256_castps_si256(_mm256_permute2f128_ps(lo, hi, 0x31)); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + auto lo = _mm256_unpacklo_pd(_mm256_castsi256_pd(self), _mm256_castsi256_pd(other)); + auto hi = _mm256_unpackhi_pd(_mm256_castsi256_pd(self), _mm256_castsi256_pd(other)); + return _mm256_castpd_si256(_mm256_permute2f128_pd(lo, hi, 0x31)); + } + else + { + assert(false && "unsupported arch/op combination"); + return {}; + } + } + template + XSIMD_INLINE batch zip_hi(batch const& self, batch const& other, requires_arch) noexcept + { + auto lo = _mm256_unpacklo_ps(self, other); + auto hi = _mm256_unpackhi_ps(self, other); + return _mm256_permute2f128_ps(lo, hi, 0x31); + } + template + XSIMD_INLINE batch zip_hi(batch const& self, batch const& other, requires_arch) noexcept + { + auto lo = _mm256_unpacklo_pd(self, other); + auto hi = _mm256_unpackhi_pd(self, other); + return _mm256_permute2f128_pd(lo, hi, 0x31); + } + + // zip_lo + template ::value, void>::type> + XSIMD_INLINE batch zip_lo(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1 || sizeof(T) == 2) + { + // extract low word + __m128i self_lo = _mm256_extractf128_si256(self, 0); + __m128i other_lo = _mm256_extractf128_si256(other, 0); + + // interleave + __m128i res_lo, res_hi; + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + res_lo = _mm_unpacklo_epi8(self_lo, other_lo); + res_hi = _mm_unpackhi_epi8(self_lo, other_lo); + } + else + { + res_lo = _mm_unpacklo_epi16(self_lo, other_lo); + res_hi = _mm_unpackhi_epi16(self_lo, other_lo); + } + + // fuse + return _mm256_castps_si256( + _mm256_insertf128_ps( + _mm256_castsi256_ps(_mm256_castsi128_si256(res_lo)), + _mm_castsi128_ps(res_hi), + 1)); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + auto lo = _mm256_unpacklo_ps(_mm256_castsi256_ps(self), _mm256_castsi256_ps(other)); + auto hi = _mm256_unpackhi_ps(_mm256_castsi256_ps(self), _mm256_castsi256_ps(other)); + return _mm256_castps_si256(_mm256_insertf128_ps(lo, _mm256_castps256_ps128(hi), 1)); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + auto lo = _mm256_unpacklo_pd(_mm256_castsi256_pd(self), _mm256_castsi256_pd(other)); + auto hi = _mm256_unpackhi_pd(_mm256_castsi256_pd(self), _mm256_castsi256_pd(other)); + return _mm256_castpd_si256(_mm256_insertf128_pd(lo, _mm256_castpd256_pd128(hi), 1)); + } + else + { + assert(false && "unsupported arch/op combination"); + return {}; + } + } + + template + XSIMD_INLINE batch zip_lo(batch const& self, batch const& other, requires_arch) noexcept + { + auto lo = _mm256_unpacklo_ps(self, other); + auto hi = _mm256_unpackhi_ps(self, other); + return _mm256_insertf128_ps(lo, _mm256_castps256_ps128(hi), 1); + } + template + XSIMD_INLINE batch zip_lo(batch const& self, batch const& other, requires_arch) noexcept + { + auto lo = _mm256_unpacklo_pd(self, other); + auto hi = _mm256_unpackhi_pd(self, other); + return _mm256_insertf128_pd(lo, _mm256_castpd256_pd128(hi), 1); + } + } +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_avx2.hpp b/include/onnxruntime/xsimd/arch/xsimd_avx2.hpp new file mode 100644 index 0000000000000..506299a0dd8e8 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_avx2.hpp @@ -0,0 +1,1021 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_AVX2_HPP +#define XSIMD_AVX2_HPP + +#include +#include + +#include "../types/xsimd_avx2_register.hpp" + +namespace xsimd +{ + + namespace kernel + { + using namespace types; + + // abs + template ::value, void>::type> + XSIMD_INLINE batch abs(batch const& self, requires_arch) noexcept + { + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_abs_epi8(self); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_abs_epi16(self); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_abs_epi32(self); + } + else + { + return abs(self, avx {}); + } + } + return self; + } + + // add + template ::value, void>::type> + XSIMD_INLINE batch add(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_add_epi8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_add_epi16(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_add_epi32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm256_add_epi64(self, other); + } + else + { + return add(self, other, avx {}); + } + } + + // avgr + template ::value, void>::type> + XSIMD_INLINE batch avgr(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_avg_epu8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_avg_epu16(self, other); + } + else + { + return avgr(self, other, generic {}); + } + } + + // avg + template ::value, void>::type> + XSIMD_INLINE batch avg(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + auto adj = ((self ^ other) << 7) >> 7; + return avgr(self, other, A {}) - adj; + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + auto adj = ((self ^ other) << 15) >> 15; + return avgr(self, other, A {}) - adj; + } + else + { + return avg(self, other, generic {}); + } + } + + // bitwise_and + template ::value, void>::type> + XSIMD_INLINE batch bitwise_and(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_and_si256(self, other); + } + template ::value, void>::type> + XSIMD_INLINE batch_bool bitwise_and(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return _mm256_and_si256(self, other); + } + + // bitwise_andnot + template ::value, void>::type> + XSIMD_INLINE batch bitwise_andnot(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_andnot_si256(other, self); + } + template ::value, void>::type> + XSIMD_INLINE batch_bool bitwise_andnot(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return _mm256_andnot_si256(other, self); + } + + // bitwise_not + template ::value, void>::type> + XSIMD_INLINE batch bitwise_not(batch const& self, requires_arch) noexcept + { + return _mm256_xor_si256(self, _mm256_set1_epi32(-1)); + } + template ::value, void>::type> + XSIMD_INLINE batch_bool bitwise_not(batch_bool const& self, requires_arch) noexcept + { + return _mm256_xor_si256(self, _mm256_set1_epi32(-1)); + } + + // bitwise_lshift + template ::value, void>::type> + XSIMD_INLINE batch bitwise_lshift(batch const& self, int32_t other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_slli_epi16(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_slli_epi32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm256_slli_epi64(self, other); + } + else + { + return bitwise_lshift(self, other, avx {}); + } + } + + template ::value, void>::type> + XSIMD_INLINE batch bitwise_lshift(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_sllv_epi32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm256_sllv_epi64(self, other); + } + else + { + return bitwise_lshift(self, other, avx {}); + } + } + + // bitwise_or + template ::value, void>::type> + XSIMD_INLINE batch bitwise_or(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_or_si256(self, other); + } + template ::value, void>::type> + XSIMD_INLINE batch_bool bitwise_or(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return _mm256_or_si256(self, other); + } + + // bitwise_rshift + template ::value, void>::type> + XSIMD_INLINE batch bitwise_rshift(batch const& self, int32_t other, requires_arch) noexcept + { + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + __m256i sign_mask = _mm256_set1_epi16((0xFF00 >> other) & 0x00FF); + __m256i cmp_is_negative = _mm256_cmpgt_epi8(_mm256_setzero_si256(), self); + __m256i res = _mm256_srai_epi16(self, other); + return _mm256_or_si256( + detail::fwd_to_sse([](__m128i s, __m128i o) noexcept + { return bitwise_and(batch(s), batch(o), sse4_2 {}); }, + sign_mask, cmp_is_negative), + _mm256_andnot_si256(sign_mask, res)); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_srai_epi16(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_srai_epi32(self, other); + } + else + { + return bitwise_rshift(self, other, avx {}); + } + } + else + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_srli_epi16(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_srli_epi32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm256_srli_epi64(self, other); + } + else + { + return bitwise_rshift(self, other, avx {}); + } + } + } + + template ::value, void>::type> + XSIMD_INLINE batch bitwise_rshift(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_srav_epi32(self, other); + } + else + { + return bitwise_rshift(self, other, avx {}); + } + } + else + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_srlv_epi32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm256_srlv_epi64(self, other); + } + else + { + return bitwise_rshift(self, other, avx {}); + } + } + } + + // bitwise_xor + template ::value, void>::type> + XSIMD_INLINE batch bitwise_xor(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm256_xor_si256(self, other); + } + template ::value, void>::type> + XSIMD_INLINE batch bitwise_xor(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + return _mm256_xor_si256(self, other); + } + + // complex_low + template + XSIMD_INLINE batch complex_low(batch, A> const& self, requires_arch) noexcept + { + __m256d tmp0 = _mm256_permute4x64_pd(self.real(), _MM_SHUFFLE(3, 1, 1, 0)); + __m256d tmp1 = _mm256_permute4x64_pd(self.imag(), _MM_SHUFFLE(1, 2, 0, 0)); + return _mm256_blend_pd(tmp0, tmp1, 10); + } + + // complex_high + template + XSIMD_INLINE batch complex_high(batch, A> const& self, requires_arch) noexcept + { + __m256d tmp0 = _mm256_permute4x64_pd(self.real(), _MM_SHUFFLE(3, 3, 1, 2)); + __m256d tmp1 = _mm256_permute4x64_pd(self.imag(), _MM_SHUFFLE(3, 2, 2, 0)); + return _mm256_blend_pd(tmp0, tmp1, 10); + } + + // fast_cast + namespace detail + { + + template + XSIMD_INLINE batch fast_cast(batch const& x, batch const&, requires_arch) noexcept + { + // from https://stackoverflow.com/questions/41144668/how-to-efficiently-perform-double-int64-conversions-with-sse-avx + // adapted to avx + __m256i xH = _mm256_srli_epi64(x, 32); + xH = _mm256_or_si256(xH, _mm256_castpd_si256(_mm256_set1_pd(19342813113834066795298816.))); // 2^84 + __m256i mask = _mm256_setr_epi16(0xFFFF, 0xFFFF, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0x0000, 0x0000, + 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0x0000, 0x0000); + __m256i xL = _mm256_or_si256(_mm256_and_si256(mask, x), _mm256_andnot_si256(mask, _mm256_castpd_si256(_mm256_set1_pd(0x0010000000000000)))); // 2^52 + __m256d f = _mm256_sub_pd(_mm256_castsi256_pd(xH), _mm256_set1_pd(19342813118337666422669312.)); // 2^84 + 2^52 + return _mm256_add_pd(f, _mm256_castsi256_pd(xL)); + } + + template + XSIMD_INLINE batch fast_cast(batch const& x, batch const&, requires_arch) noexcept + { + // from https://stackoverflow.com/questions/41144668/how-to-efficiently-perform-double-int64-conversions-with-sse-avx + // adapted to avx + __m256i xH = _mm256_srai_epi32(x, 16); + xH = _mm256_and_si256(xH, _mm256_setr_epi16(0x0000, 0x0000, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0xFFFF, 0xFFFF, 0x0000, 0x0000, 0xFFFF, 0xFFFF)); + xH = _mm256_add_epi64(xH, _mm256_castpd_si256(_mm256_set1_pd(442721857769029238784.))); // 3*2^67 + __m256i mask = _mm256_setr_epi16(0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, + 0xFFFF, 0xFFFF, 0xFFFF, 0x0000, 0xFFFF, 0xFFFF, 0xFFFF, 0x0000); + __m256i xL = _mm256_or_si256(_mm256_and_si256(mask, x), _mm256_andnot_si256(mask, _mm256_castpd_si256(_mm256_set1_pd(0x0010000000000000)))); // 2^52 + __m256d f = _mm256_sub_pd(_mm256_castsi256_pd(xH), _mm256_set1_pd(442726361368656609280.)); // 3*2^67 + 2^52 + return _mm256_add_pd(f, _mm256_castsi256_pd(xL)); + } + } + + // eq + template ::value, void>::type> + XSIMD_INLINE batch_bool eq(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_cmpeq_epi8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_cmpeq_epi16(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_cmpeq_epi32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm256_cmpeq_epi64(self, other); + } + else + { + return eq(self, other, avx {}); + } + } + + // gather + template = 0, detail::enable_sized_integral_t = 0> + XSIMD_INLINE batch gather(batch const&, T const* src, batch const& index, + kernel::requires_arch) noexcept + { + // scatter for this one is AVX512F+AVX512VL + return _mm256_i32gather_epi32(reinterpret_cast(src), index, sizeof(T)); + } + + template = 0, detail::enable_sized_integral_t = 0> + XSIMD_INLINE batch gather(batch const&, T const* src, batch const& index, + kernel::requires_arch) noexcept + { + // scatter for this one is AVX512F+AVX512VL + return _mm256_i64gather_epi64(reinterpret_cast(src), index, sizeof(T)); + } + + template = 0> + XSIMD_INLINE batch gather(batch const&, float const* src, + batch const& index, + kernel::requires_arch) noexcept + { + // scatter for this one is AVX512F+AVX512VL + return _mm256_i32gather_ps(src, index, sizeof(float)); + } + + template = 0> + XSIMD_INLINE batch gather(batch const&, double const* src, + batch const& index, + requires_arch) noexcept + { + // scatter for this one is AVX512F+AVX512VL + return _mm256_i64gather_pd(src, index, sizeof(double)); + } + + // gather: handmade conversions + template = 0> + XSIMD_INLINE batch gather(batch const&, double const* src, + batch const& index, + requires_arch) noexcept + { + const batch low(_mm256_i32gather_pd(src, _mm256_castsi256_si128(index.data), sizeof(double))); + const batch high(_mm256_i32gather_pd(src, _mm256_extractf128_si256(index.data, 1), sizeof(double))); + return detail::merge_sse(_mm256_cvtpd_ps(low.data), _mm256_cvtpd_ps(high.data)); + } + + template = 0> + XSIMD_INLINE batch gather(batch const&, double const* src, + batch const& index, + requires_arch) noexcept + { + const batch low(_mm256_i32gather_pd(src, _mm256_castsi256_si128(index.data), sizeof(double))); + const batch high(_mm256_i32gather_pd(src, _mm256_extractf128_si256(index.data, 1), sizeof(double))); + return detail::merge_sse(_mm256_cvtpd_epi32(low.data), _mm256_cvtpd_epi32(high.data)); + } + + // lt + template ::value, void>::type> + XSIMD_INLINE batch_bool lt(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_cmpgt_epi8(other, self); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_cmpgt_epi16(other, self); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_cmpgt_epi32(other, self); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm256_cmpgt_epi64(other, self); + } + else + { + return lt(self, other, avx {}); + } + } + else + { + return lt(self, other, avx {}); + } + } + + // load_complex + template + XSIMD_INLINE batch, A> load_complex(batch const& hi, batch const& lo, requires_arch) noexcept + { + using batch_type = batch; + batch_type real = _mm256_castpd_ps( + _mm256_permute4x64_pd( + _mm256_castps_pd(_mm256_shuffle_ps(hi, lo, _MM_SHUFFLE(2, 0, 2, 0))), + _MM_SHUFFLE(3, 1, 2, 0))); + batch_type imag = _mm256_castpd_ps( + _mm256_permute4x64_pd( + _mm256_castps_pd(_mm256_shuffle_ps(hi, lo, _MM_SHUFFLE(3, 1, 3, 1))), + _MM_SHUFFLE(3, 1, 2, 0))); + return { real, imag }; + } + template + XSIMD_INLINE batch, A> load_complex(batch const& hi, batch const& lo, requires_arch) noexcept + { + using batch_type = batch; + batch_type real = _mm256_permute4x64_pd(_mm256_unpacklo_pd(hi, lo), _MM_SHUFFLE(3, 1, 2, 0)); + batch_type imag = _mm256_permute4x64_pd(_mm256_unpackhi_pd(hi, lo), _MM_SHUFFLE(3, 1, 2, 0)); + return { real, imag }; + } + // mask + template ::value, void>::type> + XSIMD_INLINE uint64_t mask(batch_bool const& self, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return 0xFFFFFFFF & (uint64_t)_mm256_movemask_epi8(self); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + uint64_t mask8 = 0xFFFFFFFF & (uint64_t)_mm256_movemask_epi8(self); + return detail::mask_lut(mask8) | (detail::mask_lut(mask8 >> 8) << 4) | (detail::mask_lut(mask8 >> 16) << 8) | (detail::mask_lut(mask8 >> 24) << 12); + } + else + { + return mask(self, avx {}); + } + } + + // max + template ::value, void>::type> + XSIMD_INLINE batch max(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_max_epi8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_max_epi16(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_max_epi32(self, other); + } + else + { + return max(self, other, avx {}); + } + } + else + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_max_epu8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_max_epu16(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_max_epu32(self, other); + } + else + { + return max(self, other, avx {}); + } + } + } + + // min + template ::value, void>::type> + XSIMD_INLINE batch min(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_min_epi8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_min_epi16(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_min_epi32(self, other); + } + else + { + return min(self, other, avx {}); + } + } + else + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_min_epu8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_min_epu16(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_min_epu32(self, other); + } + else + { + return min(self, other, avx {}); + } + } + } + + // mul + template ::value, void>::type> + XSIMD_INLINE batch mul(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + __m256i mask_hi = _mm256_set1_epi32(0xFF00FF00); + __m256i res_lo = _mm256_mullo_epi16(self, other); + __m256i other_hi = _mm256_srli_epi16(other, 8); + __m256i self_hi = _mm256_and_si256(self, mask_hi); + __m256i res_hi = _mm256_mullo_epi16(self_hi, other_hi); + __m256i res = _mm256_blendv_epi8(res_lo, res_hi, mask_hi); + return res; + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_mullo_epi16(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_mullo_epi32(self, other); + } + else + { + return mul(self, other, avx {}); + } + } + + // reduce_add + template ::value, void>::type> + XSIMD_INLINE T reduce_add(batch const& self, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + __m256i tmp1 = _mm256_hadd_epi32(self, self); + __m256i tmp2 = _mm256_hadd_epi32(tmp1, tmp1); + __m128i tmp3 = _mm256_extracti128_si256(tmp2, 1); + __m128i tmp4 = _mm_add_epi32(_mm256_castsi256_si128(tmp2), tmp3); + return _mm_cvtsi128_si32(tmp4); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + __m256i tmp1 = _mm256_shuffle_epi32(self, 0x0E); + __m256i tmp2 = _mm256_add_epi64(self, tmp1); + __m128i tmp3 = _mm256_extracti128_si256(tmp2, 1); + __m128i res = _mm_add_epi64(_mm256_castsi256_si128(tmp2), tmp3); +#if defined(__x86_64__) + return _mm_cvtsi128_si64(res); +#else + __m128i m; + _mm_storel_epi64(&m, res); + int64_t i; + std::memcpy(&i, &m, sizeof(i)); + return i; +#endif + } + else + { + return reduce_add(self, avx {}); + } + } + + // rotate_left + template + XSIMD_INLINE batch rotate_left(batch const& self, requires_arch) noexcept + { + return _mm256_alignr_epi8(self, self, N); + } + template + XSIMD_INLINE batch rotate_left(batch const& self, requires_arch) noexcept + { + return bitwise_cast(rotate_left(bitwise_cast(self), avx2 {})); + } + + // sadd + template ::value, void>::type> + XSIMD_INLINE batch sadd(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_adds_epi8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_adds_epi16(self, other); + } + else + { + return sadd(self, other, avx {}); + } + } + else + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_adds_epu8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_adds_epu16(self, other); + } + else + { + return sadd(self, other, avx {}); + } + } + } + + // select + template ::value, void>::type> + XSIMD_INLINE batch select(batch_bool const& cond, batch const& true_br, batch const& false_br, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_blendv_epi8(false_br, true_br, cond); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_blendv_epi8(false_br, true_br, cond); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_blendv_epi8(false_br, true_br, cond); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm256_blendv_epi8(false_br, true_br, cond); + } + else + { + return select(cond, true_br, false_br, avx {}); + } + } + template ::value, void>::type> + XSIMD_INLINE batch select(batch_bool_constant const&, batch const& true_br, batch const& false_br, requires_arch) noexcept + { + constexpr int mask = batch_bool_constant::mask(); + // FIXME: for some reason mask here is not considered as an immediate, + // but it's okay for _mm256_blend_epi32 + // case 2: return _mm256_blend_epi16(false_br, true_br, mask); + XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_blend_epi32(false_br, true_br, mask); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + constexpr int imask = detail::interleave(mask); + return _mm256_blend_epi32(false_br, true_br, imask); + } + else + { + return select(batch_bool { Values... }, true_br, false_br, avx2 {}); + } + } + + // slide_left + template + XSIMD_INLINE batch slide_left(batch const& x, requires_arch) noexcept + { + constexpr unsigned BitCount = N * 8; + if (BitCount == 0) + { + return x; + } + if (BitCount >= 256) + { + return batch(T(0)); + } + if (BitCount > 128) + { + constexpr unsigned M = (BitCount - 128) / 8; + auto y = _mm256_bslli_epi128(x, M); + return _mm256_permute2x128_si256(y, y, 0x28); + } + if (BitCount == 128) + { + return _mm256_permute2x128_si256(x, x, 0x28); + } + // shifting by [0, 128[ bits + constexpr unsigned M = BitCount / 8; + auto y = _mm256_bslli_epi128(x, M); + auto z = _mm256_bsrli_epi128(x, 16 - M); + auto w = _mm256_permute2x128_si256(z, z, 0x28); + return _mm256_or_si256(y, w); + } + + // slide_right + template + XSIMD_INLINE batch slide_right(batch const& x, requires_arch) noexcept + { + constexpr unsigned BitCount = N * 8; + if (BitCount == 0) + { + return x; + } + if (BitCount >= 256) + { + return batch(T(0)); + } + if (BitCount > 128) + { + constexpr unsigned M = (BitCount - 128) / 8; + auto y = _mm256_bsrli_epi128(x, M); + return _mm256_permute2x128_si256(y, y, 0x81); + } + if (BitCount == 128) + { + return _mm256_permute2x128_si256(x, x, 0x81); + } + // shifting by [0, 128[ bits + constexpr unsigned M = BitCount / 8; + auto y = _mm256_bsrli_epi128(x, M); + auto z = _mm256_bslli_epi128(x, 16 - M); + auto w = _mm256_permute2x128_si256(z, z, 0x81); + return _mm256_or_si256(y, w); + } + + // ssub + template ::value, void>::type> + XSIMD_INLINE batch ssub(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_subs_epi8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_subs_epi16(self, other); + } + else + { + return ssub(self, other, avx {}); + } + } + else + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_subs_epu8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_subs_epu16(self, other); + } + else + { + return ssub(self, other, avx {}); + } + } + } + + // sub + template ::value, void>::type> + XSIMD_INLINE batch sub(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm256_sub_epi8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm256_sub_epi16(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm256_sub_epi32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm256_sub_epi64(self, other); + } + else + { + return sub(self, other, avx {}); + } + } + + // swizzle (dynamic mask) + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + return _mm256_permutevar8x32_ps(self, mask); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + batch broadcaster = { 0, 1, 0, 1, 0, 1, 0, 1 }; + constexpr uint64_t comb = 0x0000000100000001ul * 2; + return bitwise_cast(swizzle(bitwise_cast(self), bitwise_cast(mask * comb) + broadcaster, avx2 {})); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + return bitwise_cast(swizzle(bitwise_cast(self), mask, avx2 {})); + } + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + return bitwise_cast(swizzle(bitwise_cast(self), mask, avx2 {})); + } + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + return _mm256_permutevar8x32_epi32(self, mask); + } + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + return bitwise_cast(swizzle(bitwise_cast(self), mask, avx2 {})); + } + + // swizzle (constant mask) + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + { + return _mm256_permutevar8x32_ps(self, mask.as_batch()); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant, requires_arch) noexcept + { + constexpr auto mask = detail::shuffle(V0, V1, V2, V3); + return _mm256_permute4x64_pd(self, mask); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant, requires_arch) noexcept + { + constexpr auto mask = detail::shuffle(V0, V1, V2, V3); + return _mm256_permute4x64_epi64(self, mask); + } + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + { + return bitwise_cast(swizzle(bitwise_cast(self), mask, avx2 {})); + } + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + { + return _mm256_permutevar8x32_epi32(self, mask.as_batch()); + } + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + { + return bitwise_cast(swizzle(bitwise_cast(self), mask, avx2 {})); + } + + // zip_hi + template ::value, void>::type> + XSIMD_INLINE batch zip_hi(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + auto lo = _mm256_unpacklo_epi8(self, other); + auto hi = _mm256_unpackhi_epi8(self, other); + return _mm256_permute2f128_si256(lo, hi, 0x31); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + auto lo = _mm256_unpacklo_epi16(self, other); + auto hi = _mm256_unpackhi_epi16(self, other); + return _mm256_permute2f128_si256(lo, hi, 0x31); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + auto lo = _mm256_unpacklo_epi32(self, other); + auto hi = _mm256_unpackhi_epi32(self, other); + return _mm256_permute2f128_si256(lo, hi, 0x31); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + auto lo = _mm256_unpacklo_epi64(self, other); + auto hi = _mm256_unpackhi_epi64(self, other); + return _mm256_permute2f128_si256(lo, hi, 0x31); + } + else + { + assert(false && "unsupported arch/op combination"); + return {}; + } + } + + // zip_lo + template ::value, void>::type> + XSIMD_INLINE batch zip_lo(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + auto lo = _mm256_unpacklo_epi8(self, other); + auto hi = _mm256_unpackhi_epi8(self, other); + return _mm256_inserti128_si256(lo, _mm256_castsi256_si128(hi), 1); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + auto lo = _mm256_unpacklo_epi16(self, other); + auto hi = _mm256_unpackhi_epi16(self, other); + return _mm256_inserti128_si256(lo, _mm256_castsi256_si128(hi), 1); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + auto lo = _mm256_unpacklo_epi32(self, other); + auto hi = _mm256_unpackhi_epi32(self, other); + return _mm256_inserti128_si256(lo, _mm256_castsi256_si128(hi), 1); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + auto lo = _mm256_unpacklo_epi64(self, other); + auto hi = _mm256_unpackhi_epi64(self, other); + return _mm256_inserti128_si256(lo, _mm256_castsi256_si128(hi), 1); + } + else + { + assert(false && "unsupported arch/op combination"); + return {}; + } + } + } +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_avx512bw.hpp b/include/onnxruntime/xsimd/arch/xsimd_avx512bw.hpp new file mode 100644 index 0000000000000..724ced08776ef --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_avx512bw.hpp @@ -0,0 +1,701 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_AVX512BW_HPP +#define XSIMD_AVX512BW_HPP + +#include +#include + +#include "../types/xsimd_avx512bw_register.hpp" + +namespace xsimd +{ + + namespace kernel + { + using namespace types; + + namespace detail + { + template + XSIMD_INLINE batch_bool compare_int_avx512bw(batch const& self, batch const& other) noexcept + { + using register_type = typename batch_bool::register_type; + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return (register_type)_mm512_cmp_epi8_mask(self, other, Cmp); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return (register_type)_mm512_cmp_epi16_mask(self, other, Cmp); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return (register_type)_mm512_cmp_epi32_mask(self, other, Cmp); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return (register_type)_mm512_cmp_epi64_mask(self, other, Cmp); + } + } + else + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return (register_type)_mm512_cmp_epu8_mask(self, other, Cmp); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return (register_type)_mm512_cmp_epu16_mask(self, other, Cmp); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return (register_type)_mm512_cmp_epu32_mask(self, other, Cmp); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return (register_type)_mm512_cmp_epu64_mask(self, other, Cmp); + } + } + } + } + + // abs + template ::value, void>::type> + XSIMD_INLINE batch abs(batch const& self, requires_arch) noexcept + { + if (std::is_unsigned::value) + { + return self; + } + + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm512_abs_epi8(self); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_abs_epi16(self); + } + else + { + return abs(self, avx512dq {}); + } + } + + // add + template ::value, void>::type> + XSIMD_INLINE batch add(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm512_add_epi8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_add_epi16(self, other); + } + else + { + return add(self, other, avx512dq {}); + } + } + + // avgr + template ::value, void>::type> + XSIMD_INLINE batch avgr(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm512_avg_epu8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_avg_epu16(self, other); + } + else + { + return avgr(self, other, generic {}); + } + } + + // avg + template ::value, void>::type> + XSIMD_INLINE batch avg(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + auto adj = ((self ^ other) << 7) >> 7; + return avgr(self, other, A {}) - adj; + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + auto adj = ((self ^ other) << 15) >> 15; + return avgr(self, other, A {}) - adj; + } + else + { + return avg(self, other, generic {}); + } + } + + // bitwise_lshift + template ::value, void>::type> + XSIMD_INLINE batch bitwise_lshift(batch const& self, int32_t other, requires_arch) noexcept + { +#if defined(XSIMD_AVX512_SHIFT_INTRINSICS_IMM_ONLY) + XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_sllv_epi16(self, _mm512_set1_epi16(other)); +#else + XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_slli_epi16(self, other); +#endif + } + else + { + return bitwise_lshift(self, other, avx512dq {}); + } + } + + // bitwise_rshift + template ::value, void>::type> + XSIMD_INLINE batch bitwise_rshift(batch const& self, int32_t other, requires_arch) noexcept + { + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + __m512i sign_mask = _mm512_set1_epi16((0xFF00 >> other) & 0x00FF); + __m512i zeros = _mm512_setzero_si512(); + __mmask64 cmp_is_negative_mask = _mm512_cmpgt_epi8_mask(zeros, self); + __m512i cmp_sign_mask = _mm512_mask_blend_epi8(cmp_is_negative_mask, zeros, sign_mask); +#if defined(XSIMD_AVX512_SHIFT_INTRINSICS_IMM_ONLY) + __m512i res = _mm512_srav_epi16(self, _mm512_set1_epi16(other)); +#else + __m512i res = _mm512_srai_epi16(self, other); +#endif + return _mm512_or_si512(cmp_sign_mask, _mm512_andnot_si512(sign_mask, res)); +#if defined(XSIMD_AVX512_SHIFT_INTRINSICS_IMM_ONLY) + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_srav_epi16(self, _mm512_set1_epi16(other)); +#else + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_srai_epi16(self, other); +#endif + } + else + { + return bitwise_rshift(self, other, avx512dq {}); + } + } + else + { +#if defined(XSIMD_AVX512_SHIFT_INTRINSICS_IMM_ONLY) + XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_srlv_epi16(self, _mm512_set1_epi16(other)); +#else + XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_srli_epi16(self, other); +#endif + } + else + { + return bitwise_rshift(self, other, avx512dq {}); + } + } + } + + // eq + template ::value, void>::type> + XSIMD_INLINE batch_bool eq(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::compare_int_avx512bw(self, other); + } + + // ge + template ::value, void>::type> + XSIMD_INLINE batch_bool ge(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::compare_int_avx512bw(self, other); + } + + // gt + template ::value, void>::type> + XSIMD_INLINE batch_bool gt(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::compare_int_avx512bw(self, other); + } + + // le + template ::value, void>::type> + XSIMD_INLINE batch_bool le(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::compare_int_avx512bw(self, other); + } + + // lt + template ::value, void>::type> + XSIMD_INLINE batch_bool lt(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::compare_int_avx512bw(self, other); + } + + // max + template ::value, void>::type> + XSIMD_INLINE batch max(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm512_max_epi8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_max_epi16(self, other); + } + else + { + return max(self, other, avx512dq {}); + } + } + else + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm512_max_epu8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_max_epu16(self, other); + } + else + { + return max(self, other, avx512dq {}); + } + } + } + + // min + template ::value, void>::type> + XSIMD_INLINE batch min(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm512_min_epi8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_min_epi16(self, other); + } + else + { + return min(self, other, avx512dq {}); + } + } + else + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm512_min_epu8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_min_epu16(self, other); + } + else + { + return min(self, other, avx512dq {}); + } + } + } + + // mul + template ::value, void>::type> + XSIMD_INLINE batch mul(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + __m512i upper = _mm512_and_si512(_mm512_mullo_epi16(self, other), _mm512_srli_epi16(_mm512_set1_epi16(-1), 8)); + __m512i lower = _mm512_slli_epi16(_mm512_mullo_epi16(_mm512_srli_epi16(self, 8), _mm512_srli_epi16(other, 8)), 8); + return _mm512_or_si512(upper, lower); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_mullo_epi16(self, other); + } + else + { + return mul(self, other, avx512dq {}); + } + } + + // neq + template ::value, void>::type> + XSIMD_INLINE batch_bool neq(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::compare_int_avx512bw(self, other); + } + + // rotate_left + template + XSIMD_INLINE batch rotate_left(batch const& self, requires_arch) noexcept + { + return _mm512_alignr_epi8(self, self, N); + } + template + XSIMD_INLINE batch rotate_left(batch const& self, requires_arch) noexcept + { + return bitwise_cast(rotate_left(bitwise_cast(self), avx2 {})); + } + + // sadd + template ::value, void>::type> + XSIMD_INLINE batch sadd(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm512_adds_epi8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_adds_epi16(self, other); + } + else + { + return sadd(self, other, avx512dq {}); + } + } + else + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm512_adds_epu8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_adds_epu16(self, other); + } + else + { + return sadd(self, other, avx512dq {}); + } + } + } + + // select + template ::value, void>::type> + XSIMD_INLINE batch select(batch_bool const& cond, batch const& true_br, batch const& false_br, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm512_mask_blend_epi8(cond, false_br.data, true_br.data); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_mask_blend_epi16(cond, false_br.data, true_br.data); + } + else + { + return select(cond, true_br, false_br, avx512dq {}); + } + } + + // slide_left + namespace detail + { + template + constexpr std::array make_slide_perm_hi(::xsimd::detail::index_sequence) + { + return { (Is == 0 ? 8 : Is - 1)... }; + } + + template + constexpr std::array make_slide_left_pattern(::xsimd::detail::index_sequence) + { + return { (Is >= N ? Is - N : 0)... }; + } + template + constexpr std::array make_slide_left_mask(::xsimd::detail::index_sequence) + { + return { (Is >= N ? 0xFFFF : 0x0000)... }; + } + } + + template + XSIMD_INLINE batch slide_left(batch const& x, requires_arch) noexcept + { + constexpr unsigned BitCount = N * 8; + if (BitCount == 0) + { + return x; + } + if (BitCount >= 512) + { + return batch(T(0)); + } + batch xx; + if (N & 1) + { + alignas(A::alignment()) uint64_t buffer[8]; + _mm512_store_epi64(&buffer[0], x); + for (int i = 7; i > 0; --i) + buffer[i] = (buffer[i] << 8) | (buffer[i - 1] >> 56); + buffer[0] = buffer[0] << 8; + xx = _mm512_load_epi64(&buffer[0]); + + alignas(A::alignment()) auto slide_perm = detail::make_slide_perm_hi(::xsimd::detail::make_index_sequence<512 / 64>()); + __m512i xl = _mm512_slli_epi64(x, 8); + __m512i xr = _mm512_srli_epi64(x, 56); + xr = _mm512_permutex2var_epi64(xr, _mm512_load_epi64(slide_perm.data()), _mm512_setzero_si512()); + xx = _mm512_or_si512(xr, xl); + if (N == 1) + return xx; + } + else + { + xx = x; + } + alignas(A::alignment()) auto slide_pattern = detail::make_slide_left_pattern(::xsimd::detail::make_index_sequence<512 / 16>()); + alignas(A::alignment()) auto slide_mask = detail::make_slide_left_mask(::xsimd::detail::make_index_sequence<512 / 16>()); + return _mm512_and_si512(_mm512_permutexvar_epi16(_mm512_load_epi32(slide_pattern.data()), xx), _mm512_load_epi32(slide_mask.data())); + } + + // slide_right + namespace detail + { + template + constexpr std::array make_slide_perm_low(::xsimd::detail::index_sequence) + { + return { (Is + 1)... }; + } + + template + constexpr std::array make_slide_right_pattern(::xsimd::detail::index_sequence) + { + return { (Is < (32 - N) ? Is + N : 0)... }; + } + template + constexpr std::array make_slide_right_mask(::xsimd::detail::index_sequence) + { + return { (Is < 32 - N ? 0xFFFF : 0x0000)... }; + } + } + template + XSIMD_INLINE batch slide_right(batch const& x, requires_arch) noexcept + { + constexpr unsigned BitCount = N * 8; + if (BitCount == 0) + { + return x; + } + if (BitCount >= 512) + { + return batch(T(0)); + } + batch xx; + if (N & 1) + { + alignas(A::alignment()) auto slide_perm = detail::make_slide_perm_low(::xsimd::detail::make_index_sequence<512 / 64>()); + __m512i xr = _mm512_srli_epi64(x, 8); + __m512i xl = _mm512_slli_epi64(x, 56); + xl = _mm512_permutex2var_epi64(xl, _mm512_load_epi64(slide_perm.data()), _mm512_setzero_si512()); + xx = _mm512_or_si512(xr, xl); + if (N == 1) + return xx; + } + else + { + xx = x; + } + alignas(A::alignment()) auto slide_pattern = detail::make_slide_right_pattern(::xsimd::detail::make_index_sequence<512 / 16>()); + alignas(A::alignment()) auto slide_mask = detail::make_slide_right_mask(::xsimd::detail::make_index_sequence<512 / 16>()); + return _mm512_and_si512(_mm512_permutexvar_epi16(_mm512_load_epi32(slide_pattern.data()), xx), _mm512_load_epi32(slide_mask.data())); + } + + // ssub + template ::value, void>::type> + XSIMD_INLINE batch ssub(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm512_subs_epi8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_subs_epi16(self, other); + } + else + { + return ssub(self, other, avx512dq {}); + } + } + else + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm512_subs_epu8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_subs_epu16(self, other); + } + else + { + return ssub(self, other, avx512dq {}); + } + } + } + + // sub + template ::value, void>::type> + XSIMD_INLINE batch sub(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm512_sub_epi8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_sub_epi16(self, other); + } + else + { + return sub(self, other, avx512dq {}); + } + } + + // swizzle (dynamic version) + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + return _mm512_permutexvar_epi16(mask, self); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + return bitwise_cast(swizzle(bitwise_cast(self), mask, avx512bw {})); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + return _mm512_shuffle_epi8(self, mask); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + return bitwise_cast(swizzle(bitwise_cast(self), mask, avx512bw {})); + } + + // swizzle (static version) + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + { + return swizzle(self, mask.as_batch(), avx512bw {}); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + { + return swizzle(self, mask.as_batch(), avx512bw {}); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + { + return swizzle(self, mask.as_batch(), avx512bw {}); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + { + return swizzle(self, mask.as_batch(), avx512bw {}); + } + + // zip_hi + template ::value, void>::type> + XSIMD_INLINE batch zip_hi(batch const& self, batch const& other, requires_arch) noexcept + { + __m512i lo, hi; + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + lo = _mm512_unpacklo_epi8(self, other); + hi = _mm512_unpackhi_epi8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + lo = _mm512_unpacklo_epi16(self, other); + hi = _mm512_unpackhi_epi16(self, other); + } + else + { + return zip_hi(self, other, avx512f {}); + } + return _mm512_inserti32x4( + _mm512_inserti32x4( + _mm512_inserti32x4(hi, _mm512_extracti32x4_epi32(lo, 2), 0), + _mm512_extracti32x4_epi32(lo, 3), + 2), + _mm512_extracti32x4_epi32(hi, 2), + 1); + } + + // zip_lo + template ::value, void>::type> + XSIMD_INLINE batch zip_lo(batch const& self, batch const& other, requires_arch) noexcept + { + __m512i lo, hi; + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + lo = _mm512_unpacklo_epi8(self, other); + hi = _mm512_unpackhi_epi8(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + lo = _mm512_unpacklo_epi16(self, other); + hi = _mm512_unpackhi_epi16(self, other); + } + else + { + return zip_lo(self, other, avx512f {}); + } + return _mm512_inserti32x4( + _mm512_inserti32x4( + _mm512_inserti32x4(lo, _mm512_extracti32x4_epi32(hi, 0), 1), + _mm512_extracti32x4_epi32(hi, 1), + 3), + _mm512_extracti32x4_epi32(lo, 1), + 2); + } + } +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_avx512cd.hpp b/include/onnxruntime/xsimd/arch/xsimd_avx512cd.hpp new file mode 100644 index 0000000000000..95f3f1df8f6de --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_avx512cd.hpp @@ -0,0 +1,28 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_AVX512CD_HPP +#define XSIMD_AVX512CD_HPP + +#include "../types/xsimd_avx512cd_register.hpp" + +namespace xsimd +{ + + namespace kernel + { + // Nothing there yet. + + } + +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_avx512dq.hpp b/include/onnxruntime/xsimd/arch/xsimd_avx512dq.hpp new file mode 100644 index 0000000000000..4788d19e94823 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_avx512dq.hpp @@ -0,0 +1,212 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_AVX512_DQHPP +#define XSIMD_AVX512_D_HPP + +#include "../types/xsimd_avx512dq_register.hpp" + +namespace xsimd +{ + + namespace kernel + { + using namespace types; + + // bitwise_and + template + XSIMD_INLINE batch bitwise_and(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_and_ps(self, other); + } + template + XSIMD_INLINE batch bitwise_and(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_and_pd(self, other); + } + + // bitwise_andnot + template + XSIMD_INLINE batch bitwise_andnot(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_andnot_ps(other, self); + } + template + XSIMD_INLINE batch bitwise_andnot(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_andnot_pd(other, self); + } + + // bitwise_not + template + XSIMD_INLINE batch bitwise_not(batch const& self, requires_arch) noexcept + { + return _mm512_xor_ps(self, _mm512_castsi512_ps(_mm512_set1_epi32(-1))); + } + template + XSIMD_INLINE batch bitwise_not(batch const& self, requires_arch) noexcept + { + return _mm512_xor_pd(self, _mm512_castsi512_pd(_mm512_set1_epi32(-1))); + } + + // bitwise_or + template + XSIMD_INLINE batch bitwise_or(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_or_ps(self, other); + } + template + XSIMD_INLINE batch bitwise_or(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_or_pd(self, other); + } + + template + XSIMD_INLINE batch_bool bitwise_or(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + using register_type = typename batch_bool::register_type; + return register_type(self.data | other.data); + } + + // bitwise_xor + template + XSIMD_INLINE batch bitwise_xor(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_xor_ps(self, other); + } + template + XSIMD_INLINE batch bitwise_xor(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_xor_pd(self, other); + } + + // haddp + template + XSIMD_INLINE batch haddp(batch const* row, requires_arch) noexcept + { + // The following folds over the vector once: + // tmp1 = [a0..8, b0..8] + // tmp2 = [a8..f, b8..f] +#define XSIMD_AVX512_HADDP_STEP1(I, a, b) \ + batch res##I; \ + { \ + auto tmp1 = _mm512_shuffle_f32x4(a, b, _MM_SHUFFLE(1, 0, 1, 0)); \ + auto tmp2 = _mm512_shuffle_f32x4(a, b, _MM_SHUFFLE(3, 2, 3, 2)); \ + res##I = _mm512_add_ps(tmp1, tmp2); \ + } + + XSIMD_AVX512_HADDP_STEP1(0, row[0], row[2]); + XSIMD_AVX512_HADDP_STEP1(1, row[4], row[6]); + XSIMD_AVX512_HADDP_STEP1(2, row[1], row[3]); + XSIMD_AVX512_HADDP_STEP1(3, row[5], row[7]); + XSIMD_AVX512_HADDP_STEP1(4, row[8], row[10]); + XSIMD_AVX512_HADDP_STEP1(5, row[12], row[14]); + XSIMD_AVX512_HADDP_STEP1(6, row[9], row[11]); + XSIMD_AVX512_HADDP_STEP1(7, row[13], row[15]); + +#undef XSIMD_AVX512_HADDP_STEP1 + + // The following flds the code and shuffles so that hadd_ps produces the correct result + // tmp1 = [a0..4, a8..12, b0..4, b8..12] (same for tmp3) + // tmp2 = [a5..8, a12..16, b5..8, b12..16] (same for tmp4) + // tmp5 = [r1[0], r1[2], r2[0], r2[2], r1[4], r1[6] ... +#define XSIMD_AVX512_HADDP_STEP2(I, a, b, c, d) \ + batch halfx##I; \ + { \ + auto tmp1 = _mm512_shuffle_f32x4(a, b, _MM_SHUFFLE(2, 0, 2, 0)); \ + auto tmp2 = _mm512_shuffle_f32x4(a, b, _MM_SHUFFLE(3, 1, 3, 1)); \ + \ + auto resx1 = _mm512_add_ps(tmp1, tmp2); \ + \ + auto tmp3 = _mm512_shuffle_f32x4(c, d, _MM_SHUFFLE(2, 0, 2, 0)); \ + auto tmp4 = _mm512_shuffle_f32x4(c, d, _MM_SHUFFLE(3, 1, 3, 1)); \ + \ + auto resx2 = _mm512_add_ps(tmp3, tmp4); \ + \ + auto tmp5 = _mm512_shuffle_ps(resx1, resx2, _MM_SHUFFLE(2, 0, 2, 0)); \ + auto tmp6 = _mm512_shuffle_ps(resx1, resx2, _MM_SHUFFLE(3, 1, 3, 1)); \ + \ + auto resx3 = _mm512_add_ps(tmp5, tmp6); \ + \ + halfx##I = _mm256_hadd_ps(_mm512_extractf32x8_ps(resx3, 0), \ + _mm512_extractf32x8_ps(resx3, 1)); \ + } + + XSIMD_AVX512_HADDP_STEP2(0, res0, res1, res2, res3); + XSIMD_AVX512_HADDP_STEP2(1, res4, res5, res6, res7); + +#undef XSIMD_AVX512_HADDP_STEP2 + + auto concat = _mm512_castps256_ps512(halfx0); + concat = _mm512_insertf32x8(concat, halfx1, 1); + return concat; + } + + // ldexp + template + XSIMD_INLINE batch ldexp(const batch& self, const batch, A>& other, requires_arch) noexcept + { + return _mm512_scalef_pd(self, _mm512_cvtepi64_pd(other)); + } + + // mul + template + XSIMD_INLINE batch mul(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_mullo_epi64(self, other); + } + + template + XSIMD_INLINE batch mul(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_mullo_epi64(self, other); + } + + // nearbyint_as_int + template + XSIMD_INLINE batch nearbyint_as_int(batch const& self, + requires_arch) noexcept + { + return _mm512_cvtpd_epi64(self); + } + + // reduce_add + template + XSIMD_INLINE float reduce_add(batch const& rhs, requires_arch) noexcept + { + __m256 tmp1 = _mm512_extractf32x8_ps(rhs, 1); + __m256 tmp2 = _mm512_extractf32x8_ps(rhs, 0); + __m256 res1 = _mm256_add_ps(tmp1, tmp2); + return reduce_add(batch(res1), avx2 {}); + } + + // convert + namespace detail + { + template + XSIMD_INLINE batch fast_cast(batch const& x, batch const&, requires_arch) noexcept + { + return _mm512_cvtepi64_pd(self); + } + + template + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm512_cvttpd_epi64(self); + } + + } + + } + +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_avx512er.hpp b/include/onnxruntime/xsimd/arch/xsimd_avx512er.hpp new file mode 100644 index 0000000000000..be02f9850b113 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_avx512er.hpp @@ -0,0 +1,20 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_AVX512ER_HPP +#define XSIMD_AVX512ER_HPP + +#include +#include + +#include "../types/xsimd_avx512er_register.hpp" + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_avx512f.hpp b/include/onnxruntime/xsimd/arch/xsimd_avx512f.hpp new file mode 100644 index 0000000000000..c2b485a30e3d3 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_avx512f.hpp @@ -0,0 +1,2167 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_AVX512F_HPP +#define XSIMD_AVX512F_HPP + +#include +#include +#include + +#include "../types/xsimd_avx512f_register.hpp" + +namespace xsimd +{ + + namespace kernel + { + using namespace types; + + namespace detail + { + XSIMD_INLINE void split_avx512(__m512 val, __m256& low, __m256& high) noexcept + { + low = _mm512_castps512_ps256(val); + high = _mm512_extractf32x8_ps(val, 1); + } + XSIMD_INLINE void split_avx512(__m512d val, __m256d& low, __m256d& high) noexcept + { + low = _mm512_castpd512_pd256(val); + high = _mm512_extractf64x4_pd(val, 1); + } + XSIMD_INLINE void split_avx512(__m512i val, __m256i& low, __m256i& high) noexcept + { + low = _mm512_castsi512_si256(val); + high = _mm512_extracti64x4_epi64(val, 1); + } + XSIMD_INLINE __m512i merge_avx(__m256i low, __m256i high) noexcept + { + return _mm512_inserti64x4(_mm512_castsi256_si512(low), high, 1); + } + XSIMD_INLINE __m512 merge_avx(__m256 low, __m256 high) noexcept + { + return _mm512_castpd_ps(_mm512_insertf64x4(_mm512_castpd256_pd512(_mm256_castps_pd(low)), _mm256_castps_pd(high), 1)); + } + XSIMD_INLINE __m512d merge_avx(__m256d low, __m256d high) noexcept + { + return _mm512_insertf64x4(_mm512_castpd256_pd512(low), high, 1); + } + template + __m512i fwd_to_avx(F f, __m512i self) + { + __m256i self_low, self_high; + split_avx512(self, self_low, self_high); + __m256i res_low = f(self_low); + __m256i res_high = f(self_high); + return merge_avx(res_low, res_high); + } + template + __m512i fwd_to_avx(F f, __m512i self, __m512i other) + { + __m256i self_low, self_high, other_low, other_high; + split_avx512(self, self_low, self_high); + split_avx512(other, other_low, other_high); + __m256i res_low = f(self_low, other_low); + __m256i res_high = f(self_high, other_high); + return merge_avx(res_low, res_high); + } + template + __m512i fwd_to_avx(F f, __m512i self, int32_t other) + { + __m256i self_low, self_high; + split_avx512(self, self_low, self_high); + __m256i res_low = f(self_low, other); + __m256i res_high = f(self_high, other); + return merge_avx(res_low, res_high); + } + } + namespace detail + { + + XSIMD_INLINE uint32_t morton(uint16_t x, uint16_t y) noexcept + { + + static const unsigned short MortonTable256[256] = { + 0x0000, 0x0001, 0x0004, 0x0005, 0x0010, 0x0011, 0x0014, 0x0015, + 0x0040, 0x0041, 0x0044, 0x0045, 0x0050, 0x0051, 0x0054, 0x0055, + 0x0100, 0x0101, 0x0104, 0x0105, 0x0110, 0x0111, 0x0114, 0x0115, + 0x0140, 0x0141, 0x0144, 0x0145, 0x0150, 0x0151, 0x0154, 0x0155, + 0x0400, 0x0401, 0x0404, 0x0405, 0x0410, 0x0411, 0x0414, 0x0415, + 0x0440, 0x0441, 0x0444, 0x0445, 0x0450, 0x0451, 0x0454, 0x0455, + 0x0500, 0x0501, 0x0504, 0x0505, 0x0510, 0x0511, 0x0514, 0x0515, + 0x0540, 0x0541, 0x0544, 0x0545, 0x0550, 0x0551, 0x0554, 0x0555, + 0x1000, 0x1001, 0x1004, 0x1005, 0x1010, 0x1011, 0x1014, 0x1015, + 0x1040, 0x1041, 0x1044, 0x1045, 0x1050, 0x1051, 0x1054, 0x1055, + 0x1100, 0x1101, 0x1104, 0x1105, 0x1110, 0x1111, 0x1114, 0x1115, + 0x1140, 0x1141, 0x1144, 0x1145, 0x1150, 0x1151, 0x1154, 0x1155, + 0x1400, 0x1401, 0x1404, 0x1405, 0x1410, 0x1411, 0x1414, 0x1415, + 0x1440, 0x1441, 0x1444, 0x1445, 0x1450, 0x1451, 0x1454, 0x1455, + 0x1500, 0x1501, 0x1504, 0x1505, 0x1510, 0x1511, 0x1514, 0x1515, + 0x1540, 0x1541, 0x1544, 0x1545, 0x1550, 0x1551, 0x1554, 0x1555, + 0x4000, 0x4001, 0x4004, 0x4005, 0x4010, 0x4011, 0x4014, 0x4015, + 0x4040, 0x4041, 0x4044, 0x4045, 0x4050, 0x4051, 0x4054, 0x4055, + 0x4100, 0x4101, 0x4104, 0x4105, 0x4110, 0x4111, 0x4114, 0x4115, + 0x4140, 0x4141, 0x4144, 0x4145, 0x4150, 0x4151, 0x4154, 0x4155, + 0x4400, 0x4401, 0x4404, 0x4405, 0x4410, 0x4411, 0x4414, 0x4415, + 0x4440, 0x4441, 0x4444, 0x4445, 0x4450, 0x4451, 0x4454, 0x4455, + 0x4500, 0x4501, 0x4504, 0x4505, 0x4510, 0x4511, 0x4514, 0x4515, + 0x4540, 0x4541, 0x4544, 0x4545, 0x4550, 0x4551, 0x4554, 0x4555, + 0x5000, 0x5001, 0x5004, 0x5005, 0x5010, 0x5011, 0x5014, 0x5015, + 0x5040, 0x5041, 0x5044, 0x5045, 0x5050, 0x5051, 0x5054, 0x5055, + 0x5100, 0x5101, 0x5104, 0x5105, 0x5110, 0x5111, 0x5114, 0x5115, + 0x5140, 0x5141, 0x5144, 0x5145, 0x5150, 0x5151, 0x5154, 0x5155, + 0x5400, 0x5401, 0x5404, 0x5405, 0x5410, 0x5411, 0x5414, 0x5415, + 0x5440, 0x5441, 0x5444, 0x5445, 0x5450, 0x5451, 0x5454, 0x5455, + 0x5500, 0x5501, 0x5504, 0x5505, 0x5510, 0x5511, 0x5514, 0x5515, + 0x5540, 0x5541, 0x5544, 0x5545, 0x5550, 0x5551, 0x5554, 0x5555 + }; + + uint32_t z = MortonTable256[y >> 8] << 17 | MortonTable256[x >> 8] << 16 | MortonTable256[y & 0xFF] << 1 | MortonTable256[x & 0xFF]; + return z; + } + + template + XSIMD_INLINE batch_bool compare_int_avx512f(batch const& self, batch const& other) noexcept + { + using register_type = typename batch_bool::register_type; + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + // shifting to take sign into account + uint64_t mask_low0 = _mm512_cmp_epi32_mask((batch(self.data) & batch(0x000000FF)) << 24, + (batch(other.data) & batch(0x000000FF)) << 24, + Cmp); + uint64_t mask_low1 = _mm512_cmp_epi32_mask((batch(self.data) & batch(0x0000FF00)) << 16, + (batch(other.data) & batch(0x0000FF00)) << 16, + Cmp); + uint64_t mask_high0 = _mm512_cmp_epi32_mask((batch(self.data) & batch(0x00FF0000)) << 8, + (batch(other.data) & batch(0x00FF0000)) << 8, + Cmp); + uint64_t mask_high1 = _mm512_cmp_epi32_mask((batch(self.data) & batch(0xFF000000)), + (batch(other.data) & batch(0xFF000000)), + Cmp); + uint64_t mask = 0; + for (unsigned i = 0; i < 16; ++i) + { + mask |= (mask_low0 & (uint64_t(1) << i)) << (3 * i + 0); + mask |= (mask_low1 & (uint64_t(1) << i)) << (3 * i + 1); + mask |= (mask_high0 & (uint64_t(1) << i)) << (3 * i + 2); + mask |= (mask_high1 & (uint64_t(1) << i)) << (3 * i + 3); + } + return (register_type)mask; + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + // shifting to take sign into account + uint16_t mask_low = _mm512_cmp_epi32_mask((batch(self.data) & batch(0x0000FFFF)) << 16, + (batch(other.data) & batch(0x0000FFFF)) << 16, + Cmp); + uint16_t mask_high = _mm512_cmp_epi32_mask((batch(self.data) & batch(0xFFFF0000)), + (batch(other.data) & batch(0xFFFF0000)), + Cmp); + return static_cast(morton(mask_low, mask_high)); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return (register_type)_mm512_cmp_epi32_mask(self, other, Cmp); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return (register_type)_mm512_cmp_epi64_mask(self, other, Cmp); + } + } + else + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + uint64_t mask_low0 = _mm512_cmp_epu32_mask((batch(self.data) & batch(0x000000FF)), (batch(other.data) & batch(0x000000FF)), Cmp); + uint64_t mask_low1 = _mm512_cmp_epu32_mask((batch(self.data) & batch(0x0000FF00)), (batch(other.data) & batch(0x0000FF00)), Cmp); + uint64_t mask_high0 = _mm512_cmp_epu32_mask((batch(self.data) & batch(0x00FF0000)), (batch(other.data) & batch(0x00FF0000)), Cmp); + uint64_t mask_high1 = _mm512_cmp_epu32_mask((batch(self.data) & batch(0xFF000000)), (batch(other.data) & batch(0xFF000000)), Cmp); + uint64_t mask = 0; + for (unsigned i = 0; i < 16; ++i) + { + mask |= (mask_low0 & (uint64_t(1) << i)) << (3 * i + 0); + mask |= (mask_low1 & (uint64_t(1) << i)) << (3 * i + 1); + mask |= (mask_high0 & (uint64_t(1) << i)) << (3 * i + 2); + mask |= (mask_high1 & (uint64_t(1) << i)) << (3 * i + 3); + } + return (register_type)mask; + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + uint16_t mask_low = _mm512_cmp_epu32_mask((batch(self.data) & batch(0x0000FFFF)), (batch(other.data) & batch(0x0000FFFF)), Cmp); + uint16_t mask_high = _mm512_cmp_epu32_mask((batch(self.data) & batch(0xFFFF0000)), (batch(other.data) & batch(0xFFFF0000)), Cmp); + return static_cast(morton(mask_low, mask_high)); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return (register_type)_mm512_cmp_epu32_mask(self, other, Cmp); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return (register_type)_mm512_cmp_epu64_mask(self, other, Cmp); + } + } + } + } + + // abs + template + XSIMD_INLINE batch abs(batch const& self, requires_arch) noexcept + { + __m512 self_asf = (__m512)self; + __m512i self_asi = *reinterpret_cast<__m512i*>(&self_asf); + __m512i res_asi = _mm512_and_epi32(_mm512_set1_epi32(0x7FFFFFFF), self_asi); + return *reinterpret_cast<__m512*>(&res_asi); + } + template + XSIMD_INLINE batch abs(batch const& self, requires_arch) noexcept + { + __m512d self_asd = (__m512d)self; + __m512i self_asi = *reinterpret_cast<__m512i*>(&self_asd); + __m512i res_asi = _mm512_and_epi64(_mm512_set1_epi64(0x7FFFFFFFFFFFFFFF), + self_asi); + return *reinterpret_cast<__m512d*>(&res_asi); + } + template ::value, void>::type> + XSIMD_INLINE batch abs(batch const& self, requires_arch) noexcept + { + if (std::is_unsigned::value) + { + return self; + } + + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return detail::fwd_to_avx([](__m256i s) noexcept + { return abs(batch(s)); }, + self); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return detail::fwd_to_avx([](__m256i s) noexcept + { return abs(batch(s)); }, + self); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm512_abs_epi32(self); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm512_abs_epi64(self); + } + else + { + assert(false && "unsupported arch/op combination"); + return {}; + } + } + + // add + template ::value, void>::type> + XSIMD_INLINE batch add(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return detail::fwd_to_avx([](__m256i s, __m256i o) noexcept + { return add(batch(s), batch(o)); }, + self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return detail::fwd_to_avx([](__m256i s, __m256i o) noexcept + { return add(batch(s), batch(o)); }, + self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm512_add_epi32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm512_add_epi64(self, other); + } + else + { + assert(false && "unsupported arch/op combination"); + return {}; + } + } + template + XSIMD_INLINE batch add(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_add_ps(self, other); + } + template + XSIMD_INLINE batch add(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_add_pd(self, other); + } + + // all + template + XSIMD_INLINE bool all(batch_bool const& self, requires_arch) noexcept + { + using register_type = typename batch_bool::register_type; + return self.data == register_type(-1); + } + + // any + template + XSIMD_INLINE bool any(batch_bool const& self, requires_arch) noexcept + { + using register_type = typename batch_bool::register_type; + return self.data != register_type(0); + } + + // batch_bool_cast + template + XSIMD_INLINE batch_bool batch_bool_cast(batch_bool const& self, batch_bool const&, requires_arch) noexcept + { + return self.data; + } + + // bitwise_and + template + XSIMD_INLINE batch bitwise_and(batch const& self, batch const& other, requires_arch) noexcept + { +#if defined(_MSC_VER) + return _mm512_and_ps(self, other); +#else + return _mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(self), _mm512_castps_si512(other))); +#endif + } + template + XSIMD_INLINE batch bitwise_and(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_castsi512_pd(_mm512_and_si512(_mm512_castpd_si512(self), _mm512_castpd_si512(other))); + } + + template ::value, void>::type> + XSIMD_INLINE batch bitwise_and(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_and_si512(self, other); + } + + template + XSIMD_INLINE batch_bool bitwise_and(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + using register_type = typename batch_bool::register_type; + return register_type(self.data & other.data); + } + + // bitwise_andnot + template + XSIMD_INLINE batch bitwise_andnot(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_castsi512_ps(_mm512_andnot_si512(_mm512_castps_si512(other), _mm512_castps_si512(self))); + } + template + XSIMD_INLINE batch bitwise_andnot(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_castsi512_pd(_mm512_andnot_si512(_mm512_castpd_si512(other), _mm512_castpd_si512(self))); + } + + template ::value, void>::type> + XSIMD_INLINE batch bitwise_andnot(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_andnot_si512(other, self); + } + + template + XSIMD_INLINE batch_bool bitwise_andnot(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + using register_type = typename batch_bool::register_type; + return register_type(self.data & ~other.data); + } + + // bitwise_lshift + template ::value, void>::type> + XSIMD_INLINE batch bitwise_lshift(batch const& self, int32_t other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { +#if defined(XSIMD_AVX512_SHIFT_INTRINSICS_IMM_ONLY) + __m512i tmp = _mm512_sllv_epi32(self, _mm512_set1_epi32(other)); +#else + __m512i tmp = _mm512_slli_epi32(self, other); +#endif + return _mm512_and_si512(_mm512_set1_epi8(0xFF << other), tmp); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return detail::fwd_to_avx([](__m256i s, int32_t o) noexcept + { return bitwise_lshift(batch(s), o, avx2 {}); }, + self, other); +#if defined(XSIMD_AVX512_SHIFT_INTRINSICS_IMM_ONLY) + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm512_sllv_epi32(self, _mm512_set1_epi32(other)); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm512_sllv_epi64(self, _mm512_set1_epi64(other)); +#else + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm512_slli_epi32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm512_slli_epi64(self, other); +#endif + } + else + { + assert(false && "unsupported arch/op combination"); + return {}; + } + } + + // bitwise_not + template ::value, void>::type> + XSIMD_INLINE batch bitwise_not(batch const& self, requires_arch) noexcept + { + return _mm512_xor_si512(self, _mm512_set1_epi32(-1)); + } + template + XSIMD_INLINE batch_bool bitwise_not(batch_bool const& self, requires_arch) noexcept + { + using register_type = typename batch_bool::register_type; + return register_type(~self.data); + } + + template + XSIMD_INLINE batch bitwise_not(batch const& self, requires_arch) noexcept + { + return _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(self), _mm512_set1_epi32(-1))); + } + template + XSIMD_INLINE batch bitwise_not(batch const& self, requires_arch) noexcept + { + return _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(self), _mm512_set1_epi32(-1))); + } + + // bitwise_or + template + XSIMD_INLINE batch bitwise_or(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_castsi512_ps(_mm512_or_si512(_mm512_castps_si512(self), _mm512_castps_si512(other))); + } + template + XSIMD_INLINE batch bitwise_or(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_castsi512_pd(_mm512_or_si512(_mm512_castpd_si512(self), _mm512_castpd_si512(other))); + } + + template + XSIMD_INLINE batch_bool bitwise_or(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + using register_type = typename batch_bool::register_type; + return register_type(self.data | other.data); + } + + template ::value, void>::type> + XSIMD_INLINE batch bitwise_or(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_or_si512(self, other); + } + + // bitwise_rshift + template ::value, void>::type> + XSIMD_INLINE batch bitwise_rshift(batch const& self, int32_t other, requires_arch) noexcept + { + if (std::is_signed::value) + { +#if defined(XSIMD_AVX512_SHIFT_INTRINSICS_IMM_ONLY) + XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm512_srav_epi32(self, _mm512_set1_epi32(other)); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm512_srav_epi64(self, _mm512_set1_epi64(other)); +#else + XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm512_srai_epi32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm512_srai_epi64(self, other); +#endif + } + else + { + return detail::fwd_to_avx([](__m256i s, int32_t o) noexcept + { return bitwise_rshift(batch(s), o, avx2 {}); }, + self, other); + } + } + else + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { +#if defined(XSIMD_AVX512_SHIFT_INTRINSICS_IMM_ONLY) + __m512i tmp = _mm512_srlv_epi32(self, _mm512_set1_epi32(other)); +#else + __m512i tmp = _mm512_srli_epi32(self, other); +#endif + return _mm512_and_si512(_mm512_set1_epi8(0xFF >> other), tmp); +#if defined(XSIMD_AVX512_SHIFT_INTRINSICS_IMM_ONLY) + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm512_srlv_epi32(self, _mm512_set1_epi32(other)); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm512_srlv_epi64(self, _mm512_set1_epi64(other)); +#else + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm512_srli_epi32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm512_srli_epi64(self, other); +#endif + } + else + { + return detail::fwd_to_avx([](__m256i s, int32_t o) noexcept + { return bitwise_rshift(batch(s), o, avx2 {}); }, + self, other); + } + } + } + + // bitwise_xor + template + XSIMD_INLINE batch bitwise_xor(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(self), _mm512_castps_si512(other))); + } + template + XSIMD_INLINE batch bitwise_xor(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(self), _mm512_castpd_si512(other))); + } + + template + XSIMD_INLINE batch_bool bitwise_xor(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + using register_type = typename batch_bool::register_type; + return register_type(self.data | other.data); + } + + template ::value, void>::type> + XSIMD_INLINE batch bitwise_xor(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_xor_si512(self, other); + } + + // bitwise_cast + template ::value, void>::type> + XSIMD_INLINE batch bitwise_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm512_castsi512_ps(self); + } + template ::value, void>::type> + XSIMD_INLINE batch bitwise_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm512_castsi512_pd(self); + } + template ::type>::value, void>::type> + XSIMD_INLINE batch bitwise_cast(batch const& self, batch const&, requires_arch) noexcept + { + return batch(self.data); + } + template + XSIMD_INLINE batch bitwise_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm512_castps_pd(self); + } + template ::value, void>::type> + XSIMD_INLINE batch bitwise_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm512_castps_si512(self); + } + template + XSIMD_INLINE batch bitwise_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm512_castpd_ps(self); + } + template ::value, void>::type> + XSIMD_INLINE batch bitwise_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm512_castpd_si512(self); + } + + // broadcast + template ::value, void>::type> + XSIMD_INLINE batch broadcast(T val, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return _mm512_set1_epi8(val); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return _mm512_set1_epi16(val); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm512_set1_epi32(val); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm512_set1_epi64(val); + } + else + { + assert(false && "unsupported"); + return {}; + } + } + template + XSIMD_INLINE batch broadcast(float val, requires_arch) noexcept + { + return _mm512_set1_ps(val); + } + template + batch XSIMD_INLINE broadcast(double val, requires_arch) noexcept + { + return _mm512_set1_pd(val); + } + + // ceil + template + XSIMD_INLINE batch ceil(batch const& self, requires_arch) noexcept + { + return _mm512_roundscale_ps(self, _MM_FROUND_TO_POS_INF); + } + template + XSIMD_INLINE batch ceil(batch const& self, requires_arch) noexcept + { + return _mm512_roundscale_pd(self, _MM_FROUND_TO_POS_INF); + } + + // compress + template + XSIMD_INLINE batch compress(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_compress_ps(mask.mask(), self); + } + template + XSIMD_INLINE batch compress(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_compress_pd(mask.mask(), self); + } + template + XSIMD_INLINE batch compress(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_compress_epi32(mask.mask(), self); + } + template + XSIMD_INLINE batch compress(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_compress_epi32(mask.mask(), self); + } + template + XSIMD_INLINE batch compress(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_compress_epi64(mask.mask(), self); + } + template + XSIMD_INLINE batch compress(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_compress_epi64(mask.mask(), self); + } + + // convert + namespace detail + { + template + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm512_cvtepi32_ps(self); + } + + template + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm512_cvttps_epi32(self); + } + + template + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch) noexcept + { + return _mm512_cvtepu32_ps(self); + } + + template + batch fast_cast(batch const& self, batch const&, requires_arch) + { + return _mm512_cvttps_epu32(self); + } + } + + namespace detail + { + // complex_low + template + XSIMD_INLINE batch complex_low(batch, A> const& self, requires_arch) noexcept + { + __m512i idx = _mm512_setr_epi32(0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23); + return _mm512_permutex2var_ps(self.real(), idx, self.imag()); + } + template + XSIMD_INLINE batch complex_low(batch, A> const& self, requires_arch) noexcept + { + __m512i idx = _mm512_setr_epi64(0, 8, 1, 9, 2, 10, 3, 11); + return _mm512_permutex2var_pd(self.real(), idx, self.imag()); + } + + // complex_high + template + XSIMD_INLINE batch complex_high(batch, A> const& self, requires_arch) noexcept + { + __m512i idx = _mm512_setr_epi32(8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31); + return _mm512_permutex2var_ps(self.real(), idx, self.imag()); + } + template + XSIMD_INLINE batch complex_high(batch, A> const& self, requires_arch) noexcept + { + __m512i idx = _mm512_setr_epi64(4, 12, 5, 13, 6, 14, 7, 15); + return _mm512_permutex2var_pd(self.real(), idx, self.imag()); + } + } + + // div + template + XSIMD_INLINE batch div(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_div_ps(self, other); + } + template + XSIMD_INLINE batch div(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_div_pd(self, other); + } + + // eq + template + XSIMD_INLINE batch_bool eq(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_cmp_ps_mask(self, other, _CMP_EQ_OQ); + } + template + XSIMD_INLINE batch_bool eq(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_cmp_pd_mask(self, other, _CMP_EQ_OQ); + } + + template ::value, void>::type> + XSIMD_INLINE batch_bool eq(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::compare_int_avx512f(self, other); + } + template + XSIMD_INLINE batch_bool eq(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + using register_type = typename batch_bool::register_type; + return register_type(~self.data ^ other.data); + } + + // expand + template + XSIMD_INLINE batch expand(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_expand_ps(mask.mask(), self); + } + template + XSIMD_INLINE batch expand(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_expand_pd(mask.mask(), self); + } + template + XSIMD_INLINE batch expand(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_expand_epi32(mask.mask(), self); + } + template + XSIMD_INLINE batch expand(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_expand_epi32(mask.mask(), self); + } + template + XSIMD_INLINE batch expand(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_expand_epi64(mask.mask(), self); + } + template + XSIMD_INLINE batch expand(batch const& self, batch_bool const& mask, requires_arch) noexcept + { + return _mm512_maskz_expand_epi64(mask.mask(), self); + } + + // floor + template + XSIMD_INLINE batch floor(batch const& self, requires_arch) noexcept + { + return _mm512_roundscale_ps(self, _MM_FROUND_TO_NEG_INF); + } + template + XSIMD_INLINE batch floor(batch const& self, requires_arch) noexcept + { + return _mm512_roundscale_pd(self, _MM_FROUND_TO_NEG_INF); + } + + // fnma + template + XSIMD_INLINE batch fnma(batch const& x, batch const& y, batch const& z, requires_arch) noexcept + { + return _mm512_fnmadd_ps(x, y, z); + } + + template + XSIMD_INLINE batch fnma(batch const& x, batch const& y, batch const& z, requires_arch) noexcept + { + return _mm512_fnmadd_pd(x, y, z); + } + + // fma + template + XSIMD_INLINE batch fma(batch const& x, batch const& y, batch const& z, requires_arch) noexcept + { + return _mm512_fmadd_ps(x, y, z); + } + + template + XSIMD_INLINE batch fma(batch const& x, batch const& y, batch const& z, requires_arch) noexcept + { + return _mm512_fmadd_pd(x, y, z); + } + + // fms + template + XSIMD_INLINE batch fms(batch const& x, batch const& y, batch const& z, requires_arch) noexcept + { + return _mm512_fmsub_ps(x, y, z); + } + + template + XSIMD_INLINE batch fms(batch const& x, batch const& y, batch const& z, requires_arch) noexcept + { + return _mm512_fmsub_pd(x, y, z); + } + + // from bool + template + XSIMD_INLINE batch from_bool(batch_bool const& self, requires_arch) noexcept + { + return select(self, batch(1), batch(0)); + } + + // from_mask + template + XSIMD_INLINE batch_bool from_mask(batch_bool const&, uint64_t mask, requires_arch) noexcept + { + return static_cast::register_type>(mask); + } + + // gather + template = 0, detail::enable_sized_integral_t = 0> + XSIMD_INLINE batch gather(batch const&, T const* src, batch const& index, + kernel::requires_arch) noexcept + { + return _mm512_i32gather_epi32(index, static_cast(src), sizeof(T)); + } + + template = 0, detail::enable_sized_integral_t = 0> + XSIMD_INLINE batch gather(batch const&, T const* src, batch const& index, + kernel::requires_arch) noexcept + { + return _mm512_i64gather_epi64(index, static_cast(src), sizeof(T)); + } + + template = 0> + XSIMD_INLINE batch gather(batch const&, float const* src, + batch const& index, + kernel::requires_arch) noexcept + { + return _mm512_i32gather_ps(index, src, sizeof(float)); + } + + template = 0> + XSIMD_INLINE batch + gather(batch const&, double const* src, batch const& index, + kernel::requires_arch) noexcept + { + return _mm512_i64gather_pd(index, src, sizeof(double)); + } + + // gather: handmade conversions + template = 0> + XSIMD_INLINE batch gather(batch const&, double const* src, + batch const& index, + requires_arch) noexcept + { + const batch low(_mm512_i32gather_pd(_mm512_castsi512_si256(index.data), src, sizeof(double))); + const batch high(_mm512_i32gather_pd(_mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castsi512_pd(index.data), 1)), src, sizeof(double))); + return detail::merge_avx(_mm512_cvtpd_ps(low.data), _mm512_cvtpd_ps(high.data)); + } + + template = 0> + XSIMD_INLINE batch gather(batch const&, double const* src, + batch const& index, + requires_arch) noexcept + { + const batch low(_mm512_i32gather_pd(_mm512_castsi512_si256(index.data), src, sizeof(double))); + const batch high(_mm512_i32gather_pd(_mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castsi512_pd(index.data), 1)), src, sizeof(double))); + return detail::merge_avx(_mm512_cvtpd_epi32(low.data), _mm512_cvtpd_epi32(high.data)); + } + + // ge + template + XSIMD_INLINE batch_bool ge(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_cmp_ps_mask(self, other, _CMP_GE_OQ); + } + template + XSIMD_INLINE batch_bool ge(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_cmp_pd_mask(self, other, _CMP_GE_OQ); + } + template ::value, void>::type> + XSIMD_INLINE batch_bool ge(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::compare_int_avx512f(self, other); + } + + // gt + template + XSIMD_INLINE batch_bool gt(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_cmp_ps_mask(self, other, _CMP_GT_OQ); + } + template + XSIMD_INLINE batch_bool gt(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_cmp_pd_mask(self, other, _CMP_GT_OQ); + } + template ::value, void>::type> + XSIMD_INLINE batch_bool gt(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::compare_int_avx512f(self, other); + } + + // haddp + template + XSIMD_INLINE batch haddp(batch const* row, requires_arch) noexcept + { + // The following folds over the vector once: + // tmp1 = [a0..8, b0..8] + // tmp2 = [a8..f, b8..f] +#define XSIMD_AVX512_HADDP_STEP1(I, a, b) \ + batch res##I; \ + { \ + auto tmp1 = _mm512_shuffle_f32x4(a, b, _MM_SHUFFLE(1, 0, 1, 0)); \ + auto tmp2 = _mm512_shuffle_f32x4(a, b, _MM_SHUFFLE(3, 2, 3, 2)); \ + res##I = _mm512_add_ps(tmp1, tmp2); \ + } + + XSIMD_AVX512_HADDP_STEP1(0, row[0], row[2]); + XSIMD_AVX512_HADDP_STEP1(1, row[4], row[6]); + XSIMD_AVX512_HADDP_STEP1(2, row[1], row[3]); + XSIMD_AVX512_HADDP_STEP1(3, row[5], row[7]); + XSIMD_AVX512_HADDP_STEP1(4, row[8], row[10]); + XSIMD_AVX512_HADDP_STEP1(5, row[12], row[14]); + XSIMD_AVX512_HADDP_STEP1(6, row[9], row[11]); + XSIMD_AVX512_HADDP_STEP1(7, row[13], row[15]); + +#undef XSIMD_AVX512_HADDP_STEP1 + + // The following flds the code and shuffles so that hadd_ps produces the correct result + // tmp1 = [a0..4, a8..12, b0..4, b8..12] (same for tmp3) + // tmp2 = [a5..8, a12..16, b5..8, b12..16] (same for tmp4) + // tmp5 = [r1[0], r1[2], r2[0], r2[2], r1[4], r1[6] ... +#define XSIMD_AVX512_HADDP_STEP2(I, a, b, c, d) \ + batch halfx##I; \ + { \ + auto tmp1 = _mm512_shuffle_f32x4(a, b, _MM_SHUFFLE(2, 0, 2, 0)); \ + auto tmp2 = _mm512_shuffle_f32x4(a, b, _MM_SHUFFLE(3, 1, 3, 1)); \ + \ + auto resx1 = _mm512_add_ps(tmp1, tmp2); \ + \ + auto tmp3 = _mm512_shuffle_f32x4(c, d, _MM_SHUFFLE(2, 0, 2, 0)); \ + auto tmp4 = _mm512_shuffle_f32x4(c, d, _MM_SHUFFLE(3, 1, 3, 1)); \ + \ + auto resx2 = _mm512_add_ps(tmp3, tmp4); \ + \ + auto tmp5 = _mm512_shuffle_ps(resx1, resx2, _MM_SHUFFLE(2, 0, 2, 0)); \ + auto tmp6 = _mm512_shuffle_ps(resx1, resx2, _MM_SHUFFLE(3, 1, 3, 1)); \ + \ + auto resx3 = _mm512_add_ps(tmp5, tmp6); \ + \ + halfx##I = _mm256_hadd_ps(_mm256_insertf128_ps(_mm256_castps128_ps256(_mm512_extractf32x4_ps(resx3, 0)), _mm512_extractf32x4_ps(resx3, 1), 1), \ + _mm256_insertf128_ps(_mm256_castps128_ps256(_mm512_extractf32x4_ps(resx3, 2)), _mm512_extractf32x4_ps(resx3, 3), 1)); \ + } + + XSIMD_AVX512_HADDP_STEP2(0, res0, res1, res2, res3); + XSIMD_AVX512_HADDP_STEP2(1, res4, res5, res6, res7); + +#undef XSIMD_AVX512_HADDP_STEP2 + + auto concat = _mm512_castps256_ps512(halfx0); + concat = _mm512_castpd_ps(_mm512_insertf64x4(_mm512_castps_pd(concat), _mm256_castps_pd(halfx1), 1)); + return concat; + } + + template + XSIMD_INLINE batch haddp(batch const* row, requires_arch) noexcept + { +#define step1(I, a, b) \ + batch res##I; \ + { \ + auto tmp1 = _mm512_shuffle_f64x2(a, b, _MM_SHUFFLE(1, 0, 1, 0)); \ + auto tmp2 = _mm512_shuffle_f64x2(a, b, _MM_SHUFFLE(3, 2, 3, 2)); \ + res##I = _mm512_add_pd(tmp1, tmp2); \ + } + + step1(1, row[0], row[2]); + step1(2, row[4], row[6]); + step1(3, row[1], row[3]); + step1(4, row[5], row[7]); + +#undef step1 + + auto tmp5 = _mm512_shuffle_f64x2(res1, res2, _MM_SHUFFLE(2, 0, 2, 0)); + auto tmp6 = _mm512_shuffle_f64x2(res1, res2, _MM_SHUFFLE(3, 1, 3, 1)); + + auto resx1 = _mm512_add_pd(tmp5, tmp6); + + auto tmp7 = _mm512_shuffle_f64x2(res3, res4, _MM_SHUFFLE(2, 0, 2, 0)); + auto tmp8 = _mm512_shuffle_f64x2(res3, res4, _MM_SHUFFLE(3, 1, 3, 1)); + + auto resx2 = _mm512_add_pd(tmp7, tmp8); + + auto tmpx = _mm512_shuffle_pd(resx1, resx2, 0b00000000); + auto tmpy = _mm512_shuffle_pd(resx1, resx2, 0b11111111); + + return _mm512_add_pd(tmpx, tmpy); + } + + // isnan + template + XSIMD_INLINE batch_bool isnan(batch const& self, requires_arch) noexcept + { + return _mm512_cmp_ps_mask(self, self, _CMP_UNORD_Q); + } + template + XSIMD_INLINE batch_bool isnan(batch const& self, requires_arch) noexcept + { + return _mm512_cmp_pd_mask(self, self, _CMP_UNORD_Q); + } + + // ldexp + template + XSIMD_INLINE batch ldexp(const batch& self, const batch, A>& other, requires_arch) noexcept + { + return _mm512_scalef_ps(self, _mm512_cvtepi32_ps(other)); + } + + template + XSIMD_INLINE batch ldexp(const batch& self, const batch, A>& other, requires_arch) noexcept + { + // FIXME: potential data loss here when converting other elements to + // int32 before converting them back to double. + __m512d adjusted_index = _mm512_cvtepi32_pd(_mm512_cvtepi64_epi32(other)); + return _mm512_scalef_pd(self, adjusted_index); + } + + // le + template + XSIMD_INLINE batch_bool le(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_cmp_ps_mask(self, other, _CMP_LE_OQ); + } + template + XSIMD_INLINE batch_bool le(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_cmp_pd_mask(self, other, _CMP_LE_OQ); + } + template ::value, void>::type> + XSIMD_INLINE batch_bool le(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::compare_int_avx512f(self, other); + } + + // load_aligned + template ::value, void>::type> + XSIMD_INLINE batch load_aligned(T const* mem, convert, requires_arch) noexcept + { + return _mm512_load_si512((__m512i const*)mem); + } + template + XSIMD_INLINE batch load_aligned(float const* mem, convert, requires_arch) noexcept + { + return _mm512_load_ps(mem); + } + template + XSIMD_INLINE batch load_aligned(double const* mem, convert, requires_arch) noexcept + { + return _mm512_load_pd(mem); + } + + // load_complex + namespace detail + { + template + XSIMD_INLINE batch, A> load_complex(batch const& hi, batch const& lo, requires_arch) noexcept + { + __m512i real_idx = _mm512_setr_epi32(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30); + __m512i imag_idx = _mm512_setr_epi32(1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31); + auto real = _mm512_permutex2var_ps(hi, real_idx, lo); + auto imag = _mm512_permutex2var_ps(hi, imag_idx, lo); + return { real, imag }; + } + template + XSIMD_INLINE batch, A> load_complex(batch const& hi, batch const& lo, requires_arch) noexcept + { + __m512i real_idx = _mm512_setr_epi64(0, 2, 4, 6, 8, 10, 12, 14); + __m512i imag_idx = _mm512_setr_epi64(1, 3, 5, 7, 9, 11, 13, 15); + auto real = _mm512_permutex2var_pd(hi, real_idx, lo); + auto imag = _mm512_permutex2var_pd(hi, imag_idx, lo); + return { real, imag }; + } + } + + // load_unaligned + template ::value, void>::type> + XSIMD_INLINE batch load_unaligned(T const* mem, convert, requires_arch) noexcept + { + return _mm512_loadu_si512((__m512i const*)mem); + } + template + XSIMD_INLINE batch load_unaligned(float const* mem, convert, requires_arch) noexcept + { + return _mm512_loadu_ps(mem); + } + template + XSIMD_INLINE batch load_unaligned(double const* mem, convert, requires_arch) noexcept + { + return _mm512_loadu_pd(mem); + } + + // lt + template + XSIMD_INLINE batch_bool lt(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_cmp_ps_mask(self, other, _CMP_LT_OQ); + } + template + XSIMD_INLINE batch_bool lt(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_cmp_pd_mask(self, other, _CMP_LT_OQ); + } + + template ::value, void>::type> + XSIMD_INLINE batch_bool lt(batch const& self, batch const& other, requires_arch) noexcept + { + return detail::compare_int_avx512f(self, other); + } + + // mask + template + XSIMD_INLINE uint64_t mask(batch_bool const& self, requires_arch) noexcept + { + return self.data; + } + + // max + template + XSIMD_INLINE batch max(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_max_ps(self, other); + } + template + XSIMD_INLINE batch max(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_max_pd(self, other); + } + template ::value, void>::type> + XSIMD_INLINE batch max(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm512_max_epi32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm512_max_epi64(self, other); + } + else + { + return detail::fwd_to_avx([](__m256i s, __m256i o) noexcept + { return max(batch(s), batch(o)); }, + self, other); + } + } + else + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm512_max_epu32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm512_max_epu64(self, other); + } + else + { + return detail::fwd_to_avx([](__m256i s, __m256i o) noexcept + { return max(batch(s), batch(o)); }, + self, other); + } + } + } + + // min + template + XSIMD_INLINE batch min(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_min_ps(self, other); + } + template + XSIMD_INLINE batch min(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_min_pd(self, other); + } + template ::value, void>::type> + XSIMD_INLINE batch min(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm512_min_epi32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm512_min_epi64(self, other); + } + else + { + return detail::fwd_to_avx([](__m256i s, __m256i o) noexcept + { return min(batch(s), batch(o)); }, + self, other); + } + } + else + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm512_min_epu32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm512_min_epu64(self, other); + } + else + { + return detail::fwd_to_avx([](__m256i s, __m256i o) noexcept + { return min(batch(s), batch(o)); }, + self, other); + } + } + } + + // mul + template + XSIMD_INLINE batch mul(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_mul_ps(self, other); + } + template + XSIMD_INLINE batch mul(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_mul_pd(self, other); + } + template ::value, void>::type> + XSIMD_INLINE batch mul(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm512_mullo_epi32(self, other); + } + else + { + return detail::fwd_to_avx([](__m256i s, __m256i o) noexcept + { return mul(batch(s), batch(o)); }, + self, other); + } + } + + // nearbyint + template + XSIMD_INLINE batch nearbyint(batch const& self, requires_arch) noexcept + { + return _mm512_roundscale_round_ps(self, _MM_FROUND_TO_NEAREST_INT, _MM_FROUND_CUR_DIRECTION); + } + template + XSIMD_INLINE batch nearbyint(batch const& self, requires_arch) noexcept + { + return _mm512_roundscale_round_pd(self, _MM_FROUND_TO_NEAREST_INT, _MM_FROUND_CUR_DIRECTION); + } + + // nearbyint_as_int + template + XSIMD_INLINE batch nearbyint_as_int(batch const& self, + requires_arch) noexcept + { + return _mm512_cvtps_epi32(self); + } + + // neg + template + XSIMD_INLINE batch neg(batch const& self, requires_arch) noexcept + { + return 0 - self; + } + + // neq + template + XSIMD_INLINE batch_bool neq(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_cmp_ps_mask(self, other, _CMP_NEQ_UQ); + } + template + XSIMD_INLINE batch_bool neq(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_cmp_pd_mask(self, other, _CMP_NEQ_UQ); + } + template ::value, void>::type> + XSIMD_INLINE batch_bool neq(batch const& self, batch const& other, requires_arch) noexcept + { + return ~(self == other); + } + + template + XSIMD_INLINE batch_bool neq(batch_bool const& self, batch_bool const& other, requires_arch) noexcept + { + using register_type = typename batch_bool::register_type; + return register_type(self.data ^ other.data); + } + + // reciprocal + template + XSIMD_INLINE batch + reciprocal(batch const& self, + kernel::requires_arch) noexcept + { + return _mm512_rcp14_ps(self); + } + + template + XSIMD_INLINE batch + reciprocal(batch const& self, + kernel::requires_arch) noexcept + { + return _mm512_rcp14_pd(self); + } + + // reduce_add + template + XSIMD_INLINE float reduce_add(batch const& rhs, requires_arch) noexcept + { + __m128 tmp1 = _mm512_extractf32x4_ps(rhs, 0); + __m128 tmp2 = _mm512_extractf32x4_ps(rhs, 1); + __m128 tmp3 = _mm512_extractf32x4_ps(rhs, 2); + __m128 tmp4 = _mm512_extractf32x4_ps(rhs, 3); + __m128 res1 = _mm_add_ps(tmp1, tmp2); + __m128 res2 = _mm_add_ps(tmp3, tmp4); + __m128 res3 = _mm_add_ps(res1, res2); + return reduce_add(batch(res3), sse4_2 {}); + } + template + XSIMD_INLINE double reduce_add(batch const& rhs, requires_arch) noexcept + { + __m256d tmp1 = _mm512_extractf64x4_pd(rhs, 1); + __m256d tmp2 = _mm512_extractf64x4_pd(rhs, 0); + __m256d res1 = _mm256_add_pd(tmp1, tmp2); + return reduce_add(batch(res1), avx2 {}); + } + template ::value, void>::type> + XSIMD_INLINE T reduce_add(batch const& self, requires_arch) noexcept + { + __m256i low, high; + detail::split_avx512(self, low, high); + batch blow(low), bhigh(high); + return reduce_add(blow, avx2 {}) + reduce_add(bhigh, avx2 {}); + } + + // reduce_max + template ::type> + XSIMD_INLINE T reduce_max(batch const& self, requires_arch) noexcept + { + constexpr batch_constant mask; + batch step = _mm512_permutexvar_epi64(mask.as_batch(), self); + batch acc = max(self, step); + __m256i low = _mm512_castsi512_si256(acc); + return reduce_max(batch(low)); + } + + // reduce_min + template ::type> + XSIMD_INLINE T reduce_min(batch const& self, requires_arch) noexcept + { + constexpr batch_constant mask; + batch step = _mm512_permutexvar_epi64(mask.as_batch(), self); + batch acc = min(self, step); + __m256i low = _mm512_castsi512_si256(acc); + return reduce_min(batch(low)); + } + + // rsqrt + template + XSIMD_INLINE batch rsqrt(batch const& val, requires_arch) noexcept + { + return _mm512_rsqrt14_ps(val); + } + template + XSIMD_INLINE batch rsqrt(batch const& val, requires_arch) noexcept + { + return _mm512_rsqrt14_pd(val); + } + + // sadd + template ::value, void>::type> + XSIMD_INLINE batch sadd(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + auto mask = other < 0; + auto self_pos_branch = min(std::numeric_limits::max() - other, self); + auto self_neg_branch = max(std::numeric_limits::min() - other, self); + return other + select(mask, self_neg_branch, self_pos_branch); + } + else + { + const auto diffmax = std::numeric_limits::max() - self; + const auto mindiff = min(diffmax, other); + return self + mindiff; + } + } + + // scatter + template ::value || std::is_same::value, void>::type> + XSIMD_INLINE void scatter(batch const& src, T* dst, + batch const& index, + kernel::requires_arch) noexcept + { + _mm512_i32scatter_epi32(dst, index, src, sizeof(T)); + } + + template ::value || std::is_same::value, void>::type> + XSIMD_INLINE void scatter(batch const& src, T* dst, + batch const& index, + kernel::requires_arch) noexcept + { + _mm512_i64scatter_epi64(dst, index, src, sizeof(T)); + } + + template + XSIMD_INLINE void scatter(batch const& src, float* dst, + batch const& index, + kernel::requires_arch) noexcept + { + _mm512_i32scatter_ps(dst, index, src, sizeof(float)); + } + + template + XSIMD_INLINE void scatter(batch const& src, double* dst, + batch const& index, + kernel::requires_arch) noexcept + { + _mm512_i64scatter_pd(dst, index, src, sizeof(double)); + } + + // select + template + XSIMD_INLINE batch select(batch_bool const& cond, batch const& true_br, batch const& false_br, requires_arch) noexcept + { + return _mm512_mask_blend_ps(cond, false_br, true_br); + } + template + XSIMD_INLINE batch select(batch_bool const& cond, batch const& true_br, batch const& false_br, requires_arch) noexcept + { + return _mm512_mask_blend_pd(cond, false_br, true_br); + } + + template ::value, void>::type> + XSIMD_INLINE batch select(batch_bool const& cond, batch const& true_br, batch const& false_br, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + alignas(avx2::alignment()) uint8_t buffer[64]; + // FIXME: ultra inefficient + for (int i = 0; i < 64; ++i) + buffer[i] = cond.data & (1ull << i) ? 0xFF : 0; + __m256i cond_low = batch::load_aligned(&buffer[0]); + __m256i cond_hi = batch::load_aligned(&buffer[32]); + + __m256i true_low, true_hi; + detail::split_avx512(true_br, true_low, true_hi); + + __m256i false_low, false_hi; + detail::split_avx512(false_br, false_low, false_hi); + + __m256i res_low = select(batch_bool(cond_low), batch(true_low), batch(false_low), avx2 {}); + __m256i res_hi = select(batch_bool(cond_hi), batch(true_hi), batch(false_hi), avx2 {}); + return detail::merge_avx(res_low, res_hi); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + __m256i cond_low = _mm512_maskz_cvtepi32_epi16((uint64_t)cond.data & 0xFFFF, _mm512_set1_epi32(~0)); + __m256i cond_hi = _mm512_maskz_cvtepi32_epi16((uint64_t)cond.data >> 16, _mm512_set1_epi32(~0)); + + __m256i true_low, true_hi; + detail::split_avx512(true_br, true_low, true_hi); + + __m256i false_low, false_hi; + detail::split_avx512(false_br, false_low, false_hi); + + __m256i res_low = select(batch_bool(cond_low), batch(true_low), batch(false_low), avx2 {}); + __m256i res_hi = select(batch_bool(cond_hi), batch(true_hi), batch(false_hi), avx2 {}); + return detail::merge_avx(res_low, res_hi); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm512_mask_blend_epi32(cond, false_br, true_br); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm512_mask_blend_epi64(cond, false_br, true_br); + } + else + { + assert(false && "unsupported arch/type combination"); + return {}; + } + } + + template ::value, void>::type> + XSIMD_INLINE batch select(batch_bool_constant const&, batch const& true_br, batch const& false_br, requires_arch) noexcept + { + return select(batch_bool { Values... }, true_br, false_br, avx512f {}); + } + + namespace detail + { + template + using enable_signed_integer_t = typename std::enable_if::value && std::is_signed::value, + int>::type; + + template + using enable_unsigned_integer_t = typename std::enable_if::value && std::is_unsigned::value, + int>::type; + } + + // set + template + XSIMD_INLINE batch set(batch const&, requires_arch, float v0, float v1, float v2, float v3, float v4, float v5, float v6, float v7, float v8, float v9, float v10, float v11, float v12, float v13, float v14, float v15) noexcept + { + return _mm512_setr_ps(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15); + } + + template + XSIMD_INLINE batch set(batch const&, requires_arch, double v0, double v1, double v2, double v3, double v4, double v5, double v6, double v7) noexcept + { + return _mm512_setr_pd(v0, v1, v2, v3, v4, v5, v6, v7); + } + template ::value, void>::type> + XSIMD_INLINE batch set(batch const&, requires_arch, T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7) noexcept + { + return _mm512_set_epi64(v7, v6, v5, v4, v3, v2, v1, v0); + } + template ::value, void>::type> + XSIMD_INLINE batch set(batch const&, requires_arch, T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, + T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15) noexcept + { + return _mm512_setr_epi32(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15); + } + template = 0> + XSIMD_INLINE batch set(batch const&, requires_arch, T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, + T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15, + T v16, T v17, T v18, T v19, T v20, T v21, T v22, T v23, + T v24, T v25, T v26, T v27, T v28, T v29, T v30, T v31) noexcept + { +#if defined(__clang__) || __GNUC__ + return __extension__(__m512i)(__v32hi) { + v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, + v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31 + }; +#else + return _mm512_set_epi16(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, + v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31); +#endif + } + + template = 0> + XSIMD_INLINE batch set(batch const&, requires_arch, T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, + T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15, + T v16, T v17, T v18, T v19, T v20, T v21, T v22, T v23, + T v24, T v25, T v26, T v27, T v28, T v29, T v30, T v31) noexcept + { +#if defined(__clang__) || __GNUC__ + return __extension__(__m512i)(__v32hu) { + v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, + v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31 + }; +#else + return _mm512_set_epi16(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, + v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31); +#endif + } + + template = 0> + XSIMD_INLINE batch set(batch const&, requires_arch, T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, + T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15, + T v16, T v17, T v18, T v19, T v20, T v21, T v22, T v23, + T v24, T v25, T v26, T v27, T v28, T v29, T v30, T v31, + T v32, T v33, T v34, T v35, T v36, T v37, T v38, T v39, + T v40, T v41, T v42, T v43, T v44, T v45, T v46, T v47, + T v48, T v49, T v50, T v51, T v52, T v53, T v54, T v55, + T v56, T v57, T v58, T v59, T v60, T v61, T v62, T v63) noexcept + { + +#if defined(__clang__) || __GNUC__ + return __extension__(__m512i)(__v64qi) { + v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, + v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, + v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, + v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63 + }; +#else + return _mm512_set_epi8(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, + v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, + v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, + v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63); +#endif + } + template = 0> + XSIMD_INLINE batch set(batch const&, requires_arch, T v0, T v1, T v2, T v3, T v4, T v5, T v6, T v7, + T v8, T v9, T v10, T v11, T v12, T v13, T v14, T v15, + T v16, T v17, T v18, T v19, T v20, T v21, T v22, T v23, + T v24, T v25, T v26, T v27, T v28, T v29, T v30, T v31, + T v32, T v33, T v34, T v35, T v36, T v37, T v38, T v39, + T v40, T v41, T v42, T v43, T v44, T v45, T v46, T v47, + T v48, T v49, T v50, T v51, T v52, T v53, T v54, T v55, + T v56, T v57, T v58, T v59, T v60, T v61, T v62, T v63) noexcept + { + +#if defined(__clang__) || __GNUC__ + return __extension__(__m512i)(__v64qu) { + v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, + v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, + v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, + v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63 + }; +#else + return _mm512_set_epi8(v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, + v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, + v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, + v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63); +#endif + } + + template + XSIMD_INLINE batch_bool set(batch_bool const&, requires_arch, Values... values) noexcept + { + static_assert(sizeof...(Values) == batch_bool::size, "consistent init"); + using register_type = typename batch_bool::register_type; + register_type r = 0; + unsigned shift = 0; + (void)std::initializer_list { (r |= register_type(values ? 1 : 0) << (shift++))... }; + return r; + } + + // shuffle + template + XSIMD_INLINE batch shuffle(batch const& x, batch const& y, + batch_constant mask, + requires_arch) noexcept + { + constexpr uint32_t smask = (I0 & 0x3) | ((I1 & 0x3) << 2) | ((I2 & 0x3) << 4) | ((I3 & 0x3) << 6); + + // shuffle within lane + if ((I4 == I0 + 4) && (I5 == I1 + 4) && (I6 == I2 + 4) && (I7 == I3 + 4) && (I8 == I0 + 8) && (I9 == I1 + 8) && (I10 == I2 + 8) && (I11 == I3 + 8) && (I12 == I0 + 12) && (I13 == I1 + 12) && (I14 == I2 + 12) && (I15 == I3 + 12) && I0 < 4 && I1 < 4 && I2 >= 16 && I2 < 20 && I3 >= 16 && I3 < 20) + return _mm512_shuffle_ps(x, y, smask); + + // shuffle within opposite lane + if ((I4 == I0 + 4) && (I5 == I1 + 4) && (I6 == I2 + 4) && (I7 == I3 + 4) && (I8 == I0 + 8) && (I9 == I1 + 8) && (I10 == I2 + 8) && (I11 == I3 + 8) && (I12 == I0 + 12) && (I13 == I1 + 12) && (I14 == I2 + 12) && (I15 == I3 + 12) && I2 < 4 && I3 < 4 && I0 >= 16 && I0 < 20 && I1 >= 16 && I1 < 20) + return _mm512_shuffle_ps(y, x, smask); + + return shuffle(x, y, mask, generic {}); + } + + template + XSIMD_INLINE batch shuffle(batch const& x, batch const& y, batch_constant mask, requires_arch) noexcept + { + constexpr uint32_t smask = (I0 & 0x1) | ((I1 & 0x1) << 1) | ((I2 & 0x1) << 2) | ((I3 & 0x1) << 3) | ((I4 & 0x1) << 4) | ((I5 & 0x1) << 5) | ((I6 & 0x1) << 6) | ((I7 & 0x1) << 7); + // shuffle within lane + if (I0 < 2 && I1 >= 8 && I1 < 10 && I2 >= 2 && I2 < 4 && I3 >= 10 && I3 < 12 && I4 >= 4 && I4 < 6 && I5 >= 12 && I5 < 14 && I6 >= 6 && I6 < 8 && I7 >= 14) + return _mm512_shuffle_pd(x, y, smask); + + // shuffle within opposite lane + if (I1 < 2 && I0 >= 8 && I0 < 10 && I3 >= 2 && I3 < 4 && I2 >= 10 && I2 < 12 && I5 >= 4 && I5 < 6 && I4 >= 12 && I4 < 14 && I7 >= 6 && I7 < 8 && I6 >= 14) + return _mm512_shuffle_pd(y, x, smask); + + return shuffle(x, y, mask, generic {}); + } + + // slide_left + template + XSIMD_INLINE batch slide_left(batch const&, requires_arch) noexcept + { + static_assert(N == 0xDEAD, "not implemented yet"); + return {}; + } + + // slide_right + template + XSIMD_INLINE batch slide_right(batch const&, requires_arch) noexcept + { + static_assert(N == 0xDEAD, "not implemented yet"); + return {}; + } + + // sqrt + template + XSIMD_INLINE batch sqrt(batch const& val, requires_arch) noexcept + { + return _mm512_sqrt_ps(val); + } + template + XSIMD_INLINE batch sqrt(batch const& val, requires_arch) noexcept + { + return _mm512_sqrt_pd(val); + } + + // ssub + template ::value, void>::type> + XSIMD_INLINE batch ssub(batch const& self, batch const& other, requires_arch) noexcept + { + if (std::is_signed::value) + { + return sadd(self, -other); + } + else + { + const auto diff = min(self, other); + return self - diff; + } + } + + // store + template + XSIMD_INLINE void store(batch_bool const& self, bool* mem, requires_arch) noexcept + { + using register_type = typename batch_bool::register_type; + constexpr auto size = batch_bool::size; + for (std::size_t i = 0; i < size; ++i) + mem[i] = self.data & (register_type(1) << i); + } + + // store_aligned + template ::value, void>::type> + XSIMD_INLINE void store_aligned(T* mem, batch const& self, requires_arch) noexcept + { + return _mm512_store_si512((__m512i*)mem, self); + } + template ::value, void>::type> + XSIMD_INLINE void store_aligned(T* mem, batch_bool const& self, requires_arch) noexcept + { + return _mm512_store_si512((__m512i*)mem, self); + } + template + XSIMD_INLINE void store_aligned(float* mem, batch const& self, requires_arch) noexcept + { + return _mm512_store_ps(mem, self); + } + template + XSIMD_INLINE void store_aligned(double* mem, batch const& self, requires_arch) noexcept + { + return _mm512_store_pd(mem, self); + } + + // store_unaligned + template ::value, void>::type> + XSIMD_INLINE void store_unaligned(T* mem, batch const& self, requires_arch) noexcept + { + return _mm512_storeu_si512((__m512i*)mem, self); + } + template ::value, void>::type> + XSIMD_INLINE void store_unaligned(T* mem, batch_bool const& self, requires_arch) noexcept + { + return _mm512_storeu_si512((__m512i*)mem, self); + } + template + XSIMD_INLINE void store_unaligned(float* mem, batch const& self, requires_arch) noexcept + { + return _mm512_storeu_ps(mem, self); + } + template + XSIMD_INLINE void store_unaligned(double* mem, batch const& self, requires_arch) noexcept + { + return _mm512_storeu_pd(mem, self); + } + + // sub + template ::value, void>::type> + XSIMD_INLINE batch sub(batch const& self, batch const& other, requires_arch) noexcept + { + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + return detail::fwd_to_avx([](__m256i s, __m256i o) noexcept + { return sub(batch(s), batch(o)); }, + self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + return detail::fwd_to_avx([](__m256i s, __m256i o) noexcept + { return sub(batch(s), batch(o)); }, + self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + return _mm512_sub_epi32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + return _mm512_sub_epi64(self, other); + } + else + { + assert(false && "unsupported arch/op combination"); + return {}; + } + } + template + XSIMD_INLINE batch sub(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_sub_ps(self, other); + } + template + XSIMD_INLINE batch sub(batch const& self, batch const& other, requires_arch) noexcept + { + return _mm512_sub_pd(self, other); + } + + // swizzle (dynamic version) + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + return _mm512_permutexvar_ps(mask, self); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + return _mm512_permutexvar_pd(mask, self); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + return _mm512_permutexvar_epi64(mask, self); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + return bitwise_cast(swizzle(bitwise_cast(self), mask, avx512f {})); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + return _mm512_permutexvar_epi32(mask, self); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch mask, requires_arch) noexcept + { + return bitwise_cast(swizzle(bitwise_cast(self), mask, avx512f {})); + } + + // swizzle (constant version) + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + { + return swizzle(self, mask.as_batch(), avx512f {}); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + { + return swizzle(self, mask.as_batch(), avx512f {}); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + { + return swizzle(self, mask.as_batch(), avx512f {}); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + { + return swizzle(self, mask.as_batch(), avx512f {}); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + { + return swizzle(self, mask.as_batch(), avx512f {}); + } + + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + { + return swizzle(self, mask.as_batch(), avx512f {}); + } + + namespace detail + { + template + struct is_pair_of_contiguous_indices; + + template + struct is_pair_of_contiguous_indices : std::true_type + { + }; + + template + struct is_pair_of_contiguous_indices : std::conditional<(Idx0 % 2 == 0) && (Idx0 + 1 == Idx1), is_pair_of_contiguous_indices, std::false_type>::type + { + }; + + template + struct fold_batch_constant + { + using type = batch_constant; + }; + + } + + template ::value, void>::type> + XSIMD_INLINE batch swizzle(batch const& self, batch_constant, requires_arch) noexcept + { + constexpr typename detail::fold_batch_constant::type mask32; + return _mm512_permutexvar_epi32(static_cast>(mask32), self); + } + + template + XSIMD_INLINE batch + swizzle(batch const& self, batch_constant, requires_arch) noexcept + { + // FIXME: this sequence is very inefficient, but it's here to catch + // a pattern generated by detail::reduce from xsimd_generic_math.hpp. + // The whole pattern is actually decently folded by GCC and Clang, + // so bare with it. + constexpr batch_constant mask32; + auto tmp = _mm512_permutexvar_epi32(static_cast>(mask32), self); + + alignas(A::alignment()) uint16_t buffer[32]; + _mm512_store_si512((__m512i*)&buffer[0], tmp); + buffer[0] = buffer[1]; + return _mm512_load_si512(&buffer[0]); + } + + template + XSIMD_INLINE batch + swizzle(batch const& self, batch_constant mask, requires_arch) noexcept + { + return bitwise_cast(swizzle(bitwise_cast(self), mask, avx512f {})); + } + + // trunc + template + XSIMD_INLINE batch + trunc(batch const& self, requires_arch) noexcept + { + return _mm512_roundscale_round_ps(self, _MM_FROUND_TO_ZERO, _MM_FROUND_CUR_DIRECTION); + } + template + XSIMD_INLINE batch + trunc(batch const& self, requires_arch) noexcept + { + return _mm512_roundscale_round_pd(self, _MM_FROUND_TO_ZERO, _MM_FROUND_CUR_DIRECTION); + } + + // zip_hi + template ::value, void>::type> + XSIMD_INLINE batch + zip_hi(batch const& self, batch const& other, requires_arch) noexcept + { + __m512i lo, hi; + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + assert(false && "not implemented yet"); + return {}; + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + assert(false && "not implemented yet"); + return {}; + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + lo = _mm512_unpacklo_epi32(self, other); + hi = _mm512_unpackhi_epi32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + lo = _mm512_unpacklo_epi64(self, other); + hi = _mm512_unpackhi_epi64(self, other); + } + else + { + assert(false && "unsupported arch/op combination"); + return {}; + } + return _mm512_inserti32x4( + _mm512_inserti32x4( + _mm512_inserti32x4(hi, _mm512_extracti32x4_epi32(lo, 2), 0), + _mm512_extracti32x4_epi32(lo, 3), + 2), + _mm512_extracti32x4_epi32(hi, 2), + 1); + } + template + XSIMD_INLINE batch + zip_hi(batch const& self, batch const& other, requires_arch) noexcept + { + auto lo = _mm512_unpacklo_ps(self, other); + auto hi = _mm512_unpackhi_ps(self, other); + return _mm512_insertf32x4( + _mm512_insertf32x4( + _mm512_insertf32x4(hi, _mm512_extractf32x4_ps(lo, 2), 0), + _mm512_extractf32x4_ps(lo, 3), + 2), + _mm512_extractf32x4_ps(hi, 2), + 1); + } + template + XSIMD_INLINE batch + zip_hi(batch const& self, batch const& other, requires_arch) noexcept + { + auto lo = _mm512_castpd_ps(_mm512_unpacklo_pd(self, other)); + auto hi = _mm512_castpd_ps(_mm512_unpackhi_pd(self, other)); + return _mm512_castps_pd(_mm512_insertf32x4( + _mm512_insertf32x4( + _mm512_insertf32x4(hi, _mm512_extractf32x4_ps(lo, 2), 0), + _mm512_extractf32x4_ps(lo, 3), + 2), + _mm512_extractf32x4_ps(hi, 2), + 1)); + } + + // zip_lo + template ::value, void>::type> + XSIMD_INLINE batch + zip_lo(batch const& self, batch const& other, requires_arch) noexcept + { + __m512i lo, hi; + XSIMD_IF_CONSTEXPR(sizeof(T) == 1) + { + assert(false && "not implemented yet"); + return {}; + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 2) + { + assert(false && "not implemented yet"); + return {}; + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 4) + { + lo = _mm512_unpacklo_epi32(self, other); + hi = _mm512_unpackhi_epi32(self, other); + } + else XSIMD_IF_CONSTEXPR(sizeof(T) == 8) + { + lo = _mm512_unpacklo_epi64(self, other); + hi = _mm512_unpackhi_epi64(self, other); + } + else + { + assert(false && "unsupported arch/op combination"); + return {}; + } + return _mm512_inserti32x4( + _mm512_inserti32x4( + _mm512_inserti32x4(lo, _mm512_extracti32x4_epi32(hi, 0), 1), + _mm512_extracti32x4_epi32(hi, 1), + 3), + _mm512_extracti32x4_epi32(lo, 1), + 2); + } + template + XSIMD_INLINE batch + zip_lo(batch const& self, batch const& other, requires_arch) noexcept + { + auto lo = _mm512_unpacklo_ps(self, other); + auto hi = _mm512_unpackhi_ps(self, other); + return _mm512_insertf32x4( + _mm512_insertf32x4( + _mm512_insertf32x4(lo, _mm512_extractf32x4_ps(hi, 0), 1), + _mm512_extractf32x4_ps(hi, 1), + 3), + _mm512_extractf32x4_ps(lo, 1), + 2); + } + template + XSIMD_INLINE batch + zip_lo(batch const& self, batch const& other, requires_arch) noexcept + { + auto lo = _mm512_castpd_ps(_mm512_unpacklo_pd(self, other)); + auto hi = _mm512_castpd_ps(_mm512_unpackhi_pd(self, other)); + return _mm512_castps_pd(_mm512_insertf32x4( + _mm512_insertf32x4( + _mm512_insertf32x4(lo, _mm512_extractf32x4_ps(hi, 0), 1), + _mm512_extractf32x4_ps(hi, 1), + 3), + _mm512_extractf32x4_ps(lo, 1), + 2)); + } + + } + +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_avx512ifma.hpp b/include/onnxruntime/xsimd/arch/xsimd_avx512ifma.hpp new file mode 100644 index 0000000000000..df382881b0b2e --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_avx512ifma.hpp @@ -0,0 +1,20 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_AVX512VBMI_HPP +#define XSIMD_AVX512VBMI_HPP + +#include +#include + +#include "../types/xsimd_avx512vbmi_register.hpp" + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_avx512pf.hpp b/include/onnxruntime/xsimd/arch/xsimd_avx512pf.hpp new file mode 100644 index 0000000000000..6265c91718fb0 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_avx512pf.hpp @@ -0,0 +1,20 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_AVX512PF_HPP +#define XSIMD_AVX512PF_HPP + +#include +#include + +#include "../types/xsimd_avx512pf_register.hpp" + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_avx512vbmi.hpp b/include/onnxruntime/xsimd/arch/xsimd_avx512vbmi.hpp new file mode 100644 index 0000000000000..df382881b0b2e --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_avx512vbmi.hpp @@ -0,0 +1,20 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_AVX512VBMI_HPP +#define XSIMD_AVX512VBMI_HPP + +#include +#include + +#include "../types/xsimd_avx512vbmi_register.hpp" + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_avx512vnni_avx512bw.hpp b/include/onnxruntime/xsimd/arch/xsimd_avx512vnni_avx512bw.hpp new file mode 100644 index 0000000000000..b285623d02f69 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_avx512vnni_avx512bw.hpp @@ -0,0 +1,20 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_AVX512VNNI_AVX512_BW_HPP +#define XSIMD_AVX512VNNI_AVX512_BW_HPP + +#include +#include + +#include "../types/xsimd_avx512vnni_avx512bw_register.hpp" + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_avx512vnni_avx512vbmi.hpp b/include/onnxruntime/xsimd/arch/xsimd_avx512vnni_avx512vbmi.hpp new file mode 100644 index 0000000000000..a70d30fad5985 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_avx512vnni_avx512vbmi.hpp @@ -0,0 +1,20 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_AVX512VNNI_AVX512VBMI_HPP +#define XSIMD_AVX512VNNI_AVX512VBMI_HPP + +#include +#include + +#include "../types/xsimd_avx512vnni_avx512vbmi_register.hpp" + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_avxvnni.hpp b/include/onnxruntime/xsimd/arch/xsimd_avxvnni.hpp new file mode 100644 index 0000000000000..a97ba9296c516 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_avxvnni.hpp @@ -0,0 +1,20 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_AVXVNNI_HPP +#define XSIMD_AVXVNNI_HPP + +#include +#include + +#include "../types/xsimd_avxvnni_register.hpp" + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_constants.hpp b/include/onnxruntime/xsimd/arch/xsimd_constants.hpp new file mode 100644 index 0000000000000..51411d2877465 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_constants.hpp @@ -0,0 +1,391 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_NUMERICAL_CONSTANT_HPP +#define XSIMD_NUMERICAL_CONSTANT_HPP + +#include + +#include "../types/xsimd_utils.hpp" + +namespace xsimd +{ + + namespace constants + { + +#define XSIMD_DEFINE_CONSTANT(NAME, SINGLE, DOUBLE) \ + template \ + XSIMD_INLINE T NAME() noexcept \ + { \ + return T(NAME()); \ + } \ + template <> \ + XSIMD_INLINE float NAME() noexcept \ + { \ + return SINGLE; \ + } \ + template <> \ + XSIMD_INLINE double NAME() noexcept \ + { \ + return DOUBLE; \ + } + +#define XSIMD_DEFINE_CONSTANT_HEX(NAME, SINGLE, DOUBLE) \ + template \ + XSIMD_INLINE T NAME() noexcept \ + { \ + return T(NAME()); \ + } \ + template <> \ + XSIMD_INLINE float NAME() noexcept \ + { \ + return bit_cast((uint32_t)SINGLE); \ + } \ + template <> \ + XSIMD_INLINE double NAME() noexcept \ + { \ + return bit_cast((uint64_t)DOUBLE); \ + } + +// Under fast-math, GCC might replace signmask (minus zero) by zero +#if defined(__FAST_MATH__) && defined(__GNUC__) && !defined(__clang__) +#pragma GCC push_options +#pragma GCC optimize("signed-zeros") +#endif + XSIMD_DEFINE_CONSTANT(infinity, (std::numeric_limits::infinity()), (std::numeric_limits::infinity())) + XSIMD_DEFINE_CONSTANT(invlog_2, 1.442695040888963407359924681001892137426645954152986f, 1.442695040888963407359924681001892137426645954152986) + XSIMD_DEFINE_CONSTANT_HEX(invlog_2hi, 0x3fb8b000, 0x3ff7154765200000) + XSIMD_DEFINE_CONSTANT_HEX(invlog_2lo, 0xb9389ad4, 0x3de705fc2eefa200) + XSIMD_DEFINE_CONSTANT(invlog10_2, 3.32192809488736234787031942949f, 3.32192809488736234787031942949) + XSIMD_DEFINE_CONSTANT_HEX(invpi, 0x3ea2f983, 0x3fd45f306dc9c883) + XSIMD_DEFINE_CONSTANT(log_2, 0.6931471805599453094172321214581765680755001343602553f, 0.6931471805599453094172321214581765680755001343602553) + XSIMD_DEFINE_CONSTANT_HEX(log_2hi, 0x3f318000, 0x3fe62e42fee00000) + XSIMD_DEFINE_CONSTANT_HEX(log_2lo, 0xb95e8083, 0x3dea39ef35793c76) + XSIMD_DEFINE_CONSTANT_HEX(log10_2hi, 0x3e9a0000, 0x3fd3440000000000) + XSIMD_DEFINE_CONSTANT_HEX(log10_2lo, 0x39826a14, 0x3ed3509f79fef312) + XSIMD_DEFINE_CONSTANT_HEX(logeps, 0xc17f1402, 0xc04205966f2b4f12) + XSIMD_DEFINE_CONSTANT_HEX(logpi, 0x3f928682, 0x3ff250d048e7a1bd) + XSIMD_DEFINE_CONSTANT_HEX(logsqrt2pi, 0x3f6b3f8e, 0x3fed67f1c864beb5) + XSIMD_DEFINE_CONSTANT(maxflint, 16777216.0f, 9007199254740992.0) + XSIMD_DEFINE_CONSTANT(maxlog, 88.3762626647949f, 709.78271289338400) + XSIMD_DEFINE_CONSTANT(maxlog2, 127.0f, 1023.) + XSIMD_DEFINE_CONSTANT(maxlog10, 38.23080825805664f, 308.2547155599167) + XSIMD_DEFINE_CONSTANT_HEX(mediumpi, 0x43490fdb, 0x412921fb54442d18) + XSIMD_DEFINE_CONSTANT(minlog, -88.3762626647949f, -708.3964185322641) + XSIMD_DEFINE_CONSTANT(minlog2, -127.0f, -1023.) + XSIMD_DEFINE_CONSTANT(minlog10, -37.89999771118164f, -308.2547155599167) + XSIMD_DEFINE_CONSTANT(minusinfinity, (-infinity()), (-infinity())) + XSIMD_DEFINE_CONSTANT_HEX(nan, 0xffffffff, 0xffffffffffffffff) + XSIMD_DEFINE_CONSTANT_HEX(oneosqrteps, 0x453504f3, 0x4190000000000000) + XSIMD_DEFINE_CONSTANT_HEX(oneotwoeps, 0x4a800000, 0x4320000000000000) + XSIMD_DEFINE_CONSTANT_HEX(pi, 0x40490fdb, 0x400921fb54442d18) + XSIMD_DEFINE_CONSTANT_HEX(pio_2lo, 0xb33bbd2e, 0x3c91a62633145c07) + XSIMD_DEFINE_CONSTANT_HEX(pio_4lo, 0xb2bbbd2e, 0x3c81a62633145c07) + XSIMD_DEFINE_CONSTANT_HEX(pio2, 0x3fc90fdb, 0x3ff921fb54442d18) + XSIMD_DEFINE_CONSTANT_HEX(pio2_1, 0x3fc90f80, 0x3ff921fb54400000) + XSIMD_DEFINE_CONSTANT_HEX(pio2_1t, 0x37354443, 0x3dd0b4611a626331) + XSIMD_DEFINE_CONSTANT_HEX(pio2_2, 0x37354400, 0x3dd0b4611a600000) + XSIMD_DEFINE_CONSTANT_HEX(pio2_2t, 0x2e85a308, 0x3ba3198a2e037073) + XSIMD_DEFINE_CONSTANT_HEX(pio2_3, 0x2e85a300, 0x3ba3198a2e000000) + XSIMD_DEFINE_CONSTANT_HEX(pio2_3t, 0x248d3132, 0x397b839a252049c1) + XSIMD_DEFINE_CONSTANT_HEX(pio4, 0x3f490fdb, 0x3fe921fb54442d18) + XSIMD_DEFINE_CONSTANT_HEX(signmask, 0x80000000, 0x8000000000000000) + XSIMD_DEFINE_CONSTANT(smallestposval, std::numeric_limits::min(), std::numeric_limits::min()) + XSIMD_DEFINE_CONSTANT_HEX(sqrt_2pi, 0x40206c99, 0x40040d931ff62704) + XSIMD_DEFINE_CONSTANT_HEX(sqrteps, 0x39b504f3, 0x3e50000000000000) + XSIMD_DEFINE_CONSTANT_HEX(tanpio8, 0x3ed413cd, 0x3fda827999fcef31) + XSIMD_DEFINE_CONSTANT_HEX(tan3pio8, 0x401a827a, 0x4003504f333f9de6) + XSIMD_DEFINE_CONSTANT_HEX(twentypi, 0x427b53d1, 0x404f6a7a2955385e) + XSIMD_DEFINE_CONSTANT_HEX(twoopi, 0x3f22f983, 0x3fe45f306dc9c883) + XSIMD_DEFINE_CONSTANT(twotonmb, 8388608.0f, 4503599627370496.0) + XSIMD_DEFINE_CONSTANT_HEX(twotonmbo3, 0x3ba14518, 0x3ed428a2f98d7286) +#if defined(__FAST_MATH__) && defined(__GNUC__) && !defined(__clang__) +#pragma GCC pop_options +#endif + +#undef XSIMD_DEFINE_CONSTANT +#undef XSIMD_DEFINE_CONSTANT_HEX + + template + constexpr T allbits() noexcept; + + template + constexpr as_integer_t mask1frexp() noexcept; + + template + constexpr as_integer_t mask2frexp() noexcept; + + template + constexpr as_integer_t maxexponent() noexcept; + + template + constexpr as_integer_t maxexponentm1() noexcept; + + template + constexpr int32_t nmb() noexcept; + + template + constexpr T zero() noexcept; + + template + constexpr T minvalue() noexcept; + + template + constexpr T maxvalue() noexcept; + + /************************** + * allbits implementation * + **************************/ + + namespace detail + { + template ::value> + struct allbits_impl + { + static constexpr T get_value() noexcept + { + return T(~0); + } + }; + + template + struct allbits_impl + { + static constexpr T get_value() noexcept + { + return nan(); + } + }; + } + + template + XSIMD_INLINE constexpr T allbits() noexcept + { + return T(detail::allbits_impl::get_value()); + } + + /***************************** + * mask1frexp implementation * + *****************************/ + + template + XSIMD_INLINE constexpr as_integer_t mask1frexp() noexcept + { + return as_integer_t(mask1frexp()); + } + + template <> + XSIMD_INLINE constexpr int32_t mask1frexp() noexcept + { + return 0x7f800000; + } + + template <> + XSIMD_INLINE constexpr int64_t mask1frexp() noexcept + { + return 0x7ff0000000000000; + } + + /***************************** + * mask2frexp implementation * + *****************************/ + + template + XSIMD_INLINE constexpr as_integer_t mask2frexp() noexcept + { + return as_integer_t(mask2frexp()); + } + + template <> + XSIMD_INLINE constexpr int32_t mask2frexp() noexcept + { + return 0x3f000000; + } + + template <> + XSIMD_INLINE constexpr int64_t mask2frexp() noexcept + { + return 0x3fe0000000000000; + } + + /****************************** + * maxexponent implementation * + ******************************/ + + template + XSIMD_INLINE constexpr as_integer_t maxexponent() noexcept + { + return as_integer_t(maxexponent()); + } + + template <> + XSIMD_INLINE constexpr int32_t maxexponent() noexcept + { + return 127; + } + + template <> + XSIMD_INLINE constexpr int64_t maxexponent() noexcept + { + return 1023; + } + + /****************************** + * maxexponent implementation * + ******************************/ + + template + XSIMD_INLINE constexpr as_integer_t maxexponentm1() noexcept + { + return as_integer_t(maxexponentm1()); + } + + template <> + XSIMD_INLINE constexpr int32_t maxexponentm1() noexcept + { + return 126; + } + + template <> + XSIMD_INLINE constexpr int64_t maxexponentm1() noexcept + { + return 1022; + } + + /********************** + * nmb implementation * + **********************/ + + template + XSIMD_INLINE constexpr int32_t nmb() noexcept + { + return nmb(); + } + + template <> + XSIMD_INLINE constexpr int32_t nmb() noexcept + { + return 23; + } + + template <> + XSIMD_INLINE constexpr int32_t nmb() noexcept + { + return 52; + } + + /*********************** + * zero implementation * + ***********************/ + + template + XSIMD_INLINE constexpr T zero() noexcept + { + return T(typename T::value_type(0)); + } + + /*************************** + * minvalue implementation * + ***************************/ + + namespace detail + { + template + struct minvalue_impl + { + static constexpr T get_value() noexcept + { + return std::numeric_limits::min(); + } + }; + + template + struct minvalue_common + { + static constexpr T get_value() noexcept + { + return std::numeric_limits::min(); + } + }; + + template <> + struct minvalue_impl : minvalue_common + { + }; + template <> + struct minvalue_impl : minvalue_common + { + }; + template <> + struct minvalue_impl : minvalue_common + { + }; + template <> + struct minvalue_impl : minvalue_common + { + }; + template <> + struct minvalue_impl : minvalue_common + { + }; + template <> + struct minvalue_impl : minvalue_common + { + }; + template <> + struct minvalue_impl : minvalue_common + { + }; + template <> + struct minvalue_impl : minvalue_common + { + }; + + template <> + struct minvalue_impl + { + XSIMD_INLINE static float get_value() noexcept + { + return bit_cast((uint32_t)0xff7fffff); + } + }; + + template <> + struct minvalue_impl + { + XSIMD_INLINE static double get_value() noexcept + { + return bit_cast((uint64_t)0xffefffffffffffff); + } + }; + } + + template + constexpr T minvalue() noexcept + { + return T(detail::minvalue_impl::get_value()); + } + + /*************************** + * maxvalue implementation * + ***************************/ + + template + constexpr T maxvalue() noexcept + { + return T(std::numeric_limits::max()); + } + } + +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_emulated.hpp b/include/onnxruntime/xsimd/arch/xsimd_emulated.hpp new file mode 100644 index 0000000000000..2f4585bbb3a0c --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_emulated.hpp @@ -0,0 +1,771 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_EMULATED_HPP +#define XSIMD_EMULATED_HPP + +#include +#include +#include +#include + +#include "../arch/xsimd_scalar.hpp" + +#include "../types/xsimd_emulated_register.hpp" +#include "../types/xsimd_utils.hpp" + +namespace xsimd +{ + template + struct batch_bool_constant; + + template + XSIMD_INLINE batch bitwise_cast(batch const& x) noexcept; + + template + struct batch_constant; + + namespace kernel + { + using namespace types; + + // fwd + template + XSIMD_INLINE batch insert(batch const& self, T val, index, requires_arch) noexcept; + template + XSIMD_INLINE batch shuffle(batch const& x, batch const& y, batch_constant, requires_arch) noexcept; + + namespace detail + { + template + auto emulated_apply(F func, Bs const&... bs) -> decltype(func(bs.data[I]...)) + { + return func(bs.data[I]...); + } + + template + auto emulated_apply(F func, ::xsimd::detail::index_sequence, B const& b, Bs const&... bs) -> std::array + { + return { emulated_apply(func, b, bs...)... }; + } + + template + auto emulated_apply(F func, B const& b, Bs const&... bs) -> std::array + { + return emulated_apply(func, ::xsimd::detail::make_index_sequence(), b, bs...); + } + } + + // abs + template ::size> + XSIMD_INLINE batch abs(batch const& self, requires_arch>) noexcept + { + return detail::emulated_apply([](T v) + { return xsimd::abs(v); }, + self); + } + + // add + template ::size> + XSIMD_INLINE batch add(batch const& self, batch const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::add(v0, v1); }, + self, other); + } + + // all + template ::size> + XSIMD_INLINE bool all(batch_bool const& self, requires_arch>) noexcept + { + return std::all_of(self.data.begin(), self.data.end(), [](T v) + { return bool(v); }); + } + + // any + template ::size> + XSIMD_INLINE bool any(batch_bool const& self, requires_arch>) noexcept + { + return std::any_of(self.data.begin(), self.data.end(), [](T v) + { return bool(v); }); + } + + // batch_bool_cast + template ::size> + XSIMD_INLINE batch_bool batch_bool_cast(batch_bool const& self, batch_bool const&, requires_arch>) noexcept + { + return { self.data }; + } + + // bitwise_and + template ::size> + XSIMD_INLINE batch bitwise_and(batch const& self, batch const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::bitwise_and(v0, v1); }, + self, other); + } + + template ::size> + XSIMD_INLINE batch_bool bitwise_and(batch_bool const& self, batch_bool const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](bool v0, bool v1) + { return xsimd::bitwise_and(v0, v1); }, + self, other); + } + + // bitwise_andnot + template ::size> + XSIMD_INLINE batch bitwise_andnot(batch const& self, batch const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::bitwise_andnot(v0, v1); }, + self, other); + } + + template ::size> + XSIMD_INLINE batch_bool bitwise_andnot(batch_bool const& self, batch_bool const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](bool v0, bool v1) + { return xsimd::bitwise_andnot(v0, v1); }, + self, other); + } + + // bitwise_lshift + template ::size> + XSIMD_INLINE batch bitwise_lshift(batch const& self, int32_t other, requires_arch>) noexcept + { + return detail::emulated_apply([other](T v) + { return xsimd::bitwise_lshift(v, other); }, + self); + } + + // bitwise_not + template ::size> + XSIMD_INLINE batch bitwise_not(batch const& self, requires_arch>) noexcept + { + return detail::emulated_apply([](T v) + { return xsimd::bitwise_not(v); }, + self); + } + + template ::size> + XSIMD_INLINE batch_bool bitwise_not(batch_bool const& self, requires_arch>) noexcept + { + return detail::emulated_apply([](bool v) + { return xsimd::bitwise_not(v); }, + self); + } + + // bitwise_or + template ::size> + XSIMD_INLINE batch bitwise_or(batch const& self, batch const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::bitwise_or(v0, v1); }, + self, other); + } + + template ::size> + XSIMD_INLINE batch_bool bitwise_or(batch_bool const& self, batch_bool const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](bool v0, bool v1) + { return xsimd::bitwise_or(v0, v1); }, + self, other); + } + + // bitwise_rshift + template ::size> + XSIMD_INLINE batch bitwise_rshift(batch const& self, int32_t other, requires_arch>) noexcept + { + return detail::emulated_apply([other](T v) + { return xsimd::bitwise_rshift(v, other); }, + self); + } + + // bitwise_xor + template ::size> + XSIMD_INLINE batch bitwise_xor(batch const& self, batch const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::bitwise_xor(v0, v1); }, + self, other); + } + + template ::size> + XSIMD_INLINE batch_bool bitwise_xor(batch_bool const& self, batch_bool const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](bool v0, bool v1) + { return xsimd::bitwise_xor(v0, v1); }, + self, other); + } + + // bitwise_cast + template ::size> + XSIMD_INLINE batch bitwise_cast(batch const& self, batch const&, requires_arch>) noexcept + { + constexpr size_t size = batch::size; + std::array result; + char* raw_data = reinterpret_cast(result.data()); + const char* raw_input = reinterpret_cast(self.data.data()); + memcpy(raw_data, raw_input, size * sizeof(T_out)); + return result; + } + + // broadcast + template ::size> + batch XSIMD_INLINE broadcast(T val, requires_arch>) noexcept + { + constexpr size_t size = batch::size; + std::array r; + std::fill(r.begin(), r.end(), val); + return r; + } + +#if 0 + // count + template ::size> + XSIMD_INLINE size_t count(batch_bool const& x, requires_arch>) noexcept + { + uint64_t m = x.mask(); + // https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel + m = m - ((m >> 1) & (uint64_t) ~(uint64_t)0 / 3); // temp + m = (m & (uint64_t) ~(uint64_t)0 / 15 * 3) + ((m >> 2) & (uint64_t) ~(uint64_t)0 / 15 * 3); // temp + m = (m + (m >> 4)) & (uint64_t) ~(uint64_t)0 / 255 * 15; // temp + return (m * ((uint64_t) ~(uint64_t)0 / 255)) >> (sizeof(uint64_t) - 1) * CHAR_BIT; // count + } +#endif + + // store_complex + namespace detail + { + // complex_low + template ::size> + XSIMD_INLINE batch complex_low(batch, A> const& self, requires_arch>) noexcept + { + constexpr size_t size = batch::size; + std::array result; + for (size_t i = 0; i < size / 2; ++i) + { + result[2 * i] = self.real().data[i]; + result[1 + 2 * i] = self.imag().data[i]; + } + return result; + } + // complex_high + template ::size> + XSIMD_INLINE batch complex_high(batch, A> const& self, requires_arch>) noexcept + { + constexpr size_t size = batch::size; + std::array result; + for (size_t i = 0; i < size / 2; ++i) + { + result[2 * i] = self.real().data[i + size / 2]; + result[1 + 2 * i] = self.imag().data[i + size / 2]; + } + return result; + } + } + + // decr_if + template ::size> + XSIMD_INLINE batch decr_if(batch const& self, batch_bool const& mask, requires_arch>) noexcept + { + return self - batch(mask.data); + } + + // div + template ::size> + XSIMD_INLINE batch div(batch const& self, batch const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::div(v0, v1); }, + self, other); + } + + // fast_cast + namespace detail + { + template ::size> + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch>) noexcept + { + return detail::emulated_apply([](int32_t v) + { return float(v); }, + self); + } + + template ::size> + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch>) noexcept + { + return detail::emulated_apply([](uint32_t v) + { return float(v); }, + self); + } + + template ::size> + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch>) noexcept + { + return detail::emulated_apply([](int64_t v) + { return double(v); }, + self); + } + + template ::size> + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch>) noexcept + { + return detail::emulated_apply([](uint64_t v) + { return double(v); }, + self); + } + + template ::size> + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch>) noexcept + { + return detail::emulated_apply([](float v) + { return int32_t(v); }, + self); + } + + template ::size> + XSIMD_INLINE batch fast_cast(batch const& self, batch const&, requires_arch>) noexcept + { + return detail::emulated_apply([](double v) + { return int64_t(v); }, + self); + } + } + + // eq + template ::size> + XSIMD_INLINE batch_bool> eq(batch> const& self, batch> const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::eq(v0, v1); }, + self, other); + } + + template ::size> + XSIMD_INLINE batch_bool> eq(batch_bool> const& self, batch_bool> const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](bool v0, bool v1) + { return xsimd::eq(v0, v1); }, + self, other); + } + + // from_bool + template ::size> + XSIMD_INLINE batch from_bool(batch_bool const& self, requires_arch>) noexcept + { + return detail::emulated_apply([](bool v) + { return T(v); }, + self); + } + + // from_mask + template ::size> + XSIMD_INLINE batch_bool from_mask(batch_bool const&, uint64_t mask, requires_arch>) noexcept + { + constexpr size_t size = batch::size; + std::array vmask; + for (size_t i = 0; i < size; ++i) + vmask[i] = (mask >> i) & 1u; + return vmask; + } + + // ge + template ::size> + XSIMD_INLINE batch_bool> ge(batch> const& self, batch> const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::ge(v0, v1); }, + self, other); + } + + // gt + template ::size> + XSIMD_INLINE batch_bool> gt(batch> const& self, batch> const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::gt(v0, v1); }, + self, other); + } + + // haddp + template ::size> + XSIMD_INLINE batch haddp(batch const* row, requires_arch>) noexcept + { + constexpr size_t size = batch::size; + std::array r; + for (size_t i = 0; i < size; ++i) + r[i] = std::accumulate(row[i].data.begin() + 1, row[i].data.end(), row[i].data.front()); + return r; + } + + // incr_if + template ::size> + XSIMD_INLINE batch incr_if(batch const& self, batch_bool const& mask, requires_arch>) noexcept + { + return self + batch(mask.data); + } + + // insert + template ::size> + XSIMD_INLINE batch insert(batch const& self, T val, index, requires_arch>) noexcept + { + batch other = self; + other.data[I] = val; + return other; + } + + // isnan + template ::size, class = typename std::enable_if::value, void>::type> + XSIMD_INLINE batch_bool isnan(batch const& self, requires_arch>) noexcept + { + return detail::emulated_apply([](T v) + { return xsimd::isnan(v); }, + self); + } + + // load_aligned + template ::size> + XSIMD_INLINE batch load_aligned(T const* mem, convert, requires_arch>) noexcept + { + constexpr size_t size = batch::size; + std::array res; + std::copy(mem, mem + size, res.begin()); + return res; + } + + // load_unaligned + template ::size> + XSIMD_INLINE batch load_unaligned(T const* mem, convert, requires_arch>) noexcept + { + constexpr size_t size = batch::size; + std::array res; + std::copy(mem, mem + size, res.begin()); + return res; + } + + // load_complex + namespace detail + { + template ::size> + XSIMD_INLINE batch, A> load_complex(batch const& hi, batch const& lo, requires_arch>) noexcept + { + constexpr size_t size = batch::size; + std::array real, imag; + for (size_t i = 0; i < size / 2; ++i) + { + real[i] = hi.data[2 * i]; + imag[i] = hi.data[1 + 2 * i]; + } + for (size_t i = 0; i < size / 2; ++i) + { + real[size / 2 + i] = lo.data[2 * i]; + imag[size / 2 + i] = lo.data[1 + 2 * i]; + } + return { real, imag }; + } + } + + // le + template ::size> + XSIMD_INLINE batch_bool> le(batch> const& self, batch> const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::le(v0, v1); }, + self, other); + } + + // lt + template ::size> + XSIMD_INLINE batch_bool> lt(batch> const& self, batch> const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::lt(v0, v1); }, + self, other); + } + + // mask + template ::size> + XSIMD_INLINE uint64_t mask(batch_bool const& self, requires_arch>) noexcept + { + constexpr size_t size = batch::size; + uint64_t res = 0; + for (size_t i = 0; i < size; ++i) + res |= (self.data[i] ? 1u : 0u) << i; + return res; + } + + // max + template ::size> + XSIMD_INLINE batch max(batch const& self, batch const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::max(v0, v1); }, + self, other); + } + + // min + template ::size> + XSIMD_INLINE batch min(batch const& self, batch const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::min(v0, v1); }, + self, other); + } + + // mul + template ::size> + XSIMD_INLINE batch mul(batch const& self, batch const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::mul(v0, v1); }, + self, other); + } + + // nearbyint_as_int + template ::size> + XSIMD_INLINE batch, A> nearbyint_as_int(batch const& self, + requires_arch>) noexcept + { + return detail::emulated_apply([](T v) + { return xsimd::nearbyint_as_int(v); }, + self); + } + + // neg + template ::size> + XSIMD_INLINE batch neg(batch const& self, requires_arch>) noexcept + { + return detail::emulated_apply([](T v) + { return xsimd::neg(v); }, + self); + } + + // neq + template ::size> + XSIMD_INLINE batch_bool neq(batch const& self, batch const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::neq(v0, v1); }, + self, other); + } + + template ::size> + XSIMD_INLINE batch_bool neq(batch_bool const& self, batch_bool const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](bool v0, bool v1) + { return xsimd::neq(v0, v1); }, + self, other); + } + + // reduce_add + template ::size> + XSIMD_INLINE T reduce_add(batch const& self, requires_arch>) noexcept + { + constexpr size_t size = batch::size; + std::array buffer; + self.store_unaligned(buffer.data()); + return std::accumulate(buffer.begin() + 1, buffer.end(), *buffer.begin()); + } + + // reduce_max + template ::size> + XSIMD_INLINE T reduce_max(batch const& self, requires_arch>) noexcept + { + return std::accumulate(self.data.begin() + 1, self.data.end(), *self.data.begin(), [](T const& x, T const& y) + { return xsimd::max(x, y); }); + } + + // reduce_min + template ::size> + XSIMD_INLINE T reduce_min(batch const& self, requires_arch>) noexcept + { + return std::accumulate(self.data.begin() + 1, self.data.end(), *self.data.begin(), [](T const& x, T const& y) + { return xsimd::min(x, y); }); + } + + // rsqrt + template ::size> + XSIMD_INLINE batch rsqrt(batch const& self, requires_arch>) noexcept + { + return detail::emulated_apply([](T v) + { return xsimd::rsqrt(v); }, + self); + } + + // select + template ::size> + XSIMD_INLINE batch select(batch_bool const& cond, batch const& true_br, batch const& false_br, requires_arch>) noexcept + { + return detail::emulated_apply([](bool c, T t, T f) + { return xsimd::select(c, t, f); }, + cond, true_br, false_br); + } + + template + XSIMD_INLINE batch select(batch_bool_constant const& cond, batch const& true_br, batch const& false_br, requires_arch::size>>) noexcept + { + constexpr size_t size = batch::size; + static_assert(sizeof...(Values) == size, "consistent init"); + return select((batch_bool)cond, true_br, false_br, emulated<8 * sizeof(T) * size> {}); + } + + // shuffle + template + XSIMD_INLINE batch shuffle(batch const& x, batch const& y, batch_constant mask, requires_arch::size>>) noexcept + { + constexpr size_t size = batch::size; + batch bmask = mask; + std::array res; + for (size_t i = 0; i < size; ++i) + res[i] = bmask.data[i] < size ? x.data[bmask.data[i]] : y.data[bmask.data[i] - size]; + return res; + } + + // sqrt + template ::size> + XSIMD_INLINE batch sqrt(batch const& self, requires_arch>) noexcept + { + return detail::emulated_apply([](T v) + { return xsimd::sqrt(v); }, + self); + } + + // slide_left + template ::size> + XSIMD_INLINE batch slide_left(batch const& x, requires_arch>) noexcept + { + constexpr size_t size = batch::size; + std::array result; + char* raw_data = reinterpret_cast(result.data()); + memset(raw_data, 0, M); + memcpy(raw_data + M, reinterpret_cast(x.data.data()), sizeof(T) * result.size() - M); + return result; + } + + // slide_right + template ::size> + XSIMD_INLINE batch slide_right(batch const& x, requires_arch>) noexcept + { + constexpr size_t size = batch::size; + std::array result; + char* raw_data = reinterpret_cast(result.data()); + memcpy(raw_data, reinterpret_cast(x.data.data()) + M, sizeof(T) * result.size() - M); + memset(raw_data + sizeof(T) * result.size() - M, 0, M); + return result; + } + + // sadd + template ::size> + XSIMD_INLINE batch sadd(batch const& self, batch const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::sadd(v0, v1); }, + self, other); + } + + // set + template + XSIMD_INLINE batch> set(batch> const&, requires_arch>, Values... values) noexcept + { + static_assert(sizeof...(Values) == batch>::size, "consistent init"); + return { typename batch>::register_type { static_cast(values)... } }; + } + + template + XSIMD_INLINE batch_bool> set(batch_bool> const&, requires_arch>, Values... values) noexcept + { + static_assert(sizeof...(Values) == batch>::size, "consistent init"); + return { std::array { static_cast(values)... } }; + } + + // ssub + template ::size> + XSIMD_INLINE batch ssub(batch const& self, batch const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::ssub(v0, v1); }, + self, other); + } + + // store_aligned + template + XSIMD_INLINE void store_aligned(T* mem, batch> const& self, requires_arch>) noexcept + { + std::copy(self.data.begin(), self.data.end(), mem); + } + + // store_unaligned + template + XSIMD_INLINE void store_unaligned(T* mem, batch> const& self, requires_arch>) noexcept + { + std::copy(self.data.begin(), self.data.end(), mem); + } + + // sub + template ::size> + XSIMD_INLINE batch sub(batch const& self, batch const& other, requires_arch>) noexcept + { + return detail::emulated_apply([](T v0, T v1) + { return xsimd::sub(v0, v1); }, + self, other); + } + + // swizzle + + template + XSIMD_INLINE batch swizzle(batch const& self, batch_constant mask, requires_arch::size>>) noexcept + { + constexpr size_t size = batch::size; + batch bmask = mask; + std::array res; + for (size_t i = 0; i < size; ++i) + res[i] = self.data[bmask.data[i]]; + return res; + } + + // zip_hi + template ::size> + XSIMD_INLINE batch zip_hi(batch const& self, batch const& other, requires_arch>) noexcept + { + constexpr size_t size = batch::size; + // Note: irregular behavior for odd numbers. + std::array res; + if (size % 2) + { + for (size_t i = 0; i < size; ++i) + res[i] = (i % 2 ? self : other).data[size / 2 + i / 2]; + } + else + { + for (size_t i = 0; i < size; ++i) + res[i] = (i % 2 ? other : self).data[size / 2 + i / 2]; + } + return res; + } + + // zip_lo + template ::size> + XSIMD_INLINE batch zip_lo(batch const& self, batch const& other, requires_arch>) noexcept + { + constexpr size_t size = batch::size; + // Note: irregular behavior for odd numbers. + std::array res; + for (size_t i = 0; i < size; ++i) + res[i] = (i % 2 ? other : self).data[i / 2]; + return res; + } + } +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_fma3_avx.hpp b/include/onnxruntime/xsimd/arch/xsimd_fma3_avx.hpp new file mode 100644 index 0000000000000..99262531476a9 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_fma3_avx.hpp @@ -0,0 +1,80 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_FMA3_AVX_HPP +#define XSIMD_FMA3_AVX_HPP + +#include "../types/xsimd_fma3_avx_register.hpp" + +namespace xsimd +{ + + namespace kernel + { + using namespace types; + + // fnma + template + XSIMD_INLINE batch fnma(batch const& x, batch const& y, batch const& z, requires_arch>) noexcept + { + return _mm256_fnmadd_ps(x, y, z); + } + + template + XSIMD_INLINE batch fnma(batch const& x, batch const& y, batch const& z, requires_arch>) noexcept + { + return _mm256_fnmadd_pd(x, y, z); + } + + // fnms + template + XSIMD_INLINE batch fnms(batch const& x, batch const& y, batch const& z, requires_arch>) noexcept + { + return _mm256_fnmsub_ps(x, y, z); + } + + template + XSIMD_INLINE batch fnms(batch const& x, batch const& y, batch const& z, requires_arch>) noexcept + { + return _mm256_fnmsub_pd(x, y, z); + } + + // fma + template + XSIMD_INLINE batch fma(batch const& x, batch const& y, batch const& z, requires_arch>) noexcept + { + return _mm256_fmadd_ps(x, y, z); + } + + template + XSIMD_INLINE batch fma(batch const& x, batch const& y, batch const& z, requires_arch>) noexcept + { + return _mm256_fmadd_pd(x, y, z); + } + + // fms + template + XSIMD_INLINE batch fms(batch const& x, batch const& y, batch const& z, requires_arch>) noexcept + { + return _mm256_fmsub_ps(x, y, z); + } + + template + XSIMD_INLINE batch fms(batch const& x, batch const& y, batch const& z, requires_arch>) noexcept + { + return _mm256_fmsub_pd(x, y, z); + } + + } + +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_fma3_avx2.hpp b/include/onnxruntime/xsimd/arch/xsimd_fma3_avx2.hpp new file mode 100644 index 0000000000000..134053951ac63 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_fma3_avx2.hpp @@ -0,0 +1,46 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_FMA3_AVX2_HPP +#define XSIMD_FMA3_AVX2_HPP + +#include "../types/xsimd_fma3_avx2_register.hpp" + +// Allow inclusion of xsimd_fma3_avx.hpp +#ifdef XSIMD_FMA3_AVX_HPP +#undef XSIMD_FMA3_AVX_HPP +#define XSIMD_FORCE_FMA3_AVX_HPP +#endif + +// Disallow inclusion of ./xsimd_fma3_avx_register.hpp +#ifndef XSIMD_FMA3_AVX_REGISTER_HPP +#define XSIMD_FMA3_AVX_REGISTER_HPP +#define XSIMD_FORCE_FMA3_AVX_REGISTER_HPP +#endif + +// Include ./xsimd_fma3_avx.hpp but s/avx/avx2 +#define avx avx2 +#include "./xsimd_fma3_avx.hpp" +#undef avx +#undef XSIMD_FMA3_AVX_HPP + +// Carefully restore guards +#ifdef XSIMD_FORCE_FMA3_AVX_HPP +#define XSIMD_FMA3_AVX_HPP +#undef XSIMD_FORCE_FMA3_AVX_HPP +#endif + +#ifdef XSIMD_FORCE_FMA3_AVX_REGISTER_HPP +#undef XSIMD_FMA3_AVX_REGISTER_HPP +#undef XSIMD_FORCE_FMA3_AVX_REGISTER_HPP +#endif + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_fma3_sse.hpp b/include/onnxruntime/xsimd/arch/xsimd_fma3_sse.hpp new file mode 100644 index 0000000000000..9b126166ac048 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_fma3_sse.hpp @@ -0,0 +1,79 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_FMA3_SSE_HPP +#define XSIMD_FMA3_SSE_HPP + +#include "../types/xsimd_fma3_sse_register.hpp" + +namespace xsimd +{ + + namespace kernel + { + using namespace types; + // fnma + template + XSIMD_INLINE batch fnma(batch const& x, batch const& y, batch const& z, requires_arch>) noexcept + { + return _mm_fnmadd_ps(x, y, z); + } + + template + XSIMD_INLINE batch fnma(batch const& x, batch const& y, batch const& z, requires_arch>) noexcept + { + return _mm_fnmadd_pd(x, y, z); + } + + // fnms + template + XSIMD_INLINE batch fnms(batch const& x, batch const& y, batch const& z, requires_arch>) noexcept + { + return _mm_fnmsub_ps(x, y, z); + } + + template + XSIMD_INLINE batch fnms(batch const& x, batch const& y, batch const& z, requires_arch>) noexcept + { + return _mm_fnmsub_pd(x, y, z); + } + + // fma + template + XSIMD_INLINE batch fma(batch const& x, batch const& y, batch const& z, requires_arch>) noexcept + { + return _mm_fmadd_ps(x, y, z); + } + + template + XSIMD_INLINE batch fma(batch const& x, batch const& y, batch const& z, requires_arch>) noexcept + { + return _mm_fmadd_pd(x, y, z); + } + + // fms + template + XSIMD_INLINE batch fms(batch const& x, batch const& y, batch const& z, requires_arch>) noexcept + { + return _mm_fmsub_ps(x, y, z); + } + + template + XSIMD_INLINE batch fms(batch const& x, batch const& y, batch const& z, requires_arch>) noexcept + { + return _mm_fmsub_pd(x, y, z); + } + + } + +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_fma4.hpp b/include/onnxruntime/xsimd/arch/xsimd_fma4.hpp new file mode 100644 index 0000000000000..e51c7c52a82c6 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_fma4.hpp @@ -0,0 +1,79 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_FMA4_HPP +#define XSIMD_FMA4_HPP + +#include "../types/xsimd_fma4_register.hpp" + +namespace xsimd +{ + + namespace kernel + { + using namespace types; + + // fnma + template + XSIMD_INLINE batch fnma(simd_register const& x, simd_register const& y, simd_register const& z, requires_arch) noexcept + { + return _mm_nmacc_ps(x, y, z); + } + + template + XSIMD_INLINE batch fnma(simd_register const& x, simd_register const& y, simd_register const& z, requires_arch) noexcept + { + return _mm_nmacc_pd(x, y, z); + } + + // fnms + template + XSIMD_INLINE batch fnms(simd_register const& x, simd_register const& y, simd_register const& z, requires_arch) noexcept + { + return _mm_nmsub_ps(x, y, z); + } + + template + XSIMD_INLINE batch fnms(simd_register const& x, simd_register const& y, simd_register const& z, requires_arch) noexcept + { + return _mm_nmsub_pd(x, y, z); + } + + // fma + template + XSIMD_INLINE batch fma(simd_register const& x, simd_register const& y, simd_register const& z, requires_arch) noexcept + { + return _mm_macc_ps(x, y, z); + } + + template + XSIMD_INLINE batch fma(simd_register const& x, simd_register const& y, simd_register const& z, requires_arch) noexcept + { + return _mm_macc_pd(x, y, z); + } + + // fms + template + XSIMD_INLINE batch fms(simd_register const& x, simd_register const& y, simd_register const& z, requires_arch) noexcept + { + return _mm_msub_ps(x, y, z); + } + + template + XSIMD_INLINE batch fms(simd_register const& x, simd_register const& y, simd_register const& z, requires_arch) noexcept + { + return _mm_msub_pd(x, y, z); + } + } + +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_generic.hpp b/include/onnxruntime/xsimd/arch/xsimd_generic.hpp new file mode 100644 index 0000000000000..6403cfb0fc138 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_generic.hpp @@ -0,0 +1,23 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_GENERIC_HPP +#define XSIMD_GENERIC_HPP + +#include "./generic/xsimd_generic_arithmetic.hpp" +#include "./generic/xsimd_generic_complex.hpp" +#include "./generic/xsimd_generic_logical.hpp" +#include "./generic/xsimd_generic_math.hpp" +#include "./generic/xsimd_generic_memory.hpp" +#include "./generic/xsimd_generic_rounding.hpp" +#include "./generic/xsimd_generic_trigo.hpp" + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_generic_fwd.hpp b/include/onnxruntime/xsimd/arch/xsimd_generic_fwd.hpp new file mode 100644 index 0000000000000..02708d60f70b9 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_generic_fwd.hpp @@ -0,0 +1,44 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_GENERIC_FWD_HPP +#define XSIMD_GENERIC_FWD_HPP + +#include "../types/xsimd_batch_constant.hpp" + +#include + +namespace xsimd +{ + namespace kernel + { + // forward declaration + template ::value, void>::type> + XSIMD_INLINE batch abs(batch const& self, requires_arch) noexcept; + template ::value, void>::type> + XSIMD_INLINE batch bitwise_lshift(batch const& self, batch const& other, requires_arch) noexcept; + template ::value, void>::type> + XSIMD_INLINE batch bitwise_rshift(batch const& self, batch const& other, requires_arch) noexcept; + template + XSIMD_INLINE batch_bool gt(batch const& self, batch const& other, requires_arch) noexcept; + template ::value, void>::type> + XSIMD_INLINE batch mul(batch const& self, batch const& other, requires_arch) noexcept; + template ::value, void>::type> + XSIMD_INLINE batch sadd(batch const& self, batch const& other, requires_arch) noexcept; + template ::value, void>::type> + XSIMD_INLINE batch ssub(batch const& self, batch const& other, requires_arch) noexcept; + template ::value, void>::type> + XSIMD_INLINE T hadd(batch const& self, requires_arch) noexcept; + + } +} + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_i8mm_neon64.hpp b/include/onnxruntime/xsimd/arch/xsimd_i8mm_neon64.hpp new file mode 100644 index 0000000000000..5533923020363 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_i8mm_neon64.hpp @@ -0,0 +1,17 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_I8MM_NEON64_HPP +#define XSIMD_I8MM_NEON64_HPP + +#include "../types/xsimd_i8mm_neon64_register.hpp" + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_isa.hpp b/include/onnxruntime/xsimd/arch/xsimd_isa.hpp new file mode 100644 index 0000000000000..5b714b299182a --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_isa.hpp @@ -0,0 +1,130 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_ISA_HPP +#define XSIMD_ISA_HPP + +#include "../config/xsimd_arch.hpp" + +#include "./xsimd_generic_fwd.hpp" + +#if XSIMD_WITH_EMULATED +#include "./xsimd_emulated.hpp" +#endif + +#if XSIMD_WITH_SSE2 +#include "./xsimd_sse2.hpp" +#endif + +#if XSIMD_WITH_SSE3 +#include "./xsimd_sse3.hpp" +#endif + +#if XSIMD_WITH_SSSE3 +#include "./xsimd_ssse3.hpp" +#endif + +#if XSIMD_WITH_SSE4_1 +#include "./xsimd_sse4_1.hpp" +#endif + +#if XSIMD_WITH_SSE4_2 +#include "./xsimd_sse4_2.hpp" +#endif + +#if XSIMD_WITH_FMA3_SSE +#include "./xsimd_fma3_sse.hpp" +#endif + +#if XSIMD_WITH_FMA4 +#include "./xsimd_fma4.hpp" +#endif + +#if XSIMD_WITH_AVX +#include "./xsimd_avx.hpp" +#endif + +#if XSIMD_WITH_FMA3_AVX +#include "./xsimd_fma3_avx.hpp" +#endif + +#if XSIMD_WITH_AVXVNNI +#include "./xsimd_avxvnni.hpp" +#endif + +#if XSIMD_WITH_AVX2 +#include "./xsimd_avx2.hpp" +#endif + +#if XSIMD_WITH_FMA3_AVX2 +#include "./xsimd_fma3_avx2.hpp" +#endif + +#if XSIMD_WITH_AVX512F +#include "./xsimd_avx512f.hpp" +#endif + +#if XSIMD_WITH_AVX512BW +#include "./xsimd_avx512bw.hpp" +#endif + +#if XSIMD_WITH_AVX512ER +#include "./xsimd_avx512er.hpp" +#endif + +#if XSIMD_WITH_AVX512PF +#include "./xsimd_avx512pf.hpp" +#endif + +#if XSIMD_WITH_AVX512IFMA +#include "./xsimd_avx512ifma.hpp" +#endif + +#if XSIMD_WITH_AVX512VBMI +#include "./xsimd_avx512vbmi.hpp" +#endif + +#if XSIMD_WITH_AVX512VNNI_AVX512BW +#include "./xsimd_avx512vnni_avx512bw.hpp" +#endif + +#if XSIMD_WITH_AVX512VNNI_AVX512VBMI +#include "./xsimd_avx512vnni_avx512vbmi.hpp" +#endif + +#if XSIMD_WITH_NEON +#include "./xsimd_neon.hpp" +#endif + +#if XSIMD_WITH_NEON64 +#include "./xsimd_neon64.hpp" +#endif + +#if XSIMD_WITH_I8MM_NEON64 +#include "./xsimd_i8mm_neon64.hpp" +#endif + +#if XSIMD_WITH_SVE +#include "./xsimd_sve.hpp" +#endif + +#if XSIMD_WITH_RVV +#include "./xsimd_rvv.hpp" +#endif + +#if XSIMD_WITH_WASM +#include "./xsimd_wasm.hpp" +#endif + +// Must come last to have access to all conversion specializations. +#include "./xsimd_generic.hpp" + +#endif diff --git a/include/onnxruntime/xsimd/arch/xsimd_neon.hpp b/include/onnxruntime/xsimd/arch/xsimd_neon.hpp new file mode 100644 index 0000000000000..2d0a244528667 --- /dev/null +++ b/include/onnxruntime/xsimd/arch/xsimd_neon.hpp @@ -0,0 +1,2813 @@ +/*************************************************************************** + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * + * Martin Renou * + * Copyright (c) QuantStack * + * Copyright (c) Serge Guelton * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + +#ifndef XSIMD_NEON_HPP +#define XSIMD_NEON_HPP + +#include +#include +#include +#include + +#include "../types/xsimd_neon_register.hpp" +#include "../types/xsimd_utils.hpp" + +// Wrap intrinsics so we can pass them as function pointers +// - OP: intrinsics name prefix, e.g., vorrq +// - RT: type traits to deduce intrinsics return types +#define WRAP_BINARY_UINT_EXCLUDING_64(OP, RT) \ + namespace wrap \ + { \ + XSIMD_INLINE RT OP##_u8(uint8x16_t a, uint8x16_t b) noexcept \ + { \ + return ::OP##_u8(a, b); \ + } \ + XSIMD_INLINE RT OP##_u16(uint16x8_t a, uint16x8_t b) noexcept \ + { \ + return ::OP##_u16(a, b); \ + } \ + XSIMD_INLINE RT OP##_u32(uint32x4_t a, uint32x4_t b) noexcept \ + { \ + return ::OP##_u32(a, b); \ + } \ + } + +#define WRAP_BINARY_INT_EXCLUDING_64(OP, RT) \ + WRAP_BINARY_UINT_EXCLUDING_64(OP, RT) \ + namespace wrap \ + { \ + XSIMD_INLINE RT OP##_s8(int8x16_t a, int8x16_t b) noexcept \ + { \ + return ::OP##_s8(a, b); \ + } \ + XSIMD_INLINE RT OP##_s16(int16x8_t a, int16x8_t b) noexcept \ + { \ + return ::OP##_s16(a, b); \ + } \ + XSIMD_INLINE RT OP##_s32(int32x4_t a, int32x4_t b) noexcept \ + { \ + return ::OP##_s32(a, b); \ + } \ + } + +#define WRAP_BINARY_INT(OP, RT) \ + WRAP_BINARY_INT_EXCLUDING_64(OP, RT) \ + namespace wrap \ + { \ + XSIMD_INLINE RT OP##_u64(uint64x2_t a, uint64x2_t b) noexcept \ + { \ + return ::OP##_u64(a, b); \ + } \ + XSIMD_INLINE RT OP##_s64(int64x2_t a, int64x2_t b) noexcept \ + { \ + return ::OP##_s64(a, b); \ + } \ + } + +#define WRAP_BINARY_FLOAT(OP, RT) \ + namespace wrap \ + { \ + XSIMD_INLINE RT OP##_f32(float32x4_t a, float32x4_t b) noexcept \ + { \ + return ::OP##_f32(a, b); \ + } \ + } + +#define WRAP_UNARY_INT_EXCLUDING_64(OP) \ + namespace wrap \ + { \ + XSIMD_INLINE uint8x16_t OP##_u8(uint8x16_t a) noexcept \ + { \ + return ::OP##_u8(a); \ + } \ + XSIMD_INLINE int8x16_t OP##_s8(int8x16_t a) noexcept \ + { \ + return ::OP##_s8(a); \ + } \ + XSIMD_INLINE uint16x8_t OP##_u16(uint16x8_t a) noexcept \ + { \ + return ::OP##_u16(a); \ + } \ + XSIMD_INLINE int16x8_t OP##_s16(int16x8_t a) noexcept \ + { \ + return ::OP##_s16(a); \ + } \ + XSIMD_INLINE uint32x4_t OP##_u32(uint32x4_t a) noexcept \ + { \ + return ::OP##_u32(a); \ + } \ + XSIMD_INLINE int32x4_t OP##_s32(int32x4_t a) noexcept \ + { \ + return ::OP##_s32(a); \ + } \ + } + +#define WRAP_UNARY_INT(OP) \ + WRAP_UNARY_INT_EXCLUDING_64(OP) \ + namespace wrap \ + { \ + XSIMD_INLINE uint64x2_t OP##_u64(uint64x2_t a) noexcept \ + { \ + return ::OP##_u64(a); \ + } \ + XSIMD_INLINE int64x2_t OP##_s64(int64x2_t a) noexcept \ + { \ + return ::OP##_s64(a); \ + } \ + } + +#define WRAP_UNARY_FLOAT(OP) \ + namespace wrap \ + { \ + XSIMD_INLINE float32x4_t OP##_f32(float32x4_t a) noexcept \ + { \ + return ::OP##_f32(a); \ + } \ + } + +// Dummy identity caster to ease coding +XSIMD_INLINE uint8x16_t vreinterpretq_u8_u8(uint8x16_t arg) noexcept { return arg; } +XSIMD_INLINE int8x16_t vreinterpretq_s8_s8(int8x16_t arg) noexcept { return arg; } +XSIMD_INLINE uint16x8_t vreinterpretq_u16_u16(uint16x8_t arg) noexcept { return arg; } +XSIMD_INLINE int16x8_t vreinterpretq_s16_s16(int16x8_t arg) noexcept { return arg; } +XSIMD_INLINE uint32x4_t vreinterpretq_u32_u32(uint32x4_t arg) noexcept { return arg; } +XSIMD_INLINE int32x4_t vreinterpretq_s32_s32(int32x4_t arg) noexcept { return arg; } +XSIMD_INLINE uint64x2_t vreinterpretq_u64_u64(uint64x2_t arg) noexcept { return arg; } +XSIMD_INLINE int64x2_t vreinterpretq_s64_s64(int64x2_t arg) noexcept { return arg; } +XSIMD_INLINE float32x4_t vreinterpretq_f32_f32(float32x4_t arg) noexcept { return arg; } + +namespace xsimd +{ + template + struct batch_bool_constant; + + namespace kernel + { + using namespace types; + + namespace detail + { + template