From 7dae88be8d8910495a6414186d9e0b5b9c449a20 Mon Sep 17 00:00:00 2001 From: OussamaDanba Date: Wed, 1 May 2019 16:03:28 +0200 Subject: [PATCH 1/5] Implement avx2 version of poly_S3_mul for hps2048509 Explanation: poly_Rq_mul is reused to do the multiplication. The MODQ in poly_Rq_mul does nothing due to the coefficients never going above 2036 (509*4). After the multiplication the last coefficient is fetched and multiplied by two before being added to all the other coefficients. Finally all coefficients are reduced modulo 3. Results: On an Intel i5-8250u using gcc 8.3.0 the reference poly_S3_mul takes about 287000 cycles on average whereas the avx2 version takes about 3700 cycles on average. This is about a 77 times speedup. --- avx2-hps2048509/Makefile | 2 +- avx2-hps2048509/asmgen/poly_s3_mul.py | 81 +++++++++++++++++++++++++++ avx2-hps2048509/poly.c | 17 +----- 3 files changed, 83 insertions(+), 17 deletions(-) create mode 100644 avx2-hps2048509/asmgen/poly_s3_mul.py diff --git a/avx2-hps2048509/Makefile b/avx2-hps2048509/Makefile index 57c7b8d..f21a695 100644 --- a/avx2-hps2048509/Makefile +++ b/avx2-hps2048509/Makefile @@ -4,7 +4,7 @@ CFLAGS = -Wall -Wextra -Wpedantic -O3 -fomit-frame-pointer -march=native -no-pie SOURCES = crypto_sort.c poly.c pack3.c packq.c fips202.c randombytes.c sample.c verify.c owcpa.c kem.c HEADERS = crypto_sort.h params.h poly.h randombytes.h sample.h verify.h owcpa.h kem.h -OBJS = poly_rq_mul.s +OBJS = poly_rq_mul.s poly_s3_mul.s all: test/test_polymul \ test/test_ntru \ diff --git a/avx2-hps2048509/asmgen/poly_s3_mul.py b/avx2-hps2048509/asmgen/poly_s3_mul.py new file mode 100644 index 0000000..145b5f9 --- /dev/null +++ b/avx2-hps2048509/asmgen/poly_s3_mul.py @@ -0,0 +1,81 @@ +p = print + +def mod3(a, r=13, t=14, c=15): + # r = (a >> 8) + (a & 0xff); // r mod 255 == a mod 255 + p("vpsrlw $8, %ymm{}, %ymm{}".format(a, r)) + p("vpand mask_ff, %ymm{}, %ymm{}".format(a, a)) + p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r)) + + # r = (r >> 4) + (r & 0xf); // r' mod 15 == r mod 15 + p("vpand mask_f, %ymm{}, %ymm{}".format(r, a)) + p("vpsrlw $4, %ymm{}, %ymm{}".format(r, r)) + p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r)) + + # r = (r >> 2) + (r & 0x3); // r' mod 3 == r mod 3 + # r = (r >> 2) + (r & 0x3); // r' mod 3 == r mod 3 + for _ in range(2): + p("vpand mask_3, %ymm{}, %ymm{}".format(r, a)) + p("vpsrlw $2, %ymm{}, %ymm{}".format(r, r)) + p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r)) + + # t = r - 3; + p("vpsubw mask_3, %ymm{}, %ymm{}".format(r, t)) + # c = t >> 15; t is signed, so shift arithmetic + p("vpsraw $15, %ymm{}, %ymm{}".format(t, c)) + + tmp = a + # return (c&r) ^ (~c&t); + p("vpandn %ymm{}, %ymm{}, %ymm{}".format(t, c, tmp)) + p("vpand %ymm{}, %ymm{}, %ymm{}".format(c, r, t)) + p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t, tmp, r)) + +from math import ceil + +# Reuses poly_Rq_mul to do the multiplication. This works out since none of the +# coefficients ever exceed 2048 (they can be 509*4=2036 at most). +# A presumably faster alternative implementation would be to do Karatsuba recursively five times +# such that you have 243 multiplications of 16 coefficient polynomials. These multiplications +# can be done using a bitsliced implementation of the schoolbook method since the coefficients are two bits. +# TODO: If poly_S3_mul shows up (significantly) in profiling consider the alternative implementation. +if __name__ == '__main__': + p(".data") + p(".align 32") + + p("mask_ff:") + for i in range(16): + p(".word 0xff") + p("mask_f:") + for i in range(16): + p(".word 0xf") + p("mask_3:") + for i in range(16): + p(".word 0x03") + + p(".text") + p(".global poly_S3_mul") + p(".att_syntax prefix") + + p("poly_S3_mul:") + p("call poly_Rq_mul") + # result pointer *r is still in %rdi + + N_min_1 = 0 + t = 1 + # NTRU_N is in 509th element; 13th word of 32nd register + p("vmovdqa {}(%rdi), %ymm{}".format(31*32, N_min_1)) + p("vpermq ${}, %ymm{}, %ymm{}".format(int('00000011', 2), N_min_1, N_min_1)) + # move into high 16 in doubleword (to clear high 16) and multiply by two + p("vpslld $17, %ymm{}, %ymm{}".format(N_min_1, N_min_1)) + # clone into bottom 16 + p("vpsrld $16, %ymm{}, %ymm{}".format(N_min_1, t)) + p("vpor %ymm{}, %ymm{}, %ymm{}".format(N_min_1, t, N_min_1)) + # and now it's everywhere in N_min_1 + p("vbroadcastss %xmm{}, %ymm{}".format(N_min_1, N_min_1)) + + retval = 2 + for i in range(ceil(509 / 16)): + p("vpaddw {}(%rdi), %ymm{}, %ymm{}".format(i * 32, N_min_1, t)) + mod3(t, retval) + p("vmovdqa %ymm{}, {}(%rdi)".format(retval, i*32)) + + p("ret") diff --git a/avx2-hps2048509/poly.c b/avx2-hps2048509/poly.c index 215ab13..9399c7c 100644 --- a/avx2-hps2048509/poly.c +++ b/avx2-hps2048509/poly.c @@ -3,6 +3,7 @@ #include "verify.h" extern void poly_Rq_mul(poly *r, const poly *a, const poly *b); +extern void poly_S3_mul(poly *r, const poly *a, const poly *b); uint16_t mod3(uint16_t a) { @@ -44,22 +45,6 @@ void poly_Sq_mul(poly *r, const poly *a, const poly *b) r->coeffs[i] = MODQ(r->coeffs[i] - r->coeffs[NTRU_N-1]); } -void poly_S3_mul(poly *r, const poly *a, const poly *b) -{ - int k,i; - - for(k=0; kcoeffs[k] = 0; - for(i=1; icoeffs[k] += a->coeffs[k+i] * b->coeffs[NTRU_N-i]; - for(i=0; icoeffs[k] += a->coeffs[k-i] * b->coeffs[i]; - } - for(k=0; kcoeffs[k] = mod3(r->coeffs[k] + 2*r->coeffs[NTRU_N-1]); -} - void poly_Rq_mul_x_minus_1(poly *r, const poly *a) { int i; From aeb0d284cfbb7f5ff30347056f29c6889ca2718c Mon Sep 17 00:00:00 2001 From: OussamaDanba Date: Thu, 2 May 2019 12:49:58 +0200 Subject: [PATCH 2/5] Implement avx2 version of poly_Rq_mul_x_minus_1 for hps2048509 Explanation: Straightforward conversion of the C function but using avx2 to operate on 16 coefficients at a time rather than one. Some special handling is done for the first coefficient (since the last coefficient is needed for that one) and the following 15 coefficients (can't operate on 16 coefficients anymore). Results: On an Intel i5-8250u using gcc 8.3.0 the reference poly_Rq_mul_x_minus_1 takes about 530 cycles on average whereas the avx2 version takes about 40 cycles on average. This is about a 13 times speedup. --- avx2-hps2048509/Makefile | 2 +- .../asmgen/poly_rq_mul_x_minus_1.py | 65 +++++++++++++++++++ avx2-hps2048509/poly.c | 12 +--- 3 files changed, 67 insertions(+), 12 deletions(-) create mode 100644 avx2-hps2048509/asmgen/poly_rq_mul_x_minus_1.py diff --git a/avx2-hps2048509/Makefile b/avx2-hps2048509/Makefile index f21a695..d668948 100644 --- a/avx2-hps2048509/Makefile +++ b/avx2-hps2048509/Makefile @@ -4,7 +4,7 @@ CFLAGS = -Wall -Wextra -Wpedantic -O3 -fomit-frame-pointer -march=native -no-pie SOURCES = crypto_sort.c poly.c pack3.c packq.c fips202.c randombytes.c sample.c verify.c owcpa.c kem.c HEADERS = crypto_sort.h params.h poly.h randombytes.h sample.h verify.h owcpa.h kem.h -OBJS = poly_rq_mul.s poly_s3_mul.s +OBJS = poly_rq_mul.s poly_s3_mul.s poly_rq_mul_x_minus_1.s all: test/test_polymul \ test/test_ntru \ diff --git a/avx2-hps2048509/asmgen/poly_rq_mul_x_minus_1.py b/avx2-hps2048509/asmgen/poly_rq_mul_x_minus_1.py new file mode 100644 index 0000000..33d476d --- /dev/null +++ b/avx2-hps2048509/asmgen/poly_rq_mul_x_minus_1.py @@ -0,0 +1,65 @@ +from math import ceil + +p = print + +if __name__ == '__main__': + p(".data") + p(".align 32") + + p("mask_mod2048:") + for i in range(16): + p(".word 2047") + + p("mask_mod2048_omit_lowest:") + p(".word 0") + for i in range(15): + p(".word 2047") + + p("mask_mod2048_only_lowest:") + p(".word 2047") + for i in range(15): + p(".word 0") + + p("shuf_5_to_0_zerorest:") + for i in range(2): + p(".byte {}".format((i + 2*5) % 16)) + for i in range(30): + p(".byte 255") + + p(".text") + p(".global poly_Rq_mul_x_minus_1") + p(".att_syntax prefix") + + p("poly_Rq_mul_x_minus_1:") + + a_imin1 = 0 + t0 = 1 + t1 = 4 + for i in range(ceil(509 / 16)-1, 0, -1): + p("vmovdqu {}(%rsi), %ymm{}".format((i*16 - 1) * 2, a_imin1)) + p("vpsubw {}(%rsi), %ymm{}, %ymm{}".format(i * 32, a_imin1, t0)) + p("vpand mask_mod2048, %ymm{}, %ymm{}".format(t0, t0)) + p("vmovdqa %ymm{}, {}(%rdi)".format(t0, i*32)) + if i == ceil(509 / 16)-1: + # a_imin1 now contains 495 to 510 inclusive; + # we need 509 for [0], which is at position 14 + p("vextracti128 $1, %ymm{}, %xmm{}".format(a_imin1, t1)) + p("vpshufb shuf_5_to_0_zerorest, %ymm{}, %ymm{}".format(t1, t1)) + p("vpsubw {}(%rsi), %ymm{}, %ymm{}".format(0, t1, t1)) + p("vpand mask_mod2048_only_lowest, %ymm{}, %ymm{}".format(t1, t1)) + + # and now we still need to fix [1] to [15], which we cannot vmovdqu + t2 = 0 + t3 = 2 + t4 = 3 + p("vmovdqa {}(%rsi), %ymm{}".format(0, t4)) + p("vpsrlq $48, %ymm{}, %ymm{}".format(t4, t2)) + p("vpermq ${}, %ymm{}, %ymm{}".format(int('10010011', 2), t2, t2)) + p("vpsllq $16, %ymm{}, %ymm{}".format(t4, t3)) + p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t2, t3, t3)) + p("vpsubw %ymm{}, %ymm{}, %ymm{}".format(t4, t3, t4)) + p("vpand mask_mod2048_omit_lowest, %ymm{}, %ymm{}".format(t4, t4)) + p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t4, t1, t4)) + p("vmovdqa %ymm{}, {}(%rdi)".format(t4, 0)) + + p("ret") diff --git a/avx2-hps2048509/poly.c b/avx2-hps2048509/poly.c index 9399c7c..6966de8 100644 --- a/avx2-hps2048509/poly.c +++ b/avx2-hps2048509/poly.c @@ -4,6 +4,7 @@ extern void poly_Rq_mul(poly *r, const poly *a, const poly *b); extern void poly_S3_mul(poly *r, const poly *a, const poly *b); +extern void poly_Rq_mul_x_minus_1(poly *r, const poly *a); uint16_t mod3(uint16_t a) { @@ -45,17 +46,6 @@ void poly_Sq_mul(poly *r, const poly *a, const poly *b) r->coeffs[i] = MODQ(r->coeffs[i] - r->coeffs[NTRU_N-1]); } -void poly_Rq_mul_x_minus_1(poly *r, const poly *a) -{ - int i; - uint16_t last_coeff = a->coeffs[NTRU_N-1]; - - for (i = NTRU_N - 1; i > 0; i--) { - r->coeffs[i] = MODQ(a->coeffs[i-1] + (NTRU_Q - a->coeffs[i])); - } - r->coeffs[0] = MODQ(last_coeff + (NTRU_Q - a->coeffs[0])); -} - #ifdef NTRU_HPS void poly_lift(poly *r, const poly *a) { From 441583b847b024b06a69206203c360a3901109a7 Mon Sep 17 00:00:00 2001 From: OussamaDanba Date: Thu, 9 May 2019 12:08:32 +0200 Subject: [PATCH 3/5] Import avx2 version of djbsort to be used for crypto_sort On an Intel i5-8250u using gcc 8.3.0 the reference sample_fixed_type (of which the majority of cycles is spent in crypto_sort) takes about 28000 on average whereas using the avx2 version of djbsort for crypto_sort takes about 3000 cycles on average. This is about a 9 times speedup. As a result ntru_encaps is about 1.5 times faster. --- avx2-hps2048509/Makefile | 4 +- avx2-hps2048509/crypto_sort.c | 11 +- avx2-hps2048509/djbsort/int32_minmax_x86.c | 13 + avx2-hps2048509/djbsort/int32_sort.h | 25 + avx2-hps2048509/djbsort/sort.c | 1184 ++++++++++++++++++++ 5 files changed, 1234 insertions(+), 3 deletions(-) mode change 120000 => 100644 avx2-hps2048509/crypto_sort.c create mode 100644 avx2-hps2048509/djbsort/int32_minmax_x86.c create mode 100644 avx2-hps2048509/djbsort/int32_sort.h create mode 100644 avx2-hps2048509/djbsort/sort.c diff --git a/avx2-hps2048509/Makefile b/avx2-hps2048509/Makefile index d668948..f95b6c5 100644 --- a/avx2-hps2048509/Makefile +++ b/avx2-hps2048509/Makefile @@ -1,8 +1,8 @@ CC = /usr/bin/cc CFLAGS = -Wall -Wextra -Wpedantic -O3 -fomit-frame-pointer -march=native -no-pie -SOURCES = crypto_sort.c poly.c pack3.c packq.c fips202.c randombytes.c sample.c verify.c owcpa.c kem.c -HEADERS = crypto_sort.h params.h poly.h randombytes.h sample.h verify.h owcpa.h kem.h +SOURCES = crypto_sort.c djbsort/sort.c poly.c pack3.c packq.c fips202.c randombytes.c sample.c verify.c owcpa.c kem.c +HEADERS = crypto_sort.h djbsort/int32_sort.h params.h poly.h randombytes.h sample.h verify.h owcpa.h kem.h OBJS = poly_rq_mul.s poly_s3_mul.s poly_rq_mul_x_minus_1.s diff --git a/avx2-hps2048509/crypto_sort.c b/avx2-hps2048509/crypto_sort.c deleted file mode 120000 index 9c48f3a..0000000 --- a/avx2-hps2048509/crypto_sort.c +++ /dev/null @@ -1 +0,0 @@ -../ref-common/crypto_sort.c \ No newline at end of file diff --git a/avx2-hps2048509/crypto_sort.c b/avx2-hps2048509/crypto_sort.c new file mode 100644 index 0000000..722058f --- /dev/null +++ b/avx2-hps2048509/crypto_sort.c @@ -0,0 +1,10 @@ +#include +#include "crypto_sort.h" +#include "djbsort/int32_sort.h" + +// The avx2 version of djbsort (version 20180729) is used for sorting. +// The source can be found at https://sorting.cr.yp.to/ +void crypto_sort(void *array,long long n) +{ + int32_sort(array, n); +} diff --git a/avx2-hps2048509/djbsort/int32_minmax_x86.c b/avx2-hps2048509/djbsort/int32_minmax_x86.c new file mode 100644 index 0000000..c5f3006 --- /dev/null +++ b/avx2-hps2048509/djbsort/int32_minmax_x86.c @@ -0,0 +1,13 @@ +#define int32_MINMAX(a,b) \ +do { \ + int32 temp1; \ + asm( \ + "cmpl %1,%0\n\t" \ + "mov %0,%2\n\t" \ + "cmovg %1,%0\n\t" \ + "cmovg %2,%1\n\t" \ + : "+r"(a), "+r"(b), "=r"(temp1) \ + : \ + : "cc" \ + ); \ +} while(0) diff --git a/avx2-hps2048509/djbsort/int32_sort.h b/avx2-hps2048509/djbsort/int32_sort.h new file mode 100644 index 0000000..d0f5e58 --- /dev/null +++ b/avx2-hps2048509/djbsort/int32_sort.h @@ -0,0 +1,25 @@ +#ifndef int32_sort_H +#define int32_sort_H + +#include + +#define int32_sort djbsort_int32 +#define int32_sort_implementation djbsort_int32_implementation +#define int32_sort_version djbsort_int32_version +#define int32_sort_compiler djbsort_int32_compiler + +#ifdef __cplusplus +extern "C" { +#endif + +extern void int32_sort(int32_t *,long long) __attribute__((visibility("default"))); + +extern const char int32_sort_implementation[] __attribute__((visibility("default"))); +extern const char int32_sort_version[] __attribute__((visibility("default"))); +extern const char int32_sort_compiler[] __attribute__((visibility("default"))); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/avx2-hps2048509/djbsort/sort.c b/avx2-hps2048509/djbsort/sort.c new file mode 100644 index 0000000..ca81bf6 --- /dev/null +++ b/avx2-hps2048509/djbsort/sort.c @@ -0,0 +1,1184 @@ +#include "int32_sort.h" +#define int32 int32_t + +#include +#include "int32_minmax_x86.c" + +typedef __m256i int32x8; +#define int32x8_load(z) _mm256_loadu_si256((__m256i *) (z)) +#define int32x8_store(z,i) _mm256_storeu_si256((__m256i *) (z),(i)) +#define int32x8_min _mm256_min_epi32 +#define int32x8_max _mm256_max_epi32 + +#define int32x8_MINMAX(a,b) \ +do { \ + int32x8 c = int32x8_min(a,b); \ + b = int32x8_max(a,b); \ + a = c; \ +} while(0) + +__attribute__((noinline)) +static void minmax_vector(int32 *x,int32 *y,long long n) +{ + if (n < 8) { + while (n > 0) { + int32_MINMAX(*x,*y); + ++x; + ++y; + --n; + } + return; + } + if (n & 7) { + int32x8 x0 = int32x8_load(x + n - 8); + int32x8 y0 = int32x8_load(y + n - 8); + int32x8_MINMAX(x0,y0); + int32x8_store(x + n - 8,x0); + int32x8_store(y + n - 8,y0); + n &= ~7; + } + do { + int32x8 x0 = int32x8_load(x); + int32x8 y0 = int32x8_load(y); + int32x8_MINMAX(x0,y0); + int32x8_store(x,x0); + int32x8_store(y,y0); + x += 8; + y += 8; + n -= 8; + } while(n); +} + +/* stages 8,4,2,1 of size-16 bitonic merging */ +__attribute__((noinline)) +static void merge16_finish(int32 *x,int32x8 x0,int32x8 x1,int flagdown) +{ + int32x8 b0,b1,c0,c1,mask; + + int32x8_MINMAX(x0,x1); + + b0 = _mm256_permute2x128_si256(x0,x1,0x20); /* A0123B0123 */ + b1 = _mm256_permute2x128_si256(x0,x1,0x31); /* A4567B4567 */ + + int32x8_MINMAX(b0,b1); + + c0 = _mm256_unpacklo_epi64(b0,b1); /* A0145B0145 */ + c1 = _mm256_unpackhi_epi64(b0,b1); /* A2367B2367 */ + + int32x8_MINMAX(c0,c1); + + b0 = _mm256_unpacklo_epi32(c0,c1); /* A0213B0213 */ + b1 = _mm256_unpackhi_epi32(c0,c1); /* A4657B4657 */ + + c0 = _mm256_unpacklo_epi64(b0,b1); /* A0246B0246 */ + c1 = _mm256_unpackhi_epi64(b0,b1); /* A1357B1357 */ + + int32x8_MINMAX(c0,c1); + + b0 = _mm256_unpacklo_epi32(c0,c1); /* A0123B0123 */ + b1 = _mm256_unpackhi_epi32(c0,c1); /* A4567B4567 */ + + x0 = _mm256_permute2x128_si256(b0,b1,0x20); /* A01234567 */ + x1 = _mm256_permute2x128_si256(b0,b1,0x31); /* A01234567 */ + + if (flagdown) { + mask = _mm256_set1_epi32(-1); + x0 ^= mask; + x1 ^= mask; + } + + int32x8_store(&x[0],x0); + int32x8_store(&x[8],x1); +} + +/* stages 64,32 of bitonic merging; n is multiple of 128 */ +__attribute__((noinline)) +static void int32_twostages_32(int32 *x,long long n) +{ + long long i; + + while (n > 0) { + for (i = 0;i < 32;i += 8) { + int32x8 x0 = int32x8_load(&x[i]); + int32x8 x1 = int32x8_load(&x[i+32]); + int32x8 x2 = int32x8_load(&x[i+64]); + int32x8 x3 = int32x8_load(&x[i+96]); + + int32x8_MINMAX(x0,x2); + int32x8_MINMAX(x1,x3); + int32x8_MINMAX(x0,x1); + int32x8_MINMAX(x2,x3); + + int32x8_store(&x[i],x0); + int32x8_store(&x[i+32],x1); + int32x8_store(&x[i+64],x2); + int32x8_store(&x[i+96],x3); + } + x += 128; + n -= 128; + } +} + +/* stages 4q,2q,q of bitonic merging */ +__attribute__((noinline)) +static long long int32_threestages(int32 *x,long long n,long long q) +{ + long long k,i; + + for (k = 0;k + 8*q <= n;k += 8*q) + for (i = k;i < k + q;i += 8) { + int32x8 x0 = int32x8_load(&x[i]); + int32x8 x1 = int32x8_load(&x[i+q]); + int32x8 x2 = int32x8_load(&x[i+2*q]); + int32x8 x3 = int32x8_load(&x[i+3*q]); + int32x8 x4 = int32x8_load(&x[i+4*q]); + int32x8 x5 = int32x8_load(&x[i+5*q]); + int32x8 x6 = int32x8_load(&x[i+6*q]); + int32x8 x7 = int32x8_load(&x[i+7*q]); + + int32x8_MINMAX(x0,x4); + int32x8_MINMAX(x1,x5); + int32x8_MINMAX(x2,x6); + int32x8_MINMAX(x3,x7); + int32x8_MINMAX(x0,x2); + int32x8_MINMAX(x1,x3); + int32x8_MINMAX(x4,x6); + int32x8_MINMAX(x5,x7); + int32x8_MINMAX(x0,x1); + int32x8_MINMAX(x2,x3); + int32x8_MINMAX(x4,x5); + int32x8_MINMAX(x6,x7); + + int32x8_store(&x[i],x0); + int32x8_store(&x[i+q],x1); + int32x8_store(&x[i+2*q],x2); + int32x8_store(&x[i+3*q],x3); + int32x8_store(&x[i+4*q],x4); + int32x8_store(&x[i+5*q],x5); + int32x8_store(&x[i+6*q],x6); + int32x8_store(&x[i+7*q],x7); + } + + return k; +} + +/* n is a power of 2; n >= 8; if n == 8 then flagdown */ +__attribute__((noinline)) +static void int32_sort_2power(int32 *x,long long n,int flagdown) +{ long long p,q,i,j,k; + int32x8 mask; + + if (n == 8) { + int32 x0 = x[0]; + int32 x1 = x[1]; + int32 x2 = x[2]; + int32 x3 = x[3]; + int32 x4 = x[4]; + int32 x5 = x[5]; + int32 x6 = x[6]; + int32 x7 = x[7]; + + /* odd-even sort instead of bitonic sort */ + + int32_MINMAX(x1,x0); + int32_MINMAX(x3,x2); + int32_MINMAX(x2,x0); + int32_MINMAX(x3,x1); + int32_MINMAX(x2,x1); + + int32_MINMAX(x5,x4); + int32_MINMAX(x7,x6); + int32_MINMAX(x6,x4); + int32_MINMAX(x7,x5); + int32_MINMAX(x6,x5); + + int32_MINMAX(x4,x0); + int32_MINMAX(x6,x2); + int32_MINMAX(x4,x2); + + int32_MINMAX(x5,x1); + int32_MINMAX(x7,x3); + int32_MINMAX(x5,x3); + + int32_MINMAX(x2,x1); + int32_MINMAX(x4,x3); + int32_MINMAX(x6,x5); + + x[0] = x0; + x[1] = x1; + x[2] = x2; + x[3] = x3; + x[4] = x4; + x[5] = x5; + x[6] = x6; + x[7] = x7; + return; + } + + if (n == 16) { + int32x8 x0,x1,b0,b1,c0,c1; + + x0 = int32x8_load(&x[0]); + x1 = int32x8_load(&x[8]); + + mask = _mm256_set_epi32(0,0,-1,-1,0,0,-1,-1); + + x0 ^= mask; /* A01234567 */ + x1 ^= mask; /* B01234567 */ + + b0 = _mm256_unpacklo_epi32(x0,x1); /* AB0AB1AB4AB5 */ + b1 = _mm256_unpackhi_epi32(x0,x1); /* AB2AB3AB6AB7 */ + + c0 = _mm256_unpacklo_epi64(b0,b1); /* AB0AB2AB4AB6 */ + c1 = _mm256_unpackhi_epi64(b0,b1); /* AB1AB3AB5AB7 */ + + int32x8_MINMAX(c0,c1); + + mask = _mm256_set_epi32(0,0,-1,-1,-1,-1,0,0); + c0 ^= mask; + c1 ^= mask; + + b0 = _mm256_unpacklo_epi32(c0,c1); /* A01B01A45B45 */ + b1 = _mm256_unpackhi_epi32(c0,c1); /* A23B23A67B67 */ + + int32x8_MINMAX(b0,b1); + + x0 = _mm256_unpacklo_epi64(b0,b1); /* A01234567 */ + x1 = _mm256_unpackhi_epi64(b0,b1); /* B01234567 */ + + b0 = _mm256_unpacklo_epi32(x0,x1); /* AB0AB1AB4AB5 */ + b1 = _mm256_unpackhi_epi32(x0,x1); /* AB2AB3AB6AB7 */ + + c0 = _mm256_unpacklo_epi64(b0,b1); /* AB0AB2AB4AB6 */ + c1 = _mm256_unpackhi_epi64(b0,b1); /* AB1AB3AB5AB7 */ + + int32x8_MINMAX(c0,c1); + + b0 = _mm256_unpacklo_epi32(c0,c1); /* A01B01A45B45 */ + b1 = _mm256_unpackhi_epi32(c0,c1); /* A23B23A67B67 */ + + b0 ^= mask; + b1 ^= mask; + + c0 = _mm256_permute2x128_si256(b0,b1,0x20); /* A01B01A23B23 */ + c1 = _mm256_permute2x128_si256(b0,b1,0x31); /* A45B45A67B67 */ + + int32x8_MINMAX(c0,c1); + + b0 = _mm256_permute2x128_si256(c0,c1,0x20); /* A01B01A45B45 */ + b1 = _mm256_permute2x128_si256(c0,c1,0x31); /* A23B23A67B67 */ + + int32x8_MINMAX(b0,b1); + + x0 = _mm256_unpacklo_epi64(b0,b1); /* A01234567 */ + x1 = _mm256_unpackhi_epi64(b0,b1); /* B01234567 */ + + b0 = _mm256_unpacklo_epi32(x0,x1); /* AB0AB1AB4AB5 */ + b1 = _mm256_unpackhi_epi32(x0,x1); /* AB2AB3AB6AB7 */ + + c0 = _mm256_unpacklo_epi64(b0,b1); /* AB0AB2AB4AB6 */ + c1 = _mm256_unpackhi_epi64(b0,b1); /* AB1AB3AB5AB7 */ + + int32x8_MINMAX(c0,c1); + + b0 = _mm256_unpacklo_epi32(c0,c1); /* A01B01A45B45 */ + b1 = _mm256_unpackhi_epi32(c0,c1); /* A23B23A67B67 */ + + x0 = _mm256_unpacklo_epi64(b0,b1); /* A01234567 */ + x1 = _mm256_unpackhi_epi64(b0,b1); /* B01234567 */ + + mask = _mm256_set1_epi32(-1); + if (flagdown) x1 ^= mask; + else x0 ^= mask; + + merge16_finish(x,x0,x1,flagdown); + return; + } + + if (n == 32) { + int32x8 x0,x1,x2,x3; + + int32_sort_2power(x,16,1); + int32_sort_2power(x + 16,16,0); + + x0 = int32x8_load(&x[0]); + x1 = int32x8_load(&x[8]); + x2 = int32x8_load(&x[16]); + x3 = int32x8_load(&x[24]); + + if (flagdown) { + mask = _mm256_set1_epi32(-1); + x0 ^= mask; + x1 ^= mask; + x2 ^= mask; + x3 ^= mask; + } + + int32x8_MINMAX(x0,x2); + int32x8_MINMAX(x1,x3); + + merge16_finish(x,x0,x1,flagdown); + merge16_finish(x + 16,x2,x3,flagdown); + return; + } + + p = n>>3; + for (i = 0;i < p;i += 8) { + int32x8 x0 = int32x8_load(&x[i]); + int32x8 x2 = int32x8_load(&x[i+2*p]); + int32x8 x4 = int32x8_load(&x[i+4*p]); + int32x8 x6 = int32x8_load(&x[i+6*p]); + + /* odd-even stage instead of bitonic stage */ + + int32x8_MINMAX(x4,x0); + int32x8_MINMAX(x6,x2); + int32x8_MINMAX(x2,x0); + int32x8_MINMAX(x6,x4); + int32x8_MINMAX(x2,x4); + + int32x8_store(&x[i],x0); + int32x8_store(&x[i+2*p],x2); + int32x8_store(&x[i+4*p],x4); + int32x8_store(&x[i+6*p],x6); + + int32x8 x1 = int32x8_load(&x[i+p]); + int32x8 x3 = int32x8_load(&x[i+3*p]); + int32x8 x5 = int32x8_load(&x[i+5*p]); + int32x8 x7 = int32x8_load(&x[i+7*p]); + + int32x8_MINMAX(x1,x5); + int32x8_MINMAX(x3,x7); + int32x8_MINMAX(x1,x3); + int32x8_MINMAX(x5,x7); + int32x8_MINMAX(x5,x3); + + int32x8_store(&x[i+p],x1); + int32x8_store(&x[i+3*p],x3); + int32x8_store(&x[i+5*p],x5); + int32x8_store(&x[i+7*p],x7); + } + + if (n >= 128) { + int flip, flipflip; + + mask = _mm256_set1_epi32(-1); + + for (j = 0;j < n;j += 32) { + int32x8 x0 = int32x8_load(&x[j]); + int32x8 x1 = int32x8_load(&x[j+16]); + x0 ^= mask; + x1 ^= mask; + int32x8_store(&x[j],x0); + int32x8_store(&x[j+16],x1); + } + + p = 8; + for (;;) { /* for p in [8, 16, ..., n/16] */ + q = p>>1; + while (q >= 128) { + int32_threestages(x,n,q >> 2); + q >>= 3; + } + if (q == 64) { + int32_twostages_32(x,n); + q = 16; + } + if (q == 32) { + q = 8; + for (k = 0;k < n;k += 8*q) + for (i = k;i < k + q;i += 8) { + int32x8 x0 = int32x8_load(&x[i]); + int32x8 x1 = int32x8_load(&x[i+q]); + int32x8 x2 = int32x8_load(&x[i+2*q]); + int32x8 x3 = int32x8_load(&x[i+3*q]); + int32x8 x4 = int32x8_load(&x[i+4*q]); + int32x8 x5 = int32x8_load(&x[i+5*q]); + int32x8 x6 = int32x8_load(&x[i+6*q]); + int32x8 x7 = int32x8_load(&x[i+7*q]); + + int32x8_MINMAX(x0,x4); + int32x8_MINMAX(x1,x5); + int32x8_MINMAX(x2,x6); + int32x8_MINMAX(x3,x7); + int32x8_MINMAX(x0,x2); + int32x8_MINMAX(x1,x3); + int32x8_MINMAX(x4,x6); + int32x8_MINMAX(x5,x7); + int32x8_MINMAX(x0,x1); + int32x8_MINMAX(x2,x3); + int32x8_MINMAX(x4,x5); + int32x8_MINMAX(x6,x7); + + int32x8_store(&x[i],x0); + int32x8_store(&x[i+q],x1); + int32x8_store(&x[i+2*q],x2); + int32x8_store(&x[i+3*q],x3); + int32x8_store(&x[i+4*q],x4); + int32x8_store(&x[i+5*q],x5); + int32x8_store(&x[i+6*q],x6); + int32x8_store(&x[i+7*q],x7); + } + q = 4; + } + if (q == 16) { + q = 8; + for (k = 0;k < n;k += 4*q) + for (i = k;i < k + q;i += 8) { + int32x8 x0 = int32x8_load(&x[i]); + int32x8 x1 = int32x8_load(&x[i+q]); + int32x8 x2 = int32x8_load(&x[i+2*q]); + int32x8 x3 = int32x8_load(&x[i+3*q]); + + int32x8_MINMAX(x0,x2); + int32x8_MINMAX(x1,x3); + int32x8_MINMAX(x0,x1); + int32x8_MINMAX(x2,x3); + + int32x8_store(&x[i],x0); + int32x8_store(&x[i+q],x1); + int32x8_store(&x[i+2*q],x2); + int32x8_store(&x[i+3*q],x3); + } + q = 4; + } + if (q == 8) + for (k = 0;k < n;k += q + q) { + int32x8 x0 = int32x8_load(&x[k]); + int32x8 x1 = int32x8_load(&x[k+q]); + + int32x8_MINMAX(x0,x1); + + int32x8_store(&x[k],x0); + int32x8_store(&x[k+q],x1); + } + + q = n>>3; + flip = (p<<1 == q); + flipflip = !flip; + for (j = 0;j < q;j += p + p) { + for (k = j;k < j + p + p;k += p) { + for (i = k;i < k + p;i += 8) { + int32x8 x0 = int32x8_load(&x[i]); + int32x8 x1 = int32x8_load(&x[i+q]); + int32x8 x2 = int32x8_load(&x[i+2*q]); + int32x8 x3 = int32x8_load(&x[i+3*q]); + int32x8 x4 = int32x8_load(&x[i+4*q]); + int32x8 x5 = int32x8_load(&x[i+5*q]); + int32x8 x6 = int32x8_load(&x[i+6*q]); + int32x8 x7 = int32x8_load(&x[i+7*q]); + + int32x8_MINMAX(x0,x1); + int32x8_MINMAX(x2,x3); + int32x8_MINMAX(x4,x5); + int32x8_MINMAX(x6,x7); + int32x8_MINMAX(x0,x2); + int32x8_MINMAX(x1,x3); + int32x8_MINMAX(x4,x6); + int32x8_MINMAX(x5,x7); + int32x8_MINMAX(x0,x4); + int32x8_MINMAX(x1,x5); + int32x8_MINMAX(x2,x6); + int32x8_MINMAX(x3,x7); + + if (flip) { + x0 ^= mask; + x1 ^= mask; + x2 ^= mask; + x3 ^= mask; + x4 ^= mask; + x5 ^= mask; + x6 ^= mask; + x7 ^= mask; + } + + int32x8_store(&x[i],x0); + int32x8_store(&x[i+q],x1); + int32x8_store(&x[i+2*q],x2); + int32x8_store(&x[i+3*q],x3); + int32x8_store(&x[i+4*q],x4); + int32x8_store(&x[i+5*q],x5); + int32x8_store(&x[i+6*q],x6); + int32x8_store(&x[i+7*q],x7); + } + flip ^= 1; + } + flip ^= flipflip; + } + + if (p<<4 == n) break; + p <<= 1; + } + } + + for (p = 4;p >= 1;p >>= 1) { + int32 *z = x; + int32 *target = x + n; + if (p == 4) { + mask = _mm256_set_epi32(0,0,0,0,-1,-1,-1,-1); + while (z != target) { + int32x8 x0 = int32x8_load(&z[0]); + int32x8 x1 = int32x8_load(&z[8]); + x0 ^= mask; + x1 ^= mask; + int32x8_store(&z[0],x0); + int32x8_store(&z[8],x1); + z += 16; + } + } else if (p == 2) { + mask = _mm256_set_epi32(0,0,-1,-1,-1,-1,0,0); + while (z != target) { + int32x8 x0 = int32x8_load(&z[0]); + int32x8 x1 = int32x8_load(&z[8]); + x0 ^= mask; + x1 ^= mask; + int32x8 b0 = _mm256_permute2x128_si256(x0,x1,0x20); + int32x8 b1 = _mm256_permute2x128_si256(x0,x1,0x31); + int32x8_MINMAX(b0,b1); + int32x8 c0 = _mm256_permute2x128_si256(b0,b1,0x20); + int32x8 c1 = _mm256_permute2x128_si256(b0,b1,0x31); + int32x8_store(&z[0],c0); + int32x8_store(&z[8],c1); + z += 16; + } + } else { /* p == 1 */ + mask = _mm256_set_epi32(0,-1,-1,0,0,-1,-1,0); + while (z != target) { + int32x8 x0 = int32x8_load(&z[0]); + int32x8 x1 = int32x8_load(&z[8]); + x0 ^= mask; + x1 ^= mask; + int32x8 b0 = _mm256_permute2x128_si256(x0,x1,0x20); /* A0123B0123 */ + int32x8 b1 = _mm256_permute2x128_si256(x0,x1,0x31); /* A4567B4567 */ + int32x8 c0 = _mm256_unpacklo_epi64(b0,b1); /* A0145B0145 */ + int32x8 c1 = _mm256_unpackhi_epi64(b0,b1); /* A2367B2367 */ + int32x8_MINMAX(c0,c1); + int32x8 d0 = _mm256_unpacklo_epi64(c0,c1); /* A0123B0123 */ + int32x8 d1 = _mm256_unpackhi_epi64(c0,c1); /* A4567B4567 */ + int32x8_MINMAX(d0,d1); + int32x8 e0 = _mm256_permute2x128_si256(d0,d1,0x20); + int32x8 e1 = _mm256_permute2x128_si256(d0,d1,0x31); + int32x8_store(&z[0],e0); + int32x8_store(&z[8],e1); + z += 16; + } + } + + q = n>>4; + while (q >= 128 || q == 32) { + int32_threestages(x,n,q>>2); + q >>= 3; + } + while (q >= 16) { + q >>= 1; + for (j = 0;j < n;j += 4*q) + for (k = j;k < j + q;k += 8) { + int32x8 x0 = int32x8_load(&x[k]); + int32x8 x1 = int32x8_load(&x[k+q]); + int32x8 x2 = int32x8_load(&x[k+2*q]); + int32x8 x3 = int32x8_load(&x[k+3*q]); + + int32x8_MINMAX(x0,x2); + int32x8_MINMAX(x1,x3); + int32x8_MINMAX(x0,x1); + int32x8_MINMAX(x2,x3); + + int32x8_store(&x[k],x0); + int32x8_store(&x[k+q],x1); + int32x8_store(&x[k+2*q],x2); + int32x8_store(&x[k+3*q],x3); + } + q >>= 1; + } + if (q == 8) { + for (j = 0;j < n;j += 2*q) { + int32x8 x0 = int32x8_load(&x[j]); + int32x8 x1 = int32x8_load(&x[j+q]); + + int32x8_MINMAX(x0,x1); + + int32x8_store(&x[j],x0); + int32x8_store(&x[j+q],x1); + } + } + + q = n>>3; + for (k = 0;k < q;k += 8) { + int32x8 x0 = int32x8_load(&x[k]); + int32x8 x1 = int32x8_load(&x[k+q]); + int32x8 x2 = int32x8_load(&x[k+2*q]); + int32x8 x3 = int32x8_load(&x[k+3*q]); + int32x8 x4 = int32x8_load(&x[k+4*q]); + int32x8 x5 = int32x8_load(&x[k+5*q]); + int32x8 x6 = int32x8_load(&x[k+6*q]); + int32x8 x7 = int32x8_load(&x[k+7*q]); + + int32x8_MINMAX(x0,x1); + int32x8_MINMAX(x2,x3); + int32x8_MINMAX(x4,x5); + int32x8_MINMAX(x6,x7); + int32x8_MINMAX(x0,x2); + int32x8_MINMAX(x1,x3); + int32x8_MINMAX(x4,x6); + int32x8_MINMAX(x5,x7); + int32x8_MINMAX(x0,x4); + int32x8_MINMAX(x1,x5); + int32x8_MINMAX(x2,x6); + int32x8_MINMAX(x3,x7); + + int32x8_store(&x[k],x0); + int32x8_store(&x[k+q],x1); + int32x8_store(&x[k+2*q],x2); + int32x8_store(&x[k+3*q],x3); + int32x8_store(&x[k+4*q],x4); + int32x8_store(&x[k+5*q],x5); + int32x8_store(&x[k+6*q],x6); + int32x8_store(&x[k+7*q],x7); + } + } + + /* everything is still masked with _mm256_set_epi32(0,-1,0,-1,0,-1,0,-1); */ + mask = _mm256_set1_epi32(-1); + + for (i = 0;i < n;i += 64) { + int32x8 a0 = int32x8_load(&x[i]); + int32x8 a1 = int32x8_load(&x[i+8]); + int32x8 a2 = int32x8_load(&x[i+16]); + int32x8 a3 = int32x8_load(&x[i+24]); + int32x8 a4 = int32x8_load(&x[i+32]); + int32x8 a5 = int32x8_load(&x[i+40]); + int32x8 a6 = int32x8_load(&x[i+48]); + int32x8 a7 = int32x8_load(&x[i+56]); + + int32x8 b0 = _mm256_unpacklo_epi32(a0,a1); /* AB0AB1AB4AB5 */ + int32x8 b1 = _mm256_unpackhi_epi32(a0,a1); /* AB2AB3AB6AB7 */ + int32x8 b2 = _mm256_unpacklo_epi32(a2,a3); /* CD0CD1CD4CD5 */ + int32x8 b3 = _mm256_unpackhi_epi32(a2,a3); /* CD2CD3CD6CD7 */ + int32x8 b4 = _mm256_unpacklo_epi32(a4,a5); /* EF0EF1EF4EF5 */ + int32x8 b5 = _mm256_unpackhi_epi32(a4,a5); /* EF2EF3EF6EF7 */ + int32x8 b6 = _mm256_unpacklo_epi32(a6,a7); /* GH0GH1GH4GH5 */ + int32x8 b7 = _mm256_unpackhi_epi32(a6,a7); /* GH2GH3GH6GH7 */ + + int32x8 c0 = _mm256_unpacklo_epi64(b0,b2); /* ABCD0ABCD4 */ + int32x8 c1 = _mm256_unpacklo_epi64(b1,b3); /* ABCD2ABCD6 */ + int32x8 c2 = _mm256_unpackhi_epi64(b0,b2); /* ABCD1ABCD5 */ + int32x8 c3 = _mm256_unpackhi_epi64(b1,b3); /* ABCD3ABCD7 */ + int32x8 c4 = _mm256_unpacklo_epi64(b4,b6); /* EFGH0EFGH4 */ + int32x8 c5 = _mm256_unpacklo_epi64(b5,b7); /* EFGH2EFGH6 */ + int32x8 c6 = _mm256_unpackhi_epi64(b4,b6); /* EFGH1EFGH5 */ + int32x8 c7 = _mm256_unpackhi_epi64(b5,b7); /* EFGH3EFGH7 */ + + if (flagdown) { + c2 ^= mask; + c3 ^= mask; + c6 ^= mask; + c7 ^= mask; + } else { + c0 ^= mask; + c1 ^= mask; + c4 ^= mask; + c5 ^= mask; + } + + int32x8 d0 = _mm256_permute2x128_si256(c0,c4,0x20); /* ABCDEFGH0 */ + int32x8 d1 = _mm256_permute2x128_si256(c2,c6,0x20); /* ABCDEFGH1 */ + int32x8 d2 = _mm256_permute2x128_si256(c1,c5,0x20); /* ABCDEFGH2 */ + int32x8 d3 = _mm256_permute2x128_si256(c3,c7,0x20); /* ABCDEFGH5 */ + int32x8 d4 = _mm256_permute2x128_si256(c0,c4,0x31); /* ABCDEFGH4 */ + int32x8 d5 = _mm256_permute2x128_si256(c2,c6,0x31); /* ABCDEFGH3 */ + int32x8 d6 = _mm256_permute2x128_si256(c1,c5,0x31); /* ABCDEFGH6 */ + int32x8 d7 = _mm256_permute2x128_si256(c3,c7,0x31); /* ABCDEFGH7 */ + + int32x8_MINMAX(d0,d1); + int32x8_MINMAX(d2,d3); + int32x8_MINMAX(d4,d5); + int32x8_MINMAX(d6,d7); + int32x8_MINMAX(d0,d2); + int32x8_MINMAX(d1,d3); + int32x8_MINMAX(d4,d6); + int32x8_MINMAX(d5,d7); + int32x8_MINMAX(d0,d4); + int32x8_MINMAX(d1,d5); + int32x8_MINMAX(d2,d6); + int32x8_MINMAX(d3,d7); + + int32x8 e0 = _mm256_unpacklo_epi32(d0,d1); + int32x8 e1 = _mm256_unpackhi_epi32(d0,d1); + int32x8 e2 = _mm256_unpacklo_epi32(d2,d3); + int32x8 e3 = _mm256_unpackhi_epi32(d2,d3); + int32x8 e4 = _mm256_unpacklo_epi32(d4,d5); + int32x8 e5 = _mm256_unpackhi_epi32(d4,d5); + int32x8 e6 = _mm256_unpacklo_epi32(d6,d7); + int32x8 e7 = _mm256_unpackhi_epi32(d6,d7); + + int32x8 f0 = _mm256_unpacklo_epi64(e0,e2); + int32x8 f1 = _mm256_unpacklo_epi64(e1,e3); + int32x8 f2 = _mm256_unpackhi_epi64(e0,e2); + int32x8 f3 = _mm256_unpackhi_epi64(e1,e3); + int32x8 f4 = _mm256_unpacklo_epi64(e4,e6); + int32x8 f5 = _mm256_unpacklo_epi64(e5,e7); + int32x8 f6 = _mm256_unpackhi_epi64(e4,e6); + int32x8 f7 = _mm256_unpackhi_epi64(e5,e7); + + int32x8 g0 = _mm256_permute2x128_si256(f0,f4,0x20); + int32x8 g1 = _mm256_permute2x128_si256(f2,f6,0x20); + int32x8 g2 = _mm256_permute2x128_si256(f1,f5,0x20); + int32x8 g3 = _mm256_permute2x128_si256(f3,f7,0x20); + int32x8 g4 = _mm256_permute2x128_si256(f0,f4,0x31); + int32x8 g5 = _mm256_permute2x128_si256(f2,f6,0x31); + int32x8 g6 = _mm256_permute2x128_si256(f1,f5,0x31); + int32x8 g7 = _mm256_permute2x128_si256(f3,f7,0x31); + + int32x8_store(&x[i],g0); + int32x8_store(&x[i+8],g1); + int32x8_store(&x[i+16],g2); + int32x8_store(&x[i+24],g3); + int32x8_store(&x[i+32],g4); + int32x8_store(&x[i+40],g5); + int32x8_store(&x[i+48],g6); + int32x8_store(&x[i+56],g7); + } + + q = n>>4; + while (q >= 128 || q == 32) { + q >>= 2; + for (j = 0;j < n;j += 8*q) + for (i = j;i < j + q;i += 8) { + int32x8 x0 = int32x8_load(&x[i]); + int32x8 x1 = int32x8_load(&x[i+q]); + int32x8 x2 = int32x8_load(&x[i+2*q]); + int32x8 x3 = int32x8_load(&x[i+3*q]); + int32x8 x4 = int32x8_load(&x[i+4*q]); + int32x8 x5 = int32x8_load(&x[i+5*q]); + int32x8 x6 = int32x8_load(&x[i+6*q]); + int32x8 x7 = int32x8_load(&x[i+7*q]); + int32x8_MINMAX(x0,x4); + int32x8_MINMAX(x1,x5); + int32x8_MINMAX(x2,x6); + int32x8_MINMAX(x3,x7); + int32x8_MINMAX(x0,x2); + int32x8_MINMAX(x1,x3); + int32x8_MINMAX(x4,x6); + int32x8_MINMAX(x5,x7); + int32x8_MINMAX(x0,x1); + int32x8_MINMAX(x2,x3); + int32x8_MINMAX(x4,x5); + int32x8_MINMAX(x6,x7); + int32x8_store(&x[i],x0); + int32x8_store(&x[i+q],x1); + int32x8_store(&x[i+2*q],x2); + int32x8_store(&x[i+3*q],x3); + int32x8_store(&x[i+4*q],x4); + int32x8_store(&x[i+5*q],x5); + int32x8_store(&x[i+6*q],x6); + int32x8_store(&x[i+7*q],x7); + } + q >>= 1; + } + while (q >= 16) { + q >>= 1; + for (j = 0;j < n;j += 4*q) + for (i = j;i < j + q;i += 8) { + int32x8 x0 = int32x8_load(&x[i]); + int32x8 x1 = int32x8_load(&x[i+q]); + int32x8 x2 = int32x8_load(&x[i+2*q]); + int32x8 x3 = int32x8_load(&x[i+3*q]); + int32x8_MINMAX(x0,x2); + int32x8_MINMAX(x1,x3); + int32x8_MINMAX(x0,x1); + int32x8_MINMAX(x2,x3); + int32x8_store(&x[i],x0); + int32x8_store(&x[i+q],x1); + int32x8_store(&x[i+2*q],x2); + int32x8_store(&x[i+3*q],x3); + } + q >>= 1; + } + if (q == 8) + for (j = 0;j < n;j += q + q) { + int32x8 x0 = int32x8_load(&x[j]); + int32x8 x1 = int32x8_load(&x[j+q]); + int32x8_MINMAX(x0,x1); + int32x8_store(&x[j],x0); + int32x8_store(&x[j+q],x1); + } + + q = n>>3; + for (i = 0;i < q;i += 8) { + int32x8 x0 = int32x8_load(&x[i]); + int32x8 x1 = int32x8_load(&x[i+q]); + int32x8 x2 = int32x8_load(&x[i+2*q]); + int32x8 x3 = int32x8_load(&x[i+3*q]); + int32x8 x4 = int32x8_load(&x[i+4*q]); + int32x8 x5 = int32x8_load(&x[i+5*q]); + int32x8 x6 = int32x8_load(&x[i+6*q]); + int32x8 x7 = int32x8_load(&x[i+7*q]); + + int32x8_MINMAX(x0,x1); + int32x8_MINMAX(x2,x3); + int32x8_MINMAX(x4,x5); + int32x8_MINMAX(x6,x7); + int32x8_MINMAX(x0,x2); + int32x8_MINMAX(x1,x3); + int32x8_MINMAX(x4,x6); + int32x8_MINMAX(x5,x7); + int32x8_MINMAX(x0,x4); + int32x8_MINMAX(x1,x5); + int32x8_MINMAX(x2,x6); + int32x8_MINMAX(x3,x7); + + int32x8 b0 = _mm256_unpacklo_epi32(x0,x4); /* AE0AE1AE4AE5 */ + int32x8 b1 = _mm256_unpackhi_epi32(x0,x4); /* AE2AE3AE6AE7 */ + int32x8 b2 = _mm256_unpacklo_epi32(x1,x5); /* BF0BF1BF4BF5 */ + int32x8 b3 = _mm256_unpackhi_epi32(x1,x5); /* BF2BF3BF6BF7 */ + int32x8 b4 = _mm256_unpacklo_epi32(x2,x6); /* CG0CG1CG4CG5 */ + int32x8 b5 = _mm256_unpackhi_epi32(x2,x6); /* CG2CG3CG6CG7 */ + int32x8 b6 = _mm256_unpacklo_epi32(x3,x7); /* DH0DH1DH4DH5 */ + int32x8 b7 = _mm256_unpackhi_epi32(x3,x7); /* DH2DH3DH6DH7 */ + + int32x8 c0 = _mm256_unpacklo_epi64(b0,b4); /* AECG0AECG4 */ + int32x8 c1 = _mm256_unpacklo_epi64(b1,b5); /* AECG2AECG6 */ + int32x8 c2 = _mm256_unpackhi_epi64(b0,b4); /* AECG1AECG5 */ + int32x8 c3 = _mm256_unpackhi_epi64(b1,b5); /* AECG3AECG7 */ + int32x8 c4 = _mm256_unpacklo_epi64(b2,b6); /* BFDH0BFDH4 */ + int32x8 c5 = _mm256_unpacklo_epi64(b3,b7); /* BFDH2BFDH6 */ + int32x8 c6 = _mm256_unpackhi_epi64(b2,b6); /* BFDH1BFDH5 */ + int32x8 c7 = _mm256_unpackhi_epi64(b3,b7); /* BFDH3BFDH7 */ + + int32x8 d0 = _mm256_permute2x128_si256(c0,c4,0x20); /* AECGBFDH0 */ + int32x8 d1 = _mm256_permute2x128_si256(c1,c5,0x20); /* AECGBFDH2 */ + int32x8 d2 = _mm256_permute2x128_si256(c2,c6,0x20); /* AECGBFDH1 */ + int32x8 d3 = _mm256_permute2x128_si256(c3,c7,0x20); /* AECGBFDH3 */ + int32x8 d4 = _mm256_permute2x128_si256(c0,c4,0x31); /* AECGBFDH4 */ + int32x8 d5 = _mm256_permute2x128_si256(c1,c5,0x31); /* AECGBFDH6 */ + int32x8 d6 = _mm256_permute2x128_si256(c2,c6,0x31); /* AECGBFDH5 */ + int32x8 d7 = _mm256_permute2x128_si256(c3,c7,0x31); /* AECGBFDH7 */ + + if (flagdown) { + d0 ^= mask; + d1 ^= mask; + d2 ^= mask; + d3 ^= mask; + d4 ^= mask; + d5 ^= mask; + d6 ^= mask; + d7 ^= mask; + } + + int32x8_store(&x[i],d0); + int32x8_store(&x[i+q],d4); + int32x8_store(&x[i+2*q],d1); + int32x8_store(&x[i+3*q],d5); + int32x8_store(&x[i+4*q],d2); + int32x8_store(&x[i+5*q],d6); + int32x8_store(&x[i+6*q],d3); + int32x8_store(&x[i+7*q],d7); + } +} + +void int32_sort(int32 *x,long long n) +{ long long q,i,j; + + if (n <= 8) { + if (n == 8) { + int32_MINMAX(x[0],x[1]); + int32_MINMAX(x[1],x[2]); + int32_MINMAX(x[2],x[3]); + int32_MINMAX(x[3],x[4]); + int32_MINMAX(x[4],x[5]); + int32_MINMAX(x[5],x[6]); + int32_MINMAX(x[6],x[7]); + } + if (n >= 7) { + int32_MINMAX(x[0],x[1]); + int32_MINMAX(x[1],x[2]); + int32_MINMAX(x[2],x[3]); + int32_MINMAX(x[3],x[4]); + int32_MINMAX(x[4],x[5]); + int32_MINMAX(x[5],x[6]); + } + if (n >= 6) { + int32_MINMAX(x[0],x[1]); + int32_MINMAX(x[1],x[2]); + int32_MINMAX(x[2],x[3]); + int32_MINMAX(x[3],x[4]); + int32_MINMAX(x[4],x[5]); + } + if (n >= 5) { + int32_MINMAX(x[0],x[1]); + int32_MINMAX(x[1],x[2]); + int32_MINMAX(x[2],x[3]); + int32_MINMAX(x[3],x[4]); + } + if (n >= 4) { + int32_MINMAX(x[0],x[1]); + int32_MINMAX(x[1],x[2]); + int32_MINMAX(x[2],x[3]); + } + if (n >= 3) { + int32_MINMAX(x[0],x[1]); + int32_MINMAX(x[1],x[2]); + } + if (n >= 2) { + int32_MINMAX(x[0],x[1]); + } + return; + } + + if (!(n & (n - 1))) { + int32_sort_2power(x,n,0); + return; + } + + q = 8; + while (q < n - q) q += q; + /* n > q >= 8 */ + + if (q <= 128) { /* n <= 256 */ + int32x8 y[32]; + for (i = q>>3;i < q>>2;++i) y[i] = _mm256_set1_epi32(0x7fffffff); + for (i = 0;i < n;++i) i[(int32 *) y] = x[i]; + int32_sort_2power((int32 *) y,2*q,0); + for (i = 0;i < n;++i) x[i] = i[(int32 *) y]; + return; + } + + int32_sort_2power(x,q,1); + int32_sort(x + q,n - q); + + while (q >= 64) { + q >>= 2; + j = int32_threestages(x,n,q); + minmax_vector(x + j,x + j + 4*q,n - 4*q - j); + if (j + 4*q <= n) { + for (i = j;i < j + q;i += 8) { + int32x8 x0 = int32x8_load(&x[i]); + int32x8 x1 = int32x8_load(&x[i+q]); + int32x8 x2 = int32x8_load(&x[i+2*q]); + int32x8 x3 = int32x8_load(&x[i+3*q]); + int32x8_MINMAX(x0,x2); + int32x8_MINMAX(x1,x3); + int32x8_MINMAX(x0,x1); + int32x8_MINMAX(x2,x3); + int32x8_store(&x[i],x0); + int32x8_store(&x[i+q],x1); + int32x8_store(&x[i+2*q],x2); + int32x8_store(&x[i+3*q],x3); + } + j += 4*q; + } + minmax_vector(x + j,x + j + 2*q,n - 2*q - j); + if (j + 2*q <= n) { + for (i = j;i < j + q;i += 8) { + int32x8 x0 = int32x8_load(&x[i]); + int32x8 x1 = int32x8_load(&x[i+q]); + int32x8_MINMAX(x0,x1); + int32x8_store(&x[i],x0); + int32x8_store(&x[i+q],x1); + } + j += 2*q; + } + minmax_vector(x + j,x + j + q,n - q - j); + q >>= 1; + } + if (q == 32) { + j = 0; + for (;j + 64 <= n;j += 64) { + int32x8 x0 = int32x8_load(&x[j]); + int32x8 x1 = int32x8_load(&x[j+8]); + int32x8 x2 = int32x8_load(&x[j+16]); + int32x8 x3 = int32x8_load(&x[j+24]); + int32x8 x4 = int32x8_load(&x[j+32]); + int32x8 x5 = int32x8_load(&x[j+40]); + int32x8 x6 = int32x8_load(&x[j+48]); + int32x8 x7 = int32x8_load(&x[j+56]); + int32x8_MINMAX(x0,x4); + int32x8_MINMAX(x1,x5); + int32x8_MINMAX(x2,x6); + int32x8_MINMAX(x3,x7); + int32x8_MINMAX(x0,x2); + int32x8_MINMAX(x1,x3); + int32x8_MINMAX(x4,x6); + int32x8_MINMAX(x5,x7); + int32x8_MINMAX(x0,x1); + int32x8_MINMAX(x2,x3); + int32x8_MINMAX(x4,x5); + int32x8_MINMAX(x6,x7); + int32x8 a0 = _mm256_permute2x128_si256(x0,x1,0x20); + int32x8 a1 = _mm256_permute2x128_si256(x0,x1,0x31); + int32x8 a2 = _mm256_permute2x128_si256(x2,x3,0x20); + int32x8 a3 = _mm256_permute2x128_si256(x2,x3,0x31); + int32x8 a4 = _mm256_permute2x128_si256(x4,x5,0x20); + int32x8 a5 = _mm256_permute2x128_si256(x4,x5,0x31); + int32x8 a6 = _mm256_permute2x128_si256(x6,x7,0x20); + int32x8 a7 = _mm256_permute2x128_si256(x6,x7,0x31); + int32x8_MINMAX(a0,a1); + int32x8_MINMAX(a2,a3); + int32x8_MINMAX(a4,a5); + int32x8_MINMAX(a6,a7); + int32x8 b0 = _mm256_permute2x128_si256(a0,a1,0x20); + int32x8 b1 = _mm256_permute2x128_si256(a0,a1,0x31); + int32x8 b2 = _mm256_permute2x128_si256(a2,a3,0x20); + int32x8 b3 = _mm256_permute2x128_si256(a2,a3,0x31); + int32x8 b4 = _mm256_permute2x128_si256(a4,a5,0x20); + int32x8 b5 = _mm256_permute2x128_si256(a4,a5,0x31); + int32x8 b6 = _mm256_permute2x128_si256(a6,a7,0x20); + int32x8 b7 = _mm256_permute2x128_si256(a6,a7,0x31); + int32x8 c0 = _mm256_unpacklo_epi64(b0,b1); + int32x8 c1 = _mm256_unpackhi_epi64(b0,b1); + int32x8 c2 = _mm256_unpacklo_epi64(b2,b3); + int32x8 c3 = _mm256_unpackhi_epi64(b2,b3); + int32x8 c4 = _mm256_unpacklo_epi64(b4,b5); + int32x8 c5 = _mm256_unpackhi_epi64(b4,b5); + int32x8 c6 = _mm256_unpacklo_epi64(b6,b7); + int32x8 c7 = _mm256_unpackhi_epi64(b6,b7); + int32x8_MINMAX(c0,c1); + int32x8_MINMAX(c2,c3); + int32x8_MINMAX(c4,c5); + int32x8_MINMAX(c6,c7); + int32x8 d0 = _mm256_unpacklo_epi32(c0,c1); + int32x8 d1 = _mm256_unpackhi_epi32(c0,c1); + int32x8 d2 = _mm256_unpacklo_epi32(c2,c3); + int32x8 d3 = _mm256_unpackhi_epi32(c2,c3); + int32x8 d4 = _mm256_unpacklo_epi32(c4,c5); + int32x8 d5 = _mm256_unpackhi_epi32(c4,c5); + int32x8 d6 = _mm256_unpacklo_epi32(c6,c7); + int32x8 d7 = _mm256_unpackhi_epi32(c6,c7); + int32x8 e0 = _mm256_unpacklo_epi64(d0,d1); + int32x8 e1 = _mm256_unpackhi_epi64(d0,d1); + int32x8 e2 = _mm256_unpacklo_epi64(d2,d3); + int32x8 e3 = _mm256_unpackhi_epi64(d2,d3); + int32x8 e4 = _mm256_unpacklo_epi64(d4,d5); + int32x8 e5 = _mm256_unpackhi_epi64(d4,d5); + int32x8 e6 = _mm256_unpacklo_epi64(d6,d7); + int32x8 e7 = _mm256_unpackhi_epi64(d6,d7); + int32x8_MINMAX(e0,e1); + int32x8_MINMAX(e2,e3); + int32x8_MINMAX(e4,e5); + int32x8_MINMAX(e6,e7); + int32x8 f0 = _mm256_unpacklo_epi32(e0,e1); + int32x8 f1 = _mm256_unpackhi_epi32(e0,e1); + int32x8 f2 = _mm256_unpacklo_epi32(e2,e3); + int32x8 f3 = _mm256_unpackhi_epi32(e2,e3); + int32x8 f4 = _mm256_unpacklo_epi32(e4,e5); + int32x8 f5 = _mm256_unpackhi_epi32(e4,e5); + int32x8 f6 = _mm256_unpacklo_epi32(e6,e7); + int32x8 f7 = _mm256_unpackhi_epi32(e6,e7); + int32x8_store(&x[j],f0); + int32x8_store(&x[j+8],f1); + int32x8_store(&x[j+16],f2); + int32x8_store(&x[j+24],f3); + int32x8_store(&x[j+32],f4); + int32x8_store(&x[j+40],f5); + int32x8_store(&x[j+48],f6); + int32x8_store(&x[j+56],f7); + } + minmax_vector(x + j,x + j + 32,n - 32 - j); + goto continue16; + } + if (q == 16) { + j = 0; + continue16: + for (;j + 32 <= n;j += 32) { + int32x8 x0 = int32x8_load(&x[j]); + int32x8 x1 = int32x8_load(&x[j+8]); + int32x8 x2 = int32x8_load(&x[j+16]); + int32x8 x3 = int32x8_load(&x[j+24]); + int32x8_MINMAX(x0,x2); + int32x8_MINMAX(x1,x3); + int32x8_MINMAX(x0,x1); + int32x8_MINMAX(x2,x3); + int32x8 a0 = _mm256_permute2x128_si256(x0,x1,0x20); + int32x8 a1 = _mm256_permute2x128_si256(x0,x1,0x31); + int32x8 a2 = _mm256_permute2x128_si256(x2,x3,0x20); + int32x8 a3 = _mm256_permute2x128_si256(x2,x3,0x31); + int32x8_MINMAX(a0,a1); + int32x8_MINMAX(a2,a3); + int32x8 b0 = _mm256_permute2x128_si256(a0,a1,0x20); + int32x8 b1 = _mm256_permute2x128_si256(a0,a1,0x31); + int32x8 b2 = _mm256_permute2x128_si256(a2,a3,0x20); + int32x8 b3 = _mm256_permute2x128_si256(a2,a3,0x31); + int32x8 c0 = _mm256_unpacklo_epi64(b0,b1); + int32x8 c1 = _mm256_unpackhi_epi64(b0,b1); + int32x8 c2 = _mm256_unpacklo_epi64(b2,b3); + int32x8 c3 = _mm256_unpackhi_epi64(b2,b3); + int32x8_MINMAX(c0,c1); + int32x8_MINMAX(c2,c3); + int32x8 d0 = _mm256_unpacklo_epi32(c0,c1); + int32x8 d1 = _mm256_unpackhi_epi32(c0,c1); + int32x8 d2 = _mm256_unpacklo_epi32(c2,c3); + int32x8 d3 = _mm256_unpackhi_epi32(c2,c3); + int32x8 e0 = _mm256_unpacklo_epi64(d0,d1); + int32x8 e1 = _mm256_unpackhi_epi64(d0,d1); + int32x8 e2 = _mm256_unpacklo_epi64(d2,d3); + int32x8 e3 = _mm256_unpackhi_epi64(d2,d3); + int32x8_MINMAX(e0,e1); + int32x8_MINMAX(e2,e3); + int32x8 f0 = _mm256_unpacklo_epi32(e0,e1); + int32x8 f1 = _mm256_unpackhi_epi32(e0,e1); + int32x8 f2 = _mm256_unpacklo_epi32(e2,e3); + int32x8 f3 = _mm256_unpackhi_epi32(e2,e3); + int32x8_store(&x[j],f0); + int32x8_store(&x[j+8],f1); + int32x8_store(&x[j+16],f2); + int32x8_store(&x[j+24],f3); + } + minmax_vector(x + j,x + j + 16,n - 16 - j); + goto continue8; + } + /* q == 8 */ + j = 0; + continue8: + for (;j + 16 <= n;j += 16) { + int32x8 x0 = int32x8_load(&x[j]); + int32x8 x1 = int32x8_load(&x[j+8]); + int32x8_MINMAX(x0,x1); + int32x8_store(&x[j],x0); + int32x8_store(&x[j+8],x1); + int32x8 a0 = _mm256_permute2x128_si256(x0,x1,0x20); /* x0123y0123 */ + int32x8 a1 = _mm256_permute2x128_si256(x0,x1,0x31); /* x4567y4567 */ + int32x8_MINMAX(a0,a1); + int32x8 b0 = _mm256_permute2x128_si256(a0,a1,0x20); /* x01234567 */ + int32x8 b1 = _mm256_permute2x128_si256(a0,a1,0x31); /* y01234567 */ + int32x8 c0 = _mm256_unpacklo_epi64(b0,b1); /* x01y01x45y45 */ + int32x8 c1 = _mm256_unpackhi_epi64(b0,b1); /* x23y23x67y67 */ + int32x8_MINMAX(c0,c1); + int32x8 d0 = _mm256_unpacklo_epi32(c0,c1); /* x02x13x46x57 */ + int32x8 d1 = _mm256_unpackhi_epi32(c0,c1); /* y02y13y46y57 */ + int32x8 e0 = _mm256_unpacklo_epi64(d0,d1); /* x02y02x46y46 */ + int32x8 e1 = _mm256_unpackhi_epi64(d0,d1); /* x13y13x57y57 */ + int32x8_MINMAX(e0,e1); + int32x8 f0 = _mm256_unpacklo_epi32(e0,e1); /* x01234567 */ + int32x8 f1 = _mm256_unpackhi_epi32(e0,e1); /* y01234567 */ + int32x8_store(&x[j],f0); + int32x8_store(&x[j+8],f1); + } + minmax_vector(x + j,x + j + 8,n - 8 - j); + if (j + 8 <= n) { + int32_MINMAX(x[j],x[j+4]); + int32_MINMAX(x[j+1],x[j+5]); + int32_MINMAX(x[j+2],x[j+6]); + int32_MINMAX(x[j+3],x[j+7]); + int32_MINMAX(x[j],x[j+2]); + int32_MINMAX(x[j+1],x[j+3]); + int32_MINMAX(x[j],x[j+1]); + int32_MINMAX(x[j+2],x[j+3]); + int32_MINMAX(x[j+4],x[j+6]); + int32_MINMAX(x[j+5],x[j+7]); + int32_MINMAX(x[j+4],x[j+5]); + int32_MINMAX(x[j+6],x[j+7]); + j += 8; + } + minmax_vector(x + j,x + j + 4,n - 4 - j); + if (j + 4 <= n) { + int32_MINMAX(x[j],x[j+2]); + int32_MINMAX(x[j+1],x[j+3]); + int32_MINMAX(x[j],x[j+1]); + int32_MINMAX(x[j+2],x[j+3]); + j += 4; + } + if (j + 3 <= n) + int32_MINMAX(x[j],x[j+2]); + if (j + 2 <= n) + int32_MINMAX(x[j],x[j+1]); +} From 3d726b1995af3fa14fb5b84e09aa236c6a7a4108 Mon Sep 17 00:00:00 2001 From: OussamaDanba Date: Tue, 14 May 2019 12:49:02 +0200 Subject: [PATCH 4/5] Include missing files in Makefile-NIST --- avx2-hps2048509/Makefile-NIST | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/avx2-hps2048509/Makefile-NIST b/avx2-hps2048509/Makefile-NIST index d7d3d9a..8e84bd4 100644 --- a/avx2-hps2048509/Makefile-NIST +++ b/avx2-hps2048509/Makefile-NIST @@ -2,10 +2,10 @@ CC=/usr/bin/gcc CFLAGS=-O3 -fomit-frame-pointer -march=native -no-pie LDFLAGS=-lcrypto -SOURCES = crypto_sort.c fips202.c kem.c owcpa.c pack3.c packq.c poly.c sample.c verify.c rng.c PQCgenKAT_kem.c \ - poly_rq_mul.s +SOURCES = crypto_sort.c djbsort/sort.c fips202.c kem.c owcpa.c pack3.c packq.c poly.c sample.c verify.c rng.c \ + PQCgenKAT_kem.c poly_rq_mul.s poly_s3_mul.s poly_rq_mul_x_minus_1.s -HEADERS = api.h crypto_sort.h fips202.h kem.h poly.h owcpa.h params.h sample.h verify.h rng.h +HEADERS = api.h crypto_sort.h djbsort/int32_sort.h fips202.h kem.h poly.h owcpa.h params.h sample.h verify.h rng.h PQCgenKAT_kem: $(HEADERS) $(SOURCES) $(CC) $(CFLAGS) -o $@ $(SOURCES) $(LDFLAGS) From d47376c792a1d72f538955a7f535f9efb84eaa81 Mon Sep 17 00:00:00 2001 From: OussamaDanba Date: Wed, 15 May 2019 16:20:42 +0200 Subject: [PATCH 5/5] Implement avx2 version of poly_S3_inv for hps2048509 Explanation: An explanation of the method can be found in the paper "Fast constant-time gcd computation and modular inversion". There are no large changes from the case study described in that paper except that every polynomial is stored as four 256-bit vectors rather than six since it fits within that space for hps2048509. This reduces the amount of some vector operations. Results: On an Intel i5-8250u using gcc 8.3.0 the reference poly_S3_inv takes about 1526800 cycles on average whereas the avx2 version takes about 23250 cycles on average. This is about a 65 times speedup. --- avx2-hps2048509/Makefile | 4 +- avx2-hps2048509/Makefile-NIST | 6 +- avx2-hps2048509/poly.c | 79 +----- avx2-hps2048509/poly_s3_inv.c | 441 ++++++++++++++++++++++++++++++++++ avx2-hps2048509/poly_s3_inv.h | 8 + 5 files changed, 455 insertions(+), 83 deletions(-) create mode 100644 avx2-hps2048509/poly_s3_inv.c create mode 100644 avx2-hps2048509/poly_s3_inv.h diff --git a/avx2-hps2048509/Makefile b/avx2-hps2048509/Makefile index f95b6c5..866558c 100644 --- a/avx2-hps2048509/Makefile +++ b/avx2-hps2048509/Makefile @@ -1,8 +1,8 @@ CC = /usr/bin/cc CFLAGS = -Wall -Wextra -Wpedantic -O3 -fomit-frame-pointer -march=native -no-pie -SOURCES = crypto_sort.c djbsort/sort.c poly.c pack3.c packq.c fips202.c randombytes.c sample.c verify.c owcpa.c kem.c -HEADERS = crypto_sort.h djbsort/int32_sort.h params.h poly.h randombytes.h sample.h verify.h owcpa.h kem.h +SOURCES = crypto_sort.c djbsort/sort.c poly.c poly_s3_inv.c pack3.c packq.c fips202.c randombytes.c sample.c verify.c owcpa.c kem.c +HEADERS = crypto_sort.h djbsort/int32_sort.h poly_s3_inv.h params.h poly.h randombytes.h sample.h verify.h owcpa.h kem.h OBJS = poly_rq_mul.s poly_s3_mul.s poly_rq_mul_x_minus_1.s diff --git a/avx2-hps2048509/Makefile-NIST b/avx2-hps2048509/Makefile-NIST index 8e84bd4..8c5dfba 100644 --- a/avx2-hps2048509/Makefile-NIST +++ b/avx2-hps2048509/Makefile-NIST @@ -2,10 +2,10 @@ CC=/usr/bin/gcc CFLAGS=-O3 -fomit-frame-pointer -march=native -no-pie LDFLAGS=-lcrypto -SOURCES = crypto_sort.c djbsort/sort.c fips202.c kem.c owcpa.c pack3.c packq.c poly.c sample.c verify.c rng.c \ - PQCgenKAT_kem.c poly_rq_mul.s poly_s3_mul.s poly_rq_mul_x_minus_1.s +SOURCES = crypto_sort.c djbsort/sort.c fips202.c kem.c owcpa.c pack3.c packq.c poly.c poly_s3_inv.c sample.c verify.c \ + rng.c PQCgenKAT_kem.c poly_rq_mul.s poly_s3_mul.s poly_rq_mul_x_minus_1.s -HEADERS = api.h crypto_sort.h djbsort/int32_sort.h fips202.h kem.h poly.h owcpa.h params.h sample.h verify.h rng.h +HEADERS = api.h crypto_sort.h djbsort/int32_sort.h fips202.h kem.h poly.h poly_s3_inv.h owcpa.h params.h sample.h verify.h rng.h PQCgenKAT_kem: $(HEADERS) $(SOURCES) $(CC) $(CFLAGS) -o $@ $(SOURCES) $(LDFLAGS) diff --git a/avx2-hps2048509/poly.c b/avx2-hps2048509/poly.c index 6966de8..9312be3 100644 --- a/avx2-hps2048509/poly.c +++ b/avx2-hps2048509/poly.c @@ -1,6 +1,7 @@ #include "poly.h" #include "fips202.h" #include "verify.h" +#include "poly_s3_inv.h" extern void poly_Rq_mul(poly *r, const poly *a, const poly *b); extern void poly_S3_mul(poly *r, const poly *a, const poly *b); @@ -281,81 +282,3 @@ void poly_Rq_inv(poly *r, const poly *a) poly_R2_inv(&ai2, a); poly_R2_inv_to_Rq_inv(r, &ai2, a); } - -void poly_S3_inv(poly *r, const poly *a) -{ - /* Schroeppel--Orman--O'Malley--Spatscheck - * "Almost Inverse" algorithm as described - * by Silverman in NTRU Tech Report #14 */ - // with several modifications to make it run in constant-time - int i, j; - uint16_t k = 0; - uint16_t degf = NTRU_N-1; - uint16_t degg = NTRU_N-1; - int sign, fsign = 0, t, swap; - int done = 0; - poly b, c, f, g; - poly *temp_r = &f; - - /* b(X) := 1 */ - for(i=1; icoeffs[i]; - - /* g(X) := 1 + X + X^2 + ... + X^{N-1} */ - for(i=0; i> 1) | sign) & !done & ((degf - degg) >> 15); - - cswappoly(&f, &g, swap); - cswappoly(&b, &c, swap); - t = (degf ^ degg) & (-swap); - degf ^= t; - degg ^= t; - - POLY_S3_FMADD(i, f, g, sign*(!done)); - POLY_S3_FMADD(i, b, c, sign*(!done)); - - poly_divx(&f, !done); - poly_mulx(&c, !done); - degf -= !done; - k += !done; - - done = 1 - (((uint16_t)-degf) >> 15); - } - - fsign = f.coeffs[0]; - k = k - NTRU_N*((uint16_t)(NTRU_N - k - 1) >> 15); - - /* Return X^{N-k} * b(X) */ - /* This is a k-coefficient rotation. We do this by looking at the binary - representation of k, rotating for every power of 2, and performing a cmov - if the respective bit is set. */ - for (i = 0; i < NTRU_N; i++) - r->coeffs[i] = mod3(fsign * b.coeffs[i]); - - for (i = 0; i < 10; i++) { - for (j = 0; j < NTRU_N; j++) { - temp_r->coeffs[j] = r->coeffs[(j + (1 << i)) % NTRU_N]; - } - cmov((unsigned char *)&(r->coeffs), - (unsigned char *)&(temp_r->coeffs), sizeof(uint16_t) * NTRU_N, k & 1); - k >>= 1; - } - - /* Reduce modulo Phi_n */ - for(i=0; icoeffs[i] = mod3(r->coeffs[i] + 2*r->coeffs[NTRU_N-1]); -} diff --git a/avx2-hps2048509/poly_s3_inv.c b/avx2-hps2048509/poly_s3_inv.c new file mode 100644 index 0000000..240457a --- /dev/null +++ b/avx2-hps2048509/poly_s3_inv.c @@ -0,0 +1,441 @@ +#include "poly_s3_inv.h" +#include "poly.h" +#include + +#define C 508 // Amount of input coefficients + +typedef signed char small; + +typedef __m256i vec256; + +static inline void vec256_frombits(vec256 *v,const small *b) +{ + int i; + + for (i = 0;i < 3;++i) { + vec256 b0 = _mm256_loadu_si256((vec256 *) b); b += 32; /* 0,1,...,31 */ + vec256 b1 = _mm256_loadu_si256((vec256 *) b); b += 32; /* 32,33,... */ + vec256 b2 = _mm256_loadu_si256((vec256 *) b); b += 32; + vec256 b3 = _mm256_loadu_si256((vec256 *) b); b += 32; + vec256 b4 = _mm256_loadu_si256((vec256 *) b); b += 32; + vec256 b5 = _mm256_loadu_si256((vec256 *) b); b += 32; + vec256 b6 = _mm256_loadu_si256((vec256 *) b); b += 32; + vec256 b7 = _mm256_loadu_si256((vec256 *) b); b += 32; + + vec256 c0 = _mm256_unpacklo_epi32(b0,b1); /* 0 1 2 3 32 33 34 35 4 5 6 7 36 37 38 39 ... 55 */ + vec256 c1 = _mm256_unpackhi_epi32(b0,b1); /* 8 9 10 11 40 41 42 43 ... 63 */ + vec256 c2 = _mm256_unpacklo_epi32(b2,b3); + vec256 c3 = _mm256_unpackhi_epi32(b2,b3); + vec256 c4 = _mm256_unpacklo_epi32(b4,b5); + vec256 c5 = _mm256_unpackhi_epi32(b4,b5); + vec256 c6 = _mm256_unpacklo_epi32(b6,b7); + vec256 c7 = _mm256_unpackhi_epi32(b6,b7); + + vec256 d0 = c0 | _mm256_slli_epi32(c1,2); /* 0 8, 1 9, 2 10, 3 11, 32 40, 33 41, ..., 55 63 */ + vec256 d2 = c2 | _mm256_slli_epi32(c3,2); + vec256 d4 = c4 | _mm256_slli_epi32(c5,2); + vec256 d6 = c6 | _mm256_slli_epi32(c7,2); + + vec256 e0 = _mm256_unpacklo_epi64(d0,d2); + vec256 e2 = _mm256_unpackhi_epi64(d0,d2); + vec256 e4 = _mm256_unpacklo_epi64(d4,d6); + vec256 e6 = _mm256_unpackhi_epi64(d4,d6); + + vec256 f0 = e0 | _mm256_slli_epi32(e2,1); + vec256 f4 = e4 | _mm256_slli_epi32(e6,1); + + vec256 g0 = _mm256_permute2x128_si256(f0,f4,0x20); + vec256 g4 = _mm256_permute2x128_si256(f0,f4,0x31); + + vec256 h = g0 | _mm256_slli_epi32(g4,4); + +#define TRANSPOSE _mm256_set_epi8( 31,27,23,19, 30,26,22,18, 29,25,21,17, 28,24,20,16, 15,11,7,3, 14,10,6,2, 13,9,5,1, 12,8,4,0 ) + h = _mm256_shuffle_epi8(h,TRANSPOSE); + h = _mm256_permute4x64_epi64(h,0xd8); + h = _mm256_shuffle_epi32(h,0xd8); + + *v++ = h; + } +} + +static inline void vec256_tobits(const vec256 *v,small *b) +{ + int i; + + for (i = 0;i < 3;++i) { + vec256 h = *v++; + + h = _mm256_shuffle_epi32(h,0xd8); + h = _mm256_permute4x64_epi64(h,0xd8); + h = _mm256_shuffle_epi8(h,TRANSPOSE); + + vec256 g0 = h & _mm256_set1_epi8(15); + vec256 g4 = _mm256_srli_epi32(h,4) & _mm256_set1_epi8(15); + + vec256 f0 = _mm256_permute2x128_si256(g0,g4,0x20); + vec256 f4 = _mm256_permute2x128_si256(g0,g4,0x31); + + vec256 e0 = f0 & _mm256_set1_epi8(5); + vec256 e2 = _mm256_srli_epi32(f0,1) & _mm256_set1_epi8(5); + vec256 e4 = f4 & _mm256_set1_epi8(5); + vec256 e6 = _mm256_srli_epi32(f4,1) & _mm256_set1_epi8(5); + + vec256 d0 = _mm256_unpacklo_epi32(e0,e2); + vec256 d2 = _mm256_unpackhi_epi32(e0,e2); + vec256 d4 = _mm256_unpacklo_epi32(e4,e6); + vec256 d6 = _mm256_unpackhi_epi32(e4,e6); + + vec256 c0 = d0 & _mm256_set1_epi8(1); + vec256 c1 = _mm256_srli_epi32(d0,2) & _mm256_set1_epi8(1); + vec256 c2 = d2 & _mm256_set1_epi8(1); + vec256 c3 = _mm256_srli_epi32(d2,2) & _mm256_set1_epi8(1); + vec256 c4 = d4 & _mm256_set1_epi8(1); + vec256 c5 = _mm256_srli_epi32(d4,2) & _mm256_set1_epi8(1); + vec256 c6 = d6 & _mm256_set1_epi8(1); + vec256 c7 = _mm256_srli_epi32(d6,2) & _mm256_set1_epi8(1); + + vec256 b0 = _mm256_unpacklo_epi64(c0,c1); + vec256 b1 = _mm256_unpackhi_epi64(c0,c1); + vec256 b2 = _mm256_unpacklo_epi64(c2,c3); + vec256 b3 = _mm256_unpackhi_epi64(c2,c3); + vec256 b4 = _mm256_unpacklo_epi64(c4,c5); + vec256 b5 = _mm256_unpackhi_epi64(c4,c5); + vec256 b6 = _mm256_unpacklo_epi64(c6,c7); + vec256 b7 = _mm256_unpackhi_epi64(c6,c7); + + _mm256_storeu_si256((vec256 *) b,b0); b += 32; + _mm256_storeu_si256((vec256 *) b,b1); b += 32; + _mm256_storeu_si256((vec256 *) b,b2); b += 32; + _mm256_storeu_si256((vec256 *) b,b3); b += 32; + _mm256_storeu_si256((vec256 *) b,b4); b += 32; + _mm256_storeu_si256((vec256 *) b,b5); b += 32; + _mm256_storeu_si256((vec256 *) b,b6); b += 32; + _mm256_storeu_si256((vec256 *) b,b7); b += 32; + } +} + +static inline void reverse(small *srev,const small *s) +{ + int i; + for (i = 0;i < 512;++i) + srev[i] = s[511-i]; +} + +static void vec256_init(vec256 *G0,vec256 *G1,const small *s) +{ + int i; + small srev[512+(512-C)]; + small si; + small g0[512]; + small g1[512]; + + reverse(srev,s); + for (i = C;i < 512;++i) srev[i+512-C] = 0; + + for (i = 0;i < 512;++i) { + si = srev[i+512-C]; + g0[i] = si & 1; + g1[i] = (si >> 1) & 1; + } + + vec256_frombits(G0,g0); + vec256_frombits(G1,g1); +} + +static void vec256_final(small *out,const vec256 *V0,const vec256 *V1) +{ + int i; + small v0[512]; + small v1[512]; + small v[512]; + small vrev[512+(512-C)]; + + vec256_tobits(V0,v0); + vec256_tobits(V1,v1); + + for (i = 0;i < 512;++i) + v[i] = v0[i] + v1[i]; + + reverse(vrev,v); + for (i = 512;i < 512+(512-C);++i) vrev[i] = 0; + + for (i = 0;i < 512;++i) out[i] = vrev[i+512-C]; +} + +static inline int negative_mask(int x) +{ + return x >> 31; +} + +static inline void vec256_swap(vec256 *f,vec256 *g,int len,vec256 mask) +{ + vec256 flip; + int i; + + for (i = 0;i < len;++i) { + flip = mask & (f[i] ^ g[i]); + f[i] ^= flip; + g[i] ^= flip; + } +} + +static inline void vec256_scale(vec256 *f0,vec256 *f1,const vec256 c0,const vec256 c1) +{ + int i; + + for (i = 0;i < 3;++i) { + vec256 f0i = f0[i]; + vec256 f1i = f1[i]; + + f0i &= c0; + f1i ^= c1; + f1i &= f0i; + + f0[i] = f0i; + f1[i] = f1i; + } +} + +static inline void vec256_eliminate(vec256 *f0,vec256 *f1,vec256 *g0,vec256 *g1,int len,const vec256 c0,const vec256 c1) +{ + int i; + + for (i = 0;i < len;++i) { + vec256 f0i = f0[i]; + vec256 f1i = f1[i]; + vec256 g0i = g0[i]; + vec256 g1i = g1[i]; + vec256 t; + + f0i &= c0; + f1i ^= c1; + f1i &= f0i; + + t = g0i ^ f0i; + g0[i] = t | (g1i ^ f1i); + g1[i] = (g1i ^ f0i) & (f1i ^ t); + } +} + +static inline int vec256_bit0mask(vec256 *f) +{ + return -(_mm_cvtsi128_si32(_mm256_castsi256_si128(f[0])) & 1); +} + +static inline void vec256_divx_1(vec256 *f) +{ + unsigned long long f0 = _mm_cvtsi128_si64(_mm256_castsi256_si128(f[0])); + + f0 = f0 >> 1; + + f[0] = _mm256_blend_epi32(f[0],_mm256_set_epi64x(0,0,0,f0),0x3); + + f[0] = _mm256_permute4x64_epi64(f[0],0x39); +} + +static inline void vec256_divx_2(vec256 *f) +{ + unsigned long long f0 = _mm_cvtsi128_si64(_mm256_castsi256_si128(f[0])); + unsigned long long f1 = _mm_cvtsi128_si64(_mm256_castsi256_si128(f[1])); + + f0 = (f0 >> 1) | (f1 << 63); + f1 = f1 >> 1; + + f[0] = _mm256_blend_epi32(f[0],_mm256_set_epi64x(0,0,0,f0),0x3); + f[1] = _mm256_blend_epi32(f[1],_mm256_set_epi64x(0,0,0,f1),0x3); + + f[0] = _mm256_permute4x64_epi64(f[0],0x39); + f[1] = _mm256_permute4x64_epi64(f[1],0x39); +} + +static inline void vec256_timesx_1(vec256 *f) +{ + unsigned long long f0; + + f[0] = _mm256_permute4x64_epi64(f[0],0x93); + + f0 = _mm_cvtsi128_si64(_mm256_castsi256_si128(f[0])); + + f0 = f0 << 1; + + f[0] = _mm256_blend_epi32(f[0],_mm256_set_epi64x(0,0,0,f0),0x3); +} + +static inline void vec256_timesx_2(vec256 *f) +{ + unsigned long long f0,f1; + + f[0] = _mm256_permute4x64_epi64(f[0],0x93); + f[1] = _mm256_permute4x64_epi64(f[1],0x93); + + f0 = _mm_cvtsi128_si64(_mm256_castsi256_si128(f[0])); + f1 = _mm_cvtsi128_si64(_mm256_castsi256_si128(f[1])); + + f1 = (f1 << 1) | (f0 >> 63); + f0 = f0 << 1; + + f[0] = _mm256_blend_epi32(f[0],_mm256_set_epi64x(0,0,0,f0),0x3); + f[1] = _mm256_blend_epi32(f[1],_mm256_set_epi64x(0,0,0,f1),0x3); +} + +static int r3_recip(small *out,const small *s) +{ + vec256 F0[2]; + vec256 F1[2]; + vec256 G0[2]; + vec256 G1[2]; + vec256 V0[2]; + vec256 V1[2]; + vec256 R0[2]; + vec256 R1[2]; + vec256 c0vec,c1vec; + int loop; + int c0,c1; + int minusdelta = -1; + int swapmask; + vec256 swapvec; + + vec256_init(G0,G1,s); + + F0[0] = _mm256_set_epi32(-1,-1,-1,-1,-1,-1,-1,-1); + F0[1] = _mm256_set_epi32(2147483647,-1,2147483647,-1,2147483647,-1,-1,-1); + F1[0] = F1[1] = _mm256_set1_epi32(0); + + V0[1] = V0[0] = _mm256_set1_epi32(0); + V1[1] = V1[0] = _mm256_set1_epi32(0); + + R0[0] = _mm256_set_epi32(0,0,0,0,0,0,0,1); + R0[1] = _mm256_set1_epi32(0); + R1[1] = R1[0] = _mm256_set1_epi32(0); + + for (loop = 0;loop < 256;++loop) { + vec256_timesx_1(V0); + vec256_timesx_1(V1); + swapmask = negative_mask(minusdelta) & vec256_bit0mask(G0); + + c0 = vec256_bit0mask(F0) & vec256_bit0mask(G0); + c1 = vec256_bit0mask(F1) ^ vec256_bit0mask(G1); + c1 &= c0; + + minusdelta ^= swapmask & (minusdelta ^ -minusdelta); + minusdelta -= 1; + + swapvec = _mm256_set1_epi32(swapmask); + vec256_swap(F0,G0,2,swapvec); + vec256_swap(F1,G1,2,swapvec); + + c0vec = _mm256_set1_epi32(c0); + c1vec = _mm256_set1_epi32(c1); + + vec256_eliminate(F0,F1,G0,G1,2,c0vec,c1vec); + vec256_divx_2(G0); + vec256_divx_2(G1); + + vec256_swap(V0,R0,1,swapvec); + vec256_swap(V1,R1,1,swapvec); + vec256_eliminate(V0,V1,R0,R1,1,c0vec,c1vec); + } + + for (loop = 256;loop < C;++loop) { + vec256_timesx_2(V0); + vec256_timesx_2(V1); + swapmask = negative_mask(minusdelta) & vec256_bit0mask(G0); + + c0 = vec256_bit0mask(F0) & vec256_bit0mask(G0); + c1 = vec256_bit0mask(F1) ^ vec256_bit0mask(G1); + c1 &= c0; + + minusdelta ^= swapmask & (minusdelta ^ -minusdelta); + minusdelta -= 1; + + swapvec = _mm256_set1_epi32(swapmask); + vec256_swap(F0,G0,2,swapvec); + vec256_swap(F1,G1,2,swapvec); + + c0vec = _mm256_set1_epi32(c0); + c1vec = _mm256_set1_epi32(c1); + + vec256_eliminate(F0,F1,G0,G1,2,c0vec,c1vec); + vec256_divx_2(G0); + vec256_divx_2(G1); + + vec256_swap(V0,R0,2,swapvec); + vec256_swap(V1,R1,2,swapvec); + vec256_eliminate(V0,V1,R0,R1,2,c0vec,c1vec); + } + + for (loop = C-1;loop > 0;--loop) { + vec256_timesx_2(V0); + vec256_timesx_2(V1); + swapmask = negative_mask(minusdelta) & vec256_bit0mask(G0); + + c0 = vec256_bit0mask(F0) & vec256_bit0mask(G0); + c1 = vec256_bit0mask(F1) ^ vec256_bit0mask(G1); + c1 &= c0; + + minusdelta ^= swapmask & (minusdelta ^ -minusdelta); + minusdelta -= 1; + + swapvec = _mm256_set1_epi32(swapmask); + vec256_swap(F0,G0,1,swapvec); + vec256_swap(F1,G1,1,swapvec); + + c0vec = _mm256_set1_epi32(c0); + c1vec = _mm256_set1_epi32(c1); + + vec256_eliminate(F0,F1,G0,G1,1,c0vec,c1vec); + vec256_divx_1(G0); + vec256_divx_1(G1); + + vec256_swap(V0,R0,2,swapvec); + vec256_swap(V1,R1,2,swapvec); + vec256_eliminate(V0,V1,R0,R1,2,c0vec,c1vec); + } + + c0vec = _mm256_set1_epi32(vec256_bit0mask(F0)); + c1vec = _mm256_set1_epi32(vec256_bit0mask(F1)); + vec256_scale(V0,V1,c0vec,c1vec); + + vec256_final(out,V0,V1); + return negative_mask(minusdelta); +} + +// This code is based on crypto_core/invhrss701/faster from SUPERCOP. The code was written as a case study +// for the paper "Fast constant-time gcd computation and modular inversion" by Daniel J. Bernstein and Bo-Yin Yang. +// Their implementation of poly_S3_inv outperformed the implementation in hrss701 by a factor of 1.7. The expectation +// is that their method will also outperform on hps2048509 which is the reason for using it here. +// There are two changes: +// 1 - There are 508 input coefficients rather than 700. As a result the main loop has to do less iterations. +// 2 - Since there are less than 513 coefficients all polynomials can be stored as four 256-bit vectors +// rather than six. Two vectors for the bottom bits and two vectors for the top bits. A reduced vector +// count means less vector operations are needed per polynomial (less swapping for example). +// Everything else (such as the data format) is kept intact. +// See https://gcd.cr.yp.to/papers.html#safegcd for details on how the code works (and mathematical background). +void poly_S3_inv(poly *r_out, const poly *a) { + const unsigned char *in = (void*) a; + unsigned char *out = (void*) r_out; + + small intop = in[2*508]&3; + small input[512]; + small output[512]; + int i; + + intop = 3 - intop; /* 0 1 2 3 */ + intop &= (intop-3)>>5; /* 0 1 2 */ + intop += 1; /* 0 1 2 3, offset by 1 */ + + for (i = 0;i < 508;++i) { + small x = in[2*i]&3; /* 0 1 2 3 */ + x += intop; /* 0 1 2 3 4 5 6, offset by 1 */ + x = (x&3)+(x>>2); /* 0 1 2 3, offset by 1 */ + x &= (x-3)>>5; /* 0 1 2, offset by 1 */ + input[i] = x - 1; + } + + r3_recip(output,input); + + for (i = 0;i < 512;++i) { + out[2*i] = output[i]; + out[2*i+1] = 0; + } +} diff --git a/avx2-hps2048509/poly_s3_inv.h b/avx2-hps2048509/poly_s3_inv.h new file mode 100644 index 0000000..be6d40c --- /dev/null +++ b/avx2-hps2048509/poly_s3_inv.h @@ -0,0 +1,8 @@ +#ifndef POLY_S3_INV_H +#define POLY_S3_INV_H + +#include "poly.h" + +void poly_S3_inv(poly *r, const poly *a); + +#endif