-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from OussamaDanba/avx2-hps2048509_optimizations
More avx2 optimizations for hps2048509
- Loading branch information
Showing
11 changed files
with
1,836 additions
and
112 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from math import ceil | ||
|
||
p = print | ||
|
||
if __name__ == '__main__': | ||
p(".data") | ||
p(".align 32") | ||
|
||
p("mask_mod2048:") | ||
for i in range(16): | ||
p(".word 2047") | ||
|
||
p("mask_mod2048_omit_lowest:") | ||
p(".word 0") | ||
for i in range(15): | ||
p(".word 2047") | ||
|
||
p("mask_mod2048_only_lowest:") | ||
p(".word 2047") | ||
for i in range(15): | ||
p(".word 0") | ||
|
||
p("shuf_5_to_0_zerorest:") | ||
for i in range(2): | ||
p(".byte {}".format((i + 2*5) % 16)) | ||
for i in range(30): | ||
p(".byte 255") | ||
|
||
p(".text") | ||
p(".global poly_Rq_mul_x_minus_1") | ||
p(".att_syntax prefix") | ||
|
||
p("poly_Rq_mul_x_minus_1:") | ||
|
||
a_imin1 = 0 | ||
t0 = 1 | ||
t1 = 4 | ||
for i in range(ceil(509 / 16)-1, 0, -1): | ||
p("vmovdqu {}(%rsi), %ymm{}".format((i*16 - 1) * 2, a_imin1)) | ||
p("vpsubw {}(%rsi), %ymm{}, %ymm{}".format(i * 32, a_imin1, t0)) | ||
p("vpand mask_mod2048, %ymm{}, %ymm{}".format(t0, t0)) | ||
p("vmovdqa %ymm{}, {}(%rdi)".format(t0, i*32)) | ||
if i == ceil(509 / 16)-1: | ||
# a_imin1 now contains 495 to 510 inclusive; | ||
# we need 509 for [0], which is at position 14 | ||
p("vextracti128 $1, %ymm{}, %xmm{}".format(a_imin1, t1)) | ||
p("vpshufb shuf_5_to_0_zerorest, %ymm{}, %ymm{}".format(t1, t1)) | ||
p("vpsubw {}(%rsi), %ymm{}, %ymm{}".format(0, t1, t1)) | ||
p("vpand mask_mod2048_only_lowest, %ymm{}, %ymm{}".format(t1, t1)) | ||
|
||
# and now we still need to fix [1] to [15], which we cannot vmovdqu | ||
t2 = 0 | ||
t3 = 2 | ||
t4 = 3 | ||
p("vmovdqa {}(%rsi), %ymm{}".format(0, t4)) | ||
p("vpsrlq $48, %ymm{}, %ymm{}".format(t4, t2)) | ||
p("vpermq ${}, %ymm{}, %ymm{}".format(int('10010011', 2), t2, t2)) | ||
p("vpsllq $16, %ymm{}, %ymm{}".format(t4, t3)) | ||
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t2, t3, t3)) | ||
p("vpsubw %ymm{}, %ymm{}, %ymm{}".format(t4, t3, t4)) | ||
p("vpand mask_mod2048_omit_lowest, %ymm{}, %ymm{}".format(t4, t4)) | ||
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t4, t1, t4)) | ||
p("vmovdqa %ymm{}, {}(%rdi)".format(t4, 0)) | ||
|
||
p("ret") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
p = print | ||
|
||
def mod3(a, r=13, t=14, c=15): | ||
# r = (a >> 8) + (a & 0xff); // r mod 255 == a mod 255 | ||
p("vpsrlw $8, %ymm{}, %ymm{}".format(a, r)) | ||
p("vpand mask_ff, %ymm{}, %ymm{}".format(a, a)) | ||
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r)) | ||
|
||
# r = (r >> 4) + (r & 0xf); // r' mod 15 == r mod 15 | ||
p("vpand mask_f, %ymm{}, %ymm{}".format(r, a)) | ||
p("vpsrlw $4, %ymm{}, %ymm{}".format(r, r)) | ||
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r)) | ||
|
||
# r = (r >> 2) + (r & 0x3); // r' mod 3 == r mod 3 | ||
# r = (r >> 2) + (r & 0x3); // r' mod 3 == r mod 3 | ||
for _ in range(2): | ||
p("vpand mask_3, %ymm{}, %ymm{}".format(r, a)) | ||
p("vpsrlw $2, %ymm{}, %ymm{}".format(r, r)) | ||
p("vpaddw %ymm{}, %ymm{}, %ymm{}".format(r, a, r)) | ||
|
||
# t = r - 3; | ||
p("vpsubw mask_3, %ymm{}, %ymm{}".format(r, t)) | ||
# c = t >> 15; t is signed, so shift arithmetic | ||
p("vpsraw $15, %ymm{}, %ymm{}".format(t, c)) | ||
|
||
tmp = a | ||
# return (c&r) ^ (~c&t); | ||
p("vpandn %ymm{}, %ymm{}, %ymm{}".format(t, c, tmp)) | ||
p("vpand %ymm{}, %ymm{}, %ymm{}".format(c, r, t)) | ||
p("vpxor %ymm{}, %ymm{}, %ymm{}".format(t, tmp, r)) | ||
|
||
from math import ceil | ||
|
||
# Reuses poly_Rq_mul to do the multiplication. This works out since none of the | ||
# coefficients ever exceed 2048 (they can be 509*4=2036 at most). | ||
# A presumably faster alternative implementation would be to do Karatsuba recursively five times | ||
# such that you have 243 multiplications of 16 coefficient polynomials. These multiplications | ||
# can be done using a bitsliced implementation of the schoolbook method since the coefficients are two bits. | ||
# TODO: If poly_S3_mul shows up (significantly) in profiling consider the alternative implementation. | ||
if __name__ == '__main__': | ||
p(".data") | ||
p(".align 32") | ||
|
||
p("mask_ff:") | ||
for i in range(16): | ||
p(".word 0xff") | ||
p("mask_f:") | ||
for i in range(16): | ||
p(".word 0xf") | ||
p("mask_3:") | ||
for i in range(16): | ||
p(".word 0x03") | ||
|
||
p(".text") | ||
p(".global poly_S3_mul") | ||
p(".att_syntax prefix") | ||
|
||
p("poly_S3_mul:") | ||
p("call poly_Rq_mul") | ||
# result pointer *r is still in %rdi | ||
|
||
N_min_1 = 0 | ||
t = 1 | ||
# NTRU_N is in 509th element; 13th word of 32nd register | ||
p("vmovdqa {}(%rdi), %ymm{}".format(31*32, N_min_1)) | ||
p("vpermq ${}, %ymm{}, %ymm{}".format(int('00000011', 2), N_min_1, N_min_1)) | ||
# move into high 16 in doubleword (to clear high 16) and multiply by two | ||
p("vpslld $17, %ymm{}, %ymm{}".format(N_min_1, N_min_1)) | ||
# clone into bottom 16 | ||
p("vpsrld $16, %ymm{}, %ymm{}".format(N_min_1, t)) | ||
p("vpor %ymm{}, %ymm{}, %ymm{}".format(N_min_1, t, N_min_1)) | ||
# and now it's everywhere in N_min_1 | ||
p("vbroadcastss %xmm{}, %ymm{}".format(N_min_1, N_min_1)) | ||
|
||
retval = 2 | ||
for i in range(ceil(509 / 16)): | ||
p("vpaddw {}(%rdi), %ymm{}, %ymm{}".format(i * 32, N_min_1, t)) | ||
mod3(t, retval) | ||
p("vmovdqa %ymm{}, {}(%rdi)".format(retval, i*32)) | ||
|
||
p("ret") |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#include <stdint.h> | ||
#include "crypto_sort.h" | ||
#include "djbsort/int32_sort.h" | ||
|
||
// The avx2 version of djbsort (version 20180729) is used for sorting. | ||
// The source can be found at https://sorting.cr.yp.to/ | ||
void crypto_sort(void *array,long long n) | ||
{ | ||
int32_sort(array, n); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#define int32_MINMAX(a,b) \ | ||
do { \ | ||
int32 temp1; \ | ||
asm( \ | ||
"cmpl %1,%0\n\t" \ | ||
"mov %0,%2\n\t" \ | ||
"cmovg %1,%0\n\t" \ | ||
"cmovg %2,%1\n\t" \ | ||
: "+r"(a), "+r"(b), "=r"(temp1) \ | ||
: \ | ||
: "cc" \ | ||
); \ | ||
} while(0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#ifndef int32_sort_H | ||
#define int32_sort_H | ||
|
||
#include <stdint.h> | ||
|
||
#define int32_sort djbsort_int32 | ||
#define int32_sort_implementation djbsort_int32_implementation | ||
#define int32_sort_version djbsort_int32_version | ||
#define int32_sort_compiler djbsort_int32_compiler | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
extern void int32_sort(int32_t *,long long) __attribute__((visibility("default"))); | ||
|
||
extern const char int32_sort_implementation[] __attribute__((visibility("default"))); | ||
extern const char int32_sort_version[] __attribute__((visibility("default"))); | ||
extern const char int32_sort_compiler[] __attribute__((visibility("default"))); | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif | ||
|
||
#endif |
Oops, something went wrong.