Skip to content

Commit

Permalink
Implement avx2 version of poly_Rq_to_S3 for hps2048509
Browse files Browse the repository at this point in the history
Explanation:
Straightforward conversion of the C function but using avx2
to operate on 16 coefficients at a time rather than one.

Results:
On an Intel i5-8250u using gcc 8.3.0 the reference poly_Rq_to_S3
takes about 2300 cycles on average whereas the avx2 version takes about
205 cycles on average. This is about an 11 times speedup.
  • Loading branch information
OussamaDanba committed Jun 3, 2019
1 parent 08529b6 commit c5bd4ae
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 19 deletions.
2 changes: 1 addition & 1 deletion avx2-hps2048509/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ OBJS = square_1_509_patience.s \
square_63_509_shufbytes.s \
square_126_509_shufbytes.s \
square_252_509_shufbytes.s
OBJS += poly_rq_mul.s poly_r2_mul.s poly_s3_mul.s poly_rq_mul_x_minus_1.s
OBJS += poly_rq_mul.s poly_r2_mul.s poly_s3_mul.s poly_rq_mul_x_minus_1.s poly_rq_to_s3.s

all: test/test_polymul \
test/test_ntru \
Expand Down
1 change: 1 addition & 0 deletions avx2-hps2048509/Makefile-NIST
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ LDFLAGS=-lcrypto

SOURCES = crypto_sort.c djbsort/sort.c fips202.c kem.c owcpa.c pack3.c packq.c poly.c poly_r2_inv.c poly_s3_inv.c \
sample.c verify.c rng.c PQCgenKAT_kem.c poly_rq_mul.s poly_r2_mul.s poly_s3_mul.s poly_rq_mul_x_minus_1.s \
poly_rq_to_s3.s \
square_1_509_patience.s \
square_3_509_patience.s \
square_6_509_patience.s \
Expand Down
96 changes: 96 additions & 0 deletions avx2-hps2048509/asmgen/poly_rq_to_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@

from math import ceil

p = print


def mod3(a, r=13, t=14, c=15):
# r = (a >> 8) + (a & 0xff); // r mod 255 == a mod 255
p("vpsrlw $8, %ymm{}, %ymm{}".format(a, r))
p("vpand mask_ff, %ymm{}, %ymm{}".format(a, a))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r))

# r = (r >> 4) + (r & 0xf); // r' mod 15 == r mod 15
p("vpand mask_f, %ymm{}, %ymm{}".format(r, a))
p("vpsrlw $4, %ymm{}, %ymm{}".format(r, r))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r))

# r = (r >> 2) + (r & 0x3); // r' mod 3 == r mod 3
# r = (r >> 2) + (r & 0x3); // r' mod 3 == r mod 3
for _ in range(2):
p("vpand mask_3, %ymm{}, %ymm{}".format(r, a))
p("vpsrlw $2, %ymm{}, %ymm{}".format(r, r))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r))

# t = r - 3;
p("vpsubw mask_3, %ymm{}, %ymm{}".format(r, t))
# c = t >> 15; t is signed, so shift arithmetic
p("vpsraw $15, %ymm{}, %ymm{}".format(t, c))

tmp = a
# return (c&r) ^ (~c&t);
p("vpandn %ymm{}, %ymm{}, %ymm{}".format(t, c, tmp))
p("vpand %ymm{}, %ymm{}, %ymm{}".format(c, r, t))
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t, tmp, r))


if __name__ == '__main__':
p(".data")
p(".align 32")

p("const_3_repeating:")
for i in range(16):
p(".word 0x3")

p("shuf_b8_to_low_doubleword:")
for j in range(16):
p(".byte 8")
p(".byte 255")

p("mask_ff:")
for i in range(16):
p(".word 0xff")

p("mask_f:")
for i in range(16):
p(".word 0xf")
p("mask_3:")
for i in range(16):
p(".word 0x03")

p(".text")
p(".global poly_Rq_to_S3")
p(".att_syntax prefix")

p("poly_Rq_to_S3:")

r = 0
a = 1
threes = 3
last = 4
retval = 5
p("vmovdqa const_3_repeating, %ymm{}".format(threes))
p("vmovdqa {}(%rsi), %ymm{}".format((ceil(509 / 16) - 1)*32, last))

p("vpsrlw $10, %ymm{}, %ymm{}".format(last, r))
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(threes, r, r))
p("vpsllw $11, %ymm{}, %ymm{}".format(r, r))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(last, r, last))

mod3(last, retval)
p("vpsllw $1, %ymm{}, %ymm{}".format(retval, last))
p("vextracti128 $1, %ymm{}, %xmm{}".format(last, last))
p("vpshufb shuf_b8_to_low_doubleword, %ymm{}, %ymm{}".format(last, last))
p("vinserti128 $1, %xmm{}, %ymm{}, %ymm{}".format(last, last, last))

for i in range(ceil(509 / 16)):
p("vmovdqa {}(%rsi), %ymm{}".format(i*32, a))
p("vpsrlw $10, %ymm{}, %ymm{}".format(a, r))
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(threes, r, r))
p("vpsllw $11, %ymm{}, %ymm{}".format(r, r))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(a, r, r))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(last, r, r))
mod3(r, retval)
p("vmovdqa %ymm{}, {}(%rdi)".format(retval, i*32))

p("ret")
19 changes: 1 addition & 18 deletions avx2-hps2048509/poly.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
extern void poly_Rq_mul(poly *r, const poly *a, const poly *b);
extern void poly_S3_mul(poly *r, const poly *a, const poly *b);
extern void poly_Rq_mul_x_minus_1(poly *r, const poly *a);
extern void poly_Rq_to_S3(poly *r, const poly *a);

uint16_t mod3(uint16_t a)
{
Expand Down Expand Up @@ -56,24 +57,6 @@ void poly_lift(poly *r, const poly *a)
poly_Z3_to_Zq(r);
}

void poly_Rq_to_S3(poly *r, const poly *a)
{
/* NOTE: Assumes input is in [0,Q-1]^N */
/* Produces output in {0,1,2}^N */
int i;

/* Center coeffs around 3Q: [0, Q-1] -> [3Q - Q/2, 3Q + Q/2) */
for(i=0; i<NTRU_N; i++)
{
r->coeffs[i] = ((a->coeffs[i] >> (NTRU_LOGQ-1)) ^ 3) << NTRU_LOGQ;
r->coeffs[i] += a->coeffs[i];
}
/* Reduce mod (3, Phi) */
r->coeffs[NTRU_N-1] = mod3(r->coeffs[NTRU_N-1]);
for(i=0; i<NTRU_N; i++)
r->coeffs[i] = mod3(r->coeffs[i] + 2*r->coeffs[NTRU_N-1]);
}

static void poly_R2_inv_to_Rq_inv(poly *r, const poly *ai, const poly *a)
{
#if NTRU_Q <= 256 || NTRU_Q >= 65536
Expand Down

0 comments on commit c5bd4ae

Please sign in to comment.