From 8be98a28801a87270db542d5e415ac1cb5fae3bb Mon Sep 17 00:00:00 2001 From: Andy Polyakov Date: Wed, 7 Feb 2024 11:25:28 +0100 Subject: [PATCH 1/8] ntt/parameters.cuh: pack device twiddles to an array. --- ntt/kernels/ct_mixed_radix_narrow.cu | 20 ++------------------ ntt/kernels/ct_mixed_radix_wide.cu | 7 +------ ntt/kernels/gs_mixed_radix_narrow.cu | 20 ++------------------ ntt/kernels/gs_mixed_radix_wide.cu | 16 +++++----------- ntt/parameters.cuh | 22 +++++++++++++--------- 5 files changed, 23 insertions(+), 62 deletions(-) diff --git a/ntt/kernels/ct_mixed_radix_narrow.cu b/ntt/kernels/ct_mixed_radix_narrow.cu index b01dce9..ca933ca 100644 --- a/ntt/kernels/ct_mixed_radix_narrow.cu +++ b/ntt/kernels/ct_mixed_radix_narrow.cu @@ -207,28 +207,12 @@ public: assert(num_blocks == (unsigned int)num_blocks); - fr_t* d_radixX_twiddles = nullptr; - - switch (radix) { - case 7: - d_radixX_twiddles = ntt_parameters.radix7_twiddles; - break; - case 8: - d_radixX_twiddles = ntt_parameters.radix8_twiddles; - break; - case 9: - d_radixX_twiddles = ntt_parameters.radix9_twiddles; - break; - case 10: - d_radixX_twiddles = ntt_parameters.radix10_twiddles; - break; - } - const int Z_COUNT = 256/8/sizeof(fr_t); size_t shared_sz = sizeof(fr_t) << (radix - 1); + #define NTT_ARGUMENTS radix, lg_domain_size, stage, iterations, \ d_inout, ntt_parameters.partial_twiddles, \ - ntt_parameters.radix6_twiddles, d_radixX_twiddles, \ + ntt_parameters.twiddles[0], ntt_parameters.twiddles[radix-6], \ is_intt, domain_size_inverse[lg_domain_size] if (num_blocks < Z_COUNT) diff --git a/ntt/kernels/ct_mixed_radix_wide.cu b/ntt/kernels/ct_mixed_radix_wide.cu index 8934218..9d66577 100644 --- a/ntt/kernels/ct_mixed_radix_wide.cu +++ b/ntt/kernels/ct_mixed_radix_wide.cu @@ -163,7 +163,6 @@ public: assert(num_blocks == (unsigned int)num_blocks); - fr_t* d_radixX_twiddles = nullptr; fr_t* d_intermediate_twiddles = nullptr; unsigned int intermediate_twiddle_shift = 0; @@ -172,7 +171,7 @@ public: #define NTT_ARGUMENTS radix, lg_domain_size, stage, iterations, \ d_inout, ntt_parameters.partial_twiddles, \ - ntt_parameters.radix6_twiddles, d_radixX_twiddles, \ + ntt_parameters.twiddles[0], ntt_parameters.twiddles[radix-6], \ d_intermediate_twiddles, intermediate_twiddle_shift, \ is_intt, domain_size_inverse[lg_domain_size] @@ -198,7 +197,6 @@ public: } break; case 7: - d_radixX_twiddles = ntt_parameters.radix7_twiddles; switch (stage) { case 0: _CT_NTT<0><<>>(NTT_ARGUMENTS); @@ -214,7 +212,6 @@ public: } break; case 8: - d_radixX_twiddles = ntt_parameters.radix8_twiddles; switch (stage) { case 0: _CT_NTT<0><<>>(NTT_ARGUMENTS); @@ -230,7 +227,6 @@ public: } break; case 9: - d_radixX_twiddles = ntt_parameters.radix9_twiddles; switch (stage) { case 0: _CT_NTT<0><<>>(NTT_ARGUMENTS); @@ -246,7 +242,6 @@ public: } break; case 10: - d_radixX_twiddles = ntt_parameters.radix10_twiddles; switch (stage) { case 0: _CT_NTT<0><<>>(NTT_ARGUMENTS); diff --git a/ntt/kernels/gs_mixed_radix_narrow.cu b/ntt/kernels/gs_mixed_radix_narrow.cu index 365debc..e148436 100644 --- a/ntt/kernels/gs_mixed_radix_narrow.cu +++ b/ntt/kernels/gs_mixed_radix_narrow.cu @@ -210,28 +210,12 @@ public: assert(num_blocks == (unsigned int)num_blocks); - fr_t* d_radixX_twiddles = nullptr; - - switch (radix) { - case 7: - d_radixX_twiddles = ntt_parameters.radix7_twiddles; - break; - case 8: - d_radixX_twiddles = ntt_parameters.radix8_twiddles; - break; - case 9: - d_radixX_twiddles = ntt_parameters.radix9_twiddles; - break; - case 10: - d_radixX_twiddles = ntt_parameters.radix10_twiddles; - break; - } - const int Z_COUNT = 256/8/sizeof(fr_t); size_t shared_sz = sizeof(fr_t) << (radix - 1); + #define NTT_ARGUMENTS radix, lg_domain_size, stage, iterations, \ d_inout, ntt_parameters.partial_twiddles, \ - ntt_parameters.radix6_twiddles, d_radixX_twiddles, \ + ntt_parameters.twiddles[0], ntt_parameters.twiddles[radix-6], \ is_intt, domain_size_inverse[lg_domain_size] if (num_blocks < Z_COUNT) diff --git a/ntt/kernels/gs_mixed_radix_wide.cu b/ntt/kernels/gs_mixed_radix_wide.cu index 00ccac2..23d7336 100644 --- a/ntt/kernels/gs_mixed_radix_wide.cu +++ b/ntt/kernels/gs_mixed_radix_wide.cu @@ -159,19 +159,17 @@ public: assert(num_blocks == (unsigned int)num_blocks); - fr_t* d_radixX_twiddles = nullptr; fr_t* d_intermediate_twiddles = nullptr; int intermediate_twiddle_shift = 0; #define NTT_CONFIGURATION \ num_blocks, block_size, sizeof(fr_t) * block_size, stream - #define NTT_ARGUMENTS \ - radix, lg_domain_size, stage, iterations, d_inout, \ - ntt_parameters.partial_twiddles, ntt_parameters.radix6_twiddles, \ - d_radixX_twiddles, d_intermediate_twiddles, \ - intermediate_twiddle_shift, \ - is_intt, domain_size_inverse[lg_domain_size] + #define NTT_ARGUMENTS radix, lg_domain_size, stage, iterations, \ + d_inout, ntt_parameters.partial_twiddles, \ + ntt_parameters.twiddles[0], ntt_parameters.twiddles[radix-6], \ + d_intermediate_twiddles, intermediate_twiddle_shift, \ + is_intt, domain_size_inverse[lg_domain_size] switch (radix) { case 6: @@ -195,7 +193,6 @@ public: } break; case 7: - d_radixX_twiddles = ntt_parameters.radix7_twiddles; switch (stage) { case 7: _GS_NTT<0><<>>(NTT_ARGUMENTS); @@ -211,7 +208,6 @@ public: } break; case 8: - d_radixX_twiddles = ntt_parameters.radix8_twiddles; switch (stage) { case 8: _GS_NTT<0><<>>(NTT_ARGUMENTS); @@ -227,7 +223,6 @@ public: } break; case 9: - d_radixX_twiddles = ntt_parameters.radix9_twiddles; switch (stage) { case 9: _GS_NTT<0><<>>(NTT_ARGUMENTS); @@ -243,7 +238,6 @@ public: } break; case 10: - d_radixX_twiddles = ntt_parameters.radix10_twiddles; switch (stage) { case 10: _GS_NTT<0><<>>(NTT_ARGUMENTS); diff --git a/ntt/parameters.cuh b/ntt/parameters.cuh index d036c7c..2a70171 100644 --- a/ntt/parameters.cuh +++ b/ntt/parameters.cuh @@ -168,8 +168,7 @@ private: public: fr_t (*partial_twiddles)[WINDOW_SIZE]; - fr_t* radix6_twiddles, * radix7_twiddles, * radix8_twiddles, - * radix9_twiddles, * radix10_twiddles; + fr_t* twiddles[5]; fr_t (*partial_group_gen_powers)[WINDOW_SIZE]; // for LDE @@ -196,15 +195,19 @@ public: const size_t blob_sz = 64 + 128 + 256 + 512 + 32; + fr_t* radix6_twiddles; CUDA_OK(cudaGetSymbolAddress((void**)&radix6_twiddles, inverse ? inverse_radix6_twiddles : forward_radix6_twiddles)); - radix7_twiddles = (fr_t*)gpu.Dmalloc(blob_sz * sizeof(fr_t)); - radix8_twiddles = radix7_twiddles + 64; - radix9_twiddles = radix8_twiddles + 128; - radix10_twiddles = radix9_twiddles + 256; + fr_t* blob = (fr_t*)gpu.Dmalloc(blob_sz * sizeof(fr_t)); - generate_all_twiddles<<>>(radix7_twiddles, + twiddles[0] = radix6_twiddles; + twiddles[1] = blob; /* radix7_twiddles */ + twiddles[2] = twiddles[1] + 64; /* radix8_twiddles */ + twiddles[3] = twiddles[2] + 128; /* radix9_twiddles */ + twiddles[4] = twiddles[3] + 256; /* radix10_twiddles */ + + generate_all_twiddles<<>>(blob, roots[6], roots[7], roots[8], @@ -212,7 +215,8 @@ public: roots[10]); CUDA_OK(cudaGetLastError()); - CUDA_OK(cudaMemcpyAsync(radix6_twiddles, radix10_twiddles + 512, + /* copy to the constant segment */ + CUDA_OK(cudaMemcpyAsync(radix6_twiddles, twiddles[4] + 512, 32 * sizeof(fr_t), cudaMemcpyDeviceToDevice, gpu)); @@ -258,7 +262,7 @@ public: gpu.Dfree(radix6_twiddles_6); #endif - gpu.Dfree(radix7_twiddles); + gpu.Dfree(twiddles[1]); cudaSetDevice(current_id); } From 2000968db77f625e9dd1d9755060c01ceee46050 Mon Sep 17 00:00:00 2001 From: Andy Polyakov Date: Wed, 7 Feb 2024 11:39:12 +0100 Subject: [PATCH 2/8] ntt/kernels/*_wide.cu: rethink |intermediate_twiddle_shift|. --- ntt/kernels/ct_mixed_radix_wide.cu | 18 +++++++++--------- ntt/kernels/gs_mixed_radix_wide.cu | 16 ++++++++-------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/ntt/kernels/ct_mixed_radix_wide.cu b/ntt/kernels/ct_mixed_radix_wide.cu index 9d66577..550e8c5 100644 --- a/ntt/kernels/ct_mixed_radix_wide.cu +++ b/ntt/kernels/ct_mixed_radix_wide.cu @@ -54,13 +54,13 @@ void _CT_NTT(const unsigned int radix, const unsigned int lg_domain_size, } else if (intermediate_mul == 2) { unsigned int diff_mask = (1 << (iterations - 1)) - 1; unsigned int thread_ntt_idx = (tid & diff_mask) * 2; - unsigned int nbits = intermediate_twiddle_shift + iterations; + unsigned int nbits = intermediate_twiddle_shift; index_t root_idx0 = bit_rev(thread_ntt_idx, nbits); index_t root_idx1 = bit_rev(thread_ntt_idx + 1, nbits); - fr_t t0 = d_intermediate_twiddles[(thread_ntt_pos << radix) + root_idx0]; - fr_t t1 = d_intermediate_twiddles[(thread_ntt_pos << radix) + root_idx1]; + fr_t t0 = d_intermediate_twiddles[(thread_ntt_pos << nbits) + root_idx0]; + fr_t t1 = d_intermediate_twiddles[(thread_ntt_pos << nbits) + root_idx1]; r0 *= t0; r1 *= t1; @@ -182,12 +182,12 @@ public: _CT_NTT<0><<>>(NTT_ARGUMENTS); break; case 6: - intermediate_twiddle_shift = std::max(12 - lg_domain_size, 0); + intermediate_twiddle_shift = 6; d_intermediate_twiddles = ntt_parameters.radix6_twiddles_6; _CT_NTT<2><<>>(NTT_ARGUMENTS); break; case 12: - intermediate_twiddle_shift = std::max(18 - lg_domain_size, 0); + intermediate_twiddle_shift = 6; d_intermediate_twiddles = ntt_parameters.radix6_twiddles_12; _CT_NTT<2><<>>(NTT_ARGUMENTS); break; @@ -202,7 +202,7 @@ public: _CT_NTT<0><<>>(NTT_ARGUMENTS); break; case 7: - intermediate_twiddle_shift = std::max(14 - lg_domain_size, 0); + intermediate_twiddle_shift = 7; d_intermediate_twiddles = ntt_parameters.radix7_twiddles_7; _CT_NTT<2><<>>(NTT_ARGUMENTS); break; @@ -217,7 +217,7 @@ public: _CT_NTT<0><<>>(NTT_ARGUMENTS); break; case 8: - intermediate_twiddle_shift = std::max(16 - lg_domain_size, 0); + intermediate_twiddle_shift = 8; d_intermediate_twiddles = ntt_parameters.radix8_twiddles_8; _CT_NTT<2><<>>(NTT_ARGUMENTS); break; @@ -232,7 +232,7 @@ public: _CT_NTT<0><<>>(NTT_ARGUMENTS); break; case 9: - intermediate_twiddle_shift = std::max(18 - lg_domain_size, 0); + intermediate_twiddle_shift = 9; d_intermediate_twiddles = ntt_parameters.radix9_twiddles_9; _CT_NTT<2><<>>(NTT_ARGUMENTS); break; @@ -260,7 +260,7 @@ public: CUDA_OK(cudaGetLastError()); - stage += radix; + stage += iterations; } }; diff --git a/ntt/kernels/gs_mixed_radix_wide.cu b/ntt/kernels/gs_mixed_radix_wide.cu index 23d7336..d18e999 100644 --- a/ntt/kernels/gs_mixed_radix_wide.cu +++ b/ntt/kernels/gs_mixed_radix_wide.cu @@ -99,13 +99,13 @@ void _GS_NTT(const unsigned int radix, const unsigned int lg_domain_size, index_t thread_ntt_pos = (tid & inp_mask) >> (iterations - 1); unsigned int diff_mask = (1 << (iterations - 1)) - 1; unsigned int thread_ntt_idx = (tid & diff_mask) * 2; - unsigned int nbits = intermediate_twiddle_shift + iterations; + unsigned int nbits = intermediate_twiddle_shift; index_t root_idx0 = bit_rev(thread_ntt_idx, nbits); index_t root_idx1 = bit_rev(thread_ntt_idx + 1, nbits); - fr_t t0 = d_intermediate_twiddles[(thread_ntt_pos << radix) + root_idx0]; - fr_t t1 = d_intermediate_twiddles[(thread_ntt_pos << radix) + root_idx1]; + fr_t t0 = d_intermediate_twiddles[(thread_ntt_pos << nbits) + root_idx0]; + fr_t t1 = d_intermediate_twiddles[(thread_ntt_pos << nbits) + root_idx1]; r0 *= t0; r1 *= t1; @@ -178,12 +178,12 @@ public: _GS_NTT<0><<>>(NTT_ARGUMENTS); break; case 12: - intermediate_twiddle_shift = std::max(12 - lg_domain_size, 0); + intermediate_twiddle_shift = 6; d_intermediate_twiddles = ntt_parameters.radix6_twiddles_6; _GS_NTT<2><<>>(NTT_ARGUMENTS); break; case 18: - intermediate_twiddle_shift = std::max(18 - lg_domain_size, 0); + intermediate_twiddle_shift = 6; d_intermediate_twiddles = ntt_parameters.radix6_twiddles_12; _GS_NTT<2><<>>(NTT_ARGUMENTS); break; @@ -198,7 +198,7 @@ public: _GS_NTT<0><<>>(NTT_ARGUMENTS); break; case 14: - intermediate_twiddle_shift = std::max(14 - lg_domain_size, 0); + intermediate_twiddle_shift = 7; d_intermediate_twiddles = ntt_parameters.radix7_twiddles_7; _GS_NTT<2><<>>(NTT_ARGUMENTS); break; @@ -213,7 +213,7 @@ public: _GS_NTT<0><<>>(NTT_ARGUMENTS); break; case 16: - intermediate_twiddle_shift = std::max(16 - lg_domain_size, 0); + intermediate_twiddle_shift = 8; d_intermediate_twiddles = ntt_parameters.radix8_twiddles_8; _GS_NTT<2><<>>(NTT_ARGUMENTS); break; @@ -228,7 +228,7 @@ public: _GS_NTT<0><<>>(NTT_ARGUMENTS); break; case 18: - intermediate_twiddle_shift = std::max(18 - lg_domain_size, 0); + intermediate_twiddle_shift = 9; d_intermediate_twiddles = ntt_parameters.radix9_twiddles_9; _GS_NTT<2><<>>(NTT_ARGUMENTS); break; From 3868842c25b1d28ea0189772e6f504d219454af5 Mon Sep 17 00:00:00 2001 From: Andy Polyakov Date: Wed, 7 Feb 2024 12:01:04 +0100 Subject: [PATCH 3/8] ntt/kernels/*_wide.cu: refactor the launcher switch. This maximizes kernel<2> invocations. --- ntt/kernels/ct_mixed_radix_wide.cu | 66 +++++++++--------------------- ntt/kernels/gs_mixed_radix_wide.cu | 66 +++++++++--------------------- 2 files changed, 40 insertions(+), 92 deletions(-) diff --git a/ntt/kernels/ct_mixed_radix_wide.cu b/ntt/kernels/ct_mixed_radix_wide.cu index 550e8c5..0a32faf 100644 --- a/ntt/kernels/ct_mixed_radix_wide.cu +++ b/ntt/kernels/ct_mixed_radix_wide.cu @@ -175,84 +175,58 @@ public: d_intermediate_twiddles, intermediate_twiddle_shift, \ is_intt, domain_size_inverse[lg_domain_size] - switch (radix) { + switch (stage) { + case 0: + _CT_NTT<0><<>>(NTT_ARGUMENTS); + break; case 6: - switch (stage) { - case 0: - _CT_NTT<0><<>>(NTT_ARGUMENTS); - break; - case 6: + if (iterations <= 6) { intermediate_twiddle_shift = 6; d_intermediate_twiddles = ntt_parameters.radix6_twiddles_6; _CT_NTT<2><<>>(NTT_ARGUMENTS); - break; - case 12: - intermediate_twiddle_shift = 6; - d_intermediate_twiddles = ntt_parameters.radix6_twiddles_12; - _CT_NTT<2><<>>(NTT_ARGUMENTS); - break; - default: + } else { _CT_NTT<1><<>>(NTT_ARGUMENTS); - break; } break; case 7: - switch (stage) { - case 0: - _CT_NTT<0><<>>(NTT_ARGUMENTS); - break; - case 7: + if (iterations <= 7) { intermediate_twiddle_shift = 7; d_intermediate_twiddles = ntt_parameters.radix7_twiddles_7; _CT_NTT<2><<>>(NTT_ARGUMENTS); - break; - default: + } else { _CT_NTT<1><<>>(NTT_ARGUMENTS); - break; } break; case 8: - switch (stage) { - case 0: - _CT_NTT<0><<>>(NTT_ARGUMENTS); - break; - case 8: + if (iterations <= 8) { intermediate_twiddle_shift = 8; d_intermediate_twiddles = ntt_parameters.radix8_twiddles_8; _CT_NTT<2><<>>(NTT_ARGUMENTS); - break; - default: + } else { _CT_NTT<1><<>>(NTT_ARGUMENTS); - break; } break; case 9: - switch (stage) { - case 0: - _CT_NTT<0><<>>(NTT_ARGUMENTS); - break; - case 9: + if (iterations <= 9) { intermediate_twiddle_shift = 9; d_intermediate_twiddles = ntt_parameters.radix9_twiddles_9; _CT_NTT<2><<>>(NTT_ARGUMENTS); - break; - default: + } else { _CT_NTT<1><<>>(NTT_ARGUMENTS); - break; } break; - case 10: - switch (stage) { - case 0: - _CT_NTT<0><<>>(NTT_ARGUMENTS); - break; - default: + case 12: + if (iterations <= 6) { + intermediate_twiddle_shift = 6; + d_intermediate_twiddles = ntt_parameters.radix6_twiddles_12; + _CT_NTT<2><<>>(NTT_ARGUMENTS); + } else { _CT_NTT<1><<>>(NTT_ARGUMENTS); - break; } break; default: - assert(false); + _CT_NTT<1><<>>(NTT_ARGUMENTS); + break; } #undef NTT_CONFIGURATION diff --git a/ntt/kernels/gs_mixed_radix_wide.cu b/ntt/kernels/gs_mixed_radix_wide.cu index d18e999..829b77f 100644 --- a/ntt/kernels/gs_mixed_radix_wide.cu +++ b/ntt/kernels/gs_mixed_radix_wide.cu @@ -171,84 +171,58 @@ public: d_intermediate_twiddles, intermediate_twiddle_shift, \ is_intt, domain_size_inverse[lg_domain_size] - switch (radix) { + switch (stage - iterations) { + case 0: + _GS_NTT<0><<>>(NTT_ARGUMENTS); + break; case 6: - switch (stage) { - case 6: - _GS_NTT<0><<>>(NTT_ARGUMENTS); - break; - case 12: + if (iterations <= 6) { intermediate_twiddle_shift = 6; d_intermediate_twiddles = ntt_parameters.radix6_twiddles_6; _GS_NTT<2><<>>(NTT_ARGUMENTS); - break; - case 18: - intermediate_twiddle_shift = 6; - d_intermediate_twiddles = ntt_parameters.radix6_twiddles_12; - _GS_NTT<2><<>>(NTT_ARGUMENTS); - break; - default: + } else { _GS_NTT<1><<>>(NTT_ARGUMENTS); - break; } break; case 7: - switch (stage) { - case 7: - _GS_NTT<0><<>>(NTT_ARGUMENTS); - break; - case 14: + if (iterations <= 7) { intermediate_twiddle_shift = 7; d_intermediate_twiddles = ntt_parameters.radix7_twiddles_7; _GS_NTT<2><<>>(NTT_ARGUMENTS); - break; - default: + } else { _GS_NTT<1><<>>(NTT_ARGUMENTS); - break; } break; case 8: - switch (stage) { - case 8: - _GS_NTT<0><<>>(NTT_ARGUMENTS); - break; - case 16: + if (iterations <= 8) { intermediate_twiddle_shift = 8; d_intermediate_twiddles = ntt_parameters.radix8_twiddles_8; _GS_NTT<2><<>>(NTT_ARGUMENTS); - break; - default: + } else { _GS_NTT<1><<>>(NTT_ARGUMENTS); - break; } break; case 9: - switch (stage) { - case 9: - _GS_NTT<0><<>>(NTT_ARGUMENTS); - break; - case 18: + if (iterations <= 9) { intermediate_twiddle_shift = 9; d_intermediate_twiddles = ntt_parameters.radix9_twiddles_9; _GS_NTT<2><<>>(NTT_ARGUMENTS); - break; - default: + } else { _GS_NTT<1><<>>(NTT_ARGUMENTS); - break; } break; - case 10: - switch (stage) { - case 10: - _GS_NTT<0><<>>(NTT_ARGUMENTS); - break; - default: + case 12: + if (iterations <= 6) { + intermediate_twiddle_shift = 6; + d_intermediate_twiddles = ntt_parameters.radix6_twiddles_12; + _GS_NTT<2><<>>(NTT_ARGUMENTS); + } else { _GS_NTT<1><<>>(NTT_ARGUMENTS); - break; } break; default: - assert(false); + _GS_NTT<1><<>>(NTT_ARGUMENTS); + break; } #undef NTT_CONFIGURATION From 3a63f6fc3cb2ba1ab5fde0c452bdb0086711d15f Mon Sep 17 00:00:00 2001 From: Andy Polyakov Date: Wed, 7 Feb 2024 12:41:35 +0100 Subject: [PATCH 4/8] ntt/kernels/*: deduplicate [GS|CT]_NTT launchers. --- ntt/kernels/ct_mixed_radix_narrow.cu | 28 ------------- ntt/kernels/ct_mixed_radix_wide.cu | 28 ------------- ntt/kernels/gs_mixed_radix_narrow.cu | 31 -------------- ntt/kernels/gs_mixed_radix_wide.cu | 31 -------------- ntt/ntt.cuh | 60 ++++++++++++++++++++++++++++ 5 files changed, 60 insertions(+), 118 deletions(-) diff --git a/ntt/kernels/ct_mixed_radix_narrow.cu b/ntt/kernels/ct_mixed_radix_narrow.cu index ca933ca..801d046 100644 --- a/ntt/kernels/ct_mixed_radix_narrow.cu +++ b/ntt/kernels/ct_mixed_radix_narrow.cu @@ -227,31 +227,3 @@ public: stage += iterations; } }; - -void CT_NTT(fr_t* d_inout, const int lg_domain_size, bool intt, - const NTTParameters& ntt_parameters, const stream_t& stream) -{ - CT_launcher params{d_inout, lg_domain_size, intt, ntt_parameters, stream}; - - if (lg_domain_size <= std::min(10, MAX_LG_DOMAIN_SIZE)) { - params.step(lg_domain_size); - } else if (lg_domain_size <= std::min(17, MAX_LG_DOMAIN_SIZE)) { - params.step(lg_domain_size / 2 + lg_domain_size % 2); - params.step(lg_domain_size / 2); - } else if (lg_domain_size <= std::min(30, MAX_LG_DOMAIN_SIZE)) { - int step = lg_domain_size / 3; - int rem = lg_domain_size % 3; - params.step(step); - params.step(step + (lg_domain_size == 29 ? 1 : 0)); - params.step(step + (lg_domain_size == 29 ? 1 : rem)); - } else if (lg_domain_size <= std::min(32, MAX_LG_DOMAIN_SIZE)) { - int step = lg_domain_size / 4; - int rem = lg_domain_size % 4; - params.step(step); - params.step(step); - params.step(step); - params.step(step + rem); - } else { - assert(false); - } -} diff --git a/ntt/kernels/ct_mixed_radix_wide.cu b/ntt/kernels/ct_mixed_radix_wide.cu index 0a32faf..6bee49e 100644 --- a/ntt/kernels/ct_mixed_radix_wide.cu +++ b/ntt/kernels/ct_mixed_radix_wide.cu @@ -237,31 +237,3 @@ public: stage += iterations; } }; - -void CT_NTT(fr_t* d_inout, const int lg_domain_size, bool intt, - const NTTParameters& ntt_parameters, const stream_t& stream) -{ - CT_launcher params{d_inout, lg_domain_size, intt, ntt_parameters, stream}; - - if (lg_domain_size <= 10) { - params.step(lg_domain_size); - } else if (lg_domain_size <= 17) { - params.step(lg_domain_size / 2 + lg_domain_size % 2); - params.step(lg_domain_size / 2); - } else if (lg_domain_size <= 30) { - int step = lg_domain_size / 3; - int rem = lg_domain_size % 3; - params.step(step); - params.step(step + (lg_domain_size == 29 ? 1 : 0)); - params.step(step + (lg_domain_size == 29 ? 1 : rem)); - } else if (lg_domain_size <= 40) { - int step = lg_domain_size / 4; - int rem = lg_domain_size % 4; - params.step(step); - params.step(step + (rem > 2)); - params.step(step + (rem > 1)); - params.step(step + (rem > 0)); - } else { - assert(false); - } -} diff --git a/ntt/kernels/gs_mixed_radix_narrow.cu b/ntt/kernels/gs_mixed_radix_narrow.cu index e148436..798f911 100644 --- a/ntt/kernels/gs_mixed_radix_narrow.cu +++ b/ntt/kernels/gs_mixed_radix_narrow.cu @@ -230,34 +230,3 @@ public: stage -= iterations; } }; - -void GS_NTT(fr_t* d_inout, const int lg_domain_size, bool intt, - const NTTParameters& ntt_parameters, const stream_t& stream) -{ - GS_launcher params{d_inout, lg_domain_size, intt, ntt_parameters, stream}; - - if (lg_domain_size <= std::min(10, MAX_LG_DOMAIN_SIZE)) { - params.step(lg_domain_size); - } else if (lg_domain_size <= std::min(12, MAX_LG_DOMAIN_SIZE)) { - params.step(lg_domain_size - 6); - params.step(6); - } else if (lg_domain_size <= std::min(18, MAX_LG_DOMAIN_SIZE)) { - params.step(lg_domain_size / 2 + lg_domain_size % 2); - params.step(lg_domain_size / 2); - } else if (lg_domain_size <= std::min(30, MAX_LG_DOMAIN_SIZE)) { - int step = lg_domain_size / 3; - int rem = lg_domain_size % 3; - params.step(step + (lg_domain_size == 29 ? 1 : rem)); - params.step(step + (lg_domain_size == 29 ? 1 : 0)); - params.step(step); - } else if (lg_domain_size <= std::min(32, MAX_LG_DOMAIN_SIZE)) { - int step = lg_domain_size / 4; - int rem = lg_domain_size % 4; - params.step(step + rem); - params.step(step); - params.step(step); - params.step(step); - } else { - assert(false); - } -} diff --git a/ntt/kernels/gs_mixed_radix_wide.cu b/ntt/kernels/gs_mixed_radix_wide.cu index 829b77f..34b2d3a 100644 --- a/ntt/kernels/gs_mixed_radix_wide.cu +++ b/ntt/kernels/gs_mixed_radix_wide.cu @@ -233,34 +233,3 @@ public: stage -= iterations; } }; - -void GS_NTT(fr_t* d_inout, const int lg_domain_size, const bool is_intt, - const NTTParameters& ntt_parameters, const stream_t& stream) -{ - GS_launcher params{d_inout, lg_domain_size, is_intt, ntt_parameters, stream}; - - if (lg_domain_size <= 10) { - params.step(lg_domain_size); - } else if (lg_domain_size <= 12) { - params.step(lg_domain_size - 6); - params.step(6); - } else if (lg_domain_size <= 18) { - params.step(lg_domain_size / 2 + lg_domain_size % 2); - params.step(lg_domain_size / 2); - } else if (lg_domain_size <= 30) { - int step = lg_domain_size / 3; - int rem = lg_domain_size % 3; - params.step(step + (rem > 0)); - params.step(step + (rem > 1)); - params.step(step); - } else if (lg_domain_size <= 40) { - int step = lg_domain_size / 4; - int rem = lg_domain_size % 4; - params.step(step + (rem > 0)); - params.step(step + (rem > 1)); - params.step(step + (rem > 2)); - params.step(step); - } else { - assert(false); - } -} diff --git a/ntt/ntt.cuh b/ntt/ntt.cuh index d140314..7b0dc6a 100644 --- a/ntt/ntt.cuh +++ b/ntt/ntt.cuh @@ -77,6 +77,66 @@ private: CUDA_OK(cudaGetLastError()); } + static void CT_NTT(fr_t* d_inout, const int lg_domain_size, bool intt, + const NTTParameters& ntt_parameters, + const stream_t& stream) + { + CT_launcher params{d_inout, lg_domain_size, intt, ntt_parameters, stream}; + + if (lg_domain_size <= 10) { + params.step(lg_domain_size); + } else if (lg_domain_size <= 17) { + int step = lg_domain_size / 2; + params.step(step + lg_domain_size % 2); + params.step(step); + } else if (lg_domain_size <= 30) { + int step = lg_domain_size / 3; + int rem = lg_domain_size % 3; + params.step(step); + params.step(step + (lg_domain_size == 29 ? 1 : 0)); + params.step(step + (lg_domain_size == 29 ? 1 : rem)); + } else if (lg_domain_size <= 40) { + int step = lg_domain_size / 4; + int rem = lg_domain_size % 4; + params.step(step); + params.step(step + (rem > 2)); + params.step(step + (rem > 1)); + params.step(step + (rem > 0)); + } else { + assert(false); + } + } + + static void GS_NTT(fr_t* d_inout, const int lg_domain_size, const bool is_intt, + const NTTParameters& ntt_parameters, + const stream_t& stream) + { + GS_launcher params{d_inout, lg_domain_size, is_intt, ntt_parameters, stream}; + + if (lg_domain_size <= 10) { + params.step(lg_domain_size); + } else if (lg_domain_size <= 17) { + int step = lg_domain_size / 2; + params.step(step); + params.step(step + lg_domain_size % 2); + } else if (lg_domain_size <= 30) { + int step = lg_domain_size / 3; + int rem = lg_domain_size % 3; + params.step(step + (lg_domain_size == 29 ? 1 : rem)); + params.step(step + (lg_domain_size == 29 ? 1 : 0)); + params.step(step); + } else if (lg_domain_size <= 40) { + int step = lg_domain_size / 4; + int rem = lg_domain_size % 4; + params.step(step + (rem > 0)); + params.step(step + (rem > 1)); + params.step(step + (rem > 2)); + params.step(step); + } else { + assert(false); + } + } + protected: static void NTT_internal(fr_t* d_inout, uint32_t lg_domain_size, InputOutputOrder order, Direction direction, From d2b406e845bfc127a92ec46d322c4fbb772067b3 Mon Sep 17 00:00:00 2001 From: Andy Polyakov Date: Tue, 13 Feb 2024 13:54:45 +0100 Subject: [PATCH 5/8] util/vec2d_t.hpp: let application control dim_x type. --- util/vec2d_t.hpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/util/vec2d_t.hpp b/util/vec2d_t.hpp index ad33ea6..ac237fc 100644 --- a/util/vec2d_t.hpp +++ b/util/vec2d_t.hpp @@ -12,19 +12,19 @@ # define __device__ #endif -template class vec2d_t { - uint32_t dim_x; +template class vec2d_t { + dim_t dim_x; bool owned; T* ptr; public: __host__ __device__ - vec2d_t(T* data, uint32_t x) : dim_x(x), owned(false), ptr(data) {} - vec2d_t(void* data, uint32_t x) : dim_x(x), owned(false), ptr((T*)data) {} - vec2d_t(uint32_t x, size_t y) : dim_x(x), owned(true), ptr(new T[x*y]) {} + vec2d_t(T* data, dim_t x) : dim_x(x), owned(false), ptr(data) {} + vec2d_t(void* data, dim_t x) : dim_x(x), owned(false), ptr((T*)data) {} + vec2d_t(dim_t x, size_t y) : dim_x(x), owned(true), ptr(new T[x*y]) {} vec2d_t() : dim_x(0), owned(false), ptr(nullptr) {} #ifndef __CUDA_ARCH__ - vec2d_t(const vec2d_t& other) { *this = other; } + vec2d_t(const vec2d_t& other) { *this = other; owned = false; } ~vec2d_t() { if (owned) delete[] ptr; } inline vec2d_t& operator=(const vec2d_t& other) @@ -44,7 +44,7 @@ template class vec2d_t { inline T* operator[](size_t y) const { return ptr + dim_x*y; } #ifndef NDEBUG - inline uint32_t x() { return dim_x; } + inline dim_t x() { return dim_x; } #endif }; From 65357d28a5a0eedee5a9a9440b9d99e84090b671 Mon Sep 17 00:00:00 2001 From: Andy Polyakov Date: Tue, 13 Feb 2024 13:56:16 +0100 Subject: [PATCH 6/8] ff/bb31_t.cuh: add heptaroot() method. --- ff/bb31_t.cuh | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/ff/bb31_t.cuh b/ff/bb31_t.cuh index fb2a04f..9eea5d9 100644 --- a/ff/bb31_t.cuh +++ b/ff/bb31_t.cuh @@ -304,7 +304,7 @@ public: } private: - static inline bb31_t sqr_n_mul(bb31_t s, uint32_t n, bb31_t m) + static inline bb31_t sqr_n(bb31_t s, uint32_t n) { #if 0 #pragma unroll 2 @@ -315,9 +315,9 @@ private: while (n--) { uint32_t tmp[2], red; - asm("mul.lo.u32 %0, %2, %3; mul.hi.u32 %1, %2, %3;" + asm("mul.lo.u32 %0, %2, %2; mul.hi.u32 %1, %2, %2;" : "=r"(tmp[0]), "=r"(tmp[1]) - : "r"(s.val), "r"(s.val)); + : "r"(s.val)); asm("mul.lo.u32 %0, %1, %2;" : "=r"(red) : "r"(tmp[0]), "r"(M)); asm("mad.lo.cc.u32 %0, %2, %3, %0; madc.hi.u32 %1, %2, %3, %4;" : "+r"(tmp[0]), "=r"(s.val) @@ -327,6 +327,12 @@ private: final_sub(s.val); } #endif + return s; + } + + static inline bb31_t sqr_n_mul(bb31_t s, uint32_t n, bb31_t m) + { + s = sqr_n(s, n); s.mul(m); return s; @@ -354,6 +360,23 @@ public: inline bb31_t& operator/=(const bb31_t a) { return *this *= a.reciprocal(); } + inline bb31_t heptaroot() const + { + bb31_t x03, x18, x1b, ret = *this; + + x03 = sqr_n_mul(ret, 1, ret); // 0b11 + x18 = sqr_n(x03, 3); // 0b11000 + x1b = x18*x03; // 0b11011 + ret = x18*x1b; // 0b110011 + ret = sqr_n_mul(ret, 6, x1b); // 0b110011011011 + ret = sqr_n_mul(ret, 6, x1b); // 0b110011011011011011 + ret = sqr_n_mul(ret, 6, x1b); // 0b110011011011011011011011 + ret = sqr_n_mul(ret, 6, x1b); // 0b110011011011011011011011011011 + ret = sqr_n_mul(ret, 1, *this); // 0b1100110110110110110110110110111 + + return ret; + } + inline void shfl_bfly(uint32_t laneMask) { val = __shfl_xor_sync(0xFFFFFFFF, val, laneMask); } }; From 655c638e1c6b69a366261bc63c97fdc73bdd445b Mon Sep 17 00:00:00 2001 From: Andy Polyakov Date: Tue, 13 Feb 2024 17:43:20 +0100 Subject: [PATCH 7/8] poc/ntt-cuda/tests/ntt.rs: make new tests tolerable in debug build. --- poc/ntt-cuda/tests/ntt.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/poc/ntt-cuda/tests/ntt.rs b/poc/ntt-cuda/tests/ntt.rs index bb8c1ba..3cae9b1 100644 --- a/poc/ntt-cuda/tests/ntt.rs +++ b/poc/ntt-cuda/tests/ntt.rs @@ -16,7 +16,7 @@ fn gl64_self_consistency() { fr % 0xffffffff00000001 } - for lg_domain_size in 1..28 { + for lg_domain_size in 1..24 + 4 * !cfg!(debug_assertions) as i32 { let domain_size = 1usize << lg_domain_size; let v: Vec = (0..domain_size).map(|_| random_fr()).collect(); @@ -52,7 +52,7 @@ fn bb31_self_consistency() { fr % 0x78000001 } - for lg_domain_size in 1..27 { + for lg_domain_size in 1..24 + 4 * !cfg!(debug_assertions) as i32 { let domain_size = 1usize << lg_domain_size; let v: Vec = (0..domain_size).map(|_| random_fr()).collect(); From 28272bec0be9cfa8cd42470059b8c70668ff34f1 Mon Sep 17 00:00:00 2001 From: Andy Polyakov Date: Tue, 13 Feb 2024 17:44:42 +0100 Subject: [PATCH 8/8] rust/Cargo.toml: bump the version. --- rust/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 2758a96..4b57b03 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sppark" -version = "0.1.5" +version = "0.1.6" edition = "2021" description = "Zero-knowledge template library" repository = "https://github.com/supranational/sppark"