Skip to content

Commit

Permalink
polynomial/prefix_op.cuh: improve performance.
Browse files Browse the repository at this point in the history
  • Loading branch information
dot-asm committed Jan 8, 2025
1 parent 397d6c7 commit 3efb7f4
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions polynomial/prefix_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@ void d_prefix_op(OutPtr out, InPtr inp, size_t len)

const uint32_t chunk_size = blockDim.x * CHUNK;
const uint32_t blob_size = gridDim.x * chunk_size;
const uint32_t lane_off = laneid + (tid / WARP_SZ) * WARP_SZ * CHUNK;

constexpr bool coalesce = CHUNK/sizeof(fr_t) > 1;
const uint32_t lane_off = (tid / WARP_SZ) * WARP_SZ * CHUNK +
(coalesce ? laneid : laneid*CHUNK);

const Operation op;
const fr_t identity = op.identity();
Expand Down Expand Up @@ -147,7 +150,7 @@ void d_prefix_op(OutPtr out, InPtr inp, size_t len)

#pragma unroll
for (int i = do_prefetch; i < CHUNK; i++) {
size_t idx = lane_idx + WARP_SZ * i;
size_t idx = lane_idx + (coalesce ? WARP_SZ*i : i);

if (top == CHUNK && idx >= len)
top = i;
Expand Down Expand Up @@ -179,11 +182,19 @@ void d_prefix_op(OutPtr out, InPtr inp, size_t len)
chunk[i] = fr_t::csel(chunk[i], identity, i < top);
#endif

warp::prefix_op(chunk);
if (coalesce) {
warp::prefix_op(chunk);

#pragma unroll
for (int i = 1; i < CHUNK; i++)
chunk[i] = op(chunk[i], shfl_idx(chunk[i-1], WARP_SZ-1));
#pragma unroll
for (int i = 1; i < CHUNK; i++)
chunk[i] = op(chunk[i], shfl_idx(chunk[i-1], WARP_SZ-1));
} else {
#pragma unroll
for (int i = 1; i < CHUNK; i++)
chunk[i] = op(chunk[i], chunk[i-1]);

chunk[CHUNK-1] = warp::prefix_op(chunk[CHUNK-1]);
}

if (laneid == WARP_SZ-1 && warpid < 1024/WARP_SZ-1)
xchg[warpid] = chunk[CHUNK - 1];
Expand All @@ -198,7 +209,16 @@ void d_prefix_op(OutPtr out, InPtr inp, size_t len)
warp_carry = warp::prefix_op(warp_carry, nwarps);
warp_carry = shfl_idx(warp_carry, warpid);

chunk[CHUNK - 1] = op(chunk[CHUNK - 1], warp_carry);
if (coalesce) {
chunk[CHUNK - 1] = op(chunk[CHUNK - 1], warp_carry);
} else {
fr_t lane_carry = shfl_up(chunk[CHUNK-1], 1);
lane_carry = fr_t::csel(identity, lane_carry, laneid == 0);

chunk[CHUNK - 1] = op(chunk[CHUNK - 1], warp_carry);

warp_carry = op(warp_carry, lane_carry);
}

fr_t grid_carry_in = grid_carry;

Expand Down Expand Up @@ -296,7 +316,7 @@ void d_prefix_op(OutPtr out, InPtr inp, size_t len)
#pragma unroll
for(int i = 0; i < CHUNK; i++) {
if (i < top)
out[lane_idx + WARP_SZ * i] = chunk[i];
out[lane_idx + (coalesce ? WARP_SZ*i : i)] = chunk[i];
}
}
}
Expand Down

0 comments on commit 3efb7f4

Please sign in to comment.