diff --git a/polynomial/div_by_x_minus_z.cuh b/polynomial/div_by_x_minus_z.cuh index a4384a2..93a6fa6 100644 --- a/polynomial/div_by_x_minus_z.cuh +++ b/polynomial/div_by_x_minus_z.cuh @@ -18,8 +18,6 @@ template __global__ __launch_bounds__(BSZ) void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) { - static_assert(!rotate || N <= 2, "unsupported template parameter value"); - struct my { __device__ __forceinline__ static void madd_up(fr_t& coeff, fr_t& z_pow, uint32_t limit = WARP_SZ) @@ -152,7 +150,7 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) * * If |rotate| is false, the first element of the output is * the remainder and the rest is the quotient. Otherwise - * the remainder is stored at the end and the quotiend is + * the remainder is stored at the end and the quotient is * "shifted" toward the beginning of the |d_inout| vector. */ class rev_ptr_t { @@ -189,9 +187,10 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) for (int i = 1; i < N; i++) coeff[i] += coeff[i-1] * z_pow; + fr_t carry_over; bool tail_sync = false; - if (N>1 || sizeof(fr_t) <= 32) { + if (sizeof(fr_t) <= 32) { my::madd_up(coeff[N-1], z_pow = z_n); if (laneid == WARP_SZ-1) @@ -199,7 +198,7 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) __syncthreads(); - fr_t carry_over = xchg[laneid]; + carry_over = xchg[laneid]; my::madd_up(carry_over, z_pow, nwarps); @@ -272,27 +271,8 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) carry_over += carry * (z_pow = z_top_carry); } } - - if (N > 1) { - if (laneid == WARP_SZ-1) - xchg[warpid] = coeff[N-1]; - - __syncthreads(); - - fr_t carry = shfl_up(coeff[N-1], 1); - - if (laneid == 0 && warpid != 0) - carry_over = xchg[warpid-1]; - - carry = fr_t::csel(carry_over, carry, laneid == 0); - - z_pow = z; - #pragma unroll - for (int i = 0; i < N-1; i++) - coeff[i] += (carry *= z_pow); - } } else { // ~14KB loop size with 256-bit field, yet unused... - fr_t z_pow_adjust, carry_over, acc = coeff[N-1]; + fr_t z_pow_adjust, offload, acc = coeff[N-1]; z_pow = z_n; uint32_t limit = WARP_SZ; @@ -305,8 +285,7 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) if (adjust != 0) { acc = shfl_idx(acc, adjust - 1); tail_mul: - acc *= z_pow_adjust; - coeff[N-1] += acc; + coeff[N-1] += acc * z_pow_adjust; } switch (++pc) { @@ -323,6 +302,8 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) limit = nwarps; adjust = warpid; z_pow_adjust = z_pow_warp; + if (N > 1) + carry_over.zero(); break; case 1: if (gridDim.x > 1 && len - chunk > N*blockDim.x) { @@ -378,6 +359,8 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) coeff[N-1] = carry_over; acc = xchg[blockIdx.x-1]; z_pow_adjust = z_pow_block; + if (N > 1) + carry_over = acc; pc = 3; goto tail_mul; case 4: @@ -391,6 +374,25 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) z_pow_adjust = z_pow_grid; pc = 4; goto tail_mul; + case 5: + if (N > 1) { + if (blockIdx.x == 0) { + carry_over = acc; + } else { + offload = coeff[N-1]; + coeff[N-1] = carry_over; + z_pow_adjust = z_top_carry; + goto tail_mul; + } + } + pc = -1; + break; + case 6: + if (N > 1) { + carry_over = coeff[N-1]; + coeff[N-1] = offload; + } + // fall through default: pc = -1; break; @@ -398,6 +400,25 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) } while (pc >= 0); } + if (N > 1) { + if (laneid == WARP_SZ-1) + xchg[warpid] = coeff[N-1]; + + __syncthreads(); + + fr_t carry = shfl_up(coeff[N-1], 1); + + if (laneid == 0 && warpid != 0) + carry_over = xchg[warpid-1]; + + carry = fr_t::csel(carry_over, carry, laneid == 0); + + z_pow = z; + #pragma unroll + for (int i = 0; i < N-1; i++) + coeff[i] += (carry *= z_pow); + } + if (tail_sync) { __grid.sync(); __syncthreads(); @@ -411,12 +432,13 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) } if (rotate) { - if (N == 1) { - if (idx == len - 1) - inout[0] = coeff[0]; - } else { // only N==2 supported for the moment - if (idx == len - 2 + (len&1)) - inout[0] = fr_t::csel(coeff[0], coeff[1], len&1); + int rem = static_cast(--len % N); + if (idx == len - rem) { + #pragma unroll + for (int i = 1; i < N; i++) + coeff[0] = fr_t::csel(coeff[i], coeff[0], i == rem); + + inout[0] = coeff[0]; } } }