Skip to content

Commit

Permalink
Merge pull request #3 from OussamaDanba/avx2-hps2048509_optimizations
Browse files Browse the repository at this point in the history
More avx2 optimizations for hps2048509
  • Loading branch information
jschanck authored May 15, 2019
2 parents c0b6c57 + d47376c commit 9707e9c
Show file tree
Hide file tree
Showing 11 changed files with 1,836 additions and 112 deletions.
6 changes: 3 additions & 3 deletions avx2-hps2048509/Makefile
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
CC = /usr/bin/cc
CFLAGS = -Wall -Wextra -Wpedantic -O3 -fomit-frame-pointer -march=native -no-pie

SOURCES = crypto_sort.c poly.c pack3.c packq.c fips202.c randombytes.c sample.c verify.c owcpa.c kem.c
HEADERS = crypto_sort.h params.h poly.h randombytes.h sample.h verify.h owcpa.h kem.h
SOURCES = crypto_sort.c djbsort/sort.c poly.c poly_s3_inv.c pack3.c packq.c fips202.c randombytes.c sample.c verify.c owcpa.c kem.c
HEADERS = crypto_sort.h djbsort/int32_sort.h poly_s3_inv.h params.h poly.h randombytes.h sample.h verify.h owcpa.h kem.h

OBJS = poly_rq_mul.s
OBJS = poly_rq_mul.s poly_s3_mul.s poly_rq_mul_x_minus_1.s

all: test/test_polymul \
test/test_ntru \
Expand Down
6 changes: 3 additions & 3 deletions avx2-hps2048509/Makefile-NIST
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ CC=/usr/bin/gcc
CFLAGS=-O3 -fomit-frame-pointer -march=native -no-pie
LDFLAGS=-lcrypto

SOURCES = crypto_sort.c fips202.c kem.c owcpa.c pack3.c packq.c poly.c sample.c verify.c rng.c PQCgenKAT_kem.c \
poly_rq_mul.s
SOURCES = crypto_sort.c djbsort/sort.c fips202.c kem.c owcpa.c pack3.c packq.c poly.c poly_s3_inv.c sample.c verify.c \
rng.c PQCgenKAT_kem.c poly_rq_mul.s poly_s3_mul.s poly_rq_mul_x_minus_1.s

HEADERS = api.h crypto_sort.h fips202.h kem.h poly.h owcpa.h params.h sample.h verify.h rng.h
HEADERS = api.h crypto_sort.h djbsort/int32_sort.h fips202.h kem.h poly.h poly_s3_inv.h owcpa.h params.h sample.h verify.h rng.h

PQCgenKAT_kem: $(HEADERS) $(SOURCES)
$(CC) $(CFLAGS) -o $@ $(SOURCES) $(LDFLAGS)
Expand Down
65 changes: 65 additions & 0 deletions avx2-hps2048509/asmgen/poly_rq_mul_x_minus_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from math import ceil

p = print

if __name__ == '__main__':
p(".data")
p(".align 32")

p("mask_mod2048:")
for i in range(16):
p(".word 2047")

p("mask_mod2048_omit_lowest:")
p(".word 0")
for i in range(15):
p(".word 2047")

p("mask_mod2048_only_lowest:")
p(".word 2047")
for i in range(15):
p(".word 0")

p("shuf_5_to_0_zerorest:")
for i in range(2):
p(".byte {}".format((i + 2*5) % 16))
for i in range(30):
p(".byte 255")

p(".text")
p(".global poly_Rq_mul_x_minus_1")
p(".att_syntax prefix")

p("poly_Rq_mul_x_minus_1:")

a_imin1 = 0
t0 = 1
t1 = 4
for i in range(ceil(509 / 16)-1, 0, -1):
p("vmovdqu {}(%rsi), %ymm{}".format((i*16 - 1) * 2, a_imin1))
p("vpsubw {}(%rsi), %ymm{}, %ymm{}".format(i * 32, a_imin1, t0))
p("vpand mask_mod2048, %ymm{}, %ymm{}".format(t0, t0))
p("vmovdqa %ymm{}, {}(%rdi)".format(t0, i*32))
if i == ceil(509 / 16)-1:
# a_imin1 now contains 495 to 510 inclusive;
# we need 509 for [0], which is at position 14
p("vextracti128 $1, %ymm{}, %xmm{}".format(a_imin1, t1))
p("vpshufb shuf_5_to_0_zerorest, %ymm{}, %ymm{}".format(t1, t1))
p("vpsubw {}(%rsi), %ymm{}, %ymm{}".format(0, t1, t1))
p("vpand mask_mod2048_only_lowest, %ymm{}, %ymm{}".format(t1, t1))

# and now we still need to fix [1] to [15], which we cannot vmovdqu
t2 = 0
t3 = 2
t4 = 3
p("vmovdqa {}(%rsi), %ymm{}".format(0, t4))
p("vpsrlq $48, %ymm{}, %ymm{}".format(t4, t2))
p("vpermq ${}, %ymm{}, %ymm{}".format(int('10010011', 2), t2, t2))
p("vpsllq $16, %ymm{}, %ymm{}".format(t4, t3))
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t2, t3, t3))
p("vpsubw %ymm{}, %ymm{}, %ymm{}".format(t4, t3, t4))
p("vpand mask_mod2048_omit_lowest, %ymm{}, %ymm{}".format(t4, t4))
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t4, t1, t4))
p("vmovdqa %ymm{}, {}(%rdi)".format(t4, 0))

p("ret")
81 changes: 81 additions & 0 deletions avx2-hps2048509/asmgen/poly_s3_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
p = print

def mod3(a, r=13, t=14, c=15):
# r = (a >> 8) + (a & 0xff); // r mod 255 == a mod 255
p("vpsrlw $8, %ymm{}, %ymm{}".format(a, r))
p("vpand mask_ff, %ymm{}, %ymm{}".format(a, a))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r))

# r = (r >> 4) + (r & 0xf); // r' mod 15 == r mod 15
p("vpand mask_f, %ymm{}, %ymm{}".format(r, a))
p("vpsrlw $4, %ymm{}, %ymm{}".format(r, r))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r))

# r = (r >> 2) + (r & 0x3); // r' mod 3 == r mod 3
# r = (r >> 2) + (r & 0x3); // r' mod 3 == r mod 3
for _ in range(2):
p("vpand mask_3, %ymm{}, %ymm{}".format(r, a))
p("vpsrlw $2, %ymm{}, %ymm{}".format(r, r))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r))

# t = r - 3;
p("vpsubw mask_3, %ymm{}, %ymm{}".format(r, t))
# c = t >> 15; t is signed, so shift arithmetic
p("vpsraw $15, %ymm{}, %ymm{}".format(t, c))

tmp = a
# return (c&r) ^ (~c&t);
p("vpandn %ymm{}, %ymm{}, %ymm{}".format(t, c, tmp))
p("vpand %ymm{}, %ymm{}, %ymm{}".format(c, r, t))
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t, tmp, r))

from math import ceil

# Reuses poly_Rq_mul to do the multiplication. This works out since none of the
# coefficients ever exceed 2048 (they can be 509*4=2036 at most).
# A presumably faster alternative implementation would be to do Karatsuba recursively five times
# such that you have 243 multiplications of 16 coefficient polynomials. These multiplications
# can be done using a bitsliced implementation of the schoolbook method since the coefficients are two bits.
# TODO: If poly_S3_mul shows up (significantly) in profiling consider the alternative implementation.
if __name__ == '__main__':
p(".data")
p(".align 32")

p("mask_ff:")
for i in range(16):
p(".word 0xff")
p("mask_f:")
for i in range(16):
p(".word 0xf")
p("mask_3:")
for i in range(16):
p(".word 0x03")

p(".text")
p(".global poly_S3_mul")
p(".att_syntax prefix")

p("poly_S3_mul:")
p("call poly_Rq_mul")
# result pointer *r is still in %rdi

N_min_1 = 0
t = 1
# NTRU_N is in 509th element; 13th word of 32nd register
p("vmovdqa {}(%rdi), %ymm{}".format(31*32, N_min_1))
p("vpermq ${}, %ymm{}, %ymm{}".format(int('00000011', 2), N_min_1, N_min_1))
# move into high 16 in doubleword (to clear high 16) and multiply by two
p("vpslld $17, %ymm{}, %ymm{}".format(N_min_1, N_min_1))
# clone into bottom 16
p("vpsrld $16, %ymm{}, %ymm{}".format(N_min_1, t))
p("vpor %ymm{}, %ymm{}, %ymm{}".format(N_min_1, t, N_min_1))
# and now it's everywhere in N_min_1
p("vbroadcastss %xmm{}, %ymm{}".format(N_min_1, N_min_1))

retval = 2
for i in range(ceil(509 / 16)):
p("vpaddw {}(%rdi), %ymm{}, %ymm{}".format(i * 32, N_min_1, t))
mod3(t, retval)
p("vmovdqa %ymm{}, {}(%rdi)".format(retval, i*32))

p("ret")
1 change: 0 additions & 1 deletion avx2-hps2048509/crypto_sort.c

This file was deleted.

10 changes: 10 additions & 0 deletions avx2-hps2048509/crypto_sort.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#include <stdint.h>
#include "crypto_sort.h"
#include "djbsort/int32_sort.h"

// The avx2 version of djbsort (version 20180729) is used for sorting.
// The source can be found at https://sorting.cr.yp.to/
void crypto_sort(void *array,long long n)
{
int32_sort(array, n);
}
13 changes: 13 additions & 0 deletions avx2-hps2048509/djbsort/int32_minmax_x86.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#define int32_MINMAX(a,b) \
do { \
int32 temp1; \
asm( \
"cmpl %1,%0\n\t" \
"mov %0,%2\n\t" \
"cmovg %1,%0\n\t" \
"cmovg %2,%1\n\t" \
: "+r"(a), "+r"(b), "=r"(temp1) \
: \
: "cc" \
); \
} while(0)
25 changes: 25 additions & 0 deletions avx2-hps2048509/djbsort/int32_sort.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef int32_sort_H
#define int32_sort_H

#include <stdint.h>

#define int32_sort djbsort_int32
#define int32_sort_implementation djbsort_int32_implementation
#define int32_sort_version djbsort_int32_version
#define int32_sort_compiler djbsort_int32_compiler

#ifdef __cplusplus
extern "C" {
#endif

extern void int32_sort(int32_t *,long long) __attribute__((visibility("default")));

extern const char int32_sort_implementation[] __attribute__((visibility("default")));
extern const char int32_sort_version[] __attribute__((visibility("default")));
extern const char int32_sort_compiler[] __attribute__((visibility("default")));

#ifdef __cplusplus
}
#endif

#endif
Loading

0 comments on commit 9707e9c

Please sign in to comment.