Skip to content

Commit

Permalink
Merge branch 'supranational:main' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
winston-h-zhang authored Feb 20, 2024
2 parents efe394a + 28272be commit 50fc7d5
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 300 deletions.
29 changes: 26 additions & 3 deletions ff/bb31_t.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ public:
}

private:
static inline bb31_t sqr_n_mul(bb31_t s, uint32_t n, bb31_t m)
static inline bb31_t sqr_n(bb31_t s, uint32_t n)
{
#if 0
#pragma unroll 2
Expand All @@ -315,9 +315,9 @@ private:
while (n--) {
uint32_t tmp[2], red;

asm("mul.lo.u32 %0, %2, %3; mul.hi.u32 %1, %2, %3;"
asm("mul.lo.u32 %0, %2, %2; mul.hi.u32 %1, %2, %2;"
: "=r"(tmp[0]), "=r"(tmp[1])
: "r"(s.val), "r"(s.val));
: "r"(s.val));
asm("mul.lo.u32 %0, %1, %2;" : "=r"(red) : "r"(tmp[0]), "r"(M));
asm("mad.lo.cc.u32 %0, %2, %3, %0; madc.hi.u32 %1, %2, %3, %4;"
: "+r"(tmp[0]), "=r"(s.val)
Expand All @@ -327,6 +327,12 @@ private:
final_sub(s.val);
}
#endif
return s;
}

static inline bb31_t sqr_n_mul(bb31_t s, uint32_t n, bb31_t m)
{
s = sqr_n(s, n);
s.mul(m);

return s;
Expand Down Expand Up @@ -354,6 +360,23 @@ public:
inline bb31_t& operator/=(const bb31_t a)
{ return *this *= a.reciprocal(); }

inline bb31_t heptaroot() const
{
bb31_t x03, x18, x1b, ret = *this;

x03 = sqr_n_mul(ret, 1, ret); // 0b11
x18 = sqr_n(x03, 3); // 0b11000
x1b = x18*x03; // 0b11011
ret = x18*x1b; // 0b110011
ret = sqr_n_mul(ret, 6, x1b); // 0b110011011011
ret = sqr_n_mul(ret, 6, x1b); // 0b110011011011011011
ret = sqr_n_mul(ret, 6, x1b); // 0b110011011011011011011011
ret = sqr_n_mul(ret, 6, x1b); // 0b110011011011011011011011011011
ret = sqr_n_mul(ret, 1, *this); // 0b1100110110110110110110110110111

return ret;
}

inline void shfl_bfly(uint32_t laneMask)
{ val = __shfl_xor_sync(0xFFFFFFFF, val, laneMask); }
};
Expand Down
48 changes: 2 additions & 46 deletions ntt/kernels/ct_mixed_radix_narrow.cu
Original file line number Diff line number Diff line change
Expand Up @@ -207,28 +207,12 @@ public:

assert(num_blocks == (unsigned int)num_blocks);

fr_t* d_radixX_twiddles = nullptr;

switch (radix) {
case 7:
d_radixX_twiddles = ntt_parameters.radix7_twiddles;
break;
case 8:
d_radixX_twiddles = ntt_parameters.radix8_twiddles;
break;
case 9:
d_radixX_twiddles = ntt_parameters.radix9_twiddles;
break;
case 10:
d_radixX_twiddles = ntt_parameters.radix10_twiddles;
break;
}

const int Z_COUNT = 256/8/sizeof(fr_t);
size_t shared_sz = sizeof(fr_t) << (radix - 1);

#define NTT_ARGUMENTS radix, lg_domain_size, stage, iterations, \
d_inout, ntt_parameters.partial_twiddles, \
ntt_parameters.radix6_twiddles, d_radixX_twiddles, \
ntt_parameters.twiddles[0], ntt_parameters.twiddles[radix-6], \
is_intt, domain_size_inverse[lg_domain_size]

if (num_blocks < Z_COUNT)
Expand All @@ -243,31 +227,3 @@ public:
stage += iterations;
}
};

void CT_NTT(fr_t* d_inout, const int lg_domain_size, bool intt,
const NTTParameters& ntt_parameters, const stream_t& stream)
{
CT_launcher params{d_inout, lg_domain_size, intt, ntt_parameters, stream};

if (lg_domain_size <= std::min(10, MAX_LG_DOMAIN_SIZE)) {
params.step(lg_domain_size);
} else if (lg_domain_size <= std::min(17, MAX_LG_DOMAIN_SIZE)) {
params.step(lg_domain_size / 2 + lg_domain_size % 2);
params.step(lg_domain_size / 2);
} else if (lg_domain_size <= std::min(30, MAX_LG_DOMAIN_SIZE)) {
int step = lg_domain_size / 3;
int rem = lg_domain_size % 3;
params.step(step);
params.step(step + (lg_domain_size == 29 ? 1 : 0));
params.step(step + (lg_domain_size == 29 ? 1 : rem));
} else if (lg_domain_size <= std::min(32, MAX_LG_DOMAIN_SIZE)) {
int step = lg_domain_size / 4;
int rem = lg_domain_size % 4;
params.step(step);
params.step(step);
params.step(step);
params.step(step + rem);
} else {
assert(false);
}
}
117 changes: 29 additions & 88 deletions ntt/kernels/ct_mixed_radix_wide.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ void _CT_NTT(const unsigned int radix, const unsigned int lg_domain_size,
} else if (intermediate_mul == 2) {
unsigned int diff_mask = (1 << (iterations - 1)) - 1;
unsigned int thread_ntt_idx = (tid & diff_mask) * 2;
unsigned int nbits = intermediate_twiddle_shift + iterations;
unsigned int nbits = intermediate_twiddle_shift;

index_t root_idx0 = bit_rev(thread_ntt_idx, nbits);
index_t root_idx1 = bit_rev(thread_ntt_idx + 1, nbits);

fr_t t0 = d_intermediate_twiddles[(thread_ntt_pos << radix) + root_idx0];
fr_t t1 = d_intermediate_twiddles[(thread_ntt_pos << radix) + root_idx1];
fr_t t0 = d_intermediate_twiddles[(thread_ntt_pos << nbits) + root_idx0];
fr_t t1 = d_intermediate_twiddles[(thread_ntt_pos << nbits) + root_idx1];

r0 *= t0;
r1 *= t1;
Expand Down Expand Up @@ -163,7 +163,6 @@ public:

assert(num_blocks == (unsigned int)num_blocks);

fr_t* d_radixX_twiddles = nullptr;
fr_t* d_intermediate_twiddles = nullptr;
unsigned int intermediate_twiddle_shift = 0;

Expand All @@ -172,127 +171,69 @@ public:

#define NTT_ARGUMENTS radix, lg_domain_size, stage, iterations, \
d_inout, ntt_parameters.partial_twiddles, \
ntt_parameters.radix6_twiddles, d_radixX_twiddles, \
ntt_parameters.twiddles[0], ntt_parameters.twiddles[radix-6], \
d_intermediate_twiddles, intermediate_twiddle_shift, \
is_intt, domain_size_inverse[lg_domain_size]

switch (radix) {
switch (stage) {
case 0:
_CT_NTT<0><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
case 6:
switch (stage) {
case 0:
_CT_NTT<0><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
case 6:
intermediate_twiddle_shift = std::max(12 - lg_domain_size, 0);
if (iterations <= 6) {
intermediate_twiddle_shift = 6;
d_intermediate_twiddles = ntt_parameters.radix6_twiddles_6;
_CT_NTT<2><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
case 12:
intermediate_twiddle_shift = std::max(18 - lg_domain_size, 0);
d_intermediate_twiddles = ntt_parameters.radix6_twiddles_12;
_CT_NTT<2><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
default:
} else {
_CT_NTT<1><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
}
break;
case 7:
d_radixX_twiddles = ntt_parameters.radix7_twiddles;
switch (stage) {
case 0:
_CT_NTT<0><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
case 7:
intermediate_twiddle_shift = std::max(14 - lg_domain_size, 0);
if (iterations <= 7) {
intermediate_twiddle_shift = 7;
d_intermediate_twiddles = ntt_parameters.radix7_twiddles_7;
_CT_NTT<2><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
default:
} else {
_CT_NTT<1><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
}
break;
case 8:
d_radixX_twiddles = ntt_parameters.radix8_twiddles;
switch (stage) {
case 0:
_CT_NTT<0><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
case 8:
intermediate_twiddle_shift = std::max(16 - lg_domain_size, 0);
if (iterations <= 8) {
intermediate_twiddle_shift = 8;
d_intermediate_twiddles = ntt_parameters.radix8_twiddles_8;
_CT_NTT<2><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
default:
} else {
_CT_NTT<1><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
}
break;
case 9:
d_radixX_twiddles = ntt_parameters.radix9_twiddles;
switch (stage) {
case 0:
_CT_NTT<0><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
case 9:
intermediate_twiddle_shift = std::max(18 - lg_domain_size, 0);
if (iterations <= 9) {
intermediate_twiddle_shift = 9;
d_intermediate_twiddles = ntt_parameters.radix9_twiddles_9;
_CT_NTT<2><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
default:
} else {
_CT_NTT<1><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
}
break;
case 10:
d_radixX_twiddles = ntt_parameters.radix10_twiddles;
switch (stage) {
case 0:
_CT_NTT<0><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
default:
case 12:
if (iterations <= 6) {
intermediate_twiddle_shift = 6;
d_intermediate_twiddles = ntt_parameters.radix6_twiddles_12;
_CT_NTT<2><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
} else {
_CT_NTT<1><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
}
break;
default:
assert(false);
_CT_NTT<1><<<NTT_CONFIGURATION>>>(NTT_ARGUMENTS);
break;
}

#undef NTT_CONFIGURATION
#undef NTT_ARGUMENTS

CUDA_OK(cudaGetLastError());

stage += radix;
stage += iterations;
}
};

void CT_NTT(fr_t* d_inout, const int lg_domain_size, bool intt,
const NTTParameters& ntt_parameters, const stream_t& stream)
{
CT_launcher params{d_inout, lg_domain_size, intt, ntt_parameters, stream};

if (lg_domain_size <= 10) {
params.step(lg_domain_size);
} else if (lg_domain_size <= 17) {
params.step(lg_domain_size / 2 + lg_domain_size % 2);
params.step(lg_domain_size / 2);
} else if (lg_domain_size <= 30) {
int step = lg_domain_size / 3;
int rem = lg_domain_size % 3;
params.step(step);
params.step(step + (lg_domain_size == 29 ? 1 : 0));
params.step(step + (lg_domain_size == 29 ? 1 : rem));
} else if (lg_domain_size <= 40) {
int step = lg_domain_size / 4;
int rem = lg_domain_size % 4;
params.step(step);
params.step(step + (rem > 2));
params.step(step + (rem > 1));
params.step(step + (rem > 0));
} else {
assert(false);
}
}
51 changes: 2 additions & 49 deletions ntt/kernels/gs_mixed_radix_narrow.cu
Original file line number Diff line number Diff line change
Expand Up @@ -210,28 +210,12 @@ public:

assert(num_blocks == (unsigned int)num_blocks);

fr_t* d_radixX_twiddles = nullptr;

switch (radix) {
case 7:
d_radixX_twiddles = ntt_parameters.radix7_twiddles;
break;
case 8:
d_radixX_twiddles = ntt_parameters.radix8_twiddles;
break;
case 9:
d_radixX_twiddles = ntt_parameters.radix9_twiddles;
break;
case 10:
d_radixX_twiddles = ntt_parameters.radix10_twiddles;
break;
}

const int Z_COUNT = 256/8/sizeof(fr_t);
size_t shared_sz = sizeof(fr_t) << (radix - 1);

#define NTT_ARGUMENTS radix, lg_domain_size, stage, iterations, \
d_inout, ntt_parameters.partial_twiddles, \
ntt_parameters.radix6_twiddles, d_radixX_twiddles, \
ntt_parameters.twiddles[0], ntt_parameters.twiddles[radix-6], \
is_intt, domain_size_inverse[lg_domain_size]

if (num_blocks < Z_COUNT)
Expand All @@ -246,34 +230,3 @@ public:
stage -= iterations;
}
};

void GS_NTT(fr_t* d_inout, const int lg_domain_size, bool intt,
const NTTParameters& ntt_parameters, const stream_t& stream)
{
GS_launcher params{d_inout, lg_domain_size, intt, ntt_parameters, stream};

if (lg_domain_size <= std::min(10, MAX_LG_DOMAIN_SIZE)) {
params.step(lg_domain_size);
} else if (lg_domain_size <= std::min(12, MAX_LG_DOMAIN_SIZE)) {
params.step(lg_domain_size - 6);
params.step(6);
} else if (lg_domain_size <= std::min(18, MAX_LG_DOMAIN_SIZE)) {
params.step(lg_domain_size / 2 + lg_domain_size % 2);
params.step(lg_domain_size / 2);
} else if (lg_domain_size <= std::min(30, MAX_LG_DOMAIN_SIZE)) {
int step = lg_domain_size / 3;
int rem = lg_domain_size % 3;
params.step(step + (lg_domain_size == 29 ? 1 : rem));
params.step(step + (lg_domain_size == 29 ? 1 : 0));
params.step(step);
} else if (lg_domain_size <= std::min(32, MAX_LG_DOMAIN_SIZE)) {
int step = lg_domain_size / 4;
int rem = lg_domain_size % 4;
params.step(step + rem);
params.step(step);
params.step(step);
params.step(step);
} else {
assert(false);
}
}
Loading

0 comments on commit 50fc7d5

Please sign in to comment.