Skip to content

Commit

Permalink
polynomial/div_by_x_minus_z.cuh: improve performance.
Browse files Browse the repository at this point in the history
This is achieved by halving the amount of grid synchronizations, which
in turn allows reducing the computational complexity.
  • Loading branch information
dot-asm committed Jan 8, 2025
1 parent 3efb7f4 commit c907b93
Showing 1 changed file with 121 additions and 64 deletions.
185 changes: 121 additions & 64 deletions polynomial/div_by_x_minus_z.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
#endif
#include <ff/shfl.cuh>

template<class fr_t, bool rotate, int BSZ> __global__ __launch_bounds__(BSZ)
template<class fr_t, int N, bool rotate, int BSZ>
__global__ __launch_bounds__(BSZ)
void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
{
static_assert(!rotate || N <= 2, "unsupported template parameter value");

struct my {
__device__ __forceinline__
static void madd_up(fr_t& coeff, fr_t& z_pow, uint32_t limit = WARP_SZ)
Expand Down Expand Up @@ -59,13 +62,14 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
assert(blockDim.x%WARP_SZ == 0 && gridDim.x <= blockDim.x);
#endif

const uint32_t tid = threadIdx.x + blockDim.x*blockIdx.x;
const uint32_t tidx = N * (threadIdx.x + blockDim.x*blockIdx.x);
const uint32_t laneid = threadIdx.x % WARP_SZ;
const uint32_t warpid = threadIdx.x / WARP_SZ;
const uint32_t nwarps = blockDim.x / WARP_SZ;

extern __shared__ int xchg_div_by_x_minus_z[];
fr_t* xchg = reinterpret_cast<decltype(xchg)>(xchg_div_by_x_minus_z);
static __shared__ fr_t z_pow_carry[WARP_SZ], z_top_block, z_top_carry, z_n;

/*
* Calculate ascending powers of |z| in ascending threads across
Expand All @@ -76,6 +80,8 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
* implied elsewhere, gridDim.x <= blockDim.x.]
*/
fr_t z_pow = z;
if (N > 1)
z_pow ^= N;
z_pow = my::mult_up(z_pow);
fr_t z_pow_warp = z_pow; // z^(laneid+1)

Expand All @@ -86,30 +92,40 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
z_pow_block = shfl_idx(z_pow, warpid - 1);
z_pow_block *= z_pow_warp;
}
fr_t z_top_block = shfl_idx(z_pow, nwarps - 1);
z_pow = shfl_idx(z_pow, nwarps - 1);

if (threadIdx.x == 0) {
z_n = z_pow_warp;
z_top_block = z_pow;
}

fr_t z_pow_grid = z_pow_block; // z^(blockDim.x*blockIdx.x+threadIdx.x+1)
if (blockIdx.x != 0) {
z_pow = z_top_block;
z_pow = my::mult_up(z_pow, min(WARP_SZ, gridDim.x));
z_pow_grid = shfl_idx(z_pow, (blockIdx.x - 1)%WARP_SZ);

// Offload z^(z_top_block*(laneid+1)) to the shared memory to
// alleviate register pressure.
if (warpid == 0)
z_pow_carry[laneid] = z_pow;

if (blockIdx.x > WARP_SZ) {
z_pow = shfl_idx(z_pow, WARP_SZ - 1);
z_pow = my::mult_up(z_pow, (gridDim.x + WARP_SZ - 1)/WARP_SZ);
z_pow = shfl_idx(z_pow, (blockIdx.x - 1)/WARP_SZ - 1);
z_pow_grid *= z_pow;
}

if (threadIdx.x == 0)
z_top_carry = z_pow_grid;

z_pow_grid *= z_pow_block;
}

// Calculate z^(z_top_block*(laneid+1)) and offload it to the shared
// memory to alleviate register pressure.
fr_t& z_pow_carry = xchg[max(blockDim.x/WARP_SZ, gridDim.x) + laneid];
if (gridDim.x > WARP_SZ && warpid == 0)
z_pow_carry = my::mult_up(z_pow = z_top_block);
__syncthreads();

#if 0
auto check = z^(tid+1);
auto check = z^(tidx+N);
check -= z_pow_grid;
assert(check.is_zero());
#endif
Expand Down Expand Up @@ -147,58 +163,70 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
__device__ const fr_t& operator[](size_t i) const { return *(p - i); }
};
rev_ptr_t inout{d_inout, len};
fr_t coeff, carry_over, prefetch;
uint32_t stride = blockDim.x*gridDim.x;
fr_t coeff[N], prefetch;
uint32_t stride = N*blockDim.x*gridDim.x;
size_t idx;
auto __grid = cooperative_groups::this_grid();

if (tid < len)
prefetch = inout[tid];
if (tidx < len)
prefetch = inout[tidx];

for (size_t chunk = 0; chunk < len; chunk += stride) {
idx = chunk + tid;
idx = chunk + tidx;

bool tail_sync = false;
#pragma unroll
for (int i = 1; i < N; i++) {
if (idx + i < len)
coeff[i] = inout[idx + i];
}
coeff[0] = prefetch;

if (sizeof(fr_t) <= 32) {
coeff = prefetch;
if (idx + stride < len)
prefetch = inout[idx + stride];

if (idx + stride < len)
prefetch = inout[idx + stride];
z_pow = z;
#pragma unroll
for (int i = 1; i < N; i++)
coeff[i] += coeff[i-1] * z_pow;

my::madd_up(coeff, z_pow = z);
bool tail_sync = false;

if (N>1 || sizeof(fr_t) <= 32) {
my::madd_up(coeff[N-1], z_pow = z_n);

if (laneid == WARP_SZ-1)
xchg[warpid] = coeff;
xchg[warpid] = coeff[N-1];

__syncthreads();

carry_over = xchg[laneid];
fr_t carry_over = xchg[laneid];

my::madd_up(carry_over, z_pow, nwarps);

if (warpid != 0) {
carry_over = shfl_idx(carry_over, warpid - 1);
carry_over *= z_pow_warp;
coeff += carry_over;
coeff[N-1] += carry_over;
}

carry_over.zero();

size_t remaining = len - chunk;

if (gridDim.x > 1 && remaining > blockDim.x) {
tail_sync = remaining <= 2*stride - blockDim.x;
if (gridDim.x > 1 && remaining > N*blockDim.x) {
tail_sync = remaining <= 2*stride - N*blockDim.x;
uint32_t bias = tail_sync ? 0 : stride;
size_t grid_idx = chunk + (blockIdx.x*blockDim.x + bias
+ (rotate && blockIdx.x == 0));
size_t grid_idx = chunk + (blockIdx.x*N*blockDim.x + bias
+ N*(rotate && blockIdx.x == 0));
if (threadIdx.x == blockDim.x-1 && grid_idx < len)
inout[grid_idx] = coeff;
inout[grid_idx] = coeff[N-1];

__grid.sync();
__syncthreads();

if (blockIdx.x != 0) {
grid_idx = chunk + (threadIdx.x*blockDim.x + bias
+ (rotate && threadIdx.x == 0));
grid_idx = chunk + (threadIdx.x*N*blockDim.x + bias
+ N*(rotate && threadIdx.x == 0));
if (threadIdx.x < gridDim.x && grid_idx < len)
carry_over = inout[grid_idx];

Expand All @@ -218,7 +246,7 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)

if (warpid != 0) {
temp = shfl_idx(temp, warpid - 1);
temp *= (z_pow = z_pow_carry);
temp *= (z_pow = z_pow_carry[laneid]);
carry_over += temp;
}
}
Expand All @@ -229,25 +257,44 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
__syncthreads();

carry_over = xchg[blockIdx.x-1];
carry_over *= z_pow_block;
coeff += carry_over;
coeff[N-1] += carry_over * z_pow_block;
}
}

if (chunk != 0) {
carry_over = inout[chunk - !rotate];
carry_over *= z_pow_grid;
coeff += carry_over;
fr_t carry = inout[chunk - !rotate];
coeff[N-1] += carry * z_pow_grid;

if (N > 1) {
if (blockIdx.x == 0)
carry_over = carry;
else
carry_over += carry * (z_pow = z_top_carry);
}
}
} else { // ~14KB loop size with 256-bit field, yet unused...
fr_t acc, z_pow_adjust;

acc = prefetch;
if (N > 1) {
if (laneid == WARP_SZ-1)
xchg[warpid] = coeff[N-1];

__syncthreads();

fr_t carry = shfl_up(coeff[N-1], 1);

if (laneid == 0 && warpid != 0)
carry_over = xchg[warpid-1];

if (idx + stride < len)
prefetch = inout[idx + stride];
carry = fr_t::csel(carry_over, carry, laneid == 0);

z_pow = z;
z_pow = z;
#pragma unroll
for (int i = 0; i < N-1; i++)
coeff[i] += (carry *= z_pow);
}
} else { // ~14KB loop size with 256-bit field, yet unused...
fr_t z_pow_adjust, carry_over, acc = coeff[N-1];

z_pow = z_n;
uint32_t limit = WARP_SZ;
uint32_t adjust = 0;
int pc = -1;
Expand All @@ -259,12 +306,12 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
acc = shfl_idx(acc, adjust - 1);
tail_mul:
acc *= z_pow_adjust;
coeff += acc;
coeff[N-1] += acc;
}

switch (++pc) {
case 0:
coeff = acc;
coeff[N-1] = acc;

if (laneid == WARP_SZ-1)
xchg[warpid] = acc;
Expand All @@ -278,20 +325,20 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
z_pow_adjust = z_pow_warp;
break;
case 1:
if (gridDim.x > 1 && len - chunk > blockDim.x) {
tail_sync = len - chunk <= 2*stride - blockDim.x;
if (gridDim.x > 1 && len - chunk > N*blockDim.x) {
tail_sync = len - chunk <= 2*stride - N*blockDim.x;
uint32_t bias = tail_sync ? 0 : stride;
size_t xchg_idx = chunk + (blockIdx.x*blockDim.x + bias
+ (rotate && blockIdx.x == 0));
size_t xchg_idx = chunk + (blockIdx.x*N*blockDim.x + bias
+ N*(rotate && blockIdx.x == 0));
if (threadIdx.x == blockDim.x-1 && xchg_idx < len)
inout[xchg_idx] = coeff;
inout[xchg_idx] = coeff[N-1];

__grid.sync();
__syncthreads();

if (blockIdx.x != 0) {
xchg_idx = chunk + (threadIdx.x*blockDim.x + bias
+ (rotate && threadIdx.x == 0));
xchg_idx = chunk + (threadIdx.x*N*blockDim.x + bias
+ N*(rotate && threadIdx.x == 0));
if (threadIdx.x < gridDim.x && xchg_idx < len)
acc = inout[xchg_idx];

Expand All @@ -306,8 +353,8 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
}
break;
case 2: // blockIdx.x != 0
carry_over = coeff;
coeff = acc;
carry_over = coeff[N-1];
coeff[N-1] = acc;

if (gridDim.x > WARP_SZ) {
if (laneid == WARP_SZ-1)
Expand All @@ -319,16 +366,16 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)

limit = (gridDim.x + WARP_SZ - 1)/WARP_SZ;
adjust = warpid;
z_pow_adjust = z_pow_carry;
z_pow_adjust = z_pow_carry[laneid];
break;
} // else fall through
case 3: // blockIdx.x != 0
if (threadIdx.x < gridDim.x)
xchg[threadIdx.x] = coeff;
xchg[threadIdx.x] = coeff[N-1];

__syncthreads();

coeff = carry_over;
coeff[N-1] = carry_over;
acc = xchg[blockIdx.x-1];
z_pow_adjust = z_pow_block;
pc = 3;
Expand Down Expand Up @@ -356,12 +403,22 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z)
__syncthreads();
}

if (idx < len - rotate)
inout[idx + rotate] = coeff;
#pragma unroll
for (int i = 0; i < N; i++) {
if (idx + i < len - rotate)
inout[idx + i + rotate] = coeff[i];
}
}

if (rotate && idx == len - 1)
inout[0] = coeff;
if (rotate) {
if (N == 1) {
if (idx == len - 1)
inout[0] = coeff[0];
} else { // only N==2 supported for the moment
if (idx == len - 2 + (len&1))
inout[0] = fr_t::csel(coeff[0], coeff[1], len&1);
}
}
}

template<bool rotate = false, class fr_t, class stream_t>
Expand All @@ -371,6 +428,7 @@ void div_by_x_minus_z(fr_t d_inout[], size_t len, const fr_t& z,
if (gridDim <= 0)
gridDim = s.sm_count();

constexpr int N = 2;
constexpr int BSZ = sizeof(fr_t) <= 16 ? 1024 : 0;
int blockDim = BSZ;

Expand All @@ -379,7 +437,7 @@ void div_by_x_minus_z(fr_t d_inout[], size_t len, const fr_t& z,

if (saved_blockDim == 0) {
cudaFuncAttributes attr;
CUDA_OK(cudaFuncGetAttributes(&attr, d_div_by_x_minus_z<fr_t, rotate, BSZ>));
CUDA_OK(cudaFuncGetAttributes(&attr, d_div_by_x_minus_z<fr_t, N, rotate, BSZ>));
saved_blockDim = attr.maxThreadsPerBlock;
assert(saved_blockDim%WARP_SZ == 0);
}
Expand All @@ -399,9 +457,8 @@ void div_by_x_minus_z(fr_t d_inout[], size_t len, const fr_t& z,
gridDim = 1;

size_t sharedSz = sizeof(fr_t) * max(blockDim/WARP_SZ, gridDim);
sharedSz += sizeof(fr_t) * WARP_SZ;

s.launch_coop(d_div_by_x_minus_z<fr_t, rotate, BSZ>,
s.launch_coop(d_div_by_x_minus_z<fr_t, N, rotate, BSZ>,
{gridDim, blockDim, sharedSz},
d_inout, len, z);
}
Expand Down

0 comments on commit c907b93

Please sign in to comment.