Skip to content

Commit

Permalink
ntt/*: optimize for AMD CDNA.
Browse files Browse the repository at this point in the history
  • Loading branch information
dot-asm committed Jan 28, 2025
1 parent 94705c4 commit f49ef48
Showing 4 changed files with 34 additions and 12 deletions.
19 changes: 16 additions & 3 deletions ntt/kernels.cu
Original file line number Diff line number Diff line change
@@ -64,6 +64,8 @@ void bit_rev_permutation_z(fr_t* out, const fr_t* in, uint32_t lg_domain_size)

#pragma unroll 1
do {
(Z_COUNT > warpSize) ? __syncthreads() : __syncwarp();

index_t group_idx = tid >> LG_Z_COUNT;
index_t group_rev = bit_rev(group_idx, lg_domain_size - 2*LG_Z_COUNT);

@@ -91,10 +93,17 @@ void bit_rev_permutation_z(fr_t* out, const fr_t* in, uint32_t lg_domain_size)
#pragma unroll
for (uint32_t i = 0; i < Z_COUNT; i++)
regs[i] = in[i * step + base_rev];
} else {
#pragma unroll
for (uint32_t i = 0; i < Z_COUNT; i++)
regs[i].zero();
}

asm("" : "+v"(base_idx));
asm("" : "+v"(base_rev));
#endif

(Z_COUNT > WARP_SZ) ? __syncthreads() : __syncwarp();
(Z_COUNT > warpSize) ? __syncthreads() : __syncwarp();

#pragma unroll
for (uint32_t i = 0; i < Z_COUNT; i++)
@@ -103,20 +112,24 @@ void bit_rev_permutation_z(fr_t* out, const fr_t* in, uint32_t lg_domain_size)
if (group_idx == group_rev)
continue;

(Z_COUNT > WARP_SZ) ? __syncthreads() : __syncwarp();
(Z_COUNT > warpSize) ? __syncthreads() : __syncwarp();

#pragma unroll
for (uint32_t i = 0; i < Z_COUNT; i++)
xchg[gid][i][rev] = regs[i];

(Z_COUNT > WARP_SZ) ? __syncthreads() : __syncwarp();
(Z_COUNT > warpSize) ? __syncthreads() : __syncwarp();

#pragma unroll
for (uint32_t i = 0; i < Z_COUNT; i++)
out[i * step + base_idx] = xchg[gid][rev][i];

#ifdef __CUDA_ARCH__
} while (Z_COUNT <= WARP_SZ && (tid += blockDim.x*gridDim.x) < step);
// without "Z_COUNT <= WARP_SZ" compiler spills 128 bytes to stack :-(
#else
} while ((tid += blockDim.x*gridDim.x) < step);
#endif
}

template<class fr_t>
5 changes: 3 additions & 2 deletions ntt/kernels/ct_mixed_radix_wide.cu
Original file line number Diff line number Diff line change
@@ -137,19 +137,20 @@ class CT_launcher {
int stage;
const NTTParameters& ntt_parameters;
const stream_t& stream;
int min_radix;

public:
CT_launcher(fr_t* d_ptr, int lg_dsz, bool intt,
const NTTParameters& params, const stream_t& s)
: d_inout(d_ptr), lg_domain_size(lg_dsz), is_intt(intt), stage(0),
ntt_parameters(params), stream(s)
{}
{ min_radix = lg2(gpu_props(s).warpSize) + 1; }

void step(int iterations)
{
assert(iterations <= 10);

const int radix = iterations < 6 ? 6 : iterations;
const int radix = iterations < min_radix ? min_radix : iterations;

index_t num_threads = (index_t)1 << (lg_domain_size - 1);
index_t block_size = 1 << (radix - 1);
5 changes: 3 additions & 2 deletions ntt/kernels/gs_mixed_radix_wide.cu
Original file line number Diff line number Diff line change
@@ -133,19 +133,20 @@ class GS_launcher {
int stage;
const NTTParameters& ntt_parameters;
const stream_t& stream;
int min_radix;

public:
GS_launcher(fr_t* d_ptr, int lg_dsz, bool innt,
const NTTParameters& params, const stream_t& s)
: d_inout(d_ptr), lg_domain_size(lg_dsz), is_intt(innt), stage(lg_dsz),
ntt_parameters(params), stream(s)
{}
{ min_radix = lg2(gpu_props(s).warpSize) + 1; }

void step(int iterations)
{
assert(iterations <= 10);

const int radix = iterations < 6 ? 6 : iterations;
const int radix = iterations < min_radix ? min_radix : iterations;

index_t num_threads = (index_t)1 << (lg_domain_size - 1);
index_t block_size = 1 << (radix - 1);
17 changes: 12 additions & 5 deletions ntt/ntt.cuh
Original file line number Diff line number Diff line change
@@ -44,15 +44,21 @@ protected:
size_t domain_size = (size_t)1 << lg_domain_size;
// aim to read 4 cache lines of consecutive data per read
const uint32_t Z_COUNT = 256 / sizeof(fr_t);
const uint32_t bsize = Z_COUNT>WARP_SZ ? Z_COUNT : WARP_SZ;
const uint32_t warpSize = gpu_props(stream).warpSize;
const uint32_t bsize = Z_COUNT>warpSize ? Z_COUNT : warpSize;
#ifdef __HIPCC__
const uint32_t lg_switch = 17;
#else
const uint32_t lg_switch = 32;
#endif

if (domain_size <= 1024)
bit_rev_permutation<<<1, domain_size, 0, stream>>>
(d_out, d_inp, lg_domain_size);
else if (domain_size < bsize * Z_COUNT)
bit_rev_permutation<<<domain_size / WARP_SZ, WARP_SZ, 0, stream>>>
bit_rev_permutation<<<domain_size / bsize, bsize, 0, stream>>>
(d_out, d_inp, lg_domain_size);
else if (Z_COUNT > WARP_SZ || lg_domain_size <= 32)
else if (Z_COUNT > warpSize || lg_domain_size <= lg_switch)
bit_rev_permutation_z<Z_COUNT><<<domain_size / Z_COUNT / bsize, bsize,
bsize * Z_COUNT * sizeof(fr_t),
stream>>>
@@ -74,14 +80,15 @@ private:
stream_t& stream)
{
size_t domain_size = (size_t)1 << lg_dsz;
const uint32_t warpSize = gpu_props(stream).warpSize;
const auto gen_powers =
NTTParameters::all(innt)[stream].partial_group_gen_powers;

if (domain_size < WARP_SZ)
if (domain_size < warpSize)
LDE_distribute_powers<<<1, domain_size, 0, stream>>>
(inout, lg_dsz, lg_blowup, bitrev, gen_powers);
else if (lg_dsz < 32)
LDE_distribute_powers<<<domain_size / WARP_SZ, WARP_SZ, 0, stream>>>
LDE_distribute_powers<<<domain_size / warpSize, warpSize, 0, stream>>>
(inout, lg_dsz, lg_blowup, bitrev, gen_powers);
else
LDE_distribute_powers<<<stream.sm_count(), 1024, 0, stream>>>

0 comments on commit f49ef48

Please sign in to comment.