Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More avx2 optimizations for hps2048509 #3

Merged
merged 5 commits into from
May 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions avx2-hps2048509/Makefile
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
CC = /usr/bin/cc
CFLAGS = -Wall -Wextra -Wpedantic -O3 -fomit-frame-pointer -march=native -no-pie

SOURCES = crypto_sort.c poly.c pack3.c packq.c fips202.c randombytes.c sample.c verify.c owcpa.c kem.c
HEADERS = crypto_sort.h params.h poly.h randombytes.h sample.h verify.h owcpa.h kem.h
SOURCES = crypto_sort.c djbsort/sort.c poly.c poly_s3_inv.c pack3.c packq.c fips202.c randombytes.c sample.c verify.c owcpa.c kem.c
HEADERS = crypto_sort.h djbsort/int32_sort.h poly_s3_inv.h params.h poly.h randombytes.h sample.h verify.h owcpa.h kem.h

OBJS = poly_rq_mul.s
OBJS = poly_rq_mul.s poly_s3_mul.s poly_rq_mul_x_minus_1.s

all: test/test_polymul \
test/test_ntru \
Expand Down
6 changes: 3 additions & 3 deletions avx2-hps2048509/Makefile-NIST
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ CC=/usr/bin/gcc
CFLAGS=-O3 -fomit-frame-pointer -march=native -no-pie
LDFLAGS=-lcrypto

SOURCES = crypto_sort.c fips202.c kem.c owcpa.c pack3.c packq.c poly.c sample.c verify.c rng.c PQCgenKAT_kem.c \
poly_rq_mul.s
SOURCES = crypto_sort.c djbsort/sort.c fips202.c kem.c owcpa.c pack3.c packq.c poly.c poly_s3_inv.c sample.c verify.c \
rng.c PQCgenKAT_kem.c poly_rq_mul.s poly_s3_mul.s poly_rq_mul_x_minus_1.s

HEADERS = api.h crypto_sort.h fips202.h kem.h poly.h owcpa.h params.h sample.h verify.h rng.h
HEADERS = api.h crypto_sort.h djbsort/int32_sort.h fips202.h kem.h poly.h poly_s3_inv.h owcpa.h params.h sample.h verify.h rng.h

PQCgenKAT_kem: $(HEADERS) $(SOURCES)
$(CC) $(CFLAGS) -o $@ $(SOURCES) $(LDFLAGS)
Expand Down
65 changes: 65 additions & 0 deletions avx2-hps2048509/asmgen/poly_rq_mul_x_minus_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from math import ceil

p = print

if __name__ == '__main__':
p(".data")
p(".align 32")

p("mask_mod2048:")
for i in range(16):
p(".word 2047")

p("mask_mod2048_omit_lowest:")
p(".word 0")
for i in range(15):
p(".word 2047")

p("mask_mod2048_only_lowest:")
p(".word 2047")
for i in range(15):
p(".word 0")

p("shuf_5_to_0_zerorest:")
for i in range(2):
p(".byte {}".format((i + 2*5) % 16))
for i in range(30):
p(".byte 255")

p(".text")
p(".global poly_Rq_mul_x_minus_1")
p(".att_syntax prefix")

p("poly_Rq_mul_x_minus_1:")

a_imin1 = 0
t0 = 1
t1 = 4
for i in range(ceil(509 / 16)-1, 0, -1):
p("vmovdqu {}(%rsi), %ymm{}".format((i*16 - 1) * 2, a_imin1))
p("vpsubw {}(%rsi), %ymm{}, %ymm{}".format(i * 32, a_imin1, t0))
p("vpand mask_mod2048, %ymm{}, %ymm{}".format(t0, t0))
p("vmovdqa %ymm{}, {}(%rdi)".format(t0, i*32))
if i == ceil(509 / 16)-1:
# a_imin1 now contains 495 to 510 inclusive;
# we need 509 for [0], which is at position 14
p("vextracti128 $1, %ymm{}, %xmm{}".format(a_imin1, t1))
p("vpshufb shuf_5_to_0_zerorest, %ymm{}, %ymm{}".format(t1, t1))
p("vpsubw {}(%rsi), %ymm{}, %ymm{}".format(0, t1, t1))
p("vpand mask_mod2048_only_lowest, %ymm{}, %ymm{}".format(t1, t1))

# and now we still need to fix [1] to [15], which we cannot vmovdqu
t2 = 0
t3 = 2
t4 = 3
p("vmovdqa {}(%rsi), %ymm{}".format(0, t4))
p("vpsrlq $48, %ymm{}, %ymm{}".format(t4, t2))
p("vpermq ${}, %ymm{}, %ymm{}".format(int('10010011', 2), t2, t2))
p("vpsllq $16, %ymm{}, %ymm{}".format(t4, t3))
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t2, t3, t3))
p("vpsubw %ymm{}, %ymm{}, %ymm{}".format(t4, t3, t4))
p("vpand mask_mod2048_omit_lowest, %ymm{}, %ymm{}".format(t4, t4))
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t4, t1, t4))
p("vmovdqa %ymm{}, {}(%rdi)".format(t4, 0))

p("ret")
81 changes: 81 additions & 0 deletions avx2-hps2048509/asmgen/poly_s3_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
p = print

def mod3(a, r=13, t=14, c=15):
# r = (a >> 8) + (a & 0xff); // r mod 255 == a mod 255
p("vpsrlw $8, %ymm{}, %ymm{}".format(a, r))
p("vpand mask_ff, %ymm{}, %ymm{}".format(a, a))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r))

# r = (r >> 4) + (r & 0xf); // r' mod 15 == r mod 15
p("vpand mask_f, %ymm{}, %ymm{}".format(r, a))
p("vpsrlw $4, %ymm{}, %ymm{}".format(r, r))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r))

# r = (r >> 2) + (r & 0x3); // r' mod 3 == r mod 3
# r = (r >> 2) + (r & 0x3); // r' mod 3 == r mod 3
for _ in range(2):
p("vpand mask_3, %ymm{}, %ymm{}".format(r, a))
p("vpsrlw $2, %ymm{}, %ymm{}".format(r, r))
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r))

# t = r - 3;
p("vpsubw mask_3, %ymm{}, %ymm{}".format(r, t))
# c = t >> 15; t is signed, so shift arithmetic
p("vpsraw $15, %ymm{}, %ymm{}".format(t, c))

tmp = a
# return (c&r) ^ (~c&t);
p("vpandn %ymm{}, %ymm{}, %ymm{}".format(t, c, tmp))
p("vpand %ymm{}, %ymm{}, %ymm{}".format(c, r, t))
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t, tmp, r))

from math import ceil

# Reuses poly_Rq_mul to do the multiplication. This works out since none of the
# coefficients ever exceed 2048 (they can be 509*4=2036 at most).
# A presumably faster alternative implementation would be to do Karatsuba recursively five times
# such that you have 243 multiplications of 16 coefficient polynomials. These multiplications
# can be done using a bitsliced implementation of the schoolbook method since the coefficients are two bits.
# TODO: If poly_S3_mul shows up (significantly) in profiling consider the alternative implementation.
if __name__ == '__main__':
p(".data")
p(".align 32")

p("mask_ff:")
for i in range(16):
p(".word 0xff")
p("mask_f:")
for i in range(16):
p(".word 0xf")
p("mask_3:")
for i in range(16):
p(".word 0x03")

p(".text")
p(".global poly_S3_mul")
p(".att_syntax prefix")

p("poly_S3_mul:")
p("call poly_Rq_mul")
# result pointer *r is still in %rdi

N_min_1 = 0
t = 1
# NTRU_N is in 509th element; 13th word of 32nd register
p("vmovdqa {}(%rdi), %ymm{}".format(31*32, N_min_1))
p("vpermq ${}, %ymm{}, %ymm{}".format(int('00000011', 2), N_min_1, N_min_1))
# move into high 16 in doubleword (to clear high 16) and multiply by two
p("vpslld $17, %ymm{}, %ymm{}".format(N_min_1, N_min_1))
# clone into bottom 16
p("vpsrld $16, %ymm{}, %ymm{}".format(N_min_1, t))
p("vpor %ymm{}, %ymm{}, %ymm{}".format(N_min_1, t, N_min_1))
# and now it's everywhere in N_min_1
p("vbroadcastss %xmm{}, %ymm{}".format(N_min_1, N_min_1))

retval = 2
for i in range(ceil(509 / 16)):
p("vpaddw {}(%rdi), %ymm{}, %ymm{}".format(i * 32, N_min_1, t))
mod3(t, retval)
p("vmovdqa %ymm{}, {}(%rdi)".format(retval, i*32))

p("ret")
1 change: 0 additions & 1 deletion avx2-hps2048509/crypto_sort.c

This file was deleted.

10 changes: 10 additions & 0 deletions avx2-hps2048509/crypto_sort.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#include <stdint.h>
#include "crypto_sort.h"
#include "djbsort/int32_sort.h"

// The avx2 version of djbsort (version 20180729) is used for sorting.
// The source can be found at https://sorting.cr.yp.to/
void crypto_sort(void *array,long long n)
{
int32_sort(array, n);
}
13 changes: 13 additions & 0 deletions avx2-hps2048509/djbsort/int32_minmax_x86.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#define int32_MINMAX(a,b) \
do { \
int32 temp1; \
asm( \
"cmpl %1,%0\n\t" \
"mov %0,%2\n\t" \
"cmovg %1,%0\n\t" \
"cmovg %2,%1\n\t" \
: "+r"(a), "+r"(b), "=r"(temp1) \
: \
: "cc" \
); \
} while(0)
25 changes: 25 additions & 0 deletions avx2-hps2048509/djbsort/int32_sort.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef int32_sort_H
#define int32_sort_H

#include <stdint.h>

#define int32_sort djbsort_int32
#define int32_sort_implementation djbsort_int32_implementation
#define int32_sort_version djbsort_int32_version
#define int32_sort_compiler djbsort_int32_compiler

#ifdef __cplusplus
extern "C" {
#endif

extern void int32_sort(int32_t *,long long) __attribute__((visibility("default")));

extern const char int32_sort_implementation[] __attribute__((visibility("default")));
extern const char int32_sort_version[] __attribute__((visibility("default")));
extern const char int32_sort_compiler[] __attribute__((visibility("default")));

#ifdef __cplusplus
}
#endif

#endif
Loading