From c907b932c8044b279225fccea6fc80a209b94075 Mon Sep 17 00:00:00 2001 From: Andy Polyakov Date: Wed, 8 Jan 2025 17:39:26 +0100 Subject: [PATCH] polynomial/div_by_x_minus_z.cuh: improve performance. This is achieved by halving the amount of grid synchronizations, which in turn allows reducing the computational complexity. --- polynomial/div_by_x_minus_z.cuh | 185 +++++++++++++++++++++----------- 1 file changed, 121 insertions(+), 64 deletions(-) diff --git a/polynomial/div_by_x_minus_z.cuh b/polynomial/div_by_x_minus_z.cuh index 44c46da..a4384a2 100644 --- a/polynomial/div_by_x_minus_z.cuh +++ b/polynomial/div_by_x_minus_z.cuh @@ -14,9 +14,12 @@ #endif #include -template __global__ __launch_bounds__(BSZ) +template +__global__ __launch_bounds__(BSZ) void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) { + static_assert(!rotate || N <= 2, "unsupported template parameter value"); + struct my { __device__ __forceinline__ static void madd_up(fr_t& coeff, fr_t& z_pow, uint32_t limit = WARP_SZ) @@ -59,13 +62,14 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) assert(blockDim.x%WARP_SZ == 0 && gridDim.x <= blockDim.x); #endif - const uint32_t tid = threadIdx.x + blockDim.x*blockIdx.x; + const uint32_t tidx = N * (threadIdx.x + blockDim.x*blockIdx.x); const uint32_t laneid = threadIdx.x % WARP_SZ; const uint32_t warpid = threadIdx.x / WARP_SZ; const uint32_t nwarps = blockDim.x / WARP_SZ; extern __shared__ int xchg_div_by_x_minus_z[]; fr_t* xchg = reinterpret_cast(xchg_div_by_x_minus_z); + static __shared__ fr_t z_pow_carry[WARP_SZ], z_top_block, z_top_carry, z_n; /* * Calculate ascending powers of |z| in ascending threads across @@ -76,6 +80,8 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) * implied elsewhere, gridDim.x <= blockDim.x.] */ fr_t z_pow = z; + if (N > 1) + z_pow ^= N; z_pow = my::mult_up(z_pow); fr_t z_pow_warp = z_pow; // z^(laneid+1) @@ -86,30 +92,40 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) z_pow_block = shfl_idx(z_pow, warpid - 1); z_pow_block *= z_pow_warp; } - fr_t z_top_block = shfl_idx(z_pow, nwarps - 1); + z_pow = shfl_idx(z_pow, nwarps - 1); + + if (threadIdx.x == 0) { + z_n = z_pow_warp; + z_top_block = z_pow; + } fr_t z_pow_grid = z_pow_block; // z^(blockDim.x*blockIdx.x+threadIdx.x+1) if (blockIdx.x != 0) { - z_pow = z_top_block; z_pow = my::mult_up(z_pow, min(WARP_SZ, gridDim.x)); z_pow_grid = shfl_idx(z_pow, (blockIdx.x - 1)%WARP_SZ); + + // Offload z^(z_top_block*(laneid+1)) to the shared memory to + // alleviate register pressure. + if (warpid == 0) + z_pow_carry[laneid] = z_pow; + if (blockIdx.x > WARP_SZ) { z_pow = shfl_idx(z_pow, WARP_SZ - 1); z_pow = my::mult_up(z_pow, (gridDim.x + WARP_SZ - 1)/WARP_SZ); z_pow = shfl_idx(z_pow, (blockIdx.x - 1)/WARP_SZ - 1); z_pow_grid *= z_pow; } + + if (threadIdx.x == 0) + z_top_carry = z_pow_grid; + z_pow_grid *= z_pow_block; } - // Calculate z^(z_top_block*(laneid+1)) and offload it to the shared - // memory to alleviate register pressure. - fr_t& z_pow_carry = xchg[max(blockDim.x/WARP_SZ, gridDim.x) + laneid]; - if (gridDim.x > WARP_SZ && warpid == 0) - z_pow_carry = my::mult_up(z_pow = z_top_block); + __syncthreads(); #if 0 - auto check = z^(tid+1); + auto check = z^(tidx+N); check -= z_pow_grid; assert(check.is_zero()); #endif @@ -147,58 +163,70 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) __device__ const fr_t& operator[](size_t i) const { return *(p - i); } }; rev_ptr_t inout{d_inout, len}; - fr_t coeff, carry_over, prefetch; - uint32_t stride = blockDim.x*gridDim.x; + fr_t coeff[N], prefetch; + uint32_t stride = N*blockDim.x*gridDim.x; size_t idx; auto __grid = cooperative_groups::this_grid(); - if (tid < len) - prefetch = inout[tid]; + if (tidx < len) + prefetch = inout[tidx]; for (size_t chunk = 0; chunk < len; chunk += stride) { - idx = chunk + tid; + idx = chunk + tidx; - bool tail_sync = false; + #pragma unroll + for (int i = 1; i < N; i++) { + if (idx + i < len) + coeff[i] = inout[idx + i]; + } + coeff[0] = prefetch; - if (sizeof(fr_t) <= 32) { - coeff = prefetch; + if (idx + stride < len) + prefetch = inout[idx + stride]; - if (idx + stride < len) - prefetch = inout[idx + stride]; + z_pow = z; + #pragma unroll + for (int i = 1; i < N; i++) + coeff[i] += coeff[i-1] * z_pow; - my::madd_up(coeff, z_pow = z); + bool tail_sync = false; + + if (N>1 || sizeof(fr_t) <= 32) { + my::madd_up(coeff[N-1], z_pow = z_n); if (laneid == WARP_SZ-1) - xchg[warpid] = coeff; + xchg[warpid] = coeff[N-1]; __syncthreads(); - carry_over = xchg[laneid]; + fr_t carry_over = xchg[laneid]; my::madd_up(carry_over, z_pow, nwarps); if (warpid != 0) { carry_over = shfl_idx(carry_over, warpid - 1); carry_over *= z_pow_warp; - coeff += carry_over; + coeff[N-1] += carry_over; } + carry_over.zero(); + size_t remaining = len - chunk; - if (gridDim.x > 1 && remaining > blockDim.x) { - tail_sync = remaining <= 2*stride - blockDim.x; + if (gridDim.x > 1 && remaining > N*blockDim.x) { + tail_sync = remaining <= 2*stride - N*blockDim.x; uint32_t bias = tail_sync ? 0 : stride; - size_t grid_idx = chunk + (blockIdx.x*blockDim.x + bias - + (rotate && blockIdx.x == 0)); + size_t grid_idx = chunk + (blockIdx.x*N*blockDim.x + bias + + N*(rotate && blockIdx.x == 0)); if (threadIdx.x == blockDim.x-1 && grid_idx < len) - inout[grid_idx] = coeff; + inout[grid_idx] = coeff[N-1]; __grid.sync(); __syncthreads(); if (blockIdx.x != 0) { - grid_idx = chunk + (threadIdx.x*blockDim.x + bias - + (rotate && threadIdx.x == 0)); + grid_idx = chunk + (threadIdx.x*N*blockDim.x + bias + + N*(rotate && threadIdx.x == 0)); if (threadIdx.x < gridDim.x && grid_idx < len) carry_over = inout[grid_idx]; @@ -218,7 +246,7 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) if (warpid != 0) { temp = shfl_idx(temp, warpid - 1); - temp *= (z_pow = z_pow_carry); + temp *= (z_pow = z_pow_carry[laneid]); carry_over += temp; } } @@ -229,25 +257,44 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) __syncthreads(); carry_over = xchg[blockIdx.x-1]; - carry_over *= z_pow_block; - coeff += carry_over; + coeff[N-1] += carry_over * z_pow_block; } } if (chunk != 0) { - carry_over = inout[chunk - !rotate]; - carry_over *= z_pow_grid; - coeff += carry_over; + fr_t carry = inout[chunk - !rotate]; + coeff[N-1] += carry * z_pow_grid; + + if (N > 1) { + if (blockIdx.x == 0) + carry_over = carry; + else + carry_over += carry * (z_pow = z_top_carry); + } } - } else { // ~14KB loop size with 256-bit field, yet unused... - fr_t acc, z_pow_adjust; - acc = prefetch; + if (N > 1) { + if (laneid == WARP_SZ-1) + xchg[warpid] = coeff[N-1]; + + __syncthreads(); + + fr_t carry = shfl_up(coeff[N-1], 1); + + if (laneid == 0 && warpid != 0) + carry_over = xchg[warpid-1]; - if (idx + stride < len) - prefetch = inout[idx + stride]; + carry = fr_t::csel(carry_over, carry, laneid == 0); - z_pow = z; + z_pow = z; + #pragma unroll + for (int i = 0; i < N-1; i++) + coeff[i] += (carry *= z_pow); + } + } else { // ~14KB loop size with 256-bit field, yet unused... + fr_t z_pow_adjust, carry_over, acc = coeff[N-1]; + + z_pow = z_n; uint32_t limit = WARP_SZ; uint32_t adjust = 0; int pc = -1; @@ -259,12 +306,12 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) acc = shfl_idx(acc, adjust - 1); tail_mul: acc *= z_pow_adjust; - coeff += acc; + coeff[N-1] += acc; } switch (++pc) { case 0: - coeff = acc; + coeff[N-1] = acc; if (laneid == WARP_SZ-1) xchg[warpid] = acc; @@ -278,20 +325,20 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) z_pow_adjust = z_pow_warp; break; case 1: - if (gridDim.x > 1 && len - chunk > blockDim.x) { - tail_sync = len - chunk <= 2*stride - blockDim.x; + if (gridDim.x > 1 && len - chunk > N*blockDim.x) { + tail_sync = len - chunk <= 2*stride - N*blockDim.x; uint32_t bias = tail_sync ? 0 : stride; - size_t xchg_idx = chunk + (blockIdx.x*blockDim.x + bias - + (rotate && blockIdx.x == 0)); + size_t xchg_idx = chunk + (blockIdx.x*N*blockDim.x + bias + + N*(rotate && blockIdx.x == 0)); if (threadIdx.x == blockDim.x-1 && xchg_idx < len) - inout[xchg_idx] = coeff; + inout[xchg_idx] = coeff[N-1]; __grid.sync(); __syncthreads(); if (blockIdx.x != 0) { - xchg_idx = chunk + (threadIdx.x*blockDim.x + bias - + (rotate && threadIdx.x == 0)); + xchg_idx = chunk + (threadIdx.x*N*blockDim.x + bias + + N*(rotate && threadIdx.x == 0)); if (threadIdx.x < gridDim.x && xchg_idx < len) acc = inout[xchg_idx]; @@ -306,8 +353,8 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) } break; case 2: // blockIdx.x != 0 - carry_over = coeff; - coeff = acc; + carry_over = coeff[N-1]; + coeff[N-1] = acc; if (gridDim.x > WARP_SZ) { if (laneid == WARP_SZ-1) @@ -319,16 +366,16 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) limit = (gridDim.x + WARP_SZ - 1)/WARP_SZ; adjust = warpid; - z_pow_adjust = z_pow_carry; + z_pow_adjust = z_pow_carry[laneid]; break; } // else fall through case 3: // blockIdx.x != 0 if (threadIdx.x < gridDim.x) - xchg[threadIdx.x] = coeff; + xchg[threadIdx.x] = coeff[N-1]; __syncthreads(); - coeff = carry_over; + coeff[N-1] = carry_over; acc = xchg[blockIdx.x-1]; z_pow_adjust = z_pow_block; pc = 3; @@ -356,12 +403,22 @@ void d_div_by_x_minus_z(fr_t d_inout[], size_t len, fr_t z) __syncthreads(); } - if (idx < len - rotate) - inout[idx + rotate] = coeff; + #pragma unroll + for (int i = 0; i < N; i++) { + if (idx + i < len - rotate) + inout[idx + i + rotate] = coeff[i]; + } } - if (rotate && idx == len - 1) - inout[0] = coeff; + if (rotate) { + if (N == 1) { + if (idx == len - 1) + inout[0] = coeff[0]; + } else { // only N==2 supported for the moment + if (idx == len - 2 + (len&1)) + inout[0] = fr_t::csel(coeff[0], coeff[1], len&1); + } + } } template @@ -371,6 +428,7 @@ void div_by_x_minus_z(fr_t d_inout[], size_t len, const fr_t& z, if (gridDim <= 0) gridDim = s.sm_count(); + constexpr int N = 2; constexpr int BSZ = sizeof(fr_t) <= 16 ? 1024 : 0; int blockDim = BSZ; @@ -379,7 +437,7 @@ void div_by_x_minus_z(fr_t d_inout[], size_t len, const fr_t& z, if (saved_blockDim == 0) { cudaFuncAttributes attr; - CUDA_OK(cudaFuncGetAttributes(&attr, d_div_by_x_minus_z)); + CUDA_OK(cudaFuncGetAttributes(&attr, d_div_by_x_minus_z)); saved_blockDim = attr.maxThreadsPerBlock; assert(saved_blockDim%WARP_SZ == 0); } @@ -399,9 +457,8 @@ void div_by_x_minus_z(fr_t d_inout[], size_t len, const fr_t& z, gridDim = 1; size_t sharedSz = sizeof(fr_t) * max(blockDim/WARP_SZ, gridDim); - sharedSz += sizeof(fr_t) * WARP_SZ; - s.launch_coop(d_div_by_x_minus_z, + s.launch_coop(d_div_by_x_minus_z, {gridDim, blockDim, sharedSz}, d_inout, len, z); }