Skip to content

Commit

Permalink
fixup! polynomial/div_by_x_minus_z.cuh: add |rotate| template parameter.
Browse files Browse the repository at this point in the history
  • Loading branch information
dot-asm committed Oct 15, 2024
1 parent 86ad180 commit ab3ecbb
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions polynomial/div_by_x_minus_z.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <cooperative_groups.h>
#include <ff/shfl.cuh>

template<class fr_t, int BSZ, bool rotate> __global__ __launch_bounds__(BSZ)
template<class fr_t, bool rotate, int BSZ> __global__ __launch_bounds__(BSZ)
void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
{
struct my {
Expand Down Expand Up @@ -145,7 +145,7 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
size_t idx;
auto __grid = cooperative_groups::this_grid();

if (tid < stride)
if (tid < len)
prefetch = inout[tid];

for (size_t chunk = 0; chunk < len; chunk += stride) {
Expand Down Expand Up @@ -174,8 +174,10 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
coeff += carry_over;
}

if (gridDim.x > 1) {
if (gridDim.x > 1 && len - chunk > blockDim.x) {
size_t grid_idx = chunk + blockIdx.x*blockDim.x;
if (rotate)
grid_idx += blockIdx.x == 0;
if (threadIdx.x == blockDim.x-1 && grid_idx < len)
inout[grid_idx] = coeff;

Expand All @@ -184,6 +186,8 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)

if (blockIdx.x != 0) {
grid_idx = chunk + threadIdx.x*blockDim.x;
if (rotate)
grid_idx += threadIdx.x == 0;
if (threadIdx.x < gridDim.x && grid_idx < len)
carry_over = inout[grid_idx];

Expand Down Expand Up @@ -220,7 +224,7 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
}

if (chunk != 0) {
carry_over = inout[chunk - 1];
carry_over = inout[chunk - 1 + rotate];
carry_over *= z_pow_grid;
coeff += carry_over;
}
Expand All @@ -229,7 +233,7 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)

acc = prefetch;

if (idx + stride > len)
if (idx + stride < len)
prefetch = inout[idx + stride];

z_pow = z;
Expand Down Expand Up @@ -345,7 +349,7 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
inout[0] = coeff;
}

template<class fr_t, bool rotate = false, class stream_t>
template<bool rotate = false, class fr_t, class stream_t>
void div_by_x_minus_z(fr_t d_inout[], size_t len, const fr_t& z,
const stream_t& s)
{
Expand All @@ -356,7 +360,7 @@ void div_by_x_minus_z(fr_t d_inout[], size_t len, const fr_t& z,

if (BSZ == 0) {
cudaFuncAttributes attr;
CUDA_OK(cudaFuncGetAttributes(&attr, d_div_by_x_minus_z<fr_t, BSZ, rotate>));
CUDA_OK(cudaFuncGetAttributes(&attr, d_div_by_x_minus_z<fr_t, rotate, BSZ>));
blockDim = attr.maxThreadsPerBlock;
}

Expand All @@ -374,7 +378,7 @@ void div_by_x_minus_z(fr_t d_inout[], size_t len, const fr_t& z,
size_t sharedSz = sizeof(fr_t) * max(blockDim/WARP_SZ, gridDim);
sharedSz += sizeof(fr_t) * WARP_SZ;

s.launch_coop(d_div_by_x_minus_z<fr_t, BSZ, rotate>,
s.launch_coop(d_div_by_x_minus_z<fr_t, rotate, BSZ>,
{gridDim, blockDim, sharedSz},
d_inout, len, z);
}
Expand Down

0 comments on commit ab3ecbb

Please sign in to comment.