diff --git a/polynomial/prefix_op.cuh b/polynomial/prefix_op.cuh index 76d33b0..93ffc4e 100644 --- a/polynomial/prefix_op.cuh +++ b/polynomial/prefix_op.cuh @@ -110,7 +110,10 @@ void d_prefix_op(OutPtr out, InPtr inp, size_t len) const uint32_t chunk_size = blockDim.x * CHUNK; const uint32_t blob_size = gridDim.x * chunk_size; - const uint32_t lane_off = laneid + (tid / WARP_SZ) * WARP_SZ * CHUNK; + + constexpr bool coalesce = CHUNK/sizeof(fr_t) > 1; + const uint32_t lane_off = (tid / WARP_SZ) * WARP_SZ * CHUNK + + (coalesce ? laneid : laneid*CHUNK); const Operation op; const fr_t identity = op.identity(); @@ -147,7 +150,7 @@ void d_prefix_op(OutPtr out, InPtr inp, size_t len) #pragma unroll for (int i = do_prefetch; i < CHUNK; i++) { - size_t idx = lane_idx + WARP_SZ * i; + size_t idx = lane_idx + (coalesce ? WARP_SZ*i : i); if (top == CHUNK && idx >= len) top = i; @@ -179,11 +182,19 @@ void d_prefix_op(OutPtr out, InPtr inp, size_t len) chunk[i] = fr_t::csel(chunk[i], identity, i < top); #endif - warp::prefix_op(chunk); + if (coalesce) { + warp::prefix_op(chunk); - #pragma unroll - for (int i = 1; i < CHUNK; i++) - chunk[i] = op(chunk[i], shfl_idx(chunk[i-1], WARP_SZ-1)); + #pragma unroll + for (int i = 1; i < CHUNK; i++) + chunk[i] = op(chunk[i], shfl_idx(chunk[i-1], WARP_SZ-1)); + } else { + #pragma unroll + for (int i = 1; i < CHUNK; i++) + chunk[i] = op(chunk[i], chunk[i-1]); + + chunk[CHUNK-1] = warp::prefix_op(chunk[CHUNK-1]); + } if (laneid == WARP_SZ-1 && warpid < 1024/WARP_SZ-1) xchg[warpid] = chunk[CHUNK - 1]; @@ -198,7 +209,16 @@ void d_prefix_op(OutPtr out, InPtr inp, size_t len) warp_carry = warp::prefix_op(warp_carry, nwarps); warp_carry = shfl_idx(warp_carry, warpid); - chunk[CHUNK - 1] = op(chunk[CHUNK - 1], warp_carry); + if (coalesce) { + chunk[CHUNK - 1] = op(chunk[CHUNK - 1], warp_carry); + } else { + fr_t lane_carry = shfl_up(chunk[CHUNK-1], 1); + lane_carry = fr_t::csel(identity, lane_carry, laneid == 0); + + chunk[CHUNK - 1] = op(chunk[CHUNK - 1], warp_carry); + + warp_carry = op(warp_carry, lane_carry); + } fr_t grid_carry_in = grid_carry; @@ -296,7 +316,7 @@ void d_prefix_op(OutPtr out, InPtr inp, size_t len) #pragma unroll for(int i = 0; i < CHUNK; i++) { if (i < top) - out[lane_idx + WARP_SZ * i] = chunk[i]; + out[lane_idx + (coalesce ? WARP_SZ*i : i)] = chunk[i]; } } }