From c5bd4ae0f5d1bd0c6fadaecf74598372965a181a Mon Sep 17 00:00:00 2001 From: OussamaDanba Date: Mon, 3 Jun 2019 13:25:54 +0200 Subject: [PATCH] Implement avx2 version of poly_Rq_to_S3 for hps2048509 Explanation: Straightforward conversion of the C function but using avx2 to operate on 16 coefficients at a time rather than one. Results: On an Intel i5-8250u using gcc 8.3.0 the reference poly_Rq_to_S3 takes about 2300 cycles on average whereas the avx2 version takes about 205 cycles on average. This is about an 11 times speedup. --- avx2-hps2048509/Makefile | 2 +- avx2-hps2048509/Makefile-NIST | 1 + avx2-hps2048509/asmgen/poly_rq_to_s3.py | 96 +++++++++++++++++++++++++ avx2-hps2048509/poly.c | 19 +---- 4 files changed, 99 insertions(+), 19 deletions(-) create mode 100644 avx2-hps2048509/asmgen/poly_rq_to_s3.py diff --git a/avx2-hps2048509/Makefile b/avx2-hps2048509/Makefile index 981ff8d..5ded1d2 100644 --- a/avx2-hps2048509/Makefile +++ b/avx2-hps2048509/Makefile @@ -13,7 +13,7 @@ OBJS = square_1_509_patience.s \ square_63_509_shufbytes.s \ square_126_509_shufbytes.s \ square_252_509_shufbytes.s -OBJS += poly_rq_mul.s poly_r2_mul.s poly_s3_mul.s poly_rq_mul_x_minus_1.s +OBJS += poly_rq_mul.s poly_r2_mul.s poly_s3_mul.s poly_rq_mul_x_minus_1.s poly_rq_to_s3.s all: test/test_polymul \ test/test_ntru \ diff --git a/avx2-hps2048509/Makefile-NIST b/avx2-hps2048509/Makefile-NIST index 99b7603..231cc17 100644 --- a/avx2-hps2048509/Makefile-NIST +++ b/avx2-hps2048509/Makefile-NIST @@ -4,6 +4,7 @@ LDFLAGS=-lcrypto SOURCES = crypto_sort.c djbsort/sort.c fips202.c kem.c owcpa.c pack3.c packq.c poly.c poly_r2_inv.c poly_s3_inv.c \ sample.c verify.c rng.c PQCgenKAT_kem.c poly_rq_mul.s poly_r2_mul.s poly_s3_mul.s poly_rq_mul_x_minus_1.s \ + poly_rq_to_s3.s \ square_1_509_patience.s \ square_3_509_patience.s \ square_6_509_patience.s \ diff --git a/avx2-hps2048509/asmgen/poly_rq_to_s3.py b/avx2-hps2048509/asmgen/poly_rq_to_s3.py new file mode 100644 index 0000000..46605de --- /dev/null +++ b/avx2-hps2048509/asmgen/poly_rq_to_s3.py @@ -0,0 +1,96 @@ + +from math import ceil + +p = print + + +def mod3(a, r=13, t=14, c=15): + # r = (a >> 8) + (a & 0xff); // r mod 255 == a mod 255 + p("vpsrlw $8, %ymm{}, %ymm{}".format(a, r)) + p("vpand mask_ff, %ymm{}, %ymm{}".format(a, a)) + p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r)) + + # r = (r >> 4) + (r & 0xf); // r' mod 15 == r mod 15 + p("vpand mask_f, %ymm{}, %ymm{}".format(r, a)) + p("vpsrlw $4, %ymm{}, %ymm{}".format(r, r)) + p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r)) + + # r = (r >> 2) + (r & 0x3); // r' mod 3 == r mod 3 + # r = (r >> 2) + (r & 0x3); // r' mod 3 == r mod 3 + for _ in range(2): + p("vpand mask_3, %ymm{}, %ymm{}".format(r, a)) + p("vpsrlw $2, %ymm{}, %ymm{}".format(r, r)) + p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r)) + + # t = r - 3; + p("vpsubw mask_3, %ymm{}, %ymm{}".format(r, t)) + # c = t >> 15; t is signed, so shift arithmetic + p("vpsraw $15, %ymm{}, %ymm{}".format(t, c)) + + tmp = a + # return (c&r) ^ (~c&t); + p("vpandn %ymm{}, %ymm{}, %ymm{}".format(t, c, tmp)) + p("vpand %ymm{}, %ymm{}, %ymm{}".format(c, r, t)) + p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t, tmp, r)) + + +if __name__ == '__main__': + p(".data") + p(".align 32") + + p("const_3_repeating:") + for i in range(16): + p(".word 0x3") + + p("shuf_b8_to_low_doubleword:") + for j in range(16): + p(".byte 8") + p(".byte 255") + + p("mask_ff:") + for i in range(16): + p(".word 0xff") + + p("mask_f:") + for i in range(16): + p(".word 0xf") + p("mask_3:") + for i in range(16): + p(".word 0x03") + + p(".text") + p(".global poly_Rq_to_S3") + p(".att_syntax prefix") + + p("poly_Rq_to_S3:") + + r = 0 + a = 1 + threes = 3 + last = 4 + retval = 5 + p("vmovdqa const_3_repeating, %ymm{}".format(threes)) + p("vmovdqa {}(%rsi), %ymm{}".format((ceil(509 / 16) - 1)*32, last)) + + p("vpsrlw $10, %ymm{}, %ymm{}".format(last, r)) + p("vpxor %ymm{}, %ymm{}, %ymm{}".format(threes, r, r)) + p("vpsllw $11, %ymm{}, %ymm{}".format(r, r)) + p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(last, r, last)) + + mod3(last, retval) + p("vpsllw $1, %ymm{}, %ymm{}".format(retval, last)) + p("vextracti128 $1, %ymm{}, %xmm{}".format(last, last)) + p("vpshufb shuf_b8_to_low_doubleword, %ymm{}, %ymm{}".format(last, last)) + p("vinserti128 $1, %xmm{}, %ymm{}, %ymm{}".format(last, last, last)) + + for i in range(ceil(509 / 16)): + p("vmovdqa {}(%rsi), %ymm{}".format(i*32, a)) + p("vpsrlw $10, %ymm{}, %ymm{}".format(a, r)) + p("vpxor %ymm{}, %ymm{}, %ymm{}".format(threes, r, r)) + p("vpsllw $11, %ymm{}, %ymm{}".format(r, r)) + p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(a, r, r)) + p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(last, r, r)) + mod3(r, retval) + p("vmovdqa %ymm{}, {}(%rdi)".format(retval, i*32)) + + p("ret") diff --git a/avx2-hps2048509/poly.c b/avx2-hps2048509/poly.c index dbb5237..f6a0953 100644 --- a/avx2-hps2048509/poly.c +++ b/avx2-hps2048509/poly.c @@ -7,6 +7,7 @@ extern void poly_Rq_mul(poly *r, const poly *a, const poly *b); extern void poly_S3_mul(poly *r, const poly *a, const poly *b); extern void poly_Rq_mul_x_minus_1(poly *r, const poly *a); +extern void poly_Rq_to_S3(poly *r, const poly *a); uint16_t mod3(uint16_t a) { @@ -56,24 +57,6 @@ void poly_lift(poly *r, const poly *a) poly_Z3_to_Zq(r); } -void poly_Rq_to_S3(poly *r, const poly *a) -{ - /* NOTE: Assumes input is in [0,Q-1]^N */ - /* Produces output in {0,1,2}^N */ - int i; - - /* Center coeffs around 3Q: [0, Q-1] -> [3Q - Q/2, 3Q + Q/2) */ - for(i=0; icoeffs[i] = ((a->coeffs[i] >> (NTRU_LOGQ-1)) ^ 3) << NTRU_LOGQ; - r->coeffs[i] += a->coeffs[i]; - } - /* Reduce mod (3, Phi) */ - r->coeffs[NTRU_N-1] = mod3(r->coeffs[NTRU_N-1]); - for(i=0; icoeffs[i] = mod3(r->coeffs[i] + 2*r->coeffs[NTRU_N-1]); -} - static void poly_R2_inv_to_Rq_inv(poly *r, const poly *ai, const poly *a) { #if NTRU_Q <= 256 || NTRU_Q >= 65536