Skip to content

Commit

Permalink
polynomial/div_by_x_minus_z.cuh: lift limitations on the N parameter.
Browse files Browse the repository at this point in the history
  • Loading branch information
dot-asm committed Jan 9, 2025
1 parent c907b93 commit 088bb40
Showing 1 changed file with 55 additions and 33 deletions.
88 changes: 55 additions & 33 deletions polynomial/div_by_x_minus_z.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ template<class fr_t, int N, bool rotate, int BSZ>
__global__ __launch_bounds__(BSZ)
void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
{
static_assert(!rotate || N <= 2, "unsupported template parameter value");

struct my {
__device__ __forceinline__
static void madd_up(fr_t& coeff, fr_t& z_pow, uint32_t limit = WARP_SZ)
Expand Down Expand Up @@ -152,7 +150,7 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
*
* If |rotate| is false, the first element of the output is
* the remainder and the rest is the quotient. Otherwise
* the remainder is stored at the end and the quotiend is
* the remainder is stored at the end and the quotient is
* "shifted" toward the beginning of the |d_inout| vector.
*/
class rev_ptr_t {
Expand Down Expand Up @@ -189,17 +187,18 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
for (int i = 1; i < N; i++)
coeff[i] += coeff[i-1] * z_pow;

fr_t carry_over;
bool tail_sync = false;

if (N>1 || sizeof(fr_t) <= 32) {
if (sizeof(fr_t) <= 32) {
my::madd_up(coeff[N-1], z_pow = z_n);

if (laneid == WARP_SZ-1)
xchg[warpid] = coeff[N-1];

__syncthreads();

fr_t carry_over = xchg[laneid];
carry_over = xchg[laneid];

my::madd_up(carry_over, z_pow, nwarps);

Expand Down Expand Up @@ -272,27 +271,8 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
carry_over += carry * (z_pow = z_top_carry);
}
}

if (N > 1) {
if (laneid == WARP_SZ-1)
xchg[warpid] = coeff[N-1];

__syncthreads();

fr_t carry = shfl_up(coeff[N-1], 1);

if (laneid == 0 && warpid != 0)
carry_over = xchg[warpid-1];

carry = fr_t::csel(carry_over, carry, laneid == 0);

z_pow = z;
#pragma unroll
for (int i = 0; i < N-1; i++)
coeff[i] += (carry *= z_pow);
}
} else { // ~14KB loop size with 256-bit field, yet unused...
fr_t z_pow_adjust, carry_over, acc = coeff[N-1];
fr_t z_pow_adjust, offload, acc = coeff[N-1];

z_pow = z_n;
uint32_t limit = WARP_SZ;
Expand All @@ -305,8 +285,7 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
if (adjust != 0) {
acc = shfl_idx(acc, adjust - 1);
tail_mul:
acc *= z_pow_adjust;
coeff[N-1] += acc;
coeff[N-1] += acc * z_pow_adjust;
}

switch (++pc) {
Expand All @@ -323,6 +302,8 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
limit = nwarps;
adjust = warpid;
z_pow_adjust = z_pow_warp;
if (N > 1)
carry_over.zero();
break;
case 1:
if (gridDim.x > 1 && len - chunk > N*blockDim.x) {
Expand Down Expand Up @@ -378,6 +359,8 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
coeff[N-1] = carry_over;
acc = xchg[blockIdx.x-1];
z_pow_adjust = z_pow_block;
if (N > 1)
carry_over = acc;
pc = 3;
goto tail_mul;
case 4:
Expand All @@ -391,13 +374,51 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
z_pow_adjust = z_pow_grid;
pc = 4;
goto tail_mul;
case 5:
if (N > 1) {
if (blockIdx.x == 0) {
carry_over = acc;
} else {
offload = coeff[N-1];
coeff[N-1] = carry_over;
z_pow_adjust = z_top_carry;
goto tail_mul;
}
}
pc = -1;
break;
case 6:
if (N > 1) {
carry_over = coeff[N-1];
coeff[N-1] = offload;
}
// fall through
default:
pc = -1;
break;
}
} while (pc >= 0);
}

if (N > 1) {
if (laneid == WARP_SZ-1)
xchg[warpid] = coeff[N-1];

__syncthreads();

fr_t carry = shfl_up(coeff[N-1], 1);

if (laneid == 0 && warpid != 0)
carry_over = xchg[warpid-1];

carry = fr_t::csel(carry_over, carry, laneid == 0);

z_pow = z;
#pragma unroll
for (int i = 0; i < N-1; i++)
coeff[i] += (carry *= z_pow);
}

if (tail_sync) {
__grid.sync();
__syncthreads();
Expand All @@ -411,12 +432,13 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
}

if (rotate) {
if (N == 1) {
if (idx == len - 1)
inout[0] = coeff[0];
} else { // only N==2 supported for the moment
if (idx == len - 2 + (len&1))
inout[0] = fr_t::csel(coeff[0], coeff[1], len&1);
int rem = static_cast<int>(--len % N);
if (idx == len - rem) {
#pragma unroll
for (int i = 1; i < N; i++)
coeff[0] = fr_t::csel(coeff[i], coeff[0], i == rem);

inout[0] = coeff[0];
}
}
}
Expand Down

0 comments on commit 088bb40

Please sign in to comment.